Source code for caldera.transforms.networkx._nx_attach_np_features

import functools
import itertools
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generator
from typing import Hashable
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union

import networkx as nx
import numpy as np
from networkx.classes.reportviews import EdgeView
from networkx.classes.reportviews import NodeView

from ._base import NetworkxTransformBase
from ._types import _G
from caldera.utils import dict_join
from caldera.utils import functional as fn
from caldera.utils.tensor import to_one_hot

GraphGen = Generator[_G, None, None]

_T = TypeVar("T")
_K = TypeVar("K")
_S = TypeVar("S")
_V = TypeVar("V")


def values_to_one_hot(
    values: Iterable[_T],
    classes: Union[List[_T], Tuple[_T, ...]],
    num_classes: Optional[int] = None,
) -> np.ndarray:
    """Convert an iterable of values to a one-hot encoded `np.ndarray`.

    :param values: iterable of values
    :param classes: valid classes of values. Will appear in one-hot array in order they appear here.
    :param num_classes: Number of classes for one-hot encoder. If not provided, number of provided classes
        will be used.
    :return: one-hot encoded array of values.
    """
    assert len(set(classes)) == len(classes)
    d = {k: i for i, k in enumerate(classes)}
    _values = []
    for v in values:
        try:
            _values.append(d[v])
        except KeyError:
            raise KeyError(
                "Value '{}' not found in list of available one-hot classes: {}".format(
                    v, d
                )
            )
    if num_classes is None:
        num_classes = len(d)
    return to_one_hot(np.array(_values), mx=num_classes)


# def _get_join_fn(join_func):
#     if join_func == "hstack":
#
#         def join_func(a, b):
#             return np.hstack([a, b])
#
#     elif join_func == "vstack":
#
#         def join_func(a, b):
#             return np.vstack([a, b])
#
#     return join_func


# def _get_processing_func(encoding):
#     if encoding is None:
#         processing_func = fn.compose(list, np.array)
#     elif encoding == "onehot":
#         processing_func = fn.compose(list, functools.partial(values_to_one_hot))
#     else:
#         processing_func = None
#     return processing_func


def _dispatch_nx_iterator(
    g: nx.DiGraph, x: str
) -> Union[NodeView, EdgeView, Generator[Tuple[Hashable, Any], None, None]]:
    if x == "node":
        return g.nodes(data=True)
    elif x == "edge":
        return g.edges(data=True)
    elif x == "global":
        return g.globals(data=True)
    else:
        raise ValueError("choose from {}".format(["node", "edge", "global"]))


def _update_left_inplace(
    data: Dict[_K, _T], new_data: Dict[_K, _S], join_fn: Callable[[_T, _S], _K]
) -> Dict[_K, Union[_T, _S, _V]]:
    """Updates the left dictionary, joining values with the provided
    function."""
    return dict_join(data, new_data, data, mode="right", join_fn=join_fn)


def _merge_update(data, key, to_key, join_fn, process_fn, default=...):
    """Update dictionary by applying join_fn and process_fn. For example, the
    processing function may be converting values to np.ndarray, while the
    join_fn might be to apply hstack.

    :param data:
    :param key:
    :param to_key:
    :param join_fn:
    :param process_fn:
    :param default:
    :return:
    """

    select_data = fn.index_each(-1)

    getter = fn.get_each(key)
    if default is not ...:
        getter = fn.get_each(key, default=default)

    select_process_and_send_to_key = fn.compose(
        fn.index_each(-1),
        getter,
        process_fn,
        fn.map_each(lambda x: {to_key: x}),
    )
    merge_and_join = functools.partial(_update_left_inplace, join_fn=join_fn)

    teed = itertools.tee(data)
    original_data = select_data(teed[0])
    processed_data = list(select_process_and_send_to_key(teed[1]))

    for d1, d2 in zip(original_data, processed_data):
        merge_and_join(d1, d2)


[docs]class NetworkxAttachNumpyFeatures(NetworkxTransformBase):
[docs] def __init__( self, x: str, from_key: str, to_key: str, *, default: Any = ..., encoding: Optional[Union[str, Callable[[Iterable[_T]], Iterable[_T]]]] = None, join_func: Union[ str, Callable[[np.ndarray, np.ndarray], np.ndarray] ] = "hstack", global_key: str = None, **processing_kwargs ): """Initialize transform that converts networkx features to a :class:`np.ndarray` :param x: 'edge', 'node', or 'global' :param from_key: dictionary key to find feature :param to_key: new key to attach feature :param default: :param encoding: :param join_func: select from 'hstack', 'vstack' or provide a new join function :param global_key: :param processing_kwargs: """ if encoding is None: processing_func = fn.compose(list, np.array) elif encoding == "onehot": processing_func = fn.compose( list, functools.partial(values_to_one_hot, **processing_kwargs) ) elif encoding == "bool": processing_func = fn.compose( fn.map_each(lambda x: np.array([int(bool(x))])) ) else: raise ValueError( "Encoding {} is not a valid encoding. Select from 'onehot' or None".format( encoding ) ) if join_func == "hstack": def join_func(a, b): return np.hstack([a, b]) elif join_func == "vstack": def join_func(a, b): return np.vstack([a, b]) self.x = x self.from_key = from_key self.to_key = to_key self.global_key = global_key self.default = default self.join_fn = join_func self.processing_func = processing_func self.processing_kwargs = processing_kwargs
def transform(self, g: _G) -> _G: iterator = list(_dispatch_nx_iterator(g, self.x)) _merge_update( iterator, self.from_key, self.to_key, join_fn=self.join_fn, process_fn=self.processing_func, default=self.default, ) return g
[docs]class NetworkxAttachNumpyOneHot(NetworkxAttachNumpyFeatures):
[docs] def __init__( self, x: str, from_key: str, to_key: str, *, default: Any = ..., join_func: str = "hstack", global_key: str = None, classes: List[str] = None, num_classes: int = None ): """Initialize transform that converts encodes networkx features into one-hot encodings. :param x: 'edge', 'node', 'global' :param from_key: dictionary key to find feature :param to_key: new key to attach feature :param default: :param join_func: :param global_key: :param classes: :param num_classes: """ super().__init__( x, from_key, to_key, default=default, encoding="onehot", join_func=join_func, global_key=global_key, classes=classes, num_classes=num_classes, )
[docs]class NetworkxAttachNumpyBool(NetworkxAttachNumpyFeatures):
[docs] def __init__( self, x: str, from_key: str, to_key: str, *, default: Any = ..., join_func: str = "hstack", global_key: str = None, classes: List[str] = None, num_classes: int = None ): """Initialize transform that converts encodes networkx into a boolean represented as :class:`np.ndarray` of size 1. E.g. `True` gets encoded as `np.ndarray([1.])` :param x: 'edge', 'node', 'global' :param from_key: dictionary key to find feature :param to_key: new key to attach feature :param default: :param join_func: :param global_key: :param classes: :param num_classes: """ super().__init__( x, from_key, to_key, default=default, encoding="bool", join_func=join_func, global_key=global_key, classes=classes, num_classes=num_classes, )