├── data_util ├── __init__.py ├── rome16k │ ├── __init__.py │ ├── __main__.py │ ├── scenes.py │ └── parse.py ├── __main__.py ├── datasets.py ├── tf_helpers.py ├── noisy_dataset.py ├── parent_dataset.py └── real_dataset.py ├── baselines ├── myspectral.m ├── mmatch_spectral.m ├── mmatch_QP_PG.m ├── print_errors.py ├── mmatch_CVX_ALS.m ├── roc_curves.py ├── run_tests.m └── PGDDS.m ├── model ├── __init__.py ├── model.py ├── networks.py ├── skip_networks.py └── layers.py ├── .gitignore ├── myutils.py ├── sim_graphs.py ├── README.md ├── tfutils.py ├── test.py ├── experiment.py ├── options.py └── train.py /data_util/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Dataset handling code 4 | """ 5 | 6 | from data_util.datasets import * 7 | -------------------------------------------------------------------------------- /baselines/myspectral.m: -------------------------------------------------------------------------------- 1 | function [X] = myspectral(W,k) 2 | 3 | [V, D] = eigs(W,k,'lm'); 4 | Y = V(:,1:k); 5 | Dy = D(1:k,1:k); 6 | X = abs(Y*Dy*Y'); 7 | 8 | 9 | end 10 | 11 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Network model functions and utilities 4 | """ 5 | 6 | from model.layers import * 7 | from model.networks import * 8 | from model.skip_networks import * 9 | from model.model import * 10 | 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore compiled files 2 | *.sh 3 | *.pyc 4 | stdout.log 5 | # Ignore exteral imports for experiments 6 | external/* 7 | # Ignore log folders 8 | save/* 9 | launch/*yaml 10 | __pycache__/ 11 | figs/* 12 | logs/* 13 | baselines/*txt 14 | baselines/*log 15 | -------------------------------------------------------------------------------- /data_util/rome16k/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Import Rome16K files 4 | """ 5 | 6 | # Dataset files 7 | from data_util.datasets import * 8 | from data_util.parent_dataset import * 9 | from data_util.noisy_dataset import * 10 | from data_util.real_dataset import * 11 | 12 | # Utility files 13 | from data_util import rome16k 14 | 15 | -------------------------------------------------------------------------------- /baselines/mmatch_spectral.m: -------------------------------------------------------------------------------- 1 | function [X,Y,run_time] = mmatch_spectral(W,dimGroup,k) 2 | 3 | t0 = tic; 4 | 5 | % display('Spectral method begins....') 6 | 7 | k = min(k,size(W,1)); 8 | 9 | [V,~] = eigs(W,k,'la'); 10 | % [V,~] = svd(W); 11 | % Y = rounding(V(:,1:k),dimGroup,0.5); 12 | Y = abs(V(:,1:k)); 13 | Y = Y(:,1:min(size(Y,2),k)); 14 | 15 | 16 | % csdimGroup = [0;cumsum(dimGroup(:))]; 17 | % for i=1:numel(dimGroup) 18 | % 19 | % idx = csdimGroup(i)+1: csdimGroup(i+1); 20 | % Y(idx,:)= matrix2perm(Y(idx,:)); 21 | % 22 | % end 23 | 24 | % X = single(Y)*single(Y)'>0; 25 | X = max(0,single(Y)*single(Y)'); 26 | 27 | run_time = toc(t0); 28 | 29 | % [display(sprintf('Spectral method terminated in %0.2f seconds...',run_time)) 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /data_util/__main__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Generate datasets 4 | """ 5 | import numpy as np 6 | import os 7 | 8 | import options 9 | from data_util import datasets 10 | from data_util import noisy_dataset 11 | from data_util import real_dataset 12 | 13 | opts = options.get_opts() 14 | print("Generating Pose Graphs") 15 | if not os.path.exists(opts.data_dir): 16 | os.makedirs(opts.data_dir) 17 | mydataset = datasets.get_dataset(opts) 18 | 19 | # Create train and test tfrecords 20 | types = [ 21 | 'train', 22 | 'test' 23 | ] 24 | for t in types: 25 | dname = os.path.join(opts.data_dir,t) 26 | if not os.path.exists(dname): 27 | os.makedirs(dname) 28 | mydataset.convert_dataset(dname, t) 29 | 30 | # Generate numpy test files 31 | out_dir = os.path.join(opts.data_dir,'np_test') 32 | if not os.path.exists(out_dir): 33 | os.makedirs(out_dir) 34 | mydataset.create_np_dataset(out_dir, opts.dataset_params.sizes['test']) 35 | 36 | 37 | -------------------------------------------------------------------------------- /baselines/mmatch_QP_PG.m: -------------------------------------------------------------------------------- 1 | function Y = mmatch_QP_PG(Q,alpha,beta,nP,Y) 2 | 3 | tol = 1e-3; 4 | n = size(Q,1); 5 | Q = - Q + beta; 6 | nP = cumsum(nP); 7 | nP = [0,nP]; 8 | concavify = false; 9 | 10 | if concavify 11 | if n < 1000 12 | d = eig(Q); 13 | w = max(d,0); 14 | d = d - w; % all <= 0 15 | Q = Q - diag(w); 16 | W = w/2*ones(1,size(Y,2)); 17 | mu = 1.1*max(abs(d)); 18 | else 19 | d = sum(abs(Q),2)-abs(diag(Q)); % off-diagnal sum 20 | w = d + diag(Q); 21 | Q = Q - diag(w); 22 | W = w/2*ones(1,size(Y,2)); 23 | mu = 1.1*norm(Q); 24 | end 25 | else 26 | mu = 1.1*norm(Q); 27 | W = 0; 28 | end 29 | 30 | for iter = 1:200 31 | 32 | Y0 = Y; 33 | Y = Y - (Q*Y+W+alpha*Y)/mu; 34 | 35 | for i = 1:length(nP)-1 36 | ind = nP(i)+1:nP(i+1); 37 | Y(ind,:) = proj2dpam(Y(ind,:),1e-2); 38 | end 39 | 40 | RelChg = norm(Y(:)-Y0(:))/n; 41 | fprintf('Iter = %d, Res = (%d), mu = %d\n',iter,RelChg,mu); 42 | 43 | if RelChg < tol 44 | break 45 | end 46 | 47 | end 48 | 49 | end 50 | 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /data_util/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import os 4 | 5 | from data_util import parent_dataset 6 | from data_util import noisy_dataset 7 | from data_util import real_dataset 8 | 9 | 10 | def get_dataset(opts): 11 | """Getting the dataset with all the correct attributes 12 | Input: opts (options) - object with all relevant options stored 13 | Output: dataset (data_util.GraphSimDataset) - dataset for training/testing 14 | """ 15 | if opts.dataset in [ 'synth_small', 'synth_3view', 'synth_4view', \ 16 | 'synth_5view', 'synth_6view' ]: 17 | return parent_dataset.GraphSimDataset(opts, opts.dataset_params) 18 | elif 'synth_pts' in opts.dataset: 19 | return parent_dataset.GraphSimDataset(opts, opts.dataset_params) 20 | elif opts.dataset in [ 'noise_3view' ]: 21 | return noisy_dataset.GraphSimNoisyDataset(opts, opts.dataset_params) 22 | elif opts.dataset in [ 'noise_gauss' ]: 23 | return noisy_dataset.GraphSimGaussDataset(opts, opts.dataset_params) 24 | elif opts.dataset in [ 'noise_symgauss' ]: 25 | return noisy_dataset.GraphSimSymGaussDataset(opts, opts.dataset_params) 26 | elif 'noise_largepairwise' in opts.dataset or \ 27 | 'noise_pairwise' in opts.dataset: 28 | return noisy_dataset.GraphSimPairwiseDataset(opts, opts.dataset_params) 29 | elif 'noise_outlier' in opts.dataset: 30 | return noisy_dataset.GraphSimOutlierDataset(opts, opts.dataset_params) 31 | elif opts.dataset in [ 'rome16kknn0' ]: 32 | return real_dataset.KNNRome16KDataset(opts, opts.dataset_params) 33 | elif opts.dataset in [ 'rome16kgeom0', 'rome16kgeom4view0' ]: 34 | return real_dataset.GeomKNNRome16KDataset(opts, opts.dataset_params) 35 | 36 | -------------------------------------------------------------------------------- /data_util/tf_helpers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # Tensorflow features 4 | def _bytes_feature(value): 5 | """Create arbitrary tensor Tensorflow feature.""" 6 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 7 | 8 | def _int64_feature(value): 9 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 10 | 11 | class Int64Feature(object): 12 | """Custom class used for decoding serialized tensors.""" 13 | def __init__(self, key, description): 14 | super(Int64Feature, self).__init__() 15 | self._key = key 16 | self.description = description 17 | self.shape = [] 18 | self.dtype = 'int64' 19 | 20 | def get_placeholder(self): 21 | return tf.placeholder(tf.int64, shape=[None]) 22 | 23 | def get_feature_write(self, value): 24 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 25 | 26 | def get_feature_read(self): 27 | return tf.FixedLenFeature([], tf.int64) 28 | 29 | def tensors_to_item(self, keys_to_tensors): 30 | tensor = keys_to_tensors[self._key] 31 | return tf.cast(tensor, dtype=tf.int64) 32 | 33 | class TensorFeature(object): 34 | """Custom class used for decoding serialized tensors.""" 35 | def __init__(self, key, shape, dtype, description): 36 | super(TensorFeature, self).__init__() 37 | self._key = key 38 | self.shape = shape 39 | self.dtype = dtype 40 | self.description = description 41 | 42 | def get_placeholder(self): 43 | return tf.placeholder(self.dtype, shape=[None] + self.shape) 44 | 45 | def get_feature_write(self, value): 46 | v = value.astype(self.dtype).tobytes() 47 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[v])) 48 | 49 | def get_feature_read(self): 50 | return tf.FixedLenFeature([], tf.string) 51 | 52 | def tensors_to_item(self, keys_to_tensors): 53 | tensor = keys_to_tensors[self._key] 54 | tensor = tf.decode_raw(tensor, out_type=self.dtype) 55 | return tf.reshape(tensor, self.shape) 56 | 57 | -------------------------------------------------------------------------------- /baselines/print_errors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import re 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | egstr = '000008 Errors: L1: 1.430e-02, L2: 5.834e-03, BCE: 5.971e-02, ' \ 9 | 'Same sim: 4.281e-01 +/- 1.905e-01, Diff sim: 7.239e-03 +/- 3.317e-02, ' \ 10 | 'Area under ROC: 9.161e-01, Area under P-R: 7.879e-01, ' \ 11 | 'Time: 2.072e+00' 12 | 13 | def stdagg(x): 14 | return np.sqrt(np.mean(np.array(x)**2)) 15 | 16 | def myformat2(x): 17 | return '{:.05e}'.format(x) 18 | 19 | def myformat(x): 20 | return '{:.03f}'.format(x) 21 | 22 | def myformat_old(x): 23 | y = "{:.03e}".format(x).split('e') 24 | return "{}e-{}".format(y[0], y[1][-1]) 25 | 26 | efmt = '[-+]?\d+\.\d*e[-+]\d+' 27 | disp_match = re.compile(efmt) 28 | names = [ 29 | 'l1', 'l2', 'bce', \ 30 | 'ssame_m', 'ssame_s', 'sdiff_m', 'sdiff_s', \ 31 | 'roc', 'pr', \ 32 | 'time'] 33 | def parse(line): 34 | return dict(zip(names, [ float(x) for x in disp_match.findall(line) ])) 35 | 36 | agg_names = [ 'l1', 'l2', 'bce', 'ssame', 'sdiff' ] 37 | def agg(vals): 38 | aggs = dict(zip(agg_names, [ None for nm in agg_names ])) 39 | for k in [ 'l1', 'l2', 'bce', 'roc', 'pr', 'time' ]: 40 | aggs[k] = (np.mean(vals[k]), np.std(vals[k])) 41 | for k in [ 'ssame', 'sdiff' ]: 42 | aggs[k] = ( np.mean(vals[k + '_m']), stdagg(vals[k + '_s']) ) 43 | return aggs 44 | 45 | def disp_val(aggs): 46 | # fstr = "{:40}, L1: {} +/- {} , L2: {} +/- {} , BCE: {} +/- {}" 47 | # print(fstr.format(fname, 48 | # myformat(aggs['l1'][0]), myformat(aggs['l1'][1]), 49 | # myformat(aggs['l2'][0]), myformat(aggs['l2'][1]), 50 | # myformat(aggs['bce'][0]), myformat(aggs['bce'][1]))) 51 | # return 52 | fstr = "{:40} & {} $\pm$ {} & {} $\pm$ {} & {} $\pm$ {} & {} $\pm$ {} & {} $\pm$ {} \\\\ \\hline" 53 | print(fstr.format(fname, 54 | myformat(aggs['l1'][0]), myformat(aggs['l1'][1]), 55 | myformat(aggs['l2'][0]), myformat(aggs['l2'][1]), 56 | myformat(aggs['roc'][0]), myformat(aggs['roc'][1]), 57 | myformat(aggs['pr'][0]), myformat(aggs['pr'][1]), 58 | myformat(aggs['time'][0]), myformat(aggs['time'][1]))) 59 | 60 | 61 | 62 | topstr = "Method &" \ 63 | " $L_1$ &" \ 64 | " $L_2$ &" \ 65 | " Area under ROC &" \ 66 | " Area Prec.-Recall &" \ 67 | " Time (sec) \\ \hline" 68 | print(topstr) 69 | for fname in sys.argv[1:]: 70 | vals = dict(zip(names, [ [] for nm in names ])) 71 | f = open(fname, 'r') 72 | for line in f: 73 | vals_ = parse(line) 74 | for k, v in vals_.items(): 75 | vals[k].append(v) 76 | 77 | f.close() 78 | 79 | # Latex 80 | aggs = agg(vals) 81 | disp_val(aggs) 82 | -------------------------------------------------------------------------------- /myutils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utility functions for os and numpy related stuff 3 | """ 4 | import os 5 | import sys 6 | import numpy as np 7 | 8 | # Numpy 9 | def mysincf(x): 10 | """Numerically stable sinc function (sin(x)/x) 11 | Input: x (float) 12 | Output: z (float) - sin(x)/x, numerically stable around 0 13 | """ 14 | z = x if x != 0 else 1e-16 15 | return np.sin(z) / z 16 | 17 | def mysinc(x): 18 | """Numerically stable sinc function (sin(x)/x) 19 | Input: x (numpy array) 20 | Output: z (numpy array) - sin(x)/x, numerically stable around 0 21 | """ 22 | z = np.select(x == 0, 1e-16, x) 23 | return np.sin(z) / z 24 | 25 | def normalize(x): 26 | """Return the unit vector in the direction of x""" 27 | return x / (1e-16 + np.linalg.norm(x)) 28 | 29 | def sph_rot(x): 30 | """Takes unit vector and create rotation matrix from it 31 | Input: x (3x1 or 1x3 matrix) 32 | Output: 33 | - R (3x3 matrix) - rotation matrix such that dot(R,x) = x (not 34 | deterministically made) 35 | """ 36 | x = x.reshape(-1) 37 | u = normalize(np.random.randn(3)) 38 | R = np.array([ 39 | normalize(np.cross(np.cross(u,x),x)), 40 | normalize(np.cross(u,x)), 41 | x, 42 | ]) 43 | return R 44 | 45 | def dim_norm(X): 46 | """Norms of the vectors along the last dimension of X 47 | Input: X (NxM numpy array) 48 | Output: X_norm (Nx1 numpy array) - norm of each row of X 49 | """ 50 | return np.expand_dims(np.sqrt(np.sum(X**2, axis=-1)), axis=-1) 51 | 52 | def dim_normalize(X): 53 | """Return X with vectors along last dimension normalized to unit length 54 | Input: X (NxM numpy array) 55 | Output: X_norm (Nx1 numpy array) - norm of each row of X 56 | """ 57 | return X / dim_norm(X) 58 | 59 | def planer_proj(X): 60 | """Return X divided by the last element of its dimension 61 | Input: X (NxM numpy array) 62 | Output: X_proj (NxM numpy array) - X with each row divided by its last element 63 | """ 64 | return X / np.expand_dims(X[...,-1], axis=-1) 65 | 66 | # Miscellaneous 67 | def str2bool(v): 68 | """Convert a string into a boolean 69 | Input: v (string) - string to convert to boolean 70 | Output: v_bool (boolean) - appropriate boolean matching the string 71 | """ 72 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 73 | return True 74 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 75 | return False 76 | else: 77 | import argparse 78 | raise argparse.ArgumentTypeError('Boolean value expected.') 79 | 80 | def next_file(directory, fname, suffix): 81 | """Returns name of file in directory with a number suffix incremented. 82 | Input: 83 | - directory (string) - name of directory that file will be in 84 | - fname (string) - prefix to the number that will get incremented 85 | - suffix (string) - file suffix (e.g. .png, .jpg, .txt) 86 | Output: None 87 | When this is called, it will check all files in directory with prefix fname 88 | in order (i.e. fname000.png, fname001.png, etc.) until it hits the highest 89 | number and then returns the value with a number 1 higher than that. 90 | """ 91 | fidx = 1 92 | name = lambda i: os.path.join(directory,"{}{:03d}{}".format(fname,i,suffix)) 93 | while os.path.exists(name(fidx)): 94 | fidx += 1 95 | return name(fidx) 96 | 97 | 98 | -------------------------------------------------------------------------------- /sim_graphs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import os 4 | import sys 5 | import collections 6 | import scipy.linalg as la 7 | from tqdm import tqdm 8 | 9 | from myutils import * 10 | import options 11 | 12 | # Classes 13 | Points = collections.namedtuple("Points", ["p","d"]) # position and descriptor 14 | Pose = collections.namedtuple("Pose", ["R","T"]) 15 | PoseEdge = collections.namedtuple("PoseEdge", ["idx", "g_ij"]) 16 | class PoseGraph(object): 17 | """Generation for the synthetic training data, with some visualization aides 18 | """ 19 | def __init__(self, params, n_pts, n_views): 20 | """Create PoseGraph 21 | Inputs: 22 | - params - dataset_params from options.py 23 | - n_pts - number of points per view 24 | - n_views - number of views 25 | Outputs: PoseGraph 26 | """ 27 | self.params = params 28 | self.n_pts = n_pts 29 | self.n_views = n_views 30 | # Generate poses 31 | sph = dim_normalize(np.random.randn(self.n_views,3)) 32 | rot = [ sph_rot(-sph[i]) for i in range(self.n_views) ] 33 | trans = params.scale*sph 34 | # Create variables 35 | pts = params.points_scale*np.random.randn(self.n_pts,3) 36 | self.desc_dim = params.descriptor_dim 37 | self.desc_var = params.descriptor_var 38 | desc = self.desc_var*np.random.randn(self.n_pts, self.desc_dim) 39 | self.pts_w = Points(p=pts,d=desc) 40 | self.g_cw = [ Pose(R=rot[i],T=trans[i]) for i in range(self.n_views) ] 41 | # Create graph 42 | eye = np.eye(self.n_views) 43 | dist_mat = 2 - 2*np.dot(sph, sph.T) + 3*eye 44 | AdjList0 = [ dist_mat[i].argsort()[:params.knn].tolist() 45 | for i in range(self.n_views) ] 46 | A = np.array([ sum([ eye[j] for j in AdjList0[i] ]) 47 | for i in range(self.n_views) ]) 48 | self.adj_mat = np.minimum(1, A.T + A) 49 | get_adjs = lambda adj: np.argwhere(adj.reshape(-1) > 0).T.tolist()[0] 50 | self.adj_list = [] 51 | for i in range(self.n_views): 52 | pose_edges = [] 53 | for j in get_adjs(self.adj_mat[i]): 54 | Rij = np.dot(rot[i].T,rot[j]), 55 | Tij = normalize(np.dot(rot[i].T, trans[j] - trans[i])).reshape((3,1)) 56 | pose_edges.append(PoseEdge(idx=j, g_ij=Pose(R=Rij, T=Tij))) 57 | self.adj_list.append(pose_edges) 58 | 59 | def get_random_state(self, pts): 60 | """Get random state determined by 3d points pts""" 61 | seed = (np.sum(np.abs(pts**5))) 62 | return np.random.RandomState(int(seed)) 63 | 64 | def get_proj(self, i): 65 | """Get the 2d projection for view i""" 66 | pts_c = np.dot(self.pts_w.p - self.g_cw[i].T, self.g_cw[i].R.T) 67 | s = self.get_random_state(pts_c) 68 | perm = s.permutation(self.n_pts) 69 | proj_pos = planer_proj(pts_c)[perm,:2] 70 | var = self.params.descriptor_noise_var 71 | desc_noise = var*s.randn(self.n_pts, self.desc_dim) 72 | descs = self.pts_w.d[perm,:] + desc_noise 73 | return Points(p=proj_pos, d=descs) 74 | 75 | def get_perm(self, i): 76 | """Get the permutation of ground truth points for view i""" 77 | pts_c = np.dot(self.pts_w.p - self.g_cw[i].T, self.g_cw[i].R.T) 78 | s = self.get_random_state(pts_c) 79 | return s.permutation(self.n_pts) 80 | 81 | def get_all_permutations(self): 82 | """Get list of all permutations from all views""" 83 | return [ self.get_perm(i) for i in range(self.n_views) ] 84 | 85 | def get_feature_matching_mat(self): 86 | """Get matching matrix using the synthetic features""" 87 | n = self.n_pts 88 | m = self.n_views 89 | perms = [ self.get_perm(i) for i in range(m) ] 90 | sigma = 2 91 | total_graph = np.zeros((n*m, n*m)) 92 | for i in range(m): 93 | for j in ([ e.idx for e in self.adj_list[i] ]): 94 | s_ij = np.zeros((n, n)) 95 | descs_i = self.get_proj(i).d 96 | descs_j = self.get_proj(j).d 97 | for x in range(n): 98 | u = perms[i][x] 99 | for y in range(n): 100 | v = perms[j][y] 101 | s_ij[u,v] = np.exp(-np.linalg.norm(descs_i[u] - descs_j[v])/(sigma)) 102 | total_graph[i*n:(i+1)*n, j*n:(j+1)*n] = s_ij 103 | return total_graph # + np.eye(n*m) 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import os 4 | import sys 5 | import tensorflow as tf 6 | import sonnet as snt 7 | 8 | import tfutils 9 | import myutils 10 | import options 11 | 12 | from model import networks 13 | from model import skip_networks 14 | 15 | def get_regularizers(opts): 16 | """Get regularizers for the weights and biases for Sonnet networks 17 | Input: opts (options) - object with all relevant options 18 | Output: regularizers (dict) - regulraizing functions with right keys 19 | """ 20 | regularizer_fn = None 21 | bias_fn = tf.contrib.layers.l2_regularizer(0.0) 22 | if opts.weight_decay <= 0 and opts.weight_l1_decay <= 0: 23 | return None 24 | elif opts.weight_decay > 0 and opts.weight_l1_decay <= 0: 25 | regularizer_fn = \ 26 | lambda r_l2, r_l1: tf.contrib.layers.l2_regularizer(r_l2) 27 | elif opts.weight_decay <= 0 and opts.weight_l1_decay > 0: 28 | regularizer_fn = \ 29 | lambda r_l2, r_l1: tf.contrib.layers.l1_regularizer(r_l1) 30 | elif opts.weight_decay <= 0 and opts.weight_l1_decay > 0: 31 | regularizer_fn = \ 32 | lambda r_l2, r_l1: tf.contrib.layers.l1_l2_regularizer(r_l1/r_l2, 1.0) 33 | all_regs = { 34 | "w" : regularizer_fn(opts.weight_decay, opts.weight_l1_decay), 35 | "u" : regularizer_fn(opts.weight_decay, opts.weight_l1_decay), 36 | "f1" : regularizer_fn(opts.weight_decay, opts.weight_l1_decay), 37 | "f2" : regularizer_fn(opts.weight_decay, opts.weight_l1_decay), 38 | "b" : bias_fn, 39 | "c" : bias_fn, 40 | "d1" : bias_fn, 41 | "d2" : bias_fn, 42 | } 43 | if opts.architecture in ['vanilla', 'vanilla0', 'vanilla1']: 44 | return { k: all_regs[k] for k in [ "w", "b" ] } 45 | elif opts.architecture in ['skip', 'skip0', 'skip1', \ 46 | 'longskip0', 'longskip1', \ 47 | 'normedskip0', 'normedskip1', \ 48 | 'normedskip2', 'normedskip3', ]: 49 | return { k: all_regs[k] for k in [ "w", "u", "b", "c" ] } 50 | elif opts.architecture in ['attn0', 'attn1', 'attn2', \ 51 | 'spattn0', 'spattn1', 'spattn2']: 52 | return all_regs 53 | 54 | def get_network(opts, arch): 55 | """Get Sonnet networks for training and testing 56 | Input: 57 | - opts (options) - object with all relevant options 58 | - arch (ArchParams) - object with all relevant Architecture options 59 | Output: Network (snt.Module) - network to train 60 | """ 61 | regularizers = None 62 | if opts.architecture in ['vanilla', 'vanilla0', 'vanilla1']: 63 | network = networks.GraphConvLayerNetwork( 64 | opts, 65 | arch, 66 | regularizers=get_regularizers(opts)) 67 | elif opts.architecture in ['skip', 'skip0', 'skip1']: 68 | network = skip_networks.GraphSkipLayerNetwork( 69 | opts, 70 | arch, 71 | regularizers=get_regularizers(opts)) 72 | elif opts.architecture in ['longskip0', 'longskip1']: 73 | network = skip_networks.GraphLongSkipLayerNetwork( 74 | opts, 75 | arch, 76 | regularizers=get_regularizers(opts)) 77 | elif opts.architecture in ['normedskip0', 'normedskip1']: 78 | network = skip_networks.GraphLongSkipNormedNetwork( 79 | opts, 80 | arch, 81 | regularizers=get_regularizers(opts)) 82 | elif opts.architecture in ['normedskip2', 'normedskip3']: 83 | network = skip_networks.GraphSkipHopNormedNetwork( 84 | opts, 85 | arch, 86 | regularizers=get_regularizers(opts)) 87 | elif opts.architecture in ['attn0', 'attn1', 'attn2', \ 88 | 'spattn0', 'spattn1', 'spattn2']: 89 | network = networks.GraphAttentionLayerNetwork( 90 | opts, 91 | arch, 92 | regularizers=get_regularizers(opts)) 93 | return network 94 | 95 | if __name__ == "__main__": 96 | import data_util 97 | opts = options.get_opts() 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # All Graphs Lead To Rome 2 | Graph Convolutional Networks for multi-image matching. For more technical details see 3 | the ArXiv paper (https://arxiv.org/abs/1901.02078). 4 | 5 | # Dependencies 6 | All the code here is written in Python 3. You will need the following depencies: 7 | * [TensorFlow GPU](https://www.tensorflow.org/install) 8 | * [Sonnet](https://github.com/deepmind/sonnet) 9 | * [NumPy](http://www.numpy.org/) 10 | * [Matplotlib](https://matplotlib.org/users/installing.html) 11 | * [SciPy](https://www.scipy.org/install.html) 12 | * [Scikit-learn](https://scikit-learn.org/stable/install.html) 13 | * [PyYaml](https://pyyaml.org/) 14 | * [Pickle](https://docs.python.org/3/library/pickle.html) 15 | * [Pillow](https://pillow.readthedocs.io/en/stable/) 16 | * [TQDM](https://github.com/tqdm/tqdm) 17 | * [Gzip](https://docs.python.org/3/library/gzip.html) 18 | * [argparse](https://docs.python.org/3/library/argparse.html) 19 | * [argcomplete](https://pypi.org/project/argcomplete/) 20 | 21 | # Basic Code Use 22 | 23 | ## Dataset Generation 24 | To generate the synthetic datasets, call the `data_util` module to generate it in 25 | ``` 26 | $ python3 -m data_util --dataset=noise_pairwise5 27 | ``` 28 | For generating the Rome16K datasets, you need to download and untar/unzip the [Rome16K dataset](http://www.cs.cornell.edu/projects/p2f/) in a directory you choose then specify that directory in the options as `rome16_dir`. To do the initial generation, specify where you save it in `top_dir`: 29 | ``` 30 | $ python3 -m data_util.rome16k --top_dir=/your/location/Rome16K 31 | $ python3 -m data_util --dataset=rome16kgeom0 32 | 33 | ``` 34 | The synthetic datasets take around 4-10 GB, whereas the Rome16K datasets take around 200GB, so you will need space. If you don't have that space, you can specify using some of the options in `data_util` such as `max_num_tuples` 35 | 36 | ## Training and Testing 37 | If the datasets are already generated, training can just be done by calling the ` train.py`. For example: 38 | ``` 39 | $ python3 train.py \ 40 | --save_dir=save/testing \ 41 | --dataset=noise_pairwise5view5 \ 42 | --architecture=longskip2 \ 43 | --loss_type=l1 \ 44 | --geometric_loss=2 \ 45 | --use_end_bias=false \ 46 | --use_abs_value=true \ 47 | --final_embedding_dim=80 \ 48 | --learning_rate=1e-4 \ 49 | --min_learning_rate=5e-7 \ 50 | --learning_rate_continuous=True \ 51 | --learning_rate_decay_type=exponential \ 52 | --learning_rate_decay_epochs=3e-1 \ 53 | --use_unsupervised_loss=true \ 54 | --optimizer_type=adam \ 55 | --load_data=true \ 56 | --batch_size=8 \ 57 | --train_time=55 \ 58 | --test_freq=5 \ 59 | --save_interval_secs=598 \ 60 | # End Args 61 | ``` 62 | You can find the list of options for datasets and models using `python3 train.py --help`. 63 | 64 | # Code Layout 65 | The code has 3 basic components: train/test, data utilities, and models. All of these depend on the `options.py`, `myutils.py`, and `tfutils.py`, with `options.py` carrying around all the global parameters. 66 | 67 | The data utilities (`data_util`) handles the generation, saving, and loading of the datasets. The dataset classes are based off `GraphSimDataset` in `parent_dataset.py`. It heavily uses tfrecords for fast loading at training/testing time. The output of the dataset is generally a dictionary with all the necessary parts, in our case the graph laplacian and initial embeddings as well as auxillary things such as the adjacency matrx of the graph and the ground truth embeddings for testing evaluation. Rome16K also includes geometric information between the views which can be used during training. Unfortunately, sparsity is not exploited right now so all the graph matrices are stored densly. Future work include sparisfying all these. 68 | 69 | The `models` folder takes in the output of the datasets and then puts it through a [Sonnet](https://github.com/deepmind/sonnet) module based network. The modules all require as input the graph laplacian and the initial node embeddings, the sizes of which should be known in advance. The exact nature of the modules can be safely abstracted, so this is fairly modular. 70 | 71 | The training and testing is fairly straightforward - once the dataset is generated and saved on disk and the model chosen, you specify the options of how you want to train it (as shown in the example above) and run it. The above example is an example of a typical run, and can be used as starting point. To test the baselines, you will need MATLAB - they are all in the `baselines` folder. 72 | 73 | 74 | # Questions 75 | If you have any questions, please ask me at stephi@seas.upenn.edu 76 | 77 | 78 | -------------------------------------------------------------------------------- /baselines/mmatch_CVX_ALS.m: -------------------------------------------------------------------------------- 1 | function [X_bin,A,info] = mmatch_CVX_ALS(W,dimGroup,varargin) 2 | % This function is to solve 3 | % min + lambda||X||_*, st. X \in C 4 | % See Equation (10) in the paper 5 | 6 | % ---- Output: 7 | % X: a sparse binary matrix indicating correspondences 8 | % A: AA^T = X; 9 | % info: other info. 10 | 11 | % ---- Required input: 12 | % W: sparse input matrix storing scores of pairwise matches 13 | % dimGroup: a vector storing the number of points on each objects 14 | 15 | % ---- Other options: 16 | % maxRank: the restricted rank of X* (select it as large as possible) 17 | % alpha: the weight of nuclear norm 18 | % beta: the weight of l1 norm 19 | % pSelect: propotion of selected points, i.e., m'/m in section 5.4 in the paper 20 | % tol: tolerance of convergence 21 | % maxIter: maximal iteration 22 | % verbose: display info or not 23 | % eigenvalues: output eigenvalues or not 24 | 25 | % optional paramters 26 | alpha = 50; 27 | beta = 0.1; 28 | maxRank = max(dimGroup)*4; 29 | pSelect = 1; 30 | tol = 5e-4; 31 | maxIter = 200; 32 | verbose = true; 33 | eigenvalues = false; 34 | 35 | ivarargin = 1; 36 | while ivarargin <= length(varargin) 37 | switch lower(varargin{ivarargin}) 38 | case 'dimgroup' 39 | ivarargin = ivarargin+1; 40 | dimGroup = varargin{ivarargin}; 41 | case 'maxrank' 42 | ivarargin = ivarargin+1; 43 | maxRank = varargin{ivarargin}; 44 | case 'alpha' 45 | ivarargin = ivarargin+1; 46 | alpha = varargin{ivarargin}; 47 | case 'beta' 48 | ivarargin = ivarargin+1; 49 | beta = varargin{ivarargin}; 50 | case 'pselect' 51 | ivarargin = ivarargin+1; 52 | pSelect = lower(varargin{ivarargin}); 53 | case 'tol' 54 | ivarargin = ivarargin+1; 55 | tol = varargin{ivarargin}; 56 | case 'maxiter' 57 | ivarargin = ivarargin+1; 58 | maxIter = varargin{ivarargin}; 59 | case 'verbose' 60 | ivarargin = ivarargin+1; 61 | verbose = varargin{ivarargin}; 62 | case 'eigenvalues' 63 | ivarargin = ivarargin+1; 64 | eigenvalues = varargin{ivarargin}; 65 | otherwise 66 | fprintf('Unknown option ''%s'' is ignored!',varargin{ivargin}); 67 | end 68 | ivarargin = ivarargin+1; 69 | end 70 | 71 | % fprintf('Running MatchALS: alpha = %.2f, beta = %.2f, maxRank = %d, pSelect = %.2f \n',... 72 | % alpha,beta,maxRank,pSelect); 73 | 74 | W(1:size(W,1)+1:end) = 0; 75 | W = (W+W')/2; 76 | X = single(full(W)); 77 | Z = single(full(W)); 78 | Y = single(zeros(size(X))); 79 | mu = 64; 80 | 81 | n = size(X,1); 82 | maxRank = min(n,ceil(maxRank)); 83 | A = rand(n,maxRank); 84 | 85 | dimGroup = cumsum(dimGroup); 86 | dimGroup = [0;dimGroup(:)]; 87 | 88 | t0 = tic; 89 | for iter = 1:maxIter 90 | 91 | X0 = X; 92 | X = Z - (double(Y)-W+beta)/mu; 93 | B = ((A'*A+alpha/mu*eye(maxRank))\(A'*X))'; 94 | A = ((B'*B+alpha/mu*eye(maxRank))\(B'*X'))'; 95 | X = A*B'; 96 | 97 | Z = X + Y/mu; 98 | diagZ = diag(Z); 99 | % enforce the self-matching to be null 100 | for i = 1:length(dimGroup)-1 101 | ind1 = dimGroup(i)+1:dimGroup(i+1); 102 | Z(ind1,ind1) = zeros(length(ind1),length(ind1)); 103 | end 104 | % optimize for diagnal elements 105 | if pSelect == 1 106 | Z(1:size(Z,1)+1:end) = 1; 107 | else 108 | diagZ = proj2kav(diagZ,pSelect*length(diagZ)); 109 | Z(1:size(Z,1)+1:end) = diagZ; 110 | end 111 | % rounding all elements to [0,1] 112 | Z(Z<0) = 0; 113 | Z(Z>1) = 1; 114 | 115 | Y = Y + mu*(X-Z); 116 | 117 | pRes = norm(X(:)-Z(:))/n; 118 | dRes = mu*norm(X(:)-X0(:))/n; 119 | if verbose 120 | if ~mod(iter,100) 121 | % fprintf('Iter = %d, Res = (%d,%d), mu = %d \n',iter,pRes,dRes,mu); 122 | end 123 | end 124 | 125 | if pRes < tol && dRes < tol 126 | break 127 | end 128 | 129 | if pRes>10*dRes 130 | mu = 2*mu; 131 | elseif dRes>10*pRes 132 | mu = mu/2; 133 | else 134 | end 135 | 136 | end 137 | 138 | X = (X+X')/2; 139 | 140 | info.time = toc(t0); 141 | info.iter = iter; 142 | if eigenvalues 143 | info.eigenvalues = eig(X); 144 | end 145 | 146 | % X_bin = sparse(X>0.5); 147 | X_bin = X; 148 | 149 | %fprintf('Alg terminated. Time = %d, #Iter = %d, Res = (%d,%d), mu = %d \n',... 150 | % info.time,info.iter,pRes,dRes,mu); 151 | -------------------------------------------------------------------------------- /tfutils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utility functions related to tensorflow 3 | """ 4 | import tensorflow as tf 5 | import math 6 | 7 | def matmul(x,y): 8 | """Multiplies batch x with a single matrix y 9 | Input: 10 | - x (tf.Tensor size BxNxM) - batch of matrices 11 | - y (tf.Tensor size MxP) - single matrix to multiply with 12 | Output: z (tf.Tensor size BxNxP) - satisfies z[i] = x[i] * y 13 | """ 14 | return tf.einsum('bik,kj->bij', x, y) 15 | 16 | def batch_matmul(x,y): 17 | """Multiplies batch x with with batch y 18 | Input: 19 | - x (tf.Tensor size BxNxM) - batch of matrices 20 | - y (tf.Tensor size BxMxP) - batch of matrices to multiply with 21 | Output: z (tf.Tensor size BxNxP) - satisfies z[i] = x[i] * y[i] 22 | """ 23 | return tf.einsum('bik,bkj->bij', x, y) 24 | 25 | def get_sim(x): 26 | """Get similarity matrices from batch of embeddings x 27 | Input: x (tf.Tensor size BxNxM) - batch of matrices, embeddings 28 | Output: y (tf.Tensor size BxNxN) - satisfies y[i] = x[i] * x[i].T 29 | """ 30 | x_T = tf.transpose(x, perm=[0, 2, 1]) 31 | return batch_matmul(x, x_T) 32 | 33 | def get_tf_activ(activ): 34 | """Get activation function based on activation string 35 | Input: activ (string) - string describing activation function (usually relu) 36 | Output: activ_fn (callable) - callable object that performs the activation 37 | """ 38 | if activ == 'relu': 39 | return tf.nn.relu 40 | elif activ == 'leakyrelu': 41 | return tf.nn.leaky_relu 42 | elif activ == 'tanh': 43 | return tf.nn.tanh 44 | elif activ == 'elu': 45 | return tf.nn.elu 46 | 47 | def create_linear_initializer(input_size, output_size, dtype=tf.float32): 48 | """Returns a default initializer for weights of a linear module 49 | Input: 50 | - input_size (int) - number of filters in the input matrix 51 | - output_size (int) - number of filters in the output matrix 52 | - dtype (dtype, optional) - type of tensor output will be (default tf.float32) 53 | Output: initializer (callable) - function to initialize weights in network 54 | """ 55 | stddev = math.sqrt((1.3 * 2.0) / (input_size + output_size)) 56 | return tf.truncated_normal_initializer(stddev=stddev, dtype=dtype) 57 | 58 | def create_bias_initializer(unused_in, unused_out, dtype=tf.float32): 59 | """Returns a default initializer for the biases of a linear/AddBias module 60 | Input: 61 | - unused_in (int) - unused 62 | - unused_out (int) - unused 63 | - dtype (dtype, optional) - type of tensor output will be (default tf.float32) 64 | Output: initializer (callable) - function to initialize biases in network 65 | Bias initializer made to fit interface 66 | """ 67 | return tf.zeros_initializer(dtype=dtype) 68 | 69 | def bce_loss(labels, logits, add_loss=True): 70 | """Binary cross entropy loss funcion 71 | Inputs: 72 | - labels (tf.Tensor) - ground truth labels 73 | - logits (tf.Tensor) - output of the network in log space same size as labels 74 | - add_loss (boolean, optional) - add loss to tf.losses or not (default true) 75 | Output: bce (tf.Tensor) - scalar value of the mean BCE error 76 | """ 77 | bce_elements = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) 78 | bce_ = tf.reduce_mean(bce_elements) 79 | if add_loss: 80 | tf.losses.add_loss(bce_) 81 | return bce_ 82 | 83 | def l1_loss(x, y, add_loss=True): 84 | """L1 loss funcion 85 | Inputs: 86 | - x (tf.Tensor) - ground truth labels 87 | - y (tf.Tensor) - output of the network size as x 88 | - add_loss (boolean, optional) - add loss to tf.losses or not (default true) 89 | Output: l1 (tf.Tensor) - scalar value of the mean absolute error 90 | """ 91 | l1_ = tf.reduce_mean(tf.abs(x - y)) 92 | if add_loss: 93 | tf.losses.add_loss(l1_) 94 | return l1_ 95 | 96 | def l2_loss(x, y, add_loss=True): 97 | """L2 loss funcion 98 | Inputs: 99 | - x (tf.Tensor) - ground truth labels 100 | - y (tf.Tensor) - output of the network size as x 101 | - add_loss (boolean, optional) - add loss to tf.losses or not (default true) 102 | Output: l2 (tf.Tensor) - scalar value of the mean squared error 103 | """ 104 | l2_ = tf.reduce_mean(tf.square(x - y)) 105 | if add_loss: 106 | tf.losses.add_loss(l2_) 107 | return l2_ 108 | 109 | def l1_l2_loss(x, y, add_loss=True): 110 | """Addition of L1 and L2 loss funcions 111 | Inputs: 112 | - x (tf.Tensor) - ground truth labels 113 | - y (tf.Tensor) - output of the network size as x 114 | - add_loss (boolean, optional) - add loss to tf.losses or not (default true) 115 | Output: l1l2 (tf.Tensor) - scalar value of the loss 116 | """ 117 | l1_ = tf.reduce_mean(tf.abs(x - y)) 118 | l2_ = tf.reduce_mean(tf.square(x - y)) 119 | l1l2_ = l1_ + l2_ 120 | if add_loss: 121 | tf.losses.add_loss(l1l2_) 122 | return l1l2_ 123 | 124 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Run test functions and get list of all outputs 4 | """ 5 | import os 6 | import sys 7 | import glob 8 | import numpy as np 9 | import time 10 | 11 | import tensorflow as tf 12 | 13 | import data_util.datasets 14 | import model 15 | import myutils 16 | import tfutils 17 | import options 18 | 19 | loss_fns = { 20 | 'bce': tfutils.bce_loss, 21 | 'l1': tfutils.l1_loss, 22 | 'l2': tfutils.l2_loss, 23 | 'l1l2': tfutils.l1_l2_loss, 24 | } 25 | 26 | def get_test_losses(opts, sample, output, name='loss'): 27 | """Get testing loss funcion 28 | Input: 29 | - opts (options) - object with all relevant options stored 30 | - sample (dict) - sample from training dataset 31 | - output (tf.Tensor) - output from network on sample 32 | - name (string, optional) - name prefix for tensorflow scoping (default loss) 33 | Output: 34 | - gt_l1_loss (tf.Tensor) - L1 loss against ground truth 35 | - gt_l2_loss (tf.Tensor) - L2 loss against ground truth 36 | - gt_bce_loss (tf.Tensor) - BCE loss against ground truth 37 | - ssame_m (tf.Tensor) - Mean similarity of corresponding points 38 | - ssame_var (tf.Tensor) - Standard dev. of similarity of corresponding points 39 | - sdiff_m (tf.Tensor) - Mean similarity of non-corresponding points 40 | - sdiff_var (tf.Tensor) - Standard dev. of similarity of non-corresponding 41 | points 42 | """ 43 | emb = sample['TrueEmbedding'] 44 | output_sim = tfutils.get_sim(output) 45 | sim_true = tfutils.get_sim(emb) 46 | if opts.loss_type == 'bce': 47 | osim = tf.sigmoid(output_sim) 48 | osim_log = output_sim 49 | else: 50 | osim = output_sim 51 | osim_log = tf.log(tf.abs(output_sim) + 1e-9) 52 | gt_l1_loss = loss_fns['l1'](sim_true, osim, add_loss=False) 53 | gt_l2_loss = loss_fns['l2'](sim_true, osim, add_loss=False) 54 | gt_bce_loss = loss_fns['bce'](sim_true, osim, add_loss=False) 55 | num_same = tf.reduce_sum(sim_true) 56 | num_diff = tf.reduce_sum(1-sim_true) 57 | ssame_m, ssame_var = tf.nn.weighted_moments(osim, None, sim_true) 58 | sdiff_m, sdiff_var = tf.nn.weighted_moments(osim, None, 1-sim_true) 59 | 60 | return gt_l1_loss, gt_l2_loss, gt_bce_loss, ssame_m, ssame_var, sdiff_m, sdiff_var 61 | 62 | def build_test_session(opts): 63 | """Build tf.Session with relevant configuration for testing 64 | Input: opts (options) - object with all relevant options stored 65 | Output: session (tf.Session) 66 | """ 67 | config = tf.ConfigProto(device_count = {'GPU': 0}) 68 | # config.gpu_options.allow_growth = True 69 | return tf.Session(config=config) 70 | 71 | def test_values(opts): 72 | """Run testing on the network 73 | Input: opts (options) - object with all relevant options stored 74 | Output: None 75 | Saves all output in opts.save_dir, given by the user. It loads the saved 76 | configuration from the options.yaml file in opts.save_dir, so only the 77 | opts.save_dir needs to be specified. Will test and save out all test values 78 | in the test set into test_output.log in opts.save_dir 79 | """ 80 | # Get data and network 81 | dataset = data_util.datasets.get_dataset(opts) 82 | network = model.get_network(opts, opts.arch) 83 | # Sample 84 | sample = dataset.get_placeholders() 85 | print(sample) 86 | output = network(sample['Laplacian'], sample['InitEmbeddings']) 87 | losses = get_test_losses(opts, sample, output) 88 | 89 | # Tensorflow and logging operations 90 | disp_string = '{:06d} Errors: ' \ 91 | 'L1: {:.03e}, L2: {:.03e}, BCE: {:.03e}, ' \ 92 | 'Same sim: {:.03e} +/- {:.03e}, ' \ 93 | 'Diff sim: {:.03e} +/- {:.03e}, ' \ 94 | 'Time: {:.03e}, ' 95 | 96 | 97 | # Build session 98 | glob_str = os.path.join(opts.dataset_params.data_dir, 'np_test', '*npz') 99 | npz_files = sorted(glob.glob(glob_str)) 100 | vars_restore = [ v for v in tf.get_collection('weights') ] + \ 101 | [ v for v in tf.get_collection('biases') ] 102 | print(vars_restore) 103 | saver = tf.train.Saver(vars_restore) 104 | with open(os.path.join(opts.save_dir, 'test_output.log'), 'a') as log_file: 105 | with build_test_session(opts) as sess: 106 | saver.restore(sess, tf.train.latest_checkpoint(opts.save_dir)) 107 | for i, npz_file in enumerate(npz_files): 108 | sample_ = { k : np.expand_dims(v,0) for k, v in np.load(npz_file).items() } 109 | start_time = time.time() 110 | vals = sess.run(losses, { sample[k] : sample_[k] for k in sample.keys() }) 111 | end_time = time.time() 112 | dstr = disp_string.format(i, *vals, end_time - start_time) 113 | print(dstr) 114 | log_file.write(dstr) 115 | log_file.write('\n') 116 | 117 | if __name__ == "__main__": 118 | opts = options.get_opts() 119 | print("Getting options from run...") 120 | opts = options.parse_yaml_opts(opts) 121 | print("Done") 122 | test_values(opts) 123 | 124 | -------------------------------------------------------------------------------- /model/networks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import os 4 | import sys 5 | import tensorflow as tf 6 | import sonnet as snt 7 | 8 | import tfutils 9 | import myutils 10 | import options 11 | 12 | from model import layers 13 | 14 | class GraphConvLayerNetwork(snt.AbstractModule): 15 | """Basic Graph Convolutional Net, no skip connections or group norms""" 16 | def __init__(self, 17 | opts, 18 | arch, 19 | use_bias=True, 20 | initializers=None, 21 | regularizers=None, 22 | custom_getter=None, 23 | name="graphnn"): 24 | """ 25 | Input: 26 | - opts (options) - object with all relevant options stored 27 | - arch (ArchParams) - object with all relevant Architecture options 28 | - use_bias (boolean, optional) - have biases in the network (default True) 29 | - intializers (dict, optional) - specify custom initializers 30 | - regularizers (dict, optional) - specify custom regularizers 31 | - custom_getter (dict, optional) - specify custom getters 32 | - name (string, optional) - name for module for scoping (default graphnn) 33 | """ 34 | super(GraphConvLayerNetwork, self).__init__(custom_getter=custom_getter, name=name) 35 | self._nlayers = len(arch.layer_lens) 36 | self._layers = [ 37 | layers.GraphConvLayer( 38 | output_size=layer_len, 39 | activation=arch.activ, 40 | initializers=initializers, 41 | regularizers=regularizers, 42 | name="{}/graph_conv".format(name)) 43 | for layer_len in arch.layer_lens 44 | ] + [ 45 | layers.EmbeddingLinearLayer( 46 | output_size=opts.final_embedding_dim, 47 | initializers=initializers, 48 | regularizers=regularizers, 49 | name="{}/embed_lin".format(name)) 50 | ] 51 | self.normalize_emb = arch.normalize_emb 52 | 53 | def _build(self, laplacian, init_embeddings): 54 | """Applying this graph network to sample 55 | Inputs: 56 | - laplacian (tf.Tensor) - laplacian for the input graph 57 | - init_embeddings (tf.Tensor) - Initial node embeddings of the graph 58 | Outputs: output (tf.Tensor) - the output of the network 59 | """ 60 | output = init_embeddings 61 | for layer in self._layers: 62 | output = layer(laplacian, output) 63 | if self.normalize_emb: 64 | output = tf.nn.l2_normalize(output, axis=2) 65 | return output 66 | 67 | class GraphAttentionLayerNetwork(snt.AbstractModule): 68 | """Graph Attention Net, derived from https://arxiv.org/abs/1710.10903""" 69 | def __init__(self, 70 | opts, 71 | arch, 72 | use_bias=True, 73 | initializers=None, 74 | regularizers=None, 75 | custom_getter=None, 76 | name="graphnn"): 77 | super(GraphAttentionLayerNetwork, self).__init__(custom_getter=custom_getter, 78 | name=name) 79 | """ 80 | Input: 81 | - opts (options) - object with all relevant options stored 82 | - arch (ArchParams) - object with all relevant Architecture options 83 | - use_bias (boolean, optional) - have biases in the network (default True) 84 | - intializers (dict, optional) - specify custom initializers 85 | - regularizers (dict, optional) - specify custom regularizers 86 | - custom_getter (dict, optional) - specify custom getters 87 | - name (string, optional) - name for module for scoping (default graphnn) 88 | """ 89 | self._nlayers = len(arch.layer_lens) 90 | final_regularizers = None 91 | if regularizers is not None: 92 | final_regularizers = { k:v 93 | for k, v in regularizers.items() 94 | if k in ["w", "b"] } 95 | self._layers = [ 96 | layers.GraphAttentionLayer( 97 | output_size=layer_len, 98 | activation=arch.activ, 99 | sparse=arch.sparse, 100 | initializers=initializers, 101 | regularizers=regularizers, 102 | name="{}/graph_attn".format(name)) 103 | for layer_len in arch.layer_lens 104 | ] + [ 105 | layers.EmbeddingLinearLayer( 106 | output_size=opts.final_embedding_dim, 107 | initializers=initializers, 108 | regularizers=final_regularizers, 109 | name="{}/embed_lin".format(name)) 110 | ] 111 | self.normalize_emb = arch.normalize_emb 112 | 113 | def _build(self, laplacian, init_embeddings): 114 | """Applying this graph network to sample 115 | Inputs: 116 | - laplacian (tf.Tensor) - laplacian for the input graph 117 | - init_embeddings (tf.Tensor) - Initial node embeddings of the graph 118 | Outputs: output (tf.Tensor) - the output of the network 119 | """ 120 | output = init_embeddings 121 | for layer in self._layers: 122 | output = layer(laplacian, output) 123 | if self.normalize_emb: 124 | output = tf.nn.l2_normalize(output, axis=2) 125 | return output 126 | 127 | 128 | -------------------------------------------------------------------------------- /baselines/roc_curves.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | import glob 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import time 8 | import sklearn.metrics as metrics 9 | import tqdm 10 | import itertools 11 | 12 | def scatterplot_matrix(data, names, **kwargs): 13 | """Plots a scatterplot matrix of subplots. Each row of "data" is plotted 14 | against other rows, resulting in a nrows by nrows grid of subplots with the 15 | diagonal subplots labeled with "names". Additional keyword arguments are 16 | passed on to matplotlib's "plot" command. Returns the matplotlib figure 17 | object containg the subplot grid.""" 18 | numvars, numdata = data.shape 19 | fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8)) 20 | fig.subplots_adjust(hspace=0.05, wspace=0.05) 21 | 22 | for ax in axes.flat: 23 | # Hide all ticks and labels 24 | ax.xaxis.set_visible(False) 25 | ax.yaxis.set_visible(False) 26 | 27 | # Set up ticks only on one side for the "edge" subplots... 28 | if ax.is_first_col(): 29 | ax.yaxis.set_ticks_position('left') 30 | if ax.is_last_col(): 31 | ax.yaxis.set_ticks_position('right') 32 | if ax.is_first_row(): 33 | ax.xaxis.set_ticks_position('top') 34 | if ax.is_last_row(): 35 | ax.xaxis.set_ticks_position('bottom') 36 | 37 | # Plot the data. 38 | for i, j in zip(*np.triu_indices_from(axes, k=1)): 39 | for x, y in [(i,j), (j,i)]: 40 | axes[x,y].scatter(data[x], data[y], **kwargs) 41 | 42 | # Label the diagonal subplots... 43 | for i, label in enumerate(names): 44 | axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction', 45 | ha='center', va='center') 46 | 47 | # Turn on the proper x or y axes ticks. 48 | for i, j in zip(range(numvars), itertools.cycle((-1, 0))): 49 | axes[j,i].xaxis.set_visible(True) 50 | axes[i,j].yaxis.set_visible(True) 51 | 52 | return fig 53 | 54 | def main(verbose): 55 | # Constants 56 | N_ = 2048 57 | # MATLAB Output Files 58 | opt_names = [ x[:-len('TestErrors.log')] for x in sorted(glob.glob('*.log')) ] 59 | all_names = opt_names + ['GCN', 'AlmostPerfect', 'Random' ] 60 | # all_names = [ 'MatchALS015Iter', 'MatchALS025Iter', 'MatchALS050Iter', 'MatchALS100Iter' ] 61 | num_outputs = len(os.listdir('{}Outputs'.format(opt_names[0]))) 62 | # Tensorflow output file 63 | fname = 'GCN12Layer.npz' 64 | with open(fname, 'rb') as f: 65 | ld = dict(np.load(fname)) 66 | temb = ld['trueemb'] 67 | outemb = ld['out'] 68 | assert len(temb) == num_outputs 69 | os.makedirs('ROC-Curves', exist_ok=True) 70 | os.makedirs('P-R-Curves', exist_ok=True) 71 | 72 | roc_ = { k : [] for k in all_names } 73 | p_r_ = { k : [] for k in all_names } 74 | for i in tqdm.tqdm(range(num_outputs), disable=(verbose != 1)): 75 | adjmat = np.dot(temb[i], temb[i].T).reshape(-1) # They are all the same 76 | # MATLAB Outputs 77 | fig_roc, ax_roc = plt.subplots() 78 | fig_p_r, ax_p_r = plt.subplots() 79 | ax_roc.scatter([0,0,1,1],[0,1,0,1]) 80 | ax_p_r.scatter([0,0,1,1],[0,1,0,1]) 81 | for k in all_names: 82 | # Compute things 83 | if k == 'GCN': 84 | output = np.abs(np.dot(outemb[i], outemb[i].T)).reshape(-1) 85 | elif k == 'AlmostPerfect': 86 | output = np.abs(adjmat + np.random.randn(*adjmat.shape)*0.05) 87 | elif k == 'Random': 88 | output = np.abs(np.random.randn(*adjmat.shape)*0.25) 89 | else: 90 | fname = '{}Outputs/{:04d}.npy'.format(k, i+1) 91 | # print(fname) 92 | with open(fname, 'rb') as f: 93 | o = np.load(f) 94 | output = o.reshape(-1) 95 | # Get areas 96 | roc_[k].append(metrics.roc_auc_score(adjmat, output)) 97 | p_r_[k].append(metrics.average_precision_score(adjmat, output)) 98 | if verbose > 1: 99 | print('{0:04d} {1:<20}: ROC: {2:.03e}, P-R: {3:.03e}'.format(i, k, 100 | roc_[k][-1], 101 | p_r_[k][-1])) 102 | # PLot lines 103 | FPR, TPR, _ = metrics.roc_curve(adjmat, output) 104 | precision, recall, _ = metrics.precision_recall_curve(adjmat, output) 105 | ax_roc.plot(FPR, TPR, label='{} ROC ({:.03e})'.format(k, roc_[k][-1])) 106 | ax_p_r.plot(precision, recall, label='{} P-R ({:.03e})'.format(k, p_r_[k][-1])) 107 | 108 | # Finish plots 109 | ax_roc.set_xlabel('False Positive Rate') 110 | ax_p_r.set_xlabel('Precision') 111 | ax_roc.set_ylabel('True Positive Rate') 112 | ax_p_r.set_ylabel('Recall') 113 | ax_roc.set_title('ROC Curves') 114 | ax_p_r.set_title('Precision Recall Curves') 115 | ax_roc.legend() 116 | ax_p_r.legend() 117 | fig_roc.savefig('ROC-Curves/{:04d}.png'.format(i)) 118 | fig_p_r.savefig('P-R-Curves/{:04d}.png'.format(i)) 119 | plt.close(fig_roc) 120 | plt.close(fig_p_r) 121 | dispstr = '{:<15}: ROC: {:.03e} +/- {:.03e} ; P-R: {:.03e} +/- {:.03e}' 122 | for k in all_names: 123 | roc_mean = np.mean(roc_[k]) 124 | roc_std = np.std(roc_[k]) 125 | p_r_mean = np.mean(p_r_[k]) 126 | p_r_std = np.std(p_r_[k]) 127 | print(dispstr.format(k, roc_mean, roc_std, p_r_mean, p_r_std)) 128 | 129 | # plot_names = [ 'MatchALS015Iter', 'PGDDS015Iter', 'Spectral', 'GCN' ] 130 | plot_names = all_names 131 | plot_vars = np.stack([ p_r_[k] for k in plot_names ], axis=0) 132 | # # scatter_fig = scatterplot_matrix(plot_vars, plot_names) 133 | # plt.scatter(p_r_['MatchALS100Iter'], p_r_['GCN']) 134 | # plt.scatter([0,0,1,1], [0,1,0,1]) 135 | # plt.plot([0,1], [0,1]) 136 | # plt.show() 137 | # print(np.corrcoef(plot_vars)) 138 | better_than = np.zeros((len(plot_names), len(plot_names))) 139 | for i in range(len(plot_names)): 140 | for j in range(len(plot_names)): 141 | better_than[i,j] = np.sum(np.array(p_r_[plot_names[i]]) < np.array(p_r_[plot_names[j]])) 142 | print(better_than) 143 | 144 | if __name__ == '__main__': 145 | main(1) 146 | 147 | -------------------------------------------------------------------------------- /data_util/rome16k/__main__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate the Rome16K dataset and the n-tuples associating frames. Need to specify 3 | a top_directory where the Rome16K files have been unzipped 4 | (http://s3.amazonaws.com/LocationRecognition/Datasets/Rome16K.tar.gz). 5 | """ 6 | import numpy as np 7 | import os 8 | import sys 9 | import argparse 10 | import pickle 11 | import tqdm 12 | import time 13 | import itertools as it 14 | 15 | from data_util.rome16k import scenes 16 | from data_util.rome16k import parse 17 | import myutils 18 | 19 | def get_build_scene_opts(): 20 | """Parse arguments from command line and get all options for training.""" 21 | parser = argparse.ArgumentParser(description='Train motion estimator') 22 | parser.add_argument('--build_tuples', 23 | type=myutils.str2bool, 24 | default=True, 25 | help='Name of pickle file to load - None if no loading') 26 | parser.add_argument('--save_imsizes', 27 | type=myutils.str2bool, 28 | default=False, 29 | help='Save files of images, which requires internet connectivity') 30 | parser.add_argument('--overwrite_tuples', 31 | type=myutils.str2bool, 32 | default=False, 33 | help='If tuple file exists, overwrite it') 34 | parser.add_argument('--top_dir', 35 | default='/NAS/data/stephen/Rome16K', 36 | help='Storage location for pickle files') 37 | parser.add_argument('--save', 38 | choices=parse.bundle_files + [ 'all' ], 39 | default='all', 40 | help='Save out bundle file to pickle file') 41 | parser.add_argument('--min_points', 42 | default=80, 43 | type=int, 44 | help='Minimum overlap of points for connection') 45 | parser.add_argument('--max_points', 46 | default=150, 47 | type=int, 48 | help='Maximum overlap of points for connection') 49 | parser.add_argument('--max_tuple_size', 50 | default=4, 51 | type=int, 52 | help='Maximum tuple size') 53 | parser.add_argument('--max_num_tuples', 54 | type=int, 55 | default=-1, 56 | help='Maximum number of tuples you want to generate') 57 | parser.add_argument('--verbose', 58 | default=1, 59 | type=int, 60 | help='Print out everything') 61 | 62 | opts = parser.parse_args() 63 | return opts 64 | 65 | def factorial(n, stop=0): 66 | o = 1 67 | while n > stop: 68 | o *= n 69 | n -= 1 70 | return o 71 | 72 | def choose(n, k): 73 | return factorial(n, stop=k) // factorial(n - k) 74 | 75 | def silent(x): 76 | pass 77 | 78 | def process_scene_bundle(opts, bundle_file, scene_fname, tuples_fname): 79 | if opts.verbose > 0: 80 | myprint = lambda x: print(x) 81 | else: 82 | myprint = lambda x: silent(x) 83 | ######### Build and save out scene file ########### 84 | scene = parse.parse_bundle(bundle_file, 85 | opts.top_dir, 86 | get_imsize=opts.save_imsizes, 87 | verbose=opts.verbose > 1) 88 | parse.save_scene(scene, scene_fname, opts.verbose > 0) 89 | 90 | if not opts.build_tuples: 91 | return 92 | 93 | ######### Build and save out k-tuples ########### 94 | n = len(scene.cams) 95 | cam_pts = lambda i: set([ f.point for f in scene.cams[i].features ]) 96 | tuples_full = [] 97 | tuples_sizes = [] 98 | # Length 2 is a special case 99 | myprint("Building pairs...") 100 | start_time = time.time() 101 | pairs, tsizes = [], [] 102 | for x in tqdm.tqdm(it.combinations(range(n),2), total=choose(n,2), disable=opts.verbose < 1): 103 | p = len(cam_pts(x[0]) & cam_pts(x[1])) 104 | if p >= opts.min_points: 105 | pairs.append(x) 106 | tsizes.append(p) 107 | if opts.max_num_tuples > 0 and len(pairs) > opts.max_num_tuples: 108 | break 109 | 110 | tuples_full.append(pairs) 111 | tuples_sizes.append(tsizes) 112 | end_time = time.time() 113 | myprint("Done with pairs ({} sec)".format(end_time-start_time)) 114 | # Length 3 and above 115 | for k in range(3,opts.max_tuple_size+1): 116 | myprint("Selecting {}-tuples...".format(k)) 117 | start_time = time.time() 118 | tlist, tsizes = [], [] 119 | tvals = tuples_full[-1] 120 | for (i, x) in tqdm.tqdm(enumerate(tvals), total=len(tvals), disable=opts.verbose < 1): 121 | xpts = cam_pts(x[0]) 122 | for xx in x[1:]: 123 | xpts = xpts & cam_pts(xx) 124 | for j in range(x[-1]+1,n): 125 | p = len(cam_pts(j) & xpts) 126 | if p >= opts.min_points: 127 | tlist.append(x + (j,)) 128 | tsizes.append(p) 129 | if opts.max_num_tuples > 0 and len(tlist) > opts.max_num_tuples: 130 | break 131 | if opts.max_num_tuples > 0 and len(tlist) > opts.max_num_tuples: 132 | break 133 | tuples_full.append(tlist) 134 | tuples_sizes.append(tsizes) 135 | end_time = time.time() 136 | myprint("Done with {}-tuples ({} sec)".format(k, end_time-start_time)) 137 | 138 | tuples = [ [ x for i, x in enumerate(tups) if tsizes[i] <= opts.max_points ] 139 | for tups, tsizes in zip(tuples_full, tuples_sizes) ] 140 | 141 | myprint("Saving tuples...") 142 | with open(tuples_fname,'wb') as f: 143 | pickle.dump(tuples, f, protocol=pickle.HIGHEST_PROTOCOL) 144 | # with open(triplets_name(bundle_file, lite=True),'wb') as f: 145 | # pickle.dump(triplets[:100].tolist(), f, protocol=pickle.HIGHEST_PROTOCOL) 146 | myprint("Done") 147 | 148 | 149 | 150 | opts = get_build_scene_opts() 151 | if opts.save == 'all': 152 | N = len(parse.bundle_files) 153 | for i, bundle_file in enumerate(parse.bundle_files): 154 | scene_fname=os.path.join(opts.top_dir,'scenes',parse.scene_fname(bundle_file)) 155 | tuples_fname=os.path.join(opts.top_dir,'scenes',parse.tuples_fname(bundle_file)) 156 | if opts.verbose > 0: 157 | print('Computing {} ({} of {})...'.format(bundle_file,i+1,N)) 158 | if not opts.overwrite_tuples and os.path.exists(tuples_fname): 159 | if opts.verbose > 0: 160 | print('Already computed tuples, skipping...') 161 | continue 162 | start_time = time.time() 163 | process_scene_bundle(opts, bundle_file, scene_fname, tuples_fname) 164 | end_time = time.time() 165 | if opts.verbose > 0: 166 | print('Finished {} ({:0.3f} sec)'.format(bundle_file,end_time-start_time)) 167 | else: 168 | scene_fname=os.path.join(opts.top_dir,'scenes',parse.scene_fname(opts.save)) 169 | tuples_fname=os.path.join(opts.top_dir,'scenes',parse.tuples_fname(opts.save)) 170 | process_scene_bundle(opts, opts.save, scene_fname, tuples_fname) 171 | 172 | -------------------------------------------------------------------------------- /baselines/run_tests.m: -------------------------------------------------------------------------------- 1 | function run_tests(npz_files, views, save_out) 2 | 3 | if nargin < 2 4 | views = 3; 5 | end 6 | if nargin < 3 7 | save_out = false; 8 | end 9 | 10 | npymatlab_path = 'npy-matlab/npy-matlab'; 11 | pathCell = regexp(path, pathsep, 'split'); 12 | onPath = any(strcmp(npymatlab_path, pathCell)); 13 | if ~onPath 14 | addpath(npymatlab_path); 15 | end 16 | [~,~,~] = mkdir('/tmp/unzip'); 17 | if save_out 18 | [~,~,~] = mkdir('Adjmats'); 19 | end 20 | 21 | v = views; 22 | p = 80; 23 | n = p*v; 24 | dimGroups = ones(v,1)*p; 25 | params015.maxiter = 15; 26 | params025.maxiter = 25; 27 | params050.maxiter = 50; 28 | params100.maxiter = 100; 29 | params200.maxiter = 200; 30 | 31 | metric_info = { ... 32 | { 'l1', 'L1: %.03e, ', @mean }, ... 33 | { 'l2', 'L2: %.03e, ', @mean }, ... 34 | { 'bce', 'BCE: %.03e, ' , @mean }, ... 35 | { 'ssame_m', 'Same sim: %.03e ' , @mean }, ... 36 | { 'ssame_s', '+/- %.03e, ' , @(x) sqrt(mean(x.^2)) }, ... 37 | { 'sdiff_m', 'Diff sim: %.03e ' , @mean }, ... 38 | { 'sdiff_s', '+/- %.03e, ' , @(x) sqrt(mean(x.^2)) }, ... 39 | { 'roc', 'Area under ROC: %.03e, ' , @mean }, ... 40 | { 'pr', 'Area under P-R: %.03e, ' , @mean }, ... 41 | }; 42 | metrics = cell(length(metric_info),1); 43 | for i = 1:length(metric_info) 44 | metrics{i} = zeros(length(npz_files),1); 45 | end 46 | % 'MatchALS400Iter', @(W) mmatch_CVX_ALS(W, dimGroups, 'maxrank', p, 'maxiter', 400); ... 47 | % 'PGDDS200Iter', @(W) PGDDS(W, dimGroups, p, params200); ... 48 | % 'PGDDS100Iter', @(W) PGDDS(W, dimGroups, p, params100); ... 49 | test_fns = { ... 50 | 'Spectral', @(W) myspectral(W, p); ... 51 | 'MatchALS015Iter', @(W) mmatch_CVX_ALS(W, dimGroups, 'maxrank', p, 'maxiter', 15); ... 52 | 'MatchALS025Iter', @(W) mmatch_CVX_ALS(W, dimGroups, 'maxrank', p, 'maxiter', 25); ... 53 | 'MatchALS050Iter', @(W) mmatch_CVX_ALS(W, dimGroups, 'maxrank', p, 'maxiter', 50); ... 54 | 'MatchALS100Iter', @(W) mmatch_CVX_ALS(W, dimGroups, 'maxrank', p, 'maxiter', 100); ... 55 | 'PGDDS015Iter', @(W) PGDDS(W, dimGroups, p, params015); ... 56 | 'PGDDS025Iter', @(W) PGDDS(W, dimGroups, p, params025); ... 57 | 'PGDDS050Iter', @(W) PGDDS(W, dimGroups, p, params050); ... 58 | }; 59 | 60 | saveout_str = '%sOutputs/%04d.npy'; 61 | for test_fn_index = 1:size(test_fns,1) 62 | test_fn_tic = tic; 63 | test_fn = test_fns{test_fn_index,2}; 64 | fid = fopen(sprintf('%sTestErrors.log', test_fns{test_fn_index,1}), 'w'); 65 | if save_out 66 | [~,~,~] = mkdir(sprintf('%sOutputs', test_fns{test_fn_index,1})) 67 | end 68 | fprintf('%s Method:\n', test_fns{test_fn_index,1}) 69 | test_index = 0; 70 | for npz_index = 1:length(npz_files) 71 | fprintf('Matrix %03d of %03d\r', npz_index, length(npz_files)) 72 | [ W, Agt ] = load_npz(npz_files{npz_index}); 73 | tic; 74 | A_output = test_fn(W); 75 | Ah = max(0,min(1,A_output)); 76 | run_time = toc; 77 | values = evaluate_tests(Ah, Agt); 78 | for metric_idx = 1:length(metrics) 79 | metrics{metric_idx}(npz_index) = values(metric_idx); 80 | end 81 | disp_values(metric_info, fid, npz_index, values, run_time); 82 | test_index = test_index + 1; 83 | if save_out 84 | output_name = sprintf(saveout_str, test_fns{test_fn_index,1}, npz_index); 85 | writeNPY(single(Ah), output_name); 86 | adjmat_name = sprintf('Adjmats/%04d.npy', npz_index); 87 | if ~exist(adjmat_name) 88 | writeNPY(Agt, adjmat_name) 89 | end 90 | end 91 | end 92 | fprintf('\n') 93 | fclose(fid); 94 | means = zeros(length(metrics),1); 95 | for metric_idx = 1:length(metrics) 96 | means(metric_idx) = metric_info{metric_idx}{3}(metrics{metric_idx}); 97 | end 98 | disp_values(metric_info, 1, test_fn_index, means, run_time); 99 | fprintf(1, 'Total time: %.03f seconds\n', toc(test_fn_tic)); 100 | end 101 | 102 | disp('Finished'); 103 | 104 | end 105 | 106 | function disp_values(metric_info, fid, idx, values, time) 107 | fprintf(fid, '%06d Errors: ', idx); 108 | for i = 1:length(values) 109 | fprintf(fid, metric_info{i}{2}, values(i)); 110 | end 111 | fprintf(fid, 'Time: %.03e\n', time); 112 | end 113 | 114 | function [ means ] = get_metric_means(metrics) 115 | end 116 | 117 | function [ values ] = evaluate_tests(Ah, Agt) 118 | [l1, l2, bce] = testOutput_soft(Ah,Agt); 119 | [ssame, ssame_std, sdiff, sdiff_std] = testOutputhist(Ah,Agt); 120 | % [roc, pr] = testOutput_roc_pr(Ah,Agt); 121 | pr = 0; 122 | roc = 0; 123 | values = [ l1, l2, bce, ssame, ssame_std, sdiff, sdiff_std, roc, pr ]; 124 | end 125 | 126 | 127 | function [ssame, ssame_std, sdiff, sdiff_std] = testOutputhist(Ah,Agt) 128 | 129 | N = sum(sum(Agt)); 130 | M = sum(sum(1-Agt)); 131 | ssame = sum(sum(Ah.*Agt)) / N; 132 | ssame_std = sqrt(sum(sum((Ah.*Agt).^2)) / N - ssame^2); 133 | sdiff = sum(sum(Ah.*(1-Agt))) / M; 134 | sdiff_std = sqrt(sum(sum((Ah.*(1-Agt)).^2)) / M - sdiff^2); 135 | 136 | end 137 | 138 | function [l1, l2, bce] = testOutput_soft(Ah,Agt) 139 | 140 | l1 = mean2(abs(Ah-Agt)); 141 | l2 = mean2((Ah-Agt).^2); 142 | bce = -mean2(Agt.*log2(eps+Ah) + (1-Agt).*log2(eps+1-Ah)); 143 | 144 | end 145 | 146 | function [roc, pr] = testOutput_roc_pr(Ah,Agt) 147 | 148 | [TP, TN, FP, FN] = compute_thresh_errs(Ah, Agt); 149 | m = length(TP); 150 | FPR = (FP ./ max(1e-8, FP + TN)); 151 | TPR = (TP ./ max(1e-8, TP + FN)); 152 | precision = (TP ./ max(1e-8, TP + FP)); 153 | recall = (TP ./ max(1e-8, TP + FN)); 154 | % disp(size(FPR)) 155 | % disp(size(TPR)) 156 | % disp(class(FPR)) 157 | % disp(class(TPR)) 158 | roc = abs(trapz(FPR, TPR)); 159 | pr = abs(trapz(precision, recall)); 160 | 161 | end 162 | 163 | function [TP, TN, FP, FN] = compute_thresh_errs(output, adjmat, N_cutoffs) 164 | if nargin < 3 165 | N_cutoffs = 2048; 166 | end 167 | a = int32(adjmat); 168 | M_T = sum(a); 169 | M_F = numel(a) - M_T; 170 | 171 | TP = zeros(N_cutoffs, 1); 172 | TN = zeros(N_cutoffs, 1); 173 | FP = zeros(N_cutoffs, 1); 174 | FN = zeros(N_cutoffs, 1); 175 | for idx = 1:N_cutoffs 176 | i = N_cutoffs - idx; 177 | thresh = (1.0*i) / (N_cutoffs-1); 178 | o = int32(output > thresh); 179 | [ TP_, TN_, FP_, FN_ ] = calc_classifications(o,a); 180 | TP(idx) = double(TP_); 181 | TN(idx) = double(TN_); 182 | FP(idx) = double(FP_); 183 | FN(idx) = double(FN_); 184 | end 185 | end 186 | 187 | function [ TP, TN, FP, FN ] = calc_classifications(o, a) 188 | TP = sum(sum(o.*a)); 189 | TN = sum(sum((1-o).*(1-a))); 190 | FP = sum(sum(o.*(1-a))); 191 | FN = sum(sum((1-o).*a)); 192 | end 193 | 194 | function [W, Agt] = load_npz(npz_file) 195 | unzip(npz_file, '/tmp/unzip'); 196 | AdjMat = readNPY('/tmp/unzip/AdjMat.npy'); 197 | TrueEmbedding = readNPY('/tmp/unzip/TrueEmbedding.npy'); 198 | W = squeeze(double(AdjMat)) + eye(size(AdjMat)); 199 | Xgt = squeeze(double(TrueEmbedding)); 200 | Agt = Xgt*Xgt'; 201 | end 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /baselines/PGDDS.m: -------------------------------------------------------------------------------- 1 | function [XfXft, Xf,run_time] = PGDDS(X_in,dimGroup,K,param) 2 | % Projected Gradient Descent Doubly Stochastic PGSS 3 | 4 | t0 = tic; 5 | % display(sprintf('Distributed gradient descent begins.....')) 6 | 7 | if nargin <3 8 | error('Not enough input arguments') 9 | end 10 | 11 | param.nObj = numel(dimGroup); 12 | param.K = K; 13 | param.dimGroup = dimGroup; 14 | 15 | 16 | 17 | if ~isfield(param,'Adj') 18 | param.Adj = ones(param.nObj); 19 | param.Adj(1:size(param.Adj,1)+1:end) = 0; 20 | end 21 | 22 | if ~all(all( param.Adj'== param.Adj)) 23 | error('Adjacency matrix not symmetric'); 24 | end 25 | 26 | if ~isfield(param,'flagDebug') 27 | param.flagDebug=0; 28 | end 29 | 30 | 31 | if isfield(param,'maxiter') 32 | maxIter = param.maxiter; 33 | else 34 | maxIter = 200; 35 | end 36 | 37 | 38 | if any(any( triu(1-param.Adj,1))) 39 | param.flagfullGraph = 0; 40 | else 41 | param.flagfullGraph = 1; 42 | end 43 | 44 | if ~isfield(param,'t') 45 | param.t= [1]; 46 | end 47 | 48 | param.Adj = param.Adj>0; 49 | 50 | %% 51 | %param.Akron = kron(param.Adj+speye(size(param.Adj)),speye(param.K)); % TODO: remove this one 52 | param.csdimGroup = [0;cumsum(param.dimGroup(:))]; % is this necessary? 53 | 54 | % random initialization 55 | X0 = rand(param.csdimGroup(end),param.K); 56 | X0 = X0 ./ sum(X0,2); 57 | 58 | idx_diag = 1:size(X_in,1)+1:size(X_in,1)*size(X_in,2); 59 | X_in(idx_diag) = 0; 60 | 61 | idx_diag = 1:size(param.Adj,1)+1:size(param.Adj,1)*size(param.Adj,2); 62 | param.Adj(idx_diag) = 0; 63 | dmax = max( sum(param.Adj)); 64 | %% MAIN LOOP 65 | for il = 1:numel(param.t) 66 | 67 | X = X0; 68 | gamma_t = param.t(il)/(1+param.t(il)); 69 | 70 | maxstep = 2/( 2*gamma_t + 3*(1-gamma_t)*dmax ); 71 | step = 0.99*maxstep; 72 | 73 | % max step size computation 74 | 75 | for iIter=1:maxIter 76 | 77 | % gradient computation 78 | gij = - X_in*X; 79 | 80 | XtXe2 = zeros( [param.K,param.K,param.nObj]); 81 | XXXtd = zeros(size(X)); 82 | 83 | for i=1:param.nObj 84 | idxr = param.csdimGroup(i)+1: param.csdimGroup(i+1); 85 | XtXe2(:,:,i) = X(idxr,:)'*X(idxr,:); 86 | XXXtd(idxr,:) = X(idxr,:)*XtXe2(:,:,i); 87 | end 88 | gi = XXXtd-X; 89 | 90 | XtXesum = zeros([param.K,param.K,param.nObj]); 91 | 92 | for i=1:param.nObj 93 | idxr = param.csdimGroup(i)+1: param.csdimGroup(i+1); 94 | XtXesum(:,:,i) = sum(XtXe2(:,:,param.Adj(i,:)),3) + XtXe2(:,:,i); 95 | gij(idxr,:) = gij(idxr,:) + X(idxr,:)*XtXesum(:,:,i); 96 | end 97 | 98 | G = gamma_t*gi + (1- gamma_t)*gij; % Euclidean gradient wrt to x parametrization 99 | 100 | 101 | 102 | % update step 103 | X = X - step*G; 104 | 105 | 106 | 107 | % projection step 108 | for iview =1:numel(param.dimGroup) 109 | idx = (param.csdimGroup(iview)+1):param.csdimGroup(iview+1); 110 | 111 | if size(X(idx,:),1) == size(X(idx,:),2) 112 | X(idx,:) = projectDADMM(X(idx,:)); 113 | else 114 | X(idx,:) = projectPDADMM(X(idx,:)); 115 | end 116 | end 117 | 118 | 119 | 120 | end 121 | 122 | 123 | 124 | 125 | 126 | % perturb a little bit to avoid undersirable stationary points 127 | for iview =1:numel(param.dimGroup) 128 | idx = (param.csdimGroup(iview)+1):param.csdimGroup(iview+1); 129 | X0(idx,:) = projectPDADMM(X(idx,:) +0.01*randn(size(X(idx,:)))); 130 | end 131 | 132 | end 133 | 134 | 135 | 136 | %% truncation to partial permutation matrix 137 | Xf = X; 138 | 139 | % for i=1:param.nObj 140 | % idxr = param.csdimGroup(i)+1: param.csdimGroup(i+1); 141 | % [ass,~] = munkres(-Xf(idxr,:)); 142 | % idx=1:param.dimGroup(i); 143 | % Xf(idxr,:) = sparse(idx(ass >0),ass(ass >0),1,param.dimGroup(i),param.K); 144 | % end 145 | 146 | 147 | XfXft = Xf*Xf'; 148 | 149 | run_time = toc(t0); 150 | 151 | % display(sprintf('Distributed gradient descent terminated in %0.2f seconds...',run_time)) 152 | end 153 | 154 | function X= projectDADMM(X0) 155 | 156 | 157 | if size(X0,1)~= size(X0,2) 158 | error('Matrix not square') 159 | end 160 | 161 | k = size(X0,1); 162 | rho = 1; 163 | tol = k*10^(-4); 164 | 165 | maxiter = 500; 166 | 167 | U= zeros(size(X0)); 168 | Z= zeros(size(X0)); 169 | 170 | bf1= ones(k,1); 171 | 172 | 173 | for i=1:maxiter 174 | 175 | % compute X 176 | B = (1/(1+rho))*X0 + (rho/(1+rho))*(Z-U); 177 | 178 | nu = -(1/k)*sum(B,1)'; 179 | mu = (1/k)*(-sum(B,2) + bf1 -bf1*sum(nu) ); 180 | X = B+ mu*bf1' + bf1*nu'; 181 | 182 | 183 | 184 | % update Z 185 | XpU = X + U; 186 | 187 | Zprev = Z; 188 | Z = max( 0, XpU); 189 | 190 | % update X 191 | U = XpU-Z; 192 | 193 | % compute primal residual 194 | primal_res = X-Z; 195 | 196 | % compute dual residual 197 | dual_res = - rho*(Z-Zprev); 198 | 199 | % termination condition 200 | if norm(primal_res(:),2) < tol && norm(dual_res(:),2) < tol 201 | break; 202 | end 203 | 204 | end 205 | 206 | 207 | X = max(eps,X); 208 | X = X ./ sum(X,2); 209 | 210 | end 211 | 212 | function X= projectPDADMM(X0) 213 | 214 | 215 | if size(X0,1) > size(X0,2) 216 | error('Number of rows should not be greater than the number of columns') 217 | end 218 | 219 | 220 | 221 | k = size(X0,1); % # of rows 222 | m = size(X0,2); % # of columns 223 | 224 | 225 | rho = 1; 226 | tol = m*10^(-3); 227 | 228 | maxiter = 200; 229 | 230 | U= zeros(size(X0)); 231 | Z= zeros(size(X0)); 232 | t = zeros(m,1); 233 | w = zeros(m,1); 234 | 235 | 236 | lambda = k +((rho+1)/rho); 237 | 238 | %pinvAA2 =[ (1/m)*( eye(k) + (rho/(1+rho))*ones(k)) -(1/m)*(rho/(1+rho))*ones(k,m); ... 239 | % -(1/m)*(rho/(1+rho))*ones(m,k) (1/lambda)*eye(m)+(k*rho/(lambda*m*(1+rho)) )*ones(m)]; 240 | 241 | %AA = [ m*eye(k) ones(k,m);... 242 | % ones(m,k) lambda*eye(m)] ; 243 | %norm(inv(AA)-pinvAA2,'fro') 244 | %pause(2) 245 | 246 | b1k = ones(k,1); 247 | b1m = ones(m,1); 248 | 249 | for i=1:maxiter 250 | 251 | % compute X 252 | B = X0 + rho*(Z-U); 253 | b = t-w; 254 | 255 | if 0 256 | 257 | % bb = [ (1+rho)*b1k-sum(B,2); ... 258 | % (1+rho)*(b1m-b)-sum(B,1)' ]; 259 | 260 | 261 | %warning('This can be done more efficiently....') 262 | % yy = pinvAA2*bb; 263 | % mu = yy(1:k); 264 | % nu = yy(k+1:end); 265 | 266 | else 267 | 268 | bb1 = (1+rho)*b1k-sum(B,2); 269 | bb2 = (1+rho)*(b1m-b)-sum(B,1)'; 270 | sbb1 = sum(bb1); 271 | sbb2 = sum(bb2); 272 | 273 | mu = (1/m)*( bb1 + (rho/(1+rho))*repmat(sbb1,[k 1])) ... 274 | -((1/m)*(rho/(1+rho)))*repmat(sbb2,[k 1]); 275 | 276 | nu = -((1/m)*(rho/(1+rho)))*repmat(sbb1,[m 1]) + ... 277 | +(1/lambda)*bb2+(k*rho/(lambda*m*(1+rho)) )*repmat(sbb2,[m 1]); 278 | 279 | end 280 | 281 | % solve for X,S 282 | X = (1/(1+rho))*(B+ repmat(mu,[1 m]) + repmat(nu',[k 1])); 283 | s =(1/rho)*nu + b; 284 | 285 | 286 | % update Z 287 | XpU = X + U; 288 | 289 | Zprev = Z; 290 | Z = max( 0, XpU); 291 | 292 | tprev = t; 293 | t = max(0,s+w); 294 | 295 | % update X 296 | U = XpU-Z; 297 | w = w+s-t; 298 | 299 | % compute primal residual 300 | primal_res = [X-Z;s'-t'] ; 301 | 302 | % compute dual residual 303 | dual_res = - rho*( [Z-Zprev;t'-tprev'] ); 304 | 305 | % termination condition 306 | if norm(primal_res(:),2) < tol && norm(dual_res(:),2) < tol 307 | break; 308 | end 309 | 310 | end 311 | 312 | 313 | X = max(0,X); 314 | X = X ./ sum(X,2); 315 | 316 | end 317 | 318 | 319 | -------------------------------------------------------------------------------- /data_util/rome16k/scenes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Saver(object): 4 | """Class to save out objects in pickle format via dictionaries""" 5 | def __init__(self, fields, id_num, values=None): 6 | """ 7 | Inputs: 8 | - fields (list of strings) - names of values that will be stored 9 | - id_num (int) - integer id of this object 10 | - values (dict, optional) - dictionary of values (keys subset of fields) 11 | """ 12 | for f in fields: 13 | setattr(self, f, None) 14 | self.fields = fields 15 | self.id = id_num 16 | if values is not None: 17 | for f in values.keys() & fields: 18 | self.__dict__[f] = values[f] 19 | 20 | def get_dict(self): 21 | """Get dictionary representation of this object 22 | Output: dict_ (dict) - dictionary of all fields of this object, flattened 23 | """ 24 | dict_ = { f: self.output(self.__dict__[f]) for f in self.fields } 25 | dict_['id'] = self.id 26 | return dict_ 27 | 28 | def output(self, obj): 29 | """Output flattened object 30 | Input: obj (object) - np.array, list, or Saver objects get flattened, 31 | otherwise it just returns the object itself 32 | Output: obj_flatted (object) - outputs flattened object 33 | """ 34 | if type(obj) is np.ndarray: 35 | return obj.tolist() 36 | elif type(obj) is list: 37 | return [ self.output(x) for x in obj ] 38 | elif isinstance(obj, Saver): 39 | return obj.id 40 | else: 41 | return obj 42 | 43 | 44 | class Feature(Saver): 45 | """Class representing Feature in Rome16K 46 | This class functions like a pointer between Camera and Point, but with an 47 | associated SIFT descriptor and 2d location/scale/orientation 48 | """ 49 | def __init__(self, id_num, values=None): 50 | """ 51 | Inputs: id_num (int) - id of this object 52 | The fields of this object are: 53 | 'pos' - 2d location of this point (calibrated, in meters) 54 | 'pos_uncal' - 2d location of this point (uncalibrated, in pixels) 55 | 'desc' - SIFT descriptor of this feature 56 | 'scale' - scale of this feature 57 | 'orien' - orientation of this feature 58 | 'point' - Point associated with this feature 59 | 'cam' - Camera associated with this feature 60 | """ 61 | fields = [ 'pos', 'pos_uncal', 62 | 'desc', 'scale', 'orien', 63 | 'point', 'cam'] 64 | Saver.__init__(self, fields, id_num, values) 65 | self.pos = np.array(self.pos or np.zeros(3)) 66 | self.pos_uncal = np.array(self.pos_uncal or np.zeros(2)) 67 | self.desc = np.array(self.desc or np.zeros(128)) 68 | self.scale = self.scale or 0 69 | self.orien = self.orien or 0 70 | 71 | def get_uncentered_calib(self): 72 | """Return uncalibrated feature coordinates in original camera""" 73 | return np.array([ self.pos_uncal[1], -self.pos_uncal[0] ]) 74 | 75 | def get_proj_pt(self): 76 | """Project 3d point associated with feature to 2d location""" 77 | P = np.dot(self.cam.rot, self.point.pos) + self.cam.trans 78 | p = -P / P[-1] 79 | return self.cam.focal*p[:2] 80 | 81 | class Camera(Saver): 82 | """Class representing single view in Rome16K""" 83 | def __init__(self, id_num, values=None): 84 | """ 85 | Inputs: id_num (int) - id of this object 86 | The fields of this object are: 87 | 'rot' - rotation matrix for pose of camera 88 | 'trans' - translation vector for pose of camera 89 | 'focal' - focal length of camera 90 | 'k1' - radial distortion coefficient 1 91 | 'k2' - radial distortion coefficient 2 92 | 'imsize' - image size (optional) 93 | 'center' - point center 94 | 'features' - list of Features associated with this view 95 | """ 96 | fields = [ 'rot', 'trans', 97 | 'focal', 'k1', 'k2', 98 | 'imsize', 'center', 99 | 'features' ] 100 | Saver.__init__(self, fields, id_num, values) 101 | self.rot = np.array(self.rot or np.eye(3)) 102 | self.trans = np.array(self.trans or np.zeros(3)) 103 | self.focal = self.focal or 1.0 104 | self.k1 = self.k1 or 0.0 105 | self.k2 = self.k2 or 0.0 106 | self.imsize = self.imsize or (-1,-1) 107 | self.center = np.array(self.center or np.zeros(2)) 108 | 109 | def center_points(self): 110 | """Center points around zero i.e. calibrate this view""" 111 | p_uncal, p_proj = [], [] 112 | for f in self.features: 113 | p_uncal.append(f.get_uncentered_calib()) 114 | p_proj.append(f.get_proj_pt()) 115 | p_uncal, p_proj = np.array(p_uncal), np.array(p_proj) 116 | self.center = np.mean(p_uncal, 0) - np.mean(p_proj, 0) 117 | for f in self.features: 118 | f.pos = -(f.get_uncentered_calib() - self.center)/self.focal 119 | 120 | class Point(Saver): 121 | """Class representing single point in Rome16K""" 122 | def __init__(self, id_num, values=None): 123 | """ 124 | Inputs: id_num (int) - id of this object 125 | The fields of this object are: 126 | 'pos' - 3d location of this point 127 | 'color' - RGB value of this point 128 | 'features' - list of Features associated with this point 129 | """ 130 | fields = ['pos', 'color', 'features'] 131 | Saver.__init__(self, fields, id_num, values) 132 | self.pos = np.array(self.pos or np.zeros(3)) 133 | self.color = np.array(self.color or np.zeros(3, dtype='int32')) 134 | 135 | class Scene(Saver): 136 | """Class representing single point in Rome16K""" 137 | def __init__(self, id_num=0): 138 | """ 139 | Inputs: id_num (int, optional) - id of this scene (default 0) 140 | The fields of this object are: 141 | 'points' - Points associated with this scene 142 | 'cams' - Cameras associated with this scene 143 | 'features' - Features associated with this scene 144 | """ 145 | fields = ['points', 'cams', 'features'] 146 | Saver.__init__(self, fields, id_num) 147 | 148 | def get_dict(self): 149 | """Get dictionary representation of this scene""" 150 | # cams = [ c.get_dict() for c in self.cams ] 151 | # points = [ p.get_dict() for p in self.points ] 152 | # features = [ f.get_dict() for f in self.features ] 153 | cams = [] 154 | for c in self.cams: 155 | cams.append(c.get_dict()) 156 | points = [] 157 | for p in self.points: 158 | points.append(p.get_dict()) 159 | features = [] 160 | for f in self.features: 161 | features.append(f.get_dict()) 162 | return { 163 | 'cams' : cams, 'points': points, 'features': features 164 | } 165 | 166 | def save_out_dict(self): 167 | """Save dictionary representation of this scene""" 168 | points = [] 169 | for p in self.points: 170 | p_d = p.get_dict() 171 | p_d['features'] = None 172 | points.append(p_d) 173 | cams = [] 174 | for c in self.cams: 175 | c_d = c.get_dict() 176 | c_d['features'] = None 177 | cams.append(c_d) 178 | features = [] 179 | for f in self.features: 180 | features.append(f.get_dict()) 181 | return { 182 | 'cams' : cams, 'points': points, 'features': features 183 | } 184 | 185 | def load_dict(self, scene_dict): 186 | """Load dictionary representation into this Scene object""" 187 | self.cams = [ Camera(c['id'], values=c) for c in scene_dict['cams'] ] 188 | self.points = [ Point(p['id'], values=p) for p in scene_dict['points'] ] 189 | self.features = [] 190 | for f in scene_dict['features']: 191 | feature = Feature(f['id'], f) 192 | c_idx = feature.cam 193 | p_idx = feature.point 194 | feature.cam = self.cams[c_idx] 195 | feature.point = self.points[p_idx] 196 | self.features.append(feature) 197 | if self.cams[c_idx].features is None: 198 | self.cams[c_idx].features = [ feature ] 199 | else: 200 | self.cams[c_idx].features.append(feature) 201 | if self.points[p_idx].features is None: 202 | self.points[p_idx].features = [ feature ] 203 | else: 204 | self.points[p_idx].features.append(feature) 205 | for c in self.cams: 206 | c.center_points() 207 | 208 | -------------------------------------------------------------------------------- /data_util/noisy_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import os 4 | 5 | from data_util import parent_dataset # import GraphSimDataset 6 | 7 | 8 | class GraphSimNoisyDataset(parent_dataset.GraphSimDataset): 9 | """Dataset for syntehtic cycle consistency graphs 10 | Generates synthetic graphs from sim_graphs.py then stores/loads them to/from 11 | tfrecords. Adds 'neighbor' noise to output of parent_dataset.GraphSimDataset, 12 | essentially adding a matrix to confound the factorization in between 13 | """ 14 | MAX_IDX=7000 15 | 16 | def __init__(self, opts, params): 17 | """ 18 | Inputs: 19 | - opts (options) - object with all relevant options stored 20 | - params (DatasetParams) - object with all dataset parameters stored 21 | Outputs: GraphSimNoisyDataset 22 | """ 23 | parent_dataset.GraphSimDataset.__init__(self, opts, params) 24 | 25 | def gen_sample(self): 26 | # Pose graph and related objects 27 | sample = parent_dataset.GraphSimDataset.gen_sample(self) 28 | 29 | # Graph objects 30 | p = self.n_pts 31 | noise = self.dataset_params.noise_level 32 | TEmb = sample['TrueEmbedding'] 33 | Noise = np.eye(p) + noise*(np.eye(p, k=-1) + np.eye(p, k=-1)) 34 | AdjMat = np.dot(np.dot(TEmb, Noise), TEmb.T) 35 | AdjMat = np.minimum(1, AdjMat) 36 | Degrees = np.diag(np.sum(AdjMat,0)) 37 | sample['AdjMat'] = AdjMat.astype(self.dtype) 38 | sample['Degrees'] = Degrees.astype(self.dtype) 39 | 40 | # Laplacian objects 41 | Ahat = AdjMat + np.eye(*AdjMat.shape) 42 | Dhat_invsqrt = np.diag(1/np.sqrt(np.sum(Ahat,0))) 43 | Laplacian = np.dot(Dhat_invsqrt, np.dot(Ahat, Dhat_invsqrt)) 44 | sample['Laplacian'] = Laplacian.astype(self.dtype) 45 | 46 | return sample 47 | 48 | class GraphSimGaussDataset(parent_dataset.GraphSimDataset): 49 | """Dataset for syntehtic cycle consistency graphs 50 | Generates synthetic graphs from sim_graphs.py then stores/loads them to/from 51 | tfrecords. Adds gaussian noise to output of parent_dataset.GraphSimDataset 52 | """ 53 | MAX_IDX=7000 54 | 55 | def __init__(self, opts, params): 56 | parent_dataset.GraphSimDataset.__init__(self, opts, params) 57 | 58 | def gen_sample(self): 59 | # Pose graph and related objects 60 | sample = parent_dataset.GraphSimDataset.gen_sample(self) 61 | 62 | # Graph objects 63 | p = self.n_pts 64 | n = self.n_views 65 | noise = self.dataset_params.noise_level 66 | TEmb = sample['TrueEmbedding'] 67 | Noise = np.abs(np.random.randn(p*n,p*n)*noise) 68 | AdjMat = np.dot(TEmb, TEmb.T) + Noise - np.eye(p*n) 69 | AdjMat = np.minimum(1, AdjMat) 70 | Degrees = np.diag(np.sum(AdjMat,0)) 71 | sample['AdjMat'] = AdjMat.astype(self.dtype) 72 | sample['Degrees'] = Degrees.astype(self.dtype) 73 | 74 | # Laplacian objects 75 | Ahat = AdjMat + np.eye(*AdjMat.shape) 76 | Dhat_invsqrt = np.diag(1/np.sqrt(np.sum(Ahat,0))) 77 | Laplacian = np.dot(Dhat_invsqrt, np.dot(Ahat, Dhat_invsqrt)) 78 | sample['Laplacian'] = Laplacian.astype(self.dtype) 79 | 80 | return sample 81 | 82 | class GraphSimSymGaussDataset(parent_dataset.GraphSimDataset): 83 | """Dataset for syntehtic cycle consistency graphs 84 | Generates synthetic graphs from sim_graphs.py then stores/loads them to/from 85 | tfrecords. Adds symmetric gaussian noise to output of 86 | parent_dataset.GraphSimDataset 87 | """ 88 | MAX_IDX=7000 89 | 90 | def __init__(self, opts, params): 91 | parent_dataset.GraphSimDataset.__init__(self, opts, params) 92 | 93 | def gen_sample(self): 94 | # Pose graph and related objects 95 | sample = parent_dataset.GraphSimDataset.gen_sample(self) 96 | 97 | # Graph objects 98 | p = self.n_pts 99 | n = self.n_views 100 | noise = self.dataset_params.noise_level 101 | TEmb = sample['TrueEmbedding'] 102 | Noise = np.abs(np.random.randn(p*n,p*n)*noise) 103 | Mask = np.kron(np.ones((n,n))-np.eye(3),np.ones((p,p))) 104 | AdjMat = np.dot(TEmb, TEmb.T) + ((Noise+Noise.T)/2.0)*Mask - np.eye(p*n) 105 | AdjMat = np.minimum(1, AdjMat) 106 | Degrees = np.diag(np.sum(AdjMat,0)) 107 | sample['AdjMat'] = AdjMat.astype(self.dtype) 108 | sample['Degrees'] = Degrees.astype(self.dtype) 109 | 110 | # Laplacian objects 111 | Ahat = AdjMat + np.eye(*AdjMat.shape) 112 | Dhat_invsqrt = np.diag(1/np.sqrt(np.sum(Ahat,0))) 113 | Laplacian = np.dot(Dhat_invsqrt, np.dot(Ahat, Dhat_invsqrt)) 114 | sample['Laplacian'] = Laplacian.astype(self.dtype) 115 | 116 | return sample 117 | 118 | class GraphSimPairwiseDataset(parent_dataset.GraphSimDataset): 119 | """Dataset for syntehtic cycle consistency graphs 120 | Generates synthetic graphs from sim_graphs.py then stores/loads them to/from 121 | tfrecords. Adds individual noise matricies between the image matching matrix 122 | multiplies. 123 | """ 124 | MAX_IDX=7000 125 | 126 | def __init__(self, opts, params): 127 | parent_dataset.GraphSimDataset.__init__(self, opts, params) 128 | 129 | def gen_sample(self): 130 | # Pose graph and related objects 131 | sample = parent_dataset.GraphSimDataset.gen_sample(self) 132 | 133 | # Graph objects 134 | p = self.n_pts 135 | n = self.n_views 136 | r = self.dataset_params.num_repeats 137 | noise = self.dataset_params.noise_level 138 | perm = lambda p: np.eye(p)[np.random.permutation(p),:] 139 | TEmb = sample['TrueEmbedding'] 140 | AdjMat = np.zeros((p*n,p*n)) 141 | for i in range(n): 142 | TEmb_i = TEmb[p*i:p*i+p,:] 143 | for j in range(i+1, n): 144 | TEmb_j = TEmb[p*j:p*j+p,:] 145 | Noise = (1-noise)*np.eye(p) + noise*sum([ perm(p) for i in range(r) ]) 146 | Val_ij = np.dot(TEmb_i, np.dot(Noise, TEmb_j.T)) 147 | AdjMat[p*i:p*i+p, p*j:p*j+p] = Val_ij 148 | AdjMat[p*j:p*j+p, p*i:p*i+p] = Val_ij.T 149 | AdjMat = np.minimum(1, AdjMat) 150 | Degrees = np.diag(np.sum(AdjMat,0)) 151 | sample['AdjMat'] = AdjMat.astype(self.dtype) 152 | sample['Degrees'] = Degrees.astype(self.dtype) 153 | 154 | # Laplacian objects 155 | Ahat = AdjMat + np.eye(*AdjMat.shape) 156 | Dhat_invsqrt = np.diag(1/np.sqrt(np.sum(Ahat,0))) 157 | Laplacian = np.dot(Dhat_invsqrt, np.dot(Ahat, Dhat_invsqrt)) 158 | sample['Laplacian'] = Laplacian.astype(self.dtype) 159 | 160 | return sample 161 | 162 | class GraphSimOutlierDataset(parent_dataset.GraphSimDataset): 163 | """Dataset for syntehtic cycle consistency graphs 164 | Generates synthetic graphs from sim_graphs.py then stores/loads them to/from 165 | tfrecords. Adds in outlier connections to the adjacency matrix 166 | """ 167 | MAX_IDX=7000 168 | 169 | def __init__(self, opts, params): 170 | parent_dataset.GraphSimDataset.__init__(self, opts, params) 171 | 172 | def create_outlier_indeces(self, o, n): 173 | """Create the selection indeces for the outliers 174 | Inputs: 175 | - o (int) - number of outliers 176 | - n (int) - number of points total 177 | Outputs: outlier_sel (np.array) - list of indeces to be outliers 178 | """ 179 | ind_pairs = [ (x,y) for x in range(n) for y in range(x+1,n) ] 180 | probs = [ 1.0/len(ind_pairs) ] * len(ind_pairs) 181 | outlier_ind_pairs = np.random.multinomial(o, probs, size=1)[0] 182 | outlier_sel = np.zeros((n,n), dtype=np.int64) 183 | for i in range(len(outlier_ind_pairs)): 184 | outlier_sel[ind_pairs[i]] = int(outlier_ind_pairs[i]) 185 | outlier_sel[ind_pairs[i]] = (outlier_ind_pairs[i]) 186 | # for i in range(n): 187 | # for j in range(i+1,n): 188 | # outlier_sel[i,j] = outlier_ind_pairs[i*n + j] 189 | # outlier_sel[j,i] = outlier_ind_pairs[i*n + j] 190 | return outlier_sel 191 | 192 | def gen_sample(self): 193 | # Pose graph and related objects 194 | sample = parent_dataset.GraphSimDataset.gen_sample(self) 195 | 196 | # Graph objects 197 | p = self.n_pts 198 | n = self.n_views 199 | r = self.dataset_params.num_repeats 200 | o = self.dataset_params.num_outliers 201 | noise = self.dataset_params.noise_level 202 | perm = lambda p: np.eye(p)[np.random.permutation(p),:] 203 | TEmb = sample['TrueEmbedding'] 204 | AdjMat = np.zeros((p*n,p*n)) 205 | outlier_sel = self.create_outlier_indeces(o, n) 206 | # Generate matrix 207 | for i in range(n): 208 | TEmb_i = TEmb[p*i:p*i+p,:] 209 | for j in range(i+1, n): 210 | TEmb_j = TEmb[p*j:p*j+p,:] 211 | if outlier_sel[i,j] > 0: 212 | Noise = np.eye(p) 213 | # for _ in range(outlier_sel[i,j]): 214 | for _ in range(1): 215 | s0, s1 = np.random.choice(range(p), size=2, replace=False) 216 | tmp = Noise[s1,:].copy() 217 | Noise[s1,:] = Noise[s0,:] 218 | Noise[s0,:] = tmp 219 | Val_ij = np.dot(TEmb_i, np.dot(Noise, TEmb_j.T)) 220 | else: 221 | Val_ij = np.dot(TEmb_i, TEmb_j.T) 222 | AdjMat[p*i:p*i+p, p*j:p*j+p] = Val_ij 223 | AdjMat[p*j:p*j+p, p*i:p*i+p] = Val_ij.T 224 | AdjMat = np.minimum(1, AdjMat) 225 | Degrees = np.diag(np.sum(AdjMat,0)) 226 | sample['AdjMat'] = AdjMat.astype(self.dtype) 227 | sample['Degrees'] = Degrees.astype(self.dtype) 228 | 229 | # Laplacian objects 230 | Ahat = AdjMat + np.eye(*AdjMat.shape) 231 | Dhat_invsqrt = np.diag(1/np.sqrt(np.sum(Ahat,0))) 232 | Laplacian = np.dot(Dhat_invsqrt, np.dot(Ahat, Dhat_invsqrt)) 233 | sample['Laplacian'] = Laplacian.astype(self.dtype) 234 | 235 | return sample 236 | 237 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run and save out plots from a trained model 3 | """ 4 | import os 5 | import sys 6 | import numpy as np 7 | import argparse 8 | import argcomplete 9 | import matplotlib.pyplot as plt 10 | from matplotlib import cm 11 | import tqdm 12 | # import yaml 13 | 14 | import myutils 15 | # import options 16 | 17 | def str2bool(v): 18 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 19 | return True 20 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 21 | return False 22 | else: 23 | raise argparse.ArgumentTypeError('Boolean value expected.') 24 | 25 | def get_experiment_opts(): 26 | parser = argparse.ArgumentParser(description='Experiment with output') 27 | argcomplete.autocomplete(parser) 28 | parser.add_argument('--verbose', 29 | default=False, 30 | type=str2bool, 31 | help='Print everything or not') 32 | parser.add_argument('--index', 33 | default=1, 34 | type=int, 35 | help='Test data index to experiment with') 36 | parser.add_argument('--data_path', 37 | default='test001.npz', 38 | help='Path to test data to experiment with') 39 | parser.add_argument('--save_dir', 40 | default=None, 41 | help='Directory to save plot files in') 42 | plot_options = [ 'none', 'plot', 'unsorted', 'baseline', 'random', 'save_all' ] 43 | parser.add_argument('--plot_style', 44 | default=plot_options[0], 45 | choices=plot_options, 46 | help='Plot things in experiment') 47 | parser.add_argument('--viewer_size', 48 | default=8, 49 | type=int, 50 | help='Run in debug mode') 51 | 52 | opts = parser.parse_args() 53 | # Finished, return options 54 | return opts 55 | 56 | def npload(fdir,idx): 57 | return dict(np.load("{}/np_test-{:04d}.npz".format(fdir,idx))) 58 | 59 | def get_sorted(labels): 60 | idxs = np.argmax(labels, axis=1) 61 | sorted_idxs = np.argsort(idxs) 62 | slabels = labels[sorted_idxs] 63 | return slabels, sorted_idxs 64 | 65 | def plot_hist(save_dir, sim_mats, names, true_sim): 66 | fig, ax = plt.subplots(nrows=1, ncols=2) 67 | diags = [ np.reshape(v[true_sim==1],-1) for v in sim_mats ] 68 | off_diags = [ np.reshape(v[true_sim==0],-1) for v in sim_mats ] 69 | ax[0].hist(diags, bins=20, density=True, label=names) 70 | ax[0].legend() 71 | ax[0].set_title('Diagonal Similarity Rate') 72 | ax[1].hist(off_diags, bins=20, density=True, label=names) 73 | ax[1].set_title('Off Diagonal Similarity Rate') 74 | ax[1].legend() 75 | if save_dir: 76 | fig.savefig(os.path.join(save_dir, 'hist.png')) 77 | else: 78 | plt.show() 79 | 80 | def plot_baseline(save_dir, emb_init, emb_gt, emb_out): 81 | slabels, sorted_idxs = get_sorted(emb_gt) 82 | srand = myutils.dim_normalize(emb_init[sorted_idxs]) 83 | lsim = np.abs(np.dot(slabels, slabels.T)) 84 | rsim = np.abs(np.dot(srand, srand.T)) 85 | print('Sorted labels') 86 | fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2) 87 | im0 = ax0.imshow(slabels) 88 | im1 = ax1.imshow(srand) 89 | fig.colorbar(im0, ax=ax0) 90 | fig.colorbar(im1, ax=ax1) 91 | if save_dir: 92 | fig.savefig(os.path.join(save_dir, 'labels_sort.png')) 93 | else: 94 | plt.show() 95 | print('Sorted similarites') 96 | fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2) 97 | im0 = ax0.imshow(lsim) 98 | im1 = ax1.imshow(rsim) 99 | fig.colorbar(im0, ax=ax0) 100 | fig.colorbar(im1, ax=ax1) 101 | if save_dir: 102 | fig.savefig(os.path.join(save_dir, 'sim_sort.png')) 103 | else: 104 | plt.show() 105 | 106 | def plot_index(save_dir, emb_init, emb_gt, emb_out): 107 | # Embeddings 108 | slabels, sorted_idxs = get_sorted(emb_gt) 109 | soutput = emb_out[sorted_idxs] 110 | srand = myutils.dim_normalize(emb_init[sorted_idxs]) 111 | lsim = np.abs(np.dot(slabels, slabels.T)) 112 | osim = np.abs(np.dot(soutput, soutput.T)) 113 | rsim = np.abs(np.dot(srand, srand.T)) 114 | fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2) 115 | u, s, v = np.linalg.svd(np.dot(soutput.T, slabels)) 116 | o_ = np.ones_like(s) 117 | o_[-1] = np.linalg.det(np.dot(u,v)) 118 | Q = np.dot(u, np.dot(np.diag(o_), v)) 119 | im0 = ax0.imshow(np.abs(np.dot(soutput, Q))) 120 | im1 = ax1.imshow(osim) 121 | fig.colorbar(im0, ax=ax0) 122 | fig.colorbar(im1, ax=ax1) 123 | if save_dir: 124 | fig.savefig(os.path.join(save_dir, 'output.png')) 125 | else: 126 | plt.show() 127 | # Histogram 128 | diag = np.reshape(osim[lsim==1],-1) 129 | off_diag = np.reshape(osim[lsim==0],-1) 130 | baseline_diag = np.reshape(rsim[lsim==1],-1) 131 | baseline_off_diag = np.reshape(rsim[lsim==0],-1) 132 | fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2) 133 | ax0.hist([ diag, baseline_diag ], bins=20, density=True, 134 | label=[ 'diag', 'baseline_diag' ]) 135 | ax0.legend() 136 | ax0.set_title('Diagonal Similarity Rate') 137 | ax1.hist([ off_diag, baseline_off_diag ], bins=20, density=True, 138 | label=[ 'off_diag', 'baseline_off_diag' ]) 139 | ax1.set_title('Off Diagonal Similarity Rate') 140 | ax1.legend() 141 | if save_dir: 142 | fig.savefig(os.path.join(save_dir, 'sim_hist.png')) 143 | else: 144 | plt.show() 145 | 146 | def plot_index_unsorted(save_dir, emb_init, emb_gt, emb_out, adjmat): 147 | labels = emb_gt 148 | rand = myutils.dim_normalize(emb_init) 149 | lsim = np.abs(np.dot(labels, labels.T)) 150 | osim = np.abs(np.dot(emb_out, emb_out.T)) 151 | rsim = np.abs(np.dot(rand, rand.T)) 152 | fig, (ax0, ax1, ax2) = plt.subplots(nrows=1, ncols=3) 153 | im0 = ax0.imshow(output) 154 | im1 = ax1.imshow(osim) 155 | im2 = ax2.imshow(adjmat + np.eye(adjmat.shape[0])) 156 | fig.colorbar(im0, ax=ax0) 157 | fig.colorbar(im1, ax=ax1) 158 | if save_dir: 159 | fig.savefig(os.path.join(save_dir, 'unsorted_output.png')) 160 | else: 161 | plt.show() 162 | # diag = np.reshape(osim[lsim==1],-1) 163 | # off_diag = np.reshape(osim[lsim==0],-1) 164 | # baseline_diag = np.reshape(rsim[lsim==1],-1) 165 | # baseline_off_diag = np.reshape(rsim[lsim==0],-1) 166 | # fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2) 167 | # ax0.hist([ diag, baseline_diag ], bins=20, density=True, 168 | # label=[ 'diag', 'baseline_diag' ]) 169 | # ax0.legend() 170 | # ax0.set_title('Diagonal Similarity Rate') 171 | # ax1.hist([ off_diag, baseline_off_diag ], bins=20, density=True, 172 | # label=[ 'off_diag', 'baseline_off_diag' ]) 173 | # ax1.set_title('Off Diagonal Similarity Rate') 174 | # ax1.legend() 175 | # fig.savefig(os.path.join(save_dir, 'sim_hist_unsorted.png')) 176 | 177 | def plot_random(save_dir, emb_init, emb_gt, emb_out): 178 | slabels, sorted_idxs = get_sorted(emb_gt) 179 | soutput = emb_out[sorted_idxs] 180 | srand = myutils.dim_normalize(emb_init[sorted_idxs]) 181 | lsim = np.abs(np.dot(slabels, slabels.T)) 182 | osim = np.abs(np.dot(soutput, soutput.T)) 183 | rsim = np.abs(np.dot(srand, srand.T)) 184 | plots = [ rsim, osim, osim**9 ] 185 | names = [ 'rsim', 'osim', 'osim**9' ] 186 | fig, ax = plt.subplots(nrows=1, ncols=len(plots)) 187 | diags = [ np.reshape(v[lsim==1],-1) for v in plots ] 188 | off_diags = [ np.reshape(v[lsim==0],-1) for v in plots ] 189 | for i in range(len(plots)): 190 | ax[i].hist([ diags[i], off_diags[i] ], bins=20, density=True, label=['diag', 'off_diag']) 191 | ax[i].legend() 192 | ax[i].set_title(names[i]) 193 | print(np.min(diags[i])) 194 | print(np.max(off_diags[i])) 195 | print('--') 196 | if save_dir: 197 | fig.savefig(os.path.join(save_dir, 'random.png')) 198 | else: 199 | plt.show() 200 | 201 | def get_stats(emb_init, emb_gt, emb_out): 202 | slabels, sorted_idxs = get_sorted(emb_gt) 203 | soutput = emb_out[sorted_idxs] 204 | srand = myutils.dim_normalize(emb_init[sorted_idxs]) 205 | lsim = np.abs(np.dot(slabels, slabels.T)) 206 | osim = np.abs(np.dot(soutput, soutput.T)) 207 | rsim = np.abs(np.dot(srand, srand.T)) 208 | diag = np.reshape(osim[lsim==1],-1) 209 | off_diag = np.reshape(osim[lsim==0],-1) 210 | baseline_diag = np.reshape(rsim[lsim==1],-1) 211 | baseline_off_diag = np.reshape(rsim[lsim==0],-1) 212 | return (np.mean(diag), np.std(diag), \ 213 | np.mean(off_diag), np.std(off_diag), \ 214 | np.mean(baseline_diag), np.std(baseline_diag), \ 215 | np.mean(baseline_off_diag), np.std(baseline_off_diag)) 216 | 217 | 218 | if __name__ == "__main__": 219 | # Build options 220 | # opts = options.get_opts() 221 | opts = get_experiment_opts() 222 | # Run experiment 223 | ld = np.load(opts.data_path) 224 | emb_init = ld['input'] 225 | emb_gt = ld['gt'] 226 | emb_out = ld['output'] 227 | adjmat = ld['adjmat'] 228 | n = len(emb_gt) 229 | if opts.plot_style == 'none': 230 | stats = np.zeros((n,8)) 231 | if opts.verbose: 232 | for i in tqdm.tqdm(range(n)): 233 | stats[i] = get_stats(emb_init[i], emb_gt[i], emb_out[i]) 234 | else: 235 | for i in range(n): 236 | stats[i] = get_stats(emb_init[i], emb_gt[i], emb_out[i]) 237 | meanstats = np.mean(stats,0) 238 | print("Diag: {:.2e} +/- {:.2e}, Off Diag: {:.2e} +/- {:.2e}, " \ 239 | "Baseline Diag: {:.2e} +/- {:.2e}, " \ 240 | "Baseline Off Diag: {:.2e} +/- {:.2e}".format(*list(meanstats))) 241 | sys.exit() 242 | if opts.index > n: 243 | print("ERROR: index out of bounds") 244 | sys.exit() 245 | i = opts.index 246 | if opts.plot_style == 'plot': 247 | plot_index(None, emb_init[i], emb_gt[i], emb_out[i]) 248 | elif opts.plot_style == 'unsorted': 249 | plot_index_unsorted(None, emb_init[i], emb_gt[i], emb_out[i]) 250 | elif opts.plot_style == 'baseline': 251 | plot_baseline(None, emb_init[i], emb_gt[i], emb_out[i]) 252 | elif opts.plot_style == 'random': 253 | plot_random(None, emb_init[i], emb_gt[i], emb_out[i]) 254 | elif opts.plot_style == 'save_all': 255 | if opts.save_dir: 256 | save_dir = opts.save_dir 257 | else: 258 | save_dir = os.path.dirname(os.path.abspath(opts.data_path)) 259 | plot_index(save_dir, emb_init[i], emb_gt[i], emb_out[i]) 260 | plot_baseline(save_dir, emb_init[i], emb_gt[i], emb_out[i]) 261 | plot_random(save_dir, emb_init[i], emb_gt[i], emb_out[i]) 262 | 263 | 264 | 265 | -------------------------------------------------------------------------------- /model/skip_networks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import os 4 | import sys 5 | import tensorflow as tf 6 | import sonnet as snt 7 | 8 | import tfutils 9 | import myutils 10 | import options 11 | 12 | from model import layers 13 | 14 | class GraphSkipLayerNetwork(snt.AbstractModule): 15 | """Graph Convolutional Net with short skip connections 16 | 17 | Small skip connections are skip connections from one layer to the next. 18 | """ 19 | def __init__(self, 20 | opts, 21 | arch, 22 | use_bias=True, 23 | initializers=None, 24 | regularizers=None, 25 | custom_getter=None, 26 | name="graphnn"): 27 | """ 28 | Input: 29 | - opts (options) - object with all relevant options stored 30 | - arch (ArchParams) - object with all relevant Architecture options 31 | - use_bias (boolean, optional) - have biases in the network (default True) 32 | - intializers (dict, optional) - specify custom initializers 33 | - regularizers (dict, optional) - specify custom regularizers 34 | - custom_getter (dict, optional) - specify custom getters 35 | - name (string, optional) - name for module for scoping (default graphnn) 36 | """ 37 | super(GraphSkipLayerNetwork, self).__init__(custom_getter=custom_getter, name=name) 38 | self._nlayers = len(arch.layer_lens) 39 | final_regularizers = None 40 | if regularizers is not None: 41 | final_regularizers = { k:v 42 | for k, v in regularizers.items() 43 | if k in ["w", "b"] } 44 | self._layers = [ 45 | layers.GraphSkipLayer( 46 | output_size=layer_len, 47 | activation=arch.activ, 48 | initializers=initializers, 49 | regularizers=regularizers, 50 | name="{}/graph_skip".format(name)) 51 | for layer_len in arch.layer_lens 52 | ] + [ 53 | layers.EmbeddingLinearLayer( 54 | output_size=opts.final_embedding_dim, 55 | initializers=initializers, 56 | regularizers=final_regularizers, 57 | name="{}/embed_lin".format(name)) 58 | ] 59 | self.normalize_emb = arch.normalize_emb 60 | 61 | def _build(self, laplacian, init_embeddings): 62 | """Applying this graph network to sample 63 | Inputs: 64 | - laplacian (tf.Tensor) - laplacian for the input graph 65 | - init_embeddings (tf.Tensor) - Initial node embeddings of the graph 66 | Outputs: output (tf.Tensor) - the output of the network 67 | """ 68 | output = init_embeddings 69 | for layer in self._layers: 70 | output = layer(laplacian, output) 71 | if self.normalize_emb: 72 | output = tf.nn.l2_normalize(output, axis=2) 73 | return output 74 | 75 | class GraphLongSkipLayerNetwork(snt.AbstractModule): 76 | """Graph Convolutional Net with short and long skip connections 77 | 78 | Long skip connections are skip connections from the start to an intermediate 79 | layer. This combined with short skip connections make training much smoother. 80 | """ 81 | def __init__(self, 82 | opts, 83 | arch, 84 | use_bias=True, 85 | initializers=None, 86 | regularizers=None, 87 | custom_getter=None, 88 | name="graphnn"): 89 | """ 90 | Input: 91 | - opts (options) - object with all relevant options stored 92 | - arch (ArchParams) - object with all relevant Architecture options 93 | - use_bias (boolean, optional) - have biases in the network (default True) 94 | - intializers (dict, optional) - specify custom initializers 95 | - regularizers (dict, optional) - specify custom regularizers 96 | - custom_getter (dict, optional) - specify custom getters 97 | - name (string, optional) - name for module for scoping (default graphnn) 98 | """ 99 | super(GraphLongSkipLayerNetwork, self).__init__(custom_getter=custom_getter, 100 | name=name) 101 | self._nlayers = len(arch.layer_lens) 102 | final_regularizers = None 103 | if regularizers is not None: 104 | lin_regularizers = { k:v 105 | for k, v in regularizers.items() 106 | if k in ["w", "b"] } 107 | else: 108 | lin_regularizers = None 109 | self._layers = [ 110 | layers.GraphSkipLayer( 111 | output_size=layer_len, 112 | activation=arch.activ, 113 | initializers=initializers, 114 | regularizers=regularizers, 115 | name="{}/graph_skip".format(name)) 116 | for layer_len in arch.layer_lens 117 | ] + [ 118 | layers.EmbeddingLinearLayer( 119 | output_size=opts.final_embedding_dim, 120 | initializers=initializers, 121 | regularizers=lin_regularizers, 122 | name="{}/embed_lin".format(name)) 123 | ] 124 | self._skip_layer_idx = arch.skip_layers 125 | self._skip_layers = [ 126 | layers.EmbeddingLinearLayer( 127 | output_size=arch.layer_lens[skip_idx], 128 | initializers=initializers, 129 | regularizers=lin_regularizers, 130 | name="{}/skip".format(name)) 131 | for skip_idx in self._skip_layer_idx 132 | ] 133 | self.normalize_emb = arch.normalize_emb 134 | 135 | def _build(self, laplacian, init_embeddings): 136 | """Applying this graph network to sample 137 | Inputs: 138 | - laplacian (tf.Tensor) - laplacian for the input graph 139 | - init_embeddings (tf.Tensor) - Initial node embeddings of the graph 140 | Outputs: output (tf.Tensor) - the output of the network 141 | """ 142 | output = init_embeddings 143 | sk = 0 144 | for i, layer in enumerate(self._layers): 145 | if i in self._skip_layer_idx: 146 | output = layer(laplacian, output) + self._skip_layers[sk](laplacian, output) 147 | sk += 1 148 | else: 149 | output = layer(laplacian, output) 150 | if self.normalize_emb: 151 | output = tf.nn.l2_normalize(output, axis=2) 152 | return output 153 | 154 | class GraphLongSkipNormedNetwork(GraphLongSkipLayerNetwork): 155 | """Graph Convolutional Net with skip connections and group norm 156 | 157 | Group norm is an alternative to batch norm, defined here: 158 | https://arxiv.org/abs/1803.08494 159 | """ 160 | def __init__(self, 161 | opts, 162 | arch, 163 | use_bias=True, 164 | initializers=None, 165 | regularizers=None, 166 | custom_getter=None, 167 | name="graphnn"): 168 | """ 169 | Input: 170 | - opts (options) - object with all relevant options stored 171 | - arch (ArchParams) - object with all relevant Architecture options 172 | - use_bias (boolean, optional) - have biases in the network (default True) 173 | - intializers (dict, optional) - specify custom initializers 174 | - regularizers (dict, optional) - specify custom regularizers 175 | - custom_getter (dict, optional) - specify custom getters 176 | - name (string, optional) - name for module for scoping (default graphnn) 177 | """ 178 | super(GraphLongSkipNormedNetwork, self).__init__(opts, arch, 179 | use_bias=use_bias, 180 | initializers=initializers, 181 | regularizers=regularizers, 182 | custom_getter=custom_getter, 183 | name=name) 184 | self.start_normed = arch.start_normed 185 | self.group_size = arch.group_size 186 | self._group_norm_layers = [ 187 | layers.GraphGroupNorm( 188 | group_size=32, 189 | name="{}/group_norm".format(name)) 190 | for _ in range(self.start_normed, self._nlayers) 191 | ] 192 | 193 | def _build(self, laplacian, init_embeddings): 194 | """Applying this graph network to sample 195 | Inputs: 196 | - laplacian (tf.Tensor) - laplacian for the input graph 197 | - init_embeddings (tf.Tensor) - Initial node embeddings of the graph 198 | Outputs: output (tf.Tensor) - the output of the network 199 | """ 200 | output = init_embeddings 201 | sk = 0 202 | for i, layer in enumerate(self._layers): 203 | if i in self._skip_layer_idx: 204 | output = layer(laplacian, output) + self._skip_layers[sk](laplacian, init_embeddings) 205 | sk += 1 206 | else: 207 | output = layer(laplacian, output) 208 | if self.start_normed <= i < self._nlayers: 209 | output = self._group_norm_layers[i-self.start_normed](output) 210 | if self.normalize_emb: 211 | output = tf.nn.l2_normalize(output, axis=2) 212 | return output 213 | 214 | class GraphSkipHopNormedNetwork(GraphLongSkipNormedNetwork): 215 | def __init__(self, 216 | opts, 217 | arch, 218 | use_bias=True, 219 | initializers=None, 220 | regularizers=None, 221 | custom_getter=None, 222 | name="graphnn"): 223 | """ 224 | Input: 225 | - opts (options) - object with all relevant options stored 226 | - arch (ArchParams) - object with all relevant Architecture options 227 | - use_bias (boolean, optional) - have biases in the network (default True) 228 | - intializers (dict, optional) - specify custom initializers 229 | - regularizers (dict, optional) - specify custom regularizers 230 | - custom_getter (dict, optional) - specify custom getters 231 | - name (string, optional) - name for module for scoping (default graphnn) 232 | """ 233 | super(GraphSkipHopNormedNetwork, self).__init__(opts, arch, 234 | use_bias=use_bias, 235 | initializers=initializers, 236 | regularizers=regularizers, 237 | custom_getter=custom_getter, 238 | name=name) 239 | lin_regularizers = None 240 | if regularizers is not None: 241 | lin_regularizers = { k:v 242 | for k, v in regularizers.items() 243 | if k in ["w", "b"] } 244 | self._hop_layers = [ 245 | layers.EmbeddingLinearLayer( 246 | output_size=arch.layer_lens[skip_idx], 247 | initializers=initializers, 248 | regularizers=lin_regularizers, 249 | name="{}/hop".format(name)) 250 | for skip_idx in self._skip_layer_idx[1:] 251 | ] 252 | 253 | def _build(self, laplacian, init_embeddings): 254 | """Applying this graph network to sample 255 | Inputs: 256 | - laplacian (tf.Tensor) - laplacian for the input graph 257 | - init_embeddings (tf.Tensor) - Initial node embeddings of the graph 258 | Outputs: output (tf.Tensor) - the output of the network 259 | """ 260 | output = init_embeddings 261 | sk = 0 262 | last_skip = None 263 | for i, layer in enumerate(self._layers): 264 | if i in self._skip_layer_idx: 265 | output = layer(laplacian, output) 266 | skip_add = self._skip_layers[sk](laplacian, init_embeddings) 267 | output = output + skip_add 268 | if last_skip is not None: 269 | hop_add = self._hop_layers[sk-1](laplacian, last_skip) 270 | output = output + hop_add 271 | last_skip = output 272 | sk += 1 273 | else: 274 | output = layer(laplacian, output) 275 | if self.start_normed <= i < self._nlayers: 276 | output = self._group_norm_layers[i-self.start_normed](output) 277 | if self.normalize_emb: 278 | output = tf.nn.l2_normalize(output, axis=2) 279 | return output 280 | 281 | 282 | -------------------------------------------------------------------------------- /data_util/rome16k/parse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import gzip 5 | import pickle 6 | import tqdm 7 | 8 | import requests 9 | from PIL import Image 10 | import html.parser 11 | import io 12 | 13 | from data_util.rome16k import scenes 14 | 15 | # Format: 16 | # dict: {'train', 'test'} 17 | # -> dict: rome16k_name -> (ntriplets, ncams) 18 | bundle_file_info = { 19 | 'train' : { 20 | '5.1.0.0': (177, 507, 361, 127), 21 | '20.0.0.0': (158, 566, 608, 238), 22 | '55.0.0.0': (198, 644, 2314, 7033), 23 | '38.0.0.0': (115, 663, 1730, 2314), 24 | '26.1.0.0': (232, 744, 2278, 3222), 25 | '74.0.0.0': (126, 1050, 3491, 6227), 26 | '49.0.0.0': (88, 1053, 4130, 6848), 27 | '36.0.0.0': (153, 1204, 2327, 1732), 28 | '12.0.0.0': (368, 1511, 1646, 600), 29 | '60.0.0.0': (169, 2057, 9283, 25828), 30 | '54.0.0.0': (286, 2068, 3691, 3290), 31 | '57.0.0.0': (204, 2094, 2358, 461), 32 | '167.0.0.0': (94, 2119, 13777, 55213), 33 | '4.11.0.0': (758, 2714, 1615, 238), 34 | '38.3.0.0': (64, 3248, 20775, 38515), 35 | '135.0.0.0': (268, 3476, 6081, 3861), 36 | '4.8.0.0': (317, 3980, 22047, 51378), 37 | '110.0.0.0': (442, 4075, 16463, 34900), 38 | '4.3.0.0': (528, 4442, 15175, 31199), 39 | '29.0.0.0': (119, 4849, 93477, 859959), 40 | '97.0.0.0': (523, 4967, 13153, 14137), 41 | '4.6.0.0': (409, 5409, 28843, 51364), 42 | '84.0.0.0': (226, 5965, 57749, 315864), 43 | '9.1.0.0': (210, 6536, 67964, 340354), 44 | '33.0.0.0': (509, 6698, 37846, 157427), 45 | '15.0.0.0': (221, 9950, 103686, 444837), 46 | '26.5.0.0': (368, 12913, 118619, 584667), 47 | '122.0.0.0': (997, 15269, 99668, 339791), 48 | '10.0.0.0': (889, 16709, 89223, 240469), 49 | '11.0.0.0': (4222, 16871, 12571, 1983), 50 | '0.0.0.0': (2317, 22632, 50033, 36125), 51 | '17.0.0.0': (1470, 28333, 184655, 654706), 52 | '16.0.0.0': (947, 35180, 291222, 1025050), 53 | '4.1.0.0': (1320, 36460, 392329, 2626099), 54 | # These two comprize ~37% of the total training data 55 | # '26.2.0.1': (75225,), 56 | # '4.5.0.0': (79259,) 57 | }, 58 | 'test' : { 59 | '11.2.0.0': (101, 12, 0, 0), 60 | '125.0.0.0': (29, 21, 5, 0), 61 | '41.0.0.0': (75, 22, 0, 0), 62 | '37.0.0.0': (22, 25, 3, 1), 63 | '73.0.0.0': (131, 26, 0, 0), 64 | '33.0.0.1': (5, 31, 120, 356), 65 | '5.11.0.0': (105, 93, 17, 1), 66 | '0.3.0.0': (98, 170, 105, 26), 67 | '46.0.0.0': (174, 205, 125, 47), 68 | '26.4.0.0': (68, 239, 400, 558), 69 | '82.0.0.0': (30, 256, 1634, 8898), 70 | '65.0.0.0': (62, 298, 508, 367), 71 | '40.0.0.0': (154, 340, 217, 47), 72 | '56.0.0.0': (93, 477, 1256, 1862), 73 | '5.9.0.0': (309, 481, 186, 27), 74 | '34.1.0.0': (49, 487, 7810, 79231), 75 | } 76 | } 77 | bundle_files = sorted([ k for k in bundle_file_info['test'].keys() ] + \ 78 | [ k for k in bundle_file_info['train'].keys() ]) 79 | 80 | # Methods for getting image size 81 | URL_STR = 'http://www.flickr.com/photo_zoom.gne?id={}' 82 | class ImageSizeHTMLParser(html.parser.HTMLParser): 83 | """HTMLParser for getting the flikr image files (if needed)""" 84 | def __init__(self): 85 | super().__init__() 86 | self.image_url = None 87 | 88 | def handle_starttag(self, tag, attrs): 89 | attr_keys = [ at[0] for at in attrs ] 90 | attr_vals = [ at[1] for at in attrs ] 91 | if tag == 'a' and 'href' in attr_keys: 92 | link = attr_vals[attr_keys.index('href')] 93 | if '.jpg' in link: 94 | self.image_url = link 95 | PARSER = ImageSizeHTMLParser() 96 | 97 | def get_image_size(imname): 98 | """Find the image size of the image id imname 99 | Inputs: imname (string) - id to image file (ends with jpg) 100 | """ 101 | # photoid = imname[:-len('.jpg')].split('/')[-1].split('_')[-1] 102 | photoid = imname[:-len('.jpg')].split('_')[-1] 103 | url_ = URL_STR.format(photoid) 104 | ret = requests.get(url_) 105 | PARSER.feed(ret.text) 106 | im_ = requests.get(PARSER.image_url) 107 | image = Image.open(io.BytesIO(im_.content)) 108 | return image.size 109 | 110 | 111 | # Methods for getting filenames 112 | def check_valid_name(bundle_file): 113 | """Check that bundle_file is one we can load""" 114 | return bundle_file in bundle_files 115 | 116 | def scene_fname(bundle_file): 117 | """Scene file name based on bundle number 118 | Inputs: bundle_file (string) - bundle file number, from 119 | rome16k.parse.bundle_files 120 | Outputs: scene_fname (string) - path to the scene file 121 | """ 122 | if not check_valid_name(bundle_file): 123 | print("ERROR: Specified bundle file does not exist: {}".format(bundle_files)) 124 | sys.exit(1) 125 | return 'scene.{}.pkl'.format(bundle_file) 126 | 127 | def triplets_name(bundle_file, lite=False): 128 | """Triplets file name based on bundle number, where the valid connected images are 129 | WARNING: Depracated 130 | Inputs: bundle_file (string) - bundle file number, from 131 | rome16k.parse.bundle_files 132 | Outputs: tuple_fname (string) - path to the triplets 133 | """ 134 | if not check_valid_name(bundle_file): 135 | print("ERROR: Specified bundle file does not exist: {}".format(bundle_files)) 136 | sys.exit(1) 137 | if lite: 138 | return 'triplets_lite.{}.pkl'.format(bundle_file) 139 | else: 140 | return 'triplets.{}.pkl'.format(bundle_file) 141 | 142 | def tuples_fname(bundle_file): 143 | """Tuples file name based on bundle number, where the valid connected images are 144 | Inputs: bundle_file (string) - bundle file number, from 145 | rome16k.parse.bundle_files 146 | Outputs: tuple_fname (string) - path to the tuple file 147 | """ 148 | if not check_valid_name(bundle_file): 149 | print("ERROR: Specified bundle file does not exist: {}".format(bundle_files)) 150 | sys.exit(1) 151 | else: 152 | return 'tuples.{}.pkl'.format(bundle_file) 153 | 154 | # Main parsing functions 155 | def parse_sift_gzip(fname): 156 | """Parse the gzip file where the sift descriptors are in the Rome16K dataset 157 | Inputs: fname (string) - where the gzipped file is 158 | Outputs: flist (list of scenes.Feature) - Sift features from gzipped file 159 | """ 160 | with gzip.open(fname) as f: 161 | f_list = f.read().decode().split('\n')[:-1] 162 | n = (len(f_list)-1)//8 163 | meta = f_list[0] 164 | feature_list = [] 165 | for k in range(n): 166 | sift_ = [ [ float(z) for z in x.split(' ') if z != '' ] for x in f_list[(8*k+1):(8*k+9)] ] 167 | feature = scenes.Feature(0) # To fill in ID later 168 | feature.pos_uncal = np.array(sift_[0][:2]) 169 | feature.scale = np.array(sift_[0][2]) 170 | feature.orien = np.array(sift_[0][3]) 171 | feature.desc = np.array(sum(sift_[2:], sift_[1])) 172 | feature_list.append(feature) 173 | return feature_list 174 | 175 | def parse_bundle(bundle_file, top_dir, get_imsize=True, max_load=-1, verbose=False): 176 | """Parse bundle file from Rome16K dataset 177 | Inputs: 178 | - bundle_file (string) - bundle_file (string) - bundle file number, from 179 | rome16k.parse.bundle_files 180 | - top_dir (string) - location of Rome16K dataset (unzipped) 181 | - get_imsize (boolean, optional) - store image size (default True) (Make 182 | loading much slower) 183 | - max_load (int, optional) - maximum number of features to load (default -1) 184 | (If -1, load all of them) 185 | - verbose (boolean, optional) - print out everthing (default False) 186 | Outputs: scene (scenes.Scene) - loaded scene 187 | """ 188 | if verbose: 189 | myprint = lambda x: print(x) 190 | else: 191 | myprint = lambda x: 0 192 | bundle_dir = os.path.join(top_dir, 'bundle', 'components') 193 | txtname = os.path.join(bundle_dir, 'bundle.{}.txt'.format(bundle_file)) 194 | outname = os.path.join(bundle_dir, 'bundle.{}.out'.format(bundle_file)) 195 | # Load files 196 | with open(outname, 'r') as f: 197 | out_lines = [] 198 | for i, line in enumerate(f.readlines()): 199 | parsed_line = line[:-1].split(' ') 200 | if parsed_line[0] == '#': 201 | continue 202 | out_lines.append([ float(x) for x in parsed_line ]) 203 | 204 | with open(txtname, 'r') as list_file: 205 | txt_lines = list_file.readlines() 206 | # Load all SIFT features 207 | myprint("Getting feature lists...") 208 | feature_lists = [] 209 | imsize_list = [] 210 | for k, f in tqdm.tqdm(enumerate(txt_lines), total=len(txt_lines), disable=not verbose): 211 | if k == max_load: 212 | break 213 | parse = f[:-1].split(' ') 214 | fname = parse[0][len('images/'):-len('.jpg')] + ".key.gz" 215 | db_file = os.path.join(top_dir, 'db/{}'.format(fname)) 216 | if os.path.exists(db_file): 217 | feature_list = parse_sift_gzip(db_file) 218 | else: 219 | query_file = os.path.join(top_dir, 'query/{}'.format(fname)) 220 | feature_list = parse_sift_gzip(query_file) 221 | feature_lists.append(feature_list) 222 | if get_imsize: 223 | imsize_list.append(get_image_size(parse[0])) 224 | myprint("Done") 225 | 226 | meta = out_lines[0] 227 | num_cams = int(meta[0]) 228 | num_points = int(meta[1]) 229 | # Extract features 230 | myprint("Getting cameras...") 231 | cams = [] 232 | for i in range(num_cams): 233 | cam_lines = out_lines[(1+5*i):(1+5*(i+1))] 234 | cam = scenes.Camera(i) 235 | cam.focal = cam_lines[0][0] 236 | cam.k1 = cam_lines[0][1] 237 | cam.k2 = cam_lines[0][2] 238 | if get_imsize: 239 | cam.imsize = imsize_list[i] 240 | cam.rot = np.array(cam_lines[1:4]) 241 | cam.trans = np.array(cam_lines[4]) 242 | cam.features = [] 243 | err = (np.linalg.norm(np.dot(cam.rot.T, cam.rot)-np.eye(3))) 244 | if err > 1e-9: 245 | myprint((i,err)) 246 | cams.append(cam) 247 | myprint("Done") 248 | # Extract points/features 249 | myprint("Getting points and features...") 250 | points = [] 251 | features = [] 252 | start = 1+5*num_cams 253 | for i in range(num_points): 254 | lines = out_lines[(start+3*i):(start+3*(i+1))] 255 | # Construct point 256 | point = scenes.Point(i) 257 | point.pos = np.array(lines[0]) 258 | point.color = np.array(lines[1]) 259 | point.features = [] 260 | # Construct feature links 261 | cam_list = [ int(x) for x in lines[2][1::4] ] 262 | feat_list = [ int(x) for x in lines[2][2::4] ] 263 | for cam_id, feat_id in zip(cam_list, feat_list): 264 | # Create feature 265 | # # There was an recurring theme that came up in some of the files that 266 | # # They referened feature ids that simply didn't exist... I fixed this 267 | # # by just skipping them but I don't know why it happened and it is 268 | # # not documented online 269 | # if feat_id > len(feature_lists[cam_id]): 270 | # myprint('feat_id: {}'.format(feat_id)) 271 | # myprint('cam_id: {} (len: {})'.format(cam_id, len(feature_lists[cam_id]))) 272 | # continue 273 | feature = feature_lists[cam_id][feat_id] 274 | feature.cam = cams[cam_id] 275 | feature.point = point 276 | feature.id = len(features) 277 | # Connect feature to camera and point 278 | cams[cam_id].features.append(feature) 279 | point.features.append(feature) 280 | features.append(feature) 281 | points.append(point) 282 | myprint("Done") 283 | 284 | myprint("Centering points for each camera...") 285 | for c in cams: 286 | c.center_points() 287 | myprint("Done") 288 | 289 | # Create save 290 | scene = scenes.Scene() 291 | scene.cams = cams 292 | scene.points = points 293 | scene.features = features 294 | 295 | return scene 296 | 297 | # Main outward facing functions 298 | def save_scene(scene, filename, verbose=False): 299 | """Store scene object into pickle file 300 | Input: 301 | - scene (scenes.Scene) - scene object to save out 302 | - filename (string) - file location to save to 303 | - verbose (boolean, optional) - print everything out (default False) 304 | Output: None 305 | """ 306 | scene_dict = scene.save_out_dict() 307 | if verbose: 308 | print("Saving scene...") 309 | with open(filename, 'wb') as f: 310 | pickle.dump(scene_dict, f, protocol=pickle.HIGHEST_PROTOCOL) 311 | if verbose: 312 | print("Done") 313 | 314 | def load_scene(filename, verbose=False): 315 | """Load scene object using pickle into a Scene object 316 | Input: 317 | - filename (string) - file location to load from 318 | - verbose (boolean, optional) - print everything out (default False) 319 | Output: scene (scenes.Scene) - loaded scene object 320 | """ 321 | scene = scenes.Scene(0) 322 | if verbose: 323 | print("Loading pickle file...") 324 | with open(filename, 'rb') as f: 325 | scene_dict = pickle.load(f) 326 | if verbose: 327 | print("Done") 328 | print("Parsing pickle file...") 329 | scene.load_dict(scene_dict) 330 | if verbose: 331 | print("Done") 332 | return scene 333 | 334 | 335 | 336 | 337 | -------------------------------------------------------------------------------- /data_util/parent_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import os 4 | import glob 5 | import datetime 6 | import tqdm 7 | 8 | import tensorflow as tf 9 | 10 | import sim_graphs 11 | from data_util import tf_helpers 12 | 13 | class GraphSimDataset(object): 14 | """Dataset for syntehtic cycle consistency graphs 15 | Generates synthetic graphs from sim_graphs.py then stores/loads them to/from 16 | tfrecords. Generates the initial embedings using random ground truth features 17 | with added noise for each image. Parent to all other Dataset classes. 18 | """ 19 | MAX_IDX=7000 20 | 21 | def __init__(self, opts, params): 22 | """ 23 | Inputs: 24 | - opts (options) - object with all relevant options stored 25 | - params (DatasetParams) - object with all dataset parameters stored 26 | Outputs: GraphSimDataset 27 | """ 28 | self.opts = opts 29 | self.dataset_params = params 30 | self.data_dir = params.data_dir 31 | self.dtype = params.dtype 32 | self.n_views = np.random.randint(params.views[0], params.views[1]+1) 33 | self.n_pts = np.random.randint(params.points[0], params.points[1]+1) 34 | d = self.n_pts*self.n_views 35 | e = params.descriptor_dim 36 | p = params.points[-1] 37 | f = opts.final_embedding_dim 38 | self.features = { 39 | 'InitEmbeddings': 40 | tf_helpers.TensorFeature( 41 | key='InitEmbeddings', 42 | shape=[d, e], 43 | dtype=self.dtype, 44 | description='Initial embeddings for optimization'), 45 | 'AdjMat': 46 | tf_helpers.TensorFeature( 47 | key='AdjMat', 48 | shape=[d, d], 49 | dtype=self.dtype, 50 | description='Adjacency matrix for graph'), 51 | 'Degrees': 52 | tf_helpers.TensorFeature( 53 | key='Degrees', 54 | shape=[d, d], 55 | dtype=self.dtype, 56 | description='Degree matrix for graph'), 57 | 'Laplacian': 58 | tf_helpers.TensorFeature( 59 | key='Laplacian', 60 | shape=[d, d], 61 | dtype=self.dtype, 62 | description='Alternate Laplacian matrix for graph'), 63 | 'Mask': 64 | tf_helpers.TensorFeature( 65 | key='Mask', 66 | shape=[d, d], 67 | dtype=self.dtype, 68 | description='Mask for valid values of matrix'), 69 | 'MaskOffset': 70 | tf_helpers.TensorFeature( 71 | key='MaskOffset', 72 | shape=[d, d], 73 | dtype=self.dtype, 74 | description='Mask offset for loss'), 75 | 'TrueEmbedding': 76 | tf_helpers.TensorFeature( 77 | key='TrueEmbedding', 78 | shape=[d, p], 79 | dtype=self.dtype, 80 | description='True values for the low dimensional embedding'), 81 | 'NumViews': 82 | tf_helpers.Int64Feature( 83 | key='NumViews', 84 | description='Number of views used in this example'), 85 | 'NumPoints': 86 | tf_helpers.Int64Feature( 87 | key='NumPoints', 88 | description='Number of points used in this example'), 89 | } 90 | 91 | def process_features(self, loaded_features): 92 | """Augmentation after generation 93 | Input: keys (list of strings) - keys and actual values for the dataset 94 | Output: sample (dict) - sample for this dataset 95 | """ 96 | features = {} 97 | for k, feat in self.features.items(): 98 | features[k] = feat.get_feature_write(loaded_features[k]) 99 | return features 100 | 101 | def augment(self, keys, values): 102 | """Augmentation after generation 103 | Input: 104 | - keys (list of strings) - keys for the dataset values 105 | - values (list of np.array) - actual values for the dataset 106 | Output: 107 | - keys (list of strings) - keys for the dataset values, augmented 108 | - values (list of np.array) - actual values for the dataset, augmented 109 | """ 110 | return keys, values 111 | 112 | def gen_sample(self): 113 | """Return a single sample generated for this dataset 114 | Input: None 115 | Output: sample (dict) - sample for this dataset 116 | """ 117 | # Pose graph and related objects 118 | params = self.dataset_params 119 | pose_graph = sim_graphs.PoseGraph(self.dataset_params, 120 | n_pts=self.n_pts, 121 | n_views=self.n_views) 122 | sz = (pose_graph.n_pts, pose_graph.n_pts) 123 | sz2 = (pose_graph.n_views, pose_graph.n_views) 124 | if params.sparse: 125 | mask = np.kron(pose_graph.adj_mat,np.ones(sz)) 126 | else: 127 | mask = np.kron(np.ones(sz2)-np.eye(sz2[0]),np.ones(sz)) 128 | 129 | perms_ = [ np.eye(pose_graph.n_pts)[:,pose_graph.get_perm(i)] 130 | for i in range(pose_graph.n_views) ] 131 | # Embedding objects 132 | TrueEmbedding = np.concatenate(perms_, 0) 133 | InitEmbeddings = np.concatenate([ pose_graph.get_proj(i).d 134 | for i in range(pose_graph.n_views) ], 0) 135 | 136 | # Graph objects 137 | if not params.soft_edges: 138 | if params.descriptor_noise_var == 0: 139 | AdjMat = np.dot(TrueEmbedding,TrueEmbedding.T) 140 | if params.sparse: 141 | AdjMat = AdjMat * mask 142 | else: 143 | AdjMat = AdjMat - np.eye(len(AdjMat)) 144 | Degrees = np.diag(np.sum(AdjMat,0)) 145 | else: 146 | if params.sparse and params.descriptor_noise_var > 0: 147 | AdjMat = pose_graph.get_feature_matching_mat() 148 | Degrees = np.diag(np.sum(AdjMat,0)) 149 | 150 | # Laplacian objects 151 | Ahat = AdjMat + np.eye(*AdjMat.shape) 152 | Dhat_invsqrt = np.diag(1/np.sqrt(np.sum(Ahat,0))) 153 | Laplacian = np.dot(Dhat_invsqrt, np.dot(Ahat, Dhat_invsqrt)) 154 | 155 | # Mask objects 156 | neg_offset = np.kron(np.eye(sz2[0]),np.ones(sz)-np.eye(sz[0])) 157 | Mask = AdjMat - neg_offset 158 | MaskOffset = neg_offset 159 | return { 160 | 'InitEmbeddings': InitEmbeddings.astype(self.dtype), 161 | 'AdjMat': AdjMat.astype(self.dtype), 162 | 'Degrees': Degrees.astype(self.dtype), 163 | 'Laplacian': Laplacian.astype(self.dtype), 164 | 'Mask': Mask.astype(self.dtype), 165 | 'MaskOffset': MaskOffset.astype(self.dtype), 166 | 'TrueEmbedding': TrueEmbedding.astype(self.dtype), 167 | 'NumViews': pose_graph.n_views, 168 | 'NumPoints': pose_graph.n_pts, 169 | } 170 | 171 | def get_placeholders(self): 172 | """Writes data into a TF record file 173 | Input: None 174 | Output: sample (dict) - placeholders for all relevant fields 175 | """ 176 | return { k:v.get_placeholder() for k, v in self.features.items() } 177 | 178 | def convert_dataset(self, out_dir, mode): 179 | """Writes data into a TF record file 180 | Input: 181 | - out_dir (string) - directory to store tf files in 182 | - mode (string) - train or test to load appropriate dataset 183 | Output: None 184 | Calls gen_sample many times to generate a tfrecord file for tf.Dataset 185 | to load from 186 | """ 187 | params = self.dataset_params 188 | fname = '{}-{:02d}.tfrecords' 189 | outfile = lambda idx: os.path.join(out_dir, fname.format(mode, idx)) 190 | if not os.path.isdir(out_dir): 191 | os.makedirs(out_dir) 192 | 193 | print('Writing dataset to {}/{}'.format(out_dir, mode)) 194 | writer = None 195 | record_idx = 0 196 | file_idx = self.MAX_IDX + 1 197 | for index in tqdm.tqdm(range(params.sizes[mode])): 198 | if file_idx > self.MAX_IDX: 199 | file_idx = 0 200 | if writer: writer.close() 201 | writer = tf.python_io.TFRecordWriter(outfile(record_idx)) 202 | record_idx += 1 203 | loaded_features = self.gen_sample() 204 | features = self.process_features(loaded_features) 205 | example = tf.train.Example(features=tf.train.Features(feature=features)) 206 | writer.write(example.SerializeToString()) 207 | file_idx += 1 208 | 209 | if writer: writer.close() 210 | # And save out a file with the creation time for versioning 211 | timestamp_file = '{}_timestamp.txt'.format(mode) 212 | with open(os.path.join(out_dir, timestamp_file), 'w') as date_file: 213 | date_file.write('TFrecord created {}'.format(str(datetime.datetime.now()))) 214 | 215 | def create_np_dataset(self, out_dir, num_entries): 216 | """Create npz files to store dataset 217 | Input: 218 | - out_dir (string) - directory to store npz files in 219 | - num_entries (int) - number of entries to generate for npz files 220 | Output: None 221 | Save out npz files storing samples into out_dir 222 | """ 223 | fname = 'np_test-{:04d}.npz' 224 | outfile = lambda idx: os.path.join(out_dir, fname.format(idx)) 225 | print('Writing dataset to {}'.format(out_dir)) 226 | record_idx = 0 227 | for index in tqdm.tqdm(range(num_entries)): 228 | features = self.gen_sample() 229 | np.savez(outfile(index), **features) 230 | 231 | # And save out a file with the creation time for versioning 232 | timestamp_file = 'np_test_timestamp.txt' 233 | with open(os.path.join(out_dir, timestamp_file), 'w') as date_file: 234 | date_file.write('Numpy Dataset created {}'.format(str(datetime.datetime.now()))) 235 | 236 | def gen_batch(self, mode): 237 | """Return batch generated for this dataset 238 | Input: mode (string) - train or test to load appropriate dataset 239 | Output: sample (dict) - sample for this dataset, with a batch dimension 240 | """ 241 | params = self.dataset_params 242 | opts = self.opts 243 | assert mode in params.sizes, "Mode {} not supported".format(mode) 244 | batch_size = opts.batch_size 245 | keys = sorted(list(self.features.keys())) 246 | shapes = [ self.features[k].shape for k in keys ] 247 | types = [ self.features[k].dtype for k in keys ] 248 | tfshapes = [ tuple([batch_size] + s) for s in shapes ] 249 | tftypes = [ tf.as_dtype(t) for t in types ] 250 | def generator_fn(): 251 | while True: 252 | vals = [ np.zeros([batch_size] + s, types[i]) 253 | for i, s in enumerate(shapes) ] 254 | for b in range(batch_size): 255 | s = self.gen_sample() 256 | for i, k in enumerate(keys): 257 | vals[i][b] = s[k] 258 | yield tuple(vals) 259 | dataset = tf.data.Dataset.from_generator(generator_fn, 260 | tuple(tftypes), 261 | tuple(tfshapes)) 262 | batches = dataset.prefetch(2 * batch_size) 263 | 264 | iterator = batches.make_one_shot_iterator() 265 | values = iterator.get_next() 266 | return dict(zip(keys, values)) 267 | 268 | def load_batch(self, mode): 269 | """Return batch loaded from this dataset 270 | Input: mode (string) - train or test to load appropriate dataset 271 | Output: iterator (tf.data.Iterator) - iterator for train/test data 272 | """ 273 | params = self.dataset_params 274 | opts = self.opts 275 | assert mode in params.sizes, "Mode {} not supported".format(mode) 276 | data_source_name = mode + '-[0-9][0-9].tfrecords' 277 | data_sources = glob.glob(os.path.join(self.data_dir, mode, data_source_name)) 278 | if opts.shuffle_data and mode != 'test': 279 | np.random.shuffle(data_sources) # Added to help the shuffle 280 | # Build dataset provider 281 | keys_to_features = { k: v.get_feature_read() 282 | for k, v in self.features.items() } 283 | items_to_descriptions = { k: v.description 284 | for k, v in self.features.items() } 285 | def parser_op(record): 286 | example = tf.parse_single_example(record, keys_to_features) 287 | return { k : v.tensors_to_item(example) for k, v in self.features.items() } 288 | dataset = tf.data.TFRecordDataset(data_sources) 289 | dataset = dataset.map(parser_op) 290 | dataset = dataset.repeat(None) 291 | if opts.shuffle_data and mode != 'test': 292 | dataset = dataset.shuffle(buffer_size=5*opts.batch_size) 293 | if opts.batch_size > 1: 294 | dataset = dataset.batch(opts.batch_size) 295 | dataset = dataset.prefetch(buffer_size=opts.batch_size) 296 | 297 | iterator = dataset.make_one_shot_iterator() 298 | sample = iterator.get_next() 299 | return sample 300 | 301 | 302 | -------------------------------------------------------------------------------- /data_util/real_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg as la 3 | import os 4 | import sys 5 | import glob 6 | import datetime 7 | import tqdm 8 | import pickle 9 | 10 | import tensorflow as tf 11 | 12 | import sim_graphs 13 | from data_util import parent_dataset 14 | from data_util import tf_helpers 15 | from data_util.rome16k import parse 16 | 17 | class Rome16KTupleDataset(parent_dataset.GraphSimDataset): 18 | """Abstract base class for Rome16K cycle consistency graphs 19 | Generates graphs from Rome16K dataset and then stores/loads them to/from 20 | tfrecords. 21 | """ 22 | 23 | def __init__(self, opts, params, tuple_size=3): 24 | parent_dataset.GraphSimDataset.__init__(self, opts, params) 25 | self.rome16k_dir = opts.rome16k_dir 26 | self.tuple_size = tuple_size 27 | del self.features['Mask'] 28 | del self.features['MaskOffset'] 29 | self.dataset_params.sizes['train'] = \ 30 | sum([ min(int(x[1]*1.5), x[tuple_size-2]) 31 | for _, x in parse.bundle_file_info['train'].items() ]) 32 | self.dataset_params.sizes['test'] = \ 33 | sum([ min(int(x[1]*1.5), x[tuple_size-2]) 34 | for _, x in parse.bundle_file_info['test'].items() ]) 35 | 36 | def gen_sample(self): 37 | print("ERROR: Cannot generate sample - need to load data") 38 | sys.exit(1) 39 | 40 | def gen_sample_from_tuple(self, scene, tupl): 41 | print("ERROR: Not implemented in abstract base class") 42 | sys.exit(1) 43 | 44 | def scene_fname(self, bundle_file): 45 | """Scene file name based on bundle number 46 | Inputs: bundle_file (string) - bundle file number, from 47 | rome16k.parse.bundle_files 48 | Outputs: scene_fname (string) - path to the scene file 49 | """ 50 | return os.path.join(self.rome16k_dir, 'scenes', parse.scene_fname(bundle_file)) 51 | 52 | def tuples_fname(self, bundle_file): 53 | """Tuples file name based on bundle number 54 | Inputs: bundle_file (string) - bundle file number, from 55 | rome16k.parse.bundle_files 56 | Outputs: tuple_fname (string) - path to the tuple file 57 | A tuples file store the n-tuples of views in the scene that have overlapping views 58 | """ 59 | return os.path.join(self.rome16k_dir, 'scenes', parse.tuples_fname(bundle_file)) 60 | 61 | def get_tuples(self, bundle_file): 62 | """Load tuples file based on bundle number 63 | Inputs: bundle_file (string) - bundle file number, from 64 | rome16k.parse.bundle_files 65 | Outputs: tuple (list of lists of tuples) - all tuples of various sizes 66 | """ 67 | tuples_fname = self.tuples_fname(bundle_file) 68 | with open(tuples_fname, 'rb') as f: 69 | tuples_all = pickle.load(f) 70 | if self.tuple_size == 3: 71 | tuples = tuples_all[1] 72 | else: 73 | tuples_sel = tuples_all[self.tuple_size-2] 74 | n_select = int(1.5*len(tuples_all[1])) 75 | if n_select > len(tuples_sel): 76 | tuples = tuples_all[self.tuple_size-2] 77 | else: 78 | tuples = np.array(tuples_sel) 79 | tuples_idx = np.random.choice(np.arange(len(tuples)), 80 | size=n_select, replace=False) 81 | tuples = np.sort(tuples[np.sort(tuples_idx)]).tolist() 82 | return tuples 83 | 84 | def convert_dataset(self, out_dir, mode): 85 | params = self.dataset_params 86 | fname = '{}-{:02d}.tfrecords' 87 | outfile = lambda idx: os.path.join(out_dir, fname.format(mode, idx)) 88 | if not os.path.isdir(out_dir): 89 | os.makedirs(out_dir) 90 | 91 | print('Writing dataset to {}/{}'.format(out_dir, mode)) 92 | writer = None 93 | scene = None 94 | record_idx = 0 95 | file_idx = self.MAX_IDX + 1 96 | 97 | pbar = tqdm.tqdm(total=params.sizes[mode]) 98 | for bundle_file in parse.bundle_file_info[mode]: 99 | scene_name = self.scene_fname(bundle_file) 100 | np.random.seed(hash(scene_name) % 2**32) 101 | scene = parse.load_scene(scene_name) 102 | for tupl in self.get_tuples(bundle_file): 103 | if file_idx > self.MAX_IDX: 104 | file_idx = 0 105 | if writer: writer.close() 106 | writer = tf.python_io.TFRecordWriter(outfile(record_idx)) 107 | record_idx += 1 108 | loaded_features = self.gen_sample_from_tuple(scene, tupl) 109 | features = self.process_features(loaded_features) 110 | example = tf.train.Example(features=tf.train.Features(feature=features)) 111 | writer.write(example.SerializeToString()) 112 | file_idx += 1 113 | pbar.update() 114 | 115 | if writer: writer.close() 116 | # And save out a file with the creation time for versioning 117 | timestamp_file = '{}_timestamp.txt'.format(mode) 118 | with open(os.path.join(out_dir, timestamp_file), 'w') as date_file: 119 | date_file.write('TFrecord created {}'.format(str(datetime.datetime.now()))) 120 | 121 | def create_np_dataset(self, out_dir, num_entries): 122 | del num_entries 123 | fname = 'np_test-{:04d}.npz' 124 | outfile = lambda idx: os.path.join(out_dir, fname.format(idx)) 125 | print('Writing dataset to {}'.format(out_dir)) 126 | record_idx = 0 127 | pbar = tqdm.tqdm(total=self.dataset_params.sizes['test']) 128 | index = 0 129 | for bundle_file in parse.bundle_file_info['test']: 130 | scene_name = self.scene_fname(bundle_file) 131 | np.random.seed(hash(scene_name) % 2**32) 132 | scene = parse.load_scene(scene_name) 133 | for tupl in self.get_tuples(bundle_file): 134 | features = self.gen_sample_from_tuple(scene, tupl) 135 | np.savez(outfile(index), **features) 136 | index += 1 137 | pbar.update() 138 | 139 | # And save out a file with the creation time for versioning 140 | timestamp_file = 'np_test_timestamp.txt' 141 | with open(os.path.join(out_dir, timestamp_file), 'w') as date_file: 142 | date_file.write('Numpy Dataset created {}'.format(str(datetime.datetime.now()))) 143 | 144 | def gen_batch(self, mode): 145 | params = self.dataset_params 146 | opts = self.opts 147 | assert mode in params.sizes, "Mode {} not supported".format(mode) 148 | batch_size = opts.batch_size 149 | keys = sorted(list(self.features.keys())) 150 | shapes = [ self.features[k].shape for k in keys ] 151 | types = [ self.features[k].dtype for k in keys ] 152 | tfshapes = [ tuple([batch_size] + s) for s in shapes ] 153 | tftypes = [ tf.as_dtype(t) for t in types ] 154 | def generator_fn(): 155 | while True: 156 | vals = [ np.zeros([batch_size] + s, types[i]) 157 | for i, s in enumerate(shapes) ] 158 | for b in range(batch_size): 159 | s = self.gen_sample() 160 | for i, k in enumerate(keys): 161 | vals[i][b] = s[k] 162 | yield tuple(vals) 163 | dataset = tf.data.Dataset.from_generator(generator_fn, 164 | tuple(tftypes), 165 | tuple(tfshapes)) 166 | batches = dataset.prefetch(2 * batch_size) 167 | 168 | iterator = batches.make_one_shot_iterator() 169 | values = iterator.get_next() 170 | return dict(zip(keys, values)) 171 | 172 | class KNNRome16KDataset(Rome16KTupleDataset): 173 | """Abstract base class for Rome16K cycle consistency graphs 174 | Generates graphs from Rome16K dataset and then stores/loads them to/from 175 | tfrecords. Build Graphs using simple k-nearest neighbor scheme. Initial 176 | embeddings are SIFT features as well as the x,y position, log-scale, and 177 | orientation. 178 | """ 179 | def __init__(self, opts, params): 180 | super(KNNRome16KDataset, self).__init__(opts, params, tuple_size=3) 181 | 182 | def gen_sample_from_tuple(self, scene, tupl): 183 | # Parameters 184 | k = self.dataset_params.knn 185 | n = self.dataset_params.points[-1] 186 | v = self.dataset_params.views[-1] 187 | mask = np.kron(np.ones((v,v))-np.eye(v),np.ones((n,n))) 188 | cam_pt = lambda i: set([ f.point for f in scene.cams[i].features ]) 189 | point_set = cam_pt(tupl[0]) & cam_pt(tupl[1]) & cam_pt(tupl[2]) 190 | # Build features 191 | feat_perm = np.random.permutation(len(point_set))[:n] 192 | features = [] 193 | for camid in tupl: 194 | fset = [ ([ f for f in p.features if f.cam.id == camid ])[0] for p in point_set ] 195 | fset = sorted(fset, key=lambda x: x.id) 196 | features.append([ fset[x] for x in feat_perm ]) 197 | # Build descriptors 198 | descs_ = [ np.array([ f.desc for f in feats ]) for feats in features ] 199 | rids = [ np.random.permutation(len(ff)) for ff in descs_ ] 200 | perm_mats = [ np.eye(len(perm))[perm] for perm in rids ] 201 | perm = la.block_diag(*perm_mats) 202 | descs = np.dot(perm,np.concatenate(descs_)) 203 | 204 | # Build Graph 205 | desc_norms = np.sqrt(np.sum(descs**2, 1).reshape(-1, 1)) 206 | ndescs = descs / desc_norms 207 | Dinit = np.dot(ndescs,ndescs.T) 208 | # Rescaling 209 | Dmin = Dinit.min() 210 | Dmax = Dinit.max() 211 | D = (Dinit - Dmin)/(Dmax-Dmin) 212 | L = np.copy(D) 213 | for i in range(v): 214 | for j in range(v): 215 | Lsub = L[n*i:n*(i+1),n*j:n*(j+1)] 216 | for u in range(n): 217 | Lsub[u,Lsub[u].argsort()[:-k]] = 0 218 | LLT = np.maximum(L,L.T) 219 | 220 | # Build dataset options 221 | InitEmbeddings = np.concatenate(ndescs, axis=1) 222 | AdjMat = LLT*mask 223 | Degrees = np.diag(np.sum(AdjMat,0)) 224 | TrueEmbedding = np.concatenate(perm_mats,axis=0) 225 | Ahat = AdjMat + np.eye(*AdjMat.shape) 226 | Dhat_invsqrt = np.diag(1/np.sqrt(np.sum(Ahat,0))) 227 | Laplacian = np.dot(Dhat_invsqrt, np.dot(Ahat, Dhat_invsqrt)) 228 | 229 | return { 230 | 'InitEmbeddings': InitEmbeddings.astype(self.dtype), 231 | 'AdjMat': AdjMat.astype(self.dtype), 232 | 'Degrees': Degrees.astype(self.dtype), 233 | 'Laplacian': Laplacian.astype(self.dtype), 234 | 'TrueEmbedding': TrueEmbedding.astype(self.dtype), 235 | 'NumViews': v, 236 | 'NumPoints': n, 237 | } 238 | 239 | class GeomKNNRome16KDataset(Rome16KTupleDataset): 240 | """Abstract base class for Rome16K cycle consistency graphs 241 | Generates graphs from Rome16K dataset and then stores/loads them to/from 242 | tfrecords. Build Graphs using simple k-nearest neighbor scheme, and also 243 | stores the pose information for each view. Initial embeddings are SIFT 244 | features as well as the x,y position, log-scale, and orientation. 245 | """ 246 | def __init__(self, opts, params): 247 | super(GeomKNNRome16KDataset, self).__init__(opts, params, tuple_size=params.views[-1]) 248 | d = self.n_pts*self.n_views 249 | e = params.descriptor_dim 250 | self.features.update({ 251 | 'InitEmbeddings': 252 | tf_helpers.TensorFeature( 253 | key='InitEmbeddings', 254 | shape=[d, e + 2 + 1 + 1], 255 | dtype=self.dtype, 256 | description='Initial embeddings for optimization'), 257 | 'Rotations': 258 | tf_helpers.TensorFeature( 259 | key='Rotations', 260 | shape=[self.tuple_size, 3, 3], 261 | dtype=self.dtype, 262 | description='Mask offset for loss'), 263 | 'Translations': 264 | tf_helpers.TensorFeature( 265 | key='Translations', 266 | shape=[self.tuple_size, 3], 267 | dtype=self.dtype, 268 | description='Mask offset for loss'), 269 | }) 270 | 271 | def build_mask(self): 272 | p = self.n_pts 273 | v = self.n_views 274 | return tf.convert_to_tensor(1-np.kron(np.eye(v), np.ones((p,p)))) 275 | 276 | def gen_sample_from_tuple(self, scene, tupl): 277 | # Parameters 278 | k = self.dataset_params.knn 279 | n = self.dataset_params.points[-1] 280 | v = self.dataset_params.views[-1] 281 | mask = np.kron(np.ones((v,v))-np.eye(v),np.ones((n,n))) 282 | cam_pt = lambda i: set([ f.point for f in scene.cams[i].features ]) 283 | point_set = set.intersection(*[ cam_pt(t) for t in tupl ]) 284 | # Build features 285 | feat_perm = np.random.permutation(len(point_set))[:n] 286 | features = [] 287 | for camid in tupl: 288 | fset = [ ([ f for f in p.features if f.cam.id == camid ])[0] for p in point_set ] 289 | fset = sorted(fset, key=lambda x: x.id) 290 | features.append([ fset[x] for x in feat_perm ]) 291 | # Build descriptors 292 | xy_pos_ = [ np.array([ f.pos for f in feats ]) for feats in features ] 293 | scale_ = [ np.array([ f.scale for f in feats ]) for feats in features ] 294 | orien_ = [ np.array([ f.orien for f in feats ]) for feats in features ] 295 | descs_ = [ np.array([ f.desc for f in feats ]) for feats in features ] 296 | # Apply permutation to features 297 | rids = [ np.random.permutation(len(ff)) for ff in descs_ ] 298 | perm_mats = [ np.eye(len(perm))[perm] for perm in rids ] 299 | perm = la.block_diag(*perm_mats) 300 | descs = np.dot(perm,np.concatenate(descs_)) 301 | xy_pos = np.dot(perm,np.concatenate(xy_pos_)) 302 | # We have to manually normalize these values as they are much larger than the others 303 | logscale = np.dot(perm, np.log(np.concatenate(scale_)) - 1.5).reshape(-1,1) 304 | orien = np.dot(perm,np.concatenate(orien_)).reshape(-1,1) / np.pi 305 | # Build Graph 306 | desc_norms = np.sqrt(np.sum(descs**2, 1).reshape(-1, 1)) 307 | ndescs = descs / desc_norms 308 | Dinit = np.dot(ndescs,ndescs.T) 309 | # Rescaling 310 | Dmin = Dinit.min() 311 | Dmax = Dinit.max() 312 | D = (Dinit - Dmin)/(Dmax-Dmin) 313 | L = np.copy(D) 314 | for i in range(v): 315 | for j in range(v): 316 | Lsub = L[n*i:n*(i+1),n*j:n*(j+1)] 317 | for u in range(n): 318 | Lsub[u,Lsub[u].argsort()[:-k]] = 0 319 | LLT = np.maximum(L,L.T) 320 | 321 | # Build dataset options 322 | InitEmbeddings = np.concatenate([ndescs,xy_pos,logscale,orien], axis=1) 323 | AdjMat = LLT*mask 324 | Degrees = np.diag(np.sum(AdjMat,0)) 325 | TrueEmbedding = np.concatenate(perm_mats,axis=0) 326 | Ahat = AdjMat + np.eye(*AdjMat.shape) 327 | Dhat_invsqrt = np.diag(1/np.sqrt(np.sum(Ahat,0))) 328 | Laplacian = np.dot(Dhat_invsqrt, np.dot(Ahat, Dhat_invsqrt)) 329 | Rotations = np.stack([ scene.cams[i].rot.T for i in tupl ], axis=0) 330 | Translations = np.stack([ -np.dot(scene.cams[i].rot.T, scene.cams[i].trans) 331 | for i in tupl ], axis=0) 332 | 333 | return { 334 | 'InitEmbeddings': InitEmbeddings.astype(self.dtype), 335 | 'AdjMat': AdjMat.astype(self.dtype), 336 | 'Degrees': Degrees.astype(self.dtype), 337 | 'Laplacian': Laplacian.astype(self.dtype), 338 | 'TrueEmbedding': TrueEmbedding.astype(self.dtype), 339 | 'Rotations': Rotations, 340 | 'Translations': Translations, 341 | 'NumViews': v, 342 | 'NumPoints': n, 343 | } 344 | 345 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Get all options for training network 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import os 10 | import sys 11 | import argparse 12 | import collections 13 | import types 14 | import yaml 15 | import re 16 | 17 | import myutils 18 | 19 | class DatasetParams(argparse.Namespace): 20 | """Stores information about dataset""" 21 | def __init__(self, opts): 22 | super(DatasetParams, self).__init__() 23 | self.data_dir='{}/{}'.format(opts.datasets_dir, opts.dataset) 24 | self.sizes={ 'train': 40000, 'test': 3000 } 25 | self.fixed_size=True 26 | self.views=[3] 27 | self.points=[25] 28 | self.points_scale=1 29 | self.knn=4 30 | self.scale=3 31 | self.sparse=False 32 | self.soft_edges=False 33 | self.descriptor_dim=12 34 | self.descriptor_var=1.0 35 | self.descriptor_noise_var=0 36 | self.noise_level=0.1 37 | self.num_repeats=1 38 | self.num_outliers=0 39 | self.dtype='float32' 40 | 41 | # All types of dataset we have considered 42 | dataset_choices = [ 43 | 'synth_3view', 'synth_small', 'synth_4view', 'synth_5view', 'synth_6view', 44 | 'noise_3view', 45 | 'noise_gauss', 'noise_symgauss', 46 | 'noise_pairwise', 'noise_pairwise3', 'noise_pairwise5', 47 | 'noise_pairwise3view5', 'noise_pairwise3view6', 'noise_pairwise5view5', 48 | 'noise_largepairwise3', 'noise_largepairwise5', 49 | 'synth_pts50', 'synth_pts100', 50 | 'noise_outlier1', 'noise_outlier2', 'noise_outlier4', 'noise_outlier8', 51 | 'rome16kknn0', 52 | 'rome16kgeom0', 'rome16kgeom4view0', 53 | ] 54 | 55 | class ArchParams(argparse.Namespace): 56 | """Stores information about network architecture""" 57 | def __init__(self, opts): 58 | super(ArchParams, self).__init__() 59 | self.layer_lens = [ 32, 64 ] 60 | self.activ = opts.activation_type 61 | self.attn_lens = [] 62 | self.skip_layers = [] 63 | self.start_normed = 1 64 | self.group_size = 32 65 | self.normalize_emb = True 66 | self.sparse = False 67 | 68 | # All types of architectures we have considered 69 | arch_choices = [ 70 | 'vanilla', 'vanilla0', 'vanilla1', 71 | 'skip', 'skip0', 'skip1', 72 | 'longskip0', 'longskip1', 73 | 'normedskip0', 'normedskip1', 'normedskip2', 'normedskip3', 74 | 'attn0', 'attn1', 'attn2', 75 | 'spattn0', 'spattn1', 'spattn2', 76 | ] 77 | 78 | activation_types = ['relu','leakyrelu','tanh', 'elu'] 79 | loss_types = [ 'l2', 'bce', 'l1', 'l1l2' ] 80 | optimizer_types = ['sgd','adam','adadelta','momentum','adamw'] 81 | lr_decay_types = ['exponential','fixed','polynomial'] 82 | 83 | 84 | def get_opts(): 85 | """Parse arguments from command line and get all options for training. 86 | Inputs: None 87 | Outputs: opts (options) - object with all relevant options stored 88 | Also saves out options in yaml file in the save_dir directory 89 | """ 90 | parser = argparse.ArgumentParser(description='Train motion estimator') 91 | # Directory and dataset options 92 | parser.add_argument('--save_dir', 93 | default=None, 94 | help='Directory to save out logs and checkpoints') 95 | parser.add_argument('--checkpoint_start_dir', 96 | default=None, 97 | help='Place to load from if not loading from save_dir') 98 | parser.add_argument('--data_dir', 99 | default='/NAS/data/stephen/', 100 | help='Directory for saving/loading dataset') 101 | parser.add_argument('--rome16k_dir', 102 | default='/NAS/data/stephen/Rome16K', 103 | help='Directory for storing Rome16K dataset (Very specific)') 104 | # 'synth_noise1', 'synth_noise2' 105 | parser.add_argument('--dataset', 106 | default=dataset_choices[0], 107 | choices=dataset_choices, 108 | help='Choose which dataset to use') 109 | parser.add_argument('--datasets_dir', 110 | default='/NAS/data/stephen', 111 | help='Directory where all the datasets are') 112 | parser.add_argument('--load_data', 113 | default=True, 114 | type=myutils.str2bool, 115 | help='Load data or just generate it on the fly. ' 116 | 'Generating slower but you get infinite data.') 117 | parser.add_argument('--shuffle_data', 118 | default=True, 119 | type=myutils.str2bool, 120 | help='Shuffle the dataset or no?') 121 | 122 | # Architecture parameters 123 | parser.add_argument('--architecture', 124 | default=arch_choices[0], 125 | choices=arch_choices, 126 | help='Network architecture to use') 127 | parser.add_argument('--final_embedding_dim', 128 | default=12, 129 | type=int, 130 | help='Dimensionality of the output') 131 | parser.add_argument('--activation_type', 132 | default=activation_types[0], 133 | choices=activation_types, 134 | help='What type of activation to use') 135 | 136 | # Machine learning parameters 137 | parser.add_argument('--batch_size', 138 | default=32, 139 | type=int, 140 | help='Size for batches') 141 | parser.add_argument('--use_unsupervised_loss', 142 | default=False, 143 | type=myutils.str2bool, 144 | help='Use true adjacency or noisy one in loss') 145 | parser.add_argument('--use_clamping', 146 | default=False, 147 | type=myutils.str2bool, 148 | help='Use clamping to [0, 1] on the output similarities') 149 | parser.add_argument('--use_abs_value', 150 | default=False, 151 | type=myutils.str2bool, 152 | help='Use absolute value on the output similarities') 153 | parser.add_argument('--loss_type', 154 | default=loss_types[0], 155 | choices=loss_types, 156 | help='Loss function to use for training') 157 | parser.add_argument('--reconstruction_loss', 158 | default=1.0, 159 | type=float, 160 | help='Use true adjacency or noisy one in loss') 161 | parser.add_argument('--geometric_loss', 162 | default=-1, 163 | type=float, 164 | help='Weight to use on the geometric loss') 165 | parser.add_argument('--weight_decay', 166 | default=4e-5, 167 | type=float, 168 | help='Weight decay regularization') 169 | parser.add_argument('--weight_l1_decay', 170 | default=0, 171 | type=float, 172 | help='L1 weight decay regularization') 173 | parser.add_argument('--optimizer_type', 174 | default=optimizer_types[0], 175 | choices=optimizer_types, 176 | help='Optimizer type for adaptive learning methods') 177 | parser.add_argument('--learning_rate', 178 | default=1e-3, 179 | type=float, 180 | help='Learning rate for gradient descent') 181 | parser.add_argument('--momentum', 182 | default=0.6, 183 | type=float, 184 | help='Learning rate for gradient descent') 185 | parser.add_argument('--learning_rate_decay_type', 186 | default=lr_decay_types[0], 187 | choices=lr_decay_types, 188 | help='Learning rate decay policy') 189 | parser.add_argument('--min_learning_rate', 190 | default=1e-5, 191 | type=float, 192 | help='Minimum learning rate after decaying') 193 | parser.add_argument('--learning_rate_decay_rate', 194 | default=0.95, 195 | type=float, 196 | help='Learning rate decay rate') 197 | parser.add_argument('--learning_rate_continuous', 198 | default=False, 199 | type=myutils.str2bool, 200 | help='Number of epochs before learning rate decay') 201 | parser.add_argument('--learning_rate_decay_epochs', 202 | default=4, 203 | type=float, 204 | help='Number of epochs before learning rate decay') 205 | 206 | # Training options 207 | parser.add_argument('--train_time', 208 | default=-1, 209 | type=int, 210 | help='Time in minutes the training procedure runs') 211 | parser.add_argument('--num_epochs', 212 | default=-1, 213 | type=int, 214 | help='Number of epochs to run training') 215 | parser.add_argument('--test_freq', 216 | default=8, 217 | type=int, 218 | help='Minutes between running loss on test set') 219 | parser.add_argument('--test_freq_steps', 220 | default=0, 221 | type=int, 222 | help='Number of steps between running loss on test set') 223 | parser.add_argument('--num_runs', 224 | default=1, 225 | type=int, 226 | help='Number of times training runs (length determined ' 227 | 'by run_time)') 228 | 229 | # Logging options 230 | parser.add_argument('--verbose', 231 | default=False, 232 | type=myutils.str2bool, 233 | help='Print out everything') 234 | parser.add_argument('--full_tensorboard', 235 | default=True, 236 | type=myutils.str2bool, 237 | help='Display everything on tensorboard?') 238 | parser.add_argument('--save_summaries_secs', 239 | default=120, 240 | type=int, 241 | help='How frequently in seconds we save training summaries') 242 | parser.add_argument('--save_interval_secs', 243 | default=600, 244 | type=int, 245 | help='Frequency in seconds to save model while training') 246 | parser.add_argument('--log_steps', 247 | default=5, 248 | type=int, 249 | help='How frequently we print training loss') 250 | 251 | # Debugging options 252 | parser.add_argument('--debug', 253 | default=False, 254 | type=myutils.str2bool, 255 | help='Run in debug mode') 256 | 257 | 258 | opts = parser.parse_args() 259 | 260 | # Get save directory default 261 | if opts.save_dir is None: 262 | save_idx = 0 263 | while os.path.exists('save/save-{:03d}'.format(save_idx)): 264 | save_idx += 1 265 | opts.save_dir = 'save/save-{:03d}'.format(save_idx) 266 | 267 | # Determine dataset 268 | if not opts.load_data and opts.dataset in [ 'rome16kknn0' ]: 269 | print("ERROR: Cannot generate samples on the fly for this dataset: {}".format(opts.dataset)) 270 | sys.exit(1) 271 | 272 | dataset_params = DatasetParams(opts) 273 | if opts.dataset == 'synth_3view': 274 | pass 275 | elif opts.dataset == 'noise_3view': 276 | dataset_params.noise_level = 0.2 277 | elif opts.dataset == 'synth_small': 278 | dataset_params.sizes={ 'train': 400, 'test': 300 } 279 | elif opts.dataset == 'synth_4view': 280 | dataset_params.views = [4] 281 | elif opts.dataset == 'synth_5view': 282 | dataset_params.views = [5] 283 | elif opts.dataset == 'synth_6view': 284 | dataset_params.views = [6] 285 | elif opts.dataset == 'noise_gauss': 286 | dataset_params.noise_level = 0.1 287 | elif opts.dataset == 'noise_symgauss': 288 | dataset_params.noise_level = 0.1 289 | dataset_params.num_repeats = 1 290 | elif 'noise_pairwise' in opts.dataset: 291 | dataset_params.noise_level = 0.1 292 | regex0 = re.compile('noise_pairwise([0-9]+)view([0-9]+)$') 293 | regex1 = re.compile('noise_pairwise([0-9]+)$') 294 | nums0 = regex0.findall(opts.dataset) 295 | nums1 = regex1.findall(opts.dataset) 296 | if len(nums0) > 0: 297 | nums = [ int(x) for x in nums0[0] ] 298 | dataset_params.num_repeats = nums[0] 299 | dataset_params.views = [nums[1]] 300 | elif len(nums1) > 0: 301 | nums = [ int(x) for x in nums1[0] ] 302 | dataset_params.num_repeats = nums[0] 303 | elif 'noise_largepairwise' in opts.dataset: 304 | dataset_params.noise_level = 0.1 305 | dataset_params.sizes['train'] = 400000 306 | num_rep = re.search(r'[0-9]+', opts.dataset) 307 | if num_rep: 308 | dataset_params.num_repeats = int(num_rep.group(0)) 309 | elif 'synth_pts' in opts.dataset: 310 | dataset_params.noise_level = 0.1 311 | num_pts = re.search(r'[0-9]+', opts.dataset) 312 | if num_pts: 313 | dataset_params.points = [ int(num_pts.group(0)) ] 314 | elif 'noise_outlier' in opts.dataset: 315 | num_out = re.search(r'[0-9]+', opts.dataset) 316 | if num_out: 317 | dataset_params.num_outliers = int(num_out.group(0)) 318 | elif opts.dataset == 'rome16kknn0': 319 | dataset_params.points=[80] 320 | dataset_params.descriptor_dim=128 321 | # The dataset size is undermined until loading 322 | dataset_params.sizes={ 'train': -1, 'test': -1 } 323 | elif opts.dataset == 'rome16kgeom0': 324 | dataset_params.points=[80] 325 | dataset_params.descriptor_dim=128 326 | # The dataset size is undermined until loading 327 | dataset_params.sizes={ 'train': -1, 'test': -1 } 328 | elif opts.dataset == 'rome16kgeom4view0': 329 | dataset_params.views = [4] 330 | dataset_params.points=[80] 331 | dataset_params.descriptor_dim=128 332 | # The dataset size is undermined until loading 333 | dataset_params.sizes={ 'train': -1, 'test': -1 } 334 | else: 335 | pass 336 | opts.data_dir = dataset_params.data_dir 337 | setattr(opts, 'dataset_params', dataset_params) 338 | 339 | # Set up architecture 340 | arch = ArchParams(opts) 341 | if opts.architecture in ['vanilla', 'skip', 'attn0', 'spattn0']: 342 | arch.layer_lens=[ 2**min(5+k,9) for k in range(5) ] 343 | elif opts.architecture in ['vanilla0', 'skip0', 'attn1', 'spattn1']: 344 | arch.layer_lens=[ 2**min(5+k,9) for k in range(5) ] 345 | elif opts.architecture in ['vanilla1', 'skip1', 'attn2', 'spattn2']: 346 | arch.layer_lens=[ 2**min(5+k,9) for k in range(5) ] 347 | elif opts.architecture in ['longskip0', 'normedskip0', 'normedskip2']: 348 | arch.layer_lens=[ 32, 64, 128, 256, 512, 512, 512, 349 | 512, 512, 512, 1024, 1024 ] 350 | arch.skip_layers = [ len(arch.layer_lens)//2, len(arch.layer_lens) - 1 ] 351 | elif opts.architecture in ['longskip1', 'normedskip1', 'normedskip3']: 352 | arch.layer_lens=[ 32, 64, 128, 256, 512, 512, 353 | 512, 512, 512, 1024, 1024, 354 | 512, 512, 512, 1024, 1024 ] 355 | arch.skip_layers = [ 5, 10, len(arch.layer_lens) - 1 ] 356 | if opts.architecture in [ 'spattn0', 'spattn1', 'spattn2' ]: 357 | arch.sparse = True 358 | if opts.loss_type == 'bce': 359 | arch.normalize_emb = False 360 | if opts.dataset not in [ 'rome16kgeom0', 'rome16kgeom4view0' ]: 361 | opts.geometric_loss = 0 362 | setattr(opts, 'arch', arch) 363 | 364 | # Post processing 365 | if arch.normalize_emb: 366 | setattr(opts, 'embedding_offset', 1) 367 | # Save out options 368 | if not os.path.exists(opts.save_dir): 369 | os.makedirs(opts.save_dir) 370 | if opts.checkpoint_start_dir and not os.path.exists(opts.checkpoint_start_dir): 371 | print("ERROR: Checkpoint Directory {} does not exist".format(opts.checkpoint_start_dir)) 372 | return 373 | 374 | # Save out yaml file with options stored in it 375 | yaml_fname = os.path.join(opts.save_dir, 'options.yaml') 376 | if not os.path.exists(yaml_fname): 377 | with open(yaml_fname, 'w') as yml: 378 | yml.write(yaml.dump(opts.__dict__)) 379 | 380 | # Finished, return options 381 | return opts 382 | 383 | def parse_yaml_opts(opts): 384 | """Parse the options.yaml to reload options as saved 385 | Input: opts (options) - object with all relevant options 386 | Output: opts (options) - object with all relevant options loaded 387 | """ 388 | with open(os.path.join(opts.save_dir, 'options.yaml'), 'r') as yml: 389 | yaml_opts = yaml.load(yml) 390 | opts.__dict__.update(yaml_opts) 391 | return opts 392 | 393 | 394 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Train the network. 4 | """ 5 | import os 6 | import sys 7 | import collections 8 | import signal 9 | import time 10 | import numpy as np 11 | 12 | import tensorflow as tf 13 | from tensorflow.core.util.event_pb2 import SessionLog 14 | 15 | import data_util.datasets 16 | import model 17 | import myutils 18 | import tfutils 19 | import options 20 | 21 | # Dictionary of loss functions for easy switching 22 | loss_fns = { 23 | 'bce': tfutils.bce_loss, 24 | 'l1': tfutils.l1_loss, 25 | 'l2': tfutils.l2_loss, 26 | 'l1l2': tfutils.l1_l2_loss, 27 | } 28 | 29 | log_file = None 30 | def log(string): 31 | """Log to both standard output and global log file 32 | Input: string (string) - value to log out 33 | Output: None 34 | """ 35 | tf.logging.info(string) 36 | log_file.write(string) 37 | log_file.write('\n') 38 | 39 | def get_geometric_loss(opts, sample, output_sim, name='geo_loss'): 40 | """Get geometric loss funcion (epipolar constraints) 41 | Input: 42 | - opts (options) - object with all relevant options stored 43 | - sample (dict) - sample output from dataset 44 | - output_sim (tf.Tensor) - output similarity from network 45 | - name (string, optional) - name prefix for tensorflow scoping (default geo_loss) 46 | Output: None 47 | """ 48 | b = opts.batch_size 49 | v = opts.dataset_params.views[-1] 50 | p = opts.dataset_params.points[-1] 51 | # Build rotation matrices 52 | batch_size = tf.shape(sample['Rotations'])[0] 53 | R = tf.reshape(tf.tile(sample['Rotations'], [ 1, 1, p, 1 ]), [-1, v*p, 3, 3]) 54 | T = tf.reshape(tf.tile(sample['Translations'], [ 1, 1, p ]), [-1, v*p, 3]) 55 | X = tf.concat([ sample['InitEmbeddings'][...,-4:-2], 56 | tf.tile(tf.ones((1,v*p,1)), [ batch_size, 1, 1 ]) ], axis=-1) 57 | # Compute absolute essential losses 58 | RX = tf.einsum('bvik,bvk->bvi',R,X) 59 | TcrossRX = tf.cross(T, RX) 60 | E_part = tfutils.batch_matmul(RX, tf.transpose(TcrossRX, perm=[0, 2, 1])) 61 | # Compute mask of essential losses to not include self loops 62 | npmask = np.kron(1-np.eye(v),np.ones((p,p))).reshape(1,v*p,v*p).astype(opts.dataset_params.dtype) 63 | mask = tf.convert_to_tensor(npmask, name='mask_{}'.format(name)) 64 | # Compute symmetric part 65 | E = tf.multiply(tf.abs(E_part + tf.transpose(E_part, [0, 2, 1])), mask) 66 | if opts.full_tensorboard: # Add to tensorboard 67 | tf.summary.image('Geometric matrix {}'.format(name), tf.expand_dims(E, -1)) 68 | tf.summary.histogram('Geometric matrix hist {}'.format(name), E) 69 | tf.summary.scalar('Geometric matrix norm {}'.format(name), tf.norm(E, ord=np.inf)) 70 | return tf.reduce_mean(tf.multiply(output_sim, E), name=name) 71 | 72 | def get_loss(opts, sample, output, return_gt=False, name='train'): 73 | """Get total loss funcion with main loss functions and regularizers 74 | Input: 75 | - opts (options) - object with all relevant options stored 76 | - sample (dict) - sample from training dataset 77 | - output (tf.Tensor) - output from network on sample 78 | - return_gt (boolean, optional) - return ground truth losses (default False) 79 | - name (string, optional) - name prefix for tensorflow scoping (default train) 80 | Output: 81 | - loss (tf.Tensor) - total summed loss value 82 | - gt_l1_loss (tf.Tensor) - L1 loss against ground truth (only returned 83 | if return_gt=True) 84 | - gt_l2_loss (tf.Tensor) - L2 loss against ground truth (only returned 85 | if return_gt=True) 86 | """ 87 | emb = sample['TrueEmbedding'] 88 | output_sim = tfutils.get_sim(output) 89 | if opts.use_abs_value: 90 | output_sim = tf.abs(output_sim) 91 | sim_true = tfutils.get_sim(emb) 92 | # Figure out if we are using unsupervised to know if we can use GT Adjcency 93 | if opts.use_unsupervised_loss: 94 | v = opts.dataset_params.views[-1] 95 | p = opts.dataset_params.points[-1] 96 | b = opts.batch_size 97 | sim = sample['AdjMat'] + tf.eye(num_rows=v*p, batch_shape=[b]) 98 | else: 99 | sim = sim_true 100 | if opts.full_tensorboard: # Add to tensorboard 101 | tf.summary.image('Output Similarity {}'.format(name), tf.expand_dims(output_sim, -1)) 102 | tf.summary.image('Embedding Similarity {}'.format(name), tf.expand_dims(sim, -1)) 103 | # Our main loss, the reconstruction loss 104 | reconstr_loss = loss_fns[opts.loss_type](sim, output_sim) 105 | if opts.full_tensorboard: # Add to tensorboard 106 | tf.summary.scalar('Reconstruction Loss {}'.format(name), reconstr_loss) 107 | # Geometric loss terms and tensorboard 108 | if opts.geometric_loss > 0: 109 | geo_loss = get_geometric_loss(opts, sample, output_sim, name='geom_loss_{}'.format(name)) 110 | if opts.full_tensorboard: # Add to tensorboard 111 | tf.summary.scalar('Geometric Loss {}'.format(name), geo_loss) 112 | geo_loss_gt = get_geometric_loss(opts, sample, sim_true) 113 | tf.summary.scalar('Geometric Loss GT {}'.format(name), geo_loss_gt) 114 | loss = opts.reconstruction_loss * reconstr_loss + opts.geometric_loss * geo_loss 115 | else: 116 | loss = reconstr_loss 117 | tf.summary.scalar('Total Loss {}'.format(name), loss) 118 | # Compare to loss vs ground truth (almost always do that) 119 | if return_gt: 120 | output_sim_gt = output_sim 121 | if opts.loss_type == 'bce': 122 | output_sim_gt = tf.sigmoid(output_sim) 123 | gt_l1_loss = loss_fns['l1'](sim_true, output_sim_gt, add_loss=False) 124 | gt_l2_loss = loss_fns['l2'](sim_true, output_sim_gt, add_loss=False) 125 | if opts.full_tensorboard and opts.use_unsupervised_loss: # Add to tensorboard 126 | tf.summary.scalar('GT L1 Loss {}'.format(name), gt_l1_loss) 127 | tf.summary.scalar('GT L2 Loss {}'.format(name), gt_l2_loss) 128 | return loss, gt_l1_loss, gt_l2_loss 129 | else: 130 | return loss 131 | 132 | def build_optimizer(opts, global_step): 133 | """Build optimizer for training using options in opts 134 | Input: 135 | - opts (options) - object with all relevant options stored 136 | - global_step (tf.Variable) - global_step variable from tensorflow 137 | Output: optimizer (tf.train.optimizer) - optimizer build based on opts 138 | """ 139 | # Learning parameters post-processing 140 | num_batches = 1.0 * opts.dataset_params.sizes['train'] / opts.batch_size 141 | decay_steps = int(num_batches * opts.learning_rate_decay_epochs) 142 | use_staircase = (not opts.learning_rate_continuous) 143 | if opts.learning_rate_decay_type == 'fixed': 144 | learning_rate = tf.constant(opts.learning_rate, name='fixed_learning_rate') 145 | elif opts.learning_rate_decay_type == 'exponential': 146 | learning_rate = tf.train.exponential_decay(opts.learning_rate, 147 | global_step, 148 | decay_steps, 149 | opts.learning_rate_decay_rate, 150 | staircase=use_staircase, 151 | name='learning_rate') 152 | elif opts.learning_rate_decay_type == 'polynomial': 153 | learning_rate = tf.train.polynomial_decay(opts.learning_rate, 154 | global_step, 155 | decay_steps, 156 | opts.min_learning_rate, 157 | power=1.0, 158 | cycle=False, 159 | name='learning_rate') 160 | 161 | if opts.full_tensorboard: # Add to tensorboard 162 | tf.summary.scalar('learning_rate', learning_rate) 163 | # Different descent algorithms 164 | if opts.optimizer_type == 'adam': 165 | optimizer = tf.train.AdamOptimizer(learning_rate) 166 | elif opts.optimizer_type == 'adadelta': 167 | optimizer = tf.train.AdadeltaOptimizer(learning_rate) 168 | elif opts.optimizer_type == 'momentum': 169 | optimizer = tf.train.MomentumOptimizer(learning_rate,opts.momentum) 170 | elif opts.optimizer_type == 'sgd': 171 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 172 | elif opts.optimizer_type == 'adamw': 173 | optimizer = tf.contrib.opt.AdamWOptimizer(learning_rate) 174 | 175 | return optimizer 176 | 177 | def get_train_op(opts, loss): 178 | """Build train_op for training using options in opts 179 | Input: 180 | - opts (options) - object with all relevant options stored 181 | - loss (tf.Tensor) - loss value to train on 182 | Output: train_op (tf.op) - Training operation to run to train network 183 | """ 184 | global_step = tf.train.get_or_create_global_step() 185 | optimizer = build_optimizer(opts, global_step) 186 | train_op = None 187 | if opts.weight_decay > 0 or opts.weight_l1_decay > 0: 188 | reg_loss = tf.losses.get_regularization_loss() 189 | reg_optimizer = tf.train.GradientDescentOptimizer( 190 | learning_rate=opts.weight_decay) 191 | reg_step = reg_optimizer.minimize(reg_loss, global_step=global_step) 192 | with tf.control_dependencies([reg_step]): 193 | gvs = optimizer.compute_gradients(loss) 194 | capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gvs] 195 | train_op = optimizer.apply_gradients(capped_gvs) 196 | else: 197 | gvs = optimizer.compute_gradients(loss) 198 | capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gvs] 199 | train_op = optimizer.apply_gradients(capped_gvs) 200 | return train_op 201 | 202 | def build_session(opts): 203 | """Build tf.train.SingularMonitoredSession with relevant hooks 204 | Input: opts (options) - object with all relevant options stored 205 | Output: session (tf.train.SingularMonitoredSession) 206 | """ 207 | checkpoint_dir = opts.save_dir 208 | if opts.checkpoint_start_dir is not None: 209 | checkpoint_dir = opts.checkpoint_start_dir 210 | saver_hook = tf.train.CheckpointSaverHook(opts.save_dir, 211 | save_secs=opts.save_interval_secs) 212 | merged = tf.summary.merge_all() 213 | summary_hook = tf.train.SummarySaverHook(output_dir=opts.save_dir, 214 | summary_op=merged, 215 | save_secs=opts.save_summaries_secs) 216 | all_hooks = [saver_hook, summary_hook] 217 | config = tf.ConfigProto() 218 | config.gpu_options.allow_growth = True 219 | return tf.train.SingularMonitoredSession( 220 | checkpoint_dir=checkpoint_dir, 221 | hooks=all_hooks, 222 | config=config) 223 | 224 | def get_intervals(opts): 225 | """Get relevant training and testing time/step intervals 226 | Input: opts (options) - object with all relevant options stored 227 | Output: 228 | - train_steps (int) - steps to continue training for 229 | - train_time (int) - wall clock time to keep training for 230 | - test_freq_steps (int) - how frequently in steps to run test loss 231 | - test_freq (int) - how frequently in wall clock time to run test loss 232 | """ 233 | if opts.num_epochs > 0: 234 | num_batches = 1.0 * opts.dataset_params.sizes['train'] / opts.batch_size 235 | train_steps = int(num_batches * opts.num_epochs) 236 | else: 237 | train_steps = None 238 | if opts.train_time > 0: 239 | train_time = opts.train_time * 60 240 | else: 241 | train_time = None 242 | if opts.test_freq > 0: 243 | test_freq = opts.test_freq * 60 244 | else: 245 | test_freq = None 246 | if opts.test_freq_steps > 0: 247 | test_freq_steps = opts.test_freq_steps 248 | else: 249 | test_freq_steps = None 250 | return train_steps, train_time, test_freq_steps, test_freq 251 | 252 | def get_test_dict(opts, dataset, network): 253 | """Build dictionary with relevant tensors for evaluating test loss 254 | Input: 255 | - opts (options) - object with all relevant options stored 256 | - dataset (data_util.datasets.MyDataset) - dataset to evaluate on 257 | - network (callable) - Sonnet network we are training/testing 258 | Output: 259 | - test_data (dict) - Dictionary with following items 260 | * sample (dict) - test sample from dataset 261 | * output (tf.Tensor) - network output from sample 262 | * loss (tf.Tensor) - test loss value 263 | * loss_gt_l1 (tf.Tensor) - loss against ground-truth (L1) 264 | * loss_gt_l2 (tf.Tensor) - loss against ground-truth (L2) 265 | * num_steps (int) - number of steps to go in training 266 | """ 267 | test_data = {} 268 | test_data['sample'] = dataset.load_batch('test') 269 | test_data['output'] = network(test_data['sample']['Laplacian'], 270 | test_data['sample']['InitEmbeddings']) 271 | if opts.use_unsupervised_loss: 272 | test_loss, test_gt_l1_loss, test_gt_l2_loss = \ 273 | get_loss(opts, 274 | test_data['sample'], 275 | test_data['output'], 276 | return_gt=True, 277 | name='test') 278 | test_data['loss'] = test_loss 279 | test_data['loss_gt_l1'] = test_gt_l1_loss 280 | test_data['loss_gt_l2'] = test_gt_l2_loss 281 | else: 282 | test_data['loss'] = get_loss(opts, 283 | test_data['sample'], 284 | test_data['output'], 285 | name='test') 286 | num_batches = 1.0 * opts.dataset_params.sizes['test'] / opts.batch_size 287 | test_data['nsteps'] = int(num_batches) 288 | return test_data 289 | 290 | def run_test(opts, sess, test_data, verbose=True): 291 | """Run test evaluation 292 | Input: 293 | - opts (options) - object with all relevant options stored 294 | - sess (tf.train.SingularMonitoredSession) - session from get_session 295 | - test_data (dict) - Dictionary from get_test_dict 296 | - verbose (boolean, optional) - display everything (default True) 297 | Output: None 298 | Prints out the final values of the test evaluation 299 | """ 300 | # Setup 301 | npsave = {} 302 | teststr = " ------------------- " 303 | teststr += " Test loss = {:.4e} " 304 | npsave_keys = [ 'output', 'input', 'adjmat', 'gt' ] 305 | test_data_vals = [ test_data['output'], test_data['sample']['InitEmbeddings'], 306 | test_data['sample']['AdjMat'], test_data['sample']['TrueEmbedding'] ] 307 | test_vals = [ test_data['loss'] ] 308 | start_time = time.time() 309 | if opts.use_unsupervised_loss: 310 | teststr += ", GT L1 Loss = {:4e} , GT L2 Loss = {:4e} " 311 | test_vals += [ test_data['loss_gt_l1'], test_data['loss_gt_l2'] ] 312 | teststr += "({:.03} sec)" 313 | summed_vals = [ 0 for x in range(len(test_vals)) ] 314 | # Run experiment 315 | run_outputs = sess.run(test_vals + test_data_vals) 316 | for t in range(len(test_vals)): 317 | summed_vals[t] += run_outputs[t] 318 | npsave = { k: v for k, v in zip(npsave_keys, run_outputs[len(test_vals):]) } 319 | for _ in range(test_data['nsteps']-1): 320 | run_outputs = sess.run(test_vals) 321 | for t in range(len(test_vals)): 322 | summed_vals[t] += run_outputs[t] 323 | strargs = (sv / test_data['nsteps'] for sv in summed_vals) 324 | np.savez(myutils.next_file(opts.save_dir, 'test', '.npz'), **npsave) 325 | ctime = time.time() 326 | log(teststr.format(*strargs, ctime-start_time)) 327 | 328 | def train(opts): 329 | """Train the network 330 | Input: opts (options) - object with all relevant options stored 331 | Output: None 332 | Saves all output in opts.save_dir, given by the user. For how to modify the 333 | training procedure, look at options.py 334 | """ 335 | # Get data and network 336 | dataset = data_util.datasets.get_dataset(opts) 337 | network = model.get_network(opts, opts.arch) 338 | # Training 339 | with tf.device('/cpu:0'): 340 | if opts.load_data: 341 | sample = dataset.load_batch('train') 342 | else: 343 | sample = dataset.gen_batch('train') 344 | output = network(sample['Laplacian'], sample['InitEmbeddings']) 345 | loss = get_loss(opts, sample, output, name='train') 346 | train_op = get_train_op(opts, loss) 347 | # Testing 348 | test_data = get_test_dict(opts, dataset, network) 349 | 350 | # Tensorflow and logging operations 351 | step = 0 352 | train_steps, train_time, test_freq_steps, test_freq = get_intervals(opts) 353 | trainstr = "global step {}: loss = {} ({:.04} sec/step, time {:.04})" 354 | tf.logging.set_verbosity(tf.logging.INFO) 355 | # Build session 356 | with build_session(opts) as sess: 357 | # Train loop 358 | for run in range(opts.num_runs): 359 | stime = time.time() 360 | ctime = stime 361 | ttime = stime 362 | # Main loop 363 | while step != train_steps and ctime - stime <= train_time: 364 | start_time = time.time() 365 | _, loss_ = sess.run([ train_op, loss ]) 366 | ctime = time.time() 367 | # Check for logging 368 | if (step % opts.log_steps) == 0: 369 | log(trainstr.format(step, 370 | loss_, 371 | ctime - start_time, 372 | ctime - stime)) 373 | # Check if time to evaluate test 374 | if ((test_freq_steps and step % test_freq_steps == 0) or \ 375 | (ctime - ttime > test_freq)): 376 | raw_sess = sess.raw_session() 377 | run_test(opts, raw_sess, test_data) 378 | ttime = time.time() 379 | step += 1 380 | 381 | if __name__ == "__main__": 382 | opts = options.get_opts() 383 | log_file = open(os.path.join(opts.save_dir, 'logfile.log'), 'a') 384 | train(opts) 385 | log_file.close() 386 | 387 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import os 4 | import sys 5 | import tensorflow as tf 6 | import sonnet as snt 7 | 8 | import tfutils 9 | import myutils 10 | import options 11 | 12 | 13 | class GraphGroupNorm(snt.AbstractModule): 14 | """Group Norm for Graph-based inputs module 15 | 16 | Implemented since the tensorflow one does not work with unknown batch size. 17 | Assumed input dimensions is [ Batch, Nodes, Features ] 18 | """ 19 | def __init__(self, group_size=32, name='group_norm'): 20 | super(GraphGroupNorm, self).__init__(custom_getter=None, 21 | name=name) 22 | self.group_size = 32 23 | self.possible_keys = self.get_possible_initializer_keys() 24 | self._initializers = { 25 | 'gamma' : tf.ones_initializer(), 26 | 'beta' : tf.zeros_initializer() 27 | } 28 | self._gamma = None 29 | self._beta = None 30 | self._input_shape = None 31 | 32 | def get_possible_initializer_keys(cls, use_bias=True): 33 | return {"gamma", "beta"} 34 | 35 | @property 36 | def gamma(self): 37 | """Returns the Variable containing the scale parameter, gamma. 38 | Output: gamma (tf.Tensor) - weights, from the most recent __call__. 39 | Raises: 40 | snt.NotConnectedError: If the module has not been connected to the 41 | graph yet, meaning the variables do not exist. 42 | """ 43 | self._ensure_is_connected() 44 | return self._gamma 45 | 46 | @property 47 | def beta(self): 48 | """Returns the Variable containing the center parameter, beta. 49 | Output: beta (tf.Tensor) - biases, from the most recent __call__. 50 | Raises: 51 | snt.NotConnectedError: If the module has not been connected to the 52 | graph yet, meaning the variables do not exist. 53 | """ 54 | self._ensure_is_connected() 55 | return self._beta 56 | 57 | @property 58 | def initializers(self): 59 | """Returns the initializers dictionary.""" 60 | return self._initializers 61 | 62 | @property 63 | def partitioners(self): 64 | """Returns the partitioners dictionary.""" 65 | return self._partitioners 66 | 67 | def clone(self, name=None): 68 | """Returns a cloned `GraphGroupNorm` module. 69 | Input: 70 | - name (string, optional) - name of cloned module. The default name 71 | is constructed by appending "_clone" to `self.module_name`. 72 | Output: net (snt.Module) - Cloned `GraphGroupNorm` module. 73 | """ 74 | if name is None: 75 | name = self.module_name + "_clone" 76 | return GraphGroupNorm(group_size=self.group_size) 77 | 78 | def _build(self, inputs): 79 | # Based on: 80 | # https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/layers/python/layers/normalization.py 81 | input_shape = tuple(inputs.get_shape().as_list()) 82 | if len(input_shape) != 3: 83 | raise snt.IncompatibleShapeError( 84 | "{}: rank of shape must be 3 not: {}".format( 85 | self.scope_name, len(input_shape))) 86 | 87 | if input_shape[2] is None: 88 | raise snt.IncompatibleShapeError( 89 | "{}: Input size must be specified at module build time".format( 90 | self.scope_name)) 91 | self._input_shape = input_shape 92 | dtype = inputs.dtype 93 | group_sizes = [ self.group_size, self._input_shape[2] // self.group_size ] 94 | broadcast_shape = [ 1, 1 ] + group_sizes 95 | self._gamma = tf.get_variable("gamma", 96 | shape=(self._input_shape[2]), 97 | dtype=dtype, 98 | initializer=self._initializers["gamma"]) 99 | if self._gamma not in tf.get_collection('weights'): 100 | tf.add_to_collection('weights', self._gamma) 101 | self._gamma = tf.reshape(self._gamma, broadcast_shape) 102 | 103 | self._beta = tf.get_variable("beta", 104 | shape=(self._input_shape[2],), 105 | dtype=dtype, 106 | initializer=self._initializers["beta"]) 107 | if self._beta not in tf.get_collection('biases'): 108 | tf.add_to_collection('biases', self._beta) 109 | self._beta = tf.reshape(self._beta, broadcast_shape) 110 | 111 | ##### Actually perform operations 112 | # Reshape input 113 | original_shape = [ -1, self._input_shape[1], self._input_shape[2] ] 114 | inputs_shape = [ -1, self._input_shape[1] ] + group_sizes 115 | 116 | inputs = tf.reshape(inputs, inputs_shape) 117 | 118 | # Normalize 119 | mean, variance = tf.nn.moments(inputs, [1, 3], keep_dims=True) 120 | gain = tf.rsqrt(variance + 1e-7) * self._gamma 121 | offset = -mean * gain + self._beta 122 | outputs = inputs * gain + offset 123 | 124 | # Reshape back to output 125 | outputs = tf.reshape(outputs, original_shape) 126 | 127 | return outputs 128 | 129 | class AbstractGraphLayer(snt.AbstractModule): 130 | """Transformation on an graphe node embedding. 131 | 132 | This functions almost exactly like snt.Linear except it is for tensors of 133 | size batch_size x nodes x input_size. Acts by matrix multiplication on the 134 | left side of each nodes x input_size matrix. 135 | """ 136 | def __init__(self, 137 | output_size, 138 | use_bias=True, 139 | initializers=None, 140 | partitioners=None, 141 | regularizers=None, 142 | custom_getter=None, 143 | name="embed_lin"): 144 | super(AbstractGraphLayer, self).__init__(custom_getter=None, name=name) 145 | self._output_size = output_size 146 | self._use_bias = use_bias 147 | self._input_shape = None 148 | self.possible_keys = self.get_possible_initializer_keys(use_bias=use_bias) 149 | self._initializers = snt.check_initializers( 150 | initializers, self.possible_keys) 151 | self._partitioners = snt.check_partitioners( 152 | partitioners, self.possible_keys) 153 | self._regularizers = snt.check_regularizers( 154 | regularizers, self.possible_keys) 155 | 156 | @classmethod 157 | def get_possible_initializer_keys(cls, use_bias=True): 158 | raise NotImplemented("Need to overwrite in subclass") 159 | 160 | def _build(self, laplacian, inputs): 161 | return inputs 162 | 163 | @property 164 | def output_size(self): 165 | """Returns the module output size.""" 166 | if callable(self._output_size): 167 | self._output_size = self._output_size() 168 | return self._output_size 169 | 170 | @property 171 | def has_bias(self): 172 | """Returns `True` if bias Variable is present in the module.""" 173 | return self._use_bias 174 | 175 | @property 176 | def initializers(self): 177 | """Returns the initializers dictionary.""" 178 | return self._initializers 179 | 180 | @property 181 | def partitioners(self): 182 | """Returns the partitioners dictionary.""" 183 | return self._partitioners 184 | 185 | @property 186 | def regularizers(self): 187 | """Returns the regularizers dictionary.""" 188 | return self._regularizers 189 | 190 | def clone(self, name=None): 191 | """Returns a cloned `AbstractGraphLayer` module. 192 | Input: 193 | - name (string, optional) - name of cloned module. The default name 194 | is constructed by appending "_clone" to `self.module_name`. 195 | Output: net (snt.Module) - Cloned `AbstractGraphLayer` module. 196 | """ 197 | if name is None: 198 | name = self.module_name + "_clone" 199 | return AbstractGraphLayer(output_size=self.output_size, 200 | use_bias=self._use_bias, 201 | initializers=self._initializers, 202 | partitioners=self._partitioners, 203 | regularizers=self._regularizers, 204 | name=name) 205 | 206 | 207 | 208 | class EmbeddingLinearLayer(AbstractGraphLayer): 209 | """Linear transformation on an embedding, each independently. 210 | 211 | This functions almost exactly like snt.Linear except it is for tensors of 212 | size batch_size x nodes x input_size. Acts by matrix multiplication on the 213 | left side of each nodes x input_size matrix. 214 | """ 215 | def __init__(self, 216 | output_size, 217 | use_bias=True, 218 | initializers=None, 219 | partitioners=None, 220 | regularizers=None, 221 | custom_getter=None, 222 | name="embed_lin"): 223 | super(EmbeddingLinearLayer, self).__init__( 224 | output_size, 225 | use_bias=use_bias, 226 | initializers=initializers, 227 | partitioners=partitioners, 228 | regularizers=regularizers, 229 | custom_getter=custom_getter, 230 | name=name) 231 | self._w = None 232 | self._b = None 233 | self.possible_keys = self.get_possible_initializer_keys(use_bias=use_bias) 234 | 235 | @classmethod 236 | def get_possible_initializer_keys(cls, use_bias=True): 237 | return {"w", "b"} if use_bias else {"w"} 238 | 239 | def _build(self, laplacian, inputs): 240 | input_shape = tuple(inputs.get_shape().as_list()) 241 | if len(input_shape) != 3: 242 | raise snt.IncompatibleShapeError( 243 | "{}: rank of shape must be 3 not: {}".format( 244 | self.scope_name, len(input_shape))) 245 | 246 | if input_shape[2] is None: 247 | raise snt.IncompatibleShapeError( 248 | "{}: Input size must be specified at module build time".format( 249 | self.scope_name)) 250 | 251 | if self._input_shape is not None and input_shape[2] != self._input_shape[2]: 252 | raise snt.IncompatibleShapeError( 253 | "{}: Input shape must be [batch_size, {}, {}] not: [batch_size, {}, {}]" 254 | .format(self.scope_name, 255 | input_shape[2], 256 | self._input_shape[2], 257 | input_shape[1], 258 | input_shape[2])) 259 | 260 | self._input_shape = input_shape 261 | dtype = inputs.dtype 262 | 263 | if "w" not in self._initializers: 264 | self._initializers["w"] = tfutils.create_linear_initializer( 265 | self._input_shape[2], 266 | self._output_size, 267 | dtype) 268 | 269 | if "b" not in self._initializers and self._use_bias: 270 | self._initializers["b"] = tfutils.create_bias_initializer( 271 | self._input_shape[2], 272 | self._output_size, 273 | dtype) 274 | 275 | weight_shape = (self._input_shape[2], self.output_size) 276 | self._w = tf.get_variable("w", 277 | shape=weight_shape, 278 | dtype=dtype, 279 | initializer=self._initializers["w"], 280 | partitioner=self._partitioners.get("w", None), 281 | regularizer=self._regularizers.get("w", None)) 282 | if self._w not in tf.get_collection('weights'): 283 | tf.add_to_collection('weights', self._w) 284 | outputs = tfutils.matmul(inputs, self._w) 285 | 286 | if self._use_bias: 287 | bias_shape = (self.output_size,) 288 | self._b = tf.get_variable("b", 289 | shape=bias_shape, 290 | dtype=dtype, 291 | initializer=self._initializers["b"], 292 | partitioner=self._partitioners.get("b", None), 293 | regularizer=self._regularizers.get("b", None)) 294 | if self._b not in tf.get_collection('biases'): 295 | tf.add_to_collection('biases', self._b) 296 | outputs += self._b 297 | 298 | return outputs 299 | 300 | @property 301 | def w(self): 302 | """Returns the Variable containing the weight parameters. 303 | Output: w (tf.Tensor) - weights, from the most recent __call__. 304 | Raises: 305 | snt.NotConnectedError: If the module has not been connected to the 306 | graph yet, meaning the variables do not exist. 307 | """ 308 | self._ensure_is_connected() 309 | return self._w 310 | 311 | @property 312 | def b(self): 313 | """Returns the Variable containing the bias parameters. 314 | Output: b (tf.Tensor) - biases, from the most recent __call__. 315 | Raises: 316 | snt.NotConnectedError: If the module has not been connected to the 317 | graph yet, meaning the variables do not exist. 318 | AttributeError: If the module does not use bias. 319 | """ 320 | self._ensure_is_connected() 321 | if not self._use_bias: 322 | raise AttributeError( 323 | "No bias Variable in Linear Module when `use_bias=False`.") 324 | return self._b 325 | 326 | def clone(self, name=None): 327 | """Returns a cloned `EmbeddingLinearLayer` module. 328 | Input: 329 | - name (string, optional) - name of cloned module. The default name 330 | is constructed by appending "_clone" to `self.module_name`. 331 | Output: net (snt.Module) - Cloned `EmbeddingLinearLayer` module. 332 | """ 333 | if name is None: 334 | name = self.module_name + "_clone" 335 | return EmbeddingRightLinear(output_size=self.output_size, 336 | use_bias=self._use_bias, 337 | initializers=self._initializers, 338 | partitioners=self._partitioners, 339 | regularizers=self._regularizers, 340 | name=name) 341 | 342 | class GraphConvLayer(AbstractGraphLayer): 343 | """Linear transformation on an embedding, each independently. 344 | 345 | This functions almost exactly like snt.Linear except it is for tensors of 346 | size batch_size x nodes x input_size. Acts by matrix multiplication on the 347 | left side of each nodes x input_size matrix. 348 | """ 349 | def __init__(self, 350 | output_size, 351 | activation='relu', 352 | use_bias=True, 353 | initializers=None, 354 | partitioners=None, 355 | regularizers=None, 356 | custom_getter=None, 357 | name="graph_conv"): 358 | super(GraphConvLayer, self).__init__( 359 | output_size, 360 | use_bias=use_bias, 361 | initializers=initializers, 362 | partitioners=partitioners, 363 | regularizers=regularizers, 364 | custom_getter=custom_getter, 365 | name=name) 366 | self._activ = tfutils.get_tf_activ(activation) 367 | self._w = None 368 | self._b = None 369 | self.possible_keys = self.get_possible_initializer_keys(use_bias=use_bias) 370 | 371 | @classmethod 372 | def get_possible_initializer_keys(cls, use_bias=True): 373 | return {"w", "b"} if use_bias else {"w"} 374 | 375 | def _build(self, laplacian, inputs): 376 | input_shape = tuple(inputs.get_shape().as_list()) 377 | if len(input_shape) != 3: 378 | raise snt.IncompatibleShapeError( 379 | "{}: rank of shape must be 3 not: {}".format( 380 | self.scope_name, len(input_shape))) 381 | 382 | if input_shape[2] is None: 383 | raise snt.IncompatibleShapeError( 384 | "{}: Input size must be specified at module build time".format( 385 | self.scope_name)) 386 | 387 | if input_shape[1] is None: 388 | raise snt.IncompatibleShapeError( 389 | "{}: Number of nodes must be specified at module build time".format( 390 | self.scope_name)) 391 | 392 | if self._input_shape is not None and \ 393 | (input_shape[2] != self._input_shape[2] or \ 394 | input_shape[1] != self._input_shape[1]): 395 | raise snt.IncompatibleShapeError( 396 | "{}: Input shape must be [batch_size, {}, {}] not: [batch_size, {}, {}]" 397 | .format(self.scope_name, 398 | self._input_shape[1], 399 | self._input_shape[2], 400 | input_shape[1], 401 | input_shape[2])) 402 | 403 | 404 | self._input_shape = input_shape 405 | dtype = inputs.dtype 406 | 407 | if "w" not in self._initializers: 408 | self._initializers["w"] = tfutils.create_linear_initializer( 409 | self._input_shape[2], 410 | self._output_size, 411 | dtype) 412 | 413 | if "b" not in self._initializers and self._use_bias: 414 | self._initializers["b"] = tfutils.create_bias_initializer( 415 | self._input_shape[2], 416 | self._output_size, 417 | dtype) 418 | 419 | weight_shape = (self._input_shape[2], self.output_size) 420 | self._w = tf.get_variable("w", 421 | shape=weight_shape, 422 | dtype=dtype, 423 | initializer=self._initializers["w"], 424 | partitioner=self._partitioners.get("w", None), 425 | regularizer=self._regularizers.get("w", None)) 426 | if self._w not in tf.get_collection('weights'): 427 | tf.add_to_collection('weights', self._w) 428 | outputs_ = tfutils.matmul(inputs, self._w) 429 | outputs = tfutils.batch_matmul(laplacian, outputs_) 430 | 431 | if self._use_bias: 432 | bias_shape = (self.output_size,) 433 | self._b = tf.get_variable("b", 434 | shape=bias_shape, 435 | dtype=dtype, 436 | initializer=self._initializers["b"], 437 | partitioner=self._partitioners.get("b", None), 438 | regularizer=self._regularizers.get("b", None)) 439 | if self._b not in tf.get_collection('biases'): 440 | tf.add_to_collection('biases', self._b) 441 | outputs += self._b 442 | 443 | 444 | return self._activ(outputs) 445 | 446 | @property 447 | def w(self): 448 | """Returns the Variable containing the weight parameters. 449 | Output: w (tf.Tensor) - weights, from the most recent __call__. 450 | Raises: 451 | snt.NotConnectedError: If the module has not been connected to the 452 | graph yet, meaning the variables do not exist. 453 | """ 454 | self._ensure_is_connected() 455 | return self._w 456 | 457 | @property 458 | def b(self): 459 | """Returns the Variable containing the bias parameters. 460 | Output: b (tf.Tensor) - biases, from the most recent __call__. 461 | Raises: 462 | snt.NotConnectedError: If the module has not been connected to the 463 | graph yet, meaning the variables do not exist. 464 | AttributeError: If the module does not use bias. 465 | """ 466 | self._ensure_is_connected() 467 | if not self._use_bias: 468 | raise AttributeError( 469 | "No bias Variable in Linear Module when `use_bias=False`.") 470 | return self._b 471 | 472 | def clone(self, name=None): 473 | """Returns a cloned `GraphConvLayer` module. 474 | Input: 475 | - name (string, optional) - name of cloned module. The default name 476 | is constructed by appending "_clone" to `self.module_name`. 477 | Output: net (snt.Module) - Cloned `GraphConvLayer` module. 478 | """ 479 | if name is None: 480 | name = self.module_name + "_clone" 481 | return GraphConvLayer(output_size=self.output_size, 482 | use_bias=self._use_bias, 483 | initializers=self._initializers, 484 | partitioners=self._partitioners, 485 | regularizers=self._regularizers, 486 | name=name) 487 | 488 | class GraphSkipLayer(AbstractGraphLayer): 489 | """Linear transformation on an embedding, each independently. 490 | 491 | This functions almost exactly like snt.Linear except it is for tensors of 492 | size batch_size x nodes x input_size. Acts by matrix multiplication on the 493 | left side of each nodes x input_size matrix. 494 | """ 495 | def __init__(self, 496 | output_size, 497 | activation='relu', 498 | use_bias=True, 499 | initializers=None, 500 | partitioners=None, 501 | regularizers=None, 502 | custom_getter=None, 503 | name="graph_skip"): 504 | super(GraphSkipLayer, self).__init__( 505 | output_size, 506 | use_bias=use_bias, 507 | initializers=initializers, 508 | partitioners=partitioners, 509 | regularizers=regularizers, 510 | custom_getter=custom_getter, 511 | name=name) 512 | self._activ = tfutils.get_tf_activ(activation) 513 | self._w = None 514 | self._u = None 515 | self._b = None 516 | self._c = None 517 | self.possible_keys = self.get_possible_initializer_keys(use_bias=use_bias) 518 | 519 | @classmethod 520 | def get_possible_initializer_keys(cls, use_bias=True): 521 | return {"w", "u", "b", "c"} if use_bias else {"w", "u"} 522 | 523 | def _build(self, laplacian, inputs): 524 | input_shape = tuple(inputs.get_shape().as_list()) 525 | if len(input_shape) != 3: 526 | raise snt.IncompatibleShapeError( 527 | "{}: rank of shape must be 3 not: {}".format( 528 | self.scope_name, len(input_shape))) 529 | 530 | if input_shape[2] is None: 531 | raise snt.IncompatibleShapeError( 532 | "{}: Input size must be specified at module build time".format( 533 | self.scope_name)) 534 | 535 | if input_shape[1] is None: 536 | raise snt.IncompatibleShapeError( 537 | "{}: Number of nodes must be specified at module build time".format( 538 | self.scope_name)) 539 | 540 | if self._input_shape is not None and \ 541 | (input_shape[2] != self._input_shape[2] or \ 542 | input_shape[1] != self._input_shape[1]): 543 | raise snt.IncompatibleShapeError( 544 | "{}: Input shape must be [batch_size, {}, {}] not: [batch_size, {}, {}]" 545 | .format(self.scope_name, 546 | self._input_shape[1], 547 | self._input_shape[2], 548 | input_shape[1], 549 | input_shape[2])) 550 | 551 | 552 | self._input_shape = input_shape 553 | dtype = inputs.dtype 554 | 555 | if "w" not in self._initializers: 556 | self._initializers["w"] = tfutils.create_linear_initializer( 557 | self._input_shape[2], 558 | self._output_size, 559 | dtype) 560 | if "u" not in self._initializers: 561 | self._initializers["u"] = tfutils.create_linear_initializer( 562 | self._input_shape[2], 563 | self._output_size, 564 | dtype) 565 | 566 | if "b" not in self._initializers and self._use_bias: 567 | self._initializers["b"] = tfutils.create_bias_initializer( 568 | self._input_shape[2], 569 | self._output_size, 570 | dtype) 571 | if "c" not in self._initializers and self._use_bias: 572 | self._initializers["c"] = tfutils.create_bias_initializer( 573 | self._input_shape[2], 574 | self._output_size, 575 | dtype) 576 | 577 | weight_shape = (self._input_shape[2], self.output_size) 578 | self._w = tf.get_variable("w", 579 | shape=weight_shape, 580 | dtype=dtype, 581 | initializer=self._initializers["w"], 582 | partitioner=self._partitioners.get("w", None), 583 | regularizer=self._regularizers.get("w", None)) 584 | if self._w not in tf.get_collection('weights'): 585 | tf.add_to_collection('weights', self._w) 586 | self._u = tf.get_variable("u", 587 | shape=weight_shape, 588 | dtype=dtype, 589 | initializer=self._initializers["u"], 590 | partitioner=self._partitioners.get("u", None), 591 | regularizer=self._regularizers.get("u", None)) 592 | if self._u not in tf.get_collection('weights'): 593 | tf.add_to_collection('weights', self._u) 594 | preactiv_ = tfutils.matmul(inputs, self._w) 595 | preactiv = tfutils.batch_matmul(laplacian, preactiv_) 596 | skip = tfutils.matmul(inputs, self._u) 597 | 598 | if self._use_bias: 599 | bias_shape = (self.output_size,) 600 | self._b = tf.get_variable("b", 601 | shape=bias_shape, 602 | dtype=dtype, 603 | initializer=self._initializers["b"], 604 | partitioner=self._partitioners.get("b", None), 605 | regularizer=self._regularizers.get("b", None)) 606 | if self._b not in tf.get_collection('biases'): 607 | tf.add_to_collection('biases', self._b) 608 | self._c = tf.get_variable("c", 609 | shape=bias_shape, 610 | dtype=dtype, 611 | initializer=self._initializers["c"], 612 | partitioner=self._partitioners.get("c", None), 613 | regularizer=self._regularizers.get("c", None)) 614 | if self._c not in tf.get_collection('biases'): 615 | tf.add_to_collection('biases', self._c) 616 | preactiv += self._b 617 | skip += self._c 618 | 619 | activ = self._activ(preactiv) + skip 620 | 621 | return activ 622 | 623 | @property 624 | def w(self): 625 | """Returns the Variable containing the weight parameters. 626 | Output: w (tf.Tensor) - weights, from the most recent __call__. 627 | Raises: 628 | snt.NotConnectedError: If the module has not been connected to the 629 | graph yet, meaning the variables do not exist. 630 | """ 631 | self._ensure_is_connected() 632 | return self._w 633 | 634 | @property 635 | def u(self): 636 | """Returns the Variable containing the skip weights parameters. 637 | Output: u (tf.Tensor) - skip weights, from the most recent __call__. 638 | Raises: 639 | snt.NotConnectedError: If the module has not been connected to the 640 | graph yet, meaning the variables do not exist. 641 | """ 642 | self._ensure_is_connected() 643 | return self._u 644 | 645 | @property 646 | def b(self): 647 | """Returns the Variable containing the bias parameters. 648 | Output: b (tf.Tensor) - biases, from the most recent __call__. 649 | Raises: 650 | snt.NotConnectedError: If the module has not been connected to the 651 | graph yet, meaning the variables do not exist. 652 | AttributeError: If the module does not use bias. 653 | """ 654 | self._ensure_is_connected() 655 | if not self._use_bias: 656 | raise AttributeError( 657 | "No bias Variable in Linear Module when `use_bias=False`.") 658 | return self._b 659 | 660 | @property 661 | def c(self): 662 | """Returns the Variable containing the skip bias parameters. 663 | Output: c (tf.Tensor) - skip biases, from the most recent __call__. 664 | Raises: 665 | snt.NotConnectedError: If the module has not been connected to the 666 | graph yet, meaning the variables do not exist. 667 | AttributeError: If the module does not use bias. 668 | """ 669 | self._ensure_is_connected() 670 | if not self._use_bias: 671 | raise AttributeError( 672 | "No bias Variable in Linear Module when `use_bias=False`.") 673 | return self._c 674 | 675 | def clone(self, name=None): 676 | """Returns a cloned `GraphSkipLayer` module. 677 | Input: 678 | - name (string, optional) - name of cloned module. The default name 679 | is constructed by appending "_clone" to `self.module_name`. 680 | Output: net (snt.Module) - Cloned `GraphSkipLayer` module. 681 | """ 682 | if name is None: 683 | name = self.module_name + "_clone" 684 | return GraphSkipLayer(output_size=self.output_size, 685 | use_bias=self._use_bias, 686 | initializers=self._initializers, 687 | partitioners=self._partitioners, 688 | regularizers=self._regularizers, 689 | name=name) 690 | 691 | class GraphAttentionLayer(AbstractGraphLayer): 692 | """Linear transformation on an embedding, each independently. 693 | 694 | This functions almost exactly like snt.Linear except it is for tensors of 695 | size batch_size x nodes x input_size. Acts by matrix multiplication on the 696 | left side of each nodes x input_size matrix. 697 | """ 698 | def __init__(self, 699 | output_size, 700 | activation='relu', 701 | attn_activation='leakyrelu', 702 | use_bias=True, 703 | sparse=False, 704 | initializers=None, 705 | partitioners=None, 706 | regularizers=None, 707 | custom_getter=None, 708 | name="graph_attn"): 709 | super(GraphAttentionLayer, self).__init__( 710 | output_size, 711 | use_bias=use_bias, 712 | initializers=initializers, 713 | partitioners=partitioners, 714 | regularizers=regularizers, 715 | custom_getter=custom_getter, 716 | name=name) 717 | self._sparse = sparse 718 | self._activ = tfutils.get_tf_activ(activation) 719 | self._attn_activ = tfutils.get_tf_activ(attn_activation) 720 | self.weight_keys = { ("w", output_size), ("u", output_size), 721 | ("f1", 1), ("f2", 1) } 722 | self.bias_keys = set() 723 | self.weights = { x[0] : None for x in self.weight_keys } 724 | if use_bias: 725 | self.bias_keys = { ("b", output_size), ("c", output_size), 726 | ("d1", 1), ("d2", 1) } 727 | for x in self.bias_keys: 728 | self.weights[x[0]] = None 729 | self.possible_keys = self.get_possible_initializer_keys(use_bias=use_bias) 730 | 731 | @classmethod 732 | def get_possible_initializer_keys(cls, use_bias=True): 733 | if use_bias: 734 | return {"w", "u", "b", "c", "f1", "f2", "d1", "d2"} 735 | else: 736 | return {"w", "u", "f1", "f2"} 737 | 738 | def _build(self, laplacian, inputs): 739 | input_shape = tuple(inputs.get_shape().as_list()) 740 | if len(input_shape) != 3: 741 | raise snt.IncompatibleShapeError( 742 | "{}: rank of shape must be 3 not: {}".format( 743 | self.scope_name, len(input_shape))) 744 | 745 | if input_shape[2] is None: 746 | raise snt.IncompatibleShapeError( 747 | "{}: Input size must be specified at module build time".format( 748 | self.scope_name)) 749 | 750 | if input_shape[1] is None: 751 | raise snt.IncompatibleShapeError( 752 | "{}: Number of nodes must be specified at module build time".format( 753 | self.scope_name)) 754 | 755 | if self._input_shape is not None and \ 756 | (input_shape[2] != self._input_shape[2] or \ 757 | input_shape[1] != self._input_shape[1]): 758 | raise snt.IncompatibleShapeError( 759 | "{}: Input shape must be [batch_size, {}, {}] not: [batch_size, {}, {}]" 760 | .format(self.scope_name, 761 | self._input_shape[1], 762 | self._input_shape[2], 763 | input_shape[1], 764 | input_shape[2])) 765 | 766 | 767 | self._input_shape = input_shape 768 | dtype = inputs.dtype 769 | 770 | for k, s in self.weight_keys: 771 | if k not in self._initializers: 772 | self._initializers[k] = tfutils.create_linear_initializer( 773 | self._input_shape[2], s, dtype) 774 | 775 | if self._use_bias: 776 | for k, s in self.bias_keys: 777 | if k not in self._initializers: 778 | self._initializers[k] = tfutils.create_bias_initializer( 779 | self._input_shape[2], s, dtype) 780 | 781 | for k, s in self.weight_keys: 782 | weight_shape = (self._input_shape[2], s) 783 | self.weights[k] = tf.get_variable( 784 | k, 785 | shape=weight_shape, 786 | dtype=dtype, 787 | initializer=self._initializers[k], 788 | partitioner=self._partitioners.get(k, None), 789 | regularizer=self._regularizers.get(k, None)) 790 | if self.weights[k] not in tf.get_collection('weights'): 791 | tf.add_to_collection('weights', self.weights[k]) 792 | 793 | if self._use_bias: 794 | for k, s in self.bias_keys: 795 | bias_shape = (s,) 796 | self.weights[k] = tf.get_variable( 797 | k, 798 | shape=bias_shape, 799 | dtype=dtype, 800 | initializer=self._initializers[k], 801 | partitioner=self._partitioners.get(k, None), 802 | regularizer=self._regularizers.get(k, None)) 803 | if self.weights[k] not in tf.get_collection('biases'): 804 | tf.add_to_collection('biases', self.weights[k]) 805 | 806 | preactiv_ = tfutils.matmul(inputs, self.weights["w"]) 807 | f1_ = tfutils.matmul(inputs, self.weights["f1"]) 808 | f2_ = tfutils.matmul(inputs, self.weights["f2"]) 809 | if self._use_bias: 810 | f1_ += self.weights["d1"] 811 | f2_ += self.weights["d2"] 812 | preattn_mat_ = f1_ + tf.transpose(f2_, [0, 2, 1]) 813 | if self._sparse: 814 | preattn_mat = self._attn_activ(preattn_mat_) * laplacian 815 | else: 816 | preattn_mat = self._attn_activ(preattn_mat_) + laplacian 817 | attn_mat = tf.nn.softmax(preattn_mat, axis=-1) 818 | preactiv = tfutils.batch_matmul(attn_mat, preactiv_) 819 | skip = tfutils.matmul(inputs, self.weights["u"]) 820 | 821 | if self._use_bias: 822 | preactiv += self.weights["b"] 823 | skip += self.weights["c"] 824 | 825 | activ = self._activ(preactiv) + skip 826 | 827 | return activ 828 | 829 | @property 830 | def w(self): 831 | """Returns the Variable containing the weight parameters. 832 | Output: w (tf.Tensor) - weights, from the most recent __call__. 833 | Raises: 834 | snt.NotConnectedError: If the module has not been connected to the 835 | graph yet, meaning the variables do not exist. 836 | """ 837 | self._ensure_is_connected() 838 | return self.weights["w"] 839 | 840 | @property 841 | def u(self): 842 | """Returns the Variable containing the skip weight parameters. 843 | Output: u (tf.Tensor) - skip weights, from the most recent __call__. 844 | Raises: 845 | snt.NotConnectedError: If the module has not been connected to the 846 | graph yet, meaning the variables do not exist. 847 | """ 848 | self._ensure_is_connected() 849 | return self.weights["u"] 850 | 851 | @property 852 | def f1(self): 853 | """Returns the Variable containing the first attention weight parameters. 854 | Output: f1 (tf.Tensor) - attention weights, from the most recent __call__. 855 | Raises: 856 | snt.NotConnectedError: If the module has not been connected to the 857 | graph yet, meaning the variables do not exist. 858 | """ 859 | self._ensure_is_connected() 860 | return self.weights["f1"] 861 | 862 | @property 863 | def f2(self): 864 | """Returns the Variable containing the second attention weight parameters. 865 | Output: f2 (tf.Tensor) - attention weights, from the most recent __call__. 866 | Raises: 867 | snt.NotConnectedError: If the module has not been connected to the 868 | graph yet, meaning the variables do not exist. 869 | """ 870 | self._ensure_is_connected() 871 | return self.weights["f2"] 872 | 873 | @property 874 | def b(self): 875 | """Returns the Variable containing the bias parameters. 876 | Output: b (tf.Tensor) - biases, from the most recent __call__. 877 | Raises: 878 | snt.NotConnectedError: If the module has not been connected to the 879 | graph yet, meaning the variables do not exist. 880 | AttributeError: If the module does not use bias. 881 | """ 882 | self._ensure_is_connected() 883 | if not self._use_bias: 884 | raise AttributeError( 885 | "No bias Variable in Linear Module when `use_bias=False`.") 886 | return self.weights["b"] 887 | 888 | @property 889 | def c(self): 890 | """Returns the Variable containing the skip bias parameters. 891 | Output: b (tf.Tensor) - skip biases, from the most recent __call__. 892 | Raises: 893 | snt.NotConnectedError: If the module has not been connected to the 894 | graph yet, meaning the variables do not exist. 895 | AttributeError: If the module does not use bias. 896 | """ 897 | self._ensure_is_connected() 898 | if not self._use_bias: 899 | raise AttributeError( 900 | "No bias Variable in Linear Module when `use_bias=False`.") 901 | return self.weights["c"] 902 | 903 | @property 904 | def d1(self): 905 | """Returns the Variable containing first attention bias. 906 | Output: d1 (tf.Tensor) - attention biases, from the most recent __call__. 907 | Raises: 908 | snt.NotConnectedError: If the module has not been connected to the 909 | graph yet, meaning the variables do not exist. 910 | AttributeError: If the module does not use bias. 911 | """ 912 | self._ensure_is_connected() 913 | if not self._use_bias: 914 | raise AttributeError( 915 | "No bias Variable in Linear Module when `use_bias=False`.") 916 | return self.weights["d1"] 917 | 918 | @property 919 | def d2(self): 920 | """Returns the Variable containing second attention bias. 921 | Output: d2 (tf.Tensor) - attention biases, from the most recent __call__. 922 | Raises: 923 | snt.NotConnectedError: If the module has not been connected to the 924 | graph yet, meaning the variables do not exist. 925 | AttributeError: If the module does not use bias. 926 | """ 927 | self._ensure_is_connected() 928 | if not self._use_bias: 929 | raise AttributeError( 930 | "No bias Variable in Linear Module when `use_bias=False`.") 931 | return self.weights["d2"] 932 | 933 | def clone(self, name=None): 934 | """Returns a cloned `GraphAttentionLayer` module. 935 | Input: 936 | - name (string, optional) - name of cloned module. The default name 937 | is constructed by appending "_clone" to `self.module_name`. 938 | Output: net (snt.Module) - Cloned `GraphAttentionLayer` module. 939 | """ 940 | if name is None: 941 | name = self.module_name + "_clone" 942 | return GraphAttentionLayer(output_size=self.output_size, 943 | use_bias=self._use_bias, 944 | initializers=self._initializers, 945 | partitioners=self._partitioners, 946 | regularizers=self._regularizers, 947 | name=name) 948 | 949 | 950 | if __name__ == "__main__": 951 | pass 952 | 953 | 954 | --------------------------------------------------------------------------------