diff --git a/service/src/main/scala/org/apache/celeborn/server/common/HttpService.scala b/service/src/main/scala/org/apache/celeborn/server/common/HttpService.scala index a66f53305dd..958a7fb73d0 100644 --- a/service/src/main/scala/org/apache/celeborn/server/common/HttpService.scala +++ b/service/src/main/scala/org/apache/celeborn/server/common/HttpService.scala @@ -195,6 +195,11 @@ abstract class HttpService extends Service with Logging { def updateInterruptionNotice(workerInterruptionNotices: Map[String, Long]): HandleResponse = throw new UnsupportedOperationException() + def getServingState(): String = throw new UnsupportedOperationException() + + def setServingState(state: String, timeoutMs: String): String = + throw new UnsupportedOperationException() + def startHttpServer(): Unit = { httpServer = HttpServer( serviceName, diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java index 0bdf20a7569..87042fdf619 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java @@ -79,6 +79,8 @@ public class MemoryManager { private long pausePushDataAndReplicateTime = 0L; private int trimCounter = 0; private volatile boolean isPaused = false; + private volatile ServingState forcedServingState = null; + private volatile long forcedServingStateExpireTime = -1L; // -1 means no expiry // For credit stream private final AtomicLong readBufferCounter = new AtomicLong(0); private long readBufferThreshold; @@ -307,6 +309,15 @@ public boolean shouldEvict(boolean aggressiveMemoryFileEvictEnabled, double evic } public ServingState currentServingState() { + if (forcedServingState != null) { + if (forcedServingStateExpireTime > 0 + && System.currentTimeMillis() > forcedServingStateExpireTime) { + this.clearForcedServingState(); + } else { + return forcedServingState; + } + } + long memoryUsage = getMemoryUsage(); // pause replicate threshold always greater than pause push data threshold // so when trigger pause replicate, pause both push and replicate @@ -587,6 +598,30 @@ public void releaseMemoryFileStorage(int bytes) { memoryFileStorageCounter.add(-1 * bytes); } + public ServingState getServingState() { + return servingState; + } + + public ServingState getForcedServingState() { + return forcedServingState; + } + + public void forceServingState(ServingState state, Long timeoutMs) { + this.forcedServingState = state; + this.forcedServingStateExpireTime = + timeoutMs > 0 ? System.currentTimeMillis() + timeoutMs : -1L; + logger.info( + "Serving state manually forced to {} with forcedServingStateExpireTime {}", + state, + timeoutMs); + } + + public void clearForcedServingState() { + this.forcedServingState = null; + this.forcedServingStateExpireTime = -1L; + logger.info("Forced serving state override cleared"); + } + public void close() { checkService.shutdown(); reportService.shutdown(); diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala index da2cab1c349..0494619fc30 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala @@ -947,6 +947,51 @@ private[celeborn] class Worker( sb.toString() } + override def getServingState(): String = { + val sb = new StringBuilder + sb.append("====================== Worker Serving State ==========================\n") + val current = memoryManager.getServingState + sb.append(s"Current state: $current.\n") + + val forced = memoryManager.getForcedServingState + if (forced != null) { + sb.append(s"Manual override active.\n") + } + sb.toString() + } + + override def setServingState(state: String, timeoutStr: String): String = { + val sb = new StringBuilder + sb.append("====================== Set Serving State ============================\n") + if (state.isEmpty) { + memoryManager.clearForcedServingState() + sb.append("Manual servingState override cleared.\n") + return sb.toString() + } + val servingState = + try { + ServingState.valueOf(state.toUpperCase(Locale.ROOT)) + } catch { + case _: IllegalArgumentException => + return s"Invalid state '$state'. " + + s"Legal values: PUSH_AND_REPLICATE_PAUSED, PUSH_PAUSED, NONE_PAUSED\n" + } + val timeout = + if (timeoutStr.isEmpty) 0L + else + try { + JavaUtils.timeStringAsMs(timeoutStr) + } catch { + case e: NumberFormatException => + return s"Invalid timeout '$timeoutStr'. $e\n" + } + memoryManager.forceServingState(servingState, timeout) + sb.append(s"Serving state forced to: $servingState\n") + if (timeout > 0) sb.append(s"Override will auto-clear after $timeoutStr.\n") + else sb.append("Override will persist until explicitly cleared.\n") + sb.toString() + } + override def exit(exitType: String): String = { exitType.toUpperCase(Locale.ROOT) match { case "DECOMMISSION" => diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/http/api/ApiWorkerResource.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/http/api/ApiWorkerResource.scala index 087bc89e327..07597d0b470 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/http/api/ApiWorkerResource.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/http/api/ApiWorkerResource.scala @@ -88,4 +88,32 @@ class ApiWorkerResource extends ApiRequestContext { def exit(@FormParam("type") exitType: String): String = { httpService.exit(normalizeParam(exitType)) } + + @Path("/servingState") + @ApiResponse( + responseCode = "200", + content = Array(new Content( + mediaType = MediaType.TEXT_PLAIN)), + description = + "Show the current serving state and whether a manual override is active.") + @GET + def getServingState(): String = httpService.getServingState() + + @Path("/servingState") + @ApiResponse( + responseCode = "200", + content = Array(new Content( + mediaType = MediaType.APPLICATION_FORM_URLENCODED)), + description = + "Force the worker serving state. " + + "Legal values for 'state' are 'PUSH_AND_REPLICATE_PAUSED', 'PUSH_PAUSED' and 'NONE_PAUSED'," + + " or empty to clear the override. " + + "Optional 'timeoutMs' auto-clears the override after the given duration; omit to hold indefinitely.") + @POST + def setServingState( + @FormParam("state") state: String, + @FormParam("timeout") timeoutStr: String): String = { + httpService.setServingState(normalizeParam(state), normalizeParam(timeoutStr)) + } + }