NodePiece Models
NodePieceMLPForVertexClassification
This model is for training an multi-layer perceptron (MLP) on batches produced by NodePiece dataloaders, and transformed by the NodePieceMLPTransform.
The architecture is for a vertex classification task, and assumes the label of each vertex is in a batch attribute called "y", such as what is produced by the NodePieceMLPTransform.
By default, this model collects ClassficiationMetrics, and uses cross entropy as its loss function.
_init_()
init(num_layers: int, out_dim: int, hidden_dim: int, vocab_size: int, sequence_length: int, embedding_dim = 768, dropout = 0.0, class_weights = None)
Initialize a NodePieceMLPForVertexClassification. Initializes the model.
Parameters:
-
num_layers (int): The total number of layers in your model. -
out_dim (int): The output dimension of the model, a.k.a. the number of classes in the classification task. -
hidden_dim (int): The hidden dimension of your model. -
vocab_size (int): The number of tokens produced by NodePiece. Can be accessed via the dataloader usingloader.num_tokens. -
sequence_length (int): The number of tokens used to represent a single data instance. Is the sum ofmax_anchorsandmax_relational_contextdefined in the dataloader. -
embedding_dim (int): The dimension to embed the tokens in. -
dropout (float): The percentage of dropout to be applied after every layer of the model (excluding the output layer). -
class_weights (torch.Tensor): Weight the importance of each class in the classification task when computing loss. Helpful in imbalanced classification tasks.
forward()
forward(batch, get_probs = False)
Make a forward pass.
Parameters:
-
batch: The batch of data, in the same format as the data produced byNodePieceMLPTransform -
get_probs (bool, optional): Return the softmax scores of the raw logits, which can be interpreted as probabilities. Defaults to false.
compute_loss()
compute_loss(logits, batch, loss_fn = None)
Compute loss.
Parameters:
-
logits (torch.Tensor): The output of the model. -
batch: The batch of data, in the same format as the data produced byNodePieceMLPTransform -
loss_fn: A PyTorch-compatible function to produce the loss of the model, which takes in logits, the labels, and optionally the class_weights. Defaults to Cross Entropy.