Using Built-In Models
The built-in models of pyTigerGraph aim to enable developers to quickly get a graph machine learning model training using the data hosted in the TigerGraph database.
There are two ways to train models, using the model’s fit()
function, or using the Trainer
. Both will be shown in the examples below. The .fit()
method is less flexible, but easier to use, while the Trainer
allows for more flexibility.
Homogeneous Vertex Classification using GraphSAGE and .fit()
In this example, we are going use the GraphSAGEForVertexClassification
model and it’s .fit()
function.
The dataloader definition takes advantage of the factory function producing multiple dataloaders when multiple filters are defined.
After training for 5 epochs, we also get a batch for inference and use the .predict()
function of the model to perform the inference.
from pyTigerGraph.gds.models.GraphSAGE import GraphSAGEForVertexClassification
from pyTigerGraph import TigerGraphConnection
import torch
conn = TigerGraphConnection("https://your_domain.i.tgcloud.io", "Ethereum", username="user_1", password="MyPassword1!")
conn.getToken(conn.createSecret())
print("Connected!")
train_loader, valid_loader, infer_loader = conn.gds.neighborLoader(
v_in_feats=["in_degree","out_degree","send_amount","send_min","recv_amount","recv_min","pagerank"],
v_out_labels=["is_fraud"],
v_extra_feats=["is_training", "is_validation"],
output_format="PyG",
batch_size=512,
num_neighbors=10,
num_hops=2,
filter_by = ["is_training", "is_validation", ""],
shuffle=True,
timeout=600000
)
gs = GraphSAGEForVertexClassification(num_layers=2,
out_dim=2,
dropout=.2,
hidden_dim=128,
class_weights=torch.FloatTensor([1.0, 15.0]))
gs.fit(train_loader, valid_loader, number_epochs=5)
inference_batch = infer_loader.fetch([{"type": "Account", "primary_id": "0x4579ed7fefb118026040a6f94dcce4a0091c7199"},
{"type": "Account", "primary_id": "0xdf9e9c6f8a3d4bc2989e722d4ad526e98f56a429"}])
print(gs.predict(inference_batch))
Heterogeneous Vertex Classification using GraphSAGE and the Trainer
This time, we will use the Trainer
to train our model on a heterogeneous dataset.
from pyTigerGraph.gds.models.GraphSAGE import GraphSAGEForVertexClassification
from pyTigerGraph.gds import Trainer
from pyTigerGraph import TigerGraphConnection
conn = TigerGraphConnection("https://your_domain.i.tgcloud.io", "imdb", username="user_1", password="MyPassword1!")
conn.getToken(conn.createSecret())
train_loader, valid_loader, test_loader = conn.gds.neighborLoader(
v_in_feats={"Movie": ["x"], "Actor": ["x"], "Director": ["x"]},
v_out_labels={"Movie": ["y"]},
v_extra_feats={"Movie": ["train_mask", "val_mask", "test_mask"]},
output_format="PyG",
batch_size=256,
num_neighbors=10,
num_hops=2,
shuffle=True,
filter_by=[{"Movie":"train_mask"}, {"Movie": "val_mask"}, {"Movie": "test_mask"}]
)
print(train_loader.metadata())
gs = GraphSAGEForVertexClassification(2, 3, 256, .2, train_loader.metadata())
trainer = Trainer(gs, train_loader, valid_loader)
trainer.train(5)
Link Prediction with GraphSAGE and the Trainer
Now, we are going to train a link prediction model using GraphSAGE and the Trainer.
from pyTigerGraph.gds.models.GraphSAGE import GraphSAGEForLinkPrediction
from pyTigerGraph.gds import Trainer
from pyTigerGraph import TigerGraphConnection
conn = TigerGraphConnection("https://your_domain.i.tgcloud.io", "Cora", username="user_1", password="MyPassword1!")
conn.getToken(conn.createSecret())
train_edge_neighbor_loader, val_edge_neighbor_loader = conn.gds.edgeNeighborLoader(
v_in_feats=["x"],
v_out_labels=["y"],
num_batches=5,
e_extra_feats=["is_train","is_val"],
output_format="PyG",
num_neighbors=10,
num_hops=2,
filter_by=["is_train", "is_val"],
shuffle=False,
)
gs = GraphSAGEForLinkPrediction(2, 256, 256, 0.2)
trainer = Trainer(gs, train_edge_neighbor_loader, val_edge_neighbor_loader, target_type="Cite")
trainer.train(5)
Vertex Classification with NodePiece and .fit()
We are going to train a NodePiece MLP model using the .fit()
function and the NodePieceMLP built-in model.
from pyTigerGraph import TigerGraphConnection
from pyTigerGraph.gds.models.NodePieceMLP import NodePieceMLPForVertexClassification
from pyTigerGraph.gds.transforms.nodepiece_transforms import NodePieceMLPTransform
import torch
conn = TigerGraphConnection("https://your_domain.i.tgcloud.io", "Ethereum", username="user_1", password="MyPassword1!")
conn.getToken(conn.createSecret())
print("Connected!")
t = NodePieceMLPTransform(["in_degree","out_degree","send_amount","send_min",
"recv_amount","recv_min","pagerank", "betweenness"],
label = "is_fraud"
)
train_loader, valid_loader = conn.gds.nodepieceLoader(
v_feats=["in_degree","out_degree","send_amount","send_min",
"recv_amount","recv_min","pagerank", "is_fraud","betweenness"],
target_vertex_types="Account",
clear_cache=True,
compute_anchors=True,
filter_by=["is_training", "is_validation"],
anchor_percentage=0.1,
max_anchors=10,
max_distance=10,
batch_size=2048,
use_cache=False,
shuffle=False,
reverse_edge=True,
callback_fn = lambda x: t(x),
timeout=600_000)
model = NodePieceMLPForVertexClassification(num_layers=4,
hidden_dim=128,
out_dim=2,
dropout=0.5,
vocab_size=train_loader.num_tokens,
sequence_length=20,
class_weights=torch.FloatTensor([1, 15]))
model.fit(train_loader, valid_loader, 5)
print(model.trainer.get_eval_metrics())