人工智能课程的小组作业
写了不到两天吧,终于在前人的指领下,磕磕绊绊地完成了勉强还能看的结果。
用了DNN和CNN两种网络,效果都不是很好,可能是训练集和输入的图片不尽相似吧。
在代码后半部分,由于比较仓促,许多赘余的地方没有删减,等有时间了一定要polish下。
现在能够实现的功能:
- 对表格进行分割,二值化并保存至本地
- 读取图片,识别数字(仅限于单一字符)
下面是本人简陋的代码了,各位见笑
from numpy.lib.function_base import average
from openpyxl import Workbook
import cv2
import numpy as np
import io
import tensorflow as tf
import csv
import openpyxl
img=cv2.imread('./4.png',1)
def showing(img):
cv2.imshow('showing.jpg',img)
cv2.waitKey(0)
cv2.destroyAllWindows()
def horizon(): #取出水平线的坐标
height=img.shape[0]
horizontal_lines=[]
i=0
while True:
if abs(np.average(img[i])-np.average(img[i+1]))>0.6*255:
horizontal_lines.append(i)
while True:
i+=1
if np.average(img[i])>0.6*255:
break
else:
i+=1
if i == height-2:
break
return horizontal_lines
def vertical(): #取出竖直线的坐标
width=img.shape[1]
vertical_lines=[]
i=0
while True:
if abs(np.average(img[:,i])-np.average(img[:,i+1]))>0.6*255:
vertical_lines.append(i)
while True:
i+=1
if np.average(img[:,i])>0.6*255:
break
else:
i+=1
if i == width-2:
break
return vertical_lines
def sections():#确定各个表格边框的坐标
global label_y
global label_x
label_y=len(horizon())-1
label_x=len(vertical())-1
global intersection
intersection=[]
for i in vertical():
for j in horizon():
intersection.append((i,j))
sections=[]
for i in range(len(intersection)-1):
try:
if (((i+1)/(label_y+1))-int((i+1)/(label_y+1))!=0)and(i+label_y+2<=len(horizon())*len(vertical())):
sections.append((intersection[i],intersection[i+label_y+2]))
except:
pass
return sections
def segment1():#将图片按照表格分割并保存在list里面
global list_segmented_img
list_segmented_img=[]
for i in range(len(sections())):
coordinates=sections()[i]
segmented_img=img[int(coordinates[0][1]):int(coordinates[1][1]),int(coordinates[0][0]):int(coordinates[1][0])]
list_segmented_img.append(segmented_img)
return list_segmented_img
def segment2():#将上述图片裁剪为正方形
global list_cropped_img
list_cropped_img=[]
for segmented_img in list_segmented_img:
height=segmented_img.shape[0]
width=segmented_img.shape[1]
if width>height:#考虑两边长的大小情况
col_start=abs((height-width)//2)
col_end=col_start+height
cropped_img=segmented_img[20:,col_start:col_end]
cropped_img=cv2.resize(cropped_img,(28,28))
list_cropped_img.append(cropped_img)
else:
row_start=abs((height-width)//2)
row_end=row_start+width
cropped_img=segmented_img[20:,row_start:row_end]
cropped_img=cv2.resize(cropped_img,(28,28))
list_cropped_img.append(cropped_img)
return list_cropped_img
def blacknwhite(imge):
gray_img = cv2.cvtColor(imge,cv2.COLOR_BGR2GRAY)
(thresh, blacknwhite) = cv2.threshold(gray_img,255,255,cv2.THRESH_BINARY|cv2.THRESH_OTSU)
blacknwhite = cv2.bitwise_not(blacknwhite)
return blacknwhite
def blacknwhite_all():
global list_cropped_refined_img
list_cropped_refined_img=[]
for i in list_cropped_img:
list_cropped_refined_img.append(blacknwhite(i))
return list_cropped_refined_img
def save():
for i in range(len(list_cropped_refined_img)):
cv2.imwrite(f"{
i+1}.jpg",list_cropped_refined_img[i])
sections()
print(f'The size of the label is {
label_x} * {
label_y}={
label_x*label_y}')
segment1()
segment2()
blacknwhite_all()
save()
#########################上述过程为对输入图像的裁剪处理,接下来将使用DNN方法识别各个图像中心的数字##############################
def train():#训练深度学习模型
global model
mnist=tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()
#normalize图片处理
x_train=tf.keras.utils.normalize(x_train,axis=1)
x_test=tf.keras.utils.normalize(x_test,axis=1)
#创建DNN模型 128神经元*128神经元*10个神经元
model=tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128,activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(128,activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10,activation=tf.nn.softmax))
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
model.fit(x_train,y_train,epochs=15)
(val_loss,val_acc)=model.evaluate(x_test,y_test)
print('loss=',val_loss,'acc=',val_acc)
def read_img(img_path):#读取图片并Flatten
img=cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)
img=img/255
img=img.reshape(-1,28,28,1)
return img
def predicting():
global result1 #识别结果
result1=[]
count=int(label_x*label_y)
# success=0
# fail=0
for i in range(count):
prediction = np.argmax(model.predict(read_img(f'./{
i+1}.jpg')))
result1.append(prediction)
# print(prediction)
# print(i%10)
# if prediction == (i+1)%10:
# success+=1
# print("success")
# else:
# fail+=1
# print("fail")
# print(success/(success+fail))
train()
predicting()
##################为改善识别效果 使用CNN方法进行识别############################
import re
import tensorflow as tf
# from tensorflow.python.keras.backend import reshape
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, Conv2D, Flatten, MaxPooling2D, Dropout
import numpy as np
#构建DNN模型 32神经元*32神经元*10神经元
def predict_cnn():
global result2
model=Sequential()
model.add(Conv2D(10,(5,5),activation='relu', input_shape=(28,28,1)))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Conv2D(20,(5,5),activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(100,activation='relu'))
model.add(Dense(10,activation='softmax'))
model.compile(optimizer='rmsprop',loss=tf.keras.losses.categorical_crossentropy,metrics=['accuracy'])
#normalize图片处理
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()
normalized_x_train=tf.keras.utils.normalize(x_train)
normalized_x_train=normalized_x_train.reshape(-1,28,28,1)
normalized_x_test=tf.keras.utils.normalize(x_test)
normalized_x_test=normalized_x_test.reshape(-1,28,28,1)
#one hot 标签处理
one_hot_y_train=tf.one_hot(y_train,10)
one_hot_y_test=tf.one_hot(y_test,10)
train_result=model.fit(normalized_x_train,one_hot_y_train,epochs=15,validation_data=(normalized_x_test,one_hot_y_test))
import cv2
def read_img(path):
img=cv2.imread(path,cv2.IMREAD_GRAYSCALE)
img=img/255
img=img.reshape(-1,28,28,1)
return img
result2=[]
for i in range(10):
prediction = np.argmax(model.predict(read_img(f'./{
i+1}.jpg')))
result2.append(prediction)
predict_cnn()
###################以上是将各图片进行识别的过程,下面将所得结果输出为excel文件#########
def save_result(result,i):
result_array=np.array(result,dtype=str)
result_array=result_array.reshape(label_x,label_y).transpose()#还原到原来的形状
print(result_array)
workbook = Workbook()
# 默认sheet
sheet = workbook.active
sheet.title = "默认sheet"
for data in result_array:
sheet.append(data.tolist())#写入
workbook.save(f'./{
i}.xlsx')
save_result(result1,1)
save_result(result2,2)