├── .gitignore ├── LICENSE ├── README.md ├── config └── train.cfg ├── environment.yml ├── plot.sh ├── src ├── Logger │ ├── LogMetric.py │ └── __init__.py ├── data │ ├── HistoGraph.py │ ├── Iam.py │ ├── __init__.py │ ├── data_utils.py │ └── load_data.py ├── loss │ └── contrastive.py ├── models │ ├── __init__.py │ ├── distance.py │ ├── layers.py │ ├── models.py │ └── realdistance.py ├── options.py ├── plot.py ├── test.py ├── testHED.py ├── test_iam.py ├── train.py ├── train_iam.py └── utils.py ├── test.sh ├── testHED.sh ├── train.sh └── train_iam.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Pau Riba 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Metric Learning 2 | 3 | Graph Metric Learning in PyTorch. 4 | 5 | ## Install 6 | 7 | - Install all the requirements. 8 | 9 | ``` 10 | $ conda env create -f environment.yml 11 | ``` 12 | 13 | ## Usage 14 | 15 | ``` 16 | $ conda activate graphmetric 17 | ``` 18 | 19 | ### Train 20 | 21 | * Write configuration file. Follow the example [here](./config/). 22 | * Run the training script with the corresponding configuration file `./train.sh config/train.cfg` 23 | 24 | ### Test 25 | 26 | * Write configuration file. Follow the example [here](./config/) providing a load path (`--load`), and test option (`-t`). 27 | * Run the test script with the corresponding configuration file `./train.sh config/test.cfg` 28 | 29 | ## Author 30 | 31 | * [Pau Riba](http://www.cvc.uab.es/people/priba/) ([@priba](https://github.com/priba)) 32 | 33 | -------------------------------------------------------------------------------- /config/train.cfg: -------------------------------------------------------------------------------- 1 | # Example of configuration file 2 | 3 | dataset=dataset_id # dataset='histograph-gw' 4 | data_path=/path/to/the/data/ 5 | set_partition="--set_partition cv1" 6 | bz="--batch_size 64" 7 | out_size="--out_size 32" 8 | hidden="--hidden 32" 9 | dropout="--dropout 0.3" 10 | loss="--loss triplet" 11 | swap="--swap" 12 | margin="--margin 10" 13 | epochs="-e 1000" 14 | lr="-lr 1e-3" 15 | momentum="-m 0.9" 16 | decay="-d 0.0005" 17 | schedule="--schedule 50" 18 | gamma="--gamma 0.1" 19 | save= # save="-s /path/to/save" 20 | load= # load="-l /path/to/load" 21 | test= # test="-t" 22 | early_stop="-es 50" 23 | ngpu="--ngpu 1" 24 | prefetch="--prefetch 4" 25 | log= # log="--log /path/to/log" 26 | log_interval="--log-interval 256" 27 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: graphmetric 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - dglteam 6 | - defaults 7 | dependencies: 8 | - python=3.7 9 | - numpy 10 | - pytorch=1.2 11 | - torchvision 12 | - tensorflow 13 | - tensorboard 14 | - tensorboardX 15 | - dgl-cuda10.0 16 | - scipy 17 | - scikit-learn 18 | - joblib 19 | - tqdm 20 | - matplotlib 21 | -------------------------------------------------------------------------------- /plot.sh: -------------------------------------------------------------------------------- 1 | . $1 2 | 3 | python src/plot.py $dataset $data_path $bz $out_size $hidden $dropout $loss $swap $margin $epochs $lr $decay $schedule $gamma $save $load $test $early_stop $prefetch $ngpu $log $log_interval $set_partition 4 | 5 | -------------------------------------------------------------------------------- /src/Logger/LogMetric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | """ 5 | Log Metric. 6 | """ 7 | 8 | import torch 9 | from tensorboardX import SummaryWriter 10 | import os 11 | 12 | __author__ = "Pau Riba" 13 | __email__ = "priba@cvc.uab.cat" 14 | 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | def __init__(self): 19 | self.reset() 20 | 21 | def reset(self): 22 | self.val = 0 23 | self.avg = 0 24 | self.sum = 0 25 | self.count = 0 26 | 27 | def update(self, val, n=1): 28 | self.val = val 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / self.count 32 | 33 | 34 | class Logger(object): 35 | def __init__(self, log_dir, force=False): 36 | # clean previous logged data under the same directory name 37 | self._remove(log_dir, force) 38 | 39 | # create the summary writer object 40 | self._writer = SummaryWriter(log_dir) 41 | 42 | self.global_step = 0 43 | 44 | def __del__(self): 45 | self._writer.close() 46 | 47 | def add_scalar(self, name, scalar_value): 48 | assert isinstance(scalar_value, float), type(scalar_value) 49 | self._writer.add_scalar(name, scalar_value, self.global_step) 50 | 51 | def add_image(self, name, img_tensor): 52 | assert isinstance(img_tensor, torch.Tensor), type(img_tensor) 53 | self._writer.add_image(name, img_tensor, self.global_step) 54 | 55 | def step(self): 56 | self.global_step += 1 57 | 58 | @staticmethod 59 | def _remove(path, force): 60 | """ param could either be relative or absolute. """ 61 | if not os.path.exists(path): 62 | return 63 | elif os.path.isfile(path) and force: 64 | os.remove(path) # remove the file 65 | elif os.path.isdir(path) and force: 66 | import shutil 67 | shutil.rmtree(path) # remove dir and all contains 68 | else: 69 | print('Logdir contains data. Please, set `force` flag to overwrite it.') 70 | import sys 71 | sys.exit(0) 72 | 73 | -------------------------------------------------------------------------------- /src/Logger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/priba/graph_metric.pytorch/68930d7bbc6b2b3ff12e39d7e9260f7bbe6a2e80/src/Logger/__init__.py -------------------------------------------------------------------------------- /src/data/HistoGraph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import xml.etree.ElementTree as ET 4 | import numpy as np 5 | from . import data_utils as du 6 | import os 7 | import itertools 8 | import pickle 9 | import dgl 10 | 11 | __author__ = "Pau Riba" 12 | __email__ = "priba@cvc.uab.cat" 13 | 14 | 15 | class HistoGraph_train(data.Dataset): 16 | def __init__(self, root_path, file_list, triplet=False): 17 | self.root = root_path 18 | self.file_list = file_list 19 | self.triplet = triplet 20 | 21 | self.graphs, self.labels = getFileList(self.file_list) 22 | 23 | # To pickle 24 | self.graphs = [g+'.p' for g in self.graphs] 25 | self.labels = np.array(self.labels) 26 | self.unique_labels = np.unique(self.labels) 27 | if self.triplet: 28 | # Triplet (anchor, positive, negative) 29 | self.groups = [ (i, j) for i in range(len(self.labels)) for j in np.where(self.labels[i] == self.labels)[0] if i != j ] 30 | self.labels_len = np.array(list(map(len, self.labels))) 31 | self.labels_counts = np.array([(l==self.labels).sum() for l in self.labels]) 32 | else: 33 | # Siamese all pairs 34 | self.groups = list(itertools.permutations(range(len(self.labels)), 2)) 35 | 36 | def __getitem__(self, index): 37 | ind = self.groups[index] 38 | 39 | # Graph 1 40 | g1 = self._loadgraph(ind[0]) 41 | target1 = self.labels[ind[0]] 42 | 43 | # Graph 2 44 | g2 = self._loadgraph(ind[1]) 45 | target2 = self.labels[ind[1]] 46 | 47 | if self.triplet: 48 | # Random negative choice where it would be of similar size 49 | possible_ind = np.where(self.labels!=target1)[0] 50 | labels_counts = self.labels_counts[possible_ind] 51 | labels_len = np.abs(self.labels_len[possible_ind] - self.labels_len[ind[0]]) + 1.0 52 | labels_probs = 1/(labels_counts*labels_len) 53 | labels_probs = labels_probs/labels_probs.sum() 54 | neg_ind = np.random.choice(possible_ind, 1, p=labels_probs) 55 | 56 | # Graph 3 57 | g3 = self._loadgraph(neg_ind[0]) 58 | target_neg = self.labels[neg_ind[0]] 59 | 60 | return g1, g2, g3, torch.Tensor([]) 61 | 62 | target = torch.FloatTensor([0.0]) if target1 == target2 else torch.FloatTensor([1.0]) 63 | return g1, g2, torch.Tensor([]), target 64 | 65 | def __len__(self): 66 | return len(self.groups) 67 | 68 | def _loadgraph(self, i): 69 | graph_dict = pickle.load( open(os.path.join(self.root, self.graphs[i]), "rb") ) 70 | 71 | g = dgl.DGLGraph() 72 | 73 | g.gdata = {} 74 | g.gdata['std'] = torch.tensor(graph_dict['graph_properties']).float() 75 | 76 | g.add_nodes(graph_dict['node_labels'].shape[0]) 77 | g.ndata['pos'] = torch.tensor(graph_dict['node_labels']).float() 78 | if g.number_of_nodes() == 0: 79 | g.add_nodes(1, {'pos': torch.zeros(1,2)}) 80 | g.gdata['std'] = torch.zeros(2) 81 | 82 | 83 | g.add_edges(graph_dict['am'][0], graph_dict['am'][1]) 84 | 85 | # Add self connections 86 | g.add_edges(g.nodes(), g.nodes()) 87 | 88 | return g 89 | 90 | 91 | class HistoGraph(data.Dataset): 92 | def __init__(self, root_path, file_list, keywords_file=None, subset='valid'): 93 | self.root = root_path 94 | self.file_list = file_list 95 | 96 | self.graphs, self.labels = getFileList(self.file_list) 97 | 98 | # To pickle 99 | self.graphs = [g+'.p' for g in self.graphs] 100 | 101 | self.subset = subset 102 | 103 | if keywords_file is not None: 104 | with open(keywords_file, 'r') as f: 105 | queries = f.read().splitlines() 106 | queries = [ q.split(' ')[-1] for q in queries ] 107 | idx_del = [i for i, label in enumerate(self.labels) if label not in queries] 108 | 109 | for index in sorted(idx_del, reverse=True): 110 | del self.labels[index] 111 | del self.graphs[index] 112 | 113 | if ('PAR' in self.root) and ('valid' in subset): 114 | u_labels = np.unique(self.labels) 115 | for lab in u_labels: 116 | idx_del = np.where(lab == np.array(self.labels))[0] 117 | for index in reversed(range(1, len(idx_del))): 118 | del self.labels[idx_del[index]] 119 | del self.graphs[idx_del[index]] 120 | 121 | 122 | def __getitem__(self, index): 123 | # Graph 124 | g = self._loadgraph(index) 125 | target = self.labels[index] 126 | 127 | return g, target 128 | 129 | def __len__(self): 130 | return len(self.labels) 131 | 132 | def _loadgraph(self, i): 133 | graph_dict = pickle.load( open(os.path.join(self.root, self.graphs[i]), "rb") ) 134 | 135 | g = dgl.DGLGraph() 136 | 137 | g.gdata = {} 138 | g.gdata['std'] = torch.tensor(graph_dict['graph_properties']).float() 139 | 140 | g.add_nodes(graph_dict['node_labels'].shape[0]) 141 | g.ndata['pos'] = torch.tensor(graph_dict['node_labels']).float() 142 | if g.number_of_nodes() == 0: 143 | g.add_nodes(1, {'pos': torch.zeros(1,2)}) 144 | g.gdata['std'] = torch.zeros(2) 145 | 146 | g.add_edges(graph_dict['am'][0], graph_dict['am'][1]) 147 | 148 | return g 149 | 150 | def getlabels(self): 151 | return np.unique(self.labels) 152 | 153 | def setlabelsdict(self, lab_dict): 154 | self.labels_dict = lab_dict 155 | 156 | 157 | def getFileList(file_path): 158 | with open(file_path, 'r') as f: 159 | lines = f.read().splitlines() 160 | 161 | classes = [] 162 | elements = [] 163 | for line in lines: 164 | f, c = line.split(' ')[:2] 165 | classes += [c] 166 | elements += [f] 167 | return elements, classes 168 | 169 | 170 | def create_graph_histograph(file, representation='adj'): 171 | 172 | tree_gxl = ET.parse(file) 173 | root_gxl = tree_gxl.getroot() 174 | graph_properties = [] 175 | node_label = [] 176 | node_id = [] 177 | 178 | for x_std in root_gxl.iter('attr'): 179 | if x_std.get('name') == 'x_std' or x_std.get('name') == 'y_std': 180 | graph_properties.append(float(x_std.find('float').text)) 181 | 182 | for node in root_gxl.iter('node'): 183 | node_id += [node.get('id')] 184 | for attr in node.iter('attr'): 185 | if (attr.get('name') == 'x'): 186 | x = float(attr.find('float').text) 187 | elif (attr.get('name') == 'y'): 188 | y = float(attr.find('float').text) 189 | node_label += [[x, y]] 190 | 191 | node_label = np.array(node_label) 192 | node_id = np.array(node_id) 193 | 194 | row, col = np.array([]), np.array([]) 195 | for edge in root_gxl.iter('edge'): 196 | s = np.where(np.array(node_id)==edge.get('from'))[0][0] 197 | t = np.where(np.array(node_id)==edge.get('to'))[0][0] 198 | 199 | row = np.append(row, s) 200 | col = np.append(col,t) 201 | 202 | row = np.append(row, t) 203 | col = np.append(col,s) 204 | 205 | data = np.ones(row.shape) 206 | 207 | am = row, col, data 208 | 209 | return graph_properties, node_label, am 210 | 211 | -------------------------------------------------------------------------------- /src/data/Iam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.utils.data as data 4 | import xml.etree.ElementTree as ET 5 | import numpy as np 6 | from . import data_utils as du 7 | import os 8 | import itertools 9 | import pickle 10 | import dgl 11 | 12 | __author__ = "Pau Riba" 13 | __email__ = "priba@cvc.uab.cat" 14 | 15 | 16 | class Iam_train(data.Dataset): 17 | def __init__(self, root_path, file_list, triplet, num_samples=None): 18 | self.root = root_path 19 | self.file_list = file_list 20 | self.triplet = triplet 21 | self.graphs, self.labels = getFileList(self.file_list) 22 | # To pickle 23 | self.graphs = [os.path.splitext(g)[0]+'.p' for g in self.graphs] 24 | 25 | self.unique_labels = np.unique(self.labels) 26 | self.labels = [np.where(target == self.unique_labels)[0][0] for target in self.labels] 27 | self.unique_labels = np.unique(self.labels) 28 | if self.triplet: 29 | # Triplet (anchor, positive, negative) 30 | self.groups = [ (i, j) for i in range(len(self.labels)) for j in np.where(self.labels[i] == self.labels)[0] if i != j ] 31 | else: 32 | # Siamese all pairs 33 | self.groups = list(itertools.permutations(range(len(self.labels)), 2)) 34 | 35 | if num_samples is not None: 36 | np.random.shuffle(self.groups) 37 | num_labels = len(self.unique_labels) 38 | # Balance positive samples 39 | pos_samples = num_samples//2 40 | pos_samples_class = pos_samples//num_labels 41 | pos_count = np.zeros(self.unique_labels.shape) 42 | 43 | neg_samples = num_samples//2 44 | neg_samples_class = 2*neg_samples//num_labels 45 | neg_count = np.zeros(self.unique_labels.shape) 46 | 47 | group = [] 48 | 49 | if self.triplet: 50 | pos_samples_class *= 2 51 | for gr in self.groups: 52 | if pos_count[self.labels[gr[0]] == self.unique_labels] < pos_samples_class: 53 | pos_count[self.labels[gr[0]] == self.unique_labels] += 1 54 | possible_ind = np.where(self.labels!=self.labels[gr[0]])[0] 55 | neg_ind = np.random.choice(possible_ind, 1)[0] 56 | while neg_count[self.labels[neg_ind]] >= neg_samples_class: 57 | possible_ind = np.where(self.labels!=self.labels[gr[0]])[0] 58 | neg_ind = np.random.choice(possible_ind, 1)[0] 59 | neg_count[self.labels[neg_ind]] += 1 60 | group.append((gr[0], gr[1], neg_ind)) 61 | self.group = group 62 | else: 63 | for gr in self.groups: 64 | pair_label = self.labels[gr[0]] == self.labels[gr[1]] 65 | if pair_label: 66 | if pos_count[self.labels[gr[0]] == self.unique_labels] < pos_samples_class: 67 | pos_count[self.labels[gr[0]] == self.unique_labels] += 1 68 | group.append(gr) 69 | else: 70 | if (neg_count[self.labels[gr[0]] == self.unique_labels] < neg_samples_class) and (neg_count[self.labels[gr[1]] == self.unique_labels] < neg_samples_class): 71 | neg_count[self.labels[gr[0]] == self.unique_labels] += 1 72 | neg_count[self.labels[gr[1]] == self.unique_labels] += 1 73 | group.append(gr) 74 | if len(group) 1: 86 | for i, g in enumerate(g1_list): 87 | g.gdata = {} 88 | g.gdata['std'] = g1.gdata['std'][i] 89 | 90 | g2_list = dgl.unbatch(g2) 91 | for i, g in enumerate(g2_list): 92 | g.gdata = {} 93 | g.gdata['std'] = g2.gdata['std'][i] 94 | 95 | d = [] 96 | for i in range(len(g2_list)): 97 | if mode == 'pairs': 98 | d_aux = self.soft_hausdorff(g1_list[i], g2_list[i]) 99 | elif mode == 'retrieval': 100 | query = g1_list[0] 101 | d_aux = self.soft_hausdorff(query, g2_list[i]) 102 | else: 103 | raise NameError(mode + ' not implemented!') 104 | d.append(d_aux) 105 | d = torch.stack(d) 106 | return d 107 | 108 | -------------------------------------------------------------------------------- /src/models/layers.py: -------------------------------------------------------------------------------- 1 | """Torch Module for Gated Graph Convolution layer""" 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | from dgl import function as fn 7 | 8 | class EdgeConv(nn.Module): 9 | def __init__(self, in_feat, out_feat, residual=False, activation=torch.sigmoid): 10 | super(EdgeConv, self).__init__() 11 | 12 | self.residual = residual 13 | 14 | h_feat = 64 15 | self.mlp = nn.Sequential( nn.Linear(in_feat,h_feat), 16 | nn.ReLU(True), 17 | nn.Linear(h_feat, out_feat)) 18 | 19 | self.activation = activation 20 | 21 | def message(self, edges): 22 | e_out = self.mlp((edges.dst['x']-edges.src['x']).abs()) 23 | 24 | if self.residual: 25 | e_out = edges.data['x'] + e_out 26 | 27 | if self.activation is not None: 28 | e_out = self.activation(e_out) 29 | 30 | return {'e': e_out} 31 | 32 | def forward(self, g, h, he=None): 33 | with g.local_scope(): 34 | 35 | g.ndata['x'] = h 36 | if he is not None: 37 | g.edata['x'] = he 38 | 39 | g.apply_edges(self.message) 40 | 41 | return g.edata['e'] 42 | 43 | 44 | class GatedGraphConv(nn.Module): 45 | r"""Gated Graph Convolution layer from paper `Gated Graph Sequence 46 | Neural Networks `__. 47 | 48 | .. math:: 49 | h_{i}^{0} & = [ x_i \| \mathbf{0} ] 50 | 51 | a_{i}^{t} & = \sum_{j\in\mathcal{N}(i)} W_{e_{ij}} h_{j}^{t} 52 | 53 | h_{i}^{t+1} & = \mathrm{GRU}(a_{i}^{t}, h_{i}^{t}) 54 | 55 | Parameters 56 | ---------- 57 | in_feats : int 58 | Input feature size. 59 | out_feats : int 60 | Output feature size. 61 | n_steps : int 62 | Number of recurrent steps. 63 | edge_func : callable activation function/layer 64 | Maps each edge feature to a vector of shape 65 | ``(in_feats * out_feats)`` as weight to compute 66 | messages. 67 | bias : bool 68 | If True, adds a learnable bias to the output. Default: ``True``. 69 | """ 70 | def __init__(self, 71 | in_feats, 72 | out_feats, 73 | n_steps, 74 | edge_func, 75 | # edge_embedding, 76 | bias=True, 77 | dropout = 0.3, 78 | aggregator_type='sum'): 79 | 80 | super(GatedGraphConv, self).__init__() 81 | self._in_feats = in_feats 82 | self._out_feats = out_feats 83 | self._n_steps = n_steps 84 | # self.edge_embedding = edge_embedding 85 | 86 | self.edge_nn = edge_func 87 | if aggregator_type == 'sum': 88 | self.reducer = fn.sum 89 | elif aggregator_type == 'mean': 90 | self.reducer = fn.mean 91 | elif aggregator_type == 'max': 92 | self.reducer = fn.max 93 | else: 94 | raise KeyError('Aggregator type {} not recognized: '.format(aggregator_type)) 95 | 96 | self.aggre_type = aggregator_type 97 | self.gru = nn.GRUCell(in_feats, out_feats, bias=bias) 98 | self.dropout = nn.Dropout(dropout) 99 | self.reset_parameters() 100 | 101 | def reset_parameters(self): 102 | """Reinitialize learnable parameters.""" 103 | gain = init.calculate_gain('relu') 104 | self.gru.reset_parameters() 105 | 106 | def forward(self, graph, feat, efeat): 107 | # def forward(self, graph, feat): 108 | """Compute Gated Graph Convolution layer. 109 | 110 | Parameters 111 | ---------- 112 | graph : DGLGraph 113 | The graph. 114 | feat : torch.Tensor 115 | The input feature of shape :math:`(N, D_{in})` where :math:`N` 116 | is the number of nodes of the graph and :math:`D_{in}` is the 117 | input feature size. 118 | etypes : torch.LongTensor 119 | The edge type tensor of shape :math:`(E,)` where :math:`E` is 120 | the number of edges of the graph. 121 | 122 | Returns 123 | ------- 124 | torch.Tensor 125 | The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}` 126 | is the output feature size. 127 | """ 128 | graph = graph.local_var() 129 | 130 | for step in range(self._n_steps): 131 | # (n, d_in, 1) 132 | graph.ndata['h'] = feat.unsqueeze(-1) 133 | # (n, d_in, d_out) 134 | # efeat = self.edge_embedding(graph, feat) 135 | graph.edata['w'] = self.edge_nn(efeat).view(-1, self._in_feats, self._out_feats) 136 | graph.update_all(fn.u_mul_e('h', 'w', 'm'), self.reducer('m', 'neigh')) 137 | 138 | rst = graph.ndata.pop('neigh').sum(dim=1) # (N, D) 139 | 140 | feat = self.gru(rst, feat) 141 | if step < self._n_steps-1: 142 | feat = self.dropout(feat) 143 | return feat 144 | -------------------------------------------------------------------------------- /src/models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from dgl.nn.pytorch.conv import GATConv 7 | from .layers import GatedGraphConv, EdgeConv 8 | 9 | 10 | class GNN_GAT(nn.Module): 11 | def __init__(self, in_dim, hidden_dim, out_dim, heads=4, dropout=0.3): 12 | super(GNN_GAT, self).__init__() 13 | 14 | self.layers = nn.ModuleList([ 15 | GATConv(in_dim, hidden_dim, heads, residual=True, activation=F.leaky_relu), 16 | GATConv(heads*hidden_dim, hidden_dim, heads, feat_drop=dropout, residual=True, activation=F.leaky_relu),]) 17 | 18 | self.bn = nn.ModuleList([ 19 | nn.BatchNorm1d(heads*hidden_dim), 20 | nn.BatchNorm1d(heads*hidden_dim),]) 21 | 22 | self.last_layer = GATConv(heads*hidden_dim, out_dim, heads, residual=True) 23 | 24 | def forward(self, g): 25 | h = g.ndata['pos'] 26 | 27 | for i, conv in enumerate(self.layers): 28 | h = conv(g, h) 29 | h = h.view(h.shape[0], -1) 30 | h = self.bn[i](h) 31 | 32 | h = self.last_layer(g, h) 33 | h = h.mean(1) 34 | 35 | g.ndata['h'] = h 36 | return g 37 | 38 | 39 | class GNN_GRU(nn.Module): 40 | def __init__(self, in_dim, hidden_dim, out_dim, heads=4, dropout=0.3): 41 | super(GNN_GRU, self).__init__() 42 | 43 | self.embedding = nn.Linear(in_dim, out_dim) 44 | self.edge_embedding = EdgeConv(hidden_dim, hidden_dim, activation=torch.relu) 45 | self.edge_func = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU(True), nn.Linear(64, hidden_dim*hidden_dim)) 46 | # self.layers = GatedGraphConv(hidden_dim, hidden_dim, 3, self.edge_func, dropout=dropout) 47 | self.layers = GatedGraphConv(hidden_dim, out_dim, 3, self.edge_func, 1, dropout=dropout) 48 | # self.last_layer = nn.Linear(hidden_dim, out_dim) 49 | # self.last_edge_layer = EdgeConv(hidden_dim, out_dim, activation=None) 50 | 51 | def forward(self, g): 52 | h = g.ndata['pos'] 53 | 54 | h = self.embedding(h) 55 | he = self.edge_embedding(g, h) 56 | h = self.layers(g, h, he) 57 | # h = self.layers(g, h, torch.zeros(g.edges()[0].shape[0])) 58 | # he = self.last_edge_layer(g, h) 59 | # h = self.last_layer(h) 60 | 61 | g.ndata['h'] = h 62 | # g.edata['h'] = he 63 | 64 | return g 65 | 66 | -------------------------------------------------------------------------------- /src/models/realdistance.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ Graph Distance 5 | """ 6 | 7 | import torch 8 | import dgl 9 | import torch.nn as nn 10 | import numpy as np 11 | 12 | __author__ = "Pau Riba" 13 | __email__ = "priba@cvc.uab.cat" 14 | 15 | class HausdorffEditDistance(nn.Module): 16 | def __init__(self, alpha=0.5, beta=0.1, tau_n=4., tau_e=16.): 17 | super(HausdorffEditDistance, self).__init__() 18 | self.register_buffer('alpha', torch.tensor([alpha])) 19 | self.register_buffer('beta', torch.tensor([beta, 1-beta]).unsqueeze(0).unsqueeze(0)) 20 | self.register_buffer('tau_n', torch.tensor([alpha*tau_n])) 21 | self.register_buffer('tau_e', torch.tensor([(1-alpha)*tau_e])) 22 | self.p = 2 23 | 24 | def cdist(self, set1, set2): 25 | ''' Pairwise Distance between two matrices 26 | Input: x is a Nxd matrix 27 | y is an optional Mxd matirx 28 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] 29 | Source: https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/2 30 | ''' 31 | dist = set1.unsqueeze(1) - set2.unsqueeze(0) 32 | return dist.abs() 33 | 34 | 35 | def soft_hausdorff(self, g1, g2, train=True): 36 | if g1.number_of_nodes() < g2.number_of_nodes(): 37 | tmp = g2 38 | g2 = g1 39 | g1 = tmp 40 | 41 | p1 = g1.ndata['pos'] 42 | p2 = g2.ndata['pos'] 43 | device = p1.device 44 | dtype = p1.dtype 45 | 46 | # Deletion 47 | d1_edges = g1.in_degrees().to(device).to(dtype) 48 | d1 = self.tau_n + d1_edges*self.tau_e/2. 49 | 50 | # Insertion 51 | d2_edges = g2.in_degrees().to(device).to(dtype) 52 | d2 = self.tau_n + d2_edges*self.tau_e/2. 53 | 54 | # Substitution 55 | dist_matrix = self.cdist(p1, p2) 56 | dist_matrix = g1.gdata['std']*dist_matrix 57 | dist_matrix = self.beta*dist_matrix.pow(2.) 58 | dist_matrix = self.alpha*dist_matrix.sum(-1).sqrt() 59 | 60 | # Edges HED 61 | edges_hed = g1.in_degrees().unsqueeze(1)-g2.in_degrees().unsqueeze(0) 62 | edges_hed = self.tau_e*edges_hed.to(device).to(dtype).abs() 63 | 64 | dist_matrix = dist_matrix + edges_hed/2 65 | dist_matrix = dist_matrix/2. 66 | 67 | # \sum_{a\in set1} \inf_{b_\in set2} d(a,b) 68 | a, indA = dist_matrix.min(0) 69 | a = torch.min(a, d2) 70 | 71 | # \sum_{b\in set2} \inf_{a_\in set1} d(a,b) 72 | b, indB = dist_matrix.min(1) 73 | b = torch.min(b, d1) 74 | 75 | #d = a.mean() + b.mean() 76 | d = a.sum() + b.sum() 77 | 78 | upper_bound = (g1.number_of_nodes() - g2.number_of_nodes())*self.tau_n 79 | upper_bound = upper_bound.abs() 80 | if d < upper_bound: 81 | d = upper_bound.squeeze() 82 | 83 | normalization = (self.tau_n*(g1.number_of_nodes() + g2.number_of_nodes()) + self.tau_e*(g1.number_of_edges() + g2.number_of_edges())/2) 84 | d = d/normalization.squeeze() 85 | 86 | if train: 87 | return d 88 | 89 | indA[a==d2] = dist_matrix.shape[0] 90 | indB[b==d1] = dist_matrix.shape[1] 91 | return d, indB, indA 92 | 93 | 94 | def forward(self, g1, g2, mode='pairs'): 95 | ''' mode: 'pairs' expect paired graphs, same for g1 and g2. 96 | 'retrieval' g1 is just one graph and computes the distance against all graphs in g2 97 | ''' 98 | 99 | g1_list = dgl.unbatch(g1) 100 | for i, g in enumerate(g1_list): 101 | g.gdata = {} 102 | g.gdata['std'] = g1.gdata['std'][i] 103 | 104 | g2_list = dgl.unbatch(g2) 105 | for i, g in enumerate(g2_list): 106 | g.gdata = {} 107 | g.gdata['std'] = g2.gdata['std'][i] 108 | 109 | d = [] 110 | for i in range(len(g2_list)): 111 | if mode == 'pairs': 112 | d_aux = self.soft_hausdorff(g1_list[i], g2_list[i]) 113 | elif mode == 'retrieval': 114 | query = g1_list[0] 115 | d_aux = self.soft_hausdorff(query, g2_list[i]) 116 | else: 117 | raise NameError(mode + ' not implemented!') 118 | d.append(d_aux) 119 | d = torch.stack(d) 120 | return d 121 | 122 | -------------------------------------------------------------------------------- /src/options.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Parse input arguments 5 | """ 6 | 7 | import argparse 8 | 9 | __author__ = 'Pau Riba' 10 | __email__ = 'priba@cvc.uab.cat' 11 | 12 | 13 | class Options(): 14 | 15 | def __init__(self): 16 | # MODEL SETTINGS 17 | parser = argparse.ArgumentParser(description='Train a Metric Learning Graph Neural Network', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | # Positional arguments 19 | parser.add_argument('dataset', type=str, choices=['iam', 'histograph-gw', 'histograph-ak'], help='Dataset.') 20 | parser.add_argument('data_path', type=str, help='Dataset root path.') 21 | # Model parameters 22 | parser.add_argument('--model', type=str, default='GAT', help='Model to use GAT or GRU', choices=['GAT', 'GRU']) 23 | parser.add_argument('--set_partition', type=str, help='Dataset set partition (Only histograph-gw otherwise it is ignored).', default='cv1') 24 | parser.add_argument('--batch_size', '-bz', type=int, default=64, help='Batch Size.') 25 | parser.add_argument('--out_size', type=int, default=32, help='Node Embedding size.') 26 | parser.add_argument('--hidden', type=int, default=32, help='Hidden size.') 27 | parser.add_argument('--dropout', type=float, default=0.3, help='Dropout.') 28 | parser.add_argument('--loss', type=str, default='siamese', choices=['siamese', 'triplet', 'triplet_distance'], help='Loss used for training.') 29 | parser.add_argument('--swap', action='store_true', help='Swap in the triplet loss.') 30 | parser.add_argument('--margin', type=float, default=10, help='Margin in the loss function.') 31 | # Optimization options 32 | parser.add_argument('--prefetch', type=int, default=4, help='Number of workers to load data.') 33 | parser.add_argument('--epochs', '-e', type=int, default=1000, help='Number of epochs to train.') 34 | parser.add_argument('--learning_rate', '-lr', type=float, default=1e-3, help='The Learning Rate.') 35 | parser.add_argument('--decay', '-d', type=float, default=0.0005, help='Weight decay (L2 penalty).') 36 | parser.add_argument('--schedule', type=int, nargs='+', default=[], 37 | help='Decrease learning rate at these epochs.') 38 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 39 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 40 | # Checkpoints 41 | parser.add_argument('--save', '-s', type=str, default=None, help='Folder to save checkpoints.') 42 | parser.add_argument('--load', '-l', type=str, default=None, help='Checkpoint path to resume / test.') 43 | parser.add_argument('--test', '-t', action='store_true', help='Test only flag.') 44 | parser.add_argument('--early_stop', '-es', type=int, default=50, help='Early stopping epochs.') 45 | # Acceleration 46 | parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU, 1 = CUDA, 1 < DataParallel') 47 | # i/o 48 | parser.add_argument('--log', type=str, default=None, help='Log folder.') 49 | parser.add_argument('--log-interval', type=int, default=256, metavar='N', 50 | help='How many batches to wait before logging training status') 51 | self.parser = parser 52 | 53 | def parse(self): 54 | return self.parser.parse_args() 55 | 56 | -------------------------------------------------------------------------------- /src/plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function, division 4 | 5 | """ 6 | Graph metric learning 7 | """ 8 | 9 | # Python modules 10 | import torch 11 | import networkx as nx 12 | import glob 13 | import numpy as np 14 | import time 15 | import os 16 | import sys 17 | import copy 18 | import matplotlib.pyplot as plt 19 | 20 | # Own modules 21 | from options import Options 22 | from utils import load_checkpoint 23 | from models import models, distance 24 | from data.HistoGraph import HistoGraph 25 | __author__ = "Pau Riba" 26 | __email__ = "priba@cvc.uab.cat" 27 | 28 | 29 | def plot(g1, g2, ind1, ind2, name1, name2, distance): 30 | 31 | # pos1 = dict(g1.nodes(data='pos')) 32 | # pos1 = {k: v.numpy() for k, v in pos1.items()} 33 | 34 | # pos2 = dict(g2.nodes(data='pos')) 35 | # pos2 = {k: v.numpy() for k, v in pos2.items()} 36 | 37 | # plt.subplot(221) 38 | # nx.draw(g1, pos=pos1, node_size=50) 39 | # plt.title(f"{name1}") 40 | # plt.gca().invert_yaxis() 41 | # 42 | # plt.subplot(222) 43 | # nx.draw(g2, pos=pos2, node_size=50) 44 | # plt.title(f"{name2}") 45 | # plt.gca().invert_yaxis() 46 | 47 | plt.subplot(111) 48 | 49 | h = nx.disjoint_union(g1,g2) 50 | 51 | id_g1 = range(len(g1)) 52 | id_g2 = range(len(g1), len(g1)+len(g2)) 53 | 54 | pos = dict(h.nodes(data='pos')) 55 | pos = {k: v.cpu().numpy() for k, v in pos.items()} 56 | for i in id_g2: 57 | pos[i] += np.array([0,5]) 58 | 59 | nx.draw_networkx_nodes(h, pos=pos, nodelist=id_g1, node_size=50,node_color='b') 60 | nx.draw_networkx_nodes(h, pos=pos, nodelist=id_g2, node_size=50,node_color='c') 61 | nx.draw_networkx_edges(h, pos=pos) 62 | 63 | # Insertions 64 | insert_list = [i for i, v in enumerate(ind1) if v==len(g2)] 65 | # Delitions 66 | delitions_list = [i+len(g1) for i, v in enumerate(ind2) if v==len(g1)] 67 | 68 | nx.draw_networkx_nodes(h, pos=pos, nodelist=insert_list, node_size=50,node_color='b',edgecolors='lime', linewidths=2.0) 69 | nx.draw_networkx_nodes(h, pos=pos, nodelist=delitions_list, node_size=50,node_color='c', edgecolors='orange', linewidths=2.0) 70 | 71 | h.remove_edges_from(h.edges()) 72 | h = h.to_directed() 73 | 74 | for i, v in enumerate(ind1): 75 | if v!=len(g2): 76 | h.add_edge(i, v+len(g1)) 77 | for i, v in enumerate(ind2): 78 | if v!=len(g1): 79 | h.add_edge(i+len(g1), v) 80 | 81 | nx.draw_networkx_edges(h, pos=pos, edge_color='r') 82 | plt.title(f"{name1}-{name2} D: {distance}") 83 | plt.gca().invert_yaxis() 84 | plt.show() 85 | 86 | 87 | def main(query, query_name, target, target_name): 88 | 89 | print('Prepare data') 90 | split = os.path.normpath(args.data_path).split(os.sep) 91 | split[-2] = split[-2] + '-pickled' 92 | pickle_dir = os.path.join(*split) 93 | if split[0]=='': 94 | pickle_dir = os.sep + pickle_dir 95 | # gt_path = os.path.join(args.data_path, os.pardir, '00_GroundTruth', args.set_partition) 96 | gt_path = os.path.join(args.data_path, os.pardir, '00_GroundTruth') 97 | # data = HistoGraph(pickle_dir, os.path.join(gt_path, 'test.txt')) 98 | data = HistoGraph(os.path.join(pickle_dir, '02_Test'), os.path.join(gt_path, '02_Test', 'words.txt')) 99 | data_query = HistoGraph(os.path.join(pickle_dir, '02_Test'), os.path.join(gt_path, '02_Test', 'queries.txt')) 100 | 101 | # data_query = copy.deepcopy(data) 102 | data_query.graphs = [query] 103 | data_query.labels = [query_name] 104 | g1, l1 = data_query[0] 105 | 106 | data_target = data 107 | data_target.graphs = [target] 108 | data_target.labels = [target_name] 109 | g2, l2 = data_target[0] 110 | del data_target, data_query, data 111 | 112 | print('Create model') 113 | net = models.GNN(2, args.hidden, args.out_size, dropout=args.dropout) 114 | distNet = distance.SoftHd(args.out_size) 115 | 116 | print('Check CUDA') 117 | if args.cuda: 118 | print('\t* CUDA') 119 | net, distNet = net.cuda(), distNet.cuda() 120 | g1.ndata['pos'] = g1.ndata['pos'].cuda() 121 | g2.ndata['pos'] = g2.ndata['pos'].cuda() 122 | 123 | 124 | if args.load is not None: 125 | print('Loading model') 126 | checkpoint = load_checkpoint(args.load) 127 | net.load_state_dict(checkpoint['state_dict']) 128 | distNet.load_state_dict(checkpoint['state_dict_dist']) 129 | 130 | print('***PLOT***') 131 | g1 = net(g1) 132 | g2 = net(g2) 133 | dist, indB, indA = distNet.soft_hausdorff(g1, g2, train=False) 134 | plot(g1.to_networkx(node_attrs=['pos']).to_undirected(), g2.to_networkx(node_attrs=['pos']).to_undirected(), indB.tolist(), indA.tolist(), query_name, target_name, dist.item()) 135 | 136 | if __name__ == '__main__': 137 | # Parse options 138 | args = Options().parse() 139 | print('Parameters:\t' + str(args)) 140 | 141 | # Check cuda & Set random seed 142 | args.cuda = args.ngpu > 0 and torch.cuda.is_available() 143 | 144 | if args.seed > 1: 145 | np.random.seed(args.seed) 146 | torch.manual_seed(args.seed) 147 | if args.cuda: 148 | torch.cuda.manual_seed(args.seed) 149 | 150 | # Check Test and Load 151 | if args.load is None: 152 | raise Exception('Cannot plot without loading a model.') 153 | 154 | main(query='kq0089.p', query_name='HERRN', target='kw000544.p', target_name='HERRM') 155 | main(query='kq0089.p', query_name='HERRN', target='kw001777.p', target_name='HERRM') 156 | 157 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | """ 5 | Graph classification 6 | """ 7 | 8 | # Python modules 9 | import torch 10 | import glob 11 | import numpy as np 12 | import time 13 | import os 14 | 15 | # Own modules 16 | from options import Options 17 | from Logger import LogMetric 18 | from utils import load_checkpoint, knn_accuracy, mean_average_precision 19 | from models import models, distance 20 | from data.load_data import load_data 21 | from loss.contrastive import ContrastiveLoss, TripletLoss 22 | import dgl 23 | 24 | __author__ = "Pau Riba" 25 | __email__ = "priba@cvc.uab.cat" 26 | 27 | def test(data_loader, gallery_loader, nets, cuda, validation=False): 28 | batch_time = LogMetric.AverageMeter() 29 | acc = LogMetric.AverageMeter() 30 | meanap = LogMetric.AverageMeter() 31 | 32 | net, distance = nets 33 | 34 | # switch to test mode 35 | net.eval() 36 | distance.eval() 37 | 38 | end = time.time() 39 | 40 | dist_matrix = [] 41 | start = time.time() 42 | with torch.no_grad(): 43 | g_gallery = [] 44 | target_gallery = [] 45 | for j, (g, target) in enumerate(gallery_loader): 46 | if cuda: 47 | g.to(torch.device('cuda')) 48 | g.gdata['std'] = g.gdata['std'].cuda() 49 | 50 | # Output 51 | g = net(g) 52 | 53 | target_gallery.append(target) 54 | g_gallery.append(g) 55 | 56 | target_gallery = np.array(np.concatenate(target_gallery)) 57 | gdata = list(map(lambda g: g.gdata['std'], g_gallery)) 58 | g_gallery = dgl.batch(g_gallery) 59 | g_gallery.gdata = {'std': torch.cat(gdata)} 60 | 61 | target_query = [] 62 | for i, (g, target) in enumerate(data_loader): 63 | # Prepare input data 64 | if cuda: 65 | g.to(torch.device('cuda')) 66 | g.gdata['std'] = g.gdata['std'].cuda() 67 | 68 | # Output 69 | g = net(g) 70 | d = distance(g, g_gallery, mode='retrieval') 71 | 72 | dist_matrix.append(d) 73 | target_query.append(target) 74 | 75 | dist_matrix = torch.stack(dist_matrix) 76 | target_query = np.array(np.concatenate(target_query)) 77 | 78 | if validation: 79 | target_combined_query = target_query 80 | combined_dist_matrix = dist_matrix 81 | else: 82 | print('* Test No combine mAP {}'.format(mean_average_precision(dist_matrix, target_gallery, target_query))) 83 | target_combined_query = np.unique(target_query) 84 | combined_dist_matrix = torch.zeros(target_combined_query.shape[0], dist_matrix.shape[1]) 85 | 86 | for i, kw in enumerate(target_combined_query): 87 | ind = kw == target_query 88 | combined_dist_matrix[i] = dist_matrix[ind].min(0).values 89 | 90 | # K-NN classifier 91 | acc.update(knn_accuracy(combined_dist_matrix, target_gallery, target_combined_query, k=5)) 92 | 93 | # mAP retrieval 94 | meanap.update(mean_average_precision(combined_dist_matrix, target_gallery, target_combined_query)) 95 | batch_time.update(time.time()-start) 96 | print('* Test Acc {acc.avg:.3f}; mAP {meanap.avg: .5f} Time x Test {b_time.avg:.3f}' 97 | .format(acc=acc, meanap=meanap, b_time=batch_time)) 98 | return acc, meanap 99 | 100 | 101 | def main(): 102 | print('Loss & Optimizer') 103 | if args.loss=='triplet': 104 | args.triplet=True 105 | criterion = TripletLoss(margin=args.margin, swap=args.swap) 106 | elif args.loss=='triplet_distance': 107 | args.triplet=True 108 | criterion = TripletLoss(margin=args.margin, swap=args.swap, dist=True) 109 | else: 110 | args.triplet=False 111 | criterion = ContrastiveLoss(margin=args.margin) 112 | 113 | print('Prepare data') 114 | train_loader, valid_loader, valid_gallery_loader, test_loader, test_gallery_loader, in_size = load_data(args.dataset, args.data_path, triplet=args.triplet, batch_size=args.batch_size, prefetch=args.prefetch) 115 | 116 | print('Create model') 117 | net = models.GNN(in_size, args.out_size, nlayers=args.nlayers, hid=args.hidden, J=args.pow) 118 | distNet = distance.SoftHd() 119 | 120 | print('Check CUDA') 121 | if args.cuda and args.ngpu > 1: 122 | print('\t* Data Parallel **NOT TESTED**') 123 | net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) 124 | 125 | if args.cuda: 126 | print('\t* CUDA') 127 | net, distNet = net.cuda(), distNet.cuda() 128 | criterion = criterion.cuda() 129 | 130 | start_epoch = 0 131 | best_map = 0 132 | early_stop_counter = 0 133 | if args.load is not None: 134 | print('Loading model') 135 | checkpoint = load_checkpoint(args.load) 136 | net.load_state_dict(checkpoint['state_dict']) 137 | distNet.load_state_dict(checkpoint['state_dict_dist']) 138 | start_epoch = checkpoint['epoch'] 139 | best_map = checkpoint['best_map'] 140 | print('Loaded model at epoch {epoch} and mAP {meanap}%'.format(epoch=checkpoint['epoch'],meanap=checkpoint['best_map'])) 141 | 142 | print('***Test***') 143 | test(test_loader, test_gallery_loader, [net, distNet], args.cuda) 144 | 145 | if __name__ == '__main__': 146 | # Parse options 147 | args = Options().parse() 148 | print('Parameters:\t' + str(args)) 149 | 150 | # Check cuda & Set random seed 151 | args.cuda = args.ngpu > 0 and torch.cuda.is_available() 152 | 153 | np.random.seed(args.seed) 154 | torch.manual_seed(args.seed) 155 | if args.cuda: 156 | torch.cuda.manual_seed(args.seed) 157 | 158 | # Check Test and Load 159 | if args.load is None: 160 | raise Exception('Cannot test without loading a model.') 161 | 162 | main() 163 | 164 | -------------------------------------------------------------------------------- /src/testHED.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | """ 5 | Graph classification 6 | """ 7 | 8 | # Python modules 9 | import torch 10 | import glob 11 | import numpy as np 12 | import time 13 | import os 14 | import argparse 15 | 16 | # Own modules 17 | from options import Options 18 | from Logger import LogMetric 19 | from utils import load_checkpoint, knn_accuracy, mean_average_precision 20 | from models import models, realdistance 21 | from data.load_data import load_data 22 | import dgl 23 | 24 | __author__ = "Pau Riba" 25 | __email__ = "priba@cvc.uab.cat" 26 | 27 | def test(data_loader, gallery_loader, distance, cuda): 28 | batch_time = LogMetric.AverageMeter() 29 | acc = LogMetric.AverageMeter() 30 | meanap = LogMetric.AverageMeter() 31 | 32 | end = time.time() 33 | distance.eval() 34 | 35 | dist_matrix = [] 36 | start = time.time() 37 | with torch.no_grad(): 38 | g_gallery = [] 39 | target_gallery = [] 40 | for j, (g, target) in enumerate(gallery_loader): 41 | if cuda: 42 | g.to(torch.device('cuda')) 43 | g.gdata['std'] = g.gdata['std'].cuda() 44 | 45 | target_gallery.append(target) 46 | g_gallery.append(g) 47 | 48 | target_gallery = np.array(np.concatenate(target_gallery)) 49 | gdata = list(map(lambda g: g.gdata['std'], g_gallery)) 50 | g_gallery = dgl.batch(g_gallery) 51 | g_gallery.gdata = {'std': torch.cat(gdata)} 52 | 53 | target_query = [] 54 | for i, (g, target) in enumerate(data_loader): 55 | 56 | # Prepare input data 57 | if cuda: 58 | g.to(torch.device('cuda')) 59 | g.gdata['std'] = g.gdata['std'].cuda() 60 | 61 | # Output 62 | d = distance(g, g_gallery, mode='retrieval') 63 | 64 | dist_matrix.append(d) 65 | target_query.append(target) 66 | 67 | dist_matrix = torch.stack(dist_matrix) 68 | target_query = np.array(np.concatenate(target_query)) 69 | 70 | target_combined_query = np.unique(target_query) 71 | combined_dist_matrix = torch.zeros(target_combined_query.shape[0], dist_matrix.shape[1]) 72 | 73 | for i, kw in enumerate(target_combined_query): 74 | ind = kw == target_query 75 | combined_dist_matrix[i] = dist_matrix[ind].min(0).values 76 | 77 | # K-NN classifier 78 | acc.update(knn_accuracy(combined_dist_matrix, target_gallery, target_combined_query, k=5, dataset=data_loader.dataset.dataset)) 79 | 80 | # mAP retrieval 81 | meanap.update(mean_average_precision(combined_dist_matrix, target_gallery, target_combined_query)) 82 | batch_time.update(time.time()-start) 83 | print('* Test Acc {acc.avg:.3f}; mAP {meanap.avg: .3f}; Time x Test {b_time.avg:.3f}' 84 | .format(acc=acc, meanap=meanap, b_time=batch_time)) 85 | return acc, meanap 86 | 87 | 88 | def main(): 89 | print('Prepare data') 90 | train_loader, valid_loader, valid_gallery_loader, test_loader, test_gallery_loader, in_size = load_data(args.dataset, args.data_path, batch_size=args.batch_size, prefetch=args.prefetch, set_partition=args.set_partition) 91 | 92 | distance = realdistance.HausdorffEditDistance(alpha=args.alpha, beta=args.beta, tau_n=args.tau_n, tau_e=args.tau_e) 93 | 94 | print('Check CUDA') 95 | if args.cuda and args.ngpu > 1: 96 | distance = torch.nn.DataParallel(distance, device_ids=list(range(args.ngpu))) 97 | 98 | if args.cuda: 99 | distance = distance.cuda() 100 | 101 | print('***Test***') 102 | test(test_loader, test_gallery_loader, distance, args.cuda) 103 | 104 | if __name__ == '__main__': 105 | # Parse options 106 | parser = argparse.ArgumentParser(description='Train a Metric Learning Graph Neural Network', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 107 | # Positional arguments 108 | parser.add_argument('dataset', type=str, choices=['iam', 'histograph-gw', 'histograph-ak'], help='Dataset.') 109 | parser.add_argument('data_path', type=str, help='Dataset root path.') 110 | # Model parameters 111 | parser.add_argument('--set_partition', type=str, help='Dataset set partition (Only histograph-gw otherwise it is ignored).', default='cv1') 112 | parser.add_argument('--batch_size', '-bz', type=int, default=64, help='Batch Size.') 113 | parser.add_argument('--tau_n', type=float, default=4, help='Batch Size.') 114 | parser.add_argument('--tau_e', type=float, default=16, help='Batch Size.') 115 | parser.add_argument('--alpha', type=float, default=0.5, help='Batch Size.') 116 | parser.add_argument('--beta', type=float, default=0.1, help='Batch Size.') 117 | # Optimization options 118 | parser.add_argument('--prefetch', type=int, default=4, help='Number of workers to load data.') 119 | parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU, 1 = CUDA, 1 < DataParallel') 120 | # i/o 121 | args = parser.parse_args() 122 | 123 | # Check cuda & Set random seed 124 | args.cuda = args.ngpu > 0 and torch.cuda.is_available() 125 | 126 | main() 127 | 128 | -------------------------------------------------------------------------------- /src/test_iam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | """ 5 | Graph classification 6 | """ 7 | 8 | # Python modules 9 | import torch 10 | import glob 11 | import numpy as np 12 | import time 13 | import os 14 | 15 | # Own modules 16 | from options import Options 17 | from Logger import LogMetric 18 | from utils import load_checkpoint, knn_accuracy, mean_average_precision 19 | from models import models, distance 20 | from data.load_data import load_data 21 | from loss.contrastive import ContrastiveLoss, TripletLoss 22 | import dgl 23 | from sklearn import metrics 24 | 25 | __author__ = "Pau Riba" 26 | __email__ = "priba@cvc.uab.cat" 27 | 28 | def test(data_triplet_loader, nets, cuda, data_pair_loader=None): 29 | batch_time = LogMetric.AverageMeter() 30 | acc = LogMetric.AverageMeter() 31 | auc = LogMetric.AverageMeter() 32 | 33 | net, distance = nets 34 | 35 | # switch to test mode 36 | net.eval() 37 | distance.eval() 38 | 39 | end = time.time() 40 | 41 | dist_matrix = [] 42 | start = time.time() 43 | with torch.no_grad(): 44 | total, correct = 0,0 45 | for j, (g1, g2, g3, target) in enumerate(data_triplet_loader): 46 | if cuda: 47 | g1.to(torch.device('cuda')) 48 | g2.to(torch.device('cuda')) 49 | g3.to(torch.device('cuda')) 50 | 51 | # Output 52 | g1 = net(g1) 53 | g2 = net(g2) 54 | g3 = net(g3) 55 | 56 | d_pos = distance(g1, g2, mode='pairs') 57 | d_neg = distance(g1, g3, mode='pairs') 58 | total += d_pos.shape[0] 59 | correct += (d_pos < d_neg).float().sum() 60 | 61 | acc.update(correct/total) 62 | 63 | if data_pair_loader is not None: 64 | distances, labels = [], [] 65 | for j, (g1, g2, _, target) in enumerate(data_pair_loader): 66 | if cuda: 67 | g1.to(torch.device('cuda')) 68 | g2.to(torch.device('cuda')) 69 | 70 | # Output 71 | g1 = net(g1) 72 | g2 = net(g2) 73 | 74 | d = distance(g1, g2, mode='pairs') 75 | distances.append(d) 76 | labels.append(target) 77 | similarity = -torch.cat(distances, 0) 78 | similarity = (similarity-similarity.min()) / (similarity.max() - similarity.min() + 1e-8) 79 | labels = torch.cat(labels, 0) 80 | auc.update(metrics.roc_auc_score(labels.cpu(), similarity.cpu())) 81 | 82 | # mAP retrieval 83 | batch_time.update(time.time()-start) 84 | print('* Test Acc {acc.avg:.5f}; AUC {auc.avg: .5f} Time x Test {b_time.avg:.3f}' 85 | .format(acc=acc, auc=auc, b_time=batch_time)) 86 | return acc, auc 87 | 88 | 89 | def main(): 90 | print('Loss & Optimizer') 91 | if args.loss=='triplet': 92 | args.triplet=True 93 | criterion = TripletLoss(margin=args.margin, swap=args.swap) 94 | elif args.loss=='triplet_distance': 95 | args.triplet=True 96 | criterion = TripletLoss(margin=args.margin, swap=args.swap, dist=True) 97 | else: 98 | args.triplet=False 99 | criterion = ContrastiveLoss(margin=args.margin) 100 | 101 | print('Prepare data') 102 | train_loader, valid_loader, valid_gallery_loader, test_loader, test_gallery_loader, in_size = load_data(args.dataset, args.data_path, triplet=args.triplet, batch_size=args.batch_size, prefetch=args.prefetch) 103 | 104 | print('Create model') 105 | net = models.GNN(in_size, args.out_size, nlayers=args.nlayers, hid=args.hidden, J=args.pow) 106 | distNet = distance.SoftHd() 107 | 108 | print('Check CUDA') 109 | if args.cuda and args.ngpu > 1: 110 | print('\t* Data Parallel **NOT TESTED**') 111 | net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) 112 | 113 | if args.cuda: 114 | print('\t* CUDA') 115 | net, distNet = net.cuda(), distNet.cuda() 116 | criterion = criterion.cuda() 117 | 118 | start_epoch = 0 119 | best_map = 0 120 | early_stop_counter = 0 121 | if args.load is not None: 122 | print('Loading model') 123 | checkpoint = load_checkpoint(args.load) 124 | net.load_state_dict(checkpoint['state_dict']) 125 | distNet.load_state_dict(checkpoint['state_dict_dist']) 126 | start_epoch = checkpoint['epoch'] 127 | best_map = checkpoint['best_map'] 128 | print('Loaded model at epoch {epoch} and mAP {meanap}%'.format(epoch=checkpoint['epoch'],meanap=checkpoint['best_map'])) 129 | 130 | print('***Test***') 131 | test(test_loader, test_gallery_loader, [net, distNet], args.cuda) 132 | 133 | if __name__ == '__main__': 134 | # Parse options 135 | args = Options().parse() 136 | print('Parameters:\t' + str(args)) 137 | 138 | # Check cuda & Set random seed 139 | args.cuda = args.ngpu > 0 and torch.cuda.is_available() 140 | 141 | np.random.seed(args.seed) 142 | torch.manual_seed(args.seed) 143 | if args.cuda: 144 | torch.cuda.manual_seed(args.seed) 145 | 146 | # Check Test and Load 147 | if args.load is None: 148 | raise Exception('Cannot test without loading a model.') 149 | 150 | main() 151 | 152 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function, division 4 | 5 | """ 6 | Graph metric learning 7 | """ 8 | 9 | # Python modules 10 | import torch 11 | from torch.optim.lr_scheduler import StepLR 12 | import glob 13 | import numpy as np 14 | import time 15 | import os 16 | import sys 17 | 18 | # Own modules 19 | from options import Options 20 | from Logger import LogMetric 21 | from utils import save_checkpoint, load_checkpoint 22 | from models import models, distance 23 | from test import test 24 | from data.load_data import load_data 25 | from loss.contrastive import ContrastiveLoss, TripletLoss 26 | 27 | __author__ = "Pau Riba" 28 | __email__ = "priba@cvc.uab.cat" 29 | 30 | 31 | def train(data_loader, nets, optimizer, cuda, criterion, epoch): 32 | batch_time = LogMetric.AverageMeter() 33 | batch_load_time = LogMetric.AverageMeter() 34 | losses = LogMetric.AverageMeter() 35 | 36 | net, distNet = nets 37 | # switch to train mode 38 | net.train() 39 | distNet.train() 40 | 41 | end = time.time() 42 | for i, (g1, g2, g3, target) in enumerate(data_loader): 43 | # Prepare input data 44 | if cuda: 45 | g1.to(torch.device('cuda')) 46 | g2.to(torch.device('cuda')) 47 | g1.gdata['std'], g2.gdata['std'] = g1.gdata['std'].cuda(), g2.gdata['std'].cuda() 48 | if args.triplet: 49 | g3.to(torch.device('cuda')) 50 | g3.gdata['std'] = g3.gdata['std'].cuda() 51 | else: 52 | target = target.cuda() 53 | 54 | batch_load_time.update(time.time() - end) 55 | optimizer.zero_grad() 56 | 57 | # Output 58 | g1 = net(g1) 59 | g2 = net(g2) 60 | 61 | if args.triplet: 62 | g3 = net(g3) 63 | loss = criterion(g1, g2, g3, distNet) 64 | else: 65 | loss = criterion(g1, g2, target, distNet) 66 | 67 | # Gradiensts and update 68 | loss.backward() 69 | optimizer.step() 70 | 71 | # Save values 72 | losses.update(loss.item(), g1.batch_size) 73 | batch_time.update(time.time() - end) 74 | end = time.time() 75 | 76 | if i > 0 and i%args.log_interval == 0: 77 | print('Epoch: [{0}]({1}/{2}) Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f} Avg Load Time x Batch {b_load_time.avg:.3f}' 78 | .format(epoch, i, len(data_loader), loss=losses, b_time=batch_time, b_load_time=batch_load_time)) 79 | print('Epoch: [{0}] Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f} Avg Time x Batch {b_load_time.avg:.3f}' 80 | .format(epoch, loss=losses, b_time=batch_time, b_load_time=batch_load_time)) 81 | return losses 82 | 83 | 84 | def main(): 85 | print('Loss & Optimizer') 86 | if args.loss=='triplet': 87 | args.triplet=True 88 | criterion = TripletLoss(margin=args.margin, swap=args.swap) 89 | elif args.loss=='triplet_distance': 90 | args.triplet=True 91 | criterion = TripletLoss(margin=args.margin, swap=args.swap, dist=True) 92 | else: 93 | args.triplet=False 94 | criterion = ContrastiveLoss(margin=args.margin) 95 | 96 | print('Prepare data') 97 | train_loader, valid_loader, valid_gallery_loader, test_loader, test_gallery_loader, in_size = load_data(args.dataset, args.data_path, triplet=args.triplet, batch_size=args.batch_size, prefetch=args.prefetch, set_partition=args.set_partition) 98 | 99 | print('Create model') 100 | if args.model == 'GAT': 101 | net = models.GNN_GAT(in_size, args.hidden, args.out_size, dropout=args.dropout) 102 | elif args.model == 'GRU': 103 | net = models.GNN_GRU(in_size, args.hidden, args.out_size, dropout=args.dropout) 104 | 105 | distNet = distance.SoftHd(args.out_size) 106 | 107 | optimizer = torch.optim.Adam(list(net.parameters())+list(distNet.parameters()), args.learning_rate, weight_decay=args.decay) 108 | scheduler = StepLR(optimizer, 5, gamma = args.gamma) 109 | 110 | print('Check CUDA') 111 | if args.cuda and args.ngpu > 1: 112 | print('\t* Data Parallel **NOT TESTED**') 113 | net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) 114 | 115 | if args.cuda: 116 | print('\t* CUDA') 117 | net, distNet = net.cuda(), distNet.cuda() 118 | criterion = criterion.cuda() 119 | 120 | start_epoch = 0 121 | best_perf = 0 122 | early_stop_counter = 0 123 | if args.load is not None: 124 | print('Loading model') 125 | checkpoint = load_checkpoint(args.load) 126 | net.load_state_dict(checkpoint['state_dict']) 127 | distNet.load_state_dict(checkpoint['state_dict_dist']) 128 | start_epoch = checkpoint['epoch'] 129 | best_perf = checkpoint['best_perf'] 130 | 131 | if not args.test: 132 | print('***Train***') 133 | 134 | for epoch in range(start_epoch, args.epochs): 135 | 136 | loss_train = train(train_loader, [net, distNet], optimizer, args.cuda, criterion, epoch) 137 | acc_valid, map_valid = test(valid_loader, valid_gallery_loader, [net, distNet], args.cuda, validation=True) 138 | 139 | # Early-Stop + Save model 140 | if map_valid.avg > best_perf: 141 | best_perf = map_valid.avg 142 | early_stop_counter = 0 143 | if args.save is not None: 144 | save_checkpoint({'epoch': epoch + 1, 'state_dict': net.state_dict(), 'state_dict_dist': distNet.state_dict(), 'best_perf': best_perf}, directory=args.save, file_name='checkpoint') 145 | else: 146 | if early_stop_counter >= args.early_stop: 147 | print('Early Stop epoch {}'.format(epoch)) 148 | break 149 | early_stop_counter += 1 150 | 151 | # Logger 152 | if args.log: 153 | # Scalars 154 | logger.add_scalar('loss_train', loss_train.avg) 155 | logger.add_scalar('acc_valid', acc_valid.avg) 156 | logger.add_scalar('map_valid', map_valid.avg) 157 | logger.add_scalar('learning_rate', scheduler.get_lr()[0]) 158 | logger.step() 159 | 160 | scheduler.step() 161 | # Load Best model in case of save it 162 | if args.save is not None: 163 | print('Loading best model') 164 | best_model_file = os.path.join(args.save, 'checkpoint.pth') 165 | checkpoint = load_checkpoint(best_model_file) 166 | net.load_state_dict(checkpoint['state_dict']) 167 | distNet.load_state_dict(checkpoint['state_dict_dist']) 168 | print('Best model at epoch {epoch} and acc {acc}%'.format(epoch=checkpoint['epoch'],acc=checkpoint['best_perf'])) 169 | 170 | print('***Valid***') 171 | test(valid_loader, valid_gallery_loader, [net, distNet], args.cuda) 172 | print('***Test***') 173 | test(test_loader, test_gallery_loader, [net, distNet], args.cuda) 174 | sys.exit() 175 | 176 | if __name__ == '__main__': 177 | torch.autograd.set_detect_anomaly(True) 178 | # Parse options 179 | args = Options().parse() 180 | print('Parameters:\t' + str(args)) 181 | 182 | # Check cuda & Set random seed 183 | args.cuda = args.ngpu > 0 and torch.cuda.is_available() 184 | 185 | if args.seed > 1: 186 | np.random.seed(args.seed) 187 | torch.manual_seed(args.seed) 188 | if args.cuda: 189 | torch.cuda.manual_seed(args.seed) 190 | 191 | # Check Test and Load 192 | if args.test and args.load is None: 193 | raise Exception('Cannot test without loading a model.') 194 | 195 | if not args.test and args.log is not None: 196 | print('Initialize logger') 197 | ind = len(glob.glob(args.log + '*_run-batchSize_{}'.format(args.batch_size))) 198 | log_dir = args.log + '{}_run-batchSize_{}/' \ 199 | .format(ind, args.batch_size) 200 | args.save = args.save + '{}_run-batchSize_{}/' \ 201 | .format(ind, args.batch_size) 202 | # Create logger 203 | print('Log dir:\t' + log_dir) 204 | logger = LogMetric.Logger(log_dir, force=True) 205 | 206 | main() 207 | sys.exit() 208 | 209 | -------------------------------------------------------------------------------- /src/train_iam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function, division 4 | 5 | """ 6 | Graph metric learning 7 | """ 8 | 9 | # Python modules 10 | import torch 11 | from torch.optim.lr_scheduler import StepLR 12 | import glob 13 | import numpy as np 14 | import time 15 | import os 16 | import sys 17 | 18 | # Own modules 19 | from options import Options 20 | from Logger import LogMetric 21 | from utils import save_checkpoint, load_checkpoint 22 | from models import models, distance 23 | from test_iam import test 24 | from data.load_data import load_data 25 | from loss.contrastive import ContrastiveLoss, TripletLoss 26 | 27 | __author__ = "Pau Riba" 28 | __email__ = "priba@cvc.uab.cat" 29 | 30 | 31 | def train(data_loader, nets, optimizer, cuda, criterion, epoch): 32 | batch_time = LogMetric.AverageMeter() 33 | batch_load_time = LogMetric.AverageMeter() 34 | losses = LogMetric.AverageMeter() 35 | 36 | net, distNet = nets 37 | # switch to train mode 38 | net.train() 39 | distNet.train() 40 | 41 | end = time.time() 42 | for i, (g1, g2, g3, target) in enumerate(data_loader): 43 | # Prepare input data 44 | if cuda: 45 | g1.to(torch.device('cuda')) 46 | g2.to(torch.device('cuda')) 47 | g1.gdata['std'], g2.gdata['std'] = g1.gdata['std'].cuda(), g2.gdata['std'].cuda() 48 | if args.triplet: 49 | g3.to(torch.device('cuda')) 50 | g3.gdata['std'] = g3.gdata['std'].cuda() 51 | else: 52 | target = target.cuda() 53 | 54 | batch_load_time.update(time.time() - end) 55 | optimizer.zero_grad() 56 | 57 | # Output 58 | g1 = net(g1) 59 | g2 = net(g2) 60 | 61 | if args.triplet: 62 | g3 = net(g3) 63 | loss = criterion(g1, g2, g3, distNet) 64 | else: 65 | loss = criterion(g1, g2, target, distNet) 66 | 67 | # Gradiensts and update 68 | loss.backward() 69 | optimizer.step() 70 | 71 | # Save values 72 | losses.update(loss.item(), g1.batch_size) 73 | batch_time.update(time.time() - end) 74 | end = time.time() 75 | 76 | if i > 0 and i%args.log_interval == 0: 77 | print('Epoch: [{0}]({1}/{2}) Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f} Avg Load Time x Batch {b_load_time.avg:.3f}' 78 | .format(epoch, i, len(data_loader), loss=losses, b_time=batch_time, b_load_time=batch_load_time)) 79 | print('Epoch: [{0}] Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f} Avg Time x Batch {b_load_time.avg:.3f}' 80 | .format(epoch, loss=losses, b_time=batch_time, b_load_time=batch_load_time)) 81 | return losses 82 | 83 | 84 | def main(): 85 | print('Loss & Optimizer') 86 | if args.loss=='triplet': 87 | args.triplet=True 88 | criterion = TripletLoss(margin=args.margin, swap=args.swap) 89 | elif args.loss=='triplet_distance': 90 | args.triplet=True 91 | criterion = TripletLoss(margin=args.margin, swap=args.swap, dist=True) 92 | else: 93 | args.triplet=False 94 | criterion = ContrastiveLoss(margin=args.margin) 95 | 96 | print('Prepare data') 97 | train_loader, valid_loader, test_pair_loader, test_triplet_loader, in_size = load_data(args.dataset, args.data_path, triplet=args.triplet, batch_size=args.batch_size, prefetch=args.prefetch, set_partition=args.set_partition) 98 | 99 | print('Create model') 100 | net = models.GNN(in_size, args.hidden, args.out_size, dropout=args.dropout) 101 | distNet = distance.SoftHd(args.out_size) 102 | 103 | optimizer = torch.optim.Adam(list(net.parameters())+list(distNet.parameters()), args.learning_rate, weight_decay=args.decay) 104 | scheduler = StepLR(optimizer, 5, gamma = args.gamma) 105 | 106 | print('Check CUDA') 107 | if args.cuda and args.ngpu > 1: 108 | print('\t* Data Parallel **NOT TESTED**') 109 | net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) 110 | 111 | if args.cuda: 112 | print('\t* CUDA') 113 | net, distNet = net.cuda(), distNet.cuda() 114 | criterion = criterion.cuda() 115 | 116 | start_epoch = 0 117 | best_perf = 0 118 | early_stop_counter = 0 119 | if args.load is not None: 120 | print('Loading model') 121 | checkpoint = load_checkpoint(args.load) 122 | net.load_state_dict(checkpoint['state_dict']) 123 | distNet.load_state_dict(checkpoint['state_dict_dist']) 124 | start_epoch = checkpoint['epoch'] 125 | best_perf = checkpoint['best_perf'] 126 | 127 | if not args.test: 128 | print('***Train***') 129 | 130 | for epoch in range(start_epoch, args.epochs): 131 | 132 | loss_train = train(train_loader, [net, distNet], optimizer, args.cuda, criterion, epoch) 133 | acc_valid, auc_valid = test(valid_loader, [net, distNet], args.cuda) 134 | 135 | # Early-Stop + Save model 136 | if acc_valid.avg > best_perf: 137 | best_perf = acc_valid.avg 138 | early_stop_counter = 0 139 | if args.save is not None: 140 | save_checkpoint({'epoch': epoch + 1, 'state_dict': net.state_dict(), 'state_dict_dist': distNet.state_dict(), 'best_perf': best_perf}, directory=args.save, file_name='checkpoint') 141 | else: 142 | if early_stop_counter >= args.early_stop: 143 | print('Early Stop epoch {}'.format(epoch)) 144 | break 145 | early_stop_counter += 1 146 | 147 | # Logger 148 | if args.log: 149 | # Scalars 150 | logger.add_scalar('loss_train', loss_train.avg) 151 | logger.add_scalar('acc_valid', acc_valid.avg.item()) 152 | logger.add_scalar('learning_rate', scheduler.get_lr()[0]) 153 | logger.step() 154 | 155 | scheduler.step() 156 | # Load Best model in case of save it 157 | if args.save is not None: 158 | print('Loading best model') 159 | best_model_file = os.path.join(args.save, 'checkpoint.pth') 160 | checkpoint = load_checkpoint(best_model_file) 161 | net.load_state_dict(checkpoint['state_dict']) 162 | distNet.load_state_dict(checkpoint['state_dict_dist']) 163 | print('Best model at epoch {epoch} and acc {acc}%'.format(epoch=checkpoint['epoch'],acc=checkpoint['best_perf'])) 164 | 165 | print('***Valid***') 166 | test(valid_loader, [net, distNet], args.cuda) 167 | print('***Test***') 168 | test(test_triplet_loader, [net, distNet], args.cuda, data_pair_loader=test_pair_loader) 169 | sys.exit() 170 | 171 | if __name__ == '__main__': 172 | torch.autograd.set_detect_anomaly(True) 173 | # Parse options 174 | args = Options().parse() 175 | print('Parameters:\t' + str(args)) 176 | 177 | # Check cuda & Set random seed 178 | args.cuda = args.ngpu > 0 and torch.cuda.is_available() 179 | 180 | if args.seed > 1: 181 | np.random.seed(args.seed) 182 | torch.manual_seed(args.seed) 183 | if args.cuda: 184 | torch.cuda.manual_seed(args.seed) 185 | 186 | # Check Test and Load 187 | if args.test and args.load is None: 188 | raise Exception('Cannot test without loading a model.') 189 | 190 | if not args.test and args.log is not None: 191 | print('Initialize logger') 192 | ind = len(glob.glob(args.log + '*_run-batchSize_{}'.format(args.batch_size))) 193 | log_dir = args.log + '{}_run-batchSize_{}/' \ 194 | .format(ind, args.batch_size) 195 | args.save = args.save + '{}_run-batchSize_{}/' \ 196 | .format(ind, args.batch_size) 197 | # Create logger 198 | print('Log dir:\t' + log_dir) 199 | logger = LogMetric.Logger(log_dir, force=True) 200 | 201 | main() 202 | sys.exit() 203 | 204 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | """ 5 | Pytorch useful tools. 6 | """ 7 | 8 | import torch 9 | import os 10 | import errno 11 | import numpy as np 12 | from sklearn.metrics import average_precision_score 13 | from joblib import Parallel, delayed 14 | import multiprocessing 15 | 16 | __author__ = 'Pau Riba' 17 | __email__ = 'priba@cvc.uab.cat' 18 | 19 | 20 | # Evaluation 21 | def knn_accuracy(dist_matrix, target_gallery, target_query, k=5): 22 | # Predict 23 | _, ind = dist_matrix.sort(1) 24 | sort_target = target_gallery[ind.cpu()] 25 | sort_target = sort_target[:,:k] 26 | 27 | # Counts 28 | counts = np.zeros(sort_target.shape) 29 | for i in range(k): 30 | counts[:,i] = (np.expand_dims(sort_target[:, i], axis=1) == sort_target).sum(1) 31 | 32 | predict_ind = counts.argmax(1) 33 | predict = [sort_target[i, pi] for i, pi in enumerate(predict_ind)] 34 | predict = np.stack(predict) 35 | 36 | # Accuracy 37 | acc = (predict == target_query).astype(np.float).sum() 38 | acc = 100.0*acc/predict.shape[0] 39 | return acc 40 | 41 | 42 | def mean_average_precision(dist_matrix, target_gallery, target_query): 43 | # Number of queries 44 | nq = target_query.shape[0] 45 | 46 | aps = [] 47 | for q in range(nq): 48 | _, indices = dist_matrix[q].sort() 49 | rel = np.array(target_query[q] == target_gallery[indices.cpu()]) 50 | if rel.any(): 51 | x=np.float32(np.cumsum(rel))/range(1,len(rel)+1) 52 | aps.append( np.sum(x[rel])/(len(x[rel])+10**-7)) 53 | 54 | return np.mean(aps) 55 | 56 | 57 | # def mean_average_precision(dist_matrix, target_gallery, target_query): 58 | # # Number of queries 59 | # nq = target_query.shape[0] 60 | # 61 | # interpolation_points = np.linspace(0,1,11) 62 | # 63 | # aps = [] 64 | # for q in range(nq): 65 | # _, indices = dist_matrix[q].sort() 66 | # rel = np.array(target_query[q] == target_gallery[indices.cpu()]) 67 | # 68 | # recall = np.float32(np.cumsum(rel))/rel.sum() 69 | # precision = np.float32(np.cumsum(rel))/range(1,len(rel)+1) 70 | # 71 | # prec = [precision[recall>=i].max() for i in interpolation_points] 72 | # aps.append( np.mean(prec)) 73 | # 74 | # return np.mean(aps) 75 | 76 | # Checkpoints 77 | def save_checkpoint(state, directory, file_name): 78 | 79 | if not os.path.isdir(directory): 80 | os.makedirs(directory) 81 | checkpoint_file = os.path.join(directory, file_name + '.pth') 82 | torch.save(state, checkpoint_file) 83 | 84 | 85 | def load_checkpoint(model_file): 86 | if os.path.isfile(model_file): 87 | print("=> loading model '{}'".format(model_file)) 88 | checkpoint = torch.load(model_file) 89 | print("=> loaded model '{}' (epoch {}, Best Performance {})".format(model_file, checkpoint['epoch'], checkpoint['best_perf'])) 90 | return checkpoint 91 | else: 92 | print("=> no model found at '{}'".format(model_file)) 93 | raise OSError(errno.ENOENT, os.strerror(errno.ENOENT), model_file) 94 | 95 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | source env/bin/activate 2 | 3 | . $1 4 | 5 | python src/test.py $dataset $data_path $bz $out_size $nlayers $hidden $dropout $loss $swap $margin $pow $epochs $lr $momentum $decay $schedule $gamma $seed $save $load $test $early_stop $prefetch $ngpu $log $log_interval 6 | 7 | deactivate 8 | -------------------------------------------------------------------------------- /testHED.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dataset='histograph-ak' 3 | # dataset='histograph-gw' 4 | # data_path=../../Datasets/histograph/01_GW/01_Keypoint/ 5 | data_path=../../Datasets/histograph/03_AK/01_Keypoint/ 6 | bz="--batch_size 64" 7 | ngpu="--ngpu 0" 8 | prefetch="--prefetch 4" 9 | set_partition="--set_partition cv1" 10 | 11 | tau_n=4 12 | tau_e=16 13 | alpha=0.5 14 | beta=0.1 15 | 16 | echo "EXPERIMENT Tn=$tau_n; Te=$tau_e; Alpha=$alpha; Beta=$beta" 17 | python src/testHED.py $dataset $data_path $set_partition $bz $prefetch $ngpu --tau_n $tau_n --tau_e $tau_e --alpha $alpha --beta $beta 18 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | . $1 2 | 3 | python src/train.py $dataset $data_path $bz $out_size $hidden $dropout $loss $swap $margin $epochs $lr $decay $schedule $gamma $save $load $test $early_stop $prefetch $ngpu $log $log_interval $set_partition $model 4 | 5 | -------------------------------------------------------------------------------- /train_iam.sh: -------------------------------------------------------------------------------- 1 | . $1 2 | 3 | python src/train_iam.py $dataset $data_path $bz $out_size $hidden $dropout $loss $swap $margin $epochs $lr $decay $schedule $gamma $save $load $test $early_stop $prefetch $ngpu $log $log_interval $set_partition 4 | 5 | --------------------------------------------------------------------------------