https://blog.csdn.net/PC1022/article/details/80440913
warnings.filterwarnings("ignore") #忽略警告
class Logger(object): #保存日志函数
def __init__(self, filename="Default.log"):
self.terminal = sys.stdout
self.log = open(filename, "a")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
pass
sys.stdout = Logger("path/cnn/log/alexnet_image_show.txt")
#显示图片函数
def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
# 模型搭建
model = models.alexnet(pretrained=False)
model.classifier = nn.Sequential(nn.Linear(9216, 4096),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(4096, 2))
print("model", model)
#加载预训练模型
model.load_state_dict(torch.load("/path/cnn/model/alexnet_model.pkl", map_location=lambda storage, loc: storage))
#数据预处理
data_transform = transforms.Compose([
transforms.Scale((224, 224), 2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
#创建数据集
test_dataset = torchvision.datasets.ImageFolder("/path/data/show", data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)
#分类的类别
class_names = test_dataset.classes
# 显示一些图片预测函数
def visualize_model(model, num_images):
model.eval()
images_so_far = 0
for i, data in enumerate(test_loader):
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
for j in range(inputs.size()[0]):
images_so_far += 1
ax = plt.subplot(num_images//2, 2, images_so_far)
ax.axis('off')
ax.set_title('predicted: {}'.format(class_names[predicted[j]]))
imshow(inputs.cpu().data[j])
if images_so_far == num_images:
return
visualize_model(model, 10) 显示十张图片
# plt.ioff() #“关闭交互模式”。
plt.savefig("/path/cnn/log/pic/alexnet.png") # 保存图
plt.show()