├── utils ├── re.txt ├── __init__.py ├── __pycache__ │ ├── logger.cpython-35.pyc │ ├── __init__.cpython-35.pyc │ ├── iotools.cpython-35.pyc │ ├── re_ranking.cpython-35.pyc │ └── reid_metric.cpython-35.pyc ├── logger.py ├── iotools.py ├── reid_metric.py └── re_ranking.py ├── fig ├── frame.jpg └── challenge.jpg ├── tests ├── __init__.py └── lr_scheduler_test.py ├── modeling ├── backbones │ ├── __init__.py │ ├── __pycache__ │ │ ├── resnet.cpython-35.pyc │ │ ├── senet.cpython-35.pyc │ │ ├── __init__.cpython-35.pyc │ │ ├── densenet.cpython-35.pyc │ │ ├── inception.cpython-35.pyc │ │ ├── mobilenet.cpython-35.pyc │ │ └── squeezenet.cpython-35.pyc │ ├── resnet.py │ ├── squeezenet.py │ ├── mobilenet.py │ ├── ecanet.py │ └── densenet.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ └── baseline.cpython-35.pyc ├── __init__.py ├── baseline1.py └── baseline2.py ├── data ├── __pycache__ │ ├── build.cpython-35.pyc │ ├── __init__.cpython-35.pyc │ └── collate_batch.cpython-35.pyc ├── __init__.py ├── datasets │ ├── __pycache__ │ │ ├── bases.cpython-35.pyc │ │ ├── cuhk03.cpython-35.pyc │ │ ├── irour.cpython-35.pyc │ │ ├── msmt17.cpython-35.pyc │ │ ├── rgbir.cpython-35.pyc │ │ ├── rgbour.cpython-35.pyc │ │ ├── veri.cpython-35.pyc │ │ ├── __init__.cpython-35.pyc │ │ ├── cityflow.cpython-35.pyc │ │ ├── eval_reid.cpython-35.pyc │ │ ├── market1501.cpython-35.pyc │ │ ├── vehicleid.cpython-35.pyc │ │ ├── dukemtmcreid.cpython-35.pyc │ │ └── dataset_loader.cpython-35.pyc │ ├── __init__.py │ ├── dataset_loader.py │ ├── eval_reid.py │ ├── market1501.py │ ├── msmt17.py │ ├── irour.py │ ├── rgbir.py │ ├── veri.py │ ├── cityflow.py │ ├── rgbour.py │ ├── VehicleIDDataset.py │ ├── bases.py │ ├── dukemtmcreid.py │ └── vehicleid.py ├── samplers │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ └── triplet_sampler.cpython-35.pyc │ ├── __init__.py │ └── triplet_sampler.py ├── transforms │ ├── __pycache__ │ │ ├── build.cpython-35.pyc │ │ ├── __init__.cpython-35.pyc │ │ └── transforms.cpython-35.pyc │ ├── __init__.py │ ├── build.py │ └── transforms.py ├── collate_batch.py └── build.py ├── solver ├── __pycache__ │ ├── build.cpython-35.pyc │ ├── __init__.cpython-35.pyc │ └── lr_scheduler.cpython-35.pyc ├── __init__.py ├── build.py └── lr_scheduler.py ├── config ├── __pycache__ │ ├── __init__.cpython-35.pyc │ └── defaults.cpython-35.pyc ├── __init__.py └── defaults.py ├── engine ├── __pycache__ │ ├── trainer.cpython-35.pyc │ └── inference.cpython-35.pyc └── inference.py ├── layers ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── range_loss.cpython-35.pyc │ ├── center_loss.cpython-35.pyc │ ├── cluster_loss.cpython-35.pyc │ └── triplet_loss.cpython-35.pyc ├── center_loss.py ├── triplet_loss.py ├── __init__.py ├── range_loss.py └── cluster_loss.py ├── softmax_triplet.yml ├── test.py ├── README.md └── train.py /utils/re.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fig/frame.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/fig/frame.jpg -------------------------------------------------------------------------------- /fig/challenge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/fig/challenge.jpg -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /data/__pycache__/build.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/__pycache__/build.cpython-35.pyc -------------------------------------------------------------------------------- /solver/__pycache__/build.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/solver/__pycache__/build.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/utils/__pycache__/logger.cpython-35.pyc -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/config/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /config/__pycache__/defaults.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/config/__pycache__/defaults.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /engine/__pycache__/trainer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/engine/__pycache__/trainer.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/layers/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/solver/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/utils/__pycache__/iotools.cpython-35.pyc -------------------------------------------------------------------------------- /engine/__pycache__/inference.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/engine/__pycache__/inference.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/range_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/layers/__pycache__/range_loss.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/modeling/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/baseline.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/modeling/__pycache__/baseline.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/re_ranking.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/utils/__pycache__/re_ranking.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/reid_metric.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/utils/__pycache__/reid_metric.cpython-35.pyc -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import make_data_loader 8 | -------------------------------------------------------------------------------- /data/__pycache__/collate_batch.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/__pycache__/collate_batch.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/bases.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/bases.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/cuhk03.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/cuhk03.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/irour.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/irour.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/msmt17.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/msmt17.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/rgbir.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/rgbir.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/rgbour.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/rgbour.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/veri.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/veri.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/center_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/layers/__pycache__/center_loss.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/cluster_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/layers/__pycache__/cluster_loss.cpython-35.pyc -------------------------------------------------------------------------------- /layers/__pycache__/triplet_loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/layers/__pycache__/triplet_loss.cpython-35.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/solver/__pycache__/lr_scheduler.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/cityflow.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/cityflow.cpython-35.pyc -------------------------------------------------------------------------------- /data/samplers/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/samplers/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/build.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/transforms/__pycache__/build.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/eval_reid.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/eval_reid.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/market1501.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/market1501.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/vehicleid.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/vehicleid.cpython-35.pyc -------------------------------------------------------------------------------- /data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import build_transforms 8 | -------------------------------------------------------------------------------- /data/transforms/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/transforms/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/resnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/modeling/backbones/__pycache__/resnet.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/senet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/modeling/backbones/__pycache__/senet.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dukemtmcreid.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/dukemtmcreid.cpython-35.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/transforms.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/transforms/__pycache__/transforms.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/modeling/backbones/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/densenet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/modeling/backbones/__pycache__/densenet.cpython-35.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dataset_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/datasets/__pycache__/dataset_loader.cpython-35.pyc -------------------------------------------------------------------------------- /data/samplers/__pycache__/triplet_sampler.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/data/samplers/__pycache__/triplet_sampler.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/inception.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/modeling/backbones/__pycache__/inception.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/mobilenet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/modeling/backbones/__pycache__/mobilenet.cpython-35.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/squeezenet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ttaalle/multi-modal-vehicle-Re-ID/HEAD/modeling/backbones/__pycache__/squeezenet.cpython-35.pyc -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import make_optimizer, make_optimizer_with_center 8 | from .lr_scheduler import WarmupMultiStepLR -------------------------------------------------------------------------------- /data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .triplet_sampler import RandomIdentitySampler, RandomIdentitySampler_alignedreid # new add by gu 8 | -------------------------------------------------------------------------------- /data/collate_batch.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | 9 | 10 | def train_collate_fn(batch): 11 | imgs1, imgs2, imgs3, pids, _, _, = zip(*batch) 12 | pids = torch.tensor(pids, dtype=torch.int64) 13 | return torch.stack(imgs1, dim=0),torch.stack(imgs2, dim=0),torch.stack(imgs3, dim=0), pids 14 | 15 | 16 | def val_collate_fn(batch): 17 | imgs1, imgs2, imgs3, pids, camids, _ = zip(*batch) 18 | return torch.stack(imgs1, dim=0),torch.stack(imgs2, dim=0),torch.stack(imgs3, dim=0), pids, camids 19 | -------------------------------------------------------------------------------- /tests/lr_scheduler_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import unittest 3 | 4 | import torch 5 | from torch import nn 6 | 7 | sys.path.append('.') 8 | from solver.lr_scheduler import WarmupMultiStepLR 9 | from solver.build import make_optimizer 10 | from config import cfg 11 | 12 | 13 | class MyTestCase(unittest.TestCase): 14 | def test_something(self): 15 | net = nn.Linear(10, 10) 16 | optimizer = make_optimizer(cfg, net) 17 | lr_scheduler = WarmupMultiStepLR(optimizer, [20, 40], warmup_iters=10) 18 | for i in range(50): 19 | lr_scheduler.step() 20 | for j in range(3): 21 | print(i, lr_scheduler.get_lr()[0]) 22 | optimizer.step() 23 | 24 | 25 | if __name__ == '__main__': 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .baseline import Baseline, Baseline4, Baseline5, Baseline6 8 | 9 | 10 | def build_model(cfg, num_classes): 11 | # if cfg.MODEL.NAME == 'resnet50': 12 | # model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT) 13 | # model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE) 14 | model = Baseline6(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE) 15 | # model = Baseline5(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE) 16 | return model 17 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | 12 | def setup_logger(name, save_dir, distributed_rank): 13 | logger = logging.getLogger(name) 14 | logger.setLevel(logging.DEBUG) 15 | # don't log results for the non-master process 16 | if distributed_rank > 0: 17 | return logger 18 | ch = logging.StreamHandler(stream=sys.stdout) 19 | ch.setLevel(logging.DEBUG) 20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 21 | ch.setFormatter(formatter) 22 | logger.addHandler(ch) 23 | 24 | if save_dir: 25 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w') 26 | fh.setLevel(logging.DEBUG) 27 | fh.setFormatter(formatter) 28 | logger.addHandler(fh) 29 | 30 | return logger 31 | -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import errno 8 | import json 9 | import os 10 | 11 | import os.path as osp 12 | 13 | 14 | def mkdir_if_missing(directory): 15 | if not osp.exists(directory): 16 | try: 17 | os.makedirs(directory) 18 | except OSError as e: 19 | if e.errno != errno.EEXIST: 20 | raise 21 | 22 | 23 | def check_isfile(path): 24 | isfile = osp.isfile(path) 25 | if not isfile: 26 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 27 | return isfile 28 | 29 | 30 | def read_json(fpath): 31 | with open(fpath, 'r') as f: 32 | obj = json.load(f) 33 | return obj 34 | 35 | 36 | def write_json(obj, fpath): 37 | mkdir_if_missing(osp.dirname(fpath)) 38 | with open(fpath, 'w') as f: 39 | json.dump(obj, f, indent=4, separators=(',', ': ')) 40 | -------------------------------------------------------------------------------- /data/transforms/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import torchvision.transforms as T 8 | 9 | from .transforms import RandomErasing 10 | 11 | 12 | def build_transforms(cfg, is_train=True): 13 | normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 14 | if is_train: 15 | transform = T.Compose([ 16 | T.Resize(cfg.INPUT.SIZE_TRAIN), 17 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 18 | T.Pad(cfg.INPUT.PADDING), 19 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 20 | T.ToTensor(), 21 | normalize_transform, 22 | RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN) 23 | ]) 24 | else: 25 | transform = T.Compose([ 26 | T.Resize(cfg.INPUT.SIZE_TEST), 27 | T.ToTensor(), 28 | normalize_transform 29 | ]) 30 | 31 | return transform 32 | -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | #from .cuhk03 import CUHK03 7 | from .dukemtmcreid import DukeMTMCreID 8 | from .market1501 import Market1501 9 | from .msmt17 import MSMT17 10 | from .dataset_loader import ImageDataset 11 | from .veri import Veri 12 | from .vehicleid import Vehicleid 13 | from .cityflow import Cityflow 14 | from .irour import Irour 15 | from .rgbour import Rgbour 16 | from .rgbir import Rgbir 17 | 18 | __factory = { 19 | 'market1501': Market1501, 20 | # 'cuhk03': CUHK03, 21 | 'dukemtmc': DukeMTMCreID, 22 | 'msmt17': MSMT17, 23 | 'veri': Veri, 24 | 'vehicleid': Vehicleid, 25 | 'cityflow': Cityflow, 26 | 'irour':Irour, 27 | 'rgbour':Rgbour, 28 | 'rgbir':Rgbir, 29 | } 30 | 31 | 32 | def get_names(): 33 | return __factory.keys() 34 | 35 | 36 | def init_dataset(name, *args, **kwargs): 37 | if name not in __factory.keys(): 38 | raise KeyError("Unknown datasets: {}".format(name)) 39 | return __factory[name](*args, **kwargs) 40 | -------------------------------------------------------------------------------- /softmax_triplet.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | NAME: 'resnet50' 4 | PRETRAIN_PATH: '/home/lihongchao/.torch/models/resnet50-19c8e357.pth' 5 | METRIC_LOSS_TYPE: 'triplet' 6 | IF_LABELSMOOTH: 'yes' 7 | IF_WITH_CENTER: 'no' 8 | 9 | 10 | 11 | 12 | INPUT: 13 | SIZE_TRAIN: [128, 256] 14 | SIZE_TEST: [128, 256] 15 | PROB: 0.5 # random horizontal flip 16 | RE_PROB: 0.5 # random erasing 17 | PADDING: 10 18 | 19 | DATASETS: 20 | NAMES: ('market1501') 21 | 22 | DATALOADER: 23 | SAMPLER: 'softmax_triplet' 24 | NUM_INSTANCE: 4 25 | NUM_WORKERS: 8 26 | 27 | SOLVER: 28 | OPTIMIZER_NAME: 'Adam' 29 | MAX_EPOCHS: 120 30 | BASE_LR: 0.00035 31 | 32 | CLUSTER_MARGIN: 0.3 33 | 34 | CENTER_LR: 0.5 35 | CENTER_LOSS_WEIGHT: 0.0005 36 | 37 | RANGE_K: 2 38 | RANGE_MARGIN: 0.3 39 | RANGE_ALPHA: 0 40 | RANGE_BETA: 1 41 | RANGE_LOSS_WEIGHT: 1 42 | 43 | BIAS_LR_FACTOR: 1 44 | WEIGHT_DECAY: 0.0005 45 | WEIGHT_DECAY_BIAS: 0.0005 46 | IMS_PER_BATCH: 32 47 | 48 | STEPS: [40, 70] 49 | GAMMA: 0.1 50 | 51 | WARMUP_FACTOR: 0.01 52 | WARMUP_ITERS: 10 53 | WARMUP_METHOD: 'linear' 54 | 55 | CHECKPOINT_PERIOD: 40 56 | LOG_PERIOD: 20 57 | EVAL_PERIOD: 40 58 | 59 | TEST: 60 | IMS_PER_BATCH: 32 61 | RE_RANKING: 'yes' 62 | WEIGHT: "path" 63 | NECK_FEAT: 'after' 64 | FEAT_NORM: 'yes' 65 | 66 | OUTPUT_DIR: "/home/lihongchao/reid-strong-baseline/lhc/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on" 67 | 68 | 69 | -------------------------------------------------------------------------------- /data/datasets/dataset_loader.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os.path as osp 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | import pdb 11 | 12 | def read_image(img_path): 13 | """Keep reading image until succeed. 14 | This can avoid IOError incurred by heavy IO process.""" 15 | got_img = False 16 | if not osp.exists(img_path): 17 | raise IOError("{} does not exist".format(img_path)) 18 | while not got_img: 19 | try: 20 | img = Image.open(img_path).convert('RGB') 21 | #img1 = img 22 | img1 = img.crop((0, 0, 256, 128)) 23 | img2 = img.crop((256, 0, 512, 128)) 24 | img3 = img.crop((512, 0, 768, 128)) 25 | #print(img1) 26 | #print(img2) 27 | #print(img3) 28 | #pdb.set_trace() 29 | got_img = True 30 | except IOError: 31 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 32 | pass 33 | return img1, img2, img3 34 | 35 | 36 | class ImageDataset(Dataset): 37 | """Image Person ReID Dataset""" 38 | 39 | def __init__(self, dataset, transform=None): 40 | self.dataset = dataset 41 | self.transform = transform 42 | 43 | def __len__(self): 44 | return len(self.dataset) 45 | 46 | def __getitem__(self, index): 47 | img_path, pid, camid = self.dataset[index] 48 | #pdb.set_trace() 49 | img1, img2, img3 = read_image(img_path) 50 | #print(img1) 51 | #T.functional.crop(img,0,0,640,360) 52 | #print(img) 53 | #pdb.set_trace() 54 | if self.transform is not None: 55 | img1 = self.transform(img1) 56 | img2 = self.transform(img2) 57 | img3 = self.transform(img3) 58 | #print(img1) 59 | #pdb.set_trace() 60 | return img1, img2, img3, pid, camid, img_path 61 | -------------------------------------------------------------------------------- /solver/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | from torchcontrib.optim import SWA 9 | 10 | def make_optimizer(cfg, model): 11 | params = [] 12 | for key, value in model.named_parameters(): 13 | if not value.requires_grad: 14 | continue 15 | lr = cfg.SOLVER.BASE_LR 16 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 17 | if "bias" in key: 18 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 19 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 20 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 23 | # training loop 24 | optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05) 25 | 26 | 27 | else: 28 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 29 | return optimizer 30 | 31 | 32 | def make_optimizer_with_center(cfg, model, center_criterion): 33 | params = [] 34 | for key, value in model.named_parameters(): 35 | if not value.requires_grad: 36 | continue 37 | lr = cfg.SOLVER.BASE_LR 38 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 39 | if "bias" in key: 40 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 41 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 42 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 43 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 44 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 45 | else: 46 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 47 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 48 | return optimizer, optimizer_center 49 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | from .collate_batch import train_collate_fn, val_collate_fn 10 | from .datasets import init_dataset, ImageDataset 11 | from .samplers import RandomIdentitySampler, RandomIdentitySampler_alignedreid # New add by gu 12 | from .transforms import build_transforms 13 | import pdb 14 | 15 | def make_data_loader(cfg): 16 | train_transforms = build_transforms(cfg, is_train=True) 17 | val_transforms = build_transforms(cfg, is_train=False) 18 | num_workers = cfg.DATALOADER.NUM_WORKERS 19 | if len(cfg.DATASETS.NAMES) == 1: 20 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR) 21 | else: 22 | # TODO: add multi dataset to train 23 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR) 24 | 25 | num_classes = dataset.num_train_pids 26 | train_set = ImageDataset(dataset.train, train_transforms) 27 | #pdb.set_trace() 28 | if cfg.DATALOADER.SAMPLER == 'softmax': 29 | train_loader = DataLoader( 30 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 31 | collate_fn=train_collate_fn 32 | ) 33 | else: 34 | train_loader = DataLoader( 35 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 36 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 37 | # sampler=RandomIdentitySampler_alignedreid(dataset.train, cfg.DATALOADER.NUM_INSTANCE), # new add by gu 38 | num_workers=num_workers, collate_fn=train_collate_fn 39 | ) 40 | 41 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) 42 | val_loader = DataLoader( 43 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 44 | collate_fn=val_collate_fn 45 | ) 46 | #pdb.set_trace() 47 | return train_loader, val_loader, len(dataset.query), num_classes 48 | -------------------------------------------------------------------------------- /data/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import math 8 | import random 9 | 10 | 11 | class RandomErasing(object): 12 | """ Randomly selects a rectangle region in an image and erases its pixels. 13 | 'Random Erasing Data Augmentation' by Zhong et al. 14 | See https://arxiv.org/pdf/1708.04896.pdf 15 | Args: 16 | probability: The probability that the Random Erasing operation will be performed. 17 | sl: Minimum proportion of erased area against input image. 18 | sh: Maximum proportion of erased area against input image. 19 | r1: Minimum aspect ratio of erased area. 20 | mean: Erasing value. 21 | """ 22 | 23 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 24 | self.probability = probability 25 | self.mean = mean 26 | self.sl = sl 27 | self.sh = sh 28 | self.r1 = r1 29 | 30 | def __call__(self, img): 31 | 32 | if random.uniform(0, 1) >= self.probability: 33 | return img 34 | 35 | for attempt in range(100): 36 | area = img.size()[1] * img.size()[2] 37 | 38 | target_area = random.uniform(self.sl, self.sh) * area 39 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 40 | 41 | h = int(round(math.sqrt(target_area * aspect_ratio))) 42 | w = int(round(math.sqrt(target_area / aspect_ratio))) 43 | 44 | if w < img.size()[2] and h < img.size()[1]: 45 | x1 = random.randint(0, img.size()[1] - h) 46 | y1 = random.randint(0, img.size()[2] - w) 47 | if img.size()[0] == 3: 48 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 49 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 50 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 51 | else: 52 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 53 | return img 54 | 55 | return img 56 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | from os import mkdir 11 | 12 | import torch 13 | import pdb 14 | from torch.backends import cudnn 15 | 16 | sys.path.append('.') 17 | from config import cfg 18 | from data import make_data_loader 19 | from engine.inference import inference 20 | from modeling import build_model 21 | from utils.logger import setup_logger 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser(description="ReID Baseline Inference") 26 | parser.add_argument( 27 | "--config_file", default="", help="path to config file", type=str 28 | ) 29 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 30 | nargs=argparse.REMAINDER) 31 | 32 | args = parser.parse_args() 33 | 34 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 35 | 36 | if args.config_file != "": 37 | cfg.merge_from_file(args.config_file) 38 | cfg.merge_from_list(args.opts) 39 | cfg.freeze() 40 | 41 | output_dir = cfg.OUTPUT_DIR 42 | if output_dir and not os.path.exists(output_dir): 43 | mkdir(output_dir) 44 | 45 | logger = setup_logger("reid_baseline", output_dir, 0) 46 | logger.info("Using {} GPUS".format(num_gpus)) 47 | logger.info(args) 48 | 49 | if args.config_file != "": 50 | logger.info("Loaded configuration file {}".format(args.config_file)) 51 | with open(args.config_file, 'r') as cf: 52 | config_str = "\n" + cf.read() 53 | logger.info(config_str) 54 | logger.info("Running with config:\n{}".format(cfg)) 55 | 56 | if cfg.MODEL.DEVICE == "cuda": 57 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 58 | cudnn.benchmark = True 59 | 60 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) 61 | model = build_model(cfg, num_classes) 62 | model.load_param(cfg.TEST.WEIGHT) 63 | #pdb.set_trace() 64 | 65 | inference(cfg, model, val_loader, num_query) 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /layers/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | classes = torch.arange(self.num_classes).long() 43 | if self.use_gpu: classes = classes.cuda() 44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | dist = [] 48 | for i in range(batch_size): 49 | value = distmat[i][mask[i]] 50 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 51 | dist.append(value) 52 | dist = torch.cat(dist) 53 | loss = dist.mean() 54 | return loss 55 | 56 | 57 | if __name__ == '__main__': 58 | use_gpu = False 59 | center_loss = CenterLoss(use_gpu=use_gpu) 60 | features = torch.rand(16, 2048) 61 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 62 | if use_gpu: 63 | features = torch.rand(16, 2048).cuda() 64 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 65 | 66 | loss = center_loss(features, targets) 67 | print(loss) -------------------------------------------------------------------------------- /data/datasets/eval_reid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | import pdb 9 | 10 | 11 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 12 | """Evaluation with market1501 metric 13 | Key: for each query identity, its gallery images from the same camera view are discarded. 14 | """ 15 | num_q, num_g = distmat.shape 16 | if num_g < max_rank: 17 | max_rank = num_g 18 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 19 | indices = np.argsort(distmat, axis=1) 20 | 21 | 22 | query_arg = np.argsort(q_pids, axis=0) 23 | result = g_pids[indices] 24 | gall_re = result[query_arg] 25 | gall_re = gall_re.astype(np.str) 26 | #pdb.set_trace() 27 | 28 | result = gall_re[:,:100] 29 | 30 | with open("re.txt", 'w') as file_obj: 31 | for li in result: 32 | for j in range(len(li)): 33 | if j == len(li) - 1: 34 | file_obj.write(li[j] + "\n") 35 | else: 36 | file_obj.write(li[j] + " ") 37 | 38 | #pdb.set_trace() 39 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 40 | 41 | # compute cmc curve for each query 42 | all_cmc = [] 43 | all_AP = [] 44 | num_valid_q = 0. # number of valid query 45 | for q_idx in range(num_q): 46 | # get query pid and camid 47 | q_pid = q_pids[q_idx] 48 | q_camid = q_camids[q_idx] 49 | 50 | # remove gallery samples that have the same pid and camid with query 51 | order = indices[q_idx] 52 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 53 | keep = np.invert(remove) 54 | 55 | # compute cmc curve 56 | # binary vector, positions with value 1 are correct matches 57 | orig_cmc = matches[q_idx][keep] 58 | if not np.any(orig_cmc): 59 | # this condition is true when query identity does not appear in gallery 60 | continue 61 | 62 | cmc = orig_cmc.cumsum() 63 | cmc[cmc > 1] = 1 64 | 65 | all_cmc.append(cmc[:max_rank]) 66 | num_valid_q += 1. 67 | 68 | # compute average precision 69 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 70 | num_rel = orig_cmc.sum() 71 | tmp_cmc = orig_cmc.cumsum() 72 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 73 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 74 | AP = tmp_cmc.sum() / num_rel 75 | all_AP.append(AP) 76 | 77 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 78 | 79 | all_cmc = np.asarray(all_cmc).astype(np.float32) 80 | all_cmc = all_cmc.sum(0) / num_valid_q 81 | mAP = np.mean(all_AP) 82 | 83 | return all_cmc, mAP 84 | -------------------------------------------------------------------------------- /engine/inference.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import logging 7 | 8 | import torch 9 | import torch.nn as nn 10 | from ignite.engine import Engine 11 | 12 | from utils.reid_metric import R1_mAP, R1_mAP_reranking 13 | 14 | 15 | def create_supervised_evaluator(model, metrics, 16 | device=None): 17 | """ 18 | Factory function for creating an evaluator for supervised models 19 | 20 | Args: 21 | model (`torch.nn.Module`): the model to train 22 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics 23 | device (str, optional): device type specification (default: None). 24 | Applies to both model and batches. 25 | Returns: 26 | Engine: an evaluator engine with supervised inference function 27 | """ 28 | if device: 29 | if torch.cuda.device_count() > 1: 30 | model = nn.DataParallel(model) 31 | model.to(device) 32 | 33 | def _inference(engine, batch): 34 | model.eval() 35 | with torch.no_grad(): 36 | data1, data2, data3, pids, camids = batch 37 | data1 = data1.to(device) if torch.cuda.device_count() >= 1 else data1 38 | data2 = data2.to(device) if torch.cuda.device_count() >= 1 else data2 39 | data3 = data3.to(device) if torch.cuda.device_count() >= 1 else data3 40 | #feat = model(data1, data2, data3) 41 | feat = model(data1) 42 | return feat, pids, camids 43 | 44 | engine = Engine(_inference) 45 | 46 | for name, metric in metrics.items(): 47 | metric.attach(engine, name) 48 | 49 | return engine 50 | 51 | 52 | def inference( 53 | cfg, 54 | model, 55 | val_loader, 56 | num_query 57 | ): 58 | device = cfg.MODEL.DEVICE 59 | 60 | logger = logging.getLogger("reid_baseline.inference") 61 | logger.info("Enter inferencing") 62 | if cfg.TEST.RE_RANKING == 'no': 63 | print("Create evaluator") 64 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, 65 | device=device) 66 | elif cfg.TEST.RE_RANKING == 'yes': 67 | print("Create evaluator for reranking") 68 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP_reranking(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, 69 | device=device) 70 | else: 71 | print("Unsupported re_ranking config. Only support for no or yes, but got {}.".format(cfg.TEST.RE_RANKING)) 72 | 73 | evaluator.run(val_loader) 74 | cmc, mAP = evaluator.state.metrics['r1_mAP'] 75 | logger.info('Validation Results') 76 | logger.info("mAP: {:.1%}".format(mAP)) 77 | for r in [1, 5, 10]: 78 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 79 | -------------------------------------------------------------------------------- /data/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import pdb 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class Market1501(BaseImageDataset): 16 | """ 17 | Market1501 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | URL: http://www.liangzheng.org/Project/project_reid.html 21 | 22 | Dataset statistics: 23 | # identities: 1501 (+1 for background) 24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 25 | """ 26 | dataset_dir = 'market1501' 27 | 28 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 29 | super(Market1501, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 32 | self.query_dir = osp.join(self.dataset_dir, 'query') 33 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 34 | 35 | self._check_before_run() 36 | 37 | train = self._process_dir(self.train_dir, relabel=True) 38 | query = self._process_dir(self.query_dir, relabel=False) 39 | gallery = self._process_dir(self.gallery_dir, relabel=False) 40 | 41 | if verbose: 42 | print("=> Market1501 loaded") 43 | self.print_dataset_statistics(train, query, gallery) 44 | 45 | self.train = train 46 | self.query = query 47 | self.gallery = gallery 48 | 49 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 50 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 51 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 52 | 53 | def _check_before_run(self): 54 | """Check if all files are available before going deeper""" 55 | if not osp.exists(self.dataset_dir): 56 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 57 | if not osp.exists(self.train_dir): 58 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 59 | if not osp.exists(self.query_dir): 60 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 61 | if not osp.exists(self.gallery_dir): 62 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 63 | 64 | def _process_dir(self, dir_path, relabel=False): 65 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 66 | pattern = re.compile(r'([-\d]+)_c(\d)') 67 | 68 | pid_container = set() 69 | for img_path in img_paths: 70 | pid, _ = map(int, pattern.search(img_path).groups()) 71 | if pid == -1: continue # junk images are just ignored 72 | pid_container.add(pid) 73 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 74 | 75 | dataset = [] 76 | for img_path in img_paths: 77 | pid, camid = map(int, pattern.search(img_path).groups()) 78 | pdb.set_trace() 79 | if pid == -1: continue # junk images are just ignored 80 | assert 0 <= pid <= 1501 # pid == 0 means background 81 | assert 1 <= camid <= 6 82 | camid -= 1 # index starts from 0 83 | if relabel: pid = pid2label[pid] 84 | dataset.append((img_path, pid, camid)) 85 | return dataset 86 | -------------------------------------------------------------------------------- /data/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/1/17 15:00 4 | # @Author : Hao Luo 5 | # @File : msmt17.py 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class MSMT17(BaseImageDataset): 16 | """ 17 | MSMT17 18 | 19 | Reference: 20 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 21 | 22 | URL: http://www.pkuvmc.com/publications/msmt17.html 23 | 24 | Dataset statistics: 25 | # identities: 4101 26 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 27 | # cameras: 15 28 | """ 29 | dataset_dir = 'msmt17' 30 | 31 | def __init__(self,root='/home/haoluo/data', verbose=True, **kwargs): 32 | super(MSMT17, self).__init__() 33 | self.dataset_dir = osp.join(root, self.dataset_dir) 34 | self.train_dir = osp.join(self.dataset_dir, 'MSMT17_V2/mask_train_v2') 35 | self.test_dir = osp.join(self.dataset_dir, 'MSMT17_V2/mask_test_v2') 36 | self.list_train_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_train.txt') 37 | self.list_val_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_val.txt') 38 | self.list_query_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_query.txt') 39 | self.list_gallery_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_gallery.txt') 40 | 41 | self._check_before_run() 42 | train = self._process_dir(self.train_dir, self.list_train_path) 43 | #val, num_val_pids, num_val_imgs = self._process_dir(self.train_dir, self.list_val_path) 44 | query = self._process_dir(self.test_dir, self.list_query_path) 45 | gallery = self._process_dir(self.test_dir, self.list_gallery_path) 46 | if verbose: 47 | print("=> MSMT17 loaded") 48 | self.print_dataset_statistics(train, query, gallery) 49 | 50 | self.train = train 51 | self.query = query 52 | self.gallery = gallery 53 | 54 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 55 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 56 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 57 | 58 | def _check_before_run(self): 59 | """Check if all files are available before going deeper""" 60 | if not osp.exists(self.dataset_dir): 61 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 62 | if not osp.exists(self.train_dir): 63 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 64 | if not osp.exists(self.test_dir): 65 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 66 | 67 | def _process_dir(self, dir_path, list_path): 68 | with open(list_path, 'r') as txt: 69 | lines = txt.readlines() 70 | dataset = [] 71 | pid_container = set() 72 | for img_idx, img_info in enumerate(lines): 73 | img_path, pid = img_info.split(' ') 74 | pid = int(pid) # no need to relabel 75 | camid = int(img_path.split('_')[2]) 76 | img_path = osp.join(dir_path, img_path) 77 | dataset.append((img_path, pid, camid)) 78 | pid_container.add(pid) 79 | 80 | # check if pid starts from 0 and increments with 1 81 | for idx, pid in enumerate(pid_container): 82 | assert idx == pid, "See code comment for explanation" 83 | return dataset -------------------------------------------------------------------------------- /data/datasets/irour.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import pdb 10 | import os.path as osp 11 | import numpy as np 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class Irour(BaseImageDataset): 16 | """ 17 | Market1501 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | URL: http://www.liangzheng.org/Project/project_reid.html 21 | 22 | Dataset statistics: 23 | # identities: 1501 (+1 for background) 24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 25 | """ 26 | dataset_dir = 'irour' 27 | 28 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 29 | super(Irour, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 32 | self.query_dir = osp.join(self.dataset_dir, 'query') 33 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 34 | 35 | self._check_before_run() 36 | 37 | train = self._process_dir(self.train_dir, relabel=True) 38 | query = self._process_dir(self.query_dir, relabel=False) 39 | gallery = self._process_dir(self.gallery_dir, relabel=False) 40 | #pdb.set_trace() 41 | if verbose: 42 | print("=> Irour loaded") 43 | self.print_dataset_statistics(train, query, gallery) 44 | 45 | self.train = train 46 | self.query = query 47 | self.gallery = gallery 48 | 49 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 50 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 51 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 52 | #pdb.set_trace() 53 | 54 | def _check_before_run(self): 55 | """Check if all files are available before going deeper""" 56 | if not osp.exists(self.dataset_dir): 57 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 58 | if not osp.exists(self.train_dir): 59 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 60 | if not osp.exists(self.query_dir): 61 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 62 | if not osp.exists(self.gallery_dir): 63 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 64 | 65 | def _process_dir(self, dir_path, relabel=False): 66 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 67 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 68 | 69 | pid_container = set() 70 | for img_path in img_paths: 71 | pid, _ = map(int, pattern.search(img_path).groups()) 72 | if pid == -1: continue # junk images are just ignored 73 | pid_container.add(pid) 74 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 75 | 76 | dataset = [] 77 | for img_path in img_paths: 78 | pid, camid = map(int, pattern.search(img_path).groups()) 79 | #pdb.set_trace() 80 | #if pid == -1: continue # junk images are just ignored 81 | assert 1 <= pid <= 20 # pid == 0 means background 82 | assert 1 <= camid <= 5 83 | camid -= 1 # index starts from 0 84 | if relabel: pid = pid2label[pid] 85 | dataset.append((img_path, pid, camid)) 86 | return dataset 87 | -------------------------------------------------------------------------------- /data/datasets/rgbir.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import pdb 10 | import os.path as osp 11 | import numpy as np 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class Rgbir(BaseImageDataset): 16 | """ 17 | Market1501 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | URL: http://www.liangzheng.org/Project/project_reid.html 21 | 22 | Dataset statistics: 23 | # identities: 1501 (+1 for background) 24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 25 | """ 26 | dataset_dir = 'rgbir' 27 | 28 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 29 | super(Rgbir, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 32 | self.query_dir = osp.join(self.dataset_dir, 'query') 33 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 34 | 35 | self._check_before_run() 36 | 37 | train = self._process_dir(self.train_dir, relabel=True) 38 | query = self._process_dir(self.query_dir, relabel=False) 39 | gallery = self._process_dir(self.gallery_dir, relabel=False) 40 | #pdb.set_trace() 41 | if verbose: 42 | print("=> RGB_IR loaded") 43 | self.print_dataset_statistics(train, query, gallery) 44 | 45 | self.train = train 46 | self.query = query 47 | self.gallery = gallery 48 | 49 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 50 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 51 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 52 | #pdb.set_trace() 53 | 54 | def _check_before_run(self): 55 | """Check if all files are available before going deeper""" 56 | if not osp.exists(self.dataset_dir): 57 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 58 | if not osp.exists(self.train_dir): 59 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 60 | if not osp.exists(self.query_dir): 61 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 62 | if not osp.exists(self.gallery_dir): 63 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 64 | 65 | def _process_dir(self, dir_path, relabel=False): 66 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 67 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 68 | 69 | pid_container = set() 70 | for img_path in img_paths: 71 | pid, _ = map(int, pattern.search(img_path).groups()) 72 | if pid == -1: continue # junk images are just ignored 73 | pid_container.add(pid) 74 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 75 | 76 | dataset = [] 77 | for img_path in img_paths: 78 | pid, camid = map(int, pattern.search(img_path).groups()) 79 | #pdb.set_trace() 80 | #if pid == -1: continue # junk images are just ignored 81 | assert 1 <= pid <= 600 # pid == 0 means background 82 | assert 1 <= camid <= 8 83 | camid -= 1 # index starts from 0 84 | if relabel: pid = pid2label[pid] 85 | dataset.append((img_path, pid, camid)) 86 | return dataset 87 | -------------------------------------------------------------------------------- /data/datasets/veri.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import pdb 10 | import os.path as osp 11 | import numpy as np 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class Veri(BaseImageDataset): 16 | """ 17 | Market1501 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | URL: http://www.liangzheng.org/Project/project_reid.html 21 | 22 | Dataset statistics: 23 | # identities: 1501 (+1 for background) 24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 25 | """ 26 | dataset_dir = 'veri' 27 | 28 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 29 | super(Veri, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 32 | self.query_dir = osp.join(self.dataset_dir, 'query') 33 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 34 | 35 | self._check_before_run() 36 | 37 | train = self._process_dir(self.train_dir, relabel=True) 38 | query = self._process_dir(self.query_dir, relabel=False) 39 | gallery = self._process_dir(self.gallery_dir, relabel=False) 40 | #pdb.set_trace() 41 | if verbose: 42 | print("=> Veri-776 loaded") 43 | self.print_dataset_statistics(train, query, gallery) 44 | 45 | self.train = train 46 | self.query = query 47 | self.gallery = gallery 48 | 49 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 50 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 51 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 52 | #pdb.set_trace() 53 | 54 | def _check_before_run(self): 55 | """Check if all files are available before going deeper""" 56 | if not osp.exists(self.dataset_dir): 57 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 58 | if not osp.exists(self.train_dir): 59 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 60 | if not osp.exists(self.query_dir): 61 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 62 | if not osp.exists(self.gallery_dir): 63 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 64 | 65 | def _process_dir(self, dir_path, relabel=False): 66 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 67 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 68 | 69 | pid_container = set() 70 | for img_path in img_paths: 71 | pid, _ = map(int, pattern.search(img_path).groups()) 72 | if pid == -1: continue # junk images are just ignored 73 | pid_container.add(pid) 74 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 75 | 76 | dataset = [] 77 | for img_path in img_paths: 78 | pid, camid = map(int, pattern.search(img_path).groups()) 79 | #pdb.set_trace() 80 | #if pid == -1: continue # junk images are just ignored 81 | assert 1 <= pid <= 776 # pid == 0 means background 82 | assert 1 <= camid <= 20 83 | camid -= 1 # index starts from 0 84 | if relabel: pid = pid2label[pid] 85 | dataset.append((img_path, pid, camid)) 86 | return dataset 87 | -------------------------------------------------------------------------------- /data/datasets/cityflow.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import pdb 10 | import os.path as osp 11 | import numpy as np 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class Cityflow(BaseImageDataset): 16 | """ 17 | Market1501 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | URL: http://www.liangzheng.org/Project/project_reid.html 21 | 22 | Dataset statistics: 23 | # identities: 1501 (+1 for background) 24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 25 | """ 26 | dataset_dir = 'cityflow' 27 | 28 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 29 | super(Cityflow, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 32 | self.query_dir = osp.join(self.dataset_dir, 'query') 33 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 34 | 35 | self._check_before_run() 36 | 37 | train = self._process_dir(self.train_dir, relabel=True) 38 | query = self._process_dir(self.query_dir, relabel=False) 39 | gallery = self._process_dir(self.gallery_dir, relabel=False) 40 | #pdb.set_trace() 41 | if verbose: 42 | print("=> City-flow loaded") 43 | self.print_dataset_statistics(train, query, gallery) 44 | 45 | self.train = train 46 | self.query = query 47 | self.gallery = gallery 48 | 49 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 50 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 51 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 52 | pdb.set_trace() 53 | 54 | def _check_before_run(self): 55 | """Check if all files are available before going deeper""" 56 | if not osp.exists(self.dataset_dir): 57 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 58 | if not osp.exists(self.train_dir): 59 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 60 | if not osp.exists(self.query_dir): 61 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 62 | if not osp.exists(self.gallery_dir): 63 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 64 | 65 | def _process_dir(self, dir_path, relabel=False): 66 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 67 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 68 | 69 | pid_container = set() 70 | for img_path in img_paths: 71 | pid, _ = map(int, pattern.search(img_path).groups()) 72 | if pid == -1: continue # junk images are just ignored 73 | pid_container.add(pid) 74 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 75 | 76 | dataset = [] 77 | for img_path in img_paths: 78 | pid, camid = map(int, pattern.search(img_path).groups()) 79 | #pdb.set_trace() 80 | #if pid == -1: continue # junk images are just ignored 81 | assert 1 <= pid <= 21052 # pid == 0 means background 82 | assert 1 <= camid <= 40 83 | camid -= 1 # index starts from 0 84 | if relabel: pid = pid2label[pid] 85 | dataset.append((img_path, pid, camid)) 86 | return dataset -------------------------------------------------------------------------------- /data/datasets/rgbour.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import pdb 10 | import os.path as osp 11 | import numpy as np 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class Rgbour(BaseImageDataset): 16 | """ 17 | Market1501 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | URL: http://www.liangzheng.org/Project/project_reid.html 21 | 22 | Dataset statistics: 23 | # identities: 1501 (+1 for background) 24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 25 | """ 26 | dataset_dir = 'rgbour' 27 | 28 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 29 | super(Rgbour, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 32 | self.query_dir = osp.join(self.dataset_dir, 'query') 33 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 34 | 35 | 36 | self._check_before_run() 37 | 38 | train = self._process_dir(self.train_dir, relabel=True) 39 | query = self._process_dir(self.query_dir, relabel=False) 40 | gallery = self._process_dir(self.gallery_dir, relabel=False) 41 | 42 | if verbose: 43 | print("=> Rgbour loaded") 44 | self.print_dataset_statistics(train, query, gallery) 45 | 46 | self.train = train 47 | self.query = query 48 | self.gallery = gallery 49 | 50 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 51 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 52 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 53 | #pdb.set_trace() 54 | 55 | def _check_before_run(self): 56 | """Check if all files are available before going deeper""" 57 | if not osp.exists(self.dataset_dir): 58 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 59 | if not osp.exists(self.train_dir): 60 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 61 | if not osp.exists(self.query_dir): 62 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 63 | if not osp.exists(self.gallery_dir): 64 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 65 | 66 | def _process_dir(self, dir_path, relabel=False): 67 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 68 | #for py in img_paths: 69 | # print(py) 70 | # pdb.set_trace() 71 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 72 | 73 | pid_container = set() 74 | for img_path in img_paths: 75 | pid, _ = map(int, pattern.search(img_path).groups()) 76 | if pid == -1: continue # junk images are just ignored 77 | pid_container.add(pid) 78 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 79 | 80 | dataset = [] 81 | for img_path in img_paths: 82 | pid, camid = map(int, pattern.search(img_path).groups()) 83 | #pdb.set_trace() 84 | #if pid == -1: continue # junk images are just ignored 85 | assert 1 <= pid <= 20 # pid == 0 means background 86 | assert 1 <= camid <= 5 87 | camid -= 1 # index starts from 0 88 | if relabel: pid = pid2label[pid] 89 | dataset.append((img_path, pid, camid)) 90 | return dataset 91 | -------------------------------------------------------------------------------- /utils/reid_metric.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import pdb 10 | from ignite.metrics import Metric 11 | 12 | from data.datasets.eval_reid import eval_func 13 | from .re_ranking import re_ranking 14 | 15 | 16 | 17 | 18 | class R1_mAP(Metric): 19 | def __init__(self, num_query, max_rank=50, feat_norm='yes'): 20 | super(R1_mAP, self).__init__() 21 | self.num_query = num_query 22 | self.max_rank = max_rank 23 | self.feat_norm = feat_norm 24 | 25 | def reset(self): 26 | self.feats = [] 27 | self.pids = [] 28 | self.camids = [] 29 | 30 | def update(self, output): 31 | feat, pid, camid = output 32 | self.feats.append(feat) 33 | self.pids.extend(np.asarray(pid)) 34 | self.camids.extend(np.asarray(camid)) 35 | 36 | def compute(self): 37 | feats = torch.cat(self.feats, dim=0) 38 | if self.feat_norm == 'yes': 39 | print("The test feature is normalized") 40 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 41 | # query 42 | qf = feats[:self.num_query] 43 | q_pids = np.asarray(self.pids[:self.num_query]) 44 | q_camids = np.asarray(self.camids[:self.num_query]) 45 | # gallery 46 | gf = feats[self.num_query:] 47 | g_pids = np.asarray(self.pids[self.num_query:]) 48 | g_camids = np.asarray(self.camids[self.num_query:]) 49 | m, n = qf.shape[0], gf.shape[0] 50 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 51 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 52 | distmat.addmm_(1, -2, qf, gf.t()) 53 | distmat = distmat.cpu().numpy() 54 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 55 | 56 | return cmc, mAP 57 | 58 | 59 | class R1_mAP_reranking(Metric): 60 | def __init__(self, num_query, max_rank=50, feat_norm='yes'): 61 | super(R1_mAP_reranking, self).__init__() 62 | self.num_query = num_query 63 | self.max_rank = max_rank 64 | self.feat_norm = feat_norm 65 | 66 | def reset(self): 67 | self.feats = [] 68 | self.pids = [] 69 | self.camids = [] 70 | 71 | def update(self, output): 72 | feat, pid, camid = output 73 | self.feats.append(feat) 74 | self.pids.extend(np.asarray(pid)) 75 | self.camids.extend(np.asarray(camid)) 76 | 77 | def compute(self): 78 | feats = torch.cat(self.feats, dim=0) 79 | if self.feat_norm == 'yes': 80 | print("The test feature is normalized") 81 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 82 | 83 | # query 84 | qf = feats[:self.num_query] 85 | q_pids = np.asarray(self.pids[:self.num_query]) 86 | q_camids = np.asarray(self.camids[:self.num_query]) 87 | # gallery 88 | gf = feats[self.num_query:] 89 | g_pids = np.asarray(self.pids[self.num_query:]) 90 | g_camids = np.asarray(self.camids[self.num_query:]) 91 | # m, n = qf.shape[0], gf.shape[0] 92 | # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 93 | # torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 94 | # distmat.addmm_(1, -2, qf, gf.t()) 95 | # distmat = distmat.cpu().numpy() 96 | print("Enter reranking") 97 | distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 98 | #pdb.set_trace() 99 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) # modified 100 | 101 | return cmc, mAP -------------------------------------------------------------------------------- /data/datasets/VehicleIDDataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import ntpath 3 | import os 4 | import random 5 | 6 | from datasets.Dataset import Dataset 7 | 8 | 9 | class VehicleIDDataset(Dataset): 10 | #FILE_BY_PART = {'train': 'bounding_box_train', 'test': 'bounding_box_test', 'query': 'query', 'distractors': 'distractors500k'} 11 | FILE_BY_PART = {'train': 'train_VehicleID', 'test': 'test2400_gallery', 'query': 'test2400_query'} 12 | 13 | def __init__(self, data_directory, dataset_part, mean=None, std=None, num_classes=None, augment=True, png=True): 14 | if mean is None: 15 | mean = [99.100247, 103.980578, 104.326892] 16 | if std is None: 17 | std = [61.959524, 61.278575, 61.481962] 18 | if num_classes is None: 19 | num_classes = 13164 20 | 21 | super().__init__(mean=mean, std=std, num_classes=num_classes, data_directory=data_directory, dataset_part=dataset_part, augment=augment, png=png) 22 | 23 | def get_input_data(self, is_training): 24 | image_paths = self.get_images_from_folder() 25 | 26 | if is_training: 27 | random.shuffle(image_paths) 28 | 29 | file_names = [os.path.basename(file) for file in image_paths] 30 | 31 | actual_labels = [self.get_label_from_path(image_path) for image_path in image_paths] 32 | label_mapping = {label: index for index, label in enumerate(list(sorted(set(actual_labels))))} 33 | labels = [label_mapping[actual_label] for actual_label in actual_labels] 34 | views = [self.get_view_from_path(image_path) for image_path in image_paths] 35 | 36 | print('Read %d image paths for processing for dataset_part: %s' % (len(image_paths), self._dataset_part)) 37 | return image_paths, file_names, actual_labels, labels, views 38 | 39 | 40 | def get_number_of_samples(self): 41 | print(len(self.get_images_from_folder())) 42 | return len(self.get_images_from_folder()) 43 | 44 | def prepare_sliced_data_for_batching(self, sliced_input_data, image_size): 45 | image_path_tensor, file_name_tensor, actual_label_tensor, label_tensor, view_label = sliced_input_data 46 | image_tensor = self.read_and_distort_image(file_name_tensor, image_path_tensor, image_size) 47 | 48 | return self.get_dict_for_batching(actual_label_tensor=actual_label_tensor, file_name_tensor=file_name_tensor, image_path_tensor=image_path_tensor,image_tensor=image_tensor, label_tensor=label_tensor, view_label=view_label) 49 | 50 | def get_images_from_folder(self): 51 | data_file = self.get_data_file() 52 | return self.get_png_and_jpg(data_file) 53 | 54 | @staticmethod 55 | def get_png_and_jpg(data_file): 56 | all_images = glob.glob(os.path.join(data_file, '*.png')) 57 | all_images.extend(glob.glob(os.path.join(data_file, '*.jpg'))) 58 | return all_images 59 | 60 | def get_data_file(self): 61 | data_file = self.FILE_BY_PART[self._dataset_part] 62 | return os.path.join(self._data_directory, data_file) 63 | 64 | def get_input_function_dictionaries(self, batched_input_data): 65 | return {'paths': batched_input_data['path'], 'images': batched_input_data['image'], 'file_names': batched_input_data['file_name']}, \ 66 | {'labels': batched_input_data['label'], 'actual_labels': batched_input_data['actual_label'], 'views': batched_input_data['view']} 67 | 68 | @staticmethod 69 | def get_label_from_path(path): 70 | filename = ntpath.basename(path) 71 | label = filename.split('_')[0] 72 | return int(label) 73 | @staticmethod 74 | def get_view_from_path(path): 75 | filename = ntpath.basename(path) 76 | view = filename.split('_')[2][0] 77 | if (view=='1'): 78 | v=0 79 | else: 80 | v=1 81 | return v 82 | -------------------------------------------------------------------------------- /data/datasets/bases.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | 9 | 10 | class BaseDataset(object): 11 | """ 12 | Base class of reid dataset 13 | """ 14 | 15 | def get_imagedata_info(self, data): 16 | pids, cams = [], [] 17 | for _, pid, camid in data: 18 | pids += [pid] 19 | cams += [camid] 20 | pids = set(pids) 21 | cams = set(cams) 22 | num_pids = len(pids) 23 | num_cams = len(cams) 24 | num_imgs = len(data) 25 | return num_pids, num_imgs, num_cams 26 | 27 | def get_videodata_info(self, data, return_tracklet_stats=False): 28 | pids, cams, tracklet_stats = [], [], [] 29 | for img_paths, pid, camid in data: 30 | pids += [pid] 31 | cams += [camid] 32 | tracklet_stats += [len(img_paths)] 33 | pids = set(pids) 34 | cams = set(cams) 35 | num_pids = len(pids) 36 | num_cams = len(cams) 37 | num_tracklets = len(data) 38 | if return_tracklet_stats: 39 | return num_pids, num_tracklets, num_cams, tracklet_stats 40 | return num_pids, num_tracklets, num_cams 41 | 42 | def print_dataset_statistics(self): 43 | raise NotImplementedError 44 | 45 | 46 | class BaseImageDataset(BaseDataset): 47 | """ 48 | Base class of image reid dataset 49 | """ 50 | 51 | def print_dataset_statistics(self, train, query, gallery): 52 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 53 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 54 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 55 | 56 | print("Dataset statistics:") 57 | print(" ----------------------------------------") 58 | print(" subset | # ids | # images | # cameras") 59 | print(" ----------------------------------------") 60 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 61 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 62 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 63 | print(" ----------------------------------------") 64 | 65 | 66 | class BaseVideoDataset(BaseDataset): 67 | """ 68 | Base class of video reid dataset 69 | """ 70 | 71 | def print_dataset_statistics(self, train, query, gallery): 72 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \ 73 | self.get_videodata_info(train, return_tracklet_stats=True) 74 | 75 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \ 76 | self.get_videodata_info(query, return_tracklet_stats=True) 77 | 78 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \ 79 | self.get_videodata_info(gallery, return_tracklet_stats=True) 80 | 81 | tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats 82 | min_num = np.min(tracklet_stats) 83 | max_num = np.max(tracklet_stats) 84 | avg_num = np.mean(tracklet_stats) 85 | 86 | print("Dataset statistics:") 87 | print(" -------------------------------------------") 88 | print(" subset | # ids | # tracklets | # cameras") 89 | print(" -------------------------------------------") 90 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) 91 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) 92 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) 93 | print(" -------------------------------------------") 94 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) 95 | print(" -------------------------------------------") 96 | -------------------------------------------------------------------------------- /data/samplers/triplet_sampler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import copy 8 | import random 9 | import torch 10 | from collections import defaultdict 11 | 12 | import numpy as np 13 | from torch.utils.data.sampler import Sampler 14 | 15 | 16 | class RandomIdentitySampler(Sampler): 17 | """ 18 | Randomly sample N identities, then for each identity, 19 | randomly sample K instances, therefore batch size is N*K. 20 | Args: 21 | - data_source (list): list of (img_path, pid, camid). 22 | - num_instances (int): number of instances per identity in a batch. 23 | - batch_size (int): number of examples in a batch. 24 | """ 25 | 26 | def __init__(self, data_source, batch_size, num_instances): 27 | self.data_source = data_source 28 | self.batch_size = batch_size 29 | self.num_instances = num_instances 30 | self.num_pids_per_batch = self.batch_size // self.num_instances 31 | self.index_dic = defaultdict(list) 32 | for index, (_, pid, _) in enumerate(self.data_source): 33 | self.index_dic[pid].append(index) 34 | self.pids = list(self.index_dic.keys()) 35 | 36 | # estimate number of examples in an epoch 37 | self.length = 0 38 | for pid in self.pids: 39 | idxs = self.index_dic[pid] 40 | num = len(idxs) 41 | if num < self.num_instances: 42 | num = self.num_instances 43 | self.length += num - num % self.num_instances 44 | 45 | def __iter__(self): 46 | batch_idxs_dict = defaultdict(list) 47 | 48 | for pid in self.pids: 49 | idxs = copy.deepcopy(self.index_dic[pid]) 50 | if len(idxs) < self.num_instances: 51 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 52 | random.shuffle(idxs) 53 | batch_idxs = [] 54 | for idx in idxs: 55 | batch_idxs.append(idx) 56 | if len(batch_idxs) == self.num_instances: 57 | batch_idxs_dict[pid].append(batch_idxs) 58 | batch_idxs = [] 59 | 60 | avai_pids = copy.deepcopy(self.pids) 61 | final_idxs = [] 62 | 63 | while len(avai_pids) >= self.num_pids_per_batch: 64 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 65 | for pid in selected_pids: 66 | batch_idxs = batch_idxs_dict[pid].pop(0) 67 | final_idxs.extend(batch_idxs) 68 | if len(batch_idxs_dict[pid]) == 0: 69 | avai_pids.remove(pid) 70 | 71 | return iter(final_idxs) 72 | 73 | def __len__(self): 74 | return self.length 75 | 76 | 77 | # New add by gu 78 | class RandomIdentitySampler_alignedreid(Sampler): 79 | """ 80 | Randomly sample N identities, then for each identity, 81 | randomly sample K instances, therefore batch size is N*K. 82 | 83 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 84 | 85 | Args: 86 | data_source (Dataset): dataset to sample from. 87 | num_instances (int): number of instances per identity. 88 | """ 89 | def __init__(self, data_source, num_instances): 90 | self.data_source = data_source 91 | self.num_instances = num_instances 92 | self.index_dic = defaultdict(list) 93 | for index, (_, pid, _) in enumerate(data_source): 94 | self.index_dic[pid].append(index) 95 | self.pids = list(self.index_dic.keys()) 96 | self.num_identities = len(self.pids) 97 | 98 | def __iter__(self): 99 | indices = torch.randperm(self.num_identities) 100 | ret = [] 101 | for i in indices: 102 | pid = self.pids[i] 103 | t = self.index_dic[pid] 104 | replace = False if len(t) >= self.num_instances else True 105 | t = np.random.choice(t, size=self.num_instances, replace=replace) 106 | ret.extend(t) 107 | return iter(ret) 108 | 109 | def __len__(self): 110 | return self.num_identities * self.num_instances 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-spectral Vehicle Re-identification: A Challenge 2 | ## Dataset 3 | In this work, we address the RGB and IR vehicle Re-ID problem and contribute a multi-spectral vehicle Re-ID benchmark named RGBN300, including RGB and NIR (Near Infrared) vehicle images of 300 identities from 8 camera views, giving in total 50125 RGB images and 50125 NIR images respectively. In addition, we have acquired additional TIR (Thermal Infrared) data for 100 vehicles from RGBN300 to form another dataset for three-spectral vehicle Re-ID. 4 | ![RGB-NIR-TIR](fig/challenge.jpg) 5 | 6 | RGBN300 7 | link:https://pan.baidu.com/s/1uiKcqiqdhd13nLSW8TUASg 8 | Extraction code:11y8 9 | 10 | RGBNT100 11 | link:https://pan.baidu.com/s/1xqqh7N4Lctm3RcUdskG0Ug 12 | Extraction code:rjin 13 | 14 | 15 | ## HAMNet 16 | ### Pipeline 17 | ![RGB-NIR-TIR](fig/frame.jpg) 18 | 19 | @InProceedings{Li_2020_AAAI, 20 | author = {Hongchao Li, Chenglong Li, Xianpeng Zhu, Aihua Zheng and Bin Luo}, 21 | title = {Multi-spectral Vehicle Re-identification: A Challenge}, 22 | booktitle = {AAAI}, 23 | month = {February}, 24 | year = {2020} 25 | } 26 | ### Results(Rank1(mAP)) 27 | |Modality|RGBN300|RGBNT100| 28 | |:---|:---|:---| 29 | |RGB_onestream|72.6(49.5)|58.5(41.0)| 30 | |NIR_onestream|61.9(42.1)|52.8(37.1)| 31 | |TIR_onestream|-|61.8(35.7)| 32 | |RGB-NIR_multistream|77.2(56.9)|65.4(43.1)| 33 | |RGB-NIR-TIR_multistream|-|82.6(60.5)| 34 | |**RGB-NIR_HAMNet**|**84.0(61.9)**|-| 35 | |**RGB-NIR-TIR_HAMNet**|-|**84.7(64.1)**| 36 | 37 | ### Get Started 38 | 39 | The designed architecture follows this guide PyTorch-Project-Template, you can check each folder's purpose by yourself. The codes are expanded on a [ReID-baseline](https://github.com/L1aoXingyu/reid_baseline). 40 | 41 | 1.`cd` to folder where you want to download this repo 42 | 43 | 2.Run `git clone https://github.com/ttaalle/multi-modal-vehicle-Re-ID.git` 44 | 45 | 3.Install dependencies: 46 | * pytorch>=0.4 47 | * torchvision 48 | * ignite=0.1.2 49 | * yacs 50 | 51 | 52 | 4.Prepare Pretraining model on Imagenet 53 | 54 | for example /home/——/.torch/models/resnet50-19c8e357.pth 55 | 56 | 5.Prepare dataset 57 | 58 | Create a directory to store reid datasets under this repo or outside this repo. Remember to set your path to the root of the dataset in config/defaults.py for all training and testing or set in every single config file in configs/ or set in every single command. 59 | 60 | You can create a directory to store reid datasets under this repo via 61 | 62 | cd multi-modal-vehicle-Re-ID 63 | mkdir data 64 | 65 | (1) RGBN300 dataset 66 | 67 | Download dataset and only use `rgbir` to data/ 68 | 69 | The data structure would like: 70 | 71 | data 72 | rgbir # this folder contains 3 files. 73 | bounding_box_test/ 74 | bounding_box_train/ 75 | query 76 | 77 | (2) RGBNT100 dataset 78 | 79 | Download dataset and only use `rgbir` to data/ (It is worth noting that the two datasets use the same read interface in our algorithm, so in order to prevent data from being polluted, we should only change the data folder to `rgbir` when running the code.) 80 | 81 | The data structure would like: 82 | 83 | data 84 | rgbir # this folder contains 3 files. 85 | bounding_box_test/ 86 | bounding_box_train/ 87 | query 88 | 89 | 90 | ### train 91 | 92 | python3 train.py --config_file='softmax_triplet.yml' MODEL.DEVICE_ID "('your device id')" DATASETS.NAMES "('rgbir')" OUTPUT_DIR "('your path to save checkpoints and logs')" 93 | 94 | ### test 95 | python3 test.py --config_file='softmax_tripletr.yml' MODEL.DEVICE_ID "('your device id')" DATASETS.NAMES "('rgbir')" MODEL.PRETRAIN_CHOICE "('self')" TEST.WEIGHT "('your path to trained checkpoints')" 96 | 97 | 98 | 99 | To propose a stronger baseline, this version has been added bag of tricks(Random erasing augmentation, Label smoothing and BNNeck) as [Strong ReID-baseline](https://github.com/michuanhaohao/reid-strong-baseline). 100 | 101 | -------------------------------------------------------------------------------- /data/datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | 14 | from utils.iotools import mkdir_if_missing 15 | from .bases import BaseImageDataset 16 | 17 | 18 | class DukeMTMCreID(BaseImageDataset): 19 | """ 20 | DukeMTMC-reID 21 | Reference: 22 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 23 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 24 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 25 | 26 | Dataset statistics: 27 | # identities: 1404 (train + query) 28 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 29 | # cameras: 8 30 | """ 31 | dataset_dir = 'dukemtmc-reid' 32 | 33 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 34 | super(DukeMTMCreID, self).__init__() 35 | self.dataset_dir = osp.join(root, self.dataset_dir) 36 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 37 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') 38 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 39 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 40 | 41 | self._download_data() 42 | self._check_before_run() 43 | 44 | train = self._process_dir(self.train_dir, relabel=True) 45 | query = self._process_dir(self.query_dir, relabel=False) 46 | gallery = self._process_dir(self.gallery_dir, relabel=False) 47 | 48 | if verbose: 49 | print("=> DukeMTMC-reID loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 59 | 60 | def _download_data(self): 61 | if osp.exists(self.dataset_dir): 62 | print("This dataset has been downloaded.") 63 | return 64 | 65 | print("Creating directory {}".format(self.dataset_dir)) 66 | mkdir_if_missing(self.dataset_dir) 67 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 68 | 69 | print("Downloading DukeMTMC-reID dataset") 70 | urllib.request.urlretrieve(self.dataset_url, fpath) 71 | 72 | print("Extracting files") 73 | zip_ref = zipfile.ZipFile(fpath, 'r') 74 | zip_ref.extractall(self.dataset_dir) 75 | zip_ref.close() 76 | 77 | def _check_before_run(self): 78 | """Check if all files are available before going deeper""" 79 | if not osp.exists(self.dataset_dir): 80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 81 | if not osp.exists(self.train_dir): 82 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 83 | if not osp.exists(self.query_dir): 84 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 85 | if not osp.exists(self.gallery_dir): 86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 87 | 88 | def _process_dir(self, dir_path, relabel=False): 89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 90 | pattern = re.compile(r'([-\d]+)_c(\d)') 91 | 92 | pid_container = set() 93 | for img_path in img_paths: 94 | pid, _ = map(int, pattern.search(img_path).groups()) 95 | pid_container.add(pid) 96 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 97 | 98 | dataset = [] 99 | for img_path in img_paths: 100 | pid, camid = map(int, pattern.search(img_path).groups()) 101 | assert 1 <= camid <= 8 102 | camid -= 1 # index starts from 0 103 | if relabel: pid = pid2label[pid] 104 | dataset.append((img_path, pid, camid)) 105 | 106 | return dataset 107 | -------------------------------------------------------------------------------- /utils/re_ranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | @author: luohao 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 31 | query_num = probFea.size(0) 32 | all_num = query_num + galFea.size(0) 33 | if only_local: 34 | original_dist = local_distmat 35 | else: 36 | feat = torch.cat([probFea,galFea]) 37 | print('using GPU to compute original distance') 38 | distmat = torch.pow(feat,2).sum(dim=1, keepdim=True).expand(all_num,all_num) + \ 39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 40 | distmat.addmm_(1,-2,feat,feat.t()) 41 | original_dist = distmat.cpu().numpy() 42 | del feat 43 | if not local_distmat is None: 44 | original_dist = original_dist + local_distmat 45 | gallery_num = original_dist.shape[0] 46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 47 | V = np.zeros_like(original_dist).astype(np.float16) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) 49 | 50 | print('starting re_ranking') 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 55 | fi = np.where(backward_k_neigh_index == i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 62 | :int(np.around(k1 / 2)) + 1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 66 | candidate_k_reciprocal_index): 67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 68 | 69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 72 | original_dist = original_dist[:query_num, ] 73 | if k2 != 1: 74 | V_qe = np.zeros_like(V, dtype=np.float16) 75 | for i in range(all_num): 76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 77 | V = V_qe 78 | del V_qe 79 | del initial_rank 80 | invIndex = [] 81 | for i in range(gallery_num): 82 | invIndex.append(np.where(V[:, i] != 0)[0]) 83 | 84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 88 | indNonZero = np.where(V[i, :] != 0)[0] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 92 | V[indImages[j], indNonZero[j]]) 93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 94 | 95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num, query_num:] 100 | return final_dist 101 | 102 | -------------------------------------------------------------------------------- /modeling/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import math 8 | import pdb 9 | import torch 10 | from torch import nn 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 59 | padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(planes * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | #pdb.set_trace() 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | out = self.bn3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | #pdb.set_trace() 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class ResNet(nn.Module): 91 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]): 92 | self.inplanes = 64 93 | super().__init__() 94 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 95 | bias=False) 96 | self.bn1 = nn.BatchNorm2d(64) 97 | # self.relu = nn.ReLU(inplace=True) # add missed relu 98 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 99 | self.layer1 = self._make_layer(block, 64, layers[0]) 100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 102 | self.layer4 = self._make_layer( 103 | block, 512, layers[3], stride=last_stride) 104 | 105 | def _make_layer(self, block, planes, blocks, stride=1): 106 | downsample = None 107 | if stride != 1 or self.inplanes != planes * block.expansion: 108 | downsample = nn.Sequential( 109 | nn.Conv2d(self.inplanes, planes * block.expansion, 110 | kernel_size=1, stride=stride, bias=False), 111 | nn.BatchNorm2d(planes * block.expansion), 112 | ) 113 | 114 | layers = [] 115 | layers.append(block(self.inplanes, planes, stride, downsample)) 116 | self.inplanes = planes * block.expansion 117 | for i in range(1, blocks): 118 | layers.append(block(self.inplanes, planes)) 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | #pdb.set_trace() 124 | x = self.conv1(x) 125 | #pdb.set_trace() 126 | x = self.bn1(x) 127 | #pdb.set_trace() 128 | # x = self.relu(x) # add missed relu 129 | x = self.maxpool(x) 130 | #pdb.set_trace() 131 | 132 | x = self.layer1(x) 133 | #pdb.set_trace() 134 | x = self.layer2(x) 135 | #pdb.set_trace() 136 | x = self.layer3(x) 137 | #pdb.set_trace() 138 | x = self.layer4(x) 139 | 140 | return x 141 | 142 | def load_param(self, model_path): 143 | param_dict = torch.load(model_path) 144 | for i in param_dict: 145 | if 'fc' in i: 146 | continue 147 | self.state_dict()[i].copy_(param_dict[i]) 148 | 149 | def random_init(self): 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 153 | m.weight.data.normal_(0, math.sqrt(2. / n)) 154 | elif isinstance(m, nn.BatchNorm2d): 155 | m.weight.data.fill_(1) 156 | m.bias.data.zero_() 157 | 158 | -------------------------------------------------------------------------------- /layers/triplet_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | from torch import nn 8 | 9 | 10 | def normalize(x, axis=-1): 11 | """Normalizing to unit length along the specified dimension. 12 | Args: 13 | x: pytorch Variable 14 | Returns: 15 | x: pytorch Variable, same shape as input 16 | """ 17 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 18 | return x 19 | 20 | 21 | def euclidean_dist(x, y): 22 | """ 23 | Args: 24 | x: pytorch Variable, with shape [m, d] 25 | y: pytorch Variable, with shape [n, d] 26 | Returns: 27 | dist: pytorch Variable, with shape [m, n] 28 | """ 29 | m, n = x.size(0), y.size(0) 30 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 31 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 32 | dist = xx + yy 33 | dist.addmm_(1, -2, x, y.t()) 34 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 35 | return dist 36 | 37 | 38 | def hard_example_mining(dist_mat, labels, return_inds=False): 39 | """For each anchor, find the hardest positive and negative sample. 40 | Args: 41 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 42 | labels: pytorch LongTensor, with shape [N] 43 | return_inds: whether to return the indices. Save time if `False`(?) 44 | Returns: 45 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 46 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 47 | p_inds: pytorch LongTensor, with shape [N]; 48 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 49 | n_inds: pytorch LongTensor, with shape [N]; 50 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 51 | NOTE: Only consider the case in which all labels have same num of samples, 52 | thus we can cope with all anchors in parallel. 53 | """ 54 | 55 | assert len(dist_mat.size()) == 2 56 | assert dist_mat.size(0) == dist_mat.size(1) 57 | N = dist_mat.size(0) 58 | 59 | # shape [N, N] 60 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 61 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 62 | 63 | # `dist_ap` means distance(anchor, positive) 64 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 65 | dist_ap, relative_p_inds = torch.max( 66 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 67 | # `dist_an` means distance(anchor, negative) 68 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 69 | dist_an, relative_n_inds = torch.min( 70 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 71 | # shape [N] 72 | dist_ap = dist_ap.squeeze(1) 73 | dist_an = dist_an.squeeze(1) 74 | 75 | if return_inds: 76 | # shape [N, N] 77 | ind = (labels.new().resize_as_(labels) 78 | .copy_(torch.arange(0, N).long()) 79 | .unsqueeze(0).expand(N, N)) 80 | # shape [N, 1] 81 | p_inds = torch.gather( 82 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 83 | n_inds = torch.gather( 84 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 85 | # shape [N] 86 | p_inds = p_inds.squeeze(1) 87 | n_inds = n_inds.squeeze(1) 88 | return dist_ap, dist_an, p_inds, n_inds 89 | 90 | return dist_ap, dist_an 91 | 92 | 93 | class TripletLoss(object): 94 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 95 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 96 | Loss for Person Re-Identification'.""" 97 | 98 | def __init__(self, margin=None): 99 | self.margin = margin 100 | if margin is not None: 101 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 102 | else: 103 | self.ranking_loss = nn.SoftMarginLoss() 104 | 105 | def __call__(self, global_feat, labels, normalize_feature=False): 106 | if normalize_feature: 107 | global_feat = normalize(global_feat, axis=-1) 108 | dist_mat = euclidean_dist(global_feat, global_feat) 109 | dist_ap, dist_an = hard_example_mining( 110 | dist_mat, labels) 111 | y = dist_an.new().resize_as_(dist_an).fill_(1) 112 | if self.margin is not None: 113 | loss = self.ranking_loss(dist_an, dist_ap, y) 114 | else: 115 | loss = self.ranking_loss(dist_an - dist_ap, y) 116 | return loss, dist_ap, dist_an 117 | 118 | class CrossEntropyLabelSmooth(nn.Module): 119 | """Cross entropy loss with label smoothing regularizer. 120 | 121 | Reference: 122 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 123 | Equation: y = (1 - epsilon) * y + epsilon / K. 124 | 125 | Args: 126 | num_classes (int): number of classes. 127 | epsilon (float): weight. 128 | """ 129 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 130 | super(CrossEntropyLabelSmooth, self).__init__() 131 | self.num_classes = num_classes 132 | self.epsilon = epsilon 133 | self.use_gpu = use_gpu 134 | self.logsoftmax = nn.LogSoftmax(dim=1) 135 | 136 | def forward(self, inputs, targets): 137 | """ 138 | Args: 139 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 140 | targets: ground truth labels with shape (num_classes) 141 | """ 142 | log_probs = self.logsoftmax(inputs) 143 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 144 | if self.use_gpu: targets = targets.cuda() 145 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 146 | loss = (- targets * log_probs).mean(0).sum() 147 | return loss -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | # or _TEST for a test-specific parameter. 9 | # For example, the number of images during training will be 10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be 11 | # IMAGES_PER_BATCH_TEST 12 | 13 | # ----------------------------------------------------------------------------- 14 | # Config definition 15 | # ----------------------------------------------------------------------------- 16 | 17 | _C = CN() 18 | 19 | _C.MODEL = CN() 20 | # Using cuda or cpu for training 21 | _C.MODEL.DEVICE = "cuda" 22 | # ID number of GPU 23 | _C.MODEL.DEVICE_ID = '0' 24 | # Name of backbone 25 | _C.MODEL.NAME = 'resnet50' 26 | # Last stride of backbone 27 | _C.MODEL.LAST_STRIDE = 1 28 | # Path to pretrained model of backbone 29 | _C.MODEL.PRETRAIN_PATH = '' 30 | # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model 31 | # Options: 'imagenet' or 'self' 32 | _C.MODEL.PRETRAIN_CHOICE = 'imagenet' 33 | # If train with BNNeck, options: 'bnneck' or 'no' 34 | _C.MODEL.NECK = 'bnneck' 35 | # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration 36 | _C.MODEL.IF_WITH_CENTER = 'no' 37 | # The loss type of metric loss 38 | # options:'triplet','cluster','triplet_cluster','center','range_center','triplet_center','triplet_range_center' 39 | _C.MODEL.METRIC_LOSS_TYPE = 'triplet' 40 | # For example, if loss type is cross entropy loss + triplet loss + center loss 41 | # the setting should be: _C.MODEL.METRIC_LOSS_TYPE = 'triplet_center' and _C.MODEL.IF_WITH_CENTER = 'yes' 42 | 43 | # If train with label smooth, options: 'on', 'off' 44 | _C.MODEL.IF_LABELSMOOTH = 'on' 45 | 46 | 47 | # ----------------------------------------------------------------------------- 48 | # INPUT 49 | # ----------------------------------------------------------------------------- 50 | _C.INPUT = CN() 51 | # Size of the image during training 52 | _C.INPUT.SIZE_TRAIN = [384, 128] 53 | # Size of the image during test 54 | _C.INPUT.SIZE_TEST = [384, 128] 55 | # Random probability for image horizontal flip 56 | _C.INPUT.PROB = 0.5 57 | # Random probability for random erasing 58 | _C.INPUT.RE_PROB = 0.5 59 | # Values to be used for image normalization 60 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 61 | # Values to be used for image normalization 62 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 63 | # Value of padding size 64 | _C.INPUT.PADDING = 10 65 | 66 | # ----------------------------------------------------------------------------- 67 | # Dataset 68 | # ----------------------------------------------------------------------------- 69 | _C.DATASETS = CN() 70 | # List of the dataset names for training, as present in paths_catalog.py 71 | _C.DATASETS.NAMES = ('market1501') 72 | # Root directory where datasets should be used (and downloaded if not found) 73 | _C.DATASETS.ROOT_DIR = ('./data') 74 | 75 | # ----------------------------------------------------------------------------- 76 | # DataLoader 77 | # ----------------------------------------------------------------------------- 78 | _C.DATALOADER = CN() 79 | # Number of data loading threads 80 | _C.DATALOADER.NUM_WORKERS = 8 81 | # Sampler for data loading 82 | _C.DATALOADER.SAMPLER = 'softmax' 83 | # Number of instance for one batch 84 | _C.DATALOADER.NUM_INSTANCE = 16 85 | 86 | # ---------------------------------------------------------------------------- # 87 | # Solver 88 | # ---------------------------------------------------------------------------- # 89 | _C.SOLVER = CN() 90 | # Name of optimizer 91 | _C.SOLVER.OPTIMIZER_NAME = "Adam" 92 | # Number of max epoches 93 | _C.SOLVER.MAX_EPOCHS = 50 94 | # Base learning rate 95 | _C.SOLVER.BASE_LR = 3e-4 96 | # Factor of learning bias 97 | _C.SOLVER.BIAS_LR_FACTOR = 2 98 | # Momentum 99 | _C.SOLVER.MOMENTUM = 0.9 100 | # Margin of triplet loss 101 | _C.SOLVER.MARGIN = 0.3 102 | # Margin of cluster ;pss 103 | _C.SOLVER.CLUSTER_MARGIN = 0.3 104 | # Learning rate of SGD to learn the centers of center loss 105 | _C.SOLVER.CENTER_LR = 0.5 106 | # Balanced weight of center loss 107 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005 108 | # Settings of range loss 109 | _C.SOLVER.RANGE_K = 2 110 | _C.SOLVER.RANGE_MARGIN = 0.3 111 | _C.SOLVER.RANGE_ALPHA = 0 112 | _C.SOLVER.RANGE_BETA = 1 113 | _C.SOLVER.RANGE_LOSS_WEIGHT = 1 114 | 115 | # Settings of weight decay 116 | _C.SOLVER.WEIGHT_DECAY = 0.0005 117 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0. 118 | 119 | # decay rate of learning rate 120 | _C.SOLVER.GAMMA = 0.1 121 | # decay step of learning rate 122 | _C.SOLVER.STEPS = (30, 55) 123 | 124 | # warm up factor 125 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3 126 | # iterations of warm up 127 | _C.SOLVER.WARMUP_ITERS = 500 128 | # method of warm up, option: 'constant','linear' 129 | _C.SOLVER.WARMUP_METHOD = "linear" 130 | 131 | # epoch number of saving checkpoints 132 | _C.SOLVER.CHECKPOINT_PERIOD = 50 133 | # iteration of display training log 134 | _C.SOLVER.LOG_PERIOD = 100 135 | # epoch number of validation 136 | _C.SOLVER.EVAL_PERIOD = 50 137 | 138 | # Number of images per batch 139 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 140 | # see 2 images per batch 141 | _C.SOLVER.IMS_PER_BATCH = 64 142 | 143 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 144 | # see 2 images per batch 145 | _C.TEST = CN() 146 | # Number of images per batch during test 147 | _C.TEST.IMS_PER_BATCH = 128 148 | # If test with re-ranking, options: 'yes','no' 149 | _C.TEST.RE_RANKING = 'no' 150 | # Path to trained model 151 | _C.TEST.WEIGHT = "" 152 | # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after' 153 | _C.TEST.NECK_FEAT = 'after' 154 | # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance 155 | _C.TEST.FEAT_NORM = 'yes' 156 | 157 | # ---------------------------------------------------------------------------- # 158 | # Misc options 159 | # ---------------------------------------------------------------------------- # 160 | # Path to checkpoint and saved log of trained model 161 | _C.OUTPUT_DIR = "" 162 | -------------------------------------------------------------------------------- /modeling/backbones/squeezenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 5 | 6 | __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] 7 | 8 | model_urls = { 9 | 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', 10 | 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', 11 | } 12 | 13 | 14 | class Fire(nn.Module): 15 | 16 | def __init__(self, inplanes, squeeze_planes, 17 | expand1x1_planes, expand3x3_planes): 18 | super(Fire, self).__init__() 19 | self.inplanes = inplanes 20 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 21 | self.squeeze_activation = nn.ReLU(inplace=True) 22 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 23 | kernel_size=1) 24 | self.expand1x1_activation = nn.ReLU(inplace=True) 25 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 26 | kernel_size=3, padding=1) 27 | self.expand3x3_activation = nn.ReLU(inplace=True) 28 | 29 | def forward(self, x): 30 | x = self.squeeze_activation(self.squeeze(x)) 31 | return torch.cat([ 32 | self.expand1x1_activation(self.expand1x1(x)), 33 | self.expand3x3_activation(self.expand3x3(x)) 34 | ], 1) 35 | 36 | 37 | class SqueezeNet(nn.Module): 38 | 39 | def __init__(self, version='1_0', num_classes=1000): 40 | super(SqueezeNet, self).__init__() 41 | self.num_classes = num_classes 42 | if version == '1_0': 43 | self.features = nn.Sequential( 44 | nn.Conv2d(3, 96, kernel_size=7, stride=2), 45 | nn.ReLU(inplace=True), 46 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 47 | Fire(96, 16, 64, 64), 48 | Fire(128, 16, 64, 64), 49 | Fire(128, 32, 128, 128), 50 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 51 | Fire(256, 32, 128, 128), 52 | Fire(256, 48, 192, 192), 53 | Fire(384, 48, 192, 192), 54 | Fire(384, 64, 256, 256), 55 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 56 | Fire(512, 64, 256, 256), 57 | ) 58 | elif version == '1_1': 59 | self.features = nn.Sequential( 60 | nn.Conv2d(3, 64, kernel_size=3, stride=2), 61 | nn.ReLU(inplace=True), 62 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 63 | Fire(64, 16, 64, 64), 64 | Fire(128, 16, 64, 64), 65 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 66 | Fire(128, 32, 128, 128), 67 | Fire(256, 32, 128, 128), 68 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 69 | Fire(256, 48, 192, 192), 70 | Fire(384, 48, 192, 192), 71 | Fire(384, 64, 256, 256), 72 | Fire(512, 64, 256, 256), 73 | ) 74 | else: 75 | # FIXME: Is this needed? SqueezeNet should only be called from the 76 | # FIXME: squeezenet1_x() functions 77 | # FIXME: This checking is not done for the other models 78 | raise ValueError("Unsupported SqueezeNet version {version}:" 79 | "1_0 or 1_1 expected".format(version=version)) 80 | 81 | # Final convolution is initialized differently from the rest 82 | final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) 83 | self.classifier = nn.Sequential( 84 | nn.Dropout(p=0.5), 85 | final_conv, 86 | nn.ReLU(inplace=True), 87 | nn.AdaptiveAvgPool2d((1, 1)) 88 | ) 89 | 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | if m is final_conv: 93 | init.normal_(m.weight, mean=0.0, std=0.01) 94 | else: 95 | init.kaiming_uniform_(m.weight) 96 | if m.bias is not None: 97 | init.constant_(m.bias, 0) 98 | 99 | def forward(self, x): 100 | x = self.features(x) 101 | #print(x) 102 | return x 103 | 104 | 105 | def _squeezenet(version, pretrained, progress, **kwargs): 106 | model = SqueezeNet(version, **kwargs) 107 | if pretrained: 108 | arch = 'squeezenet' + version 109 | state_dict = load_state_dict_from_url(model_urls[arch], 110 | progress=progress) 111 | model.load_state_dict(state_dict) 112 | return model 113 | 114 | 115 | def squeezenet1_0(pretrained=False, progress=True, **kwargs): 116 | r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level 117 | accuracy with 50x fewer parameters and <0.5MB model size" 118 | `_ paper. 119 | Args: 120 | pretrained (bool): If True, returns a model pre-trained on ImageNet 121 | progress (bool): If True, displays a progress bar of the download to stderr 122 | """ 123 | return _squeezenet('1_0', pretrained, progress, **kwargs) 124 | 125 | 126 | def squeezenet1_1(pretrained=False, progress=True, **kwargs): 127 | r"""SqueezeNet 1.1 model from the `official SqueezeNet repo 128 | `_. 129 | SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters 130 | than SqueezeNet 1.0, without sacrificing accuracy. 131 | Args: 132 | pretrained (bool): If True, returns a model pre-trained on ImageNet 133 | progress (bool): If True, displays a progress bar of the download to stderr 134 | """ 135 | return _squeezenet('1_1', pretrained, progress, **kwargs) 136 | 137 | def load_param(self, model_path): 138 | param_dict = torch.load(model_path) 139 | for i in param_dict: 140 | if 'fc' in i: 141 | continue 142 | self.state_dict()[i].copy_(param_dict[i]) 143 | 144 | def random_init(self): 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv2d): 147 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 148 | m.weight.data.normal_(0, math.sqrt(2. / n)) 149 | elif isinstance(m, nn.BatchNorm2d): 150 | m.weight.data.fill_(1) 151 | m.bias.data.zero_() -------------------------------------------------------------------------------- /data/datasets/vehicleid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import pdb 10 | import os.path as osp 11 | import numpy as np 12 | import random 13 | 14 | from .bases import BaseImageDataset 15 | from collections import Counter 16 | 17 | 18 | class Vehicleid(BaseImageDataset): 19 | """ 20 | Market1501 21 | Reference: 22 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 23 | URL: http://www.liangzheng.org/Project/project_reid.html 24 | 25 | Dataset statistics: 26 | # identities: 1501 (+1 for background) 27 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 28 | """ 29 | dataset_dir = 'vehicleid' 30 | 31 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 32 | super(Vehicleid, self).__init__() 33 | self.dataset_dir = osp.join(root, self.dataset_dir) 34 | self.all_dir = osp.join(self.dataset_dir, 'image') 35 | self.file_dir = osp.join(self.dataset_dir, 'train_test_split') 36 | 37 | file = open(osp.join(self.file_dir, 'train_list.txt')) 38 | data_mat = [] 39 | label_mat = [] 40 | for line in file.readlines(): 41 | cur_line = line.strip().split(" ") 42 | #float_line = map(float,cur_line) 43 | data_mat.append(cur_line[0:1]) 44 | label_mat.append(cur_line[-1]) 45 | #print(data_mat) 46 | #print(label_mat) 47 | #print(np.shape(data_mat)) 48 | #pdb.set_trace() 49 | #d=np.reshape(data_mat,(113346,1)) 50 | img_paths = glob.glob(osp.join(self.all_dir, '*.jpg')) 51 | #pdb.set_trace() 52 | 53 | train = [] 54 | pid_container = set() 55 | for i in range(len(label_mat)): 56 | 57 | str1 = '%s/%s.jpg'%(self.all_dir,data_mat[i][0]) 58 | # train_paths.append(img_paths[img_paths.index(str1)]) 59 | pid = int(label_mat[i]) 60 | pid_container.add(pid) 61 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 62 | #pdb.set_trace() 63 | for i in range(len(label_mat)): 64 | str1 = '%s/%s.jpg'%(self.all_dir,data_mat[i][0]) 65 | pid = int(label_mat[i]) 66 | pid = pid2label[pid] 67 | camid = i 68 | train.append((str1,pid,camid)) 69 | #pdb.set_trace() 70 | #pdb.set_trace() 71 | # pdb.set_trace() 72 | 73 | file = open(osp.join(self.file_dir, 'test_list_800.txt')) 74 | data_mat = [] 75 | label_mat = [] 76 | for line in file.readlines(): 77 | cur_line = line.strip().split(" ") 78 | #float_line = map(float,cur_line) 79 | data_mat.append(cur_line[0:1]) 80 | label_mat.append(cur_line[-1]) 81 | #print(data_mat) 82 | #print(label_mat) 83 | #print(np.shape(data_mat)) 84 | #pdb.set_trace() 85 | #d=np.reshape(data_mat,(113346,1)) 86 | img_paths = glob.glob(osp.join(self.all_dir, '*.jpg')) 87 | #pdb.set_trace() 88 | 89 | query = [] 90 | gallery = [] 91 | pid_container = set() 92 | pid_list = [] 93 | 94 | for i in range(len(label_mat)): 95 | 96 | str1 = '%s/%s.jpg'%(self.all_dir,data_mat[i][0]) 97 | # train_paths.append(img_paths[img_paths.index(str1)]) 98 | # img_paths[i] = str1 99 | pid = int(label_mat[i]) 100 | pid_container.add(pid) 101 | pid_list.append(pid) 102 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 103 | pid_dict = dict(Counter(pid_list)) 104 | #pdb.set_trace() 105 | for i in range(len(label_mat)): 106 | str1 = '%s/%s.jpg'%(self.all_dir,data_mat[i][0]) 107 | pid = int(label_mat[i]) 108 | pid = pid2label[pid] 109 | camid = i 110 | query.append((str1,pid,camid)) 111 | #pdb.set_trace() 112 | num = 0 113 | for i in range(0,800): 114 | gallery.append(query[num]) 115 | num = pid_dict[pid_list[num]] + num 116 | #pdb.set_trace() 117 | ##pdb.set_trace() 118 | 119 | 120 | #self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 121 | #self.query_dir = osp.join(self.dataset_dir, 'query') 122 | #self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 123 | 124 | # self._check_before_run() 125 | 126 | # train = self._process_dir(self.train_dir, relabel=True) 127 | # query = self._process_dir(self.query_dir, relabel=False) 128 | # gallery = self._process_dir(self.gallery_dir, relabel=False) 129 | 130 | if verbose: 131 | print("=> Vehicleid loaded") 132 | self.print_dataset_statistics(train, query, gallery) 133 | 134 | self.train = train 135 | self.query = query 136 | self.gallery = gallery 137 | 138 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 139 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 140 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 141 | 142 | def _check_before_run(self): 143 | """Check if all files are available before going deeper""" 144 | if not osp.exists(self.dataset_dir): 145 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 146 | if not osp.exists(self.train_dir): 147 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 148 | if not osp.exists(self.query_dir): 149 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 150 | if not osp.exists(self.gallery_dir): 151 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 152 | 153 | def _process_dir(self, dir_path, relabel=False): 154 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 155 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 156 | 157 | pid_container = set() 158 | for img_path in img_paths: 159 | pid, _ = map(int, pattern.search(img_path).groups()) 160 | if pid == -1: continue # junk images are just ignored 161 | pid_container.add(pid) 162 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 163 | 164 | dataset = [] 165 | for img_path in img_paths: 166 | pid, camid = map(int, pattern.search(img_path).groups()) 167 | #pdb.set_trace() 168 | #if pid == -1: continue # junk images are just ignored 169 | assert 1 <= pid <= 776 # pid == 0 means background 170 | assert 1 <= camid <= 20 171 | camid -= 1 # index starts from 0 172 | if relabel: pid = pid2label[pid] 173 | dataset.append((img_path, pid, camid)) 174 | return dataset 175 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | import torch 11 | from torch.backends import cudnn 12 | import pdb 13 | from torchcontrib.optim import SWA 14 | 15 | 16 | sys.path.append('.') 17 | from config import cfg 18 | from data import make_data_loader 19 | from engine.trainer import do_train, do_train_with_center 20 | from modeling import build_model 21 | from layers import make_loss, make_loss_with_center 22 | from solver import make_optimizer, make_optimizer_with_center, WarmupMultiStepLR 23 | 24 | from utils.logger import setup_logger 25 | 26 | def train(cfg): 27 | # prepare dataset 28 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) 29 | #pdb.set_trace() 30 | # prepare model 31 | model = build_model(cfg, num_classes) 32 | #pdb.set_trace() 33 | if cfg.MODEL.IF_WITH_CENTER == 'no': 34 | print('Train without center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE) 35 | optimizer = make_optimizer(cfg, model) 36 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 37 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 38 | 39 | loss_func = make_loss(cfg, num_classes) # modified by gu 40 | #pdb.set_trace() 41 | # Add for using self trained model 42 | if cfg.MODEL.PRETRAIN_CHOICE == 'self': 43 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1]) 44 | print('Start epoch:', start_epoch) 45 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer') 46 | print('Path to the checkpoint of optimizer:', path_to_optimizer) 47 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH)) 48 | optimizer.load_state_dict(torch.load(path_to_optimizer)) 49 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 50 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch) 51 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet': 52 | start_epoch = 0 53 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 54 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 55 | else: 56 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE)) 57 | 58 | arguments = {} 59 | 60 | do_train( 61 | cfg, 62 | model, 63 | train_loader, 64 | val_loader, 65 | optimizer, 66 | scheduler, # modify for using self trained model 67 | loss_func, 68 | num_query, 69 | start_epoch # add for using self trained model 70 | ) 71 | #optimizer.swap_swa_sgd() 72 | elif cfg.MODEL.IF_WITH_CENTER == 'yes': 73 | print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE) 74 | loss_func, center_criterion = make_loss_with_center(cfg, num_classes) # modified by gu 75 | optimizer, optimizer_center = make_optimizer_with_center(cfg, model, center_criterion) 76 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 77 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 78 | 79 | arguments = {} 80 | 81 | # Add for using self trained model 82 | if cfg.MODEL.PRETRAIN_CHOICE == 'self': 83 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1]) 84 | print('Start epoch:', start_epoch) 85 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer') 86 | print('Path to the checkpoint of optimizer:', path_to_optimizer) 87 | path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center') 88 | print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center) 89 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH)) 90 | optimizer.load_state_dict(torch.load(path_to_optimizer)) 91 | optimizer_center.load_state_dict(torch.load(path_to_optimizer_center)) 92 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 93 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch) 94 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet': 95 | start_epoch = 0 96 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 97 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 98 | else: 99 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE)) 100 | 101 | do_train_with_center( 102 | cfg, 103 | model, 104 | center_criterion, 105 | train_loader, 106 | val_loader, 107 | optimizer, 108 | optimizer_center, 109 | scheduler, # modify for using self trained model 110 | loss_func, 111 | num_query, 112 | start_epoch # add for using self trained model 113 | ) 114 | else: 115 | print("Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(cfg.MODEL.IF_WITH_CENTER)) 116 | 117 | 118 | def main(): 119 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 120 | parser.add_argument( 121 | "--config_file", default="", help="path to config file", type=str 122 | ) 123 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 124 | nargs=argparse.REMAINDER) 125 | 126 | args = parser.parse_args() 127 | 128 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 129 | 130 | if args.config_file != "": 131 | cfg.merge_from_file(args.config_file) 132 | cfg.merge_from_list(args.opts) 133 | cfg.freeze() 134 | 135 | output_dir = cfg.OUTPUT_DIR 136 | if output_dir and not os.path.exists(output_dir): 137 | os.makedirs(output_dir) 138 | 139 | logger = setup_logger("reid_baseline", output_dir, 0) 140 | logger.info("Using {} GPUS".format(num_gpus)) 141 | logger.info(args) 142 | 143 | if args.config_file != "": 144 | logger.info("Loaded configuration file {}".format(args.config_file)) 145 | with open(args.config_file, 'r') as cf: 146 | config_str = "\n" + cf.read() 147 | logger.info(config_str) 148 | logger.info("Running with config:\n{}".format(cfg)) 149 | 150 | if cfg.MODEL.DEVICE == "cuda": 151 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu 152 | cudnn.benchmark = True 153 | train(cfg) 154 | 155 | 156 | if __name__ == '__main__': 157 | main() 158 | -------------------------------------------------------------------------------- /modeling/backbones/mobilenet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 4 | 5 | 6 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 7 | 8 | 9 | model_urls = { 10 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 11 | } 12 | 13 | 14 | def _make_divisible(v, divisor, min_value=None): 15 | """ 16 | This function is taken from the original tf repo. 17 | It ensures that all layers have a channel number that is divisible by 8 18 | It can be seen here: 19 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 20 | :param v: 21 | :param divisor: 22 | :param min_value: 23 | :return: 24 | """ 25 | if min_value is None: 26 | min_value = divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < 0.9 * v: 30 | new_v += divisor 31 | return new_v 32 | 33 | 34 | class ConvBNReLU(nn.Sequential): 35 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 36 | padding = (kernel_size - 1) // 2 37 | super(ConvBNReLU, self).__init__( 38 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 39 | nn.BatchNorm2d(out_planes), 40 | nn.ReLU6(inplace=True) 41 | ) 42 | 43 | 44 | class InvertedResidual(nn.Module): 45 | def __init__(self, inp, oup, stride, expand_ratio): 46 | super(InvertedResidual, self).__init__() 47 | self.stride = stride 48 | assert stride in [1, 2] 49 | 50 | hidden_dim = int(round(inp * expand_ratio)) 51 | self.use_res_connect = self.stride == 1 and inp == oup 52 | 53 | layers = [] 54 | if expand_ratio != 1: 55 | # pw 56 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 57 | layers.extend([ 58 | # dw 59 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 60 | # pw-linear 61 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 62 | nn.BatchNorm2d(oup), 63 | ]) 64 | self.conv = nn.Sequential(*layers) 65 | 66 | def forward(self, x): 67 | if self.use_res_connect: 68 | return x + self.conv(x) 69 | else: 70 | return self.conv(x) 71 | 72 | 73 | class MobileNetV2(nn.Module): 74 | def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): 75 | """ 76 | MobileNet V2 main class 77 | Args: 78 | num_classes (int): Number of classes 79 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 80 | inverted_residual_setting: Network structure 81 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 82 | Set to 1 to turn off rounding 83 | """ 84 | super(MobileNetV2, self).__init__() 85 | block = InvertedResidual 86 | input_channel = 32 87 | last_channel = 1280 88 | 89 | if inverted_residual_setting is None: 90 | inverted_residual_setting = [ 91 | # t, c, n, s 92 | [1, 16, 1, 1], 93 | [6, 24, 2, 2], 94 | [6, 32, 3, 2], 95 | [6, 64, 4, 2], 96 | [6, 96, 3, 1], 97 | [6, 160, 3, 2], 98 | [6, 320, 1, 1], 99 | ] 100 | 101 | # only check the first element, assuming user knows t,c,n,s are required 102 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 103 | raise ValueError("inverted_residual_setting should be non-empty " 104 | "or a 4-element list, got {}".format(inverted_residual_setting)) 105 | 106 | # building first layer 107 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 108 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 109 | features = [ConvBNReLU(3, input_channel, stride=2)] 110 | # building inverted residual blocks 111 | for t, c, n, s in inverted_residual_setting: 112 | output_channel = _make_divisible(c * width_mult, round_nearest) 113 | for i in range(n): 114 | stride = s if i == 0 else 1 115 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 116 | input_channel = output_channel 117 | # building last several layers 118 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 119 | # make it nn.Sequential 120 | self.features = nn.Sequential(*features) 121 | 122 | # building classifier 123 | self.classifier = nn.Sequential( 124 | nn.Dropout(0.2), 125 | nn.Linear(self.last_channel, num_classes), 126 | ) 127 | 128 | # weight initialization 129 | for m in self.modules(): 130 | if isinstance(m, nn.Conv2d): 131 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 132 | if m.bias is not None: 133 | nn.init.zeros_(m.bias) 134 | elif isinstance(m, nn.BatchNorm2d): 135 | nn.init.ones_(m.weight) 136 | nn.init.zeros_(m.bias) 137 | elif isinstance(m, nn.Linear): 138 | nn.init.normal_(m.weight, 0, 0.01) 139 | nn.init.zeros_(m.bias) 140 | 141 | def forward(self, x): 142 | x = self.features(x) 143 | # x = x.mean([2, 3]) 144 | # x = self.classifier(x) 145 | return x 146 | 147 | 148 | def mobilenet_v2(pretrained=False, progress=True, **kwargs): 149 | """ 150 | Constructs a MobileNetV2 architecture from 151 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 152 | Args: 153 | pretrained (bool): If True, returns a model pre-trained on ImageNet 154 | progress (bool): If True, displays a progress bar of the download to stderr 155 | """ 156 | model = MobileNetV2(**kwargs) 157 | if pretrained: 158 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 159 | progress=progress) 160 | model.load_state_dict(state_dict) 161 | return model 162 | 163 | def load_param(self, model_path): 164 | param_dict = torch.load(model_path) 165 | for i in param_dict: 166 | if 'fc' in i: 167 | continue 168 | self.state_dict()[i].copy_(param_dict[i]) 169 | 170 | def random_init(self): 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d): 173 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 174 | m.weight.data.normal_(0, math.sqrt(2. / n)) 175 | elif isinstance(m, nn.BatchNorm2d): 176 | m.weight.data.fill_(1) 177 | m.bias.data.zero_() -------------------------------------------------------------------------------- /modeling/backbones/ecanet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | # import torch.utils.model_zoo as model_zoo 4 | from .eca_module import eca_layer 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class ECABasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None, k_size=3): 17 | super(ECABasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes, 1) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.eca = eca_layer(planes, k_size) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | out = self.eca(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | 46 | class ECABottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None, k_size=3): 50 | super(ECABottleneck, self).__init__() 51 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 52 | self.bn1 = nn.BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 54 | padding=1, bias=False) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 57 | self.bn3 = nn.BatchNorm2d(planes * 4) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.eca = eca_layer(planes * 4, k_size) 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv3(out) 75 | out = self.bn3(out) 76 | out = self.eca(out) 77 | 78 | if self.downsample is not None: 79 | residual = self.downsample(x) 80 | 81 | out += residual 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | class ResNet(nn.Module): 88 | 89 | def __init__(self, block, layers, num_classes=1000, k_size=[3, 3, 3, 3]): 90 | self.inplanes = 64 91 | super(ResNet, self).__init__() 92 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 93 | bias=False) 94 | self.bn1 = nn.BatchNorm2d(64) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 97 | self.layer1 = self._make_layer(block, 64, layers[0], int(k_size[0])) 98 | self.layer2 = self._make_layer(block, 128, layers[1], int(k_size[1]), stride=2) 99 | self.layer3 = self._make_layer(block, 256, layers[2], int(k_size[2]), stride=2) 100 | self.layer4 = self._make_layer(block, 512, layers[3], int(k_size[3]), stride=2) 101 | self.avgpool = nn.AvgPool2d(7, stride=1) 102 | self.fc = nn.Linear(512 * block.expansion, num_classes) 103 | 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 107 | m.weight.data.normal_(0, math.sqrt(2. / n)) 108 | elif isinstance(m, nn.BatchNorm2d): 109 | m.weight.data.fill_(1) 110 | m.bias.data.zero_() 111 | 112 | def _make_layer(self, block, planes, blocks, k_size, stride=1): 113 | downsample = None 114 | if stride != 1 or self.inplanes != planes * block.expansion: 115 | downsample = nn.Sequential( 116 | nn.Conv2d(self.inplanes, planes * block.expansion, 117 | kernel_size=1, stride=stride, bias=False), 118 | nn.BatchNorm2d(planes * block.expansion), 119 | ) 120 | 121 | layers = [] 122 | layers.append(block(self.inplanes, planes, stride, downsample, k_size)) 123 | self.inplanes = planes * block.expansion 124 | for i in range(1, blocks): 125 | layers.append(block(self.inplanes, planes, k_size=k_size)) 126 | 127 | return nn.Sequential(*layers) 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | x = self.bn1(x) 132 | x = self.relu(x) 133 | x = self.maxpool(x) 134 | 135 | x = self.layer1(x) 136 | x = self.layer2(x) 137 | x = self.layer3(x) 138 | x = self.layer4(x) 139 | 140 | # x = self.avgpool(x) 141 | # x = x.view(x.size(0), -1) 142 | # x = self.fc(x) 143 | 144 | return x 145 | 146 | 147 | def eca_resnet18(k_size=[3, 3, 3, 3], num_classes=1_000, pretrained=False): 148 | """Constructs a ResNet-18 model. 149 | Args: 150 | k_size: Adaptive selection of kernel size 151 | pretrained (bool): If True, returns a model pre-trained on ImageNet 152 | num_classes:The classes of classification 153 | """ 154 | model = ResNet(ECABasicBlock, [2, 2, 2, 2], num_classes=num_classes, k_size=[3, 3, 3, 3]) 155 | model.avgpool = nn.AdaptiveAvgPool2d(1) 156 | return model 157 | 158 | 159 | def eca_resnet34(k_size=[3, 3, 3, 3], num_classes=1_000, pretrained=False): 160 | """Constructs a ResNet-34 model. 161 | Args: 162 | k_size: Adaptive selection of kernel size 163 | pretrained (bool): If True, returns a model pre-trained on ImageNet 164 | num_classes:The classes of classification 165 | """ 166 | model = ResNet(ECABasicBlock, [3, 4, 6, 3], num_classes=num_classes, k_size=k_size) 167 | model.avgpool = nn.AdaptiveAvgPool2d(1) 168 | return model 169 | 170 | 171 | def eca_resnet50(k_size=[3, 3, 3, 3], num_classes=1000, pretrained=False): 172 | """Constructs a ResNet-50 model. 173 | Args: 174 | k_size: Adaptive selection of kernel size 175 | num_classes:The classes of classification 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | print("Constructing eca_resnet50......") 179 | model = ResNet(ECABottleneck, [3, 4, 6, 3], num_classes=num_classes, k_size=k_size) 180 | # model.avgpool = nn.AdaptiveAvgPool2d(1) 181 | return model 182 | 183 | 184 | def eca_resnet101(k_size=[3, 3, 3, 3], num_classes=1_000, pretrained=False): 185 | """Constructs a ResNet-101 model. 186 | Args: 187 | k_size: Adaptive selection of kernel size 188 | num_classes:The classes of classification 189 | pretrained (bool): If True, returns a model pre-trained on ImageNet 190 | """ 191 | model = ResNet(ECABottleneck, [3, 4, 23, 3], num_classes=num_classes, k_size=k_size) 192 | model.avgpool = nn.AdaptiveAvgPool2d(1) 193 | return model 194 | 195 | 196 | def eca_resnet152(k_size=[3, 3, 3, 3], num_classes=1_000, pretrained=False): 197 | """Constructs a ResNet-152 model. 198 | Args: 199 | k_size: Adaptive selection of kernel size 200 | num_classes:The classes of classification 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | """ 203 | model = ResNet(ECABottleneck, [3, 8, 36, 3], num_classes=num_classes, k_size=k_size) 204 | model.avgpool = nn.AdaptiveAvgPool2d(1) 205 | return model -------------------------------------------------------------------------------- /modeling/baseline1.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from .backbones.resnet import ResNet, BasicBlock, Bottleneck 11 | from .backbones.senet import SENet, SEResNetBottleneck, SEBottleneck, SEResNeXtBottleneck 12 | 13 | 14 | def weights_init_kaiming(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Linear') != -1: 17 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 18 | nn.init.constant_(m.bias, 0.0) 19 | elif classname.find('Conv') != -1: 20 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 21 | if m.bias is not None: 22 | nn.init.constant_(m.bias, 0.0) 23 | elif classname.find('BatchNorm') != -1: 24 | if m.affine: 25 | nn.init.constant_(m.weight, 1.0) 26 | nn.init.constant_(m.bias, 0.0) 27 | 28 | 29 | def weights_init_classifier(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('Linear') != -1: 32 | nn.init.normal_(m.weight, std=0.001) 33 | if m.bias: 34 | nn.init.constant_(m.bias, 0.0) 35 | 36 | 37 | class Baseline(nn.Module): 38 | in_planes = 2048 39 | 40 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice): 41 | super(Baseline, self).__init__() 42 | if model_name == 'resnet18': 43 | self.in_planes = 512 44 | self.base = ResNet(last_stride=last_stride, 45 | block=BasicBlock, 46 | layers=[2, 2, 2, 2]) 47 | elif model_name == 'resnet34': 48 | self.in_planes = 512 49 | self.base = ResNet(last_stride=last_stride, 50 | block=BasicBlock, 51 | layers=[3, 4, 6, 3]) 52 | elif model_name == 'resnet50': 53 | self.base = ResNet(last_stride=last_stride, 54 | block=Bottleneck, 55 | layers=[3, 4, 6, 3]) 56 | elif model_name == 'resnet101': 57 | self.base = ResNet(last_stride=last_stride, 58 | block=Bottleneck, 59 | layers=[3, 4, 23, 3]) 60 | elif model_name == 'resnet152': 61 | self.base = ResNet(last_stride=last_stride, 62 | block=Bottleneck, 63 | layers=[3, 8, 36, 3]) 64 | 65 | elif model_name == 'se_resnet50': 66 | self.base = SENet(block=SEResNetBottleneck, 67 | layers=[3, 4, 6, 3], 68 | groups=1, 69 | reduction=16, 70 | dropout_p=None, 71 | inplanes=64, 72 | input_3x3=False, 73 | downsample_kernel_size=1, 74 | downsample_padding=0, 75 | last_stride=last_stride) 76 | elif model_name == 'se_resnet101': 77 | self.base = SENet(block=SEResNetBottleneck, 78 | layers=[3, 4, 23, 3], 79 | groups=1, 80 | reduction=16, 81 | dropout_p=None, 82 | inplanes=64, 83 | input_3x3=False, 84 | downsample_kernel_size=1, 85 | downsample_padding=0, 86 | last_stride=last_stride) 87 | elif model_name == 'se_resnet152': 88 | self.base = SENet(block=SEResNetBottleneck, 89 | layers=[3, 8, 36, 3], 90 | groups=1, 91 | reduction=16, 92 | dropout_p=None, 93 | inplanes=64, 94 | input_3x3=False, 95 | downsample_kernel_size=1, 96 | downsample_padding=0, 97 | last_stride=last_stride) 98 | elif model_name == 'se_resnext50': 99 | self.base = SENet(block=SEResNeXtBottleneck, 100 | layers=[3, 4, 6, 3], 101 | groups=32, 102 | reduction=16, 103 | dropout_p=None, 104 | inplanes=64, 105 | input_3x3=False, 106 | downsample_kernel_size=1, 107 | downsample_padding=0, 108 | last_stride=last_stride) 109 | elif model_name == 'se_resnext101': 110 | self.base = SENet(block=SEResNeXtBottleneck, 111 | layers=[3, 4, 23, 3], 112 | groups=32, 113 | reduction=16, 114 | dropout_p=None, 115 | inplanes=64, 116 | input_3x3=False, 117 | downsample_kernel_size=1, 118 | downsample_padding=0, 119 | last_stride=last_stride) 120 | elif model_name == 'senet154': 121 | self.base = SENet(block=SEBottleneck, 122 | layers=[3, 8, 36, 3], 123 | groups=64, 124 | reduction=16, 125 | dropout_p=0.2, 126 | last_stride=last_stride) 127 | 128 | if pretrain_choice == 'imagenet': 129 | self.base.load_param(model_path) 130 | print('Loading pretrained ImageNet model......') 131 | 132 | self.gap = nn.AdaptiveAvgPool2d(1) 133 | # self.gap = nn.AdaptiveMaxPool2d(1) 134 | self.num_classes = num_classes 135 | self.neck = neck 136 | self.neck_feat = neck_feat 137 | 138 | if self.neck == 'no': 139 | self.classifier = nn.Linear(self.in_planes, self.num_classes) 140 | # self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) # new add by luo 141 | # self.classifier.apply(weights_init_classifier) # new add by luo 142 | elif self.neck == 'bnneck': 143 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 144 | self.bottleneck.bias.requires_grad_(False) # no shift 145 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) 146 | 147 | self.bottleneck.apply(weights_init_kaiming) 148 | self.classifier.apply(weights_init_classifier) 149 | 150 | def forward(self, x): 151 | 152 | global_feat = self.gap(self.base(x)) # (b, 2048, 1, 1) 153 | global_feat = global_feat.view(global_feat.shape[0], -1) # flatten to (bs, 2048) 154 | 155 | if self.neck == 'no': 156 | feat = global_feat 157 | elif self.neck == 'bnneck': 158 | feat = self.bottleneck(global_feat) # normalize for angular softmax 159 | 160 | if self.training: 161 | cls_score = self.classifier(feat) 162 | return cls_score, global_feat # global feature for triplet loss 163 | else: 164 | if self.neck_feat == 'after': 165 | # print("Test with feature after BN") 166 | return feat 167 | else: 168 | # print("Test with feature before BN") 169 | return global_feat 170 | 171 | def load_param(self, trained_path): 172 | param_dict = torch.load(trained_path) 173 | for i in param_dict: 174 | if 'classifier' in i: 175 | continue 176 | self.state_dict()[i].copy_(param_dict[i]) 177 | -------------------------------------------------------------------------------- /modeling/baseline2.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from .backbones.resnet import ResNet, BasicBlock, Bottleneck 11 | from .backbones.senet import SENet, SEResNetBottleneck, SEBottleneck, SEResNeXtBottleneck 12 | 13 | 14 | def weights_init_kaiming(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Linear') != -1: 17 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 18 | nn.init.constant_(m.bias, 0.0) 19 | elif classname.find('Conv') != -1: 20 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 21 | if m.bias is not None: 22 | nn.init.constant_(m.bias, 0.0) 23 | elif classname.find('BatchNorm') != -1: 24 | if m.affine: 25 | nn.init.constant_(m.weight, 1.0) 26 | nn.init.constant_(m.bias, 0.0) 27 | 28 | 29 | def weights_init_classifier(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('Linear') != -1: 32 | nn.init.normal_(m.weight, std=0.001) 33 | if m.bias: 34 | nn.init.constant_(m.bias, 0.0) 35 | 36 | 37 | class Baseline(nn.Module): 38 | in_planes = 2048 39 | 40 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice): 41 | super(Baseline, self).__init__() 42 | if model_name == 'resnet18': 43 | self.in_planes = 512 44 | self.base = ResNet(last_stride=last_stride, 45 | block=BasicBlock, 46 | layers=[2, 2, 2, 2]) 47 | elif model_name == 'resnet34': 48 | self.in_planes = 512 49 | self.base = ResNet(last_stride=last_stride, 50 | block=BasicBlock, 51 | layers=[3, 4, 6, 3]) 52 | elif model_name == 'resnet50': 53 | self.base1 = ResNet(last_stride=last_stride, 54 | block=Bottleneck, 55 | layers=[3, 4, 6, 3]) 56 | self.base2 = ResNet(last_stride=last_stride, 57 | block=Bottleneck, 58 | layers=[3, 4, 6, 3]) 59 | elif model_name == 'resnet101': 60 | self.base = ResNet(last_stride=last_stride, 61 | block=Bottleneck, 62 | layers=[3, 4, 23, 3]) 63 | elif model_name == 'resnet152': 64 | self.base = ResNet(last_stride=last_stride, 65 | block=Bottleneck, 66 | layers=[3, 8, 36, 3]) 67 | 68 | elif model_name == 'se_resnet50': 69 | self.base = SENet(block=SEResNetBottleneck, 70 | layers=[3, 4, 6, 3], 71 | groups=1, 72 | reduction=16, 73 | dropout_p=None, 74 | inplanes=64, 75 | input_3x3=False, 76 | downsample_kernel_size=1, 77 | downsample_padding=0, 78 | last_stride=last_stride) 79 | elif model_name == 'se_resnet101': 80 | self.base = SENet(block=SEResNetBottleneck, 81 | layers=[3, 4, 23, 3], 82 | groups=1, 83 | reduction=16, 84 | dropout_p=None, 85 | inplanes=64, 86 | input_3x3=False, 87 | downsample_kernel_size=1, 88 | downsample_padding=0, 89 | last_stride=last_stride) 90 | elif model_name == 'se_resnet152': 91 | self.base = SENet(block=SEResNetBottleneck, 92 | layers=[3, 8, 36, 3], 93 | groups=1, 94 | reduction=16, 95 | dropout_p=None, 96 | inplanes=64, 97 | input_3x3=False, 98 | downsample_kernel_size=1, 99 | downsample_padding=0, 100 | last_stride=last_stride) 101 | elif model_name == 'se_resnext50': 102 | self.base = SENet(block=SEResNeXtBottleneck, 103 | layers=[3, 4, 6, 3], 104 | groups=32, 105 | reduction=16, 106 | dropout_p=None, 107 | inplanes=64, 108 | input_3x3=False, 109 | downsample_kernel_size=1, 110 | downsample_padding=0, 111 | last_stride=last_stride) 112 | elif model_name == 'se_resnext101': 113 | self.base = SENet(block=SEResNeXtBottleneck, 114 | layers=[3, 4, 23, 3], 115 | groups=32, 116 | reduction=16, 117 | dropout_p=None, 118 | inplanes=64, 119 | input_3x3=False, 120 | downsample_kernel_size=1, 121 | downsample_padding=0, 122 | last_stride=last_stride) 123 | elif model_name == 'senet154': 124 | self.base = SENet(block=SEBottleneck, 125 | layers=[3, 8, 36, 3], 126 | groups=64, 127 | reduction=16, 128 | dropout_p=0.2, 129 | last_stride=last_stride) 130 | 131 | if pretrain_choice == 'imagenet': 132 | self.base.load_param(model_path) 133 | print('Loading pretrained ImageNet model......') 134 | 135 | self.gap = nn.AdaptiveAvgPool2d(1) 136 | # self.gap = nn.AdaptiveMaxPool2d(1) 137 | self.num_classes = num_classes 138 | self.neck = neck 139 | self.neck_feat = neck_feat 140 | 141 | if self.neck == 'no': 142 | self.classifier = nn.Linear(self.in_planes, self.num_classes) 143 | # self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) # new add by luo 144 | # self.classifier.apply(weights_init_classifier) # new add by luo 145 | elif self.neck == 'bnneck': 146 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 147 | self.bottleneck.bias.requires_grad_(False) # no shift 148 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) 149 | 150 | self.bottleneck.apply(weights_init_kaiming) 151 | self.classifier.apply(weights_init_classifier) 152 | 153 | def forward(self, x1, x2): 154 | 155 | global_feat1 = self.gap(self.base1(x1)) # (b, 2048, 1, 1) 156 | global_feat1 = global_feat1.view(global_feat1.shape[0], -1) # flatten to (bs, 2048) 157 | global_feat2 = self.gap(self.base2(x2)) # (b, 2048, 1, 1) 158 | global_feat2 = global_feat2.view(global_feat2.shape[0], -1) # flatten to (bs, 2048) 159 | 160 | if self.neck == 'no': 161 | feat1 = global_feat1 162 | feat2 = global_feat2 163 | elif self.neck == 'bnneck': 164 | feat1 = self.bottleneck(global_feat1) # normalize for angular softmax 165 | feat2 = self.bottleneck(global_feat2) # normalize for angular softmax 166 | 167 | if self.training: 168 | cls_score1 = self.classifier(feat1) 169 | cls_score2 = self.classifier(feat2) 170 | return cls_score1, global_feat1, cls_score2, global_feat2 # global feature for triplet loss 171 | else: 172 | if self.neck_feat == 'after': 173 | # print("Test with feature after BN") 174 | return feat1, feat2 175 | else: 176 | # print("Test with feature before BN") 177 | return global_feat1, global_feat2 178 | 179 | def load_param(self, trained_path): 180 | param_dict = torch.load(trained_path) 181 | for i in param_dict: 182 | if 'classifier' in i: 183 | continue 184 | self.state_dict()[i].copy_(param_dict[i]) 185 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch.nn.functional as F 8 | 9 | from .triplet_loss import TripletLoss, CrossEntropyLabelSmooth 10 | from .cluster_loss import ClusterLoss 11 | from .center_loss import CenterLoss 12 | from .range_loss import RangeLoss 13 | 14 | 15 | def make_loss(cfg, num_classes): # modified by gu 16 | sampler = cfg.DATALOADER.SAMPLER 17 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 18 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 19 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'cluster': 20 | cluster = ClusterLoss(cfg.SOLVER.CLUSTER_MARGIN, True, True, cfg.SOLVER.IMS_PER_BATCH // cfg.DATALOADER.NUM_INSTANCE, cfg.DATALOADER.NUM_INSTANCE) 21 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_cluster': 22 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 23 | cluster = ClusterLoss(cfg.SOLVER.CLUSTER_MARGIN, True, True, cfg.SOLVER.IMS_PER_BATCH // cfg.DATALOADER.NUM_INSTANCE, cfg.DATALOADER.NUM_INSTANCE) 24 | else: 25 | print('expected METRIC_LOSS_TYPE should be triplet, cluster, triplet_cluster' 26 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 27 | 28 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 29 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo 30 | print("label smooth on, numclasses:", num_classes) 31 | 32 | if sampler == 'softmax': 33 | def loss_func(score, feat, target): 34 | return F.cross_entropy(score, target) 35 | elif cfg.DATALOADER.SAMPLER == 'triplet': 36 | def loss_func(score, feat, target): 37 | return triplet(feat, target)[0] 38 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 39 | def loss_func(score, feat, target): 40 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 41 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 42 | return xent(score, target) + triplet(feat, target)[0] # new add by luo, open label smooth 43 | else: 44 | return F.cross_entropy(score, target) + triplet(feat, target)[0] # new add by luo, no label smooth 45 | 46 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'cluster': 47 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 48 | return xent(score, target) + cluster(feat, target)[0] # new add by luo, open label smooth 49 | else: 50 | return F.cross_entropy(score, target) + cluster(feat, target)[0] # new add by luo, no label smooth 51 | 52 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_cluster': 53 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 54 | return xent(score, target) + triplet(feat, target)[0] + cluster(feat, target)[0] # new add by luo, open label smooth 55 | else: 56 | return F.cross_entropy(score, target) + triplet(feat, target)[0] + cluster(feat, target)[0] # new add by luo, no label smooth 57 | else: 58 | print('expected METRIC_LOSS_TYPE should be triplet, cluster, triplet_cluster,' 59 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 60 | else: 61 | print('expected sampler should be softmax, triplet or softmax_triplet, ' 62 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 63 | return loss_func 64 | 65 | 66 | def make_loss_with_center(cfg, num_classes): # modified by gu 67 | if cfg.MODEL.NAME == 'resnet18' or cfg.MODEL.NAME == 'resnet34': 68 | feat_dim = 512 69 | else: 70 | feat_dim = 2048 71 | 72 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center': 73 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 74 | 75 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'range_center': 76 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center_range loss 77 | range_criterion = RangeLoss(k=cfg.SOLVER.RANGE_K, margin=cfg.SOLVER.RANGE_MARGIN, alpha=cfg.SOLVER.RANGE_ALPHA, 78 | beta=cfg.SOLVER.RANGE_BETA, ordered=True, use_gpu=True, 79 | ids_per_batch=cfg.SOLVER.IMS_PER_BATCH // cfg.DATALOADER.NUM_INSTANCE, 80 | imgs_per_id=cfg.DATALOADER.NUM_INSTANCE) 81 | 82 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center': 83 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 84 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 85 | 86 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_range_center': 87 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 88 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center_range loss 89 | range_criterion = RangeLoss(k=cfg.SOLVER.RANGE_K, margin=cfg.SOLVER.RANGE_MARGIN, alpha=cfg.SOLVER.RANGE_ALPHA, 90 | beta=cfg.SOLVER.RANGE_BETA, ordered=True, use_gpu=True, 91 | ids_per_batch=cfg.SOLVER.IMS_PER_BATCH // cfg.DATALOADER.NUM_INSTANCE, 92 | imgs_per_id=cfg.DATALOADER.NUM_INSTANCE) 93 | else: 94 | print('expected METRIC_LOSS_TYPE with center should be center, ' 95 | 'range_center,triplet_center, triplet_range_center ' 96 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 97 | 98 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 99 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo 100 | print("label smooth on, numclasses:", num_classes) 101 | 102 | def loss_func(score, feat, target): 103 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center': 104 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 105 | return xent(score, target) + \ 106 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) # new add by luo, open label smooth 107 | else: 108 | return F.cross_entropy(score, target) + \ 109 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) # new add by luo, no label smooth 110 | 111 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'range_center': 112 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 113 | return xent(score, target) + \ 114 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) + \ 115 | cfg.SOLVER.RANGE_LOSS_WEIGHT * range_criterion(feat, target)[0] # new add by luo, open label smooth 116 | else: 117 | return F.cross_entropy(score, target) + \ 118 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) + \ 119 | cfg.SOLVER.RANGE_LOSS_WEIGHT * range_criterion(feat, target)[0] # new add by luo, no label smooth 120 | 121 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center': 122 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 123 | return xent(score, target) + \ 124 | triplet(feat, target)[0] + \ 125 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) # new add by luo, open label smooth 126 | else: 127 | return F.cross_entropy(score, target) + \ 128 | triplet(feat, target)[0] + \ 129 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) # new add by luo, no label smooth 130 | 131 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_range_center': 132 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 133 | return xent(score, target) + \ 134 | triplet(feat, target)[0] + \ 135 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) + \ 136 | cfg.SOLVER.RANGE_LOSS_WEIGHT * range_criterion(feat, target)[0] # new add by luo, open label smooth 137 | else: 138 | return F.cross_entropy(score, target) + \ 139 | triplet(feat, target)[0] + \ 140 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) + \ 141 | cfg.SOLVER.RANGE_LOSS_WEIGHT * range_criterion(feat, target)[0] # new add by luo, no label smooth 142 | 143 | else: 144 | print('expected METRIC_LOSS_TYPE with center should be center,' 145 | ' range_center, triplet_center, triplet_range_center ' 146 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 147 | return loss_func, center_criterion -------------------------------------------------------------------------------- /layers/range_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class RangeLoss(nn.Module): 8 | """ 9 | Range_loss = alpha * intra_class_loss + beta * inter_class_loss 10 | intra_class_loss is the harmonic mean value of the top_k largest distances beturn intra_class_pairs 11 | inter_class_loss is the shortest distance between different class centers 12 | """ 13 | def __init__(self, k=2, margin=0.1, alpha=0.5, beta=0.5, use_gpu=True, ordered=True, ids_per_batch=32, imgs_per_id=4): 14 | super(RangeLoss, self).__init__() 15 | self.use_gpu = use_gpu 16 | self.margin = margin 17 | self.k = k 18 | self.alpha = alpha 19 | self.beta = beta 20 | self.ordered = ordered 21 | self.ids_per_batch = ids_per_batch 22 | self.imgs_per_id = imgs_per_id 23 | 24 | def _pairwise_distance(self, features): 25 | """ 26 | Args: 27 | features: prediction matrix (before softmax) with shape (batch_size, feature_dim) 28 | Return: 29 | pairwise distance matrix with shape(batch_size, batch_size) 30 | """ 31 | n = features.size(0) 32 | dist = torch.pow(features, 2).sum(dim=1, keepdim=True).expand(n, n) 33 | dist = dist + dist.t() 34 | dist.addmm_(1, -2, features, features.t()) 35 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 36 | return dist 37 | 38 | def _compute_top_k(self, features): 39 | """ 40 | Args: 41 | features: prediction matrix (before softmax) with shape (batch_size, feature_dim) 42 | Return: 43 | top_k largest distances 44 | """ 45 | # reading the codes below can help understand better 46 | ''' 47 | dist_array_2 = self._pairwise_distance(features) 48 | n = features.size(0) 49 | mask = torch.zeros(n, n) 50 | if self.use_gpu: mask=mask.cuda() 51 | for i in range(0, n): 52 | for j in range(i+1, n): 53 | mask[i, j] += 1 54 | dist_array_2 = dist_array_2 * mask 55 | dist_array_2 = dist_array_2.view(1, -1) 56 | dist_array_2 = dist_array_2[torch.gt(dist_array_2, 0)] 57 | top_k_2 = dist_array_2.sort()[0][-self.k:] 58 | print(top_k_2) 59 | ''' 60 | dist_array = self._pairwise_distance(features) 61 | dist_array = dist_array.view(1, -1) 62 | top_k = dist_array.sort()[0][0, -self.k * 2::2] # Because there are 2 same value of same feature pair in the dist_array 63 | # print('top k intra class dist:', top_k) 64 | return top_k 65 | 66 | def _compute_min_dist(self, center_features): 67 | """ 68 | Args: 69 | center_features: center matrix (before softmax) with shape (center_number, center_dim) 70 | Return: 71 | minimum center distance 72 | """ 73 | ''' 74 | # reading codes below can help understand better 75 | dist_array = self._pairwise_distance(center_features) 76 | n = center_features.size(0) 77 | mask = torch.zeros(n, n) 78 | if self.use_gpu: mask=mask.cuda() 79 | for i in range(0, n): 80 | for j in range(i + 1, n): 81 | mask[i, j] += 1 82 | dist_array *= mask 83 | dist_array = dist_array.view(1, -1) 84 | dist_array = dist_array[torch.gt(dist_array, 0)] 85 | min_inter_class_dist = dist_array.min() 86 | print(min_inter_class_dist) 87 | ''' 88 | n = center_features.size(0) 89 | dist_array2 = self._pairwise_distance(center_features) 90 | min_inter_class_dist2 = dist_array2.view(1, -1).sort()[0][0][n] # exclude self compare, the first one is the min_inter_class_dist 91 | return min_inter_class_dist2 92 | 93 | def _calculate_centers(self, features, targets, ordered, ids_per_batch, imgs_per_id): 94 | """ 95 | Args: 96 | features: prediction matrix (before softmax) with shape (batch_size, feature_dim) 97 | targets: ground truth labels with shape (batch_size) 98 | ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id. 99 | ids_per_batch: num of different ids per batch 100 | imgs_per_id: num of images per id 101 | Return: 102 | center_features: center matrix (before softmax) with shape (center_number, center_dim) 103 | """ 104 | if self.use_gpu: 105 | if ordered: 106 | if targets.size(0) == ids_per_batch * imgs_per_id: 107 | unique_labels = targets[0:targets.size(0):imgs_per_id] 108 | else: 109 | unique_labels = targets.cpu().unique().cuda() 110 | else: 111 | unique_labels = targets.cpu().unique().cuda() 112 | else: 113 | if ordered: 114 | if targets.size(0) == ids_per_batch * imgs_per_id: 115 | unique_labels = targets[0:targets.size(0):imgs_per_id] 116 | else: 117 | unique_labels = targets.unique() 118 | else: 119 | unique_labels = targets.unique() 120 | 121 | center_features = torch.zeros(unique_labels.size(0), features.size(1)) 122 | if self.use_gpu: 123 | center_features = center_features.cuda() 124 | 125 | for i in range(unique_labels.size(0)): 126 | label = unique_labels[i] 127 | same_class_features = features[targets == label] 128 | center_features[i] = same_class_features.mean(dim=0) 129 | return center_features 130 | 131 | def _inter_class_loss(self, features, targets, ordered, ids_per_batch, imgs_per_id): 132 | """ 133 | Args: 134 | features: prediction matrix (before softmax) with shape (batch_size, feature_dim) 135 | targets: ground truth labels with shape (batch_size) 136 | margin: inter class ringe loss margin 137 | ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id. 138 | ids_per_batch: num of different ids per batch 139 | imgs_per_id: num of images per id 140 | Return: 141 | inter_class_loss 142 | """ 143 | center_features = self._calculate_centers(features, targets, ordered, ids_per_batch, imgs_per_id) 144 | min_inter_class_center_distance = self._compute_min_dist(center_features) 145 | # print('min_inter_class_center_dist:', min_inter_class_center_distance) 146 | return torch.relu(self.margin - min_inter_class_center_distance) 147 | 148 | def _intra_class_loss(self, features, targets, ordered, ids_per_batch, imgs_per_id): 149 | """ 150 | Args: 151 | features: prediction matrix (before softmax) with shape (batch_size, feature_dim) 152 | targets: ground truth labels with shape (batch_size) 153 | ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id. 154 | ids_per_batch: num of different ids per batch 155 | imgs_per_id: num of images per id 156 | Return: 157 | intra_class_loss 158 | """ 159 | if self.use_gpu: 160 | if ordered: 161 | if targets.size(0) == ids_per_batch * imgs_per_id: 162 | unique_labels = targets[0:targets.size(0):imgs_per_id] 163 | else: 164 | unique_labels = targets.cpu().unique().cuda() 165 | else: 166 | unique_labels = targets.cpu().unique().cuda() 167 | else: 168 | if ordered: 169 | if targets.size(0) == ids_per_batch * imgs_per_id: 170 | unique_labels = targets[0:targets.size(0):imgs_per_id] 171 | else: 172 | unique_labels = targets.unique() 173 | else: 174 | unique_labels = targets.unique() 175 | 176 | intra_distance = torch.zeros(unique_labels.size(0)) 177 | if self.use_gpu: 178 | intra_distance = intra_distance.cuda() 179 | 180 | for i in range(unique_labels.size(0)): 181 | label = unique_labels[i] 182 | same_class_distances = 1.0 / self._compute_top_k(features[targets == label]) 183 | intra_distance[i] = self.k / torch.sum(same_class_distances) 184 | # print('intra_distace:', intra_distance) 185 | return torch.sum(intra_distance) 186 | 187 | def _range_loss(self, features, targets, ordered, ids_per_batch, imgs_per_id): 188 | """ 189 | Args: 190 | features: prediction matrix (before softmax) with shape (batch_size, feature_dim) 191 | targets: ground truth labels with shape (batch_size) 192 | ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id. 193 | ids_per_batch: num of different ids per batch 194 | imgs_per_id: num of images per id 195 | Return: 196 | range_loss 197 | """ 198 | inter_class_loss = self._inter_class_loss(features, targets, ordered, ids_per_batch, imgs_per_id) 199 | intra_class_loss = self._intra_class_loss(features, targets, ordered, ids_per_batch, imgs_per_id) 200 | range_loss = self.alpha * intra_class_loss + self.beta * inter_class_loss 201 | return range_loss, intra_class_loss, inter_class_loss 202 | 203 | def forward(self, features, targets): 204 | """ 205 | Args: 206 | features: prediction matrix (before softmax) with shape (batch_size, feature_dim) 207 | targets: ground truth labels with shape (batch_size) 208 | ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id. 209 | ids_per_batch: num of different ids per batch 210 | imgs_per_id: num of images per id 211 | Return: 212 | range_loss 213 | """ 214 | assert features.size(0) == targets.size(0), "features.size(0) is not equal to targets.size(0)" 215 | if self.use_gpu: 216 | features = features.cuda() 217 | targets = targets.cuda() 218 | 219 | range_loss, intra_class_loss, inter_class_loss = self._range_loss(features, targets, self.ordered, self.ids_per_batch, self.imgs_per_id) 220 | return range_loss, intra_class_loss, inter_class_loss 221 | 222 | 223 | if __name__ == '__main__': 224 | use_gpu = False 225 | range_loss = RangeLoss(use_gpu=use_gpu, ids_per_batch=4, imgs_per_id=4) 226 | features = torch.rand(16, 2048) 227 | targets = torch.Tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]) 228 | if use_gpu: 229 | features = torch.rand(16, 2048).cuda() 230 | targets = torch.Tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]).cuda() 231 | loss = range_loss(features, targets) 232 | print(loss) 233 | -------------------------------------------------------------------------------- /layers/cluster_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class ClusterLoss(nn.Module): 9 | def __init__(self, margin=10, use_gpu=True, ordered=True, ids_per_batch=16, imgs_per_id=4): 10 | super(ClusterLoss, self).__init__() 11 | self.use_gpu = use_gpu 12 | self.margin = margin 13 | self.ordered = ordered 14 | self.ids_per_batch = ids_per_batch 15 | self.imgs_per_id = imgs_per_id 16 | 17 | def _euclidean_dist(self, x, y): 18 | """ 19 | Args: 20 | x: pytorch Variable, with shape [m, d] 21 | y: pytorch Variable, with shape [n, d] 22 | Returns: 23 | dist: pytorch Variable, with shape [m, n] 24 | """ 25 | m, n = x.size(0), y.size(0) 26 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 27 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 28 | dist = xx + yy 29 | dist.addmm_(1, -2, x, y.t()) 30 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 31 | return dist 32 | 33 | def _cluster_loss(self, features, targets, ordered=True, ids_per_batch=16, imgs_per_id=4): 34 | """ 35 | Args: 36 | features: prediction matrix (before softmax) with shape (batch_size, feature_dim) 37 | targets: ground truth labels with shape (batch_size) 38 | ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id. 39 | ids_per_batch: num of different ids per batch 40 | imgs_per_id: num of images per id 41 | Return: 42 | cluster_loss 43 | """ 44 | if self.use_gpu: 45 | if ordered: 46 | if targets.size(0) == ids_per_batch * imgs_per_id: 47 | unique_labels = targets[0:targets.size(0):imgs_per_id] 48 | else: 49 | unique_labels = targets.cpu().unique().cuda() 50 | else: 51 | unique_labels = targets.cpu().unique().cuda() 52 | else: 53 | if ordered: 54 | if targets.size(0) == ids_per_batch * imgs_per_id: 55 | unique_labels = targets[0:targets.size(0):imgs_per_id] 56 | else: 57 | unique_labels = targets.unique() 58 | else: 59 | unique_labels = targets.unique() 60 | 61 | inter_min_distance = torch.zeros(unique_labels.size(0)) 62 | intra_max_distance = torch.zeros(unique_labels.size(0)) 63 | center_features = torch.zeros(unique_labels.size(0), features.size(1)) 64 | 65 | if self.use_gpu: 66 | inter_min_distance = inter_min_distance.cuda() 67 | intra_max_distance = intra_max_distance.cuda() 68 | center_features = center_features.cuda() 69 | 70 | index = torch.range(0, unique_labels.size(0) - 1) 71 | for i in range(unique_labels.size(0)): 72 | label = unique_labels[i] 73 | same_class_features = features[targets == label] 74 | center_features[i] = same_class_features.mean(dim=0) 75 | intra_class_distance = self._euclidean_dist(center_features[index==i], same_class_features) 76 | # print('intra_class_distance', intra_class_distance) 77 | intra_max_distance[i] = intra_class_distance.max() 78 | # print('intra_max_distance:', intra_max_distance) 79 | 80 | for i in range(unique_labels.size(0)): 81 | inter_class_distance = self._euclidean_dist(center_features[index==i], center_features[index != i]) 82 | # print('inter_class_distance', inter_class_distance) 83 | inter_min_distance[i] = inter_class_distance.min() 84 | # print('inter_min_distance:', inter_min_distance) 85 | cluster_loss = torch.mean(torch.relu(intra_max_distance - inter_min_distance + self.margin)) 86 | return cluster_loss, intra_max_distance, inter_min_distance 87 | 88 | def forward(self, features, targets): 89 | """ 90 | Args: 91 | features: prediction matrix (before softmax) with shape (batch_size, feature_dim) 92 | targets: ground truth labels with shape (batch_size) 93 | ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id. 94 | ids_per_batch: num of different ids per batch 95 | imgs_per_id: num of images per id 96 | Return: 97 | cluster_loss 98 | """ 99 | assert features.size(0) == targets.size(0), "features.size(0) is not equal to targets.size(0)" 100 | cluster_loss, cluster_dist_ap, cluster_dist_an = self._cluster_loss(features, targets, self.ordered, self.ids_per_batch, self.imgs_per_id) 101 | return cluster_loss, cluster_dist_ap, cluster_dist_an 102 | 103 | 104 | class ClusterLoss_local(nn.Module): 105 | def __init__(self, margin=10, use_gpu=True, ordered=True, ids_per_batch=32, imgs_per_id=4): 106 | super(ClusterLoss_local, self).__init__() 107 | self.use_gpu = use_gpu 108 | self.margin = margin 109 | self.ordered = ordered 110 | self.ids_per_batch = ids_per_batch 111 | self.imgs_per_id = imgs_per_id 112 | 113 | def _euclidean_dist(self, x, y): 114 | """ 115 | Args: 116 | x: pytorch Variable, with shape [m, d] 117 | y: pytorch Variable, with shape [n, d] 118 | Returns: 119 | dist: pytorch Variable, with shape [m, n] 120 | """ 121 | m, n = x.size(0), y.size(0) 122 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 123 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 124 | dist = xx + yy 125 | dist.addmm_(1, -2, x, y.t()) 126 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 127 | return dist 128 | 129 | def _shortest_dist(self, dist_mat): 130 | """Parallel version. 131 | Args: 132 | dist_mat: pytorch Variable, available shape: 133 | 1) [m, n] 134 | 2) [m, n, N], N is batch size 135 | 3) [m, n, *], * can be arbitrary additional dimensions 136 | Returns: 137 | dist: three cases corresponding to `dist_mat`: 138 | 1) scalar 139 | 2) pytorch Variable, with shape [N] 140 | 3) pytorch Variable, with shape [*] 141 | """ 142 | m, n = dist_mat.size()[:2] 143 | # Just offering some reference for accessing intermediate distance. 144 | dist = [[0 for _ in range(n)] for _ in range(m)] 145 | for i in range(m): 146 | for j in range(n): 147 | if (i == 0) and (j == 0): 148 | dist[i][j] = dist_mat[i, j] 149 | elif (i == 0) and (j > 0): 150 | dist[i][j] = dist[i][j - 1] + dist_mat[i, j] 151 | elif (i > 0) and (j == 0): 152 | dist[i][j] = dist[i - 1][j] + dist_mat[i, j] 153 | else: 154 | dist[i][j] = torch.min(dist[i - 1][j], dist[i][j - 1]) + dist_mat[i, j] 155 | dist = dist[-1][-1] 156 | return dist 157 | 158 | def _local_dist(self, x, y): 159 | """ 160 | Args: 161 | x: pytorch Variable, with shape [M, m, d] 162 | y: pytorch Variable, with shape [N, n, d] 163 | Returns: 164 | dist: pytorch Variable, with shape [M, N] 165 | """ 166 | M, m, d = x.size() 167 | N, n, d = y.size() 168 | x = x.contiguous().view(M * m, d) 169 | y = y.contiguous().view(N * n, d) 170 | # shape [M * m, N * n] 171 | dist_mat = self._euclidean_dist(x, y) 172 | dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.) 173 | # shape [M * m, N * n] -> [M, m, N, n] -> [m, n, M, N] 174 | dist_mat = dist_mat.contiguous().view(M, m, N, n).permute(1, 3, 0, 2) 175 | # shape [M, N] 176 | dist_mat = self._shortest_dist(dist_mat) 177 | return dist_mat 178 | 179 | def _cluster_loss(self, features, targets,ordered=True, ids_per_batch=32, imgs_per_id=4): 180 | """ 181 | Args: 182 | features: prediction matrix (before softmax) with shape (batch_size, H, feature_dim) 183 | targets: ground truth labels with shape (batch_size) 184 | ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id. 185 | ids_per_batch: num of different ids per batch 186 | imgs_per_id: num of images per id 187 | Return: 188 | cluster_loss 189 | """ 190 | if self.use_gpu: 191 | if ordered: 192 | if targets.size(0) == ids_per_batch * imgs_per_id: 193 | unique_labels = targets[0:targets.size(0):imgs_per_id] 194 | else: 195 | unique_labels = targets.cpu().unique().cuda() 196 | else: 197 | unique_labels = targets.cpu().unique().cuda() 198 | else: 199 | if ordered: 200 | if targets.size(0) == ids_per_batch * imgs_per_id: 201 | unique_labels = targets[0:targets.size(0):imgs_per_id] 202 | else: 203 | unique_labels = targets.unique() 204 | else: 205 | unique_labels = targets.unique() 206 | 207 | inter_min_distance = torch.zeros(unique_labels.size(0)) 208 | intra_max_distance = torch.zeros(unique_labels.size(0)) 209 | center_features = torch.zeros(unique_labels.size(0), features.size(1), features.size(2)) 210 | 211 | if self.use_gpu: 212 | inter_min_distance = inter_min_distance.cuda() 213 | intra_max_distance = intra_max_distance.cuda() 214 | center_features = center_features.cuda() 215 | 216 | index = torch.range(0, unique_labels.size(0) - 1) 217 | for i in range(unique_labels.size(0)): 218 | label = unique_labels[i] 219 | same_class_features = features[targets == label] 220 | center_features[i] = same_class_features.mean(dim=0) 221 | intra_class_distance = self._local_dist(center_features[index==i], same_class_features) 222 | # print('intra_class_distance', intra_class_distance) 223 | intra_max_distance[i] = intra_class_distance.max() 224 | # print('intra_max_distance:', intra_max_distance) 225 | 226 | for i in range(unique_labels.size(0)): 227 | inter_class_distance = self._local_dist(center_features[index==i], center_features[index != i]) 228 | # print('inter_class_distance', inter_class_distance) 229 | inter_min_distance[i] = inter_class_distance.min() 230 | # print('inter_min_distance:', inter_min_distance) 231 | 232 | cluster_loss = torch.mean(torch.relu(intra_max_distance - inter_min_distance + self.margin)) 233 | return cluster_loss, intra_max_distance, inter_min_distance 234 | 235 | def forward(self, features, targets): 236 | """ 237 | Args: 238 | features: prediction matrix (before softmax) with shape (batch_size, H, feature_dim) 239 | targets: ground truth labels with shape (batch_size) 240 | ordered: bool type. If the train data per batch are formed as p*k, where p is the num of ids per batch and k is the num of images per id. 241 | ids_per_batch: num of different ids per batch 242 | imgs_per_id: num of images per id 243 | Return: 244 | cluster_loss 245 | """ 246 | assert features.size(0) == targets.size(0), "features.size(0) is not equal to targets.size(0)" 247 | cluster_loss, cluster_dist_ap, cluster_dist_an = self._cluster_loss(features, targets, self.ordered, self.ids_per_batch, self.imgs_per_id) 248 | return cluster_loss, cluster_dist_ap, cluster_dist_an 249 | 250 | 251 | if __name__ == '__main__': 252 | use_gpu = True 253 | cluster_loss = ClusterLoss(use_gpu=use_gpu, ids_per_batch=4, imgs_per_id=4) 254 | features = torch.rand(16, 2048) 255 | targets = torch.Tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]) 256 | if use_gpu: 257 | features = torch.rand(16, 2048).cuda() 258 | targets = torch.Tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]).cuda() 259 | loss = cluster_loss(features, targets) 260 | print(loss) 261 | 262 | cluster_loss_local = ClusterLoss_local(use_gpu=use_gpu, ids_per_batch=4, imgs_per_id=4) 263 | features = torch.rand(16, 8, 2048) 264 | targets = torch.Tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]) 265 | if use_gpu: 266 | features = torch.rand(16, 8, 2048).cuda() 267 | targets = torch.Tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]).cuda() 268 | loss = cluster_loss_local(features, targets) 269 | print(loss) 270 | -------------------------------------------------------------------------------- /modeling/backbones/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint as cp 6 | from collections import OrderedDict 7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 8 | 9 | 10 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 11 | 12 | model_urls = { 13 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 14 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 15 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 16 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 17 | } 18 | 19 | 20 | def _bn_function_factory(norm, relu, conv): 21 | def bn_function(*inputs): 22 | concated_features = torch.cat(inputs, 1) 23 | bottleneck_output = conv(relu(norm(concated_features))) 24 | return bottleneck_output 25 | 26 | return bn_function 27 | 28 | 29 | class _DenseLayer(nn.Sequential): 30 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False): 31 | super(_DenseLayer, self).__init__() 32 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 33 | self.add_module('relu1', nn.ReLU(inplace=True)), 34 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 35 | growth_rate, kernel_size=1, stride=1, 36 | bias=False)), 37 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 38 | self.add_module('relu2', nn.ReLU(inplace=True)), 39 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 40 | kernel_size=3, stride=1, padding=1, 41 | bias=False)), 42 | self.drop_rate = drop_rate 43 | self.memory_efficient = memory_efficient 44 | 45 | def forward(self, *prev_features): 46 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 47 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 48 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 49 | else: 50 | bottleneck_output = bn_function(*prev_features) 51 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 52 | if self.drop_rate > 0: 53 | new_features = F.dropout(new_features, p=self.drop_rate, 54 | training=self.training) 55 | return new_features 56 | 57 | 58 | class _DenseBlock(nn.Module): 59 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False): 60 | super(_DenseBlock, self).__init__() 61 | for i in range(num_layers): 62 | layer = _DenseLayer( 63 | num_input_features + i * growth_rate, 64 | growth_rate=growth_rate, 65 | bn_size=bn_size, 66 | drop_rate=drop_rate, 67 | memory_efficient=memory_efficient, 68 | ) 69 | self.add_module('denselayer%d' % (i + 1), layer) 70 | 71 | def forward(self, init_features): 72 | features = [init_features] 73 | for name, layer in self.named_children(): 74 | new_features = layer(*features) 75 | features.append(new_features) 76 | return torch.cat(features, 1) 77 | 78 | 79 | class _Transition(nn.Sequential): 80 | def __init__(self, num_input_features, num_output_features): 81 | super(_Transition, self).__init__() 82 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 83 | self.add_module('relu', nn.ReLU(inplace=True)) 84 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 85 | kernel_size=1, stride=1, bias=False)) 86 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 87 | 88 | 89 | class DenseNet(nn.Module): 90 | r"""Densenet-BC model class, based on 91 | `"Densely Connected Convolutional Networks" `_ 92 | Args: 93 | growth_rate (int) - how many filters to add each layer (`k` in paper) 94 | block_config (list of 4 ints) - how many layers in each pooling block 95 | num_init_features (int) - the number of filters to learn in the first convolution layer 96 | bn_size (int) - multiplicative factor for number of bottle neck layers 97 | (i.e. bn_size * k features in the bottleneck layer) 98 | drop_rate (float) - dropout rate after each dense layer 99 | num_classes (int) - number of classification classes 100 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 101 | but slower. Default: *False*. See `"paper" `_ 102 | """ 103 | 104 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 105 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False): 106 | 107 | super(DenseNet, self).__init__() 108 | 109 | # First convolution 110 | self.features = nn.Sequential(OrderedDict([ 111 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, 112 | padding=3, bias=False)), 113 | ('norm0', nn.BatchNorm2d(num_init_features)), 114 | ('relu0', nn.ReLU(inplace=True)), 115 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 116 | ])) 117 | 118 | # Each denseblock 119 | num_features = num_init_features 120 | for i, num_layers in enumerate(block_config): 121 | block = _DenseBlock( 122 | num_layers=num_layers, 123 | num_input_features=num_features, 124 | bn_size=bn_size, 125 | growth_rate=growth_rate, 126 | drop_rate=drop_rate, 127 | memory_efficient=memory_efficient 128 | ) 129 | self.features.add_module('denseblock%d' % (i + 1), block) 130 | num_features = num_features + num_layers * growth_rate 131 | if i != len(block_config) - 1: 132 | trans = _Transition(num_input_features=num_features, 133 | num_output_features=num_features // 2) 134 | self.features.add_module('transition%d' % (i + 1), trans) 135 | num_features = num_features // 2 136 | 137 | # Final batch norm 138 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 139 | 140 | # Linear layer 141 | self.classifier = nn.Linear(num_features, num_classes) 142 | 143 | # Official init from torch repo. 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | nn.init.kaiming_normal_(m.weight) 147 | elif isinstance(m, nn.BatchNorm2d): 148 | nn.init.constant_(m.weight, 1) 149 | nn.init.constant_(m.bias, 0) 150 | elif isinstance(m, nn.Linear): 151 | nn.init.constant_(m.bias, 0) 152 | 153 | def forward(self, x): 154 | features = self.features(x) 155 | # out = F.relu(features, inplace=True) 156 | # out = F.adaptive_avg_pool2d(out, (1, 1)) 157 | # out = torch.flatten(out, 1) 158 | # out = self.classifier(out) 159 | return features 160 | 161 | 162 | def _load_state_dict(model, model_url, progress): 163 | # '.'s are no longer allowed in module names, but previous _DenseLayer 164 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 165 | # They are also in the checkpoints in model_urls. This pattern is used 166 | # to find such keys. 167 | pattern = re.compile( 168 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 169 | 170 | state_dict = load_state_dict_from_url(model_url, progress=progress) 171 | for key in list(state_dict.keys()): 172 | res = pattern.match(key) 173 | if res: 174 | new_key = res.group(1) + res.group(2) 175 | state_dict[new_key] = state_dict[key] 176 | del state_dict[key] 177 | model.load_state_dict(state_dict) 178 | 179 | 180 | def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, 181 | **kwargs): 182 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 183 | if pretrained: 184 | _load_state_dict(model, model_urls[arch], progress) 185 | return model 186 | 187 | 188 | def densenet121(pretrained=False, progress=True, **kwargs): 189 | r"""Densenet-121 model from 190 | `"Densely Connected Convolutional Networks" `_ 191 | Args: 192 | pretrained (bool): If True, returns a model pre-trained on ImageNet 193 | progress (bool): If True, displays a progress bar of the download to stderr 194 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 195 | but slower. Default: *False*. See `"paper" `_ 196 | """ 197 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, 198 | **kwargs) 199 | 200 | 201 | def densenet161(pretrained=False, progress=True, **kwargs): 202 | r"""Densenet-161 model from 203 | `"Densely Connected Convolutional Networks" `_ 204 | Args: 205 | pretrained (bool): If True, returns a model pre-trained on ImageNet 206 | progress (bool): If True, displays a progress bar of the download to stderr 207 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 208 | but slower. Default: *False*. See `"paper" `_ 209 | """ 210 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, 211 | **kwargs) 212 | 213 | 214 | def densenet169(pretrained=False, progress=True, **kwargs): 215 | r"""Densenet-169 model from 216 | `"Densely Connected Convolutional Networks" `_ 217 | Args: 218 | pretrained (bool): If True, returns a model pre-trained on ImageNet 219 | progress (bool): If True, displays a progress bar of the download to stderr 220 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 221 | but slower. Default: *False*. See `"paper" `_ 222 | """ 223 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, 224 | **kwargs) 225 | 226 | 227 | def densenet201(pretrained=False, progress=True, **kwargs): 228 | r"""Densenet-201 model from 229 | `"Densely Connected Convolutional Networks" `_ 230 | Args: 231 | pretrained (bool): If True, returns a model pre-trained on ImageNet 232 | progress (bool): If True, displays a progress bar of the download to stderr 233 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 234 | but slower. Default: *False*. See `"paper" `_ 235 | """ 236 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, 237 | **kwargs) 238 | 239 | def load_param(self, model_path): 240 | pattern = re.compile( 241 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 242 | state_dict = torch.load(model_path) 243 | # state_dict = load_state_dict_from_url(model_url, progress=progress) 244 | for key in list(state_dict.keys()): 245 | res = pattern.match(key) 246 | if res: 247 | new_key = res.group(1) + res.group(2) 248 | state_dict[new_key] = state_dict[key] 249 | del state_dict[key] 250 | #model.load_state_dict(state_dict) 251 | #for i in param_dict: 252 | # if 'fc' in i: 253 | # continue 254 | # self.state_dict()[i].copy_(param_dict[i]) 255 | 256 | def random_init(self): 257 | for m in self.modules(): 258 | if isinstance(m, nn.Conv2d): 259 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 260 | m.weight.data.normal_(0, math.sqrt(2. / n)) 261 | elif isinstance(m, nn.BatchNorm2d): 262 | m.weight.data.fill_(1) 263 | m.bias.data.zero_() --------------------------------------------------------------------------------