Skip to content

Commit

Permalink
Merge pull request #43330 from YongGoose
Browse files Browse the repository at this point in the history
* pr/43330:
  Polish "Wrap 'error' attribute for consistent JSON serialization"
  Wrap 'error' attribute for consistent JSON serialization

Closes gh-43330
  • Loading branch information
philwebb committed Jan 14, 2025
2 parents 1c991a7 + 977279b commit e2a62d6
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 60 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright 2012-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.boot.web.error;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

import org.springframework.context.MessageSourceResolvable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/**
* A wrapper class for {@link MessageSourceResolvable} errors that is safe for JSON
* serialization.
*
* @author Yongjun Hong
* @author Phillip Webb
* @since 3.5.0
*/
public final class Error implements MessageSourceResolvable {

private final MessageSourceResolvable cause;

/**
* Create a new {@code Error} instance with the specified cause.
* @param cause the error cause (must not be {@code null})
*/
private Error(MessageSourceResolvable cause) {
Assert.notNull(cause, "'cause' must not be null");
this.cause = cause;
}

@Override
public String[] getCodes() {
return this.cause.getCodes();
}

@Override
public Object[] getArguments() {
return this.cause.getArguments();
}

@Override
public String getDefaultMessage() {
return this.cause.getDefaultMessage();
}

/**
* Return the original cause of the error.
* @return the error cause
*/
public MessageSourceResolvable getCause() {
return this.cause;
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
return Objects.equals(this.cause, ((Error) obj).cause);
}

@Override
public int hashCode() {
return Objects.hash(this.cause);
}

@Override
public String toString() {
return this.cause.toString();
}

/**
* Wrap the given errors.
* @param errors the errors to wrap
* @return a new Error list
*/
public static List<Error> wrap(List<? extends MessageSourceResolvable> errors) {
if (CollectionUtils.isEmpty(errors)) {
return Collections.emptyList();
}
List<Error> result = new ArrayList<>(errors.size());
for (MessageSourceResolvable error : errors) {
result.add(new Error(error));
}
return List.copyOf(result);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2024 the original author or authors.
* Copyright 2012-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,10 +20,10 @@
import java.io.StringWriter;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.springframework.boot.web.error.Error;
import org.springframework.boot.web.error.ErrorAttributeOptions;
import org.springframework.boot.web.error.ErrorAttributeOptions.Include;
import org.springframework.core.annotation.MergedAnnotation;
Expand All @@ -32,7 +32,6 @@
import org.springframework.http.HttpStatus;
import org.springframework.util.StringUtils;
import org.springframework.validation.BindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.validation.method.MethodValidationResult;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.reactive.function.server.ServerRequest;
Expand All @@ -48,8 +47,8 @@
* <li>error - The error reason</li>
* <li>exception - The class name of the root exception (if configured)</li>
* <li>message - The exception message (if configured)</li>
* <li>errors - Any {@link ObjectError}s from a {@link BindingResult} or
* {@link MethodValidationResult} exception (if configured)</li>
* <li>errors - Any validation errors wrapped in {@link Error}, derived from a
* {@link BindingResult} or {@link MethodValidationResult} exception (if configured)</li>
* <li>trace - The exception stack trace (if configured)</li>
* <li>path - The URL path when the exception was raised</li>
* <li>requestId - Unique ID associated with the current request</li>
Expand All @@ -61,6 +60,7 @@
* @author Scott Frederick
* @author Moritz Halbritter
* @author Yanming Zhou
* @author Yongjun Hong
* @since 2.0.0
* @see ErrorAttributes
*/
Expand Down Expand Up @@ -112,19 +112,20 @@ private void handleException(Map<String, Object> errorAttributes, Throwable erro
MergedAnnotation<ResponseStatus> responseStatusAnnotation, boolean includeStackTrace) {
Throwable exception;
if (error instanceof BindingResult bindingResult) {
errorAttributes.put("message", error.getMessage());
errorAttributes.put("errors", bindingResult.getAllErrors());
exception = error;
errorAttributes.put("message", error.getMessage());
errorAttributes.put("errors", Error.wrap(bindingResult.getAllErrors()));
}
else if (error instanceof MethodValidationResult methodValidationResult) {
addMessageAndErrorsFromMethodValidationResult(errorAttributes, methodValidationResult);
exception = error;
errorAttributes.put("message", getErrorMessage(methodValidationResult));
errorAttributes.put("errors", Error.wrap(methodValidationResult.getAllErrors()));
}
else if (error instanceof ResponseStatusException responseStatusException) {
errorAttributes.put("message", responseStatusException.getReason());
exception = (responseStatusException.getCause() != null) ? responseStatusException.getCause() : error;
errorAttributes.put("message", responseStatusException.getReason());
if (exception instanceof BindingResult bindingResult) {
errorAttributes.put("errors", bindingResult.getAllErrors());
errorAttributes.put("errors", Error.wrap(bindingResult.getAllErrors()));
}
}
else {
Expand All @@ -139,16 +140,9 @@ else if (error instanceof ResponseStatusException responseStatusException) {
}
}

private void addMessageAndErrorsFromMethodValidationResult(Map<String, Object> errorAttributes,
MethodValidationResult result) {
List<ObjectError> errors = result.getAllErrors()
.stream()
.filter(ObjectError.class::isInstance)
.map(ObjectError.class::cast)
.toList();
errorAttributes.put("message",
"Validation failed for method='" + result.getMethod() + "'. Error count: " + errors.size());
errorAttributes.put("errors", errors);
private String getErrorMessage(MethodValidationResult methodValidationResult) {
return "Validation failed for method='%s'. Error count: %s".formatted(methodValidationResult.getMethod(),
methodValidationResult.getAllErrors().size());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2024 the original author or authors.
* Copyright 2012-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,14 +20,14 @@
import java.io.StringWriter;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import jakarta.servlet.RequestDispatcher;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;

import org.springframework.boot.web.error.Error;
import org.springframework.boot.web.error.ErrorAttributeOptions;
import org.springframework.boot.web.error.ErrorAttributeOptions.Include;
import org.springframework.core.Ordered;
Expand All @@ -36,7 +36,6 @@
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
import org.springframework.validation.BindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.validation.method.MethodValidationResult;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.WebRequest;
Expand All @@ -52,8 +51,8 @@
* <li>error - The error reason</li>
* <li>exception - The class name of the root exception (if configured)</li>
* <li>message - The exception message (if configured)</li>
* <li>errors - Any {@link ObjectError}s from a {@link BindingResult} or
* {@link MethodValidationResult} exception (if configured)</li>
* <li>errors - Any validation errors wrapped in {@link Error}, derived from a
* {@link BindingResult} or {@link MethodValidationResult} exception (if configured)</li>
* <li>trace - The exception stack trace (if configured)</li>
* <li>path - The URL path when the exception was raised</li>
* </ul>
Expand All @@ -65,6 +64,7 @@
* @author Scott Frederick
* @author Moritz Halbritter
* @author Yanming Zhou
* @author Yongjun Hong
* @since 2.0.0
* @see ErrorAttributes
*/
Expand Down Expand Up @@ -141,16 +141,27 @@ private void addErrorMessage(Map<String, Object> errorAttributes, WebRequest web
BindingResult bindingResult = extractBindingResult(error);
if (bindingResult != null) {
addMessageAndErrorsFromBindingResult(errorAttributes, bindingResult);
return;
}
else {
MethodValidationResult methodValidationResult = extractMethodValidationResult(error);
if (methodValidationResult != null) {
addMessageAndErrorsFromMethodValidationResult(errorAttributes, methodValidationResult);
}
else {
addExceptionErrorMessage(errorAttributes, webRequest, error);
}
MethodValidationResult methodValidationResult = extractMethodValidationResult(error);
if (methodValidationResult != null) {
addMessageAndErrorsFromMethodValidationResult(errorAttributes, methodValidationResult);
return;
}
addExceptionErrorMessage(errorAttributes, webRequest, error);
}

private void addMessageAndErrorsFromBindingResult(Map<String, Object> errorAttributes, BindingResult result) {
errorAttributes.put("message", "Validation failed for object='%s'. Error count: %s"
.formatted(result.getObjectName(), result.getAllErrors().size()));
errorAttributes.put("errors", Error.wrap(result.getAllErrors()));
}

private void addMessageAndErrorsFromMethodValidationResult(Map<String, Object> errorAttributes,
MethodValidationResult result) {
errorAttributes.put("message", "Validation failed for method='%s'. Error count: %s"
.formatted(result.getMethod(), result.getAllErrors().size()));
errorAttributes.put("errors", Error.wrap(result.getAllErrors()));
}

private void addExceptionErrorMessage(Map<String, Object> errorAttributes, WebRequest webRequest, Throwable error) {
Expand Down Expand Up @@ -182,27 +193,6 @@ protected String getMessage(WebRequest webRequest, Throwable error) {
return "No message available";
}

private void addMessageAndErrorsFromBindingResult(Map<String, Object> errorAttributes, BindingResult result) {
addMessageAndErrorsForValidationFailure(errorAttributes, "object='" + result.getObjectName() + "'",
result.getAllErrors());
}

private void addMessageAndErrorsFromMethodValidationResult(Map<String, Object> errorAttributes,
MethodValidationResult result) {
List<ObjectError> errors = result.getAllErrors()
.stream()
.filter(ObjectError.class::isInstance)
.map(ObjectError.class::cast)
.toList();
addMessageAndErrorsForValidationFailure(errorAttributes, "method='" + result.getMethod() + "'", errors);
}

private void addMessageAndErrorsForValidationFailure(Map<String, Object> errorAttributes, String validated,
List<ObjectError> errors) {
errorAttributes.put("message", "Validation failed for " + validated + ". Error count: " + errors.size());
errorAttributes.put("errors", errors);
}

private BindingResult extractBindingResult(Throwable error) {
if (error instanceof BindingResult bindingResult) {
return bindingResult;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2024 the original author or authors.
* Copyright 2012-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -55,6 +55,7 @@
* @author Scott Frederick
* @author Moritz Halbritter
* @author Yanming Zhou
* @author Yongjun Hong
*/
class DefaultErrorAttributesTests {

Expand Down Expand Up @@ -272,7 +273,8 @@ void extractBindingResultErrors() throws Exception {
.startsWith("Validation failed for argument at index 0 in method: "
+ "int org.springframework.boot.web.reactive.error.DefaultErrorAttributesTests"
+ ".method(java.lang.String), with 1 error(s)");
assertThat(attributes).containsEntry("errors", bindingResult.getAllErrors());
assertThat(attributes).containsEntry("errors",
org.springframework.boot.web.error.Error.wrap(bindingResult.getAllErrors()));
}

@Test
Expand All @@ -287,7 +289,8 @@ void extractBindingResultErrorsThatCausedAResponseStatusException() throws Excep
buildServerRequest(request, new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid", ex)),
ErrorAttributeOptions.of(Include.MESSAGE, Include.BINDING_ERRORS));
assertThat(attributes.get("message")).isEqualTo("Invalid");
assertThat(attributes).containsEntry("errors", bindingResult.getAllErrors());
assertThat(attributes).containsEntry("errors",
org.springframework.boot.web.error.Error.wrap(bindingResult.getAllErrors()));
}

@Test
Expand All @@ -309,7 +312,7 @@ void extractMethodValidationResultErrors() throws Exception {
.isEqualTo(
"Validation failed for method='public java.lang.String java.lang.String.substring(int)'. Error count: 1");
assertThat(attributes).containsEntry("errors",
methodValidationResult.getAllErrors().stream().filter(ObjectError.class::isInstance).toList());
org.springframework.boot.web.error.Error.wrap(methodValidationResult.getAllErrors()));
}

@Test
Expand All @@ -326,6 +329,29 @@ void extractBindingResultErrorsExcludeMessageAndErrors() throws Exception {
assertThat(attributes).doesNotContainKey("errors");
}

@Test
void extractParameterValidationResultErrors() throws Exception {
Object target = "test";
Method method = String.class.getMethod("substring", int.class);
MethodParameter parameter = new MethodParameter(method, 0);
ParameterValidationResult parameterValidationResult = new ParameterValidationResult(parameter, -1,
List.of(new ObjectError("beginIndex", "beginIndex is negative")), null, null, null,
(error, sourceType) -> {
throw new IllegalArgumentException("No source object of the given type");
});
MethodValidationResult methodValidationResult = MethodValidationResult.create(target, method,
List.of(parameterValidationResult));
HandlerMethodValidationException ex = new HandlerMethodValidationException(methodValidationResult);
MockServerHttpRequest request = MockServerHttpRequest.get("/test").build();
Map<String, Object> attributes = this.errorAttributes.getErrorAttributes(buildServerRequest(request, ex),
ErrorAttributeOptions.of(Include.MESSAGE, Include.BINDING_ERRORS));
assertThat(attributes.get("message")).asString()
.isEqualTo(
"Validation failed for method='public java.lang.String java.lang.String.substring(int)'. Error count: 1");
assertThat(attributes).containsEntry("errors",
org.springframework.boot.web.error.Error.wrap(methodValidationResult.getAllErrors()));
}

@Test
void excludeStatus() {
ResponseStatusException error = new ResponseStatusException(HttpStatus.NOT_ACCEPTABLE,
Expand Down
Loading

0 comments on commit e2a62d6

Please sign in to comment.