本文以mnist以及lenet为例
1.将测试错误样本打印出来
当运行测试时,最后的输出层为AccuracyLayer层。AccuracyLayer对前一层全连接层ip2的10个神经元输出结果进行排序,然后将最大值所对应的神经元序号与标签label进行比较,相等则判定预测正确;否则判定预测错误。所以,首先对accuracy_layer函数进行功能添加,打开src/caffe/layers/accuracy_layer.cpp文件,添加如下代码段:
void AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
...
// check if true label is in top k predictions
for (int k = 0; k < top_k_; k++) {
if (bottom_data_vector[k].second == label_value) {
// 预测正确
...
}
else
{
// 预测错误
// index为batch中的图片序号(0~99),label为标签值,output为预测值
LOG(INFO) << "index:" << i << " label:" << label_value << " output:" << bottom_data_vector[k].second;
}
}
}
这样我们就知道在一个batch中哪些图片被预测错误,以及它的标签值和预测值。测试样本总共有10000个,分为100个batch,每个batch大小为100个,所以我们还需要输出每个batch的序号。打开src/caffe/solver.cpp文件,跳转到Slover::Test()函数中,添加如下语句:
void Solver<Dtype>::Test(const int test_net_id) {
...
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
// 输出batch序号
LOG(INFO) << "batch:" << i;
}
}
做完上述改变之后发现运行训练程序时,对caffe进行make all,让它对修改过的层重新编译。
2.将日志输出至文件
编译完成后,设置训练输出日志文件
$./examples/mnist/train_lenet.sh 2>&1 | tee lenet.log
见《深度学习:21天实战caffe》第295页
3.用Matlab将错误样本可视化
下面我们来写段Matlab代码,用来读取上面的日志文件,以及将MNIST数据库可视化。
clear;clc;close all;
fid = fopen('caffe.exe.txt'); % 替换为日志文件名
tline = fgetl(fid);
C = []; % 定义空矩阵用来存放结果
while ischar(tline)
if ~isempty(strfind(tline, 'batch:')) % 查找字符串
indexline = fgetl(fid);
if ~isempty(strfind(indexline, 'batch:'))
tline = indexline;
elseif isempty(strfind(indexline, 'index:'))
tline = indexline;
else
% 在tline中解析batch
idx1 = strfind(tline, 'batch:');
batch = str2num(tline(idx1 + 6 : length(tline)));
% 在indexline中解析index,label,output
idx2 = strfind(indexline, 'index:');
idx3 = strfind(indexline, 'label:');
idx4 = strfind(indexline, 'output:');
index = str2num(indexline(idx2 + 6 : idx3 - 2));
label = str2num(indexline(idx3 + 6 : idx4 - 2));
output = str2num(indexline(idx4 + 7 : length(indexline)));
% 添加到数组中
C = [C; batch, index, label, output];
end
else
tline = fgetl(fid);
end
end
fclose(fid);
% 可视化部分
image_file_name = 't10k-images.idx3-ubyte';
fid = fopen(image_file_name);
images_data = fread(fid, 'uint8');
fclose(fid);
images_data = images_data(17:end);
image_buffer = zeros(28, 28);
for k = 1:1:size(C,1)
figure(size(C,1));
index = C(k,1) * 100 + C(k,2);
image_buffer = reshape(images_data((index) * 28 * 28 + 1 : (index + 1) * 28 * 28), 28, 28);
subplot(10, 10, k);
imshow(uint8(image_buffer)'); % 转置
title(sprintf('%d->%d', C(k,3), C(k,4))); % label -> output
end
4.可视化结果
结果如下所示,其中有些图片网络未能正确识别,还有些对于人眼来说都是模棱两可的,有点太难为机器了。。。