diff --git a/build/mvn b/build/mvn index cd6c0c796d1..290e68146dc 100755 --- a/build/mvn +++ b/build/mvn @@ -38,6 +38,7 @@ install_app() { local remote_tarball="$1/$2$4" local local_tarball="${_DIR}/$2" local binary="${_DIR}/$3" + local max_attempts=3 # setup `curl` and `wget` silent options if we're running on Jenkins local curl_opts="-L" @@ -46,23 +47,44 @@ install_app() { wget_opts="--progress=bar:force ${wget_opts}" if [ -z "$3" -o ! -f "$binary" ]; then - # check if we already have the tarball - # check if we have curl installed - # download application - [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \ - echo "exec: curl ${curl_opts} ${remote_tarball}" 1>&2 && \ - curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" - # if the file still doesn't exist, lets try `wget` and cross our fingers - [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \ - echo "exec: wget ${wget_opts} ${remote_tarball}" 1>&2 && \ - wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" - # if both were unsuccessful, exit - [ ! -f "${local_tarball}" ] && \ - echo -n "ERROR: Cannot download $2 with cURL or wget; " && \ - echo "please install manually and try again." && \ - exit 2 - cd "${_DIR}" && tar -xzf "$2" - rm -rf "$local_tarball" + local attempt=1 + while [ "${attempt}" -le "${max_attempts}" ]; do + # remove any partial/corrupt download left over from a previous attempt + rm -f "${local_tarball}" + + # download application with `curl`, falling back to `wget` + if command -v curl >/dev/null 2>&1; then + echo "exec: curl ${curl_opts} ${remote_tarball}" 1>&2 + curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" + elif command -v wget >/dev/null 2>&1; then + echo "exec: wget ${wget_opts} ${remote_tarball}" 1>&2 + wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" + else + echo "ERROR: Cannot download $2: neither cURL nor wget is installed." 1>&2 + exit 2 + fi + + # Validate the download before trusting it. A flaky Apache mirror can + # return an HTML page (mirror chooser / error) with HTTP 200, which is + # not a gzip tarball; extracting it later would fail with a confusing + # exit code. `tar -tzf` lists the archive without extracting and + # exits non-zero on a non-tarball body. + if [ -f "${local_tarball}" ] && tar -tzf "${local_tarball}" >/dev/null 2>&1; then + if cd "${_DIR}" && tar -xzf "$2"; then + rm -rf "${local_tarball}" + return 0 + fi + fi + + echo "WARN: Download of $2 from $1 was not a valid tarball" \ + "(attempt ${attempt}/${max_attempts}); retrying..." 1>&2 + attempt=$((attempt + 1)) + sleep 3 + done + + rm -f "${local_tarball}" + echo "WARN: Failed to download a valid $2 from $1 after ${max_attempts} attempts." 1>&2 + return 1 fi } @@ -77,25 +99,41 @@ install_mvn() { # See simple version normalization: http://stackoverflow.com/questions/16989598/bash-comparing-version-numbers function version { echo "$@" | awk -F. '{ printf("%03d%03d%03d\n", $1,$2,$3); }'; } if [ $(version $MVN_DETECTED_VERSION) -ne $(version $MVN_VERSION) ]; then - local APACHE_MIRROR=${APACHE_MIRROR:-'https://www.apache.org/dyn/closer.lua'} - local MIRROR_URL_QUERY="?action=download" + # Default to archive.apache.org: it serves the exact tarball + # deterministically, avoiding the closer.lua mirror redirector which + # intermittently routes to a mirror that returns an HTML page instead of + # the binary. Override with APACHE_MIRROR to use a closer mirror. + local APACHE_MIRROR=${APACHE_MIRROR:-'https://archive.apache.org/dist'} local MVN_TARBALL="apache-maven-${MVN_VERSION}-bin.tar.gz" local FILE_PATH="maven/maven-3/${MVN_VERSION}/binaries" - if [ $(command -v curl) ]; then - if ! curl -L --output /dev/null --silent --head --fail "${APACHE_MIRROR}/${FILE_PATH}/${MVN_TARBALL}${MIRROR_URL_QUERY}" ; then - # Fall back to archive.apache.org for older Maven - echo "Falling back to archive.apache.org to download Maven" - APACHE_MIRROR="https://archive.apache.org/dist" - MIRROR_URL_QUERY="" - fi - fi + # closer.lua needs the ?action=download query to redirect to a mirror; + # archive.apache.org and most plain mirrors serve the file directly. + local MIRROR_URL_QUERY="" + case "${APACHE_MIRROR}" in + *closer.lua*) MIRROR_URL_QUERY="?action=download" ;; + esac - install_app \ + if ! install_app \ "${APACHE_MIRROR}/${FILE_PATH}" \ "${MVN_TARBALL}" \ "apache-maven-${MVN_VERSION}/bin/mvn" \ - "${MIRROR_URL_QUERY}" + "${MIRROR_URL_QUERY}"; then + # Last resort: fall back to archive.apache.org, which serves the exact + # tarball deterministically. Skip if it was already the chosen mirror. + local ARCHIVE_MIRROR="https://archive.apache.org/dist" + if [ "${APACHE_MIRROR%/}" != "${ARCHIVE_MIRROR}" ]; then + echo "WARN: falling back to ${ARCHIVE_MIRROR} to download Maven" 1>&2 + install_app \ + "${ARCHIVE_MIRROR}/${FILE_PATH}" \ + "${MVN_TARBALL}" \ + "apache-maven-${MVN_VERSION}/bin/mvn" \ + "" || { echo "ERROR: Failed to download Maven; please install manually." 1>&2; exit 2; } + else + echo "ERROR: Failed to download Maven; please install manually." 1>&2 + exit 2 + fi + fi MVN_BIN="${_DIR}/apache-maven-${MVN_VERSION}/bin/mvn" fi diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 7035478ebf3..ee3bda3b20d 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; @@ -57,6 +58,7 @@ public abstract class ShuffleClient { private static Logger logger = LoggerFactory.getLogger(ShuffleClient.class); private static volatile ShuffleClient _instance; private static volatile boolean initialized = false; + private static volatile String _appUniqueId; private static volatile Map hadoopFs; private static LongAdder totalReadCounter = new LongAdder(); private static LongAdder localShuffleReadCounter = new LongAdder(); @@ -69,6 +71,7 @@ public abstract class ShuffleClient { public static void reset() { _instance = null; initialized = false; + _appUniqueId = null; hadoopFs = null; } @@ -90,7 +93,7 @@ public static ShuffleClient get( CelebornConf conf, UserIdentifier userIdentifier, byte[] extension) { - if (null == _instance || !initialized) { + if (null == _instance || !initialized || !Objects.equals(appUniqueId, _appUniqueId)) { synchronized (ShuffleClient.class) { if (null == _instance) { // During the execution of Spark tasks, each task may be interrupted due to speculative @@ -102,14 +105,33 @@ public static ShuffleClient get( _instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier); _instance.setupLifecycleManagerRef(driverHost, port); _instance.setExtension(extension); + _appUniqueId = appUniqueId; initialized = true; } else if (!initialized) { _instance.shutdown(); _instance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier); _instance.setupLifecycleManagerRef(driverHost, port); _instance.setExtension(extension); + _appUniqueId = appUniqueId; + initialized = true; + } else if (!Objects.equals(appUniqueId, _appUniqueId)) { + // Do NOT shutdown() the old _instance. Callers cache the reference returned by get(), + // and shutdown() is an immediate teardown that would terminate the RpcEnv/pools still in + // use, causing RejectedExecutionException. Teardown is owned by stop()->shutdown(). The + // orphan is bounded (one per appUniqueId) and unreachable in normal single-app JVMs. + // The spark-it suite runs multiple apps in one reused JVM with overlapping lifecycles, so + // a shutdown() here tears down an instance still in use by the previous app and fails. + ShuffleClientImpl newInstance = new ShuffleClientImpl(appUniqueId, conf, userIdentifier); + newInstance.setupLifecycleManagerRef(driverHost, port); + newInstance.setExtension(extension); + // Publish _instance before _appUniqueId. The outer guard reads both volatiles without + // holding the lock, so writing _appUniqueId first would let another thread observe the + // new id while _instance is still stale and return the old instance. + _instance = newInstance; + _appUniqueId = appUniqueId; initialized = true; } + return _instance; } } return _instance; diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 358bc227a16..403d6a3a458 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -921,8 +921,10 @@ Map reviveBatch( StatusCode statusCode = entry.getValue()._1(); if (entry.getValue()._2() != null) { PartitionLocation oldLoc = oldLocMap.get(partitionId); - // Currently, revive only check if main location available, here won't remove peer loc. - pushExcludedWorkers.remove(oldLoc.hostAndPushPort()); + if (oldLoc != null) { + // Currently, revive only check if main location available, here won't remove peer loc. + pushExcludedWorkers.remove(oldLoc.hostAndPushPort()); + } } if (StatusCode.SUCCESS == statusCode) { diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala index 263a2a9a649..850740f5ffe 100644 --- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala @@ -34,7 +34,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { protected val celebornConf: CelebornConf = new CelebornConf() - protected val APP = "app-1" + protected var APP: String = _ protected val userIdentifier: UserIdentifier = UserIdentifier("mock", "mock") private val numMappers = 8 private val mapId = 1 @@ -49,6 +49,11 @@ trait WithShuffleClientSuite extends CelebornFunSuite { _shuffleId } + override protected def beforeEach(): Unit = { + super.beforeEach() + APP = s"app-${java.util.UUID.randomUUID()}" + } + override protected def afterEach() { if (lifecycleManager != null) { lifecycleManager.stop() diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala index 5b2dd6a1097..b4296d2bafc 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala @@ -258,6 +258,10 @@ object Utils extends Logging { ScalaRandom.nextInt(until - 1 - from) + from } + val MAX_SELECTABLE_PORT = 32768 + + def selectRandomPort(): Int = selectRandomInt(1024, MAX_SELECTABLE_PORT) + def startServiceOnPort[T]( startPort: Int, startService: Int => (T, Int), diff --git a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java index 01a16a31eea..6038ee4e154 100644 --- a/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java +++ b/master/src/test/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/RatisMasterStatusSystemSuiteJ.java @@ -119,6 +119,19 @@ public static void resetRaftServer( while (!serversStarted) { try { + // Re-point each server to a fresh storage directory on retry. Ratis releases the storage + // directory lock asynchronously on close(), so a failed attempt (e.g. a random ratis port + // collision) can leave the previous directory locked. Reusing the same directory on retry + // then fails with "directory is already locked"; allocating a clean directory each time + // avoids contending for a lock that has not been released yet. Skip this on the first + // attempt: callers already configure a fresh directory when building conf1/2/3, so + // reconfiguring here would orphan that just-created (empty) directory. + if (retryCount > 0) { + configureServerConf(conf1, 1); + configureServerConf(conf2, 2); + configureServerConf(conf3, 3); + } + STATUSSYSTEM1 = new HAMasterMetaManager(mockRpcEnv, conf1); STATUSSYSTEM2 = new HAMasterMetaManager(mockRpcEnv, conf2); STATUSSYSTEM3 = new HAMasterMetaManager(mockRpcEnv, conf3); @@ -131,7 +144,8 @@ public static void resetRaftServer( String id2 = UUID.randomUUID().toString(); String id3 = UUID.randomUUID().toString(); - int ratisPort1 = Utils$.MODULE$.selectRandomInt(1024, 65535); + int ratisPort1 = + Utils$.MODULE$.selectRandomInt(1024, Utils$.MODULE$.MAX_SELECTABLE_PORT() - 2); int ratisPort2 = ratisPort1 + 1; int ratisPort3 = ratisPort2 + 1; diff --git a/master/src/test/scala/org/apache/celeborn/service/deploy/master/MasterClusterFeature.scala b/master/src/test/scala/org/apache/celeborn/service/deploy/master/MasterClusterFeature.scala index 65995bfc3db..5082a06fab7 100644 --- a/master/src/test/scala/org/apache/celeborn/service/deploy/master/MasterClusterFeature.scala +++ b/master/src/test/scala/org/apache/celeborn/service/deploy/master/MasterClusterFeature.scala @@ -50,7 +50,7 @@ trait MasterClusterFeature extends Logging { } } def selectRandomPort(): Int = synchronized { - val port = Utils.selectRandomInt(1024, 65535) + val port = Utils.selectRandomPort() val portUsed = usedPorts.contains(port) || portBounded(port) usedPorts.add(port) if (portUsed) { diff --git a/pom.xml b/pom.xml index b2e1a9cacda..9df892ed3af 100644 --- a/pom.xml +++ b/pom.xml @@ -1315,6 +1315,24 @@ prepare-agent + + + + io/netty/** + + report diff --git a/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HybridShuffleWordCountTest.scala b/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HybridShuffleWordCountTest.scala index 6794f23dfcf..ef71860d8c2 100644 --- a/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HybridShuffleWordCountTest.scala +++ b/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HybridShuffleWordCountTest.scala @@ -27,7 +27,9 @@ import org.apache.flink.runtime.jobgraph.JobType import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment import org.apache.flink.streaming.api.graph.StreamingJobGraphGenerator import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually._ import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.time.SpanSugar._ import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.internal.Logging @@ -186,12 +188,18 @@ class HybridShuffleWordCountTest extends AnyFunSuite with Logging with MiniClust } private def checkFlushingFileLength(): Unit = { - workers.map(worker => { - worker.storageManager.workingDirWriters.values().asScala.map(writers => { - writers.forEach((fileName, fileWriter) => { - assert(new File(fileName).length() == fileWriter.getDiskFileInfo.getFileLength) + // getDiskFileInfo.getFileLength is the logical byte count accounted as data is written, while + // the physical file is grown asynchronously by the LocalFlusher. Right after the job finishes + // the flusher may not have drained the last buffers yet, so the on-disk length can lag (briefly + // even 0). Wait for the flush to catch up before asserting equality instead of reading mid-flush. + eventually(timeout(30.seconds), interval(500.milliseconds)) { + workers.map(worker => { + worker.storageManager.workingDirWriters.values().asScala.map(writers => { + writers.forEach((fileName, fileWriter) => { + assert(new File(fileName).length() == fileWriter.getDiskFileInfo.getFileLength) + }) }) }) - }) + } } } diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala index 3c0303f4064..6e4b081230a 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ChangePartitionManagerUpdateWorkersSuite.scala @@ -42,6 +42,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite } override def beforeEach(): Unit = { + super.beforeEach() val testConf = Map( s"${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}" -> "3") val (master, _) = setupMiniClusterWithRandomPorts(testConf, testConf, workerNum = 1) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerUnregisterShuffleSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerUnregisterShuffleSuite.scala index d689ead7e0c..c5082bd7323 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerUnregisterShuffleSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerUnregisterShuffleSuite.scala @@ -36,6 +36,13 @@ class LifecycleManagerUnregisterShuffleSuite extends WithShuffleClientSuite celebornConf .set(CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key, "true") .set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE.key, "256K") + // The default expired-check interval is 60s. removeExpiredShuffle only + // unregisters a shuffle once `unregisterTime < now - checkInterval` and runs + // on a fixed-rate timer at that interval, so with 60s the master side cannot + // be cleared until the second tick (~120s) -- exactly the eventually() window + // below, leaving no margin and no retry if an RPC briefly fails under load. + // Use a short interval so the unregister runs promptly with ample retries. + .set(CelebornConf.SHUFFLE_EXPIRED_CHECK_INTERVAL.key, "5s") override def beforeAll(): Unit = { super.beforeAll() diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashCheckDiskSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashCheckDiskSuite.scala index 4f22982bb98..b4c8e615296 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashCheckDiskSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornHashCheckDiskSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.SparkSession import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar.convertIntToGrainOfTime -import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.protocol.ShuffleMode import org.apache.celeborn.service.deploy.worker.Worker @@ -43,10 +42,11 @@ class CelebornHashCheckDiskSuite extends SparkTestBase { } override def beforeEach(): Unit = { - ShuffleClient.reset() + stopActiveSparkSessions() } override def afterEach(): Unit = { + stopActiveSparkSessions() System.gc() } @@ -59,7 +59,7 @@ class CelebornHashCheckDiskSuite extends SparkTestBase { val combineResult = combine(sparkSession) val groupByResult = groupBy(sparkSession) val repartitionResult = repartition(sparkSession) - sparkSession.stop() + stopActiveSparkSessions() val sparkSessionEnableCeleborn = SparkSession.builder() .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ShuffleFallbackSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ShuffleFallbackSuite.scala index 059e7d03373..a1c6468c2f9 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ShuffleFallbackSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ShuffleFallbackSuite.scala @@ -51,7 +51,7 @@ class ShuffleFallbackSuite extends AnyFunSuite } test(s"celeborn spark integration test - fallback") { - setupMiniClusterWithRandomPorts(workerNum = 5) + setupMiniClusterWithRandomPorts(workerNum = 3) val sparkConf = new SparkConf().setAppName("celeborn-demo") .setMaster("local[2]") .set(s"spark.${CelebornConf.SPARK_SHUFFLE_FORCE_FALLBACK_ENABLED.key}", "true") diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala index c857bd67b07..58b94fef1f0 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.internal.SQLConf import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.funsuite.AnyFunSuite +import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.CelebornConf._ import org.apache.celeborn.common.internal.Logging @@ -48,14 +49,35 @@ trait SparkTestBase extends AnyFunSuite override def beforeAll(): Unit = { logInfo("test initialized , setup Celeborn mini cluster") - setupMiniClusterWithRandomPorts(workerNum = 5) + // Use 3 workers (the MiniClusterFeature default) rather than 5. The spark-it suites run + // serially in a single JVM (scalatest forkMode=once), so every extra worker multiplies the + // long-lived thread/CPU footprint across the whole module. Under CPU contention on CI runners + // that surplus starves RPC handlers long enough to blow the 240s network timeout, which then + // amplifies through read retries and Spark stage reattempts into multi-minute hangs. 3 workers + // still exercises replication and slot spreading while cutting that footprint. + setupMiniClusterWithRandomPorts(workerNum = 3) } override def afterAll(): Unit = { logInfo("all test complete , stop Celeborn mini cluster") + // Tear down any SparkSession/SparkContext still alive in this JVM before the next suite runs. + // Spark integration suites run sequentially (parallelExecution = false), but a context that is + // not stopped keeps its LifecycleManager and the process-wide static ShuffleClient alive. A + // straggler task from such a leaked context can then bind to a later suite's LifecycleManager + // through the shared client and corrupt celebornShuffleId 0 (ArrayIndexOutOfBoundsException or + // CommitMetadata CRC mismatch). Stopping the context here triggers SparkShuffleManager.stop(), + // which shuts the client down and stops the LifecycleManager. + stopActiveSparkSessions() shutdownMiniCluster() } + protected def stopActiveSparkSessions(): Unit = { + SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession).foreach(_.stop()) + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + ShuffleClient.reset() + } + var workerDirs: Seq[String] = Seq.empty def getOneWorker(): Worker = { diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala index adac14242bd..09d9efe623b 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/fetch/failure/ShuffleReaderGetHooks.scala @@ -38,17 +38,15 @@ class ShuffleReaderGetHooks( val lock = new Object private def deleteDataFile(appUniqueId: String, celebornShuffleId: Int): Unit = { - val datafile = + val dataFiles = workerDirs.map(dir => { new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") }).filter(_.exists()) - .flatMap(_.listFiles().iterator).headOption - datafile match { - case Some(file) => { - file.delete() - } - case None => throw new RuntimeException("unexpected, there must be some data file") + .flatMap(_.listFiles().iterator) + if (dataFiles.isEmpty) { + throw new RuntimeException("unexpected, there must be some data file") } + dataFiles.foreach(_.delete()) } override def exec( diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/memory/MemorySparkTestBase.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/memory/MemorySparkTestBase.scala index 6d4b1e71191..6770d21e2fd 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/memory/MemorySparkTestBase.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/memory/MemorySparkTestBase.scala @@ -35,11 +35,15 @@ trait MemorySparkTestBase extends AnyFunSuite override def beforeAll(): Unit = { logInfo("test initialized , setup Celeborn mini cluster") val workerConfs = Map("celeborn.worker.directMemoryRatioForMemoryFileStorage" -> "0.2") - setupMiniClusterWithRandomPorts(workerConf = workerConfs, workerNum = 5) + // 3 workers (the MiniClusterFeature default) instead of 5: these memory-storage suites run in + // the same shared, serial spark-it JVM, so trimming the per-suite worker footprint reduces the + // CPU contention that otherwise starves a worker's fetch handler past the 240s network timeout. + setupMiniClusterWithRandomPorts(workerConf = workerConfs, workerNum = 3) } override def afterAll(): Unit = { logInfo("all test complete , stop Celeborn mini cluster") + stopActiveSparkSessions() shutdownMiniCluster() } diff --git a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala index b86244636b1..24e604fd279 100644 --- a/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/spark/shuffle/celeborn/SparkUtilsSuite.scala @@ -70,12 +70,12 @@ class SparkUtilsSuite extends AnyFunSuite val jobThread = new Thread { override def run(): Unit = { try { - val value = Range(1, 10000).mkString(",") + val value = Range(1, 100).mkString(",") sc.parallelize(1 to 10000, 2) .map { i => (i, value) } - .groupByKey(10) + .groupByKey(2) .mapPartitions { iter => - Thread.sleep(3000) + Thread.sleep(500) iter }.collect() } catch { diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/monitor/JVMQuake.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/monitor/JVMQuake.scala index 477b1ad2f72..822ba5faa0d 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/monitor/JVMQuake.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/monitor/JVMQuake.scala @@ -91,9 +91,19 @@ class JVMQuake(conf: CelebornConf, uniqueId: String = UUID.randomUUID().toString val runTimeTicks = currentExitTime - lastExitTime - gcTimeTicks // JVMStat time monitors are reported in ticks. Convert deltas to nanos before comparing // them against JVMQuake thresholds, which are stored as nanos. - val gcTime = ticksToNanos(gcTimeTicks) - val runTime = ticksToNanos(runTimeTicks) + checkAndDump(ticksToNanos(gcTimeTicks), ticksToNanos(runTimeTicks)) + lastExitTime = currentExitTime + lastGCTime = currentGCTime + } + /** + * Updates the GC "deficit" bucket with the latest GC and execution time deltas (in nanos) and + * heap dumps or kills the JVM once the configured thresholds are crossed. Separated from the + * jvmstat counter reads in [[run]] so the threshold logic can be exercised deterministically + * without inducing real GC pressure. + */ + @VisibleForTesting + private[monitor] def checkAndDump(gcTime: Long, runTime: Long): Unit = { bucket = Math.max(0, bucket + gcTime - (BigDecimal(runTime) * BigDecimal(runtimeWeight)).toLong) logDebug(s"Time: (gc time: ${Utils.nanoDurationToString(gcTime)}, execution time: ${Utils.nanoDurationToString(runTime)})") logDebug( @@ -110,8 +120,6 @@ class JVMQuake(conf: CelebornConf, uniqueId: String = UUID.randomUUID().toString System.exit(exitCode) } } - lastExitTime = currentExitTime - lastGCTime = currentGCTime } def shouldHeapDump: Boolean = { diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala index 2fab21e7184..0b91a21e3b7 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala @@ -63,8 +63,9 @@ trait MiniClusterFeature extends Logging { socket.close() } } + // Ports are drawn below the ephemeral floor to avoid binding races. def selectRandomPort(): Int = synchronized { - val port = Utils.selectRandomInt(1024, 65535) + val port = Utils.selectRandomPort() val portUsed = usedPorts.contains(port) || portBounded(port) usedPorts.add(port) if (portUsed) { @@ -213,7 +214,7 @@ trait MiniClusterFeature extends Logging { val workers = new Array[Worker](workerNum) val flagUpdateLock = new ReentrantLock() val threads = (1 to workerNum).map { i => - val worker = createWorker(workerConf) + var worker = createWorker(workerConf) val workerThread = new RunnerWrap({ var workerStartRetry = 0 var workerStarted = false @@ -225,10 +226,15 @@ trait MiniClusterFeature extends Logging { workerStarted = true worker.initialize() } catch { + case ie: InterruptedException => + Utils.tryLogNonFatalError(worker.stop(CelebornExitKind.EXIT_IMMEDIATELY)) + Utils.tryLogNonFatalError(worker.rpcEnv.shutdown()) + Thread.currentThread().interrupt() + throw ie case ex: Exception => - if (workers(i - 1) != null) { - workers(i - 1).shutdownGracefully() - } + Utils.tryLogNonFatalError(worker.exitImmediately()) + Utils.tryLogNonFatalError(worker.stop(CelebornExitKind.EXIT_IMMEDIATELY)) + Utils.tryLogNonFatalError(worker.rpcEnv.shutdown()) workerStarted = false workerStartRetry += 1 logError(s"cannot start worker $i, retrying: ", ex) @@ -236,6 +242,14 @@ trait MiniClusterFeature extends Logging { logError(s"cannot start worker $i, reached to max retrying", ex) throw ex } + try { + TimeUnit.SECONDS.sleep(Math.pow(2, workerStartRetry).toLong) + } catch { + case ie: InterruptedException => + Thread.currentThread().interrupt() + throw ie + } + worker = createWorker(workerConf) } } }) diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/monitor/JVMQuakeSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/monitor/JVMQuakeSuite.scala index fa3b7f36bea..94b86811c0d 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/monitor/JVMQuakeSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/monitor/JVMQuakeSuite.scala @@ -18,8 +18,7 @@ package org.apache.celeborn.service.deploy.worker.monitor import java.io.File - -import scala.collection.mutable.ArrayBuffer +import java.util.concurrent.TimeUnit import org.junit.Assert.assertTrue @@ -30,13 +29,6 @@ import org.apache.celeborn.common.util.JavaUtils class JVMQuakeSuite extends CelebornFunSuite { - private val allocation = new ArrayBuffer[Array[Byte]]() - - override def afterEach(): Unit = { - allocation.clear() - System.gc() - } - test("Convert JVMStat timer ticks to nanoseconds") { assert(JVMQuake.ticksToNanos(1L, 1000000000L) === 1L) assert(JVMQuake.ticksToNanos(1000L, 1000L) === 1000000000L) @@ -53,9 +45,14 @@ class JVMQuakeSuite extends CelebornFunSuite { .set(WORKER_JVM_QUAKE_RUNTIME_WEIGHT.key, "1") .set(WORKER_JVM_QUAKE_DUMP_THRESHOLD.key, "1s") .set(WORKER_JVM_QUAKE_KILL_THRESHOLD.key, "2s")) - quake.start() - allocateMemory(quake) - quake.stop() + + // Drive the GC "deficit" bucket deterministically rather than inducing real GC pressure: + // feed a GC-time delta above the 1s dump threshold with no offsetting execution time, so the + // heap dump is triggered exactly once. The previous version spun until real GC happened to + // trip the threshold, which could (and did) hang indefinitely when the runner had enough + // headroom that GC pauses never dominated runtime. + assert(!quake.heapDumped) + quake.checkAndDump(TimeUnit.SECONDS.toNanos(2), 0L) assertTrue(quake.heapDumped) val heapDump = new File(s"${quake.getHeapDumpSavePath}/${quake.dumpFile}") @@ -64,17 +61,10 @@ class JVMQuakeSuite extends CelebornFunSuite { JavaUtils.deleteRecursively(new File(quake.getHeapDumpLinkPath)) } - def allocateMemory(quake: JVMQuake): Unit = { - val capacity = 1024 * 100 - while (allocation.size * capacity < Runtime.getRuntime.maxMemory / 4) { - val bytes = new Array[Byte](capacity) - allocation.append(bytes) - } - while (quake.shouldHeapDump) { - for (index <- allocation.indices) { - val bytes = new Array[Byte](capacity) - allocation(index) = bytes - } - } + test("start() schedules monitoring and stop() tears it down without dumping") { + val quake = new JVMQuake(new CelebornConf().set(WORKER_JVM_QUAKE_ENABLED.key, "true")) + quake.start() + quake.stop() + assert(!quake.heapDumped) } }