参数校验是个让人很烦的东西,大量重复而又无技术含量代码充斥在所有代码中. Validator就我在项目中的应用来说只能适用于bean对象属性的校验,无法对方法的入参参数做校验.
基于懒的原因自己尝试着写一套注解+反射+aop来实现方法入参校验的功能,最终成功并建议leader在项目中应用该工具.
废话不多说,下面先来讲思路,再来贴代码分析.
思路就是在方法上或者方法的参数上加注解,比如我的注解@NotNull,然后通过aop before来代理需要对方法做参数校验的方法.
然后在aop逻辑中写对应的校验逻辑.
校验逻辑如下:
首先拿到当前方法及方法上的注解
如果方法上有@NotNull注解则表明该方法所有的参数均不能为Null
然后拿到方法的入参参数名称和参数值,一个一个比对是否是Null,如果有一个为null直接抛出RuntimeException
若方法上没有@NotNull注解,则开始下一步.拿到参数上的注解,并分别判断当前参数上是存在@NotNull注解.
如果存在则判断当前参数值是否为null,如果是则直接抛出RuntimeException
贴代码
自定义注解
很简单,没啥说的,支持方法和方法参数上使用
package com.hikedu.api.validation;
import javax.validation.Payload;
import java.lang.annotation.*;
@Target({ElementType.METHOD,ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface NotNull {
String message() default "";
Class<?>[] groups() default {};
Class<? extends Payload>[] payload() default {};
}
校验逻辑
也很简单,没啥说的,不懂的自己google
package com.hikedu.api.aspect;
import com.hikedu.api.validation.NotNull;
import javassist.*;
import javassist.bytecode.CodeAttribute;
import javassist.bytecode.LocalVariableAttribute;
import javassist.bytecode.MethodInfo;
import org.apache.log4j.Logger;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.springframework.stereotype.Component;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
/**
* @author himly [email protected]
*/
@Aspect
@Component
public class ArgumentVerifyAspect {
private static final Logger log = Logger.getLogger(ArgumentVerifyAspect.class);
@Before("execution(* com.hikedu.api.service.impl.*.*(..))")
public void doArgumentVerifyForService(JoinPoint joinPoint) throws Exception{
try{
doArgumentVerify(joinPoint);
}catch (Exception e) {
log.error("has an error,see == " + e.getMessage(),e);
throw new RuntimeException(e.getMessage());
}
}
private void doArgumentVerify(JoinPoint joinPoint) throws Exception{
String methodName = joinPoint.getSignature().getName();
Method method = getMethod(methodName,joinPoint);
Object[] args = joinPoint.getArgs();
String classType = joinPoint.getTarget().getClass().getTypeName();
Class<?> clz = Class.forName(classType);
String clzName = clz.getName();
List<Map<String,Object>> nameAndArgs = getFieldsNameAndValue(this.getClass(),clzName,methodName,args);
Annotation[] methodAnnotations = method.getAnnotations();
if (isAnnotationExists(methodAnnotations,NotNull.class)) {
int i = 0;
while (nameAndArgs.size() > i) {
if (nameAndArgs.get(i).get("value") == null) {
throw new RuntimeException(nameAndArgs.get(i).get("name") + " can not be null");
}
++i;
}
}
Annotation[][] argsAnnotations = method.getParameterAnnotations();
int i = 0;
for (Annotation[] annotation:argsAnnotations) {
boolean isExists = isAnnotationExists(annotation,NotNull.class);
if (isExists) {
if (nameAndArgs.get(i).get("value") == null) {
throw new RuntimeException(nameAndArgs.get(i).get("name") + " can not be null");
}
}
++i;
}
}
private Method getMethod(String name, JoinPoint joinPoint) {
Method[] methods = joinPoint.getTarget().getClass().getMethods();
Method realMethod = null;
for (Method method :methods) {
if (method.getName().equals(name)) {
realMethod = method;
break;
}
}
if (realMethod == null) {
throw new RuntimeException("method not found");
}
return realMethod;
}
public boolean isAnnotationExists(Annotation[] annotations,Class clz) {
for (Annotation annotation:annotations) {
Class<? extends Annotation> annotationType = annotation.annotationType();
if (annotationType.getName().equals(clz.getName())) {
return true;
}
}
return false;
}
/**
* 通过javassist反射机制 获取被切参数名以及参数值
*
* @param cls
* @param clazzName
* @param methodName
* @param args
* @return
* @throws NotFoundException
*/
private List<Map<String,Object>> getFieldsNameAndValue(Class cls, String clazzName, String methodName, Object[] args) throws NotFoundException {
List<Map<String,Object>> list = new ArrayList<>(16);
ClassPool pool = ClassPool.getDefault();
ClassClassPath classPath = new ClassClassPath(cls);
pool.insertClassPath(classPath);
CtClass cc = pool.get(clazzName);
CtMethod cm = cc.getDeclaredMethod(methodName);
MethodInfo methodInfo = cm.getMethodInfo();
CodeAttribute codeAttribute = methodInfo.getCodeAttribute();
LocalVariableAttribute attr = (LocalVariableAttribute) codeAttribute.getAttribute(LocalVariableAttribute.tag);
if (attr == null) {
throw new RuntimeException("attr not found");
}
int pos = Modifier.isStatic(cm.getModifiers()) ? 0 : 1;
for (int i = 0; i < cm.getParameterTypes().length; i++) {
LinkedHashMap<String, Object> map = new LinkedHashMap<>(2);
map.put("name",attr.variableName(i+pos));
map.put("value",args[i]);
list.add(map);
}
return list;
}
}
好了,至此思路代码都OK了,下面放个github源码链接,要研究的自己下载吧. github源码