



# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
import sys
import re

class VQAEval:
	def __init__(self, vqa, vqaRes, n=2):
		self.n 			  = n
		self.accuracy     = {
    } #精确度
		self.evalQA       = {
    } #需要评估的问答对
		self.evalQuesType = {
    } #需要评估的问题类型
		self.evalAnsType  = {
    } #需要评估的答案类型
		self.vqa 		  = vqa #vqa数据集加载
		self.vqaRes       = vqaRes #vqa结果集加载
		self.params		  = {
    'question_id': vqa.getQuesIds()} #将问题id作为参数集
		self.contractions = {
    "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
							 "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
							 "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
							 "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
							 "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
							 "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
							 "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
							 "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
							 "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
							 "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
							 "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
							 "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
							 "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
							 "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
							 "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
							 "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
							 "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
							 "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
							 "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
							 "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
							 "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
							 "youll": "you'll", "youre": "you're", "youve": "you've"}
		self.manualMap    = {
     'none': '0',
							  'zero': '0',
							  'one': '1',
							  'two': '2',
							  'three': '3',
							  'four': '4',
							  'five': '5',
							  'six': '6',
							  'seven': '7',
							  'eight': '8',
							  'nine': '9',
							  'ten': '10'
		self.articles     = ['a',

		self.periodStrip  = re.compile("(?!<=\d)(\.)(?!\d)")
		self.commaStrip   = re.compile("(\d)(\,)(\d)")
		self.punct        = [';', r"/", '[', ']', '"', '{', '}',
							 '(', ')', '=', '+', '\\', '_', '-',
							 '>', '<', '@', '`', ',', '?', '!']

	def evaluate(self, quesIds=None):
		quesIds: 要在其上计算的问题id列表评估。通常情况下,如果你的预测没有进行下去整个数据集,
		if quesIds == None:
			quesIds = [quesId for quesId in self.params['question_id']]
		gts = {
		res = {
		for quesId in quesIds:
			gts[quesId] = self.vqa.qa[quesId]
			res[quesId] = self.vqaRes.qa[quesId]
		# =================================================
		# 计算的准确性
		# =================================================
		accQA       = []
		accQuesType = {
		accAnsType  = {
		#print "computing accuracy"
		step = 0
		for quesId in quesIds:
			resAns      = res[quesId]['answer'] #按问题id选中答案
			resAns      = resAns.replace('\n', ' ') #去掉\n,\t符号
			resAns      = resAns.replace('\t', ' ')
			resAns      = resAns.strip() #去掉最后一个空符号
			resAns      = self.processPunctuation(resAns) #对答案预处理1
			resAns      = self.processDigitArticle(resAns)#对答案预处理2
			gtAcc  = [] #准确度列表
			gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']] #获取答案字典中每一个answer
			if len(set(gtAnswers)) > 1:  #如果答案词数大于1
				for ansDic in gts[quesId]['answers']: 
					ansDic['answer'] = self.processPunctuation(ansDic['answer']) #处理后的答案列表
			for gtAnsDatum in gts[quesId]['answers']: 
				otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum] #??item难道不等于gtAnsDatum
				matchingAns = [item for item in otherGTAns if item['answer']==resAns] #匹配答案
				acc = min(1, float(len(matchingAns))/3) #3个答对就是对
				gtAcc.append(acc) #将准确度保存在gtAcc在
			quesType    = gts[quesId]['question_type'] #获取问题类型
			ansType     = gts[quesId]['answer_type'] #获取答案类型
			avgGTAcc = float(sum(gtAcc))/len(gtAcc) #获取平均准确度
			accQA.append(avgGTAcc) #将其加入accQA列表中
			if quesType not in accQuesType: #将问题类型添加到问题类型准确度计算列表中
				accQuesType[quesType] = []
			accQuesType[quesType].append(avgGTAcc) #将其的准确度添加相应位置
			if ansType not in accAnsType:
				accAnsType[ansType] = []
			self.setEvalQA(quesId, avgGTAcc)
			self.setEvalQuesType(quesId, quesType, avgGTAcc) #评估问题类型准确度:传入问题id,问题类型,平均准确度
			self.setEvalAnsType(quesId, ansType, avgGTAcc) #评估答案类型:传入问题id,答案类型,平均准确度
			if step%100 == 0:
			step = step + 1

		self.setAccuracy(accQA, accQuesType, accAnsType)
		#print "Done computing accuracy"
	def processPunctuation(self, inText):
		outText = inText
		for p in self.punct:
			if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
				outText = outText.replace(p, '')
				outText = outText.replace(p, ' ')	
		outText = self.periodStrip.sub("",
		return outText
	def processDigitArticle(self, inText):
		outText = []
		tempText = inText.lower().split()
		for word in tempText:
			word = self.manualMap.setdefault(word, word)
			if word not in self.articles:
		for wordId, word in enumerate(outText):
			if word in self.contractions: 
				outText[wordId] = self.contractions[word]
		outText = ' '.join(outText)
		return outText

	def setAccuracy(self, accQA, accQuesType, accAnsType):
		self.accuracy['overall']         = round(100*float(sum(accQA))/len(accQA), self.n)
		self.accuracy['perQuestionType'] = {
    quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
		self.accuracy['perAnswerType']   = {
    ansType:  round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
	def setEvalQA(self, quesId, acc):
		self.evalQA[quesId] = round(100*acc, self.n)

	def setEvalQuesType(self, quesId, quesType, acc):
		if quesType not in self.evalQuesType:
			self.evalQuesType[quesType] = {
		self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
	def setEvalAnsType(self, quesId, ansType, acc):
		if ansType not in self.evalAnsType:
			self.evalAnsType[ansType] = {
		self.evalAnsType[ansType][quesId] = round(100*acc, self.n)

	def updateProgress(self, progress): #更新步骤
		barLength = 20
		status = ""
		if isinstance(progress, int):
			progress = float(progress)
		if not isinstance(progress, float):
			progress = 0
			status = "error: progress var must be float\r\n"
		if progress < 0:
			progress = 0
			status = "Halt...\r\n"
		if progress >= 1:
			progress = 1
			status = "Done...\r\n"
		block = int(round(barLength*progress))
		text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
		# sys.stdout.write(text)
		# sys.stdout.flush()

2. 结果分析



问题类型 含义 答案类型 问题数
(‘what time’, 0.58), 几点了? other/number:1709 2914
(‘what number is’, 3.74), 数字是多少? other/number:1632 1668
(‘how’, 4.47), 怎么样? other/number:1596 4740
(‘where is the’, 5.12), 在哪里? other/number:5 6734
(‘what brand’, 6.09), 什么牌子? other/number:4 1600
(‘where are the’, 6.37), 在哪里? other 2161
(‘who is’, 6.91), 是谁? other/number:16 2154
(‘why is the’, 7.77), 为什么是这样? other/number:3 1544
(‘why’, 8.42), 为什么? other/number:5 1544
(‘what is on the’, 8.7), 那上面是什么? other/number:3 4254
(‘what is the name’, 9.59), 名字是什么? other/number:18 1618
(‘what is the’, 11.84), 这是什么? other/number:1459 24502
(‘what is’, 12.06), 是什么?做什么? other/number:85 13561
(‘what are the’, 12.32), 什么 other/number:145 7225
(‘what’, 13.9), 什么 other/number:1552 34608
(‘what is in the’, 14.98), 什么 other 3990
(‘what type of’, 15.12), 什么 other/number:8 7962
(‘what does the’, 15.67), 什么 other/number:175 4075
(‘what kind of’, 15.69), **什么种类? other/number:8 11192
(‘what is the woman’, 16.32), 这个女人**什么 other 1706
(‘what are’, 19.09), 什么? other/number:145 3277


('what is the man', 21.24),
('which', 24), 
('what is this', 24.9), 
('what is the person', 25.09), 
('how many people are in', 32.2),  
('what color are the', 32.59), 
('what color', 33.25), 
('how many people are', 34.58),  
('how many', 35.86),  number
('what color is the', 36.14), 
('what color is', 38.62), 
('what is the color of the', 38.82), 
('none of the above', 40.6),  
('what animal is', 43.06), 
('is the person', 55.66), 
('is that a', 57.48), 
('is it', 59.16), 
('is the', 59.57), 
('is this person', 59.81), 
('what room is', 60.63), 
('is he', 61.01), 
('is this', 61.06), 
('are the', 61.46), 
('are there any', 61.47), 
('is the man', 61.48), 
('are they', 61.48), 
('is the woman', 61.55), 
('was', 61.88), 
('is this a', 62.08), 
('are there', 62.18), 
('is there a', 62.21), 
('can you', 62.48), 
('is this an', 62.67), 
('are', 63.01), 
('is there', 63.03), 
('is', 63.1), 
('do you', 63.14), 
('are these', 63.22), 
('does the', 63.68), 
('has', 63.88), 
('do', 64.63), 
('does this', 64.82), 
('what sport is', 67.96), 
('could', 71.98)

由结果可以看出:what time,what number is,how,where is the,what brand



所以我个人理解针对number类型的回答,平均准确率不高的原因不是关于图像中询问’how many’这类回答,而是针对图像中出现的数字识别?

针对‘where is the’的提问,主要是识别图像中的对象的位置,所以添加了关系之后,我们再分析下其结果。添加了关系的murel结果分析

