在遥感影像的目标检测中,我们通常希望将检测结果与原始影像进行叠加,以便查看和分析。最简单的方法就是将检测结果输出成shapefile的形式,下面提供一种基于Python的转换方法
import os
import gdal
import geopandas as gpd
import ogr
import osr
import rasterio.features
import shapely
def box_list2shp(det_file, img_file, out_shapefile):
"""
将一系列坐标点的边界框数据转换成shapefile
:param det_file: 输入的检测结果文件,每一行为一个检测框(x1, y1, x2, y2, x3, y3, x4, y4, label)
:param img_file: 输入的原始影像路径
:param out_shapefile: 输出的矢量路径
"""
bbox_data, label_data = get_box_from_txt(det_file, img_file)
# bbox_data, label_data = get_box_from_array(det_file, img_file)
with rasterio.open(img_file) as raster: # 从原始影像中获取投影和几何信息
crs = raster.crs
polygon_list = []
for i in range(len(bbox_data)):
polygon = shapely.geometry.Polygon(bbox_data[i])
polygon_list.append(polygon)
out_data = gpd.GeoSeries(polygon_list, index=label_data, crs=crs)
out_data.to_file(out_shapefile, driver='ESRI Shapefile', encoding='utf-8')
print("successfully convert box-list to shapefile")
其中,get_box_from_txt函数是从DOTA格式的txt文件中读取检测框坐标和对应标签,代码如下:
import gdal
def get_box_from_txt(txt_file, img_file):
"""
从txt文件中读取目标检测的边界框坐标点和标签信息
:param img_file: 原始影像数据,为了获取投影信息
:return box_data: 由图像坐标点组成的一系列边界框
:return label_data : 边界框对应的标签信息
"""
dataset = gdal.Open(img_file)
with open(txt_file, 'r', encoding='utf-8') as f:
bbox_data = []
label_data = []
for line in f.readlines():
curLine = line.strip().split(" ")
x1 = float(curLine[0])
y1 = float(curLine[1])
x2 = float(curLine[2])
y2 = float(curLine[3])
x3 = float(curLine[4])
y3 = float(curLine[5])
x4 = float(curLine[6])
y4 = float(curLine[7])
if dataset.GetProjection() is None: # 没有投影则需要进行这种转换
box = [(x1, -y1), (x2, -y2), (x3, -y3), (x4, -y4)]
else: # 图像坐标转地理坐标
x1, y1 = imagexy2geo(dataset, y1, x1)
x2, y2 = imagexy2geo(dataset, y2, x2)
x3, y3 = imagexy2geo(dataset, y3, x3)
x4, y4 = imagexy2geo(dataset, y4, x4)
box = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
label = curLine[8]
bbox_data.append(box)
label_data.append(label)
return bbox_data, label_data
get_box_from_array函数是从检测结果的坐标数组中读取检测框和对应标签,代码如下:
def get_box_from_array(result_array, classnames, img_file):
"""
从输出的list中读取目标检测的边界框坐标点和标签信息
:param result_array: 检测结果的list文件,格式为[ class_num, 9, obj_num], 9代表四个坐标点和置信度
:param classnames: 类别列表
:param img_file: 原始影像数据,为了获取投影信息
:return box_data: 由图像坐标点组成的一系列边界框
:return label_data : 边界框对应的标签信息
"""
dataset = gdal.Open(img_file)
bbox_data = []
label_data = []
for idx, class_result in enumerate(result_array):
for result in class_result:
x1 = float(result[0])
y1 = float(result[1])
x2 = float(result[2])
y2 = float(result[3])
x3 = float(result[4])
y3 = float(result[5])
x4 = float(result[6])
y4 = float(result[7])
if dataset.GetProjection() is None:
box = [(x1, -y1), (x2, -y2), (x3, -y3), (x4, -y4)]
else:
x1, y1 = imagexy2geo(dataset, y1, x1)
x2, y2 = imagexy2geo(dataset, y2, x2)
x3, y3 = imagexy2geo(dataset, y3, x3)
x4, y4 = imagexy2geo(dataset, y4, x4)
box = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
label = classnames[idx]
bbox_data.append(box)
label_data.append(label)
return bbox_data, label_data
imagexy2geo 图像坐标转地理坐标的代码如下:
import gdal
def imagexy2geo(dataset, row, col):
'''
根据GDAL的六参数模型将影像图上坐标(行列号)转为投影坐标或地理坐标(根据具体数据的坐标系统转换)
:param dataset: GDAL地理数据,gdal.Open("xxx.tif")
:param row: 像素的行号
:param col: 像素的列号
:return: 行列号(row, col)对应的投影坐标或地理坐标(x, y)
'''
trans = dataset.GetGeoTransform()
px = trans[0] + col * trans[1] + row * trans[2]
py = trans[3] + col * trans[4] + row * trans[5]
return px, py
注:记得在文件中import相关包