Model Trainer and Callbacks

Train Graph Machine Learning models (such as GraphSAGE and NodePiece) in a concise way. pyTigerGraph offers built-in models that can be used with the Trainer, consuming pyTigerGraph dataloaders.

Callbacks are classes that perform arbitrary operations at various stages of the training process. Inherit from the BaseCallback class to create compatible operations.

BaseCallback

The BaseCallback class is an abstract class that all other trainer callbacks inherit from. It contains a series of functions that are executed during that point in time of the trainer’s execution, such as the beginning or end of an epoch. Inherit from this class if a custom callback implementation is desired.

on_init_end()

on_init_end(trainer)

Run operations after the initialization of the trainer.

Parameter:

  • trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it.

on_epoch_start()

on_epoch_start(trainer)

Run operations at the start of a training epoch.

Parameter:

  • trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it.

on_train_step_start()

on_train_step_start(trainer)

Run operations at the start of a training step.

Parameter:

  • trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it.

on_train_step_end()

on_train_step_end(trainer)

Run operations at the end of a training step.

Parameter:

  • trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it.

on_epoch_end()

on_epoch_end(trainer)

Run operations at the end of an epoch.

Parameter:

  • trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it.

on_eval_start()

on_eval_start(trainer)

Run operations at the start of the evaulation process.

Parameter:

  • trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it.

on_eval_step_start()

on_eval_step_start(trainer)

Run operations at the start of an evaluation batch.

Parameter:

  • trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it.

on_eval_step_end()

on_eval_step_end(trainer)

Run operations at the end of an evaluation batch.

Parameter:

  • trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it.

on_eval_end()

on_eval_end(trainer)

Run operations at the end of the evaluation process.

Parameter:

  • trainer (pyTigerGraph Trainer): Takes in the trainer in order to perform operations on it.

PrinterCallback

To use, import the class and pass it to the Trainer’s callback argument.

from pyTigerGraph.gds.trainer import Trainer, PrinterCallback

trainer = Trainer(model, training_dataloader, eval_dataloader, callbacks=[PrinterCallback])

DefaultCallback

The DefaultCallback class logs metrics and updates progress bars during the training process. The Trainer callbacks parameter is populated with this callback. If you define other callbacks with that parameter, you will have to pass DefaultCallback again in your list of callbacks.

_init_()

init(output_dir = "./logs", use_tqdm = True)

Instantiate the Default Callback.

Parameters:

  • output_dir (str, optional): Path to output directory to log metrics to. Defaults to ./logs

  • use_tqdm (bool, optional): Whether to use tqdm for progress bars. Defaults to True. Install the tqdm package if the progress bar is desired.

Trainer

Train graph machine learning models that comply with the BaseModel object in pyTigerGraph. Performs training and evaluation loops and automatically collects metrics for the given task.

PyTorch is required to use the Trainer.

_init_()

init(model, training_dataloader: BaseLoader, eval_dataloader: BaseLoader, callbacks, metrics = None, target_type = None, loss_fn = None, optimizer = None, optimizer_kwargs)

Instantiate a Trainer.

Create a Trainer object to train graph machine learning models.

Parameters:

  • model (pyTigerGraph.gds.models.base_model.BaseModel): A graph machine learning model that inherits from the BaseModel class.

  • training_dataloader (pyTigerGraph.gds.dataloaders.BaseLoader): A pyTigerGraph dataloader to iterate through training batches.

  • eval_dataloader (pyTigerGraph.gds.dataloaders.BaseLoader): A pyTigerGraph dataloader to iterate through evaluation batches.

  • callbacks (List[pyTigerGraph.gds.trainer.BaseCallback], optional): A list of BaseCallback objects. Defaults to [DefaultCallback]

  • metrics (List[pyTigerGraph.gds.metrics.BaseMetrics] or pyTigerGraph.gds.metrics.BaseMetrics, optional): A list or object of type BaseMetrics. If not specified, will use the metrics corresponding to the built-in model.

  • target_type (string or tuple, optional): If using heterogenous graphs, specify the schema element to compute loss and metrics on. If using vertices, specify it with a string. If using an edge type, use the form ("src_vertex_type", "edge_type", "dest_vertex_type")

  • loss_fn (torch.nn._Loss, optional): A function that computes the loss of the model. If not specified, the default loss function of the model type will be used.

  • optimizer (torch.optim.Optimizer, optional): Specify the optimizer to be used during the training process. Defaults to Adam.

  • optimizer_kwargs (dict, optional): Dictionary of optimizer arguments, such as learning rate. Defaults to optimizer’s default values.

update_train_step_metrics()

update_train_step_metrics(metrics)

Update the metrics for a training step.

Parameter:

  • metrics (dict): Dictionary of calculated metrics.

get_train_step_metrics()

get_train_step_metrics()

Get the metrics for a training step.

Returns:

Dictionary of training metrics results.

reset_train_step_metrics()

reset_train_step_metrics()

Reset training step metrics.

update_eval_metrics()

update_eval_metrics(metrics)

Update the metrics of an evaluation loop.

Parameter:

  • metrics (dict): Dictionary of calculated metrics.

get_eval_metrics()

get_eval_metrics()

Get the metrics for an evaluation loop.

Returns:

Dictionary of evaluation loop metrics results.

reset_eval_metrics()

reset_eval_metrics()

Reset evaluation loop metrics.

train()

train(num_epochs = None, max_num_steps = None)

Train a model.

Parameters:

  • num_epochs (int, optional): Number of epochs to train for. Defaults to 1 full iteration through the training_dataloader.

  • max_num_steps (int, optional): Number of training steps to perform. num_epochs takes priority over this parameter. Defaults to the length of the training_dataloader

eval()

eval(loader = None)

Evaluate a model.

Parameter:

  • loader (pyTigerGraph.gds.dataloaders.BaseLoader, optional): A dataloader to iterate through. If not defined, defaults to the eval_dataloader specified in the Trainer instantiation.

predict()

predict(batch)

Predict a batch.

Parameter:

  • batch (any): Data object that is compatible with the model being trained. Make predictions on the batch passed in.

Returns:

Returns a tuple of (model output, evaluation metrics)