├── .gitignore ├── README.md ├── figures ├── all_path_cluster.png └── ws-flex_path_cluster.png ├── nx_ops ├── __init__.py ├── create.py ├── sample.py └── ws_flex.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | save 2 | *.onnx 3 | .*/ 4 | 5 | # Created by https://www.gitignore.io/api/vim,macosPeepOpenython,visualstudiocode 6 | # Edit at https://www.gitignore.io/?templates=vim,macosPeepOpenython,visualstudiocode 7 | 8 | ### Compressed ### 9 | *.7z 10 | *.deb 11 | *.gz 12 | *.pkg 13 | *.rar 14 | *.rpm 15 | *.sit 16 | *.sitx 17 | *.tar 18 | *.zip 19 | *.zipx 20 | *.tgz 21 | 22 | ### macOS ### 23 | # General 24 | .DS_Store 25 | .AppleDouble 26 | .LSOverride 27 | 28 | # Icon must end with two \r 29 | Icon 30 | 31 | # Thumbnails 32 | ._* 33 | 34 | # Files that might appear in the root of a volume 35 | .DocumentRevisions-V100 36 | .fseventsd 37 | .Spotlight-V100 38 | .TemporaryItems 39 | .Trashes 40 | .VolumeIcon.icns 41 | .com.apple.timemachine.donotpresent 42 | 43 | # Directories potentially created on remote AFP share 44 | .AppleDB 45 | .AppleDesktop 46 | Network Trash Folder 47 | Temporary Items 48 | .apdisk 49 | 50 | ### Python ### 51 | # Byte-compiled / optimized / DLL files 52 | __pycache__/ 53 | *.py[cod] 54 | *$py.class 55 | 56 | # C extensions 57 | *.so 58 | 59 | # Distribution / packaging 60 | .Python 61 | build/ 62 | develop-eggs/ 63 | dist/ 64 | downloads/ 65 | eggs/ 66 | .eggs/ 67 | lib/ 68 | lib64/ 69 | parts/ 70 | sdist/ 71 | var/ 72 | wheels/ 73 | pip-wheel-metadata/ 74 | share/python-wheels/ 75 | *.egg-info/ 76 | .installed.cfg 77 | *.egg 78 | MANIFEST 79 | 80 | # PyInstaller 81 | # Usually these files are written by a python script from a template 82 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 83 | *.manifest 84 | *.spec 85 | 86 | # Installer logs 87 | pip-log.txt 88 | pip-delete-this-directory.txt 89 | 90 | # Unit test / coverage reports 91 | htmlcov/ 92 | .tox/ 93 | .nox/ 94 | .coverage 95 | .coverage.* 96 | .cache 97 | nosetests.xml 98 | coverage.xml 99 | *.cover 100 | .hypothesis/ 101 | .pytest_cache/ 102 | 103 | # Translations 104 | *.mo 105 | *.pot 106 | 107 | # Scrapy stuff: 108 | .scrapy 109 | 110 | # Sphinx documentation 111 | docs/_build/ 112 | 113 | # PyBuilder 114 | target/ 115 | 116 | # pyenv 117 | .python-version 118 | 119 | # pipenv 120 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 121 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 122 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 123 | # install all needed dependencies. 124 | #Pipfile.lock 125 | 126 | # celery beat schedule file 127 | celerybeat-schedule 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # Mr Developer 140 | .mr.developer.cfg 141 | .project 142 | .pydevproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | ### Vim ### 156 | # Swap 157 | [._]*.s[a-v][a-z] 158 | [._]*.sw[a-p] 159 | [._]s[a-rt-v][a-z] 160 | [._]ss[a-gi-z] 161 | [._]sw[a-p] 162 | 163 | # Session 164 | Session.vim 165 | Sessionx.vim 166 | 167 | # Temporary 168 | .netrwhist 169 | *~ 170 | 171 | # Auto-generated tag files 172 | tags 173 | 174 | # Persistent undo 175 | [._]*.un~ 176 | 177 | # Coc configuration directory 178 | .vim 179 | 180 | ### VisualStudioCode ### 181 | .vscode/* 182 | !.vscode/settings.json 183 | !.vscode/tasks.json 184 | !.vscode/launch.json 185 | !.vscode/extensions.json 186 | 187 | ### VisualStudioCode Patch ### 188 | # Ignore all local history of files 189 | .history 190 | 191 | # End of https://www.gitignore.io/api/vim,macosPeepOpenython,visualstudiocode 192 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph-Structure-of-Neural-Networks 2 | An unofficial re-implementation of Graph Structure of Neural Networks (Jiaxuan You · Kaiming He · Jure Leskovec · Saining Xie) ICML 2020 3 | https://arxiv.org/abs/2007.06559 4 | 5 | # TODO 6 | - [ ] Graph 7 | - [x] Graph Generator 8 | - [x] WS-flex 9 | - [x] sample 10 | - [x] analyse **clustering coeffient** and **average length** of graphs generated by this generator 11 | - [x] Sample Graphs 12 | - [ ] remove isomorphic graphs by Hash ( Not implemented in paper, but I think it should be done) 13 | - [ ] Graph to Neural Network Converter 14 | - [ ] Mask out Linear 15 | - [ ] Mask out Conv2d 16 | - [ ] Mask out SeperateConvo2d 17 | - [ ] Evaluate 18 | - [ ] Train/Eval on cifar10 19 | - [ ] Train/Eval on ImageNet 20 | - [ ] Analyze 21 | 22 | ## Graph 23 | ### Graph Generator 24 | All | WS-flex 25 | :---:|:---: 26 | ![](./figures/all_path_cluster.png) | ![](./figures/ws-flex_path_cluster.png) 27 | 1. Generate Graphs 28 | ```bash 29 | python -m nx_ops.create create_all 30 | ``` 31 | 2. Analyse Clustering coeffient and average length. 32 | 33 | Use pandas and draw graphs by ipython, like: 34 | ```python 35 | import pandas 36 | from matplotlib import pyplot as plt 37 | df_ws = pd.read_csv("save/csv/ws-paper.csv.gz") 38 | df_ws.dropna().sample(4000).plot.scatter("cluster_coefficient", "avg_path_length",alpha=0.25, color="yellow", ax=ax, label="WS") 39 | ``` 40 | -------------------------------------------------------------------------------- /figures/all_path_cluster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CuriousCat-7/Graph-Structure-of-Neural-Networks/342bc4be2de730278e99c0a43cdf3708b0cbd025/figures/all_path_cluster.png -------------------------------------------------------------------------------- /figures/ws-flex_path_cluster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CuriousCat-7/Graph-Structure-of-Neural-Networks/342bc4be2de730278e99c0a43cdf3708b0cbd025/figures/ws-flex_path_cluster.png -------------------------------------------------------------------------------- /nx_ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CuriousCat-7/Graph-Structure-of-Neural-Networks/342bc4be2de730278e99c0a43cdf3708b0cbd025/nx_ops/__init__.py -------------------------------------------------------------------------------- /nx_ops/create.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .sample import * 3 | from loguru import logger 4 | from glob import glob 5 | import json 6 | 7 | 8 | def _dump(f, path, overwrite): 9 | if os.path.exists(path) and not overwrite: 10 | logger.warning("{} already exists", path) 11 | return 12 | logger.info(path) 13 | f().to_csv(path, compression="gzip") 14 | 15 | 16 | def create_all(save_root="save/csv/", overwrite=False): 17 | os.makedirs(save_root, exist_ok=True) 18 | _dump(sample_harary, os.path.join(save_root, "harary-paper.csv.gz"), overwrite) 19 | _dump(sample_ring, os.path.join(save_root, "ring-paper.csv.gz"), overwrite) 20 | _dump(sample_er, os.path.join(save_root, "er-paper.csv.gz"), overwrite) 21 | _dump(sample_ba, os.path.join(save_root, "ba-paper.csv.gz"), overwrite) 22 | _dump(sample_ws, os.path.join(save_root, "ws-paper.csv.gz"), overwrite) 23 | _dump(sample_ws_flex, os.path.join(save_root, "ws_flex-paper.csv.gz"), overwrite) 24 | 25 | 26 | def calculate_avg_cluster_path(save_root="save/csv/"): 27 | csv_paths = glob(f"{save_root}/*.csv.gz") 28 | for csv_path in csv_paths: 29 | df = pd.read_csv(csv_path) 30 | ccs = [] 31 | apls = [] 32 | logger.info(csv_path) 33 | for i in tqdm(range(len(df))): 34 | g = nx.node_link_graph(eval(df.iloc[i].graph)) 35 | d = get_avg_cluater_path(g) 36 | ccs.append(d["cluster_coefficient"]) 37 | apls.append(d["avg_path_length"]) 38 | 39 | df["cluster_coefficient"] = ccs 40 | df["avg_path_length"] = apls 41 | df.to_csv(csv_path, compression="gzip") 42 | 43 | 44 | if __name__ == "__main__": 45 | fire.Fire() 46 | -------------------------------------------------------------------------------- /nx_ops/sample.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from networkx.generators.harary_graph import hnm_harary_graph 3 | from .ws_flex import watts_strogatz_flexible_graph 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | import fire 8 | from loguru import logger 9 | 10 | 11 | def get_avg_cluater_path(g: nx.Graph) -> dict: 12 | try: 13 | cc = nx.average_clustering(g) 14 | except nx.NetworkXError as e: 15 | logger.warning(e) 16 | cc = np.nan 17 | try: 18 | apl = nx.average_shortest_path_length(g) 19 | except nx.NetworkXError as e: 20 | logger.warning(e) 21 | apl = np.nan 22 | return dict( 23 | cluster_coefficient=cc, 24 | avg_path_length=apl, 25 | ) 26 | 27 | 28 | def sample_ws(n=64, k_max=62, k_min=8, p_num=300, seed_num=30) -> pd.DataFrame: 29 | rows = [] 30 | assert k_min < k_max <= n 31 | pbar = tqdm(total=int((k_max-k_min)*p_num*seed_num)) 32 | for k in np.arange(k_min, k_max): 33 | for p in np.linspace(0, 1, p_num)**2: 34 | for seed in range(seed_num): 35 | g = nx.generators.watts_strogatz_graph(n, k, p, seed=seed) 36 | rows.append(dict( 37 | method="ws", n=n, k=k, p=p, 38 | graph=nx.node_link_data(g), 39 | )) 40 | pbar.update() 41 | return pd.DataFrame(rows) 42 | 43 | 44 | def sample_ba(n=64, m_max=30, m_min=4, seed_num=300): 45 | rows = [] 46 | pbar = tqdm(total=int(m_max-m_min)*seed_num) 47 | for m in np.arange(m_min, m_max): 48 | for seed in range(seed_num): 49 | g = nx.generators.barabasi_albert_graph(n, m, seed=seed) 50 | rows.append(dict( 51 | method="ba", n=n, m=m, 52 | graph=nx.node_link_data(g), 53 | )) 54 | pbar.update() 55 | return pd.DataFrame(rows) 56 | 57 | 58 | def sample_ws_flex(n=64, k_max=62, k_min=8, p_num=300, seed_num=30) -> pd.DataFrame: 59 | rows = [] 60 | assert k_min < k_max <= n 61 | pbar = tqdm(total=int((k_max-k_min)*p_num*seed_num)) 62 | for k in np.arange(k_min, k_max): 63 | for p in np.linspace(0, 1, p_num)**2: 64 | for seed in range(seed_num): 65 | g = watts_strogatz_flexible_graph(n, k, p, seed=seed) 66 | rows.append(dict( 67 | method="ws-flex", n=n, k=k, p=p, 68 | graph=nx.node_link_data(g), 69 | )) 70 | pbar.update() 71 | return pd.DataFrame(rows) 72 | 73 | 74 | def sample_er(n=64, m_max=int(64*63/2), m_min=int(64*4), seed_num=30): 75 | rows = [] 76 | pbar = tqdm(total=int(m_max-m_min)*seed_num) 77 | e = n*(n-1) 78 | for m in np.arange(m_min, m_max): 79 | for seed in range(seed_num): 80 | g = nx.generators.erdos_renyi_graph(n, m/e) 81 | rows.append(dict( 82 | method="er", n=n, m=m, 83 | graph=nx.node_link_data(g), 84 | )) 85 | pbar.update() 86 | return pd.DataFrame(rows) 87 | 88 | 89 | def sample_ring(n=64, k_max=62, k_min=8): 90 | rows = [] 91 | pbar = tqdm(total=k_max-k_min) 92 | for k in np.arange(k_min, k_max): 93 | g = nx.generators.watts_strogatz_graph(n, k, 0.0) 94 | rows.append(dict( 95 | method="ring", n=n, k=k, 96 | graph=nx.node_link_data(g), 97 | )) 98 | pbar.update() 99 | return pd.DataFrame(rows) 100 | 101 | 102 | def sample_harary(n=64, m_max=int(64*63/2), m_min=int(64*4)): 103 | rows = [] 104 | pbar = tqdm(total=m_max-m_min) 105 | for m in np.arange(m_min, m_max): 106 | g = hnm_harary_graph(n, m) 107 | rows.append(dict( 108 | method="harary", n=n, m=m, 109 | graph=nx.node_link_data(g), 110 | )) 111 | pbar.update() 112 | return pd.DataFrame(rows) 113 | 114 | 115 | if __name__ == "__main__": 116 | fire.Fire() 117 | -------------------------------------------------------------------------------- /nx_ops/ws_flex.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from networkx.utils import py_random_state 3 | 4 | 5 | @py_random_state(3) 6 | def watts_strogatz_flexible_graph(n, k, p, seed=None): 7 | """Returns a Watts–Strogatz flexible small-world graph defined by 8 | https://arxiv.org/abs/2007.06559 9 | 10 | Parameters 11 | ---------- 12 | n : int 13 | The number of nodes 14 | k : float or int 15 | Each node is joined with its `k` nearest neighbors in a ring 16 | topology. 17 | p : float 18 | The probability of rewiring each edge 19 | seed : integer, random_state, or None (default) 20 | Indicator of random number generation state. 21 | See :ref:`Randomness`. 22 | 23 | See Also 24 | -------- 25 | watts_strogatz_graph() 26 | 27 | References 28 | --------- 29 | .. [1] You, Jiaxuan et al. “Graph Structure of Neural Networks.” (2020). 30 | """ 31 | if k > n: 32 | raise nx.NetworkXError("k>n, choose smaller k or larger n") 33 | 34 | # If k == n, the graph is complete not Watts-Strogatz 35 | if k == n: 36 | return nx.complete_graph(n) 37 | 38 | G = nx.Graph() 39 | nodes = list(range(n)) # nodes are labeled 0 to n-1 40 | # connect each node to k/2 neighbors 41 | for j in range(1, int(k//2) + 1): 42 | targets = nodes[j:] + nodes[0:j] # first j nodes are now last in list 43 | G.add_edges_from(zip(nodes, targets)) 44 | 45 | # picks e mod n nodes and connects each node to 46 | # one closest neighboring node 47 | potential_nodes = list(range(n)) 48 | seed.shuffle(potential_nodes) 49 | total_add_edges = int(n*k//2) % n 50 | add_edge_num = 0 51 | for j in potential_nodes: 52 | if add_edge_num == total_add_edges: 53 | break 54 | i = (j + int(k//2) + 1) % n 55 | if G.has_edge(j, i): 56 | continue # when n = 4, k = 3, additional edges may crash 57 | else: 58 | G.add_edge(j, i) 59 | add_edge_num += 1 60 | 61 | # rewire edges from each node 62 | # loop over all nodes in order (label) and neighbors in order (distance) 63 | # no self loops or multiple edges allowed 64 | for j in range(1, int(k//2) + 1): # outer loop is neighbors 65 | targets = nodes[j:] + nodes[0:j] # first j nodes are now last in list 66 | # inner loop in node order 67 | for u, v in zip(nodes, targets): 68 | if seed.random() < p: 69 | w = seed.choice(nodes) 70 | # Enforce no self-loops or multiple edges 71 | while w == u or G.has_edge(u, w): 72 | w = seed.choice(nodes) 73 | if G.degree(u) >= n - 1: 74 | break # skip this rewiring 75 | else: 76 | G.remove_edge(u, v) 77 | G.add_edge(u, w) 78 | return G 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | networkx==2.4 2 | pandas 3 | fire 4 | loguru 5 | numpy 6 | --------------------------------------------------------------------------------