问题:控制访问频率,在访问的时候加上一定的次数限制
基本实现
views.py
class VisitThrottle(object):
def allow_request(self, request, view):
return True # 可以继续访问
# return False # 访问频率太高, 被限制
def wait(self):
pass
可以进一步的升级,限制 10s 内只能访问3次
import time
VISIT_RECORD = {}
class VisitThrottle(object):
'''
10s内只能访问3次
'''
def allow_request(self, request, view):
# 1. 获取用户IP
remote_addr = request.META.get('REMOTE_ADDR')
ctime = time.time()
if remote_addr not in VISIT_RECORD:
VISIT_RECORD[remote_addr] = [ctime, ]
return True
history = VISIT_RECORD.get(remote_addr)
while history and history[-1] < ctime - 10:
history.pop()
if len(history) < 3:
history.insert(0, ctime)
return True
# return True # 可以继续访问
# return False # 访问频率太高, 被限制
def wait(self):
'''
还需要等待的时间
'''
ctime = time.time()
return 60 - (ctime - self.history[-1])
源码流程
和前面一样,也是从 dispatch
开始,到 initial
def initial(self, request, *args, **kwargs):
# Ensure that the incoming request is permitted
self.perform_authentication(request)
self.check_permissions(request)
# 控制访问频率
self.check_throttles(request)
def check_throttles(self, request):
# get_throttles 里面是一个列表生成式
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
self.throttled(request, throttle.wait())
def get_throttles(self):
"""
Instantiates and returns the list of throttles that this view uses.
"""
return [throttle() for throttle in self.throttle_classes]
throttle_classes
默认使用配置文件
class APIView(View):
...
throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
...
可以添加到全局使用,首先在 utils 下新建 throttle.py,将视图文件中的类移至 throttle.py,这里修改了 60s内能访问3次
# throttle.py
import time
VISIT_RECORD = {}
class VisitThrottle(object):
'''
60s内只能访问3次
'''
def __init__(self):
self.history = None
def allow_request(self, request, view):
# 1. 获取用户IP
remote_addr = request.META.get('REMOTE_ADDR')
ctime = time.time()
if remote_addr not in VISIT_RECORD:
VISIT_RECORD[remote_addr] = [ctime, ]
return True
history = VISIT_RECORD.get(remote_addr)
self.history = history
while history and history[-1] < ctime - 60:
history.pop()
if len(history) < 3:
history.insert(0, ctime)
return True
# return True # 可以继续访问
# return False # 访问频率太高, 被限制
def wait(self):
'''
还需要等待的时间
'''
ctime = time.time()
return 60 - (ctime - self.history[-1])
然后在配置文件 settings.py 中添加路径
REST_FRAMEWORK = {
...
'DEFAULT_THROTTLE_CLASSES': ['api.utils.throttle.VisitThrottle']
}
最后将视图中的局部配置删除即可。
回到 check_throttles
def check_throttles(self, request):
for throttle in self.get_throttles():
# throttle.allow_request 为 False,走下一步,throttled 抛出异常,表示访问频率过多
if not throttle.allow_request(request, self):
self.throttled(request, throttle.wait())
def throttled(self, request, wait):
"""
If request is throttled, determine what kind of exception to raise.
"""
raise exceptions.Throttled(wait)
频率的内置类
在自定义频率的时候,为了更加规范,需要继承,并且父类有获取 IP 的方法(可以在 BaseThrottle
中查看),因此这里直接调用父类的方法即可
from rest_framework.throttling import BaseThrottle
import time
VISIT_RECORD = {}
class VisitThrottle(BaseThrottle):
'''
60s内只能访问3次
'''
def __init__(self):
self.history = None
def allow_request(self, request, view):
# 1. 获取用户IP,调用父类的方法
remote_addr = self.get_ident(request)
ctime = time.time()
if remote_addr not in VISIT_RECORD:
VISIT_RECORD[remote_addr] = [ctime, ]
return True
history = VISIT_RECORD.get(remote_addr)
self.history = history
while history and history[-1] < ctime - 60:
history.pop()
if len(history) < 3:
history.insert(0, ctime)
return True
# return True # 可以继续访问
# return False # 访问频率太高, 被限制
def wait(self):
'''
还需要等待的时间
'''
ctime = time.time()
return 60 - (ctime - self.history[-1])
进入 BaseThrottle
,发现在其下方有个 SimpleRateThrottle
,也是继承 BaseThrottle
。首先看 SimpleRateThrottle
的 __init__
方法
class SimpleRateThrottle(BaseThrottle):
... # 省略的内容
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
# 这里执行了 get_rate 方法
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
def get_rate(self):
"""
Determine the string representation of the allowed request rate.
"""
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
# scope实际上是一个字典的 key,这里在 THROTTLE_RATES 中取值
# 在上面的代码中看到 THROTTLE_RATES 是一个配置项,获取用户自定义的配置
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
至此,就可以在配置文件中写一个 60s内能访问3次 的程序,让它自动完成,无需自定义写
throttle.py
class VisitThrottle(SimpleRateThrottle):
scope = "xi" # scope作为key使用
settings.py
REST_FRAMEWORK = {
... # 省略
'DEFAULT_THROTTLE_CLASSES': ['api.utils.throttle.VisitThrottle'],
'DEFAULT_THROTTLE_RATES' : {
'xi': '3/m' # m是分钟,每分钟访问3次
}
}
这时,配置了访问次数,就会在 return self.THROTTLE_RATES[self.scope]
中获取到,返回给 get_rate
方法,然后 __init__
中的 rate
拿到的就是 3/m
class SimpleRateThrottle(BaseThrottle):
... # 省略的内容
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
# '3/m'
self.rate = self.get_rate()
# 将字符串 '3/m' 当做参数传递给 parse_rate
# 走完 parse_rate,num_requests代表3次,duration代表60s
self.num_requests, self.duration = self.parse_rate(self.rate)
.... # 省略
def parse_rate(self, rate):
"""
Given the request rate string, return a two tuple of:
<allowed number of requests>, <period of time in seconds>
"""
# rate就是 '3/m'
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
此时,构造函数走完,接着查看 allow_request
def allow_request(self, request, view):
if self.rate is None:
return True
# 内置提供的访问记录放在了缓存中,通过 get_cache_key 实现
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
# 来到 get_cache_key,源码上并没有写什么,这表示是让我们自己写的
def get_cache_key(self, request, view):
raise NotImplementedError('.get_cache_key() must be overridden')
# get_cache_key 实际上是表示能够唯一标识的方法,所以返回值可以是获取IP,用来表示谁的访问记录
# throttle.py
class VisitThrottle(SimpleRateThrottle):
scope = "xi"
def get_cache_key(self, request, view):
return self.get_ident(request) # 获取IP
回到 allow_request
def allow_request(self, request, view):
if self.rate is None:
return True
# 内置提供的访问记录放在了缓存中,通过 get_cache_key 实现
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
# 去缓存中取出所有记录
# cache = default_cache,是django内置的缓存
self.history = self.cache.get(self.key, [])
self.now = self.timer() # timer() = time.time(),获取当前时间
# Drop any requests from the history which have now passed the
# throttle duration
# 这里与上面自定义的相同
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self):
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
# 如果成功,加到历史记录中
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
"""
Called when a request to the API has failed due to throttling.
"""
return False
def wait(self):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
if available_requests <= 0:
return None
return remaining_duration / float(available_requests)
照样是前三次可以访问,后面再访问需要等一分钟,这是对匿名用户的控制
也可以对登录的用户进行控制,但在全局的设置中,不能既有匿名的,还有登录的。这时,就可以将登录用户的访问控制设为全局,匿名用户使用局部的设置。
settings.py
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': ['api.utils.auth.FirstAuthentication', 'api.utils.auth.Authentication'],
# 'DEFAULT_AUTHENTICATION_CLASSES': ['api.utils.auth.FirstAuthentication', ],
'UNAUTHENTICATED_USER': None,
'UNAUTHENTICATED_TOKEN': None,
'DEFAULT_PERMISSION_CLASSES': ['api.utils.permission.SVIPPermission'],
'DEFAULT_THROTTLE_CLASSES': ['api.utils.throttle.UserThrottle'], # 登录用户
'DEFAULT_THROTTLE_RATES' : {
'xi': '3/m',
'xiUser': '10/m'
}
}
throttle.py
# 匿名用户
class VisitThrottle(SimpleRateThrottle):
scope = "xi"
def get_cache_key(self, request, view):
return self.get_ident(request)
# 登录用户
class UserThrottle(SimpleRateThrottle):
scope = "xiUser"
def get_cache_key(self, request, view):
return request.user.username
views.py
from django.shortcuts import render, HttpResponse
from django.http import JsonResponse
from rest_framework.views import APIView
from api import models
from api.utils.permission import SVIPPermission, MyPermission
from api.utils.throttle import VisitThrottle
ORDER_DICT = {
1: {
'name': 'qiu',
'age': 18,
'gender': '男',
'content': '...'
},
2: {
'name': 'xi',
'age': 19,
'gender': '男',
'content': '.....'
}
}
def md5(user):
import hashlib
import time
ctime = str(time.time())
m = hashlib.md5(bytes(user, encoding='utf-8'))
m.update(bytes(ctime, encoding='utf-8'))
return m.hexdigest()
class AuthView(APIView):
authentication_classes = []
permission_classes = []
throttle_classes = [VisitThrottle] # 为匿名用户设置频率控制
def post(self, request, *args, **kwargs):
ret = {'code': 1000, 'msg': None}
try:
user = request._request.POST.get('username')
pwd = request._request.POST.get('password')
obj = models.UerInfo.objects.filter(username=user, password=pwd).first()
if not obj:
ret['code'] = 1001
ret['msg'] = '用户名或密码错误'
# 为登录用户创建token
else:
token = md5(user)
# 存在就更新, 不存在就创建
models.UserToken.objects.update_or_create(user=obj, defaults={'token': token})
ret['token'] = token
except Exception as e:
ret['code'] = 1002
ret['msg'] = '请求异常'
return JsonResponse(ret)
class OrderView(APIView):
'''
订单相关业务(只有SVIP用户有权限)
'''
def get(self, request, *args, **kwargs):
ret = {'code': 1000, 'msg': None, 'data': None}
try:
ret['data'] = ORDER_DICT
except Exception as e:
pass
return JsonResponse(ret)
class UserInfoView(APIView):
'''
用户中心(普通用户、VIP有权限)
'''
permission_classes = [MyPermission]
def get(self, request, *args, **kwargs):
return HttpResponse('用户信息')
总结
使用
类,继承
BaseThrottle
,实现allow_request
、wait
类,继承
SimpleRateThrottle
,实现get_cache_key
、scope = "xi"(配置文件中的key)
局部:
throttle_classes = [VisitThrottle]
全局:配置
settings.py