LightGBM用于遥感影像多分类

LightGBM是一个框架,内置了:
1.gbdt, 传统的梯度提升决策树;
2.rf, Random Forest (随机森林);
3.dart, Dropouts meet Multiple Additive Regression Trees;
4.goss, Gradient-based One-Side Sampling (基于梯度的单侧采样)。
这些方法的选择由 参数boosting_type来控制。
我之前的微博记录了随机森林和支持向量机的分类方法,这都是scikit-learn机器学习包里的机器学习方法,同样的方法和LightGBM相比就是无法GPU加速,这导致处理比较大的遥感影像比较慢,不好用于项目中,LightGBM框架虽然不包含SVM但是里面包含的方法不比SVM差
这里面的参数太多了,给大家几个链接好好研究吧,我也还在得到想要的结果的路上,建议有紧急任务的朋友先别入坑了,调参太浪费时间了,我也还没得到想要的结果
下面是示例:
原图
结果:
分类结果
说明:其实我分了4类,但只得到了3类,原因是有两个类很像,毕竟是像素级分类,仔细观察发现水体和林地确实太像了,任何一种像素级的分类方法要分开都比较难,但是比较重要的一点是分类结果的杂质比其他方法确实少很多,对于需要转到矢量的情况比较有利,另外提一嘴,这个方法陪上我前面的那篇crf优化分类结果的方式(https://blog.csdn.net/qq_20373723/article/details/109831725),结果的杂质会更少。

安装:https://www.jianshu.com/p/30555fd2bd50

1.https://blog.csdn.net/weixin_39807102/article/details/81912566?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-2.not_use_machine_learn_pai&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-2.not_use_machine_learn_pai

2.https://blog.csdn.net/u011599639/article/details/98680280

3.https://www.jianshu.com/p/1100e333fcab
调参推荐看:
https://www.cnblogs.com/bjwu/p/9307344.html

下面是我用的代码,这个代码把数据分成了5份,训练了5个模型,预测的时候划框一个区域也预测了五次,速度被大大降低,建议用下一个版本,后面会提供,这个只是放这里给一个参考,毕竟这种方式有提高效果的可能,未尝不可一试,毕竟我也才接触这个
其中数据的处理用到了sklearn.model_selection中的KFold,需要了解的看下这个链接:https://blog.csdn.net/weixin_43685844/article/details/88635492
数据准备:
样本点矢量文件(学遥感测绘专业的应该都知道什么意思吧),和以前的博客里提到一样,每个类别对应一个矢量文件, 点越多越号,下面是我选的4个类别的分布
点分布
数据存放:
数据存放
注:以前的博客这里的文件名最好是0,1,2,3按顺序来,这样和最后的类别是对应的,我这里命名错了,其实还可以通过键值对的方式实现类别的对应,下面会通过配置文件的形式体现,你们也可以不用改名字。
配置文件内容:
文件名config_order.txt
配置文件
注:那个result本来是要把栅格结果转矢量的,我去掉了,你们想要可以自己找代码转,最后一个路径是临时文件的存放位置,记得每次重新跑的时候删除上一次得,不然跳过执行预测。
必须注意的是,你的影像和样本点文件必须是带有坐标系的,因为点样本是用来在影像上取数据的,是通过地理坐标取得,所以必须要有

版本1

# -*- coding: utf-8 -*-
from osgeo import ogr, osr
from osgeo import gdal
from gdalconst import *
import os, sys, time
import numpy as np
import lightgbm as lgb
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
import pickle
from sklearn.model_selection import KFold
import pandas as pd

def getPixels(shp, img):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    ds = driver.Open(shp, 0)
    if ds is None:
        print('Could not open ' + shp)
        sys.exit(1)

    layer = ds.GetLayer()

    xValues = []
    yValues = []
    feature = layer.GetNextFeature()
    while feature:
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        xValues.append(x)
        yValues.append(y)
        feature = layer.GetNextFeature()

    gdal.AllRegister()

    ds = gdal.Open(img, GA_ReadOnly)
    if ds is None:
        print('Could not open image')
        sys.exit(1)

    rows = ds.RasterYSize
    cols = ds.RasterXSize
    bands = ds.RasterCount

    transform = ds.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    values = []
    for i in range(len(xValues)):
        x = xValues[i]
        y = yValues[i]

        xOffset = int((x - xOrigin) / pixelWidth)
        yOffset = int((y - yOrigin) / pixelHeight)

        s = str(int(x)) + ' ' + str(int(y)) + ' ' + str(xOffset) + ' ' + str(yOffset) + ' '

        pt = []
        for j in range(bands):
            band = ds.GetRasterBand(j + 1)
            data = band.ReadAsArray(xOffset-5, yOffset-5, 10, 10)
            value = data
            value = value.flatten()
            pt.append(value)
        
        temp = []
        pt = array_change(pt, temp)
        values.append(pt)
    
    temp2 = []
    all_values = array_change(values, temp2)
    all_values = np.asarray(all_values)

    temp3 = []
    result_values = array_change2(all_values, temp3)
    result_values = np.asarray(result_values)
    return result_values


def lightlgb_classfiy(class_list, model_path):
    array_num = len(class_list)
    RGB_arr = np.array([[0,0,0]])
    label= np.array([])
    count = 0

    class_final = {
    
    }
    for i in sorted(class_list):
        RGB_arr = np.concatenate((RGB_arr,class_list[i]),axis=0)
        array_l = class_list[i].shape[0]
        label = np.append(label, count * np.ones(array_l))
        class_final[i] = count
        count += 1
    RGB_arr = np.delete(RGB_arr,0,0)

    # params = {'max_bin': 63,
    # 'num_leaves': 255,
    # 'learning_rate': 0.1,
    # 'tree_learner': 'serial',
    # 'task': 'train',
    # 'is_training_metric': 'false',
    # 'min_data_in_leaf': 1,
    # 'min_sum_hessian_in_leaf': 100,
    # 'ndcg_eval_at': [1,3,5,10],
    # 'sparse_threshold': 1.0,
    # 'device': 'gpu',
    # 'gpu_platform_id': 0,
    # 'gpu_device_id': 0}

    # params = {'max_bin': 63,
    #       'num_leaves': 255,
    #       'learning_rate': 0.01,
    #       'tree_learner': 'serial',
    #       'task': 'train',
    #       'is_training_metric': 'false',
    #       'min_data_in_leaf': 1,
    #       'min_sum_hessian_in_leaf': 100,
    #       'ndcg_eval_at': [1, 3, 5, 10],
    #       'sparse_threshold': 1.0,
    #       'device': 'gpu'
    #       }

    params = {
    
    
        'boosting_type': 'gbdt',
        'objective': 'multiclass',
        'num_class': 4,
        'metric': 'multi_error',
        'num_leaves': 60,
        'min_data_in_leaf': 30,
        'min_sum_hessian_in_leaf': 6,
        'max_depth':-1,
        'learning_rate': 0.01,
        'feature_fraction': 0.9,
        'bagging_fraction': 0.8,
        'bagging_freq': 1,
        'lambda_l1': 0.1,
        'lambda_l2': 0.001,
        'min_gain_to_split': 0.2,
        'metric': 'multi_logloss',
        'random_state': 2019,
        'verbosity': -1,
        'verbose': 5,
        'is_unbalance': True,
        'device': 'gpu'
        }

    # params['is_unbalance']='true'
    # params['metric'] = 'auc'

    folds = KFold(n_splits = 5, shuffle = True, random_state = 2019)
    prob_oof = np.zeros((RGB_arr.shape[0], RGB_arr.shape[1], 4))
    
    feature_importance_df = pd.DataFrame()
    
    models = []
    for fold_, (train_idx, val_idx) in enumerate(folds.split(RGB_arr)):
        print("fold {}".format(fold_ + 1))
        train_data = lgb.Dataset(RGB_arr[train_idx], label = label[train_idx])
        val_data = lgb.Dataset(RGB_arr[val_idx], label = label[val_idx])

        clf = lgb.train(params, 
            train_data,
            5000,
            valid_sets=[train_data, val_data],
            verbose_eval=200,
            early_stopping_rounds=200)

        fold_importance_df = pd.DataFrame()
        # fold_importance_df["Feature"] = features
        fold_importance_df["importance"] = clf.feature_importance()
        fold_importance_df["fold"] = fold_ + 1
        feature_importance_df = pd.concat([feature_importance_df, fold_importance_df], axis=0)

        models.append(clf)


    with open(model_path, 'wb') as f:
            pickle.dump(np.array(models), f)

    return array_num, class_final, models, folds

def rf_train(class_list, img_arr, model_path):
    array_num = len(class_list)
    RGB_arr = np.array([[0,0,0]])
    label= np.array([])
    count = 0

    class_final = {
    
    }
    for i in sorted(class_list):
        RGB_arr = np.concatenate((RGB_arr,class_list[i]),axis=0)
        array_l = class_list[i].shape[0]
        label = np.append(label, count * np.ones(array_l))
        class_final[i] = count
        count += 1
    RGB_arr = np.delete(RGB_arr,0,0)


    if os.path.exists(model_path):
        pass
    else:
        rf = RandomForestClassifier(n_estimators=500, max_depth=10, n_jobs=14)

        rf.fit(RGB_arr, label)
        # svc.fit(RGB_arr,label)
        with open(model_path, 'wb') as f:
            pickle.dump(rf, f)

    return array_num, class_final

def get_model(model_path):
    with open(model_path, 'rb') as f:
        svc = pickle.load(f)
    return svc

def clf_predict(clf, img_arr, array_num, folds, outPath):
    img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])
    # predict = clf.predict(img_reshape, num_iteration=clf.best_iteration)
    # predict_r = [list(x).index(max(x)) for x in predict]
    # predict_r = np.array(predict_r)
    
    test_pred_prob = np.zeros((img_arr.shape[0] * img_arr.shape[1], 4))
    for cc in clf:
        test_pred_prob += cc.predict(img_reshape, num_iteration=cc.best_iteration) / folds.n_splits
    predict_r = np.argmax(test_pred_prob, axis=1)

    for j in range(array_num):
        lake_bool = predict_r == np.float(j)
        lake_bool = lake_bool[:,np.newaxis]
        try:
            lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool,lake_bool),axis=1)
            lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
            img_arr[lake_bool_4d] = np.float(j)
        except:
            lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool),axis=1)
            lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
            img_arr[lake_bool_4d] = np.float(j)

    img_arr = img_arr.transpose((2,1,0))
    img_arr = img_arr[0]
    # write_img(outPath, im_proj, im_geotrans, img_arr)
    return img_arr


def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset


def write_img_(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, gdal.GDT_Byte)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

def array_change(inlist, outlist):
    for i in range(len(inlist[0])):
        outlist.append([j[i] for j in inlist])
    return outlist

def array_change2(inlist, outlist):
    for ele in inlist:
        for ele2 in ele:
            outlist.append(ele2)
    return outlist

def getTifSize(tif):
    dataSet = gdal.Open(tif)
    width = dataSet.RasterXSize
    height = dataSet.RasterYSize
    bands = dataSet.RasterCount
    geoTrans = dataSet.GetGeoTransform()
    proj = dataSet.GetProjection()
    return width,height,bands,geoTrans,proj

def partDivisionForBoundary(model,array_num,tif1,folds,divisionSize,tempPath):
    width,height,bands,geoTrans,proj = getTifSize(tif1)
    partWidth = partHeight = divisionSize

    if width % partWidth > 0 :
        widthNum = width // partWidth + 1
    else:
        widthNum =  width // partWidth
    if height % partHeight > 0:
        heightNum = height // partHeight +1
    else:
        heightNum = height // partHeight

    realName = os.path.split(tif1)[1].split(".")[0]

    tif1 = gdal.Open(tif1)
    for i in range(heightNum):
        for j in range(widthNum):

            startX = partWidth * j
            startY = partHeight * i

            if startX+partWidth<= width and startY+partHeight<=height:
                realPartWidth = partWidth
                realPartHeight = partHeight
            elif startX + partWidth > width and startY+partHeight<=height:
                realPartWidth = width - startX
                realPartHeight = partHeight
            elif startX+partWidth <= width  and startY+partHeight > height:
                realPartWidth = partWidth
                realPartHeight = height - startY
            elif startX + partWidth > width and startY+partHeight > height:
                realPartWidth = width - startX
                realPartHeight = height - startY

            outName = realName + str(i)+str(j)+".tif"
            outPath = os.path.join(tempPath,outName)

            if not os.path.exists(outPath):

                driver = gdal.GetDriverByName("GTiff")
                outTif = driver.Create(outPath,realPartWidth,realPartHeight,1,gdal.GDT_Float32)
                outTif.SetGeoTransform(geoTrans)
                outTif.SetProjection(proj)

                data1 = tif1.ReadAsArray(startX,startY,realPartWidth,realPartHeight)
                data1 = data1.transpose((2,1,0))
                svmData = clf_predict(model, data1, array_num, folds, outPath)
                outTif.GetRasterBand(1).WriteArray(svmData)
    return 1

def partStretch(tif1,divisionSize,outStratchPath,tempPath):

    width,height,bands,geoTrans,proj = getTifSize(tif1)
    # bands = 1
    partWidth = partHeight = divisionSize

    if width % partWidth > 0 :
        widthNum = width // partWidth + 1
    else:
        widthNum =  width // partWidth
    if height % partHeight > 0:
        heightNum = height // partHeight +1
    else:
        heightNum = height // partHeight

    realName = os.path.split(tif1)[1].split(".")[0]

    driver = gdal.GetDriverByName("GTiff")
    outTif = driver.Create(outStratchPath,width,height,1,gdal.GDT_Byte)
    if outTif!= None:
        outTif.SetGeoTransform(geoTrans)
        outTif.SetProjection(proj)
    for i in range(heightNum):
        for j in range(widthNum):

            startX = partWidth * j
            startY = partHeight * i

            if startX+partWidth<= width and startY+partHeight<=height:
                realPartWidth = partWidth
                realPartHeight = partHeight
            elif startX + partWidth > width and startY+partHeight<=height:
                realPartWidth = width - startX
                realPartHeight = partHeight
            elif startX+partWidth <= width  and startY+partHeight > height:
                realPartWidth = partWidth
                realPartHeight = height - startY
            elif startX + partWidth > width and startY+partHeight > height:
                realPartWidth = width - startX
                realPartHeight = height - startY

            partTifName = realName+str(i)+str(j)+".tif"
            partTifPath = os.path.join(tempPath,partTifName)
            divisionImg = gdal.Open(partTifPath)
            for k in range(1):
                data1 = divisionImg.GetRasterBand(k+1).ReadAsArray(0,0,realPartWidth,realPartHeight)
                outPartBand = outTif.GetRasterBand(k+1)
                outPartBand.WriteArray(data1,startX,startY)

if __name__ == '__main__':
    config_file='config_order.txt'
    dirs=[]
    for line in open(config_file):
        dirs.append(line.split()[0])
    
    dir_n = len(dirs)
    shp_n = dir_n - 4
    data_image = dirs[0]
    data_image=data_image.replace('\\','/')
    
    point_dict = dirs[1]
    shp_paths = []
    for i in range(shp_n):
        shp_paths.append(dirs[i+2].replace('\\','/'))
    
    result_path = dirs[-2]
    result_path = result_path.replace('\\','/')
    
    temp_path = dirs[-1]
    temp_path = temp_path.replace('\\','/')

    time1 = time.time()
    print('Start ...')
    dd = {
    
    }
    point_dicts = point_dict.split(';')
    for p_dict in point_dicts:
        if p_dict == '' or p_dict == None:
            continue
        cls_n, d_name = p_dict.split(':')
        for p_path in shp_paths:
            path, p_name = os.path.split(p_path)
            if d_name == p_name:
                dd[cls_n] = p_path

    class_list = {
    
    }
    for k_cls, v_shp_path in dd.items():
        class_data  = getPixels(v_shp_path, data_image)
        class_list[k_cls] = class_data

    model_path = os.path.join(temp_path, 'model.pickle')
    
    print('Train model ...')
    num, class_final, models, folds= lightlgb_classfiy(class_list, model_path)

    lgb = get_model(model_path)
    
    slice_path = os.path.join(temp_path, 'slice_temp')
    if os.path.exists(slice_path):
        pass
    else:
        os.mkdir(slice_path)

    print('Predict task area ...')
    partDivisionForBoundary(lgb,num,data_image,folds,2000,slice_path)
    raster_path = os.path.join(temp_path, 'class_raster.tif')
    partStretch(data_image,2000,raster_path,slice_path)

    time2 = time.time()
    print((time2-time1)/3600)

版本2
数据处理部分不用KFold而改用train_test_split ,这次的时间倒是很快了,但是结果不是很好,调参还是个麻烦事,前面给了链接,大家自己去调整吧

# -*- coding: utf-8 -*-
from osgeo import ogr, osr
from osgeo import gdal
from gdalconst import *
import os, sys, time
import numpy as np
import lightgbm as lgb
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
import pickle
# from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split  
import pandas as pd

def getPixels(shp, img):
    driver = ogr.GetDriverByName('ESRI Shapefile')
    ds = driver.Open(shp, 0)
    if ds is None:
        print('Could not open ' + shp)
        sys.exit(1)

    layer = ds.GetLayer()

    xValues = []
    yValues = []
    feature = layer.GetNextFeature()
    while feature:
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        xValues.append(x)
        yValues.append(y)
        feature = layer.GetNextFeature()

    gdal.AllRegister()

    ds = gdal.Open(img, GA_ReadOnly)
    if ds is None:
        print('Could not open image')
        sys.exit(1)

    rows = ds.RasterYSize
    cols = ds.RasterXSize
    bands = ds.RasterCount

    transform = ds.GetGeoTransform()
    xOrigin = transform[0]
    yOrigin = transform[3]
    pixelWidth = transform[1]
    pixelHeight = transform[5]

    values = []
    for i in range(len(xValues)):
        x = xValues[i]
        y = yValues[i]

        xOffset = int((x - xOrigin) / pixelWidth)
        yOffset = int((y - yOrigin) / pixelHeight)

        s = str(int(x)) + ' ' + str(int(y)) + ' ' + str(xOffset) + ' ' + str(yOffset) + ' '

        pt = []
        for j in range(bands):
            band = ds.GetRasterBand(j + 1)
            data = band.ReadAsArray(xOffset-5, yOffset-5, 10, 10)
            value = data
            value = value.flatten()
            pt.append(value)
        
        temp = []
        pt = array_change(pt, temp)
        values.append(pt)
    
    temp2 = []
    all_values = array_change(values, temp2)
    all_values = np.asarray(all_values)

    temp3 = []
    result_values = array_change2(all_values, temp3)
    result_values = np.asarray(result_values)
    return result_values


def lightlgb_classfiy(class_list, model_path):
    array_num = len(class_list)
    RGB_arr = np.array([[0,0,0]])
    label= np.array([])
    count = 0

    class_final = {
    
    }
    for i in sorted(class_list):
        RGB_arr = np.concatenate((RGB_arr,class_list[i]),axis=0)
        array_l = class_list[i].shape[0]
        label = np.append(label, count * np.ones(array_l))
        class_final[i] = count
        count += 1
    RGB_arr = np.delete(RGB_arr,0,0)

    params = {
    
    
        'boosting_type': 'gbdt',
        'objective': 'multiclass',
        'num_class': 4,
        'metric': 'multi_error',
        'num_leaves': 50,
        'min_data_in_leaf': 100,
        'min_sum_hessian_in_leaf': 6,
        'max_depth':6,
        'learning_rate': 0.1,
        'feature_fraction': 0.8,
        'bagging_fraction': 0.8,
        'bagging_freq': 5,
        # 'lambda_l1': 0.1,
        # 'lambda_l2': 0.001,
        'min_gain_to_split': 0.2,
        'metric': 'multi_logloss',
        # 'random_state': 2019,
        'verbosity': -1,
        'verbose': 5,
        'is_unbalance': True,
        'nfold': 5,
        'device': 'gpu',
        'stratified': False,
        'shuffle': True,
        'verbose_eval': 50,
        'show_stdv': True
        }

    X, val_X, y, val_y = train_test_split(  
    RGB_arr,  
    label,  
    test_size=0.05,  
    random_state=1,  
    stratify=label ## 这里保证分割后y的比例分布与原数据一致  
    )  

    X_train = X  
    y_train = y  
    X_test = val_X  
    y_test = val_y
    

    train_data = lgb.Dataset(X_train, label = y_train)
    val_data = lgb.Dataset(X_test, label = y_test)

    clf = lgb.train(params, 
        train_data,
        num_boost_round=10000,
        valid_sets=val_data,
        early_stopping_rounds=500)

    with open(model_path, 'wb') as f:
            pickle.dump(clf, f)

    return array_num, class_final

def rf_train(class_list, model_path):
    array_num = len(class_list)
    RGB_arr = np.array([[0,0,0]])
    label= np.array([])
    count = 0

    class_final = {
    
    }
    for i in sorted(class_list):
        RGB_arr = np.concatenate((RGB_arr,class_list[i]),axis=0)
        array_l = class_list[i].shape[0]
        label = np.append(label, count * np.ones(array_l))
        class_final[i] = count
        count += 1
    RGB_arr = np.delete(RGB_arr,0,0)


    if os.path.exists(model_path):
        pass
    else:
        rf = RandomForestClassifier(n_estimators=500, max_depth=10, n_jobs=14)

        rf.fit(RGB_arr, label)
        # svc.fit(RGB_arr,label)
        with open(model_path, 'wb') as f:
            pickle.dump(rf, f)

    return array_num, class_final

def get_model(model_path):
    with open(model_path, 'rb') as f:
        svc = pickle.load(f)
    return svc

def clf_predict(clf, img_arr, array_num, outPath):
    img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])
    predict = clf.predict(img_reshape, num_iteration=clf.best_iteration)
    predict_r = [list(x).index(max(x)) for x in predict]
    predict_r = np.array(predict_r)

    img_arr = predict_r.reshape([img_arr.shape[0],img_arr.shape[1]])
    img_arr = img_arr.transpose((1,0))
  
    return img_arr


def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset


def write_img_(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, gdal.GDT_Byte)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

def array_change(inlist, outlist):
    for i in range(len(inlist[0])):
        outlist.append([j[i] for j in inlist])
    return outlist

def array_change2(inlist, outlist):
    for ele in inlist:
        for ele2 in ele:
            outlist.append(ele2)
    return outlist

def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
    out = np.zeros_like(bands).astype(np.float32)
    # a = 0
    # b = 65535
    a = img_min
    b = img_max
    # print(a, b)
    c = np.percentile(bands[:, :], lower_percent)
    d = np.percentile(bands[:, :], higher_percent)
    # x = d-c
    # if (x==0).any():
    #     t = 0
    # else:
    t = a + (bands[:, :] - c) * (b - a) / (d - c)
    t[t < a] = a
    t[t > b] = b
    out[:, :] = t
    return out

def getTifSize(tif):
    dataSet = gdal.Open(tif)
    width = dataSet.RasterXSize
    height = dataSet.RasterYSize
    bands = dataSet.RasterCount
    geoTrans = dataSet.GetGeoTransform()
    proj = dataSet.GetProjection()
    return width,height,bands,geoTrans,proj


# @jit(nopython=True)
def partDivisionForBoundary(model,array_num,tif1,divisionSize,tempPath):
    width,height,bands,geoTrans,proj = getTifSize(tif1)
    partWidth = partHeight = divisionSize

    if width % partWidth > 0 :
        widthNum = width // partWidth + 1
    else:
        widthNum =  width // partWidth
    if height % partHeight > 0:
        heightNum = height // partHeight +1
    else:
        heightNum = height // partHeight

    realName = os.path.split(tif1)[1].split(".")[0]

    tif1 = gdal.Open(tif1)
    for i in range(heightNum):
        for j in range(widthNum):

            startX = partWidth * j
            startY = partHeight * i

            if startX+partWidth<= width and startY+partHeight<=height:
                realPartWidth = partWidth
                realPartHeight = partHeight
            elif startX + partWidth > width and startY+partHeight<=height:
                realPartWidth = width - startX
                realPartHeight = partHeight
            elif startX+partWidth <= width  and startY+partHeight > height:
                realPartWidth = partWidth
                realPartHeight = height - startY
            elif startX + partWidth > width and startY+partHeight > height:
                realPartWidth = width - startX
                realPartHeight = height - startY

            outName = realName + str(i)+str(j)+".tif"
            outPath = os.path.join(tempPath,outName)

            if not os.path.exists(outPath):

                driver = gdal.GetDriverByName("GTiff")
                outTif = driver.Create(outPath,realPartWidth,realPartHeight,1,gdal.GDT_Float32)
                outTif.SetGeoTransform(geoTrans)
                outTif.SetProjection(proj)

                data1 = tif1.ReadAsArray(startX,startY,realPartWidth,realPartHeight)
                data1 = data1.transpose((2,1,0))
                svmData = clf_predict(model, data1, array_num, outPath)
                outTif.GetRasterBand(1).WriteArray(svmData)
    return 1

# @jit(nopython=True)
def partStretch(tif1,divisionSize,outStratchPath,tempPath):

    width,height,bands,geoTrans,proj = getTifSize(tif1)
    # bands = 1
    partWidth = partHeight = divisionSize

    if width % partWidth > 0 :
        widthNum = width // partWidth + 1
    else:
        widthNum =  width // partWidth
    if height % partHeight > 0:
        heightNum = height // partHeight +1
    else:
        heightNum = height // partHeight

    realName = os.path.split(tif1)[1].split(".")[0]

    driver = gdal.GetDriverByName("GTiff")
    outTif = driver.Create(outStratchPath,width,height,1,gdal.GDT_Byte)
    if outTif!= None:
        outTif.SetGeoTransform(geoTrans)
        outTif.SetProjection(proj)
    for i in range(heightNum):
        for j in range(widthNum):

            startX = partWidth * j
            startY = partHeight * i

            if startX+partWidth<= width and startY+partHeight<=height:
                realPartWidth = partWidth
                realPartHeight = partHeight
            elif startX + partWidth > width and startY+partHeight<=height:
                realPartWidth = width - startX
                realPartHeight = partHeight
            elif startX+partWidth <= width  and startY+partHeight > height:
                realPartWidth = partWidth
                realPartHeight = height - startY
            elif startX + partWidth > width and startY+partHeight > height:
                realPartWidth = width - startX
                realPartHeight = height - startY

            partTifName = realName+str(i)+str(j)+".tif"
            partTifPath = os.path.join(tempPath,partTifName)
            divisionImg = gdal.Open(partTifPath)
            for k in range(1):
                data1 = divisionImg.GetRasterBand(k+1).ReadAsArray(0,0,realPartWidth,realPartHeight)
                outPartBand = outTif.GetRasterBand(k+1)
                outPartBand.WriteArray(data1,startX,startY)


if __name__ == '__main__':
    config_file='config_order.txt'
    dirs=[]
    for line in open(config_file):
        dirs.append(line.split()[0])
    
    dir_n = len(dirs)
    shp_n = dir_n - 4
    data_image = dirs[0]
    data_image=data_image.replace('\\','/')
    
    point_dict = dirs[1]
    
    shp_paths = []
    for i in range(shp_n):
        shp_paths.append(dirs[i+2].replace('\\','/'))
   
    result_path = dirs[-2]
    result_path = result_path.replace('\\','/')
    
    temp_path = dirs[-1]
    temp_path = temp_path.replace('\\','/')

    time1 = time.time()
    print('Start ...')
    dd = {
    
    }
    point_dicts = point_dict.split(';')
    for p_dict in point_dicts:
        if p_dict == '' or p_dict == None:
            continue
        cls_n, d_name = p_dict.split(':')
        for p_path in shp_paths:
            path, p_name = os.path.split(p_path)
            if d_name == p_name:
                dd[cls_n] = p_path

    class_list = {
    
    }
    for k_cls, v_shp_path in dd.items():
        class_data  = getPixels(v_shp_path, data_image)
        class_list[k_cls] = class_data

    model_path = os.path.join(temp_path, 'model.pickle')
    
    print('Train model ...')
    num, class_final= lightlgb_classfiy(class_list, model_path)

    lgb = get_model(model_path)
    
    slice_path = os.path.join(temp_path, 'slice_temp')
    if os.path.exists(slice_path):
        pass
    else:
        os.mkdir(slice_path)

    print('Predict task area ...')
    partDivisionForBoundary(lgb,num,data_image,2000,slice_path)
    raster_path = os.path.join(temp_path, 'class_raster.tif')
    partStretch(data_image,2000,raster_path,slice_path)

    time2 = time.time()
    print((time2-time1)/3600)

猜你喜欢

转载自blog.csdn.net/qq_20373723/article/details/109908606