GraphSAGE Models
GraphSAGEForVertexClassification
Use a GraphSAGE model to classify vertices. By default, this model collects ClassficiationMetrics
, and uses cross entropy as its loss function.
_init_()
init(num_layers: int, out_dim: int, hidden_dim: int, dropout = 0.0, heterogeneous = None, class_weights = None)
Initialize the GraphSAGE Vertex Classification Model.
Parameters:
-
num_layers (int)
: The number of layers in the model. Typically corresponds tonum_hops
in the dataloader. -
out_dim (int)
: The number of output dimensions. Corresponds to the number of classes in the classification task. -
hidden_dim (int)
: The hidden dimension to use. -
dropout (float, optional)
: The amount of dropout to apply between the layers. Defaults to 0. -
heterogeneous (tuple, optional)
: If set, use the graph metadata in the PyG heterogeneous metadata format. Can also retrieve this from the dataloader by callingloader.metadata()
. Defaults to None. -
class_weights (torch.Tensor, optional)
: If set, weight the different classes in the loss function. Used in imbalanced classification tasks.
forward()
forward(batch, get_probs = False, target_type = None)
Make a forward pass.
Parameters:
-
batch (torch_geometric.Data or torch_geometric.HeteroData)
: The PyTorch Geometric data object to classify. -
get_probs (bool, optional)
: Return the softmax scores of the raw logits, which can be interpreted as probabilities. Defaults to false. -
target_type (str, optional)
: Name of the vertex type to get the logits of. Defaults to None, and will return logits for all vertex types.
compute_loss()
compute_loss(logits, batch, target_type = None, loss_fn = None)
Compute loss.
Parameters:
-
logits (torch.Tensor or dict of torch.Tensor)
: The output of the forward pass. -
batch (torch_geometric.Data or torch_geometric.HeteroData)
: The PyTorch Geometric data object to classify. Assumes the target is represented in the"y"
data object. -
target_type (str, optional)
: The string of the vertex type to compute the loss on. -
loss_fn (callable, optional)
: The function to compute the loss with. Uses cross entropy loss if not defined.
GraphSAGEForVertexRegression
Use GraphSAGE for vertex regression tasks. By default, this model collects RegressionMetrics
, and uses MSE as its loss function.
_init_()
init(num_layers: int, out_dim: int, hidden_dim: int, dropout = 0.0, heterogeneous = None)
Initialize the GraphSAGE Vertex Regression Model.
Parameters:
-
num_layers (int)
: The number of layers in the model. Typically corresponds tonum_hops
in the dataloader. -
out_dim (int)
: The dimension of the output. Corresponds to the size of vector to perform the regression of. -
hidden_dim (int)
: The hidden dimension to use. -
dropout (float, optional)
: The amount of dropout to apply between layers. Defaults to 0.0. -
heterogeneous (tuple, optional)
: If set, use the graph metadata in the PyG heterogeneous metadata format. Can also retrieve this from the dataloader by callingloader.metadata()
. Defaults to None.
forward()
forward(batch, target_type = None)
Make a forward pass.
Parameters:
-
batch (torch_geometric.Data or torch_geometric.HeteroData)
: The PyTorch Geometric data object to classify. -
target_type (str, optional)
: Name of the vertex type to get the logits of. Defaults to None, and will return logits for all vertex types.
compute_loss()
compute_loss(logits, batch, target_type = None, loss_fn = None)
Compute loss.
Parameters:
-
logits (torch.Tensor or dict of torch.Tensor)
: The output of the forward pass. -
batch (torch_geometric.Data or torch_geometric.HeteroData)
: The PyTorch Geometric data object to classify. Assumes the target is represented in the"y"
data object. -
target_type (str, optional)
: The string of the vertex type to compute the loss on. -
loss_fn (callable, optional)
: The function to compute the loss with. Uses MSE loss if not defined.
GraphSAGEForLinkPrediction
By default, this model collects LinkPredictionMetrics
with k = 10, and uses binary cross entropy as its loss function.
_init_()
init(num_layers, embedding_dim, hidden_dim, dropout = 0.0, heterogeneous = None)
Initialize the GraphSAGE Link Prediction Model.
Parameters:
-
num_layers (int)
: The number of layers in the model. Typically corresponds tonum_hops
in the dataloader. -
embedding_dim (int)
: The dimension of the embedding generated. This embedding is then used for cosine similarity between a pair of vertices to generate the prediction for the edge. -
hidden_dim (int)
: The hidden dimension to use. -
dropout (float, optional)
: The amount of dropout to apply between layers. Defaults to 0.0. -
heterogeneous (tuple, optional)
: If set, use the graph metadata in the PyG heterogeneous metadata format. Can also retrieve this from the dataloader by callingloader.metadata()
. Defaults to None.
forward()
forward(batch, target_type = None)
Make a forward pass.
Parameters:
-
batch (torch_geometric.Data or torch_geometric.HeteroData)
: The PyTorch Geometric data object to classify. -
target_type (str, optional)
: Name of the vertex type to get the logits of. Defaults to None, and will return logits for all vertex types.
compute_loss()
compute_loss(logits, batch, target_type = None, loss_fn = None)
Compute loss.
Parameters:
-
logits (torch.Tensor or dict of torch.Tensor)
: The output of the forward pass. -
batch (torch_geometric.Data or torch_geometric.HeteroData)
: The PyTorch Geometric data object to classify. Assumes the target is represented in the"y"
data object. -
target_type (str, optional)
: The string of the edge type to compute the loss on. -
loss_fn (callable, optional)
: The function to compute the loss with. Uses binary cross entropy loss if not defined.