From fa4686940293ff532314390c99da68cf742352f5 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Fri, 15 May 2026 11:57:44 -0700 Subject: [PATCH 01/14] [CELEBORN-2329][CIP22] Encryption at Rest Spark Impl --- .../shuffle/celeborn/SparkCommonUtils.java | 21 +++ .../shuffle/celeborn/SparkCryptoHandler.java | 68 +++++++++ .../celeborn/SparkCryptoHandlerSuiteJ.java | 136 ++++++++++++++++++ client-spark/spark-2-shaded/pom.xml | 1 + .../shuffle/celeborn/SparkShuffleManager.java | 6 +- .../celeborn/CelebornShuffleReader.scala | 8 +- .../CelebornColumnarShuffleReader.scala | 9 +- .../CelebornColumnarShuffleReaderSuite.scala | 8 +- client-spark/spark-3-shaded/pom.xml | 1 + .../shuffle/celeborn/SparkShuffleManager.java | 9 +- .../spark/shuffle/celeborn/SparkUtils.java | 11 +- .../celeborn/CelebornShuffleReader.scala | 35 ++++- .../celeborn/CelebornShuffleReaderSuite.scala | 2 +- client/pom.xml | 4 + .../celeborn/client/DummyShuffleClient.java | 5 + .../apache/celeborn/client/ShuffleClient.java | 17 +++ .../celeborn/client/ShuffleClientImpl.java | 28 +++- .../client/read/CelebornInputStream.java | 48 +++++-- .../client/security/CryptoHandler.java | 26 ++++ 19 files changed, 414 insertions(+), 29 deletions(-) create mode 100644 client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java create mode 100644 client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java create mode 100644 client/src/main/java/org/apache/celeborn/client/security/CryptoHandler.java diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java index 84d74f8c145..697f86776ba 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java @@ -19,11 +19,17 @@ import java.util.Collections; import java.util.Map; +import java.util.Optional; + +import scala.Option; import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; +import org.apache.spark.internal.config.package$; import org.apache.spark.memory.SparkOutOfMemoryError; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.reflect.DynConstructors; import org.apache.celeborn.reflect.DynMethods; @@ -96,4 +102,19 @@ public static void throwSparkOutOfMemoryError() { } } } + + public static Optional getCryptoHandler(SparkConf conf) { + if (!(Boolean) conf.get(package$.MODULE$.IO_ENCRYPTION_ENABLED())) { + return Optional.empty(); + } + SparkEnv env = SparkEnv.get(); + if (env == null) { + return Optional.empty(); + } + Option key = env.securityManager().getIOEncryptionKey(); + if (!key.isDefined()) { + return Optional.empty(); + } + return Optional.of(new SparkCryptoHandler(conf, key.get())); + } } diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java new file mode 100644 index 00000000000..c7385947290 --- /dev/null +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import org.apache.spark.SparkConf; +import org.apache.spark.security.CryptoStreamUtils; + +import org.apache.celeborn.client.security.CryptoHandler; + +public class SparkCryptoHandler implements CryptoHandler { + private final SparkConf sparkConf; + private final byte[] key; + + public SparkCryptoHandler(SparkConf sparkConf, byte[] key) { + this.sparkConf = sparkConf; + this.key = key; + } + + @Override + public byte[] encrypt(byte[] input, int offset, int length) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + dos.writeInt(length); + try (OutputStream cos = CryptoStreamUtils.createCryptoOutputStream(dos, sparkConf, key)) { + cos.write(input, offset, length); + } + return baos.toByteArray(); + } + + @Override + public byte[] decrypt(byte[] input, int offset, int length) throws IOException { + ByteArrayInputStream bais = new ByteArrayInputStream(input, offset, length); + DataInputStream dis = new DataInputStream(bais); + int decryptedLength = dis.readInt(); + if (decryptedLength < 0) { + throw new IOException( + "Invalid decrypted length: " + decryptedLength + ", encrypted length: " + length); + } + try (DataInputStream cis = + new DataInputStream(CryptoStreamUtils.createCryptoInputStream(dis, sparkConf, key))) { + byte[] decrypted = new byte[decryptedLength]; + cis.readFully(decrypted); + return decrypted; + } + } +} diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java new file mode 100644 index 00000000000..898cefb3109 --- /dev/null +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn; + +import static org.junit.Assert.*; + +import java.io.IOException; +import java.security.SecureRandom; +import java.util.Arrays; + +import org.apache.spark.SparkConf; +import org.apache.spark.internal.config.package$; +import org.junit.Before; +import org.junit.Test; + +import org.apache.celeborn.client.security.CryptoHandler; + +public class SparkCryptoHandlerSuiteJ { + + private byte[] key; + private CryptoHandler handler; + + @Before + public void setUp() { + key = new byte[16]; + new SecureRandom().nextBytes(key); + SparkConf sparkConf = new SparkConf(false); + sparkConf.set(package$.MODULE$.IO_ENCRYPTION_ENABLED(), true); + handler = new SparkCryptoHandler(sparkConf, key); + } + + @Test + public void testRoundTrip() throws IOException { + byte[] plaintext = "hello world, this is a test of encryption".getBytes(); + + byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length); + assertFalse( + "Encrypted output should differ from plaintext", Arrays.equals(plaintext, encrypted)); + + byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length); + assertArrayEquals(plaintext, decrypted); + } + + @Test + public void testEncryptedDiffersFromPlaintext() throws IOException { + byte[] plaintext = "deterministic test data for comparison".getBytes(); + + byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length); + assertFalse( + "Encrypted output should differ from plaintext", Arrays.equals(plaintext, encrypted)); + } + + @Test + public void testSameDataEncryptsThenDecrypts() throws IOException { + byte[] plaintext = "same data encrypted twice".getBytes(); + + byte[] encrypted1 = handler.encrypt(plaintext, 0, plaintext.length); + byte[] encrypted2 = handler.encrypt(plaintext, 0, plaintext.length); + + // Both should decrypt to the same plaintext + byte[] decrypted1 = handler.decrypt(encrypted1, 0, encrypted1.length); + byte[] decrypted2 = handler.decrypt(encrypted2, 0, encrypted2.length); + + assertArrayEquals(plaintext, decrypted1); + assertArrayEquals(plaintext, decrypted2); + } + + @Test + public void testEncryptWithOffset() throws IOException { + byte[] actual = "offset test data".getBytes(); + byte[] padded = Arrays.copyOf(actual, actual.length + 20); + + byte[] encrypted = handler.encrypt(padded, 0, actual.length); + byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length); + + assertArrayEquals(actual, decrypted); + } + + @Test + public void testDecryptWithWrongKeyFails() throws IOException { + byte[] plaintext = "secret data".getBytes(); + byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length); + + byte[] wrongKey = new byte[16]; + new SecureRandom().nextBytes(wrongKey); + SparkConf sparkConf = new SparkConf(false); + sparkConf.set(package$.MODULE$.IO_ENCRYPTION_ENABLED(), true); + CryptoHandler wrongHandler = new SparkCryptoHandler(sparkConf, wrongKey); + + byte[] decrypted = null; + try { + decrypted = wrongHandler.decrypt(encrypted, 0, encrypted.length); + } catch (IOException e) { + // acceptable — some implementations throw on wrong key + return; + } + // CryptoStreamUtils may return garbage instead of throwing + assertFalse( + "Decryption with wrong key should not produce original plaintext", + Arrays.equals(plaintext, decrypted)); + } + + @Test + public void testLargeData() throws IOException { + byte[] plaintext = new byte[64 * 1024]; // 64KB + new SecureRandom().nextBytes(plaintext); + + byte[] encrypted = handler.encrypt(plaintext, 0, plaintext.length); + byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length); + + assertArrayEquals(plaintext, decrypted); + } + + @Test + public void testEmptyData() throws IOException { + byte[] encrypted = handler.encrypt(new byte[0], 0, 0); + + byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length); + assertEquals(0, decrypted.length); + } +} diff --git a/client-spark/spark-2-shaded/pom.xml b/client-spark/spark-2-shaded/pom.xml index 9db62b423e0..7c3e0e9b852 100644 --- a/client-spark/spark-2-shaded/pom.xml +++ b/client-spark/spark-2-shaded/pom.xml @@ -72,6 +72,7 @@ com.google.guava:failureaccess io.netty:* org.apache.commons:commons-lang3 + org.apache.commons:commons-crypto org.roaringbitmap:RoaringBitmap commons-io:commons-io diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 02d48a4fbbc..993800a026b 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -208,7 +208,8 @@ public ShuffleWriter getWriter( h.lifecycleManagerPort(), celebornConf, h.userIdentifier(), - h.extension()); + h.extension(), + SparkCommonUtils.getCryptoHandler(conf)); if (h.stageRerunEnabled()) { SparkUtils.addFailureListenerIfBarrierTask(client, context, h); } @@ -260,7 +261,8 @@ public ShuffleReader getReader( Int.MaxValue(), context, celebornConf, - shuffleIdTracker); + shuffleIdTracker, + SparkCommonUtils.getCryptoHandler(conf)); } checkUserClassPathFirst(handle); return _sortShuffleManager.getReader(handle, startPartition, endPartition, context); diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 192cc647ded..dc8908205bb 100644 --- a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.celeborn import java.io.IOException +import java.util.Optional import java.util.concurrent.{ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicReference import java.util.function.BiFunction @@ -32,6 +33,7 @@ import org.apache.spark.util.collection.ExternalSorter import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.client.read.CelebornInputStream import org.apache.celeborn.client.read.MetricsCallback +import org.apache.celeborn.client.security.CryptoHandler import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.exception.{CelebornBroadcastException, CelebornIOException, CelebornRuntimeException, PartitionUnRetryAbleException} import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse @@ -45,7 +47,8 @@ class CelebornShuffleReader[K, C]( endMapIndex: Int = Int.MaxValue, context: TaskContext, conf: CelebornConf, - shuffleIdTracker: ExecutorShuffleIdTracker) + shuffleIdTracker: ExecutorShuffleIdTracker, + cryptoHandler: Optional[CryptoHandler] = Optional.empty()) extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency @@ -55,7 +58,8 @@ class CelebornShuffleReader[K, C]( handle.lifecycleManagerPort, conf, handle.userIdentifier, - handle.extension) + handle.extension, + cryptoHandler) private val exceptionRef = new AtomicReference[IOException] private val encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(context) diff --git a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala index fd888fb9dc1..c32dfaf2eb9 100644 --- a/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala +++ b/client-spark/spark-3-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala @@ -17,12 +17,15 @@ package org.apache.spark.shuffle.celeborn +import java.util.Optional + import org.apache.spark.{ShuffleDependency, TaskContext} import org.apache.spark.serializer.SerializerInstance import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.sql.execution.UnsafeRowSerializer import org.apache.spark.sql.execution.columnar.{CelebornBatchBuilder, CelebornColumnarBatchSerializer} +import org.apache.celeborn.client.security.CryptoHandler import org.apache.celeborn.common.CelebornConf class CelebornColumnarShuffleReader[K, C]( @@ -34,7 +37,8 @@ class CelebornColumnarShuffleReader[K, C]( context: TaskContext, conf: CelebornConf, metrics: ShuffleReadMetricsReporter, - shuffleIdTracker: ExecutorShuffleIdTracker) + shuffleIdTracker: ExecutorShuffleIdTracker, + cryptoHandler: Optional[CryptoHandler] = Optional.empty()) extends CelebornShuffleReader[K, C]( handle, startPartition, @@ -44,7 +48,8 @@ class CelebornColumnarShuffleReader[K, C]( context, conf, metrics, - shuffleIdTracker) { + shuffleIdTracker, + cryptoHandler) { override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { val schema = CustomShuffleDependencyUtils.getSchema(dep) diff --git a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala index d0f4462be3e..edc67d98335 100644 --- a/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala +++ b/client-spark/spark-3-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.celeborn +import java.util.Optional + import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance} import org.apache.spark.sql.execution.UnsafeRowSerializer @@ -58,7 +60,8 @@ class CelebornColumnarShuffleReaderSuite { taskContext, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty()) assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]]) } finally { if (shuffleClient != null) { @@ -92,7 +95,8 @@ class CelebornColumnarShuffleReaderSuite { taskContext, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty()) val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]]) Mockito.when(shuffleDependency.shuffleId).thenReturn(0) Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer( diff --git a/client-spark/spark-3-shaded/pom.xml b/client-spark/spark-3-shaded/pom.xml index bc8c2065e2d..d1e0cf834af 100644 --- a/client-spark/spark-3-shaded/pom.xml +++ b/client-spark/spark-3-shaded/pom.xml @@ -76,6 +76,7 @@ com.google.guava:failureaccess io.netty:* org.apache.commons:commons-lang3 + org.apache.commons:commons-crypto org.roaringbitmap:RoaringBitmap commons-io:commons-io diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index ed7865e19ff..176b0f2051f 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -288,7 +288,8 @@ public ShuffleWriter getWriter( h.lifecycleManagerPort(), celebornConf, h.userIdentifier(), - h.extension()); + h.extension(), + SparkCommonUtils.getCryptoHandler(conf)); if (h.stageRerunEnabled()) { SparkUtils.addFailureListenerIfBarrierTask(shuffleClient, context, h); } @@ -445,7 +446,8 @@ public ShuffleReader getCelebornShuffleReader( context, celebornConf, metrics, - shuffleIdTracker); + shuffleIdTracker, + SparkCommonUtils.getCryptoHandler(conf)); } else { return new CelebornShuffleReader<>( h, @@ -456,7 +458,8 @@ public ShuffleReader getCelebornShuffleReader( context, celebornConf, metrics, - shuffleIdTracker); + shuffleIdTracker, + SparkCommonUtils.getCryptoHandler(conf)); } } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 0e68a4b46f0..412d839434e 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -23,6 +23,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; @@ -65,6 +66,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornRuntimeException; import org.apache.celeborn.common.network.protocol.TransportMessage; @@ -280,7 +282,8 @@ private static class ColumnarShuffleReaderConstructorHolder { TaskContext.class, CelebornConf.class, ShuffleReadMetricsReporter.class, - ExecutorShuffleIdTracker.class) + ExecutorShuffleIdTracker.class, + Optional.class) .build(); } @@ -293,7 +296,8 @@ public static CelebornShuffleReader createColumnarShuffleReader( TaskContext context, CelebornConf conf, ShuffleReadMetricsReporter metrics, - ExecutorShuffleIdTracker shuffleIdTracker) { + ExecutorShuffleIdTracker shuffleIdTracker, + Optional cryptoHandler) { return ColumnarShuffleReaderConstructorHolder.INSTANCE.invoke( null, handle, @@ -304,7 +308,8 @@ public static CelebornShuffleReader createColumnarShuffleReader( context, conf, metrics, - shuffleIdTracker); + shuffleIdTracker, + cryptoHandler); } // Added in SPARK-32920, for Spark 3.2 and above diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 55e036155b2..36d7115d6db 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle.celeborn import java.io.IOException -import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Set => JSet} +import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Optional, Set => JSet} import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit} import java.util.concurrent.atomic.AtomicReference import java.util.function.BiFunction @@ -39,6 +39,7 @@ import org.apache.spark.util.collection.ExternalSorter import org.apache.celeborn.client.{ClientUtils, ShuffleClient} import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback} +import org.apache.celeborn.client.security.CryptoHandler import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.exception.{CelebornBroadcastException, CelebornIOException, CelebornRuntimeException, PartitionUnRetryAbleException} import org.apache.celeborn.common.network.client.TransportClient @@ -58,7 +59,8 @@ class CelebornShuffleReader[K, C]( conf: CelebornConf, metrics: ShuffleReadMetricsReporter, shuffleIdTracker: ExecutorShuffleIdTracker, - needDecompress: Boolean) + needDecompress: Boolean, + cryptoHandler: Optional[CryptoHandler]) extends ShuffleReader[K, C] with Logging { def this( @@ -80,7 +82,31 @@ class CelebornShuffleReader[K, C]( conf, metrics, shuffleIdTracker, - true) + true, + Optional.empty()) + + def this( + handle: CelebornShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + startMapIndex: Int, + endMapIndex: Int, + context: TaskContext, + conf: CelebornConf, + metrics: ShuffleReadMetricsReporter, + shuffleIdTracker: ExecutorShuffleIdTracker, + cryptoHandler: Optional[CryptoHandler]) = this( + handle, + startPartition, + endPartition, + startMapIndex, + endMapIndex, + context, + conf, + metrics, + shuffleIdTracker, + true, + cryptoHandler) private val dep = handle.dependency @@ -91,7 +117,8 @@ class CelebornShuffleReader[K, C]( handle.lifecycleManagerPort, conf, handle.userIdentifier, - handle.extension) + handle.extension, + cryptoHandler) private val exceptionRef = new AtomicReference[IOException] private val stageRerunEnabled = handle.stageRerunEnabled diff --git a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala index 29878fd76c0..94dc7a2ba24 100644 --- a/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala +++ b/client-spark/spark-3/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala @@ -55,7 +55,7 @@ class CelebornShuffleReaderSuite extends AnyFunSuite { val tmpFile = Files.createTempFile("test", ".tmp").toFile mockStatic(classOf[ShuffleClient]).when(() => - ShuffleClient.get(any(), any(), any(), any(), any(), any())).thenReturn( + ShuffleClient.get(any(), any(), any(), any(), any(), any(), any())).thenReturn( new DummyShuffleClient(conf, tmpFile)) val shuffleReader = diff --git a/client/pom.xml b/client/pom.xml index 12241885914..0a05b22ef4e 100644 --- a/client/pom.xml +++ b/client/pom.xml @@ -82,6 +82,10 @@ org.apache.commons commons-lang3 + + org.apache.commons + commons-crypto + org.mockito mockito-core diff --git a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java index 6ca3406be7f..ad618c3cfb5 100644 --- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -28,6 +28,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; @@ -39,6 +40,7 @@ import org.apache.celeborn.client.read.CelebornInputStream; import org.apache.celeborn.client.read.MetricsCallback; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.network.client.TransportClientFactory; @@ -67,6 +69,9 @@ public DummyShuffleClient(CelebornConf conf, File file) throws Exception { this.conf = conf; } + @Override + public void setupCryptoHandler(Optional cryptoHandler) {} + @Override public void setupLifecycleManagerRef(String host, int port) {} diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 7a89b051d4c..6062ac6560c 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -35,6 +35,7 @@ import org.apache.celeborn.client.read.CelebornInputStream; import org.apache.celeborn.client.read.MetricsCallback; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.identity.UserIdentifier; @@ -90,6 +91,18 @@ public static ShuffleClient get( CelebornConf conf, UserIdentifier userIdentifier, byte[] extension) { + return ShuffleClient.get( + appUniqueId, driverHost, port, conf, userIdentifier, extension, Optional.empty()); + } + + public static ShuffleClient get( + String appUniqueId, + String driverHost, + int port, + CelebornConf conf, + UserIdentifier userIdentifier, + byte[] extension, + Optional cryptoHandler) { if (null == _instance || !initialized) { synchronized (ShuffleClient.class) { if (null == _instance) { @@ -102,12 +115,14 @@ public static ShuffleClient get( _instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier); _instance.setupLifecycleManagerRef(driverHost, port); _instance.setExtension(extension); + _instance.setupCryptoHandler(cryptoHandler); initialized = true; } else if (!initialized) { _instance.shutdown(); _instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier); _instance.setupLifecycleManagerRef(driverHost, port); _instance.setExtension(extension); + _instance.setupCryptoHandler(cryptoHandler); initialized = true; } } @@ -150,6 +165,8 @@ public static void printReadStats(Logger logger) { String.format("%.2f", (localReadCount * 1.0d / totalReadCount) * 100)); } + public abstract void setupCryptoHandler(Optional cryptoHandler); + public abstract void setupLifecycleManagerRef(String host, int port); public abstract void setupLifecycleManagerRef(RpcEndpointRef endpointRef); diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index be2bdf87d11..88e479d0448 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -43,6 +43,7 @@ import org.apache.celeborn.client.compress.Compressor; import org.apache.celeborn.client.read.CelebornInputStream; import org.apache.celeborn.client.read.MetricsCallback; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornBroadcastException; import org.apache.celeborn.common.exception.CelebornIOException; @@ -101,6 +102,8 @@ public class ShuffleClientImpl extends ShuffleClient { protected byte[] extension; + private Optional cryptoHandler = Optional.empty(); + // key: appShuffleIdentifier, value: shuffleId protected Map> shuffleIdCache = JavaUtils.newConcurrentHashMap(); @@ -1060,6 +1063,20 @@ public int pushOrMergeData( length = compressor.getCompressedTotalSize(); } + if (cryptoHandler.isPresent()) { + byte[] encrypted = cryptoHandler.get().encrypt(data, offset, length); + logger.debug( + "Encrypted shuffle data for shuffle {} map {} partition {}: {} bytes -> {} bytes.", + shuffleId, + mapId, + partitionId, + length, + encrypted.length); + data = encrypted; + offset = 0; + length = encrypted.length; + } + final byte[] body = new byte[BATCH_HEADER_SIZE + length]; Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, mapId); Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, attemptId); @@ -2025,7 +2042,8 @@ public CelebornInputStream readPartition( partitionId, exceptionMaker, metricsCallback, - needDecompress); + needDecompress, + cryptoHandler); } } @@ -2090,6 +2108,14 @@ public void setExtension(byte[] extension) { this.extension = extension; } + @Override + public void setupCryptoHandler(Optional cryptoHandler) { + this.cryptoHandler = cryptoHandler; + if (cryptoHandler.isPresent()) { + logger.info("IO encryption enabled for shuffle data (encryption at rest)."); + } + } + boolean mapperEnded(int shuffleId, int mapId) { return (mapperEndMap.containsKey(shuffleId) && mapperEndMap.get(shuffleId).contains(mapId)) || isStageEnded(shuffleId); diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 37e0be3e375..969e92a1103 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -39,6 +39,7 @@ import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.client.compress.Decompressor; import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.CommitMetadata; import org.apache.celeborn.common.exception.CelebornIOException; @@ -74,7 +75,8 @@ public static CelebornInputStream create( int partitionId, ExceptionMaker exceptionMaker, MetricsCallback metricsCallback, - boolean needDecompress) + boolean needDecompress, + Optional cryptoHandler) throws IOException { if (locations == null || locations.isEmpty()) { return emptyInputStream; @@ -106,7 +108,8 @@ public static CelebornInputStream create( metricsCallback, needDecompress, startMapIndex, - endMapIndex); + endMapIndex, + cryptoHandler); } else { return new CelebornInputStreamImpl( conf, @@ -131,7 +134,8 @@ public static CelebornInputStream create( metricsCallback, needDecompress, -1, - -1); + -1, + cryptoHandler); } } } @@ -188,6 +192,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private final Map failedBatches; + private byte[] encryptedBuf; private byte[] compressedBuf; private byte[] rawDataBuf; private Decompressor decompressor; @@ -223,6 +228,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { private int shuffleId; private int partitionId; private ExceptionMaker exceptionMaker; + private Optional cryptoHandler; private boolean closed = false; private boolean integrityChecked = false; private final CommitMetadata aggregatedActualCommitMetadata = new CommitMetadata(); @@ -250,7 +256,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { MetricsCallback metricsCallback, boolean needDecompress, int numberOfSubPartitions, - int currentIndexOfSubPartition) + int currentIndexOfSubPartition, + Optional cryptoHandler) throws IOException { this( conf, @@ -275,7 +282,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { metricsCallback, needDecompress, numberOfSubPartitions, - currentIndexOfSubPartition); + currentIndexOfSubPartition, + cryptoHandler); } CelebornInputStreamImpl( @@ -301,7 +309,8 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { MetricsCallback metricsCallback, boolean needDecompress, int numberOfSubPartitions, - int currentIndexOfSubPartition) + int currentIndexOfSubPartition, + Optional cryptoHandler) throws IOException { this.conf = conf; this.clientFactory = clientFactory; @@ -337,6 +346,7 @@ private static final class CelebornInputStreamImpl extends CelebornInputStream { this.retryWaitMs = conf.networkIoRetryWaitMs(TransportModuleConstants.DATA_MODULE); this.callback = metricsCallback; this.exceptionMaker = exceptionMaker; + this.cryptoHandler = cryptoHandler; this.partitionId = partitionId; this.appShuffleId = appShuffleId; this.shuffleId = shuffleId; @@ -791,6 +801,9 @@ private boolean moveToNextChunk() throws IOException { private void init() { int bufferSize = conf.clientFetchBufferSize(); + if (cryptoHandler.isPresent()) { + encryptedBuf = new byte[bufferSize]; + } if (shouldDecompress) { int headerLen = Decompressor.getCompressionHeaderLength(conf); bufferSize += headerLen; @@ -823,17 +836,34 @@ private boolean fillBuffer() throws IOException { int batchId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8); int size = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 12); - if (shouldDecompress) { + // Read and optionally decrypt data into the appropriate buffer + if (cryptoHandler.isPresent()) { + if (size > encryptedBuf.length) { + encryptedBuf = new byte[size]; + } + currentChunk.readBytes(encryptedBuf, 0, size); + byte[] decrypted = cryptoHandler.get().decrypt(encryptedBuf, 0, size); + logger.debug( + "Decrypted shuffle data for shuffle {} partition {}: {} bytes -> {} bytes.", + shuffleId, + partitionId, + size, + decrypted.length); + size = decrypted.length; + if (shouldDecompress) { + compressedBuf = decrypted; + } else { + rawDataBuf = decrypted; + } + } else if (shouldDecompress) { if (size > compressedBuf.length) { compressedBuf = new byte[size]; } - currentChunk.readBytes(compressedBuf, 0, size); } else { if (size > rawDataBuf.length) { rawDataBuf = new byte[size]; } - currentChunk.readBytes(rawDataBuf, 0, size); } diff --git a/client/src/main/java/org/apache/celeborn/client/security/CryptoHandler.java b/client/src/main/java/org/apache/celeborn/client/security/CryptoHandler.java new file mode 100644 index 00000000000..38da34fa797 --- /dev/null +++ b/client/src/main/java/org/apache/celeborn/client/security/CryptoHandler.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.security; + +import java.io.IOException; + +public interface CryptoHandler { + byte[] encrypt(byte[] input, int offset, int length) throws IOException; + + byte[] decrypt(byte[] input, int offset, int length) throws IOException; +} From f48960d189e00f1ee0234a943543a1cf1242f0b8 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Fri, 15 May 2026 12:32:08 -0700 Subject: [PATCH 02/14] fix errors --- .../client/read/CelebornInputStreamPeerFailoverTest.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamPeerFailoverTest.java b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamPeerFailoverTest.java index a7dd7db3178..456486b9e8c 100644 --- a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamPeerFailoverTest.java +++ b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamPeerFailoverTest.java @@ -32,6 +32,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; @@ -39,6 +40,7 @@ import org.junit.Test; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.network.client.TransportClient; @@ -174,7 +176,8 @@ public void testFailureWithoutPeer() throws Exception { 0, null, new TestMetricsCallback(), - false); + false, + Optional.empty()); } private void createInputStream(String primaryHost, String replicaHost) throws IOException { @@ -209,7 +212,8 @@ private void createInputStream(String primaryHost, String replicaHost) throws IO 0, null, new TestMetricsCallback(), - false); + false, + Optional.empty()); } /** From 0cf8e57090f0a64e0b91855d112cfbbd0aab92bd Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Fri, 15 May 2026 14:10:02 -0700 Subject: [PATCH 03/14] trigger build From 4e37cad6cb3682d92ce0f5d1edd680f11c767616 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Tue, 9 Jun 2026 14:19:54 -0700 Subject: [PATCH 04/14] [CELEBORN-2329] Address PR review comments for EAR Spark impl - Add upper-bound check (decryptedLength > length - 4) in SparkCryptoHandler.decrypt() to guard against OOM from corrupted or wrong-key input - Fix incBytesRead/incDuplicateBytesRead metric undercount in CelebornInputStream by tracking encryptedSize separately from the decrypted size - Fix testEncryptWithOffset to actually exercise a non-zero offset (offset=10) - Defensively handle null in ShuffleClientImpl.setupCryptoHandler to prevent NPE - Add fallback 9-arg constructor in SparkUtils.ColumnarShuffleReaderConstructorHolder for backward compatibility with older columnar-shuffle modules - Move commons-crypto dependency from client/pom.xml to client-spark/common/pom.xml so non-Spark engines (Flink, MR) no longer pull it in unnecessarily Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- client-spark/common/pom.xml | 4 ++++ .../spark/shuffle/celeborn/SparkCryptoHandler.java | 5 ++++- .../shuffle/celeborn/SparkCryptoHandlerSuiteJ.java | 6 ++++-- .../apache/spark/shuffle/celeborn/SparkUtils.java | 12 ++++++++++++ client/pom.xml | 4 ---- .../apache/celeborn/client/ShuffleClientImpl.java | 4 ++-- .../celeborn/client/read/CelebornInputStream.java | 10 +++++++--- 7 files changed, 33 insertions(+), 12 deletions(-) diff --git a/client-spark/common/pom.xml b/client-spark/common/pom.xml index 60e0d21eb42..b115d0cd4f8 100644 --- a/client-spark/common/pom.xml +++ b/client-spark/common/pom.xml @@ -75,6 +75,10 @@ spark-sql_${scala.binary.version} provided + + org.apache.commons + commons-crypto + org.mockito mockito-core diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java index c7385947290..84af2b34b92 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java @@ -54,7 +54,10 @@ public byte[] decrypt(byte[] input, int offset, int length) throws IOException { ByteArrayInputStream bais = new ByteArrayInputStream(input, offset, length); DataInputStream dis = new DataInputStream(bais); int decryptedLength = dis.readInt(); - if (decryptedLength < 0) { + // The encrypted payload format is: [4-byte plaintext length][ciphertext...]. + // So the maximum valid decrypted length is length - 4 (the ciphertext portion). + // A value outside this range indicates corruption or a wrong key. + if (decryptedLength < 0 || decryptedLength > length - 4) { throw new IOException( "Invalid decrypted length: " + decryptedLength + ", encrypted length: " + length); } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java index 898cefb3109..6dd16c36ac1 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java @@ -83,9 +83,11 @@ public void testSameDataEncryptsThenDecrypts() throws IOException { @Test public void testEncryptWithOffset() throws IOException { byte[] actual = "offset test data".getBytes(); - byte[] padded = Arrays.copyOf(actual, actual.length + 20); + int offset = 10; + byte[] padded = new byte[offset + actual.length + 20]; + System.arraycopy(actual, 0, padded, offset, actual.length); - byte[] encrypted = handler.encrypt(padded, 0, actual.length); + byte[] encrypted = handler.encrypt(padded, offset, actual.length); byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length); assertArrayEquals(actual, decrypted); diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 412d839434e..abb3fa5803e 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -284,6 +284,18 @@ private static class ColumnarShuffleReaderConstructorHolder { ShuffleReadMetricsReporter.class, ExecutorShuffleIdTracker.class, Optional.class) + // Fallback for older columnar-shuffle modules that don't have the cryptoHandler param + .impl( + COLUMNAR_SHUFFLE_READER_CLASS, + CelebornShuffleHandle.class, + int.class, + int.class, + int.class, + int.class, + TaskContext.class, + CelebornConf.class, + ShuffleReadMetricsReporter.class, + ExecutorShuffleIdTracker.class) .build(); } diff --git a/client/pom.xml b/client/pom.xml index 0a05b22ef4e..12241885914 100644 --- a/client/pom.xml +++ b/client/pom.xml @@ -82,10 +82,6 @@ org.apache.commons commons-lang3 - - org.apache.commons - commons-crypto - org.mockito mockito-core diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index b4bb766f1a4..6a50c65b454 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -2124,8 +2124,8 @@ public void setExtension(byte[] extension) { @Override public void setupCryptoHandler(Optional cryptoHandler) { - this.cryptoHandler = cryptoHandler; - if (cryptoHandler.isPresent()) { + this.cryptoHandler = cryptoHandler != null ? cryptoHandler : Optional.empty(); + if (this.cryptoHandler.isPresent()) { logger.info("IO encryption enabled for shuffle data (encryption at rest)."); } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 969e92a1103..e3d551572b5 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -836,7 +836,11 @@ private boolean fillBuffer() throws IOException { int batchId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8); int size = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 12); - // Read and optionally decrypt data into the appropriate buffer + // Read and optionally decrypt data into the appropriate buffer. + // encryptedSize tracks the on-wire (encrypted) byte count for metrics; size is + // reassigned to the decrypted length so downstream decompression and limit logic + // operate on the correct plaintext size. + int encryptedSize = size; if (cryptoHandler.isPresent()) { if (size > encryptedBuf.length) { encryptedBuf = new byte[size]; @@ -886,7 +890,7 @@ private boolean fillBuffer() throws IOException { Set batchSet = batchesRead.computeIfAbsent(mapId, k -> new HashSet<>()); if (!batchSet.contains(batchId)) { batchSet.add(batchId); - callback.incBytesRead(BATCH_HEADER_SIZE + size); + callback.incBytesRead(BATCH_HEADER_SIZE + encryptedSize); if (shouldDecompress) { // decompress data int originalLength = decompressor.getOriginalLen(compressedBuf); @@ -904,7 +908,7 @@ private boolean fillBuffer() throws IOException { hasData = true; break; } else { - callback.incDuplicateBytesRead(BATCH_HEADER_SIZE + size); + callback.incDuplicateBytesRead(BATCH_HEADER_SIZE + encryptedSize); logger.debug( "Skip duplicated batch: mapId {}, attemptId {}, batchId {}.", mapId, From f48ae5086d528684b1b7e116922ad314029fff58 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Tue, 9 Jun 2026 22:28:05 -0700 Subject: [PATCH 05/14] [CELEBORN-2329] Add backward-compatible overload for createColumnarShuffleReader spark-3.5-columnar-shuffle and spark-4-columnar-shuffle test suites call createColumnarShuffleReader without the Optional parameter, causing a compile error. Add a 9-arg overload that defaults to Optional.empty() so existing callers don't need to be updated. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../spark/shuffle/celeborn/SparkUtils.java | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index abb3fa5803e..f1997692c63 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -324,6 +324,30 @@ public static CelebornShuffleReader createColumnarShuffleReader( cryptoHandler); } + /** Overload for callers that do not use encryption at rest. */ + public static CelebornShuffleReader createColumnarShuffleReader( + CelebornShuffleHandle handle, + int startPartition, + int endPartition, + int startMapIndex, + int endMapIndex, + TaskContext context, + CelebornConf conf, + ShuffleReadMetricsReporter metrics, + ExecutorShuffleIdTracker shuffleIdTracker) { + return createColumnarShuffleReader( + handle, + startPartition, + endPartition, + startMapIndex, + endMapIndex, + context, + conf, + metrics, + shuffleIdTracker, + Optional.empty()); + } + // Added in SPARK-32920, for Spark 3.2 and above private static final DynMethods.UnboundMethod UnregisterAllMapAndMergeOutput_METHOD = DynMethods.builder("unregisterAllMapAndMergeOutput") From 9fb403158a2c16a446152d2bef533ee002c706f7 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Thu, 11 Jun 2026 15:32:59 -0700 Subject: [PATCH 06/14] Add getCryptoHandler caching and EAR round-trip integration test - Cache getCryptoHandler() result in SparkShuffleManager (Spark 2 + 3) to avoid re-reading and re-parsing the config on every shuffle read/write call; uses a volatile field for safe lazy initialization - Add CelebornInputStreamCryptoRoundTripSuiteJ: integration-style test that exercises the full encrypt-on-write / decrypt-on-read path in CelebornInputStream, including compress+encrypt ordering, integrity check compatibility, and large payloads Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../shuffle/celeborn/SparkShuffleManager.java | 17 +- .../shuffle/celeborn/SparkShuffleManager.java | 19 +- ...ebornInputStreamCryptoRoundTripSuiteJ.java | 298 ++++++++++++++++++ 3 files changed, 329 insertions(+), 5 deletions(-) create mode 100644 client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 993800a026b..6e929c215da 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import scala.Int; @@ -35,6 +36,7 @@ import org.apache.celeborn.client.LifecycleManager; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.ShuffleMode; import org.apache.celeborn.reflect.DynMethods; @@ -65,6 +67,17 @@ public class SparkShuffleManager implements ShuffleManager { private ExecutorShuffleIdTracker shuffleIdTracker = new ExecutorShuffleIdTracker(); + // The IO encryption key is fixed for the app lifetime. Lazily initialized on first + // writer/reader call (not in the constructor) to ensure SparkEnv is available. + private volatile Optional cryptoHandler = null; + + private Optional getCryptoHandler() { + if (cryptoHandler == null) { + cryptoHandler = SparkCommonUtils.getCryptoHandler(conf); + } + return cryptoHandler; + } + public SparkShuffleManager(SparkConf conf, boolean isDriver) { SparkCommonUtils.validateAttemptConfig(conf); this.conf = conf; @@ -209,7 +222,7 @@ public ShuffleWriter getWriter( celebornConf, h.userIdentifier(), h.extension(), - SparkCommonUtils.getCryptoHandler(conf)); + getCryptoHandler()); if (h.stageRerunEnabled()) { SparkUtils.addFailureListenerIfBarrierTask(client, context, h); } @@ -262,7 +275,7 @@ public ShuffleReader getReader( context, celebornConf, shuffleIdTracker, - SparkCommonUtils.getCryptoHandler(conf)); + getCryptoHandler()); } checkUserClassPathFirst(handle); return _sortShuffleManager.getReader(handle, startPartition, endPartition, context); diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 176b0f2051f..3e45ba2c8ec 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import org.apache.spark.*; @@ -33,6 +34,7 @@ import org.apache.celeborn.client.LifecycleManager; import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.protocol.ShuffleMode; import org.apache.celeborn.reflect.DynMethods; @@ -91,6 +93,17 @@ public class SparkShuffleManager implements ShuffleManager { private ExecutorShuffleIdTracker shuffleIdTracker = new ExecutorShuffleIdTracker(); + // The IO encryption key is fixed for the app lifetime. Lazily initialized on first + // writer/reader call (not in the constructor) to ensure SparkEnv is available. + private volatile Optional cryptoHandler = null; + + private Optional getCryptoHandler() { + if (cryptoHandler == null) { + cryptoHandler = SparkCommonUtils.getCryptoHandler(conf); + } + return cryptoHandler; + } + public SparkShuffleManager(SparkConf conf, boolean isDriver) { if (conf.getBoolean(SQLConf.LOCAL_SHUFFLE_READER_ENABLED().key(), true)) { logger.warn( @@ -289,7 +302,7 @@ public ShuffleWriter getWriter( celebornConf, h.userIdentifier(), h.extension(), - SparkCommonUtils.getCryptoHandler(conf)); + getCryptoHandler()); if (h.stageRerunEnabled()) { SparkUtils.addFailureListenerIfBarrierTask(shuffleClient, context, h); } @@ -447,7 +460,7 @@ public ShuffleReader getCelebornShuffleReader( celebornConf, metrics, shuffleIdTracker, - SparkCommonUtils.getCryptoHandler(conf)); + getCryptoHandler()); } else { return new CelebornShuffleReader<>( h, @@ -459,7 +472,7 @@ public ShuffleReader getCelebornShuffleReader( celebornConf, metrics, shuffleIdTracker, - SparkCommonUtils.getCryptoHandler(conf)); + getCryptoHandler()); } } diff --git a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java new file mode 100644 index 00000000000..8078534f121 --- /dev/null +++ b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.read; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.Test; + +import org.apache.celeborn.client.ShuffleClient; +import org.apache.celeborn.client.compress.Compressor; +import org.apache.celeborn.client.security.CryptoHandler; +import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.network.buffer.NettyManagedBuffer; +import org.apache.celeborn.common.network.client.ChunkReceivedCallback; +import org.apache.celeborn.common.network.client.TransportClient; +import org.apache.celeborn.common.network.client.TransportClientFactory; +import org.apache.celeborn.common.network.protocol.TransportMessage; +import org.apache.celeborn.common.protocol.MessageType; +import org.apache.celeborn.common.protocol.PartitionLocation; +import org.apache.celeborn.common.protocol.PbStreamHandler; +import org.apache.celeborn.common.protocol.StorageInfo; +import org.apache.celeborn.common.unsafe.Platform; + +/** + * Integration-style round-trip tests for EAR (Encryption At Rest) wiring in + * {@link CelebornInputStream}. These tests verify that the encrypt-on-write / + * decrypt-on-read path works end-to-end, including interactions with compression + * and the shuffle integrity check. + */ +public class CelebornInputStreamCryptoRoundTripSuiteJ { + + private static final int BATCH_HEADER_SIZE = 16; + private static final String SHUFFLE_KEY = "app-1-1"; + + /** + * A minimal CryptoHandler for testing: the encrypted format is + * [4-byte plaintext length (int)][XOR-encrypted payload]. + * This matches the structural contract of SparkCryptoHandler so the + * bounds check (decryptedLength > length - 4) is also exercised. + */ + static class XorCryptoHandler implements CryptoHandler { + private final byte key; + + XorCryptoHandler(byte key) { + this.key = key; + } + + @Override + public byte[] encrypt(byte[] input, int offset, int length) throws IOException { + // Prefix with 4-byte plaintext length, then XOR-encrypt the payload + byte[] out = new byte[4 + length]; + Platform.putInt(out, Platform.BYTE_ARRAY_OFFSET, length); + for (int i = 0; i < length; i++) { + out[4 + i] = (byte) (input[offset + i] ^ key); + } + return out; + } + + @Override + public byte[] decrypt(byte[] input, int offset, int length) throws IOException { + // Read the plaintext length from the 4-byte prefix + int decryptedLength = Platform.getInt(input, Platform.BYTE_ARRAY_OFFSET + offset); + // Validate bounds: the 4-byte prefix must fit inside the encrypted buffer + if (decryptedLength < 0 || decryptedLength > length - 4) { + throw new IOException( + "Invalid decrypted length: " + decryptedLength + ", encrypted length: " + length); + } + byte[] out = new byte[decryptedLength]; + for (int i = 0; i < decryptedLength; i++) { + out[i] = (byte) (input[offset + 4 + i] ^ key); + } + return out; + } + } + + /** + * Build a single batch ByteBuf as ShuffleClientImpl.pushOrMergeData does: + * optionally compress, optionally encrypt, then prepend the 16-byte batch header. + */ + private ByteBuf buildBatch( + byte[] plaintext, + boolean compress, + CryptoHandler cryptoHandler, + CelebornConf conf) + throws IOException { + byte[] data = plaintext; + int offset = 0; + int length = plaintext.length; + + // Step 1: optionally compress (compress-then-encrypt ordering matches ShuffleClientImpl) + if (compress) { + Compressor compressor = Compressor.getCompressor(conf); + compressor.compress(data, offset, length); + data = compressor.getCompressedBuffer(); + offset = 0; + length = compressor.getCompressedTotalSize(); + } + + // Step 2: optionally encrypt the (possibly compressed) payload + if (cryptoHandler != null) { + data = cryptoHandler.encrypt(data, offset, length); + offset = 0; + length = data.length; + } + + // Step 3: prepend the 16-byte batch header [mapId|attemptId|batchId|payloadLen] + byte[] body = new byte[BATCH_HEADER_SIZE + length]; + Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, 0); // mapId + Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, 0); // attemptId + Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 8, 0); // batchId + Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, length); // payload length + System.arraycopy(data, offset, body, BATCH_HEADER_SIZE, length); + return Unpooled.wrappedBuffer(body); + } + + /** + * Create a CelebornInputStream backed by a mock TransportClient that serves + * the given batchBuf as a single chunk. + */ + private CelebornInputStream createStream( + ByteBuf batchBuf, + boolean needDecompress, + Optional cryptoHandler, + CelebornConf conf) + throws IOException, InterruptedException { + TransportClient client = mock(TransportClient.class); + PbStreamHandler pbHandler = + PbStreamHandler.newBuilder().setStreamId(1L).setNumChunks(1).build(); + // Encode the stream handler into an RPC response that CelebornInputStream expects + ByteBuffer rpcResponse = + new TransportMessage(MessageType.STREAM_HANDLER, pbHandler.toByteArray()).toByteBuffer(); + when(client.sendRpcSync(any(ByteBuffer.class), anyLong())).thenReturn(rpcResponse); + doNothing().when(client).sendRpc(any(ByteBuffer.class)); + doAnswer( + invocation -> { + ChunkReceivedCallback cb = invocation.getArgument(3); + // Serve the pre-built batch buffer immediately as chunk 0 + cb.onSuccess(0, new NettyManagedBuffer(batchBuf.duplicate().retain())); + return null; + }) + .when(client) + .fetchChunk(anyLong(), anyInt(), anyLong(), any(ChunkReceivedCallback.class)); + + TransportClientFactory clientFactory = mock(TransportClientFactory.class); + when(clientFactory.createClient(anyString(), anyInt())).thenReturn(client); + + ShuffleClient shuffleClient = mock(ShuffleClient.class); + + // PRIMARY location pointing to a single HDD partition + PartitionLocation location = + new PartitionLocation( + 0, 0, "host1", 9001, 9002, 9003, 9004, PartitionLocation.Mode.PRIMARY); + location.setStorageInfo(new StorageInfo(StorageInfo.Type.HDD, true, "/mnt/disk1")); + + ArrayList locations = new ArrayList<>(); + locations.add(location); + ArrayList handlers = new ArrayList<>(); + handlers.add(PbStreamHandler.newBuilder().setStreamId(1L).setNumChunks(1).build()); + + return CelebornInputStream.create( + conf, + clientFactory, + SHUFFLE_KEY, + locations, + handlers, + new int[] {0}, + new HashMap<>(), + new HashMap<>(), + 0, + 1L, + 0, + 100, + new ConcurrentHashMap<>(), + shuffleClient, + 1, + 1, + 0, + null, + new MetricsCallback() { + @Override + public void incBytesRead(long bytes) {} + + @Override + public void incReadTime(long time) {} + }, + needDecompress, + cryptoHandler); + } + + private byte[] readAll(CelebornInputStream stream) throws IOException { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + byte[] buf = new byte[4096]; + int n; + while ((n = stream.read(buf)) != -1) { + baos.write(buf, 0, n); + } + return baos.toByteArray(); + } + + @Test + public void testEncryptDecryptRoundTrip() throws IOException, InterruptedException { + byte[] plaintext = "hello, EAR round-trip without compression".getBytes(); + CelebornConf conf = new CelebornConf(); + XorCryptoHandler handler = new XorCryptoHandler((byte) 0x5A); + + // Build an encrypted batch and read it back through CelebornInputStream + ByteBuf batchBuf = buildBatch(plaintext, false, handler, conf); + try (CelebornInputStream stream = + createStream(batchBuf, false, Optional.of(handler), conf)) { + assertArrayEquals(plaintext, readAll(stream)); + } + } + + @Test + public void testNoEncryptionRoundTrip() throws IOException, InterruptedException { + byte[] plaintext = "unencrypted shuffle data sanity check".getBytes(); + CelebornConf conf = new CelebornConf(); + + // Baseline: with no CryptoHandler the data flows through unchanged + ByteBuf batchBuf = buildBatch(plaintext, false, null, conf); + try (CelebornInputStream stream = createStream(batchBuf, false, Optional.empty(), conf)) { + assertArrayEquals(plaintext, readAll(stream)); + } + } + + @Test + public void testCompressThenEncryptRoundTrip() throws IOException, InterruptedException { + // Reproduce the compress-then-encrypt ordering used in ShuffleClientImpl. + byte[] plaintext = "shuffle data with compression and encryption enabled for EAR".getBytes(); + CelebornConf conf = new CelebornConf(); + // Use LZ4 (default) + conf.set(CelebornConf.SHUFFLE_COMPRESSION_CODEC().key(), "lz4"); + XorCryptoHandler handler = new XorCryptoHandler((byte) 0x3C); + + // Writer: LZ4-compress then XOR-encrypt; Reader: decrypt then decompress + ByteBuf batchBuf = buildBatch(plaintext, true, handler, conf); + try (CelebornInputStream stream = createStream(batchBuf, true, Optional.of(handler), conf)) { + assertArrayEquals(plaintext, readAll(stream)); + } + } + + @Test + public void testEncryptWithIntegrityCheckEnabled() throws IOException, InterruptedException { + // Verify that EAR + shuffle integrity check (celeborn.client.shuffle.integrityCheck.enabled) + // work together: the checksum is computed over plaintext, so decrypt-then-verify must hold. + byte[] plaintext = "integrity check should pass after decryption".getBytes(); + CelebornConf conf = new CelebornConf(); + conf.set(CelebornConf.CLIENT_SHUFFLE_INTEGRITY_CHECK_ENABLED().key(), "true"); + XorCryptoHandler handler = new XorCryptoHandler((byte) 0x7F); + + // The integrity metadata (checksum) is added by CelebornInputStream over the decrypted data + ByteBuf batchBuf = buildBatch(plaintext, false, handler, conf); + try (CelebornInputStream stream = createStream(batchBuf, false, Optional.of(handler), conf)) { + assertArrayEquals(plaintext, readAll(stream)); + } + } + + @Test + public void testLargePayloadEncryptDecrypt() throws IOException, InterruptedException { + // 128 KB payload exercises buffer-boundary handling in fillBuffer() + byte[] plaintext = new byte[128 * 1024]; + for (int i = 0; i < plaintext.length; i++) plaintext[i] = (byte) (i % 251); + CelebornConf conf = new CelebornConf(); + XorCryptoHandler handler = new XorCryptoHandler((byte) 0xAB); + + ByteBuf batchBuf = buildBatch(plaintext, false, handler, conf); + try (CelebornInputStream stream = createStream(batchBuf, false, Optional.of(handler), conf)) { + assertArrayEquals(plaintext, readAll(stream)); + } + } +} From 9321d70440ce10de210158d25bc77038e0863ffe Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Thu, 11 Jun 2026 15:38:15 -0700 Subject: [PATCH 07/14] spotless --- ...ebornInputStreamCryptoRoundTripSuiteJ.java | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java index 8078534f121..109b1cd990f 100644 --- a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java @@ -49,10 +49,9 @@ import org.apache.celeborn.common.unsafe.Platform; /** - * Integration-style round-trip tests for EAR (Encryption At Rest) wiring in - * {@link CelebornInputStream}. These tests verify that the encrypt-on-write / - * decrypt-on-read path works end-to-end, including interactions with compression - * and the shuffle integrity check. + * Integration-style round-trip tests for EAR (Encryption At Rest) wiring in {@link + * CelebornInputStream}. These tests verify that the encrypt-on-write / decrypt-on-read path works + * end-to-end, including interactions with compression and the shuffle integrity check. */ public class CelebornInputStreamCryptoRoundTripSuiteJ { @@ -60,10 +59,9 @@ public class CelebornInputStreamCryptoRoundTripSuiteJ { private static final String SHUFFLE_KEY = "app-1-1"; /** - * A minimal CryptoHandler for testing: the encrypted format is - * [4-byte plaintext length (int)][XOR-encrypted payload]. - * This matches the structural contract of SparkCryptoHandler so the - * bounds check (decryptedLength > length - 4) is also exercised. + * A minimal CryptoHandler for testing: the encrypted format is [4-byte plaintext length + * (int)][XOR-encrypted payload]. This matches the structural contract of SparkCryptoHandler so + * the bounds check (decryptedLength > length - 4) is also exercised. */ static class XorCryptoHandler implements CryptoHandler { private final byte key; @@ -101,14 +99,11 @@ public byte[] decrypt(byte[] input, int offset, int length) throws IOException { } /** - * Build a single batch ByteBuf as ShuffleClientImpl.pushOrMergeData does: - * optionally compress, optionally encrypt, then prepend the 16-byte batch header. + * Build a single batch ByteBuf as ShuffleClientImpl.pushOrMergeData does: optionally compress, + * optionally encrypt, then prepend the 16-byte batch header. */ private ByteBuf buildBatch( - byte[] plaintext, - boolean compress, - CryptoHandler cryptoHandler, - CelebornConf conf) + byte[] plaintext, boolean compress, CryptoHandler cryptoHandler, CelebornConf conf) throws IOException { byte[] data = plaintext; int offset = 0; @@ -132,17 +127,17 @@ private ByteBuf buildBatch( // Step 3: prepend the 16-byte batch header [mapId|attemptId|batchId|payloadLen] byte[] body = new byte[BATCH_HEADER_SIZE + length]; - Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, 0); // mapId - Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, 0); // attemptId - Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 8, 0); // batchId + Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, 0); // mapId + Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, 0); // attemptId + Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 8, 0); // batchId Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, length); // payload length System.arraycopy(data, offset, body, BATCH_HEADER_SIZE, length); return Unpooled.wrappedBuffer(body); } /** - * Create a CelebornInputStream backed by a mock TransportClient that serves - * the given batchBuf as a single chunk. + * Create a CelebornInputStream backed by a mock TransportClient that serves the given batchBuf as + * a single chunk. */ private CelebornInputStream createStream( ByteBuf batchBuf, @@ -232,8 +227,7 @@ public void testEncryptDecryptRoundTrip() throws IOException, InterruptedExcepti // Build an encrypted batch and read it back through CelebornInputStream ByteBuf batchBuf = buildBatch(plaintext, false, handler, conf); - try (CelebornInputStream stream = - createStream(batchBuf, false, Optional.of(handler), conf)) { + try (CelebornInputStream stream = createStream(batchBuf, false, Optional.of(handler), conf)) { assertArrayEquals(plaintext, readAll(stream)); } } From faadb25cd356f3910f7315af9c65c1132f6f7bef Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Mon, 15 Jun 2026 13:57:42 -0700 Subject: [PATCH 08/14] [CELEBORN-2329] Address follow-up EAR review comments - Add cryptoHandler param to spark-3.5 and spark-4 columnar readers so encryption is no longer silently dropped on read for those profiles. Remove the now-unnecessary 9-arg backward-compat overload of createColumnarShuffleReader and fallback .impl in SparkUtils. - Move dedup/stale-attempt checks before decryption in CelebornInputStream.fillBuffer() to avoid paying the AES cost for batches that will be discarded. Use skipBytes() for skipped batches. - Fix encryptedBuf initial sizing to include headerLen headroom (same as compressedBuf/rawDataBuf) to prevent per-batch reallocation when compression is enabled. - Null out encryptedBuf in close() for consistency with other buffers. - Mark cryptoHandler volatile in ShuffleClientImpl for safe publication. - Guard decrypt debug log with isDebugEnabled() check. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../CelebornColumnarShuffleReader.scala | 9 +- .../CelebornColumnarShuffleReaderSuite.scala | 8 +- .../spark/shuffle/celeborn/SparkUtils.java | 36 ------ .../CelebornColumnarShuffleReader.scala | 9 +- .../CelebornColumnarShuffleReaderSuite.scala | 8 +- .../celeborn/client/ShuffleClientImpl.java | 2 +- .../client/read/CelebornInputStream.java | 121 ++++++++++-------- 7 files changed, 93 insertions(+), 100 deletions(-) diff --git a/client-spark/spark-3.5-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-3.5-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala index fd888fb9dc1..c32dfaf2eb9 100644 --- a/client-spark/spark-3.5-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala +++ b/client-spark/spark-3.5-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala @@ -17,12 +17,15 @@ package org.apache.spark.shuffle.celeborn +import java.util.Optional + import org.apache.spark.{ShuffleDependency, TaskContext} import org.apache.spark.serializer.SerializerInstance import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.sql.execution.UnsafeRowSerializer import org.apache.spark.sql.execution.columnar.{CelebornBatchBuilder, CelebornColumnarBatchSerializer} +import org.apache.celeborn.client.security.CryptoHandler import org.apache.celeborn.common.CelebornConf class CelebornColumnarShuffleReader[K, C]( @@ -34,7 +37,8 @@ class CelebornColumnarShuffleReader[K, C]( context: TaskContext, conf: CelebornConf, metrics: ShuffleReadMetricsReporter, - shuffleIdTracker: ExecutorShuffleIdTracker) + shuffleIdTracker: ExecutorShuffleIdTracker, + cryptoHandler: Optional[CryptoHandler] = Optional.empty()) extends CelebornShuffleReader[K, C]( handle, startPartition, @@ -44,7 +48,8 @@ class CelebornColumnarShuffleReader[K, C]( context, conf, metrics, - shuffleIdTracker) { + shuffleIdTracker, + cryptoHandler) { override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { val schema = CustomShuffleDependencyUtils.getSchema(dep) diff --git a/client-spark/spark-3.5-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-3.5-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala index d0f4462be3e..edc67d98335 100644 --- a/client-spark/spark-3.5-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala +++ b/client-spark/spark-3.5-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.celeborn +import java.util.Optional + import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance} import org.apache.spark.sql.execution.UnsafeRowSerializer @@ -58,7 +60,8 @@ class CelebornColumnarShuffleReaderSuite { taskContext, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty()) assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]]) } finally { if (shuffleClient != null) { @@ -92,7 +95,8 @@ class CelebornColumnarShuffleReaderSuite { taskContext, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty()) val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]]) Mockito.when(shuffleDependency.shuffleId).thenReturn(0) Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer( diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index f1997692c63..412d839434e 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -284,18 +284,6 @@ private static class ColumnarShuffleReaderConstructorHolder { ShuffleReadMetricsReporter.class, ExecutorShuffleIdTracker.class, Optional.class) - // Fallback for older columnar-shuffle modules that don't have the cryptoHandler param - .impl( - COLUMNAR_SHUFFLE_READER_CLASS, - CelebornShuffleHandle.class, - int.class, - int.class, - int.class, - int.class, - TaskContext.class, - CelebornConf.class, - ShuffleReadMetricsReporter.class, - ExecutorShuffleIdTracker.class) .build(); } @@ -324,30 +312,6 @@ public static CelebornShuffleReader createColumnarShuffleReader( cryptoHandler); } - /** Overload for callers that do not use encryption at rest. */ - public static CelebornShuffleReader createColumnarShuffleReader( - CelebornShuffleHandle handle, - int startPartition, - int endPartition, - int startMapIndex, - int endMapIndex, - TaskContext context, - CelebornConf conf, - ShuffleReadMetricsReporter metrics, - ExecutorShuffleIdTracker shuffleIdTracker) { - return createColumnarShuffleReader( - handle, - startPartition, - endPartition, - startMapIndex, - endMapIndex, - context, - conf, - metrics, - shuffleIdTracker, - Optional.empty()); - } - // Added in SPARK-32920, for Spark 3.2 and above private static final DynMethods.UnboundMethod UnregisterAllMapAndMergeOutput_METHOD = DynMethods.builder("unregisterAllMapAndMergeOutput") diff --git a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala index fd888fb9dc1..c32dfaf2eb9 100644 --- a/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala +++ b/client-spark/spark-4-columnar-shuffle/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReader.scala @@ -17,12 +17,15 @@ package org.apache.spark.shuffle.celeborn +import java.util.Optional + import org.apache.spark.{ShuffleDependency, TaskContext} import org.apache.spark.serializer.SerializerInstance import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.sql.execution.UnsafeRowSerializer import org.apache.spark.sql.execution.columnar.{CelebornBatchBuilder, CelebornColumnarBatchSerializer} +import org.apache.celeborn.client.security.CryptoHandler import org.apache.celeborn.common.CelebornConf class CelebornColumnarShuffleReader[K, C]( @@ -34,7 +37,8 @@ class CelebornColumnarShuffleReader[K, C]( context: TaskContext, conf: CelebornConf, metrics: ShuffleReadMetricsReporter, - shuffleIdTracker: ExecutorShuffleIdTracker) + shuffleIdTracker: ExecutorShuffleIdTracker, + cryptoHandler: Optional[CryptoHandler] = Optional.empty()) extends CelebornShuffleReader[K, C]( handle, startPartition, @@ -44,7 +48,8 @@ class CelebornColumnarShuffleReader[K, C]( context, conf, metrics, - shuffleIdTracker) { + shuffleIdTracker, + cryptoHandler) { override def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = { val schema = CustomShuffleDependencyUtils.getSchema(dep) diff --git a/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala b/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala index d0f4462be3e..edc67d98335 100644 --- a/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala +++ b/client-spark/spark-4-columnar-shuffle/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornColumnarShuffleReaderSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.celeborn +import java.util.Optional + import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance} import org.apache.spark.sql.execution.UnsafeRowSerializer @@ -58,7 +60,8 @@ class CelebornColumnarShuffleReaderSuite { taskContext, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty()) assert(shuffleReader.getClass == classOf[CelebornColumnarShuffleReader[Int, String]]) } finally { if (shuffleClient != null) { @@ -92,7 +95,8 @@ class CelebornColumnarShuffleReaderSuite { taskContext, new CelebornConf(), null, - new ExecutorShuffleIdTracker()) + new ExecutorShuffleIdTracker(), + Optional.empty()) val shuffleDependency = Mockito.mock(classOf[ShuffleDependency[Int, String, String]]) Mockito.when(shuffleDependency.shuffleId).thenReturn(0) Mockito.when(shuffleDependency.serializer).thenReturn(new KryoSerializer( diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 6a50c65b454..e1f4b6aab9d 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -102,7 +102,7 @@ public class ShuffleClientImpl extends ShuffleClient { protected byte[] extension; - private Optional cryptoHandler = Optional.empty(); + private volatile Optional cryptoHandler = Optional.empty(); // key: appShuffleIdentifier, value: shuffleId protected Map> shuffleIdCache = JavaUtils.newConcurrentHashMap(); diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index e3d551572b5..c22eb5a5ac9 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -740,6 +740,7 @@ public synchronized void close() { compressedBuf = null; rawDataBuf = null; + encryptedBuf = null; batchesRead = null; locations = null; attempts = null; @@ -800,12 +801,14 @@ private boolean moveToNextChunk() throws IOException { private void init() { int bufferSize = conf.clientFetchBufferSize(); + int headerLen = shouldDecompress ? Decompressor.getCompressionHeaderLength(conf) : 0; if (cryptoHandler.isPresent()) { - encryptedBuf = new byte[bufferSize]; + // The encrypted payload is: IV(16) + ciphertext(compressedSize), where + // compressedSize can reach bufferSize + headerLen, so match the same headroom. + encryptedBuf = new byte[bufferSize + headerLen]; } if (shouldDecompress) { - int headerLen = Decompressor.getCompressionHeaderLength(conf); bufferSize += headerLen; compressedBuf = new byte[bufferSize]; decompressor = Decompressor.getDecompressor(conf); @@ -836,23 +839,58 @@ private boolean fillBuffer() throws IOException { int batchId = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8); int size = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 12); - // Read and optionally decrypt data into the appropriate buffer. - // encryptedSize tracks the on-wire (encrypted) byte count for metrics; size is - // reassigned to the decrypted length so downstream decompression and limit logic - // operate on the correct plaintext size. + // encryptedSize is the on-wire byte count (used for metrics); size will be + // reassigned to the decrypted length after decryption. int encryptedSize = size; + + // Perform dedup/stale-attempt checks before decrypting to avoid paying the + // crypto cost for batches that will be discarded anyway. + if (attemptId != attempts[mapId]) { + currentChunk.skipBytes(size); + continue; + } + if (readSkewPartitionWithoutMapRange) { + LocationPushFailedBatches locationPushFailedBatches = + this.failedBatches.get(currentReader.getLocation().getUniqueId()); + if (null != locationPushFailedBatches) { + if (locationPushFailedBatches.contains(mapId, attemptId, batchId)) { + logger.warn( + "Skip duplicated batch: mapId={}, attemptId={}, batchId={}", + mapId, + attemptId, + batchId); + currentChunk.skipBytes(size); + continue; + } + } + } + Set batchSet = batchesRead.computeIfAbsent(mapId, k -> new HashSet<>()); + if (batchSet.contains(batchId)) { + callback.incDuplicateBytesRead(BATCH_HEADER_SIZE + encryptedSize); + logger.debug( + "Skip duplicated batch: mapId {}, attemptId {}, batchId {}.", + mapId, + attemptId, + batchId); + currentChunk.skipBytes(size); + continue; + } + + // Batch is unique and from the correct attempt — now read and optionally decrypt. if (cryptoHandler.isPresent()) { if (size > encryptedBuf.length) { encryptedBuf = new byte[size]; } currentChunk.readBytes(encryptedBuf, 0, size); byte[] decrypted = cryptoHandler.get().decrypt(encryptedBuf, 0, size); - logger.debug( - "Decrypted shuffle data for shuffle {} partition {}: {} bytes -> {} bytes.", - shuffleId, - partitionId, - size, - decrypted.length); + if (logger.isDebugEnabled()) { + logger.debug( + "Decrypted shuffle data for shuffle {} partition {}: {} bytes -> {} bytes.", + shuffleId, + partitionId, + size, + decrypted.length); + } size = decrypted.length; if (shouldDecompress) { compressedBuf = decrypted; @@ -871,51 +909,24 @@ private boolean fillBuffer() throws IOException { currentChunk.readBytes(rawDataBuf, 0, size); } - // de-duplicate - if (attemptId == attempts[mapId]) { - if (readSkewPartitionWithoutMapRange) { - LocationPushFailedBatches locationPushFailedBatches = - this.failedBatches.get(currentReader.getLocation().getUniqueId()); - if (null != locationPushFailedBatches) { - if (locationPushFailedBatches.contains(mapId, attemptId, batchId)) { - logger.warn( - "Skip duplicated batch: mapId={}, attemptId={}, batchId={}", - mapId, - attemptId, - batchId); - continue; - } - } - } - Set batchSet = batchesRead.computeIfAbsent(mapId, k -> new HashSet<>()); - if (!batchSet.contains(batchId)) { - batchSet.add(batchId); - callback.incBytesRead(BATCH_HEADER_SIZE + encryptedSize); - if (shouldDecompress) { - // decompress data - int originalLength = decompressor.getOriginalLen(compressedBuf); - if (rawDataBuf.length < originalLength) { - rawDataBuf = new byte[originalLength]; - } - limit = decompressor.decompress(compressedBuf, rawDataBuf, 0); - } else { - limit = size; - } - if (shuffleIntegrityCheckEnabled) { - aggregatedActualCommitMetadata.addDataWithOffsetAndLength(rawDataBuf, 0, limit); - } - position = 0; - hasData = true; - break; - } else { - callback.incDuplicateBytesRead(BATCH_HEADER_SIZE + encryptedSize); - logger.debug( - "Skip duplicated batch: mapId {}, attemptId {}, batchId {}.", - mapId, - attemptId, - batchId); + batchSet.add(batchId); + callback.incBytesRead(BATCH_HEADER_SIZE + encryptedSize); + if (shouldDecompress) { + // decompress data + int originalLength = decompressor.getOriginalLen(compressedBuf); + if (rawDataBuf.length < originalLength) { + rawDataBuf = new byte[originalLength]; } + limit = decompressor.decompress(compressedBuf, rawDataBuf, 0); + } else { + limit = size; + } + if (shuffleIntegrityCheckEnabled) { + aggregatedActualCommitMetadata.addDataWithOffsetAndLength(rawDataBuf, 0, limit); } + position = 0; + hasData = true; + break; } if (!hasData) { From 032303636d757e3433dd44c125c32e04edc789e9 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Tue, 16 Jun 2026 23:07:31 -0700 Subject: [PATCH 09/14] [CELEBORN-2329] Address Copilot follow-up review comments - Snapshot volatile cryptoHandler to a local var in pushOrMergeData to avoid a TOCTOU race between isPresent() and get() on the volatile field - Fix inaccurate comment in CelebornInputStream.init() that described SparkCryptoHandler-specific layout instead of the generic contract - Fix ByteBuf double-retain in CelebornInputStreamCryptoRoundTripSuiteJ: duplicate() shares the buffer without bumping the ref count, so the stream's single release correctly frees it - Add length < 4 guard in XorCryptoHandler.decrypt before the unsafe Platform.getInt read Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../org/apache/celeborn/client/ShuffleClientImpl.java | 6 ++++-- .../celeborn/client/read/CelebornInputStream.java | 4 ++-- .../read/CelebornInputStreamCryptoRoundTripSuiteJ.java | 10 ++++++++-- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index e1f4b6aab9d..a5970fa0654 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -1057,8 +1057,10 @@ public int pushOrMergeData( length = compressor.getCompressedTotalSize(); } - if (cryptoHandler.isPresent()) { - byte[] encrypted = cryptoHandler.get().encrypt(data, offset, length); + // Snapshot volatile field once to avoid a TOCTOU race between isPresent() and get(). + Optional handler = cryptoHandler; + if (handler.isPresent()) { + byte[] encrypted = handler.get().encrypt(data, offset, length); logger.debug( "Encrypted shuffle data for shuffle {} map {} partition {}: {} bytes -> {} bytes.", shuffleId, diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index c22eb5a5ac9..bbdaf2056ee 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -804,8 +804,8 @@ private void init() { int headerLen = shouldDecompress ? Decompressor.getCompressionHeaderLength(conf) : 0; if (cryptoHandler.isPresent()) { - // The encrypted payload is: IV(16) + ciphertext(compressedSize), where - // compressedSize can reach bufferSize + headerLen, so match the same headroom. + // Size to match compressedBuf/rawDataBuf headroom; exact overhead depends on the + // CryptoHandler implementation (e.g. SparkCryptoHandler prepends a 4-byte length). encryptedBuf = new byte[bufferSize + headerLen]; } if (shouldDecompress) { diff --git a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java index 109b1cd990f..394898a7a24 100644 --- a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java @@ -83,6 +83,10 @@ public byte[] encrypt(byte[] input, int offset, int length) throws IOException { @Override public byte[] decrypt(byte[] input, int offset, int length) throws IOException { + // Validate the buffer is large enough to hold the 4-byte length prefix + if (length < 4) { + throw new IOException("Encrypted buffer too short: " + length); + } // Read the plaintext length from the 4-byte prefix int decryptedLength = Platform.getInt(input, Platform.BYTE_ARRAY_OFFSET + offset); // Validate bounds: the 4-byte prefix must fit inside the encrypted buffer @@ -156,8 +160,10 @@ private CelebornInputStream createStream( doAnswer( invocation -> { ChunkReceivedCallback cb = invocation.getArgument(3); - // Serve the pre-built batch buffer immediately as chunk 0 - cb.onSuccess(0, new NettyManagedBuffer(batchBuf.duplicate().retain())); + // Serve the pre-built batch buffer immediately as chunk 0; duplicate() shares + // the underlying data without incrementing the ref count, so the stream's + // single release correctly frees the buffer. + cb.onSuccess(0, new NettyManagedBuffer(batchBuf.duplicate())); return null; }) .when(client) From 0d13368596d3dc76c3b8c4889e2025fd160a40b2 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Sun, 28 Jun 2026 00:21:20 -0700 Subject: [PATCH 10/14] [CELEBORN-2329] Address SteNicholas follow-up review comments (round 3) - Fix testEncryptWithIntegrityCheckEnabled to actually verify the checksum is computed over decrypted plaintext: use ArgumentCaptor on readReducerPartitionEnd and assert against independently-computed CommitMetadata over the plaintext bytes - Add commons-crypto to spark-4-shaded includes (already present in spark-2/3-shaded) to prevent NoClassDefFoundError on Spark 4 - Fix getCryptoHandler() caching in SparkShuffleManager (Spark 2 + 3): do not cache Optional.empty() when SparkEnv is transiently null, so a later call retries once the executor env is ready - Add compression-header length guard after decrypt in CelebornInputStream.fillBuffer() to give a clear error instead of an opaque ArrayIndexOutOfBounds when a corrupt batch decodes to a payload too short for the decompressor header - Guard encrypt-side logger.debug with isDebugEnabled() for consistency with the decrypt-side guard Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../shuffle/celeborn/SparkShuffleManager.java | 8 +++- .../shuffle/celeborn/SparkShuffleManager.java | 8 +++- client-spark/spark-4-shaded/pom.xml | 1 + .../celeborn/client/ShuffleClientImpl.java | 16 ++++--- .../client/read/CelebornInputStream.java | 5 +++ ...ebornInputStreamCryptoRoundTripSuiteJ.java | 44 ++++++++++++++++--- 6 files changed, 67 insertions(+), 15 deletions(-) diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 6e929c215da..5263c54fe88 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -73,7 +73,13 @@ public class SparkShuffleManager implements ShuffleManager { private Optional getCryptoHandler() { if (cryptoHandler == null) { - cryptoHandler = SparkCommonUtils.getCryptoHandler(conf); + // Only cache when SparkEnv is ready. If it is transiently null (e.g. called before + // the executor env is initialized), return empty without caching so the next call retries. + if (SparkEnv.get() != null) { + cryptoHandler = SparkCommonUtils.getCryptoHandler(conf); + } else { + return Optional.empty(); + } } return cryptoHandler; } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 3e45ba2c8ec..f528934b397 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -99,7 +99,13 @@ public class SparkShuffleManager implements ShuffleManager { private Optional getCryptoHandler() { if (cryptoHandler == null) { - cryptoHandler = SparkCommonUtils.getCryptoHandler(conf); + // Only cache when SparkEnv is ready. If it is transiently null (e.g. called before + // the executor env is initialized), return empty without caching so the next call retries. + if (SparkEnv.get() != null) { + cryptoHandler = SparkCommonUtils.getCryptoHandler(conf); + } else { + return Optional.empty(); + } } return cryptoHandler; } diff --git a/client-spark/spark-4-shaded/pom.xml b/client-spark/spark-4-shaded/pom.xml index 5e741d60179..d41419f9979 100644 --- a/client-spark/spark-4-shaded/pom.xml +++ b/client-spark/spark-4-shaded/pom.xml @@ -78,6 +78,7 @@ org.apache.commons:commons-lang3 org.roaringbitmap:RoaringBitmap commons-io:commons-io + org.apache.commons:commons-crypto diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index a5970fa0654..46c1189ca40 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -1061,13 +1061,15 @@ public int pushOrMergeData( Optional handler = cryptoHandler; if (handler.isPresent()) { byte[] encrypted = handler.get().encrypt(data, offset, length); - logger.debug( - "Encrypted shuffle data for shuffle {} map {} partition {}: {} bytes -> {} bytes.", - shuffleId, - mapId, - partitionId, - length, - encrypted.length); + if (logger.isDebugEnabled()) { + logger.debug( + "Encrypted shuffle data for shuffle {} map {} partition {}: {} bytes -> {} bytes.", + shuffleId, + mapId, + partitionId, + length, + encrypted.length); + } data = encrypted; offset = 0; length = encrypted.length; diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index bbdaf2056ee..7759bea8da9 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -893,6 +893,11 @@ private boolean fillBuffer() throws IOException { } size = decrypted.length; if (shouldDecompress) { + if (decrypted.length < Decompressor.getCompressionHeaderLength(conf)) { + throw new IOException( + "Decrypted batch too short to contain compression header: " + decrypted.length + + " bytes (shuffleId=" + shuffleId + ", partitionId=" + partitionId + ")"); + } compressedBuf = decrypted; } else { rawDataBuf = decrypted; diff --git a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java index 394898a7a24..73c855fcc3f 100644 --- a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java @@ -32,11 +32,13 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.client.compress.Compressor; import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.CommitMetadata; import org.apache.celeborn.common.network.buffer.NettyManagedBuffer; import org.apache.celeborn.common.network.client.ChunkReceivedCallback; import org.apache.celeborn.common.network.client.TransportClient; @@ -149,6 +151,21 @@ private CelebornInputStream createStream( Optional cryptoHandler, CelebornConf conf) throws IOException, InterruptedException { + return createStreamWithClient(batchBuf, needDecompress, cryptoHandler, conf, + mock(ShuffleClient.class)); + } + + /** + * Like {@link #createStream} but with a caller-supplied ShuffleClient mock, so tests can + * verify interactions such as {@code readReducerPartitionEnd}. + */ + private CelebornInputStream createStreamWithClient( + ByteBuf batchBuf, + boolean needDecompress, + Optional cryptoHandler, + CelebornConf conf, + ShuffleClient shuffleClient) + throws IOException, InterruptedException { TransportClient client = mock(TransportClient.class); PbStreamHandler pbHandler = PbStreamHandler.newBuilder().setStreamId(1L).setNumChunks(1).build(); @@ -172,8 +189,6 @@ private CelebornInputStream createStream( TransportClientFactory clientFactory = mock(TransportClientFactory.class); when(clientFactory.createClient(anyString(), anyInt())).thenReturn(client); - ShuffleClient shuffleClient = mock(ShuffleClient.class); - // PRIMARY location pointing to a single HDD partition PartitionLocation location = new PartitionLocation( @@ -268,18 +283,35 @@ public void testCompressThenEncryptRoundTrip() throws IOException, InterruptedEx @Test public void testEncryptWithIntegrityCheckEnabled() throws IOException, InterruptedException { - // Verify that EAR + shuffle integrity check (celeborn.client.shuffle.integrityCheck.enabled) - // work together: the checksum is computed over plaintext, so decrypt-then-verify must hold. + // Verify that EAR + shuffle integrity check work together: the checksum must be computed + // over the *decrypted* plaintext, not the ciphertext. We capture the crc32/bytes passed to + // readReducerPartitionEnd and assert they match an independently-computed plaintext checksum. byte[] plaintext = "integrity check should pass after decryption".getBytes(); CelebornConf conf = new CelebornConf(); conf.set(CelebornConf.CLIENT_SHUFFLE_INTEGRITY_CHECK_ENABLED().key(), "true"); XorCryptoHandler handler = new XorCryptoHandler((byte) 0x7F); - // The integrity metadata (checksum) is added by CelebornInputStream over the decrypted data + // Independently compute the expected checksum over the plaintext bytes. + CommitMetadata expected = new CommitMetadata(); + expected.addDataWithOffsetAndLength(plaintext, 0, plaintext.length); + ByteBuf batchBuf = buildBatch(plaintext, false, handler, conf); - try (CelebornInputStream stream = createStream(batchBuf, false, Optional.of(handler), conf)) { + + // createStream passes shuffleId=1, partitionId=0, startMapIndex=0, endMapIndex=100 + ShuffleClient shuffleClient = mock(ShuffleClient.class); + try (CelebornInputStream stream = + createStreamWithClient(batchBuf, false, Optional.of(handler), conf, shuffleClient)) { assertArrayEquals(plaintext, readAll(stream)); } + + // Verify readReducerPartitionEnd was called with the checksum over plaintext, not ciphertext. + ArgumentCaptor crcCaptor = ArgumentCaptor.forClass(Integer.class); + ArgumentCaptor bytesCaptor = ArgumentCaptor.forClass(Long.class); + verify(shuffleClient) + .readReducerPartitionEnd(anyInt(), anyInt(), anyInt(), anyInt(), + crcCaptor.capture(), bytesCaptor.capture()); + assertEquals("checksum must be over plaintext", expected.getChecksum(), (int) crcCaptor.getValue()); + assertEquals("byte count must match plaintext length", expected.getBytes(), (long) bytesCaptor.getValue()); } @Test From 5bb27c157cf7ca2f63989fe6ee631b377151d767 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Sun, 28 Jun 2026 00:22:30 -0700 Subject: [PATCH 11/14] Spotless --- .../client/read/CelebornInputStream.java | 9 +++++++-- ...ebornInputStreamCryptoRoundTripSuiteJ.java | 20 +++++++++++-------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 7759bea8da9..0493b48a8fb 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -895,8 +895,13 @@ private boolean fillBuffer() throws IOException { if (shouldDecompress) { if (decrypted.length < Decompressor.getCompressionHeaderLength(conf)) { throw new IOException( - "Decrypted batch too short to contain compression header: " + decrypted.length - + " bytes (shuffleId=" + shuffleId + ", partitionId=" + partitionId + ")"); + "Decrypted batch too short to contain compression header: " + + decrypted.length + + " bytes (shuffleId=" + + shuffleId + + ", partitionId=" + + partitionId + + ")"); } compressedBuf = decrypted; } else { diff --git a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java index 73c855fcc3f..06655b887b1 100644 --- a/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/read/CelebornInputStreamCryptoRoundTripSuiteJ.java @@ -151,13 +151,13 @@ private CelebornInputStream createStream( Optional cryptoHandler, CelebornConf conf) throws IOException, InterruptedException { - return createStreamWithClient(batchBuf, needDecompress, cryptoHandler, conf, - mock(ShuffleClient.class)); + return createStreamWithClient( + batchBuf, needDecompress, cryptoHandler, conf, mock(ShuffleClient.class)); } /** - * Like {@link #createStream} but with a caller-supplied ShuffleClient mock, so tests can - * verify interactions such as {@code readReducerPartitionEnd}. + * Like {@link #createStream} but with a caller-supplied ShuffleClient mock, so tests can verify + * interactions such as {@code readReducerPartitionEnd}. */ private CelebornInputStream createStreamWithClient( ByteBuf batchBuf, @@ -308,10 +308,14 @@ public void testEncryptWithIntegrityCheckEnabled() throws IOException, Interrupt ArgumentCaptor crcCaptor = ArgumentCaptor.forClass(Integer.class); ArgumentCaptor bytesCaptor = ArgumentCaptor.forClass(Long.class); verify(shuffleClient) - .readReducerPartitionEnd(anyInt(), anyInt(), anyInt(), anyInt(), - crcCaptor.capture(), bytesCaptor.capture()); - assertEquals("checksum must be over plaintext", expected.getChecksum(), (int) crcCaptor.getValue()); - assertEquals("byte count must match plaintext length", expected.getBytes(), (long) bytesCaptor.getValue()); + .readReducerPartitionEnd( + anyInt(), anyInt(), anyInt(), anyInt(), crcCaptor.capture(), bytesCaptor.capture()); + assertEquals( + "checksum must be over plaintext", expected.getChecksum(), (int) crcCaptor.getValue()); + assertEquals( + "byte count must match plaintext length", + expected.getBytes(), + (long) bytesCaptor.getValue()); } @Test From eddd7c1fd8e13eafb39d726a2f269b980f415939 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Tue, 30 Jun 2026 15:53:26 -0700 Subject: [PATCH 12/14] [CELEBORN-2329] Address PR review comments (round 4) - SparkCryptoHandler: cache minimal SparkConf in constructor to amortize per-batch toCryptoConf() scan cost (mridulm) - SparkCryptoHandler: fix decrypt bounds check from length-4 to length-20; real format is [4-byte len][16-byte IV][ciphertext] (SteNicholas) - SparkCommonUtils: log warn when IO encryption is enabled but key is unavailable, instead of silently returning empty (SteNicholas) - ShuffleClient: apply crypto handler even on already-initialized singleton so transient SparkEnv-null on first init does not permanently disable encryption (SteNicholas) - Remove dead commons-crypto from client-spark/common and shaded jars; SparkCryptoHandler delegates to Spark's CryptoStreamUtils which uses Spark's own classpath commons-crypto at runtime (SteNicholas) - CelebornInputStream: pre-size encryptedBuf with crypto overhead headroom to avoid reallocation on first encrypted batch (SteNicholas) - SparkCryptoHandlerSuiteJ: add test verifying corrected length-20 bound rejects crafted payloads that passed the old length-4 guard (SteNicholas) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- client-spark/common/pom.xml | 4 --- .../shuffle/celeborn/SparkCommonUtils.java | 8 ++++++ .../shuffle/celeborn/SparkCryptoHandler.java | 25 +++++++++++++---- .../celeborn/SparkCryptoHandlerSuiteJ.java | 27 +++++++++++++++++++ client-spark/spark-3-shaded/pom.xml | 1 - client-spark/spark-4-shaded/pom.xml | 1 - .../apache/celeborn/client/ShuffleClient.java | 7 +++++ .../client/read/CelebornInputStream.java | 7 ++--- 8 files changed, 66 insertions(+), 14 deletions(-) diff --git a/client-spark/common/pom.xml b/client-spark/common/pom.xml index b115d0cd4f8..60e0d21eb42 100644 --- a/client-spark/common/pom.xml +++ b/client-spark/common/pom.xml @@ -75,10 +75,6 @@ spark-sql_${scala.binary.version} provided - - org.apache.commons - commons-crypto - org.mockito mockito-core diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java index 697f86776ba..97704010f44 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java @@ -23,6 +23,9 @@ import scala.Option; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; @@ -34,6 +37,7 @@ import org.apache.celeborn.reflect.DynMethods; public class SparkCommonUtils { + private static final Logger logger = LoggerFactory.getLogger(SparkCommonUtils.class); public static void validateAttemptConfig(SparkConf conf) throws IllegalArgumentException { int DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4; int maxStageAttempts = @@ -113,6 +117,10 @@ public static Optional getCryptoHandler(SparkConf conf) { } Option key = env.securityManager().getIOEncryptionKey(); if (!key.isDefined()) { + logger.warn( + "IO encryption is enabled (spark.io.encryption.enabled=true) but the IO encryption key " + + "is not available from the SecurityManager. Shuffle data will be written as " + + "plaintext. Ensure the SecurityManager provides an IO encryption key."); return Optional.empty(); } return Optional.of(new SparkCryptoHandler(conf, key.get())); diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java index 84af2b34b92..e4b1ea06f65 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java @@ -23,6 +23,7 @@ import java.io.DataOutputStream; import java.io.IOException; import java.io.OutputStream; +import java.util.Properties; import org.apache.spark.SparkConf; import org.apache.spark.security.CryptoStreamUtils; @@ -30,11 +31,24 @@ import org.apache.celeborn.client.security.CryptoHandler; public class SparkCryptoHandler implements CryptoHandler { + // On-wire format: [4-byte plaintext length][16-byte IV][ciphertext]. + // The minimum overhead (length prefix + IV) added by the crypto stream. + private static final int CRYPTO_OVERHEAD_BYTES = + Integer.BYTES + CryptoStreamUtils.IV_LENGTH_IN_BYTES(); + private final SparkConf sparkConf; private final byte[] key; public SparkCryptoHandler(SparkConf sparkConf, byte[] key) { - this.sparkConf = sparkConf; + // Pre-filter sparkConf to only crypto-relevant keys so that + // CryptoStreamUtils.toCryptoConf() does not scan the full SparkConf on every batch. + Properties cryptoProps = CryptoStreamUtils.toCryptoConf(sparkConf); + SparkConf minimalConf = new SparkConf(false); + String prefix = CryptoStreamUtils.SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX(); + for (String propKey : cryptoProps.stringPropertyNames()) { + minimalConf.set(prefix + propKey, cryptoProps.getProperty(propKey)); + } + this.sparkConf = minimalConf; this.key = key; } @@ -54,10 +68,11 @@ public byte[] decrypt(byte[] input, int offset, int length) throws IOException { ByteArrayInputStream bais = new ByteArrayInputStream(input, offset, length); DataInputStream dis = new DataInputStream(bais); int decryptedLength = dis.readInt(); - // The encrypted payload format is: [4-byte plaintext length][ciphertext...]. - // So the maximum valid decrypted length is length - 4 (the ciphertext portion). - // A value outside this range indicates corruption or a wrong key. - if (decryptedLength < 0 || decryptedLength > length - 4) { + // The encrypted payload format is: [4-byte plaintext length][16-byte IV][ciphertext]. + // The minimum on-wire overhead is CRYPTO_OVERHEAD_BYTES (4 + 16 = 20), so the maximum + // valid plaintext length is length - 20. A value outside this range indicates corruption + // or a wrong key. + if (decryptedLength < 0 || decryptedLength > length - CRYPTO_OVERHEAD_BYTES) { throw new IOException( "Invalid decrypted length: " + decryptedLength + ", encrypted length: " + length); } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java index 6dd16c36ac1..6baa62626fd 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandlerSuiteJ.java @@ -20,6 +20,8 @@ import static org.junit.Assert.*; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.security.SecureRandom; import java.util.Arrays; @@ -135,4 +137,29 @@ public void testEmptyData() throws IOException { byte[] decrypted = handler.decrypt(encrypted, 0, encrypted.length); assertEquals(0, decrypted.length); } + + /** + * Verifies that the decrypt bounds check uses {@code length - 20} (4-byte length prefix + 16-byte + * IV), not the previous {@code length - 4}. A crafted payload whose embedded length value is + * between {@code length - 19} and {@code length - 5} (inclusive) must be rejected. + */ + @Test + public void testDecryptRejectsCraftedLengthBetweenLengthMinus4AndLengthMinus20() + throws IOException { + // Construct a minimal on-wire buffer: [4-byte length][16-byte IV][0-byte ciphertext]. + // Total = 20 bytes. Embed a plaintext length of 1 — valid under the old (length-4) + // guard (1 <= 20-4=16) but invalid under the corrected (length-20) guard (1 > 20-20=0). + int totalLen = 20; // 4 (length prefix) + 16 (IV) + 0 (ciphertext) + byte[] crafted = new byte[totalLen]; + ByteBuffer.wrap(crafted).order(ByteOrder.BIG_ENDIAN).putInt(1); // claim 1 byte of plaintext + + try { + handler.decrypt(crafted, 0, crafted.length); + fail("Expected IOException for crafted length > length - 20"); + } catch (IOException e) { + assertTrue( + "Exception message should mention decrypted length", + e.getMessage().contains("decrypted length") || e.getMessage().contains("Invalid")); + } + } } diff --git a/client-spark/spark-3-shaded/pom.xml b/client-spark/spark-3-shaded/pom.xml index d1e0cf834af..bc8c2065e2d 100644 --- a/client-spark/spark-3-shaded/pom.xml +++ b/client-spark/spark-3-shaded/pom.xml @@ -76,7 +76,6 @@ com.google.guava:failureaccess io.netty:* org.apache.commons:commons-lang3 - org.apache.commons:commons-crypto org.roaringbitmap:RoaringBitmap commons-io:commons-io diff --git a/client-spark/spark-4-shaded/pom.xml b/client-spark/spark-4-shaded/pom.xml index d41419f9979..5e741d60179 100644 --- a/client-spark/spark-4-shaded/pom.xml +++ b/client-spark/spark-4-shaded/pom.xml @@ -78,7 +78,6 @@ org.apache.commons:commons-lang3 org.roaringbitmap:RoaringBitmap commons-io:commons-io - org.apache.commons:commons-crypto diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 6324505e952..8bc1a8911cf 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -127,6 +127,13 @@ public static ShuffleClient get( } } } + // Apply the crypto handler even when the singleton is already initialized. This handles + // the case where SparkEnv was transiently unavailable during the first init call (causing + // an empty handler to be stored), so that encryption is correctly applied on retry. + // setupCryptoHandler is a volatile write and safe to call without the lock. + if (cryptoHandler != null && cryptoHandler.isPresent()) { + _instance.setupCryptoHandler(cryptoHandler); + } return _instance; } diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index 0493b48a8fb..d76b1eb9d3e 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -804,9 +804,10 @@ private void init() { int headerLen = shouldDecompress ? Decompressor.getCompressionHeaderLength(conf) : 0; if (cryptoHandler.isPresent()) { - // Size to match compressedBuf/rawDataBuf headroom; exact overhead depends on the - // CryptoHandler implementation (e.g. SparkCryptoHandler prepends a 4-byte length). - encryptedBuf = new byte[bufferSize + headerLen]; + // Pre-size to include the crypto overhead (e.g. SparkCryptoHandler adds a 4-byte length + // prefix and a 16-byte IV = 20 bytes) so the buffer is large enough for the first batch + // without an immediate reallocation. + encryptedBuf = new byte[bufferSize + headerLen + 64]; } if (shouldDecompress) { bufferSize += headerLen; From 509c1fca3fac9f2a0df8c7716d9faabb4de673f9 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Tue, 30 Jun 2026 16:02:16 -0700 Subject: [PATCH 13/14] Spotless Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../org/apache/spark/shuffle/celeborn/SparkCommonUtils.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java index 97704010f44..08e217390c4 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCommonUtils.java @@ -23,14 +23,13 @@ import scala.Option; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; import org.apache.spark.internal.config.package$; import org.apache.spark.memory.SparkOutOfMemoryError; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.apache.celeborn.client.security.CryptoHandler; import org.apache.celeborn.reflect.DynConstructors; @@ -38,6 +37,7 @@ public class SparkCommonUtils { private static final Logger logger = LoggerFactory.getLogger(SparkCommonUtils.class); + public static void validateAttemptConfig(SparkConf conf) throws IllegalArgumentException { int DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4; int maxStageAttempts = From 0d0c406385b29432089aebb088b5183029ca52b4 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Wed, 1 Jul 2026 13:44:42 -0700 Subject: [PATCH 14/14] [CELEBORN-2329] Cache reusable ByteArrayOutputStream per thread in SparkCryptoHandler Use a ThreadLocal for the encrypt path so each push thread reuses its own buffer (grows to the high-water-mark once, then reset() per batch) instead of allocating a new ByteArrayOutputStream on every batch. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../apache/spark/shuffle/celeborn/SparkCryptoHandler.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java index e4b1ea06f65..bf3922d484d 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SparkCryptoHandler.java @@ -38,6 +38,10 @@ public class SparkCryptoHandler implements CryptoHandler { private final SparkConf sparkConf; private final byte[] key; + // Each push thread reuses its own ByteArrayOutputStream to avoid per-batch allocation. + // The internal buffer grows to the high-water-mark once per thread and is reset() each call. + private final ThreadLocal encryptBaos = + ThreadLocal.withInitial(ByteArrayOutputStream::new); public SparkCryptoHandler(SparkConf sparkConf, byte[] key) { // Pre-filter sparkConf to only crypto-relevant keys so that @@ -54,7 +58,8 @@ public SparkCryptoHandler(SparkConf sparkConf, byte[] key) { @Override public byte[] encrypt(byte[] input, int offset, int length) throws IOException { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ByteArrayOutputStream baos = encryptBaos.get(); + baos.reset(); DataOutputStream dos = new DataOutputStream(baos); dos.writeInt(length); try (OutputStream cos = CryptoStreamUtils.createCryptoOutputStream(dos, sparkConf, key)) {