from numpy import genfromtxt
import numpy as np
from sklearn import datasets, linear_model
datapath = r"C:\Users\meachine learn\delivery.csv"
#skip_header=0和skip_footer=0,表示不跳过任何行。
deliverydata = genfromtxt(datapath, delimiter=',',skip_header=2, usecols=(1,2,3))
print(deliverydata)
[[100. 4. 9.3]
[ 50. 3. 4.8]
[100. 4. 8.9]
[100. 2. 6.5]
[ 50. 2. 4.2]
[ 80. 2. 6.2]
[ 75. 3. 7.4]
[ 65. 4. 6. ]
[ 90. 3. 7.6]
[ 90. 2. 6.1]]
x = deliverydata[:,:-1]
y = deliverydata[:,-1]
regr = linear_model.LinearRegression()
regr.fit(x,y)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)
print("coefficinet: ", regr.coef_)
coefficinet: [0.0611346 0.92342537]
print("intercept: ", regr.intercept_)
intercept: -0.868701466781709
根据模型进行预测
x_pred = [102, 6]
x_pred = np.array(x_pred).reshape(-1,2) #把列表转换成矩阵形式
y_pred = regr.predict(x_pred)
print(y_pred)
[10.90757981]
特征量中存在类别特征时,需要先对类别特征进行one-hot编码
path = r"C:\Users\meachine learn\deliverydummydone.csv"
data = genfromtxt(path, delimiter=',')
print(data)
[[100. 4. 1. 9.3]
[ 50. 3. 0. 4.8]
[100. 4. 1. 8.9]
[100. 2. 2. 6.5]
[ 50. 2. 2. 4.2]
[ 80. 2. 1. 6.2]
[ 75. 3. 1. 7.4]
[ 65. 4. 0. 6. ]
[ 90. 3. 0. 7.6]]
#对第三列类别特征编码
from sklearn import preprocessing
enc = preprocessing.OneHotEncoder()
data_coder = data[:,2]
data_coder = np.array(data_coder).reshape(-1,1)
enc.fit(data_coder)
OneHotEncoder(categorical_features='all', dtype=<class 'numpy.float64'>,
handle_unknown='error', n_values='auto', sparse=True)
data_2 = enc.transform(data_coder).toarray()
data_1 = np.array(data[:,0:2]).reshape(-1,2)
x = np.hstack((data_1, data_2)) #对特征量矩阵横向合并
print(x)
[[100. 4. 0. 1. 0.]
[ 50. 3. 1. 0. 0.]
[100. 4. 0. 1. 0.]
[100. 2. 0. 0. 1.]
[ 50. 2. 0. 0. 1.]
[ 80. 2. 0. 1. 0.]
[ 75. 3. 0. 1. 0.]
[ 65. 4. 1. 0. 0.]
[ 90. 3. 1. 0. 0.]]
y = data[:,-1]
regr = linear_model.LinearRegression()
regr.fit(x,y)
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)
print("coefficinet: ", regr.coef_)
coefficinet: [ 0.05553544 0.69257631 -0.17013278 0.57040007 -0.40026729]
print("intercept: ", regr.intercept_)
intercept: 0.19995688911881349