diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodProcessor.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodProcessor.java index d00b084c0d49..b9ee9a1dce9a 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodProcessor.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/AbstractMessageConverterMethodProcessor.java @@ -19,13 +19,16 @@ import java.io.IOException; import java.lang.reflect.Type; import java.util.ArrayList; -import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Locale; +import java.util.Map; +import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; import jakarta.servlet.ServletRequest; import jakarta.servlet.http.HttpServletRequest; @@ -73,6 +76,7 @@ * @author Rossen Stoyanchev * @author Brian Clozel * @author Juergen Hoeller + * @author Ralf Ueberfuhr * @since 3.1 */ public abstract class AbstractMessageConverterMethodProcessor extends AbstractMessageConverterMethodArgumentResolver @@ -96,8 +100,11 @@ public abstract class AbstractMessageConverterMethodProcessor extends AbstractMe private final ContentNegotiationManager contentNegotiationManager; - private final List problemMediaTypes = - Arrays.asList(MediaType.APPLICATION_PROBLEM_JSON, MediaType.APPLICATION_PROBLEM_XML); + private final Map problemMediaTypesByNormalType = Map.of( + MediaType.APPLICATION_JSON_VALUE, MediaType.APPLICATION_PROBLEM_JSON, + MediaType.APPLICATION_XML_VALUE, MediaType.APPLICATION_PROBLEM_XML + ); + private final Set problemMediaTypes = new HashSet<>(this.problemMediaTypesByNormalType.values()); private final Set safeExtensions = new HashSet<>(); @@ -240,14 +247,28 @@ protected void writeWithMessageConverters(@Nullable T value, MethodParameter "No converter found for return value of type: " + valueType); } - List compatibleMediaTypes = new ArrayList<>(); - determineCompatibleMediaTypes(acceptableTypes, producibleTypes, compatibleMediaTypes); - // For ProblemDetail, fall back on RFC 7807 format - if (compatibleMediaTypes.isEmpty() && ProblemDetail.class.isAssignableFrom(valueType)) { - determineCompatibleMediaTypes(this.problemMediaTypes, producibleTypes, compatibleMediaTypes); + if(ProblemDetail.class.isAssignableFrom(valueType)) { + acceptableTypes = this.replaceNonProblemCompliantMediaTypes(acceptableTypes); + if(acceptableTypes.isEmpty()) { + acceptableTypes = List.of(MediaType.APPLICATION_PROBLEM_JSON); + } + producibleTypes = this.replaceNonProblemCompliantMediaTypes(producibleTypes); + } + else { + var statusCode = outputMessage.getServletResponse().getStatus(); + var status = HttpStatus.resolve(statusCode); + // For error status code, leave RFC 7807 media types, if requested + // otherwise, remove them + if(null == status || !status.isError()) { + acceptableTypes.removeAll(this.problemMediaTypes); + producibleTypes.removeAll(this.problemMediaTypes); + } } + List compatibleMediaTypes = new ArrayList<>(); + determineCompatibleMediaTypes(acceptableTypes, producibleTypes, compatibleMediaTypes); + if (compatibleMediaTypes.isEmpty()) { if (logger.isDebugEnabled()) { logger.debug("No match for " + acceptableTypes + ", supported: " + producibleTypes); @@ -323,6 +344,20 @@ else if (mediaType.isPresentIn(ALL_APPLICATION_MEDIA_TYPES)) { } } + private List replaceNonProblemCompliantMediaTypes(Collection types) { + return types.stream() + .map( + mediaType -> this.problemMediaTypes.stream() + .noneMatch(problemType -> problemType.isCompatibleWith(mediaType)) + ? + this.problemMediaTypesByNormalType.get(mediaType.toString()) + : + mediaType + ) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + } + /** * Return the type of the value to be written to the response. Typically this is * a simple check via getClass on the value but if the value is null, then the diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessorTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessorTests.java index ebcbc87f7cdb..efa3781e0ba9 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessorTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessorTests.java @@ -21,27 +21,36 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeName; import jakarta.servlet.FilterChain; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import jakarta.xml.bind.annotation.XmlRootElement; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.springframework.core.MethodParameter; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.ProblemDetail; import org.springframework.http.ResponseEntity; import org.springframework.http.converter.ByteArrayHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter; import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.http.converter.xml.Jaxb2RootElementHttpMessageConverter; import org.springframework.lang.Nullable; import org.springframework.validation.beanvalidation.LocalValidatorFactoryBean; import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.bind.support.WebDataBinderFactory; @@ -206,6 +215,163 @@ public void handleReturnValueCharSequence() throws Exception { assertThat(servletResponse.getContentAsString()).isEqualTo("Foo"); } + @Nested + class ContentNegotiationTests { + private static class ContentNegotiationCase { + private final List givenAccept; + private final String expectedContentType; + + public ContentNegotiationCase(List givenAccept, String expectedContentType) { + this.givenAccept = givenAccept; + this.expectedContentType = expectedContentType; + } + + public List getGivenAccept() { + return givenAccept; + } + + public String getExpectedContentType() { + return expectedContentType; + } + + @Override + public String toString() { + return givenAccept + "->" + expectedContentType; + } + } + + private static Stream handleContentNegotiationForProblem() { + var result = Stream.of( + new ContentNegotiationCase( + List.of(MediaType.APPLICATION_JSON_VALUE), + MediaType.APPLICATION_PROBLEM_JSON_VALUE), + new ContentNegotiationCase( + List.of(MediaType.APPLICATION_PROBLEM_JSON_VALUE), + MediaType.APPLICATION_PROBLEM_JSON_VALUE), + new ContentNegotiationCase( + List.of(MediaType.APPLICATION_PROBLEM_JSON_VALUE, MediaType.APPLICATION_JSON_VALUE), + MediaType.APPLICATION_PROBLEM_JSON_VALUE), + new ContentNegotiationCase( + List.of("*/*"), + MediaType.APPLICATION_PROBLEM_JSON_VALUE), + new ContentNegotiationCase( + List.of(MediaType.APPLICATION_PDF_VALUE), + MediaType.APPLICATION_PROBLEM_JSON_VALUE), + new ContentNegotiationCase( + List.of("application/*+json"), + MediaType.APPLICATION_PROBLEM_JSON_VALUE), + new ContentNegotiationCase( + List.of(MediaType.TEXT_PLAIN_VALUE, MediaType.APPLICATION_JSON_VALUE), + MediaType.APPLICATION_PROBLEM_JSON_VALUE) + ); + // if ProblemDetail is supported by the XML message converter, add further test cases + if (new Jaxb2RootElementHttpMessageConverter() + .canWrite(ProblemDetail.class, MediaType.APPLICATION_XML)) { + result = Stream.concat(result, Stream.of( + new ContentNegotiationCase( + List.of(MediaType.APPLICATION_XML_VALUE), + MediaType.APPLICATION_PROBLEM_XML_VALUE), + new ContentNegotiationCase( + List.of(MediaType.APPLICATION_PROBLEM_XML_VALUE, MediaType.APPLICATION_XML_VALUE), + MediaType.APPLICATION_PROBLEM_XML_VALUE), + new ContentNegotiationCase( + List.of("application/*+xml"), + MediaType.APPLICATION_PROBLEM_XML_VALUE) + )); + } + return result; + } + + @ParameterizedTest // gh-29588 + @MethodSource + public void handleContentNegotiationForProblem(ContentNegotiationCase contentNegotiationCase) throws Exception { + List> converters = new ArrayList<>(); + converters.add(new MappingJackson2HttpMessageConverter()); + converters.add(new Jaxb2RootElementHttpMessageConverter()); + + Method method = JacksonController.class.getDeclaredMethod("handleException"); + MethodParameter returnType = new MethodParameter(method, -1); + + servletRequest.addHeader( + "Accept", + contentNegotiationCase.getGivenAccept() + .stream() + .collect(Collectors.joining(",")) + ); + + HttpEntityMethodProcessor processor = new HttpEntityMethodProcessor(converters); + processor.writeWithMessageConverters( + ProblemDetail.forStatus(400), + returnType, + webRequest); + + assertThat(servletResponse.getHeader("Content-Type")) + .isEqualTo(contentNegotiationCase.getExpectedContentType()); + } + + private static Stream handleContentNegotiationForProblemAcceptedButNoProblem() { + return Stream.of( + new ContentNegotiationCase( + List.of(MediaType.APPLICATION_JSON_VALUE), + MediaType.APPLICATION_JSON_VALUE), + new ContentNegotiationCase( + List.of(MediaType.APPLICATION_PROBLEM_JSON_VALUE, MediaType.APPLICATION_JSON_VALUE), + MediaType.APPLICATION_JSON_VALUE), + new ContentNegotiationCase( + List.of("*/*"), + MediaType.APPLICATION_JSON_VALUE), + new ContentNegotiationCase( + List.of(MediaType.APPLICATION_PROBLEM_JSON_VALUE, "*/*"), + MediaType.APPLICATION_JSON_VALUE), + new ContentNegotiationCase( + List.of("application/*+json"), + MediaType.APPLICATION_JSON_VALUE), + new ContentNegotiationCase( + List.of(MediaType.APPLICATION_XML_VALUE), + MediaType.APPLICATION_XML_VALUE), + new ContentNegotiationCase( + List.of(MediaType.APPLICATION_PROBLEM_XML_VALUE, MediaType.APPLICATION_XML_VALUE), + MediaType.APPLICATION_XML_VALUE), + new ContentNegotiationCase( + List.of("application/*+xml"), + MediaType.APPLICATION_XML_VALUE), + new ContentNegotiationCase( + List.of(MediaType.APPLICATION_XML_VALUE, MediaType.APPLICATION_JSON_VALUE), + MediaType.APPLICATION_XML_VALUE) + ); + } + + @ParameterizedTest // gh-29588 + @MethodSource + public void handleContentNegotiationForProblemAcceptedButNoProblem(ContentNegotiationCase contentNegotiationCase) throws Exception { + List> converters = new ArrayList<>(); + converters.add(new MappingJackson2HttpMessageConverter()); + converters.add(new Jaxb2RootElementHttpMessageConverter()); + + Method method = HttpEntityMethodProcessorTests.this.getClass().getDeclaredMethod("handle"); + MethodParameter returnType = new MethodParameter(method, -1); + + servletRequest.addHeader( + "Accept", + contentNegotiationCase.getGivenAccept() + .stream() + .collect(Collectors.joining(",")) + ); + + HttpEntityMethodProcessor processor = new HttpEntityMethodProcessor(converters); + processor.writeWithMessageConverters( + new Foo("foo"), + returnType, + webRequest); + + assertThat(servletResponse.getStatus()) + .isEqualTo(200); + assertThat(servletResponse.getHeader("Content-Type")) + .isEqualTo(contentNegotiationCase.getExpectedContentType()); + } + + } + @Test // SPR-13423 public void handleReturnValueWithETagAndETagFilter() throws Exception { String eTagValue = "\"deadb33f8badf00d\""; @@ -354,6 +520,7 @@ public void setParentProperty(String parentProperty) { @JsonTypeName("foo") + @XmlRootElement private static class Foo extends ParentClass { public Foo() { @@ -387,6 +554,13 @@ public HttpEntity> handleList() { list.add(new Bar("bar")); return new HttpEntity<>(list); } + + @ExceptionHandler + public ResponseEntity handleException() { + return ResponseEntity.internalServerError() + .body(ProblemDetail.forStatus(500)); + } + } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestResponseBodyMethodProcessorTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestResponseBodyMethodProcessorTests.java index e7018c11e101..c048747b2126 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestResponseBodyMethodProcessorTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestResponseBodyMethodProcessorTests.java @@ -403,7 +403,7 @@ void problemDetailDefaultMediaType() throws Exception { @Test void problemDetailWhenJsonRequested() throws Exception { this.servletRequest.addHeader("Accept", MediaType.APPLICATION_JSON_VALUE); - testProblemDetailMediaType(MediaType.APPLICATION_JSON_VALUE); + testProblemDetailMediaType(MediaType.APPLICATION_PROBLEM_JSON_VALUE); } @Test