'''
Send messages through all edges >>> update all nodes.
DGLGraph.update_all(message_func='default', reduce_func='default', apply_node_func='default')
message_func --message function on the edges
reduce_func--reduce function on the node
apply_node_func:apply function on the nodes
'''
'''
DGLGraph.send(edges='__ALL__', message_func='default')
edges:
int:one edge using edge id
pair of int :one edge using its endpoints
int iterable/tensor :multiple edges using edge id
pair of int iterable/pair of tensor :multiple edges using their endpoints
returns messages on the edges and can be later fetched in the destination node’s mailbox
'''
'''
DGLGraph.recv(v='__ALL__', reduce_func='default', apply_node_func='default', inplace=False)
'''
import warnings
warnings.filterwarnings("ignore")
import torch as th
import dgl
g=dgl.DGLGraph()
g.add_nodes(3)
g.ndata["x"]=th.tensor([[5.],[6.],[7.]])
g.add_edges([0,1],[1,2])
src=th.tensor([0])
dst=th.tensor([2])
g.add_edges(src,dst)
print("ndata",g.ndata["x"])
def send_source(edges):
print("src",edges.src["x"].shape,edges.src["x"]) #源节点特征 ([2, 1])
print("dst",edges.dst["x"].shape,edges.dst["x"]) #目标节点特征 ([2, 1])
return {"m":edges.src["x"]}
g.register_message_func(send_source)
'''
ndata tensor([[5.],
[6.],
[7.]])
src torch.Size([3, 1]) tensor([[5.],
[6.],
[5.]])
dst torch.Size([3, 1]) tensor([[6.],
[7.],
[7.]])
'''
def simple_reduce(nodes):
print("data_nodes",nodes.data["x"]) #节点特征
print("mailbox",nodes.mailbox["m"].shape,nodes.mailbox["m"]) #mailbox包含沿第二维堆叠的所有传入message特征 [2, 1, 1]
print("sum",nodes.mailbox["m"].sum(1))
return {"x":nodes.mailbox["m"].sum(1)} #按行求和
g.register_reduce_func(simple_reduce)
g.send(g.edges())
g.recv(g.nodes())
print("ndata",g.ndata["x"])
'''
data_nodes tensor([[6.]])
mailbox torch.Size([1, 1, 1]) tensor([[[5.]]])
sum tensor([[5.]])
data_nodes tensor([[7.]])
mailbox torch.Size([1, 2, 1]) tensor([[[6.],
[5.]]])
sum tensor([[11.]])
ndata tensor([[ 0.],
[ 5.],
[11.]])
'''
send_recv
猜你喜欢
转载自www.cnblogs.com/hapyygril/p/11586319.html
今日推荐
周排行