1、导入库和数据
import graphlab
# Limit number of worker processes. This preserves system memory, which prevents hosted notebooks from crashing.
graphlab.set_runtime_config('GRAPHLAB_DEFAULT_NUM_PYLAMBDA_WORKERS', 4)
image_train = graphlab.SFrame('image_train_data/')
image_test = graphlab.SFrame('image_test_data/')
graphlab.canvas.set_target('browser')
image_train['image'].show()
2、直接把图像像素点值作为输入,训练逻辑回归分类器
raw_pixel_model = graphlab.logistic_classifier.create(image_train,target='label',
features=['image_array'])
前3张图像全部分类错误:
raw_pixel_model.evaluate(image_test)
分类准确率只有46%
3、采用深度学习后的特征作为输入,训练逻辑回归分类器
deep_learning_model = graphlab.load_model('http://s3.amazonaws.com/GraphLab-Datasets/deeplearning/imagenet_model_iter45')
image_train['deep_features'] = deep_learning_model.extract_features(image_train)
deep_features_model = graphlab.logistic_classifier.create(image_train,
features=['deep_features'],
target='label')
测试:
(1)前3张图像分类正确
image_test[0:3]['image'].show()
deep_features_model.predict(image_test[0:3])
(2)整体分类正确率提高到78%
deep_features_model.evaluate(image_test)