Source code for caldera.blocks.node_block
import torch
from torch import nn
from caldera.blocks.aggregator import Aggregator
from caldera.blocks.block import Block
from caldera.data import GraphBatch
[docs]class NodeBlock(Block):
[docs]    def __init__(self, mlp: nn.Module):
        super().__init__({"mlp": mlp}, independent=True) 
    def forward(self, node_attr):
        return self.block_dict["mlp"](node_attr)
    def forward_from_data(self, data: GraphBatch):
        return self(data.x) 
[docs]class AggregatingNodeBlock(NodeBlock):
[docs]    def __init__(self, mlp: nn.Module, edge_aggregator: Aggregator):
        super().__init__(mlp)
        self.block_dict["edge_aggregator"] = edge_aggregator
        self._independent = False 
    # TODO: source_to_dest or dest_to_source (isn't this just reversing the graph?)
    def forward(
        self,
        *,
        node_attr,
        edge_attr,
        edges,
        global_attr: torch.Tensor = None,
        node_idx: torch.Tensor = None,
    ):
        aggregated = (
            self.block_dict["edge_aggregator"](
                edge_attr, edges[1], dim=0, dim_size=node_attr.size(0)
            ),
        )
        if global_attr is not None:
            if node_idx is None:
                raise RuntimeError(
                    "Must provide `node_index` if providing `global_attr`"
                )
            aggregated += (global_attr[node_idx],)
        out = torch.cat([node_attr, *aggregated], dim=1)
        return self.block_dict["mlp"](out)
    def forward_from_data(self, data: GraphBatch):
        return self(data.x, data.e, data.edges)