Skip to content

Parse parts in MockMultipartHttpServletRequestBuilder #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import org.springframework.web.context.request.async.DeferredResult;
import org.springframework.web.context.request.async.DeferredResultProcessingInterceptor;
import org.springframework.web.context.request.async.WebAsyncUtils;
import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.ModelAndView;
Expand Down Expand Up @@ -68,10 +67,6 @@ public TestDispatcherServlet(WebApplicationContext webApplicationContext) {
protected void service(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {

if (!request.getParts().isEmpty()) {
request = new StandardMultipartHttpServletRequest(request);
}

registerAsyncResultInterceptors(request);

super.service(request, response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,28 @@
package org.springframework.test.web.servlet.request;

import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import javax.servlet.ServletContext;
import javax.servlet.http.Part;

import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.lang.Nullable;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.mock.web.MockMultipartHttpServletRequest;
import org.springframework.mock.web.MockPart;
import org.springframework.util.Assert;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.multipart.MultipartFile;

/**
* Default builder for {@link MockMultipartHttpServletRequest}.
Expand Down Expand Up @@ -141,26 +144,47 @@ public Object merge(@Nullable Object parent) {
@Override
protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) {
MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext);
this.files.forEach(file -> request.addPart(toMockPart(file)));
this.parts.values().stream().flatMap(Collection::stream).forEach(request::addPart);
return request;
}

private MockPart toMockPart(MockMultipartFile file) {
byte[] bytes = null;
if (!file.isEmpty()) {
this.files.forEach(request::addFile);
this.parts.values().stream().flatMap(Collection::stream).forEach(part -> {
request.addPart(part);
try {
bytes = file.getBytes();
MultipartFile file = asMultipartFile(part);
if (file != null) {
request.addFile(file);
return;
}
String value = toParameterValue(part);
if (value != null) {
request.addParameter(part.getName(), toParameterValue(part));
}
}
catch (IOException ex) {
throw new IllegalStateException("Unexpected IOException", ex);
throw new IllegalStateException("Failed to read content for part " + part.getName(), ex);
}
});
return request;
}

@Nullable
private MultipartFile asMultipartFile(Part part) throws IOException {
String name = part.getName();
String filename = part.getSubmittedFileName();
if (filename != null) {
return new MockMultipartFile(name, filename, part.getContentType(), part.getInputStream());
}
MockPart part = new MockPart(file.getName(), file.getOriginalFilename(), bytes);
if (file.getContentType() != null) {
part.getHeaders().set(HttpHeaders.CONTENT_TYPE, file.getContentType());
return null;
}

@Nullable
private String toParameterValue(Part part) throws IOException {
String rawType = part.getContentType();
MediaType mediaType = (rawType != null ? MediaType.parseMediaType(rawType) : MediaType.TEXT_PLAIN);
if (!mediaType.isCompatibleWith(MediaType.TEXT_PLAIN)) {
return null;
}
return part;
Charset charset = (mediaType.getCharset() != null ? mediaType.getCharset() : StandardCharsets.UTF_8);
InputStreamReader reader = new InputStreamReader(part.getInputStream(), charset);
return FileCopyUtils.copyToString(reader);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@

package org.springframework.test.web.servlet.request;

import java.nio.charset.StandardCharsets;

import javax.servlet.http.Part;

import org.junit.jupiter.api.Test;

import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.mock.web.MockMultipartHttpServletRequest;
import org.springframework.mock.web.MockPart;
import org.springframework.mock.web.MockServletContext;
import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest;

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;

/**
Expand All @@ -38,17 +38,32 @@
public class MockMultipartHttpServletRequestBuilderTests {

@Test // gh-26166
void addFilesAndParts() throws Exception {
MockHttpServletRequest mockRequest = new MockMultipartHttpServletRequestBuilder("/upload")
.file(new MockMultipartFile("file", "test.txt", "text/plain", "Test".getBytes(StandardCharsets.UTF_8)))
.part(new MockPart("data", "{\"node\":\"node\"}".getBytes(StandardCharsets.UTF_8)))
.buildRequest(new MockServletContext());
void addFileAndParts() throws Exception {
MockMultipartHttpServletRequest mockRequest =
(MockMultipartHttpServletRequest) new MockMultipartHttpServletRequestBuilder("/upload")
.file(new MockMultipartFile("file", "test.txt", "text/plain", "Test".getBytes(UTF_8)))
.part(new MockPart("name", "value".getBytes(UTF_8)))
.buildRequest(new MockServletContext());

assertThat(mockRequest.getFileMap()).containsOnlyKeys("file");
assertThat(mockRequest.getParameterMap()).containsOnlyKeys("name");
assertThat(mockRequest.getParts()).extracting(Part::getName).containsExactly("name");
}

@Test // gh-26261
void addFileWithoutFilename() throws Exception {
MockPart jsonPart = new MockPart("data", "{\"node\":\"node\"}".getBytes(UTF_8));
jsonPart.getHeaders().setContentType(MediaType.APPLICATION_JSON);

StandardMultipartHttpServletRequest parsedRequest = new StandardMultipartHttpServletRequest(mockRequest);
MockMultipartHttpServletRequest mockRequest =
(MockMultipartHttpServletRequest) new MockMultipartHttpServletRequestBuilder("/upload")
.file(new MockMultipartFile("file", "Test".getBytes(UTF_8)))
.part(jsonPart)
.buildRequest(new MockServletContext());

assertThat(parsedRequest.getParameterMap()).containsOnlyKeys("data");
assertThat(parsedRequest.getFileMap()).containsOnlyKeys("file");
assertThat(parsedRequest.getParts()).extracting(Part::getName).containsExactly("file", "data");
assertThat(mockRequest.getFileMap()).containsOnlyKeys("file");
assertThat(mockRequest.getParameterMap()).isEmpty();
assertThat(mockRequest.getParts()).extracting(Part::getName).containsExactly("data");
}

@Test
Expand Down