1.一段时间内ip连接数大于一定值则断开该ip所有连接且拒绝ip一定时间内连接
2.一段时间内ip连接所发送的数据大于一定值则断开该ip所有连接且拒绝ip一定时间内连接
其实是实现判断频率的一种算法,有一个陷阱是如何判断连续时间内的频率,因为把时间分成一段段的话,按不同起始时间来分话频率会不一样。所以这个判断算法是有一定精度的,就是多久判断一次。相当于这个时间窗口移动的精度。
代码:
限制频率的类:
package com.cgs.iot.io.nio.handler;
public class FrequencyLimitation {
protected long interval; //间隔时间
protected long max; //最大值
protected int accuracy; //精度
protected long[] countList; //最近interval内的count
protected int index = 0; //countList的索引
protected boolean flag; //是否超出范围
protected long time; //超出范围的时间点
public static final Boolean WITHIN_LIMIT = true; //频率未超出范围
public static final Boolean EXCEED_LIMIT = false; //频率超出范围
protected int count = 1; //每次触发增加的值
protected long begin = System.currentTimeMillis(); //开始计数的时间
boolean result; //结果,据此对ctx进行处理
public FrequencyLimitation(long interval, long max, int accuracy, boolean flag, long time) {
super();
this.interval = interval;
this.max = max;
this.accuracy = accuracy;
this.countList = new long[accuracy];
this.flag = flag;
this.time = time;
}
public FrequencyLimitation(long interval, long max, int accuracy) {
super();
this.interval = interval;
this.max = max;
this.accuracy = accuracy;
this.countList = new long[accuracy];
}
/**
* 限制频率的方法,间隔时间interval内次数大于最大值max则限制
* @param ctx ChannelHandlerContext
* @return
*/
public void limitIpFrequencyWithAccuracy() {
long current = System.currentTimeMillis();
long delay = current - begin;
int indexCurrent = 0;
// logger.info("current:" + current + " - begin:" + begin + "=" + delay + (delay < interval+1?" <= ":" > ") + "interval:" + interval + "ms");
indexCurrent = getIndex(delay, interval, accuracy);
countList = rollForward(countList, indexCurrent-index);
updateCountList(countList, count);
if(outOfBoundary(countList, max)) {
flag = true;
time = System.currentTimeMillis();
setResult(EXCEED_LIMIT);
}else {
setResult(WITHIN_LIMIT);
}
if(outOfTimeBoundary(delay,interval)) {
index = 0;
begin = System.currentTimeMillis();
}else {
index = indexCurrent;
}
}
/**
* 统计数组中所有数的和
* @param countList 所需统计的数组
* @return
*/
private long sum(long[] countList) {
long sum = 0;
for(int i =0;i<countList.length;i++) {
sum = sum + countList[i];
}
return sum;
}
/**
* 将数组向前推移 如[0,1,2,3,4]向前滚2位为[2,3,4,0,0]
* @param countList 所需前滚的数组
* @param index 前滚的位数
* @return
*/
private long[] rollForward(long[] countList, int index) {
// logger.info("before rollForward " + index + ":" + Arrays.toString(countList));
long[] countListNew = new long[countList.length];
if(index<countList.length) { //若前滚位数大于数组长度,则新数组各个位为0
for(int i=0;i<countList.length-index;i++) {
countListNew[i] = countList[i+index];
}
}
// logger.info("after rollForward " + index + ":" + Arrays.toString(countListNew));
return countListNew;
}
/**
* 根据参数确定index
* @param delay 时间差
* @param interval 配置的时间间隔
* @param accuracy 精度,将时间间隔分为几份
* @return
*/
private int getIndex (long delay, long interval, int accuracy) {
int index = (int) (delay/(interval/accuracy));
return index;
}
/**
* 将数组最后一位的数值增加
* @param countList 需增加的数组
* @param count 需增加的值
* @return
*/
private long[] updateCountList(long[] countList, int count) {
countList[countList.length-1] = countList[countList.length-1] + count;
return countList;
}
/**
* 判断数组中所有值之和是否超出限制的最大值
* @param countList 需比较的数组
* @param max 最大值
* @return
*/
private boolean outOfBoundary(long[] countList, long max) {
// logger.info(" count: " + sum(countList) + (sum(countList) > max?" > ":" <= ") + "max: " + max);
return sum(countList) > max;
}
/**
* 判断延时是否超出限制的最大值
* @param delay 时间差
* @param max 最长时间
* @return
*/
private boolean outOfTimeBoundary(long delay, long max) {
return delay > max;
}
public long getBegin() {
return begin;
}
public void setBegin(long begin) {
this.begin = begin;
}
public long[] getCountList() {
return countList;
}
public void setCountList(long[] countList) {
this.countList = countList;
}
public int getIndex() {
return index;
}
public void setIndex(int index) {
this.index = index;
}
public boolean isFlag() {
return flag;
}
public void setFlag(boolean flag) {
this.flag = flag;
}
public long getTime() {
return time;
}
public void setTime(long time) {
this.time = time;
}
public boolean getResult() {
return result;
}
public void setResult(boolean result) {
this.result = result;
}
public long getMax() {
return max;
}
public void setMax(long max) {
this.max = max;
}
public long getInterval() {
return interval;
}
public void setInterval(long interval) {
this.interval = interval;
}
}
ip限制的类:
package com.cgs.iot.io.nio.handler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
public class IpLimitHandler extends ChannelInboundHandlerAdapter {
final Logger logger = LoggerFactory.getLogger(IpLimitHandler.class);
FrequencyLimitation frequencyLimitation; //限制频率算法
public IpLimitHandlerNew(long interval, long max, int accuracy, boolean flag, long time) {
super();
frequencyLimitation = new FrequencyLimitation(interval, max, accuracy, flag, time);
}
/**
* 触发计数
* @param ctx
* @param msg
*/
protected void trigger(ChannelHandlerContext ctx, Object msg) {
logger.debug(" ip: "+ctx.channel().remoteAddress() +" channel: "+ctx.channel());
frequencyLimitation.limitIpFrequencyWithAccuracy();
ctxExcute(ctx, msg, frequencyLimitation.getResult(), Thread.currentThread() .getStackTrace()[1].getMethodName());
}
/**
* 根据结果字段和方法名字段对ctx进行处理
* @param ctx ChannelHandlerContext
* @param msg 信息
* @param result 结果字段
* @param method 方法名字段
*/
protected void ctxExcute(ChannelHandlerContext ctx, Object msg, boolean result, String method) {
if(FrequencyLimitation.EXCEED_LIMIT.equals(result)) {
logger.info("Close ip: " + ctx.channel().remoteAddress() + ", this ip exceeds the limitMax " + frequencyLimitation.max + " in limitInterval " + frequencyLimitation.interval + "ms");
ctx.close();
clearChannel(ctx.channel());
}else {
if("channelRead".equals(method)) {
ctx.fireChannelRead(msg);
}else if("channelActive".equals(method)){
ctx.fireChannelActive();
}
}
}
protected void clearChannel(Channel channel) {
}
/**
* 获得ByteBuf中的字节数
* @param buf
* @return
*/
protected int getByteCountFromByteBuf(ByteBuf buf) {
buf.markReaderIndex();
byte[] bs = new byte[buf.readableBytes()];
return bs.length;
}
}
限制连接的类:
package com.cgs.iot.io.nio.handler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelHandler.Sharable;
@Sharable
public class IpLimitConnectHandler extends IpLimitHandler {
public IpLimitConnectHandler(long interval, long max, int accuracy, Boolean flag, Long time) {
super(interval, max, accuracy, flag, time);
}
@Override
public void channelActive(ChannelHandlerContext ctx)
throws Exception {
trigger(ctx, null);
}
}
限制数据的类:
package com.cgs.iot.io.nio.handler;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
public class IpLimitDataHandler extends IpLimitHandler {
public IpLimitDataHandler(long interval, long max, int accuracy, Boolean flag, Long time) {
super(interval, max, accuracy, flag, time);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg)
throws Exception {
count = getByteCountFromByteBuf((ByteBuf) msg);
trigger(ctx, msg);
}
}