Getting Started =============== Installation ------------ .. code-block:: py class Network(torch.nn.Module): def __init__( self, latent_sizes=(16, 16, 1), depths=(1, 1, 1), dropout: float = None, pass_global_to_edge: bool = True, pass_global_to_node: bool = True, edge_to_node_aggregators=tuple(["add", "max", "mean", "min"]), edge_to_global_aggregators=tuple(["add", "max", "mean", "min"]), node_to_global_aggregators=tuple(["add", "max", "mean", "min"]), aggregator_activation=defaults.activation, ): super().__init__() self.config = { "latent_size": { "node": latent_sizes[1], "edge": latent_sizes[0], "global": latent_sizes[2], "core_node_block_depth": depths[0], "core_edge_block_depth": depths[1], "core_global_block_depth": depths[2], }, "node_block_aggregator": edge_to_node_aggregators, "global_block_to_node_aggregator": node_to_global_aggregators, "global_block_to_edge_aggregator": edge_to_global_aggregators, "aggregator_activation": aggregator_activation, "pass_global_to_edge": pass_global_to_edge, "pass_global_to_node": pass_global_to_node, } self.encoder = GraphEncoder( EdgeBlock(Flex(MLP)(Flex.d(), latent_sizes[0], dropout=dropout)), NodeBlock(Flex(MLP)(Flex.d(), latent_sizes[1], dropout=dropout)), GlobalBlock(Flex(MLP)(Flex.d(), latent_sizes[2], dropout=dropout)), ) edge_layers = [self.config["latent_size"]["edge"]] * self.config["latent_size"][ "core_edge_block_depth" ] node_layers = [self.config["latent_size"]["node"]] * self.config["latent_size"][ "core_node_block_depth" ] global_layers = [self.config["latent_size"]["global"]] * self.config[ "latent_size" ]["core_global_block_depth"] self.core = GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *edge_layers, dropout=dropout, layer_norm=True), # Flex(torch.nn.Linear)(Flex.d(), edge_layers[-1]) ) ), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *node_layers, dropout=dropout, layer_norm=True), # Flex(torch.nn.Linear)(Flex.d(), node_layers[-1]) ), Flex(MultiAggregator)( Flex.d(), self.config["node_block_aggregator"], activation=self.config["aggregator_activation"], ), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)( Flex.d(), *global_layers, dropout=dropout, layer_norm=True ), # Flex(torch.nn.Linear)(Flex.d(), global_layers[-1]) ), edge_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_edge_aggregator"], activation=self.config["aggregator_activation"], ), node_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_node_aggregator"], activation=self.config["aggregator_activation"], ), ), pass_global_to_edge=self.config["pass_global_to_edge"], pass_global_to_node=self.config["pass_global_to_node"], ) self.decoder = GraphEncoder( EdgeBlock( Flex(MLP)(Flex.d(), latent_sizes[0], latent_sizes[0], dropout=dropout) ), NodeBlock( Flex(MLP)(Flex.d(), latent_sizes[1], latent_sizes[1], dropout=dropout) ), GlobalBlock(Flex(MLP)(Flex.d(), latent_sizes[2])), ) self.output_transform = GraphEncoder( EdgeBlock( torch.nn.Sequential( Flex(torch.nn.Linear)(Flex.d(), 1), torch.nn.Sigmoid() ) ), NodeBlock( torch.nn.Sequential( Flex(torch.nn.Linear)(Flex.d(), 1), torch.nn.Sigmoid() ) ), GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), 1)), ) def forward(self, data, steps, save_all: bool = False): # encoded e, x, g = self.encoder(data) data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx) # graph topography data edges = data.edges node_idx = data.node_idx edge_idx = data.edge_idx latent0 = data meta = (edges, node_idx, edge_idx) outputs = [] for _ in range(steps): # core processing step e = torch.cat([latent0.e, e], dim=1) x = torch.cat([latent0.x, x], dim=1) g = torch.cat([latent0.g, g], dim=1) data = GraphBatch(x, e, g, *meta) e, x, g = self.core(data) # decode data = GraphBatch(x, e, g, *meta) _e, _x, _g = self.decoder(data) decoded = GraphBatch(_x, _e, _g, *meta) # transform _e, _x, _g = self.output_transform(decoded) gt = GraphBatch(_x, _e, _g, edges, node_idx, edge_idx) if save_all: outputs.append(gt) else: outputs = [gt] return outputs