diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java index 041cf24f6e10..8c9a4f9fb7ac 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/TestDispatcherServlet.java @@ -35,7 +35,6 @@ import org.springframework.web.context.request.async.DeferredResult; import org.springframework.web.context.request.async.DeferredResultProcessingInterceptor; import org.springframework.web.context.request.async.WebAsyncUtils; -import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest; import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.HandlerExecutionChain; import org.springframework.web.servlet.ModelAndView; @@ -68,10 +67,6 @@ public TestDispatcherServlet(WebApplicationContext webApplicationContext) { protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { - if (!request.getParts().isEmpty()) { - request = new StandardMultipartHttpServletRequest(request); - } - registerAsyncResultInterceptors(request); super.service(request, response); diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java index 75cb972331be..c0f15fcc81c8 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java @@ -17,7 +17,10 @@ package org.springframework.test.web.servlet.request; import java.io.IOException; +import java.io.InputStreamReader; import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -25,17 +28,17 @@ import javax.servlet.ServletContext; import javax.servlet.http.Part; -import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.lang.Nullable; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockMultipartFile; import org.springframework.mock.web.MockMultipartHttpServletRequest; -import org.springframework.mock.web.MockPart; import org.springframework.util.Assert; +import org.springframework.util.FileCopyUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.multipart.MultipartFile; /** * Default builder for {@link MockMultipartHttpServletRequest}. @@ -141,26 +144,47 @@ public Object merge(@Nullable Object parent) { @Override protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) { MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext); - this.files.forEach(file -> request.addPart(toMockPart(file))); - this.parts.values().stream().flatMap(Collection::stream).forEach(request::addPart); - return request; - } - - private MockPart toMockPart(MockMultipartFile file) { - byte[] bytes = null; - if (!file.isEmpty()) { + this.files.forEach(request::addFile); + this.parts.values().stream().flatMap(Collection::stream).forEach(part -> { + request.addPart(part); try { - bytes = file.getBytes(); + MultipartFile file = asMultipartFile(part); + if (file != null) { + request.addFile(file); + return; + } + String value = toParameterValue(part); + if (value != null) { + request.addParameter(part.getName(), toParameterValue(part)); + } } catch (IOException ex) { - throw new IllegalStateException("Unexpected IOException", ex); + throw new IllegalStateException("Failed to read content for part " + part.getName(), ex); } + }); + return request; + } + + @Nullable + private MultipartFile asMultipartFile(Part part) throws IOException { + String name = part.getName(); + String filename = part.getSubmittedFileName(); + if (filename != null) { + return new MockMultipartFile(name, filename, part.getContentType(), part.getInputStream()); } - MockPart part = new MockPart(file.getName(), file.getOriginalFilename(), bytes); - if (file.getContentType() != null) { - part.getHeaders().set(HttpHeaders.CONTENT_TYPE, file.getContentType()); + return null; + } + + @Nullable + private String toParameterValue(Part part) throws IOException { + String rawType = part.getContentType(); + MediaType mediaType = (rawType != null ? MediaType.parseMediaType(rawType) : MediaType.TEXT_PLAIN); + if (!mediaType.isCompatibleWith(MediaType.TEXT_PLAIN)) { + return null; } - return part; + Charset charset = (mediaType.getCharset() != null ? mediaType.getCharset() : StandardCharsets.UTF_8); + InputStreamReader reader = new InputStreamReader(part.getInputStream(), charset); + return FileCopyUtils.copyToString(reader); } } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java index a7087f424fb0..b1414d2bf4c2 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java @@ -16,19 +16,19 @@ package org.springframework.test.web.servlet.request; -import java.nio.charset.StandardCharsets; - import javax.servlet.http.Part; import org.junit.jupiter.api.Test; import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockMultipartFile; +import org.springframework.mock.web.MockMultipartHttpServletRequest; import org.springframework.mock.web.MockPart; import org.springframework.mock.web.MockServletContext; -import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; /** @@ -38,17 +38,32 @@ public class MockMultipartHttpServletRequestBuilderTests { @Test // gh-26166 - void addFilesAndParts() throws Exception { - MockHttpServletRequest mockRequest = new MockMultipartHttpServletRequestBuilder("/upload") - .file(new MockMultipartFile("file", "test.txt", "text/plain", "Test".getBytes(StandardCharsets.UTF_8))) - .part(new MockPart("data", "{\"node\":\"node\"}".getBytes(StandardCharsets.UTF_8))) - .buildRequest(new MockServletContext()); + void addFileAndParts() throws Exception { + MockMultipartHttpServletRequest mockRequest = + (MockMultipartHttpServletRequest) new MockMultipartHttpServletRequestBuilder("/upload") + .file(new MockMultipartFile("file", "test.txt", "text/plain", "Test".getBytes(UTF_8))) + .part(new MockPart("name", "value".getBytes(UTF_8))) + .buildRequest(new MockServletContext()); + + assertThat(mockRequest.getFileMap()).containsOnlyKeys("file"); + assertThat(mockRequest.getParameterMap()).containsOnlyKeys("name"); + assertThat(mockRequest.getParts()).extracting(Part::getName).containsExactly("name"); + } + + @Test // gh-26261 + void addFileWithoutFilename() throws Exception { + MockPart jsonPart = new MockPart("data", "{\"node\":\"node\"}".getBytes(UTF_8)); + jsonPart.getHeaders().setContentType(MediaType.APPLICATION_JSON); - StandardMultipartHttpServletRequest parsedRequest = new StandardMultipartHttpServletRequest(mockRequest); + MockMultipartHttpServletRequest mockRequest = + (MockMultipartHttpServletRequest) new MockMultipartHttpServletRequestBuilder("/upload") + .file(new MockMultipartFile("file", "Test".getBytes(UTF_8))) + .part(jsonPart) + .buildRequest(new MockServletContext()); - assertThat(parsedRequest.getParameterMap()).containsOnlyKeys("data"); - assertThat(parsedRequest.getFileMap()).containsOnlyKeys("file"); - assertThat(parsedRequest.getParts()).extracting(Part::getName).containsExactly("file", "data"); + assertThat(mockRequest.getFileMap()).containsOnlyKeys("file"); + assertThat(mockRequest.getParameterMap()).isEmpty(); + assertThat(mockRequest.getParts()).extracting(Part::getName).containsExactly("data"); } @Test