├── .gitignore
├── README.md
├── archs
└── examples
│ └── sbms_cluster.yaml
├── data
├── COLLAB.py
├── QM9.py
├── SBMs.py
├── SBMs
│ ├── generate_SBM_CLUSTER.ipynb
│ └── generate_SBM_PATTERN.ipynb
├── TSP.py
├── TSP
│ ├── generate_TSP.py
│ └── prepare_TSP.ipynb
├── __init__.py
├── cora.py
├── molecules.py
├── molecules
│ └── prepare_molecules.ipynb
├── superpixels.py
└── superpixels
│ ├── prepare_superpixels_CIFAR.ipynb
│ └── prepare_superpixels_MNIST.ipynb
├── environment_gpu.yml
├── example_geno.yaml
├── models
├── architect.py
├── cell_search.py
├── cell_train.py
├── mixed.py
├── model_search.py
├── model_train.py
├── networks.py
└── operations.py
├── scripts
├── search_molecules_zinc.sh
├── search_sbms_cluster.sh
├── search_sbms_pattern.sh
├── search_superpixels_cifar10.sh
├── search_superpixels_mnist.sh
├── train_molecules_zinc.sh
├── train_sbms_cluster.sh
├── train_sbms_pattern.sh
├── train_superpixels_cifar10.sh
└── train_superpixels_mnist.sh
├── search.py
├── train.py
└── utils
├── record_utils.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | /dataset
132 | .vscode
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
137 | # pytype static type analyzer
138 | .pytype/
139 |
140 | # Cython debug symbols
141 | cython_debug/
142 |
143 | debug.ipynb
144 | *.png
145 | *.txt
146 | *.gz
147 | *.pt
148 | *.pkl
149 | *.json
150 | *.pickle
151 | *.pkl
152 | *.index
153 | *.zip
154 | *.pdf
155 | pics
156 | runs
157 | archs
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## 🔥News🔥
3 | Hi, if you like this work, you may be interested in our new work, too.
4 |
5 | Welcome to our latest work [Automatic Relation-aware Graph Network Proliferation](https://github.com/phython96/ARGNP).
6 |
7 | **This work has been accepted by CVPR2022 and selected as an ORAL presentation.**
8 |
9 | In the latest work, we have achieved state-of-the-art results on SBM_CLUSTER (77.35% OA), ZINC_100k (0.136 MAE), CIFAR10 (73.90% OA), TSP (0.855 F1-score) datasets and so on.
10 |
11 | ### What's new?
12 |
13 | 1. **We devise a novel dual relation-aware graph search space that comprises both node and relation learning operations.**
14 | So, the ARGNP can leverage the edge attributes in some datasets, such as ZINC.
15 | It significantly improves the graph representative capability.
16 | Interestingly, we also observe the performance improvement even if there is no available edge attributes in some datasets.
17 |
18 | 2. **We design a network proliferation search paradigm (NPSP) to progressively determine the GNN architectures by iteratively performing network division and differentiation.**
19 | The network proliferation search paradigm decomposes the training of global supernet into sequential local supernets optimization, which alleviates the interference among child graph neural architectures. It reduces the spatial-time complexity from quadratic to linear and enables the search to thoroughly free from the cell-sharing trick.
20 |
21 | 3. **Our framework is suitable for solving node-level, edge-level, and graph-level tasks. The codes are easy to use.**
22 |
23 | ---
24 |
25 | # Rethinking Graph Neural Architecture Search from Message-passing
26 |
27 |
28 |
29 |
30 |
31 | ## Getting Started
32 |
33 | ### 0. Prerequisites
34 |
35 | + Linux
36 | + NVIDIA GPU + CUDA CuDNN
37 |
38 | ### 1. Setup Python Environment
39 |
40 | ```sh
41 | # clone Github repo
42 | conda install git
43 | git clone https://github.com/phython96/GNAS-MP.git
44 | cd GNAS-MP
45 |
46 | # Install python environment
47 | conda env create -f environment_gpu.yml
48 | conda activate gnasmp
49 | ```
50 |
51 | ### 2. Download datasets
52 |
53 | The datasets are provided by project [benchmarking-gnns](https://github.com/graphdeeplearning/benchmarking-gnns), you can click [here](https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/docs/02_download_datasets.md) to download all the required datasets.
54 |
55 | ### 3. Search Architectures
56 |
57 | ```sh
58 | sh scripts/search_molecules_zinc.sh [gpu_id]
59 | ```
60 |
61 | ### 4. Train & Test
62 |
63 | ```
64 | sh scripts/train_molecules_zinc.sh [gpu_id] '[path_to_genotypes]/example.yaml'
65 | ```
66 |
67 | ## Reference
68 | ```latex
69 | @inproceedings{cai2021rethinking,
70 | title={Rethinking Graph Neural Architecture Search from Message-passing},
71 | author={Cai, Shaofei and Li, Liang and Deng, Jincan and Zhang, Beichen and Zha, Zheng-Jun and Su, Li and Huang, Qingming},
72 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
73 | pages={6657--6666},
74 | year={2021}
75 | }
76 | ```
77 |
--------------------------------------------------------------------------------
/archs/examples/sbms_cluster.yaml:
--------------------------------------------------------------------------------
1 | Genotype:
2 | - id: 0
3 | topology:
4 | - dst: 1
5 | ops: V_Dense
6 | src: 0
7 | - dst: 2
8 | ops: V_Sparse
9 | src: 1
10 | - dst: 3
11 | ops: V_Sum
12 | src: 1
13 | - dst: 4
14 | ops: V_Sum
15 | src: 2
16 | - dst: 5
17 | ops: V_Dense
18 | src: 3
19 | - dst: 6
20 | ops: V_Sparse
21 | src: 3
22 | - id: 1
23 | topology:
24 | - dst: 1
25 | ops: V_Sparse
26 | src: 0
27 | - dst: 2
28 | ops: V_I
29 | src: 0
30 | - dst: 3
31 | ops: V_Sum
32 | src: 1
33 | - dst: 4
34 | ops: V_Sum
35 | src: 2
36 | - dst: 5
37 | ops: V_Sparse
38 | src: 3
39 | - dst: 6
40 | ops: V_Dense
41 | src: 4
42 | - id: 2
43 | topology:
44 | - dst: 1
45 | ops: V_Sparse
46 | src: 0
47 | - dst: 2
48 | ops: V_I
49 | src: 0
50 | - dst: 3
51 | ops: V_Sum
52 | src: 1
53 | - dst: 4
54 | ops: V_Sum
55 | src: 2
56 | - dst: 5
57 | ops: V_Dense
58 | src: 4
59 | - dst: 6
60 | ops: V_Sparse
61 | src: 3
62 | - id: 3
63 | topology:
64 | - dst: 1
65 | ops: V_Dense
66 | src: 0
67 | - dst: 2
68 | ops: V_Dense
69 | src: 0
70 | - dst: 3
71 | ops: V_Sum
72 | src: 1
73 | - dst: 4
74 | ops: V_Sum
75 | src: 2
76 | - dst: 5
77 | ops: V_Sparse
78 | src: 3
79 | - dst: 6
80 | ops: V_Dense
81 | src: 4
82 |
--------------------------------------------------------------------------------
/data/COLLAB.py:
--------------------------------------------------------------------------------
1 | import time
2 | import dgl
3 | import torch
4 | from torch.utils.data import Dataset
5 |
6 | from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
7 |
8 | from scipy import sparse as sp
9 | import numpy as np
10 |
11 |
12 | def positional_encoding(g, pos_enc_dim):
13 | """
14 | Graph positional encoding v/ Laplacian eigenvectors
15 | """
16 |
17 | # Laplacian
18 | A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
19 | N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
20 | L = sp.eye(g.number_of_nodes()) - N * A * N
21 |
22 | # # Eigenvectors with numpy
23 | # EigVal, EigVec = np.linalg.eig(L.toarray())
24 | # idx = EigVal.argsort() # increasing order
25 | # EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx])
26 | # g.ndata['pos_enc'] = torch.from_numpy(np.abs(EigVec[:,1:pos_enc_dim+1])).float()
27 |
28 | # Eigenvectors with scipy
29 | #EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR')
30 | EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR', tol=1e-2)
31 | EigVec = EigVec[:, EigVal.argsort()] # increasing order
32 | g.ndata['pos_enc'] = torch.from_numpy(np.real(EigVec[:,1:pos_enc_dim+1])).float()
33 |
34 | return g
35 |
36 |
37 | class COLLABDataset(Dataset):
38 | def __init__(self, name):
39 | start = time.time()
40 | print("[I] Loading dataset %s..." % (name))
41 | self.name = name
42 | self.dataset = DglLinkPropPredDataset(name='ogbl-collab')
43 |
44 | self.graph = self.dataset[0] # single DGL graph
45 |
46 | # Create edge feat by concatenating weight and year
47 | self.graph.edata['feat'] = torch.cat(
48 | [self.graph.edata['weight'], self.graph.edata['year']],
49 | dim=1
50 | )
51 |
52 | self.split_edge = self.dataset.get_edge_split()
53 | self.train_edges = self.split_edge['train']['edge'] # positive train edges
54 | self.val_edges = self.split_edge['valid']['edge'] # positive val edges
55 | self.val_edges_neg = self.split_edge['valid']['edge_neg'] # negative val edges
56 | self.test_edges = self.split_edge['test']['edge'] # positive test edges
57 | self.test_edges_neg = self.split_edge['test']['edge_neg'] # negative test edges
58 |
59 | self.evaluator = Evaluator(name='ogbl-collab')
60 |
61 | print("[I] Finished loading.")
62 | print("[I] Data load time: {:.4f}s".format(time.time()-start))
63 |
64 | def _add_positional_encodings(self, pos_enc_dim):
65 |
66 | # Graph positional encoding v/ Laplacian eigenvectors
67 | self.graph = positional_encoding(self.graph, pos_enc_dim)
--------------------------------------------------------------------------------
/data/QM9.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import dgl.data
4 |
5 | class QM9DGL(torch.utils.data.Dataset):
6 |
7 | def __init__(self, graph_list):
8 | self.graph_list = graph_list
9 |
10 | def __getitem__(self, i):
11 | return self.graph_list[i]
12 |
13 | def __len__(self):
14 | return len(self.graph_list)
15 |
16 | class QM9Dataset(torch.utils.data.Dataset):
17 |
18 | def __init__(self, name, target):
19 | """
20 | Loading QM9 Dataset
21 | """
22 | start = time.time()
23 | print("[I] Loading dataset %s..." % (name))
24 | self.name = name
25 | self.target = target
26 | self.data = dgl.data.QM9EdgeDataset([target])
27 | graph_list = []
28 | for i in range(len(self.data)):
29 | graph, label = self.data.__getitem__(i)
30 | #graph.ndata['feat'] = torch.cat([graph.ndata['pos'], graph.ndata['attr']], dim = -1)
31 | graph.ndata['feat'] = graph.ndata['attr']
32 | graph.edata['feat'] = graph.edata['edge_attr']
33 | graph_list.append((graph, label))
34 |
35 | self.train = QM9DGL(graph_list[:110000])
36 | self.val = QM9DGL(graph_list[110000:120000])
37 | self.test = QM9DGL(graph_list[120000:])
38 |
39 | print('train, test, val sizes :',len(self.train),len(self.test),len(self.val))
40 | print("[I] Finished loading.")
41 | print("[I] Data load time: {:.4f}s".format(time.time()-start))
42 |
43 |
44 | def collate(self, samples):
45 | # The input samples is a list of pairs (graph, label).
46 | graphs, labels = map(list, zip(*samples))
47 | batched_graph = dgl.batch(graphs)
48 | labels = torch.cat(labels, dim = 0)
49 | return batched_graph, labels
50 |
51 |
52 | if __name__ == '__main__':
53 | dataset = QM9Dataset('QM9', 'mu')
54 | import ipdb; ipdb.set_trace()
--------------------------------------------------------------------------------
/data/SBMs.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import pickle
4 | import numpy as np
5 |
6 | import dgl
7 | import torch
8 |
9 | from scipy import sparse as sp
10 | import numpy as np
11 |
12 |
13 |
14 | class load_SBMsDataSetDGL(torch.utils.data.Dataset):
15 |
16 | def __init__(self,
17 | data_dir,
18 | name,
19 | split):
20 |
21 | self.split = split
22 | self.is_test = split.lower() in ['test', 'val']
23 | with open(os.path.join(data_dir, name + '_%s.pkl' % self.split), 'rb') as f:
24 | self.dataset = pickle.load(f)
25 | self.node_labels = []
26 | self.graph_lists = []
27 | self.n_samples = len(self.dataset)
28 | self._prepare()
29 |
30 |
31 | def _prepare(self):
32 |
33 | print("preparing %d graphs for the %s set..." % (self.n_samples, self.split.upper()))
34 |
35 | for data in self.dataset:
36 |
37 | node_features = data.node_feat
38 | edge_list = (data.W != 0).nonzero() # converting adj matrix to edge_list
39 |
40 | # Create the DGL Graph
41 | g = dgl.DGLGraph()
42 | g.add_nodes(node_features.size(0))
43 | g.ndata['feat'] = node_features.long()
44 | for src, dst in edge_list:
45 | g.add_edges(src.item(), dst.item())
46 |
47 | # adding edge features for Residual Gated ConvNet
48 | #edge_feat_dim = g.ndata['feat'].size(1) # dim same as node feature dim
49 | edge_feat_dim = 1 # dim same as node feature dim
50 | g.edata['feat'] = torch.ones(g.number_of_edges(), edge_feat_dim)
51 |
52 | self.graph_lists.append(g)
53 | self.node_labels.append(data.node_label)
54 |
55 |
56 | def __len__(self):
57 | """Return the number of graphs in the dataset."""
58 | return self.n_samples
59 |
60 | def __getitem__(self, idx):
61 | """
62 | Get the idx^th sample.
63 | Parameters
64 | ---------
65 | idx : int
66 | The sample index.
67 | Returns
68 | -------
69 | (dgl.DGLGraph, int)
70 | DGLGraph with node feature stored in `feat` field
71 | And its label.
72 | """
73 | return self.graph_lists[idx], self.node_labels[idx]
74 |
75 |
76 | class SBMsDatasetDGL(torch.utils.data.Dataset):
77 |
78 | def __init__(self, name):
79 | """
80 | TODO
81 | """
82 | start = time.time()
83 | print("[I] Loading data ...")
84 | self.name = name
85 | data_dir = 'data/SBMs'
86 | self.train = load_SBMsDataSetDGL(data_dir, name, split='train')
87 | self.test = load_SBMsDataSetDGL(data_dir, name, split='test')
88 | self.val = load_SBMsDataSetDGL(data_dir, name, split='val')
89 | print("[I] Finished loading.")
90 | print("[I] Data load time: {:.4f}s".format(time.time()-start))
91 |
92 |
93 |
94 |
95 | def self_loop(g):
96 | """
97 | Utility function only, to be used only when necessary as per user self_loop flag
98 | : Overwriting the function dgl.transform.add_self_loop() to not miss ndata['feat'] and edata['feat']
99 |
100 |
101 | This function is called inside a function in SBMsDataset class.
102 | """
103 | new_g = dgl.DGLGraph()
104 | new_g.add_nodes(g.number_of_nodes())
105 | new_g.ndata['feat'] = g.ndata['feat']
106 |
107 | src, dst = g.all_edges(order="eid")
108 | src = dgl.backend.zerocopy_to_numpy(src)
109 | dst = dgl.backend.zerocopy_to_numpy(dst)
110 | non_self_edges_idx = src != dst
111 | nodes = np.arange(g.number_of_nodes())
112 | new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx])
113 | new_g.add_edges(nodes, nodes)
114 |
115 | # This new edata is not used since this function gets called only for GCN, GAT
116 | # However, we need this for the generic requirement of ndata and edata
117 | new_g.edata['feat'] = torch.zeros(new_g.number_of_edges())
118 | return new_g
119 |
120 |
121 |
122 | def positional_encoding(g, pos_enc_dim):
123 | """
124 | Graph positional encoding v/ Laplacian eigenvectors
125 | """
126 |
127 | # Laplacian
128 | A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
129 | N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
130 | L = sp.eye(g.number_of_nodes()) - N * A * N
131 |
132 | # # Eigenvectors with numpy
133 | # EigVal, EigVec = np.linalg.eig(L.toarray())
134 | # idx = EigVal.argsort() # increasing order
135 | # EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx])
136 | # g.ndata['pos_enc'] = torch.from_numpy(np.abs(EigVec[:,1:pos_enc_dim+1])).float()
137 |
138 | # Eigenvectors with scipy
139 | #EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR')
140 | EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR', tol=1e-2) # for 40 PEs
141 | EigVec = EigVec[:, EigVal.argsort()] # increasing order
142 | g.ndata['pos_enc'] = torch.from_numpy(np.real(EigVec[:,1:pos_enc_dim+1])).float()
143 |
144 | return g
145 |
146 |
147 |
148 | class SBMsDataset(torch.utils.data.Dataset):
149 |
150 | def __init__(self, name):
151 | """
152 | Loading SBM datasets
153 | """
154 | start = time.time()
155 | print("[I] Loading dataset %s..." % (name))
156 | self.name = name
157 | data_dir = 'data/SBMs/'
158 | with open(data_dir+name+'.pkl',"rb") as f:
159 | f = pickle.load(f)
160 | self.train = f[0]
161 | self.val = f[1]
162 | self.test = f[2]
163 | print('train, test, val sizes :',len(self.train),len(self.test),len(self.val))
164 | print("[I] Finished loading.")
165 | print("[I] Data load time: {:.4f}s".format(time.time()-start))
166 |
167 |
168 | # form a mini batch from a given list of samples = [(graph, label) pairs]
169 | def collate(self, samples):
170 | # The input samples is a list of pairs (graph, label).
171 | graphs, labels = map(list, zip(*samples))
172 | labels = torch.cat(labels).long()
173 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))]
174 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ]
175 | #snorm_n = torch.cat(tab_snorm_n).sqrt()
176 | #tab_sizes_e = [ graphs[i].number_of_edges() for i in range(len(graphs))]
177 | #tab_snorm_e = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_e ]
178 | #snorm_e = torch.cat(tab_snorm_e).sqrt()
179 | batched_graph = dgl.batch(graphs)
180 |
181 | return batched_graph, labels
182 |
183 | # prepare dense tensors for GNNs which use; such as RingGNN and 3WLGNN
184 | def collate_dense_gnn(self, samples):
185 | # The input samples is a list of pairs (graph, label).
186 | graphs, labels = map(list, zip(*samples))
187 | labels = torch.cat(labels).long()
188 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))]
189 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ]
190 | #snorm_n = tab_snorm_n[0][0].sqrt()
191 |
192 | #batched_graph = dgl.batch(graphs)
193 |
194 | g = graphs[0]
195 | adj = self._sym_normalize_adj(g.adjacency_matrix().to_dense())
196 | """
197 | Adapted from https://github.com/leichen2018/Ring-GNN/
198 | Assigning node and edge feats::
199 | we have the adjacency matrix in R^{n x n}, the node features in R^{d_n} and edge features R^{d_e}.
200 | Then we build a zero-initialized tensor, say T, in R^{(1 + d_n + d_e) x n x n}. T[0, :, :] is the adjacency matrix.
201 | The diagonal T[1:1+d_n, i, i], i = 0 to n-1, store the node feature of node i.
202 | The off diagonal T[1+d_n:, i, j] store edge features of edge(i, j).
203 | """
204 |
205 | zero_adj = torch.zeros_like(adj)
206 |
207 | if self.name == 'SBM_CLUSTER':
208 | self.num_node_type = 7
209 | elif self.name == 'SBM_PATTERN':
210 | self.num_node_type = 3
211 |
212 | # use node feats to prepare adj
213 | adj_node_feat = torch.stack([zero_adj for j in range(self.num_node_type)])
214 | adj_node_feat = torch.cat([adj.unsqueeze(0), adj_node_feat], dim=0)
215 |
216 | for node, node_label in enumerate(g.ndata['feat']):
217 | adj_node_feat[node_label.item()+1][node][node] = 1
218 |
219 | x_node_feat = adj_node_feat.unsqueeze(0)
220 |
221 | return x_node_feat, labels
222 |
223 | def _sym_normalize_adj(self, adj):
224 | deg = torch.sum(adj, dim = 0)#.squeeze()
225 | deg_inv = torch.where(deg>0, 1./torch.sqrt(deg), torch.zeros(deg.size()))
226 | deg_inv = torch.diag(deg_inv)
227 | return torch.mm(deg_inv, torch.mm(adj, deg_inv))
228 |
229 |
230 | def _add_self_loops(self):
231 |
232 | # function for adding self loops
233 | # this function will be called only if self_loop flag is True
234 |
235 | self.train.graph_lists = [self_loop(g) for g in self.train.graph_lists]
236 | self.val.graph_lists = [self_loop(g) for g in self.val.graph_lists]
237 | self.test.graph_lists = [self_loop(g) for g in self.test.graph_lists]
238 |
239 |
240 | def _add_positional_encodings(self, pos_enc_dim):
241 |
242 | # Graph positional encoding v/ Laplacian eigenvectors
243 | self.train.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.train.graph_lists]
244 | self.val.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.val.graph_lists]
245 | self.test.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.test.graph_lists]
--------------------------------------------------------------------------------
/data/TSP.py:
--------------------------------------------------------------------------------
1 | import time
2 | import pickle
3 | import numpy as np
4 | import itertools
5 | from scipy.spatial.distance import pdist, squareform
6 |
7 | import dgl
8 | import torch
9 | from torch.utils.data import Dataset
10 |
11 | from scipy import sparse as sp
12 |
13 | class TSP(Dataset):
14 | def __init__(self, data_dir, split="train", num_neighbors=25, max_samples=10000):
15 | self.data_dir = data_dir
16 | self.split = split
17 | self.filename = f'{data_dir}/tsp50-500_{split}.txt'
18 | self.max_samples = max_samples
19 | self.num_neighbors = num_neighbors
20 | self.is_test = split.lower() in ['test', 'val']
21 |
22 | self.graph_lists = []
23 | self.edge_labels = []
24 | self._prepare()
25 | self.n_samples = len(self.edge_labels)
26 |
27 | def _prepare(self):
28 | print('preparing all graphs for the %s set...' % self.split.upper())
29 |
30 | file_data = open(self.filename, "r").readlines()[:self.max_samples]
31 |
32 | for graph_idx, line in enumerate(file_data):
33 | line = line.split(" ") # Split into list
34 | num_nodes = int(line.index('output')//2)
35 |
36 | # Convert node coordinates to required format
37 | nodes_coord = []
38 | for idx in range(0, 2 * num_nodes, 2):
39 | nodes_coord.append([float(line[idx]), float(line[idx + 1])])
40 |
41 | # Compute distance matrix
42 | W_val = squareform(pdist(nodes_coord, metric='euclidean'))
43 | # Determine k-nearest neighbors for each node
44 | knns = np.argpartition(W_val, kth=self.num_neighbors, axis=-1)[:, self.num_neighbors::-1]
45 |
46 | # Convert tour nodes to required format
47 | # Don't add final connection for tour/cycle
48 | tour_nodes = [int(node) - 1 for node in line[line.index('output') + 1:-1]][:-1]
49 |
50 | # Compute an edge adjacency matrix representation of tour
51 | edges_target = np.zeros((num_nodes, num_nodes))
52 | for idx in range(len(tour_nodes) - 1):
53 | i = tour_nodes[idx]
54 | j = tour_nodes[idx + 1]
55 | edges_target[i][j] = 1
56 | edges_target[j][i] = 1
57 | # Add final connection of tour in edge target
58 | edges_target[j][tour_nodes[0]] = 1
59 | edges_target[tour_nodes[0]][j] = 1
60 |
61 | # Construct the DGL graph
62 | g = dgl.DGLGraph()
63 | g.add_nodes(num_nodes)
64 | g.ndata['feat'] = torch.Tensor(nodes_coord)
65 |
66 | edge_feats = [] # edge features i.e. euclidean distances between nodes
67 | edge_labels = [] # edges_targets as a list
68 | # Important!: order of edge_labels must be the same as the order of edges in DGLGraph g
69 | # We ensure this by adding them together
70 | for idx in range(num_nodes):
71 | for n_idx in knns[idx]:
72 | if n_idx != idx: # No self-connection
73 | g.add_edge(idx, n_idx)
74 | edge_feats.append(W_val[idx][n_idx])
75 | edge_labels.append(int(edges_target[idx][n_idx]))
76 | # dgl.transform.remove_self_loop(g)
77 |
78 | # Sanity check
79 | assert len(edge_feats) == g.number_of_edges() == len(edge_labels)
80 |
81 | # Add edge features
82 | g.edata['feat'] = torch.Tensor(edge_feats).unsqueeze(-1)
83 |
84 | # # Uncomment to add dummy edge features instead (for Residual Gated ConvNet)
85 | # edge_feat_dim = g.ndata['feat'].shape[1] # dim same as node feature dim
86 | # g.edata['feat'] = torch.ones(g.number_of_edges(), edge_feat_dim)
87 |
88 | self.graph_lists.append(g)
89 | self.edge_labels.append(edge_labels)
90 |
91 | def __len__(self):
92 | """Return the number of graphs in the dataset."""
93 | return self.n_samples
94 |
95 | def __getitem__(self, idx):
96 | """
97 | Get the idx^th sample.
98 | Parameters
99 | ---------
100 | idx : int
101 | The sample index.
102 | Returns
103 | -------
104 | (dgl.DGLGraph, list)
105 | DGLGraph with node feature stored in `feat` field
106 | And a list of labels for each edge in the DGLGraph.
107 | """
108 | return self.graph_lists[idx], self.edge_labels[idx]
109 |
110 |
111 | class TSPDatasetDGL(Dataset):
112 | def __init__(self, name):
113 | self.name = name
114 | self.train = TSP(data_dir='./data/TSP', split='train', num_neighbors=25, max_samples=10000)
115 | self.val = TSP(data_dir='./data/TSP', split='val', num_neighbors=25, max_samples=1000)
116 | self.test = TSP(data_dir='./data/TSP', split='test', num_neighbors=25, max_samples=1000)
117 |
118 |
119 | def positional_encoding(g, pos_enc_dim):
120 | """
121 | Graph positional encoding v/ Laplacian eigenvectors
122 | """
123 |
124 | # Laplacian
125 | A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
126 | N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
127 | L = sp.eye(g.number_of_nodes()) - N * A * N
128 |
129 | # Eigenvectors with numpy
130 | EigVal, EigVec = np.linalg.eig(L.toarray())
131 | idx = EigVal.argsort() # increasing order
132 | EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx])
133 | g.ndata['pos_enc'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float()
134 |
135 | # # Eigenvectors with scipy
136 | # EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR')
137 | # EigVec = EigVec[:, EigVal.argsort()] # increasing order
138 | # g.ndata['pos_enc'] = torch.from_numpy(np.abs(EigVec[:,1:pos_enc_dim+1])).float()
139 |
140 | return g
141 |
142 |
143 | class TSPDataset(Dataset):
144 | def __init__(self, name):
145 | start = time.time()
146 | print("[I] Loading dataset %s..." % (name))
147 | self.name = name
148 | data_dir = 'data/TSP/'
149 | with open(data_dir+name+'.pkl',"rb") as f:
150 | f = pickle.load(f)
151 | self.train = f[0]
152 | self.test = f[1]
153 | self.val = f[2]
154 | print('train, test, val sizes :',len(self.train),len(self.test),len(self.val))
155 | print("[I] Finished loading.")
156 | print("[I] Data load time: {:.4f}s".format(time.time()-start))
157 |
158 | # form a mini batch from a given list of samples = [(graph, label) pairs]
159 | def collate(self, samples):
160 | # The input samples is a list of pairs (graph, label).
161 | graphs, labels = map(list, zip(*samples))
162 | # Edge classification labels need to be flattened to 1D lists
163 | labels = torch.LongTensor(np.array(list(itertools.chain(*labels))))
164 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))]
165 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ]
166 | #snorm_n = torch.cat(tab_snorm_n).sqrt()
167 | #tab_sizes_e = [ graphs[i].number_of_edges() for i in range(len(graphs))]
168 | #tab_snorm_e = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_e ]
169 | #snorm_e = torch.cat(tab_snorm_e).sqrt()
170 | batched_graph = dgl.batch(graphs)
171 |
172 | return batched_graph, labels
173 |
174 |
175 | # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN
176 | def collate_dense_gnn(self, samples, edge_feat):
177 | # The input samples is a list of pairs (graph, label).
178 | graphs, labels = map(list, zip(*samples))
179 | # Edge classification labels need to be flattened to 1D lists
180 | labels = torch.LongTensor(np.array(list(itertools.chain(*labels))))
181 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))]
182 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ]
183 | #snorm_n = tab_snorm_n[0][0].sqrt()
184 |
185 | #batched_graph = dgl.batch(graphs)
186 |
187 | g = graphs[0]
188 | adj = self._sym_normalize_adj(g.adjacency_matrix().to_dense())
189 | """
190 | Adapted from https://github.com/leichen2018/Ring-GNN/
191 | Assigning node and edge feats::
192 | we have the adjacency matrix in R^{n x n}, the node features in R^{d_n} and edge features R^{d_e}.
193 | Then we build a zero-initialized tensor, say T, in R^{(1 + d_n + d_e) x n x n}. T[0, :, :] is the adjacency matrix.
194 | The diagonal T[1:1+d_n, i, i], i = 0 to n-1, store the node feature of node i.
195 | The off diagonal T[1+d_n:, i, j] store edge features of edge(i, j).
196 | """
197 |
198 | zero_adj = torch.zeros_like(adj)
199 |
200 | in_node_dim = g.ndata['feat'].shape[1]
201 | in_edge_dim = g.edata['feat'].shape[1]
202 |
203 | if edge_feat:
204 | # use edge feats also to prepare adj
205 | adj_with_edge_feat = torch.stack([zero_adj for j in range(in_node_dim + in_edge_dim)])
206 | adj_with_edge_feat = torch.cat([adj.unsqueeze(0), adj_with_edge_feat], dim=0)
207 |
208 | us, vs = g.edges()
209 | for idx, edge_feat in enumerate(g.edata['feat']):
210 | adj_with_edge_feat[1+in_node_dim:, us[idx], vs[idx]] = edge_feat
211 |
212 | for node, node_feat in enumerate(g.ndata['feat']):
213 | adj_with_edge_feat[1:1+in_node_dim, node, node] = node_feat
214 |
215 | x_with_edge_feat = adj_with_edge_feat.unsqueeze(0)
216 |
217 | return None, x_with_edge_feat, labels, g.edges()
218 | else:
219 | # use only node feats to prepare adj
220 | adj_no_edge_feat = torch.stack([zero_adj for j in range(in_node_dim)])
221 | adj_no_edge_feat = torch.cat([adj.unsqueeze(0), adj_no_edge_feat], dim=0)
222 |
223 | for node, node_feat in enumerate(g.ndata['feat']):
224 | adj_no_edge_feat[1:1+in_node_dim, node, node] = node_feat
225 |
226 | x_no_edge_feat = adj_no_edge_feat.unsqueeze(0)
227 |
228 | return x_no_edge_feat, None, labels, g.edges()
229 |
230 | def _sym_normalize_adj(self, adj):
231 | deg = torch.sum(adj, dim = 0)#.squeeze()
232 | deg_inv = torch.where(deg>0, 1./torch.sqrt(deg), torch.zeros(deg.size()))
233 | deg_inv = torch.diag(deg_inv)
234 | return torch.mm(deg_inv, torch.mm(adj, deg_inv))
235 |
236 |
237 | def _add_self_loops(self):
238 | """
239 | No self-loop support since TSP edge classification dataset.
240 | """
241 | raise NotImplementedError
242 |
243 |
244 | def _add_positional_encodings(self, pos_enc_dim):
245 |
246 | # Graph positional encoding v/ Laplacian eigenvectors
247 | self.train.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.train.graph_lists]
248 | self.val.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.val.graph_lists]
249 | self.test.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.test.graph_lists]
250 |
251 | if __name__ == '__main__':
252 | tsp = TSPDataset('TSP')
253 | import ipdb; ipdb.set_trace()
--------------------------------------------------------------------------------
/data/TSP/generate_TSP.py:
--------------------------------------------------------------------------------
1 | import time
2 | import argparse
3 | import pprint as pp
4 | import os
5 |
6 | import numpy as np
7 | from concorde.tsp import TSPSolver # Install from https://github.com/jvkersch/pyconcorde
8 |
9 |
10 | if __name__ == "__main__":
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--min_nodes", type=int, default=50)
13 | parser.add_argument("--max_nodes", type=int, default=500)
14 | parser.add_argument("--num_samples", type=int, default=10000)
15 | parser.add_argument("--filename", type=str, default=None)
16 | parser.add_argument("--node_dim", type=int, default=2)
17 | parser.add_argument("--seed", type=int, default=1234)
18 | opts = parser.parse_args()
19 |
20 | if opts.filename is None:
21 | opts.filename = f"tsp{opts.min_nodes}-{opts.max_nodes}.txt"
22 |
23 | # Pretty print the run args
24 | pp.pprint(vars(opts))
25 |
26 | np.random.seed(opts.seed)
27 |
28 | with open(opts.filename, "w") as f:
29 | start_time = time.time()
30 | idx = 0
31 | while idx < opts.num_samples:
32 | num_nodes = np.random.randint(low=opts.min_nodes, high=opts.max_nodes+1)
33 |
34 | nodes_coord = np.random.random([num_nodes, opts.node_dim])
35 | solver = TSPSolver.from_data(nodes_coord[:, 0], nodes_coord[:, 1], norm="GEO")
36 | solution = solver.solve()
37 |
38 | # Only write instances with valid solutions
39 | if (np.sort(solution.tour) == np.arange(num_nodes)).all():
40 | f.write( " ".join( str(x)+str(" ")+str(y) for x,y in nodes_coord) )
41 | f.write( str(" ") + str('output') + str(" ") )
42 | f.write( str(" ").join( str(node_idx+1) for node_idx in solution.tour) )
43 | f.write( str(" ") + str(solution.tour[0]+1) + str(" ") )
44 | f.write( "\n" )
45 | idx += 1
46 |
47 | end_time = time.time() - start_time
48 |
49 | print(f"Completed generation of {opts.num_samples} samples of TSP{opts.min_nodes}-{opts.max_nodes}.")
50 | print(f"Total time: {end_time/60:.1f}m")
51 | print(f"Average time: {end_time/opts.num_samples:.1f}s")
52 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import dgl
2 | import torch
3 | import torch.nn as nn
4 | from data.molecules import MoleculeDataset
5 | from data.QM9 import QM9Dataset
6 | from data.SBMs import SBMsDataset
7 | from data.TSP import TSPDataset
8 | from data.superpixels import SuperPixDataset
9 | from data.cora import CoraDataset
10 | from models.networks import *
11 | from utils.utils import *
12 |
13 |
14 | class TransInput(nn.Module):
15 |
16 | def __init__(self, trans_fn):
17 | super().__init__()
18 | self.trans = trans_fn
19 |
20 | def forward(self, input):
21 | if self.trans:
22 | input['V'] = self.trans(input['V'])
23 | return input
24 |
25 |
26 | class TransOutput(nn.Module):
27 |
28 | def __init__(self, args):
29 | super().__init__()
30 | self.args = args
31 | if args.task == 'node_level':
32 | channel_sequence = (args.node_dim, ) * args.nb_mlp_layer + (args.nb_classes, )
33 | self.trans = MLP(channel_sequence)
34 | elif args.task == 'link_level':
35 | channel_sequence = (args.node_dim * 2, ) * args.nb_mlp_layer + (args.nb_classes, )
36 | self.trans = MLP(channel_sequence)
37 | elif args.task == 'graph_level':
38 | channel_sequence = (args.node_dim, ) * args.nb_mlp_layer + (args.nb_classes, )
39 | self.trans = MLP(channel_sequence)
40 | else:
41 | raise Exception('Unknown task!')
42 |
43 |
44 | def forward(self, input):
45 | G, V = input['G'], input['V']
46 | if self.args.task == 'node_level':
47 | output = self.trans(V)
48 | elif self.args.task == 'link_level':
49 | def _edge_feat(edges):
50 | e = torch.cat([edges.src['V'], edges.dst['V']], dim=1)
51 | return {'e': e}
52 | G.ndata['V'] = V
53 | G.apply_edges(_edge_feat)
54 | output = self.trans(G.edata['e'])
55 | elif self.args.task == 'graph_level':
56 | G.ndata['V'] = V
57 | readout = dgl.mean_nodes(G, 'V')
58 | output = self.trans(readout)
59 | else:
60 | raise Exception('Unknown task!')
61 | return output
62 |
63 |
64 | def get_trans_input(args):
65 | if args.data in ['ZINC']:
66 | trans_input = nn.Embedding(args.in_dim_V, args.node_dim)
67 | elif args.data in ['TSP']:
68 | trans_input = nn.Linear(args.in_dim_V, args.node_dim)
69 | elif args.data in ['SBM_CLUSTER', 'SBM_PATTERN']:
70 | trans_input = nn.Embedding(args.in_dim_V, args.node_dim)
71 | elif args.data in ['CIFAR10', 'MNIST', 'Cora']:
72 | trans_input = nn.Linear(args.in_dim_V, args.node_dim)
73 | elif args.data in ['QM9']:
74 | trans_input = nn.Linear(args.in_dim_V, args.node_dim)
75 | else:
76 | raise Exception('Unknown dataset!')
77 | return trans_input
78 |
79 |
80 | def get_loss_fn(args):
81 | if args.data in ['ZINC', 'QM9']:
82 | loss_fn = MoleculesCriterion()
83 | elif args.data in ['TSP']:
84 | loss_fn = TSPCriterion()
85 | elif args.data in ['SBM_CLUSTER', 'SBM_PATTERN']:
86 | loss_fn = SBMsCriterion(args.nb_classes)
87 | elif args.data in ['CIFAR10', 'MNIST']:
88 | loss_fn = SuperPixCriterion()
89 | elif args.data in ['Cora']:
90 | loss_fn = CiteCriterion()
91 | else:
92 | raise Exception('Unknown dataset!')
93 | return loss_fn
94 |
95 |
96 | def load_data(args):
97 | if args.data in ['ZINC']:
98 | return MoleculeDataset(args.data)
99 | elif args.data in ['QM9']:
100 | return QM9Dataset(args.data, args.extra)
101 | elif args.data in ['TSP']:
102 | return TSPDataset(args.data)
103 | elif args.data in ['MNIST', 'CIFAR10']:
104 | return SuperPixDataset(args.data)
105 | elif args.data in ['SBM_CLUSTER', 'SBM_PATTERN']:
106 | return SBMsDataset(args.data)
107 | elif args.data in ['Cora']:
108 | return CoraDataset(args.data)
109 | else:
110 | raise Exception('Unknown dataset!')
111 |
112 |
113 | def load_metric(args):
114 | if args.data in ['ZINC', 'QM9']:
115 | return MAE
116 | elif args.data in ['TSP']:
117 | return binary_f1_score
118 | elif args.data in ['MNIST', 'CIFAR10']:
119 | return accuracy_MNIST_CIFAR
120 | elif args.data in ['SBM_CLUSTER', 'SBM_PATTERN']:
121 | return accuracy_SBM
122 | elif args.data in ['Cora']:
123 | return CoraAccuracy
124 | else:
125 | raise Exception('Unknown dataset!')
126 |
--------------------------------------------------------------------------------
/data/cora.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import dgl.data
4 |
5 | class CoraDGL(torch.utils.data.Dataset):
6 |
7 | def __init__(self):
8 | self.graph = dgl.data.CoraGraphDataset()[0]
9 | self.graph.edata['feat'] = torch.ones([self.graph.num_edges(), 1]).float()
10 |
11 | def __getitem__(self, i):
12 | assert i == 0
13 | return self.graph, self.graph.ndata['label']
14 |
15 | def __len__(self):
16 | return 1
17 |
18 | class CoraDataset(torch.utils.data.Dataset):
19 |
20 | def __init__(self, name):
21 | """
22 | Loading Cora Dataset
23 | """
24 | start = time.time()
25 | print("[I] Loading dataset %s..." % (name))
26 | self.name = name
27 | base_graph = CoraDGL()
28 | self.train = base_graph
29 | self.val = base_graph
30 | self.test = base_graph
31 |
32 | print('train, test, val sizes :',len(self.train),len(self.test),len(self.val))
33 | print("[I] Finished loading.")
34 | print("[I] Data load time: {:.4f}s".format(time.time()-start))
35 |
36 |
37 | def collate(self, samples):
38 | return samples[0]
39 |
40 | if __name__ == '__main__':
41 | dataset = CoraDataset('Cora')
42 | print(dataset.train.__getitem__(0))
--------------------------------------------------------------------------------
/data/molecules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pickle
3 | import torch.utils.data
4 | import time
5 | import os
6 | import numpy as np
7 |
8 | import csv
9 |
10 | import dgl
11 |
12 | from scipy import sparse as sp
13 | import numpy as np
14 |
15 | # *NOTE
16 | # The dataset pickle and index files are in ./zinc_molecules/ dir
17 | # [.pickle and .index; for split 'train', 'val' and 'test']
18 |
19 |
20 | class MoleculeDGL(torch.utils.data.Dataset):
21 | def __init__(self, data_dir, split, num_graphs=None):
22 | self.data_dir = data_dir
23 | self.split = split
24 | self.num_graphs = num_graphs
25 |
26 | with open(data_dir + "/%s.pickle" % self.split,"rb") as f:
27 | self.data = pickle.load(f)
28 |
29 | if self.num_graphs in [10000, 1000]:
30 | # loading the sampled indices from file ./zinc_molecules/.index
31 | with open(data_dir + "/%s.index" % self.split,"r") as f:
32 | data_idx = [list(map(int, idx)) for idx in csv.reader(f)]
33 | self.data = [ self.data[i] for i in data_idx[0] ]
34 |
35 | assert len(self.data)==num_graphs, "Sample num_graphs again; available idx: train/val/test => 10k/1k/1k"
36 |
37 | """
38 | data is a list of Molecule dict objects with following attributes
39 |
40 | molecule = data[idx]
41 | ; molecule['num_atom'] : nb of atoms, an integer (N)
42 | ; molecule['atom_type'] : tensor of size N, each element is an atom type, an integer between 0 and num_atom_type
43 | ; molecule['bond_type'] : tensor of size N x N, each element is a bond type, an integer between 0 and num_bond_type
44 | ; molecule['logP_SA_cycle_normalized'] : the chemical property to regress, a float variable
45 | """
46 |
47 | self.graph_lists = []
48 | self.graph_labels = []
49 | self.n_samples = len(self.data)
50 | self._prepare()
51 |
52 | def _prepare(self):
53 | print("preparing %d graphs for the %s set..." % (self.num_graphs, self.split.upper()))
54 |
55 | for molecule in self.data:
56 | node_features = molecule['atom_type'].long()
57 |
58 | adj = molecule['bond_type']
59 | edge_list = (adj != 0).nonzero() # converting adj matrix to edge_list
60 |
61 | edge_idxs_in_adj = edge_list.split(1, dim=1)
62 | edge_features = adj[edge_idxs_in_adj].reshape(-1).long()
63 |
64 | # Create the DGL Graph
65 | g = dgl.DGLGraph()
66 | g.add_nodes(molecule['num_atom'])
67 | g.ndata['feat'] = node_features
68 |
69 | for src, dst in edge_list:
70 | g.add_edges(src.item(), dst.item())
71 | g.edata['feat'] = edge_features
72 |
73 | self.graph_lists.append(g)
74 | self.graph_labels.append(molecule['logP_SA_cycle_normalized'])
75 |
76 | def __len__(self):
77 | """Return the number of graphs in the dataset."""
78 | return self.n_samples
79 |
80 | def __getitem__(self, idx):
81 | """
82 | Get the idx^th sample.
83 | Parameters
84 | ---------
85 | idx : int
86 | The sample index.
87 | Returns
88 | -------
89 | (dgl.DGLGraph, int)
90 | DGLGraph with node feature stored in `feat` field
91 | And its label.
92 | """
93 | return self.graph_lists[idx], self.graph_labels[idx]
94 |
95 |
96 | class MoleculeDatasetDGL(torch.utils.data.Dataset):
97 | def __init__(self, name='Zinc'):
98 | t0 = time.time()
99 | self.name = name
100 |
101 | self.num_atom_type = 28 # known meta-info about the zinc dataset; can be calculated as well
102 | self.num_bond_type = 4 # known meta-info about the zinc dataset; can be calculated as well
103 |
104 | data_dir='./data/molecules'
105 | if self.name == 'ZINC-full':
106 | data_dir='./data/molecules/zinc_full'
107 | self.train = MoleculeDGL(data_dir, 'train', num_graphs=220011)
108 | self.val = MoleculeDGL(data_dir, 'val', num_graphs=24445)
109 | self.test = MoleculeDGL(data_dir, 'test', num_graphs=5000)
110 | else:
111 | self.train = MoleculeDGL(data_dir, 'train', num_graphs=10000)
112 | self.val = MoleculeDGL(data_dir, 'val', num_graphs=1000)
113 | self.test = MoleculeDGL(data_dir, 'test', num_graphs=1000)
114 | print("Time taken: {:.4f}s".format(time.time()-t0))
115 |
116 |
117 |
118 | def self_loop(g):
119 | """
120 | Utility function only, to be used only when necessary as per user self_loop flag
121 | : Overwriting the function dgl.transform.add_self_loop() to not miss ndata['feat'] and edata['feat']
122 |
123 |
124 | This function is called inside a function in MoleculeDataset class.
125 | """
126 | new_g = dgl.DGLGraph()
127 | new_g.add_nodes(g.number_of_nodes())
128 | new_g.ndata['feat'] = g.ndata['feat']
129 |
130 | src, dst = g.all_edges(order="eid")
131 | src = dgl.backend.zerocopy_to_numpy(src)
132 | dst = dgl.backend.zerocopy_to_numpy(dst)
133 | non_self_edges_idx = src != dst
134 | nodes = np.arange(g.number_of_nodes())
135 | new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx])
136 | new_g.add_edges(nodes, nodes)
137 |
138 | # This new edata is not used since this function gets called only for GCN, GAT
139 | # However, we need this for the generic requirement of ndata and edata
140 | new_g.edata['feat'] = torch.zeros(new_g.number_of_edges())
141 | return new_g
142 |
143 |
144 |
145 | def positional_encoding(g, pos_enc_dim):
146 | """
147 | Graph positional encoding v/ Laplacian eigenvectors
148 | """
149 |
150 | # Laplacian
151 | A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
152 | N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
153 | L = sp.eye(g.number_of_nodes()) - N * A * N
154 |
155 | # Eigenvectors with numpy
156 | EigVal, EigVec = np.linalg.eig(L.toarray())
157 | idx = EigVal.argsort() # increasing order
158 | EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx])
159 | g.ndata['pos_enc'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float()
160 |
161 | # # Eigenvectors with scipy
162 | # EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR')
163 | # EigVec = EigVec[:, EigVal.argsort()] # increasing order
164 | # g.ndata['pos_enc'] = torch.from_numpy(np.abs(EigVec[:,1:pos_enc_dim+1])).float()
165 |
166 | return g
167 |
168 |
169 |
170 | class MoleculeDataset(torch.utils.data.Dataset):
171 |
172 | def __init__(self, name):
173 | """
174 | Loading Molecules datasets
175 | """
176 | start = time.time()
177 | print("[I] Loading dataset %s..." % (name))
178 | self.name = name
179 | data_dir = 'data/molecules/'
180 | with open(data_dir+name+'.pkl',"rb") as f:
181 | f = pickle.load(f)
182 | self.train = f[0]
183 | self.val = f[1]
184 | self.test = f[2]
185 | self.num_atom_type = f[3]
186 | self.num_bond_type = f[4]
187 | print('train, test, val sizes :',len(self.train),len(self.test),len(self.val))
188 | print("[I] Finished loading.")
189 | print("[I] Data load time: {:.4f}s".format(time.time()-start))
190 |
191 |
192 | # form a mini batch from a given list of samples = [(graph, label) pairs]
193 | def collate(self, samples):
194 | # The input samples is a list of pairs (graph, label).
195 | graphs, labels = map(list, zip(*samples))
196 | labels = torch.tensor(np.array(labels)).unsqueeze(1)
197 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))]
198 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ]
199 | #snorm_n = torch.cat(tab_snorm_n).sqrt()
200 | #tab_sizes_e = [ graphs[i].number_of_edges() for i in range(len(graphs))]
201 | #tab_snorm_e = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_e ]
202 | #snorm_e = torch.cat(tab_snorm_e).sqrt()
203 | batched_graph = dgl.batch(graphs)
204 |
205 | return batched_graph, labels
206 |
207 | # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN
208 | def collate_dense_gnn(self, samples, edge_feat):
209 | # The input samples is a list of pairs (graph, label).
210 | graphs, labels = map(list, zip(*samples))
211 | labels = torch.tensor(np.array(labels)).unsqueeze(1)
212 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))]
213 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ]
214 | #snorm_n = tab_snorm_n[0][0].sqrt()
215 |
216 | #batched_graph = dgl.batch(graphs)
217 |
218 | g = graphs[0]
219 | adj = self._sym_normalize_adj(g.adjacency_matrix().to_dense())
220 | """
221 | Adapted from https://github.com/leichen2018/Ring-GNN/
222 | Assigning node and edge feats::
223 | we have the adjacency matrix in R^{n x n}, the node features in R^{d_n} and edge features R^{d_e}.
224 | Then we build a zero-initialized tensor, say T, in R^{(1 + d_n + d_e) x n x n}. T[0, :, :] is the adjacency matrix.
225 | The diagonal T[1:1+d_n, i, i], i = 0 to n-1, store the node feature of node i.
226 | The off diagonal T[1+d_n:, i, j] store edge features of edge(i, j).
227 | """
228 |
229 | zero_adj = torch.zeros_like(adj)
230 |
231 | if edge_feat:
232 | # use edge feats also to prepare adj
233 | adj_with_edge_feat = torch.stack([zero_adj for j in range(self.num_atom_type + self.num_bond_type)])
234 | adj_with_edge_feat = torch.cat([adj.unsqueeze(0), adj_with_edge_feat], dim=0)
235 |
236 | us, vs = g.edges()
237 | for idx, edge_label in enumerate(g.edata['feat']):
238 | adj_with_edge_feat[edge_label.item()+1+self.num_atom_type][us[idx]][vs[idx]] = 1
239 |
240 | for node, node_label in enumerate(g.ndata['feat']):
241 | adj_with_edge_feat[node_label.item()+1][node][node] = 1
242 |
243 | x_with_edge_feat = adj_with_edge_feat.unsqueeze(0)
244 |
245 | return None, x_with_edge_feat, labels
246 |
247 | else:
248 | # use only node feats to prepare adj
249 | adj_no_edge_feat = torch.stack([zero_adj for j in range(self.num_atom_type)])
250 | adj_no_edge_feat = torch.cat([adj.unsqueeze(0), adj_no_edge_feat], dim=0)
251 |
252 | for node, node_label in enumerate(g.ndata['feat']):
253 | adj_no_edge_feat[node_label.item()+1][node][node] = 1
254 |
255 | x_no_edge_feat = adj_no_edge_feat.unsqueeze(0)
256 |
257 | return x_no_edge_feat, None, labels
258 |
259 | def _sym_normalize_adj(self, adj):
260 | deg = torch.sum(adj, dim = 0)#.squeeze()
261 | deg_inv = torch.where(deg>0, 1./torch.sqrt(deg), torch.zeros(deg.size()))
262 | deg_inv = torch.diag(deg_inv)
263 | return torch.mm(deg_inv, torch.mm(adj, deg_inv))
264 |
265 | def _add_self_loops(self):
266 |
267 | # function for adding self loops
268 | # this function will be called only if self_loop flag is True
269 |
270 | self.train.graph_lists = [self_loop(g) for g in self.train.graph_lists]
271 | self.val.graph_lists = [self_loop(g) for g in self.val.graph_lists]
272 | self.test.graph_lists = [self_loop(g) for g in self.test.graph_lists]
273 |
274 | def _add_positional_encodings(self, pos_enc_dim):
275 |
276 | # Graph positional encoding v/ Laplacian eigenvectors
277 | self.train.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.train.graph_lists]
278 | self.val.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.val.graph_lists]
279 | self.test.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.test.graph_lists]
280 |
281 |
--------------------------------------------------------------------------------
/data/molecules/prepare_molecules.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Notebook for preparing and saving MOLECULAR graphs"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "import torch\n",
18 | "import pickle\n",
19 | "import time\n",
20 | "import os\n",
21 | "%matplotlib inline\n",
22 | "import matplotlib.pyplot as plt\n"
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {},
28 | "source": [
29 | "# Download ZINC dataset"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 2,
35 | "metadata": {},
36 | "outputs": [
37 | {
38 | "output_type": "stream",
39 | "name": "stdout",
40 | "text": [
41 | "File already downloaded\n"
42 | ]
43 | }
44 | ],
45 | "source": [
46 | "if not os.path.isfile('molecules.zip'):\n",
47 | " print('downloading..')\n",
48 | " # !curl https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1 -o molecules.zip -J -L -k\n",
49 | " !unzip molecules.zip -d ../\n",
50 | " # !tar -xvf molecules.zip -C ../\n",
51 | "else:\n",
52 | " print('File already downloaded')\n",
53 | " "
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "metadata": {},
59 | "source": [
60 | "# Convert to DGL format and save with pickle"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 3,
66 | "metadata": {},
67 | "outputs": [
68 | {
69 | "output_type": "stream",
70 | "name": "stdout",
71 | "text": [
72 | "/data00/caishaofei/workspace/GNAS2\n"
73 | ]
74 | }
75 | ],
76 | "source": [
77 | "import os\n",
78 | "os.chdir('../../') # go to root folder of the project\n",
79 | "print(os.getcwd())\n"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 4,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "import pickle\n",
89 | "\n",
90 | "%load_ext autoreload\n",
91 | "%autoreload 2\n",
92 | "from data.molecules import MoleculeDatasetDGL \n",
93 | "\n",
94 | "#from data import LoadData\n",
95 | "from torch.utils.data import DataLoader\n",
96 | "\n",
97 | "from data.molecules import MoleculeDataset\n"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 7,
103 | "metadata": {},
104 | "outputs": [
105 | {
106 | "output_type": "stream",
107 | "name": "stdout",
108 | "text": [
109 | "preparing 10000 graphs for the TRAIN set...\n",
110 | "/data00/caishaofei/miniconda3/envs/gnas2/lib/python3.6/site-packages/dgl/base.py:45: DGLWarning: Recommend creating graphs by `dgl.graph(data)` instead of `dgl.DGLGraph(data)`.\n",
111 | " return warnings.warn(message, category=category, stacklevel=1)\n",
112 | "preparing 1000 graphs for the VAL set...\n",
113 | "preparing 1000 graphs for the TEST set...\n",
114 | "Time taken: 378.4136s\n"
115 | ]
116 | }
117 | ],
118 | "source": [
119 | "DATASET_NAME = 'ZINC'\n",
120 | "dataset = MoleculeDatasetDGL(DATASET_NAME) \n"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 8,
126 | "metadata": {},
127 | "outputs": [
128 | {
129 | "output_type": "display_data",
130 | "data": {
131 | "text/plain": "",
132 | "image/svg+xml": "\n\n\n\n",
133 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEICAYAAACzliQjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAUDklEQVR4nO3df7DldX3f8eerq5JEQZZwQ9fdNYvOaouMWeIWyTRaWhNBSLKYZuhuU0VLXZjAjI52UkjbgdrSUiNx4sSuWcIWyCiEBA3bijWrcaSZFuSCG35KWXAJu7Pu3oiIRGcr8O4f53vLcb337r33nD333v08HzNn7ve8v78+H77wul8+3x83VYUkqQ1/a6EbIEkaHUNfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr50GEk+meTfLnQ7pGGI9+nraJdkN/AvquqLC92W6SyFNuro4Jm+mpbkJQvdBmmUDH0d1ZL8IfBq4L8leTbJbyapJBcm+Svgz7vl/jjJN5N8J8kdSd7Qt43rk/yHbvrMJHuSfCjJgST7kry3b9lzkjyU5LtJ9ib5l33zfinJziRPJ/lfSd44XRtH8g9HTTL0dVSrqncBfwX8clW9Arilm/UPgL8LnNV9/zywFvgp4F7gUzNs9m8DrwRWAhcCn0iyvJt3HXBRVR0LnMqLv1ROA7YBFwE/Cfw+sD3JMYe2sao+MnDHpWkY+mrVlVX1N1X1fYCq2lZV362qg8CVwM8keeU06/4A+HBV/aCqbgeeBV7fN++UJMdV1ber6t6uvhn4/aq6q6qer6obgIPAGUeof9KUDH216snJiSTLklyd5LEkzwC7u1knTrPut6rqub7v3wNe0U3/Y+Ac4IkkX0nyc139p4EPdUM7Tyd5GlgNvGo43ZFmx9BXC6a6Ra2/9k+BDcAv0Bu2WdPVM+cdVd1dVRvoDRP9KS8OJz0JXFVVx/d9fqKqbpqhjdLQGfpqwX7gNTPMP5beUMu3gJ8A/uN8dpLkZUl+Pckrq+oHwDPAC93sa4GLk7w5PS9Pcm6SY2fZRmkoDH214D8B/6YbUvm1KebfCDwB7AUeAu4cYF/vAnZ3w0QXA78OUFXjwPuA3wO+DewC3jNVG/vv+JGGzYezJKkhnulLUkMMfUlqiKEvSQ0x9CWpIYv+ZVMnnnhirVmzZqGbIUlLxj333PPXVTU21bxFH/pr1qxhfHx8oZshSUtGkiemm+fwjiQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNWTRP5ErLVZrLvvcvNfdffW5Q2yJNHue6UtSQwx9SWqIoS9JDTH0Jakhhw39JNuSHEjyQF/tj5Ls7D67k+zs6muSfL9v3if71nlTkvuT7Ery8SQ5Ij2SJE1rNnfvXA/8HnDjZKGq/snkdJJrgO/0Lf9YVa2bYjtbgPcBdwG3A2cDn59ziyVJ83bYM/2qugN4aqp53dn6+cBNM20jyQrguKq6s6qK3i+Q8+bcWknSQAYd038LsL+qHu2rnZzka0m+kuQtXW0lsKdvmT1dbUpJNicZTzI+MTExYBMlSZMGDf1N/PBZ/j7g1VV1GvBB4NNJjpvrRqtqa1Wtr6r1Y2NT/plHSdI8zPuJ3CQvAX4VeNNkraoOAge76XuSPAa8DtgLrOpbfVVXkySN0CBn+r8AfL2q/v+wTZKxJMu66dcAa4HHq2of8EySM7rrAO8Gbhtg35KkeZjNLZs3Af8beH2SPUku7GZt5Ecv4L4VuK+7hfNPgIuravIi8G8AfwDsAh7DO3ckaeQOO7xTVZumqb9nitqtwK3TLD8OnDrH9kmShsgnciWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGzOYPo29LciDJA321K5PsTbKz+5zTN+/yJLuSPJLkrL762V1tV5LLht8VSdLhzOZM/3rg7CnqH6uqdd3ndoAkpwAbgTd06/yXJMuSLAM+AbwDOAXY1C0rSRqhlxxugaq6I8maWW5vA3BzVR0EvpFkF3B6N29XVT0OkOTmbtmH5t5kSdJ8DTKmf2mS+7rhn+VdbSXwZN8ye7radHVJ0gjNN/S3AK8F1gH7gGuG1SCAJJuTjCcZn5iYGOamJalp8wr9qtpfVc9X1QvAtbw4hLMXWN236KquNl19uu1vrar1VbV+bGxsPk2UJE1hXqGfZEXf13cCk3f2bAc2JjkmycnAWuCrwN3A2iQnJ3kZvYu92+ffbEnSfBz2Qm6Sm4AzgROT7AGuAM5Msg4oYDdwEUBVPZjkFnoXaJ8DLqmq57vtXAp8AVgGbKuqB4fdGUnSzGZz986mKcrXzbD8VcBVU9RvB26fU+skSUPlE7mS1BBDX5IaYuhLUkMMfUlqiKEvSQ057N07khaXNZd9bqD1d1997pBaoqXIM31JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDfE1DFpwvlZAGh3P9CWpIYa+JDXE0Jekhhw29JNsS3IgyQN9td9O8vUk9yX5bJLju/qaJN9PsrP7fLJvnTcluT/JriQfT5Ij0iNJ0rRmc6Z/PXD2IbUdwKlV9Ubg/wCX9817rKrWdZ+L++pbgPcBa7vPoduUJB1hhw39qroDeOqQ2p9V1XPd1zuBVTNtI8kK4LiqurOqCrgROG9eLZYkzdswxvT/OfD5vu8nJ/lakq8keUtXWwns6VtmT1ebUpLNScaTjE9MTAyhiZIkGDD0k/xr4DngU11pH/DqqjoN+CDw6STHzXW7VbW1qtZX1fqxsbFBmihJ6jPvh7OSvAf4JeBt3ZANVXUQONhN35PkMeB1wF5+eAhoVVeTJI3QvM70k5wN/CbwK1X1vb76WJJl3fRr6F2wfbyq9gHPJDmju2vn3cBtA7dekjQnhz3TT3ITcCZwYpI9wBX07tY5BtjR3Xl5Z3enzluBDyf5AfACcHFVTV4E/g16dwL9OL1rAP3XASRJI3DY0K+qTVOUr5tm2VuBW6eZNw6cOqfWSZKGyidyJakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIbMKvSTbEtyIMkDfbUTkuxI8mj3c3lXT5KPJ9mV5L4kP9u3zgXd8o8muWD43ZEkzWS2Z/rXA2cfUrsM+FJVrQW+1H0HeAewtvtsBrZA75cEcAXwZuB04IrJXxSSpNGYVehX1R3AU4eUNwA3dNM3AOf11W+snjuB45OsAM4CdlTVU1X1bWAHP/qLRJJ0BA0ypn9SVe3rpr8JnNRNrwSe7FtuT1ebrv4jkmxOMp5kfGJiYoAmSpL6DeVCblUVUMPYVre9rVW1vqrWj42NDWuzktS8QUJ/fzdsQ/fzQFffC6zuW25VV5uuLkkakUFCfzsweQfOBcBtffV3d3fxnAF8pxsG+gLw9iTLuwu4b+9qkqQReclsFkpyE3AmcGKSPfTuwrkauCXJhcATwPnd4rcD5wC7gO8B7wWoqqeS/Hvg7m65D1fVoReHJUlH0KxCv6o2TTPrbVMsW8Al02xnG7Bt1q2TJA2VT+RKUkMMfUlqiKEvSQ0x9CWpIbO6kCtpuNZc9rmFboIa5Zm+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEN+9oyVvkPfY7L763CG2RFr8PNOXpIYY+pLUEENfkhoy79BP8vokO/s+zyT5QJIrk+ztq5/Tt87lSXYleSTJWcPpgiRptuZ9IbeqHgHWASRZBuwFPgu8F/hYVX20f/kkpwAbgTcArwK+mOR1VfX8fNsgSZqbYQ3vvA14rKqemGGZDcDNVXWwqr4B7AJOH9L+JUmzMKzQ3wjc1Pf90iT3JdmWZHlXWwk82bfMnq72I5JsTjKeZHxiYmJITZQkDRz6SV4G/Arwx11pC/BaekM/+4Br5rrNqtpaVeurav3Y2NigTZQkdYZxpv8O4N6q2g9QVfur6vmqegG4lheHcPYCq/vWW9XVJEkjMozQ30Tf0E6SFX3z3gk80E1vBzYmOSbJycBa4KtD2L8kaZYGeg1DkpcDvwhc1Ff+SJJ1QAG7J+dV1YNJbgEeAp4DLvHOHS20QV7hIC1FA4V+Vf0N8JOH1N41w/JXAVcNsk9J0vz5RK4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JashAr1bW0WWQd8vvvvrcIbZE0pHimb4kNcTQl6SGGPqS1BBDX5IaMnDoJ9md5P4kO5OMd7UTkuxI8mj3c3lXT5KPJ9mV5L4kPzvo/iVJszesM/1/WFXrqmp99/0y4EtVtRb4Uvcd4B3A2u6zGdgypP1LkmbhSA3vbABu6KZvAM7rq99YPXcCxydZcYTaIEk6xDBCv4A/S3JPks1d7aSq2tdNfxM4qZteCTzZt+6ervZDkmxOMp5kfGJiYghNlCTBcB7O+vmq2pvkp4AdSb7eP7OqKknNZYNVtRXYCrB+/fo5rStJmt7AZ/pVtbf7eQD4LHA6sH9y2Kb7eaBbfC+wum/1VV1NkjQCA4V+kpcnOXZyGng78ACwHbigW+wC4LZuejvw7u4unjOA7/QNA0mSjrBBh3dOAj6bZHJbn66q/5HkbuCWJBcCTwDnd8vfDpwD7AK+B7x3wP1rkRjkvT2SRmeg0K+qx4GfmaL+LeBtU9QLuGSQfUqS5s8nciWpIYa+JDXE9+lLjVmov5vg32tYHDzTl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kN8dXKkmbNP4u59HmmL0kNmXfoJ1md5MtJHkryYJL3d/Urk+xNsrP7nNO3zuVJdiV5JMlZw+iAJGn2BhneeQ74UFXdm+RY4J4kO7p5H6uqj/YvnOQUYCPwBuBVwBeTvK6qnh+gDZKkOZj3mX5V7auqe7vp7wIPAytnWGUDcHNVHayqbwC7gNPnu39J0twNZUw/yRrgNOCurnRpkvuSbEuyvKutBJ7sW20P0/ySSLI5yXiS8YmJiWE0UZLEEEI/ySuAW4EPVNUzwBbgtcA6YB9wzVy3WVVbq2p9Va0fGxsbtImSpM5AoZ/kpfQC/1NV9RmAqtpfVc9X1QvAtbw4hLMXWN23+qquJkkakUHu3glwHfBwVf1OX31F32LvBB7oprcDG5Mck+RkYC3w1fnuX5I0d4PcvfP3gXcB9yfZ2dV+C9iUZB1QwG7gIoCqejDJLcBD9O78ucQ7dyRptOYd+lX1F0CmmHX7DOtcBVw1331KkgbjE7mS1BBDX5Ia4gvXjjK+EEvSTDzTl6SGeKYv6ag3yP8B77763CG2ZOEZ+pIWPYcth8fhHUlqiKEvSQ0x9CWpIYa+JDXEC7lHgBedJC1WnulLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQ79OXpBkcbW/o9Exfkhoy8jP9JGcDvwssA/6gqq4edRtmw6dqJR2NRhr6SZYBnwB+EdgD3J1ke1U9dCT2Z3BLWkiLcWho1MM7pwO7qurxqvq/wM3AhhG3QZKaNerhnZXAk33f9wBvPnShJJuBzd3XZ5M8MoK2nQj89Qj2M2r2a+k5Wvtmv+Yg/3mg1X96uhmL8u6dqtoKbB3lPpOMV9X6Ue5zFOzX0nO09s1+LQ6jHt7ZC6zu+76qq0mSRmDUoX83sDbJyUleBmwEto+4DZLUrJEO71TVc0kuBb5A75bNbVX14CjbMIORDieNkP1aeo7WvtmvRSBVtdBtkCSNiE/kSlJDDH1JakhzoZ9kW5IDSR7oq52QZEeSR7ufyxeyjfM1Td+uTLI3yc7uc85CtnE+kqxO8uUkDyV5MMn7u/qSPm4z9OtoOGY/luSrSf6y69u/6+onJ7krya4kf9Td0LFkzNCv65N8o++YrVvgpk6ruTH9JG8FngVurKpTu9pHgKeq6uoklwHLq+pfLWQ752Oavl0JPFtVH13Itg0iyQpgRVXdm+RY4B7gPOA9LOHjNkO/zmfpH7MAL6+qZ5O8FPgL4P3AB4HPVNXNST4J/GVVbVnIts7FDP26GPjvVfUnC9rAWWjuTL+q7gCeOqS8Abihm76B3n94S840fVvyqmpfVd3bTX8XeJje091L+rjN0K8lr3qe7b6+tPsU8I+AyWBcisdsun4tGc2F/jROqqp93fQ3gZMWsjFHwKVJ7uuGf5bUEMihkqwBTgPu4ig6bof0C46CY5ZkWZKdwAFgB/AY8HRVPdctsocl+Evu0H5V1eQxu6o7Zh9LcszCtXBmhv4hqjfetaR+cx/GFuC1wDpgH3DNgrZmAEleAdwKfKCqnumft5SP2xT9OiqOWVU9X1Xr6D15fzrwdxa2RcNxaL+SnApcTq9/fw84AVi0w4yGfs/+bnx1cpz1wAK3Z2iqan/3L+kLwLX0/uNbcrrx01uBT1XVZ7rykj9uU/XraDlmk6rqaeDLwM8BxyeZfCh0Sb+Gpa9fZ3dDdVVVB4H/yiI+ZoZ+z3bggm76AuC2BWzLUE2GYuedwAPTLbtYdRfPrgMerqrf6Zu1pI/bdP06So7ZWJLju+kfp/c3NB6mF5K/1i22FI/ZVP36et/JR+hdp1i0x6zFu3duAs6k9zrU/cAVwJ8CtwCvBp4Azq+qJXdBdJq+nUlvmKCA3cBFfePgS0KSnwf+J3A/8EJX/i16499L9rjN0K9NLP1j9kZ6F2qX0Tu5vKWqPpzkNfT+jsYJwNeAf9adHS8JM/Trz4ExIMBO4OK+C76LSnOhL0ktc3hHkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SG/D8Wxj4pbmUvowAAAABJRU5ErkJggg==\n"
134 | },
135 | "metadata": {
136 | "needs_background": "light"
137 | }
138 | },
139 | {
140 | "output_type": "stream",
141 | "name": "stdout",
142 | "text": [
143 | "min/max : 9 37\n"
144 | ]
145 | },
146 | {
147 | "output_type": "display_data",
148 | "data": {
149 | "text/plain": "",
150 | "image/svg+xml": "\n\n\n\n",
151 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAS2ElEQVR4nO3df6xfdX3H8edrFXFRTGHcNR1QCwZdmNnKcocu/lgn2wQ0QxbDrJuiMysksmlcMqtLJltCgk4kW3SYGhBI5NdEJplskzAmcxnOghWLldliiW1KexVRGI6t8N4f33O3L9d723vv+d7e3s99PpJvvuf7Oed8z/uT0/v6nn6+53xPqgpJUlt+YrELkCSNnuEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12aIsn6JLsXuw6pD8NdWgBJ1iapJM9Z7Fq0PBnuktQgw13NSvK+JJ+Z0vaXSf4qyTuSbE/yeJKHklx4iPfZ0y37YJIzu/afSLIpyc4k30tyc5LjutXu7p4fS/JEkl9emF5K0zPc1bIbgXOSHAOQZAVwPnA9sB94A/BC4B3AFUl+ceobJHkpcDHwS1V1DPA6YFc3+w+ANwK/AvwM8H3g492813TPK6vqBVX1b6PunHQwhruaVVUPA/cB53VNrwWerKp7qurzVbWzBr4IfAF49TRv8zRwNHBakqOqaldV7ezmXQT8SVXtrqqngEuANznOriOB4a7WXQ9s6Kbf0r0mydlJ7knyaJLHgHOA46euXFU7gPcwCO79SW5M8jPd7BcBtyZ5rHuP7Qw+DFYtWG+kWTLc1bq/AdYnOZHBEfz1SY4GbgE+AqyqqpXA7UCme4Oqur6qXsUgzAv4UDfrO8DZVbVy6PG8qtrTLSctGsNdTauqCeCfgU8B366q7cBzGQy1TAAHkpwN/MZ06yd5aZLXdh8I/wX8CHimm/0J4NIkL+qWHUtybjdvolvulAXpmHQIhruWg+uBX+ueqarHgT8EbmbwJehbgNtmWPdo4DLgu8AjwE8D7+/m/WW33heSPA7cA7y828aTwKXAv3bDNq8YfbekmcWbdUhSezxyl6QGGe6S1KBDhnuSk5LcleQbSR5I8u6u/bgkdyT5Vvd8bNee7grAHUnun+7CEEnSwprNkfsB4I+q6jTgFcC7kpwGbALurKpTgTu71wBnA6d2j43AlSOvWpJ0UIe8kq6q9gJ7u+nHk2wHTgDOBdZ3i13L4HSz93Xt19Xgm9p7kqxMsrp7n2kdf/zxtXbt2h7dkKTl59577/1uVY1NN29Ol0knWQucDnyZwcUfk4H9CP9/Vd4JDC7umLS7a3tWuCfZyODInjVr1rBly5a5lCJJy16Sh2eaN+svVJO8gMFVfe+pqh8Oz+uO0ud0TmVVba6q8aoaHxub9oNHkjRPswr3JEcxCPZPV9Vnu+Z9SVZ381cz+JU9gD3ASUOrn9i1SZIOk9mcLRPgKmB7VX10aNZtwAXd9AXA54ba39adNfMK4AcHG2+XJI3ebMbcXwm8Ffh6kq1d2wcYXJJ9c5J3Ag8z+J1sGPwA0znADuBJBr+VLUk6jGZztsyXmOHX8oAzp1m+gHf1rEuS1INXqEpSgwx3SWqQ4S5JDTLcJalB3shXOoS1mz4/73V3Xfb6EVYizZ5H7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ2azQ2yr06yP8m2obabkmztHrsm762aZG2SHw3N+8QC1i5JmsFsfvL3GuBjwHWTDVX125PTSS4HfjC0/M6qWjei+iRJ8zCbG2TfnWTtdPOSBDgfeO2I65Ik9dB3zP3VwL6q+tZQ28lJvprki0lePdOKSTYm2ZJky8TERM8yJEnD+ob7BuCGodd7gTVVdTrwXuD6JC+cbsWq2lxV41U1PjY21rMMSdKweYd7kucAvwXcNNlWVU9V1fe66XuBncBL+hYpSZqbPkfuvwZ8s6p2TzYkGUuyops+BTgVeKhfiZKkuZrNqZA3AP8GvDTJ7iTv7Ga9mWcPyQC8Bri/OzXyM8BFVfXoCOuVJM3CbM6W2TBD+9unabsFuKV/WZKkPrxCVZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSg2ZzD9Wrk+xPsm2o7ZIke5Js7R7nDM17f5IdSR5M8rqFKlySNLPZHLlfA5w1TfsVVbWue9wOkOQ0BjfO/rlunb9OsmJUxUqSZueQ4V5VdwOPzvL9zgVurKqnqurbwA7gjB71SZLmoc+Y+8VJ7u+GbY7t2k4AvjO0zO6u7cck2ZhkS5ItExMTPcqQJE0133C/EngxsA7YC1w+1zeoqs1VNV5V42NjY/MsQ5I0nXmFe1Xtq6qnq+oZ4JP8/9DLHuCkoUVP7NokSYfRvMI9yeqhl+cBk2fS3Aa8OcnRSU4GTgX+vV+JkqS5es6hFkhyA7AeOD7JbuCDwPok64ACdgEXAlTVA0luBr4BHADeVVVPL0jlkqQZHTLcq2rDNM1XHWT5S4FL+xQlSerHK1QlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXokOGe5Ook+5NsG2r7iyTfTHJ/kluTrOza1yb5UZKt3eMTC1i7JGkGszlyvwY4a0rbHcDLqurngf8A3j80b2dVreseF42mTEnSXBwy3KvqbuDRKW1fqKoD3ct7gBMXoDZJ0jyNYsz994C/H3p9cpKvJvliklfPtFKSjUm2JNkyMTExgjIkSZN6hXuSPwEOAJ/umvYCa6rqdOC9wPVJXjjdulW1uarGq2p8bGysTxmSpCnmHe5J3g68AfidqiqAqnqqqr7XTd8L7AReMoI6JUlzMK9wT3IW8MfAb1bVk0PtY0lWdNOnAKcCD42iUEnS7D3nUAskuQFYDxyfZDfwQQZnxxwN3JEE4J7uzJjXAH+e5H+AZ4CLqurRad9YkrRgDhnuVbVhmuarZlj2FuCWvkVJkvrxClVJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ2aVbgnuTrJ/iTbhtqOS3JHkm91z8d27UnyV0l2JLk/yS8uVPGSpOkd8h6qnWuAjwHXDbVtAu6sqsuSbOpevw84Gzi1e7wcuLJ7VgPWbvr8vNfdddnrR1iJpIOZ1ZF7Vd0NPDql+Vzg2m76WuCNQ+3X1cA9wMokq0dQqyRplvqMua+qqr3d9CPAqm76BOA7Q8vt7tqeJcnGJFuSbJmYmOhRhiRpqtkOyxxUVVWSmuM6m4HNAOPj43NaV1oqHMbSYulz5L5vcrile97fte8BThpa7sSuTZJ0mPQJ99uAC7rpC4DPDbW/rTtr5hXAD4aGbyRJh8GshmWS3ACsB45Pshv4IHAZcHOSdwIPA+d3i98OnAPsAJ4E3jHimiVJhzCrcK+qDTPMOnOaZQt4V5+iJEn9eIWqJDXIcJekBhnuktQgw12SGmS4S1KDRnKFqqTR8+pW9WG4Sw3q88EAfji0wGEZSWqQ4S5JDTLcJalBjrlrSXAMWZobj9wlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWrQvM9zT/JS4KahplOAPwVWAr8PTHTtH6iq2+e7HUnS3M073KvqQWAdQJIVwB7gVgY3xL6iqj4yigIlSXM3qitUzwR2VtXDSUb0ltLo9L3CVVpqRjXm/mbghqHXFye5P8nVSY6dboUkG5NsSbJlYmJiukUkSfPUO9yTPBf4TeBvuqYrgRczGLLZC1w+3XpVtbmqxqtqfGxsrG8ZkqQhozhyPxu4r6r2AVTVvqp6uqqeAT4JnDGCbUiS5mAU4b6BoSGZJKuH5p0HbBvBNiRJc9DrC9Ukzwd+HbhwqPnDSdYBBeyaMk+SdBj0Cveq+k/gp6a0vbVXRZKk3rxCVZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNGtVt9qRD8lZ30uHjkbskNchwl6QGGe6S1CDDXZIaZLhLUoN6ny2TZBfwOPA0cKCqxpMcB9wErGVwH9Xzq+r7fbclSZqdUR25/2pVrauq8e71JuDOqjoVuLN7LUk6TBZqWOZc4Npu+lrgjQu0HUnSNEYR7gV8Icm9STZ2bauqam83/QiwagTbkSTN0iiuUH1VVe1J8tPAHUm+OTyzqipJTV2p+yDYCLBmzZoRlCFJmtT7yL2q9nTP+4FbgTOAfUlWA3TP+6dZb3NVjVfV+NjYWN8yJElDeoV7kucnOWZyGvgNYBtwG3BBt9gFwOf6bEeSNDd9h2VWAbcmmXyv66vqH5J8Bbg5yTuBh4Hze25HkjQHvcK9qh4CfmGa9u8BZ/Z5b0nS/HmFqiQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lq0Chus6clZu2mzy92CZIWmEfuktQgw12SGmS4S1KDDHdJatC8wz3JSUnuSvKNJA8keXfXfkmSPUm2do9zRleuJGk2+pwtcwD4o6q6L8kxwL1J7ujmXVFVH+lfXrs8Y0X6cX3+LnZd9voRVrL0zTvcq2ovsLebfjzJduCEURUmafEYskvfSM5zT7IWOB34MvBK4OIkbwO2MDi6//4062wENgKsWbNmFGVIOgL4v9IjQ+8vVJO8ALgFeE9V/RC4EngxsI7Bkf3l061XVZuraryqxsfGxvqWIUka0ivckxzFINg/XVWfBaiqfVX1dFU9A3wSOKN/mZKkuehztkyAq4DtVfXRofbVQ4udB2ybf3mSpPnoM+b+SuCtwNeTbO3aPgBsSLIOKGAXcGGPbUiS5qHP2TJfAjLNrNvnX44kaRS8QlWSGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkho0kp/8laTF5m/QP5vh3oO/Wy3pSOWwjCQ1yCN3SerpSBwS8shdkhrkkbukZa/F78+Wfbi3uFMlqYlwN6Al6dkcc5ekBi1YuCc5K8mDSXYk2bRQ25Ek/bgFCfckK4CPA2cDpzG4afZpC7EtSdKPW6gj9zOAHVX1UFX9N3AjcO4CbUuSNMVCfaF6AvCdode7gZcPL5BkI7Cxe/lEkgd7bO944Ls91l8qlks/Yfn0dbn0E5ZPX+fUz3yo17ZeNNOMRTtbpqo2A5tH8V5JtlTV+Cje60i2XPoJy6evy6WfsHz6eqT0c6GGZfYAJw29PrFrkyQdBgsV7l8BTk1ycpLnAm8GblugbUmSpliQYZmqOpDkYuAfgRXA1VX1wEJsqzOS4Z0lYLn0E5ZPX5dLP2H59PWI6GeqarFrkCSNmFeoSlKDDHdJatCSCvckVyfZn2TbUNtxSe5I8q3u+djFrHFUZujrJUn2JNnaPc5ZzBpHIclJSe5K8o0kDyR5d9fe3H49SF+b2q9Jnpfk35N8revnn3XtJyf5cveTJDd1J1ssaQfp6zVJvj20T9cd9tqW0ph7ktcATwDXVdXLurYPA49W1WXdb9gcW1XvW8w6R2GGvl4CPFFVH1nM2kYpyWpgdVXdl+QY4F7gjcDbaWy/HqSv59PQfk0S4PlV9USSo4AvAe8G3gt8tqpuTPIJ4GtVdeVi1trXQfp6EfB3VfWZxaptSR25V9XdwKNTms8Fru2mr2Xwx7LkzdDX5lTV3qq6r5t+HNjO4Arn5vbrQfralBp4ont5VPco4LXAZNi1sk9n6uuiW1LhPoNVVbW3m34EWLWYxRwGFye5vxu2WfJDFcOSrAVOB75M4/t1Sl+hsf2aZEWSrcB+4A5gJ/BYVR3oFtlNIx9sU/taVZP79NJun16R5OjDXVcL4f5/ajDGdER8ai6QK4EXA+uAvcDli1rNCCV5AXAL8J6q+uHwvNb26zR9bW6/VtXTVbWOwdXpZwA/u7gVLZypfU3yMuD9DPr8S8BxwGEfUmwh3Pd1Y5mTY5r7F7meBVNV+7p/SM8An2TwR7PkdWOVtwCfrqrPds1N7tfp+trqfgWoqseAu4BfBlYmmbxwsrmfJBnq61ndEFxV1VPAp1iEfdpCuN8GXNBNXwB8bhFrWVCTYdc5D9g207JLRfeF1FXA9qr66NCs5vbrTH1tbb8mGUuyspv+SeDXGXy/cBfwpm6xVvbpdH395tCBSRh8t3DY9+lSO1vmBmA9g5/U3Ad8EPhb4GZgDfAwcH5VLfkvImfo63oG/3UvYBdw4dC49JKU5FXAvwBfB57pmj/AYCy6qf16kL5uoKH9muTnGXxhuoLBAeTNVfXnSU5hcG+H44CvAr/bHdkuWQfp6z8BY0CArcBFQ1+8Hp7allK4S5Jmp4VhGUnSFIa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJatD/Artwr29Huy65AAAAAElFTkSuQmCC\n"
152 | },
153 | "metadata": {
154 | "needs_background": "light"
155 | }
156 | },
157 | {
158 | "output_type": "stream",
159 | "name": "stdout",
160 | "text": [
161 | "min/max : 10 36\n"
162 | ]
163 | },
164 | {
165 | "output_type": "display_data",
166 | "data": {
167 | "text/plain": "",
168 | "image/svg+xml": "\n\n\n\n",
169 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAR8klEQVR4nO3de5CddX3H8ffHKGoRC5SVSbk04KDWMjZ0VtQRLfVSEdRIL0jGC6g10pEZHW0t0E5BO0ypFZl22kJDoaBVLiNeaKEqRUd0rJcFIwaBChiGpCFZBQTUoQLf/rFP2uO6S3b3OZvd/e37NXNmn/N7bt9fnuSzv/zOc85JVSFJasvjFroASdLwGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7mpNkU5KX9TzGiUm+vFjqkWbLcJekBhnuakqSjwAHAv+a5MEk703y/CRfSXJfkm8lOXJg+xOT3JHkgSTfS/L6JL8KnAe8oDvGfd22Ryf5TrftliR/NHCcVyXZ0J3jK0meM109u+wPQ8ta/PgBtSbJJuAPquo/kuwH3Ai8EfgM8FLgUuBZwI+BrcBzq+rWJCuBvavqpiQndsc4YuC4W4HjqupLSfYCDqqqG5IcBnwWeDUwBrwBeB/wzKp6aLCeXdF/CRy5q31vAK6uqqur6tGquoaJAD66W/8ocGiSJ1fV1qq66TGO9VPg2UmeWlX3VtUNXfs64B+r6mtV9UhVXQw8BDx/nvok7ZThrtb9CvD73XTJfd0UyxHAyqr6EfA64CRga5KrkjzrMY71u0z8UrgzyReTvGDgHO+ZdI4DgF+epz5JO2W4q0WDc413AR+pqj0HHrtX1VkAVfXZqno5sBK4BTh/imPQbfuNqloDPA34FHD5wDnOnHSOX6iqS6Y7ljTfDHe1aBtwcLf8L8Crk7wiyYokT0pyZJL9k+ybZE2S3ZmYRnmQiWmaHcfYP8luAEl2615s/cWq+ilw/8C25wMnJXleJuye5Jgke0xRj7RLGO5q0V8Cf9ZNj7wOWAOcBowzMcr+Yyb+7j8OeDfw38A9wG8Cf9gd4/PATcDdSb7ftb0R2JTkfiamcl4PUFVjwNuAvwPuBW4DTpyqnsE7bKT55N0yktQgR+6S1CDDXZIaZLhLUoMMd0lq0OMXugCAffbZp1atWrXQZUjSknL99dd/v6pGplq3KMJ91apVjI2NLXQZkrSkJLlzunVOy0hSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMWxTtUpcVs1SlXzXnfTWcdM8RKpJnb6cg9yYVJtifZONB2WZIN3WNTkg1d+6okPxlYd9481i5JmsZMRu4XMfH1YR/e0VBVr9uxnORs4IcD299eVauHVJ8kaQ52Gu5VdV2SVVOtSxLgOOAlQ65LktRD3xdUXwRsq6rvDrQdlOSbSb6Y5EXT7ZhkXZKxJGPj4+M9y5AkDeob7muBSwaebwUOrKrDmPhW+Y8leepUO1bV+qoararRkZEpP45YkjRHcw73JI8Hfge4bEdbVT1UVT/olq8Hbgee0bdISdLs9Bm5vwy4pao272hIMpJkRbd8MHAIcEe/EiVJszWTWyEvAf4TeGaSzUne2q06np+dkgF4MXBjd2vkx4GTquqeIdYrSZqBmdwts3aa9hOnaLsCuKJ/WZKkPvz4AUlqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNWin4Z7kwiTbk2wcaDsjyZYkG7rH0QPrTk1yW5Jbk7xivgqXJE1vJiP3i4Cjpmg/p6pWd4+rAZI8Gzge+LVun39IsmJYxUqSZman4V5V1wH3zPB4a4BLq+qhqvoecBtweI/6JElz0GfO/eQkN3bTNnt1bfsBdw1ss7lr+zlJ1iUZSzI2Pj7eowxJ0mRzDfdzgacDq4GtwNmzPUBVra+q0aoaHRkZmWMZkqSpzCncq2pbVT1SVY8C5/P/Uy9bgAMGNt2/a5Mk7UJzCvckKweeHgvsuJPmSuD4JE9MchBwCPD1fiVKkmbr8TvbIMklwJHAPkk2A6cDRyZZDRSwCXg7QFXdlORy4DvAw8A7quqRealckjStnYZ7Va2dovmCx9j+TODMPkVJkvrxHaqS1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktSgnb5DVRq06pSr5rzvprOOGWIlkh6LI3dJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDdppuCe5MMn2JBsH2v46yS1JbkzyySR7du2rkvwkyYbucd481i5JmsZMRu4XAUdNarsGOLSqngP8F3DqwLrbq2p19zhpOGVKkmZjp+FeVdcB90xq+1xVPdw9/Sqw/zzUJkmao2HMub8F+PeB5wcl+WaSLyZ50XQ7JVmXZCzJ2Pj4+BDKkCTt0Cvck/wp8DDw0a5pK3BgVR0GvBv4WJKnTrVvVa2vqtGqGh0ZGelThiRpkjmHe5ITgVcBr6+qAqiqh6rqB93y9cDtwDOGUKckaRbmFO5JjgLeC7ymqn480D6SZEW3fDBwCHDHMAqVJM3cTr+JKcklwJHAPkk2A6czcXfME4FrkgB8tbsz5sXA+5P8FHgUOKmq7pnywJKkebPTcK+qtVM0XzDNtlcAV/QtSpLUj+9QlaQGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSg2YU7kkuTLI9ycaBtr2TXJPku93Pvbr2JPnbJLcluTHJb8xX8ZKkqc105H4RcNSktlOAa6vqEODa7jnAK4FDusc64Nz+ZUqSZmNG4V5V1wH3TGpeA1zcLV8MvHag/cM14avAnklWDqFWSdIM9Zlz37eqtnbLdwP7dsv7AXcNbLe5a/sZSdYlGUsyNj4+3qMMSdJkQ3lBtaoKqFnus76qRqtqdGRkZBhlSJI6fcJ9247plu7n9q59C3DAwHb7d22SpF2kT7hfCZzQLZ8AfHqg/U3dXTPPB344MH0jSdoFHj+TjZJcAhwJ7JNkM3A6cBZweZK3AncCx3WbXw0cDdwG/Bh485BrliTtxIzCvarWTrPqpVNsW8A7+hQlSerHd6hKUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1aEafCikNw6pTrprzvpvOOmaIlUjtM9y1JPT5xQD+ctDyY7hLi5T/01EfzrlLUoMMd0lqkOEuSQ0y3CWpQXN+QTXJM4HLBpoOBv4c2BN4GzDetZ9WVVfP9TySpNmbc7hX1a3AaoAkK4AtwCeBNwPnVNUHh1GgpNnz1lENa1rmpcDtVXXnkI4nSephWOF+PHDJwPOTk9yY5MIke021Q5J1ScaSjI2Pj0+1iSRpjnq/iSnJbsBrgFO7pnOBvwCq+3k28JbJ+1XVemA9wOjoaPWtQ3osfacppKVmGCP3VwI3VNU2gKraVlWPVNWjwPnA4UM4hyRpFoYR7msZmJJJsnJg3bHAxiGcQ5I0C72mZZLsDrwcePtA8weSrGZiWmbTpHWSpF2gV7hX1Y+AX5rU9sZeFUmSevMdqpLUID/yV5pH3qWjheLIXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSg3p/zV6STcADwCPAw1U1mmRv4DJgFbAJOK6q7u17LknSzAxr5P5bVbW6qka756cA11bVIcC13XNJ0i4yX9Mya4CLu+WLgdfO03kkSVMYRrgX8Lkk1ydZ17XtW1Vbu+W7gX0n75RkXZKxJGPj4+NDKEOStEPvOXfgiKrakuRpwDVJbhlcWVWVpCbvVFXrgfUAo6OjP7dekjR3vUfuVbWl+7kd+CRwOLAtyUqA7uf2vueRJM1cr3BPsnuSPXYsA78NbASuBE7oNjsB+HSf80iSZqfvtMy+wCeT7DjWx6rqM0m+AVye5K3AncBxPc8jSZqFXuFeVXcAvz5F+w+Al/Y5tiRp7nyHqiQ1yHCXpAYZ7pLUIMNdkho0jDcxaQ5WnXLVnPfddNYxQ6xEUoscuUtSgxy5L0F9Rv3gyF9aDhy5S1KDDHdJapDhLkkNcs59Geo7Zy9p8XPkLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDZpzuCc5IMkXknwnyU1J3tm1n5FkS5IN3ePo4ZUrSZqJPh8/8DDwnqq6IckewPVJrunWnVNVH+xfniRpLuYc7lW1FdjaLT+Q5GZgv2EVJkmau6F8cFiSVcBhwNeAFwInJ3kTMMbE6P7eKfZZB6wDOPDAA4dRhqQh8Wsgl77eL6gmeQpwBfCuqrofOBd4OrCaiZH92VPtV1Xrq2q0qkZHRkb6liFJGtAr3JM8gYlg/2hVfQKgqrZV1SNV9ShwPnB4/zIlSbPR526ZABcAN1fVhwbaVw5sdiywce7lSZLmos+c+wuBNwLfTrKhazsNWJtkNVDAJuDtPc4hSZqDPnfLfBnIFKuunns5kqRh8B2qktQgv0NV0qLhLZjD48hdkhrkyF3SUPUZfWt4HLlLUoMMd0lqkOEuSQ0y3CWpQb6g2oMvHElarBy5S1KDDHdJapDhLkkNMtwlqUGGuyQ1qIm7ZfywIUn6WY7cJalBTYzc+/BedUktWvbhLkl9LcapYadlJKlBjtwlLXstTs/OW7gnOQr4G2AF8E9VddZ8nUuSWgzoPuZlWibJCuDvgVcCzwbWJnn2fJxLkvTz5mvO/XDgtqq6o6r+B7gUWDNP55IkTTJf0zL7AXcNPN8MPG9wgyTrgHXd0weT3Dqweh/g+/NU22KyXPoJy6evy6WfsHz6Oq/9zF/12v1XpluxYC+oVtV6YP1U65KMVdXoLi5pl1su/YTl09fl0k9YPn1dqv2cr2mZLcABA8/379okSbvAfIX7N4BDkhyUZDfgeODKeTqXJGmSeZmWqaqHk5wMfJaJWyEvrKqbZnGIKadrGrRc+gnLp6/LpZ+wfPq6JPuZqlroGiRJQ+bHD0hSgwx3SWrQgod7kguTbE+ycaBt7yTXJPlu93OvhaxxGKbp5xlJtiTZ0D2OXsgahyHJAUm+kOQ7SW5K8s6uvcVrOl1fm7quSZ6U5OtJvtX1831d+0FJvpbktiSXdTdPLGmP0deLknxv4JquXuBSd2rB59yTvBh4EPhwVR3atX0AuKeqzkpyCrBXVf3JQtbZ1zT9PAN4sKo+uJC1DVOSlcDKqrohyR7A9cBrgRNp75pO19fjaOi6Jgmwe1U9mOQJwJeBdwLvBj5RVZcmOQ/4VlWdu5C19vUYfT0J+Leq+viCFjgLCz5yr6rrgHsmNa8BLu6WL2biH8ySNk0/m1NVW6vqhm75AeBmJt6x3OI1na6vTakJD3ZPn9A9CngJsCPsWrmm0/V1yVnwcJ/GvlW1tVu+G9h3IYuZZycnubGbtlnyUxWDkqwCDgO+RuPXdFJfobHrmmRFkg3AduAa4Hbgvqp6uNtkM438Ypvc16racU3P7K7pOUmeuHAVzsxiDff/UxPzRkvyN+cMnAs8HVgNbAXOXtBqhijJU4ArgHdV1f2D61q7plP0tbnrWlWPVNVqJt5tfjjwrIWtaP5M7muSQ4FTmejzc4G9gUU/pbhYw31bN5+5Y15z+wLXMy+qalv3F+lR4Hwm/tEsed1c5RXAR6vqE11zk9d0qr62el0Bquo+4AvAC4A9k+x4I2RzHzEy0Nejuim4qqqHgH9mCVzTxRruVwIndMsnAJ9ewFrmzY6w6xwLbJxu26Wie0HqAuDmqvrQwKrmrul0fW3tuiYZSbJnt/xk4OVMvL7wBeD3us1auaZT9fWWgYFJmHhtYdFf08Vwt8wlwJFMfKzmNuB04FPA5cCBwJ3AcVW1pF+MnKafRzLxX/cCNgFvH5iXXpKSHAF8Cfg28GjXfBoTc9GtXdPp+rqWhq5rkucw8YLpCiYGhJdX1fuTHMzEdzXsDXwTeEM3sl2yHqOvnwdGgAAbgJMGXnhdlBY83CVJw7dYp2UkST0Y7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalB/wt5mHs8bR+hZwAAAABJRU5ErkJggg==\n"
170 | },
171 | "metadata": {
172 | "needs_background": "light"
173 | }
174 | },
175 | {
176 | "output_type": "stream",
177 | "name": "stdout",
178 | "text": [
179 | "min/max : 11 37\n"
180 | ]
181 | }
182 | ],
183 | "source": [
184 | "def plot_histo_graphs(dataset, title):\n",
185 | " # histogram of graph sizes\n",
186 | " graph_sizes = []\n",
187 | " for graph in dataset:\n",
188 | " graph_sizes.append(graph[0].number_of_nodes())\n",
189 | " plt.figure(1)\n",
190 | " plt.hist(graph_sizes, bins=20)\n",
191 | " plt.title(title)\n",
192 | " plt.show()\n",
193 | " graph_sizes = torch.Tensor(graph_sizes)\n",
194 | " print('min/max :',graph_sizes.min().long().item(),graph_sizes.max().long().item())\n",
195 | " \n",
196 | "plot_histo_graphs(dataset.train,'trainset')\n",
197 | "plot_histo_graphs(dataset.val,'valset')\n",
198 | "plot_histo_graphs(dataset.test,'testset')\n"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 9,
204 | "metadata": {},
205 | "outputs": [
206 | {
207 | "output_type": "stream",
208 | "name": "stdout",
209 | "text": [
210 | "10000\n1000\n1000\n(Graph(num_nodes=29, num_edges=64,\n ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}\n edata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}), tensor([0.8350]))\n(Graph(num_nodes=35, num_edges=78,\n ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}\n edata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}), tensor([0.6299]))\n(Graph(num_nodes=16, num_edges=34,\n ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}\n edata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}), tensor([1.9973]))\n"
211 | ]
212 | }
213 | ],
214 | "source": [
215 | "print(len(dataset.train))\n",
216 | "print(len(dataset.val))\n",
217 | "print(len(dataset.test))\n",
218 | "\n",
219 | "print(dataset.train[0])\n",
220 | "print(dataset.val[0])\n",
221 | "print(dataset.test[0])\n"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 10,
227 | "metadata": {},
228 | "outputs": [],
229 | "source": [
230 | "num_atom_type = 28\n",
231 | "num_bond_type = 4\n"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": 11,
237 | "metadata": {},
238 | "outputs": [
239 | {
240 | "output_type": "stream",
241 | "name": "stderr",
242 | "text": [
243 | "/data00/caishaofei/miniconda3/envs/gnas2/lib/python3.6/site-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead\n",
244 | " warnings.warn(\"pickle support for Storage will be removed in 1.5. Use `torch.save` instead\", FutureWarning)\n",
245 | "Time (sec): 6.077600479125977\n"
246 | ]
247 | }
248 | ],
249 | "source": [
250 | "start = time.time()\n",
251 | "with open('data/molecules/ZINC.pkl','wb') as f:\n",
252 | " pickle.dump([dataset.train,dataset.val,dataset.test,num_atom_type,num_bond_type],f)\n",
253 | "print('Time (sec):',time.time() - start)\n"
254 | ]
255 | },
256 | {
257 | "cell_type": "markdown",
258 | "metadata": {},
259 | "source": [
260 | "# Test load function"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": 12,
266 | "metadata": {},
267 | "outputs": [
268 | {
269 | "output_type": "error",
270 | "ename": "NameError",
271 | "evalue": "name 'LoadData' is not defined",
272 | "traceback": [
273 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
274 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
275 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mDATASET_NAME\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'ZINC'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLoadData\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDATASET_NAME\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mtrainset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtestset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
276 | "\u001b[0;31mNameError\u001b[0m: name 'LoadData' is not defined"
277 | ]
278 | }
279 | ],
280 | "source": [
281 | "DATASET_NAME = 'ZINC'\n",
282 | "dataset = LoadData(DATASET_NAME)\n",
283 | "trainset, valset, testset = dataset.train, dataset.val, dataset.test\n"
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": 11,
289 | "metadata": {},
290 | "outputs": [
291 | {
292 | "name": "stdout",
293 | "output_type": "stream",
294 | "text": [
295 | "\n"
296 | ]
297 | }
298 | ],
299 | "source": [
300 | "batch_size = 10\n",
301 | "collate = MoleculeDataset.collate\n",
302 | "print(MoleculeDataset)\n",
303 | "train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=collate)\n"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": null,
309 | "metadata": {},
310 | "outputs": [],
311 | "source": []
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": null,
316 | "metadata": {},
317 | "outputs": [],
318 | "source": []
319 | }
320 | ],
321 | "metadata": {
322 | "kernelspec": {
323 | "name": "python3613jvsc74a57bd0720c3f9f2262024fbbfd811ec9823dbfd16426a3ecbb08b5c838007ee1202dee",
324 | "display_name": "Python 3.6.13 64-bit ('gnas2': conda)"
325 | },
326 | "language_info": {
327 | "codemirror_mode": {
328 | "name": "ipython",
329 | "version": 3
330 | },
331 | "file_extension": ".py",
332 | "mimetype": "text/x-python",
333 | "name": "python",
334 | "nbconvert_exporter": "python",
335 | "pygments_lexer": "ipython3",
336 | "version": "3.6.13"
337 | },
338 | "metadata": {
339 | "interpreter": {
340 | "hash": "720c3f9f2262024fbbfd811ec9823dbfd16426a3ecbb08b5c838007ee1202dee"
341 | }
342 | }
343 | },
344 | "nbformat": 4,
345 | "nbformat_minor": 4
346 | }
--------------------------------------------------------------------------------
/data/superpixels.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from scipy.spatial.distance import cdist
4 | import numpy as np
5 | import itertools
6 |
7 | import dgl
8 | import torch
9 | import torch.utils.data
10 |
11 | import time
12 |
13 | import csv
14 | from sklearn.model_selection import StratifiedShuffleSplit
15 |
16 | from scipy import sparse as sp
17 | import numpy as np
18 |
19 | def sigma(dists, kth=8):
20 | # Compute sigma and reshape
21 | try:
22 | # Get k-nearest neighbors for each node
23 | knns = np.partition(dists, kth, axis=-1)[:, kth::-1]
24 | sigma = knns.sum(axis=1).reshape((knns.shape[0], 1))/kth
25 | except ValueError: # handling for graphs with num_nodes less than kth
26 | num_nodes = dists.shape[0]
27 | # this sigma value is irrelevant since not used for final compute_edge_list
28 | sigma = np.array([1]*num_nodes).reshape(num_nodes,1)
29 |
30 | return sigma + 1e-8 # adding epsilon to avoid zero value of sigma
31 |
32 |
33 | def compute_adjacency_matrix_images(coord, feat, use_feat=True, kth=8):
34 | coord = coord.reshape(-1, 2)
35 | # Compute coordinate distance
36 | c_dist = cdist(coord, coord)
37 |
38 | if use_feat:
39 | # Compute feature distance
40 | f_dist = cdist(feat, feat)
41 | # Compute adjacency
42 | A = np.exp(- (c_dist/sigma(c_dist))**2 - (f_dist/sigma(f_dist))**2 )
43 | else:
44 | A = np.exp(- (c_dist/sigma(c_dist))**2)
45 |
46 | # Convert to symmetric matrix
47 | A = 0.5 * (A + A.T)
48 | A[np.diag_indices_from(A)] = 0
49 | return A
50 |
51 |
52 | def compute_edges_list(A, kth=8+1):
53 | # Get k-similar neighbor indices for each node
54 |
55 | num_nodes = A.shape[0]
56 | new_kth = num_nodes - kth
57 |
58 | if num_nodes > 9:
59 | knns = np.argpartition(A, new_kth-1, axis=-1)[:, new_kth:-1]
60 | knn_values = np.partition(A, new_kth-1, axis=-1)[:, new_kth:-1] # NEW
61 | else:
62 | # handling for graphs with less than kth nodes
63 | # in such cases, the resulting graph will be fully connected
64 | knns = np.tile(np.arange(num_nodes), num_nodes).reshape(num_nodes, num_nodes)
65 | knn_values = A # NEW
66 |
67 | # removing self loop
68 | if num_nodes != 1:
69 | knn_values = A[knns != np.arange(num_nodes)[:,None]].reshape(num_nodes,-1) # NEW
70 | knns = knns[knns != np.arange(num_nodes)[:,None]].reshape(num_nodes,-1)
71 | return knns, knn_values # NEW
72 |
73 |
74 | class SuperPixDGL(torch.utils.data.Dataset):
75 | def __init__(self,
76 | data_dir,
77 | dataset,
78 | split,
79 | use_mean_px=True,
80 | use_coord=True):
81 |
82 | self.split = split
83 |
84 | self.graph_lists = []
85 |
86 | if dataset == 'MNIST':
87 | self.img_size = 28
88 | with open(os.path.join(data_dir, 'mnist_75sp_%s.pkl' % split), 'rb') as f:
89 | self.labels, self.sp_data = pickle.load(f)
90 | self.graph_labels = torch.LongTensor(self.labels)
91 | elif dataset == 'CIFAR10':
92 | self.img_size = 32
93 | with open(os.path.join(data_dir, 'cifar10_150sp_%s.pkl' % split), 'rb') as f:
94 | self.labels, self.sp_data = pickle.load(f)
95 | self.graph_labels = torch.LongTensor(self.labels)
96 |
97 | self.use_mean_px = use_mean_px
98 | self.use_coord = use_coord
99 | self.n_samples = len(self.labels)
100 |
101 | self._prepare()
102 |
103 | def _prepare(self):
104 | print("preparing %d graphs for the %s set..." % (self.n_samples, self.split.upper()))
105 | self.Adj_matrices, self.node_features, self.edges_lists, self.edge_features = [], [], [], []
106 | for index, sample in enumerate(self.sp_data):
107 | mean_px, coord = sample[:2]
108 |
109 | try:
110 | coord = coord / self.img_size
111 | except AttributeError:
112 | VOC_has_variable_image_sizes = True
113 |
114 | if self.use_mean_px:
115 | A = compute_adjacency_matrix_images(coord, mean_px) # using super-pixel locations + features
116 | else:
117 | A = compute_adjacency_matrix_images(coord, mean_px, False) # using only super-pixel locations
118 | edges_list, edge_values_list = compute_edges_list(A) # NEW
119 |
120 | N_nodes = A.shape[0]
121 |
122 | mean_px = mean_px.reshape(N_nodes, -1)
123 | coord = coord.reshape(N_nodes, 2)
124 | x = np.concatenate((mean_px, coord), axis=1)
125 |
126 | edge_values_list = edge_values_list.reshape(-1) # NEW # TO DOUBLE-CHECK !
127 |
128 | self.node_features.append(x)
129 | self.edge_features.append(edge_values_list) # NEW
130 | self.Adj_matrices.append(A)
131 | self.edges_lists.append(edges_list)
132 |
133 | for index in range(len(self.sp_data)):
134 | g = dgl.DGLGraph()
135 | g.add_nodes(self.node_features[index].shape[0])
136 | g.ndata['feat'] = torch.Tensor(self.node_features[index]).half()
137 |
138 | for src, dsts in enumerate(self.edges_lists[index]):
139 | # handling for 1 node where the self loop would be the only edge
140 | # since, VOC Superpixels has few samples (5 samples) with only 1 node
141 | if self.node_features[index].shape[0] == 1:
142 | g.add_edges(src, dsts)
143 | else:
144 | g.add_edges(src, dsts[dsts!=src])
145 |
146 | # adding edge features for Residual Gated ConvNet
147 | edge_feat_dim = g.ndata['feat'].shape[1] # dim same as node feature dim
148 | #g.edata['feat'] = torch.ones(g.number_of_edges(), edge_feat_dim).half()
149 | g.edata['feat'] = torch.Tensor(self.edge_features[index]).unsqueeze(1).half() # NEW
150 |
151 | self.graph_lists.append(g)
152 |
153 | def __len__(self):
154 | """Return the number of graphs in the dataset."""
155 | return self.n_samples
156 |
157 | def __getitem__(self, idx):
158 | """
159 | Get the idx^th sample.
160 | Parameters
161 | ---------
162 | idx : int
163 | The sample index.
164 | Returns
165 | -------
166 | (dgl.DGLGraph, int)
167 | DGLGraph with node feature stored in `feat` field
168 | And its label.
169 | """
170 | return self.graph_lists[idx], self.graph_labels[idx]
171 |
172 |
173 | class DGLFormDataset(torch.utils.data.Dataset):
174 | """
175 | DGLFormDataset wrapping graph list and label list as per pytorch Dataset.
176 | *lists (list): lists of 'graphs' and 'labels' with same len().
177 | """
178 | def __init__(self, *lists):
179 | assert all(len(lists[0]) == len(li) for li in lists)
180 | self.lists = lists
181 | self.graph_lists = lists[0]
182 | self.graph_labels = lists[1]
183 |
184 | def __getitem__(self, index):
185 | return tuple(li[index] for li in self.lists)
186 |
187 | def __len__(self):
188 | return len(self.lists[0])
189 |
190 |
191 | class SuperPixDatasetDGL(torch.utils.data.Dataset):
192 | def __init__(self, name, num_val=5000):
193 | """
194 | Takes input standard image dataset name (MNIST/CIFAR10)
195 | and returns the superpixels graph.
196 |
197 | This class uses results from the above SuperPix class.
198 | which contains the steps for the generation of the Superpixels
199 | graph from a superpixel .pkl file that has been given by
200 | https://github.com/bknyaz/graph_attention_pool
201 |
202 | Please refer the SuperPix class for details.
203 | """
204 | t_data = time.time()
205 | self.name = name
206 |
207 | use_mean_px = True # using super-pixel locations + features
208 | use_mean_px = False # using only super-pixel locations
209 | if use_mean_px:
210 | print('Adj matrix defined from super-pixel locations + features')
211 | else:
212 | print('Adj matrix defined from super-pixel locations (only)')
213 | use_coord = True
214 | self.test = SuperPixDGL("./data/superpixels", dataset=self.name, split='test',
215 | use_mean_px=use_mean_px,
216 | use_coord=use_coord)
217 |
218 | self.train_ = SuperPixDGL("./data/superpixels", dataset=self.name, split='train',
219 | use_mean_px=use_mean_px,
220 | use_coord=use_coord)
221 |
222 | _val_graphs, _val_labels = self.train_[:num_val]
223 | _train_graphs, _train_labels = self.train_[num_val:]
224 |
225 | self.val = DGLFormDataset(_val_graphs, _val_labels)
226 | self.train = DGLFormDataset(_train_graphs, _train_labels)
227 |
228 | print("[I] Data load time: {:.4f}s".format(time.time()-t_data))
229 |
230 |
231 |
232 | def self_loop(g):
233 | """
234 | Utility function only, to be used only when necessary as per user self_loop flag
235 | : Overwriting the function dgl.transform.add_self_loop() to not miss ndata['feat'] and edata['feat']
236 |
237 |
238 | This function is called inside a function in SuperPixDataset class.
239 | """
240 | new_g = dgl.DGLGraph()
241 | new_g.add_nodes(g.number_of_nodes())
242 | new_g.ndata['feat'] = g.ndata['feat']
243 |
244 | src, dst = g.all_edges(order="eid")
245 | src = dgl.backend.zerocopy_to_numpy(src)
246 | dst = dgl.backend.zerocopy_to_numpy(dst)
247 | non_self_edges_idx = src != dst
248 | nodes = np.arange(g.number_of_nodes())
249 | new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx])
250 | new_g.add_edges(nodes, nodes)
251 |
252 | # This new edata is not used since this function gets called only for GCN, GAT
253 | # However, we need this for the generic requirement of ndata and edata
254 | new_g.edata['feat'] = torch.zeros(new_g.number_of_edges())
255 | return new_g
256 |
257 | def positional_encoding(g, pos_enc_dim):
258 | """
259 | Graph positional encoding v/ Laplacian eigenvectors
260 | """
261 |
262 | # Laplacian
263 | A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
264 | N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
265 | L = sp.eye(g.number_of_nodes()) - N * A * N
266 |
267 | # Eigenvectors with numpy
268 | EigVal, EigVec = np.linalg.eig(L.toarray())
269 | idx = EigVal.argsort() # increasing order
270 | EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx])
271 | g.ndata['pos_enc'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float()
272 |
273 | # # Eigenvectors with scipy
274 | # EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR')
275 | # EigVec = EigVec[:, EigVal.argsort()] # increasing order
276 | # g.ndata['pos_enc'] = torch.from_numpy(np.abs(EigVec[:,1:pos_enc_dim+1])).float()
277 |
278 | return g
279 |
280 | class SuperPixDataset(torch.utils.data.Dataset):
281 |
282 | def __init__(self, name):
283 | """
284 | Loading Superpixels datasets
285 | """
286 | start = time.time()
287 | print("[I] Loading dataset %s..." % (name))
288 | self.name = name
289 | data_dir = 'data/superpixels/'
290 | with open(data_dir+name+'.pkl',"rb") as f:
291 | f = pickle.load(f)
292 | self.train = f[0]
293 | self.val = f[1]
294 | self.test = f[2]
295 | print('train, test, val sizes :',len(self.train),len(self.test),len(self.val))
296 | print("[I] Finished loading.")
297 | print("[I] Data load time: {:.4f}s".format(time.time()-start))
298 |
299 |
300 | # form a mini batch from a given list of samples = [(graph, label) pairs]
301 | def collate(self, samples):
302 | # The input samples is a list of pairs (graph, label).
303 | graphs, labels = map(list, zip(*samples))
304 | labels = torch.tensor(np.array(labels))
305 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))]
306 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ]
307 | #snorm_n = torch.cat(tab_snorm_n).sqrt()
308 | #tab_sizes_e = [ graphs[i].number_of_edges() for i in range(len(graphs))]
309 | #tab_snorm_e = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_e ]
310 | #snorm_e = torch.cat(tab_snorm_e).sqrt()
311 | for idx, graph in enumerate(graphs):
312 | graphs[idx].ndata['feat'] = graph.ndata['feat'].float()
313 | graphs[idx].edata['feat'] = graph.edata['feat'].float()
314 | batched_graph = dgl.batch(graphs)
315 |
316 | return batched_graph, labels
317 |
318 |
319 | # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN
320 | def collate_dense_gnn(self, samples):
321 | # The input samples is a list of pairs (graph, label).
322 | graphs, labels = map(list, zip(*samples))
323 | labels = torch.tensor(np.array(labels))
324 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))]
325 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ]
326 | #snorm_n = tab_snorm_n[0][0].sqrt()
327 |
328 | #batched_graph = dgl.batch(graphs)
329 |
330 | g = graphs[0]
331 | adj = self._sym_normalize_adj(g.adjacency_matrix().to_dense())
332 | """
333 | Adapted from https://github.com/leichen2018/Ring-GNN/
334 | Assigning node and edge feats::
335 | we have the adjacency matrix in R^{n x n}, the node features in R^{d_n} and edge features R^{d_e}.
336 | Then we build a zero-initialized tensor, say T, in R^{(1 + d_n + d_e) x n x n}. T[0, :, :] is the adjacency matrix.
337 | The diagonal T[1:1+d_n, i, i], i = 0 to n-1, store the node feature of node i.
338 | The off diagonal T[1+d_n:, i, j] store edge features of edge(i, j).
339 | """
340 |
341 | zero_adj = torch.zeros_like(adj)
342 |
343 | in_dim = g.ndata['feat'].shape[1]
344 |
345 | # use node feats to prepare adj
346 | adj_node_feat = torch.stack([zero_adj for j in range(in_dim)])
347 | adj_node_feat = torch.cat([adj.unsqueeze(0), adj_node_feat], dim=0)
348 |
349 | for node, node_feat in enumerate(g.ndata['feat']):
350 | adj_node_feat[1:, node, node] = node_feat
351 |
352 | x_node_feat = adj_node_feat.unsqueeze(0)
353 |
354 | return x_node_feat, labels
355 |
356 | def _sym_normalize_adj(self, adj):
357 | deg = torch.sum(adj, dim = 0)#.squeeze()
358 | deg_inv = torch.where(deg>0, 1./torch.sqrt(deg), torch.zeros(deg.size()))
359 | deg_inv = torch.diag(deg_inv)
360 | return torch.mm(deg_inv, torch.mm(adj, deg_inv))
361 |
362 |
363 | def _add_self_loops(self):
364 |
365 | # function for adding self loops
366 | # this function will be called only if self_loop flag is True
367 |
368 | self.train.graph_lists = [self_loop(g) for g in self.train.graph_lists]
369 | self.val.graph_lists = [self_loop(g) for g in self.val.graph_lists]
370 | self.test.graph_lists = [self_loop(g) for g in self.test.graph_lists]
371 |
372 | self.train = DGLFormDataset(self.train.graph_lists, self.train.graph_labels)
373 | self.val = DGLFormDataset(self.val.graph_lists, self.val.graph_labels)
374 | self.test = DGLFormDataset(self.test.graph_lists, self.test.graph_labels)
375 |
376 |
377 | def _add_positional_encodings(self, pos_enc_dim):
378 |
379 | # Graph positional encoding v/ Laplacian eigenvectors
380 | self.train.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.train.graph_lists]
381 | self.val.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.val.graph_lists]
382 | self.test.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.test.graph_lists]
383 |
384 | if __name__ == '__main__':
385 | ds = SuperPixDataset('MNIST')
386 | print(ds.train.__getitem__(0))
387 | ds = SuperPixDataset('CIFAR10')
388 | print(ds.train.__getitem__(0))
389 |
390 |
--------------------------------------------------------------------------------
/environment_gpu.yml:
--------------------------------------------------------------------------------
1 | name: gnasmp
2 | channels:
3 | - defaults
4 | - dglteam
5 | dependencies:
6 | - python=3.6
7 | - cudatoolkit=11.0
8 | - dgl-cuda11.0
9 | - pygraphviz
10 | - pip:
11 | - ipdb
12 | - tqdm
13 | - torch==1.5.0
14 | - scikit-learn
15 | - tensorboardX
16 | - rich
17 |
--------------------------------------------------------------------------------
/example_geno.yaml:
--------------------------------------------------------------------------------
1 | Genotype:
2 | -
3 | id: 1
4 | topology:
5 | -
6 | src: 0
7 | dst: 1
8 | ops: 'V_Max'
9 | -
10 | src: 1
11 | dst: 2
12 | ops: 'V_Sum'
13 | -
14 | src: 2
15 | dst: 3
16 | ops: 'V_Sum'
17 | -
18 | src: 3
19 | dst: 4
20 | ops: 'V_Max'
21 | -
22 | id: 2
23 | topology:
24 | -
25 | src: 0
26 | dst: 1
27 | ops: 'V_Max'
28 | -
29 | src: 1
30 | dst: 2
31 | ops: 'V_Sum'
32 | -
33 | src: 2
34 | dst: 3
35 | ops: 'V_Sum'
36 | -
37 | src: 3
38 | dst: 4
39 | ops: 'V_Max'
40 | -
41 | id: 3
42 | topology:
43 | -
44 | src: 0
45 | dst: 1
46 | ops: 'V_Max'
47 | -
48 | src: 1
49 | dst: 2
50 | ops: 'V_Sum'
51 | -
52 | src: 2
53 | dst: 3
54 | ops: 'V_Sum'
55 | -
56 | src: 3
57 | dst: 4
58 | ops: 'V_Max'
59 | -
60 | id: 4
61 | topology:
62 | -
63 | src: 0
64 | dst: 1
65 | ops: 'V_Max'
66 | -
67 | src: 1
68 | dst: 2
69 | ops: 'V_Sum'
70 | -
71 | src: 2
72 | dst: 3
73 | ops: 'V_Sum'
74 | -
75 | src: 3
76 | dst: 4
77 | ops: 'V_Max'
--------------------------------------------------------------------------------
/models/architect.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 |
6 |
7 | def _concat(xs):
8 | return torch.cat([x.view(-1) for x in xs])
9 |
10 |
11 | class Architect(object):
12 |
13 | def __init__(self, model, args):
14 | self.args = args
15 | self.network_momentum = args.momentum
16 | self.network_weight_decay = args.weight_decay
17 | self.model = model
18 | self.optimizer = torch.optim.Adam(
19 | params = self.model.arch_parameters(),
20 | lr = args.arch_lr,
21 | betas = (0.5, 0.999),
22 | weight_decay = args.arch_weight_decay
23 | )
24 |
25 | def _compute_unrolled_model(self, input, target, eta, network_optimizer):
26 | loss = self.model._loss(input, target)
27 | theta = _concat(self.model.parameters()).data
28 | try:
29 | moment = _concat(network_optimizer.state[v]['momentum_buffer']
30 | for v in self.model.parameters()).mul_(self.network_momentum)
31 | except:
32 | moment = torch.zeros_like(theta)
33 | dtheta = _concat(torch.autograd.grad(
34 | outputs = loss,
35 | inputs = self.model.parameters())
36 | ).data + self.network_weight_decay*theta
37 | unrolled_model = self._construct_model_from_theta(theta.sub(eta, moment+dtheta))
38 | return unrolled_model
39 |
40 | def step(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled):
41 | self.optimizer.zero_grad()
42 | if unrolled:
43 | self._backward_step_unrolled(
44 | input_train, target_train, input_valid, target_valid, eta, network_optimizer)
45 | else:
46 | if self.args.search_mode == 'darts_1':
47 | self._backward_step(input_valid, target_valid)
48 | elif self.args.search_mode == 'train':
49 | self._backward_step(input_train, target_train)
50 | self.optimizer.step()
51 |
52 | def _backward_step(self, input_valid, target_valid):
53 | loss = self.model._loss(input_valid, target_valid)
54 | loss.backward()
55 |
56 | def _backward_step_unrolled(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer):
57 | unrolled_model = self._compute_unrolled_model(
58 | input_train, target_train, eta, network_optimizer)
59 | unrolled_loss = unrolled_model._loss(input_valid, target_valid)
60 |
61 | unrolled_loss.backward()
62 | dalpha = [v.grad for v in unrolled_model.arch_parameters()]
63 | vector = [v.grad.data for v in unrolled_model.parameters()]
64 | implicit_grads = self._hessian_vector_product(vector, input_train, target_train)
65 |
66 | for g, ig in zip(dalpha, implicit_grads):
67 | g.data.sub_(eta, ig.data)
68 |
69 | for v, g in zip(self.model.arch_parameters(), dalpha):
70 | if v.grad is None:
71 | v.grad = Variable(g.data)
72 | else:
73 | v.grad.data.copy_(g.data)
74 |
75 | def _construct_model_from_theta(self, theta):
76 | model_new = self.model.new()
77 | model_dict = self.model.state_dict()
78 |
79 | params, offset = {}, 0
80 | for k, v in self.model.named_parameters():
81 | v_length = np.prod(v.size())
82 | params[k] = theta[offset: offset+v_length].view(v.size())
83 | offset += v_length
84 |
85 | assert offset == len(theta)
86 | model_dict.update(params)
87 | model_new.load_state_dict(model_dict)
88 | if not model_new.args.disable_cuda:
89 | model_new.cuda = model_new.cuda()
90 | return model_new
91 |
92 | def _hessian_vector_product(self, vector, input, target, r=1e-2):
93 | R = r / _concat(vector).norm()
94 | for p, v in zip(self.model.parameters(), vector):
95 | p.data.add_(R, v)
96 | loss = self.model._loss(input, target)
97 | grads_p = torch.autograd.grad(loss, self.model.arch_parameters())
98 |
99 | for p, v in zip(self.model.parameters(), vector):
100 | p.data.sub_(2*R, v)
101 | loss = self.model._loss(input, target)
102 | grads_n = torch.autograd.grad(loss, self.model.arch_parameters())
103 |
104 | for p, v in zip(self.model.parameters(), vector):
105 | p.data.add_(R, v)
106 |
107 | return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]
108 |
--------------------------------------------------------------------------------
/models/cell_search.py:
--------------------------------------------------------------------------------
1 | import dgl
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | from models.operations import OPS
7 | from models.mixed import Mixed
8 |
9 | '''
10 | cell_arch :
11 | topology: list
12 | (src, dst, weights, ops)
13 | '''
14 |
15 | class Cell(nn.Module):
16 |
17 | def __init__(self, args, cell_arch):
18 | super().__init__()
19 | self.args = args
20 | self.nb_nodes = args.nb_nodes*3 #! warning
21 | self.cell_arch = cell_arch
22 | self.trans_concat_V = nn.Linear(self.nb_nodes*args.node_dim, args.node_dim, bias = True)
23 | self.batchnorm_V = nn.BatchNorm1d(args.node_dim)
24 | self.activate = nn.LeakyReLU(args.leaky_slope)
25 | self.load_arch()
26 |
27 |
28 | def load_arch(self):
29 | link_para = {}
30 | link_dict = {}
31 | for src, dst, w, ops in self.cell_arch:
32 | if dst not in link_dict:
33 | link_dict[dst] = []
34 | link_dict[dst].append((src, w))
35 | link_para[str((src, dst))] = Mixed(self.args, ops)
36 |
37 | self.link_dict = link_dict
38 | self.link_para = nn.ModuleDict(link_para)
39 |
40 |
41 | def forward(self, input, weight):
42 | G, V_in = input['G'], input['V']
43 | link_para = self.link_para
44 | link_dict = self.link_dict
45 | states = [V_in]
46 | for dst in range(1, self.nb_nodes + 1):
47 | tmp_states = []
48 | for src, w in link_dict[dst]:
49 | sub_input = {'G': G, 'V': states[src], 'V_in': V_in}
50 | tmp_states.append(link_para[str((src, dst))](sub_input, weight[w]))
51 | states.append(sum(tmp_states))
52 |
53 | V = self.trans_concat_V(torch.cat(states[1:], dim = 1))
54 |
55 | if self.batchnorm_V:
56 | V = self.batchnorm_V(V)
57 |
58 | V = self.activate(V)
59 | V = F.dropout(V, self.args.dropout, training = self.training)
60 | V = V + V_in
61 | return {'G': G, 'V': V}
62 |
63 |
--------------------------------------------------------------------------------
/models/cell_train.py:
--------------------------------------------------------------------------------
1 | import dgl
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | from models.operations import V_Package, OPS
7 |
8 |
9 | class Cell(nn.Module):
10 |
11 | def __init__(self, args, genotype):
12 |
13 | super().__init__()
14 | self.args = args
15 | self.nb_nodes = args.nb_nodes
16 | self.genotype = genotype
17 | self.trans_concat_V = nn.Linear(args.nb_nodes * args.node_dim, args.node_dim, bias = True)
18 | self.batchnorm_V = nn.BatchNorm1d(args.node_dim)
19 | self.activate = nn.LeakyReLU(args.leaky_slope)
20 | self.load_genotype()
21 |
22 |
23 | def load_genotype(self):
24 | geno = self.genotype
25 | link_dict = {}
26 | module_dict = {}
27 | for edge in geno['topology']:
28 | src, dst, ops = edge['src'], edge['dst'], edge['ops']
29 | dst = f'{dst}'
30 |
31 | if dst not in link_dict:
32 | link_dict[dst] = []
33 | link_dict[dst].append(src)
34 |
35 | if dst not in module_dict:
36 | module_dict[dst] = nn.ModuleList([])
37 | module_dict[dst].append(V_Package(self.args, OPS[ops](self.args)))
38 |
39 | self.link_dict = link_dict
40 | self.module_dict = nn.ModuleDict(module_dict)
41 |
42 |
43 | def forward(self, input):
44 |
45 | G, V_in = input['G'], input['V']
46 | states = [V_in]
47 | for dst in range(1, self.nb_nodes + 1):
48 | dst = f'{dst}'
49 | agg = []
50 | for i, src in enumerate(self.link_dict[dst]):
51 | sub_input = {'G': G, 'V': states[src], 'V_in': V_in}
52 | agg.append(self.module_dict[dst][i](sub_input))
53 | states.append(sum(agg))
54 |
55 | V = self.trans_concat_V(torch.cat(states[1:], dim = 1))
56 |
57 | if self.batchnorm_V:
58 | V = self.batchnorm_V(V)
59 |
60 | V = self.activate(V)
61 | V = F.dropout(V, self.args.dropout, training = self.training)
62 | V = V + V_in
63 | return { 'G' : G, 'V' : V }
64 |
65 |
66 | if __name__ == '__main__':
67 | import yaml
68 | from easydict import EasyDict as edict
69 | geno = yaml.load(open('example_geno.yaml', 'r'))
70 | geno = geno['Genotype'][0]
71 | args = edict({
72 | 'nb_nodes': 4,
73 | 'node_dim': 50,
74 | 'leaky_slope': 0.2,
75 | 'batchnorm_op': True,
76 | })
77 | cell = Cell(args, geno)
78 |
--------------------------------------------------------------------------------
/models/mixed.py:
--------------------------------------------------------------------------------
1 | import dgl
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | from models.operations import V_Package, OPS
7 |
8 |
9 | class Mixed(nn.Module):
10 |
11 | def __init__(self, args, ops):
12 | super().__init__()
13 | self.args = args
14 | self.type = type
15 | self.ops = ops
16 | self.candidates = nn.ModuleDict({
17 | name: V_Package(args, OPS[name](args))
18 | for name in self.ops
19 | })
20 |
21 |
22 | def forward(self, input, weight):
23 | '''
24 | weight: a dict whose 'key' is operation name and 'val' is operation weight
25 | '''
26 | weight = weight.softmax(0)
27 | output = sum( weight[i] * self.candidates[name](input) for i, name in enumerate(self.ops) )
28 | # residual = input[1] if self.type == 'V' else input[2]
29 | return output # + residual * DecayScheduler().decay_rate
30 |
31 |
--------------------------------------------------------------------------------
/models/model_search.py:
--------------------------------------------------------------------------------
1 | import dgl
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | from models.cell_search import Cell
7 | from models.operations import OPS, First_Stage, Second_Stage, Third_Stage
8 | from models.networks import MLP
9 | from data import TransInput, TransOutput, get_trans_input
10 |
11 |
12 | class Model_Search(nn.Module):
13 |
14 | def __init__(self, args, trans_input_fn, loss_fn):
15 | super().__init__()
16 | self.args = args
17 | self.nb_layers = args.nb_layers
18 | self.cell_arch_topo = self.load_cell_arch() # obtain architecture topology
19 | self.cell_arch_para = self.init_cell_arch_para() # register architecture topology parameters
20 | self.cells = nn.ModuleList([Cell(args, self.cell_arch_topo[i]) for i in range(self.nb_layers)])
21 | self.loss_fn = loss_fn
22 | self.trans_input_fn = trans_input_fn
23 | self.trans_input = TransInput(trans_input_fn)
24 | self.trans_output = TransOutput(args)
25 | if args.pos_encode > 0:
26 | self.position_encoding = nn.Linear(args.pos_encode, args.node_dim)
27 |
28 |
29 | def forward(self, input):
30 | arch_para_dict = self.group_arch_parameters()
31 | input = self.trans_input(input)
32 | G, V = input['G'], input['V']
33 | if self.args.pos_encode > 0:
34 | V = V + self.position_encoding(G.ndata['pos_enc'].float().cuda())
35 | output = {'G': G, 'V': V}
36 | for i, cell in enumerate(self.cells):
37 | output = cell(output, arch_para_dict[i])
38 | output = self.trans_output(output)
39 | return output
40 |
41 |
42 | def load_cell_arch(self):
43 | cell_arch_topo = []
44 | for _ in range(self.nb_layers):
45 | arch_topo = self.load_cell_arch_by_layer()
46 | cell_arch_topo.append(arch_topo)
47 | return cell_arch_topo
48 |
49 |
50 | def load_cell_arch_by_layer(self):
51 | arch_topo = []
52 | w = 0
53 | for dst in range(1, self.args.nb_nodes+1):
54 | for src in range(dst):
55 | arch_topo.append((src, dst, w, First_Stage))
56 | w += 1
57 | for dst in range(self.args.nb_nodes+1, 2*self.args.nb_nodes+1):
58 | src = dst - self.args.nb_nodes
59 | arch_topo.append((src, dst, w, Second_Stage))
60 | w += 1
61 | for dst in range(2*self.args.nb_nodes+1, 3*self.args.nb_nodes+1):
62 | for src in range(self.args.nb_nodes+1, 2*self.args.nb_nodes+1):
63 | arch_topo.append((src, dst, w, Third_Stage))
64 | w += 1
65 | return arch_topo
66 |
67 |
68 | def init_cell_arch_para(self):
69 | cell_arch_para = []
70 | for i_layer in range(self.nb_layers):
71 | arch_para = self.init_arch_para(self.cell_arch_topo[i_layer])
72 | cell_arch_para.extend(arch_para)
73 | self.nb_cell_topo = len(arch_para)
74 | return cell_arch_para
75 |
76 |
77 | def init_arch_para(self, arch_topo):
78 | arch_para = []
79 | for src, dst, w, ops in arch_topo:
80 | arch_para.append(Variable(1e-3 * torch.rand(len(ops)).cuda(), requires_grad = True))
81 | return arch_para
82 |
83 |
84 | def group_arch_parameters(self):
85 | group = []
86 | start = 0
87 | for _ in range(self.nb_layers):
88 | group.append(self.arch_parameters()[start: start + self.nb_cell_topo])
89 | start += self.nb_cell_topo
90 | return group
91 |
92 |
93 | # def load_cell_arch_by_layer(self):
94 | # arch_type_dict = []
95 | # w = 0
96 | # for dst in range(1, self.args.nb_nodes + 1):
97 | # for src in range(dst):
98 | # arch_type_dict.append((src, dst, w))
99 | # w += 1
100 | # return arch_type_dict
101 |
102 |
103 | # def init_cell_arch_para(self):
104 | # cell_arch_para = []
105 | # for i_layer in range(self.nb_layers):
106 | # cell_arch_para.append(self.init_arch_para(self.cell_arch_topo[i_layer]))
107 | # return cell_arch_para
108 |
109 |
110 | # def init_arch_para(self, arch_topo):
111 | # arch_para = Variable(1e-3 * torch.rand(len(arch_topo), len(OPS)).cuda(), requires_grad = True)
112 | # return arch_para
113 |
114 |
115 | def new(self):
116 | model_new = Model_Search(self.args, get_trans_input(self.args), self.loss_fn).cuda()
117 | for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
118 | x.data.copy_(y.data)
119 | return model_new
120 |
121 | def load_alpha(self, alphas):
122 | for x, y in zip(self.arch_parameters(), alphas):
123 | x.data.copy_(y.data)
124 |
125 | def arch_parameters(self):
126 | return self.cell_arch_para
127 |
128 | def _loss(self, input, targets):
129 | scores = self.forward(input)
130 | return self.loss_fn(scores, targets)
--------------------------------------------------------------------------------
/models/model_train.py:
--------------------------------------------------------------------------------
1 | import dgl
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | from models.cell_train import Cell
7 | from models.operations import OPS
8 | from models.networks import MLP
9 | from data import TransInput, TransOutput
10 |
11 |
12 | class Model_Train(nn.Module):
13 |
14 | def __init__(self, args, genotypes, trans_input_fn, loss_fn):
15 | super().__init__()
16 | self.args = args
17 | self.nb_layers = args.nb_layers
18 | self.genotypes = genotypes
19 | self.cells = nn.ModuleList([Cell(args, genotypes[i]) for i in range(self.nb_layers)])
20 | self.loss_fn = loss_fn
21 | self.trans_input = TransInput(trans_input_fn)
22 | self.trans_output = TransOutput(args)
23 | if args.pos_encode > 0:
24 | self.position_encoding = nn.Linear(args.pos_encode, args.node_dim)
25 |
26 |
27 | def forward(self, input):
28 | input = self.trans_input(input)
29 | G, V = input['G'], input['V']
30 | if self.args.pos_encode > 0:
31 | V = V + self.position_encoding(G.ndata['pos_enc'].float().to("cuda"))
32 | output = {'G': G, 'V': V}
33 | for cell in self.cells:
34 | output = cell(output)
35 | output = self.trans_output(output)
36 | return output
37 |
38 |
39 | def _loss(self, input, targets):
40 | scores = self.forward(input)
41 | return self.loss_fn(scores, targets)
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class MLP(nn.Module):
6 | def __init__(self, channel_sequence):
7 | super().__init__()
8 | nb_layers = len(channel_sequence) - 1
9 | self.seq = nn.Sequential()
10 | for i in range(nb_layers):
11 | self.seq.add_module(f"fc{i}", nn.Linear(channel_sequence[i], channel_sequence[i + 1]))
12 | if i != nb_layers - 1:
13 | self.seq.add_module(f"ReLU{i}", nn.ReLU(inplace=True))
14 |
15 | def forward(self, x):
16 | out = self.seq(x)
17 | return out
18 |
19 |
20 |
--------------------------------------------------------------------------------
/models/operations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import dgl.function as fn
3 | import torch.nn as nn
4 | import numpy as np
5 | # from models.networks import *
6 |
7 |
8 | OPS = {
9 | 'V_None' : lambda args: V_None(args),
10 | 'V_I' : lambda args: V_I(args),
11 | 'V_Max' : lambda args: V_Max(args),
12 | 'V_Mean' : lambda args: V_Mean(args),
13 | 'V_Min' : lambda args: V_Min(args),
14 | 'V_Sum' : lambda args: V_Sum(args),
15 | 'V_Sparse': lambda args: V_Sparse(args),
16 | 'V_Dense' : lambda args: V_Dense(args),
17 | }
18 |
19 |
20 | First_Stage = ['V_None', 'V_I', 'V_Sparse', 'V_Dense']
21 | Second_Stage = ['V_I', 'V_Mean', 'V_Sum', 'V_Max']
22 | Third_Stage = ['V_None', 'V_I', 'V_Sparse', 'V_Dense']
23 |
24 |
25 | class V_Package(nn.Module):
26 |
27 | def __init__(self, args, operation):
28 |
29 | super().__init__()
30 | self.args = args
31 | self.operation = operation
32 | if type(operation) in [V_None, V_I]:
33 | self.seq = None
34 | else:
35 | self.seq = nn.Sequential()
36 | self.seq.add_module('fc_bn', nn.Linear(args.node_dim, args.node_dim, bias = True))
37 | if args.batchnorm_op:
38 | self.seq.add_module('bn', nn.BatchNorm1d(self.args.node_dim))
39 | self.seq.add_module('act', nn.ReLU())
40 |
41 |
42 | def forward(self, input):
43 | V = self.operation(input)
44 | if self.seq:
45 | V = self.seq(V)
46 | return V
47 |
48 |
49 | class NodePooling(nn.Module):
50 |
51 | def __init__(self, args):
52 | super().__init__()
53 | self.A = nn.Linear(args.node_dim, args.node_dim)
54 | # self.B = nn.Linear(args.node_dim, args.node_dim)
55 | self.activate = nn.ReLU()
56 |
57 | def forward(self, V):
58 | V = self.A(V)
59 | V = self.activate(V)
60 | # V = self.B(V)
61 | return V
62 |
63 |
64 | class V_None(nn.Module):
65 |
66 | def __init__(self, args):
67 | super().__init__()
68 |
69 | def forward(self, input):
70 | V = input['V']
71 | return 0. * V
72 |
73 |
74 | class V_I(nn.Module):
75 |
76 | def __init__(self, args):
77 | super().__init__()
78 |
79 | def forward(self, input):
80 | V = input['V']
81 | return V
82 |
83 |
84 | class V_Max(nn.Module):
85 |
86 | def __init__(self, args):
87 | super().__init__()
88 | self.pooling = NodePooling(args)
89 |
90 | def forward(self, input):
91 | G, V = input['G'], input['V']
92 | # G.ndata['V'] = V
93 | G.ndata['V'] = self.pooling(V)
94 | G.update_all(fn.copy_u('V', 'M'), fn.max('M', 'V'))
95 | return G.ndata['V']
96 |
97 |
98 | class V_Mean(nn.Module):
99 |
100 | def __init__(self, args):
101 | super().__init__()
102 | self.pooling = NodePooling(args)
103 |
104 | def forward(self, input):
105 | G, V = input['G'], input['V']
106 | # G.ndata['V'] = V
107 | G.ndata['V'] = self.pooling(V)
108 | G.update_all(fn.copy_u('V', 'M'), fn.mean('M', 'V'))
109 | return G.ndata['V']
110 |
111 |
112 | class V_Sum(nn.Module):
113 |
114 | def __init__(self, args):
115 | super().__init__()
116 | self.pooling = NodePooling(args)
117 |
118 | def forward(self, input):
119 | G, V = input['G'], input['V']
120 | # G.ndata['V'] = self.pooling(V)
121 | G.ndata['V'] = V
122 | G.update_all(fn.copy_u('V', 'M'), fn.sum('M', 'V'))
123 | return G.ndata['V']
124 |
125 |
126 | class V_Min(nn.Module):
127 |
128 | def __init__(self, args):
129 | super().__init__()
130 | self.pooling = NodePooling(args)
131 |
132 | def forward(self, input):
133 | G, V = input['G'], input['V']
134 | G.ndata['V'] = self.pooling(V)
135 | G.update_all(fn.copy_u('V', 'M'), fn.min('M', 'V'))
136 | return G.ndata['V']
137 |
138 |
139 | class V_Dense(nn.Module):
140 |
141 | def __init__(self, args):
142 | super().__init__()
143 | self.W = nn.Linear(args.node_dim*2, args.node_dim, bias = True)
144 |
145 | def forward(self, input):
146 | V, V_in = input['V'], input['V_in']
147 | gates = torch.cat([V, V_in], dim = 1)
148 | gates = self.W(gates)
149 | return torch.sigmoid(gates) * V
150 |
151 |
152 | class V_Sparse(nn.Module):
153 |
154 | def __init__(self, args):
155 | super().__init__()
156 | self.W = nn.Linear(args.node_dim*2, args.node_dim, bias = True)
157 | self.a = nn.Linear(args.node_dim, 1, bias = False)
158 |
159 | def forward(self, input):
160 | V, V_in = input['V'], input['V_in']
161 | gates = torch.cat([V, V_in], dim = 1)
162 | # gates = self.W(gates)
163 | gates = torch.relu(self.W(gates))
164 | gates = self.a(gates)
165 | return torch.sigmoid(gates) * V
166 |
167 |
168 | if __name__ == '__main__':
169 | print("test")
--------------------------------------------------------------------------------
/scripts/search_molecules_zinc.sh:
--------------------------------------------------------------------------------
1 | DEVICES=$1
2 |
3 | CUDA_VISIBLE_DEVICES=$DEVICES python search.py \
4 | --task 'graph_level' \
5 | --data 'ZINC' \
6 | --data_clip 1.0 \
7 | --in_dim_V 28 \
8 | --batch 64 \
9 | --epochs 40 \
10 | --node_dim 60 \
11 | --nb_layers 12 \
12 | --nb_nodes 3 \
13 | --portion 0.9 \
14 | --dropout 0.0 \
15 | --pos_encode 0 \
16 | --nb_workers 0 \
17 | --report_freq 1 \
18 | --arch_save 'archs/folder5' \
19 | --search_mode 'train' \
20 | --batchnorm_op
21 |
22 |
--------------------------------------------------------------------------------
/scripts/search_sbms_cluster.sh:
--------------------------------------------------------------------------------
1 | DEVICES=$1
2 | CUDA_VISIBLE_DEVICES=$DEVICES python search.py \
3 | --task 'node_level' \
4 | --data 'SBM_CLUSTER' \
5 | --nb_classes 6 \
6 | --data_clip 1.0 \
7 | --in_dim_V 7 \
8 | --batch 32 \
9 | --node_dim 70 \
10 | --pos_encode 0 \
11 | --nb_layers 4 \
12 | --nb_nodes 2 \
13 | --dropout 0.2 \
14 | --portion 0.5 \
15 | --search_mode 'train' \
16 | --nb_workers 0 \
17 | --report_freq 1 \
18 | --arch_save 'archs/folder5' \
19 | --batchnorm_op
20 |
--------------------------------------------------------------------------------
/scripts/search_sbms_pattern.sh:
--------------------------------------------------------------------------------
1 | DEVICES=$1
2 |
3 | CUDA_VISIBLE_DEVICES=$DEVICES python search.py \
4 | --task 'node_level' \
5 | --data 'SBM_PATTERN' \
6 | --nb_classes 2 \
7 | --in_dim_V 3 \
8 | --data_clip 0.5 \
9 | --batch 8 \
10 | --node_dim 50 \
11 | --pos_encode 0 \
12 | --nb_layers 4 \
13 | --nb_nodes 3 \
14 | --dropout 0.0 \
15 | --portion 0.5 \
16 | --nb_workers 0 \
17 | --report_freq 1 \
18 | --search_mode 'train' \
19 | --arch_save 'archs/folder5' \
20 | --batchnorm_op
21 |
--------------------------------------------------------------------------------
/scripts/search_superpixels_cifar10.sh:
--------------------------------------------------------------------------------
1 | DEVICES=$1
2 | # no batchnorm + dropout 0.2
3 | CUDA_VISIBLE_DEVICES=$DEVICES python search.py \
4 | --task 'graph_level' \
5 | --data 'CIFAR10' \
6 | --nb_classes 10 \
7 | --in_dim_V 5 \
8 | --batch 64 \
9 | --epochs 40 \
10 | --node_dim 50 \
11 | --portion 0.5 \
12 | --nb_layers 1 \
13 | --nb_nodes 4 \
14 | --dropout 0.2 \
15 | --nb_workers 0 \
16 | --report_freq 1 \
17 | --search_mode 'train' \
18 | --arch_save 'archs/folder5'
19 |
--------------------------------------------------------------------------------
/scripts/search_superpixels_mnist.sh:
--------------------------------------------------------------------------------
1 | DEVICES=$1
2 | # no batchnorm + dropout 0.2
3 | CUDA_VISIBLE_DEVICES=$DEVICES python search.py \
4 | --task 'graph_level' \
5 | --data 'MNIST' \
6 | --nb_classes 10 \
7 | --in_dim_V 3 \
8 | --batch 64 \
9 | --epochs 40 \
10 | --node_dim 50 \
11 | --portion 0.5 \
12 | --nb_layers 4 \
13 | --nb_nodes 3 \
14 | --dropout 0.2 \
15 | --nb_workers 0 \
16 | --report_freq 1 \
17 | --search_mode 'train' \
18 | --arch_save 'archs/folder5'
19 |
--------------------------------------------------------------------------------
/scripts/train_molecules_zinc.sh:
--------------------------------------------------------------------------------
1 | DEVICES=$1
2 | GENOTYPE=$2
3 | CUDA_VISIBLE_DEVICES=$DEVICES python train.py \
4 | --task 'graph_level' \
5 | --data 'ZINC' \
6 | --in_dim_V 28 \
7 | --batch 128 \
8 | --node_dim 60 \
9 | --dropout 0.0 \
10 | --pos_encode 0 \
11 | --batchnorm_op \
12 | --epochs 200 \
13 | --lr 1e-3 \
14 | --weight_decay 0.0 \
15 | --optimizer 'ADAM' \
16 | --load_genotypes $GENOTYPE
17 |
--------------------------------------------------------------------------------
/scripts/train_sbms_cluster.sh:
--------------------------------------------------------------------------------
1 | DEVICES=$1
2 | GENOTYPE=$2
3 |
4 | CUDA_VISIBLE_DEVICES=$DEVICES python train.py \
5 | --task 'node_level' \
6 | --data 'SBM_CLUSTER' \
7 | --nb_classes 6 \
8 | --in_dim_V 7 \
9 | --pos_encode 0 \
10 | --batch 64 \
11 | --node_dim 70 \
12 | --dropout 0.2 \
13 | --batchnorm_op \
14 | --epochs 200 \
15 | --lr 1e-3 \
16 | --weight_decay 0.0 \
17 | --optimizer 'ADAM' \
18 | --load_genotypes $GENOTYPE
--------------------------------------------------------------------------------
/scripts/train_sbms_pattern.sh:
--------------------------------------------------------------------------------
1 | DEVICES=$1
2 | GENOTYPE=$2
3 |
4 | CUDA_VISIBLE_DEVICES=$DEVICES python train.py \
5 | --task 'node_level' \
6 | --data 'SBM_PATTERN' \
7 | --nb_classes 2 \
8 | --in_dim_V 3 \
9 | --batch 64 \
10 | --node_dim 70 \
11 | --pos_encode 0 \
12 | --dropout 0.2 \
13 | --batchnorm_op \
14 | --epochs 200 \
15 | --lr 1e-3 \
16 | --weight_decay 0.0 \
17 | --optimizer 'ADAM' \
18 | --load_genotypes $GENOTYPE
--------------------------------------------------------------------------------
/scripts/train_superpixels_cifar10.sh:
--------------------------------------------------------------------------------
1 | DEVICES=$1
2 | GENOTYPE=$2
3 | # no batchnorm + dropout 0.2
4 | CUDA_VISIBLE_DEVICES=$DEVICES python train.py \
5 | --task 'graph_level' \
6 | --data 'CIFAR10' \
7 | --nb_classes 10 \
8 | --in_dim_V 5 \
9 | --batch 64 \
10 | --node_dim 50 \
11 | --pos_encode 0 \
12 | --dropout 0.2 \
13 | --epochs 200 \
14 | --lr 1e-3 \
15 | --weight_decay 0.0 \
16 | --optimizer 'ADAM' \
17 | --load_genotypes $GENOTYPE
--------------------------------------------------------------------------------
/scripts/train_superpixels_mnist.sh:
--------------------------------------------------------------------------------
1 | DEVICES=$1
2 | GENOTYPE=$2
3 | # no batchnorm + dropout 0.2
4 | CUDA_VISIBLE_DEVICES=$DEVICES python train.py \
5 | --task 'graph_level' \
6 | --data 'MNIST' \
7 | --nb_classes 10 \
8 | --in_dim_V 3 \
9 | --batch 64 \
10 | --node_dim 50 \
11 | --pos_encode 0 \
12 | --dropout 0.2 \
13 | --epochs 200 \
14 | --lr 1e-3 \
15 | --weight_decay 0.0 \
16 | --optimizer 'ADAM' \
17 | --load_genotypes $GENOTYPE
--------------------------------------------------------------------------------
/search.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import dgl
4 | import yaml
5 | import torch
6 | import argparse
7 | import numpy as np
8 | import torch.backends.cudnn as cudnn
9 | from tqdm import tqdm
10 | from data import *
11 | from models.model_search import *
12 | from utils.utils import *
13 | from models.architect import Architect
14 |
15 |
16 | class Searcher(object):
17 |
18 | def __init__(self, args):
19 |
20 | self.args = args
21 | self.console = Console()
22 |
23 | self.console.log('=> [1] Initial settings')
24 | np.random.seed(args.seed)
25 | torch.manual_seed(args.seed)
26 | torch.cuda.manual_seed(args.seed)
27 | cudnn.benchmark = True
28 | cudnn.enabled = True
29 |
30 | self.console.log('=> [2] Initial models')
31 | self.metric = load_metric(args)
32 | self.loss_fn = get_loss_fn(args).cuda()
33 | self.model = Model_Search(args, get_trans_input(args), self.loss_fn).cuda()
34 | self.console.log(f'=> Supernet Parameters: {count_parameters_in_MB(self.model)}', style = 'bold red')
35 |
36 | self.console.log(f'=> [3] Preparing dataset')
37 | self.dataset = load_data(args)
38 | if args.pos_encode > 0:
39 | #! add positional encoding
40 | self.console.log(f'==> [3.1] Adding positional encodings')
41 | self.dataset._add_positional_encodings(args.pos_encode)
42 | self.search_data = self.dataset.train
43 | self.val_data = self.dataset.val
44 | self.test_data = self.dataset.test
45 | self.load_dataloader()
46 |
47 | self.console.log(f'=> [4] Initial optimizer')
48 | self.optimizer = torch.optim.SGD(
49 | params = self.model.parameters(),
50 | lr = args.lr,
51 | momentum = args.momentum,
52 | weight_decay = args.weight_decay
53 | )
54 |
55 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
56 | optimizer = self.optimizer,
57 | T_max = float(args.epochs),
58 | eta_min = args.lr_min
59 | )
60 |
61 | self.architect = Architect(self.model, self.args)
62 |
63 |
64 | def load_dataloader(self):
65 |
66 | num_search = int(len(self.search_data) * self.args.data_clip)
67 | indices = list(range(num_search))
68 | split = int(np.floor(self.args.portion * num_search))
69 | self.console.log(f'=> Para set size: {split}, Arch set size: {num_search - split}')
70 |
71 | self.para_queue = torch.utils.data.DataLoader(
72 | dataset = self.search_data,
73 | batch_size = self.args.batch,
74 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
75 | pin_memory = True,
76 | num_workers = self.args.nb_workers,
77 | collate_fn = self.dataset.collate
78 | )
79 |
80 | self.arch_queue = torch.utils.data.DataLoader(
81 | dataset = self.search_data,
82 | batch_size = self.args.batch,
83 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
84 | pin_memory = True,
85 | num_workers = self.args.nb_workers,
86 | collate_fn = self.dataset.collate
87 | )
88 |
89 | num_valid = int(len(self.val_data) * self.args.data_clip)
90 | indices = list(range(num_valid))
91 |
92 | self.val_queue = torch.utils.data.DataLoader(
93 | dataset = self.val_data,
94 | batch_size = self.args.batch,
95 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices),
96 | pin_memory = True,
97 | num_workers = self.args.nb_workers,
98 | collate_fn = self.dataset.collate
99 | )
100 |
101 | num_test = int(len(self.test_data) * self.args.data_clip)
102 | indices = list(range(num_test))
103 |
104 | self.test_queue = torch.utils.data.DataLoader(
105 | dataset = self.test_data,
106 | batch_size = self.args.batch,
107 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices),
108 | pin_memory = True,
109 | num_workers = self.args.nb_workers,
110 | collate_fn = self.dataset.collate
111 | )
112 |
113 |
114 | def run(self):
115 |
116 | self.console.log(f'=> [4] Search & Train')
117 | for i_epoch in range(self.args.epochs):
118 | self.scheduler.step()
119 | self.lr = self.scheduler.get_lr()[0]
120 | if i_epoch % self.args.report_freq == 0:
121 | geno = genotypes(
122 | args = self.args,
123 | arch_paras = self.model.group_arch_parameters(),
124 | arch_topos = self.model.cell_arch_topo,
125 | )
126 | with open(f'{self.args.arch_save}/{self.args.data}/{i_epoch}.yaml', "w") as f:
127 | yaml.dump(geno, f)
128 |
129 | # => report genotype
130 | self.console.log( geno )
131 | for i in range(self.args.nb_layers):
132 | for p in self.model.group_arch_parameters()[i]:
133 | self.console.log(p.softmax(0).detach().cpu().numpy())
134 |
135 | search_result = self.search()
136 | self.console.log(f"[green]=> search result [{i_epoch}] - loss: {search_result['loss']:.4f} - metric : {search_result['metric']:.4f}",)
137 | # DecayScheduler().step(i_epoch)
138 |
139 | with torch.no_grad():
140 | val_result = self.infer(self.val_queue)
141 | self.console.log(f"[yellow]=> valid result [{i_epoch}] - loss: {val_result['loss']:.4f} - metric : {val_result['metric']:.4f}")
142 |
143 | test_result = self.infer(self.test_queue)
144 | self.console.log(f"[red]=> test result [{i_epoch}] - loss: {test_result['loss']:.4f} - metric : {test_result['metric']:.4f}")
145 |
146 |
147 | def search(self):
148 |
149 | self.model.train()
150 | epoch_loss = 0
151 | epoch_metric = 0
152 | desc = '=> searching'
153 | device = torch.device('cuda')
154 |
155 | with tqdm(self.para_queue, desc = desc, leave = False) as t:
156 | for i_step, (batch_graphs, batch_targets) in enumerate(t):
157 | #! 1. preparing training datasets
158 | G = batch_graphs.to(device)
159 | V = batch_graphs.ndata['feat'].to(device)
160 | # E = batch_graphs.edata['feat'].to(device)
161 | batch_targets = batch_targets.to(device)
162 | #! 2. preparing validating datasets
163 | batch_graphs_search, batch_targets_search = next(iter(self.arch_queue))
164 | GS = batch_graphs_search.to(device)
165 | VS = batch_graphs_search.ndata['feat'].to(device)
166 | # ES = batch_graphs_search.edata['feat'].to(device)
167 | batch_targets_search = batch_targets_search.to(device)
168 | #! 3. optimizing architecture topology parameters
169 | self.architect.step(
170 | input_train = {'G': G, 'V': V},
171 | target_train = batch_targets,
172 | input_valid = {'G': GS, 'V': VS},
173 | target_valid = batch_targets_search,
174 | eta = self.lr,
175 | network_optimizer = self.optimizer,
176 | unrolled = self.args.unrolled
177 | )
178 | #! 4. optimizing model parameters
179 | self.optimizer.zero_grad()
180 | batch_scores = self.model({'G': G, 'V': V})
181 | loss = self.loss_fn(batch_scores, batch_targets)
182 | loss.backward()
183 | self.optimizer.step()
184 |
185 | epoch_loss += loss.detach().item()
186 | epoch_metric += self.metric(batch_scores, batch_targets)
187 | t.set_postfix(lr = self.lr,
188 | loss = epoch_loss / (i_step + 1),
189 | metric = epoch_metric / (i_step + 1))
190 |
191 | return {'loss' : epoch_loss / (i_step + 1),
192 | 'metric' : epoch_metric / (i_step + 1)}
193 |
194 |
195 | def infer(self, dataloader):
196 |
197 | self.model.eval()
198 | epoch_loss = 0
199 | epoch_metric = 0
200 | desc = '=> inferring'
201 | device = torch.device('cuda')
202 |
203 | with tqdm(dataloader, desc = desc, leave = False) as t:
204 | for i_step, (batch_graphs, batch_targets) in enumerate(t):
205 | G = batch_graphs.to(device)
206 | V = batch_graphs.ndata['feat'].to(device)
207 | # E = batch_graphs.edata['feat'].to(device)
208 | batch_targets = batch_targets.to(device)
209 | batch_scores = self.model({'G': G, 'V': V})
210 | loss = self.loss_fn(batch_scores, batch_targets)
211 |
212 | epoch_loss += loss.detach().item()
213 | epoch_metric += self.metric(batch_scores, batch_targets)
214 | t.set_postfix(loss = epoch_loss / (i_step + 1),
215 | metric = epoch_metric / (i_step + 1))
216 |
217 | return {'loss' : epoch_loss / (i_step + 1),
218 | 'metric' : epoch_metric / (i_step + 1)}
219 |
220 |
221 | if __name__ == '__main__':
222 |
223 | import warnings
224 | from rich.console import Console
225 | from rich.table import Table
226 | from rich.panel import Panel
227 | from rich.syntax import Syntax
228 | warnings.filterwarnings('ignore')
229 |
230 | parser = argparse.ArgumentParser('Rethinking Graph Neural Architecture Search From Message Passing')
231 | parser.add_argument('--task', type = str, default = 'graph_level')
232 | parser.add_argument('--data', type = str, default = 'ZINC')
233 | parser.add_argument('--in_dim_V', type = int, default = 28)
234 | parser.add_argument('--node_dim', type = int, default = 70)
235 | parser.add_argument('--nb_classes', type = int, default = 1)
236 | parser.add_argument('--nb_layers', type = int, default = 4)
237 | parser.add_argument('--nb_nodes', type = int, default = 3)
238 | parser.add_argument('--leaky_slope', type = float, default = 0.1)
239 | parser.add_argument('--batchnorm_op', action = 'store_true', default = False)
240 | parser.add_argument('--nb_mlp_layer', type = int, default = 4)
241 | parser.add_argument('--dropout', type = float, default = 0.0)
242 | parser.add_argument('--pos_encode', type = int, default = 0)
243 |
244 | parser.add_argument('--portion', type = float, default = 0.5)
245 | parser.add_argument('--data_clip', type = float, default = 1.0)
246 | parser.add_argument('--nb_workers', type = int, default = 0)
247 | parser.add_argument('--seed', type = int, default = 41)
248 | parser.add_argument('--epochs', type = int, default = 50)
249 | parser.add_argument('--batch', type = int, default = 64)
250 | parser.add_argument('--lr', type = float, default = 0.025)
251 | parser.add_argument('--lr_min', type = float, default = 0.001)
252 | parser.add_argument('--momentum', type = float, default = 0.9)
253 | parser.add_argument('--weight_decay', type = float, default = 3e-4)
254 | parser.add_argument('--unrolled', action = 'store_true', default = False)
255 | parser.add_argument('--search_mode', type = str, default = 'train')
256 | parser.add_argument('--arch_lr', type = float, default = 3e-4)
257 | parser.add_argument('--arch_weight_decay', type = float, default =1e-3)
258 | parser.add_argument('--report_freq', type = int, default = 1)
259 | parser.add_argument('--arch_save', type = str, default = './save_arch')
260 |
261 | console = Console()
262 | args = parser.parse_args()
263 | title = "[bold][red]Searching & Training"
264 | vis = ""
265 | for key, val in vars(args).items():
266 | vis += f"{key}: {val}\n"
267 | vis = Syntax(vis[:-1], "yaml", theme="monokai", line_numbers=True)
268 | richPanel = Panel.fit(vis, title = title)
269 | console.print(richPanel)
270 | data_path = os.path.join(args.arch_save, args.data)
271 | if not os.path.exists(data_path):
272 | os.mkdir(data_path)
273 | with open(os.path.join(data_path, "configs.yaml"), "w") as f:
274 | yaml.dump(vars(args), f)
275 | Searcher(args).run()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import dgl
4 | import yaml
5 | import torch
6 | import argparse
7 | import numpy as np
8 | import torch.backends.cudnn as cudnn
9 | from tqdm import tqdm
10 | from data import *
11 | from models.model_train import *
12 | from utils.utils import *
13 | from tensorboardX import SummaryWriter
14 | from utils.record_utils import record_run
15 |
16 |
17 | class Trainer(object):
18 |
19 | def __init__(self, args):
20 |
21 | self.args = args
22 | self.console = Console()
23 |
24 | self.console.log('=> [0] Initial TensorboardX')
25 | self.writer = SummaryWriter(comment = f'Task: {args.task}, Data: {args.data}, Geno: {args.load_genotypes}')
26 |
27 | self.console.log('=> [1] Initial Settings')
28 | np.random.seed(args.seed)
29 | torch.manual_seed(args.seed)
30 | torch.cuda.manual_seed(args.seed)
31 | cudnn.enabled = True
32 |
33 | self.console.log('=> [2] Initial Models')
34 | if not os.path.isfile(args.load_genotypes):
35 | raise Exception('Genotype file not found!')
36 | else:
37 | with open(args.load_genotypes, "r") as f:
38 | genotypes = yaml.load(f)
39 | args.nb_layers = len(genotypes['Genotype'])
40 | args.nb_nodes = len({ edge['dst'] for edge in genotypes['Genotype'][0]['topology'] })
41 | self.metric = load_metric(args)
42 | self.loss_fn = get_loss_fn(args).cuda()
43 | trans_input_fn = get_trans_input(args)
44 | self.model = Model_Train(args, genotypes['Genotype'], trans_input_fn, self.loss_fn).to("cuda")
45 | self.console.log(f'[red]=> Subnet Parameters: {count_parameters_in_MB(self.model)}')
46 |
47 | self.console.log(f'=> [3] Preparing Dataset')
48 | self.dataset = load_data(args)
49 | if args.pos_encode > 0:
50 | #! load - position encoding
51 | self.console.log(f'==> [3.1] Adding positional encodings')
52 | self.dataset._add_positional_encodings(args.pos_encode)
53 | self.train_data = self.dataset.train
54 | self.val_data = self.dataset.val
55 | self.test_data = self.dataset.test
56 | self.load_dataloader()
57 |
58 | self.console.log(f'=> [4] Initial Optimizers')
59 | if args.optimizer == 'SGD':
60 | self.optimizer = torch.optim.SGD(
61 | params = self.model.parameters(),
62 | lr = args.lr,
63 | momentum = args.momentum,
64 | weight_decay = args.weight_decay,
65 | )
66 |
67 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
68 | optimizer = self.optimizer,
69 | T_max = float(args.epochs),
70 | eta_min = args.lr_min
71 | )
72 |
73 | elif args.optimizer == 'ADAM':
74 | self.optimizer = torch.optim.Adam(
75 | params = self.model.parameters(),
76 | lr = args.lr,
77 | weight_decay = args.weight_decay,
78 | )
79 |
80 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
81 | optimizer = self.optimizer,
82 | mode = 'min',
83 | factor = 0.5,
84 | patience = args.patience,
85 | verbose = True
86 | )
87 | else:
88 | raise Exception('Unknown optimizer!')
89 |
90 |
91 | def load_dataloader(self):
92 |
93 | num_train = int(len(self.train_data) * self.args.data_clip)
94 | indices = list(range(num_train))
95 |
96 | self.train_queue = torch.utils.data.DataLoader(
97 | dataset = self.train_data,
98 | batch_size = self.args.batch,
99 | pin_memory = True,
100 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices),
101 | num_workers = self.args.nb_workers,
102 | collate_fn = self.dataset.collate,
103 | )
104 |
105 | num_valid = int(len(self.val_data) * self.args.data_clip)
106 | indices = list(range(num_valid))
107 |
108 | self.val_queue = torch.utils.data.DataLoader(
109 | dataset = self.val_data,
110 | batch_size = self.args.batch,
111 | pin_memory = True,
112 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices),
113 | num_workers = self.args.nb_workers,
114 | collate_fn = self.dataset.collate,
115 | shuffle = False
116 | )
117 |
118 | num_test = int(len(self.test_data) * self.args.data_clip)
119 | indices = list(range(num_test))
120 |
121 | self.test_queue = torch.utils.data.DataLoader(
122 | dataset = self.test_data,
123 | batch_size = self.args.batch,
124 | pin_memory = True,
125 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices),
126 | num_workers = self.args.nb_workers,
127 | collate_fn = self.dataset.collate,
128 | shuffle = False,
129 | )
130 |
131 |
132 | def scheduler_step(self, valid_loss):
133 |
134 | if self.args.optimizer == 'SGD':
135 | self.scheduler.step()
136 | lr = scheduler.get_lr()[0]
137 | elif self.args.optimizer == 'ADAM':
138 | self.scheduler.step(valid_loss)
139 | lr = self.optimizer.param_groups[0]['lr']
140 | if lr < 1e-5:
141 | self.console.log('=> !! learning rate is smaller than threshold !!')
142 | return lr
143 |
144 |
145 | def run(self):
146 |
147 | self.console.log(f'=> [5] Train Genotypes')
148 | self.lr = self.args.lr
149 | for i_epoch in range(self.args.epochs):
150 | #! training
151 | train_result = self.train(i_epoch, 'train')
152 | self.console.log(f"[green]=> train result [{i_epoch}] - loss: {train_result['loss']:.4f} - metric : {train_result['metric']:.4f}")
153 | with torch.no_grad():
154 | #! validating
155 | val_result = self.infer(i_epoch, self.val_queue, 'val')
156 | self.console.log(f"[yellow]=> valid result [{i_epoch}] - loss: {val_result['loss']:.4f} - metric : {val_result['metric']:.4f}")
157 | #! testing
158 | test_result = self.infer(i_epoch, self.test_queue, 'test')
159 | self.console.log(f"[underline][red]=> test result [{i_epoch}] - loss: {test_result['loss']:.4f} - metric : {test_result['metric']:.4f}")
160 | self.lr = self.scheduler_step(val_result['loss'])
161 |
162 | self.console.log(f'=> Finished! Genotype = {args.load_genotypes}')
163 |
164 |
165 | @record_run('train')
166 | def train(self, i_epoch, stage = 'train'):
167 |
168 | self.model.train()
169 | epoch_loss = 0
170 | epoch_metric = 0
171 | desc = '=> training'
172 | device = torch.device('cuda')
173 |
174 | with tqdm(self.train_queue, desc = desc, leave = False) as t:
175 | for i_step, (batch_graphs, batch_targets) in enumerate(t):
176 | #! 1. preparing datasets
177 | G = batch_graphs.to(device)
178 | V = batch_graphs.ndata['feat'].to(device)
179 | batch_targets = batch_targets.to(device)
180 |
181 | #! 2. optimizing model parameters
182 | self.optimizer.zero_grad()
183 | input = {'G': G, 'V': V}
184 | batch_scores = self.model(input)
185 | loss = self.loss_fn(batch_scores, batch_targets, graph = batch_graphs, stage = stage)
186 | loss.backward()
187 | self.optimizer.step()
188 |
189 | epoch_loss += loss.detach().item()
190 | epoch_metric += self.metric(batch_scores, batch_targets, graph = batch_graphs, stage = stage)
191 |
192 | loss_avg = epoch_loss / (i_step + 1)
193 | metric_avg = epoch_metric / (i_step + 1)
194 |
195 | result = {'loss' : loss_avg, 'metric' : metric_avg}
196 | t.set_postfix(lr = self.lr, **result)
197 |
198 | return result
199 |
200 |
201 | @record_run('infer')
202 | def infer(self, i_epoch, dataloader, stage = 'infer'):
203 |
204 | self.model.eval()
205 | epoch_loss = 0
206 | epoch_metric = 0
207 | desc = '=> inferring'
208 | device = torch.device('cuda')
209 |
210 | with tqdm(dataloader, desc = desc, leave = False) as t:
211 | for i_step, (batch_graphs, batch_targets) in enumerate(t):
212 | G = batch_graphs.to(device)
213 | V = batch_graphs.ndata['feat'].to(device)
214 |
215 | batch_targets = batch_targets.to(device)
216 | input = {'G': G, 'V': V}
217 | batch_scores = self.model(input)
218 | loss = self.loss_fn(batch_scores, batch_targets, graph = batch_graphs, stage = stage)
219 |
220 | epoch_loss += loss.detach().item()
221 | epoch_metric += self.metric(batch_scores, batch_targets, graph = batch_graphs, stage = stage)
222 |
223 | loss_avg = epoch_loss / (i_step + 1)
224 | metric_avg = epoch_metric / (i_step + 1)
225 |
226 | result = {'loss' : epoch_loss / (i_step + 1), 'metric' : metric_avg}
227 | t.set_postfix(**result)
228 |
229 | return result
230 |
231 |
232 | if __name__ == '__main__':
233 |
234 | import warnings
235 | warnings.filterwarnings('ignore')
236 |
237 | parser = argparse.ArgumentParser('Train_from_Genotype')
238 | parser.add_argument('--task', type = str, default = 'graph_level')
239 | parser.add_argument('--data', type = str, default = 'ZINC')
240 | parser.add_argument('--extra', type = str, default = '')
241 | parser.add_argument('--in_dim_V', type = int, default = 28)
242 | parser.add_argument('--node_dim', type = int, default = 70)
243 | parser.add_argument('--nb_layers', type = int, default = 4)
244 | parser.add_argument('--nb_nodes', type = int, default = 4)
245 | parser.add_argument('--nb_classes', type = int, default = 1)
246 | parser.add_argument('--leaky_slope', type = float, default = 1e-2)
247 | parser.add_argument('--batchnorm_op', default = False, action = 'store_true')
248 | parser.add_argument('--nb_mlp_layer', type = int, default = 4)
249 | parser.add_argument('--dropout', type = float, default = 0.0)
250 | parser.add_argument('--pos_encode', type = int, default = 0)
251 |
252 | parser.add_argument('--data_clip', type = float, default = 1.0)
253 | parser.add_argument('--nb_workers', type = int, default = 0)
254 | parser.add_argument('--seed', type = int, default = 41)
255 | parser.add_argument('--epochs', type = int, default = 100)
256 | parser.add_argument('--batch', type = int, default = 64)
257 | parser.add_argument('--lr', type = float, default = 0.025)
258 | parser.add_argument('--lr_min', type = float, default = 0.001)
259 | parser.add_argument('--momentum', type = float, default = 0.9)
260 | parser.add_argument('--weight_decay', type = float, default = 3e-4)
261 | parser.add_argument('--optimizer', type = str, default = 'ADAM')
262 | parser.add_argument('--patience', type = int, default = 10)
263 | parser.add_argument('--load_genotypes', type = str, required = True)
264 |
265 | from rich.console import Console
266 | from rich.table import Table
267 | from rich.panel import Panel
268 | from rich.syntax import Syntax
269 |
270 | console = Console()
271 | args = parser.parse_args()
272 | title = "[bold][red]Training from Genotype"
273 | vis = '\n'.join([f"{key}: {val}" for key, val in vars(args).items()])
274 | vis = Syntax(vis, "yaml", theme="monokai", line_numbers=True)
275 | richPanel = Panel.fit(vis, title = title)
276 | console.print(richPanel)
277 | Trainer(args).run()
278 | # - end - #
279 |
--------------------------------------------------------------------------------
/utils/record_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tensorboardX import SummaryWriter
3 |
4 | class record_run:
5 |
6 | def __init__(self, comment = ''):
7 | self.comment = comment
8 |
9 |
10 | def __call__(self, func):
11 |
12 | def wrapped_func(*args, **kwargs):
13 | writer = args[0].writer
14 | # geno_path = args[0].args.load_genotypes
15 | i_epoch = args[1]
16 | stage = args[-1]
17 |
18 | result = func(*args, **kwargs)
19 | writer.add_scalar(f'{stage}_loss', result['loss'], global_step = i_epoch)
20 | writer.add_scalar(f'{stage}_metric', result['metric'], global_step = i_epoch)
21 |
22 | return result
23 |
24 | return wrapped_func
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 | from sklearn.metrics import f1_score
8 | from sklearn.metrics import confusion_matrix
9 | from models.operations import OPS
10 |
11 | class DotDict(dict):
12 | def __init__(self, **kwds):
13 | self.update(kwds)
14 | self.__dict__ = self
15 |
16 | def load_alpha(genotype):
17 | alpha_cell = torch.Tensor(genotype.alpha_cell)
18 | alpha_edge = torch.Tensor(genotype.alpha_edge)
19 | return [alpha_cell, alpha_edge]
20 |
21 | class AvgrageMeter(object):
22 | def __init__(self):
23 | self.reset()
24 |
25 | def reset(self):
26 | self.avg = 0
27 | self.sum = 0
28 | self.cnt = 0
29 |
30 | def update(self, val, n=1):
31 | self.sum += val * n
32 | self.cnt += n
33 | self.avg = self.sum / self.cnt
34 |
35 | def count_parameters_in_MB(model):
36 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)
37 |
38 | def save_checkpoint(state, is_best, save):
39 | filename = os.path.join(save, 'checkpoint.pth.tar')
40 | torch.save(state, filename)
41 | if is_best:
42 | best_filename = os.path.join(save, 'model_best.pth.tar')
43 | shutil.copyfile(filename, best_filename)
44 |
45 | def save(model, model_path):
46 | torch.save(model.state_dict(), model_path)
47 |
48 | def load(model, model_path):
49 | model.load_state_dict(torch.load(model_path))
50 |
51 | def drop_path(x, drop_prob):
52 | if drop_prob > 0.:
53 | keep_prob = 1. - drop_prob
54 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
55 | x.div_(keep_prob)
56 | x.mul_(mask)
57 | return x
58 |
59 | def create_exp_dir(path, scripts_to_save=None):
60 | if not os.path.exists(path):
61 | os.mkdir(path)
62 | print('Experiment dir : {}'.format(path))
63 |
64 | if scripts_to_save is not None:
65 | os.mkdir(os.path.join(path, 'scripts'))
66 | for script in scripts_to_save:
67 | dst_file = os.path.join(path, 'scripts', os.path.basename(script))
68 | shutil.copyfile(script, dst_file)
69 |
70 | # ----------------------------------------------------------------
71 | # metrics
72 | class mask:
73 |
74 | def __init__(self, data_type, mask):
75 |
76 | self.data_type = data_type
77 | self.mask = mask
78 |
79 | def __call__(self, func):
80 |
81 | def wrapped_func(*args, **kwargs):
82 |
83 | if self.mask:
84 | args = list(args)
85 | graph = kwargs['graph']
86 | stage = kwargs['stage']
87 |
88 | graph_data = graph.ndata if self.data_type == 'V' else graph.edata
89 |
90 | if f'{stage}_mask' in graph_data:
91 | mask = graph_data[f'{stage}_mask']
92 | args[-1] = args[-1][mask]
93 | args[-2] = args[-2][mask]
94 | return func(*args)
95 |
96 | return wrapped_func
97 |
98 | def accuracy(output, target, topk=(1,)):
99 | maxk = max(topk)
100 | batch_size = target.size(0)
101 |
102 | _, pred = output.topk(maxk, 1, True, True)
103 | pred = pred.t()
104 | correct = pred.eq(target.view(1, -1).expand_as(pred))
105 |
106 | res = []
107 | for k in topk:
108 | correct_k = correct[:k].view(-1).float().sum(0)
109 | res.append(correct_k.mul_(100.0 / batch_size))
110 |
111 | return res
112 |
113 | @mask('E', False)
114 | def binary_f1_score(scores, targets):
115 | """Computes the F1 score using scikit-learn for binary class labels.
116 | Returns the F1 score for the positive class, i.e. labelled '1'.
117 | """
118 | y_true = targets.cpu().numpy()
119 | y_pred = scores.argmax(dim=1).cpu().numpy()
120 | return f1_score(y_true, y_pred, average='binary')
121 |
122 | @mask('V', False)
123 | def accuracy_SBM(scores, targets):
124 |
125 | S = targets.cpu().numpy()
126 | C = np.argmax(torch.nn.Softmax(dim=1)(scores).cpu().detach().numpy(), axis=1)
127 | CM = confusion_matrix(S, C).astype(np.float32)
128 | nb_classes = CM.shape[0]
129 | targets = targets.cpu().detach().numpy()
130 | nb_non_empty_classes = 0
131 | pr_classes = np.zeros(nb_classes)
132 | for r in range(nb_classes):
133 | cluster = np.where(targets == r)[0]
134 | if cluster.shape[0] != 0:
135 | pr_classes[r] = CM[r, r] / float(cluster.shape[0])
136 | if CM[r, r] > 0:
137 | nb_non_empty_classes += 1
138 | else:
139 | pr_classes[r] = 0.0
140 | acc = 100. * np.sum(pr_classes) / float(nb_classes)
141 | return acc
142 |
143 | @mask('G', False)
144 | def accuracy_MNIST_CIFAR(scores, targets):
145 | scores = scores.detach().argmax(dim=1)
146 | acc = (scores == targets).float().mean().item()
147 | return acc
148 |
149 | @mask('G', False)
150 | def MAE(scores, targets):
151 | MAE = F.l1_loss(scores, targets)
152 | MAE = MAE.detach().item()
153 | return MAE
154 |
155 | @mask('V', True)
156 | def CoraAccuracy(scores, targets):
157 | return (scores.argmax(1) == targets).float().mean().item()
158 |
159 | # ----------------------------------------------------------------
160 | # loss functions
161 |
162 | class MoleculesCriterion(nn.Module):
163 |
164 | def __init__(self):
165 | super().__init__()
166 | self.loss_fn = nn.L1Loss()
167 |
168 | @mask('G', False)
169 | def forward(self, pred, label):
170 | return self.loss_fn(pred, label)
171 |
172 | class TSPCriterion(nn.Module):
173 |
174 | def __init__(self):
175 | super().__init__()
176 | self.loss_fn = nn.CrossEntropyLoss(weight=None)
177 |
178 | @mask('E', False)
179 | def forward(self, pred, label):
180 | return self.loss_fn(pred, label)
181 |
182 | class SuperPixCriterion(nn.Module):
183 |
184 | def __init__(self):
185 | super().__init__()
186 | self.loss_fn = nn.CrossEntropyLoss(weight=None)
187 |
188 | @mask('G', False)
189 | def forward(self, pred, label):
190 | return self.loss_fn(pred, label)
191 |
192 | class SBMsCriterion(nn.Module):
193 |
194 | def __init__(self, num_classes):
195 | super().__init__()
196 | self.n_classes = num_classes
197 |
198 | @mask('V', False)
199 | def forward(self, pred, label):
200 | V = label.size(0)
201 | label_count = torch.bincount(label)
202 | label_count = label_count[label_count.nonzero()].squeeze()
203 | cluster_sizes = torch.zeros(self.n_classes).long().cuda()
204 | cluster_sizes[torch.unique(label)] = label_count
205 | weight = (V - cluster_sizes).float() / V
206 | weight *= (cluster_sizes > 0).float()
207 | # weighted cross-entropy for unbalanced classes
208 | criterion = nn.CrossEntropyLoss(weight=weight)
209 | loss = criterion(pred, label)
210 | return loss
211 |
212 | class CiteCriterion(nn.Module):
213 |
214 | def __init__(self):
215 | super().__init__()
216 | self.loss_fn = nn.CrossEntropyLoss(weight=None)
217 |
218 | @mask('V', True)
219 | def forward(self, pred, label):
220 | return self.loss_fn(pred, label)
221 |
222 |
223 | # ----------------------------------------------------------------
224 | def cell_genotype(args, id, arch_para, arch_topo):
225 | result = {'id': id, 'topology': []}
226 | link = [ [] for i in range(args.nb_nodes*3+1) ]
227 |
228 | for src, dst, w, ops in arch_topo:
229 | link[dst].append((src, ops, arch_para[w]))
230 | for dst in range(1, args.nb_nodes*3+1):
231 | nb_link = len(link[dst])
232 | best_links = sorted(
233 | range(nb_link),
234 | key = lambda lk: -max(
235 | link[dst][lk][-1][j]
236 | for j in range(len(link[dst][lk][-2]))
237 | if 'V_None' not in link[dst][lk][-2] or j != link[dst][lk][-2].index('V_None')
238 | )
239 | )[:1] #! 截取的操作数量
240 | for blink in best_links:
241 | blink = link[dst][blink]
242 | if 'V_None' in blink[-2]:
243 | best_op = torch.argmax(blink[-1][1:]).item() + 1
244 | else:
245 | best_op = torch.argmax(blink[-1]).item()
246 | src = blink[0]
247 | ops = blink[1][best_op]
248 | result['topology'].append({'src': src, 'dst': dst, 'ops': ops})
249 | return result
250 |
251 | def genotypes(args, arch_paras, arch_topos):
252 | result = {'Genotype': []}
253 | for id in range(args.nb_layers):
254 | cell_result = cell_genotype(args, id, arch_paras[id], arch_topos[id])
255 | result['Genotype'].append(cell_result)
256 | return result
257 |
258 | # ----------------------------------------------------------------
259 | import math
260 | def Singleton(cls):
261 | _instance = {}
262 |
263 | def _singleton(*args, **kargs):
264 | if cls not in _instance:
265 | _instance[cls] = cls(*args, **kargs)
266 | return _instance[cls]
267 |
268 | return _singleton
269 |
270 | @Singleton
271 | class DecayScheduler(object):
272 | def __init__(self, base_lr=1.0, last_iter=-1, T_max=50, T_start=0, T_stop=50, decay_type='cosine'):
273 | self.base_lr = base_lr
274 | self.T_max = T_max
275 | self.T_start = T_start
276 | self.T_stop = T_stop
277 | self.cnt = 0
278 | self.decay_type = decay_type
279 | self.decay_rate = 1.0
280 |
281 | def step(self, epoch):
282 | if epoch >= self.T_start:
283 | if self.decay_type == "cosine":
284 | self.decay_rate = self.base_lr * (1 + math.cos(math.pi * epoch / (self.T_max - self.T_start))) / 2.0 if epoch <= self.T_stop else self.decay_rate
285 | elif self.decay_type == "slow_cosine":
286 | self.decay_rate = self.base_lr * math.cos((math.pi/2) * epoch / (self.T_max - self.T_start)) if epoch <= self.T_stop else self.decay_rate
287 | elif self.decay_type == "linear":
288 | self.decay_rate = self.base_lr * (self.T_max - epoch) / (self.T_max - self.T_start) if epoch <= self.T_stop else self.decay_rate
289 | else:
290 | self.decay_rate = self.base_lr
291 | else:
292 | self.decay_rate = self.base_lr
293 |
294 |
--------------------------------------------------------------------------------