diff --git a/python/gigl/distributed/dist_ablp_neighborloader.py b/python/gigl/distributed/dist_ablp_neighborloader.py index 4674b798..189d8393 100644 --- a/python/gigl/distributed/dist_ablp_neighborloader.py +++ b/python/gigl/distributed/dist_ablp_neighborloader.py @@ -21,8 +21,10 @@ from gigl.distributed.distributed_neighborloader import DEFAULT_NUM_CPU_THREADS from gigl.distributed.sampler import ABLPNodeSamplerInput from gigl.distributed.utils.neighborloader import ( + NodeSamplerInput, labeled_to_homogeneous, patch_fanout_for_sampling, + resolve_node_sampler_input_from_user_input, shard_nodes_by_process, strip_label_edges, ) @@ -45,12 +47,7 @@ def __init__( self, dataset: DistLinkPredictionDataset, num_neighbors: Union[list[int], dict[EdgeType, list[int]]], - input_nodes: Optional[ - Union[ - torch.Tensor, - tuple[NodeType, torch.Tensor], - ] - ] = None, + input_nodes: NodeSamplerInput = None, # TODO(kmonte): Support multiple supervision edge types. supervision_edge_type: Optional[EdgeType] = None, num_workers: int = 1, @@ -118,11 +115,9 @@ def __init__( context (DistributedContext): Distributed context information of the current process. local_process_rank (int): The local rank of the current process within a node. local_process_world_size (int): The total number of processes within a node. - input_nodes (Optional[torch.Tensor, tuple[NodeType, torch.Tensor]]): + input_nodes (NodeSamplerInput): Indices of seed nodes to start sampling from. - If set to `None` for homogeneous settings, all nodes will be considered. - In heterogeneous graphs, this flag must be passed in as a tuple that holds - the node type and node indices. (default: `None`) + See documentation for `gigl.distributed.utils.neighborloader.NodeSamplerInput` for more details. num_workers (int): How many workers to use (subprocesses to spwan) for distributed neighbor sampling of the current process. (default: ``1``). batch_size (int, optional): how many samples per batch to load @@ -235,50 +230,41 @@ def __init__( f"The dataset must be heterogeneous for ABLP. Recieved dataset with graph of type: {type(dataset.graph)}" ) self._is_input_heterogeneous: bool = False - if isinstance(input_nodes, tuple): - if supervision_edge_type is None: - raise ValueError( - "When using heterogeneous ABLP, you must provide supervision_edge_types." - ) - self._is_input_heterogeneous = True - anchor_node_type, anchor_node_ids = input_nodes - # TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if - # this assumption is no longer valid and/or is too opinionated - assert ( - supervision_edge_type[0] == anchor_node_type - ), f"Label EdgeType are currently expected to be provided in outward edge direction as tuple (`anchor_node_type`,`relation`,`supervision_node_type`), \ - got supervision edge type {supervision_edge_type} with anchor node type {anchor_node_type}" - supervision_node_type = supervision_edge_type[2] - if dataset.edge_dir == "in": - supervision_edge_type = reverse_edge_type(supervision_edge_type) + ( + anchor_node_type, + anchor_node_ids, + self._is_labeled_homogeneous, + ) = resolve_node_sampler_input_from_user_input( + input_nodes=input_nodes, + dataset_nodes=dataset.node_ids, + ) - elif isinstance(input_nodes, torch.Tensor): + if ( + anchor_node_type is None + or anchor_node_type == DEFAULT_HOMOGENEOUS_NODE_TYPE + ): if supervision_edge_type is not None: raise ValueError( f"Expected supervision edge type to be None for homogeneous input nodes, got {supervision_edge_type}" ) - anchor_node_ids = input_nodes - anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE supervision_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE supervision_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE - elif input_nodes is None: - if dataset.node_ids is None: - raise ValueError( - "Dataset must have node ids if input_nodes are not provided." - ) - if isinstance(dataset.node_ids, abc.Mapping): + else: + if supervision_edge_type is None: raise ValueError( - f"input_nodes must be provided for heterogeneous datasets, received node_ids of type: {dataset.node_ids.keys()}" + "When using heterogeneous ABLP, you must provide supervision_edge_type." ) - if supervision_edge_type is not None: + self._is_input_heterogeneous = True + # TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if + # this assumption is no longer valid and/or is too opinionated + if supervision_edge_type[0] != anchor_node_type: raise ValueError( - f"Expected supervision edge type to be None for homogeneous input nodes, got {supervision_edge_type}" + f"Label EdgeType are currently expected to be provided in outward edge direction as tuple (`anchor_node_type`,`relation`,`supervision_node_type`), \ + got supervision edge type {supervision_edge_type} with anchor node type {anchor_node_type}" ) - - anchor_node_ids = dataset.node_ids - anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE - supervision_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE - supervision_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE + supervision_node_type = supervision_edge_type[2] + if dataset.edge_dir == "in": + supervision_edge_type = reverse_edge_type(supervision_edge_type) missing_edge_types = set([supervision_edge_type]) - set(dataset.graph.keys()) if missing_edge_types: @@ -542,6 +528,10 @@ def _set_labels( local_node_to_global_node: torch.Tensor # shape [N], where N is the number of nodes in the subgraph, and local_node_to_global_node[i] gives the global node id for local node id `i` if isinstance(data, HeteroData): + if self._supervision_edge_type is None: + raise ValueError( + "When using heterogeneous ABLP, you must provide supervision_edge_type." + ) supervision_node_type = ( self._supervision_edge_type[0] if self.edge_dir == "in" diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index caecd749..5599bf86 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -1,10 +1,11 @@ from collections import Counter, abc -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import torch from graphlearn_torch.channel import SampleMessage from graphlearn_torch.distributed import DistLoader, MpDistSamplingWorkerOptions -from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType +from graphlearn_torch.sampler import NodeSamplerInput as GLTNodeSamplerInput +from graphlearn_torch.sampler import SamplingConfig, SamplingType from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType @@ -14,18 +15,14 @@ from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_link_prediction_dataset import DistLinkPredictionDataset from gigl.distributed.utils.neighborloader import ( + NodeSamplerInput, labeled_to_homogeneous, patch_fanout_for_sampling, + resolve_node_sampler_input_from_user_input, shard_nodes_by_process, strip_label_edges, ) -from gigl.src.common.types.graph_data import ( - NodeType, # TODO (mkolodner-sc): Change to use torch_geometric.typing -) -from gigl.types.graph import ( - DEFAULT_HOMOGENEOUS_EDGE_TYPE, - DEFAULT_HOMOGENEOUS_NODE_TYPE, -) +from gigl.types.graph import DEFAULT_HOMOGENEOUS_EDGE_TYPE logger = Logger() @@ -38,9 +35,7 @@ def __init__( self, dataset: DistLinkPredictionDataset, num_neighbors: Union[List[int], Dict[EdgeType, List[int]]], - input_nodes: Optional[ - Union[torch.Tensor, Tuple[NodeType, torch.Tensor]] - ] = None, + input_nodes: NodeSamplerInput = None, num_workers: int = 1, batch_size: int = 1, context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this @@ -70,12 +65,9 @@ def __init__( context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon)(int): Required if context provided. The total number of processes within a node. - input_nodes (torch.Tensor or Tuple[str, torch.Tensor]): The - indices of seed nodes to start sampling from. - It is of type `torch.LongTensor` for homogeneous graphs. - If set to `None` for homogeneous settings, all nodes will be considered. - In heterogeneous graphs, this flag must be passed in as a tuple that holds - the node type and node indices. (default: `None`) + input_nodes (NodeSamplerInput): + Indices of seed nodes to start sampling from. + See documentation for `gigl.distributed.utils.neighborloader.NodeSamplerInput` for more details. num_workers (int): How many workers to use (subprocesses to spwan) for distributed neighbor sampling of the current process. (default: ``1``). batch_size (int, optional): how many samples per batch to load @@ -222,38 +214,28 @@ def __init__( ) # Determines if the node ids passed in are heterogeneous or homogeneous. - self._is_labeled_heterogeneous = False - if isinstance(input_nodes, torch.Tensor): - node_ids = input_nodes - - # If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting, - # if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE. - if isinstance(dataset.node_ids, abc.Mapping): - if ( - len(dataset.node_ids) == 1 - and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids - ): - node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE - self._is_labeled_heterogeneous = True - num_neighbors = patch_fanout_for_sampling( - dataset.get_edge_types(), num_neighbors - ) - else: - raise ValueError( - f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}" - ) - else: - node_type = None - else: - node_type, node_ids = input_nodes - + ( + anchor_node_type, + anchor_node_ids, + self._is_labeled_homogeneous, + ) = resolve_node_sampler_input_from_user_input( + input_nodes=input_nodes, + dataset_nodes=dataset.node_ids, + ) + if self._is_labeled_homogeneous: + # If the dataset is labeled homogeneous, we need to patch the fanout for sampling. + num_neighbors = patch_fanout_for_sampling( + dataset.get_edge_types(), num_neighbors + ) curr_process_nodes = shard_nodes_by_process( - input_nodes=node_ids, + input_nodes=anchor_node_ids, local_process_rank=local_rank, local_process_world_size=local_world_size, ) - input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type) + input_data = GLTNodeSamplerInput( + node=curr_process_nodes, input_type=anchor_node_type + ) # Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize # the memory overhead and CPU contention. @@ -343,6 +325,6 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: if isinstance(data, HeteroData): data = strip_label_edges(data) - if self._is_labeled_heterogeneous: + if self._is_labeled_homogeneous: data = labeled_to_homogeneous(DEFAULT_HOMOGENEOUS_EDGE_TYPE, data) return data diff --git a/python/gigl/distributed/utils/neighborloader.py b/python/gigl/distributed/utils/neighborloader.py index 66013076..3d98a6a2 100644 --- a/python/gigl/distributed/utils/neighborloader.py +++ b/python/gigl/distributed/utils/neighborloader.py @@ -1,13 +1,15 @@ """Utils for Neighbor loaders.""" +from collections import abc from copy import deepcopy -from typing import Union +from typing import Optional, Union import torch from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType from gigl.common.logger import Logger -from gigl.types.graph import is_label_edge_type +from gigl.src.common.types.graph_data import NodeType +from gigl.types.graph import DEFAULT_HOMOGENEOUS_NODE_TYPE, is_label_edge_type logger = Logger() @@ -118,3 +120,84 @@ def strip_label_edges(data: HeteroData) -> HeteroData: del data.num_sampled_edges[edge_type] return data + + +# Allowed inputs for node samplers. +# If None is provded, then all nodes in the graph will be sampled. +# And the graph must be homogeneous. +# If a single tensor is provided, it is assumed to be a tensor of node IDs. +# And the graph must be homogeneous, or labled homogeneous. +# If a tuple is provided, the first element is the node type and the second element is the tensor of node IDs. +# If a dict is provided, the keys are node types and the values are tensors of node IDs. +# If a dict is provided, the graph must be heterogeneous, and there must be only one key/value pair in the dict. +# We allow dicts to be passed in as a convenenience for users who have a heterogeneous graph with only one supervision edge type. +NodeSamplerInput = Optional[ + Union[ + torch.Tensor, tuple[NodeType, torch.Tensor], abc.Mapping[NodeType, torch.Tensor] + ] +] + + +def resolve_node_sampler_input_from_user_input( + input_nodes: NodeSamplerInput, + dataset_nodes: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]], +) -> tuple[Optional[NodeType], torch.Tensor, bool]: + """Resolves the input nodes for a node sampler. + This function takes the user input for input nodes and resolves it to a consistent format. + + See the comment above NodeSamplerInput for the allowed inputs. + + Args: + input_nodes (NodeSamplerInput): The input nodes provided by the user. + dataset_nodes (Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]): The nodes in the dataset. + + Returns: + tuple[NodeType, torch.Tensor, bool]: A tuple containing: + - node_type (NodeType): The type of the nodes. + - node_ids (torch.Tensor): The tensor of node IDs. + - is_labeled_homogeneous (bool): Whether the dataset is a labeled homogeneous graph. + """ + is_labeled_homoogeneous = False + if isinstance(input_nodes, torch.Tensor): + node_ids = input_nodes + + # If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting, + # if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE. + if isinstance(dataset_nodes, dict): + if ( + len(dataset_nodes) == 1 + and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset_nodes + ): + node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE + is_labeled_homoogeneous = True + else: + raise ValueError( + f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset_nodes.keys()}" + ) + else: + node_type = None + elif isinstance(input_nodes, abc.Mapping): + if len(input_nodes) != 1: + raise ValueError( + f"If input_nodes is provided as a mapping, it must contain exactly one key/value pair. Received: {input_nodes}. This may happen if you call Loader(node_ids=dataset.node_ids) with a heterogeneous dataset." + ) + node_type, node_ids = next(iter(input_nodes.items())) + is_labeled_homoogeneous = node_type == DEFAULT_HOMOGENEOUS_NODE_TYPE + elif isinstance(input_nodes, tuple): + node_type, node_ids = input_nodes + elif input_nodes is None: + if dataset_nodes is None: + raise ValueError("If input_nodes is None, the dataset must have node ids.") + if isinstance(dataset_nodes, torch.Tensor): + node_type = None + node_ids = dataset_nodes + elif isinstance(dataset_nodes, dict): + raise ValueError( + f"Input nodes must be provided for a heterogeneous graph. Received: {dataset_nodes}" + ) + + return ( + node_type, + node_ids, + is_labeled_homoogeneous, + ) diff --git a/python/tests/unit/distributed/distributed_neighborloader_test.py b/python/tests/unit/distributed/distributed_neighborloader_test.py index 391ced5c..a53ce80b 100644 --- a/python/tests/unit/distributed/distributed_neighborloader_test.py +++ b/python/tests/unit/distributed/distributed_neighborloader_test.py @@ -94,7 +94,7 @@ def _run_distributed_neighbor_loader_labeled_homogeneous( assert isinstance(dataset.node_ids, abc.Mapping) loader = DistNeighborLoader( dataset=dataset, - input_nodes=to_homogeneous(dataset.node_ids), + input_nodes=dataset.node_ids, num_neighbors=[2, 2], context=context, local_process_rank=0, @@ -253,7 +253,7 @@ def _run_cora_supervised( loader = DistABLPLoader( dataset=dataset, num_neighbors=[2, 2], - input_nodes=to_homogeneous(dataset.train_node_ids), + input_nodes=dataset.train_node_ids, pin_memory_device=torch.device("cpu"), ) count = 0 diff --git a/python/tests/unit/distributed/utils/neighborloader_test.py b/python/tests/unit/distributed/utils/neighborloader_test.py index 03d8f90d..4b253fe7 100644 --- a/python/tests/unit/distributed/utils/neighborloader_test.py +++ b/python/tests/unit/distributed/utils/neighborloader_test.py @@ -8,10 +8,15 @@ from gigl.distributed.utils.neighborloader import ( labeled_to_homogeneous, patch_fanout_for_sampling, + resolve_node_sampler_input_from_user_input, shard_nodes_by_process, strip_label_edges, ) -from gigl.types.graph import message_passing_to_positive_label +from gigl.src.common.types.graph_data import NodeType +from gigl.types.graph import ( + DEFAULT_HOMOGENEOUS_NODE_TYPE, + message_passing_to_positive_label, +) from tests.test_assets.distributed.utils import assert_tensor_equality _U2I_EDGE_TYPE = ("user", "to", "item") @@ -145,3 +150,114 @@ def test_strip_label_edges(self): self.assertFalse(_LABELED_EDGE_TYPE in stripped_data.num_sampled_edges) self.assertTrue(_U2I_EDGE_TYPE in stripped_data.num_sampled_edges) self.assertTrue(_I2U_EDGE_TYPE in stripped_data.num_sampled_edges) + + @parameterized.expand( + [ + param( + "homogeneous_tensor_input", + input_nodes=torch.tensor([1, 2, 3]), + dataset_nodes=torch.tensor([1, 2, 3, 4]), + expected_node_type=None, + expected_node_ids=torch.tensor([1, 2, 3]), + expected_is_labeled_homogeneous=False, + ), + param( + "labeled_homogeneous_tensor_input", + input_nodes=torch.tensor([1, 2, 3]), + dataset_nodes={ + DEFAULT_HOMOGENEOUS_NODE_TYPE: torch.tensor([1, 2, 3, 4]) + }, + expected_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, + expected_node_ids=torch.tensor([1, 2, 3]), + expected_is_labeled_homogeneous=True, + ), + param( + "heterogeneous_mapping_input", + input_nodes={"user": torch.tensor([1, 2])}, + dataset_nodes={ + "user": torch.tensor([1, 2, 3]), + "item": torch.tensor([4, 5]), + }, + expected_node_type=NodeType("user"), + expected_node_ids=torch.tensor([1, 2]), + expected_is_labeled_homogeneous=False, + ), + param( + "tuple_input", + input_nodes=("user", torch.tensor([1, 2])), + dataset_nodes=None, + expected_node_type=NodeType("user"), + expected_node_ids=torch.tensor([1, 2]), + expected_is_labeled_homogeneous=False, + ), + param( + "none_input_homogeneous_dataset", + input_nodes=None, + dataset_nodes=torch.tensor([1, 2, 3]), + expected_node_type=None, + expected_node_ids=torch.tensor([1, 2, 3]), + expected_is_labeled_homogeneous=False, + ), + ] + ) + def test_resolve_node_sampler_input_valid( + self, + _, + input_nodes, + dataset_nodes, + expected_node_type, + expected_node_ids, + expected_is_labeled_homogeneous, + ): + ( + node_type, + node_ids, + is_labeled_homogeneous, + ) = resolve_node_sampler_input_from_user_input(input_nodes, dataset_nodes) + self.assertEqual(node_type, expected_node_type) + assert_tensor_equality(node_ids, expected_node_ids) + self.assertEqual(is_labeled_homogeneous, expected_is_labeled_homogeneous) + + @parameterized.expand( + [ + param( + "heterogeneous_tensor_input_raises", + input_nodes=torch.tensor([1, 2, 3]), + dataset_nodes={ + "user": torch.tensor([1, 2]), + "item": torch.tensor([3, 4]), + }, + expected_exception=ValueError, + ), + param( + "mapping_with_multiple_keys_raises", + input_nodes={"user": torch.tensor([1]), "item": torch.tensor([2])}, + dataset_nodes={ + "user": torch.tensor([1, 2]), + "item": torch.tensor([3, 4]), + }, + expected_exception=ValueError, + ), + param( + "none_input_heterogeneous_dataset_raises", + input_nodes=None, + dataset_nodes={"user": torch.tensor([1, 2])}, + expected_exception=ValueError, + ), + param( + "none_input_no_dataset_raises", + input_nodes=None, + dataset_nodes=None, + expected_exception=ValueError, + ), + ] + ) + def test_resolve_node_sampler_input_invalid( + self, _, input_nodes, dataset_nodes, expected_exception + ): + with self.assertRaises(expected_exception): + resolve_node_sampler_input_from_user_input(input_nodes, dataset_nodes) + + +if __name__ == "__main__": + unittest.main()