diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 37e0be3e375..eec4657b79e 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -27,8 +27,10 @@ import scala.Tuple2; import com.github.luben.zstd.ZstdException; +import com.github.luben.zstd.ZstdInputStream; import com.google.common.util.concurrent.Uninterruptibles; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; import net.jpountz.lz4.LZ4Exception; import org.apache.commons.lang3.tuple.Pair; import org.roaringbitmap.RoaringBitmap; @@ -193,6 +195,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private Decompressor decompressor; private ByteBuf currentChunk; + private boolean currentChunkCompressed = false; private boolean firstChunk = true; private PartitionReader currentReader; private final int fetchChunkMaxRetry; @@ -213,6 +216,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private final String localHostAddress; private boolean shouldDecompress; + private InputStream currentStream; private boolean shuffleIntegrityCheckEnabled; private long fetchExcludedWorkerExpireTimeout; private ConcurrentHashMap fetchExcludedWorkers; @@ -526,7 +530,9 @@ private ByteBuf getNextChunk() throws IOException { if (!currentReader.hasNext()) { return null; } - return currentReader.next(); + Pair result = currentReader.next(); + currentChunkCompressed = result.getRight(); + return result.getLeft(); } catch (Exception e) { shuffleClient.excludeFailedFetchLocation( currentReader.getLocation().hostAndFetchPort(), e); @@ -730,6 +736,7 @@ public synchronized void close() { compressedBuf = null; rawDataBuf = null; + closeCurrentStream(); batchesRead = null; locations = null; attempts = null; @@ -800,6 +807,34 @@ private void init() { rawDataBuf = new byte[bufferSize]; } + private void closeCurrentStream() { + if (currentStream != null) { + try { + currentStream.close(); + } catch (IOException ignored) { + } + currentStream = null; + } + } + + private void setupCurrentStream() throws IOException { + closeCurrentStream(); + if (currentChunk == null) return; + InputStream base = new ByteBufInputStream(currentChunk); + currentStream = currentChunkCompressed ? new ZstdInputStream(base) : base; + } + + /** Reads exactly len bytes; returns total read (< len only on EOF). */ + private static int readFully(InputStream in, byte[] buf, int off, int len) throws IOException { + int total = 0; + while (total < len) { + int n = in.read(buf, off + total, len - total); + if (n == -1) break; + total += n; + } + return total; + } + private boolean fillBuffer() throws IOException { try { if (firstChunk && currentReader != null) { @@ -814,10 +849,23 @@ private boolean fillBuffer() throws IOException { return false; } + if (currentStream == null) { + setupCurrentStream(); + } + LocationPushFailedBatches failedBatch = new LocationPushFailedBatches(); boolean hasData = false; - while (currentChunk.isReadable() || moveToNextChunk()) { - currentChunk.readBytes(sizeBuf); + while (true) { + int headerRead = readFully(currentStream, sizeBuf, 0, BATCH_HEADER_SIZE); + if (headerRead == 0) { + closeCurrentStream(); + if (!moveToNextChunk()) break; + setupCurrentStream(); + continue; + } else if (headerRead != BATCH_HEADER_SIZE) { + throw new IOException("Invalid EOF detected"); + } + int mapId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET); int attemptId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 4); int batchId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8); @@ -827,14 +875,16 @@ private boolean fillBuffer() throws IOException { if (size > compressedBuf.length) { compressedBuf = new byte[size]; } - - currentChunk.readBytes(compressedBuf, 0, size); + if (readFully(currentStream, compressedBuf, 0, size) != size) { + throw new IOException("Invalid EOF detected"); + } } else { if (size > rawDataBuf.length) { rawDataBuf = new byte[size]; } - - currentChunk.readBytes(rawDataBuf, 0, size); + if (readFully(currentStream, rawDataBuf, 0, size) != size) { + throw new IOException("Invalid EOF detected"); + } } // de-duplicate diff --git a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java index 735a532321a..367913a1fff 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java @@ -238,7 +238,7 @@ private void checkpoint() { } @Override - public ByteBuf next() throws Exception { + public Pair next() throws Exception { Pair chunk = null; checkpoint(); if (!fetchThreadStarted) { @@ -328,7 +328,7 @@ public ByteBuf next() throws Exception { } returnedChunks++; lastReturnedChunkId = chunk.getLeft(); - return chunk.getRight(); + return Pair.of(chunk.getRight(), false); } private void checkException() throws Exception { diff --git a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java index 0e795037614..b0e665fc6aa 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java @@ -31,6 +31,7 @@ import io.netty.buffer.Unpooled; import io.netty.util.ReferenceCounted; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -214,7 +215,7 @@ public boolean hasNext() { } @Override - public ByteBuf next() throws IOException, InterruptedException { + public Pair next() throws Exception { checkException(); if (chunkIndex <= endChunkIndex) { fetchChunks(); @@ -254,8 +255,12 @@ public ByteBuf next() throws IOException, InterruptedException { logger.error("PartitionReader thread interrupted while fetching data."); throw e; } + int chunkIdx = startChunkIndex + returnedChunks; returnedChunks++; - return chunk; + boolean compressed = + streamHandler.getChunkCompressedCount() > chunkIdx + && streamHandler.getChunkCompressed(chunkIdx); + return Pair.of(chunk, compressed); } private void checkException() throws IOException { diff --git a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java index 247eacff5f7..a835e0d2408 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/PartitionReader.java @@ -20,6 +20,7 @@ import java.util.Optional; import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.tuple.Pair; import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.protocol.PartitionLocation; @@ -27,7 +28,7 @@ public interface PartitionReader { boolean hasNext(); - ByteBuf next() throws Exception; + Pair next() throws Exception; void close(); diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index 7a066720364..27d25f6140f 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -185,7 +185,7 @@ private void checkpoint() { } @Override - public ByteBuf next() throws IOException, InterruptedException { + public Pair next() throws Exception { checkpoint(); checkException(); if (chunkIndex <= endChunkIndex) { @@ -229,7 +229,11 @@ public ByteBuf next() throws IOException, InterruptedException { returnedChunks++; inflightRequestCount--; lastReturnedChunkId = chunk.getLeft(); - return chunk.getRight(); + int chunkIdx = chunk.getLeft(); + boolean compressed = + streamHandler.getChunkCompressedCount() > chunkIdx + && streamHandler.getChunkCompressed(chunkIdx); + return Pair.of(chunk.getRight(), compressed); } @Override diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index a37513a236f..d64a0a01dcd 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -43,6 +43,7 @@ import org.apache.celeborn.client.listener.WorkerStatusListener import org.apache.celeborn.common.{CelebornConf, CommitMetadata} import org.apache.celeborn.common.CelebornConf.ACTIVE_STORAGE_TYPES import org.apache.celeborn.common.client.{ApplicationInfoProvider, MasterClient} +import org.apache.celeborn.common.compression.ChunkCompressionContext import org.apache.celeborn.common.identity.{IdentityProvider, UserIdentifier} import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.{ApplicationMeta, ShufflePartitionLocationInfo, WorkerInfo} @@ -1324,7 +1325,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends userIdentifier, conf.pushDataTimeoutMs, partitionSplitEnabled = true, - isSegmentGranularityVisible = isSegmentGranularityVisible)) + isSegmentGranularityVisible = isSegmentGranularityVisible, + chunkCompressionContext = new ChunkCompressionContext( + conf.isChunkCompressionEnabled, + conf.chunkCompressionLevel))) futures.add((future, workerInfo)) }(ec) } diff --git a/common/src/main/java/org/apache/celeborn/common/compression/ChunkCompressionContext.java b/common/src/main/java/org/apache/celeborn/common/compression/ChunkCompressionContext.java new file mode 100644 index 00000000000..c11ac60ab82 --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/compression/ChunkCompressionContext.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.compression; + +/** + * Carries chunk-level compression settings from the client through to the worker's {@code + * ChunkCompressedFileChannelWriter}. Using a context object instead of a bare boolean keeps the + * call chain stable as new compression knobs are added. + */ +public final class ChunkCompressionContext { + + /** ZSTD default compression level (mirrors {@code Zstd.defaultCompressionLevel()}). */ + public static final int DEFAULT_COMPRESSION_LEVEL = 3; + + private static final ChunkCompressionContext DISABLED = + new ChunkCompressionContext(false, DEFAULT_COMPRESSION_LEVEL); + + private final boolean enabled; + private final int compressionLevel; + + public ChunkCompressionContext(boolean enabled, int compressionLevel) { + this.enabled = enabled; + this.compressionLevel = compressionLevel; + } + + /** Returns a context with compression disabled and the default compression level. */ + public static ChunkCompressionContext disabled() { + return DISABLED; + } + + public boolean isEnabled() { + return enabled; + } + + public int getCompressionLevel() { + return compressionLevel; + } +} diff --git a/common/src/main/java/org/apache/celeborn/common/meta/DiskFileInfo.java b/common/src/main/java/org/apache/celeborn/common/meta/DiskFileInfo.java index d4571fa4bbe..dcc3135dad9 100644 --- a/common/src/main/java/org/apache/celeborn/common/meta/DiskFileInfo.java +++ b/common/src/main/java/org/apache/celeborn/common/meta/DiskFileInfo.java @@ -28,6 +28,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.compression.ChunkCompressionContext; import org.apache.celeborn.common.identity.UserIdentifier; import org.apache.celeborn.common.protocol.StorageInfo; import org.apache.celeborn.common.util.Utils; @@ -39,16 +40,19 @@ public class DiskFileInfo extends FileInfo { private static final Logger logger = LoggerFactory.getLogger(DiskFileInfo.class); private final String filePath; private final StorageInfo.Type storageType; + private final ChunkCompressionContext chunkCompressionContext; public DiskFileInfo( UserIdentifier userIdentifier, boolean partitionSplitEnabled, FileMeta fileMeta, String filePath, - StorageInfo.Type storageType) { + StorageInfo.Type storageType, + ChunkCompressionContext chunkCompressionContext) { super(userIdentifier, partitionSplitEnabled, fileMeta); this.filePath = filePath; this.storageType = storageType; + this.chunkCompressionContext = chunkCompressionContext; } // only called when restore from pb or in UT @@ -58,9 +62,11 @@ public DiskFileInfo( FileMeta fileMeta, String filePath, StorageInfo.Type storageType, - long bytesFlushed) { + long bytesFlushed, + ChunkCompressionContext chunkCompressionContext) { super(userIdentifier, partitionSplitEnabled, fileMeta); this.filePath = filePath; + this.chunkCompressionContext = chunkCompressionContext; if (storageType != null) { this.storageType = storageType; } else { @@ -76,13 +82,16 @@ public DiskFileInfo(File file, UserIdentifier userIdentifier, CelebornConf conf) true, new ReduceFileMeta(new ArrayList<>(Arrays.asList(0L)), conf.shuffleChunkSize()), file.getAbsolutePath(), - StorageInfo.Type.HDD); + StorageInfo.Type.HDD, + ChunkCompressionContext.disabled()); } + // User only by the sorted public DiskFileInfo(UserIdentifier userIdentifier, FileMeta fileMeta, String filePath) { super(userIdentifier, true, fileMeta); this.filePath = filePath; this.storageType = StorageInfo.Type.HDD; + this.chunkCompressionContext = ChunkCompressionContext.disabled(); } public File getFile() { @@ -175,4 +184,16 @@ public boolean isDFS() { public StorageInfo.Type getStorageType() { return storageType; } + + public boolean isChunkCompressionEnabled() { + return chunkCompressionContext.isEnabled(); + } + + public int getChunkCompressionLevel() { + return chunkCompressionContext.getCompressionLevel(); + } + + public ChunkCompressionContext getChunkCompressionContext() { + return chunkCompressionContext; + } } diff --git a/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java b/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java index e8511f1bff1..b8db69297da 100644 --- a/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java +++ b/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java @@ -63,6 +63,10 @@ public synchronized void updateBytesFlushed(long bytes) { } } + public synchronized void setBytesFlushed(long bytesFlushed) { + this.bytesFlushed = bytesFlushed; + } + public UserIdentifier getUserIdentifier() { return userIdentifier; } diff --git a/common/src/main/java/org/apache/celeborn/common/meta/ReduceFileMeta.java b/common/src/main/java/org/apache/celeborn/common/meta/ReduceFileMeta.java index abc4498814d..c2e8cdfc0cf 100644 --- a/common/src/main/java/org/apache/celeborn/common/meta/ReduceFileMeta.java +++ b/common/src/main/java/org/apache/celeborn/common/meta/ReduceFileMeta.java @@ -24,6 +24,7 @@ public class ReduceFileMeta implements FileMeta { private final AtomicBoolean sorted = new AtomicBoolean(false); private final List chunkOffsets; + private List chunkCompressed; private long chunkSize; private long nextBoundary; @@ -43,6 +44,21 @@ public ReduceFileMeta(List chunkOffsets, long chunkSize) { this.chunkSize = chunkSize; } + public ReduceFileMeta(List chunkOffsets, List chunkCompressed) { + this.chunkOffsets = chunkOffsets; + this.chunkCompressed = chunkCompressed; + } + + public ReduceFileMeta(List chunkOffsets, List chunkCompressed, long chunkSize) { + this.chunkOffsets = chunkOffsets; + this.chunkCompressed = chunkCompressed; + this.chunkSize = chunkSize; + nextBoundary = chunkSize; + if (!chunkOffsets.isEmpty()) { + nextBoundary += chunkOffsets.get(chunkOffsets.size() - 1); + } + } + public ReduceFileMeta(List chunkOffsets) { this.chunkOffsets = chunkOffsets; } @@ -51,6 +67,10 @@ public synchronized List getChunkOffsets() { return chunkOffsets; } + public synchronized List getChunkCompressed() { + return chunkCompressed; + } + public synchronized void addChunkOffset(long offset) { nextBoundary = offset + chunkSize; if (chunkOffsets.isEmpty() || chunkOffsets.get(chunkOffsets.size() - 1) != offset) { diff --git a/common/src/main/java/org/apache/celeborn/common/network/buffer/FileChunkBuffers.java b/common/src/main/java/org/apache/celeborn/common/network/buffer/FileChunkBuffers.java index ed4969d3ad0..0962b93d9e1 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/buffer/FileChunkBuffers.java +++ b/common/src/main/java/org/apache/celeborn/common/network/buffer/FileChunkBuffers.java @@ -27,15 +27,25 @@ public class FileChunkBuffers extends ChunkBuffers { private final File file; private final TransportConf conf; + private final boolean isChunkCompressed; public FileChunkBuffers(DiskFileInfo fileInfo, TransportConf conf) { super(fileInfo.getReduceFileMeta()); + isChunkCompressed = fileInfo.isChunkCompressionEnabled(); file = fileInfo.getFile(); this.conf = conf; } @Override public ManagedBuffer chunk(int chunkIndex, int offset, int len) { + if (isChunkCompressed && (offset != 0 || len != Integer.MAX_VALUE)) { + throw new IllegalArgumentException( + "Sliced reads (offset=" + + offset + + ", len=" + + len + + ") are not supported for chunk-compressed files"); + } Tuple2 offsetLen = getChunkOffsetLength(chunkIndex, offset, len); return new FileSegmentManagedBuffer(conf, file, offsetLen._1, offsetLen._2); } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index a813a9e5015..9434b37f288 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -538,6 +538,11 @@ message PbRegisterWorkerResponse { string message = 2; } +message PbChunkCompressionConfig { + bool enabled = 1; + int32 level = 2; +} + message PbReserveSlots { string applicationId = 1; int32 shuffleId = 2; @@ -553,6 +558,7 @@ message PbReserveSlots { int32 availableStorageTypes = 12; PbPackedPartitionLocationsPair partitionLocationsPair = 13; bool isSegmentGranularityVisible = 14; + PbChunkCompressionConfig chunkCompressionConfig = 15; } message PbReserveSlotsResponse { @@ -662,6 +668,8 @@ message PbFileInfo { map partitionWritingSegment = 10; repeated PbSegmentIndex segmentIndex = 11; int32 storageType = 12; + PbChunkCompressionConfig chunkCompressionConfig = 13; + repeated bool chunkCompressed = 14; } message PbSegmentIndex { @@ -757,6 +765,7 @@ message PbStreamHandler { int32 numChunks = 2; repeated int64 chunkOffsets = 3; string fullPath = 4; + repeated bool chunkCompressed = 5; } message PbOpenStreamList { diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 40d06617fea..d3d4910e00a 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -991,10 +991,9 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def shuffleDecompressionLz4XXHashInstance: Option[String] = get(SHUFFLE_DECOMPRESSION_LZ4_XXHASH_INSTANCE) def shuffleCompressionZstdCompressLevel: Int = get(SHUFFLE_COMPRESSION_ZSTD_LEVEL) - - // ////////////////////////////////////////////////////// - // Shuffle Client RPC // - // ////////////////////////////////////////////////////// + def isChunkCompressionEnabled: Boolean = get(CHUNK_COMPRESSION_ENABLED) + def chunkCompressionLevel: Int = get(CHUNK_COMPRESSION_LEVEL) + def chunkCompressionMmapTmpDir: String = get(CHUNK_COMPRESSION_MMAP_TMPDIR) def clientRpcCacheSize: Int = get(CLIENT_RPC_CACHE_SIZE) def clientRpcCacheConcurrencyLevel: Int = get(CLIENT_RPC_CACHE_CONCURRENCY_LEVEL) def clientRpcReserveSlotsRpcTimeout: RpcTimeout = @@ -5016,6 +5015,15 @@ object CelebornConf extends Logging { .checkValue(_ > 0, "Value must be positive!") .createWithDefaultString("120s") + val CHUNK_COMPRESSION_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.chunk.compression.enabled") + .categories("client") + .version("0.6.4") + .doc("Whether to enable chunk compression for shuffle data. If true, shuffle data will be compressed at a" + + " chunk level worker side and decompressed client side.") + .booleanConf + .createWithDefault(false) + val TEST_CLIENT_PUSH_PRIMARY_DATA_TIMEOUT: ConfigEntry[Boolean] = buildConf("celeborn.test.worker.pushPrimaryDataTimeout") .withAlternative("celeborn.test.pushMasterDataTimeout") @@ -5281,6 +5289,32 @@ object CelebornConf extends Logging { .checkValues(Set(PartitionSplitMode.SOFT.name, PartitionSplitMode.HARD.name)) .createWithDefault(PartitionSplitMode.SOFT.name) + val CHUNK_COMPRESSION_LEVEL: ConfigEntry[Int] = + buildConf("celeborn.chunk.compression.level") + .categories("client") + .doc( + "ZSTD compression level to use for chunk-level compression " + + "(celeborn.chunk.compression.enabled must be true). " + + "Valid range is between -5 and 22; the default (3) matches the ZSTD library default.") + .version("0.6.4") + .intConf + .checkValue( + value => value >= -5 && value <= 22, + s"Compression level for Zstd compression codec should be an integer between -5 and 22.") + .createWithDefault(3) + + val CHUNK_COMPRESSION_MMAP_TMPDIR: ConfigEntry[String] = + buildConf("celeborn.chunk.compression.mmap.tmpDir") + .categories("worker") + .doc( + "Directory used to create memory-mapped backing files for the mmap memory manager " + + "used by chunk-level compression. Defaults to a subdirectory of the JVM temporary " + + "directory (/celeborn-mmap-memory-manager).") + .version("0.6.4") + .stringConf + .transform(_.replace("TMP_DIR", System.getProperty("java.io.tmpdir"))) + .createWithDefault("TMP_DIR/celeborn-mmap-memory-manager") + val SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] = buildConf("celeborn.client.shuffle.compression.codec") .withAlternative("celeborn.shuffle.compression.codec") diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 36f164d697e..078d9ebacfc 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -26,6 +26,7 @@ import com.google.common.base.Preconditions.checkState import com.google.protobuf.ByteString import org.roaringbitmap.RoaringBitmap +import org.apache.celeborn.common.compression.ChunkCompressionContext import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.{DiskInfo, WorkerInfo, WorkerStatus} @@ -482,7 +483,8 @@ object ControlMessages extends Logging { userIdentifier: UserIdentifier, pushDataTimeout: Long, partitionSplitEnabled: Boolean = false, - isSegmentGranularityVisible: Boolean = false) + isSegmentGranularityVisible: Boolean = false, + chunkCompressionContext: ChunkCompressionContext = ChunkCompressionContext.disabled()) extends WorkerMessage case class ReserveSlotsResponse( @@ -961,7 +963,8 @@ object ControlMessages extends Logging { userIdentifier, pushDataTimeout, partitionSplitEnabled, - isSegmentGranularityVisible) => + isSegmentGranularityVisible, + chunkCompressionContext) => val payload = PbReserveSlots.newBuilder() .setApplicationId(applicationId) .setShuffleId(shuffleId) @@ -975,6 +978,10 @@ object ControlMessages extends Logging { .setPushDataTimeout(pushDataTimeout) .setPartitionSplitEnabled(partitionSplitEnabled) .setIsSegmentGranularityVisible(isSegmentGranularityVisible) + .setChunkCompressionConfig(PbChunkCompressionConfig.newBuilder() + .setEnabled(chunkCompressionContext.isEnabled) + .setLevel(chunkCompressionContext.getCompressionLevel) + .build()) .build().toByteArray new TransportMessage(MessageType.RESERVE_SLOTS, payload) @@ -1439,7 +1446,10 @@ object ControlMessages extends Logging { userIdentifier, pbReserveSlots.getPushDataTimeout, pbReserveSlots.getPartitionSplitEnabled, - pbReserveSlots.getIsSegmentGranularityVisible) + pbReserveSlots.getIsSegmentGranularityVisible, + new ChunkCompressionContext( + pbReserveSlots.getChunkCompressionConfig.getEnabled, + pbReserveSlots.getChunkCompressionConfig.getLevel)) case RESERVE_SLOTS_RESPONSE_VALUE => val pbReserveSlotsResponse = PbReserveSlotsResponse.parseFrom(message.getPayload) diff --git a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala index e9c407ce80e..816b67ad148 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import com.google.protobuf.InvalidProtocolBufferException +import org.apache.celeborn.common.compression.ChunkCompressionContext import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.meta.{ApplicationInfo, ApplicationMeta, DeviceInfo, DiskFileInfo, DiskInfo, MapFileMeta, ReduceFileMeta, WorkerEventInfo, WorkerInfo, WorkerStatus} import org.apache.celeborn.common.meta.MapFileMeta.SegmentIndex @@ -102,7 +103,13 @@ object PbSerDeUtils { def fromPbFileInfo(pbFileInfo: PbFileInfo, userIdentifier: UserIdentifier) = { val meta = Utils.toPartitionType(pbFileInfo.getPartitionType) match { case PartitionType.REDUCE => - new ReduceFileMeta(pbFileInfo.getChunkOffsetsList) + val chunkOffsets = pbFileInfo.getChunkOffsetsList + val chunkCompressed = pbFileInfo.getChunkCompressedList + if (!chunkCompressed.isEmpty) { + new ReduceFileMeta(chunkOffsets, chunkCompressed) + } else { + new ReduceFileMeta(chunkOffsets) + } case PartitionType.MAP => val fileMeta = new MapFileMeta( pbFileInfo.getBufferSize, @@ -132,7 +139,10 @@ object PbSerDeUtils { meta, pbFileInfo.getFilePath, storageType, - pbFileInfo.getBytesFlushed) + pbFileInfo.getBytesFlushed, + new ChunkCompressionContext( + pbFileInfo.getChunkCompressionConfig.getEnabled, + pbFileInfo.getChunkCompressionConfig.getLevel)) } private def fromPbSegmentIndexList( @@ -155,6 +165,10 @@ object PbSerDeUtils { .setBytesFlushed(fileInfo.getFileLength) .setPartitionSplitEnabled(fileInfo.isPartitionSplitEnabled) .setStorageType(fileInfo.getStorageType.getValue) + .setChunkCompressionConfig( + PbChunkCompressionConfig.newBuilder() + .setEnabled(fileInfo.isChunkCompressionEnabled) + .setLevel(fileInfo.getChunkCompressionLevel)) if (fileInfo.getFileMeta.isInstanceOf[MapFileMeta]) { val mapFileMeta = fileInfo.getFileMeta.asInstanceOf[MapFileMeta] builder.setPartitionType(PartitionType.MAP.getValue) @@ -168,6 +182,9 @@ object PbSerDeUtils { val reduceFileMeta = fileInfo.getFileMeta.asInstanceOf[ReduceFileMeta] builder.setPartitionType(PartitionType.REDUCE.getValue) builder.addAllChunkOffsets(reduceFileMeta.getChunkOffsets) + if (reduceFileMeta.getChunkCompressed != null && !reduceFileMeta.getChunkCompressed.isEmpty) { + builder.addAllChunkCompressed(reduceFileMeta.getChunkCompressed) + } } builder.build } diff --git a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala index 5b8fe9979a1..a039655d5aa 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala @@ -29,6 +29,7 @@ import com.google.common.collect.Lists import org.apache.hadoop.shaded.org.apache.commons.lang3.RandomStringUtils import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.common.compression.ChunkCompressionContext import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.meta._ import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType, PbFileInfo, PbPackedWorkerResource, PbWorkerResource, StorageInfo} @@ -81,42 +82,48 @@ class PbSerDeUtilsTest extends CelebornFunSuite { new ReduceFileMeta(chunkOffsets1, 123), file1.getAbsolutePath, StorageInfo.Type.HDD, - 3000L) + 3000L, + ChunkCompressionContext.disabled()) val fileInfo2 = new DiskFileInfo( userIdentifier2, true, new ReduceFileMeta(chunkOffsets2, 123), file2.getAbsolutePath, StorageInfo.Type.SSD, - 6000L) + 6000L, + ChunkCompressionContext.disabled()) val fileInfo3 = new DiskFileInfo( userIdentifier3, true, new ReduceFileMeta(chunkOffsets3, 123), file3, StorageInfo.Type.HDFS, - 6000L) + 6000L, + ChunkCompressionContext.disabled()) val fileInfo4 = new DiskFileInfo( userIdentifier3, true, new ReduceFileMeta(chunkOffsets3, 123), file4, StorageInfo.Type.OSS, - 6000L) + 6000L, + ChunkCompressionContext.disabled()) val fileInfo5 = new DiskFileInfo( userIdentifier3, true, new ReduceFileMeta(chunkOffsets3, 123), file5, StorageInfo.Type.S3, - 6000L) + 6000L, + ChunkCompressionContext.disabled()) val fileInfo6 = new DiskFileInfo( userIdentifier3, true, new ReduceFileMeta(chunkOffsets3, 123), file6, StorageInfo.Type.S3, - 6000L) + 6000L, + ChunkCompressionContext.disabled()) val mapFileInfo1 = new DiskFileInfo( userIdentifier1, @@ -124,14 +131,16 @@ class PbSerDeUtilsTest extends CelebornFunSuite { new MapFileMeta(1024, 10), file1.getAbsolutePath, StorageInfo.Type.HDD, - 6000L) + 6000L, + ChunkCompressionContext.disabled()) val mapFileInfo2 = new DiskFileInfo( userIdentifier2, true, new MapFileMeta(1024, 10), file2.getAbsolutePath, StorageInfo.Type.SSD, - 6000L) + 6000L, + ChunkCompressionContext.disabled()) val fileInfoMap = JavaUtils.newConcurrentHashMap[String, DiskFileInfo]() mapFileInfo1.setMountPoint("/mnt") mapFileInfo2.setMountPoint("/mnt") diff --git a/docs/configuration/client.md b/docs/configuration/client.md index ece9503bd19..76e7963e06d 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -19,6 +19,8 @@ license: | | Key | Default | isDynamic | Description | Since | Deprecated | | --- | ------- | --------- | ----------- | ----- | ---------- | +| celeborn.chunk.compression.enabled | false | false | Whether to enable chunk compression for shuffle data. If true, shuffle data will be compressed at a chunk level worker side and decompressed client side. | 0.6.4 | | +| celeborn.chunk.compression.level | 3 | false | ZSTD compression level to use for chunk-level compression (celeborn.chunk.compression.enabled must be true). Valid range is between -5 and 22; the default (3) matches the ZSTD library default. | 0.6.4 | | | celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled | false | false | If this is true, Celeborn will adaptively split skewed partitions instead of reading them by Spark map range. Please note that this feature requires the `Celeborn-Optimize-Skew-Partitions-spark3_3.patch`. | 0.6.0 | | | celeborn.client.application.heartbeatInterval | 10s | false | Interval for client to send heartbeat message to master. | 0.3.0 | celeborn.application.heartbeatInterval | | celeborn.client.application.info.provider | org.apache.celeborn.common.client.DefaultApplicationInfoProvider | false | ApplicationInfoProvider class name. Default class is `org.apache.celeborn.common.client.DefaultApplicationInfoProvider`. Optional values: org.apache.celeborn.common.identity.DefaultIdentityProvider user name and tenant id are default values or user-specific values. | 0.6.1 | | diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md index bb2cec89bc3..0bb9784fb40 100644 --- a/docs/configuration/worker.md +++ b/docs/configuration/worker.md @@ -19,6 +19,7 @@ license: | | Key | Default | isDynamic | Description | Since | Deprecated | | --- | ------- | --------- | ----------- | ----- | ---------- | +| celeborn.chunk.compression.mmap.tmpDir | TMP_DIR/celeborn-mmap-memory-manager | false | Directory used to create memory-mapped backing files for the mmap memory manager used by chunk-level compression. Defaults to a subdirectory of the JVM temporary directory (/celeborn-mmap-memory-manager). | 0.6.4 | | | celeborn.cluster.name | default | false | Celeborn cluster name. | 0.5.0 | | | celeborn.container.info.provider | org.apache.celeborn.server.common.container.DefaultContainerInfoProvider | false | ContainerInfoProvider class name. Default class is `org.apache.celeborn.server.common.container.DefaultContainerInfoProvider`. | 0.6.0 | | | celeborn.dynamicConfig.refresh.interval | 120s | false | Interval for refreshing the corresponding dynamic config periodically. | 0.4.0 | | diff --git a/project/CelebornBuild.scala b/project/CelebornBuild.scala index d492f308215..58983e04461 100644 --- a/project/CelebornBuild.scala +++ b/project/CelebornBuild.scala @@ -856,6 +856,7 @@ object CelebornWorker { Dependencies.log4jSlf4jImpl, Dependencies.disruptor, Dependencies.leveldbJniAll, + Dependencies.zstdJni, Dependencies.roaringBitmap, Dependencies.rocksdbJni, Dependencies.scalatestMockito % "test", diff --git a/worker/pom.xml b/worker/pom.xml index fe7fa9e5b75..d5427f6700d 100644 --- a/worker/pom.xml +++ b/worker/pom.xml @@ -78,6 +78,10 @@ org.roaringbitmap RoaringBitmap + + com.github.luben + zstd-jni + org.apache.logging.log4j log4j-slf4j-impl diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/BypassFileChannelWriter.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/BypassFileChannelWriter.java new file mode 100644 index 00000000000..ac240cda6d0 --- /dev/null +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/BypassFileChannelWriter.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.file; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; + +import io.netty.buffer.CompositeByteBuf; + +import org.apache.celeborn.common.meta.DiskFileInfo; +import org.apache.celeborn.common.util.FileChannelUtils; + +public class BypassFileChannelWriter extends FileChannelWriter { + private final FileChannel channel; + + public BypassFileChannelWriter(DiskFileInfo diskFileInfo) throws IOException { + channel = FileChannelUtils.createWritableFileChannel(diskFileInfo.getFilePath()); + } + + @Override + public void write(CompositeByteBuf buffer, boolean gatherApiEnabled) throws IOException { + ByteBuffer[] buffers = buffer.nioBuffers(); + if (gatherApiEnabled) { + int readableBytes = buffer.readableBytes(); + long written = 0L; + do { + written = channel.write(buffers) + written; + } while (written != readableBytes); + } else { + for (ByteBuffer byteBuffer : buffers) { + while (byteBuffer.hasRemaining()) { + channel.write(byteBuffer); + } + } + } + } + + @Override + public void close(boolean commitFilesFsync) throws IOException { + try { + if (commitFilesFsync) { + channel.force(false); + } + } finally { + channel.close(); + } + } +} diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/FileChannelWriter.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/FileChannelWriter.java new file mode 100644 index 00000000000..ba50a909f23 --- /dev/null +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/FileChannelWriter.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.file; + +import java.io.IOException; + +import io.netty.buffer.CompositeByteBuf; + +public abstract class FileChannelWriter { + public abstract void write(CompositeByteBuf buffer, boolean gatherApiEnabled) throws IOException; + + public abstract void close(boolean commitFilesFsync) throws IOException; +} diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/FileChannelWriterFactory.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/FileChannelWriterFactory.java new file mode 100644 index 00000000000..05f50257c72 --- /dev/null +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/FileChannelWriterFactory.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.file; + +import java.io.IOException; + +import org.apache.celeborn.common.meta.DiskFileInfo; +import org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkBufferPool; +import org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter; + +public class FileChannelWriterFactory { + public static FileChannelWriter getFileChannelWriter( + DiskFileInfo diskFileInfo, long chunkSize, ChunkBufferPool chunkBufferPool) + throws IOException { + if (diskFileInfo.isChunkCompressionEnabled()) { + return new ChunkCompressedFileChannelWriter( + diskFileInfo, chunkSize, diskFileInfo.getChunkCompressionLevel(), chunkBufferPool); + } else { + return new BypassFileChannelWriter(diskFileInfo); + } + } +} diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/FileWriterType.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/FileWriterType.java new file mode 100644 index 00000000000..50fd3fada8e --- /dev/null +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/FileWriterType.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.file; + +public enum FileWriterType { + CHUNK_COMPRESSED, + BYPASS +} diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/chunk/compressed/ChunkBufferPool.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/chunk/compressed/ChunkBufferPool.java new file mode 100644 index 00000000000..c903fcd2230 --- /dev/null +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/chunk/compressed/ChunkBufferPool.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.file.chunk.compressed; + +import java.nio.ByteBuffer; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedDeque; + +import com.github.luben.zstd.Zstd; + +import org.apache.celeborn.common.CelebornConf; + +/** + * Pool of reusable (chunkBuffer, compressedBuffer) pairs for ChunkCompressedFileChannelWriter, + * bucketed by chunkSize so every acquired pair is exactly the right capacity. + * + *

Owns and manages the lifecycle of its internal {@link MmapMemoryManager}. Call {@link #close} + * when the pool is no longer needed to release the mmap backing files. + */ +public class ChunkBufferPool { + + public static class BufferPair { + public final ByteBuffer chunkBuffer; + public final ByteBuffer compressedBuffer; + public final long chunkSize; + + public BufferPair(ByteBuffer chunkBuffer, ByteBuffer compressedBuffer, long chunkSize) { + this.chunkBuffer = chunkBuffer; + this.compressedBuffer = compressedBuffer; + this.chunkSize = chunkSize; + } + } + + private final MmapMemoryManager mmapMemoryManager; + private final ConcurrentHashMap> poolMap = + new ConcurrentHashMap<>(); + + public ChunkBufferPool(CelebornConf conf) { + this.mmapMemoryManager = new MmapMemoryManager(conf.chunkCompressionMmapTmpDir()); + } + + public BufferPair acquire(long chunkSize) { + ConcurrentLinkedDeque bucket = + poolMap.computeIfAbsent(chunkSize, k -> new ConcurrentLinkedDeque<>()); + BufferPair pair = bucket.pollFirst(); + if (pair != null) { + pair.chunkBuffer.clear(); + pair.compressedBuffer.clear(); + return pair; + } + int chunkBufSize = Math.toIntExact(chunkSize); + int compressedBufSize = Math.toIntExact(Zstd.compressBound(chunkSize)); + ByteBuffer chunkBuf = mmapMemoryManager.allocateBuffer(chunkBufSize); + ByteBuffer compressedBuf = mmapMemoryManager.allocateBuffer(compressedBufSize); + return new BufferPair(chunkBuf, compressedBuf, chunkSize); + } + + /** Returns the pair to the bucket matching its chunkSize. */ + public void release(BufferPair pair) { + pair.chunkBuffer.clear(); + pair.compressedBuffer.clear(); + poolMap.computeIfAbsent(pair.chunkSize, k -> new ConcurrentLinkedDeque<>()).offerFirst(pair); + } + + public void close() { + poolMap.clear(); + mmapMemoryManager.close(); + } +} diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/chunk/compressed/ChunkCompressedFileChannelWriter.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/chunk/compressed/ChunkCompressedFileChannelWriter.java new file mode 100644 index 00000000000..1ae0d424528 --- /dev/null +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/chunk/compressed/ChunkCompressedFileChannelWriter.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.file.chunk.compressed; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.util.ArrayList; +import java.util.List; + +import com.github.luben.zstd.Zstd; +import com.google.common.annotations.VisibleForTesting; +import io.netty.buffer.CompositeByteBuf; + +import org.apache.celeborn.common.meta.DiskFileInfo; +import org.apache.celeborn.common.meta.ReduceFileMeta; +import org.apache.celeborn.common.util.FileChannelUtils; +import org.apache.celeborn.service.deploy.worker.file.FileChannelWriter; + +public class ChunkCompressedFileChannelWriter extends FileChannelWriter { + private final FileChannel channel; + private final DiskFileInfo diskFileInfo; + private final int compressionLevel; + private final ChunkBufferPool chunkBufferPool; + private final ChunkBufferPool.BufferPair bufferPair; + private ByteBuffer chunkBuffer; + private ByteBuffer compressedChunkBuffer; + private final List chunkOffsets; + private final List chunkCompressed; + private final long chunkSize; + private boolean closed = false; + + public ChunkCompressedFileChannelWriter( + DiskFileInfo diskFileInfo, + long chunkSize, + int compressionLevel, + ChunkBufferPool chunkBufferPool) + throws IOException { + this.diskFileInfo = diskFileInfo; + this.chunkSize = chunkSize; + channel = FileChannelUtils.createWritableFileChannel(diskFileInfo.getFilePath()); + this.compressionLevel = compressionLevel; + this.chunkBufferPool = chunkBufferPool; + bufferPair = chunkBufferPool.acquire(chunkSize); + chunkBuffer = bufferPair.chunkBuffer; + compressedChunkBuffer = bufferPair.compressedBuffer; + chunkOffsets = new ArrayList<>(); + chunkOffsets.add(0L); + chunkCompressed = new ArrayList<>(); + } + + @Override + public void write(CompositeByteBuf buffer, boolean gatherApiEnabled) throws IOException { + if (buffer.readableBytes() > chunkSize) { + // Flush any pending accumulated data before writing the large record so file offsets + // remain consistent. + compressAndFlush(); + flushLargeRecord(buffer); + return; + } + + if (buffer.readableBytes() > chunkBuffer.remaining()) { + compressAndFlush(); + } + + ByteBuffer[] buffers = buffer.nioBuffers(); + for (ByteBuffer byteBuffer : buffers) { + while (byteBuffer.hasRemaining()) { + chunkBuffer.put(byteBuffer); + } + } + } + + /** + * Writes the large record directly to the channel without compression. Large records span a full + * chunk on their own, so the decompression overhead would be paid all at once anyway; skipping + * compression avoids the ZstdOutputStream frame overhead and simplifies the write path. + */ + private void flushLargeRecord(CompositeByteBuf buffer) throws IOException { + ByteBuffer[] buffers = buffer.nioBuffers(); + for (ByteBuffer buf : buffers) { + while (buf.hasRemaining()) { + channel.write(buf); + } + } + chunkCompressed.add(false); + chunkOffsets.add(channel.position()); + } + + @VisibleForTesting + public void compressAndFlush() throws IOException { + int size = chunkBuffer.position(); + if (size == 0) return; + chunkBuffer.position(0); + chunkBuffer.limit(size); + compressedChunkBuffer.clear(); + int compressedSize; + try { + compressedSize = + (int) + Zstd.compressDirectByteBuffer( + compressedChunkBuffer, + 0, + compressedChunkBuffer.capacity(), + chunkBuffer, + 0, + size, + compressionLevel); + } catch (RuntimeException e) { + throw new IOException("Failed to compress chunk with ZSTD.", e); + } + if (Zstd.isError(compressedSize)) { + throw new IOException("ZSTD compression failed: " + Zstd.getErrorName(compressedSize)); + } + compressedChunkBuffer.position(0); + compressedChunkBuffer.limit(compressedSize); + + long written = 0L; + while (written < compressedSize) { + written += channel.write(compressedChunkBuffer); + } + chunkCompressed.add(true); + chunkOffsets.add((chunkOffsets.get(chunkOffsets.size() - 1) + written)); + chunkBuffer.clear(); + } + + @Override + public void close(boolean commitFilesFsync) throws IOException { + if (closed) { + return; + } + closed = true; + IOException failure = null; + try { + compressAndFlush(); + if (commitFilesFsync) { + channel.force(false); + } + } catch (IOException e) { + failure = e; + } finally { + chunkBufferPool.release(bufferPair); + try { + channel.close(); + } catch (IOException e) { + if (failure == null) { + failure = e; + } + } + } + + if (failure != null) { + throw failure; + } + diskFileInfo.setBytesFlushed(chunkOffsets.get(chunkOffsets.size() - 1)); + diskFileInfo.replaceFileMeta(new ReduceFileMeta(chunkOffsets, chunkCompressed, chunkSize)); + } +} diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/chunk/compressed/MmapMemoryManager.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/chunk/compressed/MmapMemoryManager.java new file mode 100644 index 00000000000..500fd186310 --- /dev/null +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/file/chunk/compressed/MmapMemoryManager.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.file.chunk.compressed; + +import java.io.File; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.util.LinkedList; +import java.util.List; +import java.util.UUID; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class MmapMemoryManager { + private static final Logger LOG = LoggerFactory.getLogger(MmapMemoryManager.class); + private static final int DEFAULT_FILE_LENGTH = 512 * 1024 * 1024; + private final String _dirPathName; + // _availableOffset has the starting offset for the next allocation in _currentBuffer. When + // _currentBuffer + // is created, it is 0. After we allocate a buffer of size x, it is x. And if we allocate another + // buffer of size + // y, then it becomes x+y, etc. We try to fulfil as many allocate() calls as possible on the same + // _currentBuffer + // until the _currentBuffer cannot hold the new object anymore, and then we create a new + // _currentBuffer. + private int _availableOffset = DEFAULT_FILE_LENGTH; // Available offset in this file. + private int _curFileLen = -1; + private final List _paths = new LinkedList<>(); + private final List _memMappedBuffers = new LinkedList<>(); + ByteBuffer _currentBuffer; + + public MmapMemoryManager(String dirPathName) { + File dirFile = new File(dirPathName); + if (!dirFile.exists()) { + if (!dirFile.mkdirs()) { + throw new RuntimeException("Unable to create directory: " + dirFile); + } + } + _dirPathName = dirPathName; + } + + private String getFilePrefix() { + return UUID.randomUUID() + "."; + } + + private void addFileIfNecessary(int len) { + if (len + _availableOffset <= _curFileLen) { + return; + } + String filePath = _dirPathName + "/" + getFilePrefix(); + final File file = new File(filePath); + if (file.exists()) { + throw new RuntimeException("File " + filePath + " already exists"); + } + file.deleteOnExit(); + int fileLen = Math.max(DEFAULT_FILE_LENGTH, len); + try (RandomAccessFile raf = new RandomAccessFile(filePath, "rw"); + FileChannel fileChannel = raf.getChannel()) { + raf.setLength(fileLen); + _currentBuffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0, fileLen); + _memMappedBuffers.add(_currentBuffer); + } catch (IOException e) { + throw new RuntimeException(e); + } + _paths.add(filePath); + _availableOffset = 0; + _curFileLen = fileLen; + } + + public synchronized ByteBuffer allocateBuffer(int size) { + addFileIfNecessary(size); + ByteBuffer buffer = _currentBuffer.duplicate(); + buffer.position(_availableOffset); + buffer.limit(_availableOffset + size); + _availableOffset += size; + return buffer.slice(); + } + + public synchronized void close() { + // MappedByteBuffers cannot be explicitly unmapped in Java; GC handles the unmap. + // We clear the internal state and delete the backing files so disk space is reclaimed. + _memMappedBuffers.clear(); + for (String path : _paths) { + File file = new File(path); + if (!file.delete()) { + LOG.warn("Unable to delete mmap backing file: {}", file); + } + } + _paths.clear(); + _curFileLen = -1; + _availableOffset = DEFAULT_FILE_LENGTH; + } +} diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionDataWriterContext.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionDataWriterContext.java index 708176c48dd..e31ee94919e 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionDataWriterContext.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionDataWriterContext.java @@ -19,6 +19,7 @@ import java.io.File; +import org.apache.celeborn.common.compression.ChunkCompressionContext; import org.apache.celeborn.common.identity.UserIdentifier; import org.apache.celeborn.common.protocol.PartitionLocation; import org.apache.celeborn.common.protocol.PartitionSplitMode; @@ -37,6 +38,7 @@ public class PartitionDataWriterContext { private final String shuffleKey; private final PartitionType partitionType; private final boolean isSegmentGranularityVisible; + private final ChunkCompressionContext chunkCompressionContext; private File workingDir; private PartitionDataWriter partitionDataWriter; @@ -52,7 +54,8 @@ public PartitionDataWriterContext( UserIdentifier userIdentifier, PartitionType partitionType, boolean partitionSplitEnabled, - boolean isSegmentGranularityVisible) { + boolean isSegmentGranularityVisible, + ChunkCompressionContext chunkCompressionContext) { this.splitThreshold = splitThreshold; this.partitionSplitMode = partitionSplitMode; this.rangeReadFilter = rangeReadFilter; @@ -64,6 +67,7 @@ public PartitionDataWriterContext( this.partitionType = partitionType; this.shuffleKey = Utils.makeShuffleKey(appId, shuffleId); this.isSegmentGranularityVisible = isSegmentGranularityVisible; + this.chunkCompressionContext = chunkCompressionContext; } public long getSplitThreshold() { @@ -98,6 +102,14 @@ public boolean isPartitionSplitEnabled() { return partitionSplitEnabled; } + public boolean isChunkCompressionEnabled() { + return chunkCompressionContext.isEnabled(); + } + + public ChunkCompressionContext getChunkCompressionContext() { + return chunkCompressionContext; + } + public String getShuffleKey() { return shuffleKey; } diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java index 5979a8434c2..35556d5c09f 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java @@ -234,6 +234,15 @@ public FileInfo getSortedFileInfo( targetBuffer); } else { DiskFileInfo diskFileInfo = ((DiskFileInfo) fileInfo); + if (diskFileInfo.isChunkCompressionEnabled()) { + // TODO this is yet to be implemented + // We can read the file one chunk at a time and store chunkid + uncompressed offsets before + // writing + throw new IOException( + "Chunk compressed shuffle file is not supported for sorting, file path: " + + diskFileInfo.getFilePath() + + ". Set celeborn.chunk.compression.enabled=false or disable range reads"); + } String fileId = shuffleKey + "-" + fileName; UserIdentifier userIdentifier = diskFileInfo.getUserIdentifier(); Set sorted = diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala index 565acb44182..f78f120e402 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Controller.scala @@ -30,6 +30,7 @@ import io.netty.util.{HashedWheelTimer, Timeout, TimerTask} import org.roaringbitmap.RoaringBitmap import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.compression.ChunkCompressionContext import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.{WorkerInfo, WorkerPartitionLocationInfo} @@ -114,7 +115,8 @@ private[deploy] class Controller( userIdentifier, pushDataTimeout, partitionSplitEnabled, - isSegmentGranularityVisible) => + isSegmentGranularityVisible, + chunkCompressionContext) => checkAuth(context, applicationId) val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId) workerSource.sample(WorkerSource.RESERVE_SLOTS_TIME, shuffleKey) { @@ -134,7 +136,8 @@ private[deploy] class Controller( userIdentifier, pushDataTimeout, partitionSplitEnabled, - isSegmentGranularityVisible) + isSegmentGranularityVisible, + chunkCompressionContext) logDebug(s"ReserveSlots for $shuffleKey finished.") } @@ -181,7 +184,8 @@ private[deploy] class Controller( userIdentifier: UserIdentifier, pushDataTimeout: Long, partitionSplitEnabled: Boolean, - isSegmentGranularityVisible: Boolean): Unit = { + isSegmentGranularityVisible: Boolean, + chunkCompressionContext: ChunkCompressionContext): Unit = { val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId) if (shutdown.get()) { val msg = "Current worker is shutting down!" @@ -213,7 +217,8 @@ private[deploy] class Controller( userIdentifier, partitionSplitEnabled, isSegmentGranularityVisible, - isPrimary = true) + isPrimary = true, + chunkCompressionContext) if (primaryLocs.size() < requestPrimaryLocs.size()) { val msg = s"Not all primary partition satisfied for $shuffleKey" logWarning(s"[handleReserveSlots] $msg, will destroy writers.") @@ -234,7 +239,8 @@ private[deploy] class Controller( userIdentifier, partitionSplitEnabled, isSegmentGranularityVisible, - isPrimary = false) + isPrimary = false, + chunkCompressionContext) if (replicaLocs.size() < requestReplicaLocs.size()) { val msg = s"Not all replica partition satisfied for $shuffleKey" logWarning(s"[handleReserveSlots] $msg, destroy writers.") @@ -277,7 +283,8 @@ private[deploy] class Controller( userIdentifier: UserIdentifier, partitionSplitEnabled: Boolean, isSegmentGranularityVisible: Boolean, - isPrimary: Boolean): jList[PartitionLocation] = { + isPrimary: Boolean, + chunkCompressionContext: ChunkCompressionContext): jList[PartitionLocation] = { val partitionLocations = new jArrayList[PartitionLocation]() try { def createWriter(partitionLocation: PartitionLocation): PartitionLocation = { @@ -293,7 +300,8 @@ private[deploy] class Controller( userIdentifier, partitionSplitEnabled, isSegmentGranularityVisible, - isPrimary) + isPrimary, + chunkCompressionContext) } if (createWriterThreadPool == null) { partitionLocations.addAll(requestLocs.asScala.map(createWriter).asJava) @@ -323,7 +331,8 @@ private[deploy] class Controller( userIdentifier: UserIdentifier, partitionSplitEnabled: Boolean, isSegmentGranularityVisible: Boolean, - isPrimary: Boolean): PartitionLocation = { + isPrimary: Boolean, + chunkCompressionContext: ChunkCompressionContext): PartitionLocation = { try { var location = if (isPrimary) { @@ -347,7 +356,8 @@ private[deploy] class Controller( rangeReadFilter, userIdentifier, partitionSplitEnabled, - isSegmentGranularityVisible) + isSegmentGranularityVisible, + chunkCompressionContext) new WorkingPartition(location, writer) } else { location diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index 7ad990e2bf4..342f014043f 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -287,7 +287,8 @@ class FetchHandler( streamId, meta.getNumChunks, meta.getChunkOffsets, - fileInfo.asInstanceOf[DiskFileInfo].getFilePath) + fileInfo.asInstanceOf[DiskFileInfo].getFilePath, + meta.getChunkCompressed) } else fileInfo match { case info: DiskFileInfo if info.isHdfs => chunkStreamManager.registerStream( @@ -337,7 +338,8 @@ class FetchHandler( s"${NettyUtils.getRemoteAddress(client.getChannel)}") makeStreamHandler( streamId, - meta.getNumChunks) + meta.getNumChunks, + chunkCompressed = meta.getChunkCompressed) } workerSource.incCounter(WorkerSource.OPEN_STREAM_SUCCESS_COUNT) PbStreamHandlerOpt.newBuilder().setStreamHandler(streamHandler) @@ -423,7 +425,8 @@ class FetchHandler( streamId: Long, numChunks: Int, offsets: util.List[java.lang.Long] = null, - filepath: String = ""): PbStreamHandler = { + filepath: String = "", + chunkCompressed: util.List[java.lang.Boolean] = null): PbStreamHandler = { val pbStreamHandlerBuilder = PbStreamHandler.newBuilder.setStreamId(streamId).setNumChunks( numChunks) if (offsets != null) { @@ -432,6 +435,9 @@ class FetchHandler( if (filepath.nonEmpty) { pbStreamHandlerBuilder.setFullPath(filepath) } + if (chunkCompressed != null) { + pbStreamHandlerBuilder.addAllChunkCompressed(chunkCompressed) + } pbStreamHandlerBuilder.build() } diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTask.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTask.scala index 67f45e28b4c..1fc30a0c5e5 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTask.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTask.scala @@ -18,7 +18,6 @@ package org.apache.celeborn.service.deploy.worker.storage import java.io.{ByteArrayInputStream, Closeable, IOException} -import java.nio.channels.FileChannel import io.netty.buffer.{ByteBufUtil, CompositeByteBuf} import org.apache.hadoop.fs.{FSDataOutputStream, Path} @@ -28,6 +27,7 @@ import org.apache.celeborn.common.metrics.source.AbstractSource import org.apache.celeborn.common.protocol.StorageInfo.Type import org.apache.celeborn.server.common.service.mpu.MultipartUploadHandler import org.apache.celeborn.service.deploy.worker.WorkerSource +import org.apache.celeborn.service.deploy.worker.file.FileChannelWriter abstract private[worker] class FlushTask( val buffer: CompositeByteBuf, @@ -51,27 +51,14 @@ abstract private[worker] class FlushTask( private[worker] class LocalFlushTask( buffer: CompositeByteBuf, - fileChannel: FileChannel, + fileChannelWriter: FileChannelWriter, notifier: FlushNotifier, keepBuffer: Boolean, source: AbstractSource, gatherApiEnabled: Boolean) extends FlushTask(buffer, notifier, keepBuffer, source) { override def flush(copyBytes: Array[Byte]): Unit = { val readableBytes = buffer.readableBytes() - val buffers = buffer.nioBuffers() - if (gatherApiEnabled) { - val readableBytes = buffer.readableBytes() - var written = 0L - do { - written = fileChannel.write(buffers) + written - } while (written != readableBytes) - } else { - for (buffer <- buffers) { - while (buffer.hasRemaining) { - fileChannel.write(buffer) - } - } - } + fileChannelWriter.write(buffer, gatherApiEnabled) source.incCounter(WorkerSource.LOCAL_FLUSH_COUNT) source.incCounter(WorkerSource.LOCAL_FLUSH_SIZE, readableBytes) // TODO: force flush file channel in scenarios where the upstream task writes and the downstream task reads simultaneously, such as flink hybrid shuffle. diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala index 9a2d4a8a740..30498bbc588 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala @@ -36,6 +36,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.fs.permission.FsPermission import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.compression.ChunkCompressionContext import org.apache.celeborn.common.exception.CelebornException import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.internal.Logging @@ -47,6 +48,7 @@ import org.apache.celeborn.common.protocol.StorageInfo.Type import org.apache.celeborn.common.quota.ResourceConsumption import org.apache.celeborn.common.util.{CelebornExitKind, CelebornHadoopUtils, CollectionUtils, DiskUtils, JavaUtils, PbSerDeUtils, ThreadUtils, Utils} import org.apache.celeborn.service.deploy.worker._ +import org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkBufferPool import org.apache.celeborn.service.deploy.worker.memory.MemoryManager import org.apache.celeborn.service.deploy.worker.memory.MemoryManager.MemoryPressureListener import org.apache.celeborn.service.deploy.worker.shuffledb.{DB, DBBackend, DBProvider} @@ -81,6 +83,8 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs val diskReserveRatio = conf.workerDiskReserveRatio var s3MultipartUploadHandlerSharedState: AutoCloseable = _ + val chunkBufferPool: ChunkBufferPool = new ChunkBufferPool(conf) + // (deviceName -> deviceInfo) and (mount point -> diskInfo) val (deviceInfos, diskInfos) = { val workingDirInfos = @@ -428,7 +432,8 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs splitMode: PartitionSplitMode, partitionType: PartitionType, rangeReadFilter: Boolean, - userIdentifier: UserIdentifier): PartitionDataWriter = { + userIdentifier: UserIdentifier, + chunkCompressionContext: ChunkCompressionContext): PartitionDataWriter = { createPartitionDataWriter( appId, shuffleId, @@ -439,7 +444,8 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs rangeReadFilter, userIdentifier, true, - isSegmentGranularityVisible = false) + isSegmentGranularityVisible = false, + chunkCompressionContext) } def ensureS3MultipartUploaderSharedState(): Unit = this.synchronized { @@ -492,7 +498,8 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs rangeReadFilter: Boolean, userIdentifier: UserIdentifier, partitionSplitEnabled: Boolean, - isSegmentGranularityVisible: Boolean): PartitionDataWriter = { + isSegmentGranularityVisible: Boolean, + chunkCompressionContext: ChunkCompressionContext): PartitionDataWriter = { if (healthyLocalWorkingDirs().isEmpty && remoteStorageDirs.isEmpty) { throw new IOException("No available working dirs!") } @@ -506,7 +513,8 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs userIdentifier, partitionType, partitionSplitEnabled, - isSegmentGranularityVisible) + isSegmentGranularityVisible, + chunkCompressionContext) val writer = try { @@ -898,6 +906,8 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs if (s3MultipartUploadHandlerSharedState != null) s3MultipartUploadHandlerSharedState.close() + + chunkBufferPool.close() } private def flushFileWriters(): Unit = { @@ -1085,7 +1095,8 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs location.getFileName, partitionDataWriterContext.getUserIdentifier, partitionDataWriterContext.getPartitionType, - partitionDataWriterContext.isPartitionSplitEnabled) + partitionDataWriterContext.isPartitionSplitEnabled, + partitionDataWriterContext.getChunkCompressionContext) (null, createDiskFileResult._1, createDiskFileResult._2, createDiskFileResult._3) } else { (null, null, null, null) @@ -1129,6 +1140,7 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs userIdentifier: UserIdentifier, partitionType: PartitionType, partitionSplitEnabled: Boolean, + chunkCompressionContext: ChunkCompressionContext, overrideStorageType: StorageInfo.Type = null): (Flusher, DiskFileInfo, File) = { val suggestedMountPoint = location.getStorageInfo.getMountPoint @@ -1174,7 +1186,8 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs partitionSplitEnabled, getFileMeta(partitionType, s"hdfs", conf.shuffleChunkSize), hdfsFilePath, - StorageInfo.Type.HDFS) + StorageInfo.Type.HDFS, + ChunkCompressionContext.disabled()) diskFileInfos.computeIfAbsent(shuffleKey, diskFileInfoMapFunc).put( fileName, hdfsFileInfo) @@ -1189,7 +1202,8 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs partitionSplitEnabled, new ReduceFileMeta(conf.shuffleChunkSize), s3FilePath, - StorageInfo.Type.S3) + StorageInfo.Type.S3, + ChunkCompressionContext.disabled()) diskFileInfos.computeIfAbsent(shuffleKey, diskFileInfoMapFunc).put( fileName, s3FileInfo) @@ -1207,7 +1221,8 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs partitionSplitEnabled, new ReduceFileMeta(conf.shuffleChunkSize), ossFilePath, - StorageInfo.Type.OSS) + StorageInfo.Type.OSS, + ChunkCompressionContext.disabled()) diskFileInfos.computeIfAbsent(shuffleKey, diskFileInfoMapFunc).put( fileName, ossFileInfo) @@ -1237,7 +1252,8 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs partitionSplitEnabled, fileMeta, filePath, - storageType) + storageType, + chunkCompressionContext) logInfo(s"created file at $filePath") diskFileInfos.computeIfAbsent(shuffleKey, diskFileInfoMapFunc).put( fileName, diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StoragePolicy.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StoragePolicy.scala index 7168a809dc7..ccbd0e7cb9b 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StoragePolicy.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StoragePolicy.scala @@ -135,6 +135,7 @@ class StoragePolicy(conf: CelebornConf, storageManager: StorageManager, source: partitionDataWriterContext.getUserIdentifier, partitionDataWriterContext.getPartitionType, partitionDataWriterContext.isPartitionSplitEnabled, + partitionDataWriterContext.getChunkCompressionContext, overrideType // this is different from location type, in case of eviction ) partitionDataWriterContext.setWorkingDir(workingDir) diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriter.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriter.scala index a04f1d67613..6e09a046e48 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriter.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriter.scala @@ -40,6 +40,7 @@ import org.apache.celeborn.common.util.Utils import org.apache.celeborn.server.common.service.mpu.MultipartUploadHandler import org.apache.celeborn.service.deploy.worker.WorkerSource import org.apache.celeborn.service.deploy.worker.congestcontrol.{CongestionController, UserCongestionControlContext} +import org.apache.celeborn.service.deploy.worker.file.{FileChannelWriter, FileChannelWriterFactory, FileWriterType} import org.apache.celeborn.service.deploy.worker.memory.MemoryManager abstract class TierWriterBase( @@ -115,6 +116,7 @@ abstract class TierWriterBase( } catch { case e: IOException => logWarning(s"close file writer $this failed", e) + throw e } } notifyFileCommitted() @@ -414,8 +416,11 @@ class LocalTierWriter( partitionDataWriterContext.getWorkingDir, fileInfo.asInstanceOf[DiskFileInfo]) - private lazy val channel: FileChannel = - FileChannelUtils.createWritableFileChannel(diskFileInfo.getFilePath) + private lazy val fileChannelWriter: FileChannelWriter = + FileChannelWriterFactory.getFileChannelWriter( + diskFileInfo, + conf.shuffleChunkSize, + storageManager.chunkBufferPool) val gatherApiEnabled: Boolean = conf.workerFlusherLocalGatherAPIEnabled val commitFilesFsync: Boolean = conf.workerCommitFilesFsync @@ -426,7 +431,7 @@ class LocalTierWriter( override def genFlushTask(finalFlush: Boolean, keepBuffer: Boolean): FlushTask = { notifier.numPendingFlushes.incrementAndGet() - new LocalFlushTask(flushBuffer, channel, notifier, true, source, gatherApiEnabled) + new LocalFlushTask(flushBuffer, fileChannelWriter, notifier, true, source, gatherApiEnabled) } override def writeInternal(buf: ByteBuf): Unit = { @@ -459,14 +464,9 @@ class LocalTierWriter( } override def closeStreams(): Unit = { - if (channel != null) { - try { - if (commitFilesFsync) { - channel.force(false) - } - } finally { - channel.close() - } + if (fileChannelWriter != null) { + // Closing with / without sync here + fileChannelWriter.close(commitFilesFsync) } } @@ -474,7 +474,7 @@ class LocalTierWriter( storageManager.notifyFileInfoCommitted(shuffleKey, filename, diskFileInfo) override def closeResource(): Unit = { - try if (channel != null) channel.close() + try if (fileChannelWriter != null) fileChannelWriter.close(false) catch { case e: IOException => logWarning( diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/file/chunk/compressed/ChunkBufferPoolSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/file/chunk/compressed/ChunkBufferPoolSuiteJ.java new file mode 100644 index 00000000000..4b05bae93b2 --- /dev/null +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/file/chunk/compressed/ChunkBufferPoolSuiteJ.java @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.storage.file.chunk.compressed; + +import static org.junit.Assert.*; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicInteger; + +import com.github.luben.zstd.Zstd; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkBufferPool; + +public class ChunkBufferPoolSuiteJ { + + // Use distinct prime-ish sizes per test so different tests never share a bucket. + // The pool instance is shared across tests; unique sizes prevent cross-test contamination. + private static final long SIZE_1 = 1009; + private static final long SIZE_2 = 2003; + private static final long SIZE_3 = 4001; + private static final long SIZE_4 = 8009; + private static final long SIZE_5 = 16007; + private static final long SIZE_6 = 32003; + private static final long SIZE_7 = 64007; + private static final long SIZE_8 = 128021; + + private static ChunkBufferPool POOL; + + @BeforeClass + public static void setUpClass() { + POOL = new ChunkBufferPool(new CelebornConf()); + } + + @AfterClass + public static void tearDownClass() { + if (POOL != null) { + POOL.close(); + } + } + + private ChunkBufferPool pool() { + return POOL; + } + + // ── Test 2: fresh acquire allocates buffers with correct capacities ───────── + + @Test + public void testFreshAcquireAllocatesCorrectCapacities() { + ChunkBufferPool.BufferPair pair = pool().acquire(SIZE_1); + try { + assertNotNull(pair.chunkBuffer); + assertNotNull(pair.compressedBuffer); + assertEquals(SIZE_1, pair.chunkBuffer.capacity()); + assertEquals((int) Zstd.compressBound(SIZE_1), pair.compressedBuffer.capacity()); + assertEquals(SIZE_1, pair.chunkSize); + } finally { + pool().release(pair); + } + } + + // ── Test 3: freshly acquired buffers start at position=0, limit=capacity ─── + + @Test + public void testFreshAcquireBuffersAreInClearState() { + ChunkBufferPool.BufferPair pair = pool().acquire(SIZE_2); + try { + assertEquals(0, pair.chunkBuffer.position()); + assertEquals((int) SIZE_2, pair.chunkBuffer.limit()); + assertEquals(0, pair.compressedBuffer.position()); + assertEquals((int) Zstd.compressBound(SIZE_2), pair.compressedBuffer.limit()); + } finally { + pool().release(pair); + } + } + + // ── Test 4: release then acquire returns the exact same BufferPair object ─── + + @Test + public void testReleaseAndAcquireReturnsSameObject() { + ChunkBufferPool.BufferPair pair = pool().acquire(SIZE_3); + pool().release(pair); + ChunkBufferPool.BufferPair reacquired = pool().acquire(SIZE_3); + try { + assertSame(pair, reacquired); + } finally { + pool().release(reacquired); + } + } + + // ── Test 5: reacquired buffers have position reset to 0 even if dirty ─────── + + @Test + public void testReacquiredBuffersAreClearedAfterDirtyUse() { + ChunkBufferPool.BufferPair pair = pool().acquire(SIZE_4); + + // Simulate dirty use: advance positions on both buffers. + pair.chunkBuffer.position(10); + pair.compressedBuffer.position(20); + + pool().release(pair); + + ChunkBufferPool.BufferPair reacquired = pool().acquire(SIZE_4); + try { + assertEquals( + "chunkBuffer position should be 0 after reacquire", 0, reacquired.chunkBuffer.position()); + assertEquals( + "compressedBuffer position should be 0 after reacquire", + 0, + reacquired.compressedBuffer.position()); + assertEquals((int) SIZE_4, reacquired.chunkBuffer.limit()); + assertEquals((int) Zstd.compressBound(SIZE_4), reacquired.compressedBuffer.limit()); + } finally { + pool().release(reacquired); + } + } + + // ── Test 6: different chunk sizes use independent buckets ────────────────── + + @Test + public void testDifferentSizesUseIndependentBuckets() { + ChunkBufferPool.BufferPair pairA = pool().acquire(SIZE_5); + ChunkBufferPool.BufferPair pairB = pool().acquire(SIZE_6); + + // Release A and B in separate buckets. + pool().release(pairA); + pool().release(pairB); + + // Reacquiring size A should give back pairA, not pairB. + ChunkBufferPool.BufferPair reacquiredA = pool().acquire(SIZE_5); + ChunkBufferPool.BufferPair reacquiredB = pool().acquire(SIZE_6); + try { + assertSame(pairA, reacquiredA); + assertSame(pairB, reacquiredB); + assertEquals(SIZE_5, reacquiredA.chunkSize); + assertEquals(SIZE_6, reacquiredB.chunkSize); + } finally { + pool().release(reacquiredA); + pool().release(reacquiredB); + } + } + + // ── Test 7: two acquires without intervening release allocate distinct pairs ─ + + @Test + public void testTwoConsecutiveAcquiresReturnDistinctPairs() { + ChunkBufferPool.BufferPair pair1 = pool().acquire(SIZE_7); + ChunkBufferPool.BufferPair pair2 = pool().acquire(SIZE_7); + try { + assertNotSame(pair1, pair2); + assertNotSame(pair1.chunkBuffer, pair2.chunkBuffer); + assertNotSame(pair1.compressedBuffer, pair2.compressedBuffer); + } finally { + pool().release(pair1); + pool().release(pair2); + } + } + + // ── Test 8: pool is LIFO — last released is first reacquired ───────────── + + @Test + public void testPoolIsLifo() { + ChunkBufferPool.BufferPair first = pool().acquire(SIZE_8); + ChunkBufferPool.BufferPair second = pool().acquire(SIZE_8); + + // Release first, then second — second is now at the head of the deque. + pool().release(first); + pool().release(second); + + ChunkBufferPool.BufferPair got1 = pool().acquire(SIZE_8); + ChunkBufferPool.BufferPair got2 = pool().acquire(SIZE_8); + try { + assertSame("LIFO: second released should be first reacquired", second, got1); + assertSame("LIFO: first released should be second reacquired", first, got2); + } finally { + pool().release(got1); + pool().release(got2); + } + } + + // ── Test 9: buffers are direct ByteBuffers ──────────────────────────────── + + @Test + public void testAcquiredBuffersAreDirect() { + ChunkBufferPool.BufferPair pair = pool().acquire(1024); + try { + assertTrue("chunkBuffer should be direct", pair.chunkBuffer.isDirect()); + assertTrue("compressedBuffer should be direct", pair.compressedBuffer.isDirect()); + } finally { + pool().release(pair); + } + } + + // ── Test 10: released pair's chunkSize matches the bucket it was acquired from + + @Test + public void testChunkSizeFieldIsPreserved() { + long size = 3072L; + ChunkBufferPool.BufferPair pair = pool().acquire(size); + assertEquals(size, pair.chunkSize); + pool().release(pair); + + ChunkBufferPool.BufferPair reacquired = pool().acquire(size); + assertEquals(size, reacquired.chunkSize); + pool().release(reacquired); + } + + // ── Test 11: data written before release is invisible after reacquire ──────── + + @Test + public void testWrittenDataNotVisibleAfterReacquire() { + ChunkBufferPool.BufferPair pair = pool().acquire(512); + // Write a known byte pattern into chunkBuffer. + pair.chunkBuffer.put((byte) 0xDE); + pair.chunkBuffer.put((byte) 0xAD); + pool().release(pair); + + ChunkBufferPool.BufferPair reacquired = pool().acquire(512); + try { + // position is 0 after clear — the buffer is logically empty regardless of stale bytes. + assertEquals(0, reacquired.chunkBuffer.position()); + assertEquals(512, reacquired.chunkBuffer.limit()); + // Writing from position 0 again must succeed without IndexOutOfBoundsException. + reacquired.chunkBuffer.put((byte) 0xFF); + assertEquals(1, reacquired.chunkBuffer.position()); + } finally { + pool().release(reacquired); + } + } + + // ── Test 12: concurrent acquire/release from multiple threads ───────────── + + @Test + public void testConcurrentAcquireRelease() throws Exception { + final long size = 256L; + final int threads = 8; + final int iterationsPerThread = 500; + final AtomicInteger errors = new AtomicInteger(0); + + ExecutorService executor = Executors.newFixedThreadPool(threads); + List> futures = new ArrayList<>(threads); + + for (int t = 0; t < threads; t++) { + futures.add( + executor.submit( + () -> { + for (int i = 0; i < iterationsPerThread; i++) { + ChunkBufferPool.BufferPair pair = null; + try { + pair = pool().acquire(size); + // Verify invariants under concurrent load. + if (pair.chunkBuffer.position() != 0) errors.incrementAndGet(); + if (pair.compressedBuffer.position() != 0) errors.incrementAndGet(); + if (pair.chunkBuffer.capacity() != (int) size) errors.incrementAndGet(); + // Simulate work: advance position. + pair.chunkBuffer.put((byte) i); + } finally { + if (pair != null) pool().release(pair); + } + } + })); + } + + executor.shutdown(); + assertTrue(executor.awaitTermination(30, TimeUnit.SECONDS)); + + for (Future f : futures) { + f.get(); // rethrow any exception from worker threads + } + assertEquals("No invariant violations expected under concurrent load", 0, errors.get()); + } + + // ── Test 13: pool depth grows as more pairs are released ────────────────── + + @Test + public void testPoolDepthGrowsWithMultipleReleases() { + final long size = 128L; + final int count = 5; + List pairs = new ArrayList<>(count); + + // Acquire 5 distinct pairs. + for (int i = 0; i < count; i++) { + pairs.add(pool().acquire(size)); + } + // Verify all are distinct. + for (int i = 0; i < count; i++) { + for (int j = i + 1; j < count; j++) { + assertNotSame(pairs.get(i), pairs.get(j)); + } + } + + // Release all 5 back. + for (ChunkBufferPool.BufferPair p : pairs) pool().release(p); + + // Acquire all 5 again — they should all come from the pool (no fresh allocations). + List reacquired = new ArrayList<>(count); + for (int i = 0; i < count; i++) reacquired.add(pool().acquire(size)); + try { + for (ChunkBufferPool.BufferPair r : reacquired) { + assertTrue( + "Reacquired pair should be one of the originally released pairs", pairs.contains(r)); + } + } finally { + for (ChunkBufferPool.BufferPair r : reacquired) pool().release(r); + } + } +} diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/file/chunk/compressed/ChunkCompressedFileChannelWriterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/file/chunk/compressed/ChunkCompressedFileChannelWriterSuiteJ.java new file mode 100644 index 00000000000..fce0421712e --- /dev/null +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/file/chunk/compressed/ChunkCompressedFileChannelWriterSuiteJ.java @@ -0,0 +1,598 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.storage.file.chunk.compressed; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; + +import java.io.*; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.*; + +import com.github.luben.zstd.ZstdInputStream; +import io.netty.buffer.*; +import org.junit.*; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.compression.ChunkCompressionContext; +import org.apache.celeborn.common.identity.UserIdentifier; +import org.apache.celeborn.common.meta.DiskFileInfo; +import org.apache.celeborn.common.meta.ReduceFileMeta; +import org.apache.celeborn.common.network.buffer.FileChunkBuffers; +import org.apache.celeborn.common.network.util.TransportConf; +import org.apache.celeborn.common.protocol.StorageInfo; +import org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkBufferPool; +import org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter; + +public class ChunkCompressedFileChannelWriterSuiteJ { + + private static ChunkBufferPool POOL; + + @BeforeClass + public static void setUpPool() { + POOL = new ChunkBufferPool(new CelebornConf()); + } + + @AfterClass + public static void tearDownPool() { + if (POOL != null) { + POOL.close(); + } + } + + // Small chunk size so tests can easily hit multi-chunk and large-record paths. + private static final int CHUNK_SIZE = 1024; + + private File tempFile; + private DiskFileInfo diskFileInfo; + private TransportConf transportConf; + + @Before + public void setup() throws Exception { + tempFile = File.createTempFile("chunk_writer_test", ".tmp"); + tempFile.deleteOnExit(); + diskFileInfo = makeDiskFileInfo(tempFile); + transportConf = mock(TransportConf.class); + } + + @After + public void teardown() { + tempFile.delete(); + } + + // ── Helpers ──────────────────────────────────────────────────────────────── + + private DiskFileInfo makeDiskFileInfo(File file) { + return new DiskFileInfo( + new UserIdentifier("tenant", "user"), + true, + new ReduceFileMeta(new ArrayList<>(Collections.singletonList(0L)), CHUNK_SIZE), + file.getAbsolutePath(), + StorageInfo.Type.HDD, + new ChunkCompressionContext(true, 1)); + } + + /** Wraps one or more strings as a CompositeByteBuf (one component per string). */ + private CompositeByteBuf composite(String... parts) { + CompositeByteBuf buf = Unpooled.compositeBuffer(); + for (String part : parts) { + buf.addComponent(true, Unpooled.wrappedBuffer(part.getBytes(StandardCharsets.UTF_8))); + } + return buf; + } + + /** Wraps a raw byte array as a single-component CompositeByteBuf. */ + private CompositeByteBuf compositeOf(byte[] data) { + CompositeByteBuf buf = Unpooled.compositeBuffer(); + buf.addComponent(true, Unpooled.wrappedBuffer(data)); + return buf; + } + + /** Returns a byte array of {@code count} repetitions of {@code s}. */ + private byte[] repeat(String s, int count) { + StringBuilder sb = new StringBuilder(s.length() * count); + for (int i = 0; i < count; i++) sb.append(s); + return sb.toString().getBytes(StandardCharsets.UTF_8); + } + + /** Decompresses one chunk's raw compressed bytes via ZstdInputStream. */ + private byte[] decompress(byte[] compressed) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (InputStream in = new ZstdInputStream(new ByteArrayInputStream(compressed))) { + byte[] tmp = new byte[4096]; + int n; + while ((n = in.read(tmp)) != -1) out.write(tmp, 0, n); + } + return out.toByteArray(); + } + + /** + * Reads every chunk from the file (using the updated ReduceFileMeta written by close()), + * decompresses compressed chunks and returns raw bytes for uncompressed ones (large records). + */ + private List readChunks() throws Exception { + FileChunkBuffers buffers = new FileChunkBuffers(diskFileInfo, transportConf); + int numChunks = buffers.numChunks(); + List chunkCompressed = diskFileInfo.getReduceFileMeta().getChunkCompressed(); + List result = new ArrayList<>(numChunks); + for (int i = 0; i < numChunks; i++) { + ByteBuffer buf = buffers.chunk(i, 0, Integer.MAX_VALUE).nioByteBuffer(); + byte[] data = new byte[buf.remaining()]; + buf.get(data); + boolean isCompressed = chunkCompressed != null && chunkCompressed.get(i); + result.add(isCompressed ? decompress(data) : data); + } + return result; + } + + /** Concatenates all decompressed chunks into one byte array. */ + private byte[] readAll() throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + for (byte[] chunk : readChunks()) out.write(chunk); + return out.toByteArray(); + } + + // ── Test 1: multiple small buffers — all fit in one chunk ────────────────── + + @Test + public void testMultipleSmallBuffersProduceOneChunk() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + writer.write(composite("hello", " ", "world"), true); + writer.write(composite("foo", "bar"), true); + writer.write(composite("!"), true); + writer.close(true); + + assertEquals(1, diskFileInfo.getReduceFileMeta().getNumChunks()); + assertArrayEquals("hello worldfoobar!".getBytes(StandardCharsets.UTF_8), readAll()); + } + + // ── Test 2: many small buffers accumulate until overflow forces a new chunk ─ + + @Test + public void testSmallBuffersOverflowIntoSecondChunk() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + // First write nearly fills the chunk buffer (CHUNK_SIZE - 10 bytes). + byte[] first = repeat("A", CHUNK_SIZE - 10); + // Second write (50 bytes) overflows → first is flushed as chunk 1, second becomes chunk 2. + byte[] second = repeat("B", 50); + + writer.write(compositeOf(first), true); + writer.write(compositeOf(second), true); + writer.close(true); + + assertEquals(2, diskFileInfo.getReduceFileMeta().getNumChunks()); + List chunks = readChunks(); + assertArrayEquals(first, chunks.get(0)); + assertArrayEquals(second, chunks.get(1)); + } + + // ── Test 3: three sequential small writes spanning three chunks ───────────── + + @Test + public void testThreeSmallWritesThreeChunks() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + byte[] a = repeat("A", CHUNK_SIZE - 5); // nearly fills chunk 1 + byte[] b = repeat("B", CHUNK_SIZE - 5); // overflows → chunk 1 = a, b nearly fills chunk 2 + byte[] c = repeat("C", 20); // overflows chunk 2 → chunk 2 = b, c is chunk 3 + + writer.write(compositeOf(a), true); + writer.write(compositeOf(b), true); + writer.write(compositeOf(c), true); + writer.close(true); + + assertEquals(3, diskFileInfo.getReduceFileMeta().getNumChunks()); + List chunks = readChunks(); + assertArrayEquals(a, chunks.get(0)); + assertArrayEquals(b, chunks.get(1)); + assertArrayEquals(c, chunks.get(2)); + } + + // ── Test 4: write that exactly fills chunkBuffer triggers flush on next write ─ + + @Test + public void testWriteExactlyChunkSizeThenMore() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + byte[] exact = repeat("E", CHUNK_SIZE); // fills chunkBuffer to the brim + byte[] more = "trailing".getBytes(StandardCharsets.UTF_8); + + writer.write(compositeOf(exact), true); // no flush yet — buffer is full but not overflowed + writer.write(compositeOf(more), true); // triggers flush of exact; more accumulates + writer.close(true); // flushes more + + assertEquals(2, diskFileInfo.getReduceFileMeta().getNumChunks()); + List chunks = readChunks(); + assertArrayEquals(exact, chunks.get(0)); + assertArrayEquals(more, chunks.get(1)); + } + + // ── Test 5: large record with no preceding data ───────────────────────────── + + @Test + public void testLargeRecordAlone() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + // 3× chunkSize — well over the large-record threshold. + byte[] large = repeat("X", CHUNK_SIZE * 3); + writer.write(compositeOf(large), true); + writer.close(true); + + assertEquals(1, diskFileInfo.getReduceFileMeta().getNumChunks()); + assertArrayEquals(large, readAll()); + } + + // ── Test 6: large record just one byte over the threshold ────────────────── + + @Test + public void testLargeRecordBoundary() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + byte[] boundary = repeat("B", CHUNK_SIZE + 1); + writer.write(compositeOf(boundary), true); + writer.close(true); + + assertEquals(1, diskFileInfo.getReduceFileMeta().getNumChunks()); + assertArrayEquals(boundary, readAll()); + } + + // ── Test 7: small write pending, then large record → 2 chunks ────────────── + + @Test + public void testPendingSmallFlushedBeforeLargeRecord() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + byte[] small = "pending".getBytes(StandardCharsets.UTF_8); + byte[] large = repeat("L", CHUNK_SIZE * 2); + + writer.write(compositeOf(small), true); // accumulates in chunkBuffer + writer.write(compositeOf(large), true); // flushes pending small → chunk 1; large → chunk 2 + writer.close(true); + + assertEquals(2, diskFileInfo.getReduceFileMeta().getNumChunks()); + List chunks = readChunks(); + assertArrayEquals(small, chunks.get(0)); + assertArrayEquals(large, chunks.get(1)); + } + + // ── Test 8: two consecutive large records → 2 chunks ────────────────────── + + @Test + public void testTwoLargeRecords() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + byte[] large1 = repeat("P", CHUNK_SIZE * 2); + byte[] large2 = repeat("Q", CHUNK_SIZE * 3); + + writer.write(compositeOf(large1), true); + writer.write(compositeOf(large2), true); + writer.close(true); + + assertEquals(2, diskFileInfo.getReduceFileMeta().getNumChunks()); + List chunks = readChunks(); + assertArrayEquals(large1, chunks.get(0)); + assertArrayEquals(large2, chunks.get(1)); + } + + // ── Test 9: interleaved small / large / small → 3 chunks ────────────────── + + @Test + public void testSmallLargeSmallProducesThreeChunks() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + byte[] small1 = "before".getBytes(StandardCharsets.UTF_8); + byte[] large = repeat("M", CHUNK_SIZE * 2); + byte[] small2 = "after".getBytes(StandardCharsets.UTF_8); + + writer.write(compositeOf(small1), true); // accumulates → pending + writer.write(compositeOf(large), true); // flushes small1 as chunk 1; large → chunk 2 + writer.write(compositeOf(small2), true); // accumulates + writer.close(true); // flushes small2 as chunk 3 + + assertEquals(3, diskFileInfo.getReduceFileMeta().getNumChunks()); + List chunks = readChunks(); + assertArrayEquals(small1, chunks.get(0)); + assertArrayEquals(large, chunks.get(1)); + assertArrayEquals(small2, chunks.get(2)); + } + + // ── Test 10: large record followed by small writes ───────────────────────── + + @Test + public void testLargeRecordThenSmallWrites() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + byte[] large = repeat("R", CHUNK_SIZE * 2); + byte[] small = "tail".getBytes(StandardCharsets.UTF_8); + + writer.write(compositeOf(large), true); // large → chunk 1 + writer.write(compositeOf(small), true); // accumulates + writer.close(true); // flushes small → chunk 2 + + assertEquals(2, diskFileInfo.getReduceFileMeta().getNumChunks()); + List chunks = readChunks(); + assertArrayEquals(large, chunks.get(0)); + assertArrayEquals(small, chunks.get(1)); + } + + // ── Test 11: no writes at all → 0 chunks ────────────────────────────────── + + @Test + public void testNoWritesProducesZeroChunks() throws IOException { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + writer.close(true); + + assertEquals(0, diskFileInfo.getReduceFileMeta().getNumChunks()); + assertEquals(0L, diskFileInfo.getFileLength()); + } + + // ── Test 12: explicit compressAndFlush mid-stream splits chunks ───────────── + + @Test + public void testExplicitCompressAndFlushSplitsChunks() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + byte[] part1 = "first part".getBytes(StandardCharsets.UTF_8); + byte[] part2 = "second part".getBytes(StandardCharsets.UTF_8); + + writer.write(compositeOf(part1), true); + writer.compressAndFlush(); // explicitly close chunk 1 + writer.write(compositeOf(part2), true); + writer.close(true); // closes chunk 2 + + assertEquals(2, diskFileInfo.getReduceFileMeta().getNumChunks()); + List chunks = readChunks(); + assertArrayEquals(part1, chunks.get(0)); + assertArrayEquals(part2, chunks.get(1)); + } + + // ── Test 13: compressAndFlush on empty buffer is a no-op ────────────────── + + @Test + public void testCompressAndFlushOnEmptyBufferIsNoop() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + writer.compressAndFlush(); // empty — should not add a chunk + writer.compressAndFlush(); // again + writer.write(composite("data"), true); + writer.compressAndFlush(); // flushes "data" as chunk 1 + writer.compressAndFlush(); // empty again — should not add a chunk + writer.close(true); + + assertEquals(1, diskFileInfo.getReduceFileMeta().getNumChunks()); + assertArrayEquals("data".getBytes(StandardCharsets.UTF_8), readAll()); + } + + // ── Test 14: fileLength (bytesFlushed) reflects compressed file size ──────── + + @Test + public void testFileLengthMatchesActualFileSize() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + writer.write(composite("hello", " ", "world"), true); + writer.write(compositeOf(repeat("Z", CHUNK_SIZE * 2)), true); + writer.close(true); + + assertEquals(tempFile.length(), diskFileInfo.getFileLength()); + assertTrue("File should be non-empty", tempFile.length() > 0); + } + + // ── Test 15: composite buffer with many small components ────────────────── + + @Test + public void testCompositeBufferWithManyComponents() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + String[] words = {"alpha", " ", "beta", " ", "gamma", " ", "delta", " ", "epsilon"}; + writer.write(composite(words), true); + writer.close(true); + + String expected = String.join("", words); + assertEquals(1, diskFileInfo.getReduceFileMeta().getNumChunks()); + assertEquals(expected, new String(readAll(), StandardCharsets.UTF_8)); + } + + // ── Test 16: chunk offsets are strictly increasing ───────────────────────── + + @Test + public void testChunkOffsetsAreStrictlyIncreasing() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + writer.write(compositeOf(repeat("A", CHUNK_SIZE - 10)), true); + writer.write(compositeOf(repeat("B", 50)), true); // triggers chunk 1 flush + writer.write(compositeOf(repeat("C", CHUNK_SIZE * 2)), true); // large → chunk 3 + writer.close(true); + + List offsets = diskFileInfo.getReduceFileMeta().getChunkOffsets(); + assertEquals(4, offsets.size()); // [0, end1, end2, end3] + assertEquals(0L, (long) offsets.get(0)); + for (int i = 1; i < offsets.size(); i++) { + assertTrue( + "offset[" + i + "] must be > offset[" + (i - 1) + "]", + offsets.get(i) > offsets.get(i - 1)); + } + // Last offset must equal the actual file size. + assertEquals(tempFile.length(), (long) offsets.get(offsets.size() - 1)); + } + + // ── Test 17: large record with high-entropy data compresses and round-trips ─ + + @Test + public void testLargeRecordHighEntropyData() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = + new org.apache.celeborn.service.deploy.worker.file.chunk.compressed + .ChunkCompressedFileChannelWriter( + diskFileInfo, CHUNK_SIZE, ChunkCompressionContext.DEFAULT_COMPRESSION_LEVEL, POOL); + + // Pseudo-random high-entropy payload: harder to compress, exercises ZSTD's full path. + byte[] highEntropy = new byte[CHUNK_SIZE * 4]; + new java.util.Random(42).nextBytes(highEntropy); + + writer.write(compositeOf(highEntropy), true); + writer.close(true); + + assertEquals(1, diskFileInfo.getReduceFileMeta().getNumChunks()); + assertArrayEquals(highEntropy, readAll()); + } + + // ── Test 18: multiple small writes, one large record, more small writes ───── + // Exercises the three-phase pattern: + // chunk 1 = accumulated smalls flushed before the large record + // chunk 2 = the large record as its own ZSTD frame + // chunk 3 = trailing smalls flushed on close + // This is the canonical regression test for the "Unknown frame descriptor" bug + // where ZstdInputStream was recreated mid-frame on each fillBuffer() call. + + @Test + public void testMultipleSmallsLargeMultipleSmallsRoundTrip() throws Exception { + org.apache.celeborn.service.deploy.worker.file.chunk.compressed.ChunkCompressedFileChannelWriter + writer = new ChunkCompressedFileChannelWriter(diskFileInfo, CHUNK_SIZE, 3, POOL); + + // Phase 1: several small writes that accumulate together into chunk 1. + // Total = 6+6+1011 = 1023 bytes — just under CHUNK_SIZE (1024). + byte[] s1 = "alpha-".getBytes(StandardCharsets.UTF_8); // 6 bytes + byte[] s2 = "beta--".getBytes(StandardCharsets.UTF_8); // 6 bytes + byte[] s3 = repeat("C", CHUNK_SIZE - 13); // 1011 bytes + + // Phase 2: large record (3× chunkSize). + // Arriving here triggers compressAndFlush() for the pending smalls (chunk 1), + // then flushLargeRecord() writes the large data as chunk 2. + byte[] large = repeat("L", CHUNK_SIZE * 3); + + // Phase 3: a few more small writes that accumulate into chunk 3. + byte[] s4 = "delta-".getBytes(StandardCharsets.UTF_8); // 6 bytes + byte[] s5 = repeat("E", CHUNK_SIZE / 2); // 512 bytes + byte[] s6 = "zeta--".getBytes(StandardCharsets.UTF_8); // 6 bytes + + writer.write(compositeOf(s1), true); + writer.write(compositeOf(s2), true); + writer.write(compositeOf(s3), true); + writer.write(compositeOf(large), true); + writer.write(compositeOf(s4), true); + writer.write(compositeOf(s5), true); + writer.write(compositeOf(s6), true); + writer.close(true); + + assertEquals(3, diskFileInfo.getReduceFileMeta().getNumChunks()); + + List chunks = readChunks(); + + // Verify per-chunk content. + ByteArrayOutputStream expectedChunk1 = new ByteArrayOutputStream(); + expectedChunk1.write(s1); + expectedChunk1.write(s2); + expectedChunk1.write(s3); + assertArrayEquals( + "chunk 1 must contain all leading small writes", + expectedChunk1.toByteArray(), + chunks.get(0)); + + assertArrayEquals("chunk 2 must contain the large record verbatim", large, chunks.get(1)); + + ByteArrayOutputStream expectedChunk3 = new ByteArrayOutputStream(); + expectedChunk3.write(s4); + expectedChunk3.write(s5); + expectedChunk3.write(s6); + assertArrayEquals( + "chunk 3 must contain all trailing small writes", + expectedChunk3.toByteArray(), + chunks.get(2)); + + // Verify the flat concatenation across all chunks matches the original write order. + ByteArrayOutputStream all = new ByteArrayOutputStream(); + all.write(s1); + all.write(s2); + all.write(s3); + all.write(large); + all.write(s4); + all.write(s5); + all.write(s6); + assertArrayEquals( + "readAll() must reproduce all data in write order", all.toByteArray(), readAll()); + } +} diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/file/chunk/compressed/MmapMemoryManagerSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/file/chunk/compressed/MmapMemoryManagerSuiteJ.java new file mode 100644 index 00000000000..71b9296fb7d --- /dev/null +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/file/chunk/compressed/MmapMemoryManagerSuiteJ.java @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.storage.file.chunk.compressed; + +import static org.junit.Assert.*; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.celeborn.service.deploy.worker.file.chunk.compressed.MmapMemoryManager; + +public class MmapMemoryManagerSuiteJ { + + private static MmapMemoryManager MANAGER; + + @BeforeClass + public static void setUpClass() { + MANAGER = new MmapMemoryManager(System.getProperty("java.io.tmpdir") + "/celeborn-mmap-test"); + } + + @AfterClass + public static void tearDownClass() { + if (MANAGER != null) { + MANAGER.close(); + } + } + + private MmapMemoryManager manager() { + return MANAGER; + } + + // ── Test 1: returned buffer is a direct ByteBuffer ───────────────────────── + + @Test + public void testAllocatedBufferIsDirect() { + assertTrue(manager().allocateBuffer(128).isDirect()); + } + + // ── Test 2: capacity equals the requested size ───────────────────────────── + + @Test + public void testAllocatedBufferCapacityMatchesRequestedSize() { + int[] sizes = {1, 7, 64, 256, 1024, 8192, 65536}; + for (int size : sizes) { + ByteBuffer buf = manager().allocateBuffer(size); + assertEquals("capacity for size " + size, size, buf.capacity()); + } + } + + // ── Test 3: slice starts at position=0, limit=capacity ───────────────────── + + @Test + public void testAllocatedBufferIsInClearState() { + int size = 512; + ByteBuffer buf = manager().allocateBuffer(size); + assertEquals(0, buf.position()); + assertEquals(size, buf.limit()); + assertEquals(size, buf.remaining()); + } + + // ── Test 4: buffer is writable — put advances position ────────────────────── + + @Test + public void testAllocatedBufferIsWritable() { + ByteBuffer buf = manager().allocateBuffer(64); + buf.put((byte) 0xAB); + buf.put((byte) 0xCD); + assertEquals(2, buf.position()); + } + + // ── Test 5: data round-trips correctly through the buffer ────────────────── + + @Test + public void testDataRoundTrips() { + int size = 1024; + byte[] data = new byte[size]; + new Random(42).nextBytes(data); + + ByteBuffer buf = manager().allocateBuffer(size); + buf.put(data); + assertEquals(0, buf.remaining()); + + buf.flip(); + byte[] readBack = new byte[size]; + buf.get(readBack); + assertArrayEquals(data, readBack); + } + + // ── Test 6: consecutive allocations do not overlap ───────────────────────── + // Write distinct patterns to two buffers and verify neither corrupts the other. + + @Test + public void testConsecutiveAllocationsDoNotOverlap() { + int size = 200; + ByteBuffer buf1 = manager().allocateBuffer(size); + ByteBuffer buf2 = manager().allocateBuffer(size); + + for (int i = 0; i < size; i++) buf1.put((byte) 0xAA); + for (int i = 0; i < size; i++) buf2.put((byte) 0xBB); + + buf1.flip(); + while (buf1.hasRemaining()) assertEquals((byte) 0xAA, buf1.get()); + + buf2.flip(); + while (buf2.hasRemaining()) assertEquals((byte) 0xBB, buf2.get()); + } + + // ── Test 7: adjacent writes don't spill into the neighboring allocation ───── + + @Test + public void testWriteToOneBufferDoesNotSpillIntoAdjacentBuffer() { + int size = 32; + ByteBuffer a = manager().allocateBuffer(size); + ByteBuffer b = manager().allocateBuffer(size); + + // Write 0xFF into every byte of a. + for (int i = 0; i < size; i++) a.put((byte) 0xFF); + + // Overwrite all of b with 0x00. + for (int i = 0; i < size; i++) b.put((byte) 0x00); + + // a must still contain 0xFF — b's writes must not have reached a. + a.flip(); + for (int i = 0; i < size; i++) { + assertEquals("byte " + i + " in a should be 0xFF after b was written", (byte) 0xFF, a.get()); + } + } + + // ── Test 8: buffer can be filled to exactly its capacity without overflow ─── + + @Test + public void testBufferCanBeFilledToCapacity() { + int size = 256; + ByteBuffer buf = manager().allocateBuffer(size); + byte[] full = new byte[size]; + new Random(7).nextBytes(full); + + buf.put(full); // must not throw + assertEquals(0, buf.remaining()); // buffer is exactly full + } + + // ── Test 9: many allocations of varying sizes all have correct properties ── + + @Test + public void testManyAllocationsOfVariousSizes() { + int[] sizes = {1, 3, 17, 100, 512, 4096, 32768}; + for (int size : sizes) { + ByteBuffer buf = manager().allocateBuffer(size); + assertEquals("capacity=" + size, size, buf.capacity()); + assertEquals("position=" + size, 0, buf.position()); + assertEquals("limit=" + size, size, buf.limit()); + assertTrue("direct=" + size, buf.isDirect()); + } + } + + // ── Test 10: sequential pattern survives put/get round-trip ──────────────── + + @Test + public void testSequentialPatternSurvivesRoundTrip() { + int size = 512; + ByteBuffer buf = manager().allocateBuffer(size); + + for (int i = 0; i < size; i++) buf.put((byte) (i & 0xFF)); + + buf.flip(); + for (int i = 0; i < size; i++) { + assertEquals("byte " + i, (byte) (i & 0xFF), buf.get()); + } + } + + // ── Test 11: concurrent allocations are thread-safe ──────────────────────── + + @Test + public void testConcurrentAllocationsAreSafe() throws Exception { + int threads = 8; + int perThread = 200; + int bufSize = 128; + AtomicInteger violations = new AtomicInteger(0); + + ExecutorService executor = Executors.newFixedThreadPool(threads); + List> futures = new ArrayList<>(threads); + + for (int t = 0; t < threads; t++) { + futures.add( + executor.submit( + () -> { + for (int i = 0; i < perThread; i++) { + ByteBuffer buf = manager().allocateBuffer(bufSize); + if (!buf.isDirect()) violations.incrementAndGet(); + if (buf.capacity() != bufSize) violations.incrementAndGet(); + if (buf.position() != 0) violations.incrementAndGet(); + if (buf.limit() != bufSize) violations.incrementAndGet(); + // Write and read back a sentinel byte to exercise the mapping. + buf.put((byte) 0x5A); + buf.flip(); + if (buf.get() != (byte) 0x5A) violations.incrementAndGet(); + } + })); + } + + executor.shutdown(); + assertTrue(executor.awaitTermination(30, TimeUnit.SECONDS)); + for (Future f : futures) f.get(); // surface any thread-level exception + assertEquals("no invariant violations under concurrent load", 0, violations.get()); + } + + // ── Test 12: concurrent writes to different buffers don't corrupt each other ─ + + @Test + public void testConcurrentWritesToDistinctBuffersAreIsolated() throws Exception { + int threads = 4; + int size = 256; + + // Pre-allocate one buffer per thread. + List bufs = new ArrayList<>(threads); + for (int i = 0; i < threads; i++) bufs.add(manager().allocateBuffer(size)); + + ExecutorService executor = Executors.newFixedThreadPool(threads); + List> futures = new ArrayList<>(threads); + + for (int t = 0; t < threads; t++) { + final ByteBuffer buf = bufs.get(t); + final byte marker = (byte) (t + 1); + futures.add( + executor.submit( + () -> { + for (int i = 0; i < size; i++) buf.put(marker); + buf.flip(); + for (int i = 0; i < size; i++) { + if (buf.get() != marker) return false; + } + return true; + })); + } + + executor.shutdown(); + assertTrue(executor.awaitTermination(30, TimeUnit.SECONDS)); + for (Future f : futures) { + assertTrue("each thread's buffer should contain only its own marker", f.get()); + } + } + + // ── Test 13: close() resets state; subsequent allocations succeed ──────────── + + @Test + public void testCloseResetsStateAndNewAllocationsSucceed() { + MmapMemoryManager local = + new MmapMemoryManager(System.getProperty("java.io.tmpdir") + "/celeborn-mmap-test-close"); + try { + ByteBuffer before = local.allocateBuffer(64); + assertNotNull(before); + + local.close(); + + // After close, the next allocation must create a new backing file and succeed. + ByteBuffer after = local.allocateBuffer(256); + assertNotNull(after); + assertEquals(256, after.capacity()); + assertEquals(0, after.position()); + assertEquals(256, after.limit()); + assertTrue(after.isDirect()); + + after.put((byte) 0x42); + after.flip(); + assertEquals((byte) 0x42, after.get()); + } finally { + local.close(); + } + } +} diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskMapPartitionDataWriterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskMapPartitionDataWriterSuiteJ.java index 5f750f89540..0dd5ee06984 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskMapPartitionDataWriterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskMapPartitionDataWriterSuiteJ.java @@ -39,6 +39,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.compression.ChunkCompressionContext; import org.apache.celeborn.common.identity.UserIdentifier; import org.apache.celeborn.common.network.util.NettyUtils; import org.apache.celeborn.common.network.util.TransportConf; @@ -133,7 +134,8 @@ public void testMultiThreadWrite() throws IOException { userIdentifier, PartitionType.MAP, false, - false); + false, + ChunkCompressionContext.disabled()); PartitionDataWriter fileWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java index 536b4aab364..6999ced4376 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/local/DiskReducePartitionDataWriterSuiteJ.java @@ -50,6 +50,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.compression.ChunkCompressionContext; import org.apache.celeborn.common.identity.UserIdentifier; import org.apache.celeborn.common.meta.DiskFileInfo; import org.apache.celeborn.common.meta.FileInfo; @@ -280,7 +281,8 @@ public void testMultiThreadWrite() throws IOException, ExecutionException, Inter userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); PartitionDataWriter partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( @@ -334,7 +336,8 @@ public void testMultiThreadWriteDuringClose() userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); PartitionDataWriter partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( @@ -389,7 +392,8 @@ public void testAfterStressfulWriteWillReadCorrect() userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); PartitionDataWriter partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( @@ -459,7 +463,8 @@ public void testWriteAndChunkRead() throws Exception { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); PartitionDataWriter partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( @@ -578,7 +583,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); PartitionDataWriter partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( @@ -610,7 +616,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( @@ -642,7 +649,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( @@ -673,7 +681,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( @@ -706,7 +715,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( @@ -738,7 +748,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( @@ -772,7 +783,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( @@ -805,7 +817,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( @@ -839,7 +852,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareDiskFileTestEnvironment( diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java index 92d0fd5d416..f6d32c6e269 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/memory/MemoryReducePartitionDataWriterSuiteJ.java @@ -41,6 +41,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.compression.ChunkCompressionContext; import org.apache.celeborn.common.identity.UserIdentifier; import org.apache.celeborn.common.meta.*; import org.apache.celeborn.common.metrics.source.AbstractSource; @@ -295,7 +296,8 @@ public void testMultiThreadWrite() throws IOException, ExecutionException, Inter userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); PartitionDataWriter partitionDataWriter = new PartitionDataWriter( @@ -350,7 +352,8 @@ public void testMultiThreadWriteDuringClose() userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); PartitionDataWriter partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareMemoryFileTestEnvironment( @@ -406,7 +409,8 @@ public void testAfterStressfulWriteWillReadCorrect() userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); PartitionDataWriter partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareMemoryEvictEnvironment( @@ -467,7 +471,8 @@ public void testWriteAndChunkRead() throws Exception { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); PartitionDataWriter partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareMemoryEvictEnvironment( @@ -555,7 +560,8 @@ public void testEvictAndChunkRead() throws Exception { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); PartitionDataWriter partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareMemoryEvictEnvironment( @@ -685,7 +691,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); PartitionDataWriter partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareMemoryFileTestEnvironment( @@ -717,7 +724,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareMemoryFileTestEnvironment( @@ -749,7 +757,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareMemoryFileTestEnvironment( @@ -780,7 +789,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareMemoryFileTestEnvironment( @@ -813,7 +823,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareMemoryFileTestEnvironment( @@ -845,7 +856,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareMemoryFileTestEnvironment( @@ -879,7 +891,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareMemoryFileTestEnvironment( @@ -912,7 +925,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareMemoryFileTestEnvironment( @@ -946,7 +960,8 @@ public void testChunkSize() throws IOException { userIdentifier, PartitionType.REDUCE, false, - false); + false, + ChunkCompressionContext.disabled()); partitionDataWriter = new PartitionDataWriter( PartitionDataWriterSuiteUtils.prepareMemoryFileTestEnvironment( diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ChunkCompressedReadWriteTest.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ChunkCompressedReadWriteTest.scala new file mode 100644 index 00000000000..04dca22f0da --- /dev/null +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ChunkCompressedReadWriteTest.scala @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.cluster + +import java.io.ByteArrayOutputStream +import java.nio.charset.StandardCharsets +import java.util.UUID + +import scala.collection.mutable + +import org.apache.commons.lang3.RandomStringUtils +import org.junit.Assert +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl} +import org.apache.celeborn.client.read.MetricsCallback +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.identity.UserIdentifier +import org.apache.celeborn.common.internal.Logging +import org.apache.celeborn.common.protocol.CompressionCodec +import org.apache.celeborn.service.deploy.MiniClusterFeature + +/** + * End-to-end read/write tests with chunk-level compression enabled + * (celeborn.chunk.compression.enabled = true). + * + * Each test runs against a live mini-cluster, pushes several batches of data, + * commits the shuffle, then reads back and verifies byte-for-byte correctness. + * Scenarios cover: + * - Different batch-level codecs (NONE, LZ4, ZSTD) layered under chunk ZSTD + * - Small chunk size to exercise multi-chunk boundary handling + * - Local-read path (LocalPartitionReader) with chunk compression + */ +class ChunkCompressedReadWriteTest extends AnyFunSuite + with Logging with MiniClusterFeature with BeforeAndAfterAll { + + var masterPort = 0 + + override def beforeAll(): Unit = { + logInfo("ChunkCompressedReadWriteTest: starting mini-cluster") + val (m, _) = setupMiniClusterWithRandomPorts() + masterPort = m.conf.masterPort + } + + override def afterAll(): Unit = { + logInfo("ChunkCompressedReadWriteTest: stopping mini-cluster") + shutdownMiniCluster() + } + + // ── Core helper ───────────────────────────────────────────────────────────── + + /** + * Pushes four variable-length data blobs to the cluster (two via pushData, + * two via mergeData), commits, then reads back all bytes from partition 0 of + * shuffle 1 and asserts that both the total length and per-blob content match. + * + * @param codec batch-level compression codec (may be NONE) + * @param readLocal whether to use the local-read short-circuit path + * @param shuffleChunkSz chunk size for the chunk-compressed writer (e.g. "8k", "1m") + */ + private def doReadWriteWithChunkCompression( + codec: CompressionCodec, + readLocal: Boolean = false, + shuffleChunkSz: String = "8m"): Unit = { + + val APP = s"app-chunk-${codec.name}-local$readLocal" + UUID.randomUUID() + + val clientConf = new CelebornConf() + .set(CelebornConf.MASTER_ENDPOINTS.key, s"localhost:$masterPort") + // Enable chunk-level ZSTD compression on the worker writer side and + // the ZSTD decompression in CelebornInputStream on the reader side. + .set(CelebornConf.CHUNK_COMPRESSION_ENABLED.key, "true") + // Batch-level codec is independent — NONE means raw batches inside the + // ZSTD chunk; LZ4/ZSTD means batch-compressed payloads inside the chunk. + .set(CelebornConf.SHUFFLE_COMPRESSION_CODEC.key, codec.name) + .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "true") + .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K") + .set(CelebornConf.READ_LOCAL_SHUFFLE_FILE, readLocal) + // Controls the accumulation buffer in ChunkCompressedFileChannelWriter. + .set(CelebornConf.SHUFFLE_CHUNK_SIZE.key, shuffleChunkSz) + .set("celeborn.data.io.numConnectionsPerPeer", "1") + + val lifecycleManager = new LifecycleManager(APP, clientConf) + val shuffleClient = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) + shuffleClient.setupLifecycleManagerRef(lifecycleManager.self) + + try { + // ── Write phase ────────────────────────────────────────────────────── + // Each string is prefixed with a 6-char sentinel so we can identify + // and verify individual blobs in the combined read output. + val dataPrefix = Array("000000", "111111", "222222", "333333") + val dataPrefixMap = new mutable.HashMap[String, String] + + val STR1 = dataPrefix(0) + RandomStringUtils.random(1024) + dataPrefixMap.put(dataPrefix(0), STR1) + val DATA1 = STR1.getBytes(StandardCharsets.UTF_8) + val dataSize1 = shuffleClient.pushData(1, 0, 0, 0, DATA1, 0, DATA1.length, 1, 1) + logInfo(s"pushData #1 size=$dataSize1") + + val STR2 = dataPrefix(1) + RandomStringUtils.random(32 * 1024) + dataPrefixMap.put(dataPrefix(1), STR2) + val DATA2 = STR2.getBytes(StandardCharsets.UTF_8) + val dataSize2 = shuffleClient.pushData(1, 0, 0, 0, DATA2, 0, DATA2.length, 1, 1) + logInfo(s"pushData #2 size=$dataSize2") + + val STR3 = dataPrefix(2) + RandomStringUtils.random(32 * 1024) + dataPrefixMap.put(dataPrefix(2), STR3) + val DATA3 = STR3.getBytes(StandardCharsets.UTF_8) + shuffleClient.mergeData(1, 0, 0, 0, DATA3, 0, DATA3.length, 1, 1) + + val STR4 = dataPrefix(3) + RandomStringUtils.random(16 * 1024) + dataPrefixMap.put(dataPrefix(3), STR4) + val DATA4 = STR4.getBytes(StandardCharsets.UTF_8) + shuffleClient.mergeData(1, 0, 0, 0, DATA4, 0, DATA4.length, 1, 1) + + shuffleClient.pushMergedData(1, 0, 0) + Thread.sleep(1000) + shuffleClient.mapperEnd(1, 0, 0, 1, 1) + + // ── Read phase ────────────────────────────────────────────────────── + val metricsCallback = new MetricsCallback { + override def incBytesRead(bytesWritten: Long): Unit = {} + override def incReadTime(time: Long): Unit = {} + } + + val inputStream = shuffleClient.readPartition( + 1, + 1, + 0, + 0, + 0, + 0, + Integer.MAX_VALUE, + null, + null, + null, + null, + null, + null, + metricsCallback, + true) + + val outputStream = new ByteArrayOutputStream() + var b = inputStream.read() + while (b != -1) { + outputStream.write(b) + b = inputStream.read() + } + + val readBytes = outputStream.toByteArray + val expectedTotal = DATA1.length + DATA2.length + DATA3.length + DATA4.length + + // ── Assertions ─────────────────────────────────────────────────────── + Assert.assertEquals( + s"Total byte count mismatch (codec=$codec, readLocal=$readLocal, chunkSz=$shuffleChunkSz)", + expectedTotal, + readBytes.length) + + val readStringMap = extractBlobs(readBytes, dataPrefix, dataPrefixMap) + for ((prefix, actual) <- readStringMap) { + Assert.assertEquals( + s"Content mismatch for blob '$prefix'", + dataPrefixMap(prefix), + actual) + } + + } finally { + Thread.sleep(3000L) + shuffleClient.shutdown() + lifecycleManager.rpcEnv.shutdown() + } + } + + /** + * Rebuilds the per-blob strings from the flat read output by scanning for + * known 6-char prefixes and extracting the expected number of characters. + */ + private def extractBlobs( + readBytes: Array[Byte], + prefixes: Array[String], + prefixMap: mutable.HashMap[String, String]): mutable.HashMap[String, String] = { + var remaining = new String(readBytes, StandardCharsets.UTF_8) + val result = new mutable.HashMap[String, String] + while (remaining.nonEmpty) { + prefixes.find(remaining.startsWith) match { + case Some(prefix) => + val len = prefixMap(prefix).length + result.put(prefix, remaining.substring(0, len)) + remaining = remaining.substring(len) + case None => + remaining = "" + } + } + result + } + + /** + * Pushes data in three phases with a 2 KB chunk size: + * Phase 1 — 3 small batches (500 B each; 516 B on disk with header → 1548 B total) + * These accumulate in the chunk buffer (< 2048 B) without flushing. + * Phase 2 — 1 large batch (3000 B; 3016 B on disk > 2048 B chunk size) + * Arrival flushes phase-1 data as chunk 1 via compressAndFlush(), then + * writes the large batch as its own chunk 2 via flushLargeRecord(). + * Phase 3 — 3 more small batches (same size; 1548 B total) + * Accumulate and are flushed as chunk 3 on close(). + * + * This exercises: + * - Multiple batches compressed together in a single ZSTD chunk (chunks 1 and 3), + * which requires ZstdInputStream to be kept alive across fillBuffer() calls. + * - The large-record path where one batch is larger than the chunk size. + */ + private def doSmallLargeSmallReadWrite(): Unit = { + val APP = "app-chunk-small-large-small" + + val clientConf = new CelebornConf() + .set(CelebornConf.MASTER_ENDPOINTS.key, s"localhost:$masterPort") + .set(CelebornConf.CHUNK_COMPRESSION_ENABLED.key, "true") + .set(CelebornConf.SHUFFLE_COMPRESSION_CODEC.key, CompressionCodec.NONE.name) + .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "true") + .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K") + // 2 KB chunk size: small batches (516 B each) accumulate; large batch (3016 B) overflows. + .set(CelebornConf.SHUFFLE_CHUNK_SIZE.key, "2k") + .set("celeborn.data.io.numConnectionsPerPeer", "1") + + val lifecycleManager = new LifecycleManager(APP, clientConf) + val shuffleClient = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) + shuffleClient.setupLifecycleManagerRef(lifecycleManager.self) + + try { + // 6-char alphanumeric prefixes — unique, non-overlapping. + // RandomStringUtils.random(N, true, true) → exactly N ASCII bytes. + val dataPrefix = Array("SMLL1-", "SMLL2-", "SMLL3-", "LARGE-", "SMLL4-", "SMLL5-", "SMLL6-") + val dataPrefixMap = new mutable.HashMap[String, String] + + // Phase 1: three small batches (500 B each → 516 B on disk with 16-B header). + // Combined 1548 B < 2048 B chunk size — all sit in the chunk buffer together. + val STR1 = dataPrefix(0) + RandomStringUtils.random(494, true, true) + dataPrefixMap.put(dataPrefix(0), STR1) + val DATA1 = STR1.getBytes(StandardCharsets.UTF_8) + shuffleClient.pushData(1, 0, 0, 0, DATA1, 0, DATA1.length, 1, 1) + + val STR2 = dataPrefix(1) + RandomStringUtils.random(494, true, true) + dataPrefixMap.put(dataPrefix(1), STR2) + val DATA2 = STR2.getBytes(StandardCharsets.UTF_8) + shuffleClient.pushData(1, 0, 0, 0, DATA2, 0, DATA2.length, 1, 1) + + val STR3 = dataPrefix(2) + RandomStringUtils.random(494, true, true) + dataPrefixMap.put(dataPrefix(2), STR3) + val DATA3 = STR3.getBytes(StandardCharsets.UTF_8) + shuffleClient.pushData(1, 0, 0, 0, DATA3, 0, DATA3.length, 1, 1) + + // Phase 2: one large batch (3000 B → 3016 B on disk > 2048 B chunk size). + // Triggers compressAndFlush() of the phase-1 smalls as chunk 1, + // then flushLargeRecord() writes this batch alone as chunk 2. + val STR4 = dataPrefix(3) + RandomStringUtils.random(2994, true, true) + dataPrefixMap.put(dataPrefix(3), STR4) + val DATA4 = STR4.getBytes(StandardCharsets.UTF_8) + shuffleClient.pushData(1, 0, 0, 0, DATA4, 0, DATA4.length, 1, 1) + + // Phase 3: three more small batches that accumulate as chunk 3 and are flushed on close(). + val STR5 = dataPrefix(4) + RandomStringUtils.random(494, true, true) + dataPrefixMap.put(dataPrefix(4), STR5) + val DATA5 = STR5.getBytes(StandardCharsets.UTF_8) + shuffleClient.pushData(1, 0, 0, 0, DATA5, 0, DATA5.length, 1, 1) + + val STR6 = dataPrefix(5) + RandomStringUtils.random(494, true, true) + dataPrefixMap.put(dataPrefix(5), STR6) + val DATA6 = STR6.getBytes(StandardCharsets.UTF_8) + shuffleClient.pushData(1, 0, 0, 0, DATA6, 0, DATA6.length, 1, 1) + + val STR7 = dataPrefix(6) + RandomStringUtils.random(494, true, true) + dataPrefixMap.put(dataPrefix(6), STR7) + val DATA7 = STR7.getBytes(StandardCharsets.UTF_8) + shuffleClient.pushData(1, 0, 0, 0, DATA7, 0, DATA7.length, 1, 1) + + Thread.sleep(1000) + shuffleClient.mapperEnd(1, 0, 0, 1, 1) + + val metricsCallback = new MetricsCallback { + override def incBytesRead(bytesWritten: Long): Unit = {} + override def incReadTime(time: Long): Unit = {} + } + + val inputStream = shuffleClient.readPartition( + 1, + 1, + 0, + 0, + 0, + 0, + Integer.MAX_VALUE, + null, + null, + null, + null, + null, + null, + metricsCallback, + true) + + val outputStream = new ByteArrayOutputStream() + var b = inputStream.read() + while (b != -1) { + outputStream.write(b) + b = inputStream.read() + } + + val readBytes = outputStream.toByteArray + val expectedTotal = + DATA1.length + DATA2.length + DATA3.length + DATA4.length + + DATA5.length + DATA6.length + DATA7.length + + Assert.assertEquals( + "Total byte count mismatch (small-large-small interleave)", + expectedTotal, + readBytes.length) + + val readStringMap = extractBlobs(readBytes, dataPrefix, dataPrefixMap) + for ((prefix, actual) <- readStringMap) { + Assert.assertEquals( + s"Content mismatch for blob '$prefix'", + dataPrefixMap(prefix), + actual) + } + + } finally { + Thread.sleep(3000L) + shuffleClient.shutdown() + lifecycleManager.rpcEnv.shutdown() + } + } + + // ── Test cases ─────────────────────────────────────────────────────────────── + + // 1. Pure chunk ZSTD — no batch-level compression. + // Simplest configuration: the chunk writer compresses raw batches. + test("chunk compression with NONE batch codec") { + doReadWriteWithChunkCompression(CompressionCodec.NONE) + } + + // 2. Chunk ZSTD wrapping LZ4-compressed batches. + // Verifies that CelebornInputStream correctly decompresses the chunk first + // then hands each batch to the LZ4 Decompressor. + test("chunk compression with LZ4 batch codec") { + doReadWriteWithChunkCompression(CompressionCodec.LZ4) + } + + // 3. Chunk ZSTD wrapping ZSTD-compressed batches (two layers of ZSTD). + // Both chunkCompressed and shouldDecompress paths are active simultaneously. + test("chunk compression with ZSTD batch codec") { + doReadWriteWithChunkCompression(CompressionCodec.ZSTD) + } + + // 4. Small chunk size (8 KB) forces many chunk flushes across the data set, + // exercising the multi-chunk offset tracking and boundary handling. + test("chunk compression with small chunk size produces multiple chunks") { + doReadWriteWithChunkCompression(CompressionCodec.NONE, shuffleChunkSz = "8k") + } + +// // 5. Same small-chunk scenario with LZ4 batches. +// test("chunk compression + LZ4 batch codec with small chunk size") { +// doReadWriteWithChunkCompression(CompressionCodec.LZ4, shuffleChunkSz = "8k") +// } + + // 6. Local-read path (LocalPartitionReader) with chunk compression. + // Verifies that the chunk-compressed file is correctly decompressed when + // read directly from disk rather than through the network fetch path. + test("chunk compression with local shuffle read") { + doReadWriteWithChunkCompression(CompressionCodec.NONE, readLocal = true) + } + + // 7. Local read + LZ4 batch codec. + test("chunk compression with local shuffle read and LZ4 batch codec") { + doReadWriteWithChunkCompression(CompressionCodec.LZ4, readLocal = true) + } + + // 8. Small batches → large record → more small batches. + // Validates that ZstdInputStream is kept alive across multiple fillBuffer() calls + // within a single chunk (chunks 1 and 3 each hold 3 batches), and that the + // large-record ZSTD frame in chunk 2 round-trips without corruption. + test("chunk compression: multiple small batches, one large record, then more small batches") { + doSmallLargeSmallReadWrite() + } +} diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala index 26a1cb1b6d2..6e2dcf113f9 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/WorkerSuite.scala @@ -34,6 +34,7 @@ import org.scalatest.funsuite.AnyFunSuite import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.client.MasterClient +import org.apache.celeborn.common.compression.ChunkCompressionContext import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.protocol._ import org.apache.celeborn.common.protocol.message.ControlMessages.CommitFilesResponse @@ -79,7 +80,8 @@ class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach with MiniClusterFe PartitionSplitMode.SOFT, PartitionType.REDUCE, true, - new UserIdentifier("1", "2")) + new UserIdentifier("1", "2"), + ChunkCompressionContext.disabled()) worker.storageManager.createPartitionDataWriter( "2", 2, @@ -88,7 +90,8 @@ class WorkerSuite extends AnyFunSuite with BeforeAndAfterEach with MiniClusterFe PartitionSplitMode.SOFT, PartitionType.REDUCE, true, - new UserIdentifier("1", "2")) + new UserIdentifier("1", "2"), + ChunkCompressionContext.disabled()) Assert.assertEquals(1, worker.storageManager.workingDirWriters.values().size()) val expiredShuffleKeys = new JHashSet[String]() diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/PartitionMetaHandlerSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/PartitionMetaHandlerSuite.scala index ff62cb9d553..428d7d30fd5 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/PartitionMetaHandlerSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/PartitionMetaHandlerSuite.scala @@ -24,6 +24,7 @@ import java.nio.file.Files import io.netty.buffer.{ByteBuf, UnpooledByteBufAllocator} import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.common.compression.ChunkCompressionContext import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.meta.{DiskFileInfo, MapFileMeta, ReduceFileMeta} import org.apache.celeborn.common.protocol._ @@ -43,7 +44,8 @@ class PartitionMetaHandlerSuite extends CelebornFunSuite with MockitoHelper { true, fileMeta, tmpFilePath.toString, - StorageInfo.Type.HDD) + StorageInfo.Type.HDD, + ChunkCompressionContext.disabled()) val mapMetaHandler = new MapPartitionMetaHandler(diskFileInfo, notifier) val pbPushDataHandShake = @@ -108,7 +110,8 @@ class PartitionMetaHandlerSuite extends CelebornFunSuite with MockitoHelper { true, fileMeta, tmpFilePath.toString, - StorageInfo.Type.HDD) + StorageInfo.Type.HDD, + ChunkCompressionContext.disabled()) val handler1 = new ReducePartitionMetaHandler(true, diskFileInfo) handler1.beforeWrite(generateSparkFormatData(byteBufAllocator, 0)) @@ -153,7 +156,8 @@ class PartitionMetaHandlerSuite extends CelebornFunSuite with MockitoHelper { true, fileMeta, tmpFilePath.toString, - StorageInfo.Type.HDD) + StorageInfo.Type.HDD, + ChunkCompressionContext.disabled()) val mapMetaHandler = new SegmentMapPartitionMetaHandler(diskFileInfo, notifier) val pbPushDataHandShake = diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManagerSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManagerSuite.scala index 6107faf986a..0e5e84628dd 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManagerSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManagerSuite.scala @@ -27,6 +27,7 @@ import org.mockito.stubbing.Stubber import org.apache.celeborn.CelebornFunSuite import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.CelebornConf.{WORKER_DISK_RESERVE_SIZE, WORKER_GRACEFUL_SHUTDOWN_ENABLED, WORKER_GRACEFUL_SHUTDOWN_RECOVER_PATH, WORKER_STORAGE_DIRS} +import org.apache.celeborn.common.compression.ChunkCompressionContext import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.meta.{DiskInfo, DiskStatus} import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType, StorageInfo} @@ -136,7 +137,8 @@ class StorageManagerSuite extends CelebornFunSuite with MockitoHelper { "myFile", new UserIdentifier("t1", "u1"), PartitionType.REDUCE, - partitionSplitEnabled = false) + partitionSplitEnabled = false, + chunkCompressionContext = ChunkCompressionContext.disabled()) fail("Should throw IOException when disks are full") } catch { case e: IOException => diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriterSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriterSuite.scala index ee6903ddf66..3443a513145 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriterSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriterSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.compression.ChunkCompressionContext import org.apache.celeborn.common.exception.AlreadyClosedException import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.meta.{DiskFileInfo, MemoryFileInfo, ReduceFileMeta} @@ -69,7 +70,8 @@ class TierWriterSuite extends AnyFunSuite with BeforeAndAfterEach { userIdentifier, PartitionType.REDUCE, false, - false) + false, + ChunkCompressionContext.disabled()) val source = new WorkerSource(celebornConf) @@ -184,7 +186,13 @@ class TierWriterSuite extends AnyFunSuite with BeforeAndAfterEach { val userIdentifier = UserIdentifier("`aa`.`bb`") val tmpFile = Files.createTempFile("celeborn", "local-test").toString val diskFileInfo = - new DiskFileInfo(userIdentifier, false, reduceFileMeta, tmpFile, StorageInfo.Type.HDD) + new DiskFileInfo( + userIdentifier, + false, + reduceFileMeta, + tmpFile, + StorageInfo.Type.HDD, + ChunkCompressionContext.disabled()) val numPendingWriters = new AtomicInteger() val flushNotifier = new FlushNotifier() val source = new WorkerSource(celebornConf) @@ -208,7 +216,8 @@ class TierWriterSuite extends AnyFunSuite with BeforeAndAfterEach { userIdentifier, PartitionType.REDUCE, false, - false) + false, + ChunkCompressionContext.disabled()) val flusher = new LocalFlusher( source, diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase1.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase1.scala index bdf9ce7cc45..abf088b8426 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase1.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase1.scala @@ -67,6 +67,7 @@ class StoragePolicyCase1 extends CelebornFunSuite { any(), any(), any(), + any(), any())).thenAnswer((mockedFlusher, mockedDiskFile, mockedFile)) val memoryHintPartitionLocation = diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase2.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase2.scala index 9dcec7e524b..fbfc41d50fa 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase2.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase2.scala @@ -67,6 +67,7 @@ class StoragePolicyCase2 extends CelebornFunSuite { any(), any(), any(), + any(), any())).thenAnswer((mockedFlusher, mockedDiskFile, mockedFile)) val memoryHintPartitionLocation = diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase3.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase3.scala index 8f21a7f4cbc..b84d7e07c4d 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase3.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase3.scala @@ -67,6 +67,7 @@ class StoragePolicyCase3 extends CelebornFunSuite { any(), any(), any(), + any(), any())).thenAnswer((mockedFlusher, mockedDiskFile, mockedFile)) val memoryHintPartitionLocation = diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase4.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase4.scala index dc321738e74..6ee1fa5f721 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase4.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/storagePolicy/StoragePolicyCase4.scala @@ -67,6 +67,7 @@ class StoragePolicyCase4 extends CelebornFunSuite { any(), any(), any(), + any(), any())).thenAnswer((mockedFlusher, mockedDiskFile, mockedFile)) val memoryHintPartitionLocation =