前言
这一节笔记中主要针对继承InMemoryDataset
,一次性加载所有的数据到内存,这种数据集一般不是很大,所以直接一次性加载完毕
构建数据集
1、Dataset
pytorch geometric
构建数据集分两种:
1、继承InMemoryDataset
,一次性加载所有的数据到内存
2、继承Dataset
,分次加载到内存
在自定义的Dataset
的初始化方法种传入数据存放的路径,然后pytorch geometric
会在这个路径下再划分2个文件夹:
1、raw_dir:存放原始数据的路径(一般是csv、mat等格式)
2、processed_dir:存放处理后的数据(一般pt格式,由process方法实现)
但是pytorch中,实际上是没有这两个文件夹的
来看官方文件:
https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets
在示例代码第二行就引入了InMemoryDataset
函数,首先我们去看下,这个函数的使用以及参数
2、InMemoryDataset解读
root
是数据集存储的根目录tansform
与pre_transform
有相同有不同,相同点是,都是一个用于接受数据并返回转换后版本的数据;不同点是,tansform
在每次访问前转化,pre_transform
是保存到磁盘之前进行转化pre_filter
是一个用于接受数据并返回布尔值的函数,用于指示数据对象是否应该保存在最终数据集中
3、官方文档例子
再回来继续看代码,我把说明整合到代码的注释了,另外,这里有一些地方视频中解释的不是很清楚,我结合文章 Hands-on Graph Neural Networks with PyTorch & PyTorch Geometric 增加了一些注释以及自己的理解。
# 官方代码 https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets
import torch
from torch_geometric.data import InMemoryDataset # https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html CLASS InMemoryDataset
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None): # 初始化函数
super(MyOwnDataset, self).__init__(root, transform, pre_transform) # super用于说明MyOwnDataset继承InMemoryDataset初始化结果
self.data, self.slices = torch.load(self.processed_paths[0]) # 详见说明1
@property # 修饰方法,使方法可以像属性一样访问(保护变量/只读函数转变)详见说明2
def raw_file_names(self): # 返回一个包含没有处理的数据的名字的list
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self): # 返回一个包含所有处理过的数据的list
return ['data.pt']
def download(self): # 下载数据集函数,不需要的话直接填充pass
# Download to `self.raw_dir`.
# 整合你的数据成一个包含data的list,然后调用 self.collate()去计算将用于 DataLodadr 的片段
def process(self):
# Read data into huge `Data` list.
data_list = [...] # 创建并读取了数据的列表
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)] # 判断数据对象是否应该保存
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list] # 保存到磁盘之前进行转化
data, slices = self.collate(data_list)# 将数据对象的python列表整理为内部存储格式 torch_geometric.data.InMemoryDataset
torch.save((data, slices), self.processed_paths[0])
1、说明1:
这部分参考了pytorch_geometric自制数据集
制作数据集需要定义data
与slices
:
data
指的是以pytorch_geometric
定义的数据类型Data
构建的图数据集;slices
指的是切片,即数据集中不同graph
的划分,如slices[‘x’]=[0,75,150]
指的是数据集中按照75个节点划分,共三个图,slices[‘y’]
,slices['edge_index ']
以此类推。slices
用于区分不同的graph
与实现shuffle
等功能。值得注意slices
需要int
的tensor
类型,否则DataLoader
不支持切片操作。
2、说明2:
这部分参考了python @property的介绍与使用
我觉得这个文章说明的比up主找的解析更容易理解一些,已经很简洁了,我就不摘到我的文章中了,大家仔细前往观看即可
亚马逊代码例子
# https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/datasets/amazon.py
import torch
from torch_geometric.data import InMemoryDataset, download_url # download_url为了下载数据
from torch_geometric.io import read_npz
class Amazon(InMemoryDataset):
r"""The Amazon Computers and Amazon Photo networks from the
`"Pitfalls of Graph Neural Network Evaluation"
<https://arxiv.org/abs/1811.05868>`_ paper.
Nodes represent goods and edges represent that two goods are frequently
bought together.
Given product reviews as bag-of-words node features, the task is to
map goods to their respective product category.
Args:
root (string): Root directory where the dataset should be saved.
name (string): The name of the dataset (:obj:`"Computers"`,
:obj:`"Photo"`).
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
"""
url = 'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/'
def __init__(self, root, name, transform=None, pre_transform=None):
self.name = name.lower() # lower将字符串所有大小转小写
assert self.name in ['computers', 'photo'] # 利用断言判断 name 值的范围是不是在 computers/photo 范围内
super(Amazon, self).__init__(root, transform, pre_transform) # 继承初始化值
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return 'amazon_electronics_{}.npz'.format(self.name)
@property
def processed_file_names(self):
return 'data.pt'
def download(self):
download_url(self.url + self.raw_file_names, self.raw_dir)
def process(self):
data = read_npz(self.raw_paths[0]) # 读取 npz 格式数据集
data = data if self.pre_transform is None else self.pre_transform(data)
data, slices = self.collate([data])
torch.save((data, slices), self.processed_paths[0])
def __repr__(self):
return '{}{}()'.format(self.__class__.__name__, self.name.capitalize())
基本上和官网文档给的结构一致,部分处理细节稍微有点不同。其实看到这个地方的时候,我已经有点懵了,因为比如说代码中self.processed_paths[0]
并没有被定义赋值,为什么可以直接调用。
这部分疑惑等后面得到解答之后再回来继续改笔记