参考链接:https://towardsdatascience.com/zonal-statistics-algorithm-with-python-in-4-steps-382a3b66648a
注意事项,栅格需要和矢量的坐标系保持一致,结果存在了cvs文件中
代码我稍微改了下,后续还会和矢量关联
import gdal
import ogr
import os
import numpy as np
import csv
import time
def boundingBoxToOffsets(bbox, geot):
col1 = int((bbox[0] - geot[0]) / geot[1])
col2 = int((bbox[1] - geot[0]) / geot[1]) + 1
row1 = int((bbox[3] - geot[3]) / geot[5])
row2 = int((bbox[2] - geot[3]) / geot[5]) + 1
return [row1, row2, col1, col2]
def geotFromOffsets(row_offset, col_offset, geot):
new_geot = [
geot[0] + (col_offset * geot[1]),
geot[1],
0.0,
geot[3] + (row_offset * geot[5]),
0.0,
geot[5]
]
return new_geot
def setFeatureStats(fid, min, max, mean, median, sd, sum, count, names=["min", "max", "mean", "median", "sd", "sum", "count", "id"]):
featstats = {
names[0]: min,
names[1]: max,
names[2]: mean,
names[3]: median,
names[4]: sd,
names[5]: sum,
names[6]: count,
names[7]: fid,
}
return featstats
def zonal(fn_raster, fn_zones, fn_csv):
mem_driver = ogr.GetDriverByName("Memory")
mem_driver_gdal = gdal.GetDriverByName("MEM")
shp_name = "temp"
# fn_raster = "C:/pyqgis/raster/USGS_NED_13_n45w116_IMG.img"
# fn_zones = "C:/temp/zonal_stats/zones.shp"
r_ds = gdal.Open(fn_raster)
p_ds = ogr.Open(fn_zones)
lyr = p_ds.GetLayer()
geot = r_ds.GetGeoTransform()
nodata = r_ds.GetRasterBand(1).GetNoDataValue()
zstats = []
p_feat = lyr.GetNextFeature()
niter = 0
while p_feat:
if p_feat.GetGeometryRef() is not None:
if os.path.exists(shp_name):
mem_driver.DeleteDataSource(shp_name)
tp_ds = mem_driver.CreateDataSource(shp_name)
tp_lyr = tp_ds.CreateLayer('polygons', None, ogr.wkbPolygon)
tp_lyr.CreateFeature(p_feat.Clone())
offsets = boundingBoxToOffsets(p_feat.GetGeometryRef().GetEnvelope(),\
geot)
new_geot = geotFromOffsets(offsets[0], offsets[2], geot)
tr_ds = mem_driver_gdal.Create(\
"", \
offsets[3] - offsets[2], \
offsets[1] - offsets[0], \
1, \
gdal.GDT_Byte)
tr_ds.SetGeoTransform(new_geot)
gdal.RasterizeLayer(tr_ds, [1], tp_lyr, burn_values=[1])
tr_array = tr_ds.ReadAsArray()
r_array = r_ds.GetRasterBand(1).ReadAsArray(\
offsets[2],\
offsets[0],\
offsets[3] - offsets[2],\
offsets[1] - offsets[0])
id = p_feat.GetFID()
if r_array is not None:
maskarray = np.ma.MaskedArray(\
r_array,\
maskarray=np.logical_or(r_array==nodata, np.logical_not(tr_array)))
if maskarray is not None:
zstats.append(setFeatureStats(\
id,\
maskarray.min(),\
maskarray.max(),\
maskarray.mean(),\
np.ma.median(maskarray),\
maskarray.std(),\
maskarray.sum(),\
maskarray.count()))
else:
zstats.append(setFeatureStats(\
id,\
nodata,\
nodata,\
nodata,\
nodata,\
nodata,\
nodata,\
nodata))
else:
zstats.append(setFeatureStats(\
id,\
nodata,\
nodata,\
nodata,\
nodata,\
nodata,\
nodata,\
nodata))
tp_ds = None
tp_lyr = None
tr_ds = None
p_feat = lyr.GetNextFeature()
# fn_csv = "C:/temp/zonal_stats/zstats.csv"
col_names = zstats[0].keys()
with open(fn_csv, 'w', newline='') as csvfile:
writer = csv.DictWriter(csvfile, col_names)
writer.writeheader()
writer.writerows(zstats)
if __name__ == "__main__":
time1 = time.time()
fn_raster = './data/t_dem.tif'
fn_zones = './data/grid1.shp'
fn_csv = './data/zonsl.csv'
zonal(fn_raster, fn_zones, fn_csv)
time2 = time.time()
print((time2-time1) / 3600.0)
如果只得到了csv而没有把值填进shp属性表那么上面操作的意义将大打折扣,下面是我完成以后的样子
下面是更新后的代码:
import gdal
import ogr
import os
import numpy as np
import csv
import pandas as pd
import time
def boundingBoxToOffsets(bbox, geot):
col1 = int((bbox[0] - geot[0]) / geot[1])
col2 = int((bbox[1] - geot[0]) / geot[1]) + 1
row1 = int((bbox[3] - geot[3]) / geot[5])
row2 = int((bbox[2] - geot[3]) / geot[5]) + 1
return [row1, row2, col1, col2]
def geotFromOffsets(row_offset, col_offset, geot):
new_geot = [
geot[0] + (col_offset * geot[1]),
geot[1],
0.0,
geot[3] + (row_offset * geot[5]),
0.0,
geot[5]
]
return new_geot
def setFeatureStats(fid, min, max, mean, median, sd, sum, count, names=["min", "max", "mean", "median", "sd", "sum", "count", "id"]):
featstats = {
names[0]: min,
names[1]: max,
names[2]: mean,
names[3]: median,
names[4]: sd,
names[5]: sum,
names[6]: count,
names[7]: fid,
}
return featstats
def zonal(fn_raster, fn_zones, fn_csv):
mem_driver = ogr.GetDriverByName("Memory")
mem_driver_gdal = gdal.GetDriverByName("MEM")
shp_name = "temp"
# fn_raster = "C:/pyqgis/raster/USGS_NED_13_n45w116_IMG.img"
# fn_zones = "C:/temp/zonal_stats/zones.shp"
r_ds = gdal.Open(fn_raster)
p_ds = ogr.Open(fn_zones)
lyr = p_ds.GetLayer()
geot = r_ds.GetGeoTransform()
nodata = r_ds.GetRasterBand(1).GetNoDataValue()
zstats = []
p_feat = lyr.GetNextFeature()
niter = 0
while p_feat:
if p_feat.GetGeometryRef() is not None:
if os.path.exists(shp_name):
mem_driver.DeleteDataSource(shp_name)
tp_ds = mem_driver.CreateDataSource(shp_name)
tp_lyr = tp_ds.CreateLayer('polygons', None, ogr.wkbPolygon)
tp_lyr.CreateFeature(p_feat.Clone())
offsets = boundingBoxToOffsets(p_feat.GetGeometryRef().GetEnvelope(),\
geot)
new_geot = geotFromOffsets(offsets[0], offsets[2], geot)
tr_ds = mem_driver_gdal.Create(\
"", \
offsets[3] - offsets[2], \
offsets[1] - offsets[0], \
1, \
gdal.GDT_Byte)
tr_ds.SetGeoTransform(new_geot)
gdal.RasterizeLayer(tr_ds, [1], tp_lyr, burn_values=[1])
tr_array = tr_ds.ReadAsArray()
r_array = r_ds.GetRasterBand(1).ReadAsArray(\
offsets[2],\
offsets[0],\
offsets[3] - offsets[2],\
offsets[1] - offsets[0])
id = p_feat.GetFID()
if r_array is not None:
maskarray = np.ma.MaskedArray(\
r_array,\
maskarray=np.logical_or(r_array==nodata, np.logical_not(tr_array)))
if maskarray is not None:
zstats.append(setFeatureStats(\
id,\
maskarray.min(),\
maskarray.max(),\
maskarray.mean(),\
np.ma.median(maskarray),\
maskarray.std(),\
maskarray.sum(),\
maskarray.count()))
else:
zstats.append(setFeatureStats(\
id,\
nodata,\
nodata,\
nodata,\
nodata,\
nodata,\
nodata,\
nodata))
else:
zstats.append(setFeatureStats(\
id,\
nodata,\
nodata,\
nodata,\
nodata,\
nodata,\
nodata,\
nodata))
tp_ds = None
tp_lyr = None
tr_ds = None
p_feat = lyr.GetNextFeature()
col_names = zstats[0].keys()
with open(fn_csv, 'w', newline='') as csvfile:
writer = csv.DictWriter(csvfile, col_names)
writer.writeheader()
writer.writerows(zstats)
def shp_field_value(csv_file, shp):
data = pd.DataFrame(pd.read_csv(csv_file))
driver = ogr.GetDriverByName('ESRI Shapefile')
layer_source = driver.Open(shp,1)
lyr = layer_source.GetLayer()
s_name = ogr.FieldDefn('min', ogr.OFTReal)
lyr.CreateField(s_name)
s_name = ogr.FieldDefn('max', ogr.OFTReal)
lyr.CreateField(s_name)
s_name = ogr.FieldDefn('mean', ogr.OFTReal)
lyr.CreateField(s_name)
s_name = ogr.FieldDefn('median', ogr.OFTReal)
lyr.CreateField(s_name)
s_name = ogr.FieldDefn('sd', ogr.OFTReal)
lyr.CreateField(s_name)
s_name = ogr.FieldDefn('sum', ogr.OFTReal)
lyr.CreateField(s_name)
s_name = ogr.FieldDefn('count', ogr.OFTReal)
lyr.CreateField(s_name)
count = 0
defn = lyr.GetLayerDefn()
featureCount = defn.GetFieldCount()
feature = lyr.GetNextFeature()
while feature is not None:
for i in range(featureCount):
feature.SetField('min', data['min'][count])
feature.SetField('max', data['max'][count])
feature.SetField('mean', data['mean'][count])
feature.SetField('median', data['median'][count])
feature.SetField('sd', data['sd'][count])
feature.SetField('sum', data['sum'][count])
feature.SetField('count', data['count'][count])
lyr.SetFeature(feature)
count+=1
feature = lyr.GetNextFeature()
if __name__ == "__main__":
time1 = time.time()
fn_raster = './data/t_dem.tif'
fn_zones = './data/grid1.shp'
fn_csv = './data/zonsl.csv'
zonal(fn_raster, fn_zones, fn_csv)
shp_field_value(fn_csv, fn_zones)
time2 = time.time()
print((time2-time1) / 3600.0)
用时:0.00384633739789327 (h)
下面和arcgis的分区统计做下比较
import os
import time
import arcpy
from arcpy import env
def zonal(raster, shp):
attri_table = "zonalstat"
arcpy.gp.ZonalStatisticsAsTable_sa(shp, "FID", raster, attri_table, "NODATA", "ALL")
arcpy.JoinField_management(shp,"FID", attri_table, "FID")
if __name__ == "__main__":
time1 = time.time()
dem_ras = './data/t_dem.tif'
shp = './data/grid.shp'
temp_path = './data/temp/'
arcpy.CheckOutExtension('Spatial')
arcpy.env.overwriteOutput = True
env.workspace = temp_path
zonal(dem_ras, shp)
time2 = time.time()
print((time2-time1) / 3600.0)
用时:0.196481666631 (h)
下面是arcgis统计以后的属性表,很明显arcgis多了几个字段包括AREA,MAJORITY,MINORITY等
arcgis统计的字段比gdal实现的多,但是我感觉时间上还是有点太长了,就算gdal加上那些arcgis多出来的字段应该也不会这么长时间,感兴趣的朋友可以尝试一下。另外注意一下,两种方式得到的字段属性值也存在一定差异,但是大体上不会差很多,下面是mean值对应的视觉显示情况:
1.gdal
2.arcgis
仔细观察下,可以发现gdal实现的效果还是有点瑕疵的,比如右下角的那几个异常值
测试数据链接:https://download.csdn.net/download/qq_20373723/13716488