自定义节流
有些时候为了对用户的访问频率进行限制和防止爬虫,需要在规定的时间中对用户访问的次数进行限制
下面自定义一个用户每分钟只能访问3次,代码如下:
from rest_framework.throttling import BaseThrottle import time VISIT_RECORD = {} #保存访问记录 class VisitThrottle(BaseThrottle): '''60s内只能访问3次''' def __init__(self): self.history = None #初始化访问记录 def allow_request(self,request,view): #获取用户ip (get_ident) remote_addr = self.get_ident(request) ctime = time.time() #如果当前IP不在访问记录里面,就添加到记录 if remote_addr not in VISIT_RECORD: VISIT_RECORD[remote_addr] = [ctime,] #键值对的形式保存 return True #True表示可以访问 #获取当前ip的历史访问记录 history = VISIT_RECORD.get(remote_addr) #初始化访问记录 self.history = history #如果有历史访问记录,并且最早一次的访问记录离当前时间超过60s,就删除最早的那个访问记录, #只要为True,就一直循环删除最早的一次访问记录 while history and history[-1] < ctime - 60: history.pop() #如果访问记录不超过三次,就把当前的访问记录插到第一个位置(pop删除最后一个) if len(history) < 3: history.insert(0,ctime) return True def wait(self): '''还需要等多久才能访问''' ctime = time.time() return 60 - (ctime - self.history[-1])
上面的大码就是当用户第一次访问的时候把它的IP地址和当前访问的时间添加到字典 VISIT_RECORD 中, 循环取出最先添加的时间判断其时间有没有过了60s,如果过了则将其删掉,然后在看列表的长度是否小于3次,如果小于3次证明,是可以访问的的。wait表示的是还需要等待多久可以继续访问
我们在使用的时候,只需在类视图中简单在类属性 throttle_classes = [ 节流控制的类] 即可
下面我们在用户登陆的时候来简单的测试下,代码代码如下:
class AuthView(APIView): """ 用于用户登录认证 """ authentication_classes = [] permission_classes = [] throttle_classes = [VisitThrottle,] # 进行访问次数的限制 def post(self,request,*args,**kwargs): ret = {'code':1000,'msg':None} try: user = request._request.POST.get('username') pwd = request._request.POST.get('password') obj = models.UserInfo.objects.filter(username=user,password=pwd).first() if not obj: ret['code'] = 1001 ret['msg'] = "用户名或密码错误" # 为登录用户创建token token = md5(user) # 存在就更新,不存在就创建 models.UserToken.objects.update_or_create(user=obj,defaults={'token':token}) ret['token'] = token except Exception as e: print(e) ret['code'] = 1002 ret['msg'] = '请求异常' return JsonResponse(ret)
我们连续访问3次登陆测试的结果如下
节流源码分析
源码入口 dispatch 代码如下
def dispatch(self, request, *args, **kwargs): """ `.dispatch()` is pretty much the same as Django's regular dispatch, but with extra hooks for startup, finalize, and exception handling. """ self.args = args self.kwargs = kwargs # 对原生的request进行封装 request = self.initialize_request(request, *args, **kwargs) self.request = request self.headers = self.default_response_headers # deprecate? try: self.initial(request, *args, **kwargs) # Get the appropriate handler method if request.method.lower() in self.http_method_names: handler = getattr(self, request.method.lower(), self.http_method_not_allowed) else: handler = self.http_method_not_allowed response = handler(request, *args, **kwargs) except Exception as exc: response = self.handle_exception(exc) self.response = self.finalize_response(request, response, *args, **kwargs) return self.response
执行认证 self.initial 代码如下
def initial(self, request, *args, **kwargs): """ Runs anything that needs to occur prior to calling the method handler. """ self.format_kwarg = self.get_format_suffix(**kwargs) # Perform content negotiation and store the accepted info on the request neg = self.perform_content_negotiation(request) request.accepted_renderer, request.accepted_media_type = neg # Determine the API version, if versioning is in use. version, scheme = self.determine_version(request, *args, **kwargs) request.version, request.versioning_scheme = version, scheme # Ensure that the incoming request is permitted # 实现认证 self.perform_authentication(request) # 权限判断 self.check_permissions(request) # 访问频率控制 self.check_throttles(request)
访问频率控制self.check_throttles(request)代码如下:
def check_throttles(self, request): """ Check if request should be throttled. Raises an appropriate exception if the request is throttled. """ for throttle in self.get_throttles(): if not throttle.allow_request(request, self): self.throttled(request, throttle.wait())
在上面的源码中我们可以知道,如果没有通过会返回 false 执行 self.throttled(request, throttle.wait()),抛出异常,代码如下
def throttled(self, request, wait): """ If request is throttled, determine what kind of exception to raise. """ raise exceptions.Throttled(wait)
如果 返回True表示可以继续访问,让我们来继续追踪 self.get_throttles() 代码如下:
def get_throttles(self): """ Instantiates and returns the list of throttles that this view uses. """ # 返回访问频率控制的列表 return [throttle() for throttle in self.throttle_classes]
通过上面的源码我们知道为什么我们在需要进行节流限制的接口中设置类属性 throttle_class = [ 节流限制的类 ] ,还有就是在我们节流限制的类中为什么要重写
allow_request(request, self) 和 wait 方法
进行全局节流的配置
上面是我们自己定义的节流控制,走我们自己定义的类属性 throttle_class = [ 节流限制的类 ] 如果我们想使用父类的从而实现全局的配置又该如何去实现呢
父类中的 thtottle_classes 设置如下
所以我们只需在配置文件中的 REST_FRAMEWORK 添加 DEFAULT_THROTTLE_CLASSES 的路径即可,配置如下
REST_FRAMEWORK = { # 全局使用的认证类 "DEFAULT_AUTHENTICATION_CLASSES":['api.utils.auth.FirstAuthtication','api.utils.auth.Authtication', ], "UNAUTHENTICATED_USER":lambda :"匿名用户", "UNAUTHENTICATED_TOKEN":None, "DEFAULT_PERMISSION_CLASSES":['api.utils.permission.SVIPPermission'],# 默认的权限认证 "DEFAULT_THROTTLE_CLASSES":["api.utils.throttle.VisitThrottle"], # 进行节流的限制 }
根据上面节流限制类路径的定义,我们在应用 api 下的utils目录下创建 throttle.py 把代码如下
from rest_framework.throttling import BaseThrottle,SimpleRateThrottle import time VISIT_RECORD = {} class VisitThrottle(BaseThrottle): def __init__(self): self.history = None def allow_request(self,request,view): # 1. 获取用户IP remote_addr = self.get_ident(request) ctime = time.time() if remote_addr not in VISIT_RECORD: VISIT_RECORD[remote_addr] = [ctime,] return True history = VISIT_RECORD.get(remote_addr) self.history = history while history and history[-1] < ctime - 60: history.pop() if len(history) < 3: history.insert(0,ctime) return True # return True # 表示可以继续访问 # return False # 表示访问频率太高,被限制 def wait(self): # 还需要等多少秒才能访问 ctime = time.time() return 60 - (ctime - self.history[-1])
我们的视图类中就不需要再添加节流的类属性的配置了,就可以实现节流的控制。简单的代码如下
class AuthView(APIView): """ 用于用户登录认证 """ authentication_classes = [] permission_classes = [] def post(self,request,*args,**kwargs): pass