├── .gitignore ├── README.md ├── core ├── __init__.py ├── data.py ├── model.py ├── nn.py ├── train_engine.py ├── utils.py └── utils_ipynb.py ├── results.ipynb ├── train.py └── write_exp_cfgs_file.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | runs 3 | results 4 | .vscode 5 | .ipynb_checkpoints -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains the code for our work **Graph Filtration Learning** which was accepted at ICML'20. 2 | 3 | 4 | # Installation 5 | 6 | In the following `` will be the directory in which you have chosen to do the installation. 7 | 8 | 1. Install Anaconda from [here](https://repo.anaconda.com/archive/Anaconda3-2020.07-Linux-x86_64.sh) into `/anaconda3`, i.e., set the prefix accordingly in the installer. 9 | 10 | 2. Activate Anaconda installation: 11 | 12 | ``` 13 | source /anaconda3/bin/activate 14 | ``` 15 | 16 | 17 | 3. Install pytorch via conda 18 | 19 | ``` 20 | conda install pytorch=1.4.0 torchvision cudatoolkit= -c pytorch 21 | ``` 22 | 23 | 24 | 25 | 4. Install `pytorch-geometric` and its dependencies following the instructions on its [gh-page](https://github.com/rusty1s/pytorch_geometric). 26 | 27 | 5. Install `torchph` via 28 | 29 | ``` 30 | cd 31 | git clone -b 'submission_icml2020' --single-branch --depth 1 https://github.com/c-hofer/torchph.git 32 | conda develop torchph 33 | ``` 34 | 6. Clone this repository into ``. 35 | 36 | # Application 37 | 38 | 1. Generate the experiment configurations you want using the `write_exp_cfgs_file.ipynb` notebook. It is assumed that the notebook server is started in `/graph_filtration_learning`. 39 | 40 | 2. Use the `train.py` script to run the experiments, e.g., 41 | ``` 42 | python train.py --cfg_file --output_dir --devices 0,1 --max_process_on_device 2 43 | ``` 44 | to use cuda device 0 and 1 with at most 2 experiments on each. 45 | 46 | Each experiment gets a unique id and its output is written to `` as a pickle file. Additionally for each CV run the corresponding trained model is dumped. 47 | 48 | 3. The notebook `results.ipynb` contains some code to browse the results. 49 | 50 | 51 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c-hofer/graph_filtration_learning/cda10fa138d26c4de15b881541d0da3246650f59/core/__init__.py -------------------------------------------------------------------------------- /core/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch_geometric 4 | import torch_geometric.data 5 | 6 | from torch_geometric.datasets import TUDataset 7 | from collections import Counter 8 | from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit 9 | 10 | POWERFUL_GNN_DATASET_NAMES = ["PTC_PGNN"] 11 | TU_DORTMUND_DATASET_NAMES = [ 12 | "NCI1", "PTC_MR", 'PTC_FM', 13 | 'PTC_FR', 'PTC_MM', "PROTEINS", 14 | "REDDIT-BINARY", "REDDIT-MULTI-5K", 15 | "ENZYMES", "DD", "IMDB-BINARY", "IMDB-MULTI", "MUTAG", "COLLAB" 16 | ] 17 | 18 | 19 | class SimpleDataset(torch.utils.data.Dataset): 20 | def __init__(self, X): 21 | assert isinstance(X, list) 22 | self.data = X 23 | 24 | def __len__(self): 25 | return len(self.data) 26 | 27 | def __getitem__(self, idx): 28 | return self.data[idx] 29 | 30 | def __iter__(self): 31 | for i in range(len(self.data)): 32 | yield self.data[i] 33 | 34 | 35 | class Subset(torch.utils.data.Dataset): 36 | def __init__(self, dataset, indices): 37 | assert isinstance(indices, (list, tuple)) 38 | self.ds = dataset 39 | self.indices = tuple(indices) 40 | 41 | assert len(indices) <= len(dataset) 42 | 43 | def __len__(self): 44 | return len(self.indices) 45 | 46 | def __getitem__(self, idx): 47 | return self.ds[self.indices[idx]] 48 | 49 | 50 | def load_powerfull_gnn_dataset_PTC(): 51 | has_node_features = False 52 | 53 | dataset_name = "PTC" 54 | path = "/home/pma/chofer/repositories/powerful-gnns/dataset/{}/{}.txt".format(dataset_name, dataset_name) 55 | 56 | with open(path, 'r') as f: 57 | num_graphs = int(f.readline().strip()) 58 | 59 | data = [] 60 | 61 | graph_label_map = {} 62 | node_lab_map = {} 63 | 64 | for i in range(num_graphs): 65 | row = f.readline().strip().split() 66 | num_nodes, graph_label = [int(w) for w in row] 67 | 68 | if graph_label not in graph_label_map: 69 | graph_label_map[graph_label] = len(graph_label_map) 70 | 71 | graph_label = graph_label_map[graph_label] 72 | 73 | 74 | nodes = [] 75 | node_labs = [] 76 | edges = [] 77 | node_features = [] 78 | 79 | for node_id in range(num_nodes): 80 | nodes.append(node_id) 81 | 82 | row = f.readline().strip().split() 83 | 84 | node_lab = int(row[0]) 85 | 86 | if node_lab not in node_lab_map: 87 | node_lab_map[node_lab] = len(node_lab_map) 88 | 89 | node_labs.append(node_lab_map[node_lab]) 90 | 91 | num_neighbors = int(row[1]) 92 | neighbors = [int(i) for i in row[2:num_neighbors+2]] 93 | assert num_neighbors == len(neighbors) 94 | 95 | edges += [(node_id, neighbor_id) for neighbor_id in neighbors] 96 | 97 | if has_node_features: 98 | node_features = [float(i) for i in row[(2 + num_neighbors):]] 99 | assert len(node_features) != 0 100 | 101 | # x = torch.tensor(node_features) if has_node_features else None 102 | x = torch.tensor(node_labs, dtype=torch.long) 103 | 104 | edge_index = torch.tensor(edges, dtype=torch.long) 105 | edge_index = edge_index.permute(1, 0) 106 | tmp = edge_index.index_select(0, torch.tensor([1, 0])) 107 | edge_index = torch.cat([edge_index, tmp], dim=1) 108 | 109 | y = torch.tensor([graph_label]) 110 | 111 | d = torch_geometric.data.Data( 112 | x=x, 113 | edge_index=edge_index, 114 | y = y 115 | ) 116 | 117 | d.num_nodes = num_nodes 118 | 119 | data.append(d) 120 | 121 | max_node_lab = max([d.x.max().item() for d in data]) + 1 122 | eye = torch.eye(max_node_lab, dtype=torch.long) 123 | for d in data: 124 | node_lab = eye.index_select(0, d.x) 125 | d.x = node_lab 126 | 127 | ds = SimpleDataset(data) 128 | ds.name = dataset_name 129 | 130 | return ds 131 | 132 | 133 | def get_boundary_info(g): 134 | 135 | e = g.edge_index.permute(1, 0).sort(1)[0].tolist() 136 | e = set([tuple(ee) for ee in e]) 137 | return torch.tensor([ee for ee in e], dtype=torch.long) 138 | 139 | 140 | def enhance_TUDataset(ds): 141 | 142 | X = [] 143 | targets = [] 144 | 145 | max_degree_by_graph = [] 146 | num_nodes = [] 147 | num_edges = [] 148 | 149 | for d in ds: 150 | 151 | targets.append(d.y.item()) 152 | 153 | boundary_info = get_boundary_info(d) 154 | d.boundary_info = boundary_info 155 | 156 | num_nodes.append(d.num_nodes) 157 | num_edges.append(boundary_info.size(0)) 158 | 159 | degree = torch.zeros(d.num_nodes, dtype=torch.long) 160 | 161 | for k, v in Counter(d.boundary_info.flatten().tolist()).items(): 162 | degree[k] = v 163 | max_degree_by_graph.append(degree.max().item()) 164 | 165 | d.node_deg = degree 166 | X.append(d) 167 | 168 | max_node_deg = max(max_degree_by_graph) 169 | 170 | num_node_lab = None 171 | if hasattr(X[0], 'x') and X[0].x is not None: 172 | 173 | all_node_lab = [] 174 | for d in X: 175 | assert d.x.sum() == d.x.size(0) # really one hot encoded? 176 | node_lab = d.x.argmax(1).tolist() 177 | d.node_lab = node_lab 178 | all_node_lab += node_lab 179 | 180 | all_node_lab = set(all_node_lab) 181 | num_node_lab = len(all_node_lab) 182 | label_map = {k: i for i, k in enumerate(sorted(all_node_lab))} 183 | 184 | for d in X: 185 | d.node_lab = [label_map[f] for f in d.node_lab] 186 | d.node_lab = torch.tensor(d.node_lab, dtype=torch.long) 187 | else: 188 | for d in X: 189 | d.node_lab = None 190 | 191 | new_ds = SimpleDataset(X) 192 | 193 | new_ds.max_node_deg = max_node_deg 194 | new_ds.avg_num_nodes = np.mean(num_nodes) 195 | new_ds.avg_num_edges = np.mean(num_edges) 196 | new_ds.num_classes = len(set(targets)) 197 | new_ds.num_node_lab = num_node_lab 198 | 199 | return new_ds 200 | 201 | 202 | def dataset_factory(dataset_name, verbose=True): 203 | if dataset_name in TU_DORTMUND_DATASET_NAMES: 204 | 205 | path = 'data/{}/'.format(dataset_name) 206 | dataset = TUDataset(path, name=dataset_name) 207 | 208 | elif dataset_name in POWERFUL_GNN_DATASET_NAMES: 209 | if dataset_name == "PTC_PGNN": 210 | dataset = load_powerfull_gnn_dataset_PTC() 211 | 212 | else: 213 | raise ValueError("dataset_name not in {}".format(TU_DORTMUND_DATASET_NAMES + POWERFUL_GNN_DATASET_NAMES)) 214 | ds_name = dataset.name 215 | dataset = enhance_TUDataset(dataset) 216 | 217 | if verbose: 218 | print("# Dataset: ", ds_name) 219 | print('# num samples: ', len(dataset)) 220 | print('# num classes: ', dataset.num_classes) 221 | print('#') 222 | print('# max node degree: ', dataset.max_node_deg) 223 | print('# num node lable: ', dataset.num_node_lab) 224 | print('#') 225 | print('# avg number of nodes: ', dataset.avg_num_nodes) 226 | print('# avg number of edges: ', dataset.avg_num_edges) 227 | 228 | 229 | return dataset 230 | 231 | 232 | def train_test_val_split( 233 | dataset, 234 | seed=0, 235 | n_splits=10, 236 | verbose=True, 237 | validation_ratio=0.0): 238 | 239 | skf = StratifiedKFold( 240 | n_splits=n_splits, 241 | shuffle = True, 242 | random_state = seed, 243 | ) 244 | 245 | targets = [x.y.item() for x in dataset] 246 | split_idx = list(skf.split(np.zeros(len(dataset)), targets)) 247 | 248 | if verbose: 249 | print('# num splits: ', len(split_idx)) 250 | print('# validation ratio: ', validation_ratio) 251 | 252 | split_ds = [] 253 | split_i = [] 254 | for train_i, test_i in split_idx: 255 | not_test_i, test_i = train_i.tolist(), test_i.tolist() 256 | 257 | if validation_ratio == 0.0: 258 | validation_i = [] 259 | train_i = not_test_i 260 | 261 | else: 262 | skf = StratifiedShuffleSplit( 263 | n_splits=1, 264 | random_state = seed, 265 | test_size=validation_ratio 266 | ) 267 | 268 | targets = [dataset[i].y.item() for i in not_test_i] 269 | train_i, validation_i = list(skf.split(np.zeros(len(not_test_i)), targets))[0] 270 | train_i, validation_i = train_i.tolist(), validation_i.tolist() 271 | 272 | # We need the indices w.r.t. the original dataset 273 | # not w.r.t. the current train fold ... 274 | train_i = [not_test_i[j] for j in train_i] 275 | validation_i = [not_test_i[j] for j in validation_i] 276 | 277 | assert len(set(train_i).intersection(set(validation_i))) == 0 278 | 279 | train = Subset(dataset, train_i) 280 | test = Subset(dataset, test_i) 281 | validation = Subset(dataset, validation_i) 282 | 283 | assert sum([len(train), len(test), len(validation)]) == len(dataset) 284 | 285 | split_ds.append((train, test, validation)) 286 | split_i.append((train_i, test_i, validation_i)) 287 | 288 | return split_ds, split_i 289 | -------------------------------------------------------------------------------- /core/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch_geometric 5 | import torch_geometric.nn as geonn 6 | 7 | import functools 8 | 9 | from chofer_torchex.nn import SLayerRationalHat 10 | from torch_geometric.nn import GINConv, global_add_pool, global_sort_pool 11 | 12 | 13 | from chofer_torchex import pershom 14 | ph = pershom.pershom_backend.__C.VertFiltCompCuda__vert_filt_persistence_batch 15 | 16 | 17 | def gin_mlp_factory(gin_mlp_type: str, dim_in: int, dim_out: int): 18 | if gin_mlp_type == 'lin': 19 | return nn.Linear(dim_in, dim_out) 20 | 21 | elif gin_mlp_type == 'lin_lrelu_lin': 22 | return nn.Sequential( 23 | nn.Linear(dim_in, dim_in), 24 | nn.LeakyReLU(), 25 | nn.Linear(dim_in, dim_out) 26 | ) 27 | 28 | elif gin_mlp_type == 'lin_bn_lrelu_lin': 29 | return nn.Sequential( 30 | nn.Linear(dim_in, dim_in), 31 | nn.BatchNorm1d(dim_in), 32 | nn.LeakyReLU(), 33 | nn.Linear(dim_in, dim_out) 34 | ) 35 | else: 36 | raise ValueError("Unknown gin_mlp_type!") 37 | 38 | 39 | def ClassifierHead( 40 | dataset, 41 | dim_in: int=None, 42 | hidden_dim: int=None, 43 | drop_out: float=None): 44 | 45 | assert (0.0 <= drop_out) and (drop_out < 1.0) 46 | assert dim_in is not None 47 | assert drop_out is not None 48 | assert hidden_dim is not None 49 | 50 | tmp = [ 51 | nn.Linear(dim_in, hidden_dim), 52 | nn.LeakyReLU(), 53 | ] 54 | 55 | if drop_out > 0: 56 | tmp += [nn.Dropout(p=drop_out)] 57 | 58 | tmp += [nn.Linear(hidden_dim, dataset.num_classes)] 59 | 60 | return nn.Sequential(*tmp) 61 | 62 | 63 | class DegreeOnlyFiltration(torch.nn.Module): 64 | def __init__(self): 65 | super().__init__() 66 | 67 | def forward(self, batch): 68 | tmp = [] 69 | for i, j in zip(batch.sample_pos[:-1], batch.sample_pos[1:]): 70 | max_deg = batch.node_deg[i:j].max() 71 | 72 | t = torch.ones(j - i, dtype=torch.float, device=batch.node_deg.device) 73 | t = t * max_deg 74 | tmp.append(t) 75 | 76 | max_deg = torch.cat(tmp, dim=0) 77 | 78 | normalized_node_deg = batch.node_deg.float() / max_deg 79 | 80 | return normalized_node_deg 81 | 82 | 83 | class Filtration(torch.nn.Module): 84 | def __init__(self, 85 | dataset, 86 | use_node_degree=None, 87 | set_node_degree_uninformative=None, 88 | use_node_label=None, 89 | gin_number=None, 90 | gin_dimension=None, 91 | gin_mlp_type = None, 92 | **kwargs 93 | ): 94 | super().__init__() 95 | 96 | dim = gin_dimension 97 | 98 | max_node_deg = dataset.max_node_deg 99 | num_node_lab = dataset.num_node_lab 100 | 101 | if set_node_degree_uninformative and use_node_degree: 102 | self.embed_deg = UniformativeDummyEmbedding(gin_dimension) 103 | elif use_node_degree: 104 | self.embed_deg = nn.Embedding(max_node_deg+1, dim) 105 | else: 106 | self.embed_deg = None 107 | 108 | self.embed_lab = nn.Embedding(num_node_lab, dim) if use_node_label else None 109 | 110 | dim_input = dim*((self.embed_deg is not None) + (self.embed_lab is not None)) 111 | 112 | dims = [dim_input] + (gin_number)*[dim] 113 | 114 | self.convs = nn.ModuleList() 115 | self.bns = nn.ModuleList() 116 | self.act = torch.nn.functional.leaky_relu 117 | 118 | for n_1, n_2 in zip(dims[:-1], dims[1:]): 119 | l = gin_mlp_factory(gin_mlp_type, n_1, n_2) 120 | self.convs.append(GINConv(l, train_eps=True)) 121 | self.bns.append(nn.BatchNorm1d(n_2)) 122 | 123 | self.fc = nn.Sequential( 124 | nn.Linear(sum(dims), dim), 125 | nn.BatchNorm1d(dim), 126 | nn.LeakyReLU(), 127 | nn.Linear(dim, 1), 128 | nn.Sigmoid() 129 | ) 130 | 131 | def forward(self, batch): 132 | 133 | node_deg = batch.node_deg 134 | node_lab = batch.node_lab 135 | 136 | edge_index = batch.edge_index 137 | 138 | tmp = [e(x) for e, x in 139 | zip([self.embed_deg, self.embed_lab], [node_deg, node_lab]) 140 | if e is not None] 141 | 142 | tmp = torch.cat(tmp, dim=1) 143 | 144 | z = [tmp] 145 | 146 | for conv, bn in zip(self.convs, self.bns): 147 | x = conv(z[-1], edge_index) 148 | x = bn(x) 149 | x = self.act(x) 150 | z.append(x) 151 | 152 | x = torch.cat(z, dim=1) 153 | ret = self.fc(x).squeeze() 154 | return ret 155 | 156 | 157 | class PershomClassifier(nn.Module): 158 | def __init__(self, 159 | dataset, 160 | num_struct_elements=None, 161 | cls_hidden_dimension=None, 162 | drop_out=None, 163 | ): 164 | 165 | super().__init__() 166 | assert isinstance(num_struct_elements, int) 167 | self.use_as_feature_extractor = False 168 | 169 | self.ldgm_0 = SLayerRationalHat(num_struct_elements, 2, radius_init=0.1) 170 | self.ldgm_0_ess = SLayerRationalHat(num_struct_elements, 1, radius_init=0.1) 171 | self.ldgm_1_ess = SLayerRationalHat(num_struct_elements, 1, radius_init=0.1) 172 | fc_in_feat = 3*num_struct_elements 173 | 174 | self.cls_head = ClassifierHead( 175 | dataset, 176 | dim_in=fc_in_feat, 177 | hidden_dim=cls_hidden_dimension, 178 | drop_out=drop_out 179 | ) 180 | 181 | def forward(self, h_0, h_0_ess, h_1_ess): 182 | tmp = [] 183 | 184 | tmp.append(self.ldgm_0(h_0)) 185 | tmp.append(self.ldgm_0_ess(h_0_ess)) 186 | tmp.append(self.ldgm_1_ess(h_1_ess)) 187 | 188 | x = torch.cat(tmp, dim=1) 189 | 190 | if not self.use_as_feature_extractor: 191 | x = self.cls_head(x) 192 | 193 | return x 194 | 195 | 196 | class PershomBase(nn.Module): 197 | def __init__(self): 198 | super().__init__() 199 | 200 | self.use_super_level_set_filtration = None 201 | self.use_as_feature_extractor = False 202 | self.fil = None 203 | self.cls = None 204 | 205 | def forward(self, batch): 206 | assert self.use_super_level_set_filtration is not None 207 | 208 | node_filt = self.fil(batch) 209 | 210 | ph_input = [] 211 | for i, j, e in zip(batch.sample_pos[:-1], batch.sample_pos[1:], batch.boundary_info): 212 | v = node_filt[i:j] 213 | ph_input.append((v, [e])) 214 | 215 | pers = ph(ph_input) 216 | 217 | if not self.use_super_level_set_filtration: 218 | h_0 = [x[0][0] for x in pers] 219 | h_0_ess = [x[1][0].unsqueeze(1) for x in pers] 220 | h_1_ess = [x[1][1].unsqueeze(1) for x in pers] 221 | 222 | else: 223 | ph_sup_input = [(-v, e) for v, e in ph_input] 224 | pers_sup = ph(ph_sup_input) 225 | 226 | h_0 = [torch.cat([x[0][0], -(y[0][0])], dim=0) for x, y in zip(pers, pers_sup)] 227 | h_0_ess = [torch.cat([x[1][0], -(y[1][0])], dim=0).unsqueeze(1) for x, y in zip(pers, pers_sup)] 228 | h_1_ess = [torch.cat([x[1][1], -(y[1][1])], dim=0).unsqueeze(1) for x, y in zip(pers, pers_sup)] 229 | 230 | y_hat = self.cls(h_0, h_0_ess, h_1_ess) 231 | 232 | return y_hat 233 | 234 | @property 235 | def feature_dimension(self): 236 | return self.cls.cls_head[0].in_features 237 | 238 | @property 239 | def use_as_feature_extractor(self): 240 | return self.use_as_feature_extractor 241 | 242 | @use_as_feature_extractor.setter 243 | def use_as_feature_extractor(self, val): 244 | if hasattr(self, 'cls'): 245 | self.cls.use_as_feature_extractor = val 246 | 247 | 248 | def init_weights(self): 249 | def init(m): 250 | if isinstance(m, nn.Linear): 251 | torch.nn.init.xavier_uniform_(m.weight) 252 | m.bias.data.fill_(0.01) 253 | 254 | self.apply(init) 255 | 256 | 257 | class PershomLearnedFilt(PershomBase): 258 | def __init__(self, 259 | dataset, 260 | use_super_level_set_filtration: bool=None, 261 | use_node_degree: bool=None, 262 | set_node_degree_uninformative: bool=None, 263 | use_node_label: bool=None, 264 | gin_number: int=None, 265 | gin_dimension: int=None, 266 | gin_mlp_type: str=None, 267 | num_struct_elements: int=None, 268 | cls_hidden_dimension: int=None, 269 | drop_out: float=None, 270 | **kwargs, 271 | ): 272 | super().__init__() 273 | 274 | 275 | self.use_super_level_set_filtration = use_super_level_set_filtration 276 | 277 | self.fil = Filtration( 278 | dataset, 279 | use_node_degree=use_node_degree, 280 | set_node_degree_uninformative=set_node_degree_uninformative, 281 | use_node_label=use_node_label, 282 | gin_number=gin_number, 283 | gin_dimension=gin_dimension, 284 | gin_mlp_type =gin_mlp_type, 285 | ) 286 | 287 | self.cls = PershomClassifier( 288 | dataset, 289 | num_struct_elements=num_struct_elements, 290 | cls_hidden_dimension=cls_hidden_dimension, 291 | drop_out=drop_out 292 | ) 293 | 294 | self.init_weights() 295 | 296 | 297 | class PershomRigidDegreeFilt(PershomBase): 298 | def __init__(self, 299 | dataset, 300 | use_super_level_set_filtration: bool=None, 301 | num_struct_elements: int=None, 302 | cls_hidden_dimension: int=None, 303 | drop_out: float=None, 304 | **kwargs, 305 | ): 306 | super().__init__() 307 | 308 | 309 | self.use_super_level_set_filtration = use_super_level_set_filtration 310 | 311 | self.fil = DegreeOnlyFiltration() 312 | 313 | self.cls = PershomClassifier( 314 | dataset, 315 | num_struct_elements=num_struct_elements, 316 | drop_out=drop_out, 317 | cls_hidden_dimension=cls_hidden_dimension 318 | ) 319 | 320 | self.init_weights() 321 | 322 | 323 | class OneHotEmbedding(nn.Module): 324 | def __init__(self, dim): 325 | super().__init__() 326 | eye = torch.eye(dim, dtype=torch.float) 327 | 328 | self.register_buffer('eye', eye) 329 | 330 | def forward(self, batch): 331 | assert batch.dtype == torch.long 332 | 333 | return self.eye.index_select(0, batch) 334 | 335 | @property 336 | def dim(self): 337 | return self.eye.size(1) 338 | 339 | 340 | class UniformativeDummyEmbedding(nn.Module): 341 | def __init__(self, dim): 342 | super().__init__() 343 | b = torch.ones(1, dim, dtype=torch.float) 344 | self.register_buffer('ones', b) 345 | 346 | def forward(self, batch): 347 | assert batch.dtype == torch.long 348 | return self.ones.expand(batch.size(0), -1) 349 | 350 | @property 351 | def dim(self): 352 | return self.ones.size(1) 353 | 354 | 355 | class GIN(nn.Module): 356 | def __init__(self, 357 | dataset, 358 | use_node_degree: bool=None, 359 | use_node_label: bool=None, 360 | gin_number: int=None, 361 | gin_dimension: int=None, 362 | gin_mlp_type: str=None, 363 | cls_hidden_dimension: int=None, 364 | drop_out: float=None, 365 | set_node_degree_uninformative: bool=None, 366 | pooling_strategy: str=None, 367 | **kwargs, 368 | ): 369 | super().__init__() 370 | self.use_as_feature_extractor = False 371 | self.pooling_strategy = pooling_strategy 372 | self.gin_dimension = gin_dimension 373 | 374 | dim = gin_dimension 375 | 376 | max_node_deg = dataset.max_node_deg 377 | num_node_lab = dataset.num_node_lab 378 | 379 | if set_node_degree_uninformative and use_node_degree: 380 | self.embed_deg = UniformativeDummyEmbedding(gin_dimension) 381 | elif use_node_degree: 382 | self.embed_deg = OneHotEmbedding(max_node_deg+1) 383 | else: 384 | self.embed_deg = None 385 | 386 | self.embed_lab = OneHotEmbedding(num_node_lab) if use_node_label else None 387 | 388 | dim_input = 0 389 | dim_input += self.embed_deg.dim if use_node_degree else 0 390 | dim_input += self.embed_lab.dim if use_node_label else 0 391 | assert dim_input > 0 392 | 393 | dims = [dim_input] + (gin_number)*[dim] 394 | self.convs = nn.ModuleList() 395 | self.bns = nn.ModuleList() 396 | self.act = torch.nn.functional.leaky_relu 397 | 398 | for n_1, n_2 in zip(dims[:-1], dims[1:]): 399 | l = gin_mlp_factory(gin_mlp_type, n_1, n_2) 400 | self.convs.append(GINConv(l, train_eps=True)) 401 | self.bns.append(nn.BatchNorm1d(n_2)) 402 | 403 | if pooling_strategy == 'sum': 404 | self.global_pool_fn = global_add_pool 405 | elif pooling_strategy == 'sort': 406 | self.k = int(np.percentile([d.num_nodes for d in dataset], 10)) 407 | self.global_pool_fn = functools.partial(global_sort_pool, k=self.k) 408 | self.sort_pool_nn = nn.Linear(self.k * gin_dimension, gin_dimension) 409 | #nn.Conv1d( 410 | # in_channels=gin_dimension, 411 | # out_channels=gin_dimension, 412 | # kernel_size=self.k 413 | #) 414 | else: 415 | raise ValueError 416 | 417 | self.cls = ClassifierHead( 418 | dataset, 419 | dim_in=gin_dimension, 420 | hidden_dim=cls_hidden_dimension, 421 | drop_out=drop_out 422 | ) 423 | 424 | @property 425 | def feature_dimension(self): 426 | return self.cls.cls_head[0].in_features 427 | 428 | def forward(self, batch): 429 | 430 | node_deg = batch.node_deg 431 | node_lab = batch.node_lab 432 | 433 | edge_index = batch.edge_index 434 | 435 | tmp = [e(x) for e, x in 436 | zip([self.embed_deg, self.embed_lab], [node_deg, node_lab]) 437 | if e is not None] 438 | 439 | tmp = torch.cat(tmp, dim=1) 440 | 441 | z = [tmp] 442 | 443 | for conv, bn in zip(self.convs, self.bns): 444 | x = conv(z[-1], edge_index) 445 | x = bn(x) 446 | x = self.act(x) 447 | z.append(x) 448 | 449 | # x = torch.cat(z, dim=1) 450 | x = z[-1] 451 | x = self.global_pool_fn(x, batch.batch) 452 | 453 | if self.pooling_strategy == 'sort': 454 | #x = x.view(x.size(0), self.gin_dimension * self.k) 455 | x = self.sort_pool_nn(x) 456 | x = x.squeeze() 457 | 458 | if not self.use_as_feature_extractor: 459 | x = self.cls(x) 460 | 461 | return x 462 | 463 | 464 | class SimpleNNBaseline(nn.Module): 465 | def __init__(self, 466 | dataset, 467 | use_node_degree: bool=None, 468 | use_node_label: bool=None, 469 | set_node_degree_uninformative: bool=None, 470 | gin_dimension: int=None, 471 | gin_mlp_type: str=None, 472 | cls_hidden_dimension: int=None, 473 | drop_out: float=None, 474 | pooling_strategy: str=None, 475 | **kwargs, 476 | ): 477 | super().__init__() 478 | self.use_as_feature_extractor = False 479 | self.pooling_strategy = pooling_strategy 480 | self.gin_dimension = gin_dimension 481 | 482 | dim = gin_dimension 483 | 484 | max_node_deg = dataset.max_node_deg 485 | num_node_lab = dataset.num_node_lab 486 | 487 | if set_node_degree_uninformative and use_node_degree: 488 | self.embed_deg = UniformativeDummyEmbedding(gin_dimension) 489 | elif use_node_degree: 490 | self.embed_deg = OneHotEmbedding(max_node_deg+1) 491 | else: 492 | self.embed_deg = None 493 | 494 | self.embed_lab = OneHotEmbedding(num_node_lab) if use_node_label else None 495 | 496 | dim_input = 0 497 | dim_input += self.embed_deg.dim if use_node_degree else 0 498 | dim_input += self.embed_lab.dim if use_node_label else 0 499 | assert dim_input > 0 500 | 501 | self.mlp = gin_mlp_factory(gin_mlp_type, dim_input, dim) 502 | 503 | if pooling_strategy == 'sum': 504 | self.global_pool_fn = global_add_pool 505 | elif pooling_strategy == 'sort': 506 | self.k = int(np.percentile([d.num_nodes for d in dataset], 10)) 507 | self.global_pool_fn = functools.partial(global_sort_pool, k=self.k) 508 | self.sort_pool_nn = nn.Linear(self.k * gin_dimension, gin_dimension) 509 | else: 510 | raise ValueError 511 | 512 | self.cls = ClassifierHead( 513 | dataset, 514 | dim_in=gin_dimension, 515 | hidden_dim=cls_hidden_dimension, 516 | drop_out=drop_out 517 | ) 518 | 519 | @property 520 | def feature_dimension(self): 521 | return self.cls.cls_head[0].in_features 522 | 523 | def forward(self, batch): 524 | 525 | node_deg = batch.node_deg 526 | node_lab = batch.node_lab 527 | 528 | edge_index = batch.edge_index 529 | 530 | tmp = [e(x) for e, x in 531 | zip([self.embed_deg, self.embed_lab], [node_deg, node_lab]) 532 | if e is not None] 533 | 534 | x = torch.cat(tmp, dim=1) 535 | 536 | x = self.mlp(x) 537 | x = self.global_pool_fn(x, batch.batch) 538 | 539 | if self.pooling_strategy == 'sort': 540 | x = self.sort_pool_nn(x) 541 | x = x.squeeze() 542 | 543 | if not self.use_as_feature_extractor: 544 | x = self.cls(x) 545 | 546 | return x 547 | 548 | -------------------------------------------------------------------------------- /core/nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DeepSets(nn.Module): 6 | aggregation_f = { 7 | 'add': torch.sum, 8 | 'mean': torch.mean, 9 | 'max' : lambda *args, **kwargs: torch.max(*args, **kwargs)[0], 10 | 'min' : lambda *args, **kwargs: torch.min(*args, **kwargs)[0] 11 | } 12 | 13 | def __init__(self, feat_out, point_dim, aggr='add'): 14 | super().__init__() 15 | 16 | self.aggregate = self.aggregation_f[aggr] 17 | 18 | self.phi = nn.Linear(point_dim, feat_out) 19 | 20 | def forward(self, batch): 21 | slice_i = [0] + [t.size(0) for t in batch] 22 | slice_i = torch.tensor(slice_i).cumsum(0) 23 | 24 | x = torch.cat(batch, dim=0) 25 | x = self.phi(x) 26 | 27 | tmp = [] 28 | for i, j in zip(slice_i[:-1], slice_i[1:]): 29 | tmp.append(self.aggregate(x[i:j], dim=0)) 30 | 31 | x = torch.stack(tmp, 0) 32 | 33 | return x 34 | -------------------------------------------------------------------------------- /core/train_engine.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import itertools 3 | import copy 4 | import uuid 5 | import pickle 6 | import datetime 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import torch_geometric 11 | 12 | import numpy as np 13 | 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | 17 | import chofer_torchex.pershom as pershom 18 | 19 | from torch.nn import Sequential, Linear, ReLU 20 | from torch.optim.lr_scheduler import MultiStepLR 21 | 22 | from torch_geometric.datasets import TUDataset 23 | from torch_geometric.nn import GINConv, global_add_pool 24 | 25 | from chofer_torchex import pershom 26 | ph = pershom.pershom_backend.__C.VertFiltCompCuda__vert_filt_persistence_batch 27 | 28 | from chofer_torchex.nn import SLayerRationalHat 29 | from collections import defaultdict, Counter 30 | 31 | from .model import PershomLearnedFilt, PershomRigidDegreeFilt, GIN, SimpleNNBaseline 32 | from .data import dataset_factory, train_test_val_split 33 | from .utils import my_collate, evaluate 34 | 35 | 36 | import torch.multiprocessing as mp 37 | # try: 38 | # mp.set_start_method('spawn') 39 | # except RuntimeError: 40 | # pass 41 | 42 | 43 | __training_cfg = { 44 | 'lr': float, 45 | 'lr_drop_fact': float, 46 | 'num_epochs': int, 47 | 'epoch_step': int, 48 | 'batch_size': int, 49 | 'weight_decay': float, 50 | 'validation_ratio': float, 51 | } 52 | 53 | 54 | __model_cfg_meta = { 55 | 'model_type': str, 56 | 'use_super_level_set_filtration': bool, 57 | 'use_node_degree': bool, 58 | 'set_node_degree_uninformative': bool, 59 | 'pooling_strategy': str, 60 | 'use_node_label': bool, 61 | 'gin_number': int, 62 | 'gin_dimension': int, 63 | 'gin_mlp_type': str, 64 | 'num_struct_elements': int, 65 | 'cls_hidden_dimension': int, 66 | 'drop_out': float, 67 | } 68 | 69 | 70 | __exp_cfg_meta = { 71 | 'dataset_name': str, 72 | 'training': __training_cfg, 73 | 'model': __model_cfg_meta, 74 | 'tag': str 75 | } 76 | 77 | 78 | __exp_res_meta = { 79 | 'exp_cfg': __exp_cfg_meta, 80 | 'cv_test_acc': list, 81 | 'cv_val_acc': list, 82 | 'cv_indices_trn_tst_val': list, 83 | 'cv_epoch_loss': list, 84 | 'start_time': list, 85 | 'id': str 86 | } 87 | 88 | 89 | def model_factory(model_cfg: dict, dataset): 90 | str_2_type = { 91 | 'PershomRigidDegreeFilt': PershomRigidDegreeFilt, 92 | 'PershomLearnedFilt': PershomLearnedFilt, 93 | 'GIN': GIN, 94 | 'SimpleNNBaseline': SimpleNNBaseline 95 | } 96 | 97 | model_type = model_cfg['model_type'] 98 | Model = str_2_type[model_type] 99 | return Model(dataset, **model_cfg) 100 | 101 | 102 | def experiment(exp_cfg, device, output_dir=None, verbose=True, output_cache=None): 103 | 104 | training_cfg = exp_cfg['training'] 105 | 106 | model_cfg = exp_cfg['model'] 107 | 108 | torch.manual_seed(0) 109 | np.random.seed(0) 110 | torch.cuda.manual_seed_all(0) 111 | 112 | dataset = dataset_factory(exp_cfg['dataset_name'], verbose=verbose) 113 | split_ds, split_i = train_test_val_split( 114 | dataset, 115 | validation_ratio=training_cfg['validation_ratio'], 116 | verbose=verbose) 117 | 118 | cv_test_acc = [[] for _ in range(len(split_ds))] 119 | cv_val_acc = [[] for _ in range(len(split_ds))] 120 | cv_epoch_loss = [[] for _ in range(len(split_ds))] 121 | 122 | uiid = str(uuid.uuid4()) 123 | 124 | if output_dir is not None: 125 | output_path = osp.join(output_dir, uiid + '.pickle') 126 | 127 | ret = {} if output_cache is None else output_cache 128 | 129 | ret['exp_cfg'] = exp_cfg 130 | ret['cv_test_acc'] = cv_test_acc 131 | ret['cv_val_acc'] = cv_val_acc 132 | ret['cv_indices_trn_tst_val'] = split_i 133 | ret['cv_epoch_loss'] = cv_epoch_loss 134 | ret['start_time'] = str(datetime.datetime.now()) 135 | ret['id'] = uiid 136 | ret['finished_training'] = False 137 | 138 | for fold_i, (train_split, test_split, validation_split) in enumerate(split_ds): 139 | 140 | model = model_factory(model_cfg, dataset).to(device) 141 | 142 | if verbose and fold_i == 0: 143 | print(model) 144 | 145 | opt = optim.Adam( 146 | model.parameters(), 147 | lr=training_cfg['lr'], 148 | weight_decay=training_cfg['weight_decay'] 149 | ) 150 | 151 | scheduler = MultiStepLR(opt, 152 | milestones=list(range(0, 153 | training_cfg['num_epochs'], 154 | training_cfg['epoch_step']) 155 | )[1:], 156 | gamma=training_cfg['lr_drop_fact']) 157 | 158 | dl_train = torch.utils.data.DataLoader( 159 | train_split, 160 | collate_fn=my_collate, 161 | batch_size=training_cfg['batch_size'], 162 | shuffle=True, 163 | # if last batch would have size 1 we have to drop it ... 164 | drop_last=(len(train_split) % training_cfg['batch_size'] == 1) 165 | ) 166 | 167 | dl_test = torch.utils.data.DataLoader( 168 | test_split , 169 | collate_fn=my_collate, 170 | batch_size=64, 171 | shuffle=False 172 | ) 173 | 174 | dl_val = None 175 | if training_cfg['validation_ratio'] > 0: 176 | dl_val = torch.utils.data.DataLoader( 177 | validation_split, 178 | collate_fn=my_collate, 179 | batch_size=64, 180 | shuffle=False 181 | ) 182 | 183 | for epoch_i in range(1, training_cfg['num_epochs']+1): 184 | 185 | model.train() 186 | scheduler.step() 187 | epoch_loss = 0 188 | 189 | for batch_i, batch in enumerate(dl_train, start=1): 190 | 191 | batch = batch.to(device) 192 | if not hasattr(batch, 'node_lab'): batch.node_lab = None 193 | batch.boundary_info = [e.to(device) for e in batch.boundary_info] 194 | 195 | y_hat = model(batch) 196 | 197 | loss = torch.nn.functional.cross_entropy(y_hat, batch.y) 198 | opt.zero_grad() 199 | loss.backward() 200 | epoch_loss += loss.item() 201 | opt.step() 202 | 203 | if verbose: 204 | print("Epoch {}/{}, Batch {}/{}".format( 205 | epoch_i, 206 | training_cfg['num_epochs'], 207 | batch_i, 208 | len(dl_train)), 209 | end='\r') 210 | 211 | # break # todo remove!!! 212 | 213 | if verbose: print('') 214 | 215 | test_acc = evaluate(dl_test, model, device) 216 | cv_test_acc[fold_i].append(test_acc*100.0) 217 | cv_epoch_loss[fold_i].append(epoch_loss) 218 | 219 | val_acc = None 220 | if training_cfg['validation_ratio'] > 0.0: 221 | val_acc = evaluate(dl_val, model, device) 222 | cv_val_acc[fold_i].append(val_acc*100.0) 223 | 224 | if verbose: print("loss {:.2f} | test_acc {:.2f} | val_acc {:.2f}".format(epoch_loss, test_acc*100.0, val_acc*100.0)) 225 | 226 | # break #todo remove!!! 227 | 228 | if output_dir is not None: 229 | model_file = osp.join(output_dir, uiid + '_model_{}.pht'.format(fold_i)) 230 | torch.save(model.to('cpu'), model_file) 231 | 232 | with open(output_path, 'bw') as fid: 233 | pickle.dump(file=fid, obj=ret) 234 | 235 | ret['finished_training'] = True 236 | if output_dir is not None: 237 | with open(output_path, 'bw') as fid: 238 | pickle.dump(file=fid, obj=ret) 239 | 240 | return ret 241 | 242 | 243 | def experiment_task(args): 244 | 245 | exp_cfg, output_dir, device_counter, lock, max_process_on_device = args 246 | 247 | with lock: 248 | device = None 249 | for k, v in device_counter.items(): 250 | if v < max_process_on_device: 251 | device_id = k 252 | device = 'cuda:{}'.format(device_id) 253 | 254 | break 255 | device_counter[device_id] += 1 256 | 257 | assert device is not None 258 | 259 | try: 260 | print(exp_cfg['dataset_name']) 261 | experiment(exp_cfg, device, output_dir=output_dir, verbose=False) 262 | device_counter[device_id] -= 1 263 | 264 | except Exception as ex: 265 | ex.exp_cfg = exp_cfg 266 | device_counter[device_id] -= 1 267 | 268 | return ex 269 | 270 | 271 | def experiment_multi_device(exp_cfgs, output_dir, visible_devices, max_process_on_device): 272 | assert isinstance(exp_cfgs, list) 273 | assert isinstance(visible_devices, list) 274 | assert osp.isdir(output_dir) 275 | assert all((i < torch.cuda.device_count() for i in visible_devices)) 276 | 277 | num_device = len(visible_devices) 278 | 279 | manager = mp.Manager() 280 | device_counter = manager.dict({t: 0 for t in visible_devices}) 281 | lock = manager.Lock() 282 | 283 | task_args = [(exp_cfg, output_dir, device_counter, lock, max_process_on_device) for exp_cfg in exp_cfgs] 284 | 285 | ret = [] 286 | with mp.Pool(num_device*max_process_on_device, maxtasksperchild=1) as pool: 287 | 288 | for i, r in enumerate(pool.imap_unordered(experiment_task, task_args)): 289 | ret.append(r) 290 | 291 | if r is None: 292 | print("# Finished job {}/{}".format(i + 1, len(task_args))) 293 | 294 | else: 295 | print("#") 296 | print("# Error in job {}/{}".format(i, len(task_args))) 297 | print("#") 298 | print("# Error:") 299 | print(r) 300 | print("# experiment configuration:") 301 | print(r.exp_cfg) 302 | 303 | ret = [r for r in ret if r is not None] 304 | if len(ret) > 0: 305 | with open(osp.join(output_dir, 'errors.pickle'), 'bw') as fid: 306 | pickle.dump(obj=ret, file=fid) 307 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import torch_geometric 2 | import torch 3 | 4 | 5 | def my_collate(data_list): 6 | ret = torch_geometric.data.Batch().from_data_list(data_list) 7 | 8 | boundary_info = [] 9 | sample_pos = [0] 10 | for d in data_list: 11 | boundary_info.append(d.boundary_info) 12 | sample_pos.append(d.num_nodes) 13 | 14 | ret.sample_pos = torch.tensor(sample_pos).cumsum(0) 15 | ret.boundary_info = boundary_info 16 | 17 | return ret 18 | 19 | 20 | def evaluate(dataloader, model, device): 21 | num_samples = 0 22 | correct = 0 23 | 24 | model = model.eval().to(device) 25 | 26 | with torch.no_grad(): 27 | for batch in dataloader: 28 | batch = batch.to(device) 29 | if not hasattr(batch, 'node_lab'): batch.node_lab = None 30 | batch.boundary_info = [e.to(device) for e in batch.boundary_info] 31 | 32 | y_hat = model(batch) 33 | 34 | y_pred = y_hat.max(dim=1)[1] 35 | 36 | correct += (y_pred == batch.y).sum().item() 37 | num_samples += batch.y.size(0) 38 | 39 | return float(correct)/ float(num_samples) -------------------------------------------------------------------------------- /core/utils_ipynb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import glob 4 | 5 | 6 | def read_exp_result_files(path): 7 | files = glob.glob(os.path.join(path, "*.pickle")) 8 | res = [] 9 | for f in files: 10 | if os.path.basename(f) == 'errors.pickle': 11 | continue 12 | 13 | r = pickle.load(open(f, 'rb')) 14 | 15 | #older cfgs have no 'set_node_degree_uninformative' ... 16 | if 'set_node_degree_uninformative' not in r['exp_cfg']['model']: 17 | r['exp_cfg']['model']['set_node_degree_uninformative'] = False 18 | 19 | res.append(r) 20 | return res -------------------------------------------------------------------------------- /results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Browse experiment results ..." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 9, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "import pickle\n", 18 | "import glob\n", 19 | "import os.path\n", 20 | "import pandas as pd\n", 21 | "import core.train_engine\n", 22 | "\n", 23 | "from core.utils_ipynb import read_exp_result_files\n", 24 | "\n", 25 | "\n", 26 | "def get_keychain_value_iter(d, key_chain=None):\n", 27 | " key_chain = [] if key_chain is None else list(key_chain).copy() \n", 28 | " \n", 29 | " if not isinstance(d, dict):\n", 30 | " \n", 31 | " yield tuple(key_chain), d\n", 32 | " else:\n", 33 | " for k, v in d.items():\n", 34 | " yield from get_keychain_value_iter(v, key_chain + [k])\n", 35 | " \n", 36 | "def get_keychain_value(d, key_chain):\n", 37 | " \n", 38 | " try:\n", 39 | " for k in key_chain:\n", 40 | " d = d[k]\n", 41 | " \n", 42 | " except Exception as ex:\n", 43 | " raise KeyError() from ex\n", 44 | " \n", 45 | " return d" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 11, 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "data": { 55 | "text/plain": [ 56 | "{('exp_cfg', 'dataset_name'): 'dataset_name',\n", 57 | " ('exp_cfg', 'training', 'lr'): 'lr',\n", 58 | " ('exp_cfg', 'training', 'lr_drop_fact'): 'lr_drop_fact',\n", 59 | " ('exp_cfg', 'training', 'num_epochs'): 'num_epochs',\n", 60 | " ('exp_cfg', 'training', 'epoch_step'): 'epoch_step',\n", 61 | " ('exp_cfg', 'training', 'batch_size'): 'batch_size',\n", 62 | " ('exp_cfg', 'training', 'weight_decay'): 'weight_decay',\n", 63 | " ('exp_cfg', 'training', 'validation_ratio'): 'validation_ratio',\n", 64 | " ('exp_cfg', 'model', 'model_type'): 'model_type',\n", 65 | " ('exp_cfg',\n", 66 | " 'model',\n", 67 | " 'use_super_level_set_filtration'): 'use_super_level_set_filtration',\n", 68 | " ('exp_cfg', 'model', 'use_node_degree'): 'use_node_degree',\n", 69 | " ('exp_cfg',\n", 70 | " 'model',\n", 71 | " 'set_node_degree_uninformative'): 'set_node_degree_uninformative',\n", 72 | " ('exp_cfg', 'model', 'pooling_strategy'): 'pooling_strategy',\n", 73 | " ('exp_cfg', 'model', 'use_node_label'): 'use_node_label',\n", 74 | " ('exp_cfg', 'model', 'gin_number'): 'gin_number',\n", 75 | " ('exp_cfg', 'model', 'gin_dimension'): 'gin_dimension',\n", 76 | " ('exp_cfg', 'model', 'gin_mlp_type'): 'gin_mlp_type',\n", 77 | " ('exp_cfg', 'model', 'num_struct_elements'): 'num_struct_elements',\n", 78 | " ('exp_cfg', 'model', 'cls_hidden_dimension'): 'cls_hidden_dimension',\n", 79 | " ('exp_cfg', 'model', 'drop_out'): 'drop_out',\n", 80 | " ('exp_cfg', 'tag'): 'tag',\n", 81 | " ('cv_test_acc',): 'cv_test_acc',\n", 82 | " ('cv_val_acc',): 'cv_val_acc',\n", 83 | " ('cv_indices_trn_tst_val',): 'cv_indices_trn_tst_val',\n", 84 | " ('cv_epoch_loss',): 'cv_epoch_loss',\n", 85 | " ('start_time',): 'start_time',\n", 86 | " ('id',): 'id'}" 87 | ] 88 | }, 89 | "execution_count": 11, 90 | "metadata": {}, 91 | "output_type": "execute_result" 92 | } 93 | ], 94 | "source": [ 95 | "kc = {k: k[-1] for k, v in list(get_keychain_value_iter(core.train_engine.__exp_res_meta))}\n", 96 | "kc" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 12, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "COL_NAMES = {\n", 106 | " ('exp_cfg', 'dataset_name'): 'dataset_name',\n", 107 | " #('exp_cfg', 'tag'): 'tag', \n", 108 | "# ('exp_cfg', 'training', 'lr'): 'lr',\n", 109 | "# ('exp_cfg', 'training', 'lr_drop_fact'): 'lr_drop_fact',\n", 110 | "# ('exp_cfg', 'training', 'num_epochs'): 'num_epochs',\n", 111 | "# ('exp_cfg', 'training', 'epoch_step'): 'epoch_step',\n", 112 | " ('exp_cfg', 'training', 'batch_size'): 'batch_size',\n", 113 | "# ('exp_cfg', 'training', 'weight_decay'): 'weight_decay',\n", 114 | "# ('exp_cfg', 'training', 'validation_ratio'): 'validation_ratio',\n", 115 | " ('exp_cfg', 'model', 'model_type'): 'model_type',\n", 116 | " ('exp_cfg', 'model', 'use_super_level_set_filtration'): 'use_super_level_set_filtration',\n", 117 | " ('exp_cfg', 'model', 'use_node_degree'): 'use_node_degree',\n", 118 | " ('exp_cfg', 'model', 'use_node_label'): 'use_node_label',\n", 119 | " ('exp_cfg', 'model', 'gin_number'): 'gin_number',\n", 120 | " ('exp_cfg', 'model', 'gin_dimension'): 'gin_dimension',\n", 121 | " #('exp_cfg', 'model', 'gin_mlp_type'): 'gin_mlp_type',\n", 122 | " ('exp_cfg', 'model', 'set_node_degree_uninformative'): 'set_node_degree_uninformative',\n", 123 | " ('exp_cfg', 'model', 'num_struct_elements'): 'num_struct_elements',\n", 124 | " ('exp_cfg', 'model', 'drop_out'): 'drop_out',\n", 125 | " ('exp_cfg', 'model', 'pooling_strategy'): 'pooling_strategy',\n", 126 | "# ('cv_test_acc',): 'cv_test_acc',\n", 127 | "# ('cv_val_acc',): 'cv_val_acc',\n", 128 | "# ('cv_indices_trn_tst_val',): 'cv_indices_trn_tst_val',\n", 129 | "# ('cv_epoch_loss',): 'cv_epoch_loss',\n", 130 | "# ('start_time',): 'start_time',\n", 131 | "# ('id',): 'id',\n", 132 | " ('finished_training',): 'finished_training'\n", 133 | "}" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 13, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "def pd_frame(path):\n", 143 | " \n", 144 | " f = read_exp_result_files(path)\n", 145 | " \n", 146 | " data_frames = []\n", 147 | " for i, res in enumerate(f):\n", 148 | " row = {}\n", 149 | " \n", 150 | " cv_acc_last = [x[-1] for x in res['cv_test_acc'] if len(x) > 0]\n", 151 | " \n", 152 | " row['acc_last_mean'] = np.mean(cv_acc_last)\n", 153 | " row['acc_last_std'] = np.std(cv_acc_last)\n", 154 | " \n", 155 | " cv_acc_validated = []\n", 156 | " for test, val in zip(res['cv_test_acc'], res['cv_val_acc']):\n", 157 | " if not len(test) == res['exp_cfg']['training']['num_epochs']:\n", 158 | " continue\n", 159 | " n = len(test)//2\n", 160 | " test = torch.tensor(test[n:])\n", 161 | " val = torch.tensor(val[n:])\n", 162 | " #test = torch.tensor(test)\n", 163 | " #val = torch.tensor(val)\n", 164 | " \n", 165 | " \n", 166 | " _, i_max = val.max(0)\n", 167 | " cv_acc_validated.append(test[i_max].item())\n", 168 | " \n", 169 | " row['acc_val_mean'] = np.mean(cv_acc_validated)\n", 170 | " row['acc_val_std'] = np.std(cv_acc_validated)\n", 171 | " \n", 172 | " \n", 173 | " cv_folds_available = sum([1 for cv in res['cv_test_acc'] if len(cv) == res['exp_cfg']['training']['num_epochs']])\n", 174 | " row['cv_folds_available'] = cv_folds_available\n", 175 | " \n", 176 | " \n", 177 | " for k, v in COL_NAMES.items():\n", 178 | " try:\n", 179 | " row[v] = get_keychain_value(res, k)\n", 180 | " except KeyError:\n", 181 | " pass\n", 182 | "\n", 183 | " f = pd.DataFrame(row, index=[i])\n", 184 | " \n", 185 | " data_frames.append(f)\n", 186 | " \n", 187 | " \n", 188 | " return pd.concat(data_frames, sort=True)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 63, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "data": { 198 | "text/html": [ 199 | "
\n", 200 | "\n", 213 | "\n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | " \n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " \n", 1078 | " \n", 1079 | " \n", 1080 | " \n", 1081 | " \n", 1082 | " \n", 1083 | " \n", 1084 | " \n", 1085 | " \n", 1086 | " \n", 1087 | " \n", 1088 | " \n", 1089 | " \n", 1090 | " \n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | " \n", 1128 | " \n", 1129 | " \n", 1130 | " \n", 1131 | " \n", 1132 | " \n", 1133 | " \n", 1134 | " \n", 1135 | " \n", 1136 | " \n", 1137 | " \n", 1138 | " \n", 1139 | " \n", 1140 | " \n", 1141 | " \n", 1142 | " \n", 1143 | " \n", 1144 | " \n", 1145 | " \n", 1146 | " \n", 1147 | " \n", 1148 | " \n", 1149 | " \n", 1150 | " \n", 1151 | " \n", 1152 | " \n", 1153 | " \n", 1154 | " \n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1178 | " \n", 1179 | " \n", 1180 | " \n", 1181 | " \n", 1182 | " \n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | " \n", 1192 | " \n", 1193 | " \n", 1194 | " \n", 1195 | " \n", 1196 | " \n", 1197 | " \n", 1198 | " \n", 1199 | " \n", 1200 | " \n", 1201 | " \n", 1202 | " \n", 1203 | " \n", 1204 | " \n", 1205 | " \n", 1206 | " \n", 1207 | " \n", 1208 | " \n", 1209 | " \n", 1210 | " \n", 1211 | " \n", 1212 | " \n", 1213 | " \n", 1214 | " \n", 1215 | " \n", 1216 | " \n", 1217 | " \n", 1218 | " \n", 1219 | " \n", 1220 | " \n", 1221 | " \n", 1222 | " \n", 1223 | " \n", 1224 | " \n", 1225 | " \n", 1226 | " \n", 1227 | " \n", 1228 | " \n", 1229 | " \n", 1230 | " \n", 1231 | " \n", 1232 | " \n", 1233 | " \n", 1234 | " \n", 1235 | " \n", 1236 | " \n", 1237 | " \n", 1238 | " \n", 1239 | " \n", 1240 | " \n", 1241 | " \n", 1242 | " \n", 1243 | " \n", 1244 | " \n", 1245 | " \n", 1246 | " \n", 1247 | " \n", 1248 | " \n", 1249 | " \n", 1250 | " \n", 1251 | " \n", 1252 | " \n", 1253 | " \n", 1254 | " \n", 1255 | " \n", 1256 | " \n", 1257 | " \n", 1258 | " \n", 1259 | " \n", 1260 | " \n", 1261 | " \n", 1262 | " \n", 1263 | " \n", 1264 | " \n", 1265 | " \n", 1266 | " \n", 1267 | " \n", 1268 | " \n", 1269 | " \n", 1270 | " \n", 1271 | " \n", 1272 | " \n", 1273 | " \n", 1274 | " \n", 1275 | " \n", 1276 | " \n", 1277 | " \n", 1278 | " \n", 1279 | " \n", 1280 | " \n", 1281 | " \n", 1282 | " \n", 1283 | " \n", 1284 | " \n", 1285 | " \n", 1286 | " \n", 1287 | " \n", 1288 | " \n", 1289 | " \n", 1290 | " \n", 1291 | " \n", 1292 | " \n", 1293 | " \n", 1294 | " \n", 1295 | " \n", 1296 | " \n", 1297 | " \n", 1298 | " \n", 1299 | " \n", 1300 | " \n", 1301 | " \n", 1302 | " \n", 1303 | " \n", 1304 | " \n", 1305 | " \n", 1306 | " \n", 1307 | " \n", 1308 | " \n", 1309 | " \n", 1310 | " \n", 1311 | " \n", 1312 | " \n", 1313 | " \n", 1314 | " \n", 1315 | " \n", 1316 | " \n", 1317 | " \n", 1318 | " \n", 1319 | " \n", 1320 | " \n", 1321 | " \n", 1322 | " \n", 1323 | " \n", 1324 | " \n", 1325 | " \n", 1326 | " \n", 1327 | " \n", 1328 | " \n", 1329 | " \n", 1330 | " \n", 1331 | " \n", 1332 | " \n", 1333 | " \n", 1334 | " \n", 1335 | " \n", 1336 | " \n", 1337 | " \n", 1338 | " \n", 1339 | " \n", 1340 | " \n", 1341 | " \n", 1342 | " \n", 1343 | " \n", 1344 | " \n", 1345 | " \n", 1346 | " \n", 1347 | " \n", 1348 | " \n", 1349 | " \n", 1350 | " \n", 1351 | " \n", 1352 | " \n", 1353 | " \n", 1354 | " \n", 1355 | " \n", 1356 | " \n", 1357 | " \n", 1358 | " \n", 1359 | " \n", 1360 | " \n", 1361 | " \n", 1362 | " \n", 1363 | " \n", 1364 | " \n", 1365 | " \n", 1366 | " \n", 1367 | " \n", 1368 | " \n", 1369 | " \n", 1370 | " \n", 1371 | " \n", 1372 | " \n", 1373 | " \n", 1374 | " \n", 1375 | " \n", 1376 | " \n", 1377 | " \n", 1378 | " \n", 1379 | " \n", 1380 | " \n", 1381 | " \n", 1382 | " \n", 1383 | " \n", 1384 | " \n", 1385 | " \n", 1386 | " \n", 1387 | " \n", 1388 | " \n", 1389 | " \n", 1390 | " \n", 1391 | " \n", 1392 | " \n", 1393 | " \n", 1394 | " \n", 1395 | " \n", 1396 | " \n", 1397 | " \n", 1398 | " \n", 1399 | " \n", 1400 | " \n", 1401 | " \n", 1402 | " \n", 1403 | " \n", 1404 | " \n", 1405 | " \n", 1406 | " \n", 1407 | " \n", 1408 | " \n", 1409 | " \n", 1410 | " \n", 1411 | " \n", 1412 | " \n", 1413 | " \n", 1414 | " \n", 1415 | " \n", 1416 | " \n", 1417 | " \n", 1418 | " \n", 1419 | " \n", 1420 | " \n", 1421 | " \n", 1422 | " \n", 1423 | " \n", 1424 | " \n", 1425 | " \n", 1426 | " \n", 1427 | " \n", 1428 | " \n", 1429 | " \n", 1430 | " \n", 1431 | " \n", 1432 | " \n", 1433 | " \n", 1434 | " \n", 1435 | " \n", 1436 | " \n", 1437 | " \n", 1438 | " \n", 1439 | " \n", 1440 | " \n", 1441 | " \n", 1442 | " \n", 1443 | " \n", 1444 | " \n", 1445 | " \n", 1446 | " \n", 1447 | " \n", 1448 | " \n", 1449 | " \n", 1450 | " \n", 1451 | " \n", 1452 | " \n", 1453 | " \n", 1454 | " \n", 1455 | " \n", 1456 | " \n", 1457 | " \n", 1458 | " \n", 1459 | " \n", 1460 | " \n", 1461 | " \n", 1462 | " \n", 1463 | " \n", 1464 | " \n", 1465 | " \n", 1466 | " \n", 1467 | " \n", 1468 | " \n", 1469 | " \n", 1470 | " \n", 1471 | " \n", 1472 | " \n", 1473 | " \n", 1474 | " \n", 1475 | " \n", 1476 | " \n", 1477 | " \n", 1478 | " \n", 1479 | " \n", 1480 | " \n", 1481 | " \n", 1482 | " \n", 1483 | " \n", 1484 | " \n", 1485 | " \n", 1486 | " \n", 1487 | " \n", 1488 | " \n", 1489 | " \n", 1490 | " \n", 1491 | " \n", 1492 | " \n", 1493 | " \n", 1494 | " \n", 1495 | " \n", 1496 | " \n", 1497 | " \n", 1498 | " \n", 1499 | " \n", 1500 | " \n", 1501 | " \n", 1502 | " \n", 1503 | " \n", 1504 | " \n", 1505 | " \n", 1506 | " \n", 1507 | " \n", 1508 | " \n", 1509 | " \n", 1510 | " \n", 1511 | " \n", 1512 | " \n", 1513 | " \n", 1514 | " \n", 1515 | " \n", 1516 | " \n", 1517 | " \n", 1518 | " \n", 1519 | " \n", 1520 | "
acc_last_meanacc_last_stdacc_val_meanacc_val_stdbatch_sizecv_folds_availabledataset_namedrop_outfinished_traininggin_dimensiongin_numbermodel_typenum_struct_elementspooling_strategyset_node_degree_uninformativeuse_node_degreeuse_node_labeluse_super_level_set_filtration
073.3132241.43377272.7710752.7186536410PROTEINS0.0True64.01.0PershomLearnedFilt100.0NaNFalseTrueFalseTrue
150.0666673.21731450.6666672.6161896410IMDB-MULTI0.5True64.01.0GINNaNNaNFalseTrueFalseNaN
249.3333334.38178050.6000003.0324186410IMDB-MULTI0.5True64.05.0GINNaNNaNFalseTrueFalseNaN
372.6000003.23109973.6000003.2619016410IMDB-BINARY0.5True64.01.0GINNaNNaNFalseTrueFalseNaN
474.9000004.43734273.1000003.1764766410IMDB-BINARY0.5True64.05.0GINNaNNaNFalseTrueFalseNaN
581.8384871.64084682.0588491.7318146410COLLAB0.5True64.01.0GINNaNNaNFalseTrueFalseNaN
682.3197691.66985882.1394091.8947226410COLLAB0.5True64.05.0GINNaNNaNFalseTrueFalseNaN
738.3333336.14636337.1666674.6577306410ENZYMES0.5True64.01.0GINNaNNaNFalseTrueTrueNaN
836.3333333.71184334.3333336.5912406410ENZYMES0.5True64.05.0GINNaNNaNFalseTrueTrueNaN
930.8333336.46572230.3333336.6583286410ENZYMES0.5True64.01.0GINNaNNaNFalseFalseTrueNaN
1035.1666678.48036935.6666675.1207646410ENZYMES0.5True64.05.0GINNaNNaNFalseFalseTrueNaN
1127.0000005.71547627.3333336.6332506410ENZYMES0.5True64.01.0GINNaNNaNFalseTrueFalseNaN
1226.1666676.41396025.3333334.8189446410ENZYMES0.5True64.05.0GINNaNNaNFalseTrueFalseNaN
1356.0620929.29233055.49019510.1123076410PTC_PGNN0.5True64.01.0GINNaNNaNFalseTrueTrueNaN
1454.0522887.44954753.1535958.5293646410PTC_PGNN0.5True64.05.0GINNaNNaNFalseTrueTrueNaN
1557.2549026.30778258.1535955.4027676410PTC_PGNN0.5True64.01.0GINNaNNaNFalseFalseTrueNaN
1659.8692815.42336057.2875824.3503346410PTC_PGNN0.5True64.05.0GINNaNNaNFalseFalseTrueNaN
1753.1862756.75342852.8921577.6578226410PTC_PGNN0.5True64.01.0GINNaNNaNFalseTrueFalseNaN
1858.1045758.90881155.84967310.1981726410PTC_PGNN0.5True64.05.0GINNaNNaNFalseTrueFalseNaN
1973.9422463.01972874.3926962.4926736410PROTEINS0.5True64.01.0GINNaNNaNFalseTrueTrueNaN
2070.3547303.50981071.0706254.3919806410PROTEINS0.5True64.05.0GINNaNNaNFalseTrueTrueNaN
2149.7333333.10125448.8000002.6465916410IMDB-MULTI0.5True64.01.0PershomLearnedFilt100.0NaNFalseTrueFalseTrue
2274.7474262.86984675.1994853.4270106410PROTEINS0.5True64.01.0GINNaNNaNFalseFalseTrueNaN
2372.2281215.77580172.8587515.9290556410PROTEINS0.5True64.05.0GINNaNNaNFalseFalseTrueNaN
2472.4123234.51380171.6956244.1803406410PROTEINS0.5True64.01.0GINNaNNaNFalseTrueFalseNaN
2571.5210752.90310970.0780244.6230636410PROTEINS0.5True64.05.0GINNaNNaNFalseTrueFalseNaN
2673.4249694.30704374.4426192.9277676410DD0.5True64.01.0GINNaNNaNFalseTrueTrueNaN
2772.2283884.62937273.3358533.4100576410DD0.5True64.05.0GINNaNNaNFalseTrueTrueNaN
2871.8894053.64220576.4744293.0766226410DD0.5True64.01.0GINNaNNaNFalseFalseTrueNaN
2970.9514564.87158871.3809783.6158706410DD0.5True64.05.0GINNaNNaNFalseFalseTrueNaN
.........................................................
8775.1085912.84673074.2084944.0077763210PROTEINS0.5True64.01.0GINNaNsumFalseTrueTrueNaN
8871.8709783.78162472.0543762.3590813210PROTEINS0.5True64.05.0GINNaNsumFalseTrueTrueNaN
8975.1093953.26479174.5696593.3462163210PROTEINS0.5True64.01.0GINNaNsumFalseFalseTrueNaN
9073.3148334.38129473.4017064.7487033210PROTEINS0.5True64.05.0GINNaNsumFalseFalseTrueNaN
9172.0543763.42290672.5040213.5290703210PROTEINS0.5True64.01.0GINNaNsumFalseTrueFalseNaN
9271.1534753.73358171.8774134.1353543210PROTEINS0.5True64.05.0GINNaNsumFalseTrueFalseNaN
9377.4445832.05800177.2745611.9388173210NCI10.5True64.01.0GINNaNsumFalseTrueTrueNaN
9480.8502101.66411180.8259372.0470403210NCI10.5True64.05.0GINNaNsumFalseTrueTrueNaN
9575.3757391.87580674.8654372.0666273210NCI10.5True64.01.0GINNaNsumFalseFalseTrueNaN
9681.6526581.71001681.4095851.4846633210NCI10.5True64.05.0GINNaNsumFalseFalseTrueNaN
9767.3955302.98565467.3947003.0547843210NCI10.5True64.01.0GINNaNsumFalseTrueFalseNaN
9875.8626552.67155775.7421842.5020573210NCI10.5True64.05.0GINNaNsumFalseTrueFalseNaN
9956.9771247.45930758.7745106.9884413210PTC_MR0.0True64.01.0PershomLearnedFilt100.0NaNFalseTrueTrueTrue
10059.95098011.25779058.2026139.5124983210PTC_MR0.0True64.01.0PershomLearnedFilt100.0NaNFalseFalseTrueTrue
10157.5490205.52703255.2287576.1212383210PTC_MR0.0True64.01.0PershomLearnedFilt100.0NaNFalseTrueFalseTrue
10268.9000003.47706869.3000003.4073453210IMDB-BINARY0.0TrueNaNNaNPershomRigidDegreeFilt100.0NaNFalseNaNNaNTrue
10346.0666674.22111145.9333334.7370873210IMDB-MULTI0.0TrueNaNNaNPershomRigidDegreeFilt100.0NaNFalseNaNNaNTrue
10490.3000002.56125090.5000002.5980763210REDDIT-BINARY0.0TrueNaNNaNPershomRigidDegreeFilt100.0NaNFalseNaNNaNTrue
10555.6706612.05450055.2108221.7014933210REDDIT-MULTI-5K0.0TrueNaNNaNPershomRigidDegreeFilt100.0NaNFalseNaNNaNTrue
10673.4901872.96869673.1314352.8602163210PROTEINS0.0TrueNaNNaNPershomRigidDegreeFilt100.0NaNFalseNaNNaNTrue
10767.7856552.24487267.9557932.4534253210NCI10.0TrueNaNNaNPershomRigidDegreeFilt100.0NaNFalseNaNNaNTrue
10877.2483932.60097576.9077592.6333303210NCI10.0True64.01.0PershomLearnedFilt100.0NaNFalseTrueTrueTrue
10972.4000004.60868771.9000004.0853403210IMDB-BINARY0.0True16.01.0PershomLearnedFilt100.0NaNFalseTrueFalseTrue
11048.7000001.48660750.0499991.657559324REDDIT-MULTI-5K0.0False64.01.0PershomLearnedFilt100.0NaNFalseTrueFalseTrue
11171.6000004.22374271.3000003.9000003210IMDB-BINARY0.0True16.03.0PershomLearnedFilt100.0NaNFalseTrueFalseTrue
11272.7000004.42831873.0000004.6260133210IMDB-BINARY0.0True16.05.0PershomLearnedFilt100.0NaNFalseTrueFalseTrue
11388.8571432.18295889.6428571.940440327REDDIT-BINARY0.0False64.01.0PershomLearnedFilt100.0NaNTrueTrueFalseTrue
11486.6666672.09496885.3333332.778889323REDDIT-BINARY0.5False64.01.0GINNaNsortFalseTrueFalseNaN
11584.7500000.25000084.2500000.750000322REDDIT-BINARY0.5False64.05.0GINNaNsortFalseTrueFalseNaN
11674.2718450.00000072.3300930.000000321NCI10.0False64.01.0PershomLearnedFilt100.0NaNFalseFalseTrueTrue
\n", 1521 | "

117 rows × 18 columns

\n", 1522 | "
" 1523 | ], 1524 | "text/plain": [ 1525 | " acc_last_mean acc_last_std acc_val_mean acc_val_std batch_size \\\n", 1526 | "0 73.313224 1.433772 72.771075 2.718653 64 \n", 1527 | "1 50.066667 3.217314 50.666667 2.616189 64 \n", 1528 | "2 49.333333 4.381780 50.600000 3.032418 64 \n", 1529 | "3 72.600000 3.231099 73.600000 3.261901 64 \n", 1530 | "4 74.900000 4.437342 73.100000 3.176476 64 \n", 1531 | "5 81.838487 1.640846 82.058849 1.731814 64 \n", 1532 | "6 82.319769 1.669858 82.139409 1.894722 64 \n", 1533 | "7 38.333333 6.146363 37.166667 4.657730 64 \n", 1534 | "8 36.333333 3.711843 34.333333 6.591240 64 \n", 1535 | "9 30.833333 6.465722 30.333333 6.658328 64 \n", 1536 | "10 35.166667 8.480369 35.666667 5.120764 64 \n", 1537 | "11 27.000000 5.715476 27.333333 6.633250 64 \n", 1538 | "12 26.166667 6.413960 25.333333 4.818944 64 \n", 1539 | "13 56.062092 9.292330 55.490195 10.112307 64 \n", 1540 | "14 54.052288 7.449547 53.153595 8.529364 64 \n", 1541 | "15 57.254902 6.307782 58.153595 5.402767 64 \n", 1542 | "16 59.869281 5.423360 57.287582 4.350334 64 \n", 1543 | "17 53.186275 6.753428 52.892157 7.657822 64 \n", 1544 | "18 58.104575 8.908811 55.849673 10.198172 64 \n", 1545 | "19 73.942246 3.019728 74.392696 2.492673 64 \n", 1546 | "20 70.354730 3.509810 71.070625 4.391980 64 \n", 1547 | "21 49.733333 3.101254 48.800000 2.646591 64 \n", 1548 | "22 74.747426 2.869846 75.199485 3.427010 64 \n", 1549 | "23 72.228121 5.775801 72.858751 5.929055 64 \n", 1550 | "24 72.412323 4.513801 71.695624 4.180340 64 \n", 1551 | "25 71.521075 2.903109 70.078024 4.623063 64 \n", 1552 | "26 73.424969 4.307043 74.442619 2.927767 64 \n", 1553 | "27 72.228388 4.629372 73.335853 3.410057 64 \n", 1554 | "28 71.889405 3.642205 76.474429 3.076622 64 \n", 1555 | "29 70.951456 4.871588 71.380978 3.615870 64 \n", 1556 | ".. ... ... ... ... ... \n", 1557 | "87 75.108591 2.846730 74.208494 4.007776 32 \n", 1558 | "88 71.870978 3.781624 72.054376 2.359081 32 \n", 1559 | "89 75.109395 3.264791 74.569659 3.346216 32 \n", 1560 | "90 73.314833 4.381294 73.401706 4.748703 32 \n", 1561 | "91 72.054376 3.422906 72.504021 3.529070 32 \n", 1562 | "92 71.153475 3.733581 71.877413 4.135354 32 \n", 1563 | "93 77.444583 2.058001 77.274561 1.938817 32 \n", 1564 | "94 80.850210 1.664111 80.825937 2.047040 32 \n", 1565 | "95 75.375739 1.875806 74.865437 2.066627 32 \n", 1566 | "96 81.652658 1.710016 81.409585 1.484663 32 \n", 1567 | "97 67.395530 2.985654 67.394700 3.054784 32 \n", 1568 | "98 75.862655 2.671557 75.742184 2.502057 32 \n", 1569 | "99 56.977124 7.459307 58.774510 6.988441 32 \n", 1570 | "100 59.950980 11.257790 58.202613 9.512498 32 \n", 1571 | "101 57.549020 5.527032 55.228757 6.121238 32 \n", 1572 | "102 68.900000 3.477068 69.300000 3.407345 32 \n", 1573 | "103 46.066667 4.221111 45.933333 4.737087 32 \n", 1574 | "104 90.300000 2.561250 90.500000 2.598076 32 \n", 1575 | "105 55.670661 2.054500 55.210822 1.701493 32 \n", 1576 | "106 73.490187 2.968696 73.131435 2.860216 32 \n", 1577 | "107 67.785655 2.244872 67.955793 2.453425 32 \n", 1578 | "108 77.248393 2.600975 76.907759 2.633330 32 \n", 1579 | "109 72.400000 4.608687 71.900000 4.085340 32 \n", 1580 | "110 48.700000 1.486607 50.049999 1.657559 32 \n", 1581 | "111 71.600000 4.223742 71.300000 3.900000 32 \n", 1582 | "112 72.700000 4.428318 73.000000 4.626013 32 \n", 1583 | "113 88.857143 2.182958 89.642857 1.940440 32 \n", 1584 | "114 86.666667 2.094968 85.333333 2.778889 32 \n", 1585 | "115 84.750000 0.250000 84.250000 0.750000 32 \n", 1586 | "116 74.271845 0.000000 72.330093 0.000000 32 \n", 1587 | "\n", 1588 | " cv_folds_available dataset_name drop_out finished_training \\\n", 1589 | "0 10 PROTEINS 0.0 True \n", 1590 | "1 10 IMDB-MULTI 0.5 True \n", 1591 | "2 10 IMDB-MULTI 0.5 True \n", 1592 | "3 10 IMDB-BINARY 0.5 True \n", 1593 | "4 10 IMDB-BINARY 0.5 True \n", 1594 | "5 10 COLLAB 0.5 True \n", 1595 | "6 10 COLLAB 0.5 True \n", 1596 | "7 10 ENZYMES 0.5 True \n", 1597 | "8 10 ENZYMES 0.5 True \n", 1598 | "9 10 ENZYMES 0.5 True \n", 1599 | "10 10 ENZYMES 0.5 True \n", 1600 | "11 10 ENZYMES 0.5 True \n", 1601 | "12 10 ENZYMES 0.5 True \n", 1602 | "13 10 PTC_PGNN 0.5 True \n", 1603 | "14 10 PTC_PGNN 0.5 True \n", 1604 | "15 10 PTC_PGNN 0.5 True \n", 1605 | "16 10 PTC_PGNN 0.5 True \n", 1606 | "17 10 PTC_PGNN 0.5 True \n", 1607 | "18 10 PTC_PGNN 0.5 True \n", 1608 | "19 10 PROTEINS 0.5 True \n", 1609 | "20 10 PROTEINS 0.5 True \n", 1610 | "21 10 IMDB-MULTI 0.5 True \n", 1611 | "22 10 PROTEINS 0.5 True \n", 1612 | "23 10 PROTEINS 0.5 True \n", 1613 | "24 10 PROTEINS 0.5 True \n", 1614 | "25 10 PROTEINS 0.5 True \n", 1615 | "26 10 DD 0.5 True \n", 1616 | "27 10 DD 0.5 True \n", 1617 | "28 10 DD 0.5 True \n", 1618 | "29 10 DD 0.5 True \n", 1619 | ".. ... ... ... ... \n", 1620 | "87 10 PROTEINS 0.5 True \n", 1621 | "88 10 PROTEINS 0.5 True \n", 1622 | "89 10 PROTEINS 0.5 True \n", 1623 | "90 10 PROTEINS 0.5 True \n", 1624 | "91 10 PROTEINS 0.5 True \n", 1625 | "92 10 PROTEINS 0.5 True \n", 1626 | "93 10 NCI1 0.5 True \n", 1627 | "94 10 NCI1 0.5 True \n", 1628 | "95 10 NCI1 0.5 True \n", 1629 | "96 10 NCI1 0.5 True \n", 1630 | "97 10 NCI1 0.5 True \n", 1631 | "98 10 NCI1 0.5 True \n", 1632 | "99 10 PTC_MR 0.0 True \n", 1633 | "100 10 PTC_MR 0.0 True \n", 1634 | "101 10 PTC_MR 0.0 True \n", 1635 | "102 10 IMDB-BINARY 0.0 True \n", 1636 | "103 10 IMDB-MULTI 0.0 True \n", 1637 | "104 10 REDDIT-BINARY 0.0 True \n", 1638 | "105 10 REDDIT-MULTI-5K 0.0 True \n", 1639 | "106 10 PROTEINS 0.0 True \n", 1640 | "107 10 NCI1 0.0 True \n", 1641 | "108 10 NCI1 0.0 True \n", 1642 | "109 10 IMDB-BINARY 0.0 True \n", 1643 | "110 4 REDDIT-MULTI-5K 0.0 False \n", 1644 | "111 10 IMDB-BINARY 0.0 True \n", 1645 | "112 10 IMDB-BINARY 0.0 True \n", 1646 | "113 7 REDDIT-BINARY 0.0 False \n", 1647 | "114 3 REDDIT-BINARY 0.5 False \n", 1648 | "115 2 REDDIT-BINARY 0.5 False \n", 1649 | "116 1 NCI1 0.0 False \n", 1650 | "\n", 1651 | " gin_dimension gin_number model_type num_struct_elements \\\n", 1652 | "0 64.0 1.0 PershomLearnedFilt 100.0 \n", 1653 | "1 64.0 1.0 GIN NaN \n", 1654 | "2 64.0 5.0 GIN NaN \n", 1655 | "3 64.0 1.0 GIN NaN \n", 1656 | "4 64.0 5.0 GIN NaN \n", 1657 | "5 64.0 1.0 GIN NaN \n", 1658 | "6 64.0 5.0 GIN NaN \n", 1659 | "7 64.0 1.0 GIN NaN \n", 1660 | "8 64.0 5.0 GIN NaN \n", 1661 | "9 64.0 1.0 GIN NaN \n", 1662 | "10 64.0 5.0 GIN NaN \n", 1663 | "11 64.0 1.0 GIN NaN \n", 1664 | "12 64.0 5.0 GIN NaN \n", 1665 | "13 64.0 1.0 GIN NaN \n", 1666 | "14 64.0 5.0 GIN NaN \n", 1667 | "15 64.0 1.0 GIN NaN \n", 1668 | "16 64.0 5.0 GIN NaN \n", 1669 | "17 64.0 1.0 GIN NaN \n", 1670 | "18 64.0 5.0 GIN NaN \n", 1671 | "19 64.0 1.0 GIN NaN \n", 1672 | "20 64.0 5.0 GIN NaN \n", 1673 | "21 64.0 1.0 PershomLearnedFilt 100.0 \n", 1674 | "22 64.0 1.0 GIN NaN \n", 1675 | "23 64.0 5.0 GIN NaN \n", 1676 | "24 64.0 1.0 GIN NaN \n", 1677 | "25 64.0 5.0 GIN NaN \n", 1678 | "26 64.0 1.0 GIN NaN \n", 1679 | "27 64.0 5.0 GIN NaN \n", 1680 | "28 64.0 1.0 GIN NaN \n", 1681 | "29 64.0 5.0 GIN NaN \n", 1682 | ".. ... ... ... ... \n", 1683 | "87 64.0 1.0 GIN NaN \n", 1684 | "88 64.0 5.0 GIN NaN \n", 1685 | "89 64.0 1.0 GIN NaN \n", 1686 | "90 64.0 5.0 GIN NaN \n", 1687 | "91 64.0 1.0 GIN NaN \n", 1688 | "92 64.0 5.0 GIN NaN \n", 1689 | "93 64.0 1.0 GIN NaN \n", 1690 | "94 64.0 5.0 GIN NaN \n", 1691 | "95 64.0 1.0 GIN NaN \n", 1692 | "96 64.0 5.0 GIN NaN \n", 1693 | "97 64.0 1.0 GIN NaN \n", 1694 | "98 64.0 5.0 GIN NaN \n", 1695 | "99 64.0 1.0 PershomLearnedFilt 100.0 \n", 1696 | "100 64.0 1.0 PershomLearnedFilt 100.0 \n", 1697 | "101 64.0 1.0 PershomLearnedFilt 100.0 \n", 1698 | "102 NaN NaN PershomRigidDegreeFilt 100.0 \n", 1699 | "103 NaN NaN PershomRigidDegreeFilt 100.0 \n", 1700 | "104 NaN NaN PershomRigidDegreeFilt 100.0 \n", 1701 | "105 NaN NaN PershomRigidDegreeFilt 100.0 \n", 1702 | "106 NaN NaN PershomRigidDegreeFilt 100.0 \n", 1703 | "107 NaN NaN PershomRigidDegreeFilt 100.0 \n", 1704 | "108 64.0 1.0 PershomLearnedFilt 100.0 \n", 1705 | "109 16.0 1.0 PershomLearnedFilt 100.0 \n", 1706 | "110 64.0 1.0 PershomLearnedFilt 100.0 \n", 1707 | "111 16.0 3.0 PershomLearnedFilt 100.0 \n", 1708 | "112 16.0 5.0 PershomLearnedFilt 100.0 \n", 1709 | "113 64.0 1.0 PershomLearnedFilt 100.0 \n", 1710 | "114 64.0 1.0 GIN NaN \n", 1711 | "115 64.0 5.0 GIN NaN \n", 1712 | "116 64.0 1.0 PershomLearnedFilt 100.0 \n", 1713 | "\n", 1714 | " pooling_strategy set_node_degree_uninformative use_node_degree \\\n", 1715 | "0 NaN False True \n", 1716 | "1 NaN False True \n", 1717 | "2 NaN False True \n", 1718 | "3 NaN False True \n", 1719 | "4 NaN False True \n", 1720 | "5 NaN False True \n", 1721 | "6 NaN False True \n", 1722 | "7 NaN False True \n", 1723 | "8 NaN False True \n", 1724 | "9 NaN False False \n", 1725 | "10 NaN False False \n", 1726 | "11 NaN False True \n", 1727 | "12 NaN False True \n", 1728 | "13 NaN False True \n", 1729 | "14 NaN False True \n", 1730 | "15 NaN False False \n", 1731 | "16 NaN False False \n", 1732 | "17 NaN False True \n", 1733 | "18 NaN False True \n", 1734 | "19 NaN False True \n", 1735 | "20 NaN False True \n", 1736 | "21 NaN False True \n", 1737 | "22 NaN False False \n", 1738 | "23 NaN False False \n", 1739 | "24 NaN False True \n", 1740 | "25 NaN False True \n", 1741 | "26 NaN False True \n", 1742 | "27 NaN False True \n", 1743 | "28 NaN False False \n", 1744 | "29 NaN False False \n", 1745 | ".. ... ... ... \n", 1746 | "87 sum False True \n", 1747 | "88 sum False True \n", 1748 | "89 sum False False \n", 1749 | "90 sum False False \n", 1750 | "91 sum False True \n", 1751 | "92 sum False True \n", 1752 | "93 sum False True \n", 1753 | "94 sum False True \n", 1754 | "95 sum False False \n", 1755 | "96 sum False False \n", 1756 | "97 sum False True \n", 1757 | "98 sum False True \n", 1758 | "99 NaN False True \n", 1759 | "100 NaN False False \n", 1760 | "101 NaN False True \n", 1761 | "102 NaN False NaN \n", 1762 | "103 NaN False NaN \n", 1763 | "104 NaN False NaN \n", 1764 | "105 NaN False NaN \n", 1765 | "106 NaN False NaN \n", 1766 | "107 NaN False NaN \n", 1767 | "108 NaN False True \n", 1768 | "109 NaN False True \n", 1769 | "110 NaN False True \n", 1770 | "111 NaN False True \n", 1771 | "112 NaN False True \n", 1772 | "113 NaN True True \n", 1773 | "114 sort False True \n", 1774 | "115 sort False True \n", 1775 | "116 NaN False False \n", 1776 | "\n", 1777 | " use_node_label use_super_level_set_filtration \n", 1778 | "0 False True \n", 1779 | "1 False NaN \n", 1780 | "2 False NaN \n", 1781 | "3 False NaN \n", 1782 | "4 False NaN \n", 1783 | "5 False NaN \n", 1784 | "6 False NaN \n", 1785 | "7 True NaN \n", 1786 | "8 True NaN \n", 1787 | "9 True NaN \n", 1788 | "10 True NaN \n", 1789 | "11 False NaN \n", 1790 | "12 False NaN \n", 1791 | "13 True NaN \n", 1792 | "14 True NaN \n", 1793 | "15 True NaN \n", 1794 | "16 True NaN \n", 1795 | "17 False NaN \n", 1796 | "18 False NaN \n", 1797 | "19 True NaN \n", 1798 | "20 True NaN \n", 1799 | "21 False True \n", 1800 | "22 True NaN \n", 1801 | "23 True NaN \n", 1802 | "24 False NaN \n", 1803 | "25 False NaN \n", 1804 | "26 True NaN \n", 1805 | "27 True NaN \n", 1806 | "28 True NaN \n", 1807 | "29 True NaN \n", 1808 | ".. ... ... \n", 1809 | "87 True NaN \n", 1810 | "88 True NaN \n", 1811 | "89 True NaN \n", 1812 | "90 True NaN \n", 1813 | "91 False NaN \n", 1814 | "92 False NaN \n", 1815 | "93 True NaN \n", 1816 | "94 True NaN \n", 1817 | "95 True NaN \n", 1818 | "96 True NaN \n", 1819 | "97 False NaN \n", 1820 | "98 False NaN \n", 1821 | "99 True True \n", 1822 | "100 True True \n", 1823 | "101 False True \n", 1824 | "102 NaN True \n", 1825 | "103 NaN True \n", 1826 | "104 NaN True \n", 1827 | "105 NaN True \n", 1828 | "106 NaN True \n", 1829 | "107 NaN True \n", 1830 | "108 True True \n", 1831 | "109 False True \n", 1832 | "110 False True \n", 1833 | "111 False True \n", 1834 | "112 False True \n", 1835 | "113 False True \n", 1836 | "114 False NaN \n", 1837 | "115 False NaN \n", 1838 | "116 True True \n", 1839 | "\n", 1840 | "[117 rows x 18 columns]" 1841 | ] 1842 | }, 1843 | "execution_count": 63, 1844 | "metadata": {}, 1845 | "output_type": "execute_result" 1846 | } 1847 | ], 1848 | "source": [ 1849 | "path = './results/'\n", 1850 | "RES = pd_frame(path)\n", 1851 | "RES" 1852 | ] 1853 | }, 1854 | { 1855 | "cell_type": "markdown", 1856 | "metadata": {}, 1857 | "source": [ 1858 | "The following cells contain some utility for messing around with results, i.e., deleting etc. " 1859 | ] 1860 | } 1861 | ], 1862 | "metadata": { 1863 | "kernelspec": { 1864 | "display_name": "Python 3", 1865 | "language": "python", 1866 | "name": "python3" 1867 | }, 1868 | "language_info": { 1869 | "codemirror_mode": { 1870 | "name": "ipython", 1871 | "version": 3 1872 | }, 1873 | "file_extension": ".py", 1874 | "mimetype": "text/x-python", 1875 | "name": "python", 1876 | "nbconvert_exporter": "python", 1877 | "pygments_lexer": "ipython3", 1878 | "version": "3.8.3" 1879 | } 1880 | }, 1881 | "nbformat": 4, 1882 | "nbformat_minor": 2 1883 | } 1884 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | from core.train_engine import experiment_multi_device 5 | 6 | 7 | if __name__ == "__main__": 8 | import torch.multiprocessing as mp 9 | mp.set_start_method('spawn', force=True) 10 | 11 | 12 | parser = argparse.ArgumentParser(description='') 13 | 14 | parser.add_argument('--cfg_file', type=str, 15 | help='') 16 | parser.add_argument('--output_dir', type=str, 17 | help='') 18 | parser.add_argument('--devices', nargs='+') 19 | 20 | parser.add_argument('--max_process_on_device', type=int) 21 | 22 | 23 | args = parser.parse_args() 24 | 25 | devices = [int(d) for d in args.devices] 26 | 27 | with open(args.cfg_file, 'r') as fid: 28 | exp_cfgs = json.load(fid) 29 | 30 | experiment_multi_device(exp_cfgs, args.output_dir, devices, args.max_process_on_device) 31 | 32 | 33 | -------------------------------------------------------------------------------- /write_exp_cfgs_file.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Generate experiment configuration files\n", 8 | "\n", 9 | "The purpos of this notebook is to generate configurations files. Those files are the input to the \n", 10 | "train script (`train.py`) " 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 5, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import os.path as osp\n", 20 | "import itertools\n", 21 | "import copy\n", 22 | "import json\n", 23 | "\n", 24 | "from core.data import dataset_factory" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 6, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "data": { 34 | "text/plain": [ 35 | "{'REDDIT-BINARY': False}" 36 | ] 37 | }, 38 | "execution_count": 6, 39 | "metadata": {}, 40 | "output_type": "execute_result" 41 | } 42 | ], 43 | "source": [ 44 | "# Check which datasets have nodelabels attached. This may take a while. \n", 45 | "dataset_names = [\n", 46 | " 'REDDIT-BINARY',\n", 47 | "# 'REDDIT-MULTI-5K',\n", 48 | "# 'COLLAB',\n", 49 | " #'IMDB-MULTI',\n", 50 | " #'IMDB-BINARY',\n", 51 | " #'ENZYMES',\n", 52 | " #'PTC_PGNN',\n", 53 | " #'PTC_FM',\n", 54 | " #'PTC_FR',\n", 55 | " #'PTC_MM',\n", 56 | " #'PTC_MR',\n", 57 | " #'PROTEINS',\n", 58 | " #'DD',\n", 59 | " #'NCI1',\n", 60 | " #'MUTAG'\n", 61 | "]\n", 62 | "\n", 63 | "dataset_has_node_lab = {n: dataset_factory(n, verbose=False).num_node_lab is not None for n in dataset_names}\n", 64 | "dataset_has_node_lab" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "### The optimization related part of the configuration..." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 7, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "training_cfg = {\n", 81 | " 'lr': 0.01, \n", 82 | " 'lr_drop_fact': 0.5, \n", 83 | " 'num_epochs': 100,\n", 84 | " 'epoch_step': 20,\n", 85 | " 'batch_size': 32,\n", 86 | " 'weight_decay': 10e-06,\n", 87 | " 'validation_ratio': 0.1\n", 88 | "}\n", 89 | "training_cfgs = [training_cfg]" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "### The model related part of the configuration..." 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 8, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "data": { 106 | "text/plain": [ 107 | "2" 108 | ] 109 | }, 110 | "execution_count": 8, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | } 114 | ], 115 | "source": [ 116 | "# Pershom rigid filtration ...\n", 117 | "proto = {\n", 118 | " 'model_type': 'PershomRigidDegreeFilt',\n", 119 | " 'use_super_level_set_filtration': None, \n", 120 | " 'num_struct_elements': 100, \n", 121 | " 'cls_hidden_dimension': 64, \n", 122 | " 'drop_out': 0.0\n", 123 | "}\n", 124 | "model_cfgs_PershomRigidDegreeFilt = []\n", 125 | "for b in [False, True]:\n", 126 | " tmp = copy.deepcopy(proto)\n", 127 | " \n", 128 | " tmp['use_super_level_set_filtration'] = b\n", 129 | " \n", 130 | " model_cfgs_PershomRigidDegreeFilt.append(tmp)\n", 131 | " \n", 132 | "len(model_cfgs_PershomRigidDegreeFilt)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 9, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "data": { 142 | "text/plain": [ 143 | "3" 144 | ] 145 | }, 146 | "execution_count": 9, 147 | "metadata": {}, 148 | "output_type": "execute_result" 149 | } 150 | ], 151 | "source": [ 152 | "# Pershom learnt filtration ...\n", 153 | "proto = {\n", 154 | " 'model_type': 'PershomLearnedFilt',\n", 155 | " 'use_super_level_set_filtration': None, \n", 156 | " 'use_node_degree': None, \n", 157 | " 'set_node_degree_uninformative': True, \n", 158 | " 'use_node_label': None, \n", 159 | " 'gin_number': 1, \n", 160 | " 'gin_dimension': 64,\n", 161 | " 'gin_mlp_type': 'lin_bn_lrelu_lin', \n", 162 | " 'num_struct_elements': 100, \n", 163 | " 'cls_hidden_dimension': 64, \n", 164 | " 'drop_out': 0.0 \n", 165 | "}\n", 166 | "model_cfgs_PershomLearnedFilt = []\n", 167 | "\n", 168 | "B = [(True, True), (False, True), (True, False)]\n", 169 | "\n", 170 | "for (a, b), c, d, e in itertools.product(B, [True], [64], [1]):\n", 171 | " tmp = copy.deepcopy(proto)\n", 172 | "\n", 173 | " tmp['use_node_degree'] = a\n", 174 | " tmp['use_node_label'] = b\n", 175 | " tmp['use_super_level_set_filtration'] = c \n", 176 | "\n", 177 | " tmp['gin_dimension'] = d\n", 178 | " tmp['gin_number'] = e\n", 179 | "\n", 180 | " model_cfgs_PershomLearnedFilt.append(tmp)\n", 181 | " \n", 182 | "len(model_cfgs_PershomLearnedFilt)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 10, 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "data": { 192 | "text/plain": [ 193 | "3" 194 | ] 195 | }, 196 | "execution_count": 10, 197 | "metadata": {}, 198 | "output_type": "execute_result" 199 | } 200 | ], 201 | "source": [ 202 | "# GIN ... \n", 203 | "proto = {\n", 204 | " 'model_type': 'GIN',\n", 205 | " 'use_node_degree': None, \n", 206 | " 'use_node_label': None, \n", 207 | " 'gin_number': None, \n", 208 | " 'gin_dimension': 64,\n", 209 | " 'gin_mlp_type': 'lin_bn_lrelu_lin', \n", 210 | " 'cls_hidden_dimension': 64, \n", 211 | " 'set_node_degree_uninformative': None,\n", 212 | " 'pooling_strategy': 'sort',\n", 213 | " 'drop_out': 0.5 \n", 214 | "}\n", 215 | "model_cfgs_GIN = []\n", 216 | "\n", 217 | "B = [(True, True), (False, True), (True, False)]\n", 218 | "\n", 219 | "for (a, b), c, d in itertools.product(B, [1], [True]):\n", 220 | " tmp = copy.deepcopy(proto)\n", 221 | "\n", 222 | " tmp['use_node_degree'] = a\n", 223 | " tmp['use_node_label'] = b\n", 224 | " tmp['gin_number'] = c\n", 225 | " tmp['set_node_degree_uninformative'] = d\n", 226 | "\n", 227 | " model_cfgs_GIN.append(tmp)\n", 228 | " \n", 229 | "len(model_cfgs_GIN)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 11, 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "data": { 239 | "text/plain": [ 240 | "6" 241 | ] 242 | }, 243 | "execution_count": 11, 244 | "metadata": {}, 245 | "output_type": "execute_result" 246 | } 247 | ], 248 | "source": [ 249 | "# SimpleNNBaseline ... \n", 250 | "proto = {\n", 251 | " 'model_type': 'SimpleNNBaseline',\n", 252 | " 'use_node_degree': None, \n", 253 | " 'use_node_label': None, \n", 254 | " 'gin_dimension': 64,\n", 255 | " 'gin_mlp_type': 'lin_bn_lrelu_lin', \n", 256 | " 'cls_hidden_dimension': 64, \n", 257 | " 'set_node_degree_uninformative': None,\n", 258 | " 'pooling_strategy': 'sum',\n", 259 | " 'drop_out': None \n", 260 | "}\n", 261 | "model_cfgs_SimpleNNBaseline = []\n", 262 | "\n", 263 | "B = [(True, True), (False, True), (True, False)]\n", 264 | "\n", 265 | "for (a, b), c, d in itertools.product(B, [False], [0.0, 0.5]):\n", 266 | " tmp = copy.deepcopy(proto)\n", 267 | "\n", 268 | " tmp['use_node_degree'] = a\n", 269 | " tmp['use_node_label'] = b\n", 270 | " tmp['set_node_degree_uninformative'] = c\n", 271 | " tmp['drop_out'] = d\n", 272 | "\n", 273 | " model_cfgs_SimpleNNBaseline.append(tmp)\n", 274 | " \n", 275 | "len(model_cfgs_SimpleNNBaseline)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "### Now we combine those parts and write the cfg file ..." 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 12, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "def combine(dataset_names, training_cfgs, model_cfgs, tag=\"\"):\n", 292 | " exp_cfgs = []\n", 293 | " continued = 0\n", 294 | " for a, b, c in itertools.product(dataset_names, training_cfgs, model_cfgs):\n", 295 | "\n", 296 | " # filter out datasets which have no node labels\n", 297 | " ds_has_node_lab = dataset_has_node_lab[a]\n", 298 | "\n", 299 | " if 'use_node_label' in c:\n", 300 | " use_node_lab = c['use_node_label']\n", 301 | "\n", 302 | " if (not ds_has_node_lab) and use_node_lab:\n", 303 | "# print(a, c['model_type'])\n", 304 | " continue\n", 305 | "\n", 306 | " tmp = {\n", 307 | " 'dataset_name': a, \n", 308 | " 'training': b, \n", 309 | " 'model': c, \n", 310 | " 'tag': tag\n", 311 | " }\n", 312 | " exp_cfgs.append(tmp)\n", 313 | " \n", 314 | " return exp_cfgs\n", 315 | "\n", 316 | "def write_file(dataset_names, training_cfgs, model_cfgs, output_dir, tag=\"\", file_name=None):\n", 317 | " exp_cfgs = combine(dataset_names, training_cfgs, model_cfgs, tag=tag)\n", 318 | " if file_name is None:\n", 319 | " file_name = \"exp_cfgs__\" + \"_\".join(dataset_names) + \".json\"\n", 320 | " \n", 321 | " with open(file_name, 'w') as fid:\n", 322 | " json.dump(exp_cfgs, fid)\n", 323 | " \n", 324 | " print('Num cfgs: ', len(exp_cfgs))" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 13, 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [ 333 | "output_dir = 'results'" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 14, 339 | "metadata": {}, 340 | "outputs": [ 341 | { 342 | "name": "stdout", 343 | "output_type": "stream", 344 | "text": [ 345 | "Num cfgs: 1\n" 346 | ] 347 | } 348 | ], 349 | "source": [ 350 | "# Write cfg file for, e.g., learned filtration setup...\n", 351 | "write_file(dataset_names, \n", 352 | " training_cfgs, \n", 353 | " model_cfgs_PershomLearnedFilt, \n", 354 | " output_dir, \n", 355 | " file_name='my_config.json', \n", 356 | " tag=\"\")" 357 | ] 358 | } 359 | ], 360 | "metadata": { 361 | "kernelspec": { 362 | "display_name": "Python 3", 363 | "language": "python", 364 | "name": "python3" 365 | }, 366 | "language_info": { 367 | "codemirror_mode": { 368 | "name": "ipython", 369 | "version": 3 370 | }, 371 | "file_extension": ".py", 372 | "mimetype": "text/x-python", 373 | "name": "python", 374 | "nbconvert_exporter": "python", 375 | "pygments_lexer": "ipython3", 376 | "version": "3.8.3" 377 | } 378 | }, 379 | "nbformat": 4, 380 | "nbformat_minor": 2 381 | } 382 | --------------------------------------------------------------------------------