模型交叉验证概念及代码总结
文章目录
几种常见交叉验证概念总结
交叉验证(Cross Validation)是验证分类器性能的一种统计分析方法,其基本思想是在某种意义下将原始数据进行分组,一部分作为训练集,另一部分作为验证集。首先用训练集对分类器进行训练,再利用验证集来测试训练得到的模型,以此来作为评价分类器的性能指标。通常的交叉验证方法包括简单交叉验证、K 折交叉验证、留一法交叉验证和留 P 法交叉验证。
1. 简单交叉验证概念
- 简单交叉验证(Cross Validation))是将原始数据随机分为两组,一组作为训练集,另一组作为验证集,利用训练集训练分类器,然后利用验证集验证模型,将最后的分类准确率作为此分类器的性能指标,通常划分30%的数据作为验证数据集。
- 调用方法:
from sklearn.model_selection import train_test_split
train_data, test_data, train_target, test_target = train_test_split(iris.data, iris.target, test_size=0.4, random_state=0) # random_stage 为随机种子
2. K 折交叉验证概念
- K 折交叉验证(K-Fold Cross Validation),是将原始数据分成 K 组(一般是均分),然后将每个子集数据分别做一次验证集,其余的 K-1 组子集数据作为训练集,这样就会得到 K 个模型,将 K 个模型最终的验证集的分类准确率取平均值,作为 K 折交叉验证分类器的性能指标。通常设置 K 大于或等于3。
- 调用方法:
from sklearn.model_selection import KFold
kf = KFold(n_splits=5) # 5 折交叉验证实例化
3. 留一法交叉验证概念
- 留一法交叉验证(Leave-One-Out Cross Validation,LOO-CV),是指每个训练集由除一个样本之外的其余样本组成,留下的一个样本组成验证集。这样,对于 N 个样本的数据集,可以组成 N 个不同的训练集和 N 个不同的验证集,因此 LOO-CV 会得到 N 个模型,用 N 个模型最终的验证集的分类准确率的平均数作为分类器的性能指标。
- 调用方法:
from sklearn.model_selection import LeaveOneOut
loo = LeaveOneOut()
4. 留 P 法交叉验证概念
- 留 P 法交叉验证(Leave-P-Out Cross Validation,LPO-CV),与留一法交叉验证类似,是从完整的数据集中删除 p 个样本,产生所有可能的训练集,对于 N 个样本,能产生 (N, p) 个训练-验证对。
- 调用方法:
from sklearn.model_selection import LeavePOut
lpo = LeavePOut(p=5)
几种常见交叉验证在 sklearn 中的调用方法示例代码
1. 简单交叉验证代码
- 使用简单交叉验证方法对模型进行交叉验证并切分数据,其中训练数据为80%,验证数据为20%,代码如下
简单交叉验证(1折交叉验证)用到了 train_test_split() 函数
:
## 简单交叉验证
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split # 导入切分数据包
# 切分数据,训练数据为80%,验证数据为20%
train_data, test_data, train_target, test_target = train_test_split(train, target, target_size=0.8, random_state=0)
# 定义分回归器,拟合预测数据
clf = SGDRegressor(max_iter=1000, tol=1e-3)
clf.fit(train_data, train_target)
score_train = mean_squared_error(train_target, clf.predict(train_data))
score_test = mean_squared_error(test_target, clf.predict(test_data))
print("SGDRegressor train MSE: ", score_train)
print("SGDRegressor test MSE: ", score_test, "\n")
2. K 折交叉验证代码
- 使用 K 折交叉验证方法对模型进行交叉验证,K=5,代码如下:
## 5折交叉验证
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold
kf = KFold(n_splits=5)
for k, (train_index, test_index) in enumerate(kf.split(train)):
train_data, test_data, train_target, test_target = (
train.values[train_index],
train.values[test_index],
target[train_index],
target[test_index])
clf = SGDRegressor(max_iter=1000, tol=1e-3)
clf.fit(train_data, train_target)
score_train = mean_squared_error(train_target, clf.predict(train_data))
score_test = mean_squared_error(test_target, clf.predict(test_data))
print(k, "折", "SGDRegressor train MSE: ", score_train
print(k, "折", "SGDRegressor test MSE: ", score_test, '\n')
3. 留一法交叉验证代码
- 使用留一法交叉验证对模型进行交叉验证,代码如下:
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import LeaveOneOut
loo = LeaveOneOut()
# num = 100
for k, (train_index, test_index) in enumerate(loo.split(train)):
train_data, test_data, train_target, test_target = train.values[train_index], train.values[test_index], target[train_index], target[test_index]
clf = SGDRegressor(max_iter=1000, tol=1e-3)
clf.fit(train_data, train_target)
score_train = mean_squared_error(train_target, clf.predict(train_data))
score_test = mean_squared_error(test_target, clf.predict(test_data))
print(k, "个", "SGDRegressor train MSE: ", score_train)
print(k, "个", "SGDRegressor test MSE: ", score_test)
if k >= 9:
break
4. 留 P 法交叉验证代码
- 使用留 P 法交叉验证对模型进行交叉验证,代码如下:
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import LeavePOut
lpo = LeavePOut(p=10)
num = 100
for k, (train_index, test_index) in enumerate(lpo.split(train)):
train_data, test_data, train_target, test_target = train.values[train_index], train.values[test_index], target[train_index], target[test_index]
clf = SGDRegressor(max_iter=1000, tol=1e-3)
clf.fit(train_data, train_target)
score_train = mean_squared_error(train_target, clf.predict(train_data))
score_test = mean_squared_error(test_target, clf.predict(test_data))
print(k, "10个", "SGDRegressor train MSE: ", score_train)
print(k, "10个", "SGDRegressor test MSE: ", score_test, "\n")