single_variable problem:
ex1.m
1 clear; 2 clc; 3 close all; 4 5 %输入数据、变量 6 data = load ('ex1data1.txt'); 7 X = data(:, 1); 8 y = data(:, 2); 9 m = length(y); 10 alpha = 0.01; 11 theta = [0;0]; 12 %----------plotdata------------ 13 plot(X, y, 'rx', 'MarkerSize', 10); 14 xlabel('populations'); 15 ylabel('profits'); 16 iterations = 1500; 17 %----------gradientDescent------------ 18 [theta,J_history] = gradientDescent(X, y, theta, iterations, alpha); 19 x = [ones(m, 1), X]; 20 hold on; 21 plot(X, x * theta); 22 legend('Training data', 'Linear regression'); 23 hold off; 24 figure; 25 plot(1:1500, J_history, 'b'); 26 legend('gradientDescent'); 27 J_vals = visualize(X, y, iterations, theta);
cost.m
function J = cost(X, y, theta) m = length(y); X = [ones(m, 1), X]; sum = 0; for i = 1:m sum = sum + (X(i, :) * theta - y(i))^2; end J = 1/ (2 * m) * sum ; end
gradientDescent.m
function [theta, J_vals] = gradientDescent(X, y, theta, iterations, alpha) m = length(y); X = [ones(m, 1), X]; J_vals = zeros(iterations, 1); for iter = 1 : iterations sum = zeros(2, 1); for i = 1:2 for j =1 : m sum(i) = sum(i) + (X(j, :) * theta - y(j)) * X(j, i); end end theta = theta - alpha / m * sum; J_vals(iter) = cost(X(:, 2), y, theta); fprintf('-----%f-----\n', J_vals(iter)); end fprintf('\n'); fprintf('我们得到theta为:\n'); fprintf('%f\n', theta); fprintf('对应代价:\n'); fprintf('%f\n', J_vals(iterations)); end
visualize.m
function J_vals = visualize(X, y, iterations,theta) %-------------surf------------------- theta0_vals = linspace(-10, 10); theta1_vals = linspace(-1, 4); % theta0_vals = linspace(-10, 10, 100); % theta1_vals = linspace(-2, 5, 100); J_vals = zeros(length(theta0_vals), length(theta1_vals)); for i = 1: length(theta0_vals) for j =1: length(theta1_vals) J_vals(i, j) = cost(X, y, [theta0_vals(i); theta1_vals(j)]); end end J_vals = J_vals'; figure; surf(theta0_vals, theta1_vals, J_vals); xlabel('theta0'); ylabel('theta1'); zlabel('J'); %--------------wan--------------- figure; contour(theta0_vals, theta1_vals, J_vals, logspace(-2, 3, 20)); xlabel('\theta_0'); ylabel('\theta_1'); hold on; plot(theta(1), theta(2), 'rx'); end
对应数据集ex1data1.txt