# import os
# from PIL import Image
# import cv2
# import numpy as np
# rgb_path='/mnt/sdb1/fenghaixia/jcrhsix/rgb/'
# h_path='/mnt/sdb1/fenghaixia/jcrhsix/dsm/'
# hrgb_path='/mnt/sdb1/fenghaixia/jcrhsix/hhhrgb/'
# savepath='/mnt/sdb1/fenghaixia/jcrhsix/aver/'
# filelist = os.listdir(rgb_path)
# for f in filelist:
# rgb_name=rgb_path+f.strip()
# h_name=h_path+f.strip()
# hrgb_name=hrgb_path+f.strip()
# # print(os.path.exists(rgb_name))
# # print(os.path.exists(h_name))
# # print(os.path.exists(hrgb_name))
# if os.path.exists(rgb_name) and os.path.exists(h_name) and os.path.exists(hrgb_name):
# # im = Image.open(path + item) #打开图片
# rgb_im = cv2.imread(rgb_name)
# h_im=cv2.imread(h_name)
# hrgb_im=cv2.imread(hrgb_name)
# aver_im=(rgb_im+h_im+hrgb_im)/3.0
# aver_im[aver_im>4.0]=255
# aver_im[aver_im<=4.0]=0
# cv2.imwrite(savepath + f, aver_im)
# print(f)
# print('finish')
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
import random
import shutil
from matplotlib.pyplot import MultipleLocator
#从pyplot导入MultipleLocator类,这个类用于设置刻度间隔
from time import time
from PIL import Image
from utils.utils_metrics import compute_mIoU
from utils.utils_metrics import compute_IoU
from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool
BATCHSIZE_PER_CARD = 16
class TTAFrame():
def __init__(self, net):
self.net = net().cuda()
self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
def test_one_img_from_path(self, path, evalmode = True):
if evalmode:
self.net.eval()
batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD
if batchsize >= 8:
return self.test_one_img_from_path_1(path)
def test_one_img_from_path_1(self, path):
img = cv2.imread(path)#.transpose(2,0,1)[None]
img90 = np.array(np.rot90(img))
img1 = np.concatenate([img[None],img90[None]])
img2 = np.array(img1)[:,::-1]
img3 = np.concatenate([img1,img2])
img4 = np.array(img3)[:,:,::-1]
img5 = np.concatenate([img3,img4]).transpose(0,3,1,2)
img5 = np.array(img5, np.float32)/255.0 * 3.2 -1.6
img5 = V(torch.Tensor(img5).cuda())
mask = self.net.forward(img5).squeeze().cpu().data.numpy()#.squeeze(1)
mask1 = mask[:4] + mask[4:,:,::-1]
mask2 = mask1[:2] + mask1[2:,::-1]
mask3 = mask2[0] + np.rot90(mask2[1])[::-1,::-1]
return mask3
def load(self, path):
# new_state_dict = OrderedDict()
# for key, value in torch.load(path).items():
# name = 'module.' + key
# new_state_dict[name] = value
#model.load_state_dict(new_state_dict)
#model = torch.load(path)
#model.pop('module.finaldeconv1.weight')
#model.pop('module.finalconv3.weight')
#self.net.load_state_dict(model,strict=False)
self.net.load_state_dict(torch.load(path))
# source = 'dataset/test/'
def saveList(pathName):
for file_name in pathName:
#f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
with open("./dataset/gt.txt", "a") as f:
f.write(file_name.split(".")[0] + "\n")
f.close
def savetrainList(pathName):
for file_name in pathName:
#f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
with open("./dataset/gt_train.txt", "a") as f:
f.write(file_name.split(".")[0] + "\n")
f.close
def dirList(gt_dir,path_list):
for i in range(0, len(path_list)):
path = os.path.join(gt_dir, path_list[i])
if os.path.isdir(path):
saveList(os.listdir(path))
print("开始运行!")
mylog = open('submits/count_low_pic.log','w')
#wtn:精度计算
miou_mode = 2
#------------------------------#
# 分类个数+1、如2+1
#------------------------------#
num_classes = 2
#--------------------------------------------#
# 区分的种类,和json_to_dataset里面的一样
#--------------------------------------------#
# name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
name_classes = ["nonwater","water"]
#-------------------------------------------------------#
# 指向VOC数据集所在的文件夹
# 默认指向根目录下的VOC数据集
#-------------------------------------------------------#
# data_path = './dataset/'
# data_train_path='./dataset/'
f=open("./dataset/gt.txt", 'w')
gt_dir = '/mnt/sdb1/fenghaixia/dddrgb/dataset/real/'
pred_dir = '/mnt/sdb1/fenghaixia/jcrhsix/aver/'
path_list = os.listdir('/mnt/sdb1/fenghaixia/jcrhsix/aver/')
path_list.sort()
dirList('/mnt/sdb1/fenghaixia/jcrhsix/aver/',path_list)
saveList(path_list)
image_ids = open(os.path.join('./dataset/', "gt.txt"),'r').read().splitlines()
train_mIou=[]
train_mPA=[]
test_mIou=[]
test_mPA=[]
if miou_mode == 0 or miou_mode == 2:
print('计算测试miou')
test_mIou,test_mPA,test_miou,test_mpa=compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes) # 执行计算mIoU的函数
mylog.write(' test_mIoU: '+str(test_miou))
mylog.write(' test_mPA: '+str(test_mpa))
print(' test_mIoU: '+str(test_miou))
# count=0
# print('计算测试样本单张iou')
# count=compute_IoU(gt_dir, pred_dir, image_ids, num_classes, lower_iou,higher_iou,weight_name,count) # 执行计算mIoU的函数
# mylog.write(' low-iou test picture num: '+str(count))
# print(weight_name + "Get miou done.")
mylog.write('Finish!')
print ('Finish!')
mylog.close()
3选2
import os
from PIL import Image
import cv2
import numpy as np
rgb_path='/mnt/sdb1/fenghaixia/jcrhsix/rgb/'
h_path='/mnt/sdb1/fenghaixia/jcrhsix/dsm/'
hrgb_path='/mnt/sdb1/fenghaixia/jcrhsix/hhhrgb/'
savepath='/mnt/sdb1/fenghaixia/jcrhsix/aver/'
filelist = os.listdir(rgb_path)
for f in filelist:
rgb_name=rgb_path+f.strip()
h_name=h_path+f.strip()
hrgb_name=hrgb_path+f.strip()
# print(os.path.exists(rgb_name))
# print(os.path.exists(h_name))
# print(os.path.exists(hrgb_name))
if os.path.exists(rgb_name) and os.path.exists(h_name) and os.path.exists(hrgb_name):
# im = Image.open(path + item) #打开图片
rgb_im = cv2.imread(rgb_name)
h_im=cv2.imread(h_name)
hrgb_im=cv2.imread(hrgb_name)
rgb_im[rgb_im>=4.0]=10
rgb_im[rgb_im<4.0]=0
h_im[h_im>=4.0]=10
h_im[h_im<4.0]=0
hrgb_im[hrgb_im>=4.0]=10
hrgb_im[hrgb_im<4.0]=0
all= rgb_im + h_im + hrgb_im
all[all>=20.0]=255
all[all<20.0]=0
cv2.imwrite(savepath + f, all)
print(f)
print('finish')