通过注解的形式去实现,需要用到的jar是
# gradle
implementation("net.jodah:expiringmap:0.5.8")
或者
# maven
<dependency>
<groupId>net.jodah</groupId>
<artifactId>expiringmap</artifactId>
<version>0.5.8</version>
</dependency>
新建注解类
package com.yulisao.common;
import java.lang.annotation.*;
/**
* 请求次数限制
* author yulisao
* createDate 2023/5/5
*/
@Documented
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface LimitRequest {
long time() default 60*1000; // 单位时间内 ,默认一分钟
int count() default 10; // 单位时间内限制请求次数, 默认10次
}
新建一个切面类
package com.yulisao.common;
import net.jodah.expiringmap.ExpirationPolicy;
import net.jodah.expiringmap.ExpiringMap;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* 请求次数限制切面
* author yulisao
* createDate 2023/5/5
*/
@Aspect
@Component
public class LimitRequestAspect {
private Logger log = LoggerFactory.getLogger(this.getClass());
private static ConcurrentHashMap<String, ExpiringMap<String, Integer>> book = new ConcurrentHashMap<>();
/**
* url上带参数的请求,如需限制请求次数,需在这里配置url前缀(也可以改成读取配置表)
* 比如下载文件的url是 '/file/upload?id=3' 或者 '/file/upload/3' ,应当配置成 ’/file/upload‘
* 因为每次上传,由于拼接了参数,其完整的url都不一样,后面是根据url+ip来累计请求次数的
*/
private List<String> spcUrlList = Arrays.asList(
"/file/upload",
"/file/dowm",
"/user/update"
);
// 定义切点 让所有有@LimitRequest注解的方法都执行切面方法
@Pointcut("@annotation(limitRequest)")
public void excudeService(LimitRequest limitRequest) {
}
@Around("excudeService(limitRequest)")
public Object doAround(ProceedingJoinPoint pjp, LimitRequest limitRequest) throws Throwable {
RequestAttributes ra = RequestContextHolder.getRequestAttributes();
ServletRequestAttributes sra = (ServletRequestAttributes) ra;
HttpServletRequest request = sra.getRequest();
String ip = getIpAddr(request);
String url = request.getServletPath();
log.info("request url is " + url);
log.info("request ip is " + ip);
// 带参数的url,取前面固定不变的部分作为url存map的key
String prefix = getPathPrefix(url);
if (StringUtils.isNotBlank(prefix)) {
url = prefix;
}
// 根据请求的url+用户真实ip作为key,记录单位时间内请求次数
ExpiringMap<String, Integer> uc = book.getOrDefault(url, ExpiringMap.builder().variableExpiration().build());
Integer uCount = uc.getOrDefault(ip, 0);
log.info("request uCount is " + uCount);
if (uCount >= limitRequest.count()) {
// 超过次数,不执行目标方法
throw new Exception("请求频繁,请稍后在试!");
} else if (uCount == 0){
// 第一次请求时,设置有效时间
uc.put(ip, uCount + 1, ExpirationPolicy.CREATED, limitRequest.time(), TimeUnit.MILLISECONDS);
} else {
// 未超过次数, 记录加一
uc.put(ip, uCount + 1);
}
book.put(url, uc);
// result的值就是被拦截方法的返回值
return pjp.proceed();
}
/**
* 获取请求IP
* @param request
* @return
*/
public static String getIpAddr(HttpServletRequest request) {
String ipAddress = request.getHeader("x-forwarded-for");
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getHeader("WL-Proxy-Client-IP");
}
if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
ipAddress = request.getRemoteAddr();
if (ipAddress.equals("127.0.0.1") || ipAddress.equals("0:0:0:0:0:0:0:1")) {
InetAddress inet = null; //根据网卡取本机配置的IP
try {
inet = InetAddress.getLocalHost();
} catch (UnknownHostException e) {
e.printStackTrace();
}
ipAddress = inet.getHostAddress();
}
}
//对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照逗号分割
if (ipAddress != null && ipAddress.length() > 15) {
// ***.***.***.***
if (ipAddress.indexOf(",") > 0) {
ipAddress = ipAddress.substring(0, ipAddress.indexOf(",")); // 截取第一个IP
}
}
return ipAddress;
}
private String getPathPrefix(String url) {
for(int i=0; i < spcUrlList.size(); i++){
Pattern pattern = Pattern.compile(spcUrlList.get(i));
Matcher matcher = pattern.matcher(url);
if(matcher.find()){
//matcher.find()-为模糊查询 matcher.matches()-为精确查询
return spcUrlList.get(i);
}
}
return null;
}
}
给需要限制请求次数的接口添加自定义注解
@ApiOperation("上传图片")
@GetMapping("/file/upload/{id}")
@LimitRequest(time = 60*1000, count = 10) // 两个参数,这里也可以重新赋值
public void workUpLoadPic(@PathVariable("id") String id){
// dosomething...
}
- 对于url固定不变的,给接口上直接加上LimitRequest注解即可。可以为不同的接口给不同的限制策略,比如获取验证码接口一分钟一次, 实名认证一天三次等等
- 而url上带参数的,需要获取url前面固定的前缀作为url的唯一标识,这样后面每请求一次才会被累加一次记录下来。不然每次请求都是一个新的key存入map,其val都是1,永远不会超限(除非是重复请求参数不变)。获取url前缀我用的是matcher.find,当然也可以换成indexof,startWith等思路,灵活应用就好。
除了注解的实现方式, 也可以通过拦截器+redis缓存之类的去实现。