Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorker
@GuardedBy("metadataLock")
Comment thread
parveensania marked this conversation as resolved.
private long pendingMetadataVersion;

@GuardedBy("this")
private WindmillEndpoints.Type activeMetadataType;

@GuardedBy("metadataLock")
private WindmillEndpoints.Type pendingMetadataType;

@GuardedBy("this")
private boolean started;

Expand Down Expand Up @@ -141,6 +147,8 @@ private FanOutStreamingEngineWorkerHarness(
this.getWorkBudgetDistributor = getWorkBudgetDistributor;
this.totalGetWorkBudget = totalGetWorkBudget;
this.activeMetadataVersion = Long.MIN_VALUE;
this.activeMetadataType = WindmillEndpoints.Type.UNKNOWN;
this.pendingMetadataType = WindmillEndpoints.Type.UNKNOWN;
this.workCommitterFactory = workCommitterFactory;
}

Expand Down Expand Up @@ -271,8 +279,14 @@ private void consumeWorkerMetadata(WindmillEndpoints windmillEndpoints) {
synchronized (metadataLock) {
// Only process versions greater than what we currently have to prevent double processing of
// metadata. workerMetadataConsumer is single-threaded so we maintain ordering.
if (windmillEndpoints.version() > pendingMetadataVersion) {
// The endpoints are also consumed if the version is the same but the type of endpoints
// sent by the server has changed.
if (windmillEndpoints.version() > pendingMetadataVersion
|| (windmillEndpoints.version() == pendingMetadataVersion
Comment thread
parveensania marked this conversation as resolved.
Outdated
&& windmillEndpoints.type() != WindmillEndpoints.Type.UNKNOWN
&& windmillEndpoints.type() != pendingMetadataType)) {
pendingMetadataVersion = windmillEndpoints.version();
pendingMetadataType = windmillEndpoints.type();
workerMetadataConsumer.execute(() -> consumeWindmillWorkerEndpoints(windmillEndpoints));
}
}
Expand All @@ -283,16 +297,19 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi
// queued up while a previous version of the windmillEndpoints were being consumed. Only consume
// the endpoints if they are the most current version.
synchronized (metadataLock) {
if (newWindmillEndpoints.version() < pendingMetadataVersion) {
if (newWindmillEndpoints.version() < pendingMetadataVersion
|| newWindmillEndpoints.type() != pendingMetadataType) {
return;
}
}

LOG.debug(
"Consuming new endpoints: {}. previous metadata version: {}, current metadata version: {}",
LOG.info(
"Consuming new endpoints: {}. previous metadata version: {}, current metadata version: {}, previous endpoint type: {}, current endpoint type: {}",
newWindmillEndpoints,
activeMetadataVersion,
newWindmillEndpoints.version());
newWindmillEndpoints.version(),
activeMetadataType,
newWindmillEndpoints.type());
closeStreamsNotIn(newWindmillEndpoints).join();
ImmutableMap<Endpoint, WindmillStreamSender> newStreams =
createAndStartNewStreams(newWindmillEndpoints.windmillEndpoints()).join();
Expand All @@ -305,6 +322,7 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi
backends.set(newBackends);
getWorkBudgetDistributor.distributeBudget(newStreams.values(), totalGetWorkBudget);
activeMetadataVersion = newWindmillEndpoints.version();
activeMetadataType = newWindmillEndpoints.type();
}

/** Close the streams that are no longer valid asynchronously. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,31 @@
*/
@AutoValue
public abstract class WindmillEndpoints {
public enum Type {
UNKNOWN,
CLOUDPATH,
DIRECTPATH;

static Type fromProto(Windmill.WorkerMetadataResponse.EndpointType protoType) {
switch (protoType) {
case CLOUDPATH:
return CLOUDPATH;
case DIRECTPATH:
return DIRECTPATH;
default:
return UNKNOWN;
}
}
}

public static final int DEFAULT_WINDMILL_SERVICE_PORT = 443;
private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class);
private static final WindmillEndpoints NO_ENDPOINTS =
WindmillEndpoints.builder()
.setVersion(Long.MAX_VALUE)
.setWindmillEndpoints(ImmutableSet.of())
.setGlobalDataEndpoints(ImmutableMap.of())
.setType(Type.UNKNOWN)
.build();

public static WindmillEndpoints none() {
Expand Down Expand Up @@ -75,6 +93,7 @@ public static WindmillEndpoints from(
.setVersion(workerMetadataResponseProto.getMetadataVersion())
.setGlobalDataEndpoints(globalDataServers)
.setWindmillEndpoints(windmillServers)
.setType(Type.fromProto(workerMetadataResponseProto.getEndpointType()))
.build();
}

Expand Down Expand Up @@ -138,6 +157,8 @@ private static Optional<HostAndPort> tryParseDirectEndpointIntoIpV6Address(
/** Version of the endpoints which increases with every modification. */
public abstract long version();

public abstract Type type();

/**
* Used by GetData GlobalDataRequest(s) to support Beam side inputs. Returns a map where the key
* is a global data tag and the value is the endpoint where the data associated with the global
Expand Down Expand Up @@ -221,6 +242,8 @@ public abstract static class Builder {
public abstract static class Builder {
public abstract Builder setVersion(long version);

public abstract Builder setType(Type type);

public abstract Builder setGlobalDataEndpoints(
ImmutableMap<String, WindmillEndpoints.Endpoint> globalDataServers);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,10 @@ public static GrpcGetWorkerMetadataStream create(
private Optional<WindmillEndpoints> extractWindmillEndpointsFrom(
WorkerMetadataResponse response) {
synchronized (metadataLock) {
if (response.getMetadataVersion() > latestResponse.getMetadataVersion()) {
this.latestResponse = response;
this.latestResponseReceived = Instant.now();
return Optional.of(WindmillEndpoints.from(response));
} else {
// If the currentMetadataVersion is greater than or equal to one in the response, the
// response data is stale, and we do not want to do anything.
LOG.debug(
"Received metadata version={}; Current metadata version={}. "
+ "Skipping update because received stale metadata",
response.getMetadataVersion(),
latestResponse.getMetadataVersion());
}
this.latestResponse = response;
this.latestResponseReceived = Instant.now();
return Optional.of(WindmillEndpoints.from(response));
}

return Optional.empty();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.runners.dataflow.worker.streaming.harness;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
Expand Down Expand Up @@ -348,6 +349,143 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce
TimeUnit.SECONDS.sleep(WAIT_FOR_METADATA_INJECTIONS_SECONDS);
}

@Test
public void testOnNewWorkerMetadata_alternatesConnectivityTypesAndRemovesStaleEndpoints()
throws InterruptedException {
String workerToken = "workerToken1";

WorkerMetadataResponse cloudPathMetadata =
WorkerMetadataResponse.newBuilder()
.setMetadataVersion(1)
.setEndpointType(Windmill.WorkerMetadataResponse.EndpointType.CLOUDPATH)
.addWorkEndpoints(
WorkerMetadataResponse.Endpoint.newBuilder()
.setBackendWorkerToken(workerToken)
.build())
.putAllGlobalDataEndpoints(DEFAULT)
.build();
WorkerMetadataResponse directPathMetadata =
WorkerMetadataResponse.newBuilder()
.setMetadataVersion(1)
.setEndpointType(Windmill.WorkerMetadataResponse.EndpointType.DIRECTPATH)
.addWorkEndpoints(
WorkerMetadataResponse.Endpoint.newBuilder()
.setBackendWorkerToken(workerToken + "1")
.build())
.addWorkEndpoints(
WorkerMetadataResponse.Endpoint.newBuilder()
.setBackendWorkerToken(workerToken + "2")
.build())
.putAllGlobalDataEndpoints(DEFAULT)
.build();
WorkerMetadataResponse directPathMetadata2 =
WorkerMetadataResponse.newBuilder()
.setMetadataVersion(1)
.setEndpointType(Windmill.WorkerMetadataResponse.EndpointType.DIRECTPATH)
.addWorkEndpoints(
WorkerMetadataResponse.Endpoint.newBuilder()
.setBackendWorkerToken(workerToken + "3")
.build())
.putAllGlobalDataEndpoints(DEFAULT)
.build();

TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor());
fanOutStreamingEngineWorkProvider =
newFanOutStreamingEngineWorkerHarness(
GetWorkBudget.builder().setItems(1).setBytes(1).build(),
getWorkBudgetDistributor,
noOpProcessWorkItemFn());

// Sequence : CLOUDPATH -> DIRECTPATH -> CLOUDPATH -> DIRECTPATH
// Start with CLOUDPATH (version 1, 1 endpoint)
// Verifies: version > pendingMetadataVersion condition triggers consumption
fakeGetWorkerMetadataStub.injectWorkerMetadata(cloudPathMetadata);
verify(getWorkBudgetDistributor, times(1)).distributeBudget(any(), any());
TimeUnit.SECONDS.sleep(WAIT_FOR_METADATA_INJECTIONS_SECONDS);
assertThat(fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams()).hasSize(1);
assertThat(
fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().keySet().stream()
.map(endpoint -> endpoint.workerToken().orElse(""))
.collect(Collectors.toSet()))
.contains(workerToken);

// Switch to DIRECTPATH (same version 1, 2 endpoints, different type)
// Verifies: type change at same version triggers consumption (consumeWorkerMetadata lines
// 284-286)
fakeGetWorkerMetadataStub.injectWorkerMetadata(directPathMetadata);
verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any());
TimeUnit.SECONDS.sleep(WAIT_FOR_METADATA_INJECTIONS_SECONDS);
assertThat(fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().values())
.hasSize(2);
// Verifies: stale CLOUDPATH endpoint is not consumed
Set<String> directPathTokens =
fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().keySet().stream()
.map(endpoint -> endpoint.workerToken().orElse(""))
.collect(Collectors.toSet());
assertThat(directPathTokens).contains(workerToken + "1");
assertThat(directPathTokens).contains(workerToken + "2");
assertThat(directPathTokens).containsNoneIn(java.util.Arrays.asList(workerToken));

// Switch back to CLOUDPATH (same version 1, 1 endpoint, different type)
fakeGetWorkerMetadataStub.injectWorkerMetadata(cloudPathMetadata);
verify(getWorkBudgetDistributor, times(3)).distributeBudget(any(), any());
TimeUnit.SECONDS.sleep(WAIT_FOR_METADATA_INJECTIONS_SECONDS);
assertThat(fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().values())
.hasSize(1);
// Verifies: stale DIRECTPATH endpoints are not consumed
Set<String> cloudPathTokens =
fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().keySet().stream()
.map(endpoint -> endpoint.workerToken().orElse(""))
.collect(Collectors.toSet());
assertThat(cloudPathTokens).contains(workerToken);
assertThat(cloudPathTokens)
.containsNoneIn(java.util.Arrays.asList(workerToken + "1", workerToken + "2"));

// Switch to DIRECTPATH (same version 1, 2 endpoints, different type)
// Verifies: type change works in both directions
fakeGetWorkerMetadataStub.injectWorkerMetadata(directPathMetadata);
verify(getWorkBudgetDistributor, times(4)).distributeBudget(any(), any());
TimeUnit.SECONDS.sleep(WAIT_FOR_METADATA_INJECTIONS_SECONDS);
assertThat(fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams()).hasSize(2);
directPathTokens =
fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().keySet().stream()
.map(endpoint -> endpoint.workerToken().orElse(""))
.collect(Collectors.toSet());
assertThat(directPathTokens).contains(workerToken + "1");
assertThat(directPathTokens).contains(workerToken + "2");
assertThat(directPathTokens).containsNoneIn(java.util.Arrays.asList(workerToken));

// Switch to DIRECTPATH (same version 1, 1 endpoint, same type)
// Verifies: same version same type does not trigger consumption, endpoints remain the same
fakeGetWorkerMetadataStub.injectWorkerMetadata(directPathMetadata2);
verify(getWorkBudgetDistributor, times(4)).distributeBudget(any(), any());
TimeUnit.SECONDS.sleep(WAIT_FOR_METADATA_INJECTIONS_SECONDS);
assertThat(fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams()).hasSize(2);
directPathTokens =
fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().keySet().stream()
.map(endpoint -> endpoint.workerToken().orElse(""))
.collect(Collectors.toSet());
assertThat(directPathTokens).contains(workerToken + "1");
assertThat(directPathTokens).contains(workerToken + "2");
assertThat(directPathTokens).containsNoneIn(java.util.Arrays.asList(workerToken + "3"));

directPathMetadata2 = directPathMetadata2.toBuilder().setMetadataVersion(2).build();

// Final switch back to DIRECTPATH (different version:2, 1 endpoint, same type)
// Verifies: version change triggers consumption even if type is the same.
fakeGetWorkerMetadataStub.injectWorkerMetadata(directPathMetadata2);
verify(getWorkBudgetDistributor, times(5)).distributeBudget(any(), any());
TimeUnit.SECONDS.sleep(WAIT_FOR_METADATA_INJECTIONS_SECONDS);
assertThat(fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams()).hasSize(1);
directPathTokens =
fanOutStreamingEngineWorkProvider.currentBackends().windmillStreams().keySet().stream()
.map(endpoint -> endpoint.workerToken().orElse(""))
.collect(Collectors.toSet());
assertThat(directPathTokens)
.containsNoneIn(java.util.Arrays.asList(workerToken + "1", workerToken + "2"));
assertThat(directPathTokens).contains(workerToken + "3");
}

private static class WindmillServiceFakeStub
extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;

import java.io.IOException;
import java.util.HashMap;
Expand Down Expand Up @@ -196,45 +195,6 @@ public void testGetWorkerMetadata_consumesSubsequentResponseMetadata() {
.collect(Collectors.toList()));
}

@Test
public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() {
WorkerMetadataResponse freshEndpoints =
WorkerMetadataResponse.newBuilder()
.setMetadataVersion(2)
.addAllWorkEndpoints(DIRECT_PATH_ENDPOINTS)
.putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS)
.setExternalEndpoint(AUTHENTICATING_SERVICE)
.build();

TestWindmillEndpointsConsumer testWindmillEndpointsConsumer =
Mockito.spy(new TestWindmillEndpointsConsumer());
GetWorkerMetadataTestStub testStub =
new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver());
stream = getWorkerMetadataTestStream(testStub, testWindmillEndpointsConsumer);
testStub.injectWorkerMetadata(freshEndpoints);

List<WorkerMetadataResponse.Endpoint> staleDirectPathEndpoints =
Lists.newArrayList(
WorkerMetadataResponse.Endpoint.newBuilder()
.setDirectEndpoint("staleWindmillEndpoint")
.build());
Map<String, WorkerMetadataResponse.Endpoint> staleGlobalDataEndpoints = new HashMap<>();
staleGlobalDataEndpoints.put(
"stale_global_data",
WorkerMetadataResponse.Endpoint.newBuilder().setDirectEndpoint("staleGlobalData").build());

testStub.injectWorkerMetadata(
WorkerMetadataResponse.newBuilder()
.setMetadataVersion(1)
.addAllWorkEndpoints(staleDirectPathEndpoints)
.putAllGlobalDataEndpoints(staleGlobalDataEndpoints)
.build());

// Should have ignored the stale update and only used initial.
verify(testWindmillEndpointsConsumer).accept(WindmillEndpoints.from(freshEndpoints));
verifyNoMoreInteractions(testWindmillEndpointsConsumer);
}

@Test
public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry()
throws InterruptedException {
Expand Down
Loading