├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── dataset ├── dataset.py └── transform.py ├── model ├── arcface.py ├── dolg.py └── gem_pool.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | data 3 | lightning_logs 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 DK 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Pytorch Implementation of Deep Orthogonal Fusion of Local and Global Features (DOLG) 3 | 4 | This is the unofficial PyTorch Implementation of "DOLG: Single-Stage Image Retrieval with Deep Orthogonal Fusion of Local and Global Features" 5 | 6 | reference: https://arxiv.org/pdf/2108.02927.pdf 7 | 8 | ## Model Structure 9 | 10 | ![Image](https://github.com/tanzeyy/DOLG-instance-retrieval/raw/main/imgs/figure2.png) 11 | 12 | ## Prerequisites 13 | 14 | + PyTorch 15 | + PyTorch Lightning 16 | + timm 17 | + sklearn 18 | + pandas 19 | + jpeg4py 20 | + albumentations 21 | + python3 22 | + CUDA 23 | 24 | ## Data 25 | 26 | You can get the GLDv2 dataset from [here](https://github.com/cvdfoundation/google-landmark). 27 | 28 | If you just want the GLDv2-clean dataset, check this [kaggle competition dataset](https://www.kaggle.com/c/landmark-retrieval-2021). 29 | 30 | Place your data like the structure below 31 | 32 | ``` 33 | data 34 | ├── train_clean.csv 35 | └── train 36 | └── ### 37 | └── ### 38 | └── ### 39 | └── ###.jpg 40 | ``` 41 | 42 | ## Citations 43 | 44 | ```bibtex 45 | @misc{yang2021dolg, 46 | title={DOLG: Single-Stage Image Retrieval with Deep Orthogonal Fusion of Local and Global Features}, 47 | author={Min Yang and Dongliang He and Miao Fan and Baorong Shi and Xuetong Xue and Fu Li and Errui Ding and Jizhou Huang}, 48 | year={2021}, 49 | eprint={2108.02927}, 50 | archivePrefix={arXiv}, 51 | primaryClass={cs.CV} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class Config: 5 | DATA_DIR = 'data' 6 | CSV_PATH = os.path.join(DATA_DIR, 'train_clean.csv') 7 | train_batch_size = 10 8 | val_batch_size = 10 9 | num_workers = 8 10 | image_size = 512 11 | output_dim = 512 12 | hidden_dim = 1024 13 | input_dim = 3 14 | epochs = 35 15 | lr = 1e-4 16 | num_of_classes = 88313 17 | pretrained = True 18 | model_name = 'resnet101' 19 | seed = 42 20 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jpeg4py as jpeg 3 | import pandas as pd 4 | from sklearn import preprocessing 5 | from torch.utils.data import Dataset 6 | from dataset.transform import image_transform 7 | from config import Config 8 | 9 | 10 | def img_path_from_id(id): 11 | img_path = os.path.join(Config.DATA_DIR, 'train', 12 | id[0], id[1], id[2], f'{id}.jpg') 13 | return img_path 14 | 15 | 16 | class LmkRetrDataset(Dataset): 17 | def __init__(self): 18 | self.df = pd.read_csv(Config.CSV_PATH) 19 | self.landmark_id_encoder = preprocessing.LabelEncoder() 20 | self.df['landmark_id'] = self.landmark_id_encoder.fit_transform( 21 | self.df['landmark_id']) 22 | self.df['path'] = self.df['id'].apply(img_path_from_id) 23 | self.paths = self.df['path'].values 24 | self.ids = self.df['id'].values 25 | self.landmark_ids = self.df['landmark_id'].values 26 | self.transform = image_transform 27 | 28 | def __len__(self): 29 | return len(self.df) 30 | 31 | def __getitem__(self, idx): 32 | path, id, landmark_id = self.paths[idx], self.ids[idx], self.landmark_ids[idx] 33 | img = jpeg.JPEG(path).decode() 34 | if self.transform: 35 | img = self.transform(image=img)['image'] 36 | return img, landmark_id, id 37 | -------------------------------------------------------------------------------- /dataset/transform.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | import albumentations.pytorch 3 | from config import Config 4 | 5 | 6 | image_transform = A.Compose([ 7 | A.Resize(Config.image_size, Config.image_size), 8 | A.Normalize(), 9 | A.pytorch.transforms.ToTensorV2() 10 | ]) 11 | -------------------------------------------------------------------------------- /model/arcface.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class ArcFace(nn.Module): 8 | def __init__(self, in_features, out_features, scale_factor=64.0, margin=0.50, criterion=None): 9 | super(ArcFace, self).__init__() 10 | self.in_features = in_features 11 | self.out_features = out_features 12 | 13 | if criterion: 14 | self.criterion = criterion 15 | else: 16 | self.criterion = nn.CrossEntropyLoss() 17 | 18 | self.margin = margin 19 | self.scale_factor = scale_factor 20 | 21 | self.weight = nn.Parameter( 22 | torch.FloatTensor(out_features, in_features)) 23 | nn.init.xavier_uniform_(self.weight) 24 | 25 | self.cos_m = math.cos(margin) 26 | self.sin_m = math.sin(margin) 27 | self.th = math.cos(math.pi - margin) 28 | self.mm = math.sin(math.pi - margin) * margin 29 | 30 | def forward(self, input, label): 31 | # input is not l2 normalized 32 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 33 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 34 | 35 | phi = cosine * self.cos_m - sine * self.sin_m 36 | phi = phi.type(cosine.type()) 37 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 38 | 39 | one_hot = torch.zeros(cosine.size(), device=input.device) 40 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 41 | 42 | logit = (one_hot * phi) + ((1.0 - one_hot) * cosine) 43 | logit *= self.scale_factor 44 | 45 | loss = self.criterion(logit, label) 46 | 47 | return loss, logit 48 | -------------------------------------------------------------------------------- /model/dolg.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | from torch.utils.data import DataLoader 8 | from pytorch_lightning import LightningModule 9 | 10 | from config import Config 11 | from model.gem_pool import GeM 12 | from model.arcface import ArcFace 13 | from dataset.dataset import LmkRetrDataset 14 | 15 | 16 | class MultiAtrous(nn.Module): 17 | def __init__(self, in_channel, out_channel, size, dilation_rates=[3, 6, 9]): 18 | super().__init__() 19 | self.dilated_convs = [ 20 | nn.Conv2d(in_channel, int(out_channel/4), 21 | kernel_size=3, dilation=rate, padding=rate) 22 | for rate in dilation_rates 23 | ] 24 | self.gap_branch = nn.Sequential( 25 | nn.AdaptiveAvgPool2d(1), 26 | nn.Conv2d(in_channel, int(out_channel/4), kernel_size=1), 27 | nn.ReLU(), 28 | nn.Upsample(size=(size, size), mode='bilinear') 29 | ) 30 | self.dilated_convs.append(self.gap_branch) 31 | self.dilated_convs = nn.ModuleList(self.dilated_convs) 32 | 33 | def forward(self, x): 34 | local_feat = [] 35 | for dilated_conv in self.dilated_convs: 36 | local_feat.append(dilated_conv(x)) 37 | local_feat = torch.cat(local_feat, dim=1) 38 | return local_feat 39 | 40 | 41 | class DolgLocalBranch(nn.Module): 42 | def __init__(self, in_channel, out_channel, hidden_channel=2048): 43 | super().__init__() 44 | self.multi_atrous = MultiAtrous(in_channel, hidden_channel, size=int(Config.image_size/8)) 45 | self.conv1x1_1 = nn.Conv2d(hidden_channel, out_channel, kernel_size=1) 46 | self.conv1x1_2 = nn.Conv2d( 47 | out_channel, out_channel, kernel_size=1, bias=False) 48 | self.conv1x1_3 = nn.Conv2d(out_channel, out_channel, kernel_size=1) 49 | 50 | self.relu = nn.ReLU() 51 | self.bn = nn.BatchNorm2d(out_channel) 52 | self.softplus = nn.Softplus() 53 | 54 | def forward(self, x): 55 | local_feat = self.multi_atrous(x) 56 | 57 | local_feat = self.conv1x1_1(local_feat) 58 | local_feat = self.relu(local_feat) 59 | local_feat = self.conv1x1_2(local_feat) 60 | local_feat = self.bn(local_feat) 61 | 62 | attention_map = self.relu(local_feat) 63 | attention_map = self.conv1x1_3(attention_map) 64 | attention_map = self.softplus(attention_map) 65 | 66 | local_feat = F.normalize(local_feat, p=2, dim=1) 67 | local_feat = local_feat * attention_map 68 | 69 | return local_feat 70 | 71 | 72 | class OrthogonalFusion(nn.Module): 73 | def __init__(self): 74 | super().__init__() 75 | 76 | def forward(self, local_feat, global_feat): 77 | global_feat_norm = torch.norm(global_feat, p=2, dim=1) 78 | projection = torch.bmm(global_feat.unsqueeze(1), torch.flatten( 79 | local_feat, start_dim=2)) 80 | projection = torch.bmm(global_feat.unsqueeze( 81 | 2), projection).view(local_feat.size()) 82 | projection = projection / \ 83 | (global_feat_norm * global_feat_norm).view(-1, 1, 1, 1) 84 | orthogonal_comp = local_feat - projection 85 | global_feat = global_feat.unsqueeze(-1).unsqueeze(-1) 86 | return torch.cat([global_feat.expand(orthogonal_comp.size()), orthogonal_comp], dim=1) 87 | 88 | 89 | class DolgNet(LightningModule): 90 | def __init__(self, input_dim, hidden_dim, output_dim, num_of_classes): 91 | super().__init__() 92 | self.cnn = timm.create_model( 93 | 'tv_resnet101', 94 | pretrained=True, 95 | features_only=True, 96 | in_chans=input_dim, 97 | out_indices=(2, 3) 98 | ) 99 | self.orthogonal_fusion = OrthogonalFusion() 100 | self.local_branch = DolgLocalBranch(512, hidden_dim) 101 | self.gap = nn.AdaptiveAvgPool2d(1) 102 | self.gem_pool = GeM() 103 | self.fc_1 = nn.Linear(1024, hidden_dim) 104 | self.fc_2 = nn.Linear(int(2*hidden_dim), output_dim) 105 | 106 | self.criterion = ArcFace( 107 | in_features=output_dim, 108 | out_features=num_of_classes, 109 | scale_factor=30, 110 | margin=0.15, 111 | criterion=nn.CrossEntropyLoss() 112 | ) 113 | self.lr = Config.lr 114 | 115 | def forward(self, x): 116 | output = self.cnn(x) 117 | 118 | local_feat = self.local_branch(output[0]) # ,hidden_channel,16,16 119 | global_feat = self.fc_1(self.gem_pool(output[1]).squeeze()) # ,1024 120 | 121 | feat = self.orthogonal_fusion(local_feat, global_feat) 122 | feat = self.gap(feat).squeeze() 123 | feat = self.fc_2(feat) 124 | 125 | return feat 126 | 127 | def training_step(self, batch, batch_idx): 128 | img, label, _ = batch 129 | embd = self(img) 130 | loss, logits = self.criterion(embd, label) 131 | return loss 132 | 133 | def configure_optimizers(self): 134 | optimizer = optim.SGD(self.parameters(), lr=self.lr, 135 | momentum=0.9, weight_decay=1e-5) 136 | scheduler = scheduler = optim.lr_scheduler.CosineAnnealingLR( 137 | optimizer, T_max=1000) 138 | return [optimizer], [scheduler] 139 | 140 | def train_dataloader(self): 141 | dataset = LmkRetrDataset() 142 | return DataLoader(dataset, batch_size=Config.train_batch_size, num_workers=Config.num_workers, 143 | shuffle=True, pin_memory=True, persistent_workers=True) 144 | -------------------------------------------------------------------------------- /model/gem_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GeM(nn.Module): 7 | def __init__(self, p=3, eps=1e-6, requires_grad=False): 8 | super(GeM, self).__init__() 9 | self.p = nn.Parameter(torch.ones(1)*p, requires_grad=requires_grad) 10 | self.eps = eps 11 | 12 | def forward(self, x): 13 | return self.gem(x, p=self.p, eps=self.eps) 14 | 15 | def gem(self, x, p=3, eps=1e-6): 16 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) 17 | 18 | def __repr__(self): 19 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 20 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.utilities.seed import seed_everything 2 | from pytorch_lightning import Trainer 3 | 4 | from model.dolg import DolgNet 5 | from config import Config 6 | 7 | 8 | seed_everything(Config.seed) 9 | 10 | model = DolgNet( 11 | input_dim=Config.input_dim, 12 | hidden_dim=Config.hidden_dim, 13 | output_dim=Config.output_dim, 14 | num_of_classes=Config.num_of_classes 15 | ) 16 | 17 | trainer = Trainer(gpus=1, max_epochs=Config.epochs) 18 | 19 | trainer.fit(model) 20 | --------------------------------------------------------------------------------