最近看论文看不下去,就看看一些ML的课程。目前在看CS231n,此篇记录自己的完成情况,学习到的东西,以及遇到的坑。
代码:https://github.com/lightaime/cs231n
这次的作业主要是实现KNN分类器,它的主要思想就是当拿到一张图片时,我想对这张图片进行分类,那么我就可以计算这张图片和所有训练图片之间的距离,找出距离这张测试图片的K张图片,获取它们的标签,然后采用类似投票的方式,将得票最多的图片的标签赋给测试图片,作为其标签。
虽然有现成的代码,但是我还是自己仿照着写了一遍,因为好多东西只有自己实现一遍,才会有更深的理解,事实证明我也遇到了一些坑,切勿眼高手低。
下面是代码详解:
自己写的读取数据:
def load_file(file,cifar10_dir):
file = cifar10_dir + '\\' + file
with open(file,'rb') as fo:
dict = pickle.load(fo,encoding='bytes')
return dict
def load_CIFAR10(cifar10_dir):
X_train = np.zeros((50000,3072))
Y_train = np.zeros((50000))
X_test = np.zeros((10000,3072))
Y_test = np.zeros((10000))
for root,dirs,files in os.walk(cifar10_dir):
index = 0
for file in files:
if file != 'test_batch':
data = load_file(file,cifar10_dir)
X_train[index*10000:(index+1)*10000] = data[b'data']
Y_train[index*10000:(index+1)*10000] = data[b'labels']
index += 1
else :
此处巨坑,如果直接写X_test = data[b'data'],那么python会自动判断X_test内的数据为int8类型,这导致后面的平方操作出错。
所以以后要特别小心python的数据类型,因为它不像C++一样显式的写出来。
X_test[0:10000] = data[b'data']
Y_test[0:10000] = data[b'labels']
return X_train,Y_train,X_test,Y_test
KNN分类器:
class KNearestNeighbor(object):
def __init__(self):
pass
def train(self,X,y):
"""
训练分类器,对于KNN来说,这只是存储训练数据
:param X: 训练图片矩阵
:param y: 图片矩阵对应的标签
:return:
"""
self.X_train = X
self.Y_train = y
def predict(self,X,k=1,num_loops=0):
"""
用该分类器预测测试数据的标签
:param X: 测试图片矩阵
:param k: 选取的临近点的数量
:param num_loops: 决定用哪种方式来计算距离
:return: 预测的标签
"""
if num_loops == 0 :
dists = self.compute_distances_no_loops(X)
elif num_loops == 1:
dists = self.compute_distances_one_loop(X)
elif num_loops == 2:
dists = self.compute_distances_two_loops(X)
else:
raise ValueError('Invaild value')
return self.predict_labels(dists,k=k)
def compute_distances_two_loops(self,X):
"""
计算测试X与X_train中每个数据的距离
:param X:测试数据
:return:距离矩阵
"""
num_test = X.shape[0]
num_train = self.X_train.shape[0]
dists = np.zeros((num_test,num_train))
for i in range(num_test):
for j in range(num_train):
dists[i][j] = np.sqrt(np.sum(np.square(X[i]-self.X_train[j])))
return dists
def compute_distances_one_loop(self,X):
num_test = X.shape[0]
num_train = self.X_train.shape[0]
dists = np.zeros((num_test,num_train))
for i in range(num_test):
dists[i] = np.sum(np.square(self.X_train - X[i]),axis = 1)
return dists
不用循环的应该是ML中经典的方法,通过矩阵乘积的方式来实现,mark。
主要就是将平方项进行展开,可以发现中间的项正好是矩阵乘以另一个矩阵的转置
def compute_distances_no_loops(self, X):
num_test = X.shape[0]
num_train = self.X_train.shape[0]
dists = np.zeros((num_test,num_train))
dists = -2*np.dot(X,self.X_train.T)
sq1 = np.transpose([np.sum(np.square(X),axis=1)])
sq2 = np.sum(np.square(self.X_train),axis=1)
dists = np.add(dists,sq1)
dists = np.add(dists,sq2)
dists = np.sqrt(dists)
return dists
def predict_labels(self,dists,k=1):
"""
通过距离矩阵和训练数据集来预测测试数据的label
:param self:
:param dist:
:param k:
:return:
"""
num_test = dists.shape[0]
y_pred = np.zeros(num_test)
for i in range(num_test):
closest_y = []
closest_y = self.Y_train[dists[i].argsort()[0:k]]
closest_y = [int(i) for i in closest_y]
y_pred[i] = np.argmax(np.bincount(closest_y))
return y_pred
交叉验证
设置5折
num_folds = 5
k_choices = [1,3,5,8,10,12,15,20,50,100]
x_train_folds = []
y_train_folds = []
将数据划分成5个部分,此时维度为(5,1000,3072)
x_train_folds = np.array_split(X_train,num_folds)
y_train_folds = np.array_split(Y_train,num_folds)
k_to_accuracies = {}
classifier = KNearestNeighbor()
for k in k_choices:
accuracies = np.zeros(num_folds)
for fold in range(num_folds):
temp_X = x_train_folds[:]
temp_y = y_train_folds[:]
取出一折作为测试数据
x_validate_fold = temp_X.pop(fold)
y_validate_fold = temp_y.pop(fold)
将其他折维度为(4,1000,3072)合并为维度为(4000,3072)的训练数据,相当于遍历拆开,mark
temp_X = np.array([y for x in temp_X for y in x])
temp_y = np.array([y for x in temp_y for y in x])
classifier.train(temp_X,temp_y)
y_test_pred = classifier.predict(x_validate_fold, k=k)
num_correct = np.sum(y_test_pred == y_validate_fold)
accuracy = float(num_correct) / num_test
accuracies[fold] = accuracy
k_to_accuracies[k] = accuracies
hint:
1.注意python的数据类型,有可能会让结果出错,比如int8操作之后超出范围等
2.get矩阵并行计算操作,very important
3.get一些api,比如argsort:排序过后返回值相应的索引,bincount:针对非负的一个数组,统计每个数字出现的次数。(python的api
是真的方便)