图神经网络(三)GCN的变体与框架(6)GraphSAGE实战完整代码

图神经网络(三)GCN的变体与框架(6)GraphSAGE实战完整代码

完整代码

代码说明

1.新建文件夹,将如下代码文件创建,如下图所示;



2.下载数据,下载地址为

3.新建data/cora文件夹,将下载好的数据放入其中,如下图所示;



4.直接运行 main.py 即可。

data.py

import os
import os.path as osp
import pickle
import numpy as np
import itertools
import scipy.sparse as sp
import urllib
from collections import namedtuple


Data = namedtuple('Data', ['x', 'y', 'adjacency_dict',
                           'train_mask', 'val_mask', 'test_mask'])


class CoraData(object):
    filenames = ["ind.cora.{}".format(name) for name in
                 ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']]

    def __init__(self, data_root="../data/cora", rebuild=False):
        """Cora数据,包括数据下载,处理,加载等功能
        当数据的缓存文件存在时,将使用缓存文件,否则将下载、进行处理,并缓存到磁盘

        处理之后的数据可以通过属性 .data 获得,它将返回一个数据对象,包括如下几部分:
            * x: 节点的特征,维度为 2708 * 1433,类型为 np.ndarray
            * y: 节点的标签,总共包括7个类别,类型为 np.ndarray
            * adjacency_dict: 邻接信息,,类型为 dict
            * train_mask: 训练集掩码向量,维度为 2708,当节点属于训练集时,相应位置为True,否则False
            * val_mask: 验证集掩码向量,维度为 2708,当节点属于验证集时,相应位置为True,否则False
            * test_mask: 测试集掩码向量,维度为 2708,当节点属于测试集时,相应位置为True,否则False

        Args:
        -------
            data_root: string, optional
                存放数据的目录,原始数据路径: ../data/cora
                缓存数据路径: {data_root}/ch7_cached.pkl
            rebuild: boolean, optional
                是否需要重新构建数据集,当设为True时,如果存在缓存数据也会重建数据

        """
        self.data_root = data_root
        save_file = osp.join(self.data_root, "ch7_cached.pkl")
        if osp.exists(save_file) and not rebuild:
            print("Using Cached file: {}".format(save_file))
            self._data = pickle.load(open(save_file, "rb"))
        else:
            self._data = self.process_data()
            with open(save_file, "wb") as f:
                pickle.dump(self.data, f)
            print("Cached file: {}".format(save_file))

    @property
    def data(self):
        """返回Data数据对象,包括x, y, adjacency, train_mask, val_mask, test_mask"""
        return self._data

    def process_data(self):
        """
        处理数据,得到节点特征和标签,邻接矩阵,训练集、验证集以及测试集
        引用自:https://github.com/rusty1s/pytorch_geometric
        """
        print("Process data ...")
        _, tx, allx, y, ty, ally, graph, test_index = [self.read_data(
            osp.join(self.data_root, name)) for name in self.filenames]
        train_index = np.arange(y.shape[0])
        val_index = np.arange(y.shape[0], y.shape[0] + 500)
        sorted_test_index = sorted(test_index)

        x = np.concatenate((allx, tx), axis=0)
        y = np.concatenate((ally, ty), axis=0).argmax(axis=1)

        x[test_index] = x[sorted_test_index]
        y[test_index] = y[sorted_test_index]
        num_nodes = x.shape[0]

        train_mask = np.zeros(num_nodes, dtype=np.bool)
        val_mask = np.zeros(num_nodes, dtype=np.bool)
        test_mask = np.zeros(num_nodes, dtype=np.bool)
        train_mask[train_index] = True
        val_mask[val_index] = True
        test_mask[test_index] = True
        adjacency_dict = graph
        print("Node's feature shape: ", x.shape)
        print("Node's label shape: ", y.shape)
        print("Adjacency's shape: ", len(adjacency_dict))
        print("Number of training nodes: ", train_mask.sum())
        print("Number of validation nodes: ", val_mask.sum())
        print("Number of test nodes: ", test_mask.sum())

        return Data(x=x, y=y, adjacency_dict=adjacency_dict,
                    train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)

    @staticmethod
    def build_adjacency(adj_dict):
        """根据邻接表创建邻接矩阵"""
        edge_index = []
        num_nodes = len(adj_dict)
        for src, dst in adj_dict.items():
            edge_index.extend([src, v] for v in dst)
            edge_index.extend([v, src] for v in dst)
        # 去除重复的边
        edge_index = list(k for k, _ in itertools.groupby(sorted(edge_index)))
        edge_index = np.asarray(edge_index)
        adjacency = sp.coo_matrix((np.ones(len(edge_index)),
                                   (edge_index[:, 0], edge_index[:, 1])),
                                  shape=(num_nodes, num_nodes), dtype="float32")
        return adjacency

    @staticmethod
    def read_data(path):
        """使用不同的方式读取原始数据以进一步处理"""
        name = osp.basename(path)
        if name == "ind.cora.test.index":
            out = np.genfromtxt(path, dtype="int64")
            return out
        else:
            out = pickle.load(open(path, "rb"), encoding="latin1")
            out = out.toarray() if hasattr(out, "toarray") else out
            return out



sampling.py

import numpy as np


def sampling(src_nodes, sample_num, neighbor_table):
    """根据源节点采样指定数量的邻居节点,注意使用的是有放回的采样;
    某个节点的邻居节点数量少于采样数量时,采样结果出现重复的节点
    
    Arguments:
        src_nodes {list, ndarray} -- 源节点列表
        sample_num {int} -- 需要采样的节点数
        neighbor_table {dict} -- 节点到其邻居节点的映射表
    
    Returns:
        np.ndarray -- 采样结果构成的列表
    """
    results = []
    for sid in src_nodes:
        # 从节点的邻居中进行有放回地进行采样
        res = np.random.choice(neighbor_table[sid], size=(sample_num, ))
        results.append(res)
    return np.asarray(results).flatten()


def multihop_sampling(src_nodes, sample_nums, neighbor_table):
    """根据源节点进行多阶采样
    
    Arguments:
        src_nodes {list, np.ndarray} -- 源节点id
        sample_nums {list of int} -- 每一阶需要采样的个数
        neighbor_table {dict} -- 节点到其邻居节点的映射
    
    Returns:
        [list of ndarray] -- 每一阶采样的结果
    """
    sampling_result = [src_nodes]
    for k, hopk_num in enumerate(sample_nums):
        hopk_result = sampling(sampling_result[k], hopk_num, neighbor_table)
        sampling_result.append(hopk_result)
    return sampling_result



net.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init


class NeighborAggregator(nn.Module):
    def __init__(self, input_dim, output_dim, 
                 use_bias=False, aggr_method="mean"):
        """聚合节点邻居

        Args:
            input_dim: 输入特征的维度
            output_dim: 输出特征的维度
            use_bias: 是否使用偏置 (default: {False})
            aggr_method: 邻居聚合方式 (default: {mean})
        """
        super(NeighborAggregator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.use_bias = use_bias
        self.aggr_method = aggr_method
        self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
        if self.use_bias:
            self.bias = nn.Parameter(torch.Tensor(self.output_dim))
        self.reset_parameters()
    
    def reset_parameters(self):
        init.kaiming_uniform_(self.weight)
        if self.use_bias:
            init.zeros_(self.bias)

    def forward(self, neighbor_feature):
        if self.aggr_method == "mean":
            aggr_neighbor = neighbor_feature.mean(dim=1)
        elif self.aggr_method == "sum":
            aggr_neighbor = neighbor_feature.sum(dim=1)
        elif self.aggr_method == "max":
            aggr_neighbor = neighbor_feature.max(dim=1)
        else:
            raise ValueError("Unknown aggr type, expected sum, max, or mean, but got {}"
                             .format(self.aggr_method))
        
        neighbor_hidden = torch.matmul(aggr_neighbor, self.weight)
        if self.use_bias:
            neighbor_hidden += self.bias

        return neighbor_hidden

    def extra_repr(self):
        return 'in_features={}, out_features={}, aggr_method={}'.format(
            self.input_dim, self.output_dim, self.aggr_method)
    

class SageGCN(nn.Module):
    def __init__(self, input_dim, hidden_dim,
                 activation=F.relu,
                 aggr_neighbor_method="mean",
                 aggr_hidden_method="sum"):
        """SageGCN层定义

        Args:
            input_dim: 输入特征的维度
            hidden_dim: 隐层特征的维度,
                当aggr_hidden_method=sum, 输出维度为hidden_dim
                当aggr_hidden_method=concat, 输出维度为hidden_dim*2
            activation: 激活函数
            aggr_neighbor_method: 邻居特征聚合方法,["mean", "sum", "max"]
            aggr_hidden_method: 节点特征的更新方法,["sum", "concat"]
        """
        super(SageGCN, self).__init__()
        assert aggr_neighbor_method in ["mean", "sum", "max"]
        assert aggr_hidden_method in ["sum", "concat"]
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.aggr_neighbor_method = aggr_neighbor_method
        self.aggr_hidden_method = aggr_hidden_method
        self.activation = activation
        self.aggregator = NeighborAggregator(input_dim, hidden_dim,
                                             aggr_method=aggr_neighbor_method)
        self.weight = nn.Parameter(torch.Tensor(input_dim, hidden_dim))
        self.reset_parameters()
    
    def reset_parameters(self):
        init.kaiming_uniform_(self.weight)

    def forward(self, src_node_features, neighbor_node_features):
        neighbor_hidden = self.aggregator(neighbor_node_features)
        self_hidden = torch.matmul(src_node_features, self.weight)
        
        if self.aggr_hidden_method == "sum":
            hidden = self_hidden + neighbor_hidden
        elif self.aggr_hidden_method == "concat":
            hidden = torch.cat([self_hidden, neighbor_hidden], dim=1)
        else:
            raise ValueError("Expected sum or concat, got {}"
                             .format(self.aggr_hidden))
        if self.activation:
            return self.activation(hidden)
        else:
            return hidden

    def extra_repr(self):
        output_dim = self.hidden_dim if self.aggr_hidden_method == "sum" else self.hidden_dim * 2
        return 'in_features={}, out_features={}, aggr_hidden_method={}'.format(
            self.input_dim, output_dim, self.aggr_hidden_method)


class GraphSage(nn.Module):
    def __init__(self, input_dim, hidden_dim,
                 num_neighbors_list):
        super(GraphSage, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_neighbors_list = num_neighbors_list
        self.num_layers = len(num_neighbors_list)
        self.gcn = nn.ModuleList()
        self.gcn.append(SageGCN(input_dim, hidden_dim[0]))
        for index in range(0, len(hidden_dim) - 2):
            self.gcn.append(SageGCN(hidden_dim[index], hidden_dim[index+1]))
        self.gcn.append(SageGCN(hidden_dim[-2], hidden_dim[-1], activation=None))

    def forward(self, node_features_list):
        hidden = node_features_list
        for l in range(self.num_layers):
            next_hidden = []
            gcn = self.gcn[l]
            for hop in range(self.num_layers - l):
                src_node_features = hidden[hop]
                src_node_num = len(src_node_features)
                neighbor_node_features = hidden[hop + 1] \
                    .view((src_node_num, self.num_neighbors_list[hop], -1))
                h = gcn(src_node_features, neighbor_node_features)
                next_hidden.append(h)
            hidden = next_hidden
        return hidden[0]

    def extra_repr(self):
        return 'in_features={}, num_neighbors_list={}'.format(
            self.input_dim, self.num_neighbors_list
        )



main.py

#coding: utf-8
"""
基于Cora的GraphSage示例
"""
import torch

import numpy as np
import torch.nn as nn
import torch.optim as optim
from net import GraphSage
from data import CoraData
from sampling import multihop_sampling

from collections import namedtuple
INPUT_DIM = 1433    # 输入维度
# Note: 采样的邻居阶数需要与GCN的层数保持一致
HIDDEN_DIM = [128, 7]   # 隐藏单元节点数
NUM_NEIGHBORS_LIST = [10, 10]   # 每阶采样邻居的节点数
assert len(HIDDEN_DIM) == len(NUM_NEIGHBORS_LIST)
BTACH_SIZE = 16     # 批处理大小
EPOCHS = 20
NUM_BATCH_PER_EPOCH = 20    # 每个epoch循环的批次数
LEARNING_RATE = 0.01    # 学习率
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Data = namedtuple('Data', ['x', 'y', 'adjacency_dict',
                           'train_mask', 'val_mask', 'test_mask'])

data = CoraData().data
x = data.x / data.x.sum(1, keepdims=True)  # 归一化数据,使得每一行和为1

train_index = np.where(data.train_mask)[0]
train_label = data.y
test_index = np.where(data.test_mask)[0]
model = GraphSage(input_dim=INPUT_DIM, hidden_dim=HIDDEN_DIM,
                  num_neighbors_list=NUM_NEIGHBORS_LIST).to(DEVICE)
print(model)
criterion = nn.CrossEntropyLoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=5e-4)


def train():
    model.train()
    for e in range(EPOCHS):
        for batch in range(NUM_BATCH_PER_EPOCH):
            batch_src_index = np.random.choice(train_index, size=(BTACH_SIZE,))
            batch_src_label = torch.from_numpy(train_label[batch_src_index]).long().to(DEVICE)
            batch_sampling_result = multihop_sampling(batch_src_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
            batch_sampling_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in batch_sampling_result]
            batch_train_logits = model(batch_sampling_x)
            loss = criterion(batch_train_logits, batch_src_label)
            optimizer.zero_grad()
            loss.backward()  # 反向传播计算参数的梯度
            optimizer.step()  # 使用优化方法进行梯度更新
            print("Epoch {:03d} Batch {:03d} Loss: {:.4f}".format(e, batch, loss.item()))
        test()


def test():
    model.eval()
    with torch.no_grad():
        test_sampling_result = multihop_sampling(test_index, NUM_NEIGHBORS_LIST, data.adjacency_dict)
        test_x = [torch.from_numpy(x[idx]).float().to(DEVICE) for idx in test_sampling_result]
        test_logits = model(test_x)
        test_label = torch.from_numpy(data.y[test_index]).long().to(DEVICE)
        predict_y = test_logits.max(1)[1]
        accuarcy = torch.eq(predict_y, test_label).float().mean().item()
        print("Test Accuracy: ", accuarcy)


if __name__ == '__main__':
    train()



运行结果

Accuracy(准确率)为76.3%。

Epoch 000 Batch 000 Loss: 1.9297
Epoch 000 Batch 001 Loss: 1.7703
Epoch 000 Batch 002 Loss: 1.7832
Epoch 000 Batch 003 Loss: 1.4397
Epoch 000 Batch 004 Loss: 1.3198
Epoch 000 Batch 005 Loss: 1.2239
Epoch 000 Batch 006 Loss: 1.2981
Epoch 000 Batch 007 Loss: 1.0694
Epoch 000 Batch 008 Loss: 0.5994
Epoch 000 Batch 009 Loss: 0.9742
Epoch 000 Batch 010 Loss: 1.0316
Epoch 000 Batch 011 Loss: 0.8984
Epoch 000 Batch 012 Loss: 1.1576
Epoch 000 Batch 013 Loss: 0.7491
Epoch 000 Batch 014 Loss: 0.5728
Epoch 000 Batch 015 Loss: 0.5756
Epoch 000 Batch 016 Loss: 0.3538
Epoch 000 Batch 017 Loss: 0.6618
Epoch 000 Batch 018 Loss: 0.2654
Epoch 000 Batch 019 Loss: 0.5493
Test Accuracy:  0.7210000157356262
Epoch 001 Batch 000 Loss: 0.3912
Epoch 001 Batch 001 Loss: 0.2587
Epoch 001 Batch 002 Loss: 0.2159
Epoch 001 Batch 003 Loss: 0.1297
Epoch 001 Batch 004 Loss: 0.1339
Epoch 001 Batch 005 Loss: 0.2038
Epoch 001 Batch 006 Loss: 0.1012
Epoch 001 Batch 007 Loss: 0.1285
Epoch 001 Batch 008 Loss: 0.1147
Epoch 001 Batch 009 Loss: 0.1407
Epoch 001 Batch 010 Loss: 0.1206
Epoch 001 Batch 011 Loss: 0.1955
Epoch 001 Batch 012 Loss: 0.0705
Epoch 001 Batch 013 Loss: 0.0626
Epoch 001 Batch 014 Loss: 0.2620
Epoch 001 Batch 015 Loss: 0.0830
Epoch 001 Batch 016 Loss: 0.0239
Epoch 001 Batch 017 Loss: 0.0721
Epoch 001 Batch 018 Loss: 0.0447
Epoch 001 Batch 019 Loss: 0.0936
Test Accuracy:  0.777999997138977
Epoch 002 Batch 000 Loss: 0.0198
Epoch 002 Batch 001 Loss: 0.0379
Epoch 002 Batch 002 Loss: 0.0442
Epoch 002 Batch 003 Loss: 0.0293
Epoch 002 Batch 004 Loss: 0.0395
Epoch 002 Batch 005 Loss: 0.0735
Epoch 002 Batch 006 Loss: 0.0920
Epoch 002 Batch 007 Loss: 0.0250
Epoch 002 Batch 008 Loss: 0.0818
Epoch 002 Batch 009 Loss: 0.0584
Epoch 002 Batch 010 Loss: 0.0673
Epoch 002 Batch 011 Loss: 0.0441
Epoch 002 Batch 012 Loss: 0.0353
Epoch 002 Batch 013 Loss: 0.0423
Epoch 002 Batch 014 Loss: 0.0392
Epoch 002 Batch 015 Loss: 0.0571
Epoch 002 Batch 016 Loss: 0.0413
Epoch 002 Batch 017 Loss: 0.0457
Epoch 002 Batch 018 Loss: 0.0389
Epoch 002 Batch 019 Loss: 0.0349
Test Accuracy:  0.7429999709129333
Epoch 003 Batch 000 Loss: 0.0390
Epoch 003 Batch 001 Loss: 0.0398
Epoch 003 Batch 002 Loss: 0.0427
Epoch 003 Batch 003 Loss: 0.0505
Epoch 003 Batch 004 Loss: 0.0429
Epoch 003 Batch 005 Loss: 0.0486
Epoch 003 Batch 006 Loss: 0.0382
Epoch 003 Batch 007 Loss: 0.0308
Epoch 003 Batch 008 Loss: 0.0745
Epoch 003 Batch 009 Loss: 0.0795
Epoch 003 Batch 010 Loss: 0.0330
Epoch 003 Batch 011 Loss: 0.0441
Epoch 003 Batch 012 Loss: 0.0480
Epoch 003 Batch 013 Loss: 0.0552
Epoch 003 Batch 014 Loss: 0.0404
Epoch 003 Batch 015 Loss: 0.0587
Epoch 003 Batch 016 Loss: 0.0653
Epoch 003 Batch 017 Loss: 0.0492
Epoch 003 Batch 018 Loss: 0.0592
Epoch 003 Batch 019 Loss: 0.0544
Test Accuracy:  0.7940000295639038
Epoch 004 Batch 000 Loss: 0.0572
Epoch 004 Batch 001 Loss: 0.0549
Epoch 004 Batch 002 Loss: 0.0433
Epoch 004 Batch 003 Loss: 0.0715
Epoch 004 Batch 004 Loss: 0.0527
Epoch 004 Batch 005 Loss: 0.0559
Epoch 004 Batch 006 Loss: 0.0504
Epoch 004 Batch 007 Loss: 0.0501
Epoch 004 Batch 008 Loss: 0.0745
Epoch 004 Batch 009 Loss: 0.0684
Epoch 004 Batch 010 Loss: 0.0415
Epoch 004 Batch 011 Loss: 0.0508
Epoch 004 Batch 012 Loss: 0.0440
Epoch 004 Batch 013 Loss: 0.0366
Epoch 004 Batch 014 Loss: 0.0565
Epoch 004 Batch 015 Loss: 0.0580
Epoch 004 Batch 016 Loss: 0.0566
Epoch 004 Batch 017 Loss: 0.0573
Epoch 004 Batch 018 Loss: 0.0391
Epoch 004 Batch 019 Loss: 0.0562
Test Accuracy:  0.7820000052452087
Epoch 005 Batch 000 Loss: 0.0552
Epoch 005 Batch 001 Loss: 0.0613
Epoch 005 Batch 002 Loss: 0.0380
Epoch 005 Batch 003 Loss: 0.0256
Epoch 005 Batch 004 Loss: 0.0467
Epoch 005 Batch 005 Loss: 0.0641
Epoch 005 Batch 006 Loss: 0.0667
Epoch 005 Batch 007 Loss: 0.0626
Epoch 005 Batch 008 Loss: 0.0533
Epoch 005 Batch 009 Loss: 0.0616
Epoch 005 Batch 010 Loss: 0.0732
Epoch 005 Batch 011 Loss: 0.0507
Epoch 005 Batch 012 Loss: 0.0404
Epoch 005 Batch 013 Loss: 0.0391
Epoch 005 Batch 014 Loss: 0.0590
Epoch 005 Batch 015 Loss: 0.0432
Epoch 005 Batch 016 Loss: 0.0471
Epoch 005 Batch 017 Loss: 0.0567
Epoch 005 Batch 018 Loss: 0.0440
Epoch 005 Batch 019 Loss: 0.0601
Test Accuracy:  0.7749999761581421
Epoch 006 Batch 000 Loss: 0.0662
Epoch 006 Batch 001 Loss: 0.0667
Epoch 006 Batch 002 Loss: 0.0438
Epoch 006 Batch 003 Loss: 0.0512
Epoch 006 Batch 004 Loss: 0.0389
Epoch 006 Batch 005 Loss: 0.0486
Epoch 006 Batch 006 Loss: 0.0644
Epoch 006 Batch 007 Loss: 0.0522
Epoch 006 Batch 008 Loss: 0.0652
Epoch 006 Batch 009 Loss: 0.0422
Epoch 006 Batch 010 Loss: 0.0756
Epoch 006 Batch 011 Loss: 0.0555
Epoch 006 Batch 012 Loss: 0.0554
Epoch 006 Batch 013 Loss: 0.0427
Epoch 006 Batch 014 Loss: 0.0596
Epoch 006 Batch 015 Loss: 0.0459
Epoch 006 Batch 016 Loss: 0.0423
Epoch 006 Batch 017 Loss: 0.0644
Epoch 006 Batch 018 Loss: 0.0328
Epoch 006 Batch 019 Loss: 0.0495
Test Accuracy:  0.7929999828338623
Epoch 007 Batch 000 Loss: 0.0593
Epoch 007 Batch 001 Loss: 0.0323
Epoch 007 Batch 002 Loss: 0.0613
Epoch 007 Batch 003 Loss: 0.0900
Epoch 007 Batch 004 Loss: 0.0417
Epoch 007 Batch 005 Loss: 0.0426
Epoch 007 Batch 006 Loss: 0.0399
Epoch 007 Batch 007 Loss: 0.0599
Epoch 007 Batch 008 Loss: 0.0556
Epoch 007 Batch 009 Loss: 0.0557
Epoch 007 Batch 010 Loss: 0.0472
Epoch 007 Batch 011 Loss: 0.0520
Epoch 007 Batch 012 Loss: 0.0506
Epoch 007 Batch 013 Loss: 0.0399
Epoch 007 Batch 014 Loss: 0.0498
Epoch 007 Batch 015 Loss: 0.0668
Epoch 007 Batch 016 Loss: 0.0679
Epoch 007 Batch 017 Loss: 0.0453
Epoch 007 Batch 018 Loss: 0.0363
Epoch 007 Batch 019 Loss: 0.0857
Test Accuracy:  0.7670000195503235
Epoch 008 Batch 000 Loss: 0.0470
Epoch 008 Batch 001 Loss: 0.0369
Epoch 008 Batch 002 Loss: 0.0407
Epoch 008 Batch 003 Loss: 0.0554
Epoch 008 Batch 004 Loss: 0.0341
Epoch 008 Batch 005 Loss: 0.0243
Epoch 008 Batch 006 Loss: 0.0624
Epoch 008 Batch 007 Loss: 0.0523
Epoch 008 Batch 008 Loss: 0.0445
Epoch 008 Batch 009 Loss: 0.0354
Epoch 008 Batch 010 Loss: 0.0339
Epoch 008 Batch 011 Loss: 0.0414
Epoch 008 Batch 012 Loss: 0.0679
Epoch 008 Batch 013 Loss: 0.0264
Epoch 008 Batch 014 Loss: 0.0316
Epoch 008 Batch 015 Loss: 0.0837
Epoch 008 Batch 016 Loss: 0.0405
Epoch 008 Batch 017 Loss: 0.0381
Epoch 008 Batch 018 Loss: 0.0508
Epoch 008 Batch 019 Loss: 0.0288
Test Accuracy:  0.7860000133514404
Epoch 009 Batch 000 Loss: 0.0314
Epoch 009 Batch 001 Loss: 0.0651
Epoch 009 Batch 002 Loss: 0.0315
Epoch 009 Batch 003 Loss: 0.0464
Epoch 009 Batch 004 Loss: 0.0721
Epoch 009 Batch 005 Loss: 0.0806
Epoch 009 Batch 006 Loss: 0.0556
Epoch 009 Batch 007 Loss: 0.0562
Epoch 009 Batch 008 Loss: 0.0417
Epoch 009 Batch 009 Loss: 0.0525
Epoch 009 Batch 010 Loss: 0.0538
Epoch 009 Batch 011 Loss: 0.0551
Epoch 009 Batch 012 Loss: 0.0518
Epoch 009 Batch 013 Loss: 0.2122
Epoch 009 Batch 014 Loss: 0.0734
Epoch 009 Batch 015 Loss: 0.0563
Epoch 009 Batch 016 Loss: 0.0507
Epoch 009 Batch 017 Loss: 0.0595
Epoch 009 Batch 018 Loss: 0.0561
Epoch 009 Batch 019 Loss: 0.0656
Test Accuracy:  0.7720000147819519
Epoch 010 Batch 000 Loss: 0.0576
Epoch 010 Batch 001 Loss: 0.0674
Epoch 010 Batch 002 Loss: 0.0563
Epoch 010 Batch 003 Loss: 0.0459
Epoch 010 Batch 004 Loss: 0.0663
Epoch 010 Batch 005 Loss: 0.0383
Epoch 010 Batch 006 Loss: 0.0371
Epoch 010 Batch 007 Loss: 0.0304
Epoch 010 Batch 008 Loss: 0.0345
Epoch 010 Batch 009 Loss: 0.0369
Epoch 010 Batch 010 Loss: 0.0554
Epoch 010 Batch 011 Loss: 0.0314
Epoch 010 Batch 012 Loss: 0.0504
Epoch 010 Batch 013 Loss: 0.0332
Epoch 010 Batch 014 Loss: 0.0413
Epoch 010 Batch 015 Loss: 0.0450
Epoch 010 Batch 016 Loss: 0.0405
Epoch 010 Batch 017 Loss: 0.0400
Epoch 010 Batch 018 Loss: 0.0333
Epoch 010 Batch 019 Loss: 0.0430
Test Accuracy:  0.7590000033378601
Epoch 011 Batch 000 Loss: 0.0617
Epoch 011 Batch 001 Loss: 0.0429
Epoch 011 Batch 002 Loss: 0.0237
Epoch 011 Batch 003 Loss: 0.0354
Epoch 011 Batch 004 Loss: 0.0579
Epoch 011 Batch 005 Loss: 0.0368
Epoch 011 Batch 006 Loss: 0.0501
Epoch 011 Batch 007 Loss: 0.0988
Epoch 011 Batch 008 Loss: 0.0508
Epoch 011 Batch 009 Loss: 0.0548
Epoch 011 Batch 010 Loss: 0.0421
Epoch 011 Batch 011 Loss: 0.0398
Epoch 011 Batch 012 Loss: 0.0422
Epoch 011 Batch 013 Loss: 0.0387
Epoch 011 Batch 014 Loss: 0.0384
Epoch 011 Batch 015 Loss: 0.0349
Epoch 011 Batch 016 Loss: 0.0794
Epoch 011 Batch 017 Loss: 0.0403
Epoch 011 Batch 018 Loss: 0.0443
Epoch 011 Batch 019 Loss: 0.0461
Test Accuracy:  0.7730000019073486
Epoch 012 Batch 000 Loss: 0.0496
Epoch 012 Batch 001 Loss: 0.0414
Epoch 012 Batch 002 Loss: 0.1526
Epoch 012 Batch 003 Loss: 0.0722
Epoch 012 Batch 004 Loss: 0.0410
Epoch 012 Batch 005 Loss: 0.0426
Epoch 012 Batch 006 Loss: 0.0522
Epoch 012 Batch 007 Loss: 0.0458
Epoch 012 Batch 008 Loss: 0.0510
Epoch 012 Batch 009 Loss: 0.0516
Epoch 012 Batch 010 Loss: 0.0549
Epoch 012 Batch 011 Loss: 0.1203
Epoch 012 Batch 012 Loss: 0.0871
Epoch 012 Batch 013 Loss: 0.0532
Epoch 012 Batch 014 Loss: 0.0786
Epoch 012 Batch 015 Loss: 0.0399
Epoch 012 Batch 016 Loss: 0.0699
Epoch 012 Batch 017 Loss: 0.0534
Epoch 012 Batch 018 Loss: 0.0372
Epoch 012 Batch 019 Loss: 0.0349
Test Accuracy:  0.7609999775886536
Epoch 013 Batch 000 Loss: 0.0508
Epoch 013 Batch 001 Loss: 0.0458
Epoch 013 Batch 002 Loss: 0.0582
Epoch 013 Batch 003 Loss: 0.0498
Epoch 013 Batch 004 Loss: 0.0449
Epoch 013 Batch 005 Loss: 0.0831
Epoch 013 Batch 006 Loss: 0.0478
Epoch 013 Batch 007 Loss: 0.0432
Epoch 013 Batch 008 Loss: 0.0556
Epoch 013 Batch 009 Loss: 0.0454
Epoch 013 Batch 010 Loss: 0.0454
Epoch 013 Batch 011 Loss: 0.0653
Epoch 013 Batch 012 Loss: 0.0605
Epoch 013 Batch 013 Loss: 0.0594
Epoch 013 Batch 014 Loss: 0.0288
Epoch 013 Batch 015 Loss: 0.0368
Epoch 013 Batch 016 Loss: 0.0438
Epoch 013 Batch 017 Loss: 0.0524
Epoch 013 Batch 018 Loss: 0.0304
Epoch 013 Batch 019 Loss: 0.0399
Test Accuracy:  0.7820000052452087
Epoch 014 Batch 000 Loss: 0.0414
Epoch 014 Batch 001 Loss: 0.0525
Epoch 014 Batch 002 Loss: 0.0579
Epoch 014 Batch 003 Loss: 0.0387
Epoch 014 Batch 004 Loss: 0.0389
Epoch 014 Batch 005 Loss: 0.0307
Epoch 014 Batch 006 Loss: 0.0396
Epoch 014 Batch 007 Loss: 0.0454
Epoch 014 Batch 008 Loss: 0.0286
Epoch 014 Batch 009 Loss: 0.0446
Epoch 014 Batch 010 Loss: 0.0260
Epoch 014 Batch 011 Loss: 0.0558
Epoch 014 Batch 012 Loss: 0.0364
Epoch 014 Batch 013 Loss: 0.0367
Epoch 014 Batch 014 Loss: 0.0347
Epoch 014 Batch 015 Loss: 0.0384
Epoch 014 Batch 016 Loss: 0.0487
Epoch 014 Batch 017 Loss: 0.0460
Epoch 014 Batch 018 Loss: 0.0396
Epoch 014 Batch 019 Loss: 0.0514
Test Accuracy:  0.7710000276565552
Epoch 015 Batch 000 Loss: 0.0426
Epoch 015 Batch 001 Loss: 0.0443
Epoch 015 Batch 002 Loss: 0.0371
Epoch 015 Batch 003 Loss: 0.0529
Epoch 015 Batch 004 Loss: 0.0540
Epoch 015 Batch 005 Loss: 0.0399
Epoch 015 Batch 006 Loss: 0.0455
Epoch 015 Batch 007 Loss: 0.0646
Epoch 015 Batch 008 Loss: 0.0736
Epoch 015 Batch 009 Loss: 0.0797
Epoch 015 Batch 010 Loss: 0.0492
Epoch 015 Batch 011 Loss: 0.0238
Epoch 015 Batch 012 Loss: 0.0509
Epoch 015 Batch 013 Loss: 0.0584
Epoch 015 Batch 014 Loss: 0.0371
Epoch 015 Batch 015 Loss: 0.0302
Epoch 015 Batch 016 Loss: 0.0464
Epoch 015 Batch 017 Loss: 0.0342
Epoch 015 Batch 018 Loss: 0.0347
Epoch 015 Batch 019 Loss: 0.0658
Test Accuracy:  0.7699999809265137
Epoch 016 Batch 000 Loss: 0.0401
Epoch 016 Batch 001 Loss: 0.1259
Epoch 016 Batch 002 Loss: 0.0284
Epoch 016 Batch 003 Loss: 0.0648
Epoch 016 Batch 004 Loss: 0.0458
Epoch 016 Batch 005 Loss: 0.0562
Epoch 016 Batch 006 Loss: 0.0393
Epoch 016 Batch 007 Loss: 0.0746
Epoch 016 Batch 008 Loss: 0.0599
Epoch 016 Batch 009 Loss: 0.0744
Epoch 016 Batch 010 Loss: 0.0399
Epoch 016 Batch 011 Loss: 0.0401
Epoch 016 Batch 012 Loss: 0.0475
Epoch 016 Batch 013 Loss: 0.0412
Epoch 016 Batch 014 Loss: 0.0283
Epoch 016 Batch 015 Loss: 0.0399
Epoch 016 Batch 016 Loss: 0.0479
Epoch 016 Batch 017 Loss: 0.0613
Epoch 016 Batch 018 Loss: 0.0453
Epoch 016 Batch 019 Loss: 0.0523
Test Accuracy:  0.7580000162124634
Epoch 017 Batch 000 Loss: 0.0567
Epoch 017 Batch 001 Loss: 0.0576
Epoch 017 Batch 002 Loss: 0.0439
Epoch 017 Batch 003 Loss: 0.0464
Epoch 017 Batch 004 Loss: 0.0406
Epoch 017 Batch 005 Loss: 0.0396
Epoch 017 Batch 006 Loss: 0.0527
Epoch 017 Batch 007 Loss: 0.1051
Epoch 017 Batch 008 Loss: 0.0313
Epoch 017 Batch 009 Loss: 0.0664
Epoch 017 Batch 010 Loss: 0.0714
Epoch 017 Batch 011 Loss: 0.0281
Epoch 017 Batch 012 Loss: 0.0499
Epoch 017 Batch 013 Loss: 0.0415
Epoch 017 Batch 014 Loss: 0.0366
Epoch 017 Batch 015 Loss: 0.0344
Epoch 017 Batch 016 Loss: 0.0508
Epoch 017 Batch 017 Loss: 0.0382
Epoch 017 Batch 018 Loss: 0.0520
Epoch 017 Batch 019 Loss: 0.0315
Test Accuracy:  0.7689999938011169
Epoch 018 Batch 000 Loss: 0.0394
Epoch 018 Batch 001 Loss: 0.0316
Epoch 018 Batch 002 Loss: 0.0460
Epoch 018 Batch 003 Loss: 0.0389
Epoch 018 Batch 004 Loss: 0.1528
Epoch 018 Batch 005 Loss: 0.0462
Epoch 018 Batch 006 Loss: 0.0501
Epoch 018 Batch 007 Loss: 0.0524
Epoch 018 Batch 008 Loss: 0.0335
Epoch 018 Batch 009 Loss: 0.0444
Epoch 018 Batch 010 Loss: 0.0334
Epoch 018 Batch 011 Loss: 0.0454
Epoch 018 Batch 012 Loss: 0.0299
Epoch 018 Batch 013 Loss: 0.0693
Epoch 018 Batch 014 Loss: 0.0376
Epoch 018 Batch 015 Loss: 0.0308
Epoch 018 Batch 016 Loss: 0.0619
Epoch 018 Batch 017 Loss: 0.0425
Epoch 018 Batch 018 Loss: 0.0423
Epoch 018 Batch 019 Loss: 0.0355
Test Accuracy:  0.7739999890327454
Epoch 019 Batch 000 Loss: 0.0387
Epoch 019 Batch 001 Loss: 0.0385
Epoch 019 Batch 002 Loss: 0.0833
Epoch 019 Batch 003 Loss: 0.0438
Epoch 019 Batch 004 Loss: 0.0434
Epoch 019 Batch 005 Loss: 0.0340
Epoch 019 Batch 006 Loss: 0.0378
Epoch 019 Batch 007 Loss: 0.0579
Epoch 019 Batch 008 Loss: 0.0227
Epoch 019 Batch 009 Loss: 0.0831
Epoch 019 Batch 010 Loss: 0.0328
Epoch 019 Batch 011 Loss: 0.0585
Epoch 019 Batch 012 Loss: 0.0485
Epoch 019 Batch 013 Loss: 0.0332
Epoch 019 Batch 014 Loss: 0.0536
Epoch 019 Batch 015 Loss: 0.0326
Epoch 019 Batch 016 Loss: 0.0385
Epoch 019 Batch 017 Loss: 0.0723
Epoch 019 Batch 018 Loss: 0.0614
Epoch 019 Batch 019 Loss: 0.0474
Test Accuracy:  0.7630000114440918

猜你喜欢

转载自blog.csdn.net/weixin_43360025/article/details/124470107