diff --git a/impl/core/src/main/java/io/serverlessworkflow/impl/WorkflowApplication.java b/impl/core/src/main/java/io/serverlessworkflow/impl/WorkflowApplication.java index c70134cd3..cf6f83b75 100644 --- a/impl/core/src/main/java/io/serverlessworkflow/impl/WorkflowApplication.java +++ b/impl/core/src/main/java/io/serverlessworkflow/impl/WorkflowApplication.java @@ -22,6 +22,8 @@ import io.serverlessworkflow.api.types.Workflow; import io.serverlessworkflow.impl.additional.NamedWorkflowAdditionalObject; import io.serverlessworkflow.impl.additional.WorkflowAdditionalObject; +import io.serverlessworkflow.impl.auth.AuthProviderFactory; +import io.serverlessworkflow.impl.auth.DefaultAuthProviderFactory; import io.serverlessworkflow.impl.config.ConfigManager; import io.serverlessworkflow.impl.config.ConfigSecretManager; import io.serverlessworkflow.impl.config.SecretManager; @@ -89,6 +91,7 @@ public class WorkflowApplication implements AutoCloseable { private final WorkflowModelFactory contextFactory; private final WorkflowScheduler scheduler; private final Map> additionalObjects; + private final AuthProviderFactory authProviderFactory; private final ConfigManager configManager; private final SecretManager secretManager; private final SchedulerListener schedulerListener; @@ -120,6 +123,7 @@ private WorkflowApplication(Builder builder) { this.scheduler = builder.scheduler; this.schedulerListener = builder.schedulerListener; this.additionalObjects = builder.additionalObjects; + this.authProviderFactory = builder.authProviderFactory; this.configManager = builder.configManager; this.secretManager = builder.secretManager; this.templateResolver = builder.templateResolver; @@ -241,6 +245,7 @@ public SchemaValidator getValidator(SchemaInline inline) { private WorkflowModelFactory modelFactory; private WorkflowModelFactory contextFactory; private Map> additionalObjects = new HashMap<>(); + private AuthProviderFactory authProviderFactory; private SecretManager secretManager; private ConfigManager configManager; private SchedulerListener schedulerListener; @@ -357,6 +362,11 @@ public Builder withAdditionalObject( return this; } + public Builder withAuthProviderFactory(AuthProviderFactory authProviderFactory) { + this.authProviderFactory = authProviderFactory; + return this; + } + public Builder withModelFactory(WorkflowModelFactory modelFactory) { this.modelFactory = modelFactory; return this; @@ -470,6 +480,10 @@ public WorkflowApplication build() { if (id == null) { id = idFactory.get(); } + if (authProviderFactory == null) { + authProviderFactory = DefaultAuthProviderFactory.factory(); + } + return new WorkflowApplication(this); } @@ -586,6 +600,10 @@ public Optional additionalObject( .map(v -> (T) v.apply(workflowContext, taskContext)); } + public AuthProviderFactory authProviderFactory() { + return authProviderFactory; + } + public Collection callableProxyBuilders() { return callableProxyBuilders; } diff --git a/impl/core/src/main/java/io/serverlessworkflow/impl/auth/AuthProviderFactory.java b/impl/core/src/main/java/io/serverlessworkflow/impl/auth/AuthProviderFactory.java index 6accc3360..2630250f2 100644 --- a/impl/core/src/main/java/io/serverlessworkflow/impl/auth/AuthProviderFactory.java +++ b/impl/core/src/main/java/io/serverlessworkflow/impl/auth/AuthProviderFactory.java @@ -15,87 +15,17 @@ */ package io.serverlessworkflow.impl.auth; -import io.serverlessworkflow.api.types.AuthenticationPolicyUnion; import io.serverlessworkflow.api.types.EndpointConfiguration; import io.serverlessworkflow.api.types.ReferenceableAuthenticationPolicy; -import io.serverlessworkflow.api.types.Use; -import io.serverlessworkflow.api.types.UseAuthentications; -import io.serverlessworkflow.api.types.Workflow; -import io.serverlessworkflow.impl.WorkflowApplication; import io.serverlessworkflow.impl.WorkflowDefinition; import java.util.Optional; -public class AuthProviderFactory { +/** Resolves the {@link AuthProvider} to use for a given authentication policy. */ +public interface AuthProviderFactory { - private AuthProviderFactory() {} + Optional getAuth( + WorkflowDefinition definition, EndpointConfiguration configuration); - public static Optional getAuth( - WorkflowDefinition definition, EndpointConfiguration configuration) { - return configuration == null - ? Optional.empty() - : getAuth(definition, configuration.getAuthentication(), "GET"); - } - - public static Optional getAuth( - WorkflowDefinition definition, ReferenceableAuthenticationPolicy auth, String method) { - if (auth == null) { - return Optional.empty(); - } - if (auth.getAuthenticationPolicyReference() != null) { - return buildFromReference( - definition.application(), - definition.workflow(), - auth.getAuthenticationPolicyReference().getUse(), - method); - } else if (auth.getAuthenticationPolicy() != null) { - return buildFromPolicy( - definition.application(), definition.workflow(), auth.getAuthenticationPolicy(), method); - } - return Optional.empty(); - } - - private static Optional buildFromReference( - WorkflowApplication app, Workflow workflow, String use, String method) { - Use useInfo = workflow.getUse(); - if (useInfo == null) { - return Optional.empty(); - } - UseAuthentications authInfo = useInfo.getAuthentications(); - return authInfo == null - ? Optional.empty() - : authInfo.getAdditionalProperties().entrySet().stream() - .filter(s -> s.getKey().equals(use)) - .findAny() - .flatMap(e -> buildFromPolicy(app, workflow, e.getValue(), method)); - } - - private static Optional buildFromPolicy( - WorkflowApplication app, - Workflow workflow, - AuthenticationPolicyUnion authenticationPolicy, - String method) { - if (authenticationPolicy.getBasicAuthenticationPolicy() != null) { - return Optional.of( - new BasicAuthProvider( - app, workflow, authenticationPolicy.getBasicAuthenticationPolicy())); - } else if (authenticationPolicy.getBearerAuthenticationPolicy() != null) { - return Optional.of( - new BearerAuthProvider( - app, workflow, authenticationPolicy.getBearerAuthenticationPolicy())); - } else if (authenticationPolicy.getDigestAuthenticationPolicy() != null) { - return Optional.of( - new DigestAuthProvider( - app, workflow, authenticationPolicy.getDigestAuthenticationPolicy(), method)); - } else if (authenticationPolicy.getOAuth2AuthenticationPolicy() != null) { - return Optional.of( - new OAuth2AuthProvider( - app, workflow, authenticationPolicy.getOAuth2AuthenticationPolicy())); - } else if (authenticationPolicy.getOpenIdConnectAuthenticationPolicy() != null) { - return Optional.of( - new OpenIdAuthProvider( - app, workflow, authenticationPolicy.getOpenIdConnectAuthenticationPolicy())); - } - - return Optional.empty(); - } + Optional getAuth( + WorkflowDefinition definition, ReferenceableAuthenticationPolicy auth, String method); } diff --git a/impl/core/src/main/java/io/serverlessworkflow/impl/auth/DefaultAuthProviderFactory.java b/impl/core/src/main/java/io/serverlessworkflow/impl/auth/DefaultAuthProviderFactory.java new file mode 100644 index 000000000..59d8f193f --- /dev/null +++ b/impl/core/src/main/java/io/serverlessworkflow/impl/auth/DefaultAuthProviderFactory.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020-Present The Serverless Workflow Specification Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.serverlessworkflow.impl.auth; + +import io.serverlessworkflow.api.types.AuthenticationPolicyUnion; +import io.serverlessworkflow.api.types.EndpointConfiguration; +import io.serverlessworkflow.api.types.ReferenceableAuthenticationPolicy; +import io.serverlessworkflow.api.types.Use; +import io.serverlessworkflow.api.types.UseAuthentications; +import io.serverlessworkflow.api.types.Workflow; +import io.serverlessworkflow.impl.WorkflowApplication; +import io.serverlessworkflow.impl.WorkflowDefinition; +import java.util.Optional; + +public class DefaultAuthProviderFactory implements AuthProviderFactory { + + private static class DefaultAuthProviderFactoryHolder { + private static final DefaultAuthProviderFactory instance = new DefaultAuthProviderFactory(); + } + + public static DefaultAuthProviderFactory factory() { + return DefaultAuthProviderFactoryHolder.instance; + } + + @Override + public Optional getAuth( + WorkflowDefinition definition, EndpointConfiguration configuration) { + return configuration == null + ? Optional.empty() + : getAuth(definition, configuration.getAuthentication(), "GET"); + } + + @Override + public Optional getAuth( + WorkflowDefinition definition, ReferenceableAuthenticationPolicy auth, String method) { + if (auth == null) { + return Optional.empty(); + } + if (auth.getAuthenticationPolicyReference() != null) { + return buildFromReference( + definition.application(), + definition.workflow(), + auth.getAuthenticationPolicyReference().getUse(), + method); + } else if (auth.getAuthenticationPolicy() != null) { + return buildFromPolicy( + definition.application(), definition.workflow(), auth.getAuthenticationPolicy(), method); + } + return Optional.empty(); + } + + private Optional buildFromReference( + WorkflowApplication app, Workflow workflow, String use, String method) { + Use useInfo = workflow.getUse(); + if (useInfo == null) { + return Optional.empty(); + } + UseAuthentications authInfo = useInfo.getAuthentications(); + return authInfo == null + ? Optional.empty() + : authInfo.getAdditionalProperties().entrySet().stream() + .filter(s -> s.getKey().equals(use)) + .findAny() + .flatMap(e -> buildFromPolicy(app, workflow, e.getValue(), method)); + } + + private Optional buildFromPolicy( + WorkflowApplication app, + Workflow workflow, + AuthenticationPolicyUnion authenticationPolicy, + String method) { + if (authenticationPolicy.getBasicAuthenticationPolicy() != null) { + return Optional.of( + new BasicAuthProvider( + app, workflow, authenticationPolicy.getBasicAuthenticationPolicy())); + } else if (authenticationPolicy.getBearerAuthenticationPolicy() != null) { + return Optional.of( + new BearerAuthProvider( + app, workflow, authenticationPolicy.getBearerAuthenticationPolicy())); + } else if (authenticationPolicy.getDigestAuthenticationPolicy() != null) { + return Optional.of( + new DigestAuthProvider( + app, workflow, authenticationPolicy.getDigestAuthenticationPolicy(), method)); + } else if (authenticationPolicy.getOAuth2AuthenticationPolicy() != null) { + return Optional.of( + new OAuth2AuthProvider( + app, workflow, authenticationPolicy.getOAuth2AuthenticationPolicy())); + } else if (authenticationPolicy.getOpenIdConnectAuthenticationPolicy() != null) { + return Optional.of( + new OpenIdAuthProvider( + app, workflow, authenticationPolicy.getOpenIdConnectAuthenticationPolicy())); + } + + return Optional.empty(); + } +} diff --git a/impl/core/src/main/java/io/serverlessworkflow/impl/resources/ResourceLoader.java b/impl/core/src/main/java/io/serverlessworkflow/impl/resources/ResourceLoader.java index a9a914497..08e3d0725 100644 --- a/impl/core/src/main/java/io/serverlessworkflow/impl/resources/ResourceLoader.java +++ b/impl/core/src/main/java/io/serverlessworkflow/impl/resources/ResourceLoader.java @@ -26,7 +26,6 @@ import io.serverlessworkflow.impl.WorkflowContext; import io.serverlessworkflow.impl.WorkflowModel; import io.serverlessworkflow.impl.WorkflowValueResolver; -import io.serverlessworkflow.impl.auth.AuthProviderFactory; import io.serverlessworkflow.impl.auth.AuthUtils; import io.serverlessworkflow.impl.expressions.ExpressionDescriptor; import java.net.URI; @@ -108,8 +107,9 @@ public T load( return loadURI( uri, function, - AuthProviderFactory.getAuth( - workflowContext.definition(), endPoint.getEndpointConfiguration()) + application + .authProviderFactory() + .getAuth(workflowContext.definition(), endPoint.getEndpointConfiguration()) .map( auth -> AuthUtils.authHeaderValue( diff --git a/impl/http/src/main/java/io/serverlessworkflow/impl/executors/http/HttpExecutorBuilder.java b/impl/http/src/main/java/io/serverlessworkflow/impl/executors/http/HttpExecutorBuilder.java index 92349d84d..e6ec62530 100644 --- a/impl/http/src/main/java/io/serverlessworkflow/impl/executors/http/HttpExecutorBuilder.java +++ b/impl/http/src/main/java/io/serverlessworkflow/impl/executors/http/HttpExecutorBuilder.java @@ -19,7 +19,7 @@ import io.serverlessworkflow.impl.WorkflowDefinition; import io.serverlessworkflow.impl.WorkflowUtils; import io.serverlessworkflow.impl.WorkflowValueResolver; -import io.serverlessworkflow.impl.auth.AuthProviderFactory; +import io.serverlessworkflow.impl.auth.AuthProvider; import jakarta.ws.rs.HttpMethod; import java.net.URI; import java.util.Map; @@ -101,24 +101,21 @@ public static HttpExecutorBuilder builder(WorkflowDefinition definition) { } private RequestExecutor buildRequestExecutor() { - String theMethod = method.toUpperCase(); - switch (theMethod) { + String httpMethod = method.toUpperCase(); + Optional auth = + definition.application().authProviderFactory().getAuth(definition, policy, httpMethod); + switch (httpMethod) { case HttpMethod.POST: case HttpMethod.PUT: case HttpMethod.PATCH: return new WithBodyRequestExecutor( - theMethod, - redirect, - AuthProviderFactory.getAuth(definition, policy, method), - definition.application(), - body); + httpMethod, redirect, auth, definition.application(), body); case HttpMethod.DELETE: case HttpMethod.HEAD: case HttpMethod.OPTIONS: case HttpMethod.GET: default: - return new WithoutBodyRequestExecutor( - theMethod, redirect, AuthProviderFactory.getAuth(definition, policy, method)); + return new WithoutBodyRequestExecutor(httpMethod, redirect, auth); } } } diff --git a/impl/test/src/test/java/io/serverlessworkflow/impl/test/CustomAuthProviderFactoryOverrideTest.java b/impl/test/src/test/java/io/serverlessworkflow/impl/test/CustomAuthProviderFactoryOverrideTest.java new file mode 100644 index 000000000..46344a050 --- /dev/null +++ b/impl/test/src/test/java/io/serverlessworkflow/impl/test/CustomAuthProviderFactoryOverrideTest.java @@ -0,0 +1,149 @@ +/* + * Copyright 2020-Present The Serverless Workflow Specification Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.serverlessworkflow.impl.test; + +import static io.serverlessworkflow.api.WorkflowReader.readWorkflowFromClasspath; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.serverlessworkflow.api.types.EndpointConfiguration; +import io.serverlessworkflow.api.types.ReferenceableAuthenticationPolicy; +import io.serverlessworkflow.api.types.Workflow; +import io.serverlessworkflow.impl.TaskContext; +import io.serverlessworkflow.impl.WorkflowApplication; +import io.serverlessworkflow.impl.WorkflowContext; +import io.serverlessworkflow.impl.WorkflowDefinition; +import io.serverlessworkflow.impl.WorkflowModel; +import io.serverlessworkflow.impl.auth.AuthProvider; +import io.serverlessworkflow.impl.auth.AuthProviderFactory; +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class CustomAuthProviderFactoryOverrideTest { + + private static final String RESPONSE = + """ + { + "message": "Hello World" + } + """; + + private MockWebServer authServer; + private MockWebServer apiServer; + + @BeforeEach + void setUp() throws IOException { + authServer = new MockWebServer(); + authServer.start(8888); + + apiServer = new MockWebServer(); + apiServer.start(8081); + } + + @AfterEach + void tearDown() throws IOException { + authServer.shutdown(); + apiServer.shutdown(); + } + + @Test + public void frameworkOverrideSuppliesAuthHeaderWithoutTokenExchange() throws Exception { + final String frameworkToken = "framework-managed-token"; + final AtomicInteger factoryInvocations = new AtomicInteger(); + + AuthProviderFactory frameworkFactory = + new AuthProviderFactory() { + @Override + public Optional getAuth( + WorkflowDefinition definition, EndpointConfiguration configuration) { + return getAuth( + definition, + configuration == null ? null : configuration.getAuthentication(), + "GET"); + } + + @Override + public Optional getAuth( + WorkflowDefinition definition, + ReferenceableAuthenticationPolicy auth, + String method) { + factoryInvocations.incrementAndGet(); + return Optional.of( + new AuthProvider() { + @Override + public String scheme() { + return "Bearer"; + } + + @Override + public String content( + WorkflowContext workflow, TaskContext task, WorkflowModel model, URI uri) { + return frameworkToken; + } + }); + } + }; + + apiServer.enqueue( + new MockResponse() + .setBody(RESPONSE) + .setHeader("Content-Type", "application/json") + .setResponseCode(200)); + apiServer.enqueue( + new MockResponse() + .setBody(RESPONSE) + .setHeader("Content-Type", "application/json") + .setResponseCode(200)); + + try (WorkflowApplication app = + WorkflowApplication.builder().withAuthProviderFactory(frameworkFactory).build()) { + + Workflow workflow = + readWorkflowFromClasspath( + "workflows-samples/oauth2/oAuthClientSecretPostClientCredentialsHttpCall.yaml"); + WorkflowDefinition definition = app.workflowDefinition(workflow); + + // Run twice on the same definition to check if the AuthProvider is resolved once at build + // time + for (int i = 0; i < 2; i++) { + Map result = + definition.instance(Map.of()).start().get().asMap().orElseThrow(); + assertTrue(result.get("message").toString().contains("Hello World")); + } + } + + for (int i = 0; i < 2; i++) { + RecordedRequest apiRequest = apiServer.takeRequest(); + assertEquals("GET", apiRequest.getMethod()); + assertEquals("/hello", apiRequest.getPath()); + assertEquals("Bearer " + frameworkToken, apiRequest.getHeader("Authorization")); + } + + // The SDK never performed a token exchange + assertEquals(0, authServer.getRequestCount()); + + assertEquals(1, factoryInvocations.get()); + } +}