Source code for caldera.blocks.node_block

import torch
from torch import nn

from caldera.blocks.aggregator import Aggregator
from caldera.blocks.block import Block
from caldera.data import GraphBatch


[docs]class NodeBlock(Block):
[docs] def __init__(self, mlp: nn.Module): super().__init__({"mlp": mlp}, independent=True)
def forward(self, node_attr): return self.block_dict["mlp"](node_attr) def forward_from_data(self, data: GraphBatch): return self(data.x)
[docs]class AggregatingNodeBlock(NodeBlock):
[docs] def __init__(self, mlp: nn.Module, edge_aggregator: Aggregator): super().__init__(mlp) self.block_dict["edge_aggregator"] = edge_aggregator self._independent = False
# TODO: source_to_dest or dest_to_source (isn't this just reversing the graph?) def forward( self, *, node_attr, edge_attr, edges, global_attr: torch.Tensor = None, node_idx: torch.Tensor = None, ): aggregated = ( self.block_dict["edge_aggregator"]( edge_attr, edges[1], dim=0, dim_size=node_attr.size(0) ), ) if global_attr is not None: if node_idx is None: raise RuntimeError( "Must provide `node_index` if providing `global_attr`" ) aggregated += (global_attr[node_idx],) out = torch.cat([node_attr, *aggregated], dim=1) return self.block_dict["mlp"](out) def forward_from_data(self, data: GraphBatch): return self(data.x, data.e, data.edges)