diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt index c426db2aabcb..e9ca1e5b9efe 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt @@ -34,6 +34,7 @@ import org.springframework.web.reactive.function.client.WebClient.RequestHeaders import reactor.core.publisher.Flux import reactor.core.publisher.Mono import reactor.util.context.Context +import reactor.util.retry.Retry import kotlin.coroutines.CoroutineContext /** @@ -226,7 +227,7 @@ inline fun WebClient.ResponseSpec.toEntityFlux(): Mono() {}) /** - * Extension for [WebClient.ResponseSpec.toEntity] providing a `toEntity()` variant + * Extension for [WebClient.ResponseSpec.toEntity] providing a `awaitEntity()` variant * leveraging Kotlin reified type parameters and allows [kotlin.coroutines.CoroutineContext] * propagation to the [CoExchangeFilterFunction]. This extension is not subject to type erasure * and retains actual generic type arguments. @@ -240,6 +241,22 @@ suspend inline fun WebClient.ResponseSpec.awaitEntity(): Respo } } +/** + * Extension for [WebClient.ResponseSpec.toEntity] providing a `awaitEntityWithRetry(Retry)` variant + * leveraging Kotlin reified type parameters and allows [kotlin.coroutines.CoroutineContext] + * propagation to the [CoExchangeFilterFunction]. This extension is not subject to type erasure + * and retains actual generic type arguments. + * + * @param retrySpec the [Retry] strategy passed to the [Mono.retryWhen] + * @param T the type of the body + */ +suspend inline fun WebClient.ResponseSpec.awaitEntityWithRetry(retrySpec: Retry): ResponseEntity { + val context = currentCoroutineContext().minusKey(Job.Key) + return withContext(context.toReactorContext()) { + toEntity().retryWhen(retrySpec).awaitSingle() + } +} + private val contextPropagationPresent = ClassUtils.isPresent("io.micrometer.context.ContextSnapshotFactory", WebClient::class.java.classLoader) diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt index 78943f1013c2..0533c3368099 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt @@ -41,8 +41,10 @@ import org.springframework.web.reactive.function.client.CoExchangeFilterFunction import reactor.core.publisher.Flux import reactor.core.publisher.Hooks import reactor.core.publisher.Mono +import reactor.util.retry.Retry import java.time.Duration import java.util.concurrent.CompletableFuture +import java.util.concurrent.atomic.AtomicInteger import java.util.function.Function import kotlin.coroutines.AbstractCoroutineContextElement import kotlin.coroutines.CoroutineContext @@ -247,6 +249,42 @@ class WebClientExtensionsTests { } } + @Test + fun `ResponseSpec#awaitEntityWithRetry with coroutine context propagation`() { + val exchangeFunction = mockk() + val mockResponse = mockk() + val mockClientHeaders = mockk() + val foo = mockk() + val slot = slot() + val atomicInteger = AtomicInteger(0) + every { exchangeFunction.exchange(capture(slot)) } answers { + if (atomicInteger.getAndIncrement() < 2) { + Mono.error(Exception()) + } else { + Mono.just(mockResponse) + } + } + every { mockResponse.statusCode() } returns HttpStatus.OK + every { mockResponse.headers() } returns mockClientHeaders + every { mockClientHeaders.asHttpHeaders() } returns HttpHeaders() + every { mockResponse.bodyToMono(object : ParameterizedTypeReference() {}) } returns Mono.just(foo) + runBlocking(FooContextElement(foo)) { + val responseEntity = WebClient.builder() + .exchangeFunction(exchangeFunction) + .filter(object : CoExchangeFilterFunction() { + override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse { + assertThat(currentCoroutineContext()[FooContextElement.Key]!!.foo).isEqualTo(foo) + return next.exchange(request) + } + }) + .build().get().uri("/path").retrieve().awaitEntityWithRetry(Retry.max(2)) + val capturedContext = slot.captured.attribute(COROUTINE_CONTEXT_ATTRIBUTE).get() as CoroutineContext + assertThat(atomicInteger.get()).isEqualTo(3) + assertThat(capturedContext[FooContextElement.Key]!!.foo).isEqualTo(foo) + assertThat(responseEntity.body).isEqualTo(foo) + } + } + @Test fun `ResponseSpec#awaitEntity with coroutine context propagation to multiple CoExchangeFilterFunctions`() { val exchangeFunction = mockk()