novoland
1/25/2016 - 8:24 AM

JSR 349 helper

JSR 349 helper

package com.meituan.show.sell.utils;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import javax.validation.Constraint;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.commons.lang3.reflect.MethodUtils;
import org.springframework.context.MessageSource;
import org.springframework.util.ReflectionUtils;
import org.springframework.validation.BindException;
import org.springframework.validation.FieldError;
import org.springframework.validation.SmartValidator;

import com.meituan.show.sell.web.exception.BeanInvalidException;

/**
 * Created by liujing21 on 15/7/19.
 */
public class ValidationUtils {
    private final static Map<String, Set<Field>> VALIDATION_GROUP_FIELDS_CACHE = new ConcurrentHashMap<>();

    private ValidationUtils() {
    }

    /**
     * JSR349 校验bean
     * 
     * @param validator
     * @param messageSource
     * @param bean
     * @param hints
     * @throws BeanInvalidException
     *             校验不通过时抛出
     */
    public static void validate(Object bean, SmartValidator validator, MessageSource messageSource, Object... hints) {
        BindException result = new BindException(bean, "");
        validator.validate(bean, result, hints);
        if (result.hasErrors()) {
            List<FieldError> errs = result.getFieldErrors();
            if (CollectionUtils.isNotEmpty(errs)) {
                throw new BeanInvalidException(messageSource, errs.toArray(new FieldError[errs.size()]));
            }
        }
    }

    /**
     * 获取实体类中,属于指定验证组的字段.字段满足如下条件之一:
     * 1.字段属于指定验证组
     * 2.字段的某个验证组的外部类等于指定验证组类
     * 
     * @param beanClass
     *            实体类
     * @param validationGroupClass
     *            验证组类
     * @return
     */
    public static Set<Field> getFieldsByGroup(Class<?> beanClass, Class<?> validationGroupClass) {
        Set<Field> result;
        String key = beanClass.getName() + "-" + validationGroupClass.getName();
        result = VALIDATION_GROUP_FIELDS_CACHE.get(key);
        if (result != null) {
            return result;
        }

        result = new HashSet<>();
        List<Field> allFields = FieldUtils.getAllFieldsList(beanClass);
        nextField: for (Field f : allFields) {
            Annotation[] annotations = f.getAnnotations();
            if (annotations != null && annotations.length > 0) {
                for (Annotation a : annotations) {
                    if (a.annotationType().isAnnotationPresent(Constraint.class)) {
                        Method m = MethodUtils.getAccessibleMethod(a.annotationType(), "groups");
                        if (m == null) {
                            continue nextField;
                        }
                        Class<?>[] groups = (Class[]) ReflectionUtils.invokeMethod(m, a);
                        for (Class<?> group : groups) {
                            if (group == validationGroupClass || group.getDeclaringClass() == validationGroupClass) {
                                result.add(f);
                                continue nextField;
                            }
                        }
                    }
                }
            }
        }
        if (result.size() > 0) {
            VALIDATION_GROUP_FIELDS_CACHE.put(key, result);
            return result;
        } else {
            return null;
        }
    }

    /**
     * 清空实体中属于validationGroup验证组但不属于excludeValidationGroup验证组的字段
     * 
     * @param bean
     * @param validationGroup
     * @param excludeValidationGroup
     */
    public static void clear(Object bean, Class<?> validationGroup, Class<?>... excludeValidationGroup) {
        Set<Field> fields = getFieldsByGroup(bean.getClass(), validationGroup);
        if (fields == null || fields.size() == 0) {
            return;
        }
        for (Class<?> group : excludeValidationGroup) {
            fields.removeAll(getFieldsByGroup(bean.getClass(), group));
        }
        for (Field field : fields) {
            try {
                FieldUtils.writeField(field, bean, null, true);
            } catch (IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }
    }
}
package com.meituan.show.sell.web.exception;

import java.util.List;

import org.apache.commons.collections.ListUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.context.MessageSource;
import org.springframework.http.HttpStatus;
import org.springframework.validation.FieldError;

import com.google.common.collect.Lists;

/**
 * Created by liujing21 on 15/7/27.
 */
public class BeanInvalidException extends APIException{
    private static final long serialVersionUID = 9172883646594494205L;

    private MessageSource messageSource;

    public BeanInvalidException(MessageSource messageSource, FieldError[] errors) {
        super("%s", HttpStatus.BAD_REQUEST.value(), (Object[]) errors);
        this.messageSource = messageSource;
    }

    @Override
    @SuppressWarnings("unchecked")
    protected List<Object> getParams() {
        if(super.getParams() == null){
            return ListUtils.EMPTY_LIST;
        }
        List<String> errors = Lists.newLinkedList();
        for(Object o : super.getParams()){
            FieldError err = (FieldError) o;
            errors.add(messageSource.getMessage(err, null));
        }

        List<Object> params = Lists.newLinkedList();
        params.add(StringUtils.join(errors, ","));
        return params;
    }
}