├── bench.png ├── gcn.py ├── LICENSE ├── README.md ├── benchmark.py ├── viz.py ├── datasets.py ├── .gitignore ├── main.py └── main_torch.py /bench.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TristanBilot/mlx-GCN/HEAD/bench.png -------------------------------------------------------------------------------- /gcn.py: -------------------------------------------------------------------------------- 1 | import mlx.nn as nn 2 | 3 | 4 | class GCNLayer(nn.Module): 5 | def __init__(self, in_features, out_features, bias=True): 6 | super(GCNLayer, self).__init__() 7 | self.linear = nn.Linear(in_features, out_features, bias) 8 | 9 | def __call__(self, x, adj): 10 | x = self.linear(x) 11 | return adj @ x 12 | 13 | 14 | class GCN(nn.Module): 15 | def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True): 16 | super(GCN, self).__init__() 17 | 18 | layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim] 19 | self.gcn_layers = [ 20 | GCNLayer(in_dim, out_dim, bias) 21 | for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:]) 22 | ] 23 | self.dropout = nn.Dropout(p=dropout) 24 | 25 | def __call__(self, x, adj): 26 | for layer in self.gcn_layers[:-1]: 27 | x = nn.relu(layer(x, adj)) 28 | x = self.dropout(x) 29 | 30 | x = self.gcn_layers[-1](x, adj) 31 | return x 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tristan Bilot 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Convolutional Network in MLX 2 | 3 | An example of [GCN](https://arxiv.org/pdf/1609.02907.pdf%EF%BC%89) implementation with MLX. Other examples are available here. 4 | 5 | The actual benchmark on **M1 Pro**, **M2 Ultra**, **M3 Max** and **Tesla V100**s is explained in this Medium article. 6 | 7 | ### Install env and requirements 8 | 9 | ``` 10 | CONDA_SUBDIR=osx-arm64 conda create -n mlx python=3.10 numpy pytorch scipy requests -c conda-forge 11 | 12 | conda activate mlx 13 | pip install mlx 14 | ``` 15 | 16 | ### Run 17 | To try the model, just run the `main.py` file. This will download the Cora dataset, run the training and testing. The actual MLX code is located in `main.py`, whereas the PyTorch equivalent is in `main_torch.py`. 18 | 19 | ``` 20 | python main.py 21 | ``` 22 | 23 | ### Run benchmark 24 | To run the benchmark on CUDA device, a new env needs to be set up without the `CONDA_SUBDIR=osx-arm64` prefix, to be in i386 mode and not arm. For all other experiments on arm and Apple Silicon, just use the env created previously. 25 | ``` 26 | python benchmark.py --experiment=[ mlx | torch_mps | torch_cpu | torch_cuda ] 27 | ``` 28 | 29 | ### Process benchmark figure 30 | This needs to install additional packages: `matplotlib` and `scikit-learn`. 31 | 32 | ``` 33 | python viz.py 34 | ``` 35 | 36 | Benchmark of GCN on MLX, MPS, CPU, CUDA 37 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import numpy as np 4 | 5 | 6 | def run(args): 7 | assert args.experiment in ["torch_cpu", "torch_mps", "torch_cuda", "mlx"], \ 8 | "Invalid backend." 9 | 10 | if args.experiment == 'mlx': 11 | from main import main as mlx_main 12 | 13 | else: 14 | import torch 15 | from main_torch import main as torch_main 16 | 17 | if args.experiment == 'torch_mps': 18 | device = torch.device("mps") 19 | elif args.experiment == 'torch_cuda': 20 | device = torch.device("cuda") 21 | elif args.experiment == 'torch_cpu': 22 | device = torch.device("cpu") 23 | 24 | times = [] 25 | for i in range(args.nb_experiment): 26 | if args.experiment == 'mlx': 27 | t = mlx_main(args) 28 | else: 29 | t = torch_main(args, device) 30 | times.append(t) 31 | 32 | mean = np.mean(times) 33 | std = np.std(times) 34 | 35 | print("") 36 | print(f"Mean training epoch time: {mean:.5f} seconds") 37 | print(f"Std training epoch time: {std:.5f} seconds") 38 | 39 | 40 | if __name__ == '__main__': 41 | 42 | parser = ArgumentParser() 43 | parser.add_argument("--experiment", type=str, default="mlx") # ["torch_cpu", "torch_mps", "torch_cuda", "mlx"] 44 | parser.add_argument("--nb_experiment", type=int, default=5) 45 | 46 | parser.add_argument("--nodes_path", type=str, default="cora/cora.content") 47 | parser.add_argument("--edges_path", type=str, default="cora/cora.cites") 48 | parser.add_argument("--hidden_dim", type=int, default=16) 49 | parser.add_argument("--dropout", type=float, default=0.5) 50 | parser.add_argument("--nb_layers", type=int, default=2) 51 | parser.add_argument("--nb_classes", type=int, default=7) 52 | parser.add_argument("--bias", type=bool, default=True) 53 | parser.add_argument("--lr", type=float, default=0.01) 54 | parser.add_argument("--weight_decay", type=float, default=5e-3) 55 | parser.add_argument("--patience", type=int, default=1000000) # we do not use patience in benchmark 56 | parser.add_argument("--epochs", type=int, default=50) 57 | args = parser.parse_args() 58 | 59 | run(args) 60 | -------------------------------------------------------------------------------- /viz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from sklearn.manifold import TSNE 4 | 5 | # Color map for each class 6 | cora_label_to_color_map = {0: "red", 1: "blue", 2: "green", 7 | 3: "orange", 4: "yellow", 5: "pink", 6: "gray"} 8 | 9 | 10 | def benchmark_plot(): 11 | backends = [ 12 | ' CPU', ' MPS', ' MLX', 13 | ' CPU', ' MPS', ' MLX', 14 | 'CPU', 'MPS', 'MLX', 15 | 'CUDA (PCIe)', 'CUDA (NVLINK)', 16 | ] 17 | times = [ 18 | 45.58, 21.10, 9.02, 19 | 9.31, 7.19, 5.8, 20 | 7.07, 4.8, 4.72, 21 | 3.83, 3.51, 22 | ] 23 | 24 | # Setting colors 25 | colors = [ 26 | '#FFCC99' for _ in range(3)] \ 27 | + ['#9999FF' for _ in range(3)] \ 28 | + ['#FF9999' for _ in range(3)] \ 29 | + ['skyblue' for _ in range(2)] 30 | 31 | plt.figure(figsize=(10, 6)) 32 | bars = plt.barh(backends, times, color=colors, edgecolor='black') 33 | 34 | plt.xlabel('Mean time per epoch (ms)') 35 | plt.title('Benchmarking MLX and PyTorch Backends') 36 | plt.gca().invert_yaxis() 37 | 38 | for index, value in enumerate(times): 39 | plt.text(value, index, f' {value} ms', va='center') 40 | 41 | # Adding a legend 42 | legend_elements = [ 43 | plt.Line2D([0], [0], color='#FFCC99', lw=4, label='M1 Pro'), 44 | plt.Line2D([0], [0], color='#9999FF', lw=4, label='M2 Ultra'), 45 | plt.Line2D([0], [0], color='#FF9999', lw=4, label='M3 Max'), 46 | plt.Line2D([0], [0], color='skyblue', lw=4, label='Tesla V100') 47 | ] 48 | plt.legend(handles=legend_elements, loc='lower right') 49 | 50 | 51 | plt.savefig("bench.png") 52 | 53 | 54 | def visualize_embedding_tSNE(labels, out_features, num_classes): 55 | """ https://github.com/gordicaleksa/pytorch-GAT """ 56 | t_sne_embeddings = TSNE(n_components=2, perplexity=30, method='barnes_hut').fit_transform(out_features) 57 | 58 | plt.figure() 59 | for class_id in range(num_classes): 60 | plt.scatter(t_sne_embeddings[node_labels == class_id, 0], 61 | t_sne_embeddings[node_labels == class_id, 1], s=20, 62 | color=cora_label_to_color_map[class_id], 63 | edgecolors='black', linewidths=0.15) 64 | 65 | plt.axis("off") 66 | plt.title("t-SNE projection of the learned features") 67 | plt.show() 68 | 69 | 70 | def visualize_validation_performance(val_acc, val_loss): 71 | f, axs = plt.subplots(1, 2, figsize=(13, 5.5)) 72 | axs[0].plot(val_loss, linewidth=2, color="red") 73 | axs[0].set_title("Validation loss") 74 | axs[0].set_ylabel("Cross Entropy Loss") 75 | axs[0].set_xlabel("Epoch") 76 | axs[0].grid() 77 | 78 | axs[1].plot(val_acc, linewidth=2, color="red") 79 | axs[1].set_title("Validation accuracy") 80 | axs[1].set_ylabel("Acc") 81 | axs[1].set_xlabel("Epoch") 82 | axs[1].grid() 83 | 84 | plt.show() 85 | 86 | 87 | if __name__ == '__main__': 88 | benchmark_plot() 89 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import tarfile 4 | 5 | import numpy as np 6 | import scipy.sparse as sparse 7 | 8 | """ 9 | Preprocessing follows the same implementation as in: 10 | https://github.com/tkipf/gcn 11 | https://github.com/senadkurtisi/pytorch-GCN/tree/main 12 | """ 13 | 14 | 15 | def download_cora(): 16 | """Downloads the cora dataset into a local cora folder.""" 17 | 18 | url = "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz" 19 | extract_to = "." 20 | 21 | if os.path.exists(os.path.join(extract_to, "cora")): 22 | return 23 | 24 | response = requests.get(url, stream=True) 25 | if response.status_code == 200: 26 | file_path = os.path.join(extract_to, url.split("/")[-1]) 27 | 28 | # Write the file to local disk 29 | with open(file_path, "wb") as file: 30 | file.write(response.raw.read()) 31 | 32 | # Extract the .tgz file 33 | with tarfile.open(file_path, "r:gz") as tar: 34 | tar.extractall(path=extract_to) 35 | print(f"Cora dataset extracted to {extract_to}") 36 | 37 | os.remove(file_path) 38 | 39 | 40 | def train_val_test_mask(labels, num_classes): 41 | """Splits the loaded dataset into train/validation/test sets.""" 42 | 43 | train_set = list(range(140)) 44 | validation_set = list(range(200, 500)) 45 | test_set = list(range(500, 1500)) 46 | 47 | return train_set, validation_set, test_set 48 | 49 | 50 | def enumerate_labels(labels): 51 | """Converts the labels from the original 52 | string form to the integer [0:MaxLabels-1] 53 | """ 54 | unique = list(set(labels)) 55 | labels = np.array([unique.index(label) for label in labels]) 56 | return labels 57 | 58 | 59 | def normalize_adjacency(adj): 60 | """Normalizes the adjacency matrix according to the 61 | paper by Kipf et al. 62 | https://arxiv.org/pdf/1609.02907.pdf 63 | """ 64 | adj = adj + sparse.eye(adj.shape[0]) 65 | 66 | node_degrees = np.array(adj.sum(1)) 67 | node_degrees = np.power(node_degrees, -0.5).flatten() 68 | node_degrees[np.isinf(node_degrees)] = 0.0 69 | node_degrees[np.isnan(node_degrees)] = 0.0 70 | degree_matrix = sparse.diags(node_degrees, dtype=np.float32) 71 | 72 | adj = degree_matrix @ adj @ degree_matrix 73 | return adj 74 | 75 | 76 | def load_data(config): 77 | """Loads the Cora graph data into MLX array format.""" 78 | print("Loading Cora dataset...") 79 | 80 | # Graph nodes 81 | raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str") 82 | raw_node_ids = raw_nodes_data[:, 0].astype( 83 | "int32" 84 | ) # unique identifier of each node 85 | raw_node_labels = raw_nodes_data[:, -1] 86 | labels_enumerated = enumerate_labels(raw_node_labels) # target labels as integers 87 | node_features = sparse.csr_matrix(raw_nodes_data[:, 1:-1], dtype="float32") 88 | 89 | # Edges 90 | ids_ordered = {raw_id: order for order, raw_id in enumerate(raw_node_ids)} 91 | raw_edges_data = np.genfromtxt(config.edges_path, dtype="int32") 92 | edges_ordered = np.array( 93 | list(map(ids_ordered.get, raw_edges_data.flatten())), dtype="int32" 94 | ).reshape(raw_edges_data.shape) 95 | 96 | # Adjacency matrix 97 | adj = sparse.coo_matrix( 98 | (np.ones(edges_ordered.shape[0]), (edges_ordered[:, 0], edges_ordered[:, 1])), 99 | shape=(labels_enumerated.shape[0], labels_enumerated.shape[0]), 100 | dtype=np.float32, 101 | ) 102 | 103 | # Make the adjacency matrix symmetric 104 | adj = adj + adj.T.multiply(adj.T > adj) 105 | adj = normalize_adjacency(adj) 106 | 107 | print("Dataset loaded.") 108 | return node_features.toarray(), labels_enumerated, adj.toarray() 109 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | cora/ 3 | 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from time import time 3 | 4 | import mlx.core as mx 5 | import mlx.nn as nn 6 | import mlx.optimizers as optim 7 | from mlx.nn.losses import cross_entropy 8 | from mlx.utils import tree_flatten 9 | 10 | from datasets import download_cora, load_data, train_val_test_mask 11 | from gcn import GCN 12 | 13 | 14 | def loss_fn(y_hat, y, weight_decay=0.0, parameters=None): 15 | l = mx.mean(nn.losses.cross_entropy(y_hat, y)) 16 | 17 | if weight_decay != 0.0: 18 | assert parameters != None, "Model parameters missing for L2 reg." 19 | 20 | l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt() 21 | return l + weight_decay * l2_reg 22 | 23 | return l 24 | 25 | 26 | def eval_fn(x, y): 27 | return mx.mean(mx.argmax(x, axis=1) == y) 28 | 29 | 30 | def forward_fn(gcn, x, adj, y, train_mask, weight_decay): 31 | y_hat = gcn(x, adj) 32 | loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters()) 33 | return loss, y_hat 34 | 35 | 36 | def to_mlx(x, y, adj, train_mask, val_mask, test_mask): 37 | x = mx.array(x, mx.float32) 38 | y = mx.array(y, mx.int32) 39 | adj = mx.array(adj) 40 | train_mask = mx.array(train_mask) 41 | val_mask = mx.array(val_mask) 42 | test_mask = mx.array(test_mask) 43 | return x, y, adj, train_mask, val_mask, test_mask 44 | 45 | 46 | def main(args): 47 | 48 | # Data loading 49 | download_cora() 50 | 51 | x, y, adj = load_data(args) 52 | train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes) 53 | 54 | x, y, adj, train_mask, val_mask, test_mask = \ 55 | to_mlx(x, y, adj, train_mask, val_mask, test_mask) 56 | 57 | gcn = GCN( 58 | x_dim=x.shape[-1], 59 | h_dim=args.hidden_dim, 60 | out_dim=args.nb_classes, 61 | nb_layers=args.nb_layers, 62 | dropout=args.dropout, 63 | bias=args.bias, 64 | ) 65 | mx.eval(gcn.parameters()) 66 | 67 | optimizer = optim.Adam(learning_rate=args.lr) 68 | loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn) 69 | 70 | best_val_loss = float("inf") 71 | cnt = 0 72 | times = [] 73 | 74 | # Training loop 75 | for epoch in range(args.epochs): 76 | start = time() 77 | 78 | # Loss 79 | (loss, y_hat), grads = loss_and_grad_fn( 80 | gcn, x, adj, y, train_mask, args.weight_decay 81 | ) 82 | optimizer.update(gcn, grads) 83 | mx.eval(gcn.parameters(), optimizer.state) 84 | 85 | # Validation 86 | val_loss = loss_fn(y_hat[val_mask], y[val_mask]) 87 | val_acc = eval_fn(y_hat[val_mask], y[val_mask]) 88 | 89 | times.append(time() - start) 90 | 91 | # Early stopping 92 | if val_loss < best_val_loss: 93 | best_val_loss = val_loss 94 | cnt = 0 95 | else: 96 | cnt += 1 97 | if cnt == args.patience: 98 | break 99 | 100 | print( 101 | " | ".join( 102 | [ 103 | f"Epoch: {epoch:3d}", 104 | f"Train loss: {loss.item():.3f}", 105 | f"Val loss: {val_loss.item():.3f}", 106 | f"Val acc: {val_acc.item():.2f}", 107 | ] 108 | ) 109 | ) 110 | 111 | # Test 112 | test_y_hat = gcn(x, adj) 113 | test_loss = loss_fn(y_hat[test_mask], y[test_mask]) 114 | test_acc = eval_fn(y_hat[test_mask], y[test_mask]) 115 | mean_time = sum(times) / len(times) 116 | 117 | print(f"Test loss: {test_loss.item():.3f} | Test acc: {test_acc.item():.2f}") 118 | print(f"Mean time: {mean_time:.5f}") 119 | return mean_time 120 | 121 | 122 | if __name__ == "__main__": 123 | 124 | parser = ArgumentParser() 125 | parser.add_argument("--nodes_path", type=str, default="cora/cora.content") 126 | parser.add_argument("--edges_path", type=str, default="cora/cora.cites") 127 | parser.add_argument("--hidden_dim", type=int, default=16) 128 | parser.add_argument("--dropout", type=float, default=0.5) 129 | parser.add_argument("--nb_layers", type=int, default=2) 130 | parser.add_argument("--nb_classes", type=int, default=7) 131 | parser.add_argument("--bias", type=bool, default=True) 132 | parser.add_argument("--lr", type=float, default=0.01) 133 | parser.add_argument("--weight_decay", type=float, default=5e-3) 134 | parser.add_argument("--patience", type=int, default=20) 135 | parser.add_argument("--epochs", type=int, default=100) 136 | args = parser.parse_args() 137 | 138 | main(args) 139 | -------------------------------------------------------------------------------- /main_torch.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from time import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from datasets import download_cora, load_data, train_val_test_mask 8 | 9 | 10 | class GCNLayer(nn.Module): 11 | def __init__(self, x_dim, h_dim, bias=True): 12 | super(GCNLayer, self).__init__() 13 | self.weight = nn.Parameter(torch.FloatTensor(torch.zeros(size=(x_dim, h_dim)))) 14 | if bias: 15 | self.bias = nn.Parameter(torch.FloatTensor(torch.zeros(size=(h_dim,)))) 16 | else: 17 | self.register_parameter('bias', None) 18 | 19 | self.initialize_weights() 20 | 21 | def initialize_weights(self): 22 | nn.init.xavier_uniform_(self.weight) 23 | if self.bias is not None: 24 | nn.init.zeros_(self.bias) 25 | 26 | def forward(self, x, adj): 27 | x = x @ self.weight 28 | if self.bias is not None: 29 | x += self.bias 30 | 31 | return torch.mm(adj, x) 32 | 33 | 34 | class GCN(nn.Module): 35 | def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True): 36 | super(GCN, self).__init__() 37 | 38 | layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim] 39 | self.gcn_layers = nn.Sequential(*[ 40 | GCNLayer(in_dim, out_dim, bias) 41 | for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:]) 42 | ]) 43 | self.dropout = nn.Dropout(p=dropout) 44 | 45 | def initialize_weights(self): 46 | self.gcn_1.initialize_weights() 47 | self.gcn_2.initialize_weights() 48 | 49 | def forward(self, x, adj): 50 | for layer in self.gcn_layers[:-1]: 51 | x = torch.relu(layer(x, adj)) 52 | x = self.dropout(x) 53 | 54 | x = self.gcn_layers[-1](x, adj) 55 | return x 56 | 57 | 58 | def to_torch(device, x, y, adj, train_mask, val_mask, test_mask): 59 | x = torch.tensor(x, dtype=torch.float32, device=device) 60 | y = torch.tensor(y, dtype=torch.long, device=device) 61 | adj = torch.tensor(adj, dtype=torch.float32, device=device) 62 | train_mask = torch.tensor(train_mask, device=device) 63 | val_mask = torch.tensor(val_mask, device=device) 64 | test_mask = torch.tensor(test_mask, device=device) 65 | return x, y, adj, train_mask, val_mask, test_mask 66 | 67 | 68 | def eval_fn(x, y): 69 | return torch.mean((torch.argmax(x, axis=1) == y).float()) 70 | 71 | def main(args, device): 72 | 73 | # Data loading 74 | download_cora() 75 | 76 | x, y, adj = load_data(args) 77 | train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes) 78 | 79 | x, y, adj, train_mask, val_mask, test_mask = \ 80 | to_torch(device, x, y, adj, train_mask, val_mask, test_mask) 81 | 82 | gcn = GCN( 83 | x_dim=x.shape[-1], 84 | h_dim=args.hidden_dim, 85 | out_dim=args.nb_classes, 86 | nb_layers=args.nb_layers, 87 | dropout=args.dropout, 88 | bias=args.bias, 89 | ).to(device) 90 | 91 | 92 | optimizer = torch.optim.Adam(gcn.parameters(), lr=args.lr, 93 | weight_decay=args.weight_decay) 94 | loss_fn = nn.CrossEntropyLoss() 95 | 96 | 97 | best_val_loss = float("inf") 98 | cnt = 0 99 | times = [] 100 | 101 | # Training loop 102 | for epoch in range(args.epochs): 103 | start = time() 104 | 105 | optimizer.zero_grad() 106 | gcn.train() 107 | 108 | y_hat = gcn(x, adj) 109 | loss = loss_fn(y_hat[train_mask], y[train_mask]) 110 | 111 | loss.backward() 112 | optimizer.step() 113 | 114 | # Validation 115 | with torch.no_grad(): 116 | gcn.eval() 117 | val_loss = loss_fn(y_hat[val_mask], y[val_mask]) 118 | val_acc = eval_fn(y_hat[val_mask], y[val_mask]) 119 | 120 | times.append(time() - start) 121 | 122 | # Early stopping 123 | if val_loss < best_val_loss: 124 | best_val_loss = val_loss 125 | cnt = 0 126 | else: 127 | cnt += 1 128 | if cnt == args.patience: 129 | break 130 | 131 | print( 132 | " | ".join( 133 | [ 134 | f"Epoch: {epoch:3d}", 135 | f"Train loss: {loss.item():.3f}", 136 | f"Val loss: {val_loss.item():.3f}", 137 | f"Val acc: {val_acc.item():.2f}", 138 | ] 139 | ) 140 | ) 141 | 142 | # Test 143 | test_y_hat = gcn(x, adj) 144 | test_loss = loss_fn(y_hat[test_mask], y[test_mask]) 145 | test_acc = eval_fn(y_hat[test_mask], y[test_mask]) 146 | mean_time = sum(times) / len(times) 147 | 148 | print(f"Test loss: {test_loss.item():.3f} | Test acc: {test_acc.item():.2f}") 149 | print(f"Mean time: {mean_time:.5f}") 150 | return mean_time 151 | 152 | 153 | if __name__ == "__main__": 154 | 155 | parser = ArgumentParser() 156 | parser.add_argument("--nodes_path", type=str, default="cora/cora.content") 157 | parser.add_argument("--edges_path", type=str, default="cora/cora.cites") 158 | parser.add_argument("--hidden_dim", type=int, default=16) 159 | parser.add_argument("--dropout", type=float, default=0.5) 160 | parser.add_argument("--nb_layers", type=int, default=2) 161 | parser.add_argument("--nb_classes", type=int, default=7) 162 | parser.add_argument("--bias", type=bool, default=True) 163 | parser.add_argument("--lr", type=float, default=0.01) 164 | parser.add_argument("--weight_decay", type=float, default=5e-3) 165 | parser.add_argument("--patience", type=int, default=20) 166 | parser.add_argument("--epochs", type=int, default=100) 167 | args = parser.parse_args() 168 | 169 | if torch.backends.mps.is_available(): 170 | device = torch.device("mps") 171 | print("Using MPS.") 172 | else: 173 | device = torch.device("cpu") 174 | print ("MPS device not found.") 175 | 176 | main(args, device) 177 | --------------------------------------------------------------------------------