Stellargraph
StellarGraph - Machine Learning on Graphs
Install / Use
/learn @stellargraph/StellargraphREADME

StellarGraph Machine Learning Library
StellarGraph is a Python library for machine learning on graphs and networks.
Table of Contents
Introduction
The StellarGraph library offers state-of-the-art algorithms for graph machine learning, making it easy to discover patterns and answer questions about graph-structured data. It can solve many machine learning tasks:
- Representation learning for nodes and edges, to be used for visualisation and various downstream machine learning tasks;
- Classification and attribute inference of nodes or edges;
- Classification of whole graphs;
- Link prediction;
- Interpretation of node classification [8].
Graph-structured data represent entities as nodes (or vertices) and relationships between them as edges (or links), and can include data associated with either as attributes. For example, a graph can contain people as nodes and friendships between them as links, with data like a person's age and the date a friendship was established. StellarGraph supports analysis of many kinds of graphs:
- homogeneous (with nodes and links of one type),
- heterogeneous (with more than one type of nodes and/or links)
- knowledge graphs (extreme heterogeneous graphs with thousands of types of edges)
- graphs with or without data associated with nodes
- graphs with edge weights
StellarGraph is built on TensorFlow 2 and its Keras high-level API, as well as Pandas and NumPy. It is thus user-friendly, modular and extensible. It interoperates smoothly with code that builds on these, such as the standard Keras layers and scikit-learn, so it is easy to augment the core graph machine learning algorithms provided by StellarGraph. It is thus also easy to install with pip or Anaconda.
Getting Started
The numerous detailed and narrated examples are a good way to get started with StellarGraph. There is likely to be one that is similar to your data or your problem (if not, let us know).
You can start working with the examples immediately in Google Colab or Binder by clicking the and
badges within each Jupyter notebook.
Alternatively, you can run download a local copy of the demos and run them using jupyter. The demos can be downloaded by cloning the master branch of this repository, or by using the curl command below:
curl -L https://github.com/stellargraph/stellargraph/archive/master.zip | tar -xz --strip=1 stellargraph-master/demos
The dependencies required to run most of our demo notebooks locally can be installed using one of the following:
- Using
pip:pip install stellargraph[demos] - Using
conda:conda install -c stellargraph stellargraph
(See Installation section for more details and more options.)
Getting Help
If you get stuck or have a problem, there are many ways to make progress and get help or support:
- Read the documentation
- Consult the examples
- Contact us:
Example: GCN
One of the earliest deep machine learning algorithms for graphs is a Graph Convolution Network (GCN) [6]. The following example uses it for node classification: predicting the class from which a node comes. It shows how easy it is to apply using StellarGraph, and shows how StellarGraph integrates smoothly with Pandas and TensorFlow and libraries built on them.
Data preparation
Data for StellarGraph can be prepared using common libraries like Pandas and scikit-learn.
import pandas as pd
from sklearn import model_selection
def load_my_data():
# your own code to load data into Pandas DataFrames, e.g. from CSV files or a database
...
nodes, edges, targets = load_my_data()
# Use scikit-learn to compute training and test sets
train_targets, test_targets = model_selection.train_test_split(targets, train_size=0.5)
Graph machine learning model
This is the only part that is specific to StellarGraph. The machine learning model consists of some graph convolution layers followed by a layer to compute the actual predictions as a TensorFlow tensor. StellarGraph makes it easy to construct all of these layers via the GCN model class. It also makes it easy to get input data in the right format via the StellarGraph graph data type and a data generator.
import stellargraph as sg
import tensorflow as tf
# convert the raw data into StellarGraph's graph format for faster operations
graph = sg.StellarGraph(nodes, edges)
generator = sg.mapper.FullBatchNodeGenerator(graph, method="gcn")
# two layers of GCN, each with hidden dimension 16
gcn = sg.layer.GCN(layer_sizes=[16, 16], generator=generator)
x_inp, x_out = gcn.in_out_tensors() # create the input and output TensorFlow tensors
# use TensorFlow Keras to add a layer to compute the (one-hot) predictions
predictions = tf.keras.layers.Dense(units=len(ground_truth_targets.columns), activation="softmax")(x_out)
# use the input and output tensors to create a TensorFlow Keras model
model = tf.keras.Model(inputs=x_inp, outputs=predictions)
Training and evaluation
The model is a conventional TensorFlow Keras model, and so tasks such as training and evaluation can use the functions offered by Keras. StellarGraph's data generators make it simple to construct the required Keras Sequences for input data.
# prepare the model for training with the Adam optimiser and an appropriate loss function
model.compile("adam", loss="categorical_crossentropy", metrics=["accuracy"])
# train the model on the train set
model.fit(generator.flow(train_targets.index, train_targets), epochs=5)
# check model generalisation on the test set
(loss, accuracy) = model.evaluate(generator.flow(test_targets.index, test_targets))
print(f"Test set: loss = {loss}, accuracy = {accuracy}")
This algorithm is spelled out in more detail in its extended narrated notebook. We provide many more algorithms, each with a detailed example.
Algorithms
The StellarGraph library currently includes the following algorithms for graph machine learning:
| Algorithm | Description | | --- | --- | | GraphSAGE [1] | Supports supervised as well as unsupervised representation learning, node classification/regression, and link prediction for homogeneous networks. The current implementation supports multiple aggregation methods, including mean, maxpool, meanpool, and attentional aggregators. | | HinSAGE | Extension of GraphSAGE algorithm to heterogeneous network
