├── .gitignore ├── LICENSE ├── README.md ├── banner.png ├── bin.py ├── grafog ├── __init__.py └── transforms │ ├── .DS_Store │ ├── __init__.py │ └── transforms.py ├── models.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | MS5.md 9 | *.pdf 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rishabh Anand 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # grafog 4 | Graph Data Augmentation Library for PyTorch Geometric. 5 | 6 | --- 7 | 8 | ## What is it? 9 | Data augmentations are heavily used in Computer Vision and Natural Language Processing to address data imbalance, data scarcity, and prevent models from overfitting. They have also proven to yield good results in both supervised and self-supervised (contrastive) settings. 10 | 11 | `grafog` (portmanteau of "graph" and "augmentation") provides a set of methods to perform data augmentation on graph-structured data, especially meant for self-supervised node classification. It is built on top of `torch_geometric` and is easily integrable with its [`Data`](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data) API. 12 | 13 | > Yannic Kilcher talks about it here: [https://youtu.be/smUHQndcmOY?t=961](https://youtu.be/smUHQndcmOY?t=961) 14 | 15 | --- 16 | 17 | ## Installation 18 | You can install the library via `pip`: 19 | 20 | ``` 21 | $ pip install grafog 22 | ``` 23 | 24 | You can also install the library from source: 25 | 26 | ``` 27 | $ git clone https://github.com/rish-16/grafog 28 | $ cd grafog 29 | $ pip install -e . 30 | ``` 31 | 32 | #### Dependencies 33 | ``` 34 | torch==1.10.2 35 | torch_geometric==2.0.3 36 | ``` 37 | --- 38 | 39 | ## Usage 40 | The library comes with the following data augmentations: 41 | 42 | | Augmentation | Remarks | When to use | 43 | |------------------------------|----------------------------------------------------|--------------------------| 44 | | `NodeDrop(p=0.05)` | Randomly drops nodes with the given `p` | before, during training | 45 | | `EdgeDrop(p=0.05)` | Randomly drops edges with the given `p` | before, during training | 46 | | `Normalize()` | Normalizes the node or edge features | before training | 47 | | `NodeMixUp(lamb, classes)` | MixUp on node features with given lambda | during training | 48 | | `NodeFeatureMasking(p=0.15)` | Randomly masks node features with the given `p` | during training | 49 | | `EdgeFeatureMasking(p=0.15)` | Randomly masks edge features with the given `p` | during training | 50 | 51 | > There are many more features to be added over time, so stay tuned! 52 | 53 | ```python 54 | from torch_geometric.datasets import CoraFull 55 | import grafog.transforms as T 56 | 57 | node_aug = T.Compose([ 58 | T.NodeDrop(p=0.45), 59 | T.NodeMixUp(lamb=0.5, classes=7), 60 | ... 61 | ]) 62 | 63 | edge_aug = T.Compose([ 64 | T.EdgeDrop(0=0.15), 65 | T.EdgeFeatureMasking() 66 | ]) 67 | 68 | data = CoraFull() 69 | model = ... 70 | 71 | for epoch in range(10): # begin training loop 72 | new_data = node_aug(data) # apply the node augmentation(s) 73 | new_data = edge_aug(new_data) # apply the edge augmentation(s) 74 | 75 | x, y = new_data.x, new_data.y 76 | ... 77 | ``` 78 | 79 | --- 80 | 81 | ## Remarks 82 | This library was built as a project for a class ([UIT2201](https://nusmods.com/modules/UIT2201/computer-science-the-i-t-revolution)) at NUS. I planned and built it over the span of 10 weeks. I thank _Prof. Mikhail Filippov_ for his guidance, feedback, and support! 83 | 84 | If you spot any issues, feel free to raise a PR or Issue. All meaningful contributions welcome! 85 | 86 | --- 87 | 88 | ## License 89 | [MIT](https://github.com/rish-16/grafog/blob/main/LICENSE) 90 | -------------------------------------------------------------------------------- /banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rish-16/grafog/95e78b3694b7d0fc92637b7d685425cf95ce8449/banner.png -------------------------------------------------------------------------------- /bin.py: -------------------------------------------------------------------------------- 1 | class KarateClub(InMemoryDataset): 2 | def __init__(self, transform=None): 3 | super(KarateClub, self).__init__('.', transform, None, None) 4 | 5 | G = nx.karate_club_graph() 6 | 7 | x = torch.eye(G.number_of_nodes(), dtype=torch.float) 8 | 9 | adj = nx.to_scipy_sparse_matrix(G).tocoo() 10 | row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long) 11 | col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long) 12 | edge_index = torch.stack([row, col], dim=0) 13 | 14 | partition = C.best_partition(G) 15 | y = torch.tensor([partition[i] for i in range(G.number_of_nodes())]) 16 | 17 | train_mask = torch.zeros(y.size(0), dtype=torch.bool) 18 | for i in range(int(y.max()) + 1): 19 | train_mask[(y == i).nonzero(as_tuple=False)[0]] = True 20 | 21 | data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask) 22 | 23 | self.data, self.slices = self.collate([data]) 24 | 25 | def visualize(h, color, epoch=None, loss=None): 26 | plt.figure(figsize=(7,7)) 27 | plt.xticks([]) 28 | plt.yticks([]) 29 | 30 | if torch.is_tensor(h): 31 | h = h.detach().cpu().numpy() 32 | plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2") 33 | if epoch is not None and loss is not None: 34 | plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16) 35 | else: 36 | nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False, 37 | node_color=color, cmap="Set2") 38 | plt.show() 39 | 40 | node_transforms = T.Compose([ 41 | T.Normalize([0.5], [0.5]) 42 | ]) 43 | 44 | edge_transforms = T.Compose([ 45 | T.RandomEdgeDrop() 46 | ]) 47 | 48 | dataset = KarateClub() 49 | data = dataset.data 50 | x = data.x 51 | edge_idx = data.edge_index 52 | y = data.y -------------------------------------------------------------------------------- /grafog/__init__.py: -------------------------------------------------------------------------------- 1 | from grafog.transforms.transforms import * -------------------------------------------------------------------------------- /grafog/transforms/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rish-16/grafog/95e78b3694b7d0fc92637b7d685425cf95ce8449/grafog/transforms/.DS_Store -------------------------------------------------------------------------------- /grafog/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from grafog.transforms.transforms import * -------------------------------------------------------------------------------- /grafog/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch_geometric as tg 5 | 6 | class Compose(nn.Module): 7 | def __init__(self, transforms): 8 | super().__init__() 9 | self.transforms = transforms 10 | 11 | def forward(self, data): 12 | for aug in self.transforms: 13 | data = aug(data) 14 | return data 15 | 16 | class NodeDrop(nn.Module): 17 | def __init__(self, p=0.05): 18 | super().__init__() 19 | self.p = p 20 | 21 | def forward(self, data): 22 | x = data.x 23 | y = data.y 24 | train_mask = data.train_mask 25 | test_mask = data.test_mask 26 | edge_idx = data.edge_index 27 | 28 | idx = torch.empty(x.size(0)).uniform_(0, 1) 29 | train_mask[torch.where(idx < self.p)] = 0 30 | test_mask[torch.where(idx < self.p)] = 0 31 | new_data = tg.data.Data(x=x, edge_index=edge_idx, y=y, train_mask=train_mask, test_mask=test_mask) 32 | 33 | return new_data 34 | 35 | class EdgeDrop(nn.Module): 36 | def __init__(self, p=0.05): 37 | super().__init__() 38 | self.p = p 39 | 40 | def forward(self, data): 41 | x = data.x 42 | y = data.y 43 | train_mask = data.train_mask 44 | test_mask = data.test_mask 45 | edge_idx = data.edge_index 46 | 47 | edge_idx = edge_idx.permute(1, 0) 48 | idx = torch.empty(edge_idx.size(0)).uniform_(0, 1) 49 | edge_idx = edge_idx[torch.where(idx >= self.p)].permute(1, 0) 50 | new_data = tg.data.Data(x=x, y=y, edge_index=edge_idx, train_mask=train_mask, test_mask=test_mask) 51 | return new_data 52 | 53 | class Normalize(nn.Module): 54 | def __init__(self): 55 | super().__init__() 56 | 57 | def forward(self, data): 58 | x = data.x 59 | y = data.y 60 | train_mask = data.train_mask 61 | test_mask = data.test_mask 62 | edge_idx = data.edge_index 63 | 64 | x = F.normalize(x) 65 | new_data = tg.data.Data(x=x, y=y, edge_index=edge_idx, train_mask=train_mask, test_mask=test_mask) 66 | return new_data 67 | 68 | class NodeMixUp(nn.Module): 69 | def __init__(self, lamb, classes): 70 | super().__init__() 71 | self.lamb = lamb 72 | self.classes = classes 73 | 74 | def forward(self, data): 75 | x = data.x 76 | y = data.y 77 | train_mask = data.train_mask 78 | test_mask = data.test_mask 79 | edge_idx = data.edge_index 80 | 81 | n, d = x.shape 82 | 83 | pair_idx = torch.randperm(n) 84 | x_b = x[pair_idx] 85 | y_b = y[pair_idx] 86 | y_a_oh = F.one_hot(y, self.classes) 87 | y_b_oh = F.one_hot(y_b, self.classes) 88 | 89 | x_mix = (self.lamb * x) + (1-self.lamb)*x_b 90 | y_mix = (self.lamb * y_a_oh) + (1 - self.lamb) * y_b_oh 91 | new_y = y_mix.argmax(1) 92 | 93 | # new_x = torch.vstack([x, x_mix]) 94 | # new_y = torch.vstack([y_a_oh, y_mix]) 95 | 96 | new_data = tg.data.Data(x=x_mix, y=new_y, edge_index=edge_idx, train_mask=train_mask, test_mask=test_mask) 97 | return new_data 98 | 99 | class NodeFeatureMasking(nn.Module): 100 | def __init__(self, p=0.15): 101 | super().__init__() 102 | self.p = p 103 | 104 | def forward(self, data): 105 | x = data.x 106 | y = data.y 107 | edge_attr = data.edge_attr 108 | train_mask = data.train_mask 109 | test_mask = data.test_mask 110 | edge_idx = data.edge_index 111 | 112 | n, d = x.shape 113 | 114 | idx = torch.empty((d,), dtype=torch.float32).uniform_(0, 1) < self.p 115 | x = x.clone() 116 | x[:, idx] = 0 117 | 118 | new_data = tg.data.Data(x=x, y=y, edge_index=edge_idx, train_mask=train_mask, test_mask=test_mask, edge_attr=edge_attr) 119 | return new_data 120 | 121 | class EdgeFeatureMasking(nn.Module): 122 | def __init__(self, p=0.15): 123 | super().__init__() 124 | self.p = p 125 | 126 | def forward(self, data): 127 | x = data.x 128 | y = data.y 129 | edge_attr = data.edge_attr 130 | train_mask = data.train_mask 131 | test_mask = data.test_mask 132 | edge_idx = data.edge_index 133 | 134 | n, d = edge_attr.shape 135 | 136 | idx = torch.empty((d,), dtype=torch.float32).uniform_(0, 1) < self.p 137 | edge_attr = edge_attr.clone() 138 | edge_attr[:, idx] = 0 139 | 140 | new_data = tg.data.Data(x=x, y=y, edge_index=edge_idx, train_mask=train_mask, test_mask=test_mask, edge_attr=edge_attr) 141 | return new_data -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torch_geometric as tg 6 | import torch_geometric.nn as tgnn 7 | 8 | class GCN(nn.Module): 9 | def __init__(self, inchannels, hidden, outchannels): 10 | super().__init__() 11 | self.conv1 = tgnn.GCNConv(inchannels, hidden) 12 | self.conv2 = tgnn.GCNConv(hidden, hidden) 13 | self.conv3 = tgnn.GCNConv(hidden, hidden) 14 | self.conv4 = tgnn.GCNConv(hidden, outchannels) 15 | 16 | def forward(self, x, edge_idx): 17 | x = torch.relu(self.conv1(x, edge_idx)) 18 | x = torch.relu(self.conv2(x, edge_idx)) 19 | x = torch.relu(self.conv3(x, edge_idx)) 20 | out = torch.softmax(self.conv4(x, edge_idx), 1) 21 | return out 22 | 23 | class GAT(nn.Module): 24 | def __init__(self, inchannels, hidden, outchannels): 25 | super().__init__() 26 | self.conv1 = tgnn.GATConv(inchannels, hidden) 27 | self.conv2 = tgnn.GATConv(hidden, hidden) 28 | self.conv3 = tgnn.GATConv(hidden, hidden) 29 | self.conv4 = tgnn.GATConv(hidden, outchannels) 30 | 31 | def forward(self, x, edge_idx): 32 | x = torch.relu(self.conv1(x, edge_idx)) 33 | x = torch.relu(self.conv2(x, edge_idx)) 34 | x = torch.relu(self.conv3(x, edge_idx)) 35 | out = torch.softmax(self.conv4(x, edge_idx), 1) 36 | return out -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.2 2 | torch_geometric==2.0.3 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('README.md') as readme_file: 4 | README = readme_file.read() 5 | 6 | setup( 7 | name = 'grafog', 8 | packages = find_packages(exclude=[]), 9 | version = '0.1', 10 | license='MIT', 11 | description = 'Graph Data Augmentations for PyTorch Geometric', 12 | long_description_content_type="text/markdown", 13 | long_description=README, 14 | author = 'Rishabh Anand', 15 | author_email = 'mail.rishabh.anand@gmail.com', 16 | url = 'https://github.com/rish-16/grafog', 17 | keywords = [ 18 | 'machine learning', 19 | 'graph deep learning', 20 | 'data augmentations' 21 | ], 22 | install_requires=[ 23 | 'torch>=1.10', 24 | 'torch_geometric>=2.0' 25 | ], 26 | classifiers=[ 27 | 'Development Status :: 4 - Beta', 28 | 'Intended Audience :: Developers', 29 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 30 | 'License :: OSI Approved :: MIT License', 31 | 'Programming Language :: Python :: 3.6', 32 | ], 33 | ) 34 | --------------------------------------------------------------------------------