LightGBM是一个框架,内置了:
1.gbdt, 传统的梯度提升决策树;
2.rf, Random Forest (随机森林);
3.dart, Dropouts meet Multiple Additive Regression Trees;
4.goss, Gradient-based One-Side Sampling (基于梯度的单侧采样)。
这些方法的选择由 参数boosting_type来控制。
我之前的微博记录了随机森林和支持向量机的分类方法,这都是scikit-learn机器学习包里的机器学习方法,同样的方法和LightGBM相比就是无法GPU加速,这导致处理比较大的遥感影像比较慢,不好用于项目中,LightGBM框架虽然不包含SVM但是里面包含的方法不比SVM差
这里面的参数太多了,给大家几个链接好好研究吧,我也还在得到想要的结果的路上,建议有紧急任务的朋友先别入坑了,调参太浪费时间了,我也还没得到想要的结果
下面是示例:
结果:
说明:其实我分了4类,但只得到了3类,原因是有两个类很像,毕竟是像素级分类,仔细观察发现水体和林地确实太像了,任何一种像素级的分类方法要分开都比较难,但是比较重要的一点是分类结果的杂质比其他方法确实少很多,对于需要转到矢量的情况比较有利,另外提一嘴,这个方法陪上我前面的那篇crf优化分类结果的方式(https://blog.csdn.net/qq_20373723/article/details/109831725),结果的杂质会更少。
安装:https://www.jianshu.com/p/30555fd2bd50
1.https://blog.csdn.net/weixin_39807102/article/details/81912566?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-2.not_use_machine_learn_pai&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromBaidu-2.not_use_machine_learn_pai
2.https://blog.csdn.net/u011599639/article/details/98680280
3.https://www.jianshu.com/p/1100e333fcab
调参推荐看:
https://www.cnblogs.com/bjwu/p/9307344.html
下面是我用的代码,这个代码把数据分成了5份,训练了5个模型,预测的时候划框一个区域也预测了五次,速度被大大降低,建议用下一个版本,后面会提供,这个只是放这里给一个参考,毕竟这种方式有提高效果的可能,未尝不可一试,毕竟我也才接触这个
其中数据的处理用到了sklearn.model_selection中的KFold,需要了解的看下这个链接:https://blog.csdn.net/weixin_43685844/article/details/88635492
数据准备:
样本点矢量文件(学遥感测绘专业的应该都知道什么意思吧),和以前的博客里提到一样,每个类别对应一个矢量文件, 点越多越号,下面是我选的4个类别的分布
数据存放:
注:以前的博客这里的文件名最好是0,1,2,3按顺序来,这样和最后的类别是对应的,我这里命名错了,其实还可以通过键值对的方式实现类别的对应,下面会通过配置文件的形式体现,你们也可以不用改名字。
配置文件内容:
文件名config_order.txt
注:那个result本来是要把栅格结果转矢量的,我去掉了,你们想要可以自己找代码转,最后一个路径是临时文件的存放位置,记得每次重新跑的时候删除上一次得,不然跳过执行预测。
必须注意的是,你的影像和样本点文件必须是带有坐标系的,因为点样本是用来在影像上取数据的,是通过地理坐标取得,所以必须要有
版本1
# -*- coding: utf-8 -*-
from osgeo import ogr, osr
from osgeo import gdal
from gdalconst import *
import os, sys, time
import numpy as np
import lightgbm as lgb
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
import pickle
from sklearn.model_selection import KFold
import pandas as pd
def getPixels(shp, img):
driver = ogr.GetDriverByName('ESRI Shapefile')
ds = driver.Open(shp, 0)
if ds is None:
print('Could not open ' + shp)
sys.exit(1)
layer = ds.GetLayer()
xValues = []
yValues = []
feature = layer.GetNextFeature()
while feature:
geometry = feature.GetGeometryRef()
x = geometry.GetX()
y = geometry.GetY()
xValues.append(x)
yValues.append(y)
feature = layer.GetNextFeature()
gdal.AllRegister()
ds = gdal.Open(img, GA_ReadOnly)
if ds is None:
print('Could not open image')
sys.exit(1)
rows = ds.RasterYSize
cols = ds.RasterXSize
bands = ds.RasterCount
transform = ds.GetGeoTransform()
xOrigin = transform[0]
yOrigin = transform[3]
pixelWidth = transform[1]
pixelHeight = transform[5]
values = []
for i in range(len(xValues)):
x = xValues[i]
y = yValues[i]
xOffset = int((x - xOrigin) / pixelWidth)
yOffset = int((y - yOrigin) / pixelHeight)
s = str(int(x)) + ' ' + str(int(y)) + ' ' + str(xOffset) + ' ' + str(yOffset) + ' '
pt = []
for j in range(bands):
band = ds.GetRasterBand(j + 1)
data = band.ReadAsArray(xOffset-5, yOffset-5, 10, 10)
value = data
value = value.flatten()
pt.append(value)
temp = []
pt = array_change(pt, temp)
values.append(pt)
temp2 = []
all_values = array_change(values, temp2)
all_values = np.asarray(all_values)
temp3 = []
result_values = array_change2(all_values, temp3)
result_values = np.asarray(result_values)
return result_values
def lightlgb_classfiy(class_list, model_path):
array_num = len(class_list)
RGB_arr = np.array([[0,0,0]])
label= np.array([])
count = 0
class_final = {
}
for i in sorted(class_list):
RGB_arr = np.concatenate((RGB_arr,class_list[i]),axis=0)
array_l = class_list[i].shape[0]
label = np.append(label, count * np.ones(array_l))
class_final[i] = count
count += 1
RGB_arr = np.delete(RGB_arr,0,0)
# params = {'max_bin': 63,
# 'num_leaves': 255,
# 'learning_rate': 0.1,
# 'tree_learner': 'serial',
# 'task': 'train',
# 'is_training_metric': 'false',
# 'min_data_in_leaf': 1,
# 'min_sum_hessian_in_leaf': 100,
# 'ndcg_eval_at': [1,3,5,10],
# 'sparse_threshold': 1.0,
# 'device': 'gpu',
# 'gpu_platform_id': 0,
# 'gpu_device_id': 0}
# params = {'max_bin': 63,
# 'num_leaves': 255,
# 'learning_rate': 0.01,
# 'tree_learner': 'serial',
# 'task': 'train',
# 'is_training_metric': 'false',
# 'min_data_in_leaf': 1,
# 'min_sum_hessian_in_leaf': 100,
# 'ndcg_eval_at': [1, 3, 5, 10],
# 'sparse_threshold': 1.0,
# 'device': 'gpu'
# }
params = {
'boosting_type': 'gbdt',
'objective': 'multiclass',
'num_class': 4,
'metric': 'multi_error',
'num_leaves': 60,
'min_data_in_leaf': 30,
'min_sum_hessian_in_leaf': 6,
'max_depth':-1,
'learning_rate': 0.01,
'feature_fraction': 0.9,
'bagging_fraction': 0.8,
'bagging_freq': 1,
'lambda_l1': 0.1,
'lambda_l2': 0.001,
'min_gain_to_split': 0.2,
'metric': 'multi_logloss',
'random_state': 2019,
'verbosity': -1,
'verbose': 5,
'is_unbalance': True,
'device': 'gpu'
}
# params['is_unbalance']='true'
# params['metric'] = 'auc'
folds = KFold(n_splits = 5, shuffle = True, random_state = 2019)
prob_oof = np.zeros((RGB_arr.shape[0], RGB_arr.shape[1], 4))
feature_importance_df = pd.DataFrame()
models = []
for fold_, (train_idx, val_idx) in enumerate(folds.split(RGB_arr)):
print("fold {}".format(fold_ + 1))
train_data = lgb.Dataset(RGB_arr[train_idx], label = label[train_idx])
val_data = lgb.Dataset(RGB_arr[val_idx], label = label[val_idx])
clf = lgb.train(params,
train_data,
5000,
valid_sets=[train_data, val_data],
verbose_eval=200,
early_stopping_rounds=200)
fold_importance_df = pd.DataFrame()
# fold_importance_df["Feature"] = features
fold_importance_df["importance"] = clf.feature_importance()
fold_importance_df["fold"] = fold_ + 1
feature_importance_df = pd.concat([feature_importance_df, fold_importance_df], axis=0)
models.append(clf)
with open(model_path, 'wb') as f:
pickle.dump(np.array(models), f)
return array_num, class_final, models, folds
def rf_train(class_list, img_arr, model_path):
array_num = len(class_list)
RGB_arr = np.array([[0,0,0]])
label= np.array([])
count = 0
class_final = {
}
for i in sorted(class_list):
RGB_arr = np.concatenate((RGB_arr,class_list[i]),axis=0)
array_l = class_list[i].shape[0]
label = np.append(label, count * np.ones(array_l))
class_final[i] = count
count += 1
RGB_arr = np.delete(RGB_arr,0,0)
if os.path.exists(model_path):
pass
else:
rf = RandomForestClassifier(n_estimators=500, max_depth=10, n_jobs=14)
rf.fit(RGB_arr, label)
# svc.fit(RGB_arr,label)
with open(model_path, 'wb') as f:
pickle.dump(rf, f)
return array_num, class_final
def get_model(model_path):
with open(model_path, 'rb') as f:
svc = pickle.load(f)
return svc
def clf_predict(clf, img_arr, array_num, folds, outPath):
img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])
# predict = clf.predict(img_reshape, num_iteration=clf.best_iteration)
# predict_r = [list(x).index(max(x)) for x in predict]
# predict_r = np.array(predict_r)
test_pred_prob = np.zeros((img_arr.shape[0] * img_arr.shape[1], 4))
for cc in clf:
test_pred_prob += cc.predict(img_reshape, num_iteration=cc.best_iteration) / folds.n_splits
predict_r = np.argmax(test_pred_prob, axis=1)
for j in range(array_num):
lake_bool = predict_r == np.float(j)
lake_bool = lake_bool[:,np.newaxis]
try:
lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool,lake_bool),axis=1)
lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
img_arr[lake_bool_4d] = np.float(j)
except:
lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool),axis=1)
lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
img_arr[lake_bool_4d] = np.float(j)
img_arr = img_arr.transpose((2,1,0))
img_arr = img_arr[0]
# write_img(outPath, im_proj, im_geotrans, img_arr)
return img_arr
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_proj,im_geotrans,im_width, im_height,im_data
def write_img(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def write_img_(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, gdal.GDT_Byte)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def array_change(inlist, outlist):
for i in range(len(inlist[0])):
outlist.append([j[i] for j in inlist])
return outlist
def array_change2(inlist, outlist):
for ele in inlist:
for ele2 in ele:
outlist.append(ele2)
return outlist
def getTifSize(tif):
dataSet = gdal.Open(tif)
width = dataSet.RasterXSize
height = dataSet.RasterYSize
bands = dataSet.RasterCount
geoTrans = dataSet.GetGeoTransform()
proj = dataSet.GetProjection()
return width,height,bands,geoTrans,proj
def partDivisionForBoundary(model,array_num,tif1,folds,divisionSize,tempPath):
width,height,bands,geoTrans,proj = getTifSize(tif1)
partWidth = partHeight = divisionSize
if width % partWidth > 0 :
widthNum = width // partWidth + 1
else:
widthNum = width // partWidth
if height % partHeight > 0:
heightNum = height // partHeight +1
else:
heightNum = height // partHeight
realName = os.path.split(tif1)[1].split(".")[0]
tif1 = gdal.Open(tif1)
for i in range(heightNum):
for j in range(widthNum):
startX = partWidth * j
startY = partHeight * i
if startX+partWidth<= width and startY+partHeight<=height:
realPartWidth = partWidth
realPartHeight = partHeight
elif startX + partWidth > width and startY+partHeight<=height:
realPartWidth = width - startX
realPartHeight = partHeight
elif startX+partWidth <= width and startY+partHeight > height:
realPartWidth = partWidth
realPartHeight = height - startY
elif startX + partWidth > width and startY+partHeight > height:
realPartWidth = width - startX
realPartHeight = height - startY
outName = realName + str(i)+str(j)+".tif"
outPath = os.path.join(tempPath,outName)
if not os.path.exists(outPath):
driver = gdal.GetDriverByName("GTiff")
outTif = driver.Create(outPath,realPartWidth,realPartHeight,1,gdal.GDT_Float32)
outTif.SetGeoTransform(geoTrans)
outTif.SetProjection(proj)
data1 = tif1.ReadAsArray(startX,startY,realPartWidth,realPartHeight)
data1 = data1.transpose((2,1,0))
svmData = clf_predict(model, data1, array_num, folds, outPath)
outTif.GetRasterBand(1).WriteArray(svmData)
return 1
def partStretch(tif1,divisionSize,outStratchPath,tempPath):
width,height,bands,geoTrans,proj = getTifSize(tif1)
# bands = 1
partWidth = partHeight = divisionSize
if width % partWidth > 0 :
widthNum = width // partWidth + 1
else:
widthNum = width // partWidth
if height % partHeight > 0:
heightNum = height // partHeight +1
else:
heightNum = height // partHeight
realName = os.path.split(tif1)[1].split(".")[0]
driver = gdal.GetDriverByName("GTiff")
outTif = driver.Create(outStratchPath,width,height,1,gdal.GDT_Byte)
if outTif!= None:
outTif.SetGeoTransform(geoTrans)
outTif.SetProjection(proj)
for i in range(heightNum):
for j in range(widthNum):
startX = partWidth * j
startY = partHeight * i
if startX+partWidth<= width and startY+partHeight<=height:
realPartWidth = partWidth
realPartHeight = partHeight
elif startX + partWidth > width and startY+partHeight<=height:
realPartWidth = width - startX
realPartHeight = partHeight
elif startX+partWidth <= width and startY+partHeight > height:
realPartWidth = partWidth
realPartHeight = height - startY
elif startX + partWidth > width and startY+partHeight > height:
realPartWidth = width - startX
realPartHeight = height - startY
partTifName = realName+str(i)+str(j)+".tif"
partTifPath = os.path.join(tempPath,partTifName)
divisionImg = gdal.Open(partTifPath)
for k in range(1):
data1 = divisionImg.GetRasterBand(k+1).ReadAsArray(0,0,realPartWidth,realPartHeight)
outPartBand = outTif.GetRasterBand(k+1)
outPartBand.WriteArray(data1,startX,startY)
if __name__ == '__main__':
config_file='config_order.txt'
dirs=[]
for line in open(config_file):
dirs.append(line.split()[0])
dir_n = len(dirs)
shp_n = dir_n - 4
data_image = dirs[0]
data_image=data_image.replace('\\','/')
point_dict = dirs[1]
shp_paths = []
for i in range(shp_n):
shp_paths.append(dirs[i+2].replace('\\','/'))
result_path = dirs[-2]
result_path = result_path.replace('\\','/')
temp_path = dirs[-1]
temp_path = temp_path.replace('\\','/')
time1 = time.time()
print('Start ...')
dd = {
}
point_dicts = point_dict.split(';')
for p_dict in point_dicts:
if p_dict == '' or p_dict == None:
continue
cls_n, d_name = p_dict.split(':')
for p_path in shp_paths:
path, p_name = os.path.split(p_path)
if d_name == p_name:
dd[cls_n] = p_path
class_list = {
}
for k_cls, v_shp_path in dd.items():
class_data = getPixels(v_shp_path, data_image)
class_list[k_cls] = class_data
model_path = os.path.join(temp_path, 'model.pickle')
print('Train model ...')
num, class_final, models, folds= lightlgb_classfiy(class_list, model_path)
lgb = get_model(model_path)
slice_path = os.path.join(temp_path, 'slice_temp')
if os.path.exists(slice_path):
pass
else:
os.mkdir(slice_path)
print('Predict task area ...')
partDivisionForBoundary(lgb,num,data_image,folds,2000,slice_path)
raster_path = os.path.join(temp_path, 'class_raster.tif')
partStretch(data_image,2000,raster_path,slice_path)
time2 = time.time()
print((time2-time1)/3600)
版本2
数据处理部分不用KFold而改用train_test_split ,这次的时间倒是很快了,但是结果不是很好,调参还是个麻烦事,前面给了链接,大家自己去调整吧
# -*- coding: utf-8 -*-
from osgeo import ogr, osr
from osgeo import gdal
from gdalconst import *
import os, sys, time
import numpy as np
import lightgbm as lgb
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
import pickle
# from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
import pandas as pd
def getPixels(shp, img):
driver = ogr.GetDriverByName('ESRI Shapefile')
ds = driver.Open(shp, 0)
if ds is None:
print('Could not open ' + shp)
sys.exit(1)
layer = ds.GetLayer()
xValues = []
yValues = []
feature = layer.GetNextFeature()
while feature:
geometry = feature.GetGeometryRef()
x = geometry.GetX()
y = geometry.GetY()
xValues.append(x)
yValues.append(y)
feature = layer.GetNextFeature()
gdal.AllRegister()
ds = gdal.Open(img, GA_ReadOnly)
if ds is None:
print('Could not open image')
sys.exit(1)
rows = ds.RasterYSize
cols = ds.RasterXSize
bands = ds.RasterCount
transform = ds.GetGeoTransform()
xOrigin = transform[0]
yOrigin = transform[3]
pixelWidth = transform[1]
pixelHeight = transform[5]
values = []
for i in range(len(xValues)):
x = xValues[i]
y = yValues[i]
xOffset = int((x - xOrigin) / pixelWidth)
yOffset = int((y - yOrigin) / pixelHeight)
s = str(int(x)) + ' ' + str(int(y)) + ' ' + str(xOffset) + ' ' + str(yOffset) + ' '
pt = []
for j in range(bands):
band = ds.GetRasterBand(j + 1)
data = band.ReadAsArray(xOffset-5, yOffset-5, 10, 10)
value = data
value = value.flatten()
pt.append(value)
temp = []
pt = array_change(pt, temp)
values.append(pt)
temp2 = []
all_values = array_change(values, temp2)
all_values = np.asarray(all_values)
temp3 = []
result_values = array_change2(all_values, temp3)
result_values = np.asarray(result_values)
return result_values
def lightlgb_classfiy(class_list, model_path):
array_num = len(class_list)
RGB_arr = np.array([[0,0,0]])
label= np.array([])
count = 0
class_final = {
}
for i in sorted(class_list):
RGB_arr = np.concatenate((RGB_arr,class_list[i]),axis=0)
array_l = class_list[i].shape[0]
label = np.append(label, count * np.ones(array_l))
class_final[i] = count
count += 1
RGB_arr = np.delete(RGB_arr,0,0)
params = {
'boosting_type': 'gbdt',
'objective': 'multiclass',
'num_class': 4,
'metric': 'multi_error',
'num_leaves': 50,
'min_data_in_leaf': 100,
'min_sum_hessian_in_leaf': 6,
'max_depth':6,
'learning_rate': 0.1,
'feature_fraction': 0.8,
'bagging_fraction': 0.8,
'bagging_freq': 5,
# 'lambda_l1': 0.1,
# 'lambda_l2': 0.001,
'min_gain_to_split': 0.2,
'metric': 'multi_logloss',
# 'random_state': 2019,
'verbosity': -1,
'verbose': 5,
'is_unbalance': True,
'nfold': 5,
'device': 'gpu',
'stratified': False,
'shuffle': True,
'verbose_eval': 50,
'show_stdv': True
}
X, val_X, y, val_y = train_test_split(
RGB_arr,
label,
test_size=0.05,
random_state=1,
stratify=label ## 这里保证分割后y的比例分布与原数据一致
)
X_train = X
y_train = y
X_test = val_X
y_test = val_y
train_data = lgb.Dataset(X_train, label = y_train)
val_data = lgb.Dataset(X_test, label = y_test)
clf = lgb.train(params,
train_data,
num_boost_round=10000,
valid_sets=val_data,
early_stopping_rounds=500)
with open(model_path, 'wb') as f:
pickle.dump(clf, f)
return array_num, class_final
def rf_train(class_list, model_path):
array_num = len(class_list)
RGB_arr = np.array([[0,0,0]])
label= np.array([])
count = 0
class_final = {
}
for i in sorted(class_list):
RGB_arr = np.concatenate((RGB_arr,class_list[i]),axis=0)
array_l = class_list[i].shape[0]
label = np.append(label, count * np.ones(array_l))
class_final[i] = count
count += 1
RGB_arr = np.delete(RGB_arr,0,0)
if os.path.exists(model_path):
pass
else:
rf = RandomForestClassifier(n_estimators=500, max_depth=10, n_jobs=14)
rf.fit(RGB_arr, label)
# svc.fit(RGB_arr,label)
with open(model_path, 'wb') as f:
pickle.dump(rf, f)
return array_num, class_final
def get_model(model_path):
with open(model_path, 'rb') as f:
svc = pickle.load(f)
return svc
def clf_predict(clf, img_arr, array_num, outPath):
img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])
predict = clf.predict(img_reshape, num_iteration=clf.best_iteration)
predict_r = [list(x).index(max(x)) for x in predict]
predict_r = np.array(predict_r)
img_arr = predict_r.reshape([img_arr.shape[0],img_arr.shape[1]])
img_arr = img_arr.transpose((1,0))
return img_arr
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_proj,im_geotrans,im_width, im_height,im_data
def write_img(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def write_img_(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, gdal.GDT_Byte)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def array_change(inlist, outlist):
for i in range(len(inlist[0])):
outlist.append([j[i] for j in inlist])
return outlist
def array_change2(inlist, outlist):
for ele in inlist:
for ele2 in ele:
outlist.append(ele2)
return outlist
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
out = np.zeros_like(bands).astype(np.float32)
# a = 0
# b = 65535
a = img_min
b = img_max
# print(a, b)
c = np.percentile(bands[:, :], lower_percent)
d = np.percentile(bands[:, :], higher_percent)
# x = d-c
# if (x==0).any():
# t = 0
# else:
t = a + (bands[:, :] - c) * (b - a) / (d - c)
t[t < a] = a
t[t > b] = b
out[:, :] = t
return out
def getTifSize(tif):
dataSet = gdal.Open(tif)
width = dataSet.RasterXSize
height = dataSet.RasterYSize
bands = dataSet.RasterCount
geoTrans = dataSet.GetGeoTransform()
proj = dataSet.GetProjection()
return width,height,bands,geoTrans,proj
# @jit(nopython=True)
def partDivisionForBoundary(model,array_num,tif1,divisionSize,tempPath):
width,height,bands,geoTrans,proj = getTifSize(tif1)
partWidth = partHeight = divisionSize
if width % partWidth > 0 :
widthNum = width // partWidth + 1
else:
widthNum = width // partWidth
if height % partHeight > 0:
heightNum = height // partHeight +1
else:
heightNum = height // partHeight
realName = os.path.split(tif1)[1].split(".")[0]
tif1 = gdal.Open(tif1)
for i in range(heightNum):
for j in range(widthNum):
startX = partWidth * j
startY = partHeight * i
if startX+partWidth<= width and startY+partHeight<=height:
realPartWidth = partWidth
realPartHeight = partHeight
elif startX + partWidth > width and startY+partHeight<=height:
realPartWidth = width - startX
realPartHeight = partHeight
elif startX+partWidth <= width and startY+partHeight > height:
realPartWidth = partWidth
realPartHeight = height - startY
elif startX + partWidth > width and startY+partHeight > height:
realPartWidth = width - startX
realPartHeight = height - startY
outName = realName + str(i)+str(j)+".tif"
outPath = os.path.join(tempPath,outName)
if not os.path.exists(outPath):
driver = gdal.GetDriverByName("GTiff")
outTif = driver.Create(outPath,realPartWidth,realPartHeight,1,gdal.GDT_Float32)
outTif.SetGeoTransform(geoTrans)
outTif.SetProjection(proj)
data1 = tif1.ReadAsArray(startX,startY,realPartWidth,realPartHeight)
data1 = data1.transpose((2,1,0))
svmData = clf_predict(model, data1, array_num, outPath)
outTif.GetRasterBand(1).WriteArray(svmData)
return 1
# @jit(nopython=True)
def partStretch(tif1,divisionSize,outStratchPath,tempPath):
width,height,bands,geoTrans,proj = getTifSize(tif1)
# bands = 1
partWidth = partHeight = divisionSize
if width % partWidth > 0 :
widthNum = width // partWidth + 1
else:
widthNum = width // partWidth
if height % partHeight > 0:
heightNum = height // partHeight +1
else:
heightNum = height // partHeight
realName = os.path.split(tif1)[1].split(".")[0]
driver = gdal.GetDriverByName("GTiff")
outTif = driver.Create(outStratchPath,width,height,1,gdal.GDT_Byte)
if outTif!= None:
outTif.SetGeoTransform(geoTrans)
outTif.SetProjection(proj)
for i in range(heightNum):
for j in range(widthNum):
startX = partWidth * j
startY = partHeight * i
if startX+partWidth<= width and startY+partHeight<=height:
realPartWidth = partWidth
realPartHeight = partHeight
elif startX + partWidth > width and startY+partHeight<=height:
realPartWidth = width - startX
realPartHeight = partHeight
elif startX+partWidth <= width and startY+partHeight > height:
realPartWidth = partWidth
realPartHeight = height - startY
elif startX + partWidth > width and startY+partHeight > height:
realPartWidth = width - startX
realPartHeight = height - startY
partTifName = realName+str(i)+str(j)+".tif"
partTifPath = os.path.join(tempPath,partTifName)
divisionImg = gdal.Open(partTifPath)
for k in range(1):
data1 = divisionImg.GetRasterBand(k+1).ReadAsArray(0,0,realPartWidth,realPartHeight)
outPartBand = outTif.GetRasterBand(k+1)
outPartBand.WriteArray(data1,startX,startY)
if __name__ == '__main__':
config_file='config_order.txt'
dirs=[]
for line in open(config_file):
dirs.append(line.split()[0])
dir_n = len(dirs)
shp_n = dir_n - 4
data_image = dirs[0]
data_image=data_image.replace('\\','/')
point_dict = dirs[1]
shp_paths = []
for i in range(shp_n):
shp_paths.append(dirs[i+2].replace('\\','/'))
result_path = dirs[-2]
result_path = result_path.replace('\\','/')
temp_path = dirs[-1]
temp_path = temp_path.replace('\\','/')
time1 = time.time()
print('Start ...')
dd = {
}
point_dicts = point_dict.split(';')
for p_dict in point_dicts:
if p_dict == '' or p_dict == None:
continue
cls_n, d_name = p_dict.split(':')
for p_path in shp_paths:
path, p_name = os.path.split(p_path)
if d_name == p_name:
dd[cls_n] = p_path
class_list = {
}
for k_cls, v_shp_path in dd.items():
class_data = getPixels(v_shp_path, data_image)
class_list[k_cls] = class_data
model_path = os.path.join(temp_path, 'model.pickle')
print('Train model ...')
num, class_final= lightlgb_classfiy(class_list, model_path)
lgb = get_model(model_path)
slice_path = os.path.join(temp_path, 'slice_temp')
if os.path.exists(slice_path):
pass
else:
os.mkdir(slice_path)
print('Predict task area ...')
partDivisionForBoundary(lgb,num,data_image,2000,slice_path)
raster_path = os.path.join(temp_path, 'class_raster.tif')
partStretch(data_image,2000,raster_path,slice_path)
time2 = time.time()
print((time2-time1)/3600)