├── bench.png
├── gcn.py
├── LICENSE
├── README.md
├── benchmark.py
├── viz.py
├── datasets.py
├── .gitignore
├── main.py
└── main_torch.py
/bench.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TristanBilot/mlx-GCN/HEAD/bench.png
--------------------------------------------------------------------------------
/gcn.py:
--------------------------------------------------------------------------------
1 | import mlx.nn as nn
2 |
3 |
4 | class GCNLayer(nn.Module):
5 | def __init__(self, in_features, out_features, bias=True):
6 | super(GCNLayer, self).__init__()
7 | self.linear = nn.Linear(in_features, out_features, bias)
8 |
9 | def __call__(self, x, adj):
10 | x = self.linear(x)
11 | return adj @ x
12 |
13 |
14 | class GCN(nn.Module):
15 | def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
16 | super(GCN, self).__init__()
17 |
18 | layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
19 | self.gcn_layers = [
20 | GCNLayer(in_dim, out_dim, bias)
21 | for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
22 | ]
23 | self.dropout = nn.Dropout(p=dropout)
24 |
25 | def __call__(self, x, adj):
26 | for layer in self.gcn_layers[:-1]:
27 | x = nn.relu(layer(x, adj))
28 | x = self.dropout(x)
29 |
30 | x = self.gcn_layers[-1](x, adj)
31 | return x
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Tristan Bilot
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Graph Convolutional Network in MLX
2 |
3 | An example of [GCN](https://arxiv.org/pdf/1609.02907.pdf%EF%BC%89) implementation with MLX. Other examples are available here.
4 |
5 | The actual benchmark on **M1 Pro**, **M2 Ultra**, **M3 Max** and **Tesla V100**s is explained in this Medium article.
6 |
7 | ### Install env and requirements
8 |
9 | ```
10 | CONDA_SUBDIR=osx-arm64 conda create -n mlx python=3.10 numpy pytorch scipy requests -c conda-forge
11 |
12 | conda activate mlx
13 | pip install mlx
14 | ```
15 |
16 | ### Run
17 | To try the model, just run the `main.py` file. This will download the Cora dataset, run the training and testing. The actual MLX code is located in `main.py`, whereas the PyTorch equivalent is in `main_torch.py`.
18 |
19 | ```
20 | python main.py
21 | ```
22 |
23 | ### Run benchmark
24 | To run the benchmark on CUDA device, a new env needs to be set up without the `CONDA_SUBDIR=osx-arm64` prefix, to be in i386 mode and not arm. For all other experiments on arm and Apple Silicon, just use the env created previously.
25 | ```
26 | python benchmark.py --experiment=[ mlx | torch_mps | torch_cpu | torch_cuda ]
27 | ```
28 |
29 | ### Process benchmark figure
30 | This needs to install additional packages: `matplotlib` and `scikit-learn`.
31 |
32 | ```
33 | python viz.py
34 | ```
35 |
36 |
37 |
--------------------------------------------------------------------------------
/benchmark.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import numpy as np
4 |
5 |
6 | def run(args):
7 | assert args.experiment in ["torch_cpu", "torch_mps", "torch_cuda", "mlx"], \
8 | "Invalid backend."
9 |
10 | if args.experiment == 'mlx':
11 | from main import main as mlx_main
12 |
13 | else:
14 | import torch
15 | from main_torch import main as torch_main
16 |
17 | if args.experiment == 'torch_mps':
18 | device = torch.device("mps")
19 | elif args.experiment == 'torch_cuda':
20 | device = torch.device("cuda")
21 | elif args.experiment == 'torch_cpu':
22 | device = torch.device("cpu")
23 |
24 | times = []
25 | for i in range(args.nb_experiment):
26 | if args.experiment == 'mlx':
27 | t = mlx_main(args)
28 | else:
29 | t = torch_main(args, device)
30 | times.append(t)
31 |
32 | mean = np.mean(times)
33 | std = np.std(times)
34 |
35 | print("")
36 | print(f"Mean training epoch time: {mean:.5f} seconds")
37 | print(f"Std training epoch time: {std:.5f} seconds")
38 |
39 |
40 | if __name__ == '__main__':
41 |
42 | parser = ArgumentParser()
43 | parser.add_argument("--experiment", type=str, default="mlx") # ["torch_cpu", "torch_mps", "torch_cuda", "mlx"]
44 | parser.add_argument("--nb_experiment", type=int, default=5)
45 |
46 | parser.add_argument("--nodes_path", type=str, default="cora/cora.content")
47 | parser.add_argument("--edges_path", type=str, default="cora/cora.cites")
48 | parser.add_argument("--hidden_dim", type=int, default=16)
49 | parser.add_argument("--dropout", type=float, default=0.5)
50 | parser.add_argument("--nb_layers", type=int, default=2)
51 | parser.add_argument("--nb_classes", type=int, default=7)
52 | parser.add_argument("--bias", type=bool, default=True)
53 | parser.add_argument("--lr", type=float, default=0.01)
54 | parser.add_argument("--weight_decay", type=float, default=5e-3)
55 | parser.add_argument("--patience", type=int, default=1000000) # we do not use patience in benchmark
56 | parser.add_argument("--epochs", type=int, default=50)
57 | args = parser.parse_args()
58 |
59 | run(args)
60 |
--------------------------------------------------------------------------------
/viz.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from sklearn.manifold import TSNE
4 |
5 | # Color map for each class
6 | cora_label_to_color_map = {0: "red", 1: "blue", 2: "green",
7 | 3: "orange", 4: "yellow", 5: "pink", 6: "gray"}
8 |
9 |
10 | def benchmark_plot():
11 | backends = [
12 | ' CPU', ' MPS', ' MLX',
13 | ' CPU', ' MPS', ' MLX',
14 | 'CPU', 'MPS', 'MLX',
15 | 'CUDA (PCIe)', 'CUDA (NVLINK)',
16 | ]
17 | times = [
18 | 45.58, 21.10, 9.02,
19 | 9.31, 7.19, 5.8,
20 | 7.07, 4.8, 4.72,
21 | 3.83, 3.51,
22 | ]
23 |
24 | # Setting colors
25 | colors = [
26 | '#FFCC99' for _ in range(3)] \
27 | + ['#9999FF' for _ in range(3)] \
28 | + ['#FF9999' for _ in range(3)] \
29 | + ['skyblue' for _ in range(2)]
30 |
31 | plt.figure(figsize=(10, 6))
32 | bars = plt.barh(backends, times, color=colors, edgecolor='black')
33 |
34 | plt.xlabel('Mean time per epoch (ms)')
35 | plt.title('Benchmarking MLX and PyTorch Backends')
36 | plt.gca().invert_yaxis()
37 |
38 | for index, value in enumerate(times):
39 | plt.text(value, index, f' {value} ms', va='center')
40 |
41 | # Adding a legend
42 | legend_elements = [
43 | plt.Line2D([0], [0], color='#FFCC99', lw=4, label='M1 Pro'),
44 | plt.Line2D([0], [0], color='#9999FF', lw=4, label='M2 Ultra'),
45 | plt.Line2D([0], [0], color='#FF9999', lw=4, label='M3 Max'),
46 | plt.Line2D([0], [0], color='skyblue', lw=4, label='Tesla V100')
47 | ]
48 | plt.legend(handles=legend_elements, loc='lower right')
49 |
50 |
51 | plt.savefig("bench.png")
52 |
53 |
54 | def visualize_embedding_tSNE(labels, out_features, num_classes):
55 | """ https://github.com/gordicaleksa/pytorch-GAT """
56 | t_sne_embeddings = TSNE(n_components=2, perplexity=30, method='barnes_hut').fit_transform(out_features)
57 |
58 | plt.figure()
59 | for class_id in range(num_classes):
60 | plt.scatter(t_sne_embeddings[node_labels == class_id, 0],
61 | t_sne_embeddings[node_labels == class_id, 1], s=20,
62 | color=cora_label_to_color_map[class_id],
63 | edgecolors='black', linewidths=0.15)
64 |
65 | plt.axis("off")
66 | plt.title("t-SNE projection of the learned features")
67 | plt.show()
68 |
69 |
70 | def visualize_validation_performance(val_acc, val_loss):
71 | f, axs = plt.subplots(1, 2, figsize=(13, 5.5))
72 | axs[0].plot(val_loss, linewidth=2, color="red")
73 | axs[0].set_title("Validation loss")
74 | axs[0].set_ylabel("Cross Entropy Loss")
75 | axs[0].set_xlabel("Epoch")
76 | axs[0].grid()
77 |
78 | axs[1].plot(val_acc, linewidth=2, color="red")
79 | axs[1].set_title("Validation accuracy")
80 | axs[1].set_ylabel("Acc")
81 | axs[1].set_xlabel("Epoch")
82 | axs[1].grid()
83 |
84 | plt.show()
85 |
86 |
87 | if __name__ == '__main__':
88 | benchmark_plot()
89 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 | import tarfile
4 |
5 | import numpy as np
6 | import scipy.sparse as sparse
7 |
8 | """
9 | Preprocessing follows the same implementation as in:
10 | https://github.com/tkipf/gcn
11 | https://github.com/senadkurtisi/pytorch-GCN/tree/main
12 | """
13 |
14 |
15 | def download_cora():
16 | """Downloads the cora dataset into a local cora folder."""
17 |
18 | url = "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz"
19 | extract_to = "."
20 |
21 | if os.path.exists(os.path.join(extract_to, "cora")):
22 | return
23 |
24 | response = requests.get(url, stream=True)
25 | if response.status_code == 200:
26 | file_path = os.path.join(extract_to, url.split("/")[-1])
27 |
28 | # Write the file to local disk
29 | with open(file_path, "wb") as file:
30 | file.write(response.raw.read())
31 |
32 | # Extract the .tgz file
33 | with tarfile.open(file_path, "r:gz") as tar:
34 | tar.extractall(path=extract_to)
35 | print(f"Cora dataset extracted to {extract_to}")
36 |
37 | os.remove(file_path)
38 |
39 |
40 | def train_val_test_mask(labels, num_classes):
41 | """Splits the loaded dataset into train/validation/test sets."""
42 |
43 | train_set = list(range(140))
44 | validation_set = list(range(200, 500))
45 | test_set = list(range(500, 1500))
46 |
47 | return train_set, validation_set, test_set
48 |
49 |
50 | def enumerate_labels(labels):
51 | """Converts the labels from the original
52 | string form to the integer [0:MaxLabels-1]
53 | """
54 | unique = list(set(labels))
55 | labels = np.array([unique.index(label) for label in labels])
56 | return labels
57 |
58 |
59 | def normalize_adjacency(adj):
60 | """Normalizes the adjacency matrix according to the
61 | paper by Kipf et al.
62 | https://arxiv.org/pdf/1609.02907.pdf
63 | """
64 | adj = adj + sparse.eye(adj.shape[0])
65 |
66 | node_degrees = np.array(adj.sum(1))
67 | node_degrees = np.power(node_degrees, -0.5).flatten()
68 | node_degrees[np.isinf(node_degrees)] = 0.0
69 | node_degrees[np.isnan(node_degrees)] = 0.0
70 | degree_matrix = sparse.diags(node_degrees, dtype=np.float32)
71 |
72 | adj = degree_matrix @ adj @ degree_matrix
73 | return adj
74 |
75 |
76 | def load_data(config):
77 | """Loads the Cora graph data into MLX array format."""
78 | print("Loading Cora dataset...")
79 |
80 | # Graph nodes
81 | raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str")
82 | raw_node_ids = raw_nodes_data[:, 0].astype(
83 | "int32"
84 | ) # unique identifier of each node
85 | raw_node_labels = raw_nodes_data[:, -1]
86 | labels_enumerated = enumerate_labels(raw_node_labels) # target labels as integers
87 | node_features = sparse.csr_matrix(raw_nodes_data[:, 1:-1], dtype="float32")
88 |
89 | # Edges
90 | ids_ordered = {raw_id: order for order, raw_id in enumerate(raw_node_ids)}
91 | raw_edges_data = np.genfromtxt(config.edges_path, dtype="int32")
92 | edges_ordered = np.array(
93 | list(map(ids_ordered.get, raw_edges_data.flatten())), dtype="int32"
94 | ).reshape(raw_edges_data.shape)
95 |
96 | # Adjacency matrix
97 | adj = sparse.coo_matrix(
98 | (np.ones(edges_ordered.shape[0]), (edges_ordered[:, 0], edges_ordered[:, 1])),
99 | shape=(labels_enumerated.shape[0], labels_enumerated.shape[0]),
100 | dtype=np.float32,
101 | )
102 |
103 | # Make the adjacency matrix symmetric
104 | adj = adj + adj.T.multiply(adj.T > adj)
105 | adj = normalize_adjacency(adj)
106 |
107 | print("Dataset loaded.")
108 | return node_features.toarray(), labels_enumerated, adj.toarray()
109 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | cora/
3 |
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # poetry
102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106 | #poetry.lock
107 |
108 | # pdm
109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110 | #pdm.lock
111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112 | # in version control.
113 | # https://pdm.fming.dev/#use-with-ide
114 | .pdm.toml
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbols
157 | cython_debug/
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | #.idea/
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from time import time
3 |
4 | import mlx.core as mx
5 | import mlx.nn as nn
6 | import mlx.optimizers as optim
7 | from mlx.nn.losses import cross_entropy
8 | from mlx.utils import tree_flatten
9 |
10 | from datasets import download_cora, load_data, train_val_test_mask
11 | from gcn import GCN
12 |
13 |
14 | def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
15 | l = mx.mean(nn.losses.cross_entropy(y_hat, y))
16 |
17 | if weight_decay != 0.0:
18 | assert parameters != None, "Model parameters missing for L2 reg."
19 |
20 | l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
21 | return l + weight_decay * l2_reg
22 |
23 | return l
24 |
25 |
26 | def eval_fn(x, y):
27 | return mx.mean(mx.argmax(x, axis=1) == y)
28 |
29 |
30 | def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
31 | y_hat = gcn(x, adj)
32 | loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())
33 | return loss, y_hat
34 |
35 |
36 | def to_mlx(x, y, adj, train_mask, val_mask, test_mask):
37 | x = mx.array(x, mx.float32)
38 | y = mx.array(y, mx.int32)
39 | adj = mx.array(adj)
40 | train_mask = mx.array(train_mask)
41 | val_mask = mx.array(val_mask)
42 | test_mask = mx.array(test_mask)
43 | return x, y, adj, train_mask, val_mask, test_mask
44 |
45 |
46 | def main(args):
47 |
48 | # Data loading
49 | download_cora()
50 |
51 | x, y, adj = load_data(args)
52 | train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes)
53 |
54 | x, y, adj, train_mask, val_mask, test_mask = \
55 | to_mlx(x, y, adj, train_mask, val_mask, test_mask)
56 |
57 | gcn = GCN(
58 | x_dim=x.shape[-1],
59 | h_dim=args.hidden_dim,
60 | out_dim=args.nb_classes,
61 | nb_layers=args.nb_layers,
62 | dropout=args.dropout,
63 | bias=args.bias,
64 | )
65 | mx.eval(gcn.parameters())
66 |
67 | optimizer = optim.Adam(learning_rate=args.lr)
68 | loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
69 |
70 | best_val_loss = float("inf")
71 | cnt = 0
72 | times = []
73 |
74 | # Training loop
75 | for epoch in range(args.epochs):
76 | start = time()
77 |
78 | # Loss
79 | (loss, y_hat), grads = loss_and_grad_fn(
80 | gcn, x, adj, y, train_mask, args.weight_decay
81 | )
82 | optimizer.update(gcn, grads)
83 | mx.eval(gcn.parameters(), optimizer.state)
84 |
85 | # Validation
86 | val_loss = loss_fn(y_hat[val_mask], y[val_mask])
87 | val_acc = eval_fn(y_hat[val_mask], y[val_mask])
88 |
89 | times.append(time() - start)
90 |
91 | # Early stopping
92 | if val_loss < best_val_loss:
93 | best_val_loss = val_loss
94 | cnt = 0
95 | else:
96 | cnt += 1
97 | if cnt == args.patience:
98 | break
99 |
100 | print(
101 | " | ".join(
102 | [
103 | f"Epoch: {epoch:3d}",
104 | f"Train loss: {loss.item():.3f}",
105 | f"Val loss: {val_loss.item():.3f}",
106 | f"Val acc: {val_acc.item():.2f}",
107 | ]
108 | )
109 | )
110 |
111 | # Test
112 | test_y_hat = gcn(x, adj)
113 | test_loss = loss_fn(y_hat[test_mask], y[test_mask])
114 | test_acc = eval_fn(y_hat[test_mask], y[test_mask])
115 | mean_time = sum(times) / len(times)
116 |
117 | print(f"Test loss: {test_loss.item():.3f} | Test acc: {test_acc.item():.2f}")
118 | print(f"Mean time: {mean_time:.5f}")
119 | return mean_time
120 |
121 |
122 | if __name__ == "__main__":
123 |
124 | parser = ArgumentParser()
125 | parser.add_argument("--nodes_path", type=str, default="cora/cora.content")
126 | parser.add_argument("--edges_path", type=str, default="cora/cora.cites")
127 | parser.add_argument("--hidden_dim", type=int, default=16)
128 | parser.add_argument("--dropout", type=float, default=0.5)
129 | parser.add_argument("--nb_layers", type=int, default=2)
130 | parser.add_argument("--nb_classes", type=int, default=7)
131 | parser.add_argument("--bias", type=bool, default=True)
132 | parser.add_argument("--lr", type=float, default=0.01)
133 | parser.add_argument("--weight_decay", type=float, default=5e-3)
134 | parser.add_argument("--patience", type=int, default=20)
135 | parser.add_argument("--epochs", type=int, default=100)
136 | args = parser.parse_args()
137 |
138 | main(args)
139 |
--------------------------------------------------------------------------------
/main_torch.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from time import time
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from datasets import download_cora, load_data, train_val_test_mask
8 |
9 |
10 | class GCNLayer(nn.Module):
11 | def __init__(self, x_dim, h_dim, bias=True):
12 | super(GCNLayer, self).__init__()
13 | self.weight = nn.Parameter(torch.FloatTensor(torch.zeros(size=(x_dim, h_dim))))
14 | if bias:
15 | self.bias = nn.Parameter(torch.FloatTensor(torch.zeros(size=(h_dim,))))
16 | else:
17 | self.register_parameter('bias', None)
18 |
19 | self.initialize_weights()
20 |
21 | def initialize_weights(self):
22 | nn.init.xavier_uniform_(self.weight)
23 | if self.bias is not None:
24 | nn.init.zeros_(self.bias)
25 |
26 | def forward(self, x, adj):
27 | x = x @ self.weight
28 | if self.bias is not None:
29 | x += self.bias
30 |
31 | return torch.mm(adj, x)
32 |
33 |
34 | class GCN(nn.Module):
35 | def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
36 | super(GCN, self).__init__()
37 |
38 | layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
39 | self.gcn_layers = nn.Sequential(*[
40 | GCNLayer(in_dim, out_dim, bias)
41 | for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
42 | ])
43 | self.dropout = nn.Dropout(p=dropout)
44 |
45 | def initialize_weights(self):
46 | self.gcn_1.initialize_weights()
47 | self.gcn_2.initialize_weights()
48 |
49 | def forward(self, x, adj):
50 | for layer in self.gcn_layers[:-1]:
51 | x = torch.relu(layer(x, adj))
52 | x = self.dropout(x)
53 |
54 | x = self.gcn_layers[-1](x, adj)
55 | return x
56 |
57 |
58 | def to_torch(device, x, y, adj, train_mask, val_mask, test_mask):
59 | x = torch.tensor(x, dtype=torch.float32, device=device)
60 | y = torch.tensor(y, dtype=torch.long, device=device)
61 | adj = torch.tensor(adj, dtype=torch.float32, device=device)
62 | train_mask = torch.tensor(train_mask, device=device)
63 | val_mask = torch.tensor(val_mask, device=device)
64 | test_mask = torch.tensor(test_mask, device=device)
65 | return x, y, adj, train_mask, val_mask, test_mask
66 |
67 |
68 | def eval_fn(x, y):
69 | return torch.mean((torch.argmax(x, axis=1) == y).float())
70 |
71 | def main(args, device):
72 |
73 | # Data loading
74 | download_cora()
75 |
76 | x, y, adj = load_data(args)
77 | train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes)
78 |
79 | x, y, adj, train_mask, val_mask, test_mask = \
80 | to_torch(device, x, y, adj, train_mask, val_mask, test_mask)
81 |
82 | gcn = GCN(
83 | x_dim=x.shape[-1],
84 | h_dim=args.hidden_dim,
85 | out_dim=args.nb_classes,
86 | nb_layers=args.nb_layers,
87 | dropout=args.dropout,
88 | bias=args.bias,
89 | ).to(device)
90 |
91 |
92 | optimizer = torch.optim.Adam(gcn.parameters(), lr=args.lr,
93 | weight_decay=args.weight_decay)
94 | loss_fn = nn.CrossEntropyLoss()
95 |
96 |
97 | best_val_loss = float("inf")
98 | cnt = 0
99 | times = []
100 |
101 | # Training loop
102 | for epoch in range(args.epochs):
103 | start = time()
104 |
105 | optimizer.zero_grad()
106 | gcn.train()
107 |
108 | y_hat = gcn(x, adj)
109 | loss = loss_fn(y_hat[train_mask], y[train_mask])
110 |
111 | loss.backward()
112 | optimizer.step()
113 |
114 | # Validation
115 | with torch.no_grad():
116 | gcn.eval()
117 | val_loss = loss_fn(y_hat[val_mask], y[val_mask])
118 | val_acc = eval_fn(y_hat[val_mask], y[val_mask])
119 |
120 | times.append(time() - start)
121 |
122 | # Early stopping
123 | if val_loss < best_val_loss:
124 | best_val_loss = val_loss
125 | cnt = 0
126 | else:
127 | cnt += 1
128 | if cnt == args.patience:
129 | break
130 |
131 | print(
132 | " | ".join(
133 | [
134 | f"Epoch: {epoch:3d}",
135 | f"Train loss: {loss.item():.3f}",
136 | f"Val loss: {val_loss.item():.3f}",
137 | f"Val acc: {val_acc.item():.2f}",
138 | ]
139 | )
140 | )
141 |
142 | # Test
143 | test_y_hat = gcn(x, adj)
144 | test_loss = loss_fn(y_hat[test_mask], y[test_mask])
145 | test_acc = eval_fn(y_hat[test_mask], y[test_mask])
146 | mean_time = sum(times) / len(times)
147 |
148 | print(f"Test loss: {test_loss.item():.3f} | Test acc: {test_acc.item():.2f}")
149 | print(f"Mean time: {mean_time:.5f}")
150 | return mean_time
151 |
152 |
153 | if __name__ == "__main__":
154 |
155 | parser = ArgumentParser()
156 | parser.add_argument("--nodes_path", type=str, default="cora/cora.content")
157 | parser.add_argument("--edges_path", type=str, default="cora/cora.cites")
158 | parser.add_argument("--hidden_dim", type=int, default=16)
159 | parser.add_argument("--dropout", type=float, default=0.5)
160 | parser.add_argument("--nb_layers", type=int, default=2)
161 | parser.add_argument("--nb_classes", type=int, default=7)
162 | parser.add_argument("--bias", type=bool, default=True)
163 | parser.add_argument("--lr", type=float, default=0.01)
164 | parser.add_argument("--weight_decay", type=float, default=5e-3)
165 | parser.add_argument("--patience", type=int, default=20)
166 | parser.add_argument("--epochs", type=int, default=100)
167 | args = parser.parse_args()
168 |
169 | if torch.backends.mps.is_available():
170 | device = torch.device("mps")
171 | print("Using MPS.")
172 | else:
173 | device = torch.device("cpu")
174 | print ("MPS device not found.")
175 |
176 | main(args, device)
177 |
--------------------------------------------------------------------------------