Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.uniffle.common.rpc;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BlockingQueue;
Expand Down Expand Up @@ -56,7 +55,7 @@ public class GrpcServer implements ServerInterface {

private static volatile boolean poolExecutorHasExecuted;
private Server server;
private final int port;
private final int configuredPort;
private int listenPort;
private final GrpcThreadPoolExecutor pool;
private List<Pair<BindableService, List<ServerInterceptor>>> servicesWithInterceptors;
Expand All @@ -68,7 +67,7 @@ protected GrpcServer(
List<Pair<BindableService, List<ServerInterceptor>>> servicesWithInterceptors,
GRPCMetrics grpcMetrics) {
this.rssConf = conf;
this.port = rssConf.getInteger(RssBaseConf.RPC_SERVER_PORT);
this.configuredPort = rssConf.getInteger(RssBaseConf.RPC_SERVER_PORT);
this.servicesWithInterceptors = servicesWithInterceptors;
this.grpcMetrics = grpcMetrics;

Expand Down Expand Up @@ -211,26 +210,27 @@ void correctMetrics() {
}

@Override
public int start() throws IOException {
public int start() throws Exception {
try {
this.listenPort =
RssUtils.startServiceOnPort(this, Constants.GRPC_SERVICE_NAME, port, rssConf);
RssUtils.startServiceOnPortWithFallback(
this::startOnPort, Constants.GRPC_SERVICE_NAME, configuredPort);
} catch (Exception e) {
ExitUtils.terminate(1, "Fail to start grpc server on conf port:" + port, e, LOG);
ExitUtils.terminate(1, "Fail to start grpc server on conf port:" + configuredPort, e, LOG);
}
return listenPort;
}

@Override
public void startOnPort(int startPort) throws Exception {
private int startOnPort(int startPort) throws Exception {
this.server = buildGrpcServer(startPort);
try {
server.start();
listenPort = server.getPort();
} catch (Exception e) {
throw e;
}
LOG.info("Grpc server started, configured port: {}, listening on {}.", port, listenPort);
LOG.info("Grpc server started, start port: {}, listening on {}.", startPort, listenPort);
return listenPort;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,9 @@

package org.apache.uniffle.common.rpc;

import java.io.IOException;

public interface ServerInterface {

int start() throws IOException;

void startOnPort(int port) throws Exception;
int start() throws Exception;

void stop() throws InterruptedException;

Expand Down
49 changes: 25 additions & 24 deletions common/src/main/java/org/apache/uniffle/common/util/RssUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
import org.apache.uniffle.common.config.RssBaseConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.ServerInterface;

import static org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_EXTRA_JAVA_SYSTEM_PROPERTIES;

Expand All @@ -69,6 +68,11 @@ public class RssUtils {
private static final Logger LOGGER = LoggerFactory.getLogger(RssUtils.class);
public static final String RSS_LOCAL_DIR_KEY = "RSS_LOCAL_DIRS";

@FunctionalInterface
public interface ServiceStarter {
int startOnPort(int port) throws Exception;
}

private RssUtils() {}

/** Load properties present in the given file. */
Expand Down Expand Up @@ -184,37 +188,34 @@ public static String getHostIp() throws Exception {
return siteLocalAddress;
}

public static int startServiceOnPort(
ServerInterface service, String serviceName, int servicePort, RssBaseConf conf) {
public static int startServiceOnPortWithFallback(
ServiceStarter serviceStarter, String serviceName, int servicePort) {
if (servicePort < 0 || servicePort > 65535) {
throw new IllegalArgumentException(
String.format("Bad service %s on port (%s)", serviceName, servicePort));
}
int actualPort = servicePort;
int maxRetries = conf.get(RssBaseConf.SERVER_PORT_MAX_RETRIES);
for (int i = 0; i < maxRetries; i++) {
try {
if (servicePort == 0) {
actualPort = findRandomTcpPort(conf);
} else {
actualPort += i;
}
service.startOnPort(actualPort);
return actualPort;
} catch (Exception e) {
if (isServerPortBindCollision(e)) {
LOGGER.warn(
String.format(
"%s:Service %s failed after %s retries (on a random free port (%s))!",
e.getMessage(), serviceName, i + 1, actualPort));
} else {
try {
return serviceStarter.startOnPort(servicePort);
} catch (Exception e) {
if (servicePort > 0 && isServerPortBindCollision(e)) {
LOGGER.warn(
"{}: Service {} failed to bind on port {}, falling back to port 0.",
e.getMessage(),
serviceName,
servicePort);
try {
return serviceStarter.startOnPort(0);
} catch (Exception fallbackException) {
throw new RssException(
String.format("Failed to start service %s on port %s", serviceName, servicePort), e);
String.format(
"Failed to start service %s on port 0 after falling back from port %s",
serviceName, servicePort),
fallbackException);
}
}
throw new RssException(
String.format("Failed to start service %s on port %s", serviceName, servicePort), e);
}
throw new RssException(
String.format("Failed to start service %s on port %s", serviceName, servicePort));
}

/** check whether the exception is caused by an address-port collision when binding. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,16 @@ public void testGetHostIp() {
}

@Test
public void testStartServiceOnPort() throws InterruptedException {
RssBaseConf rssBaseConf = new RssBaseConf();
rssBaseConf.set(RssBaseConf.SERVER_PORT_MAX_RETRIES, 100);
rssBaseConf.set(RssBaseConf.RSS_RANDOM_PORT_MIN, 30000);
rssBaseConf.set(RssBaseConf.RSS_RANDOM_PORT_MAX, 39999);
// zero port to get random port
public void testStartServiceOnPortWithFallback() throws InterruptedException {
// zero port should be delegated to the service, so the OS assigns an available port.
MockServer mockServer = new MockServer();
int port = 0;
try {
int actualPort = RssUtils.startServiceOnPort(mockServer, "MockServer", port, rssBaseConf);
assertTrue(
actualPort >= 30000
&& actualPort < 39999 + rssBaseConf.get(RssBaseConf.SERVER_PORT_MAX_RETRIES));
int actualPort =
RssUtils.startServiceOnPortWithFallback(mockServer::startOnPort, "MockServer", port);
assertEquals(0, mockServer.startPort);
assertEquals(mockServer.serverSocket.getLocalPort(), actualPort);
assertTrue(actualPort > 0);
} finally {
if (mockServer != null) {
mockServer.stop();
Expand All @@ -123,19 +120,17 @@ public void testStartServiceOnPort() throws InterruptedException {
// error port test
try {
port = -1;
RssUtils.startServiceOnPort(mockServer, "MockServer", port, rssBaseConf);
RssUtils.startServiceOnPortWithFallback(mockServer::startOnPort, "MockServer", port);
} catch (RuntimeException e) {
assertTrue(e.toString().startsWith("java.lang.IllegalArgumentException: Bad service"));
}
// a specific port to start
try {
mockServer = new MockServer();
port = 10000;
rssBaseConf.set(RssBaseConf.SERVER_PORT_MAX_RETRIES, 100);
int actualPort = RssUtils.startServiceOnPort(mockServer, "MockServer", port, rssBaseConf);
assertTrue(
actualPort >= port
&& actualPort < port + rssBaseConf.get(RssBaseConf.SERVER_PORT_MAX_RETRIES));
int actualPort =
RssUtils.startServiceOnPortWithFallback(mockServer::startOnPort, "MockServer", port);
assertEquals(port, actualPort);
} finally {
if (mockServer != null) {
mockServer.stop();
Expand All @@ -147,17 +142,21 @@ public void testStartServiceOnPort() throws InterruptedException {
try {
mockServer = new MockServer();
port = 10000;
int actualPort1 = RssUtils.startServiceOnPort(mockServer, "MockServer", port, rssBaseConf);
rssBaseConf.set(RssBaseConf.SERVER_PORT_MAX_RETRIES, 10);
int actualPort1 =
RssUtils.startServiceOnPortWithFallback(mockServer::startOnPort, "MockServer", port);
int actualPort2 =
RssUtils.startServiceOnPort(toStartSockServer, "MockServer", actualPort1, rssBaseConf);
assertTrue(actualPort1 < actualPort2);
RssUtils.startServiceOnPortWithFallback(
toStartSockServer::startOnPort, "MockServer", actualPort1);
assertEquals(0, toStartSockServer.startPort);
assertEquals(toStartSockServer.serverSocket.getLocalPort(), actualPort2);
assertTrue(actualPort2 > 0);
assertTrue(actualPort1 != actualPort2);
toStartSockServer.stop();
rssBaseConf.set(RssBaseConf.SERVER_PORT_MAX_RETRIES, 0);
RssUtils.startServiceOnPort(toStartSockServer, "MockServer", actualPort1, rssBaseConf);
assertFalse(false);
toStartSockServer = new MockServer();
RssUtils.startServiceOnPortWithFallback(toStartSockServer::startOnPort, "MockServer", 0);
assertEquals(0, toStartSockServer.startPort);
} catch (RuntimeException e) {
assertTrue(e.getMessage().startsWith("Failed to start service"));
fail(e.getMessage());
} finally {
if (mockServer != null) {
mockServer.stop();
Expand Down Expand Up @@ -351,15 +350,15 @@ public String get() {
public static class MockServer implements ServerInterface {

ServerSocket serverSocket;
int startPort = -1;

@Override
public int start() throws IOException {
// not implement
return -1;
return startOnPort(0);
}

@Override
public void startOnPort(int port) throws IOException {
int startOnPort(int port) throws IOException {
startPort = port;
serverSocket =
ServerSocketFactory.getDefault()
.createServerSocket(port, 1, InetAddress.getByName("localhost"));
Expand All @@ -374,6 +373,7 @@ public void startOnPort(int port) throws IOException {
}
})
.start();
return serverSocket.getLocalPort();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.uniffle.server.netty;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -171,20 +171,19 @@ public void initChannel(final SocketChannel ch) {
}

@Override
public int start() throws IOException {
public int start() throws Exception {
int port = shuffleServerConf.getInteger(ShuffleServerConf.NETTY_SERVER_PORT);
try {
port =
RssUtils.startServiceOnPort(
this, Constants.NETTY_STREAM_SERVICE_NAME, port, shuffleServerConf);
RssUtils.startServiceOnPortWithFallback(
this::startOnPort, Constants.NETTY_STREAM_SERVICE_NAME, port);
} catch (Exception e) {
ExitUtils.terminate(1, "Fail to start stream server", e, LOG);
}
return port;
}

@Override
public void startOnPort(int port) throws Exception {
private int startOnPort(int port) throws Exception {

ServerBootstrap serverBootstrap =
bootstrapChannel(
Expand All @@ -201,7 +200,9 @@ public void startOnPort(int port) throws Exception {
channelFuture = serverBootstrap.bind(port);
channelFuture.syncUninterruptibly();
LOG.info("bind localAddress is " + channelFuture.channel().localAddress());
LOG.info("Start stream server successfully with port " + port);
int actualPort = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
LOG.info("Start stream server successfully with port " + actualPort);
return actualPort;
} catch (Exception e) {
throw e;
}
Expand Down