NodePiece Models


This model is for training an multi-layer perceptron (MLP) on batches produced by NodePiece dataloaders, and transformed by the NodePieceMLPTransform. The architecture is for a vertex classification task, and assumes the label of each vertex is in a batch attribute called "y", such as what is produced by the NodePieceMLPTransform. By default, this model collects ClassficiationMetrics, and uses cross entropy as its loss function.


init(num_layers: int, out_dim: int, hidden_dim: int, vocab_size: int, sequence_length: int, embedding_dim = 768, dropout = 0.0, class_weights = None)

Initialize a NodePieceMLPForVertexClassification. Initializes the model.


  • num_layers (int): The total number of layers in your model.

  • out_dim (int): The output dimension of the model, a.k.a. the number of classes in the classification task.

  • hidden_dim (int): The hidden dimension of your model.

  • vocab_size (int): The number of tokens produced by NodePiece. Can be accessed via the dataloader using loader.num_tokens.

  • sequence_length (int): The number of tokens used to represent a single data instance. Is the sum of max_anchors and max_relational_context defined in the dataloader.

  • embedding_dim (int): The dimension to embed the tokens in.

  • dropout (float): The percentage of dropout to be applied after every layer of the model (excluding the output layer).

  • class_weights (torch.Tensor): Weight the importance of each class in the classification task when computing loss. Helpful in imbalanced classification tasks.


forward(batch, get_probs = False)

Make a forward pass.


  • batch: The batch of data, in the same format as the data produced by NodePieceMLPTransform

  • get_probs (bool, optional): Return the softmax scores of the raw logits, which can be interpreted as probabilities. Defaults to false.


compute_loss(logits, batch, loss_fn = None)

Compute loss.


  • logits (torch.Tensor): The output of the model.

  • batch: The batch of data, in the same format as the data produced by NodePieceMLPTransform

  • loss_fn: A PyTorch-compatible function to produce the loss of the model, which takes in logits, the labels, and optionally the class_weights. Defaults to Cross Entropy.