Giter VIP home page Giter VIP logo

jraph's Introduction

logo

Jraph - A library for graph neural networks in jax.

New! PMAP Examples and Data Loading.

We have added a pmap example.

Our friends at instadeep, Jama Hussein Mohamud and Tom Makkink have put together a nice guide to using pytorch data loading. Find it here.

New! Support For Large Distributed MPNNs

We have released a distributed graph network implementation that allows you to distribute a very large (millions of edges) graph network with explicit edge messages across multiple devices. Check it out!

New! Interactive Jraph Colabs

We have two new colabs to help you get to grips with Jraph.

The first is an educational colab with an amazing introduction to graph neural networks, graph theory,shows you how to use Jraph to solve a number of problems. Check it out here.

The second is a fully working example with best practices of using Jraph with OGBG-MOLPCBA with some great visualizations. Check it out here.

Thank you to Lisa Wang, Nikola Jovanoviฤ‡ & Ameya Daigavane.

Quick Start

Quick Start | Documentation

Jraph (pronounced "giraffe") is a lightweight library for working with graph neural networks in jax. It provides a data structure for graphs, a set of utilities for working with graphs, and a 'zoo' of forkable graph neural network models.

Installation

pip install jraph

Or Jraph can be installed directly from github using the following command:

pip install git+git://github.com/deepmind/jraph.git

The examples require additional dependencies. To install them please run:

pip install "jraph[examples, ogb_examples] @ git+git://github.com/deepmind/jraph.git"

Overview

Jraph is designed to provide utilities for working with graphs in jax, but doesn't prescribe a way to write or develop graph neural networks.

  • graph.py provides a lightweight data structure, GraphsTuple, for working with graphs.
  • utils.py provides utilities for working with GraphsTuples in jax.
    • Utilities for batching datasets of GraphsTuples.
    • Utilities to support jit compilation of variable shaped graphs via padding and masking.
    • Utilities for defining losses on partitions of inputs.
  • models.py provides examples of different types of graph neural network message passing. These are designed to be lightweight, easy to fork and adapt. They do not manage parameters for you - for that, consider using haiku or flax. See the examples for more details.

Quick Start

Jraph takes inspiration from the Tensorflow graph_nets library in defining a GraphsTuple data structure, which is a namedtuple that contains one or more directed graphs.

Representing Graphs - The GraphsTuple

import jraph
import jax.numpy as jnp

# Define a three node graph, each node has an integer as its feature.
node_features = jnp.array([[0.], [1.], [2.]])

# We will construct a graph for which there is a directed edge between each node
# and its successor. We define this with `senders` (source nodes) and `receivers`
# (destination nodes).
senders = jnp.array([0, 1, 2])
receivers = jnp.array([1, 2, 0])

# You can optionally add edge attributes.
edges = jnp.array([[5.], [6.], [7.]])

# We then save the number of nodes and the number of edges.
# This information is used to make running GNNs over multiple graphs
# in a GraphsTuple possible.
n_node = jnp.array([3])
n_edge = jnp.array([3])

# Optionally you can add `global` information, such as a graph label.

global_context = jnp.array([[1]])
graph = jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers,
edges=edges, n_node=n_node, n_edge=n_edge, globals=global_context)

A GraphsTuple can have more than one graph.

two_graph_graphstuple = jraph.batch([graph, graph])

The node and edge features are stacked on the leading axis.

jraph.batch([graph, graph]).nodes
>>> DeviceArray([[0.],
             [1.],
             [2.],
             [0.],
             [1.],
             [2.]], dtype=float32)

You can tell which nodes are from which graph by looking at n_node.

jraph.batch([graph, graph]).n_node
>>> DeviceArray([3, 3], dtype=int32)

You can store nests of features in nodes, edges and globals. This makes it possible to store multiple sets of features for each node, edge or graph, with potentially different types and semantically different meanings (for example 'training' and 'testing' nodes). The only requirement if that all arrays within each nest must have a common leading dimensions size, matching the total number of nodes, edges or graphs within the Graphstuple respectively.

node_targets = jnp.array([[True], [False], [True]])
graph = graph._replace(nodes={'inputs': graph.nodes, 'targets': node_targets})

Using the Model Zoo

Jraph provides a set of implemented reference models for you to use.

A Jraph model defines a message passing algorithm between the nodes, edges and global attributes of a graph. The user defines update functions that update graph features, which are typically neural networks but can be arbitrary jax functions.

Let's go through a GraphNetwork (paper) example. A GraphNet's first update function updates the edges using edge features, the node features of the sender and receiver and the global features.

# As one example, we just pass the edge features straight through.
def update_edge_fn(edge, sender, receiver, globals_):
  return edge

Often we use the concatenation of these features, and jraph provides an easy way of doing this with the concatenated_args decorator.

@jraph.concatenated_args
def update_edge_fn(concatenated_features):
  return concatenated_features

Typically, a learned model such as a Multi-Layer Perceptron is used within an update function.

The user similarly defines functions that update the nodes and globals. These are then used to configure a GraphNetwork. To see the arguments to the node and global update_fns please take a look at the model zoo.

net = jraph.GraphNetwork(update_edge_fn=update_edge_fn,
                         update_node_fn=update_node_fn,
                         update_global_fn=update_global_fn)

net is a function that sends messages according to the GraphNetwork algorithm and applies the update_fn. It takes a graph, and returns a graph.

updated_graph = net(graph)

Examples

For a deeper dive best place to start are the examples. In particular:

  • examples/basic.py provides an introduction to the features of the library.
  • ogb_examples/train.py provides an end to end example of training a GraphNet on molhiv Open Graph Benchmark dataset. Please note, you need to have downloaded the dataset to run this example.

The rest of the examples are short scripts demonstrating how to use various models from our model zoo, as well as making models go fast with jax.jit, and how to deal with Jax's static shape requirement.

Citing Jraph

To cite this repository:

@software{jraph2020github,
  author = {Jonathan Godwin* and Thomas Keck* and Peter Battaglia and Victor Bapst and Thomas Kipf and Yujia Li and Kimberly Stachenfeld and Petar Veli\v{c}kovi\'{c} and Alvaro Sanchez-Gonzalez},
  title = {{J}raph: {A} library for graph neural networks in jax.},
  url = {http://github.com/deepmind/jraph},
  version = {0.0.1.dev},
  year = {2020},
}

jraph's People

Contributors

adrhill avatar alvarosg avatar brettkoonce avatar hawkinsp avatar jg8610 avatar jheek avatar milescranmer avatar mplemay avatar oarriaga avatar rchen152 avatar salfaris avatar sauravmaheshkar avatar sooheon avatar speckhard avatar thomaskeck avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jraph's Issues

GAT layer: attention_query_fn & attention_logit_fn:

hi team, I'm trying to recreate the GAT on the Cora dataset as per the paper. Am I correct to assume I will need to create the attention_query_fn and attention_logit_fn? Or would you have them defined already somewhere? Any help/example will be very much appreciated!!

Suggestions/advice on improving efficiency in my own small JAX library that involves graphs?

Hello!

First, thank you for all the work put into jraph. I like the library a lot. I'm writing my own small JAX library that deals with DAGs, Connex, and I was wondering if I could ask for some suggestions/advice, since maybe the authors of this library have come across similar difficulties.

Basically, one of the primary goals of Connex is to convert any DAG into a trainable network. I'd like to incorporate maximal parallelism into the forward pass, which has been a bit of a challenge since neurons can have different numbers of inputs, and JAX/XLA does not currently support ragged arrays/tensors. I've written a more detailed description of the issue and the solution I have so far here.

If the authors of jraph have any input, I would greatly appreciate it :)

Can we have some benchmark test data?

Hi,

Did you held some GNN benchmark test on some common datasets like OGB ? Or did you compared jraph performance with other library such as AWS DGL, PyTorch Geometric? What's the roadmap of jraph? Thx for your attention.

Ideas on Message Passing

I) I'm looking onto implementing the convolution operation specified in Graph Isomorphism Network.
Which transforms the Node Features by a set of Dense Layers. Can an haiku.module object be called on the node transformation function? If not how should that be done?

node_update_fn = haiku.Seq(node_feature) + haiku.Seq(incoming edge_feature)

Invalid constraint warning when resolving jraph dependency

Using Poetry to install a project which depends on jraph, I get the warning

Resolving dependencies... (3.5s) PackageInfo: Invalid constraint (python-version (>="3.6")) found in jraph-0.0.2.dev0 dependencies, skipping

Iโ€™m not very familiar with setuptools, but it looks to me like
https://github.com/deepmind/jraph/blob/36071d5e0794b1809f73f5e005eea2dfe42bbfb6/setup.py#L44-L45
can be safely removed from install_requires as python_requires is already defined in L56:
https://github.com/deepmind/jraph/blob/36071d5e0794b1809f73f5e005eea2dfe42bbfb6/setup.py#L56

Adjacency matrix to GraphsTuple

I couldn't find any function in the jraph API that allows me to convert an arbitrary adjacency matrix to a GraphsTuple object. I assume we are expected to do this manually?

Just curious if there is any potential future work on this feature? I would be happy to explore something down this road.

using pmap with GraphsTuple

Looking at some other code that uses pmap for multigpu, pmapped functions want an additional axis representing the gpus.

GraphsTuple is just an ordered tuple of feature arrays and metadata, where feature arrays are rank 2 and metadata is rank 1. Am I understanding correctly that the way to use multigpu learning here is by having rank 3 and 2 arrays inside GraphsTuple like the following?

  1. take BS * N_GPU graphs from data source
  2. make N_GPU batches via jraph.batch
  3. pad_graphs_to_nearest_power_of_two that handles multiple inputs => now you have equal shaped arrays
  4. create one GraphsTuple with pmappable (rank 3) features
  5. use graphnets with pmapped functions

Citation

Dear Authors,

How can I cite the library? What is the appropriate bib entry?

Bests,

Benedek

Add some examples

Please add some examples targeting the possible use cases. thank you!

Missing dependencies in examples

frozendict is used in example hamiltonian_graph_network.py but it's not listed in the setup.py. I could make a PR but maybe frozendict is not wanted as a dependency.

data shuffling during the training of OGB example with padding and JIT support

Hi,

Thanks for such a great library. I have a question about data shuffling during training in the ogb_example. Can we shuffle the data during the training epochs like Dataloader(shuffle=True) in pytorch? Will this affect padding and JIT compilation? Sorry if this already implemented in make_generator. I'm not sure if I understand this correctly, during the initial epoch each batch is padded to some n dimensions and for the next epoch we shuffle the indices and make new batches does this entail new padding and JIT compilation?

Problem with ValueError

Hi everyone ๐Ÿ‘‹ ,

First of all, thank you for such a great work with this library!
I'm having some trouble to understand and create my own GNN. I'm trying to do some sort of graph classification. I followed this tutorial, and now I'm trying to apply this network example to my own data.

This is an example of a graph:

GraphsTuple(nodes=DeviceArray([[0.0000000e+00, 1.5747571e+12, 6.0000000e+00],
             [1.0000000e+00, 1.5701138e+12, 2.0000000e+00],
             [2.0000000e+00, 1.5747571e+12, 2.0000000e+00],
             [3.0000000e+00, 1.5747555e+12, 3.0000000e+00],
             [4.0000000e+00, 1.5701127e+12, 7.0000000e+00],
             [5.0000000e+00, 0.0000000e+00, 1.0000000e+00],
             [6.0000000e+00, 0.0000000e+00, 1.0000000e+00]],            dtype=float32), edges=DeviceArray([1, 1, 1, 1, 2, 2], dtype=int32), receivers=DeviceArray([3, 1, 2, 0, 4, 3], dtype=int32), senders=DeviceArray([5, 6, 6, 6, 3, 0], dtype=int32), globals=None, n_node=DeviceArray([7], dtype=int32), n_edge=DeviceArray([6], dtype=int32))

When I try to initialize the network, it outputs a ValueError: ValueError: data type <class 'numpy.int32'> not inexact.
This error comes from the last line of this code block (net.init(jax.random.PRNGKey(42), graph)):

def train(dataset: List[Dict[str, Any]], num_train_steps: int) -> hk.Params:
  """Training loop."""

  # Transform impure `net_fn` to pure functions with hk.transform.
  net = hk.without_apply_rng(hk.transform(net_fn))
  # Get a candidate graph and label to initialize the network.
  graph = dataset[0]['input_graph']
  
  # Initialize the network.
  params = net.init(jax.random.PRNGKey(42), graph)

Graph dataset[0]['input_graph'] is the one shown above.

After reading the docs, some stackoverflow threads, and searching in Google, I haven't found anything to either understand or resolve this error.

I have some hesitation about the data types of the GraphsTuple. I tried to change the int32 type to int native type of python (as Nate says in this stackoverflow thread), and I couldn't change the types. Also, it may be the float32 type of the nodes field?

I submit this issue as I haven't found any useful resource to help me debug this error. I hope there is no inconvience to do so, and help others to resolve this error faster.

Thank you!

[Potential Bug] Senders and Receivers can't be None as specified in docstrings

Hello,

I think there might be a bug in the node aggregation fn when using GCN. The docstrings state that we can set the senders and receivers to None for graphs that have no edges but due to using the senders and receivers in the node aggregation function, setting them to None causes it to break. I assume this is not intended unless I am missing something.

Thank you.

Cannot store other graph information in graphTuple

I thought we could store anything in graphTuple just like PyTorch geometric data or dgl, but that is not the case. I get the following error when storing node degrees:

TypeError: __new__() got an unexpected keyword argument 'in_degree'

is it the intended behavior, if so why, if not how can we store anything in there?

thanks

How to create custom batches?

Assume I have multiple graphs (nodes, senders, receivers and graph label). How would one create batches that would work with jax.vmap?

Reference code for NLNN with multi-head attention?

The graph networks paper points out the ability to express transformer like archs within graph net framework. Is there some example code of making multi-head attention work within this framework or the graph_nets library? I'm getting up on the segment_fn semantics.

requirements for examples

I missed the fact that these are in the setup file and ended up doing it manually

Maybe this is standard to have it as a feature - I've just never seen this.
Otherwise it might be helpful to point this out in the readme ?

How to handle heterogeneous graph features?

Say you want to differentiate edge/node types explicitly, and have differently parametrized functions operate on each type. This is different from types implicitly being encoded in the input embeddings, because it won't allow dispatching to different functions. There is no jraph native way to hold types, so the only method I can see is to keep features as dicts with say "type" and "feature" keys. Then most of the update/aggregate functions would need to first filter by the appropriate type key (the default GraphNetwork probably won't work with this straight away). Any plans to support this kind of thing in jraph?

What's the purpose of empty graph in padding?

Hi!

I came across the padding function (pad_with_graphs) and have a naive question: why do we use an empty graph there?
Also, we specify only the number of edges (and nodes). Which algorithm determines their incidence (i.e. which edge would connect which pair of nodes)? Thanks a lot.

Switch from `jax.tree_multimap` to `jax.tree_map`

Jraph is still using jax.tree_multimap, which is giving a deprecation warning. This can be problematic for users, for instance for us (Flax), since our CI fails if we hit a deprecation warning. i created an exception for Jraph now, but since tree_multimap can be replaced in-place for tree_map, this seems like an easy fix!

See google/flax#2037

Jraph Flax incorporation

Hi! I saw on the Haiku repo that they are urging people to move to flax, and was just wondering if jraph has any intention of moving to a flax framework?

Thanks!

How to do cross-graph attention?

Hi everyone,

I'm trying to re-implement the GMN model from deepmind in Jax. This model is designed to compute the similarity score between two graphs, and needs to compute the cross-graph node-wise attention weights between two graphs. Since different graphs can have different numbers of nodes, we need to do the cross-graph attention pair-by-pair.

The original implementation is written in Tensorflow, and given a node representation matrix, we can use tf.dynamically_partition to split it into a list of node representation matrices, each of which corresponds to a graph, as follows:

def batch_block_pair_attention(data,
                               block_idx,
                               n_blocks,
                               similarity='dotproduct'):
  """Compute batched attention between pairs of blocks.

  This function partitions the batch data into blocks according to block_idx.
  For each pair of blocks, x = data[block_idx == 2i], and
  y = data[block_idx == 2i+1], we compute

  x_i attend to y_j:
  a_{i->j} = exp(sim(x_i, y_j)) / sum_j exp(sim(x_i, y_j))
  y_j attend to x_i:
  a_{j->i} = exp(sim(x_i, y_j)) / sum_i exp(sim(x_i, y_j))

  and

  attention_x = sum_j a_{i->j} y_j
  attention_y = sum_i a_{j->i} x_i.

  Args:
    data: NxD float tensor.
    block_idx: N-dim int tensor.
    n_blocks: integer.
    similarity: a string, the similarity metric.

  Returns:
    attention_output: NxD float tensor, each x_i replaced by attention_x_i.

  Raises:
    ValueError: if n_blocks is not an integer or not a multiple of 2.
  """
  if not isinstance(n_blocks, int):
    raise ValueError('n_blocks (%s) has to be an integer.' % str(n_blocks))

  if n_blocks % 2 != 0:
    raise ValueError('n_blocks (%d) must be a multiple of 2.' % n_blocks)

  sim = get_pairwise_similarity(similarity)

  results = []

  # This is probably better than doing boolean_mask for each i
  partitions = tf.dynamic_partition(data, block_idx, n_blocks)

  # It is rather complicated to allow n_blocks be a tf tensor and do this in a
  # dynamic loop, and probably unnecessary to do so.  Therefore we are
  # restricting n_blocks to be a integer constant here and using the plain for
  # loop.
  for i in range(0, n_blocks, 2):
    x = partitions[i]
    y = partitions[i + 1]
    attention_x, attention_y = compute_cross_attention(x, y, sim)
    results.append(attention_x)
    results.append(attention_y)

  results = tf.concat(results, axis=0)
  # the shape of the first dimension is lost after concat, reset it back
  results.set_shape(data.shape)
  return results

However, we do not have some functions similar to tf.dynamically_partition in Jax.

@jg8610 Do you have any advice on how to do the cross-graph attention in Jax and Jraph, or do you have some similar cases internally? Thanks!

Feature engineering methods

Hi, does this library provide utilities for feature engineering on graphs? For example I may want to calculate different node centrality measures or other metrics that are commonly used in traditional ML. Basically, the question would be - will we have utilities for all the graph related stuff + machine learning or just related to GNNs?

jraph provides a data structure for graphs, a set of utilites for working with graphs

How does jraph support large scale graph neural network training?

HI,

We note that DeepMind team use jraph to win the 2021 KDD OGB MAG240M-LSC Track.
As I found that, this OGB dataset contain 100 million nodes, so I want to say, how does jraph support this large scale gnn model training? The jraph documentation is too simple, could you please provide some more details about the MAG240M-LSC Track and the large scale gnn training? thx.

examples/higgs_detection.py doesn't learn

It's a super cool example (particle physicist here!), but i'm not sure if the system actually learns anything. The loss just seems to be fluctuating around random performance on the test set, e.g. 11000 steps with default settings:

I1124 17:51:33.153964 4592291264 higgs_detection.py:204] step 0 loss train 0.5299999713897705 test 0.4699999988079071
I1124 17:51:36.350952 4592291264 higgs_detection.py:204] step 1000 loss train 0.4000000059604645 test 0.41999998688697815
I1124 17:51:39.533082 4592291264 higgs_detection.py:204] step 2000 loss train 0.5099999904632568 test 0.5400000214576721
I1124 17:51:42.694689 4592291264 higgs_detection.py:204] step 3000 loss train 0.4699999988079071 test 0.49000000953674316
I1124 17:51:45.841547 4592291264 higgs_detection.py:204] step 4000 loss train 0.5199999809265137 test 0.49000000953674316
I1124 17:51:48.940982 4592291264 higgs_detection.py:204] step 5000 loss train 0.550000011920929 test 0.41999998688697815
I1124 17:51:52.107565 4592291264 higgs_detection.py:204] step 6000 loss train 0.5099999904632568 test 0.47999998927116394
I1124 17:51:55.312087 4592291264 higgs_detection.py:204] step 7000 loss train 0.5299999713897705 test 0.4699999988079071
I1124 17:51:58.471485 4592291264 higgs_detection.py:204] step 8000 loss train 0.5600000023841858 test 0.550000011920929
I1124 17:52:01.599973 4592291264 higgs_detection.py:204] step 9000 loss train 0.49000000953674316 test 0.550000011920929
I1124 17:52:04.751693 4592291264 higgs_detection.py:204] step 10000 loss train 0.46000000834465027 test 0.47999998927116394
I1124 17:52:07.865066 4592291264 higgs_detection.py:204] step 11000 loss train 0.4699999988079071 test 0.5

Not sure if it was meant to work in practice, or just a nice example of a problem implementation (that part is done very nicely :) )

Replicating GAT with CORA dataset

Hello,

Thanks very much for such a wonderful product! I am trying to replicate GAT's paper with the CORA dataset, but I am finding some issues in using jraph . I started from your example notebook, implementing GAT, along with add_self_edges_fn:

def add_self_edges_fn(receivers: jnp.ndarray,
                      senders: jnp.ndarray,
                      total_num_nodes: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    r"""Adds self edges. Assumes self edges are not in the graph yet."""
    receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), axis=0)
    senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), axis=0)
    return receivers, senders
  
def GAT(attention_query_fn: Callable,
        attention_logit_fn: Callable,
        node_update_fn: Optional[Callable] = None,
        add_self_edges: bool = True) -> Callable:
    r""" Main GAT function"""
    # pylint: disable=g-long-lambda
    if node_update_fn is None:
        # By default, apply the leaky relu and then concatenate the heads on the
        # feature axis.
        node_update_fn = lambda x: jnp.reshape(jax.nn.leaky_relu(x), (x.shape[0], -1))

    def _ApplyGAT(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
        """Applies a Graph Attention layer."""
        nodes, edges, receivers, senders, _, _, _ = graph
        
        try:
            sum_n_node = nodes.shape[0]
        except IndexError:
            raise IndexError('GAT requires node features')

        nodes = attention_query_fn(nodes)
        total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]

        if add_self_edges:
            receivers, senders = add_self_edges_fn(receivers, senders,
                                                    total_num_nodes)
        sent_attributes = nodes[senders]
        received_attributes = nodes[receivers]
        att_softmax_logits = attention_logit_fn(sent_attributes,
                                                received_attributes, edges)

        att_weights = jraph.segment_softmax(
            att_softmax_logits, segment_ids=receivers, num_segments=sum_n_node)

        messages = sent_attributes * att_weights

        nodes = jax.ops.segment_sum(messages, receivers, num_segments=sum_n_node)

        nodes = node_update_fn(nodes)

        return graph._replace(nodes=nodes)

    return _ApplyGAT


def gat_definition(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """ Define GAT algorithm to run 
    Parameters
    ----------
    graph: jraph.GraphsTupe, input network to be processed 
    
    Return 
    -------
    jraph.GraphsTuple updated node graph
    """

    def _attention_logit_fn(sender_attr: jnp.ndarray, receiver_attr: jnp.ndarray,
                            edges: jnp.ndarray) -> jnp.ndarray:
        del edges
        x = jnp.concatenate((sender_attr, receiver_attr), axis=-1)
        return jax.nn.leaky_relu(hk.Linear(1)(x))

    gn = GAT(
        attention_query_fn=lambda n: hk.Linear(8)(n),
        attention_logit_fn=_attention_logit_fn,
        node_update_fn=None,
        add_self_edges=True)
    graph = gn(graph)

    gn = GAT(
        attention_query_fn=lambda n: hk.Linear(8)(n),
        attention_logit_fn=_attention_logit_fn,
        node_update_fn=hk.Linear(2),
        add_self_edges=True)
    graph = gn(graph)
    return graph

Then, after defining the main GAT, I run the training as:


def run_cora(network: hk.Transformed, num_steps: int) -> jnp.ndarray:
  r""" Run training on CORA dataset """
  cora_graph = cora_ds[0]['input_graph']
  labels = cora_ds[0]['target']
  params = network.init(jax.random.PRNGKey(42), cora_graph)

  @jax.jit
  def predict(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, cora_graph)
    return jnp.argmax(decoded_graph.nodes, axis=1)

  @jax.jit
  def prediction_loss(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, cora_graph)
    preds = jnp.argmax(decoded_graph.nodes, axis=1)
    # We interpret the decoded nodes as a pair of logits for each node.
    loss = compute_bce_with_logits_loss(preds, labels)
    return loss#, preds

  opt_init, opt_update = optax.adam(5e-4)
  opt_state = opt_init(params)

  @jax.jit
  def update(params: hk.Params, opt_state) -> Tuple[hk.Params, Any]:
    """Returns updated params and state."""
    g = jax.grad(prediction_loss)(params)
    updates, opt_state = opt_update(g, opt_state)
    return optax.apply_updates(params, updates), opt_state

  @jax.jit
  def accuracy(params: hk.Params) -> jnp.ndarray:
    decoded_graph = network.apply(params, cora_graph)
    return jnp.mean(jnp.argmax(decoded_graph.nodes, axis=1) == labels)

  for step in range(num_steps):
    if step%100==0:
        print(f"step {step} accuracy {accuracy(params).item():.2f}")
    params, opt_state = update(params, opt_state)

  return predict(params)

The problem is that accuracy stick to the same values throughout all the steps I am running (e.g. 1000 steps, accuracy = 0.13).
Could I ask you some indications to understand where I am wrong?
Thank you

Message Passing with edge updates

Hi there,

I was looking to implement a message passing network with edge updates as described in https://arxiv.org/abs/1806.03146.
Looking at the Jraph paper, it is explained that calculating the messages M_t for each edge should be done with the edge update function phi_e, in the GraphNetwork from the model zoo. However, as I understand it, this prevents me from implementing a function that just updates the edges, based on the edge feature, sending and receiving node.

Is there a workaround using the current model zoo to seperate edge updates and edge-wise messages or is this a known problem?

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.