Traffic-Net简介
Traffic-Net是交通图像的数据集,其收集目的是确保可以训练机器学习系统以检测交通状况并提供实时监视,分析和警报。
样本集下载地址:https://github.com/OlafenwaMoses/Traffic-Net/releases/download/1.0/trafficnet_dataset_v1.zip
这是Traffic-Net数据集的第一个版本。它包含4个类别的4,400张图像。此版本中包含的类为:
- 事故
- 密集交通
- 火
- 稀疏交通
每个类别有1,100张图像,其中900张图像进行训练,200张图像进行测试。
代码参考
需要安装imageai的包,并且在tensorflow2下运行会有一些错误,看log修改对应的错误的imageai的包的文件
1、from tensorflow.python.keras.utils 需要改成from tensorflow.keras.utils
2、optimizer = Adam(lr=self.__initial_learning_rate, decay=1e-4) 改成 optimizer = tf.optimizers.Adam(lr=self.__initial_learning_rate, decay=1e-4)
from io import open
import requests
import shutil
from zipfile import ZipFile
from imageai.Prediction.Custom import ModelTraining, CustomImagePrediction
import os
execution_path = os.getcwd()
SOURCE_PATH = "https://github.com/OlafenwaMoses/Traffic-Net/releases/download/1.0/trafficnet_dataset_v1.zip"
FILE_DIR = os.path.join(execution_path, "trafficnet_dataset_v1.zip")
DATASET_DIR = os.path.join(execution_path, "trafficnet_dataset_v1.zip")
def download_traffic_net():
if (os.path.exists(FILE_DIR) == False):
print("Downloading trafficnet_dataset_v1.zip")
data = requests.get(SOURCE_PATH,
stream=True)
with open(FILE_DIR, "wb") as file:
shutil.copyfileobj(data.raw, file)
del data
extract = ZipFile(FILE_DIR)
extract.extractall(execution_path)
extract.close()
def train_traffic_net():
download_traffic_net()
trainer = ModelTraining()
trainer.setModelTypeAsResNet()
trainer.setDataDirectory("trafficnet_dataset_v1")
trainer.trainModel(num_objects=4, num_experiments=200, batch_size=32, save_full_model=True, enhance_data=True)
def run_predict():
predictor = CustomImagePrediction()
predictor.setModelPath(model_path="trafficnet_resnet_model_ex-055_acc-0.913750.h5")
predictor.setJsonPath(model_json="model_class.json")
predictor.loadFullModel(num_objects=4)
predictions, probabilities = predictor.predictImage(image_input="images/1.jpg", result_count=4)
for prediction, probability in zip(predictions, probabilities):
print(prediction, " : ", probability)
#Un-comment the line below to train your model
#train_traffic_net()
#Un-comment the line below to run predictions
run_predict()