diff --git a/spring-core/src/main/java/org/springframework/core/codec/InputStreamDecoder.java b/spring-core/src/main/java/org/springframework/core/codec/InputStreamDecoder.java new file mode 100644 index 000000000000..224a499af022 --- /dev/null +++ b/spring-core/src/main/java/org/springframework/core/codec/InputStreamDecoder.java @@ -0,0 +1,270 @@ +package org.springframework.core.codec; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +/** + * Decoder that translates data buffers to an {@link InputStream}. + */ +public class InputStreamDecoder extends AbstractDataBufferDecoder { + + public static final String FAIL_FAST = InputStreamDecoder.class.getName() + ".FAIL_FAST"; + + public InputStreamDecoder() { + super(MimeTypeUtils.ALL); + } + + @Override + public boolean canDecode(ResolvableType elementType, @Nullable MimeType mimeType) { + return (elementType.resolve() == InputStream.class && super.canDecode(elementType, mimeType)); + } + + @Override + public InputStream decode(DataBuffer dataBuffer, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + if (logger.isDebugEnabled()) { + logger.debug(Hints.getLogPrefix(hints) + "Reading " + dataBuffer.readableByteCount() + " bytes"); + } + return dataBuffer.asInputStream(true); + } + + @Override + public Mono decodeToMono(Publisher input, ResolvableType elementType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + boolean failFast = hints == null || Boolean.TRUE.equals(hints.getOrDefault(FAIL_FAST, Boolean.TRUE)); + FlowBufferInputStream inputStream = new FlowBufferInputStream(getMaxInMemorySize(), failFast); + Flux.from(input).subscribe(inputStream); + + return Mono.just(inputStream); + } + + static class FlowBufferInputStream extends InputStream implements Subscriber { + + private static final Object END = new Object(); + + private final AtomicBoolean closed = new AtomicBoolean(); + + private final BlockingQueue backlog; + + private final int maximumMemorySize; + + private final boolean failFast; + + private final AtomicInteger buffered = new AtomicInteger(); + + @Nullable + private InputStreamWithSize current = new InputStreamWithSize(0, InputStream.nullInputStream()); + + @Nullable + private Subscription subscription; + + FlowBufferInputStream(int maximumMemorySize, boolean failFast) { + this.backlog = new LinkedBlockingDeque<>(); + this.maximumMemorySize = maximumMemorySize; + this.failFast = failFast; + } + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + if (this.closed.get()) { + subscription.cancel(); + } else { + subscription.request(1); + } + } + + @Override + public void onNext(DataBuffer buffer) { + if (this.closed.get()) { + DataBufferUtils.release(buffer); + return; + } + int readableByteCount = buffer.readableByteCount(); + int current = this.buffered.addAndGet(readableByteCount); + if (current < this.maximumMemorySize) { + this.subscription.request(1); + } + InputStream stream = buffer.asInputStream(true); + this.backlog.add(new InputStreamWithSize(readableByteCount, stream)); + if (this.closed.get()) { + DataBufferUtils.release(buffer); + } + } + + @Override + public void onError(Throwable throwable) { + if (failFast) { + Object next; + while ((next = this.backlog.poll()) != null) { + if (next instanceof InputStreamWithSize) { + try { + ((InputStreamWithSize) next).inputStream.close(); + } catch (Throwable t) { + throwable.addSuppressed(t); + } + } + } + } + this.backlog.add(throwable); + } + + @Override + public void onComplete() { + this.backlog.add(END); + } + + private boolean forward() throws IOException { + this.current.inputStream.close(); + try { + Object next = this.backlog.take(); + if (next == END) { + this.current = null; + return true; + } else if (next instanceof RuntimeException) { + close(); + throw (RuntimeException) next; + } else if (next instanceof IOException) { + close(); + throw (IOException) next; + } else if (next instanceof Throwable) { + close(); + throw new IllegalStateException((Throwable) next); + } else { + int buffer = buffered.addAndGet(-this.current.size); + if (buffer < this.maximumMemorySize) { + this.subscription.request(1); + } + this.current = (InputStreamWithSize) next; + return false; + } + } catch (InterruptedException e) { + throw new IllegalStateException(e); + } + } + + @Override + public int read() throws IOException { + if (this.closed.get()) { + throw new IOException("closed"); + } else if (this.current == null) { + return -1; + } + int read; + while ((read = this.current.inputStream.read()) == -1) { + if (forward()) { + return -1; + } + } + return read; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + Objects.checkFromIndexSize(off, len, b.length); + if (this.closed.get()) { + throw new IOException("closed"); + } else if (this.current == null) { + return -1; + } else if (len == 0) { + return 0; + } + int sum = 0; + do { + int read = this.current.inputStream.read(b, off + sum, len - sum); + if (read == -1) { + if (sum > 0 && this.backlog.isEmpty()) { + return sum; + } else if (forward()) { + return sum == 0 ? -1 : sum; + } + } else { + sum += read; + } + } while (sum < len); + return sum; + } + + @Override + public int available() throws IOException { + if (this.closed.get()) { + throw new IOException("closed"); + } else if (this.current == null) { + return 0; + } + int available = this.current.inputStream.available(); + for (Object value : this.backlog) { + if (value instanceof InputStreamWithSize) { + available += ((InputStreamWithSize) value).inputStream.available(); + } else { + break; + } + } + return available; + } + + @Override + public void close() throws IOException { + if (this.closed.compareAndSet(false, true)) { + if (this.subscription != null) { + this.subscription.cancel(); + } + IOException exception = null; + if (this.current != null) { + try { + this.current.inputStream.close(); + } catch (IOException e) { + exception = e; + } + } + for (Object value : this.backlog) { + if (value instanceof InputStreamWithSize) { + try { + ((InputStreamWithSize) value).inputStream.close(); + } catch (IOException e) { + if (exception == null) { + exception = e; + } else { + exception.addSuppressed(e); + } + } + } + } + if (exception != null) { + throw exception; + } + } + } + } + + static class InputStreamWithSize { + + final int size; + + final InputStream inputStream; + + InputStreamWithSize(int size, InputStream inputStream) { + this.size = size; + this.inputStream = inputStream; + } + } +} \ No newline at end of file diff --git a/spring-core/src/test/java/org/springframework/core/codec/InputStreamDecoderTests.java b/spring-core/src/test/java/org/springframework/core/codec/InputStreamDecoderTests.java new file mode 100644 index 000000000000..4b1f842abbe4 --- /dev/null +++ b/spring-core/src/test/java/org/springframework/core/codec/InputStreamDecoderTests.java @@ -0,0 +1,151 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.core.codec; + +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.testfixture.codec.AbstractDecoderTests; +import org.springframework.lang.Nullable; +import org.springframework.util.MimeType; +import org.springframework.util.MimeTypeUtils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Map; +import java.util.function.Consumer; + +import static org.assertj.core.api.Assertions.*; + +/** + * @author Vladislav Kisel + */ +class InputStreamDecoderTests extends AbstractDecoderTests { + + private final byte[] fooBytes = "foo".getBytes(StandardCharsets.UTF_8); + + private final byte[] barBytes = "bar".getBytes(StandardCharsets.UTF_8); + + + InputStreamDecoderTests() { + super(new InputStreamDecoder()); + } + + @Override + @Test + public void canDecode() { + assertThat(this.decoder.canDecode(ResolvableType.forClass(InputStream.class), + MimeTypeUtils.TEXT_PLAIN)).isTrue(); + assertThat(this.decoder.canDecode(ResolvableType.forClass(Integer.class), + MimeTypeUtils.TEXT_PLAIN)).isFalse(); + assertThat(this.decoder.canDecode(ResolvableType.forClass(InputStream.class), + MimeTypeUtils.APPLICATION_JSON)).isTrue(); + } + + @Override + @Test + public void decode() { + Flux input = Flux.just( + this.bufferFactory.wrap(this.fooBytes), + this.bufferFactory.wrap(this.barBytes)); + + testDecodeAll(input, InputStream.class, step -> step + .consumeNextWith(expectInputStream(this.fooBytes)) + .consumeNextWith(expectInputStream(this.barBytes)) + .verifyComplete()); + } + + @Override + @Test + public void decodeToMono() { + Flux input = Flux.concat( + dataBuffer(this.fooBytes), + dataBuffer(this.barBytes)); + + byte[] expected = new byte[this.fooBytes.length + this.barBytes.length]; + System.arraycopy(this.fooBytes, 0, expected, 0, this.fooBytes.length); + System.arraycopy(this.barBytes, 0, expected, this.fooBytes.length, this.barBytes.length); + + testDecodeToMonoAll(input, InputStream.class, step -> step + .consumeNextWith(expectInputStream(expected)) + .verifyComplete()); + testDecodeToMonoErrorFailLast(input, expected); + } + + @Override + protected void testDecodeToMonoError(Publisher input, ResolvableType outputType, + @Nullable MimeType mimeType, @Nullable Map hints) { + + input = Flux.from(input).concatWith(Flux.error(new InputException())); + try (InputStream result = this.decoder.decodeToMono(input, outputType, mimeType, hints).block()) { + assertThatThrownBy(() -> result.read()).isInstanceOf(InputException.class); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private void testDecodeToMonoErrorFailLast(Publisher input, byte[] expected) { + input = Flux.concatDelayError(Flux.from(input), Flux.error(new InputException())); + try (InputStream result = this.decoder.decodeToMono(input, + ResolvableType.forType(InputStream.class), + null, + Collections.singletonMap(InputStreamDecoder.FAIL_FAST, false)).block()) { + byte[] actual = new byte[expected.length]; + assertThat(result.read(actual)).isEqualTo(expected.length); + assertThat(actual).isEqualTo(expected); + assertThatThrownBy(() -> result.read()).isInstanceOf(InputException.class); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + protected void testDecodeToMonoCancel(Publisher input, ResolvableType outputType, + @Nullable MimeType mimeType, @Nullable Map hints) { } + + @Override + protected void testDecodeToMonoEmpty(ResolvableType outputType, @Nullable MimeType mimeType, + @Nullable Map hints) { + + try (InputStream result = this.decoder.decodeToMono(Flux.empty(), outputType, mimeType, hints).block()) { + assertThat(result.read()).isEqualTo(-1); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private Consumer expectInputStream(byte[] expected) { + return actual -> { + try (actual) { + byte[] actualBytes = actual.readAllBytes(); + assertThat(actualBytes).isEqualTo(expected); + } catch (IOException ignored) { + } + }; + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingClientHttpRequest.java index abedb2c051c6..9cba17d3303a 100644 --- a/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingClientHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/client/AbstractBufferingClientHttpRequest.java @@ -29,7 +29,7 @@ * @author Arjen Poutsma * @since 3.0.6 */ -abstract class AbstractBufferingClientHttpRequest extends AbstractClientHttpRequest { +public abstract class AbstractBufferingClientHttpRequest extends AbstractClientHttpRequest { private ByteArrayOutputStream bufferedOutput = new ByteArrayOutputStream(1024); diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java new file mode 100644 index 000000000000..ef08f074c4b7 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; + +/** + * {@link ClientHttpRequest} implementation based on + * JDK HTTP client. + * + *

Created via the {@link JdkClientHttpRequestFactory}. + */ +final class JdkClientHttpRequest extends AbstractBufferingClientHttpRequest { + + private final HttpClient httpClient; + + private final HttpMethod method; + + private final URI uri; + + private final boolean expectContinue; + + @Nullable + private final Duration requestTimeout; + + JdkClientHttpRequest(HttpClient client, HttpMethod method, URI uri, + boolean expectContinue, @Nullable Duration requestTimeout) { + this.httpClient = client; + this.method = method; + this.uri = uri; + this.expectContinue = expectContinue; + this.requestTimeout = requestTimeout; + } + + + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + @Deprecated + public String getMethodValue() { + return this.method.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) throws IOException { + HttpRequest.Builder builder = HttpRequest.newBuilder(this.uri); + + addHeaders(builder, headers); + + builder.method(this.method.name(), bufferedOutput.length == 0 + ? HttpRequest.BodyPublishers.noBody() + : HttpRequest.BodyPublishers.ofByteArray(bufferedOutput)); + + if (expectContinue) { + builder.expectContinue(true); + } + if (requestTimeout != null) { + builder.timeout(requestTimeout); + } + + HttpResponse response; + try { + response = this.httpClient.send(builder.build(), HttpResponse.BodyHandlers.ofInputStream()); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new JdkClientHttpResponse(response); + } + + /** + * Add the given headers to the given HTTP request. + * @param builder the request builder to add the headers to + * @param headers the headers to add + */ + static void addHeaders(HttpRequest.Builder builder, HttpHeaders headers) { + headers.forEach((headerName, headerValues) -> { + if (HttpHeaders.COOKIE.equalsIgnoreCase(headerName)) { // RFC 6265 + String headerValue = StringUtils.collectionToDelimitedString(headerValues, "; "); + builder.header(headerName, headerValue); + } + else if (!HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(headerName) && + !HttpHeaders.TRANSFER_ENCODING.equalsIgnoreCase(headerName)) { + for (String headerValue : headerValues) { + builder.header(headerName, headerValue); + } + } + }); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java new file mode 100644 index 000000000000..7b8fc4b4ac59 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequestFactory.java @@ -0,0 +1,117 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.springframework.http.HttpMethod; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.time.Duration; + +/** + * {@link org.springframework.http.client.ClientHttpRequestFactory} implementation that + * uses the Java {@link HttpClient}. + */ +public class JdkClientHttpRequestFactory implements ClientHttpRequestFactory { + + private HttpClient httpClient; + + private boolean expectContinue; + + @Nullable + private Duration requestTimeout; + + private boolean bufferRequestBody = true; + + + /** + * Create a new instance of the {@code JdkClientHttpRequestFactory} + * with a default {@link HttpClient}. + */ + public JdkClientHttpRequestFactory() { + this.httpClient = HttpClient.newHttpClient(); + } + + /** + * Create a new instance of the {@code JdkClientHttpRequestFactory} + * with the given {@link HttpClient} instance. + * @param httpClient the HttpClient instance to use for this request factory + */ + public JdkClientHttpRequestFactory(HttpClient httpClient) { + this.httpClient = httpClient; + } + + /** + * Set the {@code HttpClient} used for + * {@linkplain #createRequest(URI, HttpMethod) synchronous execution}. + */ + public void setHttpClient(HttpClient httpClient) { + Assert.notNull(httpClient, "HttpClient must not be null"); + this.httpClient = httpClient; + } + + /** + * Return the {@code HttpClient} used for + * {@linkplain #createRequest(URI, HttpMethod) synchronous execution}. + */ + public HttpClient getHttpClient() { + return this.httpClient; + } + + /** + * If {@code true}, requests the server to acknowledge the request before sending the body. + * @param expectContinue {@code} if the server is requested to acknowledge the request + * @see HttpRequest#expectContinue() + */ + public void setExpectContinue(boolean expectContinue) { + this.expectContinue = expectContinue; + } + + /** + * Set the request timeout for a request. A {code null} of 0 specifies an infinite timeout. + * @param requestTimeout the timeout value or {@code null} to disable the timeout + * @see HttpRequest#timeout() + */ + public void setRequestTimeout(@Nullable Duration requestTimeout) { + this.requestTimeout = requestTimeout; + } + + /** + * Indicates whether this request factory should buffer the request body internally. + *

Default is {@code true}. When sending large amounts of data via POST or PUT, it is + * recommended to change this property to {@code false}, so as not to run out of memory. + */ + public void setBufferRequestBody(boolean bufferRequestBody) { + this.bufferRequestBody = bufferRequestBody; + } + + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + HttpClient client = getHttpClient(); + + if (this.bufferRequestBody) { + return new JdkClientHttpRequest(client, httpMethod, uri, expectContinue, requestTimeout); + } + else { + return new JdkClientStreamingHttpRequest(client, httpMethod, uri, expectContinue, requestTimeout); + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java new file mode 100644 index 000000000000..bd21be8e3b6e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java @@ -0,0 +1,93 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatusCode; +import org.springframework.lang.Nullable; + +import java.io.IOException; +import java.io.InputStream; +import java.net.http.HttpResponse; + +/** + * {@link ClientHttpResponse} implementation based on + * JDK HTTP client. + * + *

Created via the {@link JdkClientHttpRequest}. + */ +final class JdkClientHttpResponse implements ClientHttpResponse { + + private final HttpResponse httpResponse; + + @Nullable + private HttpHeaders headers; + + + JdkClientHttpResponse(HttpResponse httpResponse) { + this.httpResponse = httpResponse; + } + + + @Override + public HttpStatusCode getStatusCode() throws IOException { + return HttpStatusCode.valueOf(this.httpResponse.statusCode()); + } + + @Override + @Deprecated + public int getRawStatusCode() throws IOException { + return this.httpResponse.statusCode(); + } + + @Override + public String getStatusText() throws IOException { + return ""; + } + + @Override + public HttpHeaders getHeaders() { + if (this.headers == null) { + this.headers = new HttpHeaders(); + this.httpResponse.headers().map().forEach((key, values) -> this.headers.addAll(key, values)); + } + return this.headers; + } + + @Override + public InputStream getBody() throws IOException { + return this.httpResponse.body(); + } + + @Override + public void close() { + // Release underlying connection back to the connection manager + try { + try { + // Attempt to keep connection alive by consuming its remaining content + this.httpResponse.body().readAllBytes(); + } + finally { + this.httpResponse.body().close(); + } + } + catch (IOException ex) { + // Ignore exception on close... + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientStreamingHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientStreamingHttpRequest.java new file mode 100644 index 000000000000..0a26b743440e --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientStreamingHttpRequest.java @@ -0,0 +1,216 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.lang.Nullable; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicReference; + +/** + * {@link ClientHttpRequest} implementation based on + * JDK HTTP client in streaming mode. + * + *

Created via the {@link JdkClientHttpRequestFactory}. + */ +final class JdkClientStreamingHttpRequest extends AbstractClientHttpRequest + implements StreamingHttpOutputMessage { + + private final HttpClient httpClient; + + private final HttpMethod method; + + private final URI uri; + + private final boolean expectContinue; + + @Nullable + private final Duration requestTimeout; + + @Nullable + private Body body; + + JdkClientStreamingHttpRequest(HttpClient client, HttpMethod method, URI uri, + boolean expectContinue, @Nullable Duration requestTimeout) { + this.httpClient = client; + this.method = method; + this.uri = uri; + this.expectContinue = expectContinue; + this.requestTimeout = requestTimeout; + } + + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + @Deprecated + public String getMethodValue() { + return this.method.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public void setBody(Body body) { + assertNotExecuted(); + this.body = body; + } + + @Override + protected OutputStream getBodyInternal(HttpHeaders headers) throws IOException { + throw new UnsupportedOperationException("getBody not supported"); + } + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers) throws IOException { + HttpRequest.Builder builder = HttpRequest.newBuilder(this.uri); + + JdkClientHttpRequest.addHeaders(builder, headers); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference reference; + if (this.body != null) { + reference = new AtomicReference<>(); + builder.method(this.method.name(), HttpRequest.BodyPublishers.fromPublisher(subscriber -> { + SubscriptionOutputStream outputStream = new SubscriptionOutputStream(subscriber); + reference.set(outputStream); + latch.countDown(); + try { + subscriber.onSubscribe(outputStream); + } catch (Throwable t) { + outputStream.closed = true; + throw t; + } + })); + } else { + reference = null; + builder.method(this.method.name(), HttpRequest.BodyPublishers.noBody()); + } + + if (expectContinue) { + builder.expectContinue(true); + } + if (requestTimeout != null) { + builder.timeout(requestTimeout); + } + + HttpResponse response; + try { + if (this.body != null) { + CompletableFuture> future = this.httpClient.sendAsync(builder.build(), HttpResponse.BodyHandlers.ofInputStream()); + latch.await(); + SubscriptionOutputStream outputStream = reference.get(); + try (outputStream) { + this.body.writeTo(outputStream); + } catch (Throwable t) { + outputStream.cancel(); + outputStream.subscriber.onError(t); + } + response = future.join(); + } else { + response = this.httpClient.send(builder.build(), HttpResponse.BodyHandlers.ofInputStream());; + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new JdkClientHttpResponse(response); + } + + static class SubscriptionOutputStream extends OutputStream implements Flow.Subscription { + + private final Flow.Subscriber subscriber; + + private final Semaphore semaphore = new Semaphore(0); + + private volatile boolean closed; + + SubscriptionOutputStream(Flow.Subscriber subscriber) { + this.subscriber = subscriber; + } + + @Override + public void write(byte[] b) throws IOException {; + if (acquire()) { + subscriber.onNext(ByteBuffer.wrap(b)); + } + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + if (acquire()) { + subscriber.onNext(ByteBuffer.wrap(b, off, len)); + } + } + + @Override + public void write(int b) throws IOException { + if (acquire()) { + subscriber.onNext(ByteBuffer.wrap(new byte[] {(byte) b})); + } + } + + @Override + public void close() throws IOException { + if (!closed) { + closed = true; + subscriber.onComplete(); + } + } + + private boolean acquire() throws IOException { + if (closed) { + throw new IOException("closed"); + } + try { + semaphore.acquire(); + return true; + } catch (InterruptedException e) { + closed = true; + subscriber.onError(e); + return false; + } + } + + @Override + public void request(long n) { + semaphore.release((int) n); + } + + @Override + public void cancel() { + closed = true; + semaphore.release(1); + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java b/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java index a999661e733e..7c828a3c98b4 100644 --- a/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java +++ b/spring-web/src/main/java/org/springframework/http/codec/support/BaseDefaultCodecs.java @@ -44,6 +44,7 @@ import org.springframework.http.codec.FormHttpMessageWriter; import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.HttpMessageWriter; +import org.springframework.core.codec.InputStreamDecoder; import org.springframework.http.codec.ResourceHttpMessageReader; import org.springframework.http.codec.ResourceHttpMessageWriter; import org.springframework.http.codec.ServerSentEventHttpMessageReader; @@ -341,6 +342,7 @@ protected void initTypedReaders() { addCodec(this.typedReaders, new DecoderHttpMessageReader<>(new ByteArrayDecoder())); addCodec(this.typedReaders, new DecoderHttpMessageReader<>(new ByteBufferDecoder())); addCodec(this.typedReaders, new DecoderHttpMessageReader<>(new DataBufferDecoder())); + addCodec(this.typedReaders, new DecoderHttpMessageReader<>(new InputStreamDecoder())); if (nettyByteBufPresent) { addCodec(this.typedReaders, new DecoderHttpMessageReader<>(new NettyByteBufDecoder())); } diff --git a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java index 31c7c4224cba..a2f70424b4e1 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java @@ -40,6 +40,7 @@ import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.client.JdkClientHttpRequestFactory; import org.springframework.http.client.support.InterceptingHttpAccessor; import org.springframework.http.converter.ByteArrayHttpMessageConverter; import org.springframework.http.converter.GenericHttpMessageConverter; @@ -203,6 +204,7 @@ else if (kotlinSerializationJsonPresent) { * @param requestFactory the HTTP request factory to use * @see org.springframework.http.client.SimpleClientHttpRequestFactory * @see org.springframework.http.client.HttpComponentsClientHttpRequestFactory + * @see JdkClientHttpRequestFactory */ public RestTemplate(ClientHttpRequestFactory requestFactory) { this(); diff --git a/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java new file mode 100644 index 000000000000..6c58ed53d17c --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/JdkClientHttpRequestFactoryTests.java @@ -0,0 +1,26 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +public class JdkClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTests { + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + return new JdkClientHttpRequestFactory(); + } + +} diff --git a/spring-web/src/test/java/org/springframework/http/client/StreamingJdkClientHttpRequestFactoryTests.java b/spring-web/src/test/java/org/springframework/http/client/StreamingJdkClientHttpRequestFactoryTests.java new file mode 100644 index 000000000000..78089fd62625 --- /dev/null +++ b/spring-web/src/test/java/org/springframework/http/client/StreamingJdkClientHttpRequestFactoryTests.java @@ -0,0 +1,28 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.client; + +public class StreamingJdkClientHttpRequestFactoryTests extends AbstractHttpRequestFactoryTests { + + @Override + protected ClientHttpRequestFactory createRequestFactory() { + JdkClientHttpRequestFactory requestFactory = new JdkClientHttpRequestFactory(); + requestFactory.setBufferRequestBody(false); + return requestFactory; + } + +} diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java index b1076e844e20..bc0fb2005d25 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java @@ -46,10 +46,7 @@ import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; -import org.springframework.http.client.ClientHttpRequestFactory; -import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; -import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; -import org.springframework.http.client.SimpleClientHttpRequestFactory; +import org.springframework.http.client.*; import org.springframework.http.converter.FormHttpMessageConverter; import org.springframework.http.converter.json.MappingJacksonValue; import org.springframework.util.LinkedMultiValueMap; @@ -93,7 +90,8 @@ static Stream> clientHttpRequestFactories() { return Stream.of( named("JDK", new SimpleClientHttpRequestFactory()), named("HttpComponents", new HttpComponentsClientHttpRequestFactory()), - named("OkHttp", new OkHttp3ClientHttpRequestFactory()) + named("OkHttp", new OkHttp3ClientHttpRequestFactory()), + named("JDKClient", new JdkClientHttpRequestFactory()) ); } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java index 419ae2ceb417..8b308248e587 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java @@ -16,13 +16,25 @@ package org.springframework.web.reactive.function; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.util.List; +import java.util.Objects; import java.util.Optional; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -88,6 +100,101 @@ private static BodyExtractor, ReactiveHttpInputMessage> toMono(Resol skipBodyAsMono(inputMessage)); } + /** + * Variant of {@link BodyExtractors#toMono(InputStreamMapper, boolean, int)} with a + * default buffer size and fast failure. + * @param streamMapper the mapper that is reading the body + * @return {@code BodyExtractor} that reads the response body as input stream + * @param the type of the value that is resolved from the returned stream + */ + public static BodyExtractor, ReactiveHttpInputMessage> toMono( + InputStreamMapper streamMapper) { + return toMono(streamMapper, true); + } + + /** + * Variant of {@link BodyExtractors#toMono(InputStreamMapper, boolean, int)} with a + * default buffer size. + * @param streamMapper the mapper that is reading the body + * @param failFast {@code false} if previously read bytes are discarded upon an error + * @return {@code BodyExtractor} that reads the response body as input stream + * @param the type of the value that is resolved from the returned stream + */ + public static BodyExtractor, ReactiveHttpInputMessage> toMono( + InputStreamMapper streamMapper, + boolean failFast) { + return toMono(streamMapper, failFast, 256 * 1024, true); + } + + /** + * Extractor where the response body is processed by reading an input stream of the + * response body. + * @param streamMapper the mapper that is reading the body + * @param failFast {@code false} if previously read bytes are discarded upon an error + * @param maximumMemorySize the amount of memory that is buffered until reading is suspended + * @return {@code BodyExtractor} that reads the response body as input stream + * @param the type of the value that is resolved from the returned stream + */ + public static BodyExtractor, ReactiveHttpInputMessage> toMono( + InputStreamMapper streamMapper, + boolean failFast, + int maximumMemorySize) { + return toMono(streamMapper, failFast, maximumMemorySize, true); + } + + static BodyExtractor, ReactiveHttpInputMessage> toInputStream() { + return toMono(stream -> stream, true, 256 * 1024, false); + } + + private static BodyExtractor, ReactiveHttpInputMessage> toMono( + InputStreamMapper streamMapper, + boolean failFast, + int maximumMemorySize, + boolean close) { + + Assert.notNull(streamMapper, "'streamMapper' must not be null"); + Assert.isTrue(maximumMemorySize > 0, "'maximumMemorySize' must be positive"); + return (inputMessage, context) -> { + FlowBufferInputStream inputStream = new FlowBufferInputStream(maximumMemorySize, failFast); + try { + inputMessage.getBody().subscribe(inputStream); + T value = streamMapper.apply(inputStream); + if (close) { + inputStream.close(); + } + return Mono.just(value); + } catch (Throwable t) { + try { + inputStream.close(); + } catch (Throwable suppressed) { + t.addSuppressed(suppressed); + } + return Mono.error(t); + } + }; + } + + /** + * Variant of {@link BodyExtractors#toMono(InputStreamMapper, boolean, int, boolean)} with a + * default buffer size. + * @param streamSupplier the supplier of the output stream + * @return {@code BodyExtractor} that reads the response body as input stream + */ + public static BodyExtractor, ReactiveHttpInputMessage> toMono( + Supplier streamSupplier) { + + Assert.notNull(streamSupplier, "'streamSupplier' must not be null"); + return (inputMessage, context) -> { + try (OutputStream outputStream = streamSupplier.get()) { + Flux writeResult = DataBufferUtils.write(inputMessage.getBody(), outputStream); + writeResult.blockLast(); + return Mono.empty(); + } catch (Throwable t) { + return Mono.error(t); + } + }; + } + /** * Extractor to decode the input content into {@code Flux}. * @param elementClass the class of the element type to decode to @@ -277,4 +384,219 @@ private static Flux consumeAndCancel(ReactiveHttpInputMessage messag }); } + @FunctionalInterface + public interface InputStreamMapper { + + T apply(InputStream stream) throws IOException; + } + + static class FlowBufferInputStream extends InputStream implements Subscriber { + + private static final Object END = new Object(); + + private final AtomicBoolean closed = new AtomicBoolean(); + + private final BlockingQueue backlog; + + private final int maximumMemorySize; + + private final boolean failFast; + + private final AtomicInteger buffered = new AtomicInteger(); + + @Nullable + private InputStreamWithSize current = new InputStreamWithSize(0, InputStream.nullInputStream()); + + @Nullable + private Subscription subscription; + + FlowBufferInputStream(int maximumMemorySize, boolean failFast) { + this.backlog = new LinkedBlockingDeque<>(); + this.maximumMemorySize = maximumMemorySize; + this.failFast = failFast; + } + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + if (this.closed.get()) { + subscription.cancel(); + } else { + subscription.request(1); + } + } + + @Override + public void onNext(DataBuffer buffer) { + if (this.closed.get()) { + DataBufferUtils.release(buffer); + return; + } + int readableByteCount = buffer.readableByteCount(); + int current = this.buffered.addAndGet(readableByteCount); + if (current < this.maximumMemorySize) { + this.subscription.request(1); + } + InputStream stream = buffer.asInputStream(true); + this.backlog.add(new InputStreamWithSize(readableByteCount, stream)); + if (this.closed.get()) { + DataBufferUtils.release(buffer); + } + } + + @Override + public void onError(Throwable throwable) { + if (failFast) { + Object next; + while ((next = this.backlog.poll()) != null) { + if (next instanceof InputStreamWithSize) { + try { + ((InputStreamWithSize) next).inputStream.close(); + } catch (Throwable t) { + throwable.addSuppressed(t); + } + } + } + } + this.backlog.add(throwable); + } + + @Override + public void onComplete() { + this.backlog.add(END); + } + + private boolean forward() throws IOException { + this.current.inputStream.close(); + try { + Object next = this.backlog.take(); + if (next == END) { + this.current = null; + return true; + } else if (next instanceof RuntimeException) { + close(); + throw (RuntimeException) next; + } else if (next instanceof IOException) { + close(); + throw (IOException) next; + } else if (next instanceof Throwable) { + close(); + throw new IllegalStateException((Throwable) next); + } else { + int buffer = buffered.addAndGet(-this.current.size); + if (buffer < this.maximumMemorySize) { + this.subscription.request(1); + } + this.current = (InputStreamWithSize) next; + return false; + } + } catch (InterruptedException e) { + throw new IllegalStateException(e); + } + } + + @Override + public int read() throws IOException { + if (this.closed.get()) { + throw new IOException("closed"); + } else if (this.current == null) { + return -1; + } + int read; + while ((read = this.current.inputStream.read()) == -1) { + if (forward()) { + return -1; + } + } + return read; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + Objects.checkFromIndexSize(off, len, b.length); + if (this.closed.get()) { + throw new IOException("closed"); + } else if (this.current == null) { + return -1; + } + int sum = 0; + do { + int read = this.current.inputStream.read(b, off + sum, len - sum); + if (read == -1) { + if (this.backlog.isEmpty()) { + return sum; + } else if (forward()) { + return sum == 0 ? -1 : sum; + } + } else { + sum += read; + } + } while (sum < len); + return sum; + } + + @Override + public int available() throws IOException { + if (this.closed.get()) { + throw new IOException("closed"); + } else if (this.current == null) { + return 0; + } + int available = this.current.inputStream.available(); + for (Object value : this.backlog) { + if (value instanceof InputStreamWithSize) { + available += ((InputStreamWithSize) value).inputStream.available(); + } else { + break; + } + } + return available; + } + + @Override + public void close() throws IOException { + if (this.closed.compareAndSet(false, true)) { + if (this.subscription != null) { + this.subscription.cancel(); + } + IOException exception = null; + if (this.current != null) { + try { + this.current.inputStream.close(); + } catch (IOException e) { + exception = e; + } + } + for (Object value : this.backlog) { + if (value instanceof InputStreamWithSize) { + try { + ((InputStreamWithSize) value).inputStream.close(); + } catch (IOException e) { + if (exception == null) { + exception = e; + } else { + exception.addSuppressed(e); + } + } + } + } + if (exception != null) { + throw exception; + } + } + } + } + + static class InputStreamWithSize { + + final int size; + + final InputStream inputStream; + + InputStreamWithSize(int size, InputStream inputStream) { + this.size = size; + this.inputStream = inputStream; + } + } + } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java index fe951e0f59bb..2f911bef89be 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyInserters.java @@ -16,10 +16,16 @@ package org.springframework.web.reactive.function; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; import java.util.stream.Collectors; import org.reactivestreams.Publisher; +import org.springframework.core.io.buffer.DataBufferFactory; import reactor.core.publisher.Mono; import org.springframework.core.ParameterizedTypeReference; @@ -40,6 +46,7 @@ import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import reactor.core.publisher.Sinks; /** * Static factory methods for {@link BodyInserter} implementations. @@ -358,6 +365,90 @@ public static > BodyInserter outputMessage.writeWith(publisher); } + /** + * Inserter where the request body is written to an output stream. The stream is closed + * automatically if it is not closed manually. + * @param consumer the consumer that is writing to the output stream + * @return the inserter to write directly to the body via an output stream + */ + public static BodyInserter fromOutputStream( + FromOutputStream consumer) { + + Assert.notNull(consumer, "'publisher' must not be null"); + return (outputMessage, context) -> { + Sinks.Many sink = Sinks.many() + .unicast() + .onBackpressureBuffer(); + + Mono mono = outputMessage.writeWith(sink.asFlux()); + WriterOutputStream outputStream = new WriterOutputStream(outputMessage.bufferFactory(), sink); + try { + consumer.accept(outputStream); + } catch (Throwable t) { + sink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); + } + outputStream.close(); + + return mono; + }; + } + + /** + * Variant of {@link BodyInserters#fromInputStream(Supplier, int)} that uses + * a default chunk size. + * @param streamSupplier the supplier that is supplying the stream to write + * @return the inserter to write the request body from a supplied input stream + */ + public static BodyInserter fromInputStream( + Supplier streamSupplier) { + return fromInputStream(streamSupplier, 8192); + } + + /** + * Inserter where the request body is read from an input stream. The supplied + * input stream is closed once it is no longer consumed. + * @param streamSupplier the supplier that is supplying the stream to write + * @param chunkSize the size of each chunk that is buffered before sending + * @return the inserter to write the request body from a supplied input stream + */ + public static BodyInserter fromInputStream( + Supplier streamSupplier, int chunkSize) { + + Assert.notNull(streamSupplier, "'streamSupplier' must not be null"); + Assert.state(chunkSize > 0, "'chunkSize' must be a positive number"); + return (outputMessage, context) -> { + Sinks.Many sink = Sinks.many() + .unicast() + .onBackpressureBuffer(); + + DataBufferFactory factory = outputMessage.bufferFactory(); + Mono mono = outputMessage.writeWith(sink.asFlux()); + try { + InputStream inputStream = streamSupplier.get(); + if (inputStream == null) { + sink.emitError(new NullPointerException("inputStream"), Sinks.EmitFailureHandler.FAIL_FAST); + } else { + try (inputStream) { + int length; + byte[] buffer = new byte[chunkSize]; + while ((length = inputStream.read(buffer)) != -1) { + if (length == 0) { + continue; + } + byte[] wrapped = new byte[length]; + System.arraycopy(buffer, 0, wrapped, 0, length); + sink.emitNext(factory.wrap(wrapped), Sinks.EmitFailureHandler.FAIL_FAST); + } + sink.emitComplete(Sinks.EmitFailureHandler.FAIL_FAST); + } + } + } catch (Throwable t) { + sink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); + } + + return mono; + }; + } private static Mono writeWithMessageWriters( M outputMessage, BodyInserter.Context context, Object body, ResolvableType bodyType, @Nullable ReactiveAdapter adapter) { @@ -477,6 +568,20 @@ > MultipartInserter withPublisher(String name, P publi } + /** + * A consumer for an output stream of which the content is written to the request body. + */ + @FunctionalInterface + public interface FromOutputStream { + + /** + * Accepts an output stream which content is written to the request body. + * @param outputStream the output stream that represents the request body + * @throws IOException if an I/O error occurs what aborts the request + */ + void accept(OutputStream outputStream) throws IOException; + + } private static class DefaultFormInserter implements FormInserter { @@ -556,4 +661,45 @@ public Mono insert(ClientHttpRequest outputMessage, Context context) { } } + private static class WriterOutputStream extends OutputStream { + + private final DataBufferFactory factory; + + private final Sinks.Many sink; + + private final AtomicBoolean closed = new AtomicBoolean(); + + private WriterOutputStream(DataBufferFactory factory, Sinks.Many sink) { + this.factory = factory; + this.sink = sink; + } + + @Override + public void write(int b) throws IOException { + if (closed.get()) { + throw new IOException("closed"); + } + DataBuffer buffer = factory.allocateBuffer(1); + buffer.write((byte) (b & 0xFF)); + sink.tryEmitNext(buffer).orThrow(); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + if (closed.get()) { + throw new IOException("closed"); + } + DataBuffer buffer = factory.allocateBuffer(len); + buffer.write(b, off, len); + sink.tryEmitNext(buffer).orThrow(); + } + + @Override + public void close() { + if (closed.compareAndSet(false, true)) { + sink.tryEmitComplete().orThrow(); + } + } + } + } diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequest.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequest.java new file mode 100644 index 000000000000..56db0ac35b87 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.client.AbstractBufferingClientHttpRequest; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.net.URI; + +/** + * {@link ClientHttpRequest} implementation based on + * Spring's {@link WebClient}. + * + *

Created via the {@link WebClientHttpRequestFactory}. + */ +final class WebClientHttpRequest extends AbstractBufferingClientHttpRequest { + + private final WebClient webClient; + + private final HttpMethod method; + + private final URI uri; + + WebClientHttpRequest(WebClient client, HttpMethod method, URI uri) { + this.webClient = client; + this.method = method; + this.uri = uri; + } + + + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + @Deprecated + public String getMethodValue() { + return this.method.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) throws IOException { + WebClient.RequestHeadersSpec request = this.webClient.method(this.method) + .uri(this.uri) + .bodyValue(bufferedOutput.length == 0 + ? BodyInserters.empty() + : BodyInserters.fromValue(bufferedOutput)); + + request.headers(value -> value.addAll(headers)); + + @SuppressWarnings("deprecation") + ClientResponse response = request.exchange().block(); + return new WebClientHttpResponse(response); + } +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequestFactory.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequestFactory.java new file mode 100644 index 000000000000..1597cec8c2ca --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpRequestFactory.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import org.springframework.http.HttpMethod; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.util.Assert; +import org.springframework.web.reactive.function.client.WebClient; + +import java.io.IOException; +import java.net.URI; + +/** + * {@link ClientHttpRequestFactory} implementation that + * uses Spring's {@link WebClient}. + */ +public class WebClientHttpRequestFactory implements ClientHttpRequestFactory { + + private WebClient webClient; + + private boolean bufferRequestBody = true; + + + /** + * Create a new instance of the {@code WebClientHttpRequestFactory} + * with a default {@link WebClient} based on system properties. + */ + public WebClientHttpRequestFactory() { + this.webClient = WebClient.create(); + } + + /** + * Create a new instance of the {@code WebClientHttpRequestFactory} + * with the given {@link WebClient} instance. + * @param webClient the HttpClient instance to use for this request factory + */ + public WebClientHttpRequestFactory(WebClient webClient) { + this.webClient = webClient; + } + + /** + * Set the {@code HttpClient} used for + * {@linkplain #createRequest(URI, HttpMethod) synchronous execution}. + */ + public void setHttpClient(WebClient webClient) { + Assert.notNull(webClient, "WebClient must not be null"); + this.webClient = webClient; + } + + /** + * Return the {@code HttpClient} used for + * {@linkplain #createRequest(URI, HttpMethod) synchronous execution}. + */ + public WebClient getHttpClient() { + return this.webClient; + } + + /** + * Indicates whether this request factory should buffer the request body internally. + *

Default is {@code true}. When sending large amounts of data via POST or PUT, it is + * recommended to change this property to {@code false}, so as not to run out of memory. + * @since 4.0 + */ + public void setBufferRequestBody(boolean bufferRequestBody) { + this.bufferRequestBody = bufferRequestBody; + } + + @Override + public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException { + WebClient client = getHttpClient(); + + if (this.bufferRequestBody) { + return new WebClientHttpRequest(client, httpMethod, uri); + } + else { + return new WebClientStreamingHttpRequest(client, httpMethod, uri); + } + } +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpResponse.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpResponse.java new file mode 100644 index 000000000000..eaa9a1247804 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientHttpResponse.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatusCode; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.web.reactive.function.client.ClientResponse; + +import java.io.IOException; +import java.io.InputStream; +/** + * {@link ClientHttpResponse} implementation based on + * Spring's web client. + * + *

Created via the {@link WebClientHttpRequest}. + */ +final class WebClientHttpResponse implements ClientHttpResponse { + + private final ClientResponse response; + + @Nullable + private HttpHeaders headers; + + + WebClientHttpResponse(ClientResponse response) { + this.response = response; + } + + + @Override + public HttpStatusCode getStatusCode() throws IOException { + return this.response.statusCode(); + } + + @Override + @Deprecated + public int getRawStatusCode() throws IOException { + return this.response.statusCode().value(); + } + + @Override + public String getStatusText() throws IOException { + return ""; + } + + @Override + public HttpHeaders getHeaders() { + if (this.headers == null) { + this.headers = this.response.headers().asHttpHeaders(); + } + return this.headers; + } + + @Override + public InputStream getBody() throws IOException { + return this.response.body(BodyExtractors.toInputStream()).block(); + } + + @Override + public void close() { + try { + this.response.releaseBody().block(); + } + catch (Exception ex) { + // Ignore exception on close... + } + } + +} diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientStreamingHttpRequest.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientStreamingHttpRequest.java new file mode 100644 index 000000000000..9b68ecb46b75 --- /dev/null +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/WebClientStreamingHttpRequest.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.StreamingHttpOutputMessage; +import org.springframework.http.client.AbstractClientHttpRequest; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.lang.Nullable; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.io.OutputStream; +import java.net.URI; + +/** + * {@link ClientHttpRequest} implementation based on + * Spring's web client in streaming mode. + * + *

Created via the {@link WebClientHttpRequestFactory}. + */ +final class WebClientStreamingHttpRequest extends AbstractClientHttpRequest + implements StreamingHttpOutputMessage { + + private final WebClient webClient; + + private final HttpMethod method; + + private final URI uri; + + @Nullable + private Body body; + + WebClientStreamingHttpRequest(WebClient client, HttpMethod method, URI uri) { + this.webClient = client; + this.method = method; + this.uri = uri; + } + + @Override + public HttpMethod getMethod() { + return this.method; + } + + @Override + @Deprecated + public String getMethodValue() { + return this.method.name(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public void setBody(Body body) { + assertNotExecuted(); + this.body = body; + } + + @Override + protected OutputStream getBodyInternal(HttpHeaders headers) throws IOException { + throw new UnsupportedOperationException("getBody not supported"); + } + + @Override + protected ClientHttpResponse executeInternal(HttpHeaders headers) throws IOException { + WebClient.RequestHeadersSpec request = this.webClient.method(this.method) + .uri(this.uri) + .bodyValue(this.body == null + ? BodyInserters.empty() + : BodyInserters.fromOutputStream(outputStream -> this.body.writeTo(outputStream))); + + request.headers(value -> value.addAll(headers)); + + @SuppressWarnings("deprecation") + ClientResponse response = request.exchange().block(); + return new WebClientHttpResponse(response); + } +} diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java index eea4ef3a8ce7..a460dc94aa74 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java @@ -16,6 +16,8 @@ package org.springframework.web.reactive.function; +import java.io.ByteArrayOutputStream; +import java.io.OutputStream; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -31,6 +33,7 @@ import io.netty.util.IllegalReferenceCountException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.web.reactive.function.client.WebClient; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -409,6 +412,44 @@ public void toDataBuffers() { .verify(); } + @Test + void toMonoInputStream() { + BodyExtractor, ReactiveHttpInputMessage> extractor = BodyExtractors.toMono( + stream -> new String(stream.readAllBytes(), StandardCharsets.UTF_8)); + + byte[] bytes = "foo".getBytes(StandardCharsets.UTF_8); + DefaultDataBuffer dataBuffer = DefaultDataBufferFactory.sharedInstance.wrap(ByteBuffer.wrap(bytes)); + Flux body = Flux.just(dataBuffer); + + MockServerHttpRequest request = MockServerHttpRequest.post("/").body(body); + Mono result = extractor.extract(request, this.context); + + StepVerifier.create(result) + .expectNext("foo") + .expectComplete() + .verify(); + } + + @Test + void toMonoOutputStream() { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + BodyExtractor, ReactiveHttpInputMessage> extractor = BodyExtractors.toMono( + () -> outputStream); + + byte[] bytes = "foo".getBytes(StandardCharsets.UTF_8); + DefaultDataBuffer dataBuffer = DefaultDataBufferFactory.sharedInstance.wrap(ByteBuffer.wrap(bytes)); + Flux body = Flux.just(dataBuffer); + + MockServerHttpRequest request = MockServerHttpRequest.post("/").body(body); + Mono result = extractor.extract(request, this.context); + + StepVerifier.create(result) + .expectComplete() + .verify(); + + assertThat(outputStream.toString(StandardCharsets.UTF_8)).isEqualTo("foo"); + } + @Test // SPR-17054 public void unsupportedMediaTypeShouldConsumeAndCancel() { NettyDataBufferFactory factory = new NettyDataBufferFactory(new PooledByteBufAllocator(true)); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyInsertersTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyInsertersTests.java index 71124300d96e..a137ec560180 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyInsertersTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyInsertersTests.java @@ -16,6 +16,7 @@ package org.springframework.web.reactive.function; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.net.URI; import java.nio.ByteBuffer; @@ -437,6 +438,49 @@ public void ofDataBuffers() { .verify(); } + @Test + void fromOutputStream() { + byte[] bytes = "foo".getBytes(UTF_8); + + BodyInserter inserter = BodyInserters.fromOutputStream( + outputStream -> outputStream.write(bytes)); + + MockServerHttpResponse response = new MockServerHttpResponse(); + Mono result = inserter.insert(response, this.context); + StepVerifier.create(result).expectComplete().verify(); + + StepVerifier.create(response.getBody()) + .consumeNextWith(dataBuffer -> { + byte[] resultBytes = new byte[dataBuffer.readableByteCount()]; + dataBuffer.read(resultBytes); + DataBufferUtils.release(dataBuffer); + assertThat(resultBytes).isEqualTo(bytes); + }) + .expectComplete() + .verify(); + } + + @Test + void fromInputStream() { + byte[] bytes = "foo".getBytes(UTF_8); + + BodyInserter inserter = BodyInserters.fromInputStream( + () -> new ByteArrayInputStream(bytes)); + + MockServerHttpResponse response = new MockServerHttpResponse(); + Mono result = inserter.insert(response, this.context); + StepVerifier.create(result).expectComplete().verify(); + + StepVerifier.create(response.getBody()) + .consumeNextWith(dataBuffer -> { + byte[] resultBytes = new byte[dataBuffer.readableByteCount()]; + dataBuffer.read(resultBytes); + DataBufferUtils.release(dataBuffer); + assertThat(resultBytes).isEqualTo(bytes); + }) + .expectComplete() + .verify(); + } interface SafeToSerialize {}