Skip to content

Commit 495b8ec

Browse files
authored
Add torch_geometric contrib module (#43)
1 parent ca2bd8a commit 495b8ec

File tree

4 files changed

+118
-0
lines changed

4 files changed

+118
-0
lines changed

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@
230230
intersphinx_mapping = {
231231
'https://docs.python.org/3/': None,
232232
'torch': ('https://pytorch.org/docs/stable', None),
233+
'torch_geometric': ('https://pytorch-geometric.readthedocs.io/en/latest', None),
233234
'numpy': ('https://numpy.org/doc/stable', None),
234235
'optuna': ('https://optuna.readthedocs.io/en/latest', None),
235236
'sklearn': ('https://scikit-learn.org/stable', None),

setup.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ ray =
7878
ray[tune]<2.0.0; python_version < "3.9"
7979
torch =
8080
torch
81+
torch-geometric =
82+
torch
83+
torch-sparse
84+
torch-geometric
8185
optuna =
8286
optuna
8387
numpy =
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
PyTorch Geometric is an extension to PyTorch for geometric learning on graphs,
5+
point clouds, meshes, and other non-standard objects.
6+
The ``class-resolver`` provides several class resolvers and function resolvers
7+
to make it possible to more easily parametrize models and training loops.
8+
""" # noqa:D205,D400
9+
10+
from torch_geometric.nn.aggr import Aggregation, MeanAggregation
11+
from torch_geometric.nn.conv import MessagePassing, SimpleConv
12+
13+
from ..api import ClassResolver
14+
15+
__all__ = [
16+
"message_passing_resolver",
17+
"aggregation_resolver",
18+
]
19+
20+
message_passing_resolver = ClassResolver.from_subclasses(
21+
base=MessagePassing, # type: ignore
22+
suffix="Conv",
23+
default=SimpleConv,
24+
)
25+
"""A resolver for message passing layers.
26+
27+
.. seealso:: https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#convolutional-layers
28+
"""
29+
30+
aggregation_resolver = ClassResolver.from_subclasses(
31+
base=Aggregation,
32+
default=MeanAggregation,
33+
)
34+
35+
"""A resolver for aggregation layers.
36+
37+
This includes the following:
38+
39+
- :class:`torch_geometric.nn.aggr.MeanAggregation`
40+
- :class:`torch_geometric.nn.aggr.MaxAggregation`
41+
- :class:`torch_geometric.nn.aggr.MinAggregation`
42+
- :class:`torch_geometric.nn.aggr.SumAggregation`
43+
- :class:`torch_geometric.nn.aggr.MedianAggregation`
44+
- :class:`torch_geometric.nn.aggr.SoftmaxAggregation` (learnable)
45+
- :class:`torch_geometric.nn.aggr.PowerMeanAggregation` (learnable)
46+
- :class:`torch_geometric.nn.aggr.LSTMAggregation` (learnable)
47+
- :class:`torch_geometric.nn.aggr.MLPAggregation` (learnable)
48+
- :class:`torch_geometric.nn.aggr.SetTransformerAggregation)` (learnable)
49+
- :class:`torch_geometric.nn.aggr.SortAggregation`
50+
51+
Some example usage (based on the torch-geometric docs):
52+
53+
.. code-block::
54+
55+
import torch
56+
57+
from class_resolver.contrib.torch_geometric import aggregation_resolver
58+
59+
mean_aggr = aggregation_resolver.make("mean")
60+
61+
# Feature matrix holding 1000 elements with 64 features each:
62+
x = torch.randn(1000, 64)
63+
64+
# Randomly assign elements to 100 sets:
65+
index = torch.randint(0, 100, (1000, ))
66+
67+
output = mean_aggr(x, index) # Output shape: [100, 64]
68+
69+
.. seealso:: https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#aggregation-operators
70+
"""
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""Tests for the torch-geometric contribution module."""
4+
5+
import unittest
6+
7+
try:
8+
import torch_geometric
9+
except ImportError: # pragma: no cover
10+
torch_geometric = None # pragma: no cover
11+
12+
13+
@unittest.skipUnless(
14+
torch_geometric, "Can not test torch_geometric contrib without ``pip install torch``."
15+
)
16+
class TestTorch(unittest.TestCase):
17+
"""Test for the torch-geometric contribution module."""
18+
19+
def test_message_passing(self):
20+
"""Tests for the message passing resolver."""
21+
from torch_geometric.nn.conv import SimpleConv
22+
23+
from class_resolver.contrib.torch_geometric import message_passing_resolver
24+
25+
self.assertEqual(SimpleConv, message_passing_resolver.lookup("simple"))
26+
self.assertEqual(SimpleConv, message_passing_resolver.lookup(None))
27+
28+
def test_aggregation(self):
29+
"""Test the aggregation resolver."""
30+
import torch
31+
32+
from class_resolver.contrib.torch_geometric import aggregation_resolver
33+
34+
# Feature matrix holding 1000 elements with 64 features each:
35+
x = torch.randn(1000, 64)
36+
37+
# Randomly assign elements to 100 sets:
38+
index = torch.randint(0, 100, (1000,))
39+
40+
for cls in aggregation_resolver:
41+
aggr = cls()
42+
output = aggr(x, index)
43+
self.assertEqual((100, 64), tuple(output.shape))

0 commit comments

Comments
 (0)