├── .gitattributes ├── LICENSE ├── README.md ├── config ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── defaults.cpython-37.pyc └── defaults.py ├── configs ├── default.yml └── msmt.yml ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── build.cpython-37.pyc │ └── collate_batch.cpython-37.pyc ├── build.py ├── collate_batch.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── bases.cpython-36.pyc │ │ ├── bases.cpython-37.pyc │ │ ├── cuhk03.cpython-36.pyc │ │ ├── cuhk03.cpython-37.pyc │ │ ├── dataset_loader.cpython-36.pyc │ │ ├── dataset_loader.cpython-37.pyc │ │ ├── dukemtmcreid.cpython-36.pyc │ │ ├── dukemtmcreid.cpython-37.pyc │ │ ├── eval_reid.cpython-36.pyc │ │ ├── eval_reid.cpython-37.pyc │ │ ├── market1501.cpython-36.pyc │ │ ├── market1501.cpython-37.pyc │ │ ├── msmt.cpython-36.pyc │ │ └── msmt.cpython-37.pyc │ ├── bases.py │ ├── dataset_loader.py │ ├── dukemtmcreid.py │ ├── eval_reid.py │ ├── market1501.py │ └── msmt.py ├── samplers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── triplet_sampler.cpython-37.pyc │ └── triplet_sampler.py └── transforms │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── autoaugment.cpython-37.pyc │ ├── build.cpython-37.pyc │ └── transforms.cpython-37.pyc │ ├── autoaugment.py │ ├── build.py │ └── transforms.py ├── engine ├── __pycache__ │ └── trainer.cpython-37.pyc ├── inference.py └── trainer.py ├── images ├── github_main_graph.png ├── github_vis.png └── test.txt ├── layers ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── cross_entropy_loss.cpython-37.pyc │ └── triplet_loss.cpython-37.pyc ├── center_loss.py ├── cross_entropy_loss.py └── triplet_loss.py ├── modeling ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── baseline.cpython-37.pyc ├── backbones │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── inception.cpython-37.pyc │ │ ├── resnest.cpython-37.pyc │ │ ├── resnet.cpython-37.pyc │ │ └── se_module.cpython-37.pyc │ ├── densenet.py │ ├── inception.py │ ├── resnest.py │ ├── resnet.py │ ├── se_module.py │ └── se_resnet.py └── baseline.py ├── requirements.txt ├── solver ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── build.cpython-36.pyc │ ├── build.cpython-37.pyc │ ├── lr_scheduler.cpython-36.pyc │ └── lr_scheduler.cpython-37.pyc ├── build.py └── lr_scheduler.py ├── test.sh ├── tests ├── __init__.py └── lr_scheduler_test.py ├── tools ├── __init__.py ├── test.py └── train.py ├── train.sh └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── iotools.cpython-37.pyc ├── logger.cpython-37.pyc ├── re_ranking.cpython-37.pyc └── reid_metric.cpython-37.pyc ├── iotools.py ├── logger.py ├── re_ranking.py └── reid_metric.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Gutianpei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Person Re-identification via Attention Pyramid (APNet) 2 | The official PyTorch code implementation for TIP20' Submission: "Person Re-identification via Attention Pyramid" 3 | 4 | This repo only contains channel-wise attention(SE-Layer) implementation, to reproduce the result of spatial attention in our paper, please refer to [RGA-S](https://github.com/microsoft/Relation-Aware-Global-Attention-Networks) by Microsoft and simply change the attention agent. We also want to thank FastReid which is the codebase of our implementation. 5 | 6 | ## Introduction 7 | Recently, attention mechanism has been widely used in the ReID system to facilitate high-performance identification and demonstrates the powerful representation ability by discovering discriminative regions and mitigating the misalignment. However, detecting the salient regions with the attention model is confronted with the dilemma to jointly capture both coarse and fine-grained clues, since the focus varies as the image scale changes. To address the above issue, we propose an effective attention pyramid networks (APNet) to jointly learn the attentions under different scales. Our attention pyramid imitates the process of human vi- sual perception which tends to notice the foreground person over the cluttered background, and further focus on the specific color of the shirt with a close observation. Please see the Figure1 below and our paper for the method detail. 8 | 9 | We validate our method in Market1501, DukeMTMC and MSMT17 datasets, and our method shows a superior performance on all the datasets. Please check the Result section for the detaied quantity and quality result. 10 | 11 | ![image](https://github.com/CHENGY12/APNet/blob/main/images/github_main_graph.png) 12 | Figure 1: The architecture of Attention Pyramid Networks (APNet). Our APNet adopts the “split-attend-merge-stack” principle, which first splits the feature maps into multiple parts, obtains the attention map of each part, and the attention map for current pyramid level is constructed by merging each attention map. Then in deeper pyramid level, we split the features into more fine-grained parts and learn the fine-grained attention guiding by coarse attentions. Finally, attentions with different granularities are stacked as attention pyramid and applies to original input feature by element-wise product. 13 | 14 | 15 | ## Requirements 16 | - Python 3.6+ 17 | - PyTorch 1.5+ 18 | - CUDA 10.0+ 19 | 20 | Configuration other than the above setting is untested and we recommend to follow our setting. 21 | 22 | To build all the dependency, please follow the instruction below. 23 | ``` 24 | conda create -n apnet python=3.7 -y 25 | conda activate apnet 26 | pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html 27 | conda install ignite -c pytorch 28 | git clone https://github.com/CHENGY12/APNet.git 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | To download the pretrained ResNet-50 model, please run the following command in your python console: 33 | ``` 34 | from torchvision.models import resnet50 35 | resnet50(pretrained=True) 36 | ``` 37 | The model should be located in RESNET_PATH=```/home/YOURNAME/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth``` or ```/home/YOURNAME/.cache/torch/checkpoints/resnet50-19c8e357.pth``` 38 | 39 | ### Downloading 40 | - Market-1501 41 | - DukeMTMC-reID 42 | - MSMT17 43 | ### Preparation 44 | After downloading the datasets above, move them to the `Datasets/` folder in the project root directory, and rename dataset folders to 'market1501', 'duke' and 'msmt17' respectively. I.e., the `Datasets/` folder should be organized as: 45 | ``` 46 | |-- market1501 47 | |-- bounding_box_train 48 | |-- bounding_box_test 49 | |-- ... 50 | |-- duke 51 | |-- bounding_box_train 52 | |-- bounding_box_test 53 | |-- ... 54 | |-- msmt17 55 | |-- bounding_box_train 56 | |-- bounding_box_test 57 | |-- ... 58 | ``` 59 | 60 | ## Usage 61 | ### Training 62 | Change the PRETRAIN_PATH parameter in configs/default.yml to your RESNET_PATH 63 | To train with different pyramid level, please edit LEVEL parareter in configs/default.yml 64 | ``` 65 | sh train.sh 66 | ``` 67 | ### Evaluation 68 | ``` 69 | sh test.sh 70 | ``` 71 | 72 | ## Result 73 | | Dataset | Top-1 | mAP | 74 | | :------------: | :---: | :---: | 75 | | Market-1501 | 96.2 | 90.5 | 76 | | DukeMTMC-Re-ID | 90.4 | 81.5 | 77 | | MSMT17 | 83.7 | 63.5 | 78 | 79 | ![image](https://github.com/CHENGY12/APNet/blob/main/images/github_vis.png) 80 | Figure 2: Visualizations of the attention maps with different pyramid level. We adopt the Grad-CAM to visualize the learned attention maps of our attention pyramid. For each sample, from left to right, we show the input image, attention of first level pyramid, attention of second level pyramid. We can observe that attentions in different pyramid levels capture the salient clues of different scales. 81 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg 2 | -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/config/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /config/__pycache__/defaults.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/config/__pycache__/defaults.cpython-37.pyc -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | # or _TEST for a test-specific parameter. 9 | # For example, the number of images during training will be 10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be 11 | # IMAGES_PER_BATCH_TEST 12 | 13 | # ----------------------------------------------------------------------------- 14 | # Config definition 15 | # ----------------------------------------------------------------------------- 16 | 17 | _C = CN() 18 | 19 | _C.MODEL = CN() 20 | _C.MODEL.DEVICE = "cuda:0" 21 | _C.MODEL.NAME = 'seresnet50' # 'resnet50' 'seresnet50' 'densenet196' 22 | _C.MODEL.LAST_STRIDE = 1 23 | _C.MODEL.PRETRAIN_PATH = '' 24 | _C.MODEL.PRETRAIN_PATH_SE = '' 25 | _C.MODEL.PRETRAIN_PATH_DENSE = '' 26 | 27 | # ----------------------------------------------------------------------------- 28 | # APNet 29 | # ----------------------------------------------------------------------------- 30 | _C.APNET = CN() 31 | _C.APNET.LEVEL = 0 32 | _C.APNET.MSMT = False 33 | # ----------------------------------------------------------------------------- 34 | # INPUT 35 | # ----------------------------------------------------------------------------- 36 | _C.INPUT = CN() 37 | # Size of the image during training 38 | _C.INPUT.SIZE_TRAIN = [384, 128] 39 | # Size of the image during test 40 | _C.INPUT.SIZE_TEST = [384, 128] 41 | # Random probability for image horizontal flip 42 | _C.INPUT.PROB = 0.5 43 | # Values to be used for image normalization 44 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 45 | # Values to be used for image normalization 46 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 47 | # Value of padding size 48 | _C.INPUT.PADDING = 10 49 | 50 | # ----------------------------------------------------------------------------- 51 | # Dataset 52 | # ----------------------------------------------------------------------------- 53 | _C.DATASETS = CN() 54 | # List of the dataset names for training, as present in paths_catalog.py 55 | _C.DATASETS.NAMES = ('market1501') 56 | 57 | # ----------------------------------------------------------------------------- 58 | # DataLoader 59 | # ----------------------------------------------------------------------------- 60 | _C.DATALOADER = CN() 61 | # Number of data loading threads 62 | _C.DATALOADER.NUM_WORKERS = 8 63 | # Sampler for data loading 64 | _C.DATALOADER.SAMPLER = 'softmax' 65 | # Number of instance for one batch 66 | _C.DATALOADER.NUM_INSTANCE = 16 67 | 68 | # ---------------------------------------------------------------------------- # 69 | # Solver 70 | # ---------------------------------------------------------------------------- # 71 | _C.SOLVER = CN() 72 | _C.SOLVER.OPTIMIZER_NAME = "Adam" 73 | 74 | _C.SOLVER.MAX_EPOCHS = 50 75 | 76 | _C.SOLVER.BASE_LR = 3e-4 77 | _C.SOLVER.BIAS_LR_FACTOR = 2 78 | 79 | _C.SOLVER.MOMENTUM = 0.9 80 | 81 | _C.SOLVER.MARGIN = 0.3 82 | 83 | _C.SOLVER.SMOOTH = 0.1 84 | _C.SOLVER.CLASSNUM = 751 85 | 86 | # Learning rate of SGD to learn the centers of center loss 87 | _C.SOLVER.CENTER_LR = 0.5 88 | # Balanced weight of center loss 89 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005 90 | 91 | _C.SOLVER.WEIGHT_DECAY = 0.0005 92 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0. 93 | 94 | _C.SOLVER.GAMMA = 0.1 95 | _C.SOLVER.STEPS = (30, 55) 96 | 97 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3 98 | _C.SOLVER.WARMUP_ITERS = 500 99 | _C.SOLVER.WARMUP_METHOD = "linear" 100 | 101 | _C.SOLVER.CHECKPOINT_PERIOD = 50 102 | _C.SOLVER.LOG_PERIOD = 100 103 | _C.SOLVER.EVAL_PERIOD = 50 104 | _C.SOLVER.FINETUNE = False 105 | 106 | # Number of images per batch 107 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 108 | # see 2 images per batch 109 | _C.SOLVER.IMS_PER_BATCH = 64 110 | 111 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 112 | # see 2 images per batch 113 | _C.TEST = CN() 114 | _C.TEST.IMS_PER_BATCH = 128 115 | _C.TEST.RE_RANK = False 116 | _C.TEST.WEIGHT = "" 117 | 118 | # ---------------------------------------------------------------------------- # 119 | # Misc options 120 | # ---------------------------------------------------------------------------- # 121 | _C.OUTPUT_DIR = "" 122 | -------------------------------------------------------------------------------- /configs/default.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: 'resnet50' 3 | PRETRAIN_PATH: 'RESNET_PATH' 4 | #ResNet50 Pretrained Model Path, eg "/home/gutianpei/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth" 5 | 6 | APNET: 7 | LEVEL: 2 8 | 9 | INPUT: 10 | SIZE_TRAIN: [384, 192] 11 | SIZE_TEST: [384, 192] 12 | PROB: 0.5 # random horizontal flip 13 | PADDING: 10 14 | 15 | DATASETS: 16 | NAMES: ('market1501') #select from "dukemtmc", "market1501" and "msmt17" 17 | 18 | DATALOADER: 19 | SAMPLER: 'softmax_triplet' 20 | NUM_INSTANCE: 4 21 | NUM_WORKERS: 8 22 | 23 | SOLVER: 24 | OPTIMIZER_NAME: 'Adam' 25 | MAX_EPOCHS: 160 26 | BASE_LR: 0.0004 27 | BIAS_LR_FACTOR: 1 28 | WEIGHT_DECAY: 0.001 29 | WEIGHT_DECAY_BIAS: 0.001 30 | SMOOTH: 0.1 31 | IMS_PER_BATCH: 80 32 | 33 | STEPS: [40, 80, 120, 160] 34 | GAMMA: 0.1 35 | 36 | WARMUP_FACTOR: 0.01 37 | WARMUP_ITERS: 10 38 | WARMUP_METHOD: 'linear' 39 | # CLASSNUM: 1019 40 | 41 | CHECKPOINT_PERIOD: 10 42 | LOG_PERIOD: 100 43 | EVAL_PERIOD: 40 44 | 45 | 46 | TEST: 47 | IMS_PER_BATCH: 512 48 | RE_RANK: False 49 | WEIGHT: "path" 50 | 51 | OUTPUT_DIR: "/home/gutianpei/ivg/github_temp/att/Market/" 52 | # /home/gtp_cgy/ivg/dataset/Occluded-DukeMTMC-Dataset/Occluded_Duke 53 | -------------------------------------------------------------------------------- /configs/msmt.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: 'resnet50' 3 | PRETRAIN_PATH: '/home/gutianpei/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth' 4 | #ResNet50 Pretrained Model Path, eg "/home/gutianpei/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth" 5 | 6 | APNET: 7 | LEVEL: 2 8 | MSMT: True 9 | 10 | INPUT: 11 | SIZE_TRAIN: [384, 192] 12 | SIZE_TEST: [384, 192] 13 | PROB: 0.5 # random horizontal flip 14 | PADDING: 10 15 | 16 | DATASETS: 17 | NAMES: ('msmt17') #select from "dukemtmc", "market1501" and "msmt17" 18 | 19 | DATALOADER: 20 | SAMPLER: 'softmax_triplet' 21 | NUM_INSTANCE: 4 22 | NUM_WORKERS: 8 23 | 24 | SOLVER: 25 | OPTIMIZER_NAME: 'Adam' 26 | MAX_EPOCHS: 160 27 | BASE_LR: 0.0004 28 | BIAS_LR_FACTOR: 1 29 | WEIGHT_DECAY: 0.001 30 | WEIGHT_DECAY_BIAS: 0.001 31 | SMOOTH: 0.1 32 | IMS_PER_BATCH: 80 33 | 34 | STEPS: [40, 80, 120, 160] 35 | GAMMA: 0.1 36 | 37 | WARMUP_FACTOR: 0.01 38 | WARMUP_ITERS: 10 39 | WARMUP_METHOD: 'linear' 40 | # CLASSNUM: 1019 41 | 42 | CHECKPOINT_PERIOD: 10 43 | LOG_PERIOD: 100 44 | EVAL_PERIOD: 40 45 | 46 | 47 | TEST: 48 | IMS_PER_BATCH: 512 49 | RE_RANK: False 50 | WEIGHT: "path" 51 | 52 | OUTPUT_DIR: "/home/gutianpei/ivg/github_temp/att/Market/" 53 | # /home/gtp_cgy/ivg/dataset/Occluded-DukeMTMC-Dataset/Occluded_Duke 54 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import make_data_loader 2 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/collate_batch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/__pycache__/collate_batch.cpython-37.pyc -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from .collate_batch import train_collate_fn, val_collate_fn 4 | from .datasets import init_dataset, ImageDataset 5 | from .samplers import RandomIdentitySampler 6 | from .transforms import build_transforms 7 | import pdb 8 | import numpy as np 9 | import torch 10 | import matplotlib.pyplot as plt 11 | 12 | def make_data_loader(cfg): 13 | train_transforms = build_transforms(cfg, is_train=True) 14 | val_transforms = build_transforms(cfg, is_train=False) 15 | num_workers = cfg.DATALOADER.NUM_WORKERS 16 | if len(cfg.DATASETS.NAMES) == 1: 17 | dataset = init_dataset(cfg.DATASETS.NAMES) 18 | else: 19 | # TODO: add multi dataset to train 20 | dataset = init_dataset(cfg.DATASETS.NAMES) 21 | 22 | num_classes = dataset.num_train_pids 23 | train_set = ImageDataset(dataset.train, 'train', train_transforms) 24 | 25 | if cfg.DATALOADER.SAMPLER == 'softmax': 26 | train_loader = DataLoader( 27 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 28 | collate_fn=train_collate_fn 29 | ) 30 | else: 31 | train_loader = DataLoader( 32 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 33 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 34 | num_workers=num_workers, collate_fn=train_collate_fn 35 | ) 36 | 37 | val_set = ImageDataset(dataset.query + dataset.gallery, 'test', val_transforms) 38 | 39 | val_loader = DataLoader( 40 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 41 | collate_fn=val_collate_fn 42 | ) 43 | 44 | return train_loader, val_loader, len(dataset.query), num_classes 45 | -------------------------------------------------------------------------------- /data/collate_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def train_collate_fn(batch): 4 | imgs, pids, _, _, = zip(*batch) 5 | pids = torch.tensor(pids, dtype=torch.int64) 6 | return torch.stack(imgs, dim=0), pids 7 | 8 | 9 | def val_collate_fn(batch): 10 | imgs, pids, camids, _ = zip(*batch) 11 | return torch.stack(imgs, dim=0), pids, camids 12 | -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dukemtmcreid import DukeMTMCreID 3 | from .market1501 import Market1501 4 | from .msmt import MSMT17_V1 5 | from .dataset_loader import ImageDataset 6 | 7 | __factory = { 8 | 'market1501': Market1501, 9 | 'msmt17': MSMT17_V1, 10 | 'dukemtmc': DukeMTMCreID 11 | } 12 | 13 | 14 | def get_names(): 15 | return __factory.keys() 16 | 17 | 18 | def init_dataset(name, *args, **kwargs): 19 | if name not in __factory.keys(): 20 | raise KeyError("Unknown datasets: {}".format(name)) 21 | return __factory[name](*args, **kwargs) 22 | -------------------------------------------------------------------------------- /data/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/bases.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/bases.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/bases.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/bases.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/cuhk03.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/cuhk03.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/cuhk03.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/cuhk03.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dataset_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/dataset_loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dataset_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/dataset_loader.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dukemtmcreid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/dukemtmcreid.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dukemtmcreid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/dukemtmcreid.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/eval_reid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/eval_reid.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/eval_reid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/eval_reid.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/market1501.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/market1501.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/market1501.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/market1501.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/msmt.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/msmt.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/msmt.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/datasets/__pycache__/msmt.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/bases.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(object): 6 | """ 7 | Base class of reid dataset 8 | """ 9 | 10 | def get_imagedata_info(self, data): 11 | pids, imgids, cams = [], [], [] 12 | for _, pid, camid in data: 13 | pids += [pid] 14 | cams += [camid] 15 | pids = set(pids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cams = len(cams) 19 | num_imgs = len(data) 20 | return num_pids, num_imgs, num_cams 21 | 22 | def get_videodata_info(self, data, return_tracklet_stats=False): 23 | pids, cams, tracklet_stats = [], [], [] 24 | for img_paths, pid, camid in data: 25 | pids += [pid] 26 | cams += [camid] 27 | tracklet_stats += [len(img_paths)] 28 | pids = set(pids) 29 | cams = set(cams) 30 | num_pids = len(pids) 31 | num_cams = len(cams) 32 | num_tracklets = len(data) 33 | if return_tracklet_stats: 34 | return num_pids, num_tracklets, num_cams, tracklet_stats 35 | return num_pids, num_tracklets, num_cams 36 | 37 | def print_dataset_statistics(self): 38 | raise NotImplementedError 39 | 40 | 41 | class BaseImageDataset(BaseDataset): 42 | """ 43 | Base class of image reid dataset 44 | """ 45 | 46 | def print_dataset_statistics(self, train, query, gallery): 47 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 48 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 49 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 50 | 51 | print("Dataset statistics:") 52 | print(" ----------------------------------------") 53 | print(" subset | # ids | # images | # cameras") 54 | print(" ----------------------------------------") 55 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 56 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 57 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 58 | print(" ----------------------------------------") 59 | 60 | 61 | class BaseVideoDataset(BaseDataset): 62 | """ 63 | Base class of video reid dataset 64 | """ 65 | 66 | def print_dataset_statistics(self, train, query, gallery): 67 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \ 68 | self.get_videodata_info(train, return_tracklet_stats=True) 69 | 70 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \ 71 | self.get_videodata_info(query, return_tracklet_stats=True) 72 | 73 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \ 74 | self.get_videodata_info(gallery, return_tracklet_stats=True) 75 | 76 | tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats 77 | min_num = np.min(tracklet_stats) 78 | max_num = np.max(tracklet_stats) 79 | avg_num = np.mean(tracklet_stats) 80 | 81 | print("Dataset statistics:") 82 | print(" -------------------------------------------") 83 | print(" subset | # ids | # tracklets | # cameras") 84 | print(" -------------------------------------------") 85 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) 86 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) 87 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) 88 | print(" -------------------------------------------") 89 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) 90 | print(" -------------------------------------------") 91 | -------------------------------------------------------------------------------- /data/datasets/dataset_loader.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os.path as osp 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | import mat4py 9 | import pdb 10 | 11 | 12 | 13 | def read_image(img_path): 14 | """Keep reading image until succeed. 15 | This can avoid IOError incurred by heavy IO process.""" 16 | got_img = False 17 | if not osp.exists(img_path): 18 | raise IOError("{} does not exist".format(img_path)) 19 | while not got_img: 20 | try: 21 | img = Image.open(img_path).convert('RGB') 22 | got_img = True 23 | except IOError: 24 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 25 | pass 26 | return img 27 | 28 | class ImageDataset(Dataset): 29 | """Image Person ReID Dataset""" 30 | 31 | def __init__(self, dataset, mode, transform=None): 32 | self.dataset = dataset 33 | self.transform = transform 34 | self.mode = mode 35 | 36 | def __len__(self): 37 | return len(self.dataset) 38 | 39 | def __getitem__(self, index): 40 | img_path, pid, camid = self.dataset[index] 41 | img = read_image(img_path) 42 | 43 | if self.transform is not None: 44 | img = self.transform(img) 45 | 46 | return img, pid, camid, img_path 47 | -------------------------------------------------------------------------------- /data/datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | import os.path as osp 9 | 10 | from utils.iotools import mkdir_if_missing 11 | from .bases import BaseImageDataset 12 | 13 | 14 | class DukeMTMCreID(BaseImageDataset): 15 | """ 16 | DukeMTMC-reID 17 | Reference: 18 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 21 | 22 | Dataset statistics: 23 | # identities: 1404 (train + query) 24 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 25 | # cameras: 8 26 | """ 27 | dataset_dir = '' 28 | 29 | # def __init__(self, root='/home/gtp_cgy/ivg/dataset/Occluded-DukeMTMC-Dataset/Occluded_Duke', verbose=True, **kwargs): 30 | def __init__(self, root='./Datasets/', verbose=True, **kwargs): 31 | super(DukeMTMCreID, self).__init__() 32 | self.dataset_dir = osp.join(root, self.dataset_dir) 33 | self.train_dir = osp.join(self.dataset_dir, 'duke/bounding_box_train') 34 | self.query_dir = osp.join(self.dataset_dir, 'duke/query') 35 | self.gallery_dir = osp.join(self.dataset_dir, 'duke/bounding_box_test') 36 | # self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 37 | # self.query_dir = osp.join(self.dataset_dir, 'query') 38 | # self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 39 | 40 | self._check_before_run() 41 | 42 | train = self._process_dir(self.train_dir, relabel=True) 43 | query = self._process_dir(self.query_dir, relabel=False) 44 | gallery = self._process_dir(self.gallery_dir, relabel=False) 45 | 46 | if verbose: 47 | print("=> DukeMTMC-reID loaded") 48 | self.print_dataset_statistics(train, query, gallery) 49 | 50 | self.train = train 51 | self.query = query 52 | self.gallery = gallery 53 | 54 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 55 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 56 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 57 | 58 | def _check_before_run(self): 59 | """Check if all files are available before going deeper""" 60 | if not osp.exists(self.dataset_dir): 61 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 62 | if not osp.exists(self.train_dir): 63 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 64 | if not osp.exists(self.query_dir): 65 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 66 | if not osp.exists(self.gallery_dir): 67 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 68 | 69 | def _process_dir(self, dir_path, relabel=False): 70 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 71 | pattern = re.compile(r'([-\d]+)_c(\d)') 72 | 73 | pid_container = set() 74 | for img_path in img_paths: 75 | pid, _ = map(int, pattern.search(img_path).groups()) 76 | pid_container.add(pid) 77 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 78 | 79 | dataset = [] 80 | for img_path in img_paths: 81 | pid, camid = map(int, pattern.search(img_path).groups()) 82 | assert 1 <= camid <= 8 83 | camid -= 1 # index starts from 0 84 | if relabel: pid = pid2label[pid] 85 | dataset.append((img_path, pid, camid)) 86 | 87 | return dataset 88 | -------------------------------------------------------------------------------- /data/datasets/eval_reid.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | 5 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 6 | """Evaluation with market1501 metric 7 | Key: for each query identity, its gallery images from the same camera view are discarded. 8 | """ 9 | num_q, num_g = distmat.shape 10 | if num_g < max_rank: 11 | max_rank = num_g 12 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 13 | indices = np.argsort(distmat, axis=1) 14 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 15 | 16 | # compute cmc curve for each query 17 | all_cmc = [] 18 | all_AP = [] 19 | num_valid_q = 0. # number of valid query 20 | for q_idx in range(num_q): 21 | # get query pid and camid 22 | q_pid = q_pids[q_idx] 23 | q_camid = q_camids[q_idx] 24 | 25 | # remove gallery samples that have the same pid and camid with query 26 | order = indices[q_idx] 27 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 28 | keep = np.invert(remove) 29 | 30 | # compute cmc curve 31 | # binary vector, positions with value 1 are correct matches 32 | orig_cmc = matches[q_idx][keep] 33 | if not np.any(orig_cmc): 34 | # this condition is true when query identity does not appear in gallery 35 | continue 36 | 37 | cmc = orig_cmc.cumsum() 38 | cmc[cmc > 1] = 1 39 | 40 | all_cmc.append(cmc[:max_rank]) 41 | num_valid_q += 1. 42 | 43 | # compute average precision 44 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 45 | num_rel = orig_cmc.sum() 46 | tmp_cmc = orig_cmc.cumsum() 47 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 48 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 49 | AP = tmp_cmc.sum() / num_rel 50 | all_AP.append(AP) 51 | 52 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 53 | 54 | all_cmc = np.asarray(all_cmc).astype(np.float32) 55 | all_cmc = all_cmc.sum(0) / num_valid_q 56 | mAP = np.mean(all_AP) 57 | 58 | return all_cmc, mAP 59 | -------------------------------------------------------------------------------- /data/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import re 4 | 5 | import os.path as osp 6 | import pdb 7 | 8 | 9 | 10 | from .bases import BaseImageDataset 11 | 12 | 13 | class Market1501(BaseImageDataset): 14 | """ 15 | Market1501 16 | Reference: 17 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 18 | URL: http://www.liangzheng.org/Project/project_reid.html 19 | 20 | Dataset statistics: 21 | # identities: 1501 (+1 for background) 22 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 23 | """ 24 | dataset_dir = 'market1501' 25 | 26 | def __init__(self, root='./Datasets/', verbose=True, **kwargs): 27 | super(Market1501, self).__init__() 28 | self.dataset_dir = osp.join(root, self.dataset_dir) 29 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 30 | self.query_dir = osp.join(self.dataset_dir, 'query') 31 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 32 | 33 | self._check_before_run() 34 | 35 | train = self._process_dir(self.train_dir, relabel=True) 36 | query = self._process_dir(self.query_dir, relabel=False) 37 | gallery = self._process_dir(self.gallery_dir, relabel=False) 38 | 39 | if verbose: 40 | print("=> Market1501 loaded") 41 | self.print_dataset_statistics(train, query, gallery) 42 | 43 | self.train = train 44 | self.query = query 45 | self.gallery = gallery 46 | 47 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 48 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 49 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 50 | 51 | def _check_before_run(self): 52 | """Check if all files are available before going deeper""" 53 | if not osp.exists(self.dataset_dir): 54 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 55 | if not osp.exists(self.train_dir): 56 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 57 | if not osp.exists(self.query_dir): 58 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 59 | if not osp.exists(self.gallery_dir): 60 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 61 | 62 | def _process_dir(self, dir_path, relabel=False): 63 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 64 | pattern = re.compile(r'([-\d]+)_c(\d)') 65 | 66 | pid_container = set() 67 | for img_path in img_paths: 68 | pid, _ = map(int, pattern.search(img_path).groups()) 69 | if pid == -1: continue # junk images are just ignored 70 | pid_container.add(pid) 71 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 72 | 73 | dataset = [] 74 | 75 | for img_path in img_paths: 76 | pid, camid = map(int, pattern.search(img_path).groups()) 77 | if pid == -1: continue # junk images are just ignored 78 | assert 0 <= pid <= 1501 # pid == 0 means background 79 | assert 1 <= camid <= 6 80 | camid -= 1 # index starts from 0 81 | if relabel: pid = pid2label[pid] 82 | dataset.append((img_path, pid, camid)) 83 | 84 | return dataset 85 | -------------------------------------------------------------------------------- /data/datasets/msmt.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | import pdb 8 | import os.path as osp 9 | 10 | from utils.iotools import mkdir_if_missing 11 | from .bases import BaseImageDataset 12 | 13 | 14 | class MSMT17_V1(BaseImageDataset): 15 | """ 16 | DukeMTMC-reID 17 | Reference: 18 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 21 | 22 | Dataset statistics: 23 | # identities: 1404 (train + query) 24 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 25 | # cameras: 8 26 | """ 27 | dataset_dir = 'msmt17' 28 | 29 | def __init__(self, root='./Datasets/', verbose=True, **kwargs): 30 | super(MSMT17_V1, self).__init__() 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | 33 | self.train_dir = osp.join(self.dataset_dir, 'train') 34 | self.gallery_dir = osp.join(self.dataset_dir, 'test') 35 | self.list_train_path = osp.join(self.dataset_dir, 'list_train.txt') 36 | self.list_val_path = osp.join(self.dataset_dir, 'list_val.txt') 37 | self.list_query_path = osp.join(self.dataset_dir, 'list_query.txt') 38 | self.list_gallery_path = osp.join(self.dataset_dir, 'list_gallery.txt') 39 | 40 | self._check_before_run() 41 | 42 | train = self._process_dir(self.train_dir, self.list_train_path, relabel=False) 43 | #train = self._process_dir(self.train_dir, self.list_train_path, relabel=True) 44 | query = self._process_dir(self.gallery_dir, self.list_query_path, relabel=False) 45 | gallery = self._process_dir(self.gallery_dir, self.list_gallery_path, relabel=False) 46 | #val = self._process_dir(self.list_val_path, relabel=False) 47 | 48 | if verbose: 49 | print("=> MSMT17_V1 loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | #self.val = val 55 | self.gallery = gallery 56 | 57 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 58 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 59 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 60 | 61 | 62 | def _check_before_run(self): 63 | """Check if all files are available before going deeper""" 64 | if not osp.exists(self.dataset_dir): 65 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 66 | if not osp.exists(self.train_dir): 67 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 68 | if not osp.exists(self.list_query_path): 69 | raise RuntimeError("'{}' is not available".format(self.list_query_path)) 70 | if not osp.exists(self.list_gallery_path): 71 | raise RuntimeError("'{}' is not available".format(self.list_gallery_path)) 72 | 73 | def _process_dir(self, dir_path, txt_path, relabel=False): 74 | with open(txt_path, 'r') as f: 75 | img_paths = f.read() 76 | img_paths = img_paths.split() 77 | pid_container = set() 78 | 79 | for pid in img_paths[1::2]: 80 | pid = int(pid) 81 | 82 | pid_container.add(pid) 83 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 84 | dataset = [] 85 | for img_path in img_paths[::2]: 86 | pid = int(img_path.split('/')[0]) 87 | #print(pid) 88 | camid = int(img_path.split('_')[2]) 89 | assert 1 <= camid <= 15 90 | camid -= 1 # index starts from 0 91 | if relabel: pid = pid2label[pid] 92 | img_path = osp.join(dir_path,img_path) 93 | dataset.append((img_path, pid, camid)) 94 | 95 | 96 | return dataset 97 | -------------------------------------------------------------------------------- /data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .triplet_sampler import RandomIdentitySampler 4 | -------------------------------------------------------------------------------- /data/samplers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/samplers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/samplers/__pycache__/triplet_sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/samplers/__pycache__/triplet_sampler.cpython-37.pyc -------------------------------------------------------------------------------- /data/samplers/triplet_sampler.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import copy 4 | import random 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | from torch.utils.data.sampler import Sampler 9 | 10 | 11 | class RandomIdentitySampler(Sampler): 12 | """ 13 | Randomly sample N identities, then for each identity, 14 | randomly sample K instances, therefore batch size is N*K. 15 | Args: 16 | - data_source (list): list of (img_path, pid, camid). 17 | - num_instances (int): number of instances per identity in a batch. 18 | - batch_size (int): number of examples in a batch. 19 | """ 20 | 21 | def __init__(self, data_source, batch_size, num_instances): 22 | self.data_source = data_source 23 | self.batch_size = batch_size 24 | self.num_instances = num_instances 25 | self.num_pids_per_batch = self.batch_size // self.num_instances 26 | self.index_dic = defaultdict(list) 27 | for index, (_, pid, _) in enumerate(self.data_source): 28 | self.index_dic[pid].append(index) 29 | self.pids = list(self.index_dic.keys()) 30 | 31 | # estimate number of examples in an epoch 32 | self.length = 0 33 | for pid in self.pids: 34 | idxs = self.index_dic[pid] 35 | num = len(idxs) 36 | if num < self.num_instances: 37 | num = self.num_instances 38 | self.length += num - num % self.num_instances 39 | 40 | def __iter__(self): 41 | batch_idxs_dict = defaultdict(list) 42 | 43 | for pid in self.pids: 44 | idxs = copy.deepcopy(self.index_dic[pid]) 45 | if len(idxs) < self.num_instances: 46 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 47 | random.shuffle(idxs) 48 | batch_idxs = [] 49 | for idx in idxs: 50 | batch_idxs.append(idx) 51 | if len(batch_idxs) == self.num_instances: 52 | batch_idxs_dict[pid].append(batch_idxs) 53 | batch_idxs = [] 54 | 55 | avai_pids = copy.deepcopy(self.pids) 56 | final_idxs = [] 57 | 58 | while len(avai_pids) >= self.num_pids_per_batch: 59 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 60 | for pid in selected_pids: 61 | batch_idxs = batch_idxs_dict[pid].pop(0) 62 | final_idxs.extend(batch_idxs) 63 | if len(batch_idxs_dict[pid]) == 0: 64 | avai_pids.remove(pid) 65 | 66 | return iter(final_idxs) 67 | 68 | def __len__(self): 69 | return self.length 70 | -------------------------------------------------------------------------------- /data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_transforms 2 | -------------------------------------------------------------------------------- /data/transforms/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/transforms/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/autoaugment.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/transforms/__pycache__/autoaugment.cpython-37.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/transforms/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/data/transforms/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /data/transforms/autoaugment.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageEnhance, ImageOps 2 | import numpy as np 3 | import random 4 | 5 | 6 | class ImageNetPolicy(object): 7 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 8 | 9 | Example: 10 | >>> policy = ImageNetPolicy() 11 | >>> transformed = policy(image) 12 | 13 | Example as a PyTorch Transform: 14 | >>> transform=transforms.Compose([ 15 | >>> transforms.Resize(256), 16 | >>> ImageNetPolicy(), 17 | >>> transforms.ToTensor()]) 18 | """ 19 | def __init__(self, fillcolor=(128, 128, 128)): 20 | self.policies = [ 21 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 22 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 23 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 24 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 25 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 26 | 27 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 28 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 29 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 30 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 31 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 32 | 33 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 34 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 35 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 36 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 37 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 38 | 39 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 40 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 41 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 42 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 43 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 44 | 45 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 46 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 47 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 48 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 49 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 50 | ] 51 | 52 | 53 | def __call__(self, img): 54 | policy_idx = random.randint(0, len(self.policies) - 1) 55 | return self.policies[policy_idx](img) 56 | 57 | def __repr__(self): 58 | return "AutoAugment ImageNet Policy" 59 | 60 | 61 | class CIFAR10Policy(object): 62 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 63 | 64 | Example: 65 | >>> policy = CIFAR10Policy() 66 | >>> transformed = policy(image) 67 | 68 | Example as a PyTorch Transform: 69 | >>> transform=transforms.Compose([ 70 | >>> transforms.Resize(256), 71 | >>> CIFAR10Policy(), 72 | >>> transforms.ToTensor()]) 73 | """ 74 | def __init__(self, fillcolor=(128, 128, 128)): 75 | self.policies = [ 76 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 77 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 78 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 79 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 80 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 81 | 82 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 83 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 84 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 85 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 86 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 87 | 88 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 89 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 90 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 91 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 92 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 93 | 94 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 95 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), 96 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 97 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 98 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 99 | 100 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 101 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 102 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 103 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 104 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 105 | ] 106 | 107 | 108 | def __call__(self, img): 109 | policy_idx = random.randint(0, len(self.policies) - 1) 110 | return self.policies[policy_idx](img) 111 | 112 | def __repr__(self): 113 | return "AutoAugment CIFAR10 Policy" 114 | 115 | 116 | class SVHNPolicy(object): 117 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 118 | 119 | Example: 120 | >>> policy = SVHNPolicy() 121 | >>> transformed = policy(image) 122 | 123 | Example as a PyTorch Transform: 124 | >>> transform=transforms.Compose([ 125 | >>> transforms.Resize(256), 126 | >>> SVHNPolicy(), 127 | >>> transforms.ToTensor()]) 128 | """ 129 | def __init__(self, fillcolor=(128, 128, 128)): 130 | self.policies = [ 131 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 132 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 133 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 134 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 135 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 136 | 137 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 138 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 139 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 140 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 141 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 142 | 143 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 144 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 145 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 146 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 147 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 148 | 149 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 150 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 151 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 152 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 153 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 154 | 155 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 156 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 157 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 158 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 159 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 160 | ] 161 | 162 | 163 | def __call__(self, img): 164 | policy_idx = random.randint(0, len(self.policies) - 1) 165 | return self.policies[policy_idx](img) 166 | 167 | def __repr__(self): 168 | return "AutoAugment SVHN Policy" 169 | 170 | 171 | class SubPolicy(object): 172 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 173 | ranges = { 174 | "shearX": np.linspace(0, 0.3, 10), 175 | "shearY": np.linspace(0, 0.3, 10), 176 | "translateX": np.linspace(0, 150 / 331, 10), 177 | "translateY": np.linspace(0, 150 / 331, 10), 178 | "rotate": np.linspace(0, 30, 10), 179 | "color": np.linspace(0.0, 0.9, 10), 180 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 181 | "solarize": np.linspace(256, 0, 10), 182 | "contrast": np.linspace(0.0, 0.9, 10), 183 | "sharpness": np.linspace(0.0, 0.9, 10), 184 | "brightness": np.linspace(0.0, 0.9, 10), 185 | "autocontrast": [0] * 10, 186 | "equalize": [0] * 10, 187 | "invert": [0] * 10 188 | } 189 | 190 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 191 | def rotate_with_fill(img, magnitude): 192 | rot = img.convert("RGBA").rotate(magnitude) 193 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 194 | 195 | func = { 196 | "shearX": lambda img, magnitude: img.transform( 197 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 198 | Image.BICUBIC, fillcolor=fillcolor), 199 | "shearY": lambda img, magnitude: img.transform( 200 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 201 | Image.BICUBIC, fillcolor=fillcolor), 202 | "translateX": lambda img, magnitude: img.transform( 203 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 204 | fillcolor=fillcolor), 205 | "translateY": lambda img, magnitude: img.transform( 206 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 207 | fillcolor=fillcolor), 208 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 209 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 210 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 211 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 212 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 213 | 1 + magnitude * random.choice([-1, 1])), 214 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 215 | 1 + magnitude * random.choice([-1, 1])), 216 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 217 | 1 + magnitude * random.choice([-1, 1])), 218 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 219 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 220 | "invert": lambda img, magnitude: ImageOps.invert(img) 221 | } 222 | 223 | self.p1 = p1 224 | self.operation1 = func[operation1] 225 | self.magnitude1 = ranges[operation1][magnitude_idx1] 226 | self.p2 = p2 227 | self.operation2 = func[operation2] 228 | self.magnitude2 = ranges[operation2][magnitude_idx2] 229 | 230 | 231 | def __call__(self, img): 232 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 233 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 234 | return img -------------------------------------------------------------------------------- /data/transforms/build.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | 3 | 4 | from .transforms import RandomErasing 5 | from .autoaugment import ImageNetPolicy 6 | 7 | 8 | def build_transforms(cfg, is_train=True): 9 | normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 10 | if is_train: 11 | transform = T.Compose([ 12 | T.Resize(cfg.INPUT.SIZE_TRAIN), 13 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 14 | T.Pad(cfg.INPUT.PADDING), 15 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 16 | T.ToTensor(), 17 | normalize_transform, 18 | RandomErasing(probability=cfg.INPUT.PROB, mean=cfg.INPUT.PIXEL_MEAN) 19 | ]) 20 | else: 21 | transform = T.Compose([ 22 | T.Resize(cfg.INPUT.SIZE_TEST), 23 | T.ToTensor(), 24 | normalize_transform 25 | ]) 26 | 27 | return transform 28 | -------------------------------------------------------------------------------- /data/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | 5 | class RandomErasing(object): 6 | """ Randomly selects a rectangle region in an image and erases its pixels. 7 | 'Random Erasing Data Augmentation' by Zhong et al. 8 | See https://arxiv.org/pdf/1708.04896.pdf 9 | Args: 10 | probability: The probability that the Random Erasing operation will be performed. 11 | sl: Minimum proportion of erased area against input image. 12 | sh: Maximum proportion of erased area against input image. 13 | r1: Minimum aspect ratio of erased area. 14 | mean: Erasing value. 15 | """ 16 | 17 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 18 | self.probability = probability 19 | self.mean = mean 20 | self.sl = sl 21 | self.sh = sh 22 | self.r1 = r1 23 | 24 | def __call__(self, img): 25 | 26 | if random.uniform(0, 1) > self.probability: 27 | return img 28 | 29 | for attempt in range(100): 30 | area = img.size()[1] * img.size()[2] 31 | 32 | target_area = random.uniform(self.sl, self.sh) * area 33 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 34 | 35 | h = int(round(math.sqrt(target_area * aspect_ratio))) 36 | w = int(round(math.sqrt(target_area / aspect_ratio))) 37 | 38 | if w < img.size()[2] and h < img.size()[1]: 39 | x1 = random.randint(0, img.size()[1] - h) 40 | y1 = random.randint(0, img.size()[2] - w) 41 | if img.size()[0] == 3: 42 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 43 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 44 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 45 | else: 46 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 47 | return img 48 | 49 | return img 50 | -------------------------------------------------------------------------------- /engine/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/engine/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /engine/inference.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import pdb 4 | import torch 5 | import numpy as np 6 | from ignite.engine import Engine 7 | 8 | from utils.reid_metric import R1_mAP 9 | 10 | 11 | def create_supervised_evaluator(model, metrics, 12 | device=None): 13 | """ 14 | Factory function for creating an evaluator for supervised models 15 | 16 | Args: 17 | model (`torch.nn.Module`): the model to train 18 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics 19 | device (str, optional): device type specification (default: None). 20 | Applies to both model and batches. 21 | Returns: 22 | Engine: an evaluator engine with supervised inference function 23 | """ 24 | if device: 25 | model.to(device) 26 | 27 | def fliplr(img): 28 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long().cuda() # N x C x H x W 29 | img_flip = img.index_select(3,inv_idx) 30 | return img_flip 31 | 32 | def _inference(engine, batch): 33 | model.eval() 34 | with torch.no_grad(): 35 | data, pids, camids = batch 36 | data = data.cuda() 37 | 38 | # feat = model(data) 39 | 40 | data_f = fliplr(data) 41 | feat = model(data) 42 | feat_f = model(data_f) 43 | feat = feat + feat_f 44 | 45 | return feat, pids, camids 46 | 47 | engine = Engine(_inference) 48 | 49 | for name, metric in metrics.items(): 50 | metric.attach(engine, name) 51 | 52 | return engine 53 | 54 | 55 | def inference( 56 | cfg, 57 | model, 58 | val_loader, 59 | num_query 60 | ): 61 | device = cfg.MODEL.DEVICE 62 | 63 | logger = logging.getLogger("reid_baseline.inference") 64 | logger.info("Start inferencing") 65 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query,re_rank=cfg.TEST.RE_RANK)}, 66 | device=device) 67 | 68 | evaluator.run(val_loader) 69 | cmc, mAP = evaluator.state.metrics['r1_mAP'] 70 | logger.info('Validation Results') 71 | logger.info("mAP: {:.1%}".format(mAP)) 72 | for r in [1, 5, 10]: 73 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 74 | # logger.info("Attributes Accuracy: {:.2%}".format(att_acc)) 75 | -------------------------------------------------------------------------------- /engine/trainer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | import pdb 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from ignite.engine import Engine, Events 13 | from ignite.handlers import ModelCheckpoint, Timer 14 | from ignite.metrics import RunningAverage 15 | 16 | from utils.reid_metric import R1_mAP 17 | 18 | 19 | def create_supervised_trainer(model, optimizer, loss_fn, 20 | device=None): 21 | """ 22 | Factory function for creating a trainer for supervised models 23 | 24 | Args: 25 | model (`torch.nn.Module`): the model to train 26 | optimizer (`torch.optim.Optimizer`): the optimizer to use 27 | loss_fn (torch.nn loss function): the loss function to use 28 | device (str, optional): device type specification (default: None). 29 | Applies to both model and batches. 30 | 31 | Returns: 32 | Engine: a trainer engine with supervised update function 33 | """ 34 | if device: 35 | model.to(device) 36 | #pdb.set_trace() 37 | 38 | def _update(engine, batch): 39 | model.train() 40 | optimizer.zero_grad() 41 | img, target = batch 42 | 43 | img = img.cuda() 44 | target = target.cuda() 45 | score, feat = model(img) 46 | 47 | loss = loss_fn(score, feat, target) 48 | 49 | loss.backward() 50 | optimizer.step() 51 | # compute acc 52 | acc = (score.max(1)[1] == target).float().mean() 53 | 54 | return loss.item(), acc.item() 55 | 56 | return Engine(_update) 57 | 58 | 59 | def create_supervised_evaluator(model, metrics, 60 | device=None): 61 | """ 62 | Factory function for creating an evaluator for supervised models 63 | 64 | Args: 65 | model (`torch.nn.Module`): the model to train 66 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics 67 | device (str, optional): device type specification (default: None). 68 | Applies to both model and batches. 69 | Returns: 70 | Engine: an evaluator engine with supervised inference function 71 | """ 72 | if device: 73 | model.to(device) 74 | 75 | def fliplr(img): 76 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long().cuda() # N x C x H x W 77 | img_flip = img.index_select(3,inv_idx) 78 | return img_flip 79 | 80 | def _inference(engine, batch): 81 | model.eval() 82 | with torch.no_grad(): 83 | data, pids, camids = batch 84 | data = data.cuda() 85 | # feat = model(data) 86 | # ######### fliplr #### 87 | data_f = fliplr(data) 88 | feat = model(data) 89 | feat_f = model(data_f) 90 | feat = feat + feat_f 91 | 92 | 93 | 94 | return feat, pids, camids 95 | 96 | engine = Engine(_inference) 97 | for name, metric in metrics.items(): 98 | metric.attach(engine, name) 99 | return engine 100 | 101 | 102 | def do_train( 103 | cfg, 104 | model, 105 | train_loader, 106 | val_loader, 107 | optimizer, 108 | scheduler, 109 | loss_fn, 110 | num_query 111 | ): 112 | log_period = cfg.SOLVER.LOG_PERIOD 113 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 114 | eval_period = cfg.SOLVER.EVAL_PERIOD 115 | output_dir = cfg.OUTPUT_DIR 116 | device = cfg.MODEL.DEVICE 117 | epochs = cfg.SOLVER.MAX_EPOCHS 118 | 119 | logger = logging.getLogger("reid_baseline.train") 120 | logger.info("Start training") 121 | 122 | trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) 123 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query)}, device=device) 124 | checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, n_saved=5, require_empty=False) 125 | timer = Timer(average=True) 126 | 127 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model, 'optimizer': optimizer}) 128 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 129 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 130 | 131 | # average metric to attach on trainer 132 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss') 133 | RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc') 134 | 135 | @trainer.on(Events.EPOCH_COMPLETED) 136 | def adjust_learning_rate(engine): 137 | scheduler.step() 138 | 139 | @trainer.on(Events.ITERATION_COMPLETED) 140 | def log_training_loss(engine): 141 | iter = (engine.state.iteration - 1) % len(train_loader) + 1 142 | 143 | if iter % log_period == 0: 144 | logger.info("Epoch[{}] Iteration[{}/{}] \nLoss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" 145 | .format(engine.state.epoch, iter, len(train_loader), engine.state.metrics['avg_loss'], 146 | engine.state.metrics['avg_acc'], scheduler.get_lr()[0])) 147 | 148 | # adding handlers using `trainer.on` decorator API 149 | @trainer.on(Events.EPOCH_COMPLETED) 150 | def print_times(engine): 151 | logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' 152 | .format(engine.state.epoch, timer.value() * timer.step_count, 153 | train_loader.batch_size / timer.value())) 154 | logger.info('-' * 10) 155 | timer.reset() 156 | 157 | @trainer.on(Events.EPOCH_COMPLETED) 158 | def log_validation_results(engine): 159 | if ((engine.state.epoch % eval_period == 0) or (engine.state.epoch == epochs)) and (engine.state.epoch > 0.5*epochs): 160 | evaluator.run(val_loader) 161 | cmc, mAP = evaluator.state.metrics['r1_mAP'] 162 | logger.info("Validation Results - Epoch: {}".format(engine.state.epoch)) 163 | logger.info("mAP: {:.1%}".format(mAP)) 164 | 165 | for r in [1, 5, 10]: 166 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 167 | 168 | 169 | trainer.run(train_loader, max_epochs=epochs) 170 | -------------------------------------------------------------------------------- /images/github_main_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/images/github_main_graph.png -------------------------------------------------------------------------------- /images/github_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/images/github_vis.png -------------------------------------------------------------------------------- /images/test.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import pdb 4 | from .triplet_loss import TripletLoss 5 | from .cross_entropy_loss import CrossEntropyLoss 6 | # from .center_loss import CenterLoss 7 | 8 | 9 | 10 | 11 | def make_loss(cfg): 12 | sampler = cfg.DATALOADER.SAMPLER 13 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 14 | cross_entropy = CrossEntropyLoss(num_classes=cfg.SOLVER.CLASSNUM,epsilon=cfg.SOLVER.SMOOTH) 15 | 16 | if sampler == 'softmax': 17 | def loss_func(score, feat, target): 18 | return F.cross_entropy(score, target) 19 | elif cfg.DATALOADER.SAMPLER == 'triplet': 20 | def loss_func(score, feat, target): 21 | return triplet(feat, target)[0] 22 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 23 | def loss_func(score, feat, target): 24 | loss_id = cross_entropy(score, target) + triplet(feat, target)[0] 25 | # cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) #+PairwiseConfusion(score)/100.0 26 | return loss_id 27 | else: 28 | print('expected sampler should be softmax, triplet or softmax_triplet, ' 29 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 30 | return loss_func 31 | -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/layers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/cross_entropy_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/layers/__pycache__/cross_entropy_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/triplet_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/layers/__pycache__/triplet_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | Reference: 10 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 11 | Args: 12 | num_classes (int): number of classes. 13 | feat_dim (int): feature dimension. 14 | """ 15 | 16 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 17 | super(CenterLoss, self).__init__() 18 | self.num_classes = num_classes 19 | self.feat_dim = feat_dim 20 | self.use_gpu = use_gpu 21 | 22 | if self.use_gpu: 23 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 24 | else: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 26 | 27 | def forward(self, x, labels): 28 | """ 29 | Args: 30 | x: feature matrix with shape (batch_size, feat_dim). 31 | labels: ground truth labels with shape (num_classes). 32 | """ 33 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 34 | 35 | batch_size = x.size(0) 36 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 37 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 38 | distmat.addmm_(1, -2, x, self.centers.t()) 39 | 40 | classes = torch.arange(self.num_classes).long() 41 | if self.use_gpu: classes = classes.cuda() 42 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 43 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 44 | 45 | dist = distmat * mask.float() 46 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 47 | #dist = [] 48 | #for i in range(batch_size): 49 | # value = distmat[i][mask[i]] 50 | # value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 51 | # dist.append(value) 52 | #dist = torch.cat(dist) 53 | #loss = dist.mean() 54 | return loss 55 | 56 | 57 | if __name__ == '__main__': 58 | use_gpu = False 59 | center_loss = CenterLoss(use_gpu=use_gpu) 60 | features = torch.rand(16, 2048) 61 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 62 | if use_gpu: 63 | features = torch.rand(16, 2048).cuda() 64 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 65 | 66 | loss = center_loss(features, targets) 67 | print(loss) -------------------------------------------------------------------------------- /layers/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CrossEntropyLoss(nn.Module): 6 | """Cross entropy loss with label smoothing regularizer. 7 | 8 | Reference: 9 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 10 | 11 | Equation: y = (1 - epsilon) * y + epsilon / K. 12 | 13 | Args: 14 | - num_classes (int): number of classes 15 | - epsilon (float): weight 16 | - use_gpu (bool): whether to use gpu devices 17 | - label_smooth (bool): whether to apply label smoothing, if False, epsilon = 0 18 | """ 19 | def __init__(self, num_classes, epsilon=0.05, use_gpu=True, label_smooth=True): 20 | super(CrossEntropyLoss, self).__init__() 21 | self.num_classes = num_classes 22 | self.epsilon = epsilon if label_smooth else 0 23 | self.use_gpu = use_gpu 24 | self.logsoftmax = nn.LogSoftmax(dim=1) 25 | 26 | def forward(self, inputs, targets): 27 | """ 28 | Args: 29 | - inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 30 | - targets: ground truth labels with shape (num_classes) 31 | """ 32 | log_probs = self.logsoftmax(inputs) 33 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 34 | if self.use_gpu: targets = targets.cuda() 35 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 36 | loss = (- targets * log_probs).mean(0).sum() 37 | return loss 38 | -------------------------------------------------------------------------------- /layers/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def normalize(x, axis=-1): 6 | """Normalizing to unit length along the specified dimension. 7 | Args: 8 | x: pytorch Variable 9 | Returns: 10 | x: pytorch Variable, same shape as input 11 | """ 12 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 13 | return x 14 | 15 | 16 | def euclidean_dist(x, y): 17 | """ 18 | Args: 19 | x: pytorch Variable, with shape [m, d] 20 | y: pytorch Variable, with shape [n, d] 21 | Returns: 22 | dist: pytorch Variable, with shape [m, n] 23 | """ 24 | m, n = x.size(0), y.size(0) 25 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 26 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 27 | dist = xx + yy 28 | dist.addmm_(1, -2, x, y.t()) 29 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 30 | return dist 31 | 32 | 33 | def hard_example_mining(dist_mat, labels, return_inds=False): 34 | """For each anchor, find the hardest positive and negative sample. 35 | Args: 36 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 37 | labels: pytorch LongTensor, with shape [N] 38 | return_inds: whether to return the indices. Save time if `False`(?) 39 | Returns: 40 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 41 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 42 | p_inds: pytorch LongTensor, with shape [N]; 43 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 44 | n_inds: pytorch LongTensor, with shape [N]; 45 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 46 | NOTE: Only consider the case in which all labels have same num of samples, 47 | thus we can cope with all anchors in parallel. 48 | """ 49 | 50 | assert len(dist_mat.size()) == 2 51 | assert dist_mat.size(0) == dist_mat.size(1) 52 | N = dist_mat.size(0) 53 | 54 | # shape [N, N] 55 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 56 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 57 | 58 | # `dist_ap` means distance(anchor, positive) 59 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 60 | dist_ap, relative_p_inds = torch.max( 61 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 62 | # `dist_an` means distance(anchor, negative) 63 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 64 | dist_an, relative_n_inds = torch.min( 65 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 66 | # shape [N] 67 | dist_ap = dist_ap.squeeze(1) 68 | dist_an = dist_an.squeeze(1) 69 | 70 | if return_inds: 71 | # shape [N, N] 72 | ind = (labels.new().resize_as_(labels) 73 | .copy_(torch.arange(0, N).long()) 74 | .unsqueeze(0).expand(N, N)) 75 | # shape [N, 1] 76 | p_inds = torch.gather( 77 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 78 | n_inds = torch.gather( 79 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 80 | # shape [N] 81 | p_inds = p_inds.squeeze(1) 82 | n_inds = n_inds.squeeze(1) 83 | return dist_ap, dist_an, p_inds, n_inds 84 | 85 | return dist_ap, dist_an 86 | 87 | 88 | class TripletLoss(object): 89 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 90 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 91 | Loss for Person Re-Identification'.""" 92 | 93 | def __init__(self, margin=None): 94 | self.margin = margin 95 | if margin is not None: 96 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 97 | else: 98 | self.ranking_loss = nn.SoftMarginLoss() 99 | 100 | def __call__(self, global_feat, labels, normalize_feature=False): 101 | if normalize_feature: 102 | global_feat = normalize(global_feat, axis=-1) 103 | dist_mat = euclidean_dist(global_feat, global_feat) 104 | dist_ap, dist_an = hard_example_mining( 105 | dist_mat, labels) 106 | y = dist_an.new().resize_as_(dist_an).fill_(1) 107 | if self.margin is not None: 108 | loss = self.ranking_loss(dist_an, dist_ap, y) 109 | else: 110 | loss = self.ranking_loss(dist_an - dist_ap, y) 111 | return loss, dist_ap, dist_an 112 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .baseline import Baseline#,Baseline_SE,Baseline_DENSE 2 | import torch.nn as nn 3 | import torch 4 | 5 | def build_model(cfg, num_classes): 6 | if cfg.MODEL.NAME == 'resnet50': 7 | model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.APNET.LEVEL,cfg.APNET.MSMT) 8 | return model 9 | else: 10 | raise RuntimeError("'{}' is not available".format(cfg.MODEL.NAME)) 11 | -------------------------------------------------------------------------------- /modeling/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/modeling/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/baseline.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/modeling/__pycache__/baseline.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/modeling/backbones/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/inception.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/modeling/backbones/__pycache__/inception.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/resnest.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/modeling/backbones/__pycache__/resnest.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/modeling/backbones/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/se_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/modeling/backbones/__pycache__/se_module.cpython-37.pyc -------------------------------------------------------------------------------- /modeling/backbones/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from collections import OrderedDict 7 | 8 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 9 | 10 | 11 | model_urls = { 12 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 13 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 14 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 15 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 16 | } 17 | 18 | 19 | def densenet121(pretrained=False, **kwargs): 20 | r"""Densenet-121 model from 21 | `"Densely Connected Convolutional Networks" `_ 22 | 23 | Args: 24 | pretrained (bool): If True, returns a model pre-trained on ImageNet 25 | """ 26 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 27 | **kwargs) 28 | if pretrained: 29 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 30 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 31 | # They are also in the checkpoints in model_urls. This pattern is used 32 | # to find such keys. 33 | pattern = re.compile( 34 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 35 | state_dict = model_zoo.load_url(model_urls['densenet121']) 36 | for key in list(state_dict.keys()): 37 | res = pattern.match(key) 38 | if res: 39 | new_key = res.group(1) + res.group(2) 40 | state_dict[new_key] = state_dict[key] 41 | del state_dict[key] 42 | model.load_state_dict(state_dict) 43 | return model 44 | 45 | 46 | def densenet169(pretrained=False, **kwargs): 47 | r"""Densenet-169 model from 48 | `"Densely Connected Convolutional Networks" `_ 49 | 50 | Args: 51 | pretrained (bool): If True, returns a model pre-trained on ImageNet 52 | """ 53 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 54 | **kwargs) 55 | if pretrained: 56 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 57 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 58 | # They are also in the checkpoints in model_urls. This pattern is used 59 | # to find such keys. 60 | pattern = re.compile( 61 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 62 | state_dict = model_zoo.load_url(model_urls['densenet169']) 63 | for key in list(state_dict.keys()): 64 | res = pattern.match(key) 65 | if res: 66 | new_key = res.group(1) + res.group(2) 67 | state_dict[new_key] = state_dict[key] 68 | del state_dict[key] 69 | model.load_state_dict(state_dict) 70 | return model 71 | 72 | 73 | def densenet201(pretrained=False, **kwargs): 74 | r"""Densenet-201 model from 75 | `"Densely Connected Convolutional Networks" `_ 76 | 77 | Args: 78 | pretrained (bool): If True, returns a model pre-trained on ImageNet 79 | """ 80 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 81 | **kwargs) 82 | if pretrained: 83 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 84 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 85 | # They are also in the checkpoints in model_urls. This pattern is used 86 | # to find such keys. 87 | pattern = re.compile( 88 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 89 | state_dict = model_zoo.load_url(model_urls['densenet201']) 90 | for key in list(state_dict.keys()): 91 | res = pattern.match(key) 92 | if res: 93 | new_key = res.group(1) + res.group(2) 94 | state_dict[new_key] = state_dict[key] 95 | del state_dict[key] 96 | model.load_state_dict(state_dict) 97 | return model 98 | 99 | 100 | def densenet161(pretrained=False, **kwargs): 101 | r"""Densenet-161 model from 102 | `"Densely Connected Convolutional Networks" `_ 103 | 104 | Args: 105 | pretrained (bool): If True, returns a model pre-trained on ImageNet 106 | """ 107 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 108 | **kwargs) 109 | if pretrained: 110 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 111 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 112 | # They are also in the checkpoints in model_urls. This pattern is used 113 | # to find such keys. 114 | pattern = re.compile( 115 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 116 | state_dict = model_zoo.load_url(model_urls['densenet161']) 117 | for key in list(state_dict.keys()): 118 | res = pattern.match(key) 119 | if res: 120 | new_key = res.group(1) + res.group(2) 121 | state_dict[new_key] = state_dict[key] 122 | del state_dict[key] 123 | model.load_state_dict(state_dict) 124 | return model 125 | 126 | 127 | class _DenseLayer(nn.Sequential): 128 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 129 | super(_DenseLayer, self).__init__() 130 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 131 | self.add_module('relu1', nn.ReLU(inplace=True)), 132 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 133 | growth_rate, kernel_size=1, stride=1, bias=False)), 134 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 135 | self.add_module('relu2', nn.ReLU(inplace=True)), 136 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 137 | kernel_size=3, stride=1, padding=1, bias=False)), 138 | self.drop_rate = drop_rate 139 | 140 | def forward(self, x): 141 | new_features = super(_DenseLayer, self).forward(x) 142 | if self.drop_rate > 0: 143 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 144 | return torch.cat([x, new_features], 1) 145 | 146 | 147 | class _DenseBlock(nn.Sequential): 148 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 149 | super(_DenseBlock, self).__init__() 150 | for i in range(num_layers): 151 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 152 | self.add_module('denselayer%d' % (i + 1), layer) 153 | 154 | 155 | class _Transition(nn.Sequential): 156 | def __init__(self, num_input_features, num_output_features): 157 | super(_Transition, self).__init__() 158 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 159 | self.add_module('relu', nn.ReLU(inplace=True)) 160 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 161 | kernel_size=1, stride=1, bias=False)) 162 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 163 | 164 | 165 | class DenseNet(nn.Module): 166 | r"""Densenet-BC model class, based on 167 | `"Densely Connected Convolutional Networks" `_ 168 | 169 | Args: 170 | growth_rate (int) - how many filters to add each layer (`k` in paper) 171 | block_config (list of 4 ints) - how many layers in each pooling block 172 | num_init_features (int) - the number of filters to learn in the first convolution layer 173 | bn_size (int) - multiplicative factor for number of bottle neck layers 174 | (i.e. bn_size * k features in the bottleneck layer) 175 | drop_rate (float) - dropout rate after each dense layer 176 | num_classes (int) - number of classification classes 177 | """ 178 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 179 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 180 | 181 | super(DenseNet, self).__init__() 182 | 183 | # First convolution 184 | self.features = nn.Sequential(OrderedDict([ 185 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 186 | ('norm0', nn.BatchNorm2d(num_init_features)), 187 | ('relu0', nn.ReLU(inplace=True)), 188 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 189 | ])) 190 | 191 | # Each denseblock 192 | num_features = num_init_features 193 | for i, num_layers in enumerate(block_config): 194 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 195 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 196 | self.features.add_module('denseblock%d' % (i + 1), block) 197 | num_features = num_features + num_layers * growth_rate 198 | if i != len(block_config) - 1: 199 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 200 | self.features.add_module('transition%d' % (i + 1), trans) 201 | num_features = num_features // 2 202 | 203 | # Final batch norm 204 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 205 | 206 | # Linear layer 207 | self.classifier = nn.Linear(num_features, num_classes) 208 | 209 | # Official init from torch repo. 210 | for m in self.modules(): 211 | if isinstance(m, nn.Conv2d): 212 | nn.init.kaiming_normal(m.weight.data) 213 | elif isinstance(m, nn.BatchNorm2d): 214 | m.weight.data.fill_(1) 215 | m.bias.data.zero_() 216 | elif isinstance(m, nn.Linear): 217 | m.bias.data.zero_() 218 | 219 | def forward(self, x): 220 | features = self.features(x) 221 | out = F.relu(features, inplace=True) 222 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 223 | out = self.classifier(out) 224 | return out 225 | -------------------------------------------------------------------------------- /modeling/backbones/inception.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | __all__ = ['Inception3', 'inception_v3'] 9 | 10 | 11 | model_urls = { 12 | # Inception v3 ported from TensorFlow 13 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 14 | } 15 | 16 | 17 | def inception_v3(pretrained=False, **kwargs): 18 | r"""Inception v3 model architecture from 19 | `"Rethinking the Inception Architecture for Computer Vision" `_. 20 | 21 | Args: 22 | pretrained (bool): If True, returns a model pre-trained on ImageNet 23 | """ 24 | if pretrained: 25 | if 'transform_input' not in kwargs: 26 | kwargs['transform_input'] = True 27 | model = Inception3(**kwargs) 28 | model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google'])) 29 | return model 30 | 31 | return Inception3(**kwargs) 32 | 33 | 34 | class Inception3(nn.Module): 35 | 36 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): 37 | super(Inception3, self).__init__() 38 | self.aux_logits = aux_logits 39 | self.transform_input = transform_input 40 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) 41 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 42 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 43 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 44 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 45 | self.Mixed_5b = InceptionA(192, pool_features=32) 46 | self.Mixed_5c = InceptionA(256, pool_features=64) 47 | self.Mixed_5d = InceptionA(288, pool_features=64) 48 | self.Mixed_6a = InceptionB(288) 49 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 50 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 51 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 52 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 53 | if aux_logits: 54 | self.AuxLogits = InceptionAux(768, num_classes) 55 | self.Mixed_7a = InceptionD(768) 56 | self.Mixed_7b = InceptionE(1280) 57 | self.Mixed_7c = InceptionE(2048) 58 | self.fc = nn.Linear(2048, num_classes) 59 | 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 62 | import scipy.stats as stats 63 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1 64 | X = stats.truncnorm(-2, 2, scale=stddev) 65 | values = torch.Tensor(X.rvs(m.weight.data.numel())) 66 | values = values.view(m.weight.data.size()) 67 | m.weight.data.copy_(values) 68 | elif isinstance(m, nn.BatchNorm2d): 69 | m.weight.data.fill_(1) 70 | m.bias.data.zero_() 71 | 72 | def forward(self, x): 73 | if self.transform_input: 74 | x = x.clone() 75 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 76 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 77 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 78 | # 299 x 299 x 3 79 | x = self.Conv2d_1a_3x3(x) 80 | # 149 x 149 x 32 81 | x = self.Conv2d_2a_3x3(x) 82 | # 147 x 147 x 32 83 | x = self.Conv2d_2b_3x3(x) 84 | # 147 x 147 x 64 85 | x = F.max_pool2d(x, kernel_size=3, stride=2) 86 | # 73 x 73 x 64 87 | x = self.Conv2d_3b_1x1(x) 88 | # 73 x 73 x 80 89 | x = self.Conv2d_4a_3x3(x) 90 | # 71 x 71 x 192 91 | x = F.max_pool2d(x, kernel_size=3, stride=2) 92 | # 35 x 35 x 192 93 | x = self.Mixed_5b(x) 94 | # 35 x 35 x 256 95 | x = self.Mixed_5c(x) 96 | # 35 x 35 x 288 97 | x = self.Mixed_5d(x) 98 | # 35 x 35 x 288 99 | x = self.Mixed_6a(x) 100 | # 17 x 17 x 768 101 | x = self.Mixed_6b(x) 102 | # 17 x 17 x 768 103 | x = self.Mixed_6c(x) 104 | # 17 x 17 x 768 105 | x = self.Mixed_6d(x) 106 | # 17 x 17 x 768 107 | x = self.Mixed_6e(x) 108 | # 17 x 17 x 768 109 | if self.training and self.aux_logits: 110 | aux = self.AuxLogits(x) 111 | # 17 x 17 x 768 112 | x = self.Mixed_7a(x) 113 | # 8 x 8 x 1280 114 | x = self.Mixed_7b(x) 115 | # 8 x 8 x 2048 116 | x = self.Mixed_7c(x) 117 | # 8 x 8 x 2048 118 | x = F.avg_pool2d(x, kernel_size=8) 119 | # 1 x 1 x 2048 120 | x = F.dropout(x, training=self.training) 121 | # 1 x 1 x 2048 122 | x = x.view(x.size(0), -1) 123 | # 2048 124 | x = self.fc(x) 125 | # 1000 (num_classes) 126 | if self.training and self.aux_logits: 127 | return x, aux 128 | return x 129 | 130 | def get_features_mixed_6e(self): 131 | return nn.Sequential( 132 | self.Conv2d_1a_3x3, 133 | self.Conv2d_2a_3x3, 134 | self.Conv2d_2b_3x3, 135 | nn.MaxPool2d(kernel_size=3, stride=2), 136 | self.Conv2d_3b_1x1, 137 | self.Conv2d_4a_3x3, 138 | nn.MaxPool2d(kernel_size=3, stride=2), 139 | self.Mixed_5b, 140 | self.Mixed_5c, 141 | self.Mixed_5d, 142 | self.Mixed_6a, 143 | self.Mixed_6b, 144 | self.Mixed_6c, 145 | self.Mixed_6d, 146 | self.Mixed_6e, 147 | ) 148 | 149 | def get_features_mixed_7c(self): 150 | return nn.Sequential( 151 | self.Conv2d_1a_3x3, 152 | self.Conv2d_2a_3x3, 153 | self.Conv2d_2b_3x3, 154 | nn.MaxPool2d(kernel_size=3, stride=2), 155 | self.Conv2d_3b_1x1, 156 | self.Conv2d_4a_3x3, 157 | nn.MaxPool2d(kernel_size=3, stride=2), 158 | self.Mixed_5b, 159 | self.Mixed_5c, 160 | self.Mixed_5d, 161 | self.Mixed_6a, 162 | self.Mixed_6b, 163 | self.Mixed_6c, 164 | self.Mixed_6d, 165 | self.Mixed_6e, 166 | self.Mixed_7a, 167 | self.Mixed_7b, 168 | self.Mixed_7c, 169 | ) 170 | 171 | def load_state_dict(self, state_dict, strict=True): 172 | model_dict = self.state_dict() 173 | pretrained_dict = {k: v for k, v in state_dict.items() 174 | if k in model_dict and model_dict[k].size() == v.size()} 175 | 176 | if len(pretrained_dict) == len(state_dict): 177 | logging.info('%s: All params loaded' % type(self).__name__) 178 | else: 179 | logging.info('%s: Some params were not loaded:' % type(self).__name__) 180 | not_loaded_keys = [k for k in state_dict.keys() if k not in pretrained_dict.keys()] 181 | logging.info(('%s, ' * (len(not_loaded_keys) - 1) + '%s') % tuple(not_loaded_keys)) 182 | 183 | model_dict.update(pretrained_dict) 184 | super(Inception3, self).load_state_dict(model_dict) 185 | 186 | 187 | class InceptionA(nn.Module): 188 | 189 | def __init__(self, in_channels, pool_features): 190 | super(InceptionA, self).__init__() 191 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 192 | 193 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 194 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 195 | 196 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 197 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 198 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 199 | 200 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 201 | 202 | def forward(self, x): 203 | branch1x1 = self.branch1x1(x) 204 | 205 | branch5x5 = self.branch5x5_1(x) 206 | branch5x5 = self.branch5x5_2(branch5x5) 207 | 208 | branch3x3dbl = self.branch3x3dbl_1(x) 209 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 210 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 211 | 212 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 213 | branch_pool = self.branch_pool(branch_pool) 214 | 215 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 216 | return torch.cat(outputs, 1) 217 | 218 | 219 | class InceptionB(nn.Module): 220 | 221 | def __init__(self, in_channels): 222 | super(InceptionB, self).__init__() 223 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) 224 | 225 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 226 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 227 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) 228 | 229 | def forward(self, x): 230 | branch3x3 = self.branch3x3(x) 231 | 232 | branch3x3dbl = self.branch3x3dbl_1(x) 233 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 234 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 235 | 236 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 237 | 238 | outputs = [branch3x3, branch3x3dbl, branch_pool] 239 | return torch.cat(outputs, 1) 240 | 241 | 242 | class InceptionC(nn.Module): 243 | 244 | def __init__(self, in_channels, channels_7x7): 245 | super(InceptionC, self).__init__() 246 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 247 | 248 | c7 = channels_7x7 249 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 250 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 251 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 252 | 253 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 254 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 255 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 256 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 257 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 258 | 259 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 260 | 261 | def forward(self, x): 262 | branch1x1 = self.branch1x1(x) 263 | 264 | branch7x7 = self.branch7x7_1(x) 265 | branch7x7 = self.branch7x7_2(branch7x7) 266 | branch7x7 = self.branch7x7_3(branch7x7) 267 | 268 | branch7x7dbl = self.branch7x7dbl_1(x) 269 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 270 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 271 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 272 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 273 | 274 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 275 | branch_pool = self.branch_pool(branch_pool) 276 | 277 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 278 | return torch.cat(outputs, 1) 279 | 280 | 281 | class InceptionD(nn.Module): 282 | 283 | def __init__(self, in_channels): 284 | super(InceptionD, self).__init__() 285 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 286 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) 287 | 288 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 289 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) 290 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) 291 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) 292 | 293 | def forward(self, x): 294 | branch3x3 = self.branch3x3_1(x) 295 | branch3x3 = self.branch3x3_2(branch3x3) 296 | 297 | branch7x7x3 = self.branch7x7x3_1(x) 298 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 299 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 300 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 301 | 302 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 303 | outputs = [branch3x3, branch7x7x3, branch_pool] 304 | return torch.cat(outputs, 1) 305 | 306 | 307 | class InceptionE(nn.Module): 308 | 309 | def __init__(self, in_channels): 310 | super(InceptionE, self).__init__() 311 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 312 | 313 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 314 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 315 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 316 | 317 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 318 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 319 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 320 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 321 | 322 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 323 | 324 | def forward(self, x): 325 | branch1x1 = self.branch1x1(x) 326 | 327 | branch3x3 = self.branch3x3_1(x) 328 | branch3x3 = [ 329 | self.branch3x3_2a(branch3x3), 330 | self.branch3x3_2b(branch3x3), 331 | ] 332 | branch3x3 = torch.cat(branch3x3, 1) 333 | 334 | branch3x3dbl = self.branch3x3dbl_1(x) 335 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 336 | branch3x3dbl = [ 337 | self.branch3x3dbl_3a(branch3x3dbl), 338 | self.branch3x3dbl_3b(branch3x3dbl), 339 | ] 340 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 341 | 342 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 343 | branch_pool = self.branch_pool(branch_pool) 344 | 345 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 346 | return torch.cat(outputs, 1) 347 | 348 | 349 | class InceptionAux(nn.Module): 350 | 351 | def __init__(self, in_channels, num_classes): 352 | super(InceptionAux, self).__init__() 353 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) 354 | self.conv1 = BasicConv2d(128, 768, kernel_size=5) 355 | self.conv1.stddev = 0.01 356 | self.fc = nn.Linear(768, num_classes) 357 | self.fc.stddev = 0.001 358 | 359 | def forward(self, x): 360 | # 17 x 17 x 768 361 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 362 | # 5 x 5 x 768 363 | x = self.conv0(x) 364 | # 5 x 5 x 128 365 | x = self.conv1(x) 366 | # 1 x 1 x 768 367 | x = x.view(x.size(0), -1) 368 | # 768 369 | x = self.fc(x) 370 | # 1000 371 | return x 372 | 373 | 374 | class BasicConv2d(nn.Module): 375 | 376 | def __init__(self, in_channels, out_channels, **kwargs): 377 | super(BasicConv2d, self).__init__() 378 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 379 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 380 | 381 | def forward(self, x): 382 | x = self.conv(x) 383 | x = self.bn(x) 384 | return F.relu(x, inplace=True) 385 | -------------------------------------------------------------------------------- /modeling/backbones/resnest.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## Email: zhanghang0704@gmail.com 4 | ## Copyright (c) 2020 5 | ## 6 | ## LICENSE file in the root directory of this source tree 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | """ResNeSt models""" 9 | 10 | import torch 11 | 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | # from .utils import drop_connect 16 | 17 | """Split-Attention""" 18 | 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU 23 | from torch.nn.modules.utils import _pair 24 | 25 | __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269'] 26 | 27 | _url_format = 'https://hangzh.s3.amazonaws.com/encoding/models/{}-{}.pth' 28 | 29 | _model_sha256 = {name: checksum for checksum, name in [ 30 | ('528c19ca', 'resnest50'), 31 | ('22405ba7', 'resnest101'), 32 | ('75117900', 'resnest200'), 33 | ('0cc87c48', 'resnest269'), 34 | ]} 35 | 36 | def short_hash(name): 37 | if name not in _model_sha256: 38 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 39 | return _model_sha256[name][:8] 40 | 41 | resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for 42 | name in _model_sha256.keys() 43 | } 44 | 45 | 46 | 47 | class DropBlock2D(object): 48 | def __init__(self, *args, **kwargs): 49 | raise NotImplementedError 50 | 51 | class SplAtConv2d(Module): 52 | """Split-Attention Conv2d 53 | """ 54 | def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), 55 | dilation=(1, 1), groups=1, bias=True, 56 | radix=2, reduction_factor=4, 57 | rectify=False, rectify_avg=False, norm_layer=None, 58 | dropblock_prob=0.0, **kwargs): 59 | super(SplAtConv2d, self).__init__() 60 | padding = _pair(padding) 61 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 62 | self.rectify_avg = rectify_avg 63 | inter_channels = max(in_channels*radix//reduction_factor, 32) 64 | self.radix = radix 65 | self.cardinality = groups 66 | self.channels = channels 67 | self.dropblock_prob = dropblock_prob 68 | if self.rectify: 69 | from rfconv import RFConv2d 70 | self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, 71 | groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs) 72 | else: 73 | self.conv = Conv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, 74 | groups=groups*radix, bias=bias, **kwargs) 75 | self.use_bn = norm_layer is not None 76 | self.bn0 = norm_layer(channels*radix) 77 | self.relu = ReLU(inplace=True) 78 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 79 | self.bn1 = norm_layer(inter_channels) 80 | self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality) 81 | if dropblock_prob > 0.0: 82 | self.dropblock = DropBlock2D(dropblock_prob, 3) 83 | 84 | def forward(self, x): 85 | x = self.conv(x) 86 | if self.use_bn: 87 | x = self.bn0(x) 88 | if self.dropblock_prob > 0.0: 89 | x = self.dropblock(x) 90 | x = self.relu(x) 91 | 92 | batch, channel = x.shape[:2] 93 | if self.radix > 1: 94 | splited = torch.split(x, channel//self.radix, dim=1) 95 | gap = sum(splited) 96 | else: 97 | gap = x 98 | gap = F.adaptive_avg_pool2d(gap, 1) 99 | gap = self.fc1(gap) 100 | 101 | if self.use_bn: 102 | gap = self.bn1(gap) 103 | gap = self.relu(gap) 104 | 105 | atten = self.fc2(gap).view((batch, self.radix, self.channels)) 106 | if self.radix > 1: 107 | atten = F.softmax(atten, dim=1).view(batch, -1, 1, 1) 108 | else: 109 | atten = F.sigmoid(atten, dim=1).view(batch, -1, 1, 1) 110 | 111 | if self.radix > 1: 112 | atten = torch.split(atten, channel//self.radix, dim=1) 113 | out = sum([att*split for (att, split) in zip(atten, splited)]) 114 | else: 115 | out = atten * x 116 | return out.contiguous() 117 | 118 | class GlobalAvgPool2d(nn.Module): 119 | def __init__(self): 120 | """Global average pooling over the input's spatial dimensions""" 121 | super(GlobalAvgPool2d, self).__init__() 122 | 123 | def forward(self, inputs): 124 | return nn.functional.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1) 125 | 126 | class Bottleneck(nn.Module): 127 | """ResNet Bottleneck 128 | """ 129 | # pylint: disable=unused-argument 130 | expansion = 4 131 | def __init__(self, inplanes, planes, stride=1, downsample=None, 132 | radix=1, cardinality=1, bottleneck_width=64, 133 | avd=False, avd_first=False, dilation=1, is_first=False, 134 | rectified_conv=False, rectify_avg=False, 135 | norm_layer=None, dropblock_prob=0.0, last_gamma=False, drop_connection_rate=0.0): 136 | super(Bottleneck, self).__init__() 137 | group_width = int(planes * (bottleneck_width / 64.)) * cardinality 138 | self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) 139 | self.bn1 = norm_layer(group_width) 140 | self.dropblock_prob = dropblock_prob 141 | self.radix = radix 142 | self.avd = avd and (stride > 1 or is_first) 143 | self.avd_first = avd_first 144 | 145 | self.drop_connection_rate = drop_connection_rate 146 | 147 | if self.avd: 148 | self.avd_layer = nn.AvgPool2d(3, stride, padding=1) 149 | stride = 1 150 | 151 | if dropblock_prob > 0.0: 152 | self.dropblock1 = DropBlock2D(dropblock_prob, 3) 153 | if radix == 1: 154 | self.dropblock2 = DropBlock2D(dropblock_prob, 3) 155 | self.dropblock3 = DropBlock2D(dropblock_prob, 3) 156 | 157 | if radix > 1: 158 | self.conv2 = SplAtConv2d( 159 | group_width, group_width, kernel_size=3, 160 | stride=stride, padding=dilation, 161 | dilation=dilation, groups=cardinality, bias=False, 162 | radix=radix, rectify=rectified_conv, 163 | rectify_avg=rectify_avg, 164 | norm_layer=norm_layer, 165 | dropblock_prob=dropblock_prob) 166 | elif rectified_conv: 167 | from rfconv import RFConv2d 168 | self.conv2 = RFConv2d( 169 | group_width, group_width, kernel_size=3, stride=stride, 170 | padding=dilation, dilation=dilation, 171 | groups=cardinality, bias=False, 172 | average_mode=rectify_avg) 173 | self.bn2 = norm_layer(group_width) 174 | else: 175 | self.conv2 = nn.Conv2d( 176 | group_width, group_width, kernel_size=3, stride=stride, 177 | padding=dilation, dilation=dilation, 178 | groups=cardinality, bias=False) 179 | self.bn2 = norm_layer(group_width) 180 | 181 | self.conv3 = nn.Conv2d( 182 | group_width, planes * 4, kernel_size=1, bias=False) 183 | self.bn3 = norm_layer(planes*4) 184 | 185 | if last_gamma: 186 | from torch.nn.init import zeros_ 187 | zeros_(self.bn3.weight) 188 | self.relu = nn.ReLU(inplace=True) 189 | self.downsample = downsample 190 | self.dilation = dilation 191 | self.stride = stride 192 | 193 | def forward(self, x): 194 | residual = x 195 | 196 | out = self.conv1(x) 197 | out = self.bn1(out) 198 | if self.dropblock_prob > 0.0: 199 | out = self.dropblock1(out) 200 | out = self.relu(out) 201 | 202 | if self.avd and self.avd_first: 203 | out = self.avd_layer(out) 204 | 205 | out = self.conv2(out) 206 | if self.radix == 1: 207 | out = self.bn2(out) 208 | if self.dropblock_prob > 0.0: 209 | out = self.dropblock2(out) 210 | out = self.relu(out) 211 | 212 | if self.avd and not self.avd_first: 213 | out = self.avd_layer(out) 214 | 215 | out = self.conv3(out) 216 | out = self.bn3(out) 217 | if self.dropblock_prob > 0.0: 218 | out = self.dropblock3(out) 219 | 220 | if self.downsample is not None: 221 | residual = self.downsample(x) 222 | 223 | if self.drop_connection_rate > 0: 224 | out = drop_connect(out, p=self.drop_connection_rate, training=self.training) + residual 225 | else: 226 | out += residual 227 | out = self.relu(out) 228 | 229 | return out 230 | 231 | class ResNet(nn.Module): 232 | """ResNet Variants 233 | 234 | Parameters 235 | ---------- 236 | block : Block 237 | Class for the residual block. Options are BasicBlockV1, BottleneckV1. 238 | layers : list of int 239 | Numbers of layers in each block 240 | classes : int, default 1000 241 | Number of classification classes. 242 | dilated : bool, default False 243 | Applying dilation strategy to pretrained ResNet yielding a stride-8 model, 244 | typically used in Semantic Segmentation. 245 | norm_layer : object 246 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 247 | for Synchronized Cross-GPU BachNormalization). 248 | 249 | Reference: 250 | 251 | - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 252 | 253 | - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." 254 | """ 255 | # pylint: disable=unused-variable 256 | def __init__(self, block, layers, radix=1, groups=1, bottleneck_width=64, 257 | num_classes=1000, dilated=False, dilation=1, 258 | deep_stem=False, stem_width=64, avg_down=False, 259 | rectified_conv=False, rectify_avg=False, 260 | avd=False, avd_first=False, 261 | final_drop=0.0, dropblock_prob=0, 262 | last_gamma=False, norm_layer=nn.BatchNorm2d): 263 | self.cardinality = groups 264 | self.bottleneck_width = bottleneck_width 265 | # ResNet-D params 266 | self.inplanes = stem_width*2 if deep_stem else 64 267 | self.avg_down = avg_down 268 | self.last_gamma = last_gamma 269 | # ResNeSt params 270 | self.radix = radix 271 | self.avd = avd 272 | self.avd_first = avd_first 273 | 274 | self.global_drop_connect_rate = 0.0 275 | 276 | super(ResNet, self).__init__() 277 | self.rectified_conv = rectified_conv 278 | self.rectify_avg = rectify_avg 279 | if rectified_conv: 280 | from rfconv import RFConv2d 281 | conv_layer = RFConv2d 282 | else: 283 | conv_layer = nn.Conv2d 284 | conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {} 285 | if deep_stem: 286 | self.conv1 = nn.Sequential( 287 | conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs), 288 | norm_layer(stem_width), 289 | nn.ReLU(inplace=True), 290 | conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), 291 | norm_layer(stem_width), 292 | nn.ReLU(inplace=True), 293 | conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), 294 | ) 295 | else: 296 | self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3, 297 | bias=False, **conv_kwargs) 298 | self.bn1 = norm_layer(self.inplanes) 299 | self.relu = nn.ReLU(inplace=True) 300 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 301 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False) 302 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 303 | if dilated or dilation == 4: 304 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 305 | dilation=2, norm_layer=norm_layer, 306 | dropblock_prob=dropblock_prob, drop_connection_rate=self.global_drop_connect_rate * 0.5) 307 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 308 | dilation=4, norm_layer=norm_layer, 309 | dropblock_prob=dropblock_prob, drop_connection_rate=self.global_drop_connect_rate) 310 | elif dilation==2: 311 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 312 | dilation=1, norm_layer=norm_layer, 313 | dropblock_prob=dropblock_prob, drop_connection_rate=self.global_drop_connect_rate * 0.5) 314 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 315 | dilation=2, norm_layer=norm_layer, 316 | dropblock_prob=dropblock_prob, drop_connection_rate=self.global_drop_connect_rate) 317 | else: 318 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 319 | norm_layer=norm_layer, 320 | dropblock_prob=dropblock_prob, drop_connection_rate=self.global_drop_connect_rate * 0.5) 321 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 322 | norm_layer=norm_layer, 323 | dropblock_prob=dropblock_prob, drop_connection_rate=self.global_drop_connect_rate) 324 | self.avgpool = GlobalAvgPool2d() 325 | self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None 326 | self.fc = nn.Linear(512 * block.expansion, num_classes) 327 | import pdb 328 | pdb.set_trace() 329 | for m in self.modules(): 330 | if isinstance(m, nn.Conv2d): 331 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 332 | m.weight.data.normal_(0, math.sqrt(2. / n)) 333 | elif isinstance(m, norm_layer): 334 | m.weight.data.fill_(1) 335 | m.bias.data.zero_() 336 | 337 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, 338 | dropblock_prob=0.0, is_first=True, drop_connection_rate=0.0): 339 | downsample = None 340 | if stride != 1 or self.inplanes != planes * block.expansion: 341 | down_layers = [] 342 | if self.avg_down: 343 | if dilation == 1: 344 | down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride, 345 | ceil_mode=True, count_include_pad=False)) 346 | else: 347 | down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1, 348 | ceil_mode=True, count_include_pad=False)) 349 | down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, 350 | kernel_size=1, stride=1, bias=False)) 351 | else: 352 | down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, 353 | kernel_size=1, stride=stride, bias=False)) 354 | down_layers.append(norm_layer(planes * block.expansion)) 355 | downsample = nn.Sequential(*down_layers) 356 | 357 | layers = [] 358 | if dilation == 1 or dilation == 2: 359 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 360 | radix=self.radix, cardinality=self.cardinality, 361 | bottleneck_width=self.bottleneck_width, 362 | avd=self.avd, avd_first=self.avd_first, 363 | dilation=1, is_first=is_first, rectified_conv=self.rectified_conv, 364 | rectify_avg=self.rectify_avg, 365 | norm_layer=norm_layer, dropblock_prob=dropblock_prob, 366 | last_gamma=self.last_gamma)) 367 | elif dilation == 4: 368 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 369 | radix=self.radix, cardinality=self.cardinality, 370 | bottleneck_width=self.bottleneck_width, 371 | avd=self.avd, avd_first=self.avd_first, 372 | dilation=2, is_first=is_first, rectified_conv=self.rectified_conv, 373 | rectify_avg=self.rectify_avg, 374 | norm_layer=norm_layer, dropblock_prob=dropblock_prob, 375 | last_gamma=self.last_gamma)) 376 | else: 377 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 378 | 379 | self.inplanes = planes * block.expansion 380 | for i in range(1, blocks): 381 | layers.append(block(self.inplanes, planes, 382 | radix=self.radix, cardinality=self.cardinality, 383 | bottleneck_width=self.bottleneck_width, 384 | avd=self.avd, avd_first=self.avd_first, 385 | dilation=dilation, rectified_conv=self.rectified_conv, 386 | rectify_avg=self.rectify_avg, 387 | norm_layer=norm_layer, dropblock_prob=dropblock_prob, 388 | last_gamma=self.last_gamma, drop_connection_rate=drop_connection_rate)) 389 | 390 | return nn.Sequential(*layers) 391 | 392 | def forward(self, x): 393 | x = self.conv1(x) 394 | x = self.bn1(x) 395 | x = self.relu(x) 396 | x = self.maxpool(x) 397 | 398 | x = self.layer1(x) 399 | x = self.layer2(x) 400 | x = self.layer3(x) 401 | x = self.layer4(x) 402 | 403 | x = self.avgpool(x) 404 | #x = x.view(x.size(0), -1) 405 | x = torch.flatten(x, 1) 406 | if self.drop: 407 | x = self.drop(x) 408 | x = self.fc(x) 409 | 410 | return x 411 | 412 | def extract_features(self, x): 413 | x = self.conv1(x) 414 | x = self.bn1(x) 415 | x = self.relu(x) 416 | x = self.maxpool(x) 417 | 418 | x = self.layer1(x) 419 | x = self.layer2(x) 420 | x = self.layer3(x) 421 | x = self.layer4(x) 422 | return x 423 | 424 | def get_all_features(self, x): 425 | x = self.conv1(x) 426 | x = self.bn1(x) 427 | x = self.relu(x) 428 | x = self.maxpool(x) 429 | 430 | x1 = self.layer1(x) 431 | x2 = self.layer2(x1) 432 | x3 = self.layer3(x2) 433 | x4 = self.layer4(x3) 434 | return x4, x3, x2 435 | 436 | def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): 437 | model = ResNet(Bottleneck, [3, 4, 6, 3], 438 | radix=2, groups=1, bottleneck_width=64, 439 | deep_stem=True, stem_width=32, avg_down=True, 440 | avd=True, avd_first=False, **kwargs) 441 | if pretrained: 442 | model.load_state_dict(torch.hub.load_state_dict_from_url( 443 | resnest_model_urls['resnest50'], progress=True, check_hash=True)) 444 | return model 445 | 446 | def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): 447 | model = ResNet(Bottleneck, [3, 4, 23, 3], 448 | radix=2, groups=1, bottleneck_width=64, 449 | deep_stem=True, stem_width=64, avg_down=True, 450 | avd=True, avd_first=False, **kwargs) 451 | if pretrained: 452 | model.load_state_dict(torch.hub.load_state_dict_from_url( 453 | resnest_model_urls['resnest101'], progress=True, check_hash=True)) 454 | return model 455 | 456 | def resnest200(pretrained=False, from_moco=False, root='~/.encoding/models', **kwargs): 457 | model = ResNet(Bottleneck, [3, 24, 36, 3], 458 | radix=2, groups=1, bottleneck_width=64, 459 | deep_stem=True, stem_width=64, avg_down=True, 460 | avd=True, avd_first=False, **kwargs) 461 | if pretrained: 462 | model.load_state_dict(torch.hub.load_state_dict_from_url( 463 | resnest_model_urls['resnest200'], progress=True, check_hash=True)) 464 | 465 | if from_moco: 466 | # path = '/home/rym/workspace/fgvc7/moco_pretrained_resnest200.pth.tar' 467 | path = '/home/cgy/Works/fgvc7-master/logs/checkpoint_0059.pth.tar' 468 | sd = torch.load(path,map_location='cpu')['state_dict'] 469 | new_sd = {} 470 | model.fc = nn.Sequential( 471 | nn.Linear(2048, 2048), 472 | nn.ReLU(), 473 | nn.Linear(2048, 128) 474 | ) 475 | for key in sd.keys(): 476 | if 'module.encoder_k' in key: 477 | new_sd[key[17:]] = sd[key] 478 | model.load_state_dict(new_sd) 479 | model.fc = nn.Sequential() 480 | print('===> Successfully loaded ResNeSt-200 MOCO pretrained model') 481 | 482 | return model 483 | 484 | def resnest269(pretrained=False,from_moco=False, root='~/.encoding/models', **kwargs): 485 | model = ResNet(Bottleneck, [3, 30, 48, 8], 486 | radix=2, groups=1, bottleneck_width=64, 487 | deep_stem=True, stem_width=64, avg_down=True, 488 | avd=True, avd_first=False, **kwargs) 489 | if pretrained: 490 | model.load_state_dict(torch.hub.load_state_dict_from_url( 491 | resnest_model_urls['resnest269'], progress=True, check_hash=True)) 492 | 493 | if from_moco: 494 | # path = '/home/rym/workspace/fgvc7/moco_pretrained_resnest200.pth.tar' 495 | path = '/home/cgy/Works/fgvc7-master/logs/moco_pretrained_resnest269.pth.tar' 496 | sd = torch.load(path,map_location='cpu')['state_dict'] 497 | new_sd = {} 498 | model.fc = nn.Sequential( 499 | nn.Linear(2048, 2048), 500 | nn.ReLU(), 501 | nn.Linear(2048, 128) 502 | ) 503 | for key in sd.keys(): 504 | if 'module.encoder_k' in key: 505 | new_sd[key[17:]] = sd[key] 506 | model.load_state_dict(new_sd) 507 | model.fc = nn.Sequential() 508 | print('===> Successfully loaded ResNeSt-269 MOCO pretrained model') 509 | return model -------------------------------------------------------------------------------- /modeling/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class Bottleneck(nn.Module): 9 | expansion = 4 10 | 11 | def __init__(self, inplanes, planes, stride=1, downsample=None): 12 | super(Bottleneck, self).__init__() 13 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(planes) 18 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 19 | self.bn3 = nn.BatchNorm2d(planes * 4) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv3(out) 36 | out = self.bn3(out) 37 | 38 | if self.downsample is not None: 39 | residual = self.downsample(x) 40 | 41 | out += residual 42 | out = self.relu(out) 43 | 44 | return out 45 | 46 | 47 | class ResNet(nn.Module): 48 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]): 49 | self.inplanes = 64 50 | super().__init__() 51 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 52 | bias=False) 53 | self.bn1 = nn.BatchNorm2d(64) 54 | #self.relu = nn.ReLU(inplace=True) 55 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 56 | self.layer1 = self._make_layer(block, 64, layers[0]) 57 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 58 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 59 | self.layer4 = self._make_layer( 60 | block, 512, layers[3], stride=last_stride) 61 | 62 | def _make_layer(self, block, planes, blocks, stride=1): 63 | downsample = None 64 | if stride != 1 or self.inplanes != planes * block.expansion: 65 | downsample = nn.Sequential( 66 | nn.Conv2d(self.inplanes, planes * block.expansion, 67 | kernel_size=1, stride=stride, bias=False), 68 | nn.BatchNorm2d(planes * block.expansion), 69 | ) 70 | 71 | layers = [] 72 | layers.append(block(self.inplanes, planes, stride, downsample)) 73 | self.inplanes = planes * block.expansion 74 | for i in range(1, blocks): 75 | layers.append(block(self.inplanes, planes)) 76 | 77 | return nn.Sequential(*layers) 78 | 79 | def forward(self, x): 80 | x = self.conv1(x) 81 | x = self.bn1(x) 82 | #x = self.relu(x) 83 | x = self.maxpool(x) 84 | 85 | x = self.layer1(x) 86 | x = self.layer2(x) 87 | x = self.layer3(x) 88 | x = self.layer4(x) 89 | 90 | return x 91 | 92 | def load_param(self, model_path): 93 | param_dict = torch.load(model_path) 94 | for i in param_dict: 95 | if 'fc' in i: 96 | continue 97 | self.state_dict()[i].copy_(param_dict[i]) 98 | 99 | def random_init(self): 100 | for m in self.modules(): 101 | if isinstance(m, nn.Conv2d): 102 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 103 | m.weight.data.normal_(0, math.sqrt(2. / n)) 104 | elif isinstance(m, nn.BatchNorm2d): 105 | m.weight.data.fill_(1) 106 | m.bias.data.zero_() 107 | -------------------------------------------------------------------------------- /modeling/backbones/se_module.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SELayer(nn.Module): 5 | def __init__(self, channel, reduction=16): 6 | super(SELayer, self).__init__() 7 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 8 | self.fc = nn.Sequential( 9 | nn.Linear(channel, channel // reduction, bias=False), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(channel // reduction, channel, bias=False), 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, x): 16 | b, c, _, _ = x.size() 17 | y = self.avg_pool(x).view(b, c) 18 | y = self.fc(y).view(b, c, 1, 1) 19 | return x * y.expand_as(x) 20 | 21 | 22 | -------------------------------------------------------------------------------- /modeling/backbones/se_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from torchvision.models import ResNet 5 | from .se_module import SELayer 6 | 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 11 | 12 | 13 | class SEBasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16): 17 | super(SEBasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes, 1) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.se = SELayer(planes, reduction) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | out = self.se(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | 46 | class SEBottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16): 50 | super(SEBottleneck, self).__init__() 51 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 52 | self.bn1 = nn.BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 54 | padding=1, bias=False) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 57 | self.bn3 = nn.BatchNorm2d(planes * 4) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.se = SELayer(planes * 4, reduction) 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv3(out) 75 | out = self.bn3(out) 76 | out = self.se(out) 77 | 78 | if self.downsample is not None: 79 | residual = self.downsample(x) 80 | 81 | out += residual 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | class SEResNet(nn.Module): 87 | def __init__(self, last_stride=2, block=SEBottleneck, layers=[3, 4, 6, 3]): 88 | self.inplanes = 64 89 | super().__init__() 90 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 91 | bias=False) 92 | self.bn1 = nn.BatchNorm2d(64) 93 | #self.relu = nn.ReLU(inplace=True) 94 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 95 | self.layer1 = self._make_layer(block, 64, layers[0]) 96 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 97 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 98 | self.layer4 = self._make_layer( 99 | block, 512, layers[3], stride=last_stride) 100 | 101 | def _make_layer(self, block, planes, blocks, stride=1): 102 | downsample = None 103 | if stride != 1 or self.inplanes != planes * block.expansion: 104 | downsample = nn.Sequential( 105 | nn.Conv2d(self.inplanes, planes * block.expansion, 106 | kernel_size=1, stride=stride, bias=False), 107 | nn.BatchNorm2d(planes * block.expansion), 108 | ) 109 | 110 | layers = [] 111 | layers.append(block(self.inplanes, planes, stride, downsample)) 112 | self.inplanes = planes * block.expansion 113 | for i in range(1, blocks): 114 | layers.append(block(self.inplanes, planes)) 115 | 116 | return nn.Sequential(*layers) 117 | 118 | def forward(self, x): 119 | x = self.conv1(x) 120 | x = self.bn1(x) 121 | #x = self.relu(x) 122 | x = self.maxpool(x) 123 | 124 | x = self.layer1(x) 125 | x = self.layer2(x) 126 | x = self.layer3(x) 127 | x = self.layer4(x) 128 | 129 | return x 130 | 131 | def load_param(self, model_path): 132 | param_dict = torch.load(model_path) 133 | for i in param_dict: 134 | if 'fc' in i: 135 | continue 136 | self.state_dict()[i].copy_(param_dict[i]) 137 | 138 | def random_init(self): 139 | for m in self.modules(): 140 | if isinstance(m, nn.Conv2d): 141 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 142 | m.weight.data.normal_(0, math.sqrt(2. / n)) 143 | elif isinstance(m, nn.BatchNorm2d): 144 | m.weight.data.fill_(1) 145 | m.bias.data.zero_() 146 | 147 | 148 | 149 | 150 | 151 | def se_resnet18(num_classes=1_000): 152 | """Constructs a ResNet-18 model. 153 | 154 | Args: 155 | pretrained (bool): If True, returns a model pre-trained on ImageNet 156 | """ 157 | model = ResNet(SEBasicBlock, [2, 2, 2, 2], num_classes=num_classes) 158 | model.avgpool = nn.AdaptiveAvgPool2d(1) 159 | return model 160 | 161 | 162 | def se_resnet34(num_classes=1_000): 163 | """Constructs a ResNet-34 model. 164 | 165 | Args: 166 | pretrained (bool): If True, returns a model pre-trained on ImageNet 167 | """ 168 | model = ResNet(SEBasicBlock, [3, 4, 6, 3], num_classes=num_classes) 169 | model.avgpool = nn.AdaptiveAvgPool2d(1) 170 | return model 171 | 172 | 173 | def se_resnet50(num_classes=1_000, pretrained=False): 174 | """Constructs a ResNet-50 model. 175 | 176 | Args: 177 | pretrained (bool): If True, returns a model pre-trained on ImageNet 178 | """ 179 | model = ResNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes) 180 | model.avgpool = nn.AdaptiveAvgPool2d(1) 181 | if pretrained: 182 | model.load_state_dict(model_zoo.load_url("https://www.dropbox.com/s/xpq8ne7rwa4kg4c/seresnet50-60a8950a85b2b.pkl")) 183 | return model 184 | 185 | 186 | def se_resnet101(num_classes=1_000): 187 | """Constructs a ResNet-101 model. 188 | 189 | Args: 190 | pretrained (bool): If True, returns a model pre-trained on ImageNet 191 | """ 192 | model = ResNet(SEBottleneck, [3, 4, 23, 3], num_classes=num_classes) 193 | model.avgpool = nn.AdaptiveAvgPool2d(1) 194 | return model 195 | 196 | 197 | def se_resnet152(num_classes=1_000): 198 | """Constructs a ResNet-152 model. 199 | 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | """ 203 | model = ResNet(SEBottleneck, [3, 8, 36, 3], num_classes=num_classes) 204 | model.avgpool = nn.AdaptiveAvgPool2d(1) 205 | return model 206 | 207 | 208 | class CifarSEBasicBlock(nn.Module): 209 | def __init__(self, inplanes, planes, stride=1, reduction=16): 210 | super(CifarSEBasicBlock, self).__init__() 211 | self.conv1 = conv3x3(inplanes, planes, stride) 212 | self.bn1 = nn.BatchNorm2d(planes) 213 | self.relu = nn.ReLU(inplace=True) 214 | self.conv2 = conv3x3(planes, planes) 215 | self.bn2 = nn.BatchNorm2d(planes) 216 | self.se = SELayer(planes, reduction) 217 | if inplanes != planes: 218 | self.downsample = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False), 219 | nn.BatchNorm2d(planes)) 220 | else: 221 | self.downsample = lambda x: x 222 | self.stride = stride 223 | 224 | def forward(self, x): 225 | residual = self.downsample(x) 226 | out = self.conv1(x) 227 | out = self.bn1(out) 228 | out = self.relu(out) 229 | 230 | out = self.conv2(out) 231 | out = self.bn2(out) 232 | out = self.se(out) 233 | 234 | out += residual 235 | out = self.relu(out) 236 | 237 | return out 238 | 239 | 240 | class CifarSEResNet(nn.Module): 241 | def __init__(self, block, n_size, num_classes=10, reduction=16): 242 | super(CifarSEResNet, self).__init__() 243 | self.inplane = 16 244 | self.conv1 = nn.Conv2d(3, self.inplane, kernel_size=3, stride=1, padding=1, bias=False) 245 | self.bn1 = nn.BatchNorm2d(self.inplane) 246 | self.relu = nn.ReLU(inplace=True) 247 | self.layer1 = self._make_layer(block, 16, blocks=n_size, stride=1, reduction=reduction) 248 | self.layer2 = self._make_layer(block, 32, blocks=n_size, stride=2, reduction=reduction) 249 | self.layer3 = self._make_layer(block, 64, blocks=n_size, stride=2, reduction=reduction) 250 | self.avgpool = nn.AdaptiveAvgPool2d(1) 251 | self.fc = nn.Linear(64, num_classes) 252 | self.initialize() 253 | 254 | def initialize(self): 255 | for m in self.modules(): 256 | if isinstance(m, nn.Conv2d): 257 | nn.init.kaiming_normal_(m.weight) 258 | elif isinstance(m, nn.BatchNorm2d): 259 | nn.init.constant_(m.weight, 1) 260 | nn.init.constant_(m.bias, 0) 261 | 262 | def _make_layer(self, block, planes, blocks, stride, reduction): 263 | strides = [stride] + [1] * (blocks - 1) 264 | layers = [] 265 | for stride in strides: 266 | layers.append(block(self.inplane, planes, stride, reduction)) 267 | self.inplane = planes 268 | 269 | return nn.Sequential(*layers) 270 | 271 | def forward(self, x): 272 | x = self.conv1(x) 273 | x = self.bn1(x) 274 | x = self.relu(x) 275 | 276 | x = self.layer1(x) 277 | x = self.layer2(x) 278 | x = self.layer3(x) 279 | 280 | x = self.avgpool(x) 281 | x = x.view(x.size(0), -1) 282 | x = self.fc(x) 283 | 284 | return x 285 | 286 | 287 | class CifarSEPreActResNet(CifarSEResNet): 288 | def __init__(self, block, n_size, num_classes=10, reduction=16): 289 | super(CifarSEPreActResNet, self).__init__(block, n_size, num_classes, reduction) 290 | self.bn1 = nn.BatchNorm2d(self.inplane) 291 | self.initialize() 292 | 293 | def forward(self, x): 294 | x = self.conv1(x) 295 | x = self.layer1(x) 296 | x = self.layer2(x) 297 | x = self.layer3(x) 298 | 299 | x = self.bn1(x) 300 | x = self.relu(x) 301 | 302 | x = self.avgpool(x) 303 | x = x.view(x.size(0), -1) 304 | x = self.fc(x) 305 | 306 | 307 | def se_resnet20(**kwargs): 308 | """Constructs a ResNet-18 model. 309 | 310 | """ 311 | model = CifarSEResNet(CifarSEBasicBlock, 3, **kwargs) 312 | return model 313 | 314 | 315 | def se_resnet32(**kwargs): 316 | """Constructs a ResNet-34 model. 317 | 318 | """ 319 | model = CifarSEResNet(CifarSEBasicBlock, 5, **kwargs) 320 | return model 321 | 322 | 323 | def se_resnet56(**kwargs): 324 | """Constructs a ResNet-34 model. 325 | 326 | """ 327 | model = CifarSEResNet(CifarSEBasicBlock, 9, **kwargs) 328 | return model 329 | 330 | 331 | def se_preactresnet20(**kwargs): 332 | """Constructs a ResNet-18 model. 333 | 334 | """ 335 | model = CifarSEPreActResNet(CifarSEBasicBlock, 3, **kwargs) 336 | return model 337 | 338 | 339 | def se_preactresnet32(**kwargs): 340 | """Constructs a ResNet-34 model. 341 | 342 | """ 343 | model = CifarSEPreActResNet(CifarSEBasicBlock, 5, **kwargs) 344 | return model 345 | 346 | 347 | def se_preactresnet56(**kwargs): 348 | """Constructs a ResNet-34 model. 349 | 350 | """ 351 | model = CifarSEPreActResNet(CifarSEBasicBlock, 9, **kwargs) 352 | return model 353 | -------------------------------------------------------------------------------- /modeling/baseline.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import sys 6 | import pdb 7 | 8 | from .backbones.se_module import SELayer 9 | from .backbones.inception import BasicConv2d 10 | from .backbones.resnet import ResNet 11 | from .backbones.resnest import resnest50 12 | sys.path.append('.') 13 | 14 | 15 | EPSILON = 1e-12 16 | 17 | 18 | def weights_init_kaiming(m): 19 | classname = m.__class__.__name__ 20 | if classname.find('Linear') != -1: 21 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 22 | nn.init.constant_(m.bias, 0.0) 23 | elif classname.find('Conv') != -1: 24 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 25 | if m.bias is not None: 26 | nn.init.constant_(m.bias, 0.0) 27 | elif classname.find('BatchNorm') != -1: 28 | if m.affine: 29 | nn.init.constant_(m.weight, 1.0) 30 | nn.init.constant_(m.bias, 0.0) 31 | 32 | 33 | def weights_init_classifier(m): 34 | classname = m.__class__.__name__ 35 | if classname.find('Linear') != -1: 36 | nn.init.normal_(m.weight, std=0.001) 37 | if m.bias: 38 | nn.init.constant_(m.bias, 0.0) 39 | 40 | 41 | class SAMS(nn.Module): 42 | """ 43 | Split-Attend-Merge-Stack agent 44 | Input an feature map with shape H*W*C, we first split the feature maps into 45 | multiple parts, obtain the attention map of each part, and the attention map 46 | for the current pyramid level is constructed by mergiing each attention map. 47 | """ 48 | def __init__(self, in_channels, channels, 49 | radix=4, reduction_factor=4, 50 | norm_layer=nn.BatchNorm2d): 51 | super(SAMS, self).__init__() 52 | inter_channels = max(in_channels*radix//reduction_factor, 32) 53 | self.radix = radix 54 | self.channels = channels 55 | self.relu = nn.ReLU(inplace=True) 56 | self.fc1 = nn.Conv2d(channels, inter_channels, 1, groups=1) 57 | self.bn1 = norm_layer(inter_channels) 58 | self.fc2 = nn.Conv2d(inter_channels, channels*radix, 1, groups=1) 59 | 60 | 61 | def forward(self, x): 62 | 63 | batch, channel = x.shape[:2] 64 | splited = torch.split(x, channel//self.radix, dim=1) 65 | 66 | gap = sum(splited) 67 | gap = F.adaptive_avg_pool2d(gap, 1) 68 | gap = self.fc1(gap) 69 | gap = self.bn1(gap) 70 | gap = self.relu(gap) 71 | 72 | atten = self.fc2(gap).view((batch, self.radix, self.channels)) 73 | atten = F.softmax(atten, dim=1).view(batch, -1, 1, 1) 74 | atten = torch.split(atten, channel//self.radix, dim=1) 75 | 76 | out= torch.cat([att*split for (att, split) in zip(atten, splited)],1) 77 | return out.contiguous() 78 | 79 | 80 | class SELayer(nn.Module): 81 | def __init__(self, channel, reduction=16): 82 | super(SELayer, self).__init__() 83 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 84 | self.fc = nn.Sequential( 85 | nn.Linear(channel, channel // reduction, bias=False), 86 | nn.ReLU(inplace=True), 87 | nn.Linear(channel // reduction, channel, bias=False), 88 | nn.Sigmoid() 89 | ) 90 | 91 | def forward(self, x): 92 | b, c, _, _ = x.size() 93 | y = self.avg_pool(x).view(b, c) 94 | y = self.fc(y).view(b, c, 1, 1) 95 | return y 96 | 97 | class BN2d(nn.Module): 98 | def __init__(self, planes): 99 | super(BN2d, self).__init__() 100 | self.bottleneck2 = nn.BatchNorm2d(planes) 101 | self.bottleneck2.bias.requires_grad_(False) # no shift 102 | self.bottleneck2.apply(weights_init_kaiming) 103 | 104 | def forward(self, x): 105 | return self.bottleneck2(x) 106 | 107 | 108 | class Baseline(nn.Module): 109 | in_planes = 2048 110 | 111 | def __init__(self, num_classes, last_stride, model_path,level,msmt): 112 | super(Baseline, self).__init__() 113 | print(f"Training with pyramid level {level}") 114 | self.level = level 115 | self.is_msmt = msmt 116 | self.base = ResNet(last_stride= last_stride) 117 | 118 | 119 | self.base.load_param(model_path) 120 | self.base_1 = nn.Sequential(*list(self.base.children())[0:3]) 121 | self.base_2 = nn.Sequential(*list(self.base.children())[3:4]) 122 | self.base_3 = nn.Sequential(*list(self.base.children())[4:5]) 123 | self.base_4 = nn.Sequential(*list(self.base.children())[5:6]) 124 | self.base_5 = nn.Sequential(*list(self.base.children())[6:]) 125 | 126 | 127 | if self.level > 0: 128 | self.att1 = SELayer(64,8) 129 | self.att2 = SELayer(256,32) 130 | self.att3 = SELayer(512,64) 131 | self.att4 = SELayer(1024,128) 132 | self.att5 = SELayer(2048,256) 133 | if self.level > 1: # second pyramid level 134 | self.att_s1=SAMS(64,int(64/self.level),radix=self.level) 135 | self.att_s2=SAMS(256,int(256/self.level),radix=self.level) 136 | self.att_s3=SAMS(512,int(512/self.level),radix=self.level) 137 | self.att_s4=SAMS(1024,int(1024/self.level),radix=self.level) 138 | self.att_s5=SAMS(2048,int(2048/self.level),radix=self.level) 139 | self.BN1 = BN2d(64) 140 | self.BN2 = BN2d(256) 141 | self.BN3 = BN2d(512) 142 | self.BN4 = BN2d(1024) 143 | self.BN5 = BN2d(2048) 144 | 145 | if self.level > 2: 146 | self.att_ss1=SAMS(64,int(64/self.level),radix=self.level) 147 | self.att_ss2=SAMS(256,int(256/self.level),radix=self.level) 148 | self.att_ss3=SAMS(512,int(512/self.level),radix=self.level) 149 | self.att_ss4=SAMS(1024,int(1024/self.level),radix=self.level) 150 | self.att_ss5=SAMS(2048,int(2048/self.level),radix=self.level) 151 | self.BN_1 = BN2d(64) 152 | self.BN_2 = BN2d(256) 153 | self.BN_3 = BN2d(512) 154 | self.BN_4 = BN2d(1024) 155 | self.BN_5 = BN2d(2048) 156 | if self.level > 3: 157 | raise RuntimeError("We do not support pyramid level greater than three.") 158 | 159 | self.gap = nn.AdaptiveAvgPool2d(1) 160 | self.num_classes = num_classes 161 | 162 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 163 | self.bottleneck.bias.requires_grad_(False) # no shift 164 | self.bottleneck.apply(weights_init_kaiming) 165 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) 166 | self.classifier.apply(weights_init_classifier) 167 | 168 | 169 | def forward(self, x): 170 | 171 | 172 | #pdb.set_trace() 173 | x = self.base_1(x) 174 | if self.level > 2: 175 | x = self.att_ss1(x) 176 | x = self.BN_1(x) 177 | if self.level > 1: 178 | x = self.att_s1(x) 179 | x = self.BN1(x) 180 | if self.level > 0: 181 | y = self.att1(x) 182 | x=x*y.expand_as(x) 183 | 184 | 185 | x = self.base_2(x) 186 | if self.level > 2: 187 | x = self.att_ss2(x) 188 | x = self.BN_2(x) 189 | if self.level > 1: 190 | x = self.att_s2(x) 191 | x = self.BN2(x) 192 | if self.level > 0: 193 | y = self.att2(x) 194 | x=x*y.expand_as(x) 195 | 196 | 197 | x = self.base_3(x) 198 | if self.level > 2: 199 | x = self.att_ss3(x) 200 | x = self.BN_3(x) 201 | if self.level > 1: 202 | x = self.att_s3(x) 203 | x = self.BN3(x) 204 | if self.level > 0: 205 | y = self.att3(x) 206 | x=x*y.expand_as(x) 207 | 208 | x = self.base_4(x) 209 | if self.level > 2: 210 | x = self.att_ss4(x) 211 | x = self.BN_4(x) 212 | if self.level > 1: 213 | x = self.att_s4(x) 214 | x = self.BN4(x) 215 | if self.level > 0: 216 | y = self.att4(x) 217 | x=x*y.expand_as(x) 218 | 219 | 220 | x = self.base_5(x) 221 | if self.level > 2: 222 | x = self.att_ss5(x) 223 | x = self.BN_5(x) 224 | if self.level > 1: 225 | x = self.att_s5(x) 226 | x = self.BN5(x) 227 | if self.level > 0: 228 | y = self.att5(x) 229 | x=x*y.expand_as(x) 230 | 231 | 232 | global_feat = self.gap(x) # (b, 2048, 1, 1) 233 | global_feat = global_feat.view(global_feat.shape[0], -1) # flatten to (bs, 2048) 234 | 235 | feat = self.bottleneck(global_feat) # normalize for angular softmax 236 | 237 | if self.training: 238 | cls_score = self.classifier(feat) 239 | 240 | return cls_score, global_feat # global feature for triplet loss 241 | else: 242 | if self.is_msmt: 243 | return self.classifier(feat) 244 | else: 245 | return feat 246 | 247 | # return self.classifier(feat) 248 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-ignite 2 | yacs 3 | mat4py 4 | matplotlib 5 | 6 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .build import make_optimizer 3 | from .lr_scheduler import WarmupMultiStepLR,WarmupStepLR 4 | -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/solver/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/solver/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /solver/__pycache__/build.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/solver/__pycache__/build.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/solver/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/solver/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/solver/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /solver/build.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | def make_optimizer(cfg, model): 6 | params = [] 7 | for key, value in model.named_parameters(): 8 | if not value.requires_grad: 9 | continue 10 | lr = cfg.SOLVER.BASE_LR 11 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 12 | if "bias" in key: 13 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 14 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 15 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 16 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 17 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 18 | else: 19 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 20 | return optimizer 21 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | 2 | from bisect import bisect_right 3 | import torch 4 | 5 | 6 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 7 | # separating MultiStepLR with WarmupLR 8 | # but the current LRScheduler design doesn't allow it 9 | 10 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 11 | def __init__( 12 | self, 13 | optimizer, 14 | milestones, 15 | gamma=0.1, 16 | warmup_factor=1.0 / 3, 17 | warmup_iters=500, 18 | warmup_method="linear", 19 | last_epoch=-1, 20 | ): 21 | if not list(milestones) == sorted(milestones): 22 | raise ValueError( 23 | "Milestones should be a list of" " increasing integers. Got {}", 24 | milestones, 25 | ) 26 | 27 | if warmup_method not in ("constant", "linear"): 28 | raise ValueError( 29 | "Only 'constant' or 'linear' warmup_method accepted" 30 | "got {}".format(warmup_method) 31 | ) 32 | self.milestones = milestones 33 | self.gamma = gamma 34 | self.warmup_factor = warmup_factor 35 | self.warmup_iters = warmup_iters 36 | self.warmup_method = warmup_method 37 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 38 | 39 | def get_lr(self): 40 | warmup_factor = 1 41 | if self.last_epoch < self.warmup_iters: 42 | if self.warmup_method == "constant": 43 | warmup_factor = self.warmup_factor 44 | elif self.warmup_method == "linear": 45 | alpha = self.last_epoch / self.warmup_iters 46 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 47 | return [ 48 | base_lr 49 | * warmup_factor 50 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 51 | for base_lr in self.base_lrs 52 | ] 53 | 54 | 55 | class WarmupStepLR(torch.optim.lr_scheduler._LRScheduler): 56 | def __init__( 57 | self, 58 | optimizer, 59 | step_size=2, 60 | gamma=0.9, 61 | warmup_factor=1.0 / 3, 62 | warmup_iters=500, 63 | warmup_method="linear", 64 | last_epoch=-1, 65 | ): 66 | # if not list(milestones) == sorted(milestones): 67 | # raise ValueError( 68 | # "Milestones should be a list of" " increasing integers. Got {}", 69 | # milestones, 70 | # ) 71 | 72 | if warmup_method not in ("constant", "linear"): 73 | raise ValueError( 74 | "Only 'constant' or 'linear' warmup_method accepted" 75 | "got {}".format(warmup_method) 76 | ) 77 | self.step_size = step_size 78 | self.gamma = gamma 79 | self.warmup_factor = warmup_factor 80 | self.warmup_iters = warmup_iters 81 | self.warmup_method = warmup_method 82 | super(WarmupStepLR, self).__init__(optimizer, last_epoch) 83 | 84 | def get_lr(self): 85 | warmup_factor = 1 86 | if self.last_epoch < self.warmup_iters: 87 | if self.warmup_method == "constant": 88 | warmup_factor = self.warmup_factor 89 | elif self.warmup_method == "linear": 90 | alpha = self.last_epoch / self.warmup_iters 91 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 92 | return [ 93 | base_lr 94 | * warmup_factor 95 | * self.gamma ** (self.last_epoch // self.step_size) 96 | for base_lr in self.base_lrs 97 | ] 98 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1 python3 tools/test.py --config_file='configs/default.yml' TEST.WEIGHT 'WEIGHT_PATH' 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | -------------------------------------------------------------------------------- /tests/lr_scheduler_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import unittest 3 | 4 | import torch 5 | from torch import nn 6 | 7 | sys.path.append('.') 8 | from solver.lr_scheduler import WarmupMultiStepLR 9 | from solver.build import make_optimizer 10 | from config import cfg 11 | 12 | 13 | class MyTestCase(unittest.TestCase): 14 | def test_something(self): 15 | net = nn.Linear(10, 10) 16 | optimizer = make_optimizer(cfg, net) 17 | lr_scheduler = WarmupMultiStepLR(optimizer, [20, 40], warmup_iters=10) 18 | for i in range(50): 19 | lr_scheduler.step() 20 | for j in range(3): 21 | print(i, lr_scheduler.get_lr()[0]) 22 | optimizer.step() 23 | 24 | 25 | if __name__ == '__main__': 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | import sys 5 | from os import mkdir 6 | 7 | import torch 8 | from torch import nn 9 | from torch.backends import cudnn 10 | 11 | sys.path.append('.') 12 | from config import cfg 13 | from data import make_data_loader 14 | from engine.inference import inference 15 | from modeling import build_model 16 | from utils.logger import setup_logger 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser(description="ReID Baseline Inference") 21 | parser.add_argument( 22 | "--config_file", default="", help="path to config file", type=str 23 | ) 24 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 25 | nargs=argparse.REMAINDER) 26 | 27 | args = parser.parse_args() 28 | 29 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 30 | 31 | if args.config_file != "": 32 | cfg.merge_from_file(args.config_file) 33 | cfg.merge_from_list(args.opts) 34 | cfg.freeze() 35 | 36 | output_dir = cfg.OUTPUT_DIR 37 | if output_dir and not os.path.exists(output_dir): 38 | mkdir(output_dir) 39 | 40 | logger = setup_logger("reid_baseline", output_dir, 0) 41 | logger.info("Using {} GPUS".format(num_gpus)) 42 | logger.info(args) 43 | 44 | if args.config_file != "": 45 | logger.info("Loaded configuration file {}".format(args.config_file)) 46 | with open(args.config_file, 'r') as cf: 47 | config_str = "\n" + cf.read() 48 | logger.info(config_str) 49 | logger.info("Running with config:\n{}".format(cfg)) 50 | 51 | cudnn.benchmark = True 52 | 53 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) 54 | 55 | model = build_model(cfg, num_classes) # 56 | # import pdb 57 | # pdb.set_trace() 58 | model.load_state_dict(torch.load(cfg.TEST.WEIGHT)["model"]) 59 | model = nn.DataParallel(model) 60 | # model.load_state_dict(torch.load(cfg.TEST.WEIGHT)) 61 | 62 | #total_loader = torch.utils.data.ConcatDataset([train_loader, val_loader]) 63 | 64 | inference(cfg, model, val_loader, num_query) 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import numpy as np 7 | import random 8 | import torch 9 | 10 | from torch.backends import cudnn 11 | from torch import nn 12 | sys.path.append('.') 13 | from config import cfg 14 | import pdb 15 | from data import make_data_loader 16 | from engine.trainer import do_train 17 | from modeling import build_model 18 | from layers import make_loss 19 | from solver import make_optimizer, WarmupMultiStepLR,WarmupStepLR 20 | 21 | from utils.logger import setup_logger 22 | 23 | def setup_seed(seed): 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | np.random.seed(seed) 28 | random.seed(seed) 29 | 30 | 31 | 32 | 33 | def train(cfg): 34 | 35 | # prepare dataset 36 | 37 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) 38 | 39 | # prepare model 40 | model = build_model(cfg, num_classes) 41 | if cfg.SOLVER.FINETUNE: 42 | model.load_state_dict(torch.load(cfg.TEST.WEIGHT).module.state_dict()) 43 | model = nn.DataParallel(model) 44 | 45 | 46 | optimizer = make_optimizer(cfg, model) 47 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 48 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 49 | # scheduler = WarmupStepLR(optimizer,3, 9, cfg.SOLVER.WARMUP_FACTOR, 50 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 51 | 52 | loss_func = make_loss(cfg) 53 | 54 | arguments = {} 55 | 56 | do_train( 57 | cfg, 58 | model, 59 | train_loader, 60 | val_loader, 61 | optimizer, 62 | scheduler, 63 | loss_func, 64 | num_query 65 | ) 66 | 67 | 68 | def main(): 69 | setup_seed(1) 70 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 71 | parser.add_argument( 72 | "--config_file", default="", help="path to config file", type=str 73 | ) 74 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 75 | nargs=argparse.REMAINDER) 76 | 77 | args = parser.parse_args() 78 | 79 | # num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 80 | 81 | if args.config_file != "": 82 | cfg.merge_from_file(args.config_file) 83 | cfg.merge_from_list(args.opts) 84 | cfg.freeze() 85 | 86 | output_dir = cfg.OUTPUT_DIR 87 | if output_dir and not os.path.exists(output_dir): 88 | os.makedirs(output_dir) 89 | 90 | logger = setup_logger("reid_baseline", output_dir, 0) 91 | logger.info(args) 92 | 93 | if args.config_file != "": 94 | logger.info("Loaded configuration file {}".format(args.config_file)) 95 | with open(args.config_file, 'r') as cf: 96 | config_str = "\n" + cf.read() 97 | logger.info(config_str) 98 | logger.info("Running with config:\n{}".format(cfg)) 99 | 100 | cudnn.benchmark = True 101 | train(cfg) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1 python3 tools/train.py --config_file='configs/msmt.yml' 2 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/utils/__pycache__/iotools.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/re_ranking.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/utils/__pycache__/re_ranking.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/reid_metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CHENGY12/APNet/95fc5b5893d562fe57600f9a1df589cf3711ee7b/utils/__pycache__/reid_metric.cpython-37.pyc -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import errno 4 | import json 5 | import os 6 | 7 | import os.path as osp 8 | 9 | 10 | def mkdir_if_missing(directory): 11 | if not osp.exists(directory): 12 | try: 13 | os.makedirs(directory) 14 | except OSError as e: 15 | if e.errno != errno.EEXIST: 16 | raise 17 | 18 | 19 | def check_isfile(path): 20 | isfile = osp.isfile(path) 21 | if not isfile: 22 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 23 | return isfile 24 | 25 | 26 | def read_json(fpath): 27 | with open(fpath, 'r') as f: 28 | obj = json.load(f) 29 | return obj 30 | 31 | 32 | def write_json(obj, fpath): 33 | mkdir_if_missing(osp.dirname(fpath)) 34 | with open(fpath, 'w') as f: 35 | json.dump(obj, f, indent=4, separators=(',', ': ')) 36 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import logging 4 | import os 5 | import sys 6 | 7 | 8 | def setup_logger(name, save_dir, distributed_rank): 9 | logger = logging.getLogger(name) 10 | logger.setLevel(logging.DEBUG) 11 | # don't log results for the non-master process 12 | if distributed_rank > 0: 13 | return logger 14 | ch = logging.StreamHandler(stream=sys.stdout) 15 | ch.setLevel(logging.DEBUG) 16 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 17 | ch.setFormatter(formatter) 18 | logger.addHandler(ch) 19 | 20 | if save_dir: 21 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w') 22 | fh.setLevel(logging.DEBUG) 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | return logger 27 | -------------------------------------------------------------------------------- /utils/re_ranking.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 5 | 6 | # The following naming, e.g. gallery_num, is different from outer scope. 7 | # Don't care about it. 8 | 9 | original_dist = np.concatenate( 10 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 11 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 12 | axis=0) 13 | original_dist = np.power(original_dist, 2).astype(np.float32) 14 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 15 | V = np.zeros_like(original_dist).astype(np.float32) 16 | initial_rank = np.argsort(original_dist).astype(np.int32) 17 | 18 | query_num = q_g_dist.shape[0] 19 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 20 | all_num = gallery_num 21 | 22 | for i in range(all_num): 23 | # k-reciprocal neighbors 24 | forward_k_neigh_index = initial_rank[i,:k1+1] 25 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 26 | fi = np.where(backward_k_neigh_index==i)[0] 27 | k_reciprocal_index = forward_k_neigh_index[fi] 28 | k_reciprocal_expansion_index = k_reciprocal_index 29 | for j in range(len(k_reciprocal_index)): 30 | candidate = k_reciprocal_index[j] 31 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 32 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 33 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 34 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 35 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 36 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 37 | 38 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 39 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 40 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 41 | original_dist = original_dist[:query_num,] 42 | if k2 != 1: 43 | V_qe = np.zeros_like(V,dtype=np.float32) 44 | for i in range(all_num): 45 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 46 | V = V_qe 47 | del V_qe 48 | del initial_rank 49 | invIndex = [] 50 | for i in range(gallery_num): 51 | invIndex.append(np.where(V[:,i] != 0)[0]) 52 | 53 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 54 | 55 | 56 | for i in range(query_num): 57 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 58 | indNonZero = np.where(V[i,:] != 0)[0] 59 | indImages = [] 60 | indImages = [invIndex[ind] for ind in indNonZero] 61 | for j in range(len(indNonZero)): 62 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 63 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 64 | 65 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 66 | del original_dist 67 | del V 68 | del jaccard_dist 69 | final_dist = final_dist[:query_num,query_num:] 70 | return final_dist 71 | -------------------------------------------------------------------------------- /utils/reid_metric.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | 4 | import numpy as np 5 | import torch 6 | from ignite.metrics import Metric 7 | import pickle 8 | import time 9 | import pdb 10 | 11 | from data.datasets.eval_reid import eval_func 12 | from utils.re_ranking import re_ranking 13 | 14 | 15 | def euclidean_dist(x, y): 16 | """ 17 | Args: 18 | x: pytorch Variable, with shape [m, d] 19 | y: pytorch Variable, with shape [n, d] 20 | Returns: 21 | dist: pytorch Variable, with shape [m, n] 22 | """ 23 | m, n = x.size(0), y.size(0) 24 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 25 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 26 | dist = xx + yy 27 | dist.addmm_(1, -2, x, y.t()) 28 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 29 | return dist 30 | 31 | def euclidean_dist_cpu(x, y): 32 | m, n = x.shape[0], y.shape[0] 33 | xx = np.power(x, 2).sum(1) 34 | xx = np.reshape(xx, [xx.shape[0], 1]) 35 | xx = xx.repeat(n, axis=1) 36 | yy = np.power(y, 2).sum(1) 37 | yy = np.reshape(yy, [yy.shape[0], 1]) 38 | yy = yy.repeat(m, axis=1).T 39 | dist = xx + yy 40 | dist -= 2 * np.dot(x, y.T) 41 | dist = np.sqrt(np.clip(dist, 1e-12, dist.max())) 42 | return dist 43 | 44 | def cos_dist(x, y): 45 | """ 46 | Args: 47 | x: pytorch Variable, with shape [m, d] 48 | y: pytorch Variable, with shape [n, d] 49 | Returns: 50 | dist: pytorch Variable, with shape [m, n] 51 | """ 52 | xx = x/x.norm(dim=1)[:,None] 53 | yy = y/y.norm(dim=1)[:,None] 54 | dist = torch.mm(xx,yy.t()) 55 | # dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 56 | return 1-dist 57 | 58 | class R1_mAP(Metric): 59 | def __init__(self, num_query, max_rank=50, re_rank = False): 60 | super(R1_mAP, self).__init__() 61 | self.num_query = num_query 62 | self.max_rank = max_rank 63 | self.re_rank = re_rank 64 | self.count = 0 65 | 66 | 67 | 68 | def reset(self): 69 | self.feats = [] 70 | self.pids = [] 71 | self.camids = [] 72 | 73 | def update(self, output): 74 | feat, pid, camid = output 75 | self.feats.append(feat) 76 | self.pids.extend(np.asarray(pid)) 77 | self.camids.extend(np.asarray(camid)) 78 | 79 | def compute(self): 80 | # f = open(self.pkl_path, "rb") 81 | # feats = pickle.load(f) 82 | 83 | 84 | feats = torch.cat(self.feats, dim=0) 85 | fnorm = torch.norm(feats,p=2,dim=1,keepdim=True) 86 | feats = feats.div(fnorm.expand_as(feats)) 87 | # # query 88 | qf = feats[:self.num_query] 89 | q_pids = np.asarray(self.pids[:self.num_query]) 90 | q_camids = np.asarray(self.camids[:self.num_query]) 91 | # gallery 92 | gf = feats[self.num_query:] 93 | g_pids = np.asarray(self.pids[self.num_query:]) 94 | g_camids = np.asarray(self.camids[self.num_query:]) 95 | 96 | m, n = qf.shape[0], gf.shape[0] 97 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 98 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 99 | distmat.addmm_(1, -2, qf, gf.t()) 100 | 101 | qf = qf.cpu().numpy() 102 | gf = gf.cpu().numpy() 103 | 104 | 105 | 106 | distmat_q_g = euclidean_dist_cpu(qf,gf) 107 | 108 | 109 | # raw_data = {"qf":qf, 110 | # "gf":gf, 111 | # #"distmat_q_g":distmat_q_g, 112 | # "q_pids":q_pids, 113 | # "g_pids":g_pids, 114 | # "q_camids":q_camids, 115 | # "g_camids":g_camids 116 | # } 117 | # #save distmat 118 | # f = open('/home/gtp_cgy/ivg/dataset/LRR/msmt_train.pkl','wb+') 119 | # pickle.dump(raw_data,f) 120 | # f.close() 121 | 122 | pids = np.asarray(self.pids) 123 | camids = np.asarray(self.camids) 124 | 125 | #print(len(pids)) 126 | 127 | 128 | # raw_data = { 129 | # "feats": feats, 130 | # "pids": pids, 131 | # "camids": camids 132 | # } 133 | # 134 | # 135 | 136 | 137 | # exit() 138 | 139 | 140 | start = time.time() 141 | if self.re_rank: 142 | #distmat_cos = cos_dist(qf,gf) 143 | distmat_q_q = euclidean_dist_cpu(qf,qf) 144 | distmat_g_g = euclidean_dist_cpu(gf,gf) 145 | # distmat_q_q = distmat_q_q.cpu().numpy() 146 | # distmat_g_g = distmat_g_g.cpu().numpy() 147 | distmat = re_ranking(distmat_q_g,distmat_q_q,distmat_g_g ) 148 | duration = time.time()-start 149 | print(f"Re-ranking runing in {duration}") 150 | else: 151 | distmat = distmat_q_g 152 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 153 | 154 | return cmc, mAP 155 | --------------------------------------------------------------------------------