美丽城市–垃圾分类识别
- 比赛网址:AI研习社分类竞赛
0. 利用使用tensroflow-slim训练自己的图像分类
见往期博客分享:使用tensroflow-slim训练自己的图像分类
1. 下载官方数据集进行处理
- 数据集处理代码如下,所有路径均为绝对路径
import csv
import os
import cv2
filepath = '/*/classgabbage/data/train.csv'
file_pathname = '/*/classgabbage/data/train'
def read_path(file_pathname, special_filename, clas):
#遍历该目录下的所有图片文件
for filename in os.listdir(file_pathname):
if filename == special_filename:
img = cv2.imread(file_pathname+'/'+filename)
#####save figure
cv2.imwrite('/*/classgabbage/data/class'+"/"+clas+"/"+filename,img)
with open(filepath) as f:
reader = csv.reader(f)
for row in reader:
#print(reader.line_num, row)
#print(row[0])
if row[1] == 'cardboard':
clas = 'cardboard'
special_filename = row[0]
read_path(file_pathname, special_filename, clas)
print("finish" + row[0])
2. 依照第0步骤,完成所有流程
- 在此选用的模型是Inception-ResNet-v2;
- 训练脚本为:
CUDA_VISIBLE_DEVICES=0,2 python3 train_image_classifier.py \
--train_dir=/*/slim/class_gabbage/train_eval/training \
--dataset_dir=/*/slim/class_gabbage/data \
--dataset_name=gabbage \
--dataset_split_name=train \
--model_name=inception_resnet_v2 \
--checkpoint_path=/*/models-master/research/slim/class_gabbage/inception_resnet_v2_2016_08_30.ckpt \
--checkpoint_exclude_scopes=InceptionResnetV2/AuxLogits,InceptionResnetV2/Logits \
--trainable_scopes=InceptionResnetV2/AuxLogits,InceptionResnetV2/Logits \
--max_number_of_steps=5000 \
--learning_rate=0.004 \
--save_interval_secs=60 \
--save_summaries_secs=60 \
--log_every_n_steps=2 \
--train_image_size=300 \
--num_epochs_per_decay=200 \
--batch_size=32 \
--clone_on_cpu=False \
--num_clones=2 \
--optimizer=rmsprop
- validation脚本为:
CUDA_VISIBLE_DEVICES=1 python3 eval_image_classifier.py \
--checkpoint_path=/*/class_gabbage/train_eval/training \
--eval_dir=/*/slim/class_gabbage/train_eval/eval \
--dataset_name=gabbage \
--dataset_split_name=validation \
--dataset_dir=/*/research/slim/class_gabbage/data \
--model_name=inception_resnet_v2
剩余冻成模型等步骤均在第0步骤的链接中有所体现
3. 最后得到结果csv文件进行比赛提交
import tensorflow as tf
import numpy as np
import cv2
from datasets import dataset_utils
from IPython import display
import csv
import os
dataset_dir='/*/research/slim/class_gabbage/data'
model_dir ='/*/models-master/research/slim/class_gabbage/model/inception_resnet_v2.pb'
file_pathname = '/*/workspace/classgabbage/data/test1'
# from IPython import display
f= open('/*/models-master/research/slim/result1.csv','a+',newline='')
gd = tf.GraphDef.FromString(open(model_dir, 'rb').read())
inp, predictions = tf.import_graph_def(gd, return_elements = ['input:0','InceptionResnetV2/Logits/Predictions:0'])
with tf.Session(graph=inp.graph):
for filename in os.listdir(file_pathname):
print(filename)
img = cv2.imread(file_pathname+'/'+filename)
width = 300
height = 300
dim = (width, height)
# resize image to [-1,1] Maps pixel values to the range [-1, 1]
resized = (cv2.resize(img, dim)).astype(np.float) / 128 - 1
image_np_expanded = np.expand_dims(resized, axis=0)
x = predictions.eval(feed_dict={
inp: image_np_expanded})
label_map = dataset_utils.read_label_file(dataset_dir)
print("Top 1 Prediction: ",label_map[x.argmax()])
writer = csv.writer(f)
row=[filename[:-4],label_map[x.argmax()]]
writer.writerow(row)