Model Trainer and Callbacks
Train Graph Machine Learning models (such as GraphSAGE and NodePiece) in a concise way. pyTigerGraph offers built-in models that can be used with the Trainer, consuming pyTigerGraph dataloaders.
Callbacks are classes that perform arbitrary operations at various stages of the
training process. Inherit from the BaseCallback
class to create compatible operations.
BaseCallback
The BaseCallback
class is an abstract class that all other trainer
callbacks inherit from. It contains a series of functions that are executed
during that point in time of the trainer’s execution, such as the beginning
or end of an epoch. Inherit from this class if a custom callback implementation is desired.
on_init_end()
on_init_end(trainer)
Run operations after the initialization of the trainer.
Parameter:
-
trainer (pyTigerGraph Trainer)
: Takes in the trainer in order to perform operations on it.
on_epoch_start()
on_epoch_start(trainer)
Run operations at the start of a training epoch.
Parameter:
-
trainer (pyTigerGraph Trainer)
: Takes in the trainer in order to perform operations on it.
on_train_step_start()
on_train_step_start(trainer)
Run operations at the start of a training step.
Parameter:
-
trainer (pyTigerGraph Trainer)
: Takes in the trainer in order to perform operations on it.
on_train_step_end()
on_train_step_end(trainer)
Run operations at the end of a training step.
Parameter:
-
trainer (pyTigerGraph Trainer)
: Takes in the trainer in order to perform operations on it.
on_epoch_end()
on_epoch_end(trainer)
Run operations at the end of an epoch.
Parameter:
-
trainer (pyTigerGraph Trainer)
: Takes in the trainer in order to perform operations on it.
on_eval_start()
on_eval_start(trainer)
Run operations at the start of the evaulation process.
Parameter:
-
trainer (pyTigerGraph Trainer)
: Takes in the trainer in order to perform operations on it.
on_eval_step_start()
on_eval_step_start(trainer)
Run operations at the start of an evaluation batch.
Parameter:
-
trainer (pyTigerGraph Trainer)
: Takes in the trainer in order to perform operations on it.
PrinterCallback
To use, import the class and pass it to the Trainer’s callback argument.
from pyTigerGraph.gds.trainer import Trainer, PrinterCallback
trainer = Trainer(model, training_dataloader, eval_dataloader, callbacks=[PrinterCallback])
DefaultCallback
The DefaultCallback
class logs metrics and updates progress bars during the training process.
The Trainer callbacks
parameter is populated with this callback.
If you define other callbacks with that parameter, you will have to pass DefaultCallback
again in your list of callbacks.
_init_()
init(output_dir = "./logs", use_tqdm = True)
Instantiate the Default Callback.
Parameters:
-
output_dir (str, optional)
: Path to output directory to log metrics to. Defaults to./logs
-
use_tqdm (bool, optional)
: Whether to use tqdm for progress bars. Defaults to True. Install thetqdm
package if the progress bar is desired.
Trainer
Train graph machine learning models that comply with the BaseModel
object in pyTigerGraph.
Performs training and evaluation loops and automatically collects metrics for the given task.
PyTorch is required to use the Trainer.
_init_()
init(model, training_dataloader: BaseLoader, eval_dataloader: BaseLoader, callbacks, metrics = None, target_type = None, loss_fn = None, optimizer = None, optimizer_kwargs)
Instantiate a Trainer.
Create a Trainer object to train graph machine learning models.
Parameters:
-
model (pyTigerGraph.gds.models.base_model.BaseModel)
: A graph machine learning model that inherits from the BaseModel class. -
training_dataloader (pyTigerGraph.gds.dataloaders.BaseLoader)
: A pyTigerGraph dataloader to iterate through training batches. -
eval_dataloader (pyTigerGraph.gds.dataloaders.BaseLoader)
: A pyTigerGraph dataloader to iterate through evaluation batches. -
callbacks (List[pyTigerGraph.gds.trainer.BaseCallback], optional)
: A list ofBaseCallback
objects. Defaults to[DefaultCallback]
-
metrics (List[pyTigerGraph.gds.metrics.BaseMetrics] or pyTigerGraph.gds.metrics.BaseMetrics, optional)
: A list or object of typeBaseMetrics
. If not specified, will use the metrics corresponding to the built-in model. -
target_type (string or tuple, optional)
: If using heterogenous graphs, specify the schema element to compute loss and metrics on. If using vertices, specify it with a string. If using an edge type, use the form("src_vertex_type", "edge_type", "dest_vertex_type")
-
loss_fn (torch.nn._Loss, optional)
: A function that computes the loss of the model. If not specified, the default loss function of the model type will be used. -
optimizer (torch.optim.Optimizer, optional)
: Specify the optimizer to be used during the training process. Defaults to Adam. -
optimizer_kwargs (dict, optional)
: Dictionary of optimizer arguments, such as learning rate. Defaults to optimizer’s default values.
update_train_step_metrics()
update_train_step_metrics(metrics)
Update the metrics for a training step.
Parameter:
-
metrics (dict)
: Dictionary of calculated metrics.
get_train_step_metrics()
get_train_step_metrics()
Get the metrics for a training step.
Returns:
Dictionary of training metrics results.
update_eval_metrics()
update_eval_metrics(metrics)
Update the metrics of an evaluation loop.
Parameter:
-
metrics (dict)
: Dictionary of calculated metrics.
get_eval_metrics()
get_eval_metrics()
Get the metrics for an evaluation loop.
Returns:
Dictionary of evaluation loop metrics results.
train()
train(num_epochs = None, max_num_steps = None)
Train a model.
Parameters:
-
num_epochs (int, optional)
: Number of epochs to train for. Defaults to 1 full iteration through thetraining_dataloader
. -
max_num_steps (int, optional)
: Number of training steps to perform.num_epochs
takes priority over this parameter. Defaults to the length of thetraining_dataloader