├── cirtorch ├── layers │ ├── __init__.py │ ├── normalization.py │ ├── pooling.py │ ├── loss.py │ └── functional.py ├── utils │ ├── __init__.py │ ├── general.py │ ├── whiten.py │ ├── evaluate.py │ └── download.py ├── datasets │ ├── __init__.py │ ├── testdataset.py │ ├── datahelpers.py │ ├── genericdataset.py │ └── traindataset.py ├── examples │ ├── __init__.py │ └── attack │ │ ├── myutil │ │ ├── __init__.py │ │ ├── baseline.py │ │ ├── utils.py │ │ ├── sfm_dataset.py │ │ ├── distillation_dataset.py │ │ ├── triplet_dataset.py │ │ └── mi_sgd.py │ │ ├── classifier.py │ │ ├── extract_rank.py │ │ ├── distillation.py │ │ └── attack.py ├── networks │ ├── __init__.py │ └── imageretrievalnet.py └── __init__.py └── README.md /cirtorch/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cirtorch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cirtorch/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cirtorch/examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cirtorch/networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cirtorch/examples/attack/myutil/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cirtorch/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets, examples, layers, networks, utils 2 | 3 | from .datasets import datahelpers, genericdataset, testdataset, traindataset 4 | from .layers import functional, loss, normalization, pooling 5 | from .networks import imageretrievalnet 6 | from .utils import general, download, evaluate, whiten -------------------------------------------------------------------------------- /cirtorch/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import cirtorch.layers.functional as LF 5 | 6 | # -------------------------------------- 7 | # Normalization layers 8 | # -------------------------------------- 9 | 10 | class L2N(nn.Module): 11 | 12 | def __init__(self, eps=1e-6): 13 | super(L2N,self).__init__() 14 | self.eps = eps 15 | 16 | def forward(self, x): 17 | return LF.l2n(x, eps=self.eps) 18 | 19 | def __repr__(self): 20 | return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')' 21 | -------------------------------------------------------------------------------- /cirtorch/utils/general.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | 4 | def get_root(): 5 | return os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))) 6 | 7 | 8 | def get_data_root(): 9 | return os.path.join(get_root(), 'data') 10 | 11 | 12 | def htime(c): 13 | c = round(c) 14 | 15 | days = c // 86400 16 | hours = c // 3600 % 24 17 | minutes = c // 60 % 60 18 | seconds = c % 60 19 | 20 | if days > 0: 21 | return '{:d}d {:d}h {:d}m {:d}s'.format(days, hours, minutes, seconds) 22 | if hours > 0: 23 | return '{:d}h {:d}m {:d}s'.format(hours, minutes, seconds) 24 | if minutes > 0: 25 | return '{:d}m {:d}s'.format(minutes, seconds) 26 | return '{:d}s'.format(seconds) 27 | 28 | 29 | def sha256_hash(filename, block_size=65536, length=8): 30 | sha256 = hashlib.sha256() 31 | with open(filename, 'rb') as f: 32 | for block in iter(lambda: f.read(block_size), b''): 33 | sha256.update(block) 34 | return sha256.hexdigest()[:length-1] -------------------------------------------------------------------------------- /cirtorch/datasets/testdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | DATASETS = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k'] 5 | 6 | def configdataset(dataset, dir_main): 7 | 8 | dataset = dataset.lower() 9 | 10 | if dataset not in DATASETS: 11 | raise ValueError('Unknown dataset: {}!'.format(dataset)) 12 | 13 | # loading imlist, qimlist, and gnd, in cfg as a dict 14 | gnd_fname = os.path.join(dir_main, dataset, 'gnd_{}.pkl'.format(dataset)) 15 | with open(gnd_fname, 'rb') as f: 16 | cfg = pickle.load(f) 17 | cfg['gnd_fname'] = gnd_fname 18 | 19 | cfg['ext'] = '.jpg' 20 | cfg['qext'] = '.jpg' 21 | cfg['dir_data'] = os.path.join(dir_main, dataset) 22 | cfg['dir_images'] = os.path.join(cfg['dir_data'], 'jpg') 23 | 24 | cfg['n'] = len(cfg['imlist']) 25 | cfg['nq'] = len(cfg['qimlist']) 26 | 27 | cfg['im_fname'] = config_imname 28 | cfg['qim_fname'] = config_qimname 29 | 30 | cfg['dataset'] = dataset 31 | 32 | return cfg 33 | 34 | def config_imname(cfg, i): 35 | return os.path.join(cfg['dir_images'], cfg['imlist'][i] + cfg['ext']) 36 | 37 | def config_qimname(cfg, i): 38 | return os.path.join(cfg['dir_images'], cfg['qimlist'][i] + cfg['qext']) 39 | -------------------------------------------------------------------------------- /cirtorch/utils/whiten.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def whitenapply(X, m, P, dimensions=None): 5 | 6 | if not dimensions: 7 | dimensions = P.shape[0] 8 | 9 | X = np.dot(P[:dimensions, :], X - m) 10 | X = X / (np.linalg.norm(X, ord=2, axis=0, keepdims=True) + 1e-6) 11 | 12 | return X 13 | 14 | 15 | def pcawhitenlearn(X): 16 | 17 | N = X.shape[1] 18 | 19 | # Learning PCA w/o annotations 20 | m = X.mean(axis=1, keepdims=True) 21 | Xc = X - m 22 | Xcov = np.dot(Xc, Xc.T) 23 | Xcov = (Xcov + Xcov.T) / (2 * N) 24 | eigval, eigvec = np.linalg.eig(Xcov) 25 | order = eigval.argsort()[::-1] 26 | eigval = eigval[order] 27 | eigvec = eigvec[:, order] 28 | 29 | P = np.dot(np.linalg.inv(np.sqrt(np.diag(eigval))), eigvec.T) 30 | 31 | return m, P 32 | 33 | 34 | def whitenlearn(X, qidxs, pidxs): 35 | 36 | # Learning Lw w annotations 37 | m = X[:, qidxs].mean(axis=1, keepdims=True) 38 | df = X[:, qidxs] - X[:, pidxs] 39 | __import__("pdb").set_trace() 40 | S = np.dot(df, df.T) / df.shape[1] 41 | P = np.linalg.inv(np.linalg.cholesky(S)) 42 | df = np.dot(P, X - m) 43 | D = np.dot(df, df.T) 44 | eigval, eigvec = np.linalg.eig(D) 45 | order = eigval.argsort()[::-1] 46 | eigval = eigval[order] 47 | eigvec = eigvec[:, order] 48 | 49 | P = np.dot(eigvec.T, P) 50 | 51 | return m, P 52 | -------------------------------------------------------------------------------- /cirtorch/layers/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | import cirtorch.layers.functional as LF 6 | 7 | # -------------------------------------- 8 | # Pooling layers 9 | # -------------------------------------- 10 | 11 | class MAC(nn.Module): 12 | 13 | def __init__(self): 14 | super(MAC,self).__init__() 15 | 16 | def forward(self, x): 17 | return LF.mac(x) 18 | 19 | def __repr__(self): 20 | return self.__class__.__name__ + '()' 21 | 22 | class SPoC(nn.Module): 23 | 24 | def __init__(self): 25 | super(SPoC,self).__init__() 26 | 27 | def forward(self, x): 28 | return LF.spoc(x) 29 | 30 | def __repr__(self): 31 | return self.__class__.__name__ + '()' 32 | 33 | class GeM(nn.Module): 34 | 35 | def __init__(self, p=3, eps=1e-6): 36 | super(GeM,self).__init__() 37 | self.p = Parameter(torch.ones(1)*p) 38 | self.eps = eps 39 | 40 | def forward(self, x): 41 | return LF.gem(x, p=self.p, eps=self.eps) 42 | 43 | def __repr__(self): 44 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 45 | 46 | class RMAC(nn.Module): 47 | 48 | def __init__(self, L=3, eps=1e-6): 49 | super(RMAC,self).__init__() 50 | self.L = L 51 | self.eps = eps 52 | 53 | def forward(self, x): 54 | return LF.rmac(x, L=self.L, eps=self.eps) 55 | 56 | def __repr__(self): 57 | return self.__class__.__name__ + '(' + 'L=' + '{}'.format(self.L) + ')' 58 | 59 | -------------------------------------------------------------------------------- /cirtorch/examples/attack/myutil/baseline.py: -------------------------------------------------------------------------------- 1 | result = { 2 | 'alexnet': { 3 | 'mac': {'oxford5k': [57.11], 4 | 'paris6k': [65.64], 5 | 'roxford5k': [45.23, 32.96, 10.43, 57.25, 55.43, 15.36], 6 | 'rparis6k': [63.99, 46.93, 20.06, 88, 91.29, 58.29]}, 7 | 'gem': {'oxford5k': [59.86], 8 | 'paris6k': [73.66], 9 | 'roxford5k': [50.21, 36.72, 14.29, 58.1, 53.6, 23.32], 10 | 'rparis6k': [70.65, 51.89, 22.8, 87.71, 88.86, 57.86]} 11 | }, 12 | 'vgg16': { 13 | 'mac': {'oxford5k': [81.45], 14 | 'paris6k': [88.31], 15 | 'roxford5k': [75.07, 57.15, 29.96, 78.6, 78.33, 45.57], 16 | 'rparis6k': [86.39, 69.6, 44.97, 93.57, 96.86, 84.71]}, 17 | 'gem': {'oxford5k': [85.24], 18 | 'paris6k': [86.28], 19 | 'roxford5k': [76.43, 59.17, 32.26, 80.52, 81.29, 49.71], 20 | 'rparis6k': [84.66, 67.06, 42.4, 95.14, 97.57, 83]} 21 | }, 22 | 'resnet101': { 23 | 'mac': {'oxford5k': [81.69], 24 | 'paris6k': [83.55], 25 | 'roxford5k': [73.85, 56.14, 29.8, 78.33, 79.86, 46.57], 26 | 'rparis6k': [81.56, 63.91, 39.06, 93.52, 96.71, 79.57]}, 27 | 'gem': {'oxford5k': [86.24], 28 | 'paris6k': [90.66], 29 | 'roxford5k': [80.63, 63.13, 38.51, 82.72, 83.14, 54.57], 30 | 'rparis6k': [90.33, 74.06, 51.69, 94.96, 98.29, 88.29]} 31 | }, 32 | } 33 | -------------------------------------------------------------------------------- /cirtorch/datasets/datahelpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | import torch 5 | 6 | def cid2filename(cid, prefix): 7 | """ 8 | Creates a training image path out of its CID name 9 | 10 | Arguments 11 | --------- 12 | cid : name of the image 13 | prefix : root directory where images are saved 14 | 15 | Returns 16 | ------- 17 | filename : full image filename 18 | """ 19 | return os.path.join(prefix, cid[-2:], cid[-4:-2], cid[-6:-4], cid) 20 | 21 | def pil_loader(path): 22 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 23 | with open(path, 'rb') as f: 24 | img = Image.open(f) 25 | return img.convert('RGB') 26 | 27 | def accimage_loader(path): 28 | import accimage 29 | try: 30 | return accimage.Image(path) 31 | except IOError: 32 | # Potentially a decoding problem, fall back to PIL.Image 33 | return pil_loader(path) 34 | 35 | def default_loader(path): 36 | from torchvision import get_image_backend 37 | if get_image_backend() == 'accimage': 38 | return accimage_loader(path) 39 | else: 40 | return pil_loader(path) 41 | 42 | def imresize(img, imsize): 43 | if isinstance(imsize, int): 44 | img.thumbnail((imsize, imsize), Image.ANTIALIAS) 45 | else: 46 | img = img.resize(imsize) 47 | return img 48 | 49 | def flip(x, dim): 50 | xsize = x.size() 51 | dim = x.dim() + dim if dim < 0 else dim 52 | x = x.view(-1, *xsize[dim:]) 53 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] 54 | return x.view(xsize) 55 | 56 | def collate_tuples(batch): 57 | if len(batch) == 1: 58 | return [batch[0][0]], [batch[0][1]] 59 | return [batch[i][0] for i in range(len(batch))], [batch[i][1] for i in range(len(batch))] 60 | -------------------------------------------------------------------------------- /cirtorch/layers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import cirtorch.layers.functional as LF 5 | 6 | # -------------------------------------- 7 | # Loss/Error layers 8 | # -------------------------------------- 9 | 10 | class ContrastiveLoss(nn.Module): 11 | r"""Creates a criterion that measures the triplet loss given an input 12 | tensors x1, x2, x3 and a margin with a value greater than 0. 13 | This is used for measuring a relative similarity between samples. A triplet 14 | is composed by `a`, `p` and `n`: anchor, positive examples and negative 15 | example respectively. The shape of all input variables should be 16 | :math:`(N, D)`. 17 | 18 | The distance swap is described in detail in the paper `Learning shallow 19 | convolutional feature descriptors with triplet losses`_ by 20 | V. Balntas, E. Riba et al. 21 | 22 | .. math:: 23 | L(a, p, n) = \frac{1}{N} \left( \sum_{i=1}^N \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} \right) 24 | 25 | where :math:`d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p`. 26 | 27 | Args: 28 | anchor: anchor input tensor 29 | positive: positive input tensor 30 | negative: negative input tensor 31 | p: the norm degree. Default: 2 32 | 33 | Shape: 34 | - Input: :math:`(N, D)` where `D = vector dimension` 35 | - Output: :math:`(N, 1)` 36 | 37 | >>> contrastive_loss = ContrastiveLoss(margin=0.7) 38 | >>> input = autograd.Variable(torch.randn(128, 35)) 39 | >>> label = autograd.Variable(torch.Tensor([-1, 1, 0, 0, 0, 0, 0] * 5)) 40 | >>> output = contrastive_loss(input, label) 41 | >>> output.backward() 42 | 43 | .. _Learning shallow convolutional feature descriptors with triplet losses: 44 | http://www.iis.ee.ic.ac.uk/%7Evbalnt/shallow_descr/TFeat_paper.pdf 45 | """ 46 | 47 | def __init__(self, margin=0.7, eps=1e-6): 48 | super(ContrastiveLoss, self).__init__() 49 | self.margin = margin 50 | self.eps = eps 51 | 52 | def forward(self, x, label): 53 | return LF.contrastive_loss(x, label, margin=self.margin, eps=self.eps) 54 | 55 | def __repr__(self): 56 | return self.__class__.__name__ + '(' + 'margin=' + str(self.margin) + ')' 57 | -------------------------------------------------------------------------------- /cirtorch/examples/attack/myutil/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | SMOOTHER = 0.999999 7 | 8 | 9 | def get_random_size(imsize, w, h): 10 | scale = w / h 11 | DELTA = 8 12 | L = imsize[0] 13 | H = imsize[1] 14 | if scale >= 1: 15 | nw = random.randint(L, H) 16 | nh = int(nw / scale) 17 | nh += random.randint(-DELTA, DELTA) 18 | else: 19 | nh = random.randint(L, H) 20 | nw = int(nh * scale) 21 | nw += random.randint(-DELTA, DELTA) 22 | return (nw, nh) 23 | 24 | 25 | def do_whiten(b, m, p): 26 | b = p @ (b - m) 27 | b = b / (b.norm(dim=0) + 1e-6) 28 | return b 29 | 30 | 31 | def one_hot(size, index): 32 | mask = torch.LongTensor(index.size(0), size).fill_(0).to(index.device) 33 | ret = mask.scatter_(1, index.view(-1, 1), 1) 34 | return ret 35 | 36 | 37 | # def w2img(w, eps): 38 | # return w 39 | # return torch.tanh((w)/SMOOTHER) * eps 40 | 41 | 42 | def rescale_check(check, sat, sat_change, sat_min): 43 | return sat_change < check and sat > sat_min 44 | 45 | 46 | def inv_gfr(attack, baseline): 47 | s = 0 48 | for k in attack.keys(): 49 | s += sum( 50 | [ 51 | abs(attack[k][i] - baseline[k][i]) / baseline[k][i] 52 | for i in range(len(attack[k])) 53 | ] 54 | ) 55 | return 1 - s / 14 56 | 57 | 58 | class MultiLoss(nn.Module): 59 | def __init__(self): 60 | super(MultiLoss, self).__init__() 61 | self.weight = nn.Parameter(torch.zeros(3)) 62 | 63 | def forward(self, losses): 64 | w = (-self.weight).exp() 65 | return torch.dot(w, losses) + self.weight.sum() 66 | 67 | 68 | class bcolors: 69 | HEADER = "\033[95m" 70 | OKBLUE = "\033[94m" 71 | OKGREEN = "\033[92m" 72 | WARNING = "\033[93m" 73 | FAIL = "\033[91m" 74 | ENDC = "\033[0m" 75 | BOLD = "\033[1m" 76 | UNDERLINE = "\033[4m" 77 | 78 | def str(s, color): 79 | return color + s + bcolors.ENDC 80 | 81 | 82 | def idcg(n_rel): 83 | """idcg 84 | 85 | :n_rel: number of real doc 86 | :returns: value of idcg 87 | 88 | """ 89 | nums = np.ones(n_rel) 90 | denoms = np.log2(np.arange(n_rel) + 2) 91 | return (nums / denoms).sum() 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the project page of our paper: 2 | 3 | **Universal Perturbation Attack Against Image Retrieval**, 4 | Li, J., Ji, R., Liu, H., Hong, X., Gao, Y., & Tian, Q. 5 | ICCV 2019. 6 | [[PDF]](http://openaccess.thecvf.com/content_ICCV_2019/papers/Li_Universal_Perturbation_Attack_Against_Image_Retrieval_ICCV_2019_paper.pdf) 7 | 8 | ## Code 9 | 10 | Our codes are based on [filipradenovic/cnnimageretrieval-pytorch](https://github.com/filipradenovic/cnnimageretrieval-pytorch) (Commit `c4fca89`). 11 | Please refer to their repository for details. 12 | 13 | The attack codes locate in `cirtorch/examples/attack`. 14 | 15 | ### Prepare Features 16 | 17 | 1. Follow the steps in [filipradenovic/cnnimageretrieval-pytorch](https://github.com/filipradenovic/cnnimageretrieval-pytorch) to download datasets and train the retrieval models. (Our pretrained models are available at [Google Drive](https://drive.google.com/file/d/1iItJHZb2NHh-EyF-A0358a3VmlKg0OMo/view).) 18 | 2. Refer to the function `cluster()` in `cirtorch/examples/attack/myutil/triplet_dataset.py` about extracting features and clustering. 19 | 20 | ### Train Classifiers (Optional) 21 | 22 | ``` 23 | python -m cirtorch.examples.attack.classifier.py PATH 24 | ``` 25 | 26 | ### Generate UAP 27 | 28 | Refer to arguments in `cirtorch/examples/attack/attack.py` for details. 29 | 30 | ### Ranking Distillation 31 | 32 | 1. Refer to `cirtorch/examples/attack/extract_rank.py` for extracting ranking list. 33 | 2. Refer to `cirtorch/examples/attack/distillation.py` for distillation. 34 | 35 | ## Typos in Paper 36 | 37 | 1. Eq. 6 should be ![](http://latex.codecogs.com/gif.latex?\frac{\partial%20d(f,f_j)}{\partial\delta}-\frac{\partial%20d(f,f_k)}{\partial\delta}) 38 | 2. Eq. 7 should be ![](http://latex.codecogs.com/gif.latex?m W: 41 | Hd = idx.tolist()[0] 42 | 43 | v = F.max_pool2d(x, (x.size(-2), x.size(-1))) 44 | v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v) 45 | 46 | for l in range(1, L+1): 47 | wl = math.floor(2*w/(l+1)) 48 | wl2 = math.floor(wl/2 - 1) 49 | 50 | if l+Wd == 1: 51 | b = 0 52 | else: 53 | b = (W-wl)/(l+Wd-1) 54 | cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b) - wl2 # center coordinates 55 | if l+Hd == 1: 56 | b = 0 57 | else: 58 | b = (H-wl)/(l+Hd-1) 59 | cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b) - wl2 # center coordinates 60 | 61 | for i_ in cenH.tolist(): 62 | for j_ in cenW.tolist(): 63 | if wl == 0: 64 | continue 65 | R = x[:,:,(int(i_)+torch.Tensor(range(wl)).long()).tolist(),:] 66 | R = R[:,:,:,(int(j_)+torch.Tensor(range(wl)).long()).tolist()] 67 | vt = F.max_pool2d(R, (R.size(-2), R.size(-1))) 68 | vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt) 69 | v += vt 70 | 71 | return v 72 | 73 | 74 | # -------------------------------------- 75 | # normalization 76 | # -------------------------------------- 77 | 78 | def l2n(x, eps=1e-6): 79 | return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x) 80 | 81 | 82 | # -------------------------------------- 83 | # loss 84 | # -------------------------------------- 85 | 86 | def contrastive_loss(x, label, margin=0.7, eps=1e-6): 87 | # x is D x N 88 | dim = x.size(0) # D 89 | nq = torch.sum(label.data==-1) # number of tuples 90 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n 91 | 92 | x1 = x[:, ::S].permute(1,0).repeat(1,S-1).view((S-1)*nq,dim).permute(1,0) 93 | idx = [i for i in range(len(label)) if label.data[i] != -1] 94 | x2 = x[:, idx] 95 | lbl = label[label!=-1] 96 | 97 | dif = x1 - x2 98 | D = torch.pow(dif+eps, 2).sum(dim=0).sqrt() 99 | 100 | y = 0.5*lbl*torch.pow(D,2) + 0.5*(1-lbl)*torch.pow(torch.clamp(margin-D, min=0),2) 101 | y = torch.sum(y) 102 | return y -------------------------------------------------------------------------------- /cirtorch/examples/attack/classifier.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import shutil 5 | import time 6 | import math 7 | import pickle 8 | import pdb 9 | from glob import glob 10 | from pprint import pprint 11 | 12 | import numpy as np 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim 17 | import torch.utils.data 18 | import torchvision.transforms as transforms 19 | import torchvision.models as models 20 | import torchvision.datasets as datasets 21 | 22 | N_CLASS = 512 23 | 24 | class AverageMeter(object): 25 | """Computes and stores the average and current value""" 26 | 27 | def __init__(self): 28 | self.reset() 29 | 30 | def reset(self): 31 | self.val = 0 32 | self.avg = 0 33 | self.sum = 0 34 | self.count = 0 35 | 36 | def update(self, val, n=1): 37 | self.val = val 38 | self.sum += val * n 39 | self.count += n 40 | self.avg = self.sum / self.count 41 | 42 | def group_list(x, y, group_size): 43 | for i in range(0, len(x), group_size): 44 | yield x[i: i+group_size], y[i: i+group_size] 45 | 46 | 47 | def train(dataset_name): 48 | print(dataset_name) 49 | dataset = pickle.load(open(dataset_name, 'rb')) 50 | targets = [] 51 | for i in range(N_CLASS): 52 | l = dataset['clustered_pool'][i].shape[0] 53 | targets += [i for _ in range(l)] 54 | X = np.concatenate(dataset['clustered_pool']) 55 | Y = np.array(targets) 56 | # cls = nn.Linear(X.shape[1], N_CLASS).cuda() 57 | cls = nn.Sequential( 58 | nn.Linear(X.shape[1], 512), 59 | nn.ReLU(True), 60 | nn.Linear(512, N_CLASS), 61 | ).cuda() 62 | # define optimizer 63 | optimizer = torch.optim.Adam(cls.parameters(), 1e-3, weight_decay=5e-4) 64 | criteria = nn.CrossEntropyLoss() 65 | min_loss = float('inf') 66 | 67 | for epoch in range(50): 68 | print(epoch) 69 | 70 | # set manual seeds per epoch 71 | np.random.seed(epoch) 72 | torch.manual_seed(epoch) 73 | torch.cuda.manual_seed_all(epoch) 74 | 75 | 76 | index = np.random.permutation(X.shape[0]) 77 | x = X[index] 78 | y = Y[index] 79 | 80 | index = 0 81 | acces = AverageMeter() 82 | losses = AverageMeter() 83 | for (data, target) in group_list(x, y, 128): 84 | data = torch.from_numpy(data).cuda() 85 | target = torch.from_numpy(target).cuda() 86 | 87 | optimizer.zero_grad() 88 | output = cls(data) 89 | loss = criteria(output, target) 90 | loss.backward() 91 | optimizer.step() 92 | 93 | pred = output.data.max(1, keepdim=True)[1] 94 | correct = pred.eq(target.data.view_as(pred)).cpu().sum().item() 95 | 96 | acces.update(correct / data.size(0)) 97 | losses.update(loss.item()) 98 | 99 | print(f'Epoch :{epoch}\tLoss: {losses.val}({losses.avg})\t' 100 | f'Acc: {acces.val}({acces.avg})') 101 | if losses.avg < min_loss: 102 | # torch.save(cls.state_dict(), dataset_name + '_cls.pth') 103 | torch.save(cls, dataset_name + '_cls.pth') 104 | 105 | 106 | def main(): 107 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 108 | 109 | # set random seeds (maybe pass as argument) 110 | torch.manual_seed(0) 111 | torch.cuda.manual_seed_all(0) 112 | np.random.seed(0) 113 | 114 | path = sys.argv[1] 115 | datasets = glob(path + '/*.KMeans') 116 | pprint(datasets) 117 | 118 | for dataset in datasets: 119 | train(dataset) 120 | 121 | main() 122 | 123 | -------------------------------------------------------------------------------- /cirtorch/examples/attack/myutil/distillation_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pdb 4 | from random import sample 5 | import random 6 | import time 7 | 8 | import numpy as np 9 | import torch 10 | import torch.utils.data as data 11 | from torch.autograd import Variable 12 | 13 | from cirtorch.networks.imageretrievalnet import extract_vectors 14 | from cirtorch.datasets.datahelpers import default_loader, imresize, cid2filename 15 | from cirtorch.datasets.genericdataset import ImagesFromList 16 | from cirtorch.utils.general import get_data_root 17 | from sklearn.cluster import KMeans 18 | from cirtorch.examples.attack.myutil.utils import get_random_size 19 | from cirtorch.examples.attack.myutil.utils import do_whiten 20 | 21 | import matplotlib 22 | 23 | matplotlib.use("Agg") 24 | import matplotlib.pyplot as plt 25 | 26 | 27 | class Distillation_dataset(data.Dataset): 28 | def __init__( 29 | self, 30 | imsize=None, 31 | nnum=5, 32 | qsize=2000, 33 | poolsize=20000, 34 | transform=None, 35 | loader=default_loader, 36 | filename=None, 37 | q_percent=1, 38 | ): 39 | 40 | # setting up paths 41 | data_root = get_data_root() 42 | name = "retrieval-SfM-120k" 43 | db_root = os.path.join(data_root, "train", name) 44 | ims_root = os.path.join(db_root, "ims") 45 | 46 | # loading db 47 | db_fn = os.path.join(db_root, "{}.pkl".format(name)) 48 | with open(db_fn, "rb") as f: 49 | db = pickle.load(f)["val"] 50 | 51 | # initializing tuples dataset 52 | self.imsize = imsize 53 | self.images = [ 54 | cid2filename(db["cids"][i], ims_root) for i in range(len(db["cids"])) 55 | ] 56 | self.clusters = db["cluster"] 57 | self.qpool = db["qidxs"] 58 | # self.ppool = db['pidxs'] 59 | 60 | # size of training subset for an epoch 61 | self.nnum = nnum 62 | self.qsize = min(qsize, len(self.qpool)) 63 | self.poolsize = min(poolsize, len(self.images)) 64 | self.qidxs = self.qpool 65 | self.index = np.arange(len(self.qidxs)) 66 | 67 | if q_percent < 1: 68 | number = int(len(self.qidxs) * q_percent) 69 | self.index = np.random.permutation(self.index) 70 | self.index = self.index[:number] 71 | 72 | self.pidxs = [] 73 | self.nidxs = [] 74 | 75 | self.poolvecs = None 76 | 77 | self.transform = transform 78 | self.loader = loader 79 | self.filename = filename 80 | self.phase = 1 81 | 82 | self.ranks = torch.load(f"{filename}/ranks_362") 83 | if os.path.isfile(f"{filename}/pool_vecs"): 84 | self.pool_vecs = pickle.load(open(f"{filename}/pool_vecs", "rb")) 85 | print(len(self.images)) 86 | self.loaded_images = [] 87 | if os.path.exists("./images"): 88 | self.loaded_images = pickle.load(open("./images", "rb")) 89 | else: 90 | for i in range(len(self.images)): 91 | try: 92 | img = self.loader(self.images[i]) 93 | if self.imsize is not None: 94 | img = imresize(img, self.imsize) 95 | if self.transform is not None: 96 | img_tensor = self.transform(img).unsqueeze(0) 97 | img.close() 98 | self.loaded_images.append(img_tensor) 99 | except: 100 | self.loaded_images.append(None) 101 | pickle.dump(self.loaded_images, open("./images", "wb")) 102 | 103 | def __getitem__(self, index): 104 | # Not used 105 | # coding in ../distillation.py 106 | pass 107 | 108 | def __len__(self): 109 | return len(self.index) 110 | if not self.qidxs: 111 | return 0 112 | return len(self.qidxs) 113 | -------------------------------------------------------------------------------- /cirtorch/datasets/genericdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.utils.data as data 5 | 6 | from cirtorch.datasets.datahelpers import default_loader, imresize 7 | from cirtorch.examples.attack.myutil.utils import get_random_size 8 | 9 | 10 | class ImagesFromList(data.Dataset): 11 | """A generic data loader that loads images from a list 12 | (Based on ImageFolder from pytorch) 13 | 14 | Args: 15 | root (string): Root directory path. 16 | images (list): Relative image paths as strings. 17 | imsize (int, Default: None): Defines the maximum size of longer image side 18 | bbxs (list): List of (x1,y1,x2,y2) tuples to crop the query images 19 | transform (callable, optional): A function/transform that takes in an PIL image 20 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 21 | loader (callable, optional): A function to load an image given its path. 22 | 23 | Attributes: 24 | images_fn (list): List of full image filename 25 | """ 26 | 27 | def __init__( 28 | self, 29 | root, 30 | images, 31 | imsize=None, 32 | bbxs=None, 33 | transform=None, 34 | loader=default_loader, 35 | random=False, 36 | ): 37 | 38 | images_fn = [os.path.join(root, images[i]) for i in range(len(images))] 39 | 40 | if len(images_fn) == 0: 41 | raise (RuntimeError("Dataset contains 0 images!")) 42 | 43 | self.root = root 44 | self.images = images 45 | self.imsize = imsize 46 | self.images_fn = images_fn 47 | self.bbxs = bbxs 48 | self.transform = transform 49 | self.loader = loader 50 | self.random_size = random 51 | 52 | def __getitem__(self, index): 53 | """ 54 | Args: 55 | index (int): Index 56 | 57 | Returns: 58 | image (PIL): Loaded image 59 | """ 60 | path = self.images_fn[index] 61 | img = self.loader(path) 62 | if self.bbxs: 63 | img = img.crop(self.bbxs[index]) 64 | if self.imsize is not None: 65 | if self.random_size: 66 | w, h = img.size 67 | imsize = get_random_size(self.imsize, w, h) 68 | img = imresize(img, imsize) 69 | else: 70 | img = imresize(img, self.imsize) 71 | if self.transform is not None: 72 | img = self.transform(img) 73 | 74 | return img 75 | 76 | def __len__(self): 77 | return len(self.images_fn) 78 | 79 | def __repr__(self): 80 | fmt_str = "Dataset " + self.__class__.__name__ + "\n" 81 | fmt_str += " Number of images: {}\n".format(self.__len__()) 82 | fmt_str += " Root Location: {}\n".format(self.root) 83 | tmp = " Transforms (if any): " 84 | fmt_str += "{0}{1}\n".format( 85 | tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 86 | ) 87 | return fmt_str 88 | 89 | 90 | class ImagesFromDataList(data.Dataset): 91 | """A generic data loader that loads images given as an array of pytorch tensors 92 | (Based on ImageFolder from pytorch) 93 | 94 | Args: 95 | images (list): Images as tensors. 96 | transform (callable, optional): A function/transform that image as a tensors 97 | and returns a transformed version. E.g, ``normalize`` with mean and std 98 | """ 99 | 100 | def __init__(self, images, transform=None): 101 | 102 | if len(images) == 0: 103 | raise (RuntimeError("Dataset contains 0 images!")) 104 | 105 | self.images = images 106 | self.transform = transform 107 | 108 | def __getitem__(self, index): 109 | """ 110 | Args: 111 | index (int): Index 112 | 113 | Returns: 114 | image (Tensor): Loaded image 115 | """ 116 | img = self.images[index] 117 | if self.transform is not None: 118 | img = self.transform(img) 119 | 120 | if len(img.size()): 121 | # img = img.repeat(3,1,1) 122 | img = img.unsqueeze(0) 123 | 124 | return img 125 | 126 | def __len__(self): 127 | return len(self.images) 128 | 129 | def __repr__(self): 130 | fmt_str = "Dataset " + self.__class__.__name__ + "\n" 131 | fmt_str += " Number of images: {}\n".format(self.__len__()) 132 | tmp = " Transforms (if any): " 133 | fmt_str += "{0}{1}\n".format( 134 | tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 135 | ) 136 | return fmt_str 137 | -------------------------------------------------------------------------------- /cirtorch/utils/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def compute_ap(ranks, nres): 4 | """ 5 | Computes average precision for given ranked indexes. 6 | 7 | Arguments 8 | --------- 9 | ranks : zerro-based ranks of positive images 10 | nres : number of positive images 11 | 12 | Returns 13 | ------- 14 | ap : average precision 15 | """ 16 | 17 | # number of images ranked by the system 18 | nimgranks = len(ranks) 19 | 20 | # accumulate trapezoids in PR-plot 21 | ap = 0 22 | 23 | recall_step = 1. / nres 24 | 25 | for j in np.arange(nimgranks): 26 | rank = ranks[j] 27 | 28 | if rank == 0: 29 | precision_0 = 1. 30 | else: 31 | precision_0 = float(j) / rank 32 | 33 | precision_1 = float(j + 1) / (rank + 1) 34 | 35 | ap += (precision_0 + precision_1) * recall_step / 2. 36 | 37 | return ap 38 | 39 | def compute_map(ranks, gnd, kappas=[]): 40 | """ 41 | Computes the mAP for a given set of returned results. 42 | 43 | Usage: 44 | map = compute_map (ranks, gnd) 45 | computes mean average precsion (map) only 46 | 47 | map, aps, pr, prs = compute_map (ranks, gnd, kappas) 48 | computes mean average precision (map), average precision (aps) for each query 49 | computes mean precision at kappas (pr), precision at kappas (prs) for each query 50 | 51 | Notes: 52 | 1) ranks starts from 0, ranks.shape = db_size X #queries 53 | 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array 54 | 3) If there are no positive images for some query, that query is excluded from the evaluation 55 | """ 56 | 57 | map = 0. 58 | nq = len(gnd) # number of queries 59 | aps = np.zeros(nq) 60 | pr = np.zeros(len(kappas)) 61 | prs = np.zeros((nq, len(kappas))) 62 | nempty = 0 63 | 64 | for i in np.arange(nq): 65 | qgnd = np.array(gnd[i]['ok']) 66 | 67 | # no positive images, skip from the average 68 | if qgnd.shape[0] == 0: 69 | aps[i] = float('nan') 70 | prs[i, :] = float('nan') 71 | nempty += 1 72 | continue 73 | 74 | try: 75 | qgndj = np.array(gnd[i]['junk']) 76 | except: 77 | qgndj = np.empty(0) 78 | 79 | # sorted positions of positive and junk images (0 based) 80 | pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)] 81 | junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)] 82 | 83 | k = 0; 84 | ij = 0; 85 | if len(junk): 86 | # decrease positions of positives based on the number of 87 | # junk images appearing before them 88 | ip = 0 89 | while (ip < len(pos)): 90 | while (ij < len(junk) and pos[ip] > junk[ij]): 91 | k += 1 92 | ij += 1 93 | pos[ip] = pos[ip] - k 94 | ip += 1 95 | 96 | # compute ap 97 | ap = compute_ap(pos, len(qgnd)) 98 | map = map + ap 99 | aps[i] = ap 100 | 101 | # compute precision @ k 102 | pos += 1 # get it to 1-based 103 | for j in np.arange(len(kappas)): 104 | kq = min(max(pos), kappas[j]); 105 | prs[i, j] = (pos <= kq).sum() / kq 106 | pr = pr + prs[i, :] 107 | 108 | map = map / (nq - nempty) 109 | pr = pr / (nq - nempty) 110 | 111 | return map, aps, pr, prs 112 | 113 | 114 | def compute_map_and_print(dataset, ranks, gnd, kappas=[1, 5, 10]): 115 | 116 | # old evaluation protocol 117 | if dataset.startswith('oxford5k') or dataset.startswith('paris6k'): 118 | map, aps, _, _ = compute_map(ranks, gnd) 119 | print('>> {}: mAP {:.2f}'.format(dataset, np.around(map*100, decimals=2))) 120 | return [np.around(map*100, decimals=2)] 121 | 122 | # new evaluation protocol 123 | elif dataset.startswith('roxford5k') or dataset.startswith('rparis6k'): 124 | 125 | gnd_t = [] 126 | for i in range(len(gnd)): 127 | g = {} 128 | g['ok'] = np.concatenate([gnd[i]['easy']]) 129 | g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['hard']]) 130 | gnd_t.append(g) 131 | mapE, apsE, mprE, prsE = compute_map(ranks, gnd_t, kappas) 132 | 133 | gnd_t = [] 134 | for i in range(len(gnd)): 135 | g = {} 136 | g['ok'] = np.concatenate([gnd[i]['easy'], gnd[i]['hard']]) 137 | g['junk'] = np.concatenate([gnd[i]['junk']]) 138 | gnd_t.append(g) 139 | mapM, apsM, mprM, prsM = compute_map(ranks, gnd_t, kappas) 140 | 141 | gnd_t = [] 142 | for i in range(len(gnd)): 143 | g = {} 144 | g['ok'] = np.concatenate([gnd[i]['hard']]) 145 | g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['easy']]) 146 | gnd_t.append(g) 147 | mapH, apsH, mprH, prsH = compute_map(ranks, gnd_t, kappas) 148 | 149 | print('>> {}: mAP E: {}, M: {}, H: {}'.format(dataset, np.around(mapE*100, decimals=2), np.around(mapM*100, decimals=2), np.around(mapH*100, decimals=2))) 150 | print('>> {}: mP@k{} E: {}, M: {}, H: {}'.format(dataset, kappas, np.around(mprE*100, decimals=2), np.around(mprM*100, decimals=2), np.around(mprH*100, decimals=2))) 151 | return [np.around(v*100, decimals=2) for v in [mapE, mapM, mapH, mprE[-1], mprM[-1], mprH[-1]]] 152 | -------------------------------------------------------------------------------- /cirtorch/examples/attack/myutil/triplet_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pdb 4 | from random import sample 5 | import random 6 | import time 7 | 8 | import numpy as np 9 | import torch 10 | import torch.utils.data as data 11 | from torch.autograd import Variable 12 | 13 | from cirtorch.networks.imageretrievalnet import extract_vectors 14 | from cirtorch.datasets.datahelpers import default_loader, imresize, cid2filename 15 | from cirtorch.datasets.genericdataset import ImagesFromList 16 | from cirtorch.utils.general import get_data_root 17 | from sklearn.cluster import KMeans 18 | from cirtorch.examples.attack.myutil.utils import get_random_size 19 | from cirtorch.examples.attack.myutil.utils import do_whiten 20 | 21 | import matplotlib 22 | 23 | matplotlib.use("Agg") 24 | import matplotlib.pyplot as plt 25 | 26 | FNAME = "base/pool_vgg_gem" 27 | TIMES = 1 28 | 29 | 30 | class MyTripletDataset(data.Dataset): 31 | def __init__( 32 | self, 33 | imsize=None, 34 | nnum=5, 35 | qsize=2000, 36 | poolsize=20000, 37 | transform=None, 38 | loader=default_loader, 39 | norm=None, 40 | filename=None, 41 | random=True, 42 | ): 43 | 44 | # setting up paths 45 | data_root = get_data_root() 46 | name = "retrieval-SfM-120k" 47 | db_root = os.path.join(data_root, "train", name) 48 | ims_root = os.path.join(db_root, "ims") 49 | 50 | # loading db 51 | db_fn = os.path.join(db_root, "{}.pkl".format(name)) 52 | with open(db_fn, "rb") as f: 53 | db = pickle.load(f)["val"] 54 | 55 | # initializing tuples dataset 56 | self.imsize = imsize 57 | self.images = [ 58 | cid2filename(db["cids"][i], ims_root) for i in range(len(db["cids"])) 59 | ] 60 | self.clusters = db["cluster"] 61 | self.qpool = db["qidxs"] 62 | # self.ppool = db['pidxs'] 63 | 64 | # size of training subset for an epoch 65 | self.nnum = nnum 66 | self.qsize = min(qsize, len(self.qpool)) 67 | self.poolsize = min(poolsize, len(self.images)) 68 | self.qidxs = None 69 | self.pidxs = None 70 | self.nidxs = None 71 | 72 | self.poolvecs = None 73 | 74 | self.transform = transform 75 | self.loader = loader 76 | self.pool_clusters_centers = None 77 | self.clustered_pool = [] 78 | self.norm = norm 79 | self.kmeans_ = None 80 | if filename is None: 81 | self.filename = FNAME 82 | else: 83 | self.filename = filename 84 | 85 | self.loaded_imgs = [] 86 | self.random = random 87 | 88 | def __getitem__(self, index): 89 | # output = self.loader(self.images[self.qidxs[index]]) 90 | output = self.loaded_imgs[index] 91 | 92 | if self.imsize is not None: 93 | if self.random: 94 | w, h = output.size 95 | imsize = get_random_size(self.imsize, w, h) 96 | output = imresize(output, imsize) 97 | else: 98 | output = imresize(output, self.imsize) 99 | 100 | if self.transform is not None: 101 | output = self.transform(output) 102 | return output 103 | 104 | def __len__(self): 105 | if not self.qidxs: 106 | return 0 107 | return len(self.qidxs) 108 | 109 | def create_epoch_tuples(self, net): 110 | 111 | print(">> Creating tuples...") 112 | if not os.path.exists(self.filename) or not os.path.exists( 113 | self.filename + ".KMeans" 114 | ): 115 | self.cluster(net) 116 | 117 | self.qidxs = self.qpool 118 | self.pidxs = [] 119 | self.nidxs = [] 120 | 121 | # prepare network 122 | net.cuda() 123 | net.eval() 124 | 125 | print(">> Extracting descriptors for pool...") 126 | fname = self.filename 127 | print("") 128 | 129 | print("cluster...") 130 | fname = fname + ".KMeans" 131 | with open(fname, "rb") as f: 132 | p = pickle.load(f) 133 | self.kmeans_ = p["kmeans"] 134 | self.clustered_pool = p["clustered_pool"] 135 | self.pool_clusters_centers = torch.from_numpy(self.kmeans_.cluster_centers_) 136 | print("") 137 | 138 | if not os.path.isfile("data/train_imgs.pkl"): 139 | for idx in self.qidxs: 140 | self.loaded_imgs.append(self.loader(self.images[idx])) 141 | pickle.dump(self.loaded_imgs, open("data/train_imgs.pkl", "wb")) 142 | self.loaded_imgs = pickle.load(open("data/train_imgs.pkl", "rb")) 143 | 144 | def cluster(self, net): 145 | self.pidxs = [] 146 | self.nidxs = [] 147 | 148 | # draw poolsize random images for pool of negatives images 149 | idxs2images = torch.randperm(len(self.images))[: self.poolsize] 150 | 151 | # prepare network 152 | self.net = net 153 | net.cuda() 154 | net.eval() 155 | 156 | Lw = net.meta["Lw"]["retrieval-SfM-120k"]["ss"] 157 | Lw_m = torch.from_numpy(Lw["m"]).cuda().float() 158 | Lw_p = torch.from_numpy(Lw["P"]).cuda().float() 159 | 160 | print(">> Extracting descriptors for pool...") 161 | loader = torch.utils.data.DataLoader( 162 | ImagesFromList( 163 | root="", 164 | images=[self.images[i] for i in idxs2images], 165 | imsize=self.imsize, 166 | transform=self.transform, 167 | random=True, 168 | ), 169 | batch_size=1, 170 | shuffle=False, 171 | num_workers=8, 172 | pin_memory=True, 173 | ) 174 | fname = self.filename 175 | if os.path.exists(fname): 176 | self.poolvecs = torch.load(fname).cuda() 177 | else: 178 | self.poolvecs = torch.Tensor( 179 | net.meta["outputdim"], len(idxs2images) * TIMES 180 | ).cuda() 181 | with torch.no_grad(): 182 | for _ in range(TIMES): 183 | print(_) 184 | for i, input in enumerate(loader): 185 | print( 186 | "\r>>>> {}/{} done...".format(i + 1, len(idxs2images)), 187 | end="", 188 | ) 189 | input = (input - self.norm[0]) / self.norm[1] 190 | b = net(Variable(input.cuda())) 191 | b = do_whiten(b, Lw_m, Lw_p) 192 | self.poolvecs[:, i + _ * len(idxs2images)] = b.squeeze() 193 | torch.save(self.poolvecs.cpu(), fname) 194 | print("") 195 | 196 | print(">> KMeans...") 197 | poolvecs = self.poolvecs.cpu().numpy().T 198 | fname = fname + ".KMeans" 199 | kmeans = KMeans(n_clusters=512, n_jobs=-1) 200 | kmeans.fit(poolvecs) 201 | clustered_pool = [] 202 | self.pool_clusters_centers = torch.from_numpy(kmeans.cluster_centers_) 203 | for i in range(kmeans.cluster_centers_.shape[0]): 204 | clustered_pool.append(poolvecs[kmeans.labels_ == i, :]) 205 | with open(fname, "wb") as f: 206 | pickle.dump({"kmeans": kmeans, "clustered_pool": clustered_pool}, f) 207 | -------------------------------------------------------------------------------- /cirtorch/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def download_test(data_dir): 4 | """ 5 | DOWNLOAD_TEST Checks, and, if required, downloads the necessary datasets for the testing. 6 | 7 | download_test(DATA_ROOT) checks if the data necessary for running the example script exist. 8 | If not it downloads it in the folder structure: 9 | DATA_ROOT/test/oxford5k/ : folder with Oxford images and ground truth file 10 | DATA_ROOT/test/paris6k/ : folder with Paris images and ground truth file 11 | DATA_ROOT/test/roxford5k/ : folder with Oxford images and revisited ground truth file 12 | DATA_ROOT/test/rparis6k/ : folder with Paris images and revisited ground truth file 13 | """ 14 | 15 | # Create data folder if it does not exist 16 | if not os.path.isdir(data_dir): 17 | os.mkdir(data_dir) 18 | 19 | # Create datasets folder if it does not exist 20 | datasets_dir = os.path.join(data_dir, 'test') 21 | if not os.path.isdir(datasets_dir): 22 | os.mkdir(datasets_dir) 23 | 24 | # Download datasets folders test/DATASETNAME/ 25 | datasets = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k'] 26 | for di in range(len(datasets)): 27 | dataset = datasets[di] 28 | 29 | if dataset == 'oxford5k': 30 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' 31 | dl_files = ['oxbuild_images.tgz'] 32 | elif dataset == 'paris6k': 33 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' 34 | dl_files = ['paris_1.tgz', 'paris_2.tgz'] 35 | elif dataset == 'roxford5k': 36 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' 37 | dl_files = ['oxbuild_images.tgz'] 38 | elif dataset == 'rparis6k': 39 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' 40 | dl_files = ['paris_1.tgz', 'paris_2.tgz'] 41 | else: 42 | raise ValueError('Unknown dataset: {}!'.format(dataset)) 43 | 44 | dst_dir = os.path.join(datasets_dir, dataset, 'jpg') 45 | if not os.path.isdir(dst_dir): 46 | 47 | # for oxford and paris download images 48 | if dataset == 'oxford5k' or dataset == 'paris6k': 49 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) 50 | os.makedirs(dst_dir) 51 | for dli in range(len(dl_files)): 52 | dl_file = dl_files[dli] 53 | src_file = os.path.join(src_dir, dl_file) 54 | dst_file = os.path.join(dst_dir, dl_file) 55 | print('>> Downloading dataset {} archive {}...'.format(dataset, dl_file)) 56 | os.system('wget {} -O {}'.format(src_file, dst_file)) 57 | print('>> Extracting dataset {} archive {}...'.format(dataset, dl_file)) 58 | # create tmp folder 59 | dst_dir_tmp = os.path.join(dst_dir, 'tmp') 60 | os.system('mkdir {}'.format(dst_dir_tmp)) 61 | # extract in tmp folder 62 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir_tmp)) 63 | # remove all (possible) subfolders by moving only files in dst_dir 64 | os.system('find {} -type f -exec mv -i {{}} {} \\;'.format(dst_dir_tmp, dst_dir)) 65 | # remove tmp folder 66 | os.system('rm -rf {}'.format(dst_dir_tmp)) 67 | print('>> Extracted, deleting dataset {} archive {}...'.format(dataset, dl_file)) 68 | os.system('rm {}'.format(dst_file)) 69 | 70 | # for roxford and rparis just make sym links 71 | elif dataset == 'roxford5k' or dataset == 'rparis6k': 72 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) 73 | dataset_old = dataset[1:] 74 | dst_dir_old = os.path.join(datasets_dir, dataset_old, 'jpg') 75 | os.mkdir(os.path.join(datasets_dir, dataset)) 76 | os.system('ln -s {} {}'.format(dst_dir_old, dst_dir)) 77 | print('>> Created symbolic link from {} jpg to {} jpg'.format(dataset_old, dataset)) 78 | 79 | 80 | gnd_src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'test', dataset) 81 | gnd_dst_dir = os.path.join(datasets_dir, dataset) 82 | gnd_dl_file = 'gnd_{}.pkl'.format(dataset) 83 | gnd_src_file = os.path.join(gnd_src_dir, gnd_dl_file) 84 | gnd_dst_file = os.path.join(gnd_dst_dir, gnd_dl_file) 85 | if not os.path.exists(gnd_dst_file): 86 | print('>> Downloading dataset {} ground truth file...'.format(dataset)) 87 | os.system('wget {} -O {}'.format(gnd_src_file, gnd_dst_file)) 88 | 89 | 90 | def download_train(data_dir): 91 | """ 92 | DOWNLOAD_TRAIN Checks, and, if required, downloads the necessary datasets for the training. 93 | 94 | download_train(DATA_ROOT) checks if the data necessary for running the example script exist. 95 | If not it downloads it in the folder structure: 96 | DATA_ROOT/train/retrieval-SfM-120k/ : folder with rsfm120k images and db files 97 | DATA_ROOT/train/retrieval-SfM-30k/ : folder with rsfm30k images and db files 98 | """ 99 | 100 | # Create data folder if it does not exist 101 | if not os.path.isdir(data_dir): 102 | os.mkdir(data_dir) 103 | 104 | # Create datasets folder if it does not exist 105 | datasets_dir = os.path.join(data_dir, 'train') 106 | if not os.path.isdir(datasets_dir): 107 | os.mkdir(datasets_dir) 108 | 109 | # Download folder train/retrieval-SfM-120k/ 110 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'ims') 111 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims') 112 | dl_file = 'ims.tar.gz' 113 | if not os.path.isdir(dst_dir): 114 | src_file = os.path.join(src_dir, dl_file) 115 | dst_file = os.path.join(dst_dir, dl_file) 116 | print('>> Image directory does not exist. Creating: {}'.format(dst_dir)) 117 | os.makedirs(dst_dir) 118 | print('>> Downloading ims.tar.gz...') 119 | os.system('wget {} -O {}'.format(src_file, dst_file)) 120 | print('>> Extracting {}...'.format(dst_file)) 121 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir)) 122 | print('>> Extracted, deleting {}...'.format(dst_file)) 123 | os.system('rm {}'.format(dst_file)) 124 | 125 | # Create symlink for train/retrieval-SfM-30k/ 126 | dst_dir_old = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims') 127 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-30k', 'ims') 128 | if not os.path.isdir(dst_dir): 129 | os.makedirs(os.path.join(datasets_dir, 'retrieval-SfM-30k')) 130 | os.system('ln -s {} {}'.format(dst_dir_old, dst_dir)) 131 | print('>> Created symbolic link from retrieval-SfM-120k/ims to retrieval-SfM-30k/ims') 132 | 133 | # Download db files 134 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'dbs') 135 | datasets = ['retrieval-SfM-120k', 'retrieval-SfM-30k'] 136 | for dataset in datasets: 137 | dst_dir = os.path.join(datasets_dir, dataset) 138 | if dataset == 'retrieval-SfM-120k': 139 | dl_files = ['{}.pkl'.format(dataset), '{}-whiten.pkl'.format(dataset)] 140 | elif dataset == 'retrieval-SfM-30k': 141 | dl_files = ['{}-whiten.pkl'.format(dataset)] 142 | 143 | if not os.path.isdir(dst_dir): 144 | print('>> Dataset directory does not exist. Creating: {}'.format(dst_dir)) 145 | os.mkdir(dst_dir) 146 | 147 | for i in range(len(dl_files)): 148 | src_file = os.path.join(src_dir, dl_files[i]) 149 | dst_file = os.path.join(dst_dir, dl_files[i]) 150 | if not os.path.isfile(dst_file): 151 | print('>> DB file {} does not exist. Downloading...'.format(dl_files[i])) 152 | os.system('wget {} -O {}'.format(src_file, dst_file)) 153 | -------------------------------------------------------------------------------- /cirtorch/networks/imageretrievalnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.model_zoo as model_zoo 7 | from torch.autograd import Variable 8 | 9 | import torchvision 10 | 11 | from cirtorch.layers.pooling import MAC, SPoC, GeM, RMAC 12 | from cirtorch.layers.normalization import L2N 13 | from cirtorch.datasets.genericdataset import ImagesFromList 14 | from cirtorch.utils.general import get_data_root 15 | 16 | # for some models, we have imported features (convolutions) from caffe because the image retrieval performance is higher for them 17 | FEATURES = { 18 | 'vgg16' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-vgg16-features-d369c8e.pth', 19 | 'resnet50' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet50-features-ac468af.pth', 20 | 'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet101-features-10a101d.pth', 21 | 'resnet152' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet152-features-1011020.pth', 22 | } 23 | 24 | POOLING = { 25 | 'mac' : MAC, 26 | 'spoc' : SPoC, 27 | 'gem' : GeM, 28 | 'rmac' : RMAC, 29 | } 30 | 31 | WHITENING = { 32 | 'alexnet-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-whiten-454ad53.pth', 33 | 'vgg16-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-whiten-eaa6695.pth', 34 | 'resnet101-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-whiten-22ab0c1.pth', 35 | } 36 | 37 | OUTPUT_DIM = { 38 | 'alexnet' : 256, 39 | 'vgg11' : 512, 40 | 'vgg13' : 512, 41 | 'vgg16' : 512, 42 | 'vgg19' : 512, 43 | 'resnet18' : 512, 44 | 'resnet34' : 512, 45 | 'resnet50' : 2048, 46 | 'resnet101' : 2048, 47 | 'resnet152' : 2048, 48 | 'densenet121' : 1024, 49 | 'densenet161' : 2208, 50 | 'densenet169' : 1664, 51 | 'densenet201' : 1920, 52 | 'squeezenet1_0' : 512, 53 | 'squeezenet1_1' : 512, 54 | } 55 | 56 | 57 | class ImageRetrievalNet(nn.Module): 58 | 59 | def __init__(self, features, pool, whiten, meta): 60 | super(ImageRetrievalNet, self).__init__() 61 | self.features = nn.Sequential(*features) 62 | self.pool = pool 63 | self.whiten = whiten 64 | self.norm = L2N() 65 | self.meta = meta 66 | # self.classifier = nn.Linear(meta['outputdim'], 300) 67 | 68 | 69 | # def classify(self, x): 70 | # o = self.pool(self.features(x)) 71 | # o = o.view(o.size(0), -1) 72 | # o = self.classifier(o) 73 | # return o 74 | 75 | 76 | def forward(self, x): 77 | # features -> pool -> norm 78 | o = self.norm(self.pool(self.features(x))).squeeze(-1).squeeze(-1) 79 | # if whiten exist: whiten -> norm 80 | if self.whiten is not None: 81 | o = self.norm(self.whiten(o)) 82 | # permute so that it is Dx1 column vector per image (DxN if many images) 83 | return o.permute(1,0) 84 | 85 | def __repr__(self): 86 | tmpstr = super(ImageRetrievalNet, self).__repr__()[:-1] 87 | tmpstr += self.meta_repr() 88 | tmpstr = tmpstr + ')' 89 | return tmpstr 90 | 91 | def meta_repr(self): 92 | tmpstr = ' (' + 'meta' + '): dict( \n' # + self.meta.__repr__() + '\n' 93 | tmpstr += ' architecture: {}\n'.format(self.meta['architecture']) 94 | tmpstr += ' pooling: {}\n'.format(self.meta['pooling']) 95 | tmpstr += ' whitening: {}\n'.format(self.meta['whitening']) 96 | tmpstr += ' outputdim: {}\n'.format(self.meta['outputdim']) 97 | tmpstr += ' mean: {}\n'.format(self.meta['mean']) 98 | tmpstr += ' std: {}\n'.format(self.meta['std']) 99 | tmpstr = tmpstr + ' )\n' 100 | return tmpstr 101 | 102 | 103 | def init_network(model='resnet101', pooling='gem', whitening=False, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], pretrained=True): 104 | 105 | # loading network from torchvision 106 | if pretrained: 107 | if model not in FEATURES: 108 | # initialize with network pretrained on imagenet in pytorch 109 | net_in = getattr(torchvision.models, model)(pretrained=True) 110 | else: 111 | # initialize with random weights, later on we will fill features with custom pretrained network 112 | net_in = getattr(torchvision.models, model)(pretrained=False) 113 | else: 114 | # initialize with random weights 115 | net_in = getattr(torchvision.models, model)(pretrained=False) 116 | 117 | # initialize features 118 | # take only convolutions for features, 119 | # always ends with ReLU to make last activations non-negative 120 | if model.startswith('alexnet'): 121 | features = list(net_in.features.children())[:-1] 122 | elif model.startswith('vgg'): 123 | features = list(net_in.features.children())[:-1] 124 | elif model.startswith('resnet'): 125 | features = list(net_in.children())[:-2] 126 | elif model.startswith('densenet'): 127 | features = list(net_in.features.children()) 128 | features.append(nn.ReLU(inplace=True)) 129 | elif model.startswith('squeezenet'): 130 | features = list(net_in.features.children()) 131 | else: 132 | raise ValueError('Unsupported or unknown model: {}!'.format(model)) 133 | 134 | # initialize pooling 135 | pool = POOLING[pooling]() 136 | 137 | # get output dimensionality size 138 | dim = OUTPUT_DIM[model] 139 | 140 | # initialize whitening 141 | if whitening: 142 | w = '{}-{}'.format(model, pooling) 143 | whiten = nn.Linear(dim, dim, bias=True) 144 | if w in WHITENING: 145 | print(">> {}: for '{}' custom computed whitening '{}' is used" 146 | .format(os.path.basename(__file__), w, os.path.basename(WHITENING[w]))) 147 | whiten_dir = os.path.join(get_data_root(), 'whiten') 148 | whiten.load_state_dict(model_zoo.load_url(WHITENING[w], model_dir=whiten_dir)) 149 | else: 150 | print(">> {}: for '{}' there is no whitening computed, random weights are used" 151 | .format(os.path.basename(__file__), w)) 152 | else: 153 | whiten = None 154 | 155 | # create meta information to be stored in the network 156 | meta = {'architecture':model, 'pooling':pooling, 'whitening':whitening, 'outputdim':dim, 'mean':mean, 'std':std} 157 | 158 | # create a generic image retrieval network 159 | net = ImageRetrievalNet(features, pool, whiten, meta) 160 | 161 | # initialize features with custom pretrained network if needed 162 | if pretrained and model in FEATURES: 163 | print(">> {}: for '{}' custom pretrained features '{}' are used" 164 | .format(os.path.basename(__file__), model, os.path.basename(FEATURES[model]))) 165 | model_dir = os.path.join(get_data_root(), 'networks') 166 | net.features.load_state_dict(model_zoo.load_url(FEATURES[model], model_dir=model_dir)) 167 | 168 | return net 169 | 170 | 171 | def extract_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10): 172 | # moving network to gpu and eval mode 173 | net.cuda() 174 | net.eval() 175 | 176 | # creating dataset loader 177 | loader = torch.utils.data.DataLoader( 178 | ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform), 179 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True 180 | ) 181 | 182 | # extracting vectors 183 | vecs = torch.zeros(net.meta['outputdim'], len(images)) 184 | for i, input in enumerate(loader): 185 | input_var = Variable(input.cuda()) 186 | 187 | if len(ms) == 1: 188 | vecs[:, i] = extract_ss(net, input_var) 189 | else: 190 | vecs[:, i] = extract_ms(net, input_var, ms, msp) 191 | 192 | if (i+1) % print_freq == 0 or (i+1) == len(images): 193 | print('\r>>>> {}/{} done...'.format((i+1), len(images)), end='') 194 | print('') 195 | return vecs 196 | 197 | 198 | def extract_ss(net, input_var): 199 | return net(input_var).cpu().data.squeeze() 200 | 201 | 202 | def extract_ms(net, input_var, ms, msp): 203 | 204 | v = torch.zeros(net.meta['outputdim']) 205 | 206 | for s in ms: 207 | if s == 1: 208 | input_var_t = input_var.clone() 209 | else: 210 | size = (int(input_var.size(-2) * s), int(input_var.size(-1) * s)) 211 | input_var_t = nn.functional.upsample(input_var, size=size, mode='bilinear') 212 | v += net(input_var_t).pow(msp).cpu().data.squeeze() 213 | 214 | v /= len(ms) 215 | v = v.pow(1./msp) 216 | v /= v.norm() 217 | 218 | return v 219 | -------------------------------------------------------------------------------- /cirtorch/datasets/traindataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import torch 5 | import torch.utils.data as data 6 | from torch.autograd import Variable 7 | 8 | from cirtorch.networks.imageretrievalnet import extract_vectors 9 | from cirtorch.datasets.datahelpers import default_loader, imresize, cid2filename 10 | from cirtorch.datasets.genericdataset import ImagesFromList 11 | from cirtorch.utils.general import get_data_root 12 | 13 | class TuplesDataset(data.Dataset): 14 | """Data loader that loads training and validation tuples of 15 | Radenovic etal ECCV16: CNN image retrieval learns from BoW 16 | 17 | Args: 18 | name (string): dataset name: 'retrieval-sfm-120k' 19 | mode (string): 'train' or 'val' for training and validation parts of dataset 20 | imsize (int, Default: None): Defines the maximum size of longer image side 21 | transform (callable, optional): A function/transform that takes in an PIL image 22 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 23 | loader (callable, optional): A function to load an image given its path. 24 | nnum (int, Default:5): Number of negatives for a query image in a training tuple 25 | qsize (int, Default:1000): Number of query images, ie number of (q,p,n1,...nN) tuples, to be processed in one epoch 26 | poolsize (int, Default:10000): Pool size for negative images re-mining 27 | 28 | Attributes: 29 | images (list): List of full filenames for each image 30 | clusters (list): List of clusterID per image 31 | qpool (list): List of all query image indexes 32 | ppool (list): List of positive image indexes, each corresponding to query at the same position in qpool 33 | 34 | qidxs (list): List of qsize query image indexes to be processed in an epoch 35 | pidxs (list): List of qsize positive image indexes, each corresponding to query at the same position in qidxs 36 | nidxs (list): List of qsize tuples of negative images 37 | Each nidxs tuple contains nnum images corresponding to query image at the same position in qidxs 38 | 39 | Lists qidxs, pidxs, nidxs are refreshed by calling the ``create_epoch_tuples()`` method, 40 | ie new q-p pairs are picked and negative images are remined 41 | """ 42 | 43 | def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, transform=None, loader=default_loader): 44 | 45 | if not (mode == 'train' or mode == 'val'): 46 | raise(RuntimeError("MODE should be either train or val, passed as string")) 47 | 48 | # setting up paths 49 | data_root = get_data_root() 50 | db_root = os.path.join(data_root, 'train', name) 51 | ims_root = os.path.join(db_root, 'ims') 52 | 53 | # loading db 54 | db_fn = os.path.join(db_root, '{}.pkl'.format(name)) 55 | with open(db_fn, 'rb') as f: 56 | db = pickle.load(f)[mode] 57 | 58 | # initializing tuples dataset 59 | self.name = name 60 | self.mode = mode 61 | self.imsize = imsize 62 | self.images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))] 63 | self.clusters = db['cluster'] 64 | self.qpool = db['qidxs'] 65 | self.ppool = db['pidxs'] 66 | 67 | ## If we want to keep only unique q-p pairs 68 | ## However, ordering of pairs will change, although that is not important 69 | # qpidxs = list(set([(self.qidxs[i], self.pidxs[i]) for i in range(len(self.qidxs))])) 70 | # self.qidxs = [qpidxs[i][0] for i in range(len(qpidxs))] 71 | # self.pidxs = [qpidxs[i][1] for i in range(len(qpidxs))] 72 | 73 | # size of training subset for an epoch 74 | self.nnum = nnum 75 | self.qsize = min(qsize, len(self.qpool)) 76 | self.poolsize = min(poolsize, len(self.images)) 77 | self.qidxs = None 78 | self.pidxs = None 79 | self.nidxs = None 80 | 81 | self.transform = transform 82 | self.loader = loader 83 | 84 | def __getitem__(self, index): 85 | """ 86 | Args: 87 | index (int): Index 88 | 89 | Returns: 90 | images tuple (q,p,n1,...,nN): Loaded train/val tuple at index of self.qidxs 91 | """ 92 | if self.__len__() == 0: 93 | raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!")) 94 | 95 | output = [] 96 | # query image 97 | output.append(self.loader(self.images[self.qidxs[index]])) 98 | # positive image 99 | output.append(self.loader(self.images[self.pidxs[index]])) 100 | # negative images 101 | for i in range(len(self.nidxs[index])): 102 | output.append(self.loader(self.images[self.nidxs[index][i]])) 103 | 104 | if self.imsize is not None: 105 | output = [imresize(img, self.imsize) for img in output] 106 | 107 | if self.transform is not None: 108 | output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))] 109 | 110 | target = torch.Tensor([-1, 1] + [0]*len(self.nidxs[index])) 111 | 112 | return output, target 113 | 114 | def __len__(self): 115 | if not self.qidxs: 116 | return 0 117 | return len(self.qidxs) 118 | 119 | def __repr__(self): 120 | fmt_str = self.__class__.__name__ + '\n' 121 | fmt_str += ' Name and mode: {} {}\n'.format(self.name, self.mode) 122 | fmt_str += ' Number of images: {}\n'.format(len(self.images)) 123 | fmt_str += ' Number of training tuples: {}\n'.format(len(self.qpool)) 124 | fmt_str += ' Number of negatives per tuple: {}\n'.format(self.nnum) 125 | fmt_str += ' Number of tuples processed in an epoch: {}\n'.format(self.qsize) 126 | fmt_str += ' Pool size for negative remining: {}\n'.format(self.poolsize) 127 | tmp = ' Transforms (if any): ' 128 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 129 | return fmt_str 130 | 131 | def create_epoch_tuples(self, net): 132 | 133 | print('>> Creating tuples for an epoch of {}-{}...'.format(self.name, self.mode)) 134 | 135 | ## ------------------------ 136 | ## SELECTING POSITIVE PAIRS 137 | ## ------------------------ 138 | 139 | # draw qsize random queries for tuples 140 | idxs2qpool = torch.randperm(len(self.qpool))[:self.qsize] 141 | self.qidxs = [self.qpool[i] for i in idxs2qpool] 142 | self.pidxs = [self.ppool[i] for i in idxs2qpool] 143 | 144 | ## ------------------------ 145 | ## SELECTING NEGATIVE PAIRS 146 | ## ------------------------ 147 | 148 | # if nnum = 0 create dummy nidxs 149 | # useful when only positives used for training 150 | if self.nnum == 0: 151 | self.nidxs = [[] for _ in range(len(self.qidxs))] 152 | return 153 | 154 | # draw poolsize random images for pool of negatives images 155 | idxs2images = torch.randperm(len(self.images))[:self.poolsize] 156 | 157 | # prepare network 158 | net.cuda() 159 | net.eval() 160 | 161 | print('>> Extracting descriptors for query images...') 162 | loader = torch.utils.data.DataLoader( 163 | ImagesFromList(root='', images=[self.images[i] for i in self.qidxs], imsize=self.imsize, transform=self.transform), 164 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True 165 | ) 166 | qvecs = torch.Tensor(net.meta['outputdim'], len(self.qidxs)).cuda() 167 | for i, input in enumerate(loader): 168 | print('\r>>>> {}/{} done...'.format(i+1, len(self.qidxs)), end='') 169 | qvecs[:, i] = net(Variable(input.cuda())).data.squeeze() 170 | print('') 171 | 172 | print('>> Extracting descriptors for negative pool...') 173 | loader = torch.utils.data.DataLoader( 174 | ImagesFromList(root='', images=[self.images[i] for i in idxs2images], imsize=self.imsize, transform=self.transform), 175 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True 176 | ) 177 | poolvecs = torch.Tensor(net.meta['outputdim'], len(idxs2images)).cuda() 178 | for i, input in enumerate(loader): 179 | print('\r>>>> {}/{} done...'.format(i+1, len(idxs2images)), end='') 180 | poolvecs[:, i] = net(Variable(input.cuda())).data.squeeze() 181 | print('') 182 | 183 | print('>> Searching for hard negatives...') 184 | scores = torch.mm(poolvecs.t(), qvecs) 185 | scores, ranks = torch.sort(scores, dim=0, descending=True) 186 | self.nidxs = [] 187 | for q in range(len(self.qidxs)): 188 | qcluster = self.clusters[self.qidxs[q]] 189 | clusters = [qcluster] 190 | nidxs = [] 191 | r = 0 192 | avg_ndist = torch.Tensor([0]).cuda() 193 | n_ndist = torch.Tensor([0]).cuda() 194 | while len(nidxs) < self.nnum: 195 | potential = idxs2images[ranks[r, q]] 196 | # take at most one image from the same cluster 197 | if not self.clusters[potential] in clusters: 198 | nidxs.append(potential) 199 | clusters.append(self.clusters[potential]) 200 | avg_ndist += torch.pow(qvecs[:,q]-poolvecs[:,ranks[r, q]]+1e-6, 2).sum(dim=0).sqrt() 201 | n_ndist += 1 202 | r += 1 203 | self.nidxs.append(nidxs) 204 | print('>>>> Average negative distance: {:.2f}'.format(list((avg_ndist/n_ndist).cpu())[0])) 205 | print('>>>> Done') 206 | -------------------------------------------------------------------------------- /cirtorch/examples/attack/extract_rank.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import math 6 | import pickle 7 | import pdb 8 | import random 9 | 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim 15 | import torch.utils.data 16 | import torchvision.transforms as transforms 17 | import torchvision.models as models 18 | import torchvision.datasets as datasets 19 | 20 | from cirtorch.networks.imageretrievalnet import init_network, extract_vectors 21 | from cirtorch.layers.loss import ContrastiveLoss 22 | from cirtorch.datasets.datahelpers import collate_tuples, cid2filename 23 | from cirtorch.datasets.traindataset import TuplesDataset 24 | from cirtorch.datasets.testdataset import configdataset 25 | from cirtorch.utils.download import download_train, download_test 26 | from cirtorch.utils.whiten import whitenlearn, whitenapply 27 | from cirtorch.utils.evaluate import compute_map_and_print 28 | from cirtorch.utils.general import get_data_root, htime 29 | from cirtorch.examples.attack.myutil.utils import do_whiten 30 | from cirtorch.examples.attack.myutil.triplet_dataset import MyTripletDataset 31 | from cirtorch.datasets.genericdataset import ImagesFromList 32 | 33 | training_dataset_names = ["retrieval-SfM-120k", "Landmarks"] 34 | test_datasets_names = [ 35 | "oxford5k,paris6k", 36 | "roxford5k,rparis6k", 37 | "oxford5k,paris6k,roxford5k,rparis6k", 38 | ] 39 | test_whiten_names = ["retrieval-SfM-30k", "retrieval-SfM-120k"] 40 | 41 | model_names = sorted( 42 | name 43 | for name in models.__dict__ 44 | if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) 45 | ) 46 | pool_names = ["mac", "spoc", "gem", "rmac"] 47 | 48 | parser = argparse.ArgumentParser(description="PyTorch CNN Image Retrieval Training") 49 | 50 | parser.add_argument( 51 | "--network-path", help="network path, destination where network is saved" 52 | ) 53 | parser.add_argument( 54 | "--image-size", 55 | default=362, 56 | type=int, 57 | metavar="N", 58 | help="maximum size of longer image side used for training (default: 1024)", 59 | ) 60 | 61 | # standard train/val options 62 | parser.add_argument( 63 | "--gpu-id", 64 | "-g", 65 | default="0", 66 | metavar="N", 67 | help="gpu id used for training (default: 0)", 68 | ) 69 | parser.add_argument( 70 | "--workers", 71 | "-j", 72 | default=8, 73 | type=int, 74 | metavar="N", 75 | help="number of data loading workers (default: 8)", 76 | ) 77 | parser.add_argument( 78 | "--epochs", 79 | default=100, 80 | type=int, 81 | metavar="N", 82 | help="number of total epochs to run (default: 100)", 83 | ) 84 | min_loss = float("inf") 85 | 86 | 87 | def main(): 88 | global args, min_loss 89 | args = parser.parse_args() 90 | 91 | # set cuda visible device 92 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 93 | 94 | # check if test dataset are downloaded 95 | # and download if they are not 96 | download_train(get_data_root()) 97 | download_test(get_data_root()) 98 | 99 | # set random seeds (maybe pass as argument) 100 | torch.manual_seed(1234) 101 | torch.cuda.manual_seed_all(1234) 102 | np.random.seed(1234) 103 | random.seed(1234) 104 | torch.backends.cudnn.deterministic = True 105 | 106 | state = torch.load(args.network_path) 107 | model = init_network( 108 | model=state["meta"]["architecture"], 109 | pooling=state["meta"]["pooling"], 110 | whitening=state["meta"]["whitening"], 111 | mean=state["meta"]["mean"], 112 | std=state["meta"]["std"], 113 | pretrained=False, 114 | ) 115 | model.load_state_dict(state["state_dict"]) 116 | model.meta["Lw"] = state["meta"]["Lw"] 117 | model.cuda() 118 | 119 | # whitening 120 | Lw = model.meta["Lw"]["retrieval-SfM-120k"]["ss"] 121 | Lw_m = torch.from_numpy(Lw["m"]).cuda().float() 122 | Lw_p = torch.from_numpy(Lw["P"]).cuda().float() 123 | 124 | # Data loading code 125 | normalize = transforms.Normalize(mean=model.meta["mean"], std=model.meta["std"]) 126 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 127 | query_dataset = MyTripletDataset( 128 | # imsize=args.image_size, 129 | imsize=(362, 362), 130 | random=False, 131 | transform=transform, 132 | norm=(0, 0), 133 | filename="base/" + args.network_path.replace("/", "_") + "_triplet", 134 | ) 135 | # val_dataset.test_cluster(model) 136 | # return 137 | query_dataset.create_epoch_tuples(model) 138 | query_loader = torch.utils.data.DataLoader( 139 | query_dataset, 140 | batch_size=1, 141 | shuffle=False, 142 | num_workers=args.workers, 143 | pin_memory=True, 144 | worker_init_fn=lambda _: random.seed(1234), 145 | ) 146 | 147 | base_dataset = ImagesFromList( 148 | root="", 149 | images=query_dataset.images, 150 | # imsize=query_dataset.imsize, 151 | imsize=(362, 362), 152 | transform=query_dataset.transform, 153 | random=False, 154 | ) 155 | base_loader = torch.utils.data.DataLoader(base_dataset, batch_size=1, shuffle=False) 156 | 157 | # test(["oxford5k"], model) 158 | extract(query_loader, base_loader, model, Lw_m, Lw_p) 159 | 160 | 161 | def extract(query_loader, base_loader, model, Lw_m, Lw_p): 162 | # create tuples for validation 163 | query_loader.dataset.create_epoch_tuples(model) 164 | 165 | # switch to evaluate mode 166 | model.eval() 167 | 168 | nq = len(query_loader.dataset) 169 | nb = len(base_loader.dataset) 170 | base_features = torch.Tensor(model.meta["outputdim"], nb).cuda() 171 | query_features = torch.Tensor(model.meta["outputdim"], nq).cuda() 172 | ranks = torch.Tensor(nq, nb) 173 | network_path = args.network_path.replace("/", "_") 174 | os.makedirs(f"ranks/{network_path}", exist_ok=True) 175 | with torch.no_grad(): 176 | print(">>> base") 177 | for i, input in enumerate(base_loader): 178 | feature = model(input.cuda()) 179 | feature = do_whiten(feature, Lw_m, Lw_p) 180 | base_features[:, i] = feature.squeeze() 181 | print(">>> base over") 182 | torch.save(base_features, f"ranks/{network_path}/base_362") 183 | 184 | print(">>> query") 185 | for i, input in enumerate(query_loader): 186 | feature = model(input.cuda()) 187 | feature = do_whiten(feature, Lw_m, Lw_p) 188 | query_features[:, i] = feature.squeeze() 189 | score = base_features.t() @ feature 190 | _, rank = torch.sort(score, dim=0, descending=True) 191 | ranks[i, :] = rank.squeeze() 192 | torch.save(query_features, f"ranks/{network_path}/query_362") 193 | print(">>> query over") 194 | torch.save(ranks, f"ranks/{network_path}/ranks_362") 195 | 196 | 197 | def test(datasets, net): 198 | print(">> Evaluating network on test datasets...") 199 | image_size = 1024 200 | 201 | # moving network to gpu and eval mode 202 | net.cuda() 203 | net.eval() 204 | # set up the transform 205 | normalize = transforms.Normalize(mean=net.meta["mean"], std=net.meta["std"]) 206 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 207 | 208 | # compute whitening 209 | # Lw = None 210 | Lw = net.meta["Lw"]["retrieval-SfM-120k"]["ss"] 211 | 212 | # evaluate on test datasets 213 | # datasets = args.test_datasets.split(",") 214 | for dataset in datasets: 215 | start = time.time() 216 | 217 | print(">> {}: Extracting...".format(dataset)) 218 | 219 | # prepare config structure for the test dataset 220 | cfg = configdataset(dataset, os.path.join(get_data_root(), "test")) 221 | images = [cfg["im_fname"](cfg, i) for i in range(cfg["n"])] 222 | qimages = [cfg["qim_fname"](cfg, i) for i in range(cfg["nq"])] 223 | bbxs = [tuple(cfg["gnd"][i]["bbx"]) for i in range(cfg["nq"])] 224 | 225 | # extract database and query vectors 226 | print(">> {}: database images...".format(dataset)) 227 | vecs = extract_vectors(net, images, image_size, transform) 228 | print(">> {}: query images...".format(dataset)) 229 | qvecs = extract_vectors(net, qimages, image_size, transform, bbxs) 230 | 231 | print(">> {}: Evaluating...".format(dataset)) 232 | 233 | # convert to numpy 234 | vecs = vecs.numpy() 235 | qvecs = qvecs.numpy() 236 | 237 | # search, rank, and print 238 | scores = np.dot(vecs.T, qvecs) 239 | ranks = np.argsort(-scores, axis=0) 240 | compute_map_and_print(dataset, ranks, cfg["gnd"]) 241 | 242 | if Lw is not None: 243 | # whiten the vectors 244 | vecs_lw = whitenapply(vecs, Lw["m"], Lw["P"]) 245 | qvecs_lw = whitenapply(qvecs, Lw["m"], Lw["P"]) 246 | 247 | # search, rank, and print 248 | scores = np.dot(vecs_lw.T, qvecs_lw) 249 | ranks = np.argsort(-scores, axis=0) 250 | compute_map_and_print(dataset + " + whiten", ranks, cfg["gnd"]) 251 | 252 | print(">> {}: elapsed time: {}".format(dataset, htime(time.time() - start))) 253 | 254 | 255 | def save_checkpoint(state, is_best, directory): 256 | filename = os.path.join(directory, "model_epoch%d.pth.tar" % state["epoch"]) 257 | torch.save(state, filename) 258 | if is_best: 259 | filename_best = os.path.join(directory, "model_best.pth.tar") 260 | shutil.copyfile(filename, filename_best) 261 | 262 | 263 | class AverageMeter(object): 264 | """Computes and stores the average and current value""" 265 | 266 | def __init__(self): 267 | self.reset() 268 | 269 | def reset(self): 270 | self.val = 0 271 | self.avg = 0 272 | self.sum = 0 273 | self.count = 0 274 | 275 | def update(self, val, n=1): 276 | self.val = val 277 | self.sum += val * n 278 | self.count += n 279 | self.avg = self.sum / self.count 280 | 281 | 282 | if __name__ == "__main__": 283 | main() 284 | -------------------------------------------------------------------------------- /cirtorch/examples/attack/myutil/mi_sgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | from cirtorch.examples.attack.myutil.utils import rescale_check 4 | import math 5 | 6 | 7 | CHECK = 1e-5 8 | # CHECK = 1e-3 9 | SAT_MIN = 0.5 10 | 11 | 12 | class MI_SGD(Optimizer): 13 | r"""Implements stochastic gradient descent (optionally with momentum). 14 | 15 | Nesterov momentum is based on the formula from 16 | `On the importance of initialization and momentum in deep learning`__. 17 | 18 | Args: 19 | params (iterable): iterable of parameters to optimize or dicts defining 20 | parameter groups 21 | lr (float): learning rate 22 | momentum (float, optional): momentum factor (default: 0) 23 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 24 | dampening (float, optional): dampening for momentum (default: 0) 25 | nesterov (bool, optional): enables Nesterov momentum (default: False) 26 | 27 | Example: 28 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 29 | >>> optimizer.zero_grad() 30 | >>> loss_fn(model(input), target).backward() 31 | >>> optimizer.step() 32 | 33 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 34 | 35 | .. note:: 36 | The implementation of SGD with Momentum/Nesterov subtly differs from 37 | Sutskever et. al. and implementations in some other frameworks. 38 | 39 | Considering the specific case of Momentum, the update can be written as 40 | 41 | .. math:: 42 | v = \rho * v + g \\ 43 | p = p - lr * v 44 | 45 | where p, g, v and :math:`\rho` denote the parameters, gradient, 46 | velocity, and momentum respectively. 47 | 48 | This is in contrast to Sutskever et. al. and 49 | other frameworks which employ an update of the form 50 | 51 | .. math:: 52 | v = \rho * v + lr * g \\ 53 | p = p - v 54 | 55 | The Nesterov version is analogously modified. 56 | """ 57 | 58 | def __init__( 59 | self, 60 | params, 61 | lr=required, 62 | momentum=0, 63 | dampening=0, 64 | weight_decay=0, 65 | nesterov=False, 66 | max_eps=10 / 255, 67 | ): 68 | if lr is not required and lr < 0.0: 69 | raise ValueError("Invalid learning rate: {}".format(lr)) 70 | if momentum < 0.0: 71 | raise ValueError("Invalid momentum value: {}".format(momentum)) 72 | if weight_decay < 0.0: 73 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 74 | 75 | defaults = dict( 76 | lr=lr, 77 | momentum=momentum, 78 | dampening=dampening, 79 | weight_decay=weight_decay, 80 | nesterov=nesterov, 81 | sign=False, 82 | ) 83 | if nesterov and (momentum <= 0 or dampening != 0): 84 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 85 | super(MI_SGD, self).__init__(params, defaults) 86 | self.sat = 0 87 | self.sat_prev = 0 88 | self.max_eps = max_eps 89 | 90 | def __setstate__(self, state): 91 | super(MI_SGD, self).__setstate__(state) 92 | for group in self.param_groups: 93 | group.setdefault("nesterov", False) 94 | 95 | def rescale(self,): 96 | for group in self.param_groups: 97 | if not group["sign"]: 98 | continue 99 | for p in group["params"]: 100 | self.sat_prev = self.sat 101 | self.sat = (p.data.abs() >= self.max_eps).sum().item() / p.data.numel() 102 | sat_change = abs(self.sat - self.sat_prev) 103 | if rescale_check(CHECK, self.sat, sat_change, SAT_MIN): 104 | print('rescaled') 105 | p.data = p.data / 2 106 | 107 | def step(self, closure=None): 108 | """Performs a single optimization step. 109 | 110 | Arguments: 111 | closure (callable, optional): A closure that reevaluates the model 112 | and returns the loss. 113 | """ 114 | loss = None 115 | if closure is not None: 116 | loss = closure() 117 | 118 | for group in self.param_groups: 119 | weight_decay = group["weight_decay"] 120 | momentum = group["momentum"] 121 | dampening = group["dampening"] 122 | nesterov = group["nesterov"] 123 | 124 | for p in group["params"]: 125 | if p.grad is None: 126 | continue 127 | d_p = p.grad.data 128 | if group["sign"]: 129 | d_p = d_p / (d_p.norm(1) + 1e-12) 130 | if weight_decay != 0: 131 | d_p.add_(weight_decay, p.data) 132 | if momentum != 0: 133 | param_state = self.state[p] 134 | if "momentum_buffer" not in param_state: 135 | buf = param_state["momentum_buffer"] = torch.zeros_like(p.data) 136 | buf.mul_(momentum).add_(d_p) 137 | else: 138 | buf = param_state["momentum_buffer"] 139 | buf.mul_(momentum).add_(1 - dampening, d_p) 140 | if nesterov: 141 | d_p = d_p.add(momentum, buf) 142 | else: 143 | d_p = buf 144 | 145 | if group["sign"]: 146 | p.data.add_(-group["lr"], d_p.sign()) 147 | p.data = torch.clamp(p.data, -self.max_eps, self.max_eps) 148 | else: 149 | p.data.add_(-group["lr"], d_p) 150 | 151 | return loss 152 | 153 | 154 | class SIGN_AdaBound(Optimizer): 155 | """Implements AdaBound algorithm. 156 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_. 157 | Arguments: 158 | params (iterable): iterable of parameters to optimize or dicts defining 159 | parameter groups 160 | lr (float, optional): Adam learning rate (default: 1e-3) 161 | betas (Tuple[float, float], optional): coefficients used for computing 162 | running averages of gradient and its square (default: (0.9, 0.999)) 163 | final_lr (float, optional): final (SGD) learning rate (default: 0.1) 164 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3) 165 | eps (float, optional): term added to the denominator to improve 166 | numerical stability (default: 1e-8) 167 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 168 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm 169 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate: 170 | https://openreview.net/forum?id=Bkg3g2R9FX 171 | """ 172 | 173 | def __init__( 174 | self, 175 | params, 176 | lr=1e-3, 177 | betas=(0.9, 0.999), 178 | final_lr=0.1, 179 | gamma=1e-3, 180 | eps=1e-8, 181 | weight_decay=0, 182 | amsbound=False, 183 | max_eps=10 / 255, 184 | ): 185 | if not 0.0 <= lr: 186 | raise ValueError("Invalid learning rate: {}".format(lr)) 187 | if not 0.0 <= eps: 188 | raise ValueError("Invalid epsilon value: {}".format(eps)) 189 | if not 0.0 <= betas[0] < 1.0: 190 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 191 | if not 0.0 <= betas[1] < 1.0: 192 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 193 | if not 0.0 <= final_lr: 194 | raise ValueError("Invalid final learning rate: {}".format(final_lr)) 195 | if not 0.0 <= gamma < 1.0: 196 | raise ValueError("Invalid gamma parameter: {}".format(gamma)) 197 | defaults = dict( 198 | lr=lr, 199 | betas=betas, 200 | final_lr=final_lr, 201 | gamma=gamma, 202 | eps=eps, 203 | weight_decay=weight_decay, 204 | amsbound=amsbound, 205 | ) 206 | super(SIGN_AdaBound, self).__init__(params, defaults) 207 | 208 | self.base_lrs = list(map(lambda group: group["lr"], self.param_groups)) 209 | self.max_eps = max_eps 210 | self.sat = 0 211 | self.sat_prev = 0 212 | 213 | def rescale(self,): 214 | for group in self.param_groups: 215 | if not group["sign"]: 216 | continue 217 | for p in group["params"]: 218 | self.sat_prev = self.sat 219 | self.sat = (p.data.abs() == self.max_eps).sum().item() / p.data.numel() 220 | sat_change = abs(self.sat - self.sat_prev) 221 | if rescale_check(CHECK, self.sat, sat_change, SAT_MIN): 222 | p.data = p.data / 2 223 | 224 | def __setstate__(self, state): 225 | super(SIGN_AdaBound, self).__setstate__(state) 226 | for group in self.param_groups: 227 | group.setdefault("amsbound", False) 228 | 229 | def step(self, closure=None): 230 | """Performs a single optimization step. 231 | Arguments: 232 | closure (callable, optional): A closure that reevaluates the model 233 | and returns the loss. 234 | """ 235 | loss = None 236 | if closure is not None: 237 | loss = closure() 238 | 239 | for group, base_lr in zip(self.param_groups, self.base_lrs): 240 | for p in group["params"]: 241 | if p.grad is None: 242 | continue 243 | grad = p.grad.data 244 | # if group["sign"]: 245 | # grad = grad / (grad.norm(1) + 1e-12) 246 | if grad.is_sparse: 247 | raise RuntimeError( 248 | "Adam does not support sparse gradients, please consider SparseAdam instead" 249 | ) 250 | amsbound = group["amsbound"] 251 | 252 | state = self.state[p] 253 | 254 | # State initialization 255 | if len(state) == 0: 256 | state["step"] = 0 257 | # Exponential moving average of gradient values 258 | state["exp_avg"] = torch.zeros_like(p.data) 259 | # Exponential moving average of squared gradient values 260 | state["exp_avg_sq"] = torch.zeros_like(p.data) 261 | if amsbound: 262 | # Maintains max of all exp. moving avg. of sq. grad. values 263 | state["max_exp_avg_sq"] = torch.zeros_like(p.data) 264 | 265 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 266 | if amsbound: 267 | max_exp_avg_sq = state["max_exp_avg_sq"] 268 | beta1, beta2 = group["betas"] 269 | 270 | state["step"] += 1 271 | 272 | if group["weight_decay"] != 0: 273 | grad = grad.add(group["weight_decay"], p.data) 274 | 275 | # Decay the first and second moment running average coefficient 276 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 277 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 278 | if amsbound: 279 | # Maintains the maximum of all 2nd moment running avg. till now 280 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 281 | # Use the max. for normalizing running avg. of gradient 282 | denom = max_exp_avg_sq.sqrt().add_(group["eps"]) 283 | else: 284 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 285 | 286 | bias_correction1 = 1 - beta1 ** state["step"] 287 | bias_correction2 = 1 - beta2 ** state["step"] 288 | step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 289 | 290 | # Applies bounds on actual learning rate 291 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay 292 | final_lr = group["final_lr"] * group["lr"] / base_lr 293 | lower_bound = final_lr * (1 - 1 / (group["gamma"] * state["step"] + 1)) 294 | upper_bound = final_lr * (1 + 1 / (group["gamma"] * state["step"])) 295 | step_size = torch.full_like(denom, step_size) 296 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) 297 | 298 | p.data.add_(-step_size) 299 | __import__("pdb").set_trace() 300 | # if group["sign"]: 301 | # p.data = torch.clamp(p.data, -self.max_eps, self.max_eps) 302 | 303 | return loss 304 | 305 | 306 | class SIGN_Adam(Optimizer): 307 | """Implements Adam algorithm. 308 | 309 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 310 | 311 | Arguments: 312 | params (iterable): iterable of parameters to optimize or dicts defining 313 | parameter groups 314 | lr (float, optional): learning rate (default: 1e-3) 315 | betas (Tuple[float, float], optional): coefficients used for computing 316 | running averages of gradient and its square (default: (0.9, 0.999)) 317 | eps (float, optional): term added to the denominator to improve 318 | numerical stability (default: 1e-8) 319 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 320 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 321 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 322 | 323 | .. _Adam\: A Method for Stochastic Optimization: 324 | https://arxiv.org/abs/1412.6980 325 | .. _On the Convergence of Adam and Beyond: 326 | https://openreview.net/forum?id=ryQu7f-RZ 327 | """ 328 | 329 | def __init__( 330 | self, 331 | params, 332 | lr=1e-3, 333 | betas=(0.9, 0.999), 334 | eps=1e-8, 335 | weight_decay=0, 336 | amsgrad=False, 337 | max_eps=10 / 255, 338 | ): 339 | if not 0.0 <= lr: 340 | raise ValueError("Invalid learning rate: {}".format(lr)) 341 | if not 0.0 <= eps: 342 | raise ValueError("Invalid epsilon value: {}".format(eps)) 343 | if not 0.0 <= betas[0] < 1.0: 344 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 345 | if not 0.0 <= betas[1] < 1.0: 346 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 347 | defaults = dict( 348 | lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad 349 | ) 350 | super(SIGN_Adam, self).__init__(params, defaults) 351 | self.max_eps = max_eps 352 | self.sat = 0 353 | self.sat_prev = 0 354 | 355 | def __setstate__(self, state): 356 | super(SIGN_Adam, self).__setstate__(state) 357 | for group in self.param_groups: 358 | group.setdefault("amsgrad", False) 359 | 360 | def step(self, closure=None): 361 | """Performs a single optimization step. 362 | 363 | Arguments: 364 | closure (callable, optional): A closure that reevaluates the model 365 | and returns the loss. 366 | """ 367 | loss = None 368 | if closure is not None: 369 | loss = closure() 370 | 371 | for group in self.param_groups: 372 | for p in group["params"]: 373 | if p.grad is None: 374 | continue 375 | grad = p.grad.data 376 | # if group["sign"]: 377 | # grad = grad / (grad.norm(1) + 1e-12) 378 | # grad = grad / (grad.norm(1) + 1e-12) 379 | if grad.is_sparse: 380 | raise RuntimeError( 381 | "Adam does not support sparse gradients, please consider SparseAdam instead" 382 | ) 383 | amsgrad = group["amsgrad"] 384 | 385 | state = self.state[p] 386 | 387 | # State initialization 388 | if len(state) == 0: 389 | state["step"] = 0 390 | # Exponential moving average of gradient values 391 | state["exp_avg"] = torch.zeros_like(p.data) 392 | # Exponential moving average of squared gradient values 393 | state["exp_avg_sq"] = torch.zeros_like(p.data) 394 | if amsgrad: 395 | # Maintains max of all exp. moving avg. of sq. grad. values 396 | state["max_exp_avg_sq"] = torch.zeros_like(p.data) 397 | 398 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 399 | if amsgrad: 400 | max_exp_avg_sq = state["max_exp_avg_sq"] 401 | beta1, beta2 = group["betas"] 402 | 403 | state["step"] += 1 404 | 405 | if group["weight_decay"] != 0: 406 | grad = grad.add(group["weight_decay"], p.data) 407 | 408 | # Decay the first and second moment running average coefficient 409 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 410 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 411 | if amsgrad: 412 | # Maintains the maximum of all 2nd moment running avg. till now 413 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 414 | # Use the max. for normalizing running avg. of gradient 415 | denom = max_exp_avg_sq.sqrt().add_(group["eps"]) 416 | else: 417 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 418 | 419 | bias_correction1 = 1 - beta1 ** state["step"] 420 | bias_correction2 = 1 - beta2 ** state["step"] 421 | step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 422 | 423 | # p.data.addcdiv_(-step_size, exp_avg, denom.sign()) 424 | p.data.addcdiv_(-step_size, exp_avg, denom) 425 | if group["sign"]: 426 | p.data = torch.clamp(p.data, -self.max_eps, self.max_eps) 427 | 428 | return loss 429 | 430 | def rescale(self,): 431 | for group in self.param_groups: 432 | if not group["sign"]: 433 | continue 434 | for p in group["params"]: 435 | self.sat_prev = self.sat 436 | self.sat = (p.data.abs() == self.max_eps).sum().item() / p.data.numel() 437 | sat_change = abs(self.sat - self.sat_prev) 438 | if rescale_check(CHECK, self.sat, sat_change, SAT_MIN): 439 | p.data = p.data / 2 440 | -------------------------------------------------------------------------------- /cirtorch/examples/attack/distillation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import math 6 | import pickle 7 | import pdb 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.multiprocessing 13 | 14 | # torch.multiprocessing.set_sharing_strategy("file_system") 15 | import torch.nn as nn 16 | import torch.optim 17 | import torch.utils.data 18 | import torchvision.transforms as transforms 19 | import torchvision.models as models 20 | import torchvision.datasets as datasets 21 | 22 | from cirtorch.networks.imageretrievalnet import init_network, extract_vectors 23 | from cirtorch.layers.loss import ContrastiveLoss 24 | from cirtorch.datasets.datahelpers import collate_tuples, cid2filename 25 | from cirtorch.datasets.traindataset import TuplesDataset 26 | from cirtorch.datasets.testdataset import configdataset 27 | from cirtorch.utils.download import download_train, download_test 28 | from cirtorch.utils.whiten import whitenlearn, whitenapply 29 | from cirtorch.utils.evaluate import compute_map_and_print 30 | from cirtorch.utils.general import get_data_root, htime 31 | from cirtorch.examples.attack.myutil.distillation_dataset import Distillation_dataset 32 | from cirtorch.examples.attack.myutil.utils import bcolors 33 | from cirtorch.examples.attack.myutil.utils import do_whiten 34 | from cirtorch.layers.normalization import L2N 35 | 36 | f = os.path.realpath(__file__) 37 | f = open(f, "r") 38 | print("".join(f.readlines())) 39 | f.close() 40 | 41 | training_dataset_names = ["retrieval-SfM-120k", "Landmarks"] 42 | test_datasets_names = [ 43 | "oxford5k,paris6k", 44 | "roxford5k,rparis6k", 45 | "oxford5k,paris6k,roxford5k,rparis6k", 46 | ] 47 | test_whiten_names = ["retrieval-SfM-30k", "retrieval-SfM-120k"] 48 | 49 | model_names = sorted( 50 | name 51 | for name in models.__dict__ 52 | if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) 53 | ) 54 | pool_names = ["mac", "spoc", "gem", "rmac"] 55 | loss_names = ["contrastive", "cross_entropy"] 56 | optimizer_names = ["sgd", "adam"] 57 | 58 | parser = argparse.ArgumentParser(description="PyTorch CNN Image Retrieval Training") 59 | 60 | # export directory, training and val datasets, test datasets 61 | parser.add_argument( 62 | "directory", metavar="DIR", help="destination where trained network should be saved" 63 | ) 64 | parser.add_argument( 65 | "--target", help="destination where trained network should be saved" 66 | ) 67 | parser.add_argument( 68 | "--test-datasets", 69 | "-td", 70 | metavar="DATASETS", 71 | default="oxford5k", 72 | choices=test_datasets_names, 73 | help="comma separated list of test datasets: " 74 | + " | ".join(test_datasets_names) 75 | + " (default: oxford5k,paris6k)", 76 | ) 77 | parser.add_argument( 78 | "--test-whiten", 79 | metavar="DATASET", 80 | default="", 81 | choices=test_whiten_names, 82 | help="dataset used to learn whitening for testing: " 83 | + " | ".join(test_whiten_names) 84 | + " (default: None)", 85 | ) 86 | 87 | # network architecture and initialization options 88 | parser.add_argument( 89 | "--arch", 90 | "-a", 91 | metavar="ARCH", 92 | default="resnet101", 93 | choices=model_names, 94 | help="model architecture: " + " | ".join(model_names) + " (default: resnet101)", 95 | ) 96 | parser.add_argument( 97 | "--pool", 98 | "-p", 99 | metavar="POOL", 100 | default="gem", 101 | choices=pool_names, 102 | help="pooling options: " + " | ".join(pool_names) + " (default: gem)", 103 | ) 104 | parser.add_argument( 105 | "--whitening", 106 | "-w", 107 | dest="whitening", 108 | action="store_true", 109 | help="train model with end-to-end whitening", 110 | ) 111 | parser.add_argument( 112 | "--not-pretrained", 113 | dest="pretrained", 114 | action="store_false", 115 | help="use model with random weights (default: pretrained on imagenet)", 116 | ) 117 | parser.add_argument( 118 | "--loss", 119 | "-l", 120 | metavar="LOSS", 121 | default="contrastive", 122 | choices=loss_names, 123 | help="training loss options: " + " | ".join(loss_names) + " (default: contrastive)", 124 | ) 125 | parser.add_argument( 126 | "--loss-margin", 127 | "-lm", 128 | metavar="LM", 129 | default=0.7, 130 | type=float, 131 | help="loss margin: (default: 0.7)", 132 | ) 133 | 134 | # train/val options specific for image retrieval learning 135 | parser.add_argument( 136 | "--image-size", 137 | default=362, 138 | type=int, 139 | metavar="N", 140 | help="maximum size of longer image side used for training (default: 1024)", 141 | ) 142 | parser.add_argument( 143 | "--query-size", 144 | "-qs", 145 | default=2000, 146 | type=int, 147 | metavar="N", 148 | help="number of queries randomly drawn per one train epoch (default: 2000)", 149 | ) 150 | parser.add_argument( 151 | "--pool-size", 152 | "-ps", 153 | default=20000, 154 | type=int, 155 | metavar="N", 156 | help="size of the pool for hard negative mining (default: 20000)", 157 | ) 158 | 159 | # standard train/val options 160 | parser.add_argument( 161 | "--gpu-id", 162 | "-g", 163 | default="0", 164 | metavar="N", 165 | help="gpu id used for training (default: 0)", 166 | ) 167 | parser.add_argument( 168 | "--workers", 169 | "-j", 170 | default=8, 171 | type=int, 172 | metavar="N", 173 | help="number of data loading workers (default: 8)", 174 | ) 175 | parser.add_argument( 176 | "--epochs", 177 | default=500, 178 | type=int, 179 | metavar="N", 180 | help="number of total epochs to run (default: 100)", 181 | ) 182 | parser.add_argument( 183 | "--batch-size", 184 | "-b", 185 | default=10, 186 | type=int, 187 | metavar="N", 188 | help="number of (q,p,n1,...,nN) tuples in a mini-batch (default: 5)", 189 | ) 190 | parser.add_argument( 191 | "--optimizer", 192 | "-o", 193 | metavar="OPTIMIZER", 194 | default="adam", 195 | choices=optimizer_names, 196 | help="optimizer options: " + " | ".join(optimizer_names) + " (default: adam)", 197 | ) 198 | parser.add_argument( 199 | "--lr", 200 | "--learning-rate", 201 | default=5e-7, 202 | type=float, 203 | metavar="LR", 204 | help="initial learning rate (default: 1e-6)", 205 | ) 206 | parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") 207 | parser.add_argument( 208 | "--weight-decay", 209 | "--wd", 210 | default=1e-4, 211 | type=float, 212 | metavar="W", 213 | help="weight decay (default: 1e-4)", 214 | ) 215 | parser.add_argument( 216 | "--print-freq", 217 | default=50, 218 | type=int, 219 | metavar="N", 220 | help="print frequency (default: 10)", 221 | ) 222 | parser.add_argument( 223 | "--is_pretrained", dest="is_pretrained", action="store_true", help="is_pretrained" 224 | ) 225 | parser.add_argument( 226 | "--is_random", dest="is_random", action="store_true", help="is_random" 227 | ) 228 | parser.add_argument("--notion", help="notion") 229 | parser.add_argument("--q_percent", default=1, type=float) 230 | 231 | min_loss = float("inf") 232 | 233 | 234 | class Whiten_layer(nn.Module): 235 | def __init__(self, d_in, d_out): 236 | super(Whiten_layer, self).__init__() 237 | self.w = nn.Linear(d_in, d_out) 238 | self.norm = L2N() 239 | 240 | def forward(self, x): 241 | return self.norm(self.w(x)) 242 | 243 | 244 | def main(): 245 | global args, min_loss 246 | args = parser.parse_args() 247 | print(args) 248 | 249 | # set cuda visible device 250 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 251 | 252 | # check if test dataset are downloaded 253 | # and download if they are not 254 | download_train(get_data_root()) 255 | download_test(get_data_root()) 256 | 257 | # create export dir if it doesnt exist 258 | directory = "{}".format(args.target.replace("/", "_")) 259 | directory += "_{}".format(args.arch) 260 | directory += "_{}".format(args.pool) 261 | if args.whitening: 262 | directory += "_whiten" 263 | if not args.pretrained: 264 | directory += "_notpretrained" 265 | # directory += "_bsize{}_imsize{}".format(args.batch_size, args.image_size) 266 | directory += "_pretrained" if args.is_pretrained else "" 267 | directory += "_random" if args.is_random else "" 268 | directory += args.notion 269 | 270 | target_net = args.target[6:].replace("_", "/") 271 | print(target_net) 272 | state = torch.load(target_net) 273 | lw = state["meta"]["Lw"]["retrieval-SfM-120k"]["ss"] 274 | 275 | args.directory = os.path.join(args.directory, directory) 276 | print(">> Creating directory if it does not exist:\n>> '{}'".format(args.directory)) 277 | if not os.path.exists(args.directory): 278 | os.makedirs(args.directory) 279 | 280 | # set random seeds (maybe pass as argument) 281 | torch.manual_seed(0) 282 | torch.cuda.manual_seed_all(0) 283 | np.random.seed(0) 284 | 285 | # create model 286 | print(">> Using pre-trained model '{}'".format(args.arch)) 287 | model = init_network( 288 | model=args.arch, 289 | pooling=args.pool, 290 | whitening=args.whitening, 291 | pretrained=not args.is_random, 292 | ) 293 | model.cuda() 294 | 295 | target_model = init_network( 296 | model=args.arch, 297 | pooling=args.pool, 298 | whitening=args.whitening, 299 | pretrained=not args.is_random, 300 | ) 301 | target_model.load_state_dict(state["state_dict"]) 302 | lw_m = lw["m"].copy() 303 | lw_p = lw["P"].copy() 304 | target_model.lw_m = nn.Parameter(torch.from_numpy(lw_m).float()) 305 | target_model.lw_p = nn.Parameter(torch.from_numpy(lw_p).float()) 306 | target_model.cuda() 307 | 308 | lw_m = lw["m"].copy() 309 | lw_p = lw["P"].copy() 310 | model.lw_m = nn.Parameter(torch.from_numpy(lw_m).float()) 311 | model.lw_p = nn.Parameter(torch.from_numpy(lw_p).float()) 312 | whiten_layer = Whiten_layer(lw["P"].shape[1], lw["P"].shape[0]) 313 | model.white_layer = whiten_layer 314 | model.cuda() 315 | 316 | # parameters split into features and pool (no weight decay for pooling layer) 317 | parameters = [ 318 | {"params": model.features.parameters()}, 319 | {"params": model.pool.parameters(), "lr": args.lr * 10, "weight_decay": 0}, 320 | {"params": model.white_layer.parameters(), "lr": 1e-2, "weight_decay": 5e-1}, 321 | ] 322 | if model.whiten is not None: 323 | parameters.append({"params": model.whiten.parameters()}) 324 | 325 | # define optimizer 326 | if args.optimizer == "sgd": 327 | optimizer = torch.optim.SGD( 328 | parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay 329 | ) 330 | elif args.optimizer == "adam": 331 | optimizer = torch.optim.Adam( 332 | parameters, args.lr, weight_decay=args.weight_decay 333 | ) 334 | 335 | # define learning rate decay schedule 336 | exp_decay = math.exp(-0.01) 337 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=exp_decay) 338 | 339 | # optionally resume from a checkpoint 340 | start_epoch = 0 341 | # Data loading code 342 | normalize = transforms.Normalize(mean=model.meta["mean"], std=model.meta["std"]) 343 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 344 | val_dataset = Distillation_dataset( 345 | imsize=(args.image_size, args.image_size), 346 | nnum=1, 347 | qsize=float("Inf"), 348 | poolsize=float("Inf"), 349 | transform=transform, 350 | filename=args.target, 351 | q_percent=args.q_percent, 352 | ) 353 | val_loader = torch.utils.data.DataLoader( 354 | val_dataset, 355 | batch_size=10, 356 | shuffle=False, 357 | num_workers=args.workers, 358 | pin_memory=False, 359 | collate_fn=collate_tuples, 360 | ) 361 | 362 | min_epoch = -1 363 | for epoch in range(start_epoch, args.epochs): 364 | if args.is_pretrained or args.is_random: 365 | break 366 | 367 | # set manual seeds per epoch 368 | np.random.seed(epoch) 369 | torch.manual_seed(epoch) 370 | torch.cuda.manual_seed_all(epoch) 371 | 372 | # adjust learning rate for each epoch 373 | scheduler.step() 374 | 375 | loss = train(val_loader, model, optimizer, epoch, target_model) 376 | print(loss) 377 | 378 | # evaluate on test datasets 379 | if (epoch + 1) % 1 == 0: 380 | with torch.no_grad(): 381 | test(args.test_datasets, model, lw) 382 | 383 | # remember best loss and save checkpoint 384 | is_best = loss < min_loss 385 | min_loss = min(loss, min_loss) 386 | 387 | save_checkpoint( 388 | { 389 | "epoch": epoch + 1, 390 | "meta": model.meta, 391 | "state_dict": model.state_dict(), 392 | "min_loss": min_loss, 393 | "optimizer": optimizer.state_dict(), 394 | }, 395 | is_best, 396 | args.directory, 397 | ) 398 | if is_best: 399 | min_epoch = epoch 400 | # if epoch - min_epoch > 5: 401 | # # break 402 | # if val_dataset.phase == 1: 403 | # print(bcolors.str(">>> phase 2", bcolors.OKGREEN)) 404 | # val_dataset.phase = 2 405 | # min_epoch = epoch 406 | # for group in optimizer.param_groups: 407 | # group["lr"] /= 10 408 | # else: 409 | # break 410 | 411 | if args.is_pretrained or args.is_random: 412 | save_checkpoint( 413 | { 414 | "epoch": 0 + 1, 415 | "meta": model.meta, 416 | "state_dict": model.state_dict(), 417 | "min_loss": min_loss, 418 | "optimizer": optimizer.state_dict(), 419 | }, 420 | True, 421 | args.directory, 422 | ) 423 | # print("calculate whiten") 424 | # lw = learning_lw(model) 425 | # filename = os.path.join(args.directory, "lw") 426 | # pickle.dump(lw, open(filename, "wb")) 427 | 428 | 429 | def train(train_loader, model, optimizer, epoch, target_model): 430 | batch_time = AverageMeter() 431 | data_time = AverageMeter() 432 | losses = AverageMeter() 433 | 434 | # switch to train mode 435 | model.train() 436 | 437 | end = time.time() 438 | dataset = train_loader.dataset 439 | l = np.arange(len(dataset)) 440 | print(len(l)) 441 | # np.random.shuffle(l) 442 | optimizer.zero_grad() 443 | end = time.time() 444 | for batch_i, index in enumerate(l): 445 | data_time.update(time.time() - end) 446 | end = time.time() 447 | r = dataset.ranks[index, :].long().numpy() 448 | BIN = 128 449 | if dataset.phase == 1: 450 | size = len(r) // BIN 451 | bid = [] 452 | for i in range(BIN): 453 | bi = np.random.choice(r[i * size : i * size + size], 1)[0] 454 | while dataset.loaded_images[bi] is None: 455 | bi = np.random.choice(r[i * size : i * size + size], 1)[0] 456 | bid.append(bi) 457 | elif dataset.phase == 2: 458 | # For convenience 459 | bid = r[:BIN] 460 | 461 | output = [] 462 | output.append(dataset.loaded_images[dataset.qidxs[index]]) 463 | for bi in bid: 464 | output.append(dataset.loaded_images[bi].detach()) 465 | output = torch.cat(output).cuda() 466 | 467 | target_output = target_model(output) 468 | target_output = do_whiten(target_output, target_model.lw_m, target_model.lw_p) 469 | 470 | while True: 471 | my_output = model(output) 472 | my_output = model.white_layer(my_output.t()).t() 473 | similarity = my_output[:, 0].view(1, -1) @ my_output[:, 1:] 474 | diff = similarity.t() - similarity - (1e-1 if dataset.phase == 1 else 1e-6) 475 | diff = -diff.triu() - torch.eye(BIN).cuda() 476 | coff = [i for i in range(diff.size(0) - 1, 0, -1)] + [0] 477 | diff = diff * torch.Tensor(coff).view(-1, 1).cuda() 478 | loss = nn.functional.relu(diff) 479 | loss = loss.sum() - 10000 * similarity.std() 480 | 481 | losses.update(loss.item(), 1) 482 | loss.backward() 483 | batch_time.update(time.time() - end) 484 | end = time.time() 485 | if (batch_i + 0) % args.batch_size == 0: 486 | optimizer.step() 487 | 488 | if batch_i % args.print_freq == 0: 489 | print( 490 | ">> Train: [{0}][{1}/{2}]\t" 491 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 492 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 493 | "Loss {loss.val:.4f} ({loss.avg:.4f})".format( 494 | epoch + 1, 495 | batch_i, 496 | len(dataset), 497 | batch_time=batch_time, 498 | data_time=data_time, 499 | loss=losses, 500 | ) 501 | ) 502 | break 503 | optimizer.step() 504 | 505 | return losses.avg 506 | 507 | 508 | def learning_lw(net): 509 | net.cuda() 510 | net.eval() 511 | # set up the transform 512 | normalize = transforms.Normalize(mean=net.meta["mean"], std=net.meta["std"]) 513 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 514 | 515 | test_whiten = "retrieval-SfM-30k" 516 | print(">> {}: Learning whitening...".format(test_whiten)) 517 | 518 | # loading db 519 | db_root = os.path.join(get_data_root(), "train", test_whiten) 520 | ims_root = os.path.join(db_root, "ims") 521 | db_fn = os.path.join(db_root, "{}-whiten.pkl".format(test_whiten)) 522 | with open(db_fn, "rb") as f: 523 | db = pickle.load(f) 524 | images = [cid2filename(db["cids"][i], ims_root) for i in range(len(db["cids"]))] 525 | 526 | # extract whitening vectors 527 | print(">> {}: Extracting...".format(args.test_whiten)) 528 | wvecs = extract_vectors(net, images, 1024, transform) 529 | 530 | # learning whitening 531 | print(">> {}: Learning...".format(args.test_whiten)) 532 | wvecs = wvecs.numpy() 533 | m, P = whitenlearn(wvecs, db["qidxs"], db["pidxs"]) 534 | Lw = {"m": m, "P": P} 535 | return Lw 536 | 537 | 538 | def test(datasets, net, lw): 539 | 540 | print(">> Evaluating network on test datasets...") 541 | 542 | # for testing we use image size of max 1024 543 | image_size = 1024 544 | 545 | # moving network to gpu and eval mode 546 | net.cuda() 547 | net.eval() 548 | # set up the transform 549 | normalize = transforms.Normalize(mean=net.meta["mean"], std=net.meta["std"]) 550 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 551 | 552 | Lw = lw 553 | Lw = None 554 | 555 | # evaluate on test datasets 556 | datasets = args.test_datasets.split(",") 557 | for dataset in datasets: 558 | start = time.time() 559 | 560 | print(">> {}: Extracting...".format(dataset)) 561 | 562 | # prepare config structure for the test dataset 563 | cfg = configdataset(dataset, os.path.join(get_data_root(), "test")) 564 | images = [cfg["im_fname"](cfg, i) for i in range(cfg["n"])] 565 | qimages = [cfg["qim_fname"](cfg, i) for i in range(cfg["nq"])] 566 | bbxs = [tuple(cfg["gnd"][i]["bbx"]) for i in range(cfg["nq"])] 567 | 568 | # extract database and query vectors 569 | print(">> {}: database images...".format(dataset)) 570 | vecs = extract_vectors(net, images, image_size, transform) 571 | print(">> {}: query images...".format(dataset)) 572 | qvecs = extract_vectors(net, qimages, image_size, transform, bbxs) 573 | 574 | print(">> {}: Evaluating...".format(dataset)) 575 | 576 | vecs = do_whiten(vecs.cuda(), net.lw_m, net.lw_p).cpu() 577 | qvecs = do_whiten(qvecs.cuda(), net.lw_m, net.lw_p).cpu() 578 | 579 | # convert to numpy 580 | vecs = vecs.numpy() 581 | qvecs = qvecs.numpy() 582 | 583 | # search, rank, and print 584 | scores = np.dot(vecs.T, qvecs) 585 | ranks = np.argsort(-scores, axis=0) 586 | compute_map_and_print(dataset, ranks, cfg["gnd"]) 587 | 588 | if Lw is not None: 589 | # whiten the vectors 590 | vecs_lw = whitenapply(vecs, Lw["m"], Lw["P"]) 591 | qvecs_lw = whitenapply(qvecs, Lw["m"], Lw["P"]) 592 | 593 | # search, rank, and print 594 | scores = np.dot(vecs_lw.T, qvecs_lw) 595 | ranks = np.argsort(-scores, axis=0) 596 | compute_map_and_print(dataset + " + whiten", ranks, cfg["gnd"]) 597 | 598 | print(">> {}: elapsed time: {}".format(dataset, htime(time.time() - start))) 599 | 600 | 601 | def save_checkpoint(state, is_best, directory): 602 | filename = os.path.join(directory, "model_epoch%d.pth.tar" % state["epoch"]) 603 | torch.save(state, filename) 604 | if is_best: 605 | filename_best = os.path.join(directory, "model_best.pth.tar") 606 | shutil.copyfile(filename, filename_best) 607 | 608 | 609 | class AverageMeter(object): 610 | """Computes and stores the average and current value""" 611 | 612 | def __init__(self): 613 | self.reset() 614 | 615 | def reset(self): 616 | self.val = 0 617 | self.avg = 0 618 | self.sum = 0 619 | self.count = 0 620 | 621 | def update(self, val, n=1): 622 | self.val = val 623 | self.sum += val * n 624 | self.count += n 625 | self.avg = self.sum / self.count 626 | 627 | 628 | def set_batchnorm_eval(m): 629 | classname = m.__class__.__name__ 630 | if classname.find("BatchNorm") != -1: 631 | # freeze running mean and std 632 | m.eval() 633 | # freeze parameters 634 | # for p in m.parameters(): 635 | # p.requires_grad = False 636 | 637 | 638 | if __name__ == "__main__": 639 | main() 640 | -------------------------------------------------------------------------------- /cirtorch/examples/attack/attack.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import pdb 5 | import pickle 6 | import random 7 | import shutil 8 | import time 9 | from pprint import pprint 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | import torch.utils.data 17 | import torchvision 18 | import torchvision.models as models 19 | import torchvision.transforms as transforms 20 | 21 | from cirtorch.datasets.datahelpers import cid2filename, collate_tuples 22 | from cirtorch.datasets.testdataset import configdataset 23 | from cirtorch.datasets.traindataset import TuplesDataset 24 | from cirtorch.examples.attack.myutil.baseline import result as baseline_result 25 | from cirtorch.examples.attack.myutil.mi_sgd import (MI_SGD, SIGN_AdaBound, 26 | SIGN_Adam) 27 | from cirtorch.examples.attack.myutil.sfm_dataset import SfMDataset 28 | from cirtorch.examples.attack.myutil.triplet_dataset import MyTripletDataset 29 | from cirtorch.examples.attack.myutil.utils import (MultiLoss, bcolors, 30 | do_whiten, idcg, inv_gfr, 31 | one_hot, rescale_check) 32 | from cirtorch.layers.loss import ContrastiveLoss 33 | from cirtorch.networks.imageretrievalnet import extract_vectors, init_network 34 | from cirtorch.utils.download import download_test, download_train 35 | from cirtorch.utils.evaluate import compute_map_and_print 36 | from cirtorch.utils.general import get_data_root, htime 37 | from cirtorch.utils.whiten import whitenapply, whitenlearn 38 | 39 | f = os.path.realpath(__file__) 40 | f = open(f, "r") 41 | print("".join(f.readlines())) 42 | f.close() 43 | 44 | training_dataset_names = ["retrieval-SfM-120k", "Landmarks"] 45 | test_datasets_names = [ 46 | "oxford5k,paris6k", 47 | "roxford5k,rparis6k", 48 | "oxford5k,paris6k,roxford5k,rparis6k", 49 | ] 50 | test_whiten_names = ["retrieval-SfM-30k", "retrieval-SfM-120k"] 51 | 52 | model_names = sorted( 53 | name 54 | for name in models.__dict__ 55 | if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) 56 | ) 57 | 58 | base = {} # storing the feature of base 59 | MAX_EPS = 10.0 / 255 # max eps of perturbation 60 | MODE = "bilinear" # mode of resize 61 | 62 | parser = argparse.ArgumentParser(description="PyTorch CNN Image Retrieval Training") 63 | 64 | # export directory, training and val datasets, test datasets 65 | parser.add_argument( 66 | "--test-datasets", 67 | "-td", 68 | metavar="DATASETS", 69 | default="oxford5k,paris6k,roxford5k,rparis6k", 70 | choices=test_datasets_names, 71 | help="comma separated list of test datasets: " 72 | + " | ".join(test_datasets_names) 73 | + " (default: oxford5k,paris6k)", 74 | ) 75 | parser.add_argument( 76 | "--network-path", help="network path, destination where network is saved" 77 | ) 78 | parser.add_argument( 79 | "--image-size", 80 | default=1024, 81 | type=int, 82 | metavar="N", 83 | help="maximum size of longer image side used for training (default: 1024)", 84 | ) 85 | 86 | # standard train/val options 87 | parser.add_argument( 88 | "--gpu-id", 89 | "-g", 90 | default="0", 91 | metavar="N", 92 | help="gpu id used for training (default: 0)", 93 | ) 94 | parser.add_argument( 95 | "--workers", 96 | "-j", 97 | default=1, 98 | type=int, 99 | metavar="N", 100 | help="number of data loading workers (default: 8)", 101 | ) 102 | parser.add_argument( 103 | "--epochs", 104 | default=100, 105 | type=int, 106 | metavar="N", 107 | help="number of total epochs to run (default: 100)", 108 | ) 109 | parser.add_argument( 110 | "--batch-size", 111 | "-b", 112 | default=1, 113 | type=int, 114 | metavar="N", 115 | help="number of (q,p,n1,...,nN) tuples in a mini-batch (default: 5)", 116 | ) 117 | parser.add_argument( 118 | "--print-freq", 119 | default=500, 120 | type=int, 121 | metavar="N", 122 | help="print frequency (default: 10)", 123 | ) 124 | parser.add_argument("--noise-path", type=str, help="noise path") 125 | parser.add_argument( 126 | "--loss-margin", 127 | "-lm", 128 | metavar="LM", 129 | default=0.8, 130 | type=float, 131 | help="loss margin: (default: 0.7)", 132 | ) 133 | parser.add_argument( 134 | "--image-size-L", default=256, type=int, help="min of image size for random" 135 | ) 136 | parser.add_argument( 137 | "--image-size-H", default=1024, type=int, help="max of image size for random" 138 | ) 139 | parser.add_argument("--noise-size", default=1024, type=int, help="noise-size") 140 | 141 | # loss 142 | parser.add_argument( 143 | "--point_wise", dest="point_wise", action="store_true", help="point-wise loss" 144 | ) 145 | parser.add_argument( 146 | "--label_wise", dest="label_wise", action="store_true", help="label-wise loss" 147 | ) 148 | parser.add_argument( 149 | "--pair_wise", dest="pair_wise", action="store_true", help="pair-wise loss" 150 | ) 151 | parser.add_argument( 152 | "--list_wise", dest="list_wise", action="store_true", help="list_wise loss" 153 | ) 154 | 155 | parser.add_argument("--max-eps", default=10, type=int, help="max eps") 156 | 157 | args = parser.parse_args() 158 | pprint(args) 159 | 160 | 161 | def main(): 162 | global base 163 | global MAX_EPS 164 | MAX_EPS = args.max_eps / 255.0 165 | 166 | # load base 167 | fname = args.network_path.replace("/", "_") + ".pkl" 168 | if os.path.exists(f"base/{fname}"): 169 | with open(f"base/{fname}", "rb") as f: 170 | base = pickle.load(f) 171 | 172 | # for saving noise 173 | os.makedirs(args.noise_path, exist_ok=True) 174 | 175 | # set cuda visible device 176 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 177 | 178 | torch.manual_seed(1234) 179 | torch.cuda.manual_seed_all(1234) 180 | np.random.seed(1234) 181 | random.seed(1234) 182 | torch.backends.cudnn.deterministic = True 183 | 184 | # load retrieval model 185 | state = torch.load(args.network_path) 186 | model = init_network( 187 | model=state["meta"]["architecture"], 188 | pooling=state["meta"]["pooling"], 189 | whitening=state["meta"]["whitening"], 190 | mean=state["meta"]["mean"], 191 | std=state["meta"]["std"], 192 | pretrained=False, 193 | ) 194 | model.load_state_dict(state["state_dict"]) 195 | model.meta["Lw"] = state["meta"]["Lw"] 196 | model.cuda() 197 | 198 | # perturbation for training 199 | noise = torch.zeros((3, args.noise_size, args.noise_size)).cuda() 200 | 201 | print(state["meta"]["architecture"]) 202 | print(state["meta"]["pooling"]) 203 | noise.requires_grad = True 204 | 205 | optimizer = MI_SGD( 206 | [ 207 | {"params": [noise], "lr": MAX_EPS / 10, "momentum": 1, "sign": True}, 208 | # {"params": [noise], "lr": 1e-2, "momentum": 1, "sign": True}, 209 | ], 210 | max_eps=MAX_EPS, 211 | ) 212 | print(optimizer) 213 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=math.exp(-0.01)) 214 | 215 | # Data loading code 216 | normalize = transforms.Normalize(mean=model.meta["mean"], std=model.meta["std"]) 217 | transform = transforms.Compose( 218 | [ 219 | transforms.ToTensor(), 220 | # normalize, 221 | ] 222 | ) 223 | mean = torch.Tensor(normalize.mean).view(3, 1, 1) 224 | std = torch.Tensor(normalize.std).view(3, 1, 1) 225 | 226 | # dataloader 227 | val_dataset = MyTripletDataset( 228 | imsize=(args.image_size_L, args.image_size_H), 229 | transform=transform, 230 | norm=(mean, std), 231 | filename="base/" + args.network_path.replace("/", "_") + "_triplet", 232 | ) 233 | val_dataset.create_epoch_tuples(model) 234 | val_loader = torch.utils.data.DataLoader( 235 | val_dataset, 236 | batch_size=args.batch_size, 237 | shuffle=True, 238 | num_workers=args.workers, 239 | pin_memory=True, 240 | worker_init_fn=lambda _: random.seed(1234), 241 | ) 242 | 243 | # load classifier model 244 | if args.label_wise: 245 | classification_model = torch.load( 246 | "base/" + args.network_path.replace("/", "_") + "_triplet.KMeans_cls.pth" 247 | ) 248 | else: 249 | classification_model = None 250 | 251 | noise_best = None 252 | min_loss = float("inf") 253 | min_epoch = -1 254 | for epoch in range(args.epochs): 255 | # set manual seeds per epoch 256 | np.random.seed(epoch + 1234) 257 | torch.manual_seed(epoch + 1234) 258 | torch.cuda.manual_seed_all(epoch + 1234) 259 | random.seed(epoch + 1234) 260 | 261 | # train for one epoch on train set 262 | scheduler.step() 263 | begin_time = time.time() 264 | loss, noise = train( 265 | val_loader, 266 | model, 267 | noise, 268 | epoch, 269 | normalize, 270 | classification_model, 271 | optimizer, 272 | None, 273 | ) 274 | print("epoch time", time.time() - begin_time) 275 | 276 | # evaluate on test datasets 277 | loss = test(args.test_datasets, model, noise.cpu(), 1024) 278 | print(bcolors.str(f"test fgr: {1-loss}", bcolors.OKGREEN)) 279 | 280 | # remember best loss and save checkpoint 281 | is_best = loss < min_loss 282 | min_loss = min(loss, min_loss) 283 | save_noise(noise, is_best, epoch) 284 | if is_best: 285 | min_epoch = epoch 286 | noise_best = noise.clone().detach() 287 | if epoch - min_epoch > 5: 288 | break 289 | 290 | print("Best") 291 | loss = test(args.test_datasets, model, noise_best.cpu(), 1024) 292 | print(bcolors.str(f"test fgr: {1-loss}", bcolors.OKGREEN)) 293 | 294 | 295 | def train(train_loader, model, noise, epoch, normalize, cls, optimizer, multiLoss): 296 | """ train perturbation 297 | train_loader: data loader 298 | model: victim retrieval model 299 | noise: perturbation to be optimized 300 | epoch: current epoch 301 | normalize: data normalize parameter 302 | cls: classification model 303 | optimizer: optimizer for iter 304 | multiLoss: multi loss 305 | """ 306 | 307 | global args 308 | noise.requires_grad = True 309 | batch_time = AverageMeter() 310 | data_time = AverageMeter() 311 | losses = AverageMeter() 312 | 313 | model.eval() 314 | 315 | # normalize 316 | mean = normalize.mean 317 | std = normalize.std 318 | mean = torch.Tensor(mean).view(1, 3, 1, 1).cuda() 319 | std = torch.Tensor(std).view(1, 3, 1, 1).cuda() 320 | 321 | # whitening 322 | Lw = model.meta["Lw"]["retrieval-SfM-120k"]["ss"] 323 | Lw_m = torch.from_numpy(Lw["m"]).cuda().float() 324 | Lw_p = torch.from_numpy(Lw["P"]).cuda().float() 325 | 326 | # cluster center and base cluster id 327 | pool_clusters_centers = train_loader.dataset.pool_clusters_centers.cuda().float() 328 | clustered_pool = train_loader.dataset.clustered_pool 329 | 330 | end = time.time() 331 | optimizer.zero_grad() 332 | optimizer.rescale() 333 | for i, (input) in enumerate(train_loader): 334 | # measure data loading time. 335 | data_time.update(time.time() - end) 336 | model.zero_grad() 337 | 338 | input = input.cuda() 339 | with torch.no_grad(): 340 | norm_output = (input - mean) / std 341 | feature = model(norm_output) 342 | feature = do_whiten(feature, Lw_m, Lw_p).detach() 343 | 344 | optimizer.zero_grad() 345 | current_noise = noise 346 | current_noise = F.interpolate( 347 | current_noise.unsqueeze(0), 348 | mode=MODE, 349 | size=tuple(input.shape[-2:]), 350 | align_corners=True, 351 | ).squeeze() 352 | perturted_input = torch.clamp(input + current_noise, 0, 1) 353 | 354 | perturbed_input = (perturted_input - mean) / std 355 | perturbed_feature = model(perturbed_input) 356 | perturbed_feature = do_whiten(perturbed_feature, Lw_m, Lw_p) 357 | 358 | # pair-wise 359 | if args.pair_wise: 360 | with torch.no_grad(): 361 | scores = torch.mm((pool_clusters_centers), feature) 362 | scores, ranks = torch.sort(scores, dim=0, descending=True) 363 | 364 | pos_i = ranks[0, 0].item() 365 | neg_i = ranks[-1, 0].item() 366 | # neg_feature = torch.from_numpy( 367 | # np.concatenate( 368 | # ( 369 | # clustered_pool[neg_i][ 370 | # np.random.choice(clustered_pool[neg_i].shape[0]), : 371 | # ].reshape(1, -1), 372 | # ) 373 | # ) 374 | # ).cuda() 375 | # pos_feature = ( 376 | # torch.from_numpy( 377 | # clustered_pool[pos_i][ 378 | # np.random.choice(clustered_pool[pos_i].shape[0]), : 379 | # ] 380 | # ) 381 | # .cuda() 382 | # .unsqueeze(0) 383 | # ) 384 | neg_feature = pool_clusters_centers[neg_i, :].view(1, -1) 385 | pos_feature = pool_clusters_centers[pos_i, :].view(1, -1) 386 | perturbed_feature = perturbed_feature.t() 387 | # neg_feature = torch.cat((neg_feature, -feature.t())) 388 | # pos_feature = torch.cat((pos_feature, feature.t())) 389 | # perturbed_feature = torch.cat( 390 | # (perturbed_feature.t(), perturbed_feature.t()) 391 | # ) 392 | neg_feature = neg_feature * 10 393 | pos_feature = pos_feature * 10 394 | perturbed_feature = perturbed_feature * 10 395 | 396 | pair_loss = F.triplet_margin_loss( 397 | perturbed_feature, neg_feature, pos_feature, args.loss_margin 398 | ) 399 | else: 400 | pair_loss = torch.zeros(1).cuda() 401 | 402 | # point-wise 403 | if args.point_wise: 404 | point_loss = ( 405 | torch.dot(perturbed_feature.squeeze(), feature.squeeze()) + 1 406 | ) / 2 407 | else: 408 | point_loss = torch.zeros(1).cuda() 409 | 410 | # label-wise 411 | if args.label_wise: 412 | actual_pred = cls(feature.t()) 413 | perturbed_pred = cls(perturbed_feature.t()) 414 | actual_label = actual_pred.max(1, keepdim=True)[1].item() 415 | one_hot_actual_label = one_hot( 416 | perturbed_pred.size(1), torch.LongTensor([actual_label]).cuda() 417 | ).float() 418 | label_loss = F.relu( 419 | (perturbed_pred * one_hot_actual_label).sum() 420 | - (perturbed_pred * (1 - one_hot_actual_label)).max() 421 | ) 422 | else: 423 | label_loss = torch.zeros(1).cuda() 424 | 425 | if args.list_wise: 426 | clean_scores = torch.mm((pool_clusters_centers), feature) 427 | _, clean_ranks = torch.sort(clean_scores, dim=0, descending=True) 428 | 429 | # pos_i = clean_ranks[:256, :].squeeze() 430 | # neg_i = clean_ranks[256:, :].squeeze() 431 | pos_i = clean_ranks[:, :].squeeze() 432 | neg_i = torch.flip(pos_i, (0,)) 433 | 434 | scores = -torch.mm((pool_clusters_centers), perturbed_feature) 435 | _, ranks = torch.sort(scores, dim=0, descending=True) 436 | 437 | doc_ranks = torch.zeros(pool_clusters_centers.size(0)).to(feature.device) 438 | doc_ranks[ranks] = 1 + torch.arange(pool_clusters_centers.size(0)).to( 439 | feature.device 440 | ).float().view((-1, 1)) 441 | doc_ranks = doc_ranks.view((-1, 1)) 442 | 443 | score_diffs = scores[pos_i] - scores[neg_i].view(neg_i.size(0)) 444 | exped = score_diffs.exp() 445 | N = 1 / idcg(pos_i.size(0)) 446 | ndcg_diffs = (1 / (1 + doc_ranks[pos_i])).log2() - ( 447 | 1 / (1 + doc_ranks[neg_i]) 448 | ).log2().view(neg_i.size(0)) 449 | 450 | lamb_updates = -1 / (1 + exped) * N * ndcg_diffs.abs() 451 | lambs = torch.zeros((pool_clusters_centers.shape[0], 1)).to(feature.device) 452 | lambs[pos_i] += lamb_updates.sum(dim=1, keepdim=True) 453 | lambs[neg_i] -= lamb_updates.sum(dim=0, keepdim=True).t() 454 | scores.backward(lambs) 455 | list_loss = torch.zeros(1).cuda() 456 | else: 457 | list_loss = torch.zeros(1).cuda() 458 | 459 | label_loss = label_loss.view(1) 460 | point_loss = point_loss.view(1) 461 | pair_loss = pair_loss.view(1) 462 | list_loss = list_loss.view(1) 463 | 464 | loss = label_loss + point_loss + pair_loss 465 | 466 | if not args.list_wise: 467 | loss.backward() 468 | 469 | losses.update(loss.item()) 470 | optimizer.step() 471 | 472 | # measure elapsed time 473 | batch_time.update(time.time() - end) 474 | end = time.time() 475 | 476 | if i % args.print_freq == 0: 477 | # optimizer.rescale() 478 | print( 479 | ">> Train: [{0}][{1}/{2}]\t" 480 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 481 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 482 | "Loss {loss.val:.4f} ({loss.avg:.4f})\t" 483 | "Noise l2: {noise:.4f}".format( 484 | epoch + 1, 485 | i, 486 | len(train_loader), 487 | batch_time=batch_time, 488 | data_time=data_time, 489 | loss=losses, 490 | noise=noise.norm(), 491 | ) 492 | ) 493 | 494 | noise.requires_grad = False 495 | print(bcolors.str(f"Train {epoch}: Loss: {losses.avg}", bcolors.OKGREEN)) 496 | return losses.avg, noise 497 | 498 | 499 | def test(datasets, net, noise, image_size): 500 | global base 501 | print(">> Evaluating network on test datasets...") 502 | 503 | net.cuda() 504 | net.eval() 505 | normalize = transforms.Normalize(mean=net.meta["mean"], std=net.meta["std"]) 506 | 507 | def add_noise(img): 508 | n = noise 509 | n = F.interpolate( 510 | n.unsqueeze(0), mode=MODE, size=tuple(img.shape[-2:]), align_corners=True 511 | ).squeeze() 512 | return torch.clamp(img + n, 0, 1) 513 | 514 | transform_base = transforms.Compose([transforms.ToTensor(), normalize]) 515 | transform_query = transforms.Compose( 516 | [transforms.ToTensor(), transforms.Lambda(add_noise), normalize] 517 | ) 518 | 519 | if "Lw" in net.meta: 520 | Lw = net.meta["Lw"]["retrieval-SfM-120k"]["ss"] 521 | else: 522 | Lw = None 523 | 524 | # evaluate on test datasets 525 | datasets = args.test_datasets.split(",") 526 | attack_result = {} 527 | for dataset in datasets: 528 | start = time.time() 529 | 530 | print(">> {}: Extracting...".format(dataset)) 531 | 532 | cfg = configdataset(dataset, os.path.join(get_data_root(), "test")) 533 | images = [cfg["im_fname"](cfg, i) for i in range(cfg["n"])] 534 | qimages = [cfg["qim_fname"](cfg, i) for i in range(cfg["nq"])] 535 | bbxs = [tuple(cfg["gnd"][i]["bbx"]) for i in range(cfg["nq"])] 536 | 537 | # extract database and query vectors 538 | print(">> {}: database images...".format(dataset)) 539 | with torch.no_grad(): 540 | if dataset in base and str(image_size) in base[dataset]: 541 | vecs = base[dataset][str(image_size)] 542 | else: 543 | vecs = extract_vectors(net, images, image_size, transform_base) 544 | if dataset not in base: 545 | base[dataset] = {} 546 | base[dataset][str(image_size)] = vecs 547 | fname = args.network_path.replace("/", "_") + ".pkl" 548 | with open(f"base/{fname}", "wb") as f: 549 | pickle.dump(base, f) 550 | print(">> {}: query images...".format(dataset)) 551 | qvecs = extract_vectors(net, qimages, image_size, transform_query, bbxs) 552 | 553 | print(">> {}: Evaluating...".format(dataset)) 554 | 555 | # convert to numpy 556 | vecs = vecs.numpy() 557 | qvecs = qvecs.numpy() 558 | 559 | # whiten the vectors 560 | vecs_lw = whitenapply(vecs, Lw["m"], Lw["P"]) 561 | qvecs_lw = whitenapply(qvecs, Lw["m"], Lw["P"]) 562 | 563 | # search, rank, and print 564 | scores = np.dot(vecs_lw.T, qvecs_lw) 565 | ranks = np.argsort(-scores, axis=0) 566 | r = compute_map_and_print(dataset + " + whiten", ranks, cfg["gnd"]) 567 | attack_result[dataset] = r 568 | 569 | print(">> {}: elapsed time: {}".format(dataset, htime(time.time() - start))) 570 | return inv_gfr( 571 | attack_result, baseline_result[net.meta["architecture"]][net.meta["pooling"]] 572 | ) 573 | 574 | 575 | def save_noise(noise, is_best, epoch): 576 | filename = os.path.join(args.noise_path, "noise_%d" % epoch) 577 | np.save(filename, noise.cpu().numpy()) 578 | torchvision.utils.save_image(noise, filename + ".png", normalize=True) 579 | if is_best: 580 | filename_best = os.path.join(args.noise_path, "noise_best") 581 | shutil.copyfile(filename + ".npy", filename_best + ".npy") 582 | shutil.copyfile(filename + ".png", filename_best + ".png") 583 | 584 | 585 | class AverageMeter(object): 586 | """Computes and stores the average and current value""" 587 | 588 | def __init__(self): 589 | self.reset() 590 | 591 | def reset(self): 592 | self.val = 0 593 | self.avg = 0 594 | self.sum = 0 595 | self.count = 0 596 | 597 | def update(self, val, n=1): 598 | self.val = val 599 | self.sum += val * n 600 | self.count += n 601 | self.avg = self.sum / self.count 602 | 603 | 604 | if __name__ == "__main__": 605 | main() 606 | --------------------------------------------------------------------------------