NodePiece Models
NodePieceMLPForVertexClassification
This model is for training an multi-layer perceptron (MLP) on batches produced by NodePiece dataloaders, and transformed by the NodePieceMLPTransform
.
The architecture is for a vertex classification task, and assumes the label of each vertex is in a batch attribute called "y"
, such as what is produced by the NodePieceMLPTransform
.
By default, this model collects ClassficiationMetrics
, and uses cross entropy as its loss function.
_init_()
init(num_layers: int, out_dim: int, hidden_dim: int, vocab_size: int, sequence_length: int, embedding_dim = 768, dropout = 0.0, class_weights = None)
Initialize a NodePieceMLPForVertexClassification. Initializes the model.
Parameters:
-
num_layers (int)
: The total number of layers in your model. -
out_dim (int)
: The output dimension of the model, a.k.a. the number of classes in the classification task. -
hidden_dim (int)
: The hidden dimension of your model. -
vocab_size (int)
: The number of tokens produced by NodePiece. Can be accessed via the dataloader usingloader.num_tokens
. -
sequence_length (int)
: The number of tokens used to represent a single data instance. Is the sum ofmax_anchors
andmax_relational_context
defined in the dataloader. -
embedding_dim (int)
: The dimension to embed the tokens in. -
dropout (float)
: The percentage of dropout to be applied after every layer of the model (excluding the output layer). -
class_weights (torch.Tensor)
: Weight the importance of each class in the classification task when computing loss. Helpful in imbalanced classification tasks.
forward()
forward(batch, get_probs = False)
Make a forward pass.
Parameters:
-
batch
: The batch of data, in the same format as the data produced byNodePieceMLPTransform
-
get_probs (bool, optional)
: Return the softmax scores of the raw logits, which can be interpreted as probabilities. Defaults to false.
compute_loss()
compute_loss(logits, batch, loss_fn = None)
Compute loss.
Parameters:
-
logits (torch.Tensor)
: The output of the model. -
batch
: The batch of data, in the same format as the data produced byNodePieceMLPTransform
-
loss_fn
: A PyTorch-compatible function to produce the loss of the model, which takes in logits, the labels, and optionally the class_weights. Defaults to Cross Entropy.