函数MnistConv使用反向传播算法训练网络,获取神经网络的权重和训练数据,并返回训练后的权重。
The function MnistConv, which trains thenetwork using the back-propagation algorithm, takes the neural network’sweights and training data and returns the trained weights.
[W1,W5, Wo] = MnistConv(W1, W5, Wo, X, D)
其中,W1、W5和Wo分别是卷积滤波器矩阵、池化-隐藏层权重矩阵和隐藏-输出层权重矩阵。
where W1, W5, and Wo are the convolutionfilter matrix, pooling-hidden layer weight matrix, and hidden-output layerweight matrix, respectively.
X和D分别是训练数据的输入和正确输出。
X and D are the input and correct outputfrom the training data, respectively.
下面的清单显示了MnistConv.m文件中的代码,它实现了MnistConv函数的功能。
The following listing shows the MnistConv.mfile, which implements the MnistConv function.
function [W1, W5, Wo] = MnistConv(W1, W5,Wo, X, D)
alpha = 0.01;
beta = 0.95;
momentum1 = zeros(size(W1));
momentum5 = zeros(size(W5));
momentumo = zeros(size(Wo));
N = length(D);
bsize = 100;
blist = 1:bsize:(N-bsize+1);
% One epoch loop 按照时代进行循环
for batch = 1:length(blist)
dW1= zeros(size(W1));
dW5= zeros(size(W5));
dWo= zeros(size(Wo));
%Mini-batch loop
begin= blist(batch);
fork = begin:begin+bsize-1
%Forward pass = inference
x= X(:, :, k); % Input, 28x28
y1 = Conv(x, W1); % Convolution,20x20x20
y2 = ReLU(y1); %
y3 = Pool(y2); % Pool, 10x10x20
y4 = reshape(y3, [], 1); % 2000
v5 = W5*y4; % ReLU, 360
y5 = ReLU(v5); %
v = Wo*y5; % Softmax, 10
y= Softmax(v); %
%One-hot encoding
d= zeros(10, 1);
d(sub2ind(size(d),D(k), 1)) = 1;
% Backpropagation 反向传播
e = d - y; % Output layer
delta = e;
e5= Wo' * delta; % Hidden(ReLU) layer
delta5= (y5 > 0) .* e5;
e4= W5' * delta5; % Pooling layer
e3 = reshape(e4, size(y3));
e2= zeros(size(y2));
W3= ones(size(y2)) / (2*2);
forc = 1:20
e2(:, :, c) = kron(e3(:, :,c), ones([2 2])) .* W3(:, :, c);
end
delta2= (y2 > 0) .* e2; % ReLU layer
delta1_x= zeros(size(W1)); % Convolutional layer
forc = 1:20
delta1_x(:,:, c) = conv2(x(:, :), rot90(delta2(:, :, c), 2), 'valid');
end
dW1= dW1 + delta1_x;
dW5= dW5 + delta5*y4';
dWo= dWo + delta *y5';
end
%Update weights 更新权值
dW1= dW1 / bsize;
dW5= dW5 / bsize;
dWo= dWo / bsize;
momentum1= alpha*dW1 + beta*momentum1;
W1= W1 + momentum1;
momentum5= alpha*dW5 + beta*momentum5;
W5= W5 + momentum5;
momentumo= alpha*dWo + beta*momentumo;
Wo= Wo + momentumo;
end
end % 函数结束的end标记
这个代码看起来比以前的例子要复杂得多。
This code appears to be rather more complexthan the previous examples.
让我们一部分一部分地看看这段代码。
Let’s take a look at it part by part.
函数MnistConv通过小批量方法训练网络,而前面的示例使用SGD和批量方法。
The function MnistConv trains the networkvia the minibatch method, while the previous examples employed the SGD andbatch methods.
——本文译自Phil Kim所著的《Matlab Deep Learning》
更多精彩文章请关注微信号: