From 7fc950ac5a7d932c234f0250289d9cfa24b16c3d Mon Sep 17 00:00:00 2001 From: Junfan Zhang Date: Wed, 24 Jun 2026 11:04:53 +0800 Subject: [PATCH] [#2758] fix(rpc): simplify server port fallback to improve success rate --- .../apache/uniffle/common/rpc/GrpcServer.java | 18 +++--- .../uniffle/common/rpc/ServerInterface.java | 6 +- .../apache/uniffle/common/util/RssUtils.java | 49 ++++++++-------- .../uniffle/common/util/RssUtilsTest.java | 56 +++++++++---------- .../uniffle/server/netty/StreamServer.java | 15 ++--- 5 files changed, 71 insertions(+), 73 deletions(-) diff --git a/common/src/main/java/org/apache/uniffle/common/rpc/GrpcServer.java b/common/src/main/java/org/apache/uniffle/common/rpc/GrpcServer.java index 70ff4a7a67..1e3a29565b 100644 --- a/common/src/main/java/org/apache/uniffle/common/rpc/GrpcServer.java +++ b/common/src/main/java/org/apache/uniffle/common/rpc/GrpcServer.java @@ -17,7 +17,6 @@ package org.apache.uniffle.common.rpc; -import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.concurrent.BlockingQueue; @@ -56,7 +55,7 @@ public class GrpcServer implements ServerInterface { private static volatile boolean poolExecutorHasExecuted; private Server server; - private final int port; + private final int configuredPort; private int listenPort; private final GrpcThreadPoolExecutor pool; private List>> servicesWithInterceptors; @@ -68,7 +67,7 @@ protected GrpcServer( List>> servicesWithInterceptors, GRPCMetrics grpcMetrics) { this.rssConf = conf; - this.port = rssConf.getInteger(RssBaseConf.RPC_SERVER_PORT); + this.configuredPort = rssConf.getInteger(RssBaseConf.RPC_SERVER_PORT); this.servicesWithInterceptors = servicesWithInterceptors; this.grpcMetrics = grpcMetrics; @@ -211,18 +210,18 @@ void correctMetrics() { } @Override - public int start() throws IOException { + public int start() throws Exception { try { this.listenPort = - RssUtils.startServiceOnPort(this, Constants.GRPC_SERVICE_NAME, port, rssConf); + RssUtils.startServiceOnPortWithFallback( + this::startOnPort, Constants.GRPC_SERVICE_NAME, configuredPort); } catch (Exception e) { - ExitUtils.terminate(1, "Fail to start grpc server on conf port:" + port, e, LOG); + ExitUtils.terminate(1, "Fail to start grpc server on conf port:" + configuredPort, e, LOG); } return listenPort; } - @Override - public void startOnPort(int startPort) throws Exception { + private int startOnPort(int startPort) throws Exception { this.server = buildGrpcServer(startPort); try { server.start(); @@ -230,7 +229,8 @@ public void startOnPort(int startPort) throws Exception { } catch (Exception e) { throw e; } - LOG.info("Grpc server started, configured port: {}, listening on {}.", port, listenPort); + LOG.info("Grpc server started, start port: {}, listening on {}.", startPort, listenPort); + return listenPort; } @Override diff --git a/common/src/main/java/org/apache/uniffle/common/rpc/ServerInterface.java b/common/src/main/java/org/apache/uniffle/common/rpc/ServerInterface.java index ab6a91b535..54dcccd4ce 100644 --- a/common/src/main/java/org/apache/uniffle/common/rpc/ServerInterface.java +++ b/common/src/main/java/org/apache/uniffle/common/rpc/ServerInterface.java @@ -17,13 +17,9 @@ package org.apache.uniffle.common.rpc; -import java.io.IOException; - public interface ServerInterface { - int start() throws IOException; - - void startOnPort(int port) throws Exception; + int start() throws Exception; void stop() throws InterruptedException; diff --git a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java index 925f4601c1..3e558c1af1 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java +++ b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java @@ -60,7 +60,6 @@ import org.apache.uniffle.common.config.RssBaseConf; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssException; -import org.apache.uniffle.common.rpc.ServerInterface; import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_EXTRA_JAVA_SYSTEM_PROPERTIES; @@ -69,6 +68,11 @@ public class RssUtils { private static final Logger LOGGER = LoggerFactory.getLogger(RssUtils.class); public static final String RSS_LOCAL_DIR_KEY = "RSS_LOCAL_DIRS"; + @FunctionalInterface + public interface ServiceStarter { + int startOnPort(int port) throws Exception; + } + private RssUtils() {} /** Load properties present in the given file. */ @@ -184,37 +188,34 @@ public static String getHostIp() throws Exception { return siteLocalAddress; } - public static int startServiceOnPort( - ServerInterface service, String serviceName, int servicePort, RssBaseConf conf) { + public static int startServiceOnPortWithFallback( + ServiceStarter serviceStarter, String serviceName, int servicePort) { if (servicePort < 0 || servicePort > 65535) { throw new IllegalArgumentException( String.format("Bad service %s on port (%s)", serviceName, servicePort)); } - int actualPort = servicePort; - int maxRetries = conf.get(RssBaseConf.SERVER_PORT_MAX_RETRIES); - for (int i = 0; i < maxRetries; i++) { - try { - if (servicePort == 0) { - actualPort = findRandomTcpPort(conf); - } else { - actualPort += i; - } - service.startOnPort(actualPort); - return actualPort; - } catch (Exception e) { - if (isServerPortBindCollision(e)) { - LOGGER.warn( - String.format( - "%s:Service %s failed after %s retries (on a random free port (%s))!", - e.getMessage(), serviceName, i + 1, actualPort)); - } else { + try { + return serviceStarter.startOnPort(servicePort); + } catch (Exception e) { + if (servicePort > 0 && isServerPortBindCollision(e)) { + LOGGER.warn( + "{}: Service {} failed to bind on port {}, falling back to port 0.", + e.getMessage(), + serviceName, + servicePort); + try { + return serviceStarter.startOnPort(0); + } catch (Exception fallbackException) { throw new RssException( - String.format("Failed to start service %s on port %s", serviceName, servicePort), e); + String.format( + "Failed to start service %s on port 0 after falling back from port %s", + serviceName, servicePort), + fallbackException); } } + throw new RssException( + String.format("Failed to start service %s on port %s", serviceName, servicePort), e); } - throw new RssException( - String.format("Failed to start service %s on port %s", serviceName, servicePort)); } /** check whether the exception is caused by an address-port collision when binding. */ diff --git a/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java b/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java index 1d2bed8c6d..9e632f80f2 100644 --- a/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java +++ b/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java @@ -102,19 +102,16 @@ public void testGetHostIp() { } @Test - public void testStartServiceOnPort() throws InterruptedException { - RssBaseConf rssBaseConf = new RssBaseConf(); - rssBaseConf.set(RssBaseConf.SERVER_PORT_MAX_RETRIES, 100); - rssBaseConf.set(RssBaseConf.RSS_RANDOM_PORT_MIN, 30000); - rssBaseConf.set(RssBaseConf.RSS_RANDOM_PORT_MAX, 39999); - // zero port to get random port + public void testStartServiceOnPortWithFallback() throws InterruptedException { + // zero port should be delegated to the service, so the OS assigns an available port. MockServer mockServer = new MockServer(); int port = 0; try { - int actualPort = RssUtils.startServiceOnPort(mockServer, "MockServer", port, rssBaseConf); - assertTrue( - actualPort >= 30000 - && actualPort < 39999 + rssBaseConf.get(RssBaseConf.SERVER_PORT_MAX_RETRIES)); + int actualPort = + RssUtils.startServiceOnPortWithFallback(mockServer::startOnPort, "MockServer", port); + assertEquals(0, mockServer.startPort); + assertEquals(mockServer.serverSocket.getLocalPort(), actualPort); + assertTrue(actualPort > 0); } finally { if (mockServer != null) { mockServer.stop(); @@ -123,7 +120,7 @@ public void testStartServiceOnPort() throws InterruptedException { // error port test try { port = -1; - RssUtils.startServiceOnPort(mockServer, "MockServer", port, rssBaseConf); + RssUtils.startServiceOnPortWithFallback(mockServer::startOnPort, "MockServer", port); } catch (RuntimeException e) { assertTrue(e.toString().startsWith("java.lang.IllegalArgumentException: Bad service")); } @@ -131,11 +128,9 @@ public void testStartServiceOnPort() throws InterruptedException { try { mockServer = new MockServer(); port = 10000; - rssBaseConf.set(RssBaseConf.SERVER_PORT_MAX_RETRIES, 100); - int actualPort = RssUtils.startServiceOnPort(mockServer, "MockServer", port, rssBaseConf); - assertTrue( - actualPort >= port - && actualPort < port + rssBaseConf.get(RssBaseConf.SERVER_PORT_MAX_RETRIES)); + int actualPort = + RssUtils.startServiceOnPortWithFallback(mockServer::startOnPort, "MockServer", port); + assertEquals(port, actualPort); } finally { if (mockServer != null) { mockServer.stop(); @@ -147,17 +142,21 @@ public void testStartServiceOnPort() throws InterruptedException { try { mockServer = new MockServer(); port = 10000; - int actualPort1 = RssUtils.startServiceOnPort(mockServer, "MockServer", port, rssBaseConf); - rssBaseConf.set(RssBaseConf.SERVER_PORT_MAX_RETRIES, 10); + int actualPort1 = + RssUtils.startServiceOnPortWithFallback(mockServer::startOnPort, "MockServer", port); int actualPort2 = - RssUtils.startServiceOnPort(toStartSockServer, "MockServer", actualPort1, rssBaseConf); - assertTrue(actualPort1 < actualPort2); + RssUtils.startServiceOnPortWithFallback( + toStartSockServer::startOnPort, "MockServer", actualPort1); + assertEquals(0, toStartSockServer.startPort); + assertEquals(toStartSockServer.serverSocket.getLocalPort(), actualPort2); + assertTrue(actualPort2 > 0); + assertTrue(actualPort1 != actualPort2); toStartSockServer.stop(); - rssBaseConf.set(RssBaseConf.SERVER_PORT_MAX_RETRIES, 0); - RssUtils.startServiceOnPort(toStartSockServer, "MockServer", actualPort1, rssBaseConf); - assertFalse(false); + toStartSockServer = new MockServer(); + RssUtils.startServiceOnPortWithFallback(toStartSockServer::startOnPort, "MockServer", 0); + assertEquals(0, toStartSockServer.startPort); } catch (RuntimeException e) { - assertTrue(e.getMessage().startsWith("Failed to start service")); + fail(e.getMessage()); } finally { if (mockServer != null) { mockServer.stop(); @@ -351,15 +350,15 @@ public String get() { public static class MockServer implements ServerInterface { ServerSocket serverSocket; + int startPort = -1; @Override public int start() throws IOException { - // not implement - return -1; + return startOnPort(0); } - @Override - public void startOnPort(int port) throws IOException { + int startOnPort(int port) throws IOException { + startPort = port; serverSocket = ServerSocketFactory.getDefault() .createServerSocket(port, 1, InetAddress.getByName("localhost")); @@ -374,6 +373,7 @@ public void startOnPort(int port) throws IOException { } }) .start(); + return serverSocket.getLocalPort(); } @Override diff --git a/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java b/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java index 23ccaff382..7271e7c9ca 100644 --- a/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java +++ b/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java @@ -17,7 +17,7 @@ package org.apache.uniffle.server.netty; -import java.io.IOException; +import java.net.InetSocketAddress; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -171,20 +171,19 @@ public void initChannel(final SocketChannel ch) { } @Override - public int start() throws IOException { + public int start() throws Exception { int port = shuffleServerConf.getInteger(ShuffleServerConf.NETTY_SERVER_PORT); try { port = - RssUtils.startServiceOnPort( - this, Constants.NETTY_STREAM_SERVICE_NAME, port, shuffleServerConf); + RssUtils.startServiceOnPortWithFallback( + this::startOnPort, Constants.NETTY_STREAM_SERVICE_NAME, port); } catch (Exception e) { ExitUtils.terminate(1, "Fail to start stream server", e, LOG); } return port; } - @Override - public void startOnPort(int port) throws Exception { + private int startOnPort(int port) throws Exception { ServerBootstrap serverBootstrap = bootstrapChannel( @@ -201,7 +200,9 @@ public void startOnPort(int port) throws Exception { channelFuture = serverBootstrap.bind(port); channelFuture.syncUninterruptibly(); LOG.info("bind localAddress is " + channelFuture.channel().localAddress()); - LOG.info("Start stream server successfully with port " + port); + int actualPort = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); + LOG.info("Start stream server successfully with port " + actualPort); + return actualPort; } catch (Exception e) { throw e; }