diff --git a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpoint.scala b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpoint.scala index 8d12c227af6..3b219372113 100644 --- a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpoint.scala +++ b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpoint.scala @@ -20,6 +20,7 @@ package org.apache.celeborn.common.rpc import org.apache.celeborn.common.exception.CelebornException import org.apache.celeborn.common.network.client.TransportClient import org.apache.celeborn.common.rpc.netty.RemoteNettyRpcCallContext +import org.apache.celeborn.common.util.Utils /** * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be @@ -138,6 +139,11 @@ trait RpcEndpoint { } def checkAuth(context: RpcCallContext, appId: String): Unit = { + // Validate the application id at the single auth chokepoint so every current + // and future RPC handler that calls checkAuth is covered, and so it runs even + // when auth is disabled (clientId == null). This guards the worker against + // path traversal via appId (e.g. "../foo") before any filesystem path is built. + Utils.validateAppId(appId) context match { case remoteContext: RemoteNettyRpcCallContext => checkAuth(remoteContext.transportClient, appId) diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala index 5b2dd6a1097..2c4f3e310e4 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala @@ -700,6 +700,18 @@ object Utils extends Logging { (appId, shuffleId) } + private val appIdPattern = "[A-Za-z0-9_-]+".r.pattern + + def validateAppId(applicationId: String): Unit = { + // matches() anchors the whole input, so a trailing newline (which `$` would + // otherwise tolerate) is rejected along with any other traversal character. + if (applicationId == null || !appIdPattern.matcher(applicationId).matches()) { + throw new IllegalArgumentException( + s"Invalid application id: '$applicationId'. " + + "Application id must be non-empty and match [A-Za-z0-9_-]+.") + } + } + def splitPartitionLocationUniqueId(uniqueId: String): (Int, Int) = { val splits = uniqueId.split("-") val partitionId = splits.dropRight(1).mkString("-").toInt diff --git a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala index 8be472b6447..e59296fc6b9 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala @@ -289,4 +289,30 @@ class UtilsSuite extends CelebornFunSuite { celebornConf) assert(testInstance.isInstanceOf[DefaultIdentityProvider]) } + + test("validateAppId rejects path traversal and accepts valid ids") { + Seq( + "application_1234567890123_0001", + "local-1234567890123", + "app1", + "my-app-id", + "app_with_underscores").foreach { id => + Utils.validateAppId(id) + } + + Seq( + "../etc/passwd", + "app/../secret", + "app/id", + "app\\id", + "app id", + "app\n", + "valid_app\n", + "", + null).foreach { id => + intercept[IllegalArgumentException] { + Utils.validateAppId(id) + } + } + } } diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala index 9a2d4a8a740..a8341659510 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/StorageManager.scala @@ -1161,6 +1161,9 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs throw new IOException(s"No available disks! suggested mountPoint $suggestedMountPoint") } + // NOTE: the DFS branches below (HDFS/S3/OSS) also build "$appId/$shuffleId" + // paths but rely solely on the upstream Utils.validateAppId guard at the RPC + // entry points if (storageType == Type.HDFS && location.getStorageInfo.HDFSAvailable()) { val shuffleDir = new Path(new Path(hdfsDir, conf.workerWorkingDir), s"$appId/$shuffleId") @@ -1216,8 +1219,15 @@ final private[worker] class StorageManager(conf: CelebornConf, workerSource: Abs val dir = dirs(getNextIndex % dirs.size) val mountPoint = DeviceInfo.getMountPoint(dir.getAbsolutePath, mountPoints) val shuffleDir = new File(dir, s"$appId/$shuffleId") - shuffleDir.mkdirs() val file = new File(shuffleDir, fileName) + // Defense in depth: ensure the resolved path stays under the working dir + // even if appId / shuffleId / fileName contained traversal characters. + val dirCanonical = dir.getCanonicalPath + File.separator + if (!file.getCanonicalPath.startsWith(dirCanonical)) { + throw new IOException( + s"Refusing to create shuffle file outside working dir: ${file.getCanonicalPath}") + } + shuffleDir.mkdirs() try { if (file.exists()) { throw new FileAlreadyExistsException(