python gdal完成arcgis分区统计功能(zonal)

参考链接:https://towardsdatascience.com/zonal-statistics-algorithm-with-python-in-4-steps-382a3b66648a
注意事项,栅格需要和矢量的坐标系保持一致,结果存在了cvs文件中
结果
代码我稍微改了下,后续还会和矢量关联

import gdal
import ogr
import os
import numpy as np
import csv
import time

def boundingBoxToOffsets(bbox, geot):
    col1 = int((bbox[0] - geot[0]) / geot[1])
    col2 = int((bbox[1] - geot[0]) / geot[1]) + 1
    row1 = int((bbox[3] - geot[3]) / geot[5])
    row2 = int((bbox[2] - geot[3]) / geot[5]) + 1
    return [row1, row2, col1, col2]

def geotFromOffsets(row_offset, col_offset, geot):
    new_geot = [
    geot[0] + (col_offset * geot[1]),
    geot[1],
    0.0,
    geot[3] + (row_offset * geot[5]),
    0.0,
    geot[5]
    ]
    return new_geot

def setFeatureStats(fid, min, max, mean, median, sd, sum, count, names=["min", "max", "mean", "median", "sd", "sum", "count", "id"]):
    featstats = {
    
    
    names[0]: min,
    names[1]: max,
    names[2]: mean,
    names[3]: median,
    names[4]: sd,
    names[5]: sum,
    names[6]: count,
    names[7]: fid,
    }
    return featstats

def zonal(fn_raster, fn_zones, fn_csv):
    mem_driver = ogr.GetDriverByName("Memory")
    mem_driver_gdal = gdal.GetDriverByName("MEM")
    shp_name = "temp"

    # fn_raster = "C:/pyqgis/raster/USGS_NED_13_n45w116_IMG.img"
    # fn_zones = "C:/temp/zonal_stats/zones.shp"

    r_ds = gdal.Open(fn_raster)
    p_ds = ogr.Open(fn_zones)

    lyr = p_ds.GetLayer()
    geot = r_ds.GetGeoTransform()
    nodata = r_ds.GetRasterBand(1).GetNoDataValue()

    zstats = []

    p_feat = lyr.GetNextFeature()
    niter = 0

    while p_feat:
        if p_feat.GetGeometryRef() is not None:
            if os.path.exists(shp_name):
                mem_driver.DeleteDataSource(shp_name)
            tp_ds = mem_driver.CreateDataSource(shp_name)
            tp_lyr = tp_ds.CreateLayer('polygons', None, ogr.wkbPolygon)
            tp_lyr.CreateFeature(p_feat.Clone())
            offsets = boundingBoxToOffsets(p_feat.GetGeometryRef().GetEnvelope(),\
            geot)
            new_geot = geotFromOffsets(offsets[0], offsets[2], geot)

            tr_ds = mem_driver_gdal.Create(\
            "", \
            offsets[3] - offsets[2], \
            offsets[1] - offsets[0], \
            1, \
            gdal.GDT_Byte)

            tr_ds.SetGeoTransform(new_geot)
            gdal.RasterizeLayer(tr_ds, [1], tp_lyr, burn_values=[1])
            tr_array = tr_ds.ReadAsArray()

            r_array = r_ds.GetRasterBand(1).ReadAsArray(\
            offsets[2],\
            offsets[0],\
            offsets[3] - offsets[2],\
            offsets[1] - offsets[0])

            id = p_feat.GetFID()

            if r_array is not None:
                maskarray = np.ma.MaskedArray(\
                r_array,\
                maskarray=np.logical_or(r_array==nodata, np.logical_not(tr_array)))
                 
                if maskarray is not None:
                    zstats.append(setFeatureStats(\
                    id,\
                    maskarray.min(),\
                    maskarray.max(),\
                    maskarray.mean(),\
                    np.ma.median(maskarray),\
                    maskarray.std(),\
                    maskarray.sum(),\
                    maskarray.count()))
                else:
                    zstats.append(setFeatureStats(\
                    id,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata))
            else:
                zstats.append(setFeatureStats(\
                    id,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata))

            tp_ds = None
            tp_lyr = None
            tr_ds = None

            p_feat = lyr.GetNextFeature()

    # fn_csv = "C:/temp/zonal_stats/zstats.csv"
    col_names = zstats[0].keys()
    with open(fn_csv, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, col_names)
        writer.writeheader()
        writer.writerows(zstats)

if __name__ == "__main__":
    time1 = time.time()
    fn_raster = './data/t_dem.tif'
    fn_zones = './data/grid1.shp'
    fn_csv = './data/zonsl.csv'
    zonal(fn_raster, fn_zones, fn_csv)
    time2 = time.time()
    print((time2-time1) / 3600.0)

如果只得到了csv而没有把值填进shp属性表那么上面操作的意义将大打折扣,下面是我完成以后的样子
对比图
统计结果
下面是更新后的代码:

import gdal
import ogr
import os
import numpy as np
import csv
import pandas as pd
import time

def boundingBoxToOffsets(bbox, geot):
    col1 = int((bbox[0] - geot[0]) / geot[1])
    col2 = int((bbox[1] - geot[0]) / geot[1]) + 1
    row1 = int((bbox[3] - geot[3]) / geot[5])
    row2 = int((bbox[2] - geot[3]) / geot[5]) + 1
    return [row1, row2, col1, col2]

def geotFromOffsets(row_offset, col_offset, geot):
    new_geot = [
    geot[0] + (col_offset * geot[1]),
    geot[1],
    0.0,
    geot[3] + (row_offset * geot[5]),
    0.0,
    geot[5]
    ]
    return new_geot

def setFeatureStats(fid, min, max, mean, median, sd, sum, count, names=["min", "max", "mean", "median", "sd", "sum", "count", "id"]):
    featstats = {
    
    
    names[0]: min,
    names[1]: max,
    names[2]: mean,
    names[3]: median,
    names[4]: sd,
    names[5]: sum,
    names[6]: count,
    names[7]: fid,
    }
    return featstats

def zonal(fn_raster, fn_zones, fn_csv):
    mem_driver = ogr.GetDriverByName("Memory")
    mem_driver_gdal = gdal.GetDriverByName("MEM")
    shp_name = "temp"

    # fn_raster = "C:/pyqgis/raster/USGS_NED_13_n45w116_IMG.img"
    # fn_zones = "C:/temp/zonal_stats/zones.shp"

    r_ds = gdal.Open(fn_raster)
    p_ds = ogr.Open(fn_zones)

    lyr = p_ds.GetLayer()
    geot = r_ds.GetGeoTransform()
    nodata = r_ds.GetRasterBand(1).GetNoDataValue()

    zstats = []

    p_feat = lyr.GetNextFeature()
    niter = 0

    while p_feat:
        if p_feat.GetGeometryRef() is not None:
            if os.path.exists(shp_name):
                mem_driver.DeleteDataSource(shp_name)
            tp_ds = mem_driver.CreateDataSource(shp_name)
            tp_lyr = tp_ds.CreateLayer('polygons', None, ogr.wkbPolygon)
            tp_lyr.CreateFeature(p_feat.Clone())
            offsets = boundingBoxToOffsets(p_feat.GetGeometryRef().GetEnvelope(),\
            geot)
            new_geot = geotFromOffsets(offsets[0], offsets[2], geot)

            tr_ds = mem_driver_gdal.Create(\
            "", \
            offsets[3] - offsets[2], \
            offsets[1] - offsets[0], \
            1, \
            gdal.GDT_Byte)

            tr_ds.SetGeoTransform(new_geot)
            gdal.RasterizeLayer(tr_ds, [1], tp_lyr, burn_values=[1])
            tr_array = tr_ds.ReadAsArray()

            r_array = r_ds.GetRasterBand(1).ReadAsArray(\
            offsets[2],\
            offsets[0],\
            offsets[3] - offsets[2],\
            offsets[1] - offsets[0])

            id = p_feat.GetFID()

            if r_array is not None:
                maskarray = np.ma.MaskedArray(\
                r_array,\
                maskarray=np.logical_or(r_array==nodata, np.logical_not(tr_array)))
                 
                if maskarray is not None:
                    zstats.append(setFeatureStats(\
                    id,\
                    maskarray.min(),\
                    maskarray.max(),\
                    maskarray.mean(),\
                    np.ma.median(maskarray),\
                    maskarray.std(),\
                    maskarray.sum(),\
                    maskarray.count()))
                else:
                    zstats.append(setFeatureStats(\
                    id,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata))
            else:
                zstats.append(setFeatureStats(\
                    id,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata,\
                    nodata))

            tp_ds = None
            tp_lyr = None
            tr_ds = None

            p_feat = lyr.GetNextFeature()
            
    col_names = zstats[0].keys()
    with open(fn_csv, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, col_names)
        writer.writeheader()
        writer.writerows(zstats)

def shp_field_value(csv_file, shp):
    data = pd.DataFrame(pd.read_csv(csv_file))
    driver = ogr.GetDriverByName('ESRI Shapefile')
    layer_source = driver.Open(shp,1)
    lyr = layer_source.GetLayer()
    
    s_name = ogr.FieldDefn('min', ogr.OFTReal)
    lyr.CreateField(s_name)
    s_name = ogr.FieldDefn('max', ogr.OFTReal)
    lyr.CreateField(s_name)
    s_name = ogr.FieldDefn('mean', ogr.OFTReal)
    lyr.CreateField(s_name)
    s_name = ogr.FieldDefn('median', ogr.OFTReal)
    lyr.CreateField(s_name)
    s_name = ogr.FieldDefn('sd', ogr.OFTReal)
    lyr.CreateField(s_name)
    s_name = ogr.FieldDefn('sum', ogr.OFTReal)
    lyr.CreateField(s_name)
    s_name = ogr.FieldDefn('count', ogr.OFTReal)
    lyr.CreateField(s_name)

    count = 0
    defn = lyr.GetLayerDefn()
    featureCount = defn.GetFieldCount()
    feature = lyr.GetNextFeature()
    while feature is not None:
        for i in range(featureCount):
            feature.SetField('min', data['min'][count])
            feature.SetField('max', data['max'][count])
            feature.SetField('mean', data['mean'][count])
            feature.SetField('median', data['median'][count])
            feature.SetField('sd', data['sd'][count])
            feature.SetField('sum', data['sum'][count])
            feature.SetField('count', data['count'][count])
            lyr.SetFeature(feature)
        count+=1
        feature = lyr.GetNextFeature()

if __name__ == "__main__":
    time1 = time.time()
    fn_raster = './data/t_dem.tif'
    fn_zones = './data/grid1.shp'
    fn_csv = './data/zonsl.csv'
    zonal(fn_raster, fn_zones, fn_csv)
    shp_field_value(fn_csv, fn_zones)
    time2 = time.time()
    print((time2-time1) / 3600.0)

用时:0.00384633739789327 (h)

下面和arcgis的分区统计做下比较

import os
import time
import arcpy
from arcpy import env

def zonal(raster, shp):
	attri_table = "zonalstat"
	arcpy.gp.ZonalStatisticsAsTable_sa(shp, "FID", raster, attri_table, "NODATA", "ALL")
	arcpy.JoinField_management(shp,"FID", attri_table, "FID")

if __name__ == "__main__":
	time1 = time.time()
	dem_ras = './data/t_dem.tif'
	shp = './data/grid.shp'
	temp_path = './data/temp/'

	arcpy.CheckOutExtension('Spatial')
	arcpy.env.overwriteOutput = True
	env.workspace = temp_path

	zonal(dem_ras, shp)
	time2 = time.time()
	print((time2-time1) / 3600.0)

用时:0.196481666631 (h)
下面是arcgis统计以后的属性表,很明显arcgis多了几个字段包括AREA,MAJORITY,MINORITY等
arcgis
arcgis统计的字段比gdal实现的多,但是我感觉时间上还是有点太长了,就算gdal加上那些arcgis多出来的字段应该也不会这么长时间,感兴趣的朋友可以尝试一下。另外注意一下,两种方式得到的字段属性值也存在一定差异,但是大体上不会差很多,下面是mean值对应的视觉显示情况:
1.gdal
gdal

2.arcgis
arcgis
仔细观察下,可以发现gdal实现的效果还是有点瑕疵的,比如右下角的那几个异常值

测试数据链接:https://download.csdn.net/download/qq_20373723/13716488

猜你喜欢

转载自blog.csdn.net/qq_20373723/article/details/111350741