Data Transforms

There are currently two sub-modules, nodepiece_transforms, and pyg_transforms. The NodePiece transforms help convert batches produced by NodePiece dataloaders into a format used by PyTorch models. The PyG transforms help convert batches of PyTorch Geometric objects (produced by loader such as the NeighborLoader and HGTLoader).

All transforms can be called by passing the batch into the instantiated object, for example:

from pyTigerGraph.gds.transforms.pyg_transforms import TemporalPyGTransform
#instantiate connection here...
transform = TemporalPyGTransform(vertex_start_attrs=vertex_start_attrs,
                                        vertex_end_attrs=vertex_end_attrs,
                                        edge_start_attrs=edge_start_attrs,
                                        edge_end_attrs=edge_end_attrs,
                                        start_dt=0,
                                        end_dt=6,
                                        feature_transforms=feature_transforms)
train_loader = conn.gds.neighborLoader(params here, callback_fn = lambda x: transform(x))