按照模式分类课本写的代码,如有错误欢迎指正!
main.m
%程序运行可能会需要3-5分钟的时间,请耐心等待。
clear;
%已对lms.mat进行随机打乱,并将Y由标量化为[1,10]矩阵形成data.mat
load data;
%定义测试集和验证集并增加偏置,测试集与验证集比例为4:1
testX = X(1:4000,:);
testY = Y(1:4000,:);
verifyX = X(4001:end,:);
verifyY = Y(4001:end,:);
testX = [ones(4000,1),testX];
%定义权重Wij和权重Wjk,并增加偏置
%输入层有400个单元,隐藏层有25个单元,输出层有10个单元
Wij = (rand(401,25)*2 - 1)*0.1;
Wjk = (rand(26,10)*2 - 1)*0.1;
%设置学习率和epoch,batch为整个数据集
epoch = 1000;
eta = 0.0005;
count = 1;
%变量初始化完成
%前馈与反向传播运算
while(count <= epoch)
%采用批量梯度下降法
%输入层与隐藏层的静激活
netj = testX*Wij;
%激活函数使用sigmoid函数
yj = sigmoid(netj);
%对第二层加入偏置
yj = [ones(4000,1),yj];
%第二层的静激活
netk = yj*Wjk;
%第二层的激活sigmoid函数
zk = sigmoid(netk);
%批量学习算法,计算损失函数
J = (norm(zk-testY)^2)/(2*4000);
%损失可视化用
plot_J(1,count) = J;
%开始反向传播
%计算两层sigmoid激活函数的导数
dnetj = dsigmoid(netj);
dnetk = dsigmoid(netk);
%计算deltaWjk,在批量梯度下降算法中应累加
deltaWjk = zeros(10,26);
for i = 1:4000
%bug:这里计算的是内积——已修复
deltaWjk = deltaWjk + eta*((testY(i,:) - zk(i,:))'.*dnetk(i,:)')*yj(i,:);
end
Wjk = deltaWjk'+Wjk;
%计算deltaWij,在批量梯度下降算法中应累加
deltaWij = zeros(25,401);
for i = 1:4000
deltak = (testY(i,:) - zk(i,:)).*dnetk(i,:);
%去掉偏重,权重减少一维
tmp = zeros(25, 1);
for j = 1:10
tmp = tmp + eta*Wjk(2:end,j)*deltak(j);
end
deltaWij = deltaWij + (tmp.*dnetj(i,:)')*testX(i,:);
end
Wij = deltaWij'+Wij;
count = count+1;
end
pre_res = predict(verifyX,Wij,Wjk);
acc = accuracy(verifyY, pre_res);
fprintf('epoch = %d\n',epoch);
fprintf('learning_rate = %f\n',eta);
fprintf('第一次epoch的cost: %f\n', plot_J(1));
fprintf('最后一次epoch的cost: %f\n',plot_J(end));
fprintf('测试集的正确率为%f\n',acc);
%为了使cost的减少更直观,从第十次开始画
plot(plot_J(10:end));
title('损失函数变化');
xlabel('迭代次数');
ylabel('cost');
predict.m
function [res] = predict2(x,Wij,Wjk)
%采用批量梯度下降法
%对x加入偏置
x = [ones(1000,1), x];
%输入层与隐藏层的静激活
netj = x*Wij;
%激活函数使用ReLU函数
yj = sigmoid(netj);
%对第二层加入偏置
yj = [ones(1000,1),yj];
%第二层的静激活
netk = yj*Wjk;
%第二层的激活Relu函数
res = sigmoid(netk);
end
sigmoid.m
function [y] = sigmoid(x)
y = 1./(1+exp(-x));
end
dsigmoid.m
function [y] = dsigmoid(x)
tem = 1./(1+exp(-x));
y = tem.*(1 - tem);
end
实验结果: