├── README.md ├── configs └── aicity20.yml ├── images ├── change_background.png ├── framework.png ├── illustrated.png ├── image_translation.png ├── paper.pdf ├── results.png └── veri.png ├── lib ├── config │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── defaults.cpython-37.pyc │ └── defaults.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── build.cpython-37.pyc │ │ └── collate_batch.cpython-37.pyc │ ├── build.py │ ├── collate_batch.py │ ├── datasets │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── aicity20.cpython-37.pyc │ │ │ ├── aicity20_ReCam.cpython-37.pyc │ │ │ ├── aicity20_ReColor.cpython-37.pyc │ │ │ ├── aicity20_ReOri.cpython-37.pyc │ │ │ ├── aicity20_ReType.cpython-37.pyc │ │ │ ├── aicity20_sim.cpython-37.pyc │ │ │ ├── aicity20_split.cpython-37.pyc │ │ │ ├── aicity20_trainval.cpython-37.pyc │ │ │ ├── bases.cpython-37.pyc │ │ │ ├── cuhk03.cpython-37.pyc │ │ │ ├── dataset_loader.cpython-37.pyc │ │ │ ├── dukemtmcreid.cpython-37.pyc │ │ │ ├── market1501.cpython-37.pyc │ │ │ ├── msmt17.cpython-37.pyc │ │ │ └── veri.cpython-37.pyc │ │ ├── aicity20.py │ │ ├── aicity20_ReCam.py │ │ ├── aicity20_ReColor.py │ │ ├── aicity20_ReOri.py │ │ ├── aicity20_ReType.py │ │ ├── aicity20_sim.py │ │ ├── aicity20_split.py │ │ ├── aicity20_trainval.py │ │ ├── bases.py │ │ ├── cuhk03.py │ │ ├── dataset_loader.py │ │ ├── dukemtmcreid.py │ │ ├── market1501.py │ │ ├── msmt17.py │ │ ├── track3_h5.py │ │ └── veri.py │ ├── samplers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── triplet_sampler.cpython-37.pyc │ │ └── triplet_sampler.py │ └── transforms │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── augmix.cpython-37.pyc │ │ ├── build.cpython-37.pyc │ │ └── transforms.cpython-37.pyc │ │ ├── augmix.py │ │ ├── build.py │ │ ├── fmix.py │ │ ├── transforms.py │ │ └── vis_transform.py ├── engine │ ├── __pycache__ │ │ └── train_net.cpython-37.pyc │ ├── inference.py │ └── train_net.py ├── layers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── build.cpython-37.pyc │ │ ├── metric_learning.cpython-37.pyc │ │ ├── pooling.cpython-37.pyc │ │ └── triplet_loss.cpython-37.pyc │ ├── build.py │ ├── metric_learning.py │ ├── pooling.py │ └── triplet_loss.py ├── modeling │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── baseline.cpython-37.pyc │ ├── backbones │ │ ├── STNModule.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── STNModule.cpython-37.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── mixstyle.cpython-37.pyc │ │ │ ├── resnest.cpython-37.pyc │ │ │ ├── resnet.cpython-37.pyc │ │ │ ├── resnet_ibn_a.cpython-37.pyc │ │ │ ├── resnext_ibn_a.cpython-37.pyc │ │ │ ├── resnext_ibn_a_2_head.cpython-37.pyc │ │ │ └── resnext_ibn_a_attention.cpython-37.pyc │ │ ├── densenet.py │ │ ├── mixstyle.py │ │ ├── nfnet.py │ │ ├── osnet.py │ │ ├── osnet_ain.py │ │ ├── regnet │ │ │ ├── RegNetY-1.6GF_dds_8gpu.yaml │ │ │ ├── RegNetY-3.2GF_dds_8gpu.yaml │ │ │ ├── RegNetY-800MF_dds_8gpu.yaml │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ ├── config.cpython-37.pyc │ │ │ │ └── regnet.cpython-37.pyc │ │ │ ├── config.py │ │ │ └── regnet.py │ │ ├── res2net.py │ │ ├── resnest.py │ │ ├── resnet.py │ │ ├── resnet_ibn_a.py │ │ ├── resnet_ibn_b.py │ │ ├── resnext_ibn_a.py │ │ ├── resnext_ibn_a_2_head.py │ │ └── resnext_ibn_a_attention.py │ ├── baseline.py │ └── multiheads_baseline.py ├── solver │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── build.cpython-37.pyc │ │ ├── lr_scheduler.cpython-37.pyc │ │ ├── ranger.cpython-37.pyc │ │ └── swa.cpython-37.pyc │ ├── build.py │ ├── lr_scheduler.py │ ├── ranger.py │ └── swa.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── iotools.cpython-37.pyc │ ├── logger.cpython-37.pyc │ ├── post_process.cpython-37.pyc │ └── reid_eval.cpython-37.pyc │ ├── actmap.py │ ├── bbox_utils.py │ ├── iotools.py │ ├── logger.py │ ├── post_process.py │ ├── reid_eval.py │ └── vis.py ├── scripts ├── ReCamID.sh ├── ReOriID.sh ├── submit.sh ├── test.sh └── train.sh ├── setup.py └── tools ├── aicity20 ├── compute_distmat_from_feats.py ├── eval_by_distmat.py ├── fix_track.py ├── multi_model_ensemble.py ├── submit.py ├── vis_result.py └── weakly_supervised_crop_aug.py ├── gen_vis.py ├── test.py ├── train.py └── vis_actmap.py /README.md: -------------------------------------------------------------------------------- 1 | # A STRONG BASELINE FOR VEHICLE RE-IDENTIFICATION 2 | **This paper is accepted to the IEEE Conference on Computer Vision and Pattern Recognition Workshop(CVPRW) 2021** 3 | 4 | ![](./images/framework.png) 5 | ![](images/illustrated.png) 6 | 7 | This repo is the official implementation for the paper [**A Strong Baseline For Vehicle Re-Identification**](./images/paper.pdf) in [Track 2, 2021 AI CITY CHALLENGE](https://www.aicitychallenge.org/). 8 | 9 | 10 | ## I.INTRODUCTION 11 | Our proposed method sheds light on three main factors that contribute most to the performance, including: 12 | + Minizing the gap between real and synthetic data 13 | + Network modification by stacking multi heads with attention mechanism to backbone 14 | + Adaptive loss weight adjustment. 15 | 16 | Our method achieves 61.34% mAP on the private CityFlow testset without using external dataset or pseudo labeling, and outperforms all previous works at 87.1% mAP on the [Veri](https://vehiclereid.github.io/VeRi/) benchmark. 17 | 18 | ## II. INSTALLATION 19 | 1. pytorch>=1.2.0 20 | 2. yacs 21 | 3. [apex](https://github.com/NVIDIA/apex) (optional for FP16 training, if you don't have apex installed, please turn-off FP16 training by setting SOLVER.FP16=False) 22 | ```` 23 | $ git clone https://github.com/NVIDIA/apex 24 | $ cd apex 25 | $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 26 | ```` 27 | 4. python>=3.7 28 | 5. cv2 29 | ## III. REPRODUCE THE RESULT ON AICITY 2020 CHALLENGE 30 | Download the Imagenet pretrained checkpoint [resnext101_ibn](http://118.69.233.170:60001/open/AICity/Imagenet_pretrained/resnext101_ibn_a.pth.tar), [resnet50_ibn](http://118.69.233.170:60001/open/AICity/Imagenet_pretrained/resnet50_ibn_a.pth.tar), [resnet152](http://118.69.233.170:60001/open/AICity/Imagenet_pretrained/resnet152-b121ed2d.pth) 31 | 32 | ### 1.Train 33 | 34 | + **Prepare training data** 35 | - Convert the original synthetic images into more realistic one, using [Unit](https://github.com/mingyuliutw/UNIT) repository 36 | ![](images/image_translation.png) 37 | 38 | - Using Mask-RCNN (pre-train on COCO) to extract foreground (car) and background, then we swap the foreground and background between training images. 39 | ![](images/change_background.png) 40 | 41 | 42 | + **Vehicle ReID** 43 | Train multiple models using 3 different backbones: ResNext101_ibn, Resnet50_ibn, Resnet152 44 | ```bash 45 | ./scripts/train.sh 46 | ``` 47 | 48 | + **Orientation ReID** 49 | ```bash 50 | ./scripts/ReOriID.sh 51 | ``` 52 | 53 | + **Camera ReID** 54 | ```bash 55 | ./scripts/ReCamID.sh 56 | ``` 57 | 58 | ### 2. Test 59 | ```bash 60 | ./scripts/test.sh 61 | ``` 62 | 63 | 64 | ## IV. PERFORMANCE 65 | 66 | ### 1. Comparison with state-of-the art methods on VeRi776 67 | 68 | ![](images/veri.png) 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /configs/aicity20.yml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | NUM_INSTANCE: 8 3 | NUM_WORKERS: 4 4 | SAMPLER: 'softmax_triplet' 5 | DATASETS: 6 | COMBINEALL: False 7 | ROOT_DIR: '/media/data/ai-city/Track2' 8 | TEST: ('aicity20_split',) 9 | TRAIN: ('aicity20_split',) 10 | INPUT: 11 | AUGMIX_PROB: 0.25 12 | COLORJIT_PROB: 0.25 13 | COLOR_SPACE: 'rgb' 14 | CUTOFF_LONGTAILED: False 15 | LONGTAILED_THR: 2 16 | PADDING: 10 17 | PIXEL_MEAN: [0.485, 0.456, 0.406] 18 | PIXEL_STD: [0.229, 0.224, 0.225] 19 | PROB: 0.5 20 | RANDOM_AFFINE_PROB: 0.25 21 | RANDOM_BLUR_PROB: 0.25 22 | RANDOM_PATCH_PROB: 0.25 23 | RE_PROB: 0.5 24 | RE_SH: 0.4 25 | SIZE_TEST: [320, 320] 26 | SIZE_TRAIN: [320, 320] 27 | VERTICAL_FLIP_PROB: 0.0 28 | MODEL: 29 | DEVICE: 'cuda' 30 | DEVICE_ID: "'0'" 31 | DROPOUT_PROB: 0.0 32 | EMBEDDING_DIM: 512 33 | EMBEDDING_HEAD: 'fc' 34 | FC_WEIGHT_NORM: False 35 | FROZEN_FEATURE_EPOCH: 0 36 | GLOBAL_DIM: 2048 37 | ID_LOSS_TYPE: 'circle' 38 | ID_LOSS_WEIGHT: 1.0 39 | IF_LABELSMOOTH: 'on' 40 | LAST_STRIDE: 1 41 | LOCAL_DIM: 512 42 | METRIC_LOSS_TYPE: 'triplet' 43 | MODEL_TYPE: 'baseline_multiheads' 44 | NAME: 'resnext101_ibn_a' 45 | NECK: 'bnneck' 46 | POOLING_METHOD: 'GeM' 47 | PRETRAIN_CHOICE: 'imagenet' 48 | PRETRAIN_PATH: 'pretrained_ckpts/resnext101_ibn_a.pth.tar' 49 | TRIPLET_LOSS_WEIGHT: 1.0 50 | OUTPUT_DIR: 'output/' 51 | SOLVER: 52 | BASE_LR: 0.00035 53 | BIAS_LR_FACTOR: 1 54 | CENTER_LOSS_WEIGHT: 0.0005 55 | CENTER_LR: 0.5 56 | CHECKPOINT_PERIOD: 50 57 | CLUSTER_MARGIN: 0.3 58 | COSINE_MARGIN: 0.35 59 | COSINE_SCALE: 64 60 | CYCLE_EPOCH: 30 61 | EVAL_PERIOD: 1 62 | FC_LR_FACTOR: 1 63 | FP16: True 64 | FREEZE_BASE_EPOCHS: 0 65 | GAMMA: 0.1 66 | HARD_EXAMPLE_MINING_METHOD: batch_hard 67 | IMS_PER_BATCH: 32 68 | LOG_PERIOD: 50 69 | LR_SCHEDULER: 'cosine_step' 70 | MARGIN: 0.3 71 | MAX_EPOCHS: 12 72 | MOMENTUM: 0.9 73 | NO_BIAS_DECAY: False 74 | OPTIMIZER_NAME: 'Adam' 75 | RANGE_ALPHA: 0 76 | RANGE_BETA: 1 77 | RANGE_K: 2 78 | RANGE_LOSS_WEIGHT: 1 79 | RANGE_MARGIN: 0.3 80 | STEPS: (12, 20) 81 | WARMUP_FACTOR: 0.01 82 | WARMUP_ITERS: 500 83 | WARMUP_METHOD: 'linear' 84 | WEIGHT_DECAY: 1e-06 85 | WEIGHT_DECAY_BIAS: 1e-06 86 | XBM_SIZE: 4 87 | TEST: 88 | ATTRIBUTES_RERANK: False 89 | CAM_DIST_PATH: 'dists_test/cam_feat_distmat.npy' 90 | DO_DBA: False 91 | DO_RERANK: True 92 | FEAT_NORM: 'yes' 93 | FLIP_TEST: False 94 | IMS_PER_BATCH: 32 95 | NECK_FEAT: 'after' 96 | ORI_DIST_PATH: 'dists_test/ori_feat_distmat.npy' 97 | QUERY_EXPANSION: False 98 | RERANK_PARAM: [50, 15, 0.5] 99 | TRACK_AUG: False 100 | TRACK_RERANK: False 101 | USE_VOC: False 102 | WEIGHT: 'output/best.pth' 103 | WRITE_RESULT: True 104 | -------------------------------------------------------------------------------- /images/change_background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/images/change_background.png -------------------------------------------------------------------------------- /images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/images/framework.png -------------------------------------------------------------------------------- /images/illustrated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/images/illustrated.png -------------------------------------------------------------------------------- /images/image_translation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/images/image_translation.png -------------------------------------------------------------------------------- /images/paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/images/paper.pdf -------------------------------------------------------------------------------- /images/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/images/results.png -------------------------------------------------------------------------------- /images/veri.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/images/veri.png -------------------------------------------------------------------------------- /lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | -------------------------------------------------------------------------------- /lib/config/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/config/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/defaults.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/config/__pycache__/defaults.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import make_data_loader 8 | -------------------------------------------------------------------------------- /lib/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/__pycache__/collate_batch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/__pycache__/collate_batch.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | from .collate_batch import train_collate_fn, val_collate_fn 10 | from .datasets import init_dataset, ImageDataset, BaseImageDataset, apply_id_bias 11 | from .samplers import RandomIdentitySampler, MPerClassSampler, RandomIdentityCrossDomainSampler # New add by gu 12 | from .transforms import build_transforms 13 | 14 | 15 | def make_data_loader(cfg, shuffle_train=True): 16 | train_transforms = build_transforms(cfg, is_train=shuffle_train) 17 | val_transforms = build_transforms(cfg, is_train=False) 18 | num_workers = cfg.DATALOADER.NUM_WORKERS 19 | # import ipdb; ipdb.set_trace() 20 | 21 | dataset = BaseImageDataset() 22 | # LOAD TRAIN 23 | print(cfg.DATASETS.TRAIN) 24 | # import ipdb; ipdb.set_trace() 25 | if isinstance(cfg.DATASETS.TRAIN, str): 26 | cur_dataset = init_dataset(cfg.DATASETS.TRAIN, root=cfg.DATASETS.ROOT_DIR) 27 | dataset = cur_dataset 28 | else: 29 | for i, dataset_name in enumerate(cfg.DATASETS.TRAIN): 30 | cur_dataset = init_dataset(dataset_name, root=cfg.DATASETS.ROOT_DIR) 31 | min_id, max_id = dataset.get_id_range(dataset.train) 32 | dataset.train.extend(apply_id_bias(cur_dataset.train, id_bias=max_id + 1)) 33 | dataset.train_tracks += cur_dataset.train_tracks 34 | if cfg.DATASETS.COMBINEALL: 35 | min_id, max_id = dataset.get_id_range(dataset.train) 36 | to_merge_train = dataset.relabel(cur_dataset.query + cur_dataset.gallery) 37 | dataset.train.extend(apply_id_bias(to_merge_train, id_bias=max_id + 1)) 38 | dataset.train_tracks += cur_dataset.test_tracks 39 | dataset.train = dataset.relabel(dataset.train) # in case of inconsistent ids 40 | # dataset.train.extend(dataset.train) 41 | # dataset.train.extend(dataset.train) 42 | # dataset.train.extend(dataset.train) 43 | 44 | # cutoff long tailed data 45 | if cfg.INPUT.CUTOFF_LONGTAILED:Cybercoreess(dataset.train, 46 | NUM_INSTANCE_PER_CLS=cfg.INPUT.LONGTAILED_THR) 47 | 48 | # LOAD VALIDATE 49 | if isinstance(cfg.DATASETS.TEST, str): 50 | cur_dataset = init_dataset(cfg.DATASETS.TEST, root=cfg.DATASETS.ROOT_DIR) 51 | dataset.query, dataset.gallery = cur_dataset.query, cur_dataset.gallery 52 | dataset.test_tracks = cur_dataset.test_tracks 53 | dataset.query_orientation = cur_dataset.query_orientation 54 | dataset.gallery_orientation = cur_dataset.gallery_orientation 55 | else: 56 | dataset.query, dataset.gallery = [], [] 57 | for i, dataset_name in enumerate(cfg.DATASETS.TEST): 58 | cur_dataset = init_dataset(dataset_name, root=cfg.DATASETS.ROOT_DIR) 59 | dataset.query.extend(apply_id_bias(cur_dataset.query, id_bias=i * 10000)) 60 | dataset.gallery.extend(apply_id_bias(cur_dataset.gallery, id_bias=i * 10000)) 61 | dataset.test_tracks += cur_dataset.test_tracks 62 | dataset.query_orientation = cur_dataset.query_orientation 63 | dataset.gallery_orientation = cur_dataset.gallery_orientation 64 | dataset.print_dataset_statistics(dataset.train, dataset.query, dataset.gallery) 65 | num_train_pids, num_train_imgs, num_train_cams = dataset.get_imagedata_info(dataset.train) 66 | 67 | # import ipdb; ipdb.set_trace() 68 | num_classes = num_train_pids 69 | train_set = ImageDataset(dataset.train, train_transforms, True) 70 | # import ipdb; ipdb.set_trace() 71 | if cfg.DATALOADER.SAMPLER == 'softmax': 72 | train_loader = DataLoader( 73 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=shuffle_train, num_workers=num_workers, 74 | collate_fn=train_collate_fn 75 | ) 76 | elif cfg.DATALOADER.SAMPLER == 'm_per_class': 77 | train_loader = DataLoader( 78 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 79 | sampler=MPerClassSampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 80 | num_workers=num_workers, collate_fn=train_collate_fn 81 | ) 82 | else: 83 | train_loader = DataLoader( 84 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 85 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 86 | num_workers=num_workers, collate_fn=train_collate_fn 87 | ) 88 | 89 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms, False) 90 | val_loader = DataLoader( 91 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 92 | collate_fn=val_collate_fn 93 | ) 94 | return train_loader, val_loader, len(dataset.query), num_classes, dataset 95 | 96 | -------------------------------------------------------------------------------- /lib/data/collate_batch.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | 9 | 10 | def train_collate_fn(batch): 11 | imgs, pids, camids, domain, img_paths = zip(*batch) 12 | pids = torch.tensor(pids, dtype=torch.int64) 13 | return torch.stack(imgs, dim=0), pids, camids, domain, img_paths 14 | 15 | 16 | def val_collate_fn(batch): 17 | imgs, pids, camids, domain, img_paths = zip(*batch) 18 | return torch.stack(imgs, dim=0), pids, camids, img_paths 19 | -------------------------------------------------------------------------------- /lib/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from .cuhk03 import CUHK03 3 | from .dukemtmcreid import DukeMTMCreID 4 | from .market1501 import Market1501 5 | from .msmt17 import MSMT17 6 | from .veri import VeRi 7 | from .aicity20 import AICity20 8 | from .aicity20_sim import AICity20Sim 9 | from .aicity20_trainval import AICity20Trainval 10 | from .aicity20_ReOri import AICity20ReOri 11 | from .aicity20_ReCam import AICity20ReCam 12 | from .aicity20_ReColor import AICity20ReColor 13 | from .aicity20_ReType import AICity20ReType 14 | from .dataset_loader import ImageDataset 15 | from .bases import BaseImageDataset, apply_id_bias 16 | from .aicity20_split import AICity20_Split 17 | 18 | __factory = { 19 | 'market1501': Market1501, 20 | 'cuhk03': CUHK03, 21 | 'dukemtmc-reid': DukeMTMCreID, 22 | 'msmt17': MSMT17, 23 | 'veri': VeRi, 24 | 'aicity20': AICity20, 25 | 'aicity20-sim': AICity20Sim, 26 | 'aicity20-trainval': AICity20Trainval, 27 | 'aicity20-ReOri': AICity20ReOri, 28 | 'aicity20-ReCam': AICity20ReCam, 29 | 'aicity20-ReColor': AICity20ReColor, 30 | 'aicity20-ReType': AICity20ReType, 31 | 'aicity20_split': AICity20_Split 32 | } 33 | 34 | 35 | def get_names(): 36 | return __factory.keys() 37 | 38 | 39 | def init_dataset(name, *args, **kwargs): 40 | if name not in __factory.keys(): 41 | raise KeyError("Unknown datasets: {}".format(name)) 42 | return __factory[name](*args, **kwargs) 43 | -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/aicity20.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/aicity20.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/aicity20_ReCam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/aicity20_ReCam.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/aicity20_ReColor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/aicity20_ReColor.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/aicity20_ReOri.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/aicity20_ReOri.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/aicity20_ReType.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/aicity20_ReType.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/aicity20_sim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/aicity20_sim.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/aicity20_split.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/aicity20_split.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/aicity20_trainval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/aicity20_trainval.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/bases.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/bases.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/cuhk03.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/cuhk03.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/dataset_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/dataset_loader.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/dukemtmcreid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/dukemtmcreid.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/market1501.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/market1501.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/msmt17.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/msmt17.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/__pycache__/veri.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/datasets/__pycache__/veri.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/datasets/aicity20.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | import os 6 | import os.path as osp 7 | import xml.etree.ElementTree as ET 8 | 9 | 10 | from .bases import BaseImageDataset 11 | 12 | 13 | class AICity20(BaseImageDataset): 14 | """ 15 | ---------------------------------------- 16 | subset | # ids | # images | # cameras 17 | ---------------------------------------- 18 | train | 333 | 36935 | 36 19 | query | 333 | 1052 | ? 20 | gallery | 333 | 18290 | ? 21 | ---------------------------------------- 22 | 23 | """ 24 | dataset_dir = 'AIC21_Track2_ReID' 25 | dataset_aug_dir = 'AIC20_ReID_Cropped' 26 | def __init__(self, root='', verbose=True, **kwargs): 27 | super(AICity20, self).__init__() 28 | # import ipdb; ipdb.set_trace() 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | # self.dataset_aug_dir = osp.join(root, self.dataset_aug_dir) 31 | 32 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 33 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 34 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 35 | # self.train_aug_dir = osp.join(self.dataset_aug_dir, 'image_train') 36 | 37 | self.list_train_path = osp.join(self.dataset_dir, 'name_train.txt') 38 | self.list_query_path = osp.join(self.dataset_dir, 'name_query.txt') 39 | self.list_gallery_path = osp.join(self.dataset_dir, 'name_test.txt') 40 | 41 | self.train_label_path = osp.join(self.dataset_dir, 'train_label.xml') 42 | 43 | 44 | self._check_before_run() 45 | 46 | train = self._process_dir(self.train_dir, self.list_train_path, self.train_label_path, relabel=False) 47 | query = self._process_dir(self.query_dir, self.list_query_path, None) 48 | gallery = self._process_dir(self.gallery_dir, self.list_gallery_path, None) 49 | # import ipdb; ipdb.set_trace() 50 | # train += self._process_dir(self.train_aug_dir, self.list_train_path, self.train_label_path, relabel=False) 51 | train = self.relabel(train) 52 | if verbose: 53 | print("=> AI CITY 2021 data loaded") 54 | #self.print_dataset_statistics(train, query, gallery) 55 | 56 | self.train = train 57 | self.query = query 58 | self.gallery = gallery 59 | 60 | self.train_tracks = self._read_tracks(os.path.join(self.dataset_dir, 'train_track.txt')) 61 | self.test_tracks = self._read_tracks(os.path.join(self.dataset_dir, 'test_track.txt')) 62 | 63 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 64 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 65 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 66 | 67 | 68 | def _check_before_run(self): 69 | """Check if all files are available before going deeper""" 70 | if not osp.exists(self.dataset_dir): 71 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 72 | if not osp.exists(self.train_dir): 73 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 74 | 75 | def _process_dir(self, img_dir, list_path, label_path, relabel=False, domain='real'): 76 | dataset = [] 77 | if label_path: 78 | tree = ET.parse(label_path, parser=ET.XMLParser(encoding='utf-8')) 79 | objs = tree.find('Items') 80 | for obj in objs: 81 | image_name = obj.attrib['imageName'] 82 | img_path = osp.join(img_dir, image_name) 83 | pid = int(obj.attrib['vehicleID']) 84 | camid = int(obj.attrib['cameraID'][1:]) 85 | dataset.append((img_path, pid, camid, domain)) 86 | #dataset.append((img_path, camid, pid)) 87 | if relabel: dataset = self.relabel(dataset) 88 | else: 89 | with open(list_path, 'r') as f: 90 | lines = f.readlines() 91 | for line in lines: 92 | line = line.strip() 93 | img_path = osp.join(img_dir, line) 94 | pid = 0 95 | camid = 0 96 | dataset.append((img_path, pid, camid, domain)) 97 | return dataset 98 | 99 | if __name__ == '__main__': 100 | dataset = AICity20(root='/media/data/ai-city/Track2/AIC21_Track2_ReID') 101 | -------------------------------------------------------------------------------- /lib/data/datasets/aicity20_ReCam.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | import os 6 | import os.path as osp 7 | import xml.etree.ElementTree as ET 8 | 9 | 10 | from .bases import BaseImageDataset 11 | 12 | 13 | class AICity20ReCam(BaseImageDataset): 14 | """ 15 | 将AI City train 中333个ID, 1-95为测试集, 241-478为训练集 16 | 测试集中随机取500张作为query 17 | """ 18 | dataset_dir = 'AIC21_Track2_ReID/AIC21_Track2_ReID' 19 | dataset_aug_dir = 'AIC20_ReID_Cropped/' 20 | dataset_blend_dir = 'AIC20_ReID_blend/' 21 | 22 | def __init__(self, root='', verbose=True, **kwargs): 23 | super(AICity20ReCam, self).__init__() 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | 26 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 27 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 28 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 29 | 30 | train_list_path = osp.join(self.dataset_dir, 'name_train.txt') 31 | query_list_path = osp.join(self.dataset_dir, 'name_query.txt') 32 | gallery_list_path = osp.join(self.dataset_dir, 'name_test.txt') 33 | 34 | self.train_label_path = osp.join(self.dataset_dir, 'train_label.xml') 35 | self.query_label_path = osp.join(self.dataset_dir, 'query_label.xml') 36 | self.gallery_label_path = osp.join(self.dataset_dir, 'test_label.xml') 37 | 38 | self._check_before_run() 39 | 40 | train = self._process_dir(self.train_dir, train_list_path, self.train_label_path, relabel=False) 41 | query = self._process_dir(self.query_dir, query_list_path, None) 42 | gallery = self._process_dir(self.gallery_dir, gallery_list_path, None) 43 | # train += self._process_dir(self.train_aug_dir, train_list_path, relabel=False) 44 | # train += self._process_dir(os.path.join(root, self.dataset_blend_dir, 'image_train') 45 | # , train_list_path, relabel=False) 46 | 47 | train = train+query+gallery 48 | 49 | train = self.relabel(train) 50 | if verbose: 51 | print("=> aicity trainval for ReCamID loaded") 52 | # self.print_dataset_statistics(train, query, gallery) 53 | 54 | self.train = train 55 | self.query = query 56 | self.gallery = gallery 57 | 58 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 59 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 60 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 61 | 62 | self.train_tracks = self._read_tracks(osp.join(self.dataset_dir, 'train_track.txt')) 63 | self.test_tracks = self._read_tracks(osp.join(self.dataset_dir, 'test_track.txt')) 64 | # import ipdb; ipdb.set_trace() 65 | def _check_before_run(self): 66 | """Check if all files are available before going deeper""" 67 | if not osp.exists(self.dataset_dir): 68 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 69 | if not osp.exists(self.train_dir): 70 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 71 | if not osp.exists(self.query_dir): 72 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 73 | if not osp.exists(self.gallery_dir): 74 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 75 | 76 | def _process_dir(self, img_dir, list_path, label_path, relabel=False): 77 | dataset = [] 78 | if label_path: 79 | tree = ET.parse(label_path, parser=ET.XMLParser(encoding='utf-8')) 80 | objs = tree.find('Items') 81 | for obj in objs: 82 | image_name = obj.attrib['imageName'] 83 | img_path = osp.join(img_dir, image_name) 84 | pid = int(obj.attrib['cameraID'][1:]) 85 | camid = int(obj.attrib['cameraID'][1:]) 86 | domain=0 87 | dataset.append((img_path, pid, camid, domain)) 88 | if relabel: dataset = self.relabel(dataset) 89 | else: 90 | with open(list_path, 'r') as f: 91 | lines = f.readlines() 92 | for line in lines: 93 | line = line.strip() 94 | img_path = osp.join(img_dir, line) 95 | pid = 0 96 | camid = 0 97 | domain=0 98 | dataset.append((img_path, pid, camid, domain)) 99 | return dataset 100 | 101 | if __name__ == '__main__': 102 | dataset = AICity20ReCam(root='/home/zxy/data/ReID/vehicle') 103 | -------------------------------------------------------------------------------- /lib/data/datasets/aicity20_ReColor.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | import os 6 | import os.path as osp 7 | import xml.etree.ElementTree as ET 8 | import json 9 | 10 | from .bases import BaseImageDataset 11 | from .aicity20 import AICity20 12 | 13 | class AICity20ReColor(AICity20): 14 | """ 15 | Simulation data: include attribute information 16 | - orientation 17 | - color 18 | - cls type (truck, suv) 19 | """ 20 | dataset_dir = 'AIC20_ReID_Simulation' 21 | def __init__(self, root='', verbose=True, **kwargs): 22 | super(AICity20, self).__init__() 23 | self.dataset_dir = osp.join(root, self.dataset_dir) 24 | 25 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 26 | self.list_train_path = osp.join(self.dataset_dir, 'name_train.txt') 27 | self.train_label_path = osp.join(self.dataset_dir, 'train_label.xml') 28 | self._check_before_run() 29 | 30 | train = self._process_dir(self.train_dir, self.list_train_path, self.train_label_path, relabel=False) 31 | 32 | train_num = 180000 33 | #train_num = 100000 34 | #train_num = 50000 35 | query_num = 500 36 | gallery_num = 5000 37 | query = train[train_num:train_num+query_num] 38 | gallery = train[train_num+query_num: train_num+query_num+gallery_num] 39 | train = train[:train_num] 40 | 41 | if verbose: 42 | print("=> AI CITY 2020 sim data loaded") 43 | 44 | self.train = train 45 | self.query = query 46 | self.gallery = gallery 47 | 48 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 49 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 50 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 51 | 52 | def _process_dir(self, img_dir, list_path, label_path, relabel=False): 53 | dataset = [] 54 | if label_path: 55 | tree = ET.parse(label_path, parser=ET.XMLParser(encoding='utf-8')) 56 | objs = tree.find('Items') 57 | for obj in objs: 58 | image_name = obj.attrib['imageName'] 59 | img_path = osp.join(img_dir, image_name) 60 | pid = int(obj.attrib['colorID']) 61 | camid = int(obj.attrib['cameraID'][1:]) 62 | dataset.append((img_path, pid, camid)) 63 | if relabel: dataset = self.relabel(dataset) 64 | else: 65 | with open(list_path, 'r') as f: 66 | lines = f.readlines() 67 | for line in lines: 68 | line = line.strip() 69 | img_path = osp.join(img_dir, line) 70 | pid = 0 71 | camid = 0 72 | dataset.append((img_path, pid, camid)) 73 | return dataset 74 | 75 | if __name__ == '__main__': 76 | dataset = AICity20ReColor(root='/home/zxy/data/ReID/vehicle') 77 | 78 | 79 | # 80 | # # encoding: utf-8 81 | # 82 | # import glob 83 | # import re 84 | # import os 85 | # import os.path as osp 86 | # import xml.etree.ElementTree as ET 87 | # import json 88 | # 89 | # from .bases import BaseImageDataset 90 | # 91 | # 92 | # class AICity20ReOri(BaseImageDataset): 93 | # """ 94 | # ---------------------------------------- 95 | # subset | # ids | # images | # cameras 96 | # ---------------------------------------- 97 | # train | 333 | 36935 | 36 98 | # query | 333 | 1052 | ? 99 | # gallery | 333 | 18290 | ? 100 | # ---------------------------------------- 101 | # 102 | # """ 103 | # dataset_dir = 'AIC20_ReID/' 104 | # dataset_aug_dir = 'AIC20_ReID_Cropped' 105 | # def __init__(self, root='', verbose=True, **kwargs): 106 | # super(AICity20ReOri, self).__init__() 107 | # self.dataset_dir = osp.join(root, self.dataset_dir) 108 | # self.dataset_aug_dir = osp.join(root, self.dataset_aug_dir) 109 | # 110 | # self.train_dir = osp.join(self.dataset_aug_dir, 'image_train') 111 | # self.query_dir = osp.join(self.dataset_aug_dir, 'image_query') 112 | # self.gallery_dir = osp.join(self.dataset_aug_dir, 'image_test') 113 | # self.train_aug_dir = osp.join(self.dataset_aug_dir, 'image_train') 114 | # 115 | # self.orientation_train_path = osp.join(self.dataset_dir, 'orientation', 'orientation_train.json') 116 | # self.orientation_query_path = osp.join(self.dataset_dir, 'orientation', 'orientation_query.json') 117 | # self.orientation_gallery_path = osp.join(self.dataset_dir, 'orientation', 'orientation_test.json') 118 | # 119 | # self._check_before_run() 120 | # 121 | # train = self._process_dir(self.train_dir, self.orientation_train_path, relabel=False) 122 | # query = self._process_dir(self.query_dir, self.orientation_query_path) 123 | # gallery = self._process_dir(self.gallery_dir, self.orientation_gallery_path) 124 | # 125 | # #train = self.relabel(train) 126 | # if verbose: 127 | # print("=> AI CITY 2020 data loaded") 128 | # #self.print_dataset_statistics(train, query, gallery) 129 | # 130 | # self.train = train 131 | # self.query = query 132 | # self.gallery = gallery 133 | # 134 | # self.train_tracks = self._read_tracks(os.path.join(self.dataset_dir, 'train_track_id.txt')) 135 | # self.test_tracks = self._read_tracks(os.path.join(self.dataset_dir, 'test_track_id.txt')) 136 | # 137 | # self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 138 | # self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 139 | # self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 140 | # 141 | # 142 | # def _check_before_run(self): 143 | # """Check if all files are available before going deeper""" 144 | # if not osp.exists(self.dataset_dir): 145 | # raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 146 | # if not osp.exists(self.train_dir): 147 | # raise RuntimeError("'{}' is not available".format(self.train_dir)) 148 | # 149 | # def _read_orientation_info(self, path): 150 | # with open(path, 'r') as f: 151 | # orientation = json.load(f) 152 | # return orientation 153 | # 154 | # def _process_dir(self, img_dir, json_path, relabel=False): 155 | # dataset = [] 156 | # orientation_dict = self._read_orientation_info(json_path) 157 | # for k, v in orientation_dict.items(): 158 | # img_path = osp.join(img_dir, k) 159 | # pid = int(float(v) * 360 / 10) 160 | # camid = 0 161 | # dataset.append([img_path, pid, camid]) 162 | # if relabel: self.relabel(dataset) 163 | # return dataset 164 | # 165 | # if __name__ == '__main__': 166 | # dataset = AICity20ReOri(root='/home/zxy/data/ReID/vehicle') 167 | -------------------------------------------------------------------------------- /lib/data/datasets/aicity20_ReOri.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | import os 6 | import os.path as osp 7 | import xml.etree.ElementTree as ET 8 | import json 9 | 10 | from .bases import BaseImageDataset 11 | from .aicity20 import AICity20 12 | 13 | class AICity20ReOri(AICity20): 14 | """ 15 | Simulation data: include attribute information 16 | - orientation 17 | - color 18 | - cls type (truck, suv) 19 | """ 20 | dataset_dir = 'AIC20_ReID_Simulation' 21 | def __init__(self, root='', verbose=True, **kwargs): 22 | super(AICity20, self).__init__() 23 | self.dataset_dir = osp.join(root, self.dataset_dir) 24 | 25 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 26 | self.list_train_path = osp.join(self.dataset_dir, 'name_train.txt') 27 | self.train_label_path = osp.join(self.dataset_dir, 'train_label.xml') 28 | self._check_before_run() 29 | 30 | train = self._process_dir(self.train_dir, self.list_train_path, self.train_label_path, relabel=False) 31 | 32 | #train_num = 190000 33 | train_num = 100000 34 | #train_num = 50000 35 | query_num = 500 36 | gallery_num = 5000 37 | query = train[train_num:train_num+query_num] 38 | gallery = train[train_num+query_num: train_num+query_num+gallery_num] 39 | train = train[:train_num] 40 | 41 | if verbose: 42 | print("=> AI CITY 2020 sim data loaded") 43 | 44 | self.train = train 45 | self.query = query 46 | self.gallery = gallery 47 | 48 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 49 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 50 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 51 | 52 | def _process_dir(self, img_dir, list_path, label_path, relabel=False): 53 | dataset = [] 54 | if label_path: 55 | tree = ET.parse(label_path, parser=ET.XMLParser(encoding='utf-8')) 56 | objs = tree.find('Items') 57 | for obj in objs: 58 | image_name = obj.attrib['imageName'] 59 | img_path = osp.join(img_dir, image_name) 60 | pid = int(float(obj.attrib['orientation']) / 10) 61 | camid = int(obj.attrib['cameraID'][1:]) 62 | dataset.append((img_path, pid, camid)) 63 | if relabel: dataset = self.relabel(dataset) 64 | else: 65 | with open(list_path, 'r') as f: 66 | lines = f.readlines() 67 | for line in lines: 68 | line = line.strip() 69 | img_path = osp.join(img_dir, line) 70 | pid = 0 71 | camid = 0 72 | dataset.append((img_path, pid, camid)) 73 | return dataset 74 | 75 | if __name__ == '__main__': 76 | dataset = AICity20ReOri(root='/home/zxy/data/ReID/vehicle') 77 | 78 | 79 | # 80 | # # encoding: utf-8 81 | # 82 | # import glob 83 | # import re 84 | # import os 85 | # import os.path as osp 86 | # import xml.etree.ElementTree as ET 87 | # import json 88 | # 89 | # from .bases import BaseImageDataset 90 | # 91 | # 92 | # class AICity20ReOri(BaseImageDataset): 93 | # """ 94 | # ---------------------------------------- 95 | # subset | # ids | # images | # cameras 96 | # ---------------------------------------- 97 | # train | 333 | 36935 | 36 98 | # query | 333 | 1052 | ? 99 | # gallery | 333 | 18290 | ? 100 | # ---------------------------------------- 101 | # 102 | # """ 103 | # dataset_dir = 'AIC20_ReID/' 104 | # dataset_aug_dir = 'AIC20_ReID_Cropped' 105 | # def __init__(self, root='', verbose=True, **kwargs): 106 | # super(AICity20ReOri, self).__init__() 107 | # self.dataset_dir = osp.join(root, self.dataset_dir) 108 | # self.dataset_aug_dir = osp.join(root, self.dataset_aug_dir) 109 | # 110 | # self.train_dir = osp.join(self.dataset_aug_dir, 'image_train') 111 | # self.query_dir = osp.join(self.dataset_aug_dir, 'image_query') 112 | # self.gallery_dir = osp.join(self.dataset_aug_dir, 'image_test') 113 | # self.train_aug_dir = osp.join(self.dataset_aug_dir, 'image_train') 114 | # 115 | # self.orientation_train_path = osp.join(self.dataset_dir, 'orientation', 'orientation_train.json') 116 | # self.orientation_query_path = osp.join(self.dataset_dir, 'orientation', 'orientation_query.json') 117 | # self.orientation_gallery_path = osp.join(self.dataset_dir, 'orientation', 'orientation_test.json') 118 | # 119 | # self._check_before_run() 120 | # 121 | # train = self._process_dir(self.train_dir, self.orientation_train_path, relabel=False) 122 | # query = self._process_dir(self.query_dir, self.orientation_query_path) 123 | # gallery = self._process_dir(self.gallery_dir, self.orientation_gallery_path) 124 | # 125 | # #train = self.relabel(train) 126 | # if verbose: 127 | # print("=> AI CITY 2020 data loaded") 128 | # #self.print_dataset_statistics(train, query, gallery) 129 | # 130 | # self.train = train 131 | # self.query = query 132 | # self.gallery = gallery 133 | # 134 | # self.train_tracks = self._read_tracks(os.path.join(self.dataset_dir, 'train_track_id.txt')) 135 | # self.test_tracks = self._read_tracks(os.path.join(self.dataset_dir, 'test_track_id.txt')) 136 | # 137 | # self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 138 | # self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 139 | # self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 140 | # 141 | # 142 | # def _check_before_run(self): 143 | # """Check if all files are available before going deeper""" 144 | # if not osp.exists(self.dataset_dir): 145 | # raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 146 | # if not osp.exists(self.train_dir): 147 | # raise RuntimeError("'{}' is not available".format(self.train_dir)) 148 | # 149 | # def _read_orientation_info(self, path): 150 | # with open(path, 'r') as f: 151 | # orientation = json.load(f) 152 | # return orientation 153 | # 154 | # def _process_dir(self, img_dir, json_path, relabel=False): 155 | # dataset = [] 156 | # orientation_dict = self._read_orientation_info(json_path) 157 | # for k, v in orientation_dict.items(): 158 | # img_path = osp.join(img_dir, k) 159 | # pid = int(float(v) * 360 / 10) 160 | # camid = 0 161 | # dataset.append([img_path, pid, camid]) 162 | # if relabel: self.relabel(dataset) 163 | # return dataset 164 | # 165 | # if __name__ == '__main__': 166 | # dataset = AICity20ReOri(root='/home/zxy/data/ReID/vehicle') 167 | -------------------------------------------------------------------------------- /lib/data/datasets/aicity20_ReType.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | import os 6 | import os.path as osp 7 | import xml.etree.ElementTree as ET 8 | import json 9 | 10 | from .bases import BaseImageDataset 11 | from .aicity20 import AICity20 12 | 13 | class AICity20ReType(AICity20): 14 | """ 15 | Simulation data: include attribute information 16 | - orientation 17 | - color 18 | - cls type (truck, suv) 19 | """ 20 | dataset_dir = 'AIC20_ReID_Simulation' 21 | def __init__(self, root='', verbose=True, **kwargs): 22 | super(AICity20, self).__init__() 23 | self.dataset_dir = osp.join(root, self.dataset_dir) 24 | 25 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 26 | self.list_train_path = osp.join(self.dataset_dir, 'name_train.txt') 27 | self.train_label_path = osp.join(self.dataset_dir, 'train_label.xml') 28 | self._check_before_run() 29 | 30 | train = self._process_dir(self.train_dir, self.list_train_path, self.train_label_path, relabel=False) 31 | 32 | train_num = 180000 33 | #train_num = 100000 34 | #train_num = 50000 35 | query_num = 500 36 | gallery_num = 5000 37 | query = train[train_num:train_num+query_num] 38 | gallery = train[train_num+query_num: train_num+query_num+gallery_num] 39 | train = train[:train_num] 40 | 41 | if verbose: 42 | print("=> AI CITY 2020 sim data loaded") 43 | 44 | self.train = train 45 | self.query = query 46 | self.gallery = gallery 47 | 48 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 49 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 50 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 51 | 52 | def _process_dir(self, img_dir, list_path, label_path, relabel=False): 53 | dataset = [] 54 | if label_path: 55 | tree = ET.parse(label_path, parser=ET.XMLParser(encoding='utf-8')) 56 | objs = tree.find('Items') 57 | for obj in objs: 58 | image_name = obj.attrib['imageName'] 59 | img_path = osp.join(img_dir, image_name) 60 | pid = int(obj.attrib['typeID']) 61 | camid = int(obj.attrib['cameraID'][1:]) 62 | dataset.append((img_path, pid, camid)) 63 | if relabel: dataset = self.relabel(dataset) 64 | else: 65 | with open(list_path, 'r') as f: 66 | lines = f.readlines() 67 | for line in lines: 68 | line = line.strip() 69 | img_path = osp.join(img_dir, line) 70 | pid = 0 71 | camid = 0 72 | dataset.append((img_path, pid, camid)) 73 | return dataset 74 | 75 | 76 | -------------------------------------------------------------------------------- /lib/data/datasets/aicity20_sim.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | import os 6 | import os.path as osp 7 | import xml.etree.ElementTree as ET 8 | 9 | from .bases import BaseImageDataset 10 | from .aicity20 import AICity20 11 | 12 | class AICity20Sim(AICity20): 13 | """ 14 | Simulation data: include attribute information 15 | - orientation 16 | - color 17 | - cls type (truck, suv) 18 | """ 19 | dataset_dir = 'AIC21_Track2_ReID_Simulation' 20 | def __init__(self, root='', verbose=True, **kwargs): 21 | super(AICity20, self).__init__() 22 | self.dataset_dir = osp.join(root, self.dataset_dir) 23 | 24 | self.train_dir = osp.join(self.dataset_dir, 'output_UNIT') 25 | self.list_train_path = osp.join(self.dataset_dir, 'name_train.txt') 26 | self.train_label_path = osp.join(self.dataset_dir, 'train_label.xml') 27 | self._check_before_run() 28 | 29 | train = self._process_dir(self.train_dir, self.list_train_path, self.train_label_path, relabel=True, domain='syn') 30 | 31 | if verbose: 32 | print("=> AI CITY 2021 sim data loaded") 33 | #self.print_dataset_statistics(train, query, gallery) 34 | 35 | self.train = train 36 | self.query = [] 37 | self.gallery = [] 38 | 39 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 40 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 41 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 42 | 43 | 44 | if __name__ == '__main__': 45 | dataset = AICity20Sim(root='/media/data/ai-city/Track2/') 46 | -------------------------------------------------------------------------------- /lib/data/datasets/aicity20_split.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | import os 6 | import os.path as osp 7 | import xml.etree.ElementTree as ET 8 | 9 | 10 | from .bases import BaseImageDataset 11 | 12 | 13 | class AICity20_Split(BaseImageDataset): 14 | """ 15 | ---------------------------------------- 16 | subset | # ids | # images | # cameras 17 | ---------------------------------------- 18 | train | 333 | 36935 | 36 19 | query | 333 | 1052 | ? 20 | gallery | 333 | 18290 | ? 21 | ---------------------------------------- 22 | 23 | """ 24 | dataset_dir = 'AIC21_Track2_ReID/AIC21_Track2_ReID' 25 | dataset_aug_dir = 'AIC20_ReID_Cropped' 26 | def __init__(self, root='', verbose=True, **kwargs): 27 | super(AICity20_Split, self).__init__() 28 | # import ipdb; ipdb.set_trace() 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | # self.dataset_aug_dir = osp.join(root, self.dataset_aug_dir) 31 | 32 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 33 | self.query_dir = osp.join(self.dataset_dir, 'image_train') 34 | self.gallery_dir = osp.join(self.dataset_dir, 'image_train') 35 | # self.train_aug_dir = osp.join(self.dataset_aug_dir, 'image_train') 36 | 37 | # self.list_train_path = osp.join(self.dataset_dir, 'name_train.txt') 38 | # self.list_query_path = osp.join(self.dataset_dir, 'name_query.txt') 39 | # self.list_gallery_path = osp.join(self.dataset_dir, 'name_test.txt') 40 | 41 | self.train_label_path = osp.join(self.dataset_dir, 'train_split.xml') 42 | self.test_label_path = osp.join(self.dataset_dir, 'test_split.xml') 43 | self.query_label_path = osp.join(self.dataset_dir, 'query_split.xml') 44 | 45 | self._check_before_run() 46 | 47 | train = self._process_dir(self.train_dir, self.train_label_path, relabel=False) 48 | query = self._process_dir(self.query_dir, self.query_label_path, relabel=False) 49 | gallery = self._process_dir(self.gallery_dir, self.test_label_path, relabel=False) 50 | # import ipdb; ipdb.set_trace() 51 | # train += self._process_dir(self.train_aug_dir, self.list_train_path, self.train_label_path, relabel=False) 52 | train = self.relabel(train) 53 | if verbose: 54 | print("=> AI CITY SPLIT 2021 data loaded") 55 | #self.print_dataset_statistics(train, query, gallery) 56 | 57 | self.train = train 58 | self.query = query 59 | self.gallery = gallery 60 | 61 | self.train_tracks = self._read_tracks(os.path.join(self.dataset_dir, 'train_track.txt')) 62 | self.test_tracks = self._read_tracks(os.path.join(self.dataset_dir, 'test_track.txt')) 63 | 64 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 65 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 66 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 67 | 68 | 69 | def _check_before_run(self): 70 | """Check if all files are available before going deeper""" 71 | if not osp.exists(self.dataset_dir): 72 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 73 | if not osp.exists(self.train_dir): 74 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 75 | 76 | def _process_dir(self, img_dir, label_path, relabel=False, domain='real'): 77 | dataset = [] 78 | if label_path: 79 | tree = ET.parse(label_path, parser=ET.XMLParser(encoding='utf-8')) 80 | objs = tree.find('Items') 81 | for obj in objs: 82 | image_name = obj.attrib['imageName'] 83 | img_path = osp.join(img_dir, image_name) 84 | pid = int(obj.attrib['vehicleID']) 85 | camid = int(obj.attrib['cameraID'][1:]) 86 | dataset.append((img_path, pid, camid, domain)) 87 | #dataset.append((img_path, camid, pid)) 88 | if relabel: dataset = self.relabel(dataset) 89 | else: 90 | with open(list_path, 'r') as f: 91 | lines = f.readlines() 92 | for line in lines: 93 | line = line.strip() 94 | img_path = osp.join(img_dir, line) 95 | pid = 0 96 | camid = 0 97 | dataset.append((img_path, pid, camid, domain)) 98 | return dataset 99 | 100 | if __name__ == '__main__': 101 | dataset = AICity20_Split(root='/media/data/ai-city/Track2/') -------------------------------------------------------------------------------- /lib/data/datasets/aicity20_trainval.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | import os 6 | import os.path as osp 7 | import xml.etree.ElementTree as ET 8 | 9 | 10 | from .bases import BaseImageDataset 11 | 12 | 13 | class AICity20Trainval(BaseImageDataset): 14 | """ 15 | 将AI City train 中333个ID, 1-95为测试集, 241-478为训练集 16 | 测试集中随机取500张作为query 17 | """ 18 | dataset_dir = 'AIC20_ReID/' 19 | dataset_aug_dir = 'AIC20_ReID_Cropped/' 20 | dataset_blend_dir = 'AIC20_ReID_blend/' 21 | 22 | def __init__(self, root='', verbose=True, **kwargs): 23 | super(AICity20Trainval, self).__init__() 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.dataset_aug_dir = osp.join(root, self.dataset_aug_dir) 26 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 27 | self.query_dir = osp.join(self.dataset_aug_dir, 'image_train') 28 | self.gallery_dir = osp.join(self.dataset_aug_dir, 'image_train') 29 | self.train_aug_dir = osp.join(self.dataset_aug_dir, 'image_train') 30 | 31 | train_list_path = osp.join(self.dataset_dir, 'trainval_partial', 'train.txt') 32 | query_list_path = osp.join(self.dataset_dir, 'trainval_partial', 'query.txt') 33 | gallery_list_path = osp.join(self.dataset_dir, 'trainval_partial', 'test.txt') 34 | #train_aug_list_path = osp.join(self.dataset_dir, 'trainval_partial', 'train.txt') 35 | 36 | self._check_before_run() 37 | 38 | train = self._process_dir(self.train_dir, train_list_path, relabel=False) 39 | query = self._process_dir(self.query_dir, query_list_path, relabel=False) 40 | gallery = self._process_dir(self.gallery_dir, gallery_list_path, relabel=False) 41 | # train += self._process_dir(self.train_aug_dir, train_list_path, relabel=False) 42 | # train += self._process_dir(os.path.join(root, self.dataset_blend_dir, 'image_train') 43 | # , train_list_path, relabel=False) 44 | 45 | 46 | train = self.relabel(train) 47 | if verbose: 48 | print("=> aicity trainval loaded") 49 | # self.print_dataset_statistics(train, query, gallery) 50 | 51 | self.train = train 52 | self.query = query 53 | self.gallery = gallery 54 | 55 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 56 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 57 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 58 | 59 | self.train_tracks = self._read_tracks(osp.join(self.dataset_dir, 'train_track.txt')) 60 | self.test_tracks = self._read_tracks(osp.join(self.dataset_dir, 'trainval_partial', 'test_track.txt')) 61 | 62 | def _check_before_run(self): 63 | """Check if all files are available before going deeper""" 64 | if not osp.exists(self.dataset_dir): 65 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 66 | if not osp.exists(self.train_dir): 67 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 68 | if not osp.exists(self.query_dir): 69 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 70 | if not osp.exists(self.gallery_dir): 71 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 72 | 73 | def _process_dir(self, dir_path, list_path, relabel=False): 74 | dataset = [] 75 | with open(list_path, 'r') as f: 76 | lines = f.readlines() 77 | 78 | for line in lines: 79 | line = line.strip() 80 | pid, camid, trackid, image_name = line.split('_') 81 | pid = int(pid) 82 | camid = int(camid[1:]) 83 | img_path = osp.join(dir_path, image_name) 84 | dataset.append((img_path, pid, camid)) 85 | #dataset.append((img_path, camid, pid)) 86 | if relabel: dataset = self.relabel(dataset) 87 | 88 | return dataset 89 | 90 | if __name__ == '__main__': 91 | dataset = AICity20Trainval(root='/home/zxy/data/ReID/vehicle') 92 | -------------------------------------------------------------------------------- /lib/data/datasets/bases.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | 9 | 10 | class BaseDataset(object): 11 | """ 12 | Base class of reid dataset 13 | """ 14 | 15 | def get_imagedata_info(self, data): 16 | pids, cams = [], [] 17 | for _, pid, camid,_ in data: 18 | pids += [pid] 19 | cams += [camid] 20 | pids = set(pids) 21 | cams = set(cams) 22 | num_pids = len(pids) 23 | num_cams = len(cams) 24 | num_imgs = len(data) 25 | return num_pids, num_imgs, num_cams 26 | 27 | def get_videodata_info(self, data, return_tracklet_stats=False): 28 | pids, cams, tracklet_stats = [], [], [] 29 | for img_paths, pid, camid in data: 30 | pids += [pid] 31 | cams += [camid] 32 | tracklet_stats += [len(img_paths)] 33 | pids = set(pids) 34 | cams = set(cams) 35 | num_pids = len(pids) 36 | num_cams = len(cams) 37 | num_tracklets = len(data) 38 | if return_tracklet_stats: 39 | return num_pids, num_tracklets, num_cams, tracklet_stats 40 | return num_pids, num_tracklets, num_cams 41 | 42 | def print_dataset_statistics(self): 43 | raise NotImplementedError 44 | 45 | 46 | class BaseImageDataset(BaseDataset): 47 | """ 48 | Base class of image reid dataset 49 | """ 50 | def __init__(self): 51 | self.train = [] 52 | self.query = [] 53 | self.gallery = [] 54 | self.train_tracks = [] # track information 55 | self.test_tracks = [] 56 | self.query_orientation = None 57 | self.gallery_orientation = None 58 | 59 | def longtail_data_process(self, data, NUM_INSTANCE_PER_CLS=2): 60 | labels = {} 61 | for img_path, pid, camid in data: 62 | if pid in labels: 63 | labels[pid].append([img_path, pid, camid]) 64 | else: 65 | labels[pid] = [[img_path, pid, camid]] 66 | 67 | # cut-off long-tail data 68 | keep_data = [] 69 | remove_data = [] 70 | for key, value in labels.items(): 71 | if len(value) < NUM_INSTANCE_PER_CLS: 72 | remove_data.extend(value) 73 | continue 74 | keep_data.extend(value) 75 | keep_data = self.relabel(keep_data) 76 | 77 | # import shutil 78 | # import os 79 | # dst_dir = './longtailed-N3' 80 | # for img_path, pid, camid in remove_data: 81 | # dst_path = os.path.join(dst_dir, str(pid).zfill(5)) 82 | # if not os.path.exists(dst_path): 83 | # os.makedirs(dst_path) 84 | # shutil.copyfile(img_path, os.path.join(dst_path, os.path.basename(img_path))) 85 | 86 | return keep_data 87 | 88 | def combine_all(self): 89 | # combine train, query, gallery 90 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(self.train) 91 | new_train = self.query + self.gallery 92 | #new_train = self.relabel(new_train) 93 | 94 | for img_path, pid, camid in new_train: 95 | self.train.append([img_path, pid + num_train_pids, camid]) 96 | self.train = self.relabel(self.train) 97 | self.query = [] 98 | self.gallery = [] 99 | 100 | def get_id_range(self, lists): 101 | pid_container = set() 102 | for img_path, pid, camid, domain in lists: 103 | pid_container.add(pid) 104 | 105 | if len(pid_container) == 0: 106 | min_id, max_id = 0, 0 107 | else: 108 | min_id, max_id = min(pid_container), max(pid_container) 109 | return min_id, max_id 110 | 111 | def relabel(self, lists): 112 | relabeled = [] 113 | pid_container = set() 114 | for img_path, pid, camid, domain in lists: 115 | pid_container.add(pid) 116 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 117 | for img_path, pid, camid, domain in lists: 118 | pid = pid2label[pid] 119 | relabeled.append([img_path, pid, camid, domain]) 120 | return relabeled 121 | 122 | def _read_tracks(self, path): 123 | tracks = [] 124 | with open(path, 'r') as f: 125 | lines = f.readlines() 126 | for line in lines: 127 | line = line.strip() 128 | track = line.split(' ') 129 | tracks.append(track) 130 | return tracks 131 | 132 | def print_dataset_statistics(self, train, query, gallery): 133 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 134 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 135 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 136 | 137 | print("Dataset statistics:") 138 | print(" ----------------------------------------") 139 | print(" subset | # ids | # images | # cameras") 140 | print(" ----------------------------------------") 141 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 142 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 143 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 144 | print(" ----------------------------------------") 145 | 146 | 147 | class BaseVideoDataset(BaseDataset): 148 | """ 149 | Base class of video reid dataset 150 | """ 151 | 152 | def print_dataset_statistics(self, train, query, gallery): 153 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \ 154 | self.get_videodata_info(train, return_tracklet_stats=True) 155 | 156 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \ 157 | self.get_videodata_info(query, return_tracklet_stats=True) 158 | 159 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \ 160 | self.get_videodata_info(gallery, return_tracklet_stats=True) 161 | 162 | tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats 163 | min_num = np.min(tracklet_stats) 164 | max_num = np.max(tracklet_stats) 165 | avg_num = np.mean(tracklet_stats) 166 | 167 | print("Dataset statistics:") 168 | print(" -------------------------------------------") 169 | print(" subset | # ids | # tracklets | # cameras") 170 | print(" -------------------------------------------") 171 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) 172 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) 173 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) 174 | print(" -------------------------------------------") 175 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) 176 | print(" -------------------------------------------") 177 | 178 | 179 | def apply_id_bias(train, id_bias=0): 180 | # add id bias 181 | id_biased_train = [] 182 | for img_path, pid, camid, domain in train: 183 | id_biased_train.append([img_path, pid + id_bias, camid, domain]) 184 | return id_biased_train 185 | -------------------------------------------------------------------------------- /lib/data/datasets/cuhk03.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | import os 7 | import glob 8 | import re 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | 12 | from lib.utils.iotools import mkdir_if_missing, write_json, read_json 13 | from .bases import BaseImageDataset 14 | 15 | 16 | class CUHK03(BaseImageDataset): 17 | """ 18 | CUHK03 19 | Reference: 20 | Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014. 21 | URL: http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#! 22 | 23 | Dataset statistics: 24 | # identities: 1360 25 | # images: 13164 26 | # cameras: 6 27 | # splits: 20 (classic) 28 | Args: 29 | split_id (int): split index (default: 0) 30 | cuhk03_labeled (bool): whether to load labeled images; if false, detected images are loaded (default: False) 31 | """ 32 | dataset_dir = 'cuhk03' 33 | 34 | def __init__(self, root='', cuhk03_labeled=False, verbose=True, 35 | **kwargs): 36 | super(CUHK03, self).__init__() 37 | self.dataset_dir = osp.join(root, self.dataset_dir) 38 | 39 | self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected') 40 | self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled') 41 | self._check_before_run() 42 | 43 | if cuhk03_labeled: 44 | image_type = 'cuhk03_labeled' 45 | else: 46 | image_type = 'cuhk03_detected' 47 | self.dataset_dir = osp.join(self.dataset_dir, image_type) 48 | 49 | train = self.process_dir(self.dataset_dir, relabel=True) 50 | query = [] 51 | gallery = [] 52 | 53 | if verbose: 54 | print("=> CUHK03 ({}) loaded".format(image_type)) 55 | # self.print_dataset_statistics(train, query, gallery) 56 | 57 | self.train = train 58 | self.query = query 59 | self.gallery = gallery 60 | 61 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 62 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 63 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 64 | 65 | def _check_before_run(self): 66 | """Check if all files are available before going deeper""" 67 | if not osp.exists(self.dataset_dir): 68 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 69 | 70 | def process_dir(self, dir_path, relabel=True): 71 | img_paths = glob.glob(osp.join(dir_path, '*.png')) 72 | pid_container = set() 73 | for img_path in img_paths: 74 | img_name = os.path.basename(img_path) 75 | video, pid, camid, _ = img_name.split('_') 76 | video, pid, camid = int(video), int(pid), int(camid) 77 | pid = (video-1) * 1000 + pid 78 | pid_container.add(pid) 79 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 80 | 81 | dataset = [] 82 | for img_path in img_paths: 83 | img_name = os.path.basename(img_path) 84 | video, pid, camid, _ = img_name.split('_') 85 | video, pid, camid = int(video), int(pid), int(camid) 86 | pid = (video-1) * 1000 + pid 87 | if relabel: pid = pid2label[pid] 88 | dataset.append((img_path, pid, camid)) 89 | 90 | return dataset 91 | 92 | if __name__ == '__main__': 93 | dataset = CUHK03(root='/home/zxy/data/ReID') -------------------------------------------------------------------------------- /lib/data/datasets/dataset_loader.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os.path as osp 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | import cv2 11 | from glob import glob 12 | from random import randint 13 | import numpy as np 14 | import random 15 | from PIL import Image 16 | 17 | def read_image(img_path): 18 | """Keep reading image until succeed. 19 | This can avoid IOError incurred by heavy IO process.""" 20 | got_img = False 21 | if not osp.exists(img_path): 22 | raise IOError("{} does not exist".format(img_path)) 23 | while not got_img: 24 | try: 25 | img = Image.open(img_path).convert('RGB') 26 | # img = cv2.imread(img_path, 1) #BGR 27 | # img = Image.fromarray(img) 28 | got_img = True 29 | except IOError: 30 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 31 | pass 32 | return img 33 | 34 | 35 | class ImageDataset(Dataset): 36 | """Image Person ReID Dataset""" 37 | 38 | def __init__(self, dataset, transform=None, change_background=False): 39 | self.dataset = dataset 40 | self.transform = transform 41 | self.change_background = change_background 42 | if self.change_background: 43 | self.mask_list = glob('/media/data/ai-city/Track2/AIC21_Track2_ReID/AIC21_Track2_ReID/track2_segmented/mask/*.npy') 44 | 45 | self.path = '/media/data/ai-city/Track2/AIC21_Track2_ReID/AIC21_Track2_ReID/track2_segmented/' 46 | self._ori_len = len(self.dataset) 47 | self.times = 1 48 | # print("I'm ngocnt") 49 | 50 | def __len__(self): 51 | return len(self.dataset) 52 | 53 | def __getitem__(self, index): 54 | img_path, pid, camid, domain = self.dataset[index] 55 | if self.change_background==True: 56 | prob = randint(1, 10)/10 57 | img_name = img_path.split('/')[-1] 58 | mask_path = self.path+'mask/'+img_name.split('.')[0]+'.npy' 59 | if (prob>=0.5) and (mask_path in self.mask_list): 60 | # import ipdb; ipdb.set_trace() 61 | foreground=read_image(self.path+'foreground/'+img_name.split('.')[0]+'.jpg') 62 | mask = np.load(mask_path) 63 | width, height = foreground.size 64 | 65 | # select background 66 | background_path = random.choice(glob(self.path+'background_painted/*.jpg')) 67 | background = read_image(background_path) 68 | background = background.resize((width, height)) 69 | # merge 70 | merge = background * (1 - np.stack([mask, mask, mask], axis=2)) + foreground 71 | img = Image.fromarray(merge) 72 | 73 | else: 74 | img = read_image(img_path) 75 | else: 76 | img = read_image(img_path) 77 | 78 | if self.transform is not None: 79 | img = self.transform(img) 80 | return img, pid, camid, domain, img_path 81 | -------------------------------------------------------------------------------- /lib/data/datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | import os 14 | from .bases import BaseImageDataset 15 | 16 | 17 | class DukeMTMCreID(BaseImageDataset): 18 | """ 19 | DukeMTMC-reID 20 | Reference: 21 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 22 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 23 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 24 | 25 | Dataset statistics: 26 | # identities: 1404 (train + query) 27 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 28 | # cameras: 8 29 | """ 30 | dataset_dir = 'dukemtmc-reid' 31 | 32 | def __init__(self, root='', verbose=True, **kwargs): 33 | super(DukeMTMCreID, self).__init__() 34 | self.dataset_dir = osp.join(root, self.dataset_dir) 35 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 36 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 37 | self.query_dir = osp.join(self.dataset_dir, 'query') 38 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 39 | 40 | self._download_data() 41 | self._check_before_run() 42 | 43 | train = self._process_dir(self.train_dir, relabel=True) 44 | query = self._process_dir(self.query_dir, relabel=False) 45 | gallery = self._process_dir(self.gallery_dir, relabel=False) 46 | 47 | if verbose: 48 | print("=> DukeMTMC-reID loaded") 49 | #self.print_dataset_statistics(train, query, gallery) 50 | 51 | self.train = train 52 | self.query = query 53 | self.gallery = gallery 54 | 55 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 56 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 57 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 58 | 59 | def _download_data(self): 60 | if osp.exists(self.dataset_dir): 61 | print("This dataset has been downloaded.") 62 | return 63 | 64 | print("Creating directory {}".format(self.dataset_dir)) 65 | os.mkdir(self.dataset_dir) 66 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 67 | 68 | print("Downloading DukeMTMC-reID dataset") 69 | urllib.request.urlretrieve(self.dataset_url, fpath) 70 | 71 | print("Extracting files") 72 | zip_ref = zipfile.ZipFile(fpath, 'r') 73 | zip_ref.extractall(self.dataset_dir) 74 | zip_ref.close() 75 | 76 | def _check_before_run(self): 77 | """Check if all files are available before going deeper""" 78 | if not osp.exists(self.dataset_dir): 79 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 80 | if not osp.exists(self.train_dir): 81 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 82 | if not osp.exists(self.query_dir): 83 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 84 | if not osp.exists(self.gallery_dir): 85 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 86 | 87 | def _process_dir(self, dir_path, relabel=False): 88 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 89 | pattern = re.compile(r'([-\d]+)_c(\d)') 90 | 91 | pid_container = set() 92 | for img_path in img_paths: 93 | pid, _ = map(int, pattern.search(img_path).groups()) 94 | pid_container.add(pid) 95 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 96 | 97 | dataset = [] 98 | for img_path in img_paths: 99 | pid, camid = map(int, pattern.search(img_path).groups()) 100 | assert 1 <= camid <= 8 101 | camid -= 1 # index starts from 0 102 | if relabel: pid = pid2label[pid] 103 | dataset.append((img_path, pid, camid)) 104 | 105 | return dataset 106 | -------------------------------------------------------------------------------- /lib/data/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class Market1501(BaseImageDataset): 16 | """ 17 | Market1501 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | URL: http://www.liangzheng.org/Project/project_reid.html 21 | 22 | Dataset statistics: 23 | # identities: 1501 (+1 for background) 24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 25 | """ 26 | dataset_dir = 'market1501' 27 | 28 | def __init__(self, root='', verbose=True, **kwargs): 29 | super(Market1501, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 32 | self.query_dir = osp.join(self.dataset_dir, 'query') 33 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 34 | 35 | self._check_before_run() 36 | 37 | train = self._process_dir(self.train_dir, relabel=True) 38 | query = self._process_dir(self.query_dir, relabel=False) 39 | gallery = self._process_dir(self.gallery_dir, relabel=False) 40 | 41 | if verbose: 42 | print("=> Market1501 loaded") 43 | #self.print_dataset_statistics(train, query, gallery) 44 | 45 | self.train = train 46 | self.query = query 47 | self.gallery = gallery 48 | 49 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 50 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 51 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 52 | 53 | def _check_before_run(self): 54 | """Check if all files are available before going deeper""" 55 | if not osp.exists(self.dataset_dir): 56 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 57 | if not osp.exists(self.train_dir): 58 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 59 | if not osp.exists(self.query_dir): 60 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 61 | if not osp.exists(self.gallery_dir): 62 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 63 | 64 | def _process_dir(self, dir_path, relabel=False): 65 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 66 | pattern = re.compile(r'([-\d]+)_c(\d)') 67 | 68 | pid_container = set() 69 | for img_path in img_paths: 70 | pid, _ = map(int, pattern.search(img_path).groups()) 71 | if pid == -1: continue # junk images are just ignored 72 | pid_container.add(pid) 73 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 74 | 75 | dataset = [] 76 | for img_path in img_paths: 77 | pid, camid = map(int, pattern.search(img_path).groups()) 78 | if pid == -1: continue # junk images are just ignored 79 | assert 0 <= pid <= 1501 # pid == 0 means background 80 | assert 1 <= camid <= 6 81 | camid -= 1 # index starts from 0 82 | if relabel: pid = pid2label[pid] 83 | dataset.append((img_path, pid, camid)) 84 | 85 | return dataset 86 | -------------------------------------------------------------------------------- /lib/data/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/1/17 15:00 4 | # @Author : Hao Luo 5 | # @File : msmt17.py 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class MSMT17(BaseImageDataset): 16 | """ 17 | MSMT17 18 | 19 | Reference: 20 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 21 | 22 | URL: http://www.pkuvmc.com/publications/msmt17.html 23 | 24 | Dataset statistics: 25 | # identities: 4101 26 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 27 | # cameras: 15 28 | """ 29 | dataset_dir = 'msmt17' 30 | 31 | def __init__(self, root='', verbose=True, **kwargs): 32 | super(MSMT17, self).__init__() 33 | self.dataset_dir = osp.join(root, self.dataset_dir) 34 | self.train_dir = osp.join(self.dataset_dir, 'MSMT17_V2/mask_train_v2') 35 | self.test_dir = osp.join(self.dataset_dir, 'MSMT17_V2/mask_test_v2') 36 | self.list_train_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_train.txt') 37 | self.list_val_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_val.txt') 38 | self.list_query_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_query.txt') 39 | self.list_gallery_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_gallery.txt') 40 | 41 | self._check_before_run() 42 | train = self._process_dir(self.train_dir, self.list_train_path) 43 | #val, num_val_pids, num_val_imgs = self._process_dir(self.train_dir, self.list_val_path) 44 | query = self._process_dir(self.test_dir, self.list_query_path) 45 | gallery = self._process_dir(self.test_dir, self.list_gallery_path) 46 | if verbose: 47 | print("=> MSMT17 loaded") 48 | #self.print_dataset_statistics(train, query, gallery) 49 | 50 | self.train = train 51 | self.query = query 52 | self.gallery = gallery 53 | 54 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 55 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 56 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 57 | 58 | def _check_before_run(self): 59 | """Check if all files are available before going deeper""" 60 | if not osp.exists(self.dataset_dir): 61 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 62 | if not osp.exists(self.train_dir): 63 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 64 | if not osp.exists(self.test_dir): 65 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 66 | 67 | def _process_dir(self, dir_path, list_path): 68 | with open(list_path, 'r') as txt: 69 | lines = txt.readlines() 70 | dataset = [] 71 | pid_container = set() 72 | for img_idx, img_info in enumerate(lines): 73 | img_path, pid = img_info.split(' ') 74 | pid = int(pid) # no need to relabel 75 | camid = int(img_path.split('_')[2]) 76 | img_path = osp.join(dir_path, img_path) 77 | dataset.append((img_path, pid, camid)) 78 | pid_container.add(pid) 79 | 80 | # check if pid starts from 0 and increments with 1 81 | for idx, pid in enumerate(pid_container): 82 | assert idx == pid, "See code comment for explanation" 83 | return dataset -------------------------------------------------------------------------------- /lib/data/datasets/track3_h5.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | import os 6 | import os.path as osp 7 | import xml.etree.ElementTree as ET 8 | import h5py 9 | 10 | from .bases import BaseImageDataset 11 | 12 | class Track3(BaseImageDataset): 13 | """ 14 | ---------------------------------------- 15 | subset | # ids | # images | # cameras 16 | ---------------------------------------- 17 | train | 333 | 36935 | 36 18 | query | 333 | 1052 | ? 19 | gallery | 333 | 18290 | ? 20 | ---------------------------------------- 21 | 22 | """ 23 | dataset_dir = '/media/data/ai-city/Aic_track3/train' 24 | h5_files = ['S01_data.h5', 'S03_data.h5', 'S04_data.h5'] 25 | def __init__(self, root='', verbose=True, **kwargs): 26 | super(Track3, self).__init__() 27 | # import ipdb; ipdb.set_trace() 28 | self.img_dir = osp.join(self.dataset_dir, 'JPEGImages') 29 | train = [] 30 | for h5_file_name in h5_files: 31 | self.data_file = osp.join(self.dataset_dir, h5_file_name) 32 | data_ = self._process_dir(self.img_dir, self.data_file) 33 | train.append(data_) 34 | 35 | train = self.relabel(train) 36 | if verbose: 37 | print("=> AI CITY 2020 data loaded") 38 | #self.print_dataset_statistics(train, query, gallery) 39 | 40 | self.train = train 41 | # self.query = query 42 | # self.gallery = gallery 43 | 44 | self.train_tracks = self._read_tracks(os.path.join(self.dataset_dir, 'train_track.txt')) 45 | self.test_tracks = self._read_tracks(os.path.join(self.dataset_dir, 'test_track.txt')) 46 | 47 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 48 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 49 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 50 | 51 | def _process_dir(self, img_dir, data_file, relabel=False): 52 | dataset = [] 53 | 54 | 55 | 56 | 57 | if label_path: 58 | tree = ET.parse(label_path, parser=ET.XMLParser(encoding='utf-8')) 59 | objs = tree.find('Items') 60 | for obj in objs: 61 | image_name = obj.attrib['imageName'] 62 | img_path = osp.join(img_dir, image_name) 63 | pid = int(obj.attrib['vehicleID']) 64 | camid = int(obj.attrib['cameraID'][1:]) 65 | dataset.append((img_path, pid, camid)) 66 | #dataset.append((img_path, camid, pid)) 67 | if relabel: dataset = self.relabel(dataset) 68 | else: 69 | with open(list_path, 'r') as f: 70 | lines = f.readlines() 71 | for line in lines: 72 | line = line.strip() 73 | img_path = osp.join(img_dir, line) 74 | pid = 0 75 | camid = 0 76 | dataset.append((img_path, pid, camid)) 77 | return dataset 78 | 79 | if __name__ == '__main__': 80 | dataset = AICity20(root='/media/data/ai-city/Track2/AIC21_Track2_ReID') 81 | -------------------------------------------------------------------------------- /lib/data/datasets/veri.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | 4 | import os.path as osp 5 | 6 | from .bases import BaseImageDataset 7 | 8 | 9 | class VeRi(BaseImageDataset): 10 | """ 11 | VeRi-776 12 | Reference: 13 | Liu, Xinchen, et al. "Large-scale vehicle re-identification in urban surveillance videos." ICME 2016. 14 | 15 | URL:https://vehiclereid.github.io/VeRi/ 16 | 17 | Dataset statistics: 18 | # identities: 776 19 | # images: 37778 (train) + 1678 (query) + 11579 (gallery) 20 | # cameras: 20 21 | """ 22 | 23 | dataset_dir = 'VeRi' 24 | 25 | def __init__(self, root='../', verbose=True, **kwargs): 26 | super(VeRi, self).__init__() 27 | self.dataset_dir = osp.join(root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 29 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 31 | 32 | self._check_before_run() 33 | 34 | train = self._process_dir(self.train_dir, relabel=True) 35 | query = self._process_dir(self.query_dir, relabel=False) 36 | gallery = self._process_dir(self.gallery_dir, relabel=False) 37 | 38 | if verbose: 39 | print("=> VeRi-776 loaded") 40 | #self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 49 | 50 | self.test_tracks = self._read_tracks(osp.join(self.dataset_dir, 'test_track.txt')) 51 | 52 | 53 | def _check_before_run(self): 54 | """Check if all files are available before going deeper""" 55 | if not osp.exists(self.dataset_dir): 56 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 57 | if not osp.exists(self.train_dir): 58 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 59 | if not osp.exists(self.query_dir): 60 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 61 | if not osp.exists(self.gallery_dir): 62 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 63 | 64 | def _process_dir(self, dir_path, relabel=False, domain='real'): 65 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 66 | pattern = re.compile(r'([-\d]+)_c(\d+)') 67 | 68 | pid_container = set() 69 | for img_path in img_paths: 70 | pid, _ = map(int, pattern.search(img_path).groups()) 71 | if pid == -1: continue # junk images are just ignored 72 | pid_container.add(pid) 73 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 74 | 75 | dataset = [] 76 | for img_path in img_paths: 77 | pid, camid = map(int, pattern.search(img_path).groups()) 78 | if pid == -1: continue # junk images are just ignored 79 | assert 0 <= pid <= 776 # pid == 0 means background 80 | assert 1 <= camid <= 20 81 | camid -= 1 # index starts from 0 82 | if relabel: pid = pid2label[pid] 83 | dataset.append((img_path, pid, camid, domain)) 84 | #dataset.append((img_path, camid, pid)) 85 | return dataset 86 | 87 | -------------------------------------------------------------------------------- /lib/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .triplet_sampler import RandomIdentitySampler, RandomIdentitySampler_alignedreid, MPerClassSampler, RandomIdentityCrossDomainSampler # new add by gu 8 | -------------------------------------------------------------------------------- /lib/data/samplers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/samplers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/samplers/__pycache__/triplet_sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/samplers/__pycache__/triplet_sampler.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import build_transforms 8 | -------------------------------------------------------------------------------- /lib/data/transforms/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/transforms/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/transforms/__pycache__/augmix.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/transforms/__pycache__/augmix.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/transforms/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/transforms/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/transforms/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/data/transforms/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /lib/data/transforms/augmix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Base augmentations operators.""" 16 | 17 | import numpy as np 18 | from PIL import Image 19 | from PIL import ImageOps 20 | import torch 21 | import random 22 | 23 | 24 | # ImageNet code should change this value 25 | IMAGE_SIZE = [256, 128] 26 | 27 | 28 | def int_parameter(level, maxval): 29 | """Helper function to scale `val` between 0 and maxval . 30 | 31 | Args: 32 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 33 | maxval: Maximum value that the operation can have. This will be scaled to 34 | level/PARAMETER_MAX. 35 | 36 | Returns: 37 | An int that results from scaling `maxval` according to `level`. 38 | """ 39 | return int(level * maxval / 10) 40 | 41 | 42 | def float_parameter(level, maxval): 43 | """Helper function to scale `val` between 0 and maxval. 44 | 45 | Args: 46 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 47 | maxval: Maximum value that the operation can have. This will be scaled to 48 | level/PARAMETER_MAX. 49 | 50 | Returns: 51 | A float that results from scaling `maxval` according to `level`. 52 | """ 53 | return float(level) * maxval / 10. 54 | 55 | 56 | def sample_level(n): 57 | return np.random.uniform(low=0.1, high=n) 58 | 59 | 60 | def autocontrast(pil_img, _): 61 | return ImageOps.autocontrast(pil_img) 62 | 63 | 64 | def equalize(pil_img, _): 65 | return ImageOps.equalize(pil_img) 66 | 67 | 68 | def posterize(pil_img, level): 69 | level = int_parameter(sample_level(level), 4) 70 | ret = ImageOps.posterize(pil_img, 4 - level) 71 | return ret 72 | 73 | def rotate(pil_img, level): 74 | degrees = int_parameter(sample_level(level), 30) 75 | if np.random.uniform() > 0.5: 76 | degrees = -degrees 77 | ret = pil_img.rotate(degrees, resample=Image.BILINEAR) 78 | return ret 79 | 80 | def solarize(pil_img, level): 81 | level = int_parameter(sample_level(level), 256) 82 | ret = ImageOps.solarize(pil_img, 256 - level) 83 | return ret 84 | 85 | def shear_x(pil_img, level): 86 | level = float_parameter(sample_level(level), 0.3) 87 | if np.random.uniform() > 0.5: 88 | level = -level 89 | ret = pil_img.transform(pil_img.size, 90 | Image.AFFINE, (1, level, 0, 0, 1, 0), 91 | resample=Image.BILINEAR) 92 | return ret 93 | 94 | def shear_y(pil_img, level): 95 | level = float_parameter(sample_level(level), 0.3) 96 | if np.random.uniform() > 0.5: 97 | level = -level 98 | ret = pil_img.transform(pil_img.size, 99 | Image.AFFINE, (1, 0, 0, level, 1, 0), 100 | resample=Image.BILINEAR) 101 | return ret 102 | 103 | def translate_x(pil_img, level): 104 | level = int_parameter(sample_level(level), pil_img.size[1] / 3) 105 | if np.random.random() > 0.5: 106 | level = -level 107 | ret = pil_img.transform(pil_img.size, 108 | Image.AFFINE, (1, 0, level, 0, 1, 0), 109 | resample=Image.BILINEAR) 110 | return ret 111 | 112 | def translate_y(pil_img, level): 113 | level = int_parameter(sample_level(level), pil_img.size[1] / 3) 114 | if np.random.random() > 0.5: 115 | level = -level 116 | ret = pil_img.transform(pil_img.size, 117 | Image.AFFINE, (1, 0, 0, 0, 1, level), 118 | resample=Image.BILINEAR) 119 | return ret 120 | augmentations = [ 121 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 122 | translate_x, translate_y 123 | ] 124 | 125 | # aug_prob_coeff = 0.1 126 | # mixture_width = 3 127 | # mixture_depth = 1 128 | # aug_severity = 1 129 | # 130 | # def aug(image, preprocess): 131 | # """Perform AugMix augmentations and compute mixture. 132 | # 133 | # Args: 134 | # image: PIL.Image input image 135 | # preprocess: Preprocessing function which should return a torch tensor. 136 | # 137 | # Returns: 138 | # mixed: Augmented and mixed image. 139 | # """ 140 | # ws = np.float32( 141 | # np.random.dirichlet([aug_prob_coeff] * mixture_width)) 142 | # m = np.float32(np.random.beta(aug_prob_coeff, aug_prob_coeff)) 143 | # 144 | # mix = torch.zeros_like(preprocess(image)) 145 | # for i in range(mixture_width): 146 | # image_aug = image.copy() 147 | # depth = mixture_depth if mixture_depth > 0 else np.random.randint( 148 | # 1, 4) 149 | # for _ in range(depth): 150 | # op = np.random.choice(augmentations) 151 | # image_aug = op(image_aug, aug_severity) 152 | # # Preprocessing commutes since all coefficients are convex 153 | # mix += ws[i] * preprocess(image_aug) 154 | # 155 | # mixed = (1 - m) * preprocess(image) + m * mix 156 | # return mixed 157 | 158 | 159 | class AugMix(object): 160 | # Args: 161 | # image: PIL.Image input image 162 | # preprocess: Preprocessing function which should return a torch tensor. 163 | # 164 | # Returns: 165 | # mixed: Augmented and mixed image. 166 | # 167 | 168 | def __init__(self, prob=0.5, aug_prob_coeff = 0.1, 169 | mixture_width = 3, 170 | mixture_depth = 1, 171 | aug_severity = 1): 172 | self.prob = prob 173 | self.aug_prob_coeff = aug_prob_coeff 174 | self.mixture_width = mixture_width 175 | self.mixture_depth = mixture_depth 176 | self.aug_severity = aug_severity 177 | 178 | def __call__(self, img): 179 | if random.random() > self.prob: 180 | return np.asarray(img) 181 | ws = np.float32( 182 | np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) 183 | m = np.float32(np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff)) 184 | 185 | mix = np.zeros([img.size[1], img.size[0], 3]) 186 | for i in range(self.mixture_width): 187 | image_aug = img.copy() 188 | depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint( 189 | 1, 4) 190 | for _ in range(depth): 191 | op = np.random.choice(augmentations) 192 | image_aug = op(image_aug, self.aug_severity) 193 | # Preprocessing commutes since all coefficients are convex 194 | mix += ws[i] * np.asarray(image_aug) 195 | 196 | mixed = (1 - m) * np.asarray(img) + m * mix 197 | return mixed.astype(np.uint8) -------------------------------------------------------------------------------- /lib/data/transforms/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import torchvision.transforms as T 8 | 9 | from .transforms import RandomErasing, RandomPatch, ColorSpaceConvert, ColorAugmentation, RandomBlur, GaussianBlur 10 | from .augmix import AugMix 11 | 12 | def build_transforms(cfg, is_train=True): 13 | normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 14 | 15 | if is_train: 16 | transform = T.Compose([ 17 | T.Resize(cfg.INPUT.SIZE_TRAIN), 18 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 19 | T.Pad(cfg.INPUT.PADDING), 20 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 21 | RandomPatch(prob_happen=cfg.INPUT.RANDOM_PATCH_PROB, patch_max_area=0.16), 22 | T.RandomApply([T.ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0)], p=cfg.INPUT.COLORJIT_PROB), 23 | AugMix(prob=cfg.INPUT.AUGMIX_PROB), 24 | RandomBlur(p=cfg.INPUT.RANDOM_BLUR_PROB), 25 | T.ToTensor(), 26 | normalize_transform, 27 | RandomErasing(probability=cfg.INPUT.RE_PROB, sh=cfg.INPUT.RE_SH, mean=cfg.INPUT.PIXEL_MEAN) 28 | ]) 29 | else: 30 | transform = T.Compose([ 31 | T.Resize(cfg.INPUT.SIZE_TEST), 32 | T.ToTensor(), 33 | normalize_transform 34 | ]) 35 | 36 | return transform 37 | -------------------------------------------------------------------------------- /lib/data/transforms/vis_transform.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | 3 | from transforms import RandomErasing, RandomPatch, ColorSpaceConvert, ColorAugmentation, RandomBlur 4 | from augmix import AugMix 5 | 6 | if __name__ == '__main__': 7 | from PIL import Image 8 | import cv2 9 | img_path = '/home/zxy/data/ReID/vehicle/AIC20_ReID/image_query/000345.jpg' 10 | img = Image.open(img_path).convert('RGB') 11 | 12 | transform = T.Compose([ 13 | T.Resize([256, 256]), 14 | T.RandomHorizontalFlip(0.0), 15 | T.Pad(0), 16 | T.RandomCrop([256, 256]), 17 | #RandomPatch(prob_happen=0.0, patch_max_area=0.5), 18 | #T.RandomApply([T.transforms.RandomAffine(degrees=20, scale=(0.8, 1.3))], p=0.5), 19 | #T.RandomApply([T.ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0)], p=0.0), 20 | AugMix(prob=0.5), 21 | RandomBlur(p=1.0), 22 | ]) 23 | canvas = transform(img) 24 | #img.save('test.jpg') 25 | cv2.imwrite('test.jpg', canvas[:, :, ::-1]) -------------------------------------------------------------------------------- /lib/engine/__pycache__/train_net.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/engine/__pycache__/train_net.cpython-37.pyc -------------------------------------------------------------------------------- /lib/engine/inference.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import logging 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | from lib.utils.reid_eval import evaluator 11 | 12 | def inference( 13 | cfg, 14 | model, 15 | val_loader, 16 | num_query, 17 | dataset 18 | ): 19 | device = cfg.MODEL.DEVICE 20 | model.to(device) 21 | logger = logging.getLogger("reid_baseline.inference") 22 | logger.info("Enter inferencing") 23 | metric = evaluator(num_query, dataset, cfg, max_rank=50) 24 | model.eval() 25 | start = time.time() 26 | with torch.no_grad(): 27 | for batch in val_loader: 28 | data, pid, camid, img_path = batch 29 | data = data.cuda() 30 | feats = model(data) 31 | if cfg.TEST.FLIP_TEST: 32 | data_flip = data.flip(dims=[3]) # NCHW 33 | feats_flip = model(data_flip) 34 | feats = (feats + feats_flip) / 2 35 | output = [feats, pid, camid, img_path] 36 | metric.update(output) 37 | end = time.time() 38 | logger.info("inference takes {:.3f}s".format((end - start))) 39 | torch.cuda.empty_cache() 40 | cmc, mAP, indices_np = metric.compute() 41 | logger.info('Validation Results') 42 | logger.info("mAP: {:.1%}".format(mAP)) 43 | for r in [1, 5, 10]: 44 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 45 | return indices_np 46 | 47 | def select_topk(indices, query, gallery, topk=10): 48 | results = [] 49 | for i in range(indices.shape[0]): 50 | ids = indices[i][:topk] 51 | results.append([query[i][0]] + [gallery[id][0] for id in ids]) 52 | return results 53 | 54 | 55 | def extract_features(cfg, model, loader): 56 | device = cfg.MODEL.DEVICE 57 | model.to(device) 58 | model.eval() 59 | feats = [] 60 | with torch.no_grad(): 61 | for i, batch in enumerate(loader): 62 | data, pid, camid, img_path = batch 63 | data = data.cuda() 64 | feat = model(data) 65 | feats.append(feat) 66 | feats = torch.cat(feats, dim=0) 67 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 68 | return feats -------------------------------------------------------------------------------- /lib/engine/train_net.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @author: Xiangyu 3 | ''' 4 | import os 5 | import logging 6 | import time 7 | import torch 8 | import random 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | from lib.utils.reid_eval import evaluator 13 | 14 | global ITER 15 | ITER = 0 16 | 17 | # logging 18 | ITER_LOG=0 19 | global WRITER 20 | WRITER = SummaryWriter(log_dir='output/logs') 21 | 22 | 23 | try: 24 | from apex.parallel import DistributedDataParallel as DDP 25 | from apex.fp16_utils import * 26 | from apex import amp, optimizers 27 | from apex.multi_tensor_apply import multi_tensor_applier 28 | except ImportError: 29 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") 30 | 31 | 32 | def do_train( 33 | cfg, 34 | model, 35 | dataset, 36 | train_loader, 37 | val_loader, 38 | optimizer, 39 | scheduler, 40 | loss_fn, 41 | num_query, 42 | start_epoch 43 | ): 44 | output_dir = cfg.OUTPUT_DIR 45 | device = cfg.MODEL.DEVICE 46 | 47 | if device: 48 | #model.to(device) 49 | model.cuda() 50 | # Apex FP16 training 51 | if cfg.SOLVER.FP16: 52 | logging.getLogger("Using Mix Precision training") 53 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 54 | 55 | logger = logging.getLogger("reid_baseline.train") 56 | logger.info("Start training") 57 | 58 | best_mAP = 0 59 | 60 | for epoch in range(start_epoch+1, cfg.SOLVER.MAX_EPOCHS+1): 61 | logger.info("Epoch[{}] lr={:.2e}" 62 | .format(epoch, scheduler.get_lr()[0])) 63 | 64 | # freeze feature layer at warmup stage 65 | if cfg.SOLVER.FREEZE_BASE_EPOCHS != 0: 66 | if epoch < cfg.SOLVER.FREEZE_BASE_EPOCHS: 67 | logger.info("freeze base layers") 68 | frozen_feature_layers(model) 69 | elif epoch == cfg.SOLVER.FREEZE_BASE_EPOCHS: 70 | logger.info("open all layers") 71 | open_all_layers(model) 72 | train(model, dataset, train_loader, optimizer, loss_fn, epoch, cfg, logger) 73 | 74 | if epoch % cfg.SOLVER.EVAL_PERIOD == 0 or epoch == cfg.SOLVER.MAX_EPOCHS: 75 | mAP, cmc = validate(model, dataset, val_loader, num_query, epoch, cfg, logger) 76 | ap_rank_1 = cmc[0] 77 | if mAP >= best_mAP: 78 | best_mAP = mAP 79 | torch.save(model.state_dict(), os.path.join(output_dir, 'best.pth')) 80 | 81 | scheduler.step() 82 | torch.cuda.empty_cache() # release cache 83 | torch.save({'state_dict': model.state_dict(), 'epoch': epoch, 'optimizer': optimizer.state_dict()}, 84 | os.path.join(output_dir, 'resume.pth.tar')) 85 | 86 | logger.info("best mAP: {:.1%}".format(best_mAP)) 87 | torch.save(model.state_dict(), os.path.join(output_dir, 'final.pth')) 88 | os.remove(os.path.join(output_dir, 'resume.pth.tar')) 89 | 90 | 91 | def train(model, dataset, train_loader, optimizer, loss_fn, epoch, cfg, logger): 92 | losses = AverageMeter() 93 | data_time = AverageMeter() 94 | model_time = AverageMeter() 95 | 96 | start = time.time() 97 | model.train() 98 | ITER = 0 99 | log_period = cfg.SOLVER.LOG_PERIOD 100 | data_start = time.time() 101 | # import ipdb; ipdb.set_trace() 102 | for batch in train_loader: 103 | data_time.update(time.time() - data_start) 104 | input, target, _, _, _ = batch 105 | input = input.cuda() 106 | target = target.cuda() 107 | model_start = time.time() 108 | ITER += 1 109 | optimizer.zero_grad() 110 | score, feat = model(input, target) 111 | id_loss, metric_loss = loss_fn(score, feat, target) 112 | loss = id_loss + metric_loss 113 | if cfg.SOLVER.FP16: 114 | with amp.scale_loss(loss, optimizer) as scaled_loss: 115 | scaled_loss.backward() 116 | else: 117 | loss.backward() 118 | optimizer.step() 119 | 120 | torch.cuda.synchronize() 121 | 122 | model_time.update(time.time() - model_start) 123 | losses.update(to_python_float(loss.data), input.size(0)) 124 | 125 | if ITER % log_period == 0: 126 | logger.info("Epoch[{}] Iteration[{}/{}] id_loss: {:.3f}, metric_loss: {:.5f}, total_loss: {:.3f}, data time: {:.3f}s, model time: {:.3f}s" 127 | .format(epoch, ITER, len(train_loader), 128 | id_loss.item(), metric_loss.item(), losses.val, data_time.val, model_time.val)) 129 | global ITER_LOG 130 | WRITER.add_scalar(f'Loss_Train_id_loss',id_loss.item(), ITER_LOG) 131 | WRITER.add_scalar(f'Loss_Train_metric_loss',metric_loss.item(), ITER_LOG) 132 | WRITER.add_scalar(f'Loss_Train_totals',losses.val, ITER_LOG) 133 | ITER_LOG+=1 134 | data_start = time.time() 135 | end = time.time() 136 | logger.info("epoch takes {:.3f}s".format((end - start))) 137 | return 138 | 139 | 140 | def validate(model, dataset, val_loader, num_query, epoch, cfg, logger): 141 | metric = evaluator(num_query, dataset, cfg, max_rank=50) 142 | # import ipdb; ipdb.set_trace() 143 | model.eval() 144 | with torch.no_grad(): 145 | for batch in val_loader: 146 | data, pid, camid, img_path = batch 147 | data = data.cuda() 148 | feats = model(data) 149 | output = [feats, pid, camid, img_path] 150 | metric.update(output) 151 | cmc, mAP, _ = metric.compute() 152 | logger.info("Validation Results - Epoch: {}".format(epoch)) 153 | logger.info("mAP: {:.1%}".format(mAP)) 154 | for r in [1, 5, 10]: 155 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 156 | return mAP, cmc 157 | 158 | 159 | class AverageMeter(object): 160 | """Computes and stores the average and current value""" 161 | def __init__(self): 162 | self.reset() 163 | 164 | def reset(self): 165 | self.val = 0 166 | self.avg = 0 167 | self.sum = 0 168 | self.count = 0 169 | 170 | def update(self, val, n=1): 171 | self.val = val 172 | self.sum += val * n 173 | self.count += n 174 | self.avg = self.sum / self.count 175 | 176 | 177 | def frozen_feature_layers(model): 178 | for name, module in model.named_children(): 179 | # if 'classifier' in name: 180 | # module.train() 181 | # for p in module.parameters(): 182 | # p.requires_grad = True 183 | # else: 184 | # module.eval() 185 | # for p in module.parameters(): 186 | # p.requires_grad = False 187 | if 'base' in name: 188 | module.eval() 189 | for p in module.parameters(): 190 | p.requires_grad = False 191 | 192 | 193 | def open_all_layers(model): 194 | for name, module in model.named_children(): 195 | module.train() 196 | for p in module.parameters(): 197 | p.requires_grad = True -------------------------------------------------------------------------------- /lib/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import make_loss -------------------------------------------------------------------------------- /lib/layers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/layers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/layers/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/layers/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /lib/layers/__pycache__/metric_learning.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/layers/__pycache__/metric_learning.cpython-37.pyc -------------------------------------------------------------------------------- /lib/layers/__pycache__/pooling.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/layers/__pycache__/pooling.cpython-37.pyc -------------------------------------------------------------------------------- /lib/layers/__pycache__/triplet_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/layers/__pycache__/triplet_loss.cpython-37.pyc -------------------------------------------------------------------------------- /lib/layers/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | from .triplet_loss import TripletLoss, CrossEntropyLabelSmooth 12 | from .metric_learning import ContrastiveLoss 13 | from .metric_learning import ContrastiveLoss, SupConLoss 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | def make_loss(cfg, num_classes): # modified by gu 17 | make_loss.update_iter_interval = 500 18 | make_loss.id_loss_history = [] 19 | make_loss.metric_loss_history = [] 20 | make_loss.ID_LOSS_WEIGHT = cfg.MODEL.ID_LOSS_WEIGHT 21 | make_loss.TRIPLET_LOSS_WEIGHT = cfg.MODEL.TRIPLET_LOSS_WEIGHT 22 | 23 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 24 | metric_loss_func = TripletLoss(cfg.SOLVER.MARGIN, cfg.SOLVER.HARD_EXAMPLE_MINING_METHOD) # triplet loss 25 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'contrastive': 26 | metric_loss_func = ContrastiveLoss(cfg.SOLVER.MARGIN) 27 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'supconloss': 28 | metric_loss_func = SupConLoss(num_ids=int(cfg.SOLVER.IMS_PER_BATCH/cfg.DATALOADER.NUM_INSTANCE), views=cfg.DATALOADER.NUM_INSTANCE) 29 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'none': 30 | def metric_loss_func(feat, target): 31 | return 0 32 | else: 33 | print('got unsupported metric loss type {}'.format( 34 | cfg.MODEL.METRIC_LOSS_TYPE)) 35 | 36 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 37 | id_loss_func = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo 38 | print("label smooth on, numclasses:", num_classes) 39 | else: 40 | id_loss_func = F.cross_entropy 41 | 42 | def loss_func(score, feat, target): 43 | _id_loss = id_loss_func(score, target) 44 | _metric_loss = metric_loss_func(feat, target) 45 | make_loss.id_loss_history.append(_id_loss.item()) 46 | make_loss.metric_loss_history.append(_metric_loss.item()) 47 | if len(make_loss.id_loss_history)==0: 48 | pass 49 | elif (len(make_loss.id_loss_history) % make_loss.update_iter_interval == 0): 50 | 51 | _id_history = np.array(make_loss.id_loss_history) 52 | id_mean = _id_history.mean() 53 | id_std = _id_history.std() 54 | 55 | _metric_history = np.array(make_loss.metric_loss_history) 56 | metric_mean = _metric_history.mean() 57 | metric_std = _metric_history.std() 58 | 59 | id_weighted = id_std 60 | metric_weighted = metric_std 61 | if id_weighted > metric_weighted: 62 | new_weight = 1 - (id_weighted-metric_weighted)/id_weighted 63 | make_loss.ID_LOSS_WEIGHT = make_loss.ID_LOSS_WEIGHT*0.9+new_weight*0.1 64 | 65 | make_loss.id_loss_history = [] 66 | make_loss.metric_loss_history = [] 67 | print(f"update weighted loss ID_LOSS_WEIGHT={round(make_loss.ID_LOSS_WEIGHT,3)},TRIPLET_LOSS_WEIGHT={make_loss.TRIPLET_LOSS_WEIGHT}") 68 | else: 69 | pass 70 | return make_loss.ID_LOSS_WEIGHT * _id_loss, make_loss.TRIPLET_LOSS_WEIGHT * _metric_loss 71 | return loss_func -------------------------------------------------------------------------------- /lib/layers/pooling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.parameter import Parameter 6 | 7 | ''' 8 | Generalized-mean (GeM) Pooling 9 | borrowed from cirtorch 10 | ''' 11 | class GeM(nn.Module): 12 | 13 | def __init__(self, p=3.0, eps=1e-6, freeze_p=True): 14 | super(GeM, self).__init__() 15 | self.p = p if freeze_p else Parameter(torch.ones(1) * p) 16 | self.eps = eps 17 | 18 | def forward(self, x): 19 | return F.adaptive_avg_pool2d(x.clamp(min=self.eps).pow(self.p), 20 | (1, 1)).pow(1. / self.p) 21 | 22 | def __repr__(self): 23 | if isinstance(self.p, float): 24 | p = self.p 25 | else: 26 | p = self.p.data.tolist()[0] 27 | return self.__class__.__name__ +\ 28 | '(' + 'p=' + '{:.4f}'.format(p) +\ 29 | ', ' + 'eps=' + str(self.eps) + ')' 30 | -------------------------------------------------------------------------------- /lib/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .baseline import Baseline_reduce, Baseline, Baseline_2_Head 8 | from .multiheads_baseline import Baseline as MultiHeadsBaseline 9 | 10 | 11 | 12 | def build_model(cfg, num_classes): 13 | if cfg.MODEL.MODEL_TYPE == 'baseline_reduce': 14 | print("using global feature baseline reduce") 15 | model = Baseline_reduce(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, 16 | cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE, 17 | cfg) 18 | elif cfg.MODEL.MODEL_TYPE == 'baseline': 19 | print("using global feature baseline") 20 | model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, 21 | cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE, 22 | cfg) 23 | elif cfg.MODEL.MODEL_TYPE == 'baseline_2_head': 24 | print("using low-level feature + high-level feature and GeM Pooling + Adaptive Pooling") 25 | model = Baseline_2_Head(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, 26 | cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE, 27 | cfg) 28 | elif cfg.MODEL.MODEL_TYPE == 'baseline_multiheads': 29 | print("using global feature baseline") 30 | model = MultiHeadsBaseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, 31 | cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE, 32 | cfg) 33 | 34 | else: 35 | print("unsupport model type") 36 | model = None 37 | 38 | return model 39 | -------------------------------------------------------------------------------- /lib/modeling/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/__pycache__/baseline.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/__pycache__/baseline.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/backbones/STNModule.py: -------------------------------------------------------------------------------- 1 | """ A plug and play Spatial Transformer Module in Pytorch """ 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | 10 | 11 | class SpatialTransformer(nn.Module): 12 | """ 13 | Implements a spatial transformer 14 | as proposed in the Jaderberg paper. 15 | Comprises of 3 parts: 16 | 1. Localization Net 17 | 2. A grid generator 18 | 3. A roi pooled module. 19 | The current implementation uses a very small convolutional net with 20 | 2 convolutional layers and 2 fully connected layers. Backends 21 | can be swapped in favor of VGG, ResNets etc. TTMV 22 | Returns: 23 | A roi feature map with the same input spatial dimension as the input feature map. 24 | """ 25 | def __init__(self, in_channels, spatial_dims, kernel_size,use_dropout=True): 26 | super(SpatialTransformer, self).__init__() 27 | self._h, self._w = spatial_dims 28 | self._in_ch = in_channels 29 | self._ksize = kernel_size 30 | self.dropout = use_dropout 31 | 32 | # localization net 33 | self.conv1_stn = nn.Conv2d(in_channels, 32, kernel_size=self._ksize, stride=1, padding=1, bias=False) # size : [1x3x32x32] 34 | self.conv2_stn = nn.Conv2d(32, 32, kernel_size=self._ksize, stride=1, padding=1, bias=False) 35 | self.conv3_stn = nn.Conv2d(32, 32, kernel_size=self._ksize, stride=1, padding=1, bias=False) 36 | self.conv4_stn = nn.Conv2d(32, 32, kernel_size=self._ksize, stride=1, padding=1, bias=False) 37 | self.conv5_stn = nn.Conv2d(32, 32, kernel_size=self._ksize, stride=1, padding=1, bias=False) 38 | 39 | self.fc1_stn = nn.Linear(32*20*20, 512) 40 | self.fc2_stn = nn.Linear(512, 6) 41 | 42 | 43 | def forward(self, x): 44 | """ 45 | Forward pass of the STN module. 46 | x -> input feature map 47 | """ 48 | batch_images = x 49 | x = F.relu(self.conv1_stn(x.detach())) 50 | x = F.relu(self.conv2_stn(x)) 51 | x = F.max_pool2d(x, 2) 52 | x = F.relu(self.conv3_stn(x)) 53 | x = F.max_pool2d(x,2) 54 | x = F.relu(self.conv4_stn(x)) 55 | x = F.max_pool2d(x, 2) 56 | x = F.relu(self.conv5_stn(x)) 57 | x = F.max_pool2d(x, 2) 58 | # print("Pre view size:{}".format(x.size())) 59 | x = x.view(-1, 32*20*20) 60 | if self.dropout: 61 | x = F.dropout(self.fc1_stn(x), p=0.3) 62 | x = self.fc2_stn(x) 63 | else: 64 | x = self.fc1_stn(x) 65 | x = self.fc2_stn(x) # params [Nx6] 66 | # import ipdb; ipdb.set_trace() 67 | x = x.view(-1, 2,3) # change it to the 2x3 matrix 68 | # print(x.size()) 69 | affine_grid_points = F.affine_grid(x, torch.Size((x.size(0), self._in_ch, self._h, self._w))) 70 | assert(affine_grid_points.size(0) == batch_images.size(0)), "The batch sizes of the input images must be same as the generated grid." 71 | rois = F.grid_sample(batch_images, affine_grid_points) 72 | # print("rois found to be of size:{}".format(rois.size())) 73 | return rois, affine_grid_points -------------------------------------------------------------------------------- /lib/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet50, resnet152 2 | from .resnet_ibn_a import resnet50_ibn_a, resnet101_ibn_a, se_resnet101_ibn_a 3 | from .resnext_ibn_a import resnext50_ibn_a, resnext101_ibn_a 4 | from .resnext_ibn_a_2_head import resnext101_ibn_a_2_head 5 | from .resnest import resnest50 6 | from .regnet.regnet import regnety_800mf, regnety_1600mf, regnety_3200mf 7 | from .mixstyle import MixStyle, MixStyle2 8 | from .STNModule import SpatialTransformer 9 | from .resnext_ibn_a_attention import resnext101_ibn_a_attention 10 | # from .nfnet import dm_nfnet_f0 11 | 12 | factory = { 13 | 'resnet50': resnet50, 14 | 'resnet50_ibn_a': resnet50_ibn_a, 15 | 'resnet101_ibn_a': resnet101_ibn_a, 16 | 'resnext101_ibn_a': resnext101_ibn_a, 17 | 'resnext101_ibn_a_2_head': resnext101_ibn_a_2_head, 18 | 'resnext101_ibn_a_attention': resnext101_ibn_a_attention, 19 | 'resnest50': resnest50, 20 | 'regnety_800mf': regnety_800mf, 21 | 'regnety_1600mf': regnety_1600mf, 22 | 'regnety_3200mf': regnety_3200mf, 23 | 'resnet152': resnet152, 24 | } 25 | def build_backbone(name, *args, **kwargs): 26 | if name not in factory.keys(): 27 | raise KeyError("Unknown datasets: {}".format(name)) 28 | return factory[name](*args, **kwargs) -------------------------------------------------------------------------------- /lib/modeling/backbones/__pycache__/STNModule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/backbones/__pycache__/STNModule.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/backbones/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/backbones/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/backbones/__pycache__/mixstyle.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/backbones/__pycache__/mixstyle.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/backbones/__pycache__/resnest.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/backbones/__pycache__/resnest.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/backbones/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/backbones/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/backbones/__pycache__/resnet_ibn_a.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/backbones/__pycache__/resnet_ibn_a.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/backbones/__pycache__/resnext_ibn_a.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/backbones/__pycache__/resnext_ibn_a.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/backbones/__pycache__/resnext_ibn_a_2_head.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/backbones/__pycache__/resnext_ibn_a_2_head.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/backbones/__pycache__/resnext_ibn_a_attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/backbones/__pycache__/resnext_ibn_a_attention.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/backbones/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.checkpoint as cp 5 | from collections import OrderedDict 6 | from torch import Tensor 7 | from torch.jit.annotations import List 8 | import re 9 | 10 | __all__ = ['DenseNet', 'densenet121'] 11 | 12 | model_urls = { 13 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 14 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 15 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 16 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 17 | } 18 | 19 | class _DenseLayer(nn.Sequential): 20 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 21 | super(_DenseLayer, self).__init__() 22 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 23 | self.add_module('relu1', nn.ReLU(inplace=True)), 24 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 25 | growth_rate, kernel_size=1, stride=1, 26 | bias=False)), 27 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 28 | self.add_module('relu2', nn.ReLU(inplace=True)), 29 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 30 | kernel_size=3, stride=1, padding=1, 31 | bias=False)), 32 | self.drop_rate = drop_rate 33 | 34 | def forward(self, x): 35 | new_features = super(_DenseLayer, self).forward(x) 36 | if self.drop_rate > 0: 37 | new_features = F.dropout(new_features, p=self.drop_rate, 38 | training=self.training) 39 | return torch.cat([x, new_features], 1) 40 | 41 | 42 | class _DenseBlock(nn.Sequential): 43 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 44 | super(_DenseBlock, self).__init__() 45 | for i in range(num_layers): 46 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, 47 | bn_size, drop_rate) 48 | self.add_module('denselayer%d' % (i + 1), layer) 49 | 50 | 51 | class _Transition(nn.Sequential): 52 | def __init__(self, num_input_features, num_output_features, last_stride=2): 53 | super(_Transition, self).__init__() 54 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 55 | self.add_module('relu', nn.ReLU(inplace=True)) 56 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 57 | kernel_size=1, stride=1, bias=False)) 58 | print('last_stride: ', last_stride) 59 | if last_stride == 2: 60 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 61 | 62 | 63 | class DenseNet(nn.Module): 64 | r"""Densenet-BC model class, based on 65 | `"Densely Connected Convolutional Networks" `_ 66 | Args: 67 | growth_rate (int) - how many filters to add each layer (`k` in paper) 68 | block_config (list of 4 ints) - how many layers in each pooling block 69 | num_init_features (int) - the number of filters to learn in the first convolution layer 70 | bn_size (int) - multiplicative factor for number of bottle neck layers 71 | (i.e. bn_size * k features in the bottleneck layer) 72 | drop_rate (float) - dropout rate after each dense layer 73 | num_classes (int) - number of classification classes 74 | """ 75 | 76 | def __init__(self, last_stride, growth_rate=32, block_config=(6, 12, 24, 16), 77 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 78 | 79 | super(DenseNet, self).__init__() 80 | 81 | self.features = nn.Sequential(OrderedDict([ 82 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, 83 | padding=3, bias=False)), 84 | ('norm0', nn.BatchNorm2d(num_init_features)), 85 | ('relu0', nn.ReLU(inplace=True)), 86 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 87 | ])) 88 | 89 | # Each denseblock 90 | num_features = num_init_features 91 | for i, num_layers in enumerate(block_config): 92 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 93 | bn_size=bn_size, growth_rate=growth_rate, 94 | drop_rate=drop_rate) 95 | self.features.add_module('denseblock%d' % (i + 1), block) 96 | num_features = num_features + num_layers * growth_rate 97 | if i != len(block_config) - 1: 98 | if i == 2: 99 | trans = _Transition(num_input_features=num_features, 100 | num_output_features=num_features // 2, last_stride=1) 101 | else: 102 | trans = _Transition(num_input_features=num_features, 103 | num_output_features=num_features // 2) 104 | self.features.add_module('transition%d' % (i + 1), trans) 105 | num_features = num_features // 2 106 | 107 | # Final batch norm 108 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 109 | 110 | # Linear layer 111 | self.classifier = nn.Linear(num_features, num_classes) 112 | 113 | # Official init from torch repo. 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | nn.init.kaiming_normal_(m.weight) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | nn.init.constant_(m.weight, 1) 119 | nn.init.constant_(m.bias, 0) 120 | elif isinstance(m, nn.Linear): 121 | nn.init.constant_(m.bias, 0) 122 | 123 | def forward(self, x): 124 | features = self.features(x) 125 | out = F.relu(features, inplace=True) 126 | return out 127 | 128 | def load_param(self, model_path): 129 | pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 130 | state_dict = torch.load(model_path) 131 | 132 | for key in list(state_dict.keys()): 133 | res = pattern.match(key) 134 | if res: 135 | new_key = res.group(1) + res.group(2) 136 | state_dict[new_key] = state_dict[key] 137 | del state_dict[key] 138 | self.load_state_dict(state_dict) 139 | 140 | 141 | def densenet121(last_stride, pretrained=False, **kwargs): 142 | r"""Densenet-121 model from 143 | `"Densely Connected Convolutional Networks" `_ 144 | 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | progress (bool): If True, displays a progress bar of the download to stderr 148 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 149 | but slower. Default: *False*. See `"paper" `_ 150 | """ 151 | return DenseNet(last_stride, 32, (6, 12, 24, 16), 64, **kwargs) 152 | 153 | 154 | -------------------------------------------------------------------------------- /lib/modeling/backbones/mixstyle.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class MixStyle(nn.Module): 7 | """MixStyle. 8 | Reference: 9 | Zhou et al. Domain Generalization with MixStyle. ICLR 2021. 10 | """ 11 | 12 | def __init__(self, p=0.5, alpha=0.3, eps=1e-6): 13 | """ 14 | Args: 15 | p (float): probability of using MixStyle. 16 | alpha (float): parameter of the Beta distribution. 17 | eps (float): scaling parameter to avoid numerical issues. 18 | """ 19 | super().__init__() 20 | self.p = p 21 | self.beta = torch.distributions.Beta(alpha, alpha) 22 | self.eps = eps 23 | self.alpha = alpha 24 | 25 | print('* MixStyle params') 26 | print(f'- p: {p}') 27 | print(f'- alpha: {alpha}') 28 | 29 | def __repr__(self): 30 | return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})' 31 | 32 | def forward(self, x): 33 | if not self.training: 34 | return x 35 | 36 | if random.random() > self.p: 37 | return x 38 | 39 | B = x.size(0) 40 | 41 | mu = x.mean(dim=[2, 3], keepdim=True) 42 | var = x.var(dim=[2, 3], keepdim=True) 43 | sig = (var + self.eps).sqrt() 44 | mu, sig = mu.detach(), sig.detach() 45 | x_normed = (x - mu) / sig 46 | 47 | lmda = self.beta.sample((B, 1, 1, 1)) 48 | lmda = lmda.to(x.device) 49 | 50 | perm = torch.randperm(B) 51 | mu2, sig2 = mu[perm], sig[perm] 52 | mu_mix = mu * lmda + mu2 * (1 - lmda) 53 | sig_mix = sig * lmda + sig2 * (1 - lmda) 54 | 55 | return x_normed * sig_mix + mu_mix 56 | 57 | 58 | class MixStyle2(nn.Module): 59 | """MixStyle (w/ domain prior). 60 | The input should contain two equal-sized mini-batches from two distinct domains. 61 | Reference: 62 | Zhou et al. Domain Generalization with MixStyle. ICLR 2021. 63 | """ 64 | 65 | def __init__(self, p=0.5, alpha=0.3, eps=1e-6): 66 | """ 67 | Args: 68 | p (float): probability of using MixStyle. 69 | alpha (float): parameter of the Beta distribution. 70 | eps (float): scaling parameter to avoid numerical issues. 71 | """ 72 | super().__init__() 73 | self.p = p 74 | self.beta = torch.distributions.Beta(alpha, alpha) 75 | self.eps = eps 76 | self.alpha = alpha 77 | 78 | print('* MixStyle params') 79 | print(f'- p: {p}') 80 | print(f'- alpha: {alpha}') 81 | 82 | def __repr__(self): 83 | return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})' 84 | 85 | def forward(self, x): 86 | """ 87 | For the input x, the first half comes from one domain, 88 | while the second half comes from the other domain. 89 | """ 90 | if not self.training: 91 | return x 92 | 93 | if random.random() > self.p: 94 | return x 95 | 96 | B = x.size(0) 97 | 98 | mu = x.mean(dim=[2, 3], keepdim=True) 99 | var = x.var(dim=[2, 3], keepdim=True) 100 | sig = (var + self.eps).sqrt() 101 | mu, sig = mu.detach(), sig.detach() 102 | x_normed = (x - mu) / sig 103 | 104 | lmda = self.beta.sample((B, 1, 1, 1)) 105 | lmda = lmda.to(x.device) 106 | 107 | perm = torch.arange(B-1, -1, -1) # inverse index 108 | perm_b, perm_a = perm.chunk(2) 109 | perm_b = perm_b[torch.randperm(B // 2)] 110 | perm_a = perm_a[torch.randperm(B // 2)] 111 | perm = torch.cat([perm_b, perm_a], 0) 112 | 113 | mu2, sig2 = mu[perm], sig[perm] 114 | mu_mix = mu * lmda + mu2 * (1 - lmda) 115 | sig_mix = sig * lmda + sig2 * (1 - lmda) 116 | 117 | return x_normed * sig_mix + mu_mix -------------------------------------------------------------------------------- /lib/modeling/backbones/regnet/RegNetY-1.6GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 27 7 | W0: 48 8 | WA: 20.71 9 | WM: 2.65 10 | GROUP_W: 24 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.8 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_EPOCHS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 1024 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 800 26 | NUM_GPUS: 8 27 | OUT_DIR: . -------------------------------------------------------------------------------- /lib/modeling/backbones/regnet/RegNetY-3.2GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | 2 | MODEL: 3 | TYPE: regnet 4 | NUM_CLASSES: 1000 5 | REGNET: 6 | SE_ON: True 7 | DEPTH: 21 8 | W0: 80 9 | WA: 42.63 10 | WM: 2.66 11 | GROUP_W: 24 12 | OPTIM: 13 | LR_POLICY: cos 14 | BASE_LR: 0.4 15 | MAX_EPOCH: 100 16 | MOMENTUM: 0.9 17 | WEIGHT_DECAY: 5e-5 18 | WARMUP_EPOCHS: 5 19 | TRAIN: 20 | DATASET: imagenet 21 | IM_SIZE: 224 22 | BATCH_SIZE: 512 23 | TEST: 24 | DATASET: imagenet 25 | IM_SIZE: 256 26 | BATCH_SIZE: 400 27 | NUM_GPUS: 8 28 | OUT_DIR: . -------------------------------------------------------------------------------- /lib/modeling/backbones/regnet/RegNetY-800MF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 14 7 | W0: 56 8 | WA: 38.84 9 | WM: 2.4 10 | GROUP_W: 16 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.8 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_EPOCHS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 1024 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 800 26 | NUM_GPUS: 8 27 | OUT_DIR: . -------------------------------------------------------------------------------- /lib/modeling/backbones/regnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/backbones/regnet/__init__.py -------------------------------------------------------------------------------- /lib/modeling/backbones/regnet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/backbones/regnet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/backbones/regnet/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/backbones/regnet/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/backbones/regnet/__pycache__/regnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/modeling/backbones/regnet/__pycache__/regnet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/modeling/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | from torch import nn 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 59 | padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(planes * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | out = self.bn3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class ResNet(nn.Module): 91 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]): 92 | self.inplanes = 64 93 | super().__init__() 94 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 95 | bias=False) 96 | self.bn1 = nn.BatchNorm2d(64) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 99 | self.layer1 = self._make_layer(block, 64, layers[0]) 100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 102 | self.layer4 = self._make_layer( 103 | block, 512, layers[3], stride=last_stride) 104 | 105 | def _make_layer(self, block, planes, blocks, stride=1): 106 | downsample = None 107 | if stride != 1 or self.inplanes != planes * block.expansion: 108 | downsample = nn.Sequential( 109 | nn.Conv2d(self.inplanes, planes * block.expansion, 110 | kernel_size=1, stride=stride, bias=False), 111 | nn.BatchNorm2d(planes * block.expansion), 112 | ) 113 | 114 | layers = [] 115 | layers.append(block(self.inplanes, planes, stride, downsample)) 116 | self.inplanes = planes * block.expansion 117 | for i in range(1, blocks): 118 | layers.append(block(self.inplanes, planes)) 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | x = self.conv1(x) 124 | x = self.bn1(x) 125 | x = self.relu(x) 126 | x = self.maxpool(x) 127 | 128 | x = self.layer1(x) 129 | x = self.layer2(x) 130 | x = self.layer3(x) 131 | x = self.layer4(x) 132 | 133 | return x 134 | 135 | def load_param(self, model_path): 136 | param_dict = torch.load(model_path) 137 | for i in param_dict: 138 | if 'fc' in i: 139 | continue 140 | self.state_dict()[i].copy_(param_dict[i]) 141 | 142 | def random_init(self): 143 | for m in self.modules(): 144 | if isinstance(m, nn.Conv2d): 145 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 146 | m.weight.data.normal_(0, math.sqrt(2. / n)) 147 | elif isinstance(m, nn.BatchNorm2d): 148 | m.weight.data.fill_(1) 149 | m.bias.data.zero_() 150 | 151 | def resnet50(last_stride): 152 | return ResNet(last_stride=last_stride,block=Bottleneck, layers=[3, 4, 6, 3]) 153 | 154 | def resnet152(last_stride): 155 | return ResNet(last_stride=last_stride,block=Bottleneck, layers=[3, 8, 36, 3]) -------------------------------------------------------------------------------- /lib/modeling/backbones/resnet_ibn_b.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | __all__ = ['ResNet', 'resnet50_ibn_b', 'resnet101_ibn_b', 7 | 'resnet152_ibn_b'] 8 | 9 | model_urls = { 10 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 11 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 12 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 13 | } 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | "3x3 convolution with padding" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, IN=False): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 62 | padding=1, bias=False) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 66 | self.IN = None 67 | if IN: 68 | self.IN = nn.InstanceNorm2d(planes * 4, affine=True) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | if self.IN is not None: 92 | out = self.IN(out) 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class ResNet(nn.Module): 99 | 100 | def __init__(self, last_stride, block, layers, num_classes=1000): 101 | scale = 64 102 | self.inplanes = scale 103 | super(ResNet, self).__init__() 104 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, 105 | bias=False) 106 | self.bn1 = nn.InstanceNorm2d(scale, affine=True) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 109 | self.layer1 = self._make_layer(block, scale, layers[0], stride=1, IN=True) 110 | self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2, IN=True) 111 | self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2) 112 | self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=last_stride) 113 | self.avgpool = nn.AvgPool2d(7) 114 | self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) 115 | 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 119 | m.weight.data.normal_(0, math.sqrt(2. / n)) 120 | elif isinstance(m, nn.BatchNorm2d): 121 | m.weight.data.fill_(1) 122 | m.bias.data.zero_() 123 | elif isinstance(m, nn.InstanceNorm2d): 124 | m.weight.data.fill_(1) 125 | m.bias.data.zero_() 126 | 127 | def _make_layer(self, block, planes, blocks, stride=1, IN=False): 128 | downsample = None 129 | if stride != 1 or self.inplanes != planes * block.expansion: 130 | downsample = nn.Sequential( 131 | nn.Conv2d(self.inplanes, planes * block.expansion, 132 | kernel_size=1, stride=stride, bias=False), 133 | nn.BatchNorm2d(planes * block.expansion), 134 | ) 135 | 136 | layers = [] 137 | layers.append(block(self.inplanes, planes, stride, downsample)) 138 | self.inplanes = planes * block.expansion 139 | for i in range(1, blocks - 1): 140 | layers.append(block(self.inplanes, planes)) 141 | layers.append(block(self.inplanes, planes, IN=IN)) 142 | 143 | return nn.Sequential(*layers) 144 | 145 | def forward(self, x): 146 | x = self.conv1(x) 147 | x = self.bn1(x) 148 | x = self.relu(x) 149 | x = self.maxpool(x) 150 | 151 | x = self.layer1(x) 152 | x = self.layer2(x) 153 | x = self.layer3(x) 154 | x = self.layer4(x) 155 | 156 | # x = self.avgpool(x) 157 | # x = x.view(x.size(0), -1) 158 | # x = self.fc(x) 159 | 160 | return x 161 | 162 | def load_param(self, model_path): 163 | param_dict = torch.load(model_path) 164 | for i in param_dict: 165 | if 'fc' in i: 166 | continue 167 | self.state_dict()[i].copy_(param_dict[i]) 168 | 169 | 170 | def resnet50_ibn_b(last_stride, pretrained=False, **kwargs): 171 | """Constructs a ResNet-50 model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(last_stride, Bottleneck, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 178 | return model 179 | 180 | 181 | def resnet101_ibn_b(last_stride, pretrained=False, **kwargs): 182 | """Constructs a ResNet-101 model. 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(last_stride, Bottleneck, [3, 4, 23, 3], **kwargs) 187 | if pretrained: 188 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 189 | return model 190 | 191 | 192 | def resnet152_ibn_b(pretrained=False, **kwargs): 193 | """Constructs a ResNet-152 model. 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 198 | if pretrained: 199 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 200 | return model -------------------------------------------------------------------------------- /lib/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import make_optimizer 8 | from .lr_scheduler import build_lr_scheduler, WarmupMultiStepLR -------------------------------------------------------------------------------- /lib/solver/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/solver/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/solver/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/solver/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /lib/solver/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/solver/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /lib/solver/__pycache__/ranger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/solver/__pycache__/ranger.cpython-37.pyc -------------------------------------------------------------------------------- /lib/solver/__pycache__/swa.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/solver/__pycache__/swa.cpython-37.pyc -------------------------------------------------------------------------------- /lib/solver/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .ranger import Ranger 11 | from .swa import SWA 12 | 13 | 14 | def make_optimizer(cfg, model): 15 | params = [] 16 | for key, value in model.named_parameters(): 17 | if not value.requires_grad: 18 | continue 19 | lr = cfg.SOLVER.BASE_LR 20 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 21 | if "bias" in key: 22 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 23 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 24 | if 'classifier' in key: # different learning rate for initialized fc layers 25 | lr = cfg.SOLVER.FC_LR_FACTOR * lr 26 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 27 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 28 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 29 | elif cfg.SOLVER.OPTIMIZER_NAME == 'Ranger': 30 | optimizer = Ranger(params) 31 | elif cfg.SOLVER.OPTIMIZER_NAME == 'SWA': 32 | print('using SWA') 33 | optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM) 34 | optimizer = SWA(optimizer, swa_start=0, swa_freq=1) 35 | else: 36 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 37 | return optimizer 38 | -------------------------------------------------------------------------------- /lib/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | import math 9 | 10 | 11 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 12 | # separating MultiStepLR with WarmupLR 13 | # but the current LRScheduler design doesn't allow it 14 | 15 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 16 | def __init__( 17 | self, 18 | optimizer, 19 | milestones, 20 | gamma=0.1, 21 | warmup_factor=1.0 / 3, 22 | warmup_iters=10, 23 | warmup_method="linear", 24 | last_epoch=-1, 25 | ): 26 | if not list(milestones) == sorted(milestones): 27 | raise ValueError( 28 | "Milestones should be a list of" " increasing integers. Got {}", 29 | milestones, 30 | ) 31 | 32 | if warmup_method not in ("constant", "linear"): 33 | raise ValueError( 34 | "Only 'constant' or 'linear' warmup_method accepted" 35 | "got {}".format(warmup_method) 36 | ) 37 | self.milestones = milestones 38 | self.gamma = gamma 39 | self.warmup_factor = warmup_factor 40 | self.warmup_iters = warmup_iters 41 | self.warmup_method = warmup_method 42 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 43 | 44 | def get_lr(self): 45 | warmup_factor = 1 46 | if self.last_epoch < self.warmup_iters: 47 | if self.warmup_method == "constant": 48 | warmup_factor = self.warmup_factor 49 | elif self.warmup_method == "linear": 50 | alpha = self.last_epoch / self.warmup_iters 51 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 52 | return [ 53 | base_lr 54 | * warmup_factor 55 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 56 | for base_lr in self.base_lrs 57 | ] 58 | 59 | 60 | ''' 61 | Bag of Tricks for Image Classification with Convolutional Neural Networks 62 | ''' 63 | class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): 64 | def __init__( 65 | self, 66 | optimizer, 67 | max_epochs, 68 | warmup_epochs=10, 69 | eta_min=1e-7, 70 | last_epoch=-1, 71 | ): 72 | self.max_epochs = max_epochs - 1 73 | self.eta_min=eta_min 74 | self.warmup_epochs = warmup_epochs 75 | super(WarmupCosineLR, self).__init__(optimizer, last_epoch) 76 | 77 | 78 | def get_lr(self): 79 | if self.last_epoch < self.warmup_epochs: 80 | lr = [base_lr * (self.last_epoch+1) / (self.warmup_epochs + 1e-32) for base_lr in self.base_lrs] 81 | else: 82 | lr = [self.eta_min + (base_lr - self.eta_min) * 83 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) / 2 84 | for base_lr in self.base_lrs] 85 | return lr 86 | 87 | 88 | class CosineStepLR(torch.optim.lr_scheduler._LRScheduler): 89 | def __init__( 90 | self, 91 | optimizer, 92 | max_epochs, 93 | step_epochs=2, 94 | gamma=0.3, 95 | eta_min=0, 96 | last_epoch=-1, 97 | ): 98 | self.max_epochs = max_epochs 99 | self.eta_min=eta_min 100 | self.step_epochs = step_epochs 101 | self.gamma = gamma 102 | self.last_cosine_lr = 0 103 | super(CosineStepLR, self).__init__(optimizer, last_epoch) 104 | 105 | 106 | def get_lr(self): 107 | if self.last_epoch < self.max_epochs - self.step_epochs: 108 | lr = [self.eta_min + (base_lr - self.eta_min) * 109 | (1 + math.cos(math.pi * (self.last_epoch) / (self.max_epochs - self.step_epochs))) / 2 110 | for base_lr in self.base_lrs] 111 | self.last_cosine_lr = lr 112 | else: 113 | lr = [self.gamma ** (self.step_epochs - self.max_epochs + self.last_epoch + 1) * base_lr for base_lr in self.last_cosine_lr] 114 | 115 | return lr 116 | 117 | 118 | class CyclicCosineLR(torch.optim.lr_scheduler._LRScheduler): 119 | def __init__(self, 120 | optimizer, 121 | cycle_epoch, 122 | cycle_decay=0.7, 123 | last_epoch=-1): 124 | self.cycle_decay = cycle_decay 125 | self.cycle_epoch = cycle_epoch 126 | self.cur_count = 0 127 | super(CyclicCosineLR, self).__init__(optimizer, last_epoch) 128 | 129 | def get_lr(self): 130 | self.cur_count = (self.last_epoch + 1) // self.cycle_epoch 131 | decay = self.cycle_decay ** self.cur_count 132 | return [base_lr * decay * 133 | (1 + math.cos(math.pi * (self.last_epoch % self.cycle_epoch) / self.cycle_epoch)) / 2 134 | for base_lr in self.base_lrs] 135 | 136 | 137 | 138 | def build_lr_scheduler(optimizer, lr_scheduler, cfg, last_epoch): 139 | if lr_scheduler == 'warmup_multi_step': 140 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 141 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, last_epoch=last_epoch) 142 | elif lr_scheduler == 'cosine': 143 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 144 | optimizer, float(cfg.SOLVER.MAX_EPOCHS), last_epoch=last_epoch) 145 | elif lr_scheduler == 'warmup_cosine': 146 | scheduler = WarmupCosineLR(optimizer, max_epochs=float(cfg.SOLVER.MAX_EPOCHS), 147 | warmup_epochs=cfg.SOLVER.WARMUP_ITERS, last_epoch=last_epoch) 148 | elif lr_scheduler == 'cyclic_cosine': 149 | scheduler = CyclicCosineLR(optimizer, cfg.SOLVER.CYCLE_EPOCH) 150 | elif lr_scheduler == 'cosine_step': 151 | scheduler = CosineStepLR(optimizer, max_epochs=float(cfg.SOLVER.MAX_EPOCHS), last_epoch=last_epoch) 152 | else:# multi-steps as default 153 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 154 | optimizer, milestones=cfg.SOLVER.STEPS, gamma=cfg.SOLVER.GAMMA, last_epoch=last_epoch) 155 | 156 | return scheduler 157 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /lib/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/__pycache__/iotools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/utils/__pycache__/iotools.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/__pycache__/post_process.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/utils/__pycache__/post_process.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/__pycache__/reid_eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybercore-co-ltd/track2_aicity_2021/d4262b43b4cea4b43e6fe880be6af8eac82dd6e9/lib/utils/__pycache__/reid_eval.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/actmap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | ''' 5 | input: featmap: torch.tensor.float(), N*C*H*W 6 | output: foreground_mask, background_mask: torch.tensor.bool(), N*1*H*W 7 | ''' 8 | def batch_attention_mask(featmap, threshold=0.75): 9 | actmap = (featmap**2).sum(dim=1) # N*W*H 10 | val = actmap.view(actmap.size(0), -1) 11 | min_val, _ = val.min(dim=1) 12 | max_val, _ = val.max(dim=1) 13 | thr = min_val + (max_val - min_val) * threshold 14 | for i in range(actmap.size(0)): 15 | actmap[i] = actmap[i] < thr[i] 16 | return actmap.unsqueeze(dim=1) 17 | 18 | 19 | def generate_attention_mask(actmap, threshold=1.0): 20 | actmap = (actmap - np.min(actmap)) / ( 21 | np.max(actmap) - np.min(actmap) + 1e-12 22 | ) 23 | foreground_mask = actmap >= (actmap.mean() * threshold) 24 | background_mask = actmap < (actmap.mean() * threshold) 25 | return foreground_mask, background_mask 26 | -------------------------------------------------------------------------------- /lib/utils/bbox_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.ndimage import label 4 | import cv2 5 | 6 | def extract_bbox_from_mask(input): 7 | assert input.ndim == 2, 'Invalid input shape' 8 | rows = np.any(input, axis=1) 9 | cols = np.any(input, axis=0) 10 | ymin, ymax = np.where(rows)[0][[0, -1]] 11 | xmin, xmax = np.where(cols)[0][[0, -1]] 12 | return xmin, ymin, xmax, ymax 13 | 14 | 15 | def localize_from_map(actmap, threshold_ratio=0.5): 16 | foreground_map = actmap >= (actmap.mean() * threshold_ratio) 17 | # single object 18 | try: 19 | bbox = extract_bbox_from_mask(foreground_map) 20 | except: 21 | bbox = None 22 | return bbox 23 | 24 | 25 | def bbox_nms(bbox_list, threshold=0.5): 26 | bbox_list = sorted(bbox_list, key=lambda x: x[-1], reverse=True) 27 | selected_bboxes = [] 28 | while len(bbox_list) > 0: 29 | obj = bbox_list.pop(0) 30 | selected_bboxes.append(obj) 31 | def iou_filter(x): 32 | iou = compute_iou(obj[1:5], x[1:5]) 33 | if (x[0] == obj[0] and iou >= threshold): 34 | return None 35 | else: 36 | return x 37 | bbox_list = list(filter(iou_filter, bbox_list)) 38 | return selected_bboxes 39 | 40 | 41 | def compute_iou(box_a, box_b): 42 | x_a = max(box_a[0], box_b[0]) 43 | y_a = max(box_a[1], box_b[1]) 44 | x_b = min(box_a[2], box_b[2]) 45 | y_b = min(box_a[3], box_b[3]) 46 | inter_area = max(x_b - x_a + 1, 0) * max(y_b - y_a + 1, 0) 47 | box_a_area = (box_a[2] - box_a[0] + 1) * (box_a[3] - box_a[1] + 1) 48 | box_b_area = (box_b[2] - box_b[0] + 1) * (box_b[3] - box_b[1] + 1) 49 | return inter_area / float(box_a_area + box_b_area - inter_area) 50 | 51 | 52 | def draw_bbox(canvas, bboxes, color=[255, 0, 255]): 53 | for bbox in bboxes: 54 | if bbox is not None: 55 | cv2.rectangle(canvas, (int(bbox[0]), int(bbox[1])), ((int(bbox[2]), int(bbox[3]))), color=color) 56 | return canvas 57 | -------------------------------------------------------------------------------- /lib/utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import errno 8 | import json 9 | import os 10 | 11 | import os.path as osp 12 | 13 | 14 | def mkdir_if_missing(directory): 15 | if not osp.exists(directory): 16 | try: 17 | os.makedirs(directory) 18 | except OSError as e: 19 | if e.errno != errno.EEXIST: 20 | raise 21 | 22 | 23 | def check_isfile(path): 24 | isfile = osp.isfile(path) 25 | if not isfile: 26 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 27 | return isfile 28 | 29 | 30 | def read_json(fpath): 31 | with open(fpath, 'r') as f: 32 | obj = json.load(f) 33 | return obj 34 | 35 | 36 | def write_json(obj, fpath): 37 | mkdir_if_missing(osp.dirname(fpath)) 38 | with open(fpath, 'w') as f: 39 | json.dump(obj, f, indent=4, separators=(',', ': ')) 40 | -------------------------------------------------------------------------------- /lib/utils/logger.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | 12 | def setup_logger(name, save_dir, distributed_rank): 13 | logger = logging.getLogger(name) 14 | logger.setLevel(logging.DEBUG) 15 | # don't log results for the non-master process 16 | if distributed_rank > 0: 17 | return logger 18 | ch = logging.StreamHandler(stream=sys.stdout) 19 | ch.setLevel(logging.DEBUG) 20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 21 | ch.setFormatter(formatter) 22 | logger.addHandler(ch) 23 | 24 | if save_dir: 25 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w') 26 | fh.setLevel(logging.DEBUG) 27 | fh.setFormatter(formatter) 28 | logger.addHandler(fh) 29 | 30 | return logger 31 | -------------------------------------------------------------------------------- /lib/utils/reid_eval.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | import json 10 | import os.path as osp 11 | import os 12 | import time 13 | from .post_process import * 14 | 15 | def eval_func(indices, q_pids, g_pids, q_camids, g_camids, max_rank=50): 16 | """Evaluation with market1501 metric 17 | Key: for each query identity, its gallery images from the same camera view are discarded. 18 | """ 19 | num_q, num_g = indices.shape 20 | if num_g < max_rank: 21 | max_rank = num_g 22 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 23 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 24 | 25 | # compute cmc curve for each query 26 | all_cmc = [] 27 | all_AP = [] 28 | total_camids = len(set(g_camids)) 29 | num_valid_q = 0. # number of valid query 30 | for q_idx in range(num_q): 31 | # get query pid and camid 32 | q_pid = q_pids[q_idx] 33 | q_camid = q_camids[q_idx] 34 | 35 | # remove gallery samples that have the same pid and camid with query 36 | order = indices[q_idx] 37 | if total_camids == 1:# for NAIC 38 | remove = (g_pids[order] == q_pid) & (g_camids[order] != q_camid) 39 | else: # for others like market1501 40 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 41 | keep = np.invert(remove) 42 | 43 | # compute cmc curve 44 | # binary vector, positions with value 1 are correct matches 45 | orig_cmc = matches[q_idx][keep] 46 | if not np.any(orig_cmc): 47 | # this condition is true when query identity does not appear in gallery 48 | continue 49 | 50 | cmc = orig_cmc.cumsum() 51 | cmc[cmc > 1] = 1 52 | 53 | all_cmc.append(cmc[:max_rank]) 54 | num_valid_q += 1. 55 | 56 | # compute average precision 57 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 58 | num_rel = orig_cmc.sum() 59 | tmp_cmc = orig_cmc.cumsum() 60 | tmp_cmc = np.array(tmp_cmc) / (np.arange(len(tmp_cmc)) + 1.) 61 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 62 | AP = tmp_cmc.sum() / num_rel 63 | all_AP.append(AP) 64 | 65 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 66 | 67 | all_cmc = np.asarray(all_cmc).astype(np.float32) 68 | all_cmc = all_cmc.sum(0) / num_valid_q 69 | mAP = np.mean(all_AP) 70 | 71 | return all_cmc, mAP 72 | 73 | 74 | class evaluator(object): 75 | def __init__(self, num_query, dataset, cfg, max_rank=50): 76 | super(evaluator, self).__init__() 77 | self.num_query = num_query 78 | self.max_rank = max_rank 79 | self.feat_norm = cfg.TEST.FEAT_NORM 80 | self.query_expansion = cfg.TEST.QUERY_EXPANSION 81 | self.query_expansion_topk = 6 82 | self.do_DBA = cfg.TEST.DO_DBA 83 | self.dataset = dataset 84 | self.do_rerank = cfg.TEST.DO_RERANK 85 | self.rerank_param = cfg.TEST.RERANK_PARAM 86 | self.cfg = cfg 87 | 88 | self.feats = [] 89 | self.pids = [] 90 | self.camids = [] 91 | self.img_paths = [] 92 | 93 | def update(self, output): 94 | feat, pid, camid, img_path = output 95 | self.feats.append(feat) 96 | self.pids.extend(np.asarray(pid)) 97 | self.camids.extend(np.asarray(camid)) 98 | self.img_paths.extend(img_path) 99 | 100 | def compute(self): 101 | feats = torch.cat(self.feats, dim=0) 102 | if self.feat_norm == 'yes': 103 | print("The test feature is normalized") 104 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 105 | 106 | if self.do_DBA: 107 | feats = database_aug(feats, top_k=6) 108 | # query 109 | qf = feats[:self.num_query] 110 | q_pids = np.asarray(self.pids[:self.num_query]) 111 | q_camids = np.asarray(self.camids[:self.num_query]) 112 | # gallery 113 | gf = feats[self.num_query:] 114 | g_pids = np.asarray(self.pids[self.num_query:]) 115 | g_camids = np.asarray(self.camids[self.num_query:]) 116 | g_names = [os.path.basename(img_path) for img_path in self.img_paths[self.num_query:]] 117 | 118 | if self.query_expansion: 119 | qf = average_query_expansion(qf, feats, top_k=6) 120 | 121 | if self.cfg.TEST.TRACK_AUG: 122 | gf = track_aug(gf, self.dataset.test_tracks, self.img_paths[self.num_query:]) 123 | 124 | #qf, gf = pca_whiten(qf, gf) 125 | if self.cfg.TEST.USE_VOC: 126 | print('using VOC-ReID') 127 | cam_dist = np.load(self.cfg.TEST.CAM_DIST_PATH) 128 | ori_dist = np.load(self.cfg.TEST.ORI_DIST_PATH) 129 | else: 130 | cam_dist = None 131 | ori_dist = None 132 | 133 | if self.do_rerank: 134 | distmat_np = re_ranking(qf, gf, 135 | k1=self.rerank_param[0], 136 | k2=self.rerank_param[1], 137 | lambda_value=self.rerank_param[2], USE_VOC=self.cfg.TEST.USE_VOC, cam_dist=cam_dist, ori_dist=ori_dist) 138 | else: 139 | distmat, indices = comput_distmat(qf, gf) 140 | distmat_np = distmat.cpu().numpy() 141 | 142 | # track_idxs = generate_track_idxs(g_names, self.dataset.test_tracks) 143 | # distmat_track_np = generate_track_distmat(distmat_np, track_idxs) 144 | # np.save(os.path.dirname(self.cfg.TEST.WEIGHT) + '/distmat_track', 145 | # distmat_track_np) 146 | 147 | # cam_distmat = np.load('./output/aicity20/experiments/ReCamID/distmat.npy') 148 | # ori_distmat = np.load('./output/aicity20/experiments/ReOriID/distmat.npy') 149 | #cam_distmat = np.load('./output/aicity20/0410-test/ReCamID/distmat.npy') 150 | #ori_distmat = np.load('./output/aicity20/0410-test/ReOriID/distmat.npy') 151 | 152 | # cam_distmat = np.load('./output/veri/0411-search/ReCamID/distmat.npy') 153 | #ori_distmat = np.load('./output/veri/ReOriID/distmat.npy') 154 | #distmat_np = distmat_np - 0.1 * ori_distmat# - 0.1 * cam_distmat 155 | 156 | indices_np = np.argsort(distmat_np, axis=1) 157 | if self.cfg.TEST.TRACK_RERANK and len(self.dataset.test_tracks) > 0: 158 | rerank_indice_by_track(indices_np, self.img_paths[self.num_query:], self.dataset.test_tracks) 159 | 160 | cmc, mAP = eval_func(indices_np, q_pids, g_pids, q_camids, g_camids) 161 | if self.cfg.TEST.WRITE_RESULT: 162 | np.save(os.path.dirname(self.cfg.TEST.WEIGHT) + '/distmat', 163 | distmat_np) 164 | np.save(os.path.dirname(self.cfg.TEST.WEIGHT) + '/feats', feats.cpu().numpy()) 165 | write_results(indices_np, os.path.dirname(self.cfg.TEST.WEIGHT), 166 | self.img_paths[self.num_query:]) 167 | 168 | return cmc, mAP, indices_np -------------------------------------------------------------------------------- /lib/utils/vis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | 5 | def concat_vis(imgs, size = [256, 128]): 6 | n = len(imgs) 7 | canvas = np.zeros((size[0], n * size[1], 3)) #(h*w*c) 8 | for i, img in enumerate(imgs): 9 | img = cv2.resize(img, (size[1], size[0])) # (w*h) 10 | canvas[:, i*size[1]:(i+1)*size[1], :] = img 11 | return canvas -------------------------------------------------------------------------------- /scripts/ReCamID.sh: -------------------------------------------------------------------------------- 1 | python tools/train.py --config_file='configs/aicity20.yml' \ 2 | MODEL.DEVICE_ID "('1')" \ 3 | MODEL.MODEL_TYPE "baseline" \ 4 | MODEL.NAME "('resnext101_ibn_a')" \ 5 | MODEL.PRETRAIN_PATH "('/home/cybercore/su/AICity2021-VOC-ReID/resnext101_ibn_a.pth.tar')" \ 6 | SOLVER.LR_SCHEDULER 'cosine_step' \ 7 | DATALOADER.NUM_INSTANCE 8 \ 8 | MODEL.ID_LOSS_TYPE 'circle' \ 9 | SOLVER.WARMUP_ITERS 0 \ 10 | SOLVER.MAX_EPOCHS 12 \ 11 | SOLVER.COSINE_MARGIN 0.35 \ 12 | SOLVER.COSINE_SCALE 64 \ 13 | SOLVER.FREEZE_BASE_EPOCHS 2 \ 14 | SOLVER.IMS_PER_BATCH 32 \ 15 | MODEL.TRIPLET_LOSS_WEIGHT 1.0 \ 16 | INPUT.SIZE_TRAIN '([320, 320])' \ 17 | INPUT.SIZE_TEST '([320, 320])' \ 18 | DATASETS.TRAIN "('aicity20-ReCam',)" \ 19 | DATASETS.TEST "('veri',)" \ 20 | DATASETS.ROOT_DIR "('/media/data/ai-city/Track2')" \ 21 | OUTPUT_DIR "('./output/aicity20/0409-ensemble/ReCamID')" -------------------------------------------------------------------------------- /scripts/ReOriID.sh: -------------------------------------------------------------------------------- 1 | python tools/train.py --config_file='configs/aicity20.yml' \ 2 | MODEL.DEVICE_ID "('1')" \ 3 | MODEL.MODEL_TYPE "baseline" \ 4 | MODEL.NAME "('resnext101_ibn_a')" \ 5 | MODEL.PRETRAIN_PATH "('/home/cybercore/su/AICity2021-VOC-ReID/resnext101_ibn_a.pth.tar')" \ 6 | SOLVER.LR_SCHEDULER 'cosine_step' \ 7 | DATALOADER.NUM_INSTANCE 16 \ 8 | MODEL.ID_LOSS_TYPE 'softmax' \ 9 | MODEL.METRIC_LOSS_TYPE 'none' \ 10 | DATALOADER.SAMPLER 'softmax' \ 11 | INPUT.PROB 0.0 \ 12 | INPUT.SIZE_TRAIN '([320, 320])' \ 13 | INPUT.SIZE_TEST '([320, 320])' \ 14 | MODEL.IF_LABELSMOOTH 'on' \ 15 | SOLVER.WARMUP_ITERS 0 \ 16 | SOLVER.MAX_EPOCHS 12 \ 17 | SOLVER.FREEZE_BASE_EPOCHS 2 \ 18 | SOLVER.IMS_PER_BATCH 32 \ 19 | DATASETS.TRAIN "('aicity20-ReOri',)" \ 20 | DATASETS.TEST "('aicity20-ReOri',)" \ 21 | DATASETS.ROOT_DIR "('/media/data/ai-city/Track2/AIC21_Track2_ReID_Simulation/')" \ 22 | OUTPUT_DIR "('./output/aicity21_ori')" -------------------------------------------------------------------------------- /scripts/submit.sh: -------------------------------------------------------------------------------- 1 | # Submit ReID 2 | python tools/aicity20/submit.py --config_file='configs/aicity20.yml' \ 3 | MODEL.DEVICE_ID "('1')" \ 4 | MODEL.NAME "('resnext101_ibn_a')" \ 5 | MODEL.MODEL_TYPE "baseline" \ 6 | DATASETS.TRAIN "('aicity20',)" \ 7 | DATASETS.TEST "('aicity20',)" \ 8 | DATASETS.ROOT_DIR "('/media/data/ai-city/Track2')" \ 9 | MODEL.PRETRAIN_CHOICE "('self')" \ 10 | INPUT.SIZE_TRAIN '([320, 320])' \ 11 | INPUT.SIZE_TEST '([320, 320])' \ 12 | TEST.DO_RERANK True \ 13 | TEST.RERANK_PARAM "([50, 15, 0.5])" \ 14 | TEST.FLIP_TEST True \ 15 | TEST.TRACK_RERANK False \ 16 | TEST.WRITE_RESULT True \ 17 | TEST.USE_VOC True \ 18 | TEST.WEIGHT "('./output/aicity20/0409-ensemble/A-Nam-SynData-next101-320-circle/final.pth')" -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | python tools/aicity20/submit.py --config_file='configs/aicity20.yml' \ 2 | MODEL.DEVICE_ID "('1')" \ 3 | MODEL.NAME "('resnext101_ibn_a')" \ 4 | MODEL.MODEL_TYPE "baseline" \ 5 | DATASETS.TRAIN "('aicity20',)" \ 6 | DATASETS.TEST "('aicity20',)" \ 7 | DATASETS.ROOT_DIR "('/home/zxy/data/ReID/vehicle')" \ 8 | MODEL.PRETRAIN_CHOICE "('self')" \ 9 | INPUT.SIZE_TRAIN '([320, 320])' \ 10 | INPUT.SIZE_TEST '([320, 320])' \ 11 | TEST.DO_RERANK True \ 12 | TEST.RERANK_PARAM "([50, 15, 0.5])" \ 13 | TEST.FLIP_TEST True \ 14 | TEST.WRITE_RESULT True \ 15 | TEST.USE_VOC True \ 16 | TEST.CAM_DIST_PATH './output/aicity20/0409-ensemble/ReCamID/feat_distmat.npy' \ 17 | TEST.ORI_DIST_PATH './output/aicity20/0409-ensemble/ReOriID/feat_distmat.npy' \ 18 | TEST.WEIGHT "('./output/aicity20/0409-ensemble/next101-320-circle/best.pth')" -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | python -W ignore tools/train.py --config_file='configs/aicity20.yml' 2 | 3 | python -W ignore tools/train.py --config_file='configs/aicity20.yml' 4 | 5 | python -W ignore tools/train.py --config_file='configs/aicity20.yml' 6 | 7 | python -W ignore tools/train.py --config_file='configs/aicity20.yml' 8 | 9 | python -W ignore tools/train.py --config_file='configs/aicity20.yml' 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | from setuptools import find_packages, setup 4 | 5 | import torch 6 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 7 | CUDAExtension) 8 | 9 | 10 | def readme(): 11 | with open('README.md', encoding='utf-8') as f: 12 | content = f.read() 13 | return content 14 | 15 | 16 | def make_cuda_ext(name, module, sources, sources_cuda=[]): 17 | 18 | define_macros = [] 19 | extra_compile_args = {'cxx': []} 20 | 21 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 22 | define_macros += [('WITH_CUDA', None)] 23 | extension = CUDAExtension 24 | extra_compile_args['nvcc'] = [ 25 | '-D__CUDA_NO_HALF_OPERATORS__', 26 | '-D__CUDA_NO_HALF_CONVERSIONS__', 27 | '-D__CUDA_NO_HALF2_OPERATORS__', 28 | ] 29 | sources += sources_cuda 30 | else: 31 | print(f'Compiling {name} without CUDA') 32 | extension = CppExtension 33 | 34 | return extension( 35 | name=f'{module}.{name}', 36 | sources=[os.path.join(*module.split('.'), p) for p in sources], 37 | define_macros=define_macros, 38 | extra_compile_args=extra_compile_args) 39 | 40 | 41 | def parse_requirements(fname='requirements.txt', with_version=True): 42 | """Parse the package dependencies listed in a requirements file but strips 43 | specific versioning information. 44 | 45 | Args: 46 | fname (str): path to requirements file 47 | with_version (bool, default=False): if True include version specs 48 | 49 | Returns: 50 | List[str]: list of requirements items 51 | 52 | CommandLine: 53 | python -c "import setup; print(setup.parse_requirements())" 54 | """ 55 | import sys 56 | from os.path import exists 57 | import re 58 | require_fpath = fname 59 | 60 | def parse_line(line): 61 | """Parse information from a line in a requirements text file.""" 62 | if line.startswith('-r '): 63 | # Allow specifying requirements in other files 64 | target = line.split(' ')[1] 65 | for info in parse_require_file(target): 66 | yield info 67 | else: 68 | info = {'line': line} 69 | if line.startswith('-e '): 70 | info['package'] = line.split('#egg=')[1] 71 | elif '@git+' in line: 72 | info['package'] = line 73 | else: 74 | # Remove versioning from the package 75 | pat = '(' + '|'.join(['>=', '==', '>']) + ')' 76 | parts = re.split(pat, line, maxsplit=1) 77 | parts = [p.strip() for p in parts] 78 | 79 | info['package'] = parts[0] 80 | if len(parts) > 1: 81 | op, rest = parts[1:] 82 | if ';' in rest: 83 | # Handle platform specific dependencies 84 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies 85 | version, platform_deps = map(str.strip, 86 | rest.split(';')) 87 | info['platform_deps'] = platform_deps 88 | else: 89 | version = rest # NOQA 90 | info['version'] = (op, version) 91 | yield info 92 | 93 | def parse_require_file(fpath): 94 | with open(fpath, 'r') as f: 95 | for line in f.readlines(): 96 | line = line.strip() 97 | if line and not line.startswith('#'): 98 | for info in parse_line(line): 99 | yield info 100 | 101 | def gen_packages_items(): 102 | if exists(require_fpath): 103 | for info in parse_require_file(require_fpath): 104 | parts = [info['package']] 105 | if with_version and 'version' in info: 106 | parts.extend(info['version']) 107 | if not sys.version.startswith('3.4'): 108 | # apparently package_deps are broken in 3.4 109 | platform_deps = info.get('platform_deps') 110 | if platform_deps is not None: 111 | parts.append(';' + platform_deps) 112 | item = ''.join(parts) 113 | yield item 114 | 115 | packages = list(gen_packages_items()) 116 | return packages 117 | 118 | 119 | if __name__ == '__main__': 120 | setup( 121 | name='mmdet', 122 | version=get_version(), 123 | description='OpenMMLab Detection Toolbox and Benchmark', 124 | long_description=readme(), 125 | long_description_content_type='text/markdown', 126 | author='OpenMMLab', 127 | author_email='openmmlab@gmail.com', 128 | keywords='computer vision, object detection', 129 | url='https://github.com/open-mmlab/mmdetection', 130 | packages=find_packages(exclude=('configs', 'tools', 'demo')), 131 | classifiers=[ 132 | 'Development Status :: 5 - Production/Stable', 133 | 'License :: OSI Approved :: Apache Software License', 134 | 'Operating System :: OS Independent', 135 | 'Programming Language :: Python :: 3', 136 | 'Programming Language :: Python :: 3.6', 137 | 'Programming Language :: Python :: 3.7', 138 | 'Programming Language :: Python :: 3.8', 139 | ], 140 | license='Apache License 2.0', 141 | setup_requires=parse_requirements('requirements/build.txt'), 142 | tests_require=parse_requirements('requirements/tests.txt'), 143 | install_requires=parse_requirements('requirements/runtime.txt'), 144 | extras_require={ 145 | 'all': parse_requirements('requirements.txt'), 146 | 'tests': parse_requirements('requirements/tests.txt'), 147 | 'build': parse_requirements('requirements/build.txt'), 148 | 'optional': parse_requirements('requirements/optional.txt'), 149 | }, 150 | ext_modules=[], 151 | cmdclass={'build_ext': BuildExtension}, 152 | zip_safe=False) 153 | -------------------------------------------------------------------------------- /tools/aicity20/compute_distmat_from_feats.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import argparse 4 | 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser(description="ReID Baseline Inference") 8 | parser.add_argument( 9 | "--src_dir", default="./output/aicity20/0410-test/r50-320-circle", help="path to config file", type=str 10 | ) 11 | args = parser.parse_args() 12 | src_dir = args.src_dir 13 | 14 | feat = np.load(src_dir + '/' + 'feats.npy') 15 | feat = torch.tensor(feat, device='cpu') 16 | all_num = len(feat) 17 | distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num) + \ 18 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 19 | distmat.addmm_(1, -2, feat, feat.t()) 20 | distmat = distmat.cpu().numpy() 21 | np.save(src_dir + '/' + 'feat_distmat', distmat) 22 | -------------------------------------------------------------------------------- /tools/aicity20/eval_by_distmat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import torch 5 | 6 | sys.path.append('.') 7 | from lib.data.datasets.veri import VeRi 8 | from lib.data.datasets.aicity20_trainval import AICity20Trainval 9 | from lib.utils.post_process import build_track_lookup, re_ranking 10 | 11 | 12 | def generate_track_results(distmat, tracks, topk=100): 13 | indice = np.argsort(distmat, axis=1) 14 | results = [] 15 | m, n =distmat.shape 16 | for i in range(m): 17 | result = [] 18 | track_idxs = indice[i] 19 | for idx in track_idxs: 20 | result.extend(tracks[idx]) 21 | results.append(result[:topk]) 22 | return results 23 | 24 | 25 | def results_to_pid(results, img_to_pid): 26 | 27 | result_pids = [] 28 | for line in results: 29 | result_pid = [] 30 | for name in line: 31 | result_pid.append(img_to_pid[name]) 32 | result_pids.append(result_pid) 33 | return result_pids 34 | 35 | 36 | def eval_results(query_pids, gallery_pids, result_pids): 37 | query_pids = np.array(query_pids) 38 | gallery_pids = np.array(gallery_pids) 39 | result_pids = np.array(result_pids) 40 | gt_match = gallery_pids == query_pids[:, np.newaxis] 41 | 42 | all_cmc = [] 43 | all_AP = [] 44 | num_valid_q = 0 45 | for i in range(len(query_pids)): 46 | if not np.any(gt_match[i]): 47 | continue 48 | num_valid_q += 1 49 | num_rel = gt_match[i].sum() 50 | match = query_pids[i] == result_pids[i] 51 | cmc = match.cumsum() 52 | cmc[cmc > 1] = 1 53 | all_cmc.append(cmc) 54 | tmp_cmc = match.cumsum() 55 | tmp_cmc = np.array(tmp_cmc) / (np.arange(len(tmp_cmc)) + 1.) 56 | tmp_cmc = np.asarray(tmp_cmc) * match 57 | AP = tmp_cmc.sum() / num_rel 58 | all_AP.append(AP) 59 | 60 | all_cmc = np.array(all_cmc).sum(0) / num_valid_q 61 | mAP = np.mean(all_AP) 62 | print("mAP: {:.1%}".format(mAP)) 63 | for r in [1, 5, 10]: 64 | print("CMC curve, Rank-{:<3}:{:.1%}".format(r, all_cmc[r - 1])) 65 | 66 | 67 | def generate_results(distmat, gallery, topk=100): 68 | assert distmat.shape[1] == len(gallery) 69 | names = [os.path.basename(img_path) for img_path, pid, camid in gallery] 70 | indice = np.argsort(distmat, axis=1) 71 | indice = indice[:, :topk] 72 | results = [] 73 | m, n = indice.shape 74 | for i in range(m): 75 | result = [] 76 | for j in range(n): 77 | result.append(names[indice[i, j]]) 78 | results.append(result) 79 | return results 80 | 81 | 82 | def results_to_track(results, tracks, topk=100): 83 | m, n = len(results), len(results[0]) 84 | lookup_map = {} 85 | for i, track in enumerate(tracks): 86 | for img_id in track: 87 | lookup_map[img_id] = i 88 | reranked_results = [] 89 | for i in range(m): 90 | used_track_id = set() 91 | reranked_result = [] 92 | for j in range(n): 93 | track_id = lookup_map[results[i][j]] 94 | if track_id in used_track_id: 95 | continue 96 | used_track_id.add(track_id) 97 | reranked_result.extend(tracks[track_id]) 98 | reranked_results.append(reranked_result[:topk]) 99 | return reranked_results 100 | 101 | 102 | 103 | if __name__ == '__main__': 104 | dataset = VeRi(root='/home/zxy/data/ReID/vehicle') 105 | distmat1 = np.load('./output/veri/0411-search/circle-N16/distmat.npy') 106 | cam_distmat = np.load('./output/veri/0411-search/ReCamID/distmat.npy') 107 | ori_distmat = np.load('./output/veri/0411-search/ReOriID/distmat.npy') 108 | 109 | # type_distmat = np.load('./output/aicity20/0409-ReTypeID/feat_distmat.npy') 110 | # color_distmat = np.load('./output/aicity20/0409-ReColorID/feat_distmat.npy') 111 | 112 | # cam_distmat = np.load('./output/aicity20/experiments/ReCamID/distmat.npy') 113 | # ori_distmat = np.load('./output/aicity20/experiments/ReOriID/distmat.npy') 114 | 115 | #cam_distmat = np.load('./output/aicity20/0407-ReCamID/distmat_test.npy') 116 | #ori_distmat = np.load('./output/aicity20/0409-ReOriID/distmat.npy') 117 | # distmat5 = np.load('./output/aicity20/0407-ensemble/se-r101/distmat.npy') 118 | 119 | qf = torch.rand(len(dataset.query), 1) 120 | gf = torch.rand(len(dataset.gallery), 1) 121 | 122 | distmat = distmat1- 0.1 * ori_distmat - 0.1 * cam_distmat 123 | #distmat = distmat[:len(qf), len(qf):] 124 | #distmat = re_ranking(qf, gf, k1=50, k2=15, lambda_value=0.5, local_distmat=distmat, only_local=True) 125 | 126 | 127 | query_pids = [pid for _, pid, _ in dataset.query] 128 | gallery_pids = [] 129 | img_to_pid = {} 130 | for img_path, pid, _ in dataset.gallery: 131 | name = os.path.basename(img_path) 132 | gallery_pids.append(pid) 133 | img_to_pid[name] = pid 134 | 135 | #distmat = (distmat1 + distmat2 + distmat3 + distmat5) / 4 #- 0.1 * distmat4 136 | #distmat = distmat1 * distmat2 * distmat3 * distmat5 137 | results = generate_results(distmat, dataset.gallery, topk=-1) 138 | #results = results_to_track(results, dataset.test_tracks) 139 | result_pids = results_to_pid(results, img_to_pid) 140 | eval_results(query_pids, gallery_pids, result_pids) 141 | -------------------------------------------------------------------------------- /tools/aicity20/fix_track.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | src_path = '/home/zxy/data/ReID/vehicle/AIC20_ReID/train_track_id.txt' 4 | 5 | results = [] 6 | with open(src_path, 'r') as f: 7 | lines = f.readlines() 8 | for line in lines: 9 | line = line.strip() 10 | line = line.split(' ') 11 | line = [ele.zfill(6) + '.jpg' for ele in line] 12 | results.append(line) 13 | 14 | with open(os.path.dirname(src_path) + '/train_track.txt', 'w') as f: 15 | for result in results: 16 | f.write(' '.join(result) + '\n') -------------------------------------------------------------------------------- /tools/aicity20/multi_model_ensemble.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | sys.path.append('.') 4 | from lib.data.datasets.aicity20 import AICity20 5 | from tools.aicity20.submit import write_result_with_track 6 | 7 | 8 | if __name__ == '__main__': 9 | dataset = AICity20('/home/zxy/data/ReID/vehicle') 10 | distmat_path = ['./output/aicity20/0409-ensemble/r50-320-circle/distmat.npy', 11 | './output/aicity20/0409-ensemble/next101-320-circle/distmat.npy', 12 | './output/aicity20/0409-ensemble/r101-320-circle/distmat.npy', 13 | ] 14 | #cam_distmat = np.load('./output/aicity20/0407-ReCamID/distmat_submit.npy') 15 | #ori_distmat = np.load('./output/aicity20/0409-ensemble/ReTypeID/distmat_submit.npy') 16 | distmat = [] 17 | for path in distmat_path: 18 | distmat.append(np.load(path)) 19 | distmat = sum(distmat) / len(distmat) 20 | #distmat = distmat - 0.1 * cam_distmat - 0.1 * ori_distmat 21 | 22 | indices = np.argsort(distmat, axis=1) 23 | write_result_with_track(indices, './output/aicity20/submit/', dataset.test_tracks) 24 | 25 | -------------------------------------------------------------------------------- /tools/aicity20/submit.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | from os import mkdir 11 | 12 | import torch 13 | from torch.backends import cudnn 14 | 15 | sys.path.append('.') 16 | from lib.config import cfg 17 | from lib.data import make_data_loader 18 | from lib.engine.inference import inference, select_topk 19 | from lib.modeling import build_model 20 | from lib.utils.logger import setup_logger 21 | from lib.data.datasets.aicity20 import AICity20 22 | 23 | def write_result(indices, dst_dir, topk=100): 24 | indices = indices[:, :topk] 25 | if not os.path.exists(dst_dir): 26 | os.makedirs(dst_dir) 27 | m, n = indices.shape 28 | print('m: {} n: {}'.format(m, n)) 29 | with open(os.path.join(dst_dir, 'track2.txt'), 'w') as f: 30 | for i in range(m): 31 | write_line = indices[i] + 1 32 | write_line = ' '.join(map(str, write_line.tolist())) + '\n' 33 | f.write(write_line) 34 | 35 | ''' 36 | 根据rank K, 选择整个track,直到把topk填满 37 | ''' 38 | def write_result_with_track(indices, dst_dir, tracks, topk=100): 39 | indices = indices[:, :topk] 40 | if not os.path.exists(dst_dir): 41 | os.makedirs(dst_dir) 42 | m, n = indices.shape 43 | print('m: {} n: {}'.format(m, n)) 44 | 45 | results = [] 46 | for i in range(m): 47 | results.append((indices[i] + 1).tolist()) 48 | 49 | # rerank results according to tracks 50 | lookup_map = {} 51 | for i, track in enumerate(tracks): 52 | # for img_id in track: 53 | # lookup_map[int(img_id)] = i 54 | for img_name in track: 55 | lookup_map[int(img_name.split('.')[0])] = i 56 | reranked_results = [] 57 | for i in range(m): 58 | used_track_id = set() 59 | reranked_result = [] 60 | for j in range(topk): 61 | track_id = lookup_map[results[i][j]] 62 | if track_id in used_track_id: 63 | continue 64 | used_track_id.add(track_id) 65 | reranked_result.extend(tracks[track_id]) 66 | if len(reranked_result) >= topk: break 67 | reranked_results.append(reranked_result[:topk]) 68 | 69 | with open(os.path.join(dst_dir, 'track2.txt'), 'w') as f: 70 | for i in range(m): 71 | write_line = ' '.join(map(str, reranked_results[i])) + '\n' 72 | f.write(write_line) 73 | 74 | 75 | def main(): 76 | parser = argparse.ArgumentParser(description="ReID Baseline Inference") 77 | parser.add_argument( 78 | "--config_file", default="./configs/debug.yml", help="path to config file", type=str 79 | ) 80 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 81 | nargs=argparse.REMAINDER) 82 | 83 | args = parser.parse_args() 84 | 85 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 86 | 87 | if args.config_file != "": 88 | cfg.merge_from_file(args.config_file) 89 | cfg.merge_from_list(args.opts) 90 | cfg.freeze() 91 | 92 | output_dir = cfg.OUTPUT_DIR 93 | if output_dir and not os.path.exists(output_dir): 94 | mkdir(output_dir) 95 | 96 | logger = setup_logger("reid_baseline", output_dir, 0) 97 | logger.info("Using {} GPUS".format(num_gpus)) 98 | logger.info(args) 99 | 100 | if args.config_file != "": 101 | logger.info("Loaded configuration file {}".format(args.config_file)) 102 | with open(args.config_file, 'r') as cf: 103 | config_str = "\n" + cf.read() 104 | logger.info(config_str) 105 | logger.info("Running with config:\n{}".format(cfg)) 106 | 107 | if cfg.MODEL.DEVICE == "cuda": 108 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 109 | cudnn.benchmark = True 110 | 111 | train_loader, val_loader, num_query, num_classes, dataset = make_data_loader(cfg) 112 | model = build_model(cfg, num_classes) 113 | model.load_param(cfg.TEST.WEIGHT) 114 | 115 | indices_np = inference(cfg, model, val_loader, num_query, dataset) 116 | ## read meta information 117 | dataset = AICity20(cfg.DATASETS.ROOT_DIR) 118 | #write_result(indices_np, os.path.dirname(cfg.TEST.WEIGHT), topk=100) 119 | write_result_with_track(indices_np, os.path.dirname(cfg.TEST.WEIGHT), dataset.test_tracks) 120 | 121 | if __name__ == '__main__': 122 | main() 123 | -------------------------------------------------------------------------------- /tools/aicity20/vis_result.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import sys 5 | 6 | sys.path.append('.') 7 | from lib.data.datasets.aicity20_trainval import AICity20Trainval 8 | 9 | def visualize_submit(dataset, out_dir, submit_txt_path, topk=5): 10 | query_dir = dataset.query_dir 11 | gallery_dir = dataset.gallery_dir 12 | 13 | vis_size = (256, 256) 14 | if not os.path.exists(out_dir): 15 | os.makedirs(out_dir) 16 | results = [] 17 | with open(submit_txt_path, 'r') as f: 18 | lines = f.readlines() 19 | for line in lines: 20 | line = line.strip() 21 | results.append(line.split(' ')) 22 | 23 | query_pids = [pid for _, pid, _ in dataset.query] 24 | img_to_pid = {} 25 | for img_path, pid, _ in dataset.gallery: 26 | name = os.path.basename(img_path) 27 | img_to_pid[name] = pid 28 | 29 | for i, result in enumerate(results): 30 | is_False = False 31 | # query_path = os.path.join(query_dir, str(i+1).zfill(6)+'.jpg') 32 | query_path = os.path.join(query_dir, os.path.basename(dataset.query[i][0])) 33 | gallery_paths = [] 34 | for name in result: 35 | # gallery_paths.append(os.path.join(gallery_dir, index.zfill(6)+'.jpg')) 36 | gallery_paths.append(os.path.join(gallery_dir, name)) 37 | 38 | imgs = [] 39 | imgs.append(cv2.resize(cv2.imread(query_path), vis_size)) 40 | for n in range(topk): 41 | img = cv2.resize(cv2.imread(gallery_paths[n]), vis_size) 42 | if query_pids[i] != img_to_pid[result[n]]: 43 | img = cv2.rectangle(img, (0, 0), vis_size, (0, 0, 255), 2) 44 | is_False = True 45 | imgs.append(img) 46 | 47 | canvas = np.concatenate(imgs, axis=1) 48 | #if is_False: 49 | cv2.imwrite(os.path.join(out_dir, os.path.basename(query_path)), canvas) 50 | 51 | 52 | if __name__ == '__main__': 53 | # dataset_dir = '/home/xiangyuzhu/data/ReID/AIC20_ReID' 54 | dataset = AICity20Trainval(root='/home/zxy/data/ReID/vehicle') 55 | # 56 | # dataset_dir = '/home/zxy/data/ReID/vehicle/AIC20_ReID_Cropped' 57 | # query_dir = os.path.join(dataset_dir, 'image_query') 58 | # gallery_dir = os.path.join(dataset_dir, 'image_test') 59 | 60 | out_dir = 'vis/' 61 | submit_txt_path = './output/aicity20/experiments/circle-sim-aug/result_voc.txt' 62 | visualize_submit(dataset, out_dir, submit_txt_path) 63 | -------------------------------------------------------------------------------- /tools/aicity20/weakly_supervised_crop_aug.py: -------------------------------------------------------------------------------- 1 | ''' 2 | detect object by actmap 3 | ''' 4 | 5 | # encoding: utf-8 6 | import argparse 7 | import os 8 | import sys 9 | from os import mkdir 10 | import cv2 11 | import numpy as np 12 | import torch 13 | from torch.backends import cudnn 14 | from torch.nn import functional as F 15 | import json 16 | sys.path.append('.') 17 | from lib.config import cfg 18 | from lib.data import make_data_loader 19 | from lib.engine.inference import inference 20 | from lib.modeling import build_model 21 | from lib.utils.logger import setup_logger 22 | from lib.utils.bbox_utils import localize_from_map, draw_bbox 23 | 24 | 25 | def vis_actmap(model, cfg, loader, out_dir): 26 | device = cfg.MODEL.DEVICE 27 | model.to(device) 28 | model.eval() 29 | 30 | img_size = cfg.INPUT.SIZE_TEST 31 | if not os.path.exists(out_dir): 32 | os.mkdir(out_dir) 33 | if not os.path.exists(os.path.join(out_dir, 'image_train')): 34 | os.mkdir(os.path.join(out_dir, 'image_train')) 35 | if not os.path.exists(os.path.join(out_dir, 'image_query')): 36 | os.mkdir(os.path.join(out_dir, 'image_query')) 37 | if not os.path.exists(os.path.join(out_dir, 'image_test')): 38 | os.mkdir(os.path.join(out_dir, 'image_test')) 39 | 40 | results = [] 41 | 42 | with torch.no_grad(): 43 | for i, batch in enumerate(loader): 44 | data, pid, camid, img_path = batch 45 | data = data.cuda() 46 | featmap = model(data, return_featmap=True) # N*2048*7*7 47 | featmap = (featmap**2).sum(1) # N*1*7*7 48 | canvas = [] 49 | for j in range(featmap.size(0)): 50 | fm = featmap[j].detach().cpu().numpy() 51 | 52 | # something is not right! 53 | # fm[0:3, 0:3] = 0 54 | # fm[0:3, 12:15] = 0 55 | # fm[12:15, 0:3] = 0 56 | # fm[12:15, 12:15] = 0 57 | 58 | fm[0:4, :] = 0 59 | fm[12:16, :] = 0 60 | fm[:, 0:4] = 0 61 | fm[:, 12:16] = 0 62 | 63 | fm = cv2.resize(fm, (img_size[1], img_size[0])) 64 | fm = 255 * (fm - np.min(fm)) / ( 65 | np.max(fm) - np.min(fm) + 1e-12 66 | ) 67 | bbox = localize_from_map(fm, threshold_ratio=1.0) 68 | fm = np.uint8(np.floor(fm)) 69 | fm = cv2.applyColorMap(fm, cv2.COLORMAP_JET) 70 | 71 | img = cv2.imread(img_path[j]) 72 | height, width, _ = img.shape 73 | #img = cv2.resize(img, (img_size[1], img_size[0])) 74 | bbox = np.array(bbox, dtype=np.float32) 75 | bbox[0::2] *= width / img_size[1] 76 | bbox[1::2] *= height / img_size[0] 77 | 78 | bbox[:2] *= 0.7 79 | bbox[2:] *= 1.2 80 | 81 | bbox = np.array(bbox, dtype=np.int) 82 | 83 | results.append({'img_path': '/'.join(img_path[j].split('/')[-2:]), 'bbox': bbox.tolist()}) 84 | 85 | crop = img[bbox[1]:bbox[3], bbox[0]:bbox[2], :] 86 | 87 | #crop = cv2.resize(crop, (img_size[1], img_size[0])) 88 | cv2.imwrite(os.path.join(out_dir, '/'.join(img_path[j].split('/')[-2:])), crop) 89 | 90 | #overlapped = img * 0.3 + fm * 0.7 91 | #overlapped = draw_bbox(overlapped, [bbox]) 92 | 93 | #overlapped = overlapped.astype(np.uint8) 94 | #canvas.append(cv2.resize(overlapped, (img_size[1], img_size[0]))) 95 | #canvas = np.concatenate(canvas[:8], axis=1) # .reshape([-1, 2048, 3]) 96 | #cv2.imwrite(os.path.join(out_dir, '{}.jpg'.format(i)), canvas) 97 | return results 98 | 99 | 100 | 101 | def main(): 102 | parser = argparse.ArgumentParser(description="ReID Baseline Inference") 103 | parser.add_argument( 104 | "--config_file", default="./configs/debug.yml", help="path to config file", type=str 105 | ) 106 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 107 | nargs=argparse.REMAINDER) 108 | 109 | args = parser.parse_args() 110 | 111 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 112 | 113 | if args.config_file != "": 114 | cfg.merge_from_file(args.config_file) 115 | cfg.merge_from_list(args.opts) 116 | cfg.freeze() 117 | 118 | output_dir = cfg.OUTPUT_DIR 119 | if output_dir and not os.path.exists(output_dir): 120 | mkdir(output_dir) 121 | 122 | logger = setup_logger("reid_baseline", output_dir, 0) 123 | logger.info("Using {} GPUS".format(num_gpus)) 124 | logger.info(args) 125 | 126 | if args.config_file != "": 127 | logger.info("Loaded configuration file {}".format(args.config_file)) 128 | logger.info("Running with config:\n{}".format(cfg)) 129 | 130 | if cfg.MODEL.DEVICE == "cuda": 131 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 132 | cudnn.benchmark = True 133 | 134 | train_loader, val_loader, num_query, num_classes, dataset = make_data_loader(cfg) 135 | model = build_model(cfg, num_classes) 136 | model.load_param(cfg.TEST.WEIGHT) 137 | 138 | results = [] 139 | out_dir = os.path.dirname(cfg.TEST.WEIGHT) 140 | results += vis_actmap(model, cfg, train_loader, out_dir) 141 | results += vis_actmap(model, cfg, val_loader, out_dir) 142 | 143 | with open(os.path.join(out_dir, 'detection.json'), 'w') as f: 144 | json.dump(results, f) 145 | 146 | 147 | 148 | 149 | if __name__ == '__main__': 150 | main() 151 | 152 | 153 | ''' 154 | python tools/aicity20/weakly_supervised_crop_aug.py --config_file='configs/aicity20.yml' \ 155 | MODEL.DEVICE_ID "('0')" \ 156 | MODEL.NAME "('resnet50_ibn_a')" \ 157 | MODEL.MODEL_TYPE "baseline" \ 158 | DATASETS.TRAIN "('aicity20',)" \ 159 | DATASETS.TEST "('aicity20',)" \ 160 | DATALOADER.SAMPLER 'softmax' \ 161 | DATASETS.ROOT_DIR "('/home/zxy/data/ReID/vehicle')" \ 162 | MODEL.PRETRAIN_CHOICE "('self')" \ 163 | TEST.WEIGHT "('./output/aicity20/0326-search/augmix/best.pth')" 164 | ''' -------------------------------------------------------------------------------- /tools/gen_vis.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # make folder 4 | folder = './vis_best_with_ori_camid_track_rank' 5 | if not os.path.exists(folder): 6 | os.makedirs(folder, mode = 0o777) 7 | 8 | # read results 9 | file_path = '/home/cybercore/su/AICity2021-VOC-ReID/output/aicity20/0409-ensemble/A-Nam-SynData-next101-320-circle/track2.txt' 10 | file_submit = open(file_path, 'r') 11 | all_id_list = file_submit.readlines() 12 | 13 | 14 | for idx, id_list in enumerate(all_id_list): 15 | 16 | id_list = id_list.split(' ') 17 | file = open(f'{folder}/{idx+1:06}.txt', 'w') 18 | for id in id_list: 19 | id = id.split('.')[0] 20 | file.write(id + '\n') 21 | 22 | file.close() 23 | 24 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | from os import mkdir 11 | 12 | import torch 13 | from torch.backends import cudnn 14 | 15 | sys.path.append('.') 16 | from lib.config import cfg 17 | from lib.data import make_data_loader 18 | from lib.engine.inference import inference 19 | from lib.modeling import build_model 20 | from lib.utils.logger import setup_logger 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser(description="ReID Baseline Inference") 25 | parser.add_argument( 26 | "--config_file", default="./configs/debug.yml", help="path to config file", type=str 27 | ) 28 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 29 | nargs=argparse.REMAINDER) 30 | 31 | args = parser.parse_args() 32 | 33 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 34 | 35 | if args.config_file != "": 36 | cfg.merge_from_file(args.config_file) 37 | cfg.merge_from_list(args.opts) 38 | cfg.freeze() 39 | 40 | output_dir = cfg.OUTPUT_DIR 41 | if output_dir and not os.path.exists(output_dir): 42 | mkdir(output_dir) 43 | 44 | logger = setup_logger("reid_baseline", output_dir, 0) 45 | logger.info("Using {} GPUS".format(num_gpus)) 46 | logger.info(args) 47 | 48 | if args.config_file != "": 49 | logger.info("Loaded configuration file {}".format(args.config_file)) 50 | # with open(args.config_file, 'r') as cf: 51 | # config_str = "\n" + cf.read() 52 | # logger.info(config_str) 53 | logger.info("Running with config:\n{}".format(cfg)) 54 | 55 | if cfg.MODEL.DEVICE == "cuda": 56 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 57 | cudnn.benchmark = True 58 | 59 | train_loader, val_loader, num_query, num_classes, dataset = make_data_loader(cfg) 60 | model = build_model(cfg, num_classes) 61 | model.load_param(cfg.TEST.WEIGHT) 62 | 63 | inference(cfg, model, val_loader, num_query, dataset) 64 | 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | import torch 11 | 12 | from torch.backends import cudnn 13 | 14 | sys.path.append('.') 15 | from lib.config import cfg 16 | from lib.data import make_data_loader 17 | from lib.engine.train_net import do_train 18 | from lib.modeling import build_model 19 | from lib.layers import make_loss 20 | from lib.solver import make_optimizer, build_lr_scheduler 21 | 22 | from lib.utils.logger import setup_logger 23 | 24 | 25 | def train(cfg): 26 | # prepare dataset 27 | train_loader, val_loader, num_query, num_classes, dataset = make_data_loader(cfg) 28 | 29 | # prepare model 30 | model = build_model(cfg, num_classes) 31 | optimizer = make_optimizer(cfg, model) 32 | 33 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 34 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 35 | 36 | loss_func = make_loss(cfg, num_classes) # modified by gu 37 | 38 | # Add for using self trained model 39 | if cfg.MODEL.PRETRAIN_CHOICE == 'imagenet': 40 | start_epoch = 0 41 | last_epoch = -1 42 | elif cfg.MODEL.PRETRAIN_CHOICE == 'finetune': 43 | start_epoch = 0 44 | last_epoch = -1 45 | model.load_param(cfg.MODEL.PRETRAIN_PATH, skip_fc=False) 46 | elif cfg.MODEL.PRETRAIN_CHOICE == 'resume': 47 | checkpoint = torch.load(cfg.MODEL.PRETRAIN_PATH, map_location='cuda') 48 | start_epoch = checkpoint['epoch'] 49 | last_epoch = start_epoch 50 | model.load_state_dict(checkpoint['state_dict']) 51 | model.cuda() 52 | #optimizer = make_optimizer(cfg, model) 53 | optimizer.load_state_dict(checkpoint['optimizer']) 54 | print('resume from {}'.format(cfg.MODEL.PRETRAIN_PATH)) 55 | else: 56 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE)) 57 | 58 | scheduler = build_lr_scheduler(optimizer, cfg.SOLVER.LR_SCHEDULER, cfg, last_epoch) 59 | 60 | do_train( 61 | cfg, 62 | model, 63 | dataset, 64 | train_loader, 65 | val_loader, 66 | optimizer, 67 | scheduler, # modify for using self trained model 68 | loss_func, 69 | num_query, 70 | start_epoch # add for using self trained model 71 | ) 72 | 73 | 74 | def main(): 75 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 76 | parser.add_argument( 77 | "--config_file", default="./configs/debug.yml", help="path to config file", type=str 78 | ) 79 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 80 | nargs=argparse.REMAINDER) 81 | 82 | args = parser.parse_args() 83 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 84 | 85 | if args.config_file != "": 86 | cfg.merge_from_file(args.config_file) 87 | cfg.merge_from_list(args.opts) 88 | cfg.freeze() 89 | 90 | output_dir = cfg.OUTPUT_DIR 91 | if output_dir and not os.path.exists(output_dir): 92 | os.makedirs(output_dir) 93 | 94 | logger = setup_logger("reid_baseline", output_dir, 0) 95 | logger.info("Using {} GPUS".format(num_gpus)) 96 | logger.info(args) 97 | 98 | if args.config_file != "": 99 | logger.info("Loaded configuration file {}".format(args.config_file)) 100 | # with open(args.config_file, 'r') as cf: 101 | # config_str = "\n" + cf.read() 102 | # logger.info(config_str) 103 | logger.info("Running with config:\n{}".format(cfg)) 104 | 105 | if cfg.MODEL.DEVICE == "cuda": 106 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu 107 | cudnn.benchmark = True 108 | train(cfg) 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /tools/vis_actmap.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import argparse 3 | import os 4 | import sys 5 | from os import mkdir 6 | import cv2 7 | import numpy as np 8 | import torch 9 | from torch.backends import cudnn 10 | from torch.nn import functional as F 11 | sys.path.append('.') 12 | from lib.config import cfg 13 | from lib.data import make_data_loader 14 | from lib.engine.inference import inference 15 | from lib.modeling import build_model 16 | from lib.utils.logger import setup_logger 17 | from lib.utils.bbox_utils import localize_from_map, draw_bbox 18 | 19 | def vis_actmap(model, cfg, val_loader, max_num=100): 20 | device = cfg.MODEL.DEVICE 21 | model.to(device) 22 | model.eval() 23 | out_dir = os.path.join(cfg.OUTPUT_DIR, 'actmap') 24 | 25 | img_size = cfg.INPUT.SIZE_TEST 26 | if not os.path.exists(out_dir): 27 | os.mkdir(out_dir) 28 | with torch.no_grad(): 29 | for i, batch in enumerate(val_loader): 30 | if i >= max_num: 31 | break 32 | data, pid, camid, img_path = batch 33 | data = data.cuda() 34 | featmap = model(data, return_featmap=True) # N*2048*7*7 35 | featmap = (featmap**2).sum(1) # N*1*7*7 36 | canvas = [] 37 | for j in range(featmap.size(0)): 38 | fm = featmap[j].detach().cpu().numpy() 39 | 40 | # something is not right! 41 | fm[0:3, 0:3] = 0 42 | fm[0, 15] = 0 43 | fm[15, 0] = 0 44 | fm[15, 15] = 0 45 | 46 | fm = cv2.resize(fm, (img_size[1], img_size[0])) 47 | fm = 255 * (fm - np.min(fm)) / ( 48 | np.max(fm) - np.min(fm) + 1e-12 49 | ) 50 | bbox = localize_from_map(fm, threshold_ratio=1.0) 51 | fm = np.uint8(np.floor(fm)) 52 | fm = cv2.applyColorMap(fm, cv2.COLORMAP_JET) 53 | 54 | img = cv2.imread(img_path[j]) 55 | img = cv2.resize(img, (img_size[1], img_size[0])) 56 | 57 | overlapped = img * 0.3 + fm * 0.7 58 | overlapped = draw_bbox(overlapped, [bbox]) 59 | 60 | overlapped = overlapped.astype(np.uint8) 61 | canvas.append(overlapped) 62 | canvas = np.concatenate(canvas[:4], axis=1)#.reshape([-1, 2048, 3]) 63 | cv2.imwrite(os.path.join(out_dir, '{}.jpg'.format(i)), canvas) 64 | 65 | 66 | 67 | def main(): 68 | parser = argparse.ArgumentParser(description="ReID Baseline Inference") 69 | parser.add_argument( 70 | "--config_file", default="./configs/debug.yml", help="path to config file", type=str 71 | ) 72 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 73 | nargs=argparse.REMAINDER) 74 | 75 | args = parser.parse_args() 76 | 77 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 78 | 79 | if args.config_file != "": 80 | cfg.merge_from_file(args.config_file) 81 | cfg.merge_from_list(args.opts) 82 | cfg.freeze() 83 | 84 | output_dir = cfg.OUTPUT_DIR 85 | if output_dir and not os.path.exists(output_dir): 86 | mkdir(output_dir) 87 | 88 | logger = setup_logger("reid_baseline", output_dir, 0) 89 | logger.info("Using {} GPUS".format(num_gpus)) 90 | logger.info(args) 91 | 92 | if args.config_file != "": 93 | logger.info("Loaded configuration file {}".format(args.config_file)) 94 | # with open(args.config_file, 'r') as cf: 95 | # config_str = "\n" + cf.read() 96 | # logger.info(config_str) 97 | logger.info("Running with config:\n{}".format(cfg)) 98 | 99 | if cfg.MODEL.DEVICE == "cuda": 100 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 101 | cudnn.benchmark = True 102 | 103 | train_loader, val_loader, num_query, num_classes, dataset = make_data_loader(cfg) 104 | model = build_model(cfg, num_classes) 105 | model.load_param(cfg.TEST.WEIGHT) 106 | 107 | vis_actmap(model, cfg, val_loader) 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | --------------------------------------------------------------------------------