ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

组队学习-图神经网络(seventh)

2021-07-06 10:33:04  阅读:329  来源: 互联网

标签:index Data torch seventh 组队 神经网络 edge data 节点


超大规模数据集类的创建

前面我们只接触了数据可全部储存于内存的数据集,这些数据集对应的数据集类在创建对象时就将所有的数据加载到内存。然而如果数据集规模超级大,我们很难有足够大的内存完全存下所有数据。所以需要一个按需加载样本到内存的数据集类。

Dataset类

在PyG中,我们通过继承torch_geometric.data.Dataset基类来自定义一个按需加载样本到内存的数据集类。继承torch_geometric.data.InMemoryDataset基类要实现的方法,继承此基类同样要实现,此外还需实现以下方法:

  • len():返回数据集中的样本的数量
  • get():实现加载单个图的操作。在内部,getitem()返回通过调用get()来获取Data对象,并根据transform参数对它们进行选择性转换。

通过下面的方法我们可以不用定义一个Dataset类,而直接生成一个Dataloader对象,直接用于训练:

from torch_geometric.data import Data, DataLoader
data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)

我们也可以通过下面的方式将一个列表的Data对象组成一个batch:

from torch_geometric.data import Data, Batch
data_list = [Data(...), ..., Data(...)]
loader = Batch.from_data_list(data_list, batch_size=32)

图样本封装成批(BATCHING)与DataLoader类

合并小图组成大图

图可以有任意数量的节点和边,它不是规整的数据结构,因此对图数据封装成批的操作与对图像与序列等数据封装成批的操作不同。PyTorch Geometric中采用的将多个图封装成批的方式是,将小图作为连通组件的形式合并,构建一个大图。于是小图的邻接矩阵存储于大图邻接矩阵的对角线上。大图的邻接矩阵、属性矩阵、预测目标矩阵分别为:
在这里插入图片描述
此方法有以下关键的优势:

  • 依靠消息传递方案的GNN运算不需要被修改,因为消息仍然不能在属于不同图的两个节点之间交换。
  • 没有额外的计算或内存的开销。

小图的属性增值与拼接

将小图存储到大图中时需要对小图的属性做一些修改,一个最显著的例子就是要对节点序号增值。最一般的形式,PyTorch Geometric的DataLoader类会自动对edge_index张量增值,增加的值为当前被处理图的前面的图的累积节点数量。PyTorch Geometric允许我们通过覆盖torch_geometric.data.inc()和torch_geometric.data.cat_dim()函数来实现我们希望的行为。

图的匹配

如果你想在一个Data对象中存储多个图,例如用于图匹配等应用,我们需要确保所有这些图的正确封装成批行为。例如,考虑将两个图,一个源图Gs和一个目标图Gt ,存储在一个Data类中,即

class PairData(Data):
	def __init__(self, edge_index_s, x_s, edge_index_t,x_t):
		super(PairData, self).__init__()
		self.edge_index_s = edge_index_s
		self.x_s = x_s
		self.edge_index_t = edge_index_t
		self.x_t = x_t

这种情况下,edge_index_s应该根据源图Gs的节点数做增值,即x_s.size(0),而edge_index_t应该根据目标图Gt的节点数做增值,即x_t.size(0)。
我们通过一个例子来看一下节点增值:

edge_index_s = torch.tensor([
    [0, 0, 0, 0],
    [1, 2, 3, 4],
])
x_s = torch.randn(5, 16)  # 5 nodes.
edge_index_t = torch.tensor([
    [0, 0, 0],
    [1, 2, 3],
])
x_t = torch.randn(4, 16)  # 4 nodes.

data = PairData(edge_index_s, x_s, edge_index_t, x_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)
>>> Batch(edge_index_s=[2, 8], x_s=[10, 16],
          edge_index_t=[2, 6], x_t=[8, 16])

print(batch.edge_index_s)
>>> tensor([[0, 0, 0, 0, 5, 5, 5, 5],
            [1, 2, 3, 4, 6, 7, 8, 9]])

print(batch.edge_index_t)
>>> tensor([[0, 0, 0, 4, 4, 4],
            [1, 2, 3, 5, 6, 7]])

我们可以通过DataLoader中的follow_batch参数来维护batch属性。

二部图

二部图的邻接矩阵定义两种类型的节点之间的连接关系,不同类型节点的节点数量不需要一致,所以边的源节点与目标节点做的增值操作应是不同的。我们需要告诉PyTorch Geometric,它应该在edge_index中独立地为边的源节点和目标节点做增值操作。

def __inc__(self, key, value):
	if key == 'edge_index':
		return torch.tensor([[self.x_s.size(0)],[self.x_t.size(0)]])
	else:
		return super().__inc__(key, value)

其中,edge_index[0]根据x_s.size(0)(边的源节点)做增值运算,而edge_index[1](边的目标节点)根据x_t.size(0)做增值运算。

新的维度上拼接

有时, Data对象的属性需要在一个新的维度上做拼接(如经典的封装成
批),例如,图级别属性或预测目标。具体来说,形状[num_features]
的属性列表应该被返回为[num_examples, num_features],而不是
[num_examples * num_features]。PyTorch Geometric通过在
cat_dim() 中返回一个None的连接维度来实现这一点。

class MyData(Data):
	def __cat_dim__(self, key, item):
		if key == 'foo':
			return None
		else:
			return super().__cat_dim__(key, item)

图预测任务实践

运行情况:
(1)虚拟内存需要128G
(2)使用教程的参数需要运行49个epoch,16个num_workers,每个epoch运行时间大概为3~4分钟,整体运行需要至少5小时
(3)试验运行开始后,程序会在saves 目录下创建一个task_name 参数指定名称的文件夹用于记录试验过程,当saves目录下已经有一个同名的文件夹时,程序会在 task_name 参数末尾增加一个后缀作为文件夹名称。试验运行过程中,所有的print 输出都会写入到试验文件夹下的output 文件,tensorboard.SummaryWriter记录的信息也存储在试验文件夹下的文件中。

参考资料:
1.按需获取的数据集类的创建
2.图预测任务实践

标签:index,Data,torch,seventh,组队,神经网络,edge,data,节点
来源: https://blog.csdn.net/Etc_in_the_great/article/details/118511948

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有