一、任务描述
看到一张图像,你的大脑可以很容易地分辨出图像是关于什么的,但是计算机能分辨出图像所代表的内容吗?随着深度学习技术的进步、庞大数据集的可用性和计算机能力,我们可以构建可以为图像生成说明的模型。
我们将使用CNN(卷积神经网络) 和 LSTM(长期短期记忆)来实现字幕生成器。图像特征将从 Xception 中提取,Xception 是在 imagenet 数据集上训练的 CNN 模型,然后我们将特征输入 LSTM 模型,该模型将负责生成图像说明。
二、数据集说明
这里使用 Flickr_8K 数据集。虽然还有其他大型数据集,如 Flickr_30K 和 MSCOCO 数据集,但仅训练网络可能需要几周时间,所以我们将使用小型 Flickr8k 数据集。但是庞大数据集的优势在于我们可以构建更好的模型。
Flickr_8K 数据集为基于句子的图像描述和搜索引入了一个新的基准集合,由 8,000 张图像组成,每张图像都与五个不同的标题配对,这些标题提供了对显着实体和事件的清晰描述。
数据集下载地址
链接:https://pan.baidu.com/s/1aG3CYioORpPdXC89_F_s3A
提取码:q9el
数据集内的Flickr8k_Dataset.zip是图像文件压缩包。
Flickr8k_text.zip是英文描述等文件。
flickr8kcn文件夹是对应的中文描述等文件,下面代码是根据英语等训练的,可以自行修改,以中文进行训练。
三、模型结构
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_2 (InputLayer) [(None, 32)] 0 []
input_1 (InputLayer) [(None, 2048)] 0 []
embedding (Embedding) (None, 32, 256) 1939712 ['input_2[0][0]']
dropout (Dropout) (None, 2048) 0 ['input_1[0][0]']
dropout_1 (Dropout) (None, 32, 256) 0 ['embedding[0][0]']
dense (Dense) (None, 256) 524544 ['dropout[0][0]']
lstm (LSTM) (None, 256) 525312 ['dropout_1[0][0]']
add (Add) (None, 256) 0 ['dense[0][0]',
'lstm[0][0]']
dense_1 (Dense) (None, 256) 65792 ['add[0][0]']
dense_2 (Dense) (None, 7577) 1947289 ['dense_1[0][0]']
==================================================================================================
Total params: 5,002,649
Trainable params: 5,002,649
Non-trainable params: 0
__________________________________________________________________________________________________
四、训练模型
1、参考代码
这里在图像预处理的时候使用了Xception进行特征抽取,并保存到features.p文件。
import string
import numpy as np
from PIL import Image
import os
from pickle import dump, load
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.xception import Xception, preprocess_input
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import add
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Dense, LSTM, Embedding, Dropout
from tensorflow.keras.utils import plot_model
# small library for seeing the progress of loops.
from tqdm import tqdm_notebook as tqdm
#tqdm().pandas()
# Loading a text file into memory
def load_doc(filename):
# Opening the file as read only
file = open(filename, 'r')
text = file.read()
file.close()
return text
# get all imgs with their captions
def all_img_captions(filename):
file = load_doc(filename)
captions = file.split('\n')
descriptions ={}
for caption in captions[:-1]:
img, caption = caption.split('\t')
if img[:-2] not in descriptions:
descriptions[img[:-2]] = [caption]
else:
descriptions[img[:-2]].append(caption)
return descriptions
##Data cleaning- lower casing, removing puntuations and words containing numbers
#此函数获取所有描述并执行数据清理。这是我们处理文本数据时的重要一步,根据我们的目标,我们决定要对文本执行哪种类型的清理。在我们的例子中,我们将删除标点符号,将所有文本转换为小写,并删除包含数字的单词。
def cleaning_text(captions):
table = str.maketrans('','',string.punctuation)
for img,caps in captions.items():
for i,img_caption in enumerate(caps):
img_caption.replace("-"," ")
desc = img_caption.split()
#converts to lower case
desc = [word.lower() for word in desc]
#remove punctuation from each token
desc = [word.translate(table) for word in desc]
#remove hanging 's and a
desc = [word for word in desc if(len(word)>1)]
#remove tokens with numbers in them
desc = [word for word in desc if(word.isalpha())]
#convert back to string
img_caption = ' '.join(desc)
captions[img][i]= img_caption
return captions
def text_vocabulary(descriptions):
# build vocabulary of all unique words
vocab = set()
for key in descriptions.keys():
[vocab.update(d.split()) for d in descriptions[key]]
return vocab
#此函数将创建一个已预处理的所有描述的列表并将它们存储到一个文件中。
def save_descriptions(descriptions, filename):
lines = list()
for key, desc_list in descriptions.items():
for desc in desc_list:
lines.append(key + '\t' + desc )
data = "\n".join(lines)
file = open(filename,"w")
file.write(data)
file.close()
# all_train_captions = []
# for key, val in descriptions.items():
# for cap in val:
# all_train_captions.append(cap)
# # Consider only words which occur at least 8 times in the corpus
# word_count_threshold = 8
# word_counts = {}
# nsents = 0
# for sent in all_train_captions:
# nsents += 1
# for w in sent.split(' '):
# word_counts[w] = word_counts.get(w, 0) + 1
# vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold]
# print('preprocessed words %d ' % len(vocab))
dataset_text = "Flickr8k_text"
dataset_images = "Flicker8k_Dataset"
#we prepare our text data
filename = dataset_text + "/" + "Flickr8k.token.txt"
#loading the file that contains all data
#mapping them into descriptions dictionary img to 5 captions
descriptions = all_img_captions(filename)
print("Length of descriptions =" ,len(descriptions))
#cleaning the descriptions
clean_descriptions = cleaning_text(descriptions)
#building vocabulary
vocabulary = text_vocabulary(clean_descriptions)
print("Length of vocabulary = ", len(vocabulary))
#saving each description to file
save_descriptions(clean_descriptions, "descriptions.txt")
def extract_features(directory):
model = Xception(include_top=False, pooling='avg')
features = {}
for img in tqdm(os.listdir(directory)):
filename = directory + "/" + img
image = Image.open(filename)
image = image.resize((299, 299))
image = np.expand_dims(image, axis=0)
# image = preprocess_input(image)
image = image / 127.5
image = image - 1.0
feature = model.predict(image)
features[img] = feature
return features
# 提取特征向量 2048 feature vector
# 如果已经提取好了,可以注释掉下面两句
features = extract_features(dataset_images)
dump(features, open("features.p","wb"))
features = load(open("features.p","rb"))
# load the data
def load_photos(filename):
file = load_doc(filename)
photos = file.split("\n")[:-1]
return photos
def load_clean_descriptions(filename, photos):
# loading clean_descriptions
file = load_doc(filename)
descriptions = {}
for line in file.split("\n"):
words = line.split()
if len(words) < 1:
continue
image, image_caption = words[0], words[1:]
if image in photos:
if image not in descriptions:
descriptions[image] = []
desc = '<start> ' + " ".join(image_caption) + ' <end>'
descriptions[image].append(desc)
return descriptions
def load_features(photos):
# loading all features
all_features = load(open("features.p", "rb"))
# selecting only needed features
features = {k: all_features[k] for k in photos}
return features
filename = dataset_text + "/" + "Flickr_8k.trainImages.txt"
#train = loading_data(filename)
train_imgs = load_photos(filename)
train_descriptions = load_clean_descriptions("descriptions.txt", train_imgs)
train_features = load_features(train_imgs)
#converting dictionary to clean list of descriptions
def dict_to_list(descriptions):
all_desc = []
for key in descriptions.keys():
[all_desc.append(d) for d in descriptions[key]]
return all_desc
#creating tokenizer class
#this will vectorise text corpus
#each integer will represent token in dictionary
from keras.preprocessing.text import Tokenizer
def create_tokenizer(descriptions):
desc_list = dict_to_list(descriptions)
tokenizer = Tokenizer()
tokenizer.fit_on_texts(desc_list)
return tokenizer
# give each word a index, and store that into tokenizer.p pickle file
tokenizer = create_tokenizer(train_descriptions)
dump(tokenizer, open('tokenizer.p', 'wb'))
vocab_size = len(tokenizer.word_index) + 1
vocab_size
#calculate maximum length of descriptions
def max_length(descriptions):
desc_list = dict_to_list(descriptions)
return max(len(d.split()) for d in desc_list)
max_length = max_length(descriptions)
max_length
print(features['1000268201_693b08cb0e.jpg'][0])
# Define the model
#1 Photo feature extractor - we extracted features from pretrained model Xception.
#2 Sequence processor - word embedding layer that handles text, followed by LSTM
#3 Decoder - Both 1 and 2 model produce fixed length vector. They are merged together and processed by dense layer to make final prediction
#create input-output sequence pairs from the image description.
#data generator, used by model.fit_generator()
def data_generator(descriptions, features, tokenizer, max_length):
while 1:
for key, description_list in descriptions.items():
#retrieve photo features
feature = features[key][0]
input_image, input_sequence, output_word = create_sequences(tokenizer, max_length, description_list, feature)
yield [[input_image, input_sequence], output_word]
def create_sequences(tokenizer, max_length, desc_list, feature):
X1, X2, y = list(), list(), list()
# walk through each description for the image
for desc in desc_list:
# encode the sequence
seq = tokenizer.texts_to_sequences([desc])[0]
# split one sequence into multiple X,y pairs
for i in range(1, len(seq)):
# split into input and output pair
in_seq, out_seq = seq[:i], seq[i]
# pad input sequence
in_seq = pad_sequences([in_seq], maxlen=max_length)[0]
# encode output sequence
out_seq = to_categorical([out_seq], num_classes=vocab_size)[0]
# store
X1.append(feature)
X2.append(in_seq)
y.append(out_seq)
return np.array(X1), np.array(X2), np.array(y)
[a,b],c = next(data_generator(train_descriptions, features, tokenizer, max_length))
a.shape, b.shape, c.shape
# 定义模型
def define_model(vocab_size, max_length):
# features from the CNN model squeezed from 2048 to 256 nodes
inputs1 = Input(shape=(2048,))
fe1 = Dropout(0.5)(inputs1)
fe2 = Dense(256, activation='relu')(fe1)
# LSTM sequence model
inputs2 = Input(shape=(max_length,))
se1 = Embedding(vocab_size, 256, mask_zero=True)(inputs2)
se2 = Dropout(0.5)(se1)
se3 = LSTM(256)(se2)
# Merging both models
decoder1 = add([fe2, se3])
decoder2 = Dense(256, activation='relu')(decoder1)
outputs = Dense(vocab_size, activation='softmax')(decoder2)
# tie it together [image, seq] [word]
model = Model(inputs=[inputs1, inputs2], outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer='adam')
# summarize model
print(model.summary())
#plot_model(model, to_file='model.png', show_shapes=True)
return model
# 开始创建并训练模型
print('Dataset: ', len(train_imgs))
print('Descriptions: train=', len(train_descriptions))
print('Photos: train=', len(train_features))
print('Vocabulary Size:', vocab_size)
print('Description Length: ', max_length)
model = define_model(vocab_size, max_length)
epochs = 10
steps = len(train_descriptions)
# 创建保存model的文件夹
os.mkdir("models")
for i in range(epochs):
generator = data_generator(train_descriptions, train_features, tokenizer, max_length)
model.fit_generator(generator, epochs=1, steps_per_epoch= steps, verbose=1)
model.save("models/model_" + str(i) + ".h5")
2、测试模型
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.applications.xception import Xception
from keras.models import load_model
from pickle import load
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import argparse
ap = argparse.ArgumentParser()
ap.add_argument('-i', '--image', required=False, help="Image Path", default='test/3291255271_a185eba408.jpg')
args = vars(ap.parse_args())
img_path = args['image']
def extract_features(filename, model):
try:
image = Image.open(filename)
except:
print("ERROR: Couldn't open image! Make sure the image path and extension is correct")
image = image.resize((299,299))
image = np.array(image)
# for images that has 4 channels, we convert them into 3 channels
if image.shape[2] == 4:
image = image[..., :3]
image = np.expand_dims(image, axis=0)
image = image/127.5
image = image - 1.0
feature = model.predict(image)
return feature
def word_for_id(integer, tokenizer):
for word, index in tokenizer.word_index.items():
if index == integer:
return word
return None
def generate_desc(model, tokenizer, photo, max_length):
in_text = 'start'
for i in range(max_length):
sequence = tokenizer.texts_to_sequences([in_text])[0]
sequence = pad_sequences([sequence], maxlen=max_length)
pred = model.predict([photo,sequence], verbose=0)
pred = np.argmax(pred)
word = word_for_id(pred, tokenizer)
if word is None:
break
in_text += ' ' + word
if word == 'end':
break
return in_text
#path = 'Flicker8k_Dataset/111537222_07e56d5a30.jpg'
max_length = 32
tokenizer = load(open("tokenizer.p","rb"))
model = load_model('models/model_9.h5')
xception_model = Xception(include_top=False, pooling="avg")
photo = extract_features(img_path, xception_model)
img = Image.open(img_path)
description = generate_desc(model, tokenizer, photo, max_length)
print("\n\n")
print(description)
plt.imshow(img)
3、测试结果
由于训练有些慢,没有训练太多epochs,导致测试结果有些奇怪,但是为类似工作提供一个参考思路。