classNet(torch.nn.Module):def__init__(self):super(Net, self).__init__()
self.conv1 = GCNConv(2,64)
self.conv2 = GCNConv(64,256)
self.conv3 = GCNConv(256,512)
self.linear = torch.nn.Linear(512,512)
self.linear2 = torch.nn.Linear(512,10)defforward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv3(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x, _ = scatter_max(x, data.batch, dim=0)
action_mean = self.linear(x)
x = self.linear2(action_mean)return x
Output
epoch:1 loss:1.311 Test Accuracy:76.31%% Test Loss:0.714
epoch:2 loss:0.592 Test Accuracy:87.32%% Test Loss:0.398
epoch:3 loss:0.414 Test Accuracy:88.37%% Test Loss:0.371
epoch:4 loss:0.350 Test Accuracy:89.79%% Test Loss:0.326
epoch:5 loss:0.308 Test Accuracy:92.77%% Test Loss:0.235
epoch:6 loss:0.282 Test Accuracy:92.79%% Test Loss:0.243
epoch:7 loss:0.262 Test Accuracy:93.78%% Test Loss:0.209
epoch:8 loss:0.250 Test Accuracy:93.91%% Test Loss:0.203
epoch:9 loss:0.240 Test Accuracy:94.03%% Test Loss:0.199
epoch:10 loss:0.226 Test Accuracy:93.88%% Test Loss:0.196
epoch:11 loss:0.227 Test Accuracy:93.72%% Test Loss:0.199
epoch:12 loss:0.218 Test Accuracy:94.05%% Test Loss:0.204
epoch:13 loss:0.200 Test Accuracy:95.10%% Test Loss:0.168
epoch:14 loss:0.206 Test Accuracy:94.79%% Test Loss:0.169
epoch:15 loss:0.196 Test Accuracy:93.98%% Test Loss:0.191
epoch:16 loss:0.190 Test Accuracy:95.16%% Test Loss:0.158
epoch:17 loss:0.194 Test Accuracy:95.03%% Test Loss:0.165
epoch:18 loss:0.189 Test Accuracy:93.57%% Test Loss:0.209
epoch:19 loss:0.181 Test Accuracy:95.27%% Test Loss:0.158
epoch:20 loss:0.183 Test Accuracy:95.68%% Test Loss:0.148
epoch:21 loss:0.183 Test Accuracy:94.40%% Test Loss:0.179
epoch:22 loss:0.181 Test Accuracy:94.76%% Test Loss:0.180
epoch:23 loss:0.175 Test Accuracy:94.42%% Test Loss:0.180
epoch:24 loss:0.175 Test Accuracy:95.25%% Test Loss:0.158
epoch:25 loss:0.165 Test Accuracy:94.72%% Test Loss:0.175
epoch:26 loss:0.166 Test Accuracy:94.91%% Test Loss:0.173
epoch:27 loss:0.164 Test Accuracy:94.91%% Test Loss:0.157
epoch:28 loss:0.164 Test Accuracy:95.60%% Test Loss:0.145
epoch:29 loss:0.169 Test Accuracy:93.82%% Test Loss:0.213
epoch:30 loss:0.163 Test Accuracy:95.81%% Test Loss:0.139
1.2 Change channel
classNet(torch.nn.Module):def__init__(self):super(Net, self).__init__()
self.conv1 = GCNConv(2,16)
self.conv2 = GCNConv(16,128)
self.conv3 = GCNConv(128,512)
self.linear = torch.nn.Linear(512,512)
self.linear2 = torch.nn.Linear(512,10)defforward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv3(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x, _ = scatter_max(x, data.batch, dim=0)
action_mean = self.linear(x)
x = self.linear2(action_mean)return x
Output
epoch:1 loss:1.472 Test Accuracy:65.24%% Test Loss:1.006
epoch:2 loss:0.881 Test Accuracy:78.00%% Test Loss:0.671
epoch:3 loss:0.671 Test Accuracy:82.65%% Test Loss:0.546
epoch:4 loss:0.555 Test Accuracy:85.86%% Test Loss:0.447
epoch:5 loss:0.486 Test Accuracy:86.88%% Test Loss:0.418
epoch:6 loss:0.434 Test Accuracy:87.71%% Test Loss:0.384
epoch:7 loss:0.415 Test Accuracy:87.27%% Test Loss:0.379
epoch:8 loss:0.394 Test Accuracy:89.62%% Test Loss:0.328
epoch:9 loss:0.374 Test Accuracy:90.03%% Test Loss:0.314
epoch:10 loss:0.360 Test Accuracy:89.25%% Test Loss:0.329
epoch:11 loss:0.341 Test Accuracy:90.24%% Test Loss:0.300
epoch:12 loss:0.339 Test Accuracy:91.68%% Test Loss:0.270
epoch:13 loss:0.310 Test Accuracy:91.38%% Test Loss:0.277
epoch:14 loss:0.308 Test Accuracy:88.32%% Test Loss:0.352
epoch:15 loss:0.299 Test Accuracy:91.20%% Test Loss:0.278
epoch:16 loss:0.297 Test Accuracy:90.05%% Test Loss:0.303
epoch:17 loss:0.280 Test Accuracy:92.57%% Test Loss:0.240
epoch:18 loss:0.281 Test Accuracy:92.48%% Test Loss:0.246
epoch:19 loss:0.271 Test Accuracy:92.20%% Test Loss:0.243
epoch:20 loss:0.271 Test Accuracy:93.02%% Test Loss:0.217
epoch:21 loss:0.264 Test Accuracy:92.10%% Test Loss:0.257
epoch:22 loss:0.262 Test Accuracy:92.76%% Test Loss:0.226
epoch:23 loss:0.264 Test Accuracy:92.85%% Test Loss:0.222
epoch:24 loss:0.259 Test Accuracy:93.21%% Test Loss:0.219
epoch:25 loss:0.249 Test Accuracy:92.30%% Test Loss:0.254
epoch:26 loss:0.250 Test Accuracy:92.40%% Test Loss:0.241
epoch:27 loss:0.246 Test Accuracy:93.26%% Test Loss:0.225
epoch:28 loss:0.242 Test Accuracy:93.48%% Test Loss:0.216
epoch:29 loss:0.242 Test Accuracy:93.39%% Test Loss:0.215
epoch:30 loss:0.244 Test Accuracy:92.83%% Test Loss:0.227
1.3 Change Linear
classNet(torch.nn.Module):def__init__(self):super(Net, self).__init__()
self.conv1 = GCNConv(2,64)
self.conv2 = GCNConv(64,256)
self.conv3 = GCNConv(256,1024)
self.linear = torch.nn.Linear(1024,512)
self.linear2 = torch.nn.Linear(512,10)defforward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv3(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x, _ = scatter_max(x, data.batch, dim=0)
action_mean = self.linear(x)
x = self.linear2(action_mean)return x
Output
epoch:1 loss:1.250 Test Accuracy:78.13%% Test Loss:0.663
epoch:2 loss:0.586 Test Accuracy:86.19%% Test Loss:0.432
epoch:3 loss:0.414 Test Accuracy:90.19%% Test Loss:0.321
epoch:4 loss:0.337 Test Accuracy:92.22%% Test Loss:0.258
epoch:5 loss:0.295 Test Accuracy:91.20%% Test Loss:0.273
epoch:6 loss:0.271 Test Accuracy:93.45%% Test Loss:0.215
epoch:7 loss:0.256 Test Accuracy:92.12%% Test Loss:0.262
epoch:8 loss:0.242 Test Accuracy:93.78%% Test Loss:0.204
epoch:9 loss:0.233 Test Accuracy:94.60%% Test Loss:0.185
epoch:10 loss:0.225 Test Accuracy:93.80%% Test Loss:0.213
epoch:11 loss:0.226 Test Accuracy:91.60%% Test Loss:0.262
epoch:12 loss:0.217 Test Accuracy:94.44%% Test Loss:0.179
epoch:13 loss:0.212 Test Accuracy:94.99%% Test Loss:0.179
epoch:14 loss:0.206 Test Accuracy:94.28%% Test Loss:0.193
epoch:15 loss:0.205 Test Accuracy:93.70%% Test Loss:0.200
epoch:16 loss:0.192 Test Accuracy:94.47%% Test Loss:0.186
epoch:17 loss:0.192 Test Accuracy:95.10%% Test Loss:0.161
epoch:18 loss:0.187 Test Accuracy:93.66%% Test Loss:0.210
epoch:19 loss:0.186 Test Accuracy:94.01%% Test Loss:0.200
epoch:20 loss:0.184 Test Accuracy:95.57%% Test Loss:0.150
epoch:21 loss:0.190 Test Accuracy:95.52%% Test Loss:0.150
epoch:22 loss:0.173 Test Accuracy:94.39%% Test Loss:0.178
epoch:23 loss:0.178 Test Accuracy:94.41%% Test Loss:0.190
epoch:24 loss:0.170 Test Accuracy:94.87%% Test Loss:0.170
epoch:25 loss:0.177 Test Accuracy:95.66%% Test Loss:0.143
epoch:26 loss:0.167 Test Accuracy:94.76%% Test Loss:0.174
epoch:27 loss:0.168 Test Accuracy:95.22%% Test Loss:0.147
epoch:28 loss:0.168 Test Accuracy:96.14%% Test Loss:0.129
epoch:29 loss:0.165 Test Accuracy:95.70%% Test Loss:0.143
epoch:30 loss:0.160 Test Accuracy:95.86%% Test Loss:0.140
1.4 Add Layer
classNet(torch.nn.Module):def__init__(self):super(Net, self).__init__()
self.conv1 = GCNConv(2,16)
self.conv2 = GCNConv(16,64)
self.conv3 = GCNConv(64,256)
self.conv4 = GCNConv(256,512)
self.linear = torch.nn.Linear(512,512)
self.linear2 = torch.nn.Linear(512,10)defforward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv3(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv4(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x, _ = scatter_max(x, data.batch, dim=0)
action_mean = self.linear(x)
x = self.linear2(action_mean)return x
Output
epoch:1 loss:1.427 Test Accuracy:71.44%% Test Loss:0.855
epoch:2 loss:0.723 Test Accuracy:80.81%% Test Loss:0.577
epoch:3 loss:0.506 Test Accuracy:86.76%% Test Loss:0.436
epoch:4 loss:0.427 Test Accuracy:89.07%% Test Loss:0.352
epoch:5 loss:0.389 Test Accuracy:90.06%% Test Loss:0.318
epoch:6 loss:0.361 Test Accuracy:89.33%% Test Loss:0.336
epoch:7 loss:0.336 Test Accuracy:90.64%% Test Loss:0.318
epoch:8 loss:0.321 Test Accuracy:90.05%% Test Loss:0.316
epoch:9 loss:0.313 Test Accuracy:91.52%% Test Loss:0.265
epoch:10 loss:0.299 Test Accuracy:90.85%% Test Loss:0.303
epoch:11 loss:0.287 Test Accuracy:92.03%% Test Loss:0.258
epoch:12 loss:0.279 Test Accuracy:91.43%% Test Loss:0.271
epoch:13 loss:0.276 Test Accuracy:92.56%% Test Loss:0.238
epoch:14 loss:0.268 Test Accuracy:92.25%% Test Loss:0.253
epoch:15 loss:0.261 Test Accuracy:92.45%% Test Loss:0.248
epoch:16 loss:0.254 Test Accuracy:92.98%% Test Loss:0.218
epoch:17 loss:0.246 Test Accuracy:93.73%% Test Loss:0.203
epoch:18 loss:0.244 Test Accuracy:92.39%% Test Loss:0.241
epoch:19 loss:0.243 Test Accuracy:93.30%% Test Loss:0.216
epoch:20 loss:0.236 Test Accuracy:93.71%% Test Loss:0.204
epoch:21 loss:0.235 Test Accuracy:93.94%% Test Loss:0.195
epoch:22 loss:0.228 Test Accuracy:93.81%% Test Loss:0.196
epoch:23 loss:0.229 Test Accuracy:93.58%% Test Loss:0.206
epoch:24 loss:0.225 Test Accuracy:93.66%% Test Loss:0.206
epoch:25 loss:0.222 Test Accuracy:94.21%% Test Loss:0.187
epoch:26 loss:0.219 Test Accuracy:92.57%% Test Loss:0.244
epoch:27 loss:0.223 Test Accuracy:94.35%% Test Loss:0.182
epoch:28 loss:0.210 Test Accuracy:93.73%% Test Loss:0.202
epoch:29 loss:0.212 Test Accuracy:94.18%% Test Loss:0.187
epoch:30 loss:0.208 Test Accuracy:94.16%% Test Loss:0.187
1.5 Dec Layer
classNet(torch.nn.Module):def__init__(self):super(Net, self).__init__()
self.conv1 = GCNConv(2,16)
self.conv2 = GCNConv(16,64)
self.conv3 = GCNConv(64,256)
self.conv4 = GCNConv(256,512)
self.linear = torch.nn.Linear(512,512)
self.linear2 = torch.nn.Linear(512,10)defforward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv3(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv4(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x, _ = scatter_max(x, data.batch, dim=0)
action_mean = self.linear(x)
x = self.linear2(action_mean)return x
Output
epoch:1 loss:1.294 Test Accuracy:73.58%% Test Loss:0.827
epoch:2 loss:0.765 Test Accuracy:80.89%% Test Loss:0.593
epoch:3 loss:0.600 Test Accuracy:84.10%% Test Loss:0.498
epoch:4 loss:0.532 Test Accuracy:85.91%% Test Loss:0.436
epoch:5 loss:0.474 Test Accuracy:86.62%% Test Loss:0.429
epoch:6 loss:0.428 Test Accuracy:88.02%% Test Loss:0.378
epoch:7 loss:0.389 Test Accuracy:90.22%% Test Loss:0.316
epoch:8 loss:0.370 Test Accuracy:89.78%% Test Loss:0.332
epoch:9 loss:0.345 Test Accuracy:91.19%% Test Loss:0.281
epoch:10 loss:0.325 Test Accuracy:89.05%% Test Loss:0.340
epoch:11 loss:0.307 Test Accuracy:92.12%% Test Loss:0.261
epoch:12 loss:0.294 Test Accuracy:92.69%% Test Loss:0.231
epoch:13 loss:0.274 Test Accuracy:91.27%% Test Loss:0.297
epoch:14 loss:0.272 Test Accuracy:93.50%% Test Loss:0.215
epoch:15 loss:0.264 Test Accuracy:92.89%% Test Loss:0.238
epoch:16 loss:0.257 Test Accuracy:93.26%% Test Loss:0.215
epoch:17 loss:0.249 Test Accuracy:91.77%% Test Loss:0.269
epoch:18 loss:0.251 Test Accuracy:92.26%% Test Loss:0.248
epoch:19 loss:0.246 Test Accuracy:92.76%% Test Loss:0.231
epoch:20 loss:0.243 Test Accuracy:92.68%% Test Loss:0.230
epoch:21 loss:0.237 Test Accuracy:93.94%% Test Loss:0.194
epoch:22 loss:0.229 Test Accuracy:93.74%% Test Loss:0.205
epoch:23 loss:0.230 Test Accuracy:91.84%% Test Loss:0.262
epoch:24 loss:0.228 Test Accuracy:93.90%% Test Loss:0.190
epoch:25 loss:0.221 Test Accuracy:94.03%% Test Loss:0.195
epoch:26 loss:0.218 Test Accuracy:94.27%% Test Loss:0.194
epoch:27 loss:0.217 Test Accuracy:94.02%% Test Loss:0.189
epoch:28 loss:0.221 Test Accuracy:94.17%% Test Loss:0.187
epoch:29 loss:0.215 Test Accuracy:94.27%% Test Loss:0.188
epoch:30 loss:0.212 Test Accuracy:94.44%% Test Loss:0.190