GraphSAGE Models

GraphSAGEForVertexClassification

Use a GraphSAGE model to classify vertices. By default, this model collects ClassficiationMetrics, and uses cross entropy as its loss function.

_init_()

init(num_layers: int, out_dim: int, hidden_dim: int, dropout = 0.0, heterogeneous = None, class_weights = None)

Initialize the GraphSAGE Vertex Classification Model.

Parameters:

  • num_layers (int): The number of layers in the model. Typically corresponds to num_hops in the dataloader.

  • out_dim (int): The number of output dimensions. Corresponds to the number of classes in the classification task.

  • hidden_dim (int): The hidden dimension to use.

  • dropout (float, optional): The amount of dropout to apply between the layers. Defaults to 0.

  • heterogeneous (tuple, optional): If set, use the graph metadata in the PyG heterogeneous metadata format. Can also retrieve this from the dataloader by calling loader.metadata(). Defaults to None.

  • class_weights (torch.Tensor, optional): If set, weight the different classes in the loss function. Used in imbalanced classification tasks.

forward()

forward(batch, get_probs = False, target_type = None)

Make a forward pass.

Parameters:

  • batch (torch_geometric.Data or torch_geometric.HeteroData): The PyTorch Geometric data object to classify.

  • get_probs (bool, optional): Return the softmax scores of the raw logits, which can be interpreted as probabilities. Defaults to false.

  • target_type (str, optional): Name of the vertex type to get the logits of. Defaults to None, and will return logits for all vertex types.

compute_loss()

compute_loss(logits, batch, target_type = None, loss_fn = None)

Compute loss.

Parameters:

  • logits (torch.Tensor or dict of torch.Tensor): The output of the forward pass.

  • batch (torch_geometric.Data or torch_geometric.HeteroData): The PyTorch Geometric data object to classify. Assumes the target is represented in the "y" data object.

  • target_type (str, optional): The string of the vertex type to compute the loss on.

  • loss_fn (callable, optional): The function to compute the loss with. Uses cross entropy loss if not defined.

GraphSAGEForVertexRegression

Use GraphSAGE for vertex regression tasks. By default, this model collects RegressionMetrics, and uses MSE as its loss function.

_init_()

init(num_layers: int, out_dim: int, hidden_dim: int, dropout = 0.0, heterogeneous = None)

Initialize the GraphSAGE Vertex Regression Model.

Parameters:

  • num_layers (int): The number of layers in the model. Typically corresponds to num_hops in the dataloader.

  • out_dim (int): The dimension of the output. Corresponds to the size of vector to perform the regression of.

  • hidden_dim (int): The hidden dimension to use.

  • dropout (float, optional): The amount of dropout to apply between layers. Defaults to 0.0.

  • heterogeneous (tuple, optional): If set, use the graph metadata in the PyG heterogeneous metadata format. Can also retrieve this from the dataloader by calling loader.metadata(). Defaults to None.

forward()

forward(batch, target_type = None)

Make a forward pass.

Parameters:

  • batch (torch_geometric.Data or torch_geometric.HeteroData): The PyTorch Geometric data object to classify.

  • target_type (str, optional): Name of the vertex type to get the logits of. Defaults to None, and will return logits for all vertex types.

compute_loss()

compute_loss(logits, batch, target_type = None, loss_fn = None)

Compute loss.

Parameters:

  • logits (torch.Tensor or dict of torch.Tensor): The output of the forward pass.

  • batch (torch_geometric.Data or torch_geometric.HeteroData): The PyTorch Geometric data object to classify. Assumes the target is represented in the "y" data object.

  • target_type (str, optional): The string of the vertex type to compute the loss on.

  • loss_fn (callable, optional): The function to compute the loss with. Uses MSE loss if not defined.

GraphSAGEForLinkPrediction

By default, this model collects LinkPredictionMetrics with k = 10, and uses binary cross entropy as its loss function.

_init_()

init(num_layers, embedding_dim, hidden_dim, dropout = 0.0, heterogeneous = None)

Initialize the GraphSAGE Link Prediction Model.

Parameters:

  • num_layers (int): The number of layers in the model. Typically corresponds to num_hops in the dataloader.

  • embedding_dim (int): The dimension of the embedding generated. This embedding is then used for cosine similarity between a pair of vertices to generate the prediction for the edge.

  • hidden_dim (int): The hidden dimension to use.

  • dropout (float, optional): The amount of dropout to apply between layers. Defaults to 0.0.

  • heterogeneous (tuple, optional): If set, use the graph metadata in the PyG heterogeneous metadata format. Can also retrieve this from the dataloader by calling loader.metadata(). Defaults to None.

forward()

forward(batch, target_type = None)

Make a forward pass.

Parameters:

  • batch (torch_geometric.Data or torch_geometric.HeteroData): The PyTorch Geometric data object to classify.

  • target_type (str, optional): Name of the vertex type to get the logits of. Defaults to None, and will return logits for all vertex types.

compute_loss()

compute_loss(logits, batch, target_type = None, loss_fn = None)

Compute loss.

Parameters:

  • logits (torch.Tensor or dict of torch.Tensor): The output of the forward pass.

  • batch (torch_geometric.Data or torch_geometric.HeteroData): The PyTorch Geometric data object to classify. Assumes the target is represented in the "y" data object.

  • target_type (str, optional): The string of the edge type to compute the loss on.

  • loss_fn (callable, optional): The function to compute the loss with. Uses binary cross entropy loss if not defined.

get_embeddings()

get_embeddings(batch)

Get embeddings.

Parameter:

  • batch (torch_geometric.Data or torch_geometric.HeteroData): Get the embeddings for all vertices in a batch.