from numpy import * import matplotlib.pyplot as plt import operator def parse_file(): fr = open('datingTestSet2.txt') org_list = fr.readlines() org_list_lens = len(org_list) org_array = zeros((org_list_lens, 3)) org_label = [] index = 0 for line in org_list: line = line.strip() line = line.split('\t') org_array[index] = line[0:3] org_label.append(int(line[-1])) index += 1 return org_array, org_label def matlab_analys(array_matlab, org_label): paint = plt.figure() ax = paint.add_subplot(111) # ax为一行一列第一个图 ax.scatter(array_matlab[:, 0], array_matlab[:, 1], 5 * array(org_label), 5 * array(org_label)) plt.show() def normal(array_normal): lens_normal = array_normal.shape[0] min_value = array_normal.min(0) max_value = array_normal.max(0) d_value = max_value - min_value nol_array = array_normal - tile(min_value, (lens_normal, 1)) nol_array = nol_array / tile(d_value, (lens_normal, 1)) return nol_array, d_value, min_value def add_weight(array_add, i1, i2, i3): f1 = float(i1 / (i1 + i2 + i3)) f2 = float(i2 / (i1 + i2 + i3)) f3 = float(i3 / (i1 + i2 + i3)) lens_add = array_add.shape[0] add_weight_array = array([f1, f2, f3]) added_arrary = array_add * tile(add_weight_array, (lens_add, 1)) return added_arrary def distance(array_measured_tar_dist, array_dist, label_dist, k): lens_dist = array_dist.shape[0] array_tar = tile(array_measured_tar_dist, (lens_dist, 1)) - array_dist array_tar_sq = array_tar ** 2 distance_sq = array_tar_sq.sum(axis=1) distances = distance_sq ** 0.5 sorted_list = distances.argsort() class_count = {} for i in range(k): nor_label = label_dist[sorted_list[i]] class_count[nor_label] = class_count.get(nor_label, 0) + 1 sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True) return sorted_class_count[0][0] def test(array_test, label_test): rates = 0.5 lens_test = array_test.shape[0] mid_lens_test = int(rates * lens_test) error_count = 0.0 for i in range(mid_lens_test): result_test = distance(array_test[i, :], array_test[mid_lens_test:lens_test, ], label_test[mid_lens_test:lens_test], 3) # print('The test result is %d,and the truth is %d' % (result_test, label_test[i])) if result_test != label_test[i]: error_count += 1 error_rate = error_count / mid_lens_test print('the test num is %d error num is %d error rate is %f' % (mid_lens_test, error_count, error_rate)) def mytest(array_mytest, label_mytest): # rates = 0.5 flag = False lens_test = array_mytest.shape[0] error_count = 0.0 for i in range(lens_test): result_test = distance(array_mytest[i], array_mytest, label_mytest, 3) # print('The test result is %d,and the truth is %d' % (result_test, label_mytest[i])) if result_test != label_mytest[i]: error_count += 1 error_rate = error_count / lens_test if error_count < 22: print('the test num is %d error num is %d error rate is %f' % (lens_test, error_count, error_rate)) flag = True return flag def persontest(): result_list = ['not at all', 'in small doses', 'in large doses'] fly_miles = float(input('Number of frequent flyer miles per year?')) game_times = float(input('Percentage of game time played ?')) ice_cream = float(input('How many ice cream a week do you eat?')) measured_tar_pt = [fly_miles, game_times, ice_cream] array_measured_tar_pt = array(measured_tar_pt) normal_array_measured_tar_pt = (array_measured_tar_pt - mix_value_main) / d_value_main result_pt = distance(normal_array_measured_tar_pt, array_normal_main, label_org_main, 3) print('you will probably like this person : ', result_list[result_pt - 1]) def train(normal_array_tr, label_tr): for i1 in range(1, 11): for i2 in range(1, 11): for i3 in range(1, 11): added_array = add_weight(normal_array_tr, i1, i2, i3) if mytest(added_array, label_tr): print('i1 = %d i2 = %d i3 = %d' % (i1, i2, i3)) if __name__ == '__main__': array_org_main, label_org_main = parse_file() array_normal_main, d_value_main, mix_value_main = normal(array_org_main) # matlab_analys(array_normal_main, label_org_main) # array_added_main = add_weight(array_normal_main) # matlab_analys(array_added_main, label_org_main) # measured_tar_main = [11145, 3.410627, 0.631838] # array_measured_tar_main = array(measured_tar_main) # normal_array_measured_tar_main = (array_measured_tar_main - mix_value_main) / d_value_main # result = distance(normal_array_measured_tar_main, array_normal_main, label_org_main, 3) # print(result) # test(array_normal_main, label_org_main) # mytest(array_normal_main, label_org_main) # persontest() train(array_normal_main, label_org_main)
以上是全部代码
结合前面的文章使用