├── .github └── ISSUE_TEMPLATE │ └── bug_report.md ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── configs └── cuhkpedes │ ├── baseline_gru_cliprn101_ls_bs128.yaml │ ├── baseline_gru_cliprn50_ls_bs128.yaml │ ├── baseline_gru_rn50_ls_bs128.yaml │ ├── moco_gru_cliprn101_ls_bs128_2048.yaml │ └── moco_gru_cliprn50_ls_bs128_2048.yaml ├── lib ├── __init__.py ├── config │ ├── __init__.py │ ├── defaults.py │ └── paths_catalog.py ├── data │ ├── __init__.py │ ├── build.py │ ├── collate_batch.py │ ├── datasets │ │ ├── __init__.py │ │ ├── concat_dataset.py │ │ └── cuhkpedes.py │ ├── metrics │ │ ├── __init__.py │ │ └── evaluation.py │ ├── samplers │ │ ├── __init__.py │ │ └── triplet_batch_sampler.py │ └── transforms.py ├── engine │ ├── __init__.py │ ├── inference.py │ └── trainer.py ├── models │ ├── backbones │ │ ├── __init__.py │ │ ├── build.py │ │ ├── gru.py │ │ ├── m_resnet.py │ │ └── resnet.py │ ├── embeddings │ │ ├── __init__.py │ │ ├── build.py │ │ ├── moco_head │ │ │ ├── __init__.py │ │ │ ├── head.py │ │ │ └── loss.py │ │ └── simple_head │ │ │ ├── __init__.py │ │ │ ├── head.py │ │ │ └── loss.py │ ├── losses.py │ └── model.py ├── solver │ ├── __init__.py │ ├── build.py │ └── lr_scheduler.py └── utils │ ├── caption.py │ ├── checkpoint.py │ ├── comm.py │ ├── directory.py │ ├── logger.py │ └── metric_logger.py ├── requirements.txt ├── run.sh ├── run.submit_file ├── test_net.py └── train_net.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | *.pyc 3 | *.ipynb 4 | *.npy 5 | 6 | output/ 7 | datasets/ 8 | pretrained/ 9 | __pycache__/ 10 | condor_log/ 11 | .cache/ 12 | .nv/ 13 | docker_stderror 14 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | 3 | - repo: https://github.com/psf/black 4 | rev: 20.8b1 # Replace by any tag/version: https://github.com/psf/black/tags 5 | hooks: 6 | - id: black 7 | language_version: python3 # Should be a command that runs python3.6+ 8 | 9 | # isort 10 | - repo: https://github.com/timothycrosley/isort 11 | rev: 5.6.4 12 | hooks: 13 | - id: isort 14 | 15 | # flake8 16 | - repo: https://github.com/PyCQA/flake8 17 | rev: 3.8.3 18 | hooks: 19 | - id: flake8 20 | args: ["--config=setup.cfg", "--ignore=W504, W503, E501, E203, E741, F821"] 21 | 22 | # pre-commit-hooks 23 | - repo: https://github.com/pre-commit/pre-commit-hooks 24 | rev: v3.2.0 25 | hooks: 26 | - id: trailing-whitespace # Trim trailing whitespace 27 | - id: check-merge-conflict # Check for files that contain merge conflict strings 28 | - id: end-of-file-fixer # Make sure files end in a newline and only a newline 29 | - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0 30 | - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*- 31 | args: ["--remove"] 32 | - id: mixed-line-ending # Replace or check mixed line ending 33 | args: ["--fix=lf"] 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text Based Person Search with Limited Data 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/text-based-person-search-with-limited-data/nlp-based-person-retrival-on-cuhk-pedes)](https://paperswithcode.com/sota/nlp-based-person-retrival-on-cuhk-pedes?p=text-based-person-search-with-limited-data) 4 | 5 | This is the codebase for our [BMVC 2021 paper](https://arxiv.org/abs/2110.10807). 6 | 7 | Slides and video for the online presentation are now available at [BMVC 2021 virtual conference website](https://www.bmvc2021-virtualconference.com/conference/papers/paper_0044.html). 8 | 9 | ## Updates 10 | - (10/12/2021) Add download link of trained models. 11 | - (06/12/2021) Code refactor for easy reproduce. 12 | - (20/10/2021) Code released! 13 | 14 | ## Abstract 15 | Text-based person search (TBPS) aims at retrieving a target person from an image gallery with a descriptive text query. 16 | Solving such a fine-grained cross-modal retrieval task is challenging, which is further hampered by the lack of large-scale datasets. 17 | In this paper, we present a framework with two novel components to handle the problems brought by limited data. 18 | Firstly, to fully utilize the existing small-scale benchmarking datasets for more discriminative feature learning, we introduce a cross-modal momentum contrastive learning framework to enrich the training data for a given mini-batch. Secondly, we propose to transfer knowledge learned from existing coarse-grained large-scale datasets containing image-text pairs from drastically different problem domains to compensate for the lack of TBPS training data. A transfer learning method is designed so that useful information can be transferred despite the large domain gap. Armed with these components, our method achieves new state of the art on the CUHK-PEDES dataset with significant improvements over the prior art in terms of Rank-1 and mAP. 19 | 20 | ## Results 21 | ![image](https://user-images.githubusercontent.com/37724292/144879635-86ab9c7b-0317-4b42-ac46-a37b06853d18.png) 22 | 23 | ## Installation 24 | ### Setup environment 25 | ```bash 26 | conda create -n txtreid-env python=3.7 27 | conda activate txtreid-env 28 | git clone https://github.com/BrandonHanx/TextReID.git 29 | cd TextReID 30 | pip install -r requirements.txt 31 | pre-commit install 32 | ``` 33 | ### Get CUHK-PEDES dataset 34 | - Request the images from [Dr. Shuang Li](https://github.com/ShuangLI59/Person-Search-with-Natural-Language-Description). 35 | - Download the pre-processed captions we provide from [Google Drive](https://drive.google.com/file/d/1V4d8OjFket5SaQmBVozFFeflNs6f9e1R/view?usp=sharing). 36 | - Organize the dataset as following: 37 | ```bash 38 | datasets 39 | └── cuhkpedes 40 | ├── annotations 41 | │ ├── test.json 42 | │ ├── train.json 43 | │ └── val.json 44 | ├── clip_vocab_vit.npy 45 | └── imgs 46 | ├── cam_a 47 | ├── cam_b 48 | ├── CUHK01 49 | ├── CUHK03 50 | ├── Market 51 | ├── test_query 52 | └── train_query 53 | ``` 54 | 55 | ### Download CLIP weights 56 | ```bash 57 | mkdir pretrained/clip/ 58 | cd pretrained/clip 59 | wget https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt 60 | wget https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt 61 | cd - 62 | 63 | ``` 64 | 65 | ### Train 66 | ```bash 67 | python train_net.py \ 68 | --config-file configs/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048.yaml \ 69 | --use-tensorboard 70 | ``` 71 | ### Inference 72 | ```bash 73 | python test_net.py \ 74 | --config-file configs/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048.yaml \ 75 | --checkpoint-file output/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048/best.pth 76 | ``` 77 | You can download our trained models (with CLIP RN50 and RN101) from [Google Drive](https://drive.google.com/drive/folders/1MoceVsLiByg3Sg8_9yByGSvR3ru15hJL?usp=sharing). 78 | 79 | ## TODO 80 | - [ ] Try larger pre-trained CLIP models. 81 | - [ ] Fix the bug of multi-gpu runninng. 82 | - [ ] Add dataloader for [ICFG-PEDES](https://github.com/zifyloo/SSAN). 83 | 84 | ## Citation 85 | If you find this project useful for your research, please use the following BibTeX entry. 86 | ``` 87 | @inproceedings{han2021textreid, 88 | title={Text-Based Person Search with Limited Data}, 89 | author={Han, Xiao and He, Sen and Zhang, Li and Xiang, Tao}, 90 | booktitle={BMVC}, 91 | year={2021} 92 | } 93 | ``` 94 | -------------------------------------------------------------------------------- /configs/cuhkpedes/baseline_gru_cliprn101_ls_bs128.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | WEIGHT: "imagenet" 3 | FREEZE: False 4 | VISUAL_MODEL: "m_resnet101" 5 | TEXTUAL_MODEL: "bigru" 6 | NUM_CLASSES: 11003 7 | GRU: 8 | ONEHOT: "clip_vit" 9 | EMBEDDING_SIZE: 512 10 | NUM_UNITS: 512 11 | VOCABULARY_SIZE: 512 12 | DROPOUT_KEEP_PROB: 1.0 13 | MAX_LENGTH: 100 14 | RESNET: 15 | RES5_STRIDE: 1 16 | EMBEDDING: 17 | EMBED_HEAD: 'simple' 18 | FEATURE_SIZE: 256 19 | DROPOUT_PROB: 0.0 20 | EPSILON: 0.1 21 | INPUT: 22 | HEIGHT: 384 23 | WIDTH: 128 24 | USE_AUG: True 25 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 26 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 27 | DATASETS: 28 | TRAIN: ("cuhkpedes_train", ) 29 | TEST: ("cuhkpedes_test", ) 30 | SOLVER: 31 | IMS_PER_BATCH: 128 32 | NUM_EPOCHS: 80 33 | BASE_LR: 0.0001 34 | WEIGHT_DECAY: 0.00004 35 | CHECKPOINT_PERIOD: 40 36 | LRSCHEDULER: 'step' 37 | STEPS: (40, 70) 38 | WARMUP_FACTOR: 0.1 39 | WARMUP_EPOCHS: 5 40 | TEST: 41 | IMS_PER_BATCH: 128 42 | -------------------------------------------------------------------------------- /configs/cuhkpedes/baseline_gru_cliprn50_ls_bs128.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | WEIGHT: "imagenet" 3 | FREEZE: False 4 | VISUAL_MODEL: "m_resnet50" 5 | TEXTUAL_MODEL: "bigru" 6 | NUM_CLASSES: 11003 7 | GRU: 8 | ONEHOT: "clip_vit" 9 | EMBEDDING_SIZE: 512 10 | NUM_UNITS: 512 11 | VOCABULARY_SIZE: 512 12 | DROPOUT_KEEP_PROB: 1.0 13 | MAX_LENGTH: 100 14 | RESNET: 15 | RES5_STRIDE: 1 16 | EMBEDDING: 17 | EMBED_HEAD: 'simple' 18 | FEATURE_SIZE: 256 19 | DROPOUT_PROB: 0.0 20 | EPSILON: 0.1 21 | INPUT: 22 | HEIGHT: 384 23 | WIDTH: 128 24 | USE_AUG: True 25 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 26 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 27 | DATASETS: 28 | TRAIN: ("cuhkpedes_train", ) 29 | TEST: ("cuhkpedes_test", ) 30 | SOLVER: 31 | IMS_PER_BATCH: 128 32 | NUM_EPOCHS: 80 33 | BASE_LR: 0.0001 34 | WEIGHT_DECAY: 0.00004 35 | CHECKPOINT_PERIOD: 40 36 | LRSCHEDULER: 'step' 37 | STEPS: (40, 70) 38 | WARMUP_FACTOR: 0.1 39 | WARMUP_EPOCHS: 5 40 | TEST: 41 | IMS_PER_BATCH: 128 42 | -------------------------------------------------------------------------------- /configs/cuhkpedes/baseline_gru_rn50_ls_bs128.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | WEIGHT: "imagenet" 3 | FREEZE: False 4 | VISUAL_MODEL: "resnet50" 5 | TEXTUAL_MODEL: "bigru" 6 | NUM_CLASSES: 11003 7 | GRU: 8 | ONEHOT: "yes" 9 | EMBEDDING_SIZE: 512 10 | NUM_UNITS: 512 11 | VOCABULARY_SIZE: 12000 12 | DROPOUT_KEEP_PROB: 1.0 13 | MAX_LENGTH: 100 14 | RESNET: 15 | RES5_STRIDE: 1 16 | EMBEDDING: 17 | EMBED_HEAD: 'simple' 18 | FEATURE_SIZE: 256 19 | DROPOUT_PROB: 0.0 20 | EPSILON: 0.1 21 | INPUT: 22 | HEIGHT: 384 23 | WIDTH: 128 24 | USE_AUG: True 25 | DATASETS: 26 | TRAIN: ("cuhkpedes_train", ) 27 | TEST: ("cuhkpedes_test", ) 28 | SOLVER: 29 | IMS_PER_BATCH: 128 30 | NUM_EPOCHS: 80 31 | BASE_LR: 0.0001 32 | WEIGHT_DECAY: 0.00004 33 | CHECKPOINT_PERIOD: 40 34 | LRSCHEDULER: 'step' 35 | STEPS: (40, 70) 36 | WARMUP_FACTOR: 0.1 37 | WARMUP_EPOCHS: 5 38 | TEST: 39 | IMS_PER_BATCH: 128 40 | -------------------------------------------------------------------------------- /configs/cuhkpedes/moco_gru_cliprn101_ls_bs128_2048.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | WEIGHT: "imagenet" 3 | FREEZE: False 4 | VISUAL_MODEL: "m_resnet101" 5 | TEXTUAL_MODEL: "bigru" 6 | NUM_CLASSES: 11003 7 | GRU: 8 | ONEHOT: "clip_vit" 9 | EMBEDDING_SIZE: 512 10 | NUM_UNITS: 512 11 | VOCABULARY_SIZE: 512 12 | DROPOUT_KEEP_PROB: 1.0 13 | MAX_LENGTH: 100 14 | RESNET: 15 | RES5_STRIDE: 1 16 | EMBEDDING: 17 | EMBED_HEAD: 'moco' 18 | FEATURE_SIZE: 256 19 | DROPOUT_PROB: 0.0 20 | EPSILON: 0.1 21 | MOCO: 22 | FC: False 23 | K: 2048 24 | INPUT: 25 | HEIGHT: 384 26 | WIDTH: 128 27 | USE_AUG: True 28 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 29 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 30 | DATASETS: 31 | TRAIN: ("cuhkpedes_train", ) 32 | TEST: ("cuhkpedes_test", ) 33 | SOLVER: 34 | IMS_PER_BATCH: 128 35 | NUM_EPOCHS: 80 36 | BASE_LR: 0.0001 37 | WEIGHT_DECAY: 0.00004 38 | CHECKPOINT_PERIOD: 40 39 | LRSCHEDULER: 'step' 40 | STEPS: (40, 70) 41 | WARMUP_FACTOR: 0.1 42 | WARMUP_EPOCHS: 5 43 | TEST: 44 | IMS_PER_BATCH: 128 45 | -------------------------------------------------------------------------------- /configs/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | WEIGHT: "imagenet" 3 | FREEZE: False 4 | VISUAL_MODEL: "m_resnet50" 5 | TEXTUAL_MODEL: "bigru" 6 | NUM_CLASSES: 11003 7 | GRU: 8 | ONEHOT: "clip_vit" 9 | EMBEDDING_SIZE: 512 10 | NUM_UNITS: 512 11 | VOCABULARY_SIZE: 512 12 | DROPOUT_KEEP_PROB: 1.0 13 | MAX_LENGTH: 100 14 | RESNET: 15 | RES5_STRIDE: 1 16 | EMBEDDING: 17 | EMBED_HEAD: 'moco' 18 | FEATURE_SIZE: 256 19 | DROPOUT_PROB: 0.0 20 | EPSILON: 0.1 21 | MOCO: 22 | FC: False 23 | K: 2048 24 | INPUT: 25 | HEIGHT: 384 26 | WIDTH: 128 27 | USE_AUG: True 28 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 29 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 30 | DATASETS: 31 | TRAIN: ("cuhkpedes_train", ) 32 | TEST: ("cuhkpedes_test", ) 33 | SOLVER: 34 | IMS_PER_BATCH: 128 35 | NUM_EPOCHS: 80 36 | BASE_LR: 0.0001 37 | WEIGHT_DECAY: 0.00004 38 | CHECKPOINT_PERIOD: 40 39 | LRSCHEDULER: 'step' 40 | STEPS: (40, 70) 41 | WARMUP_FACTOR: 0.1 42 | WARMUP_EPOCHS: 5 43 | TEST: 44 | IMS_PER_BATCH: 128 45 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrandonHanx/TextReID/0d00d8e0844fbd3f322147786affcc19d0e42b68/lib/__init__.py -------------------------------------------------------------------------------- /lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg 2 | 3 | __all__ = ["cfg"] 4 | -------------------------------------------------------------------------------- /lib/config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _C = CN() 4 | _C.ROOT = "./" 5 | 6 | # ----------------------------------------------------------------------------- 7 | # Dataset 8 | # ----------------------------------------------------------------------------- 9 | _C.DATASETS = CN() 10 | _C.DATASETS.TRAIN = () 11 | _C.DATASETS.TEST = () 12 | _C.DATASETS.USE_ONEHOT = True 13 | 14 | 15 | # ----------------------------------------------------------------------------- 16 | # DataLoader 17 | # ----------------------------------------------------------------------------- 18 | _C.DATALOADER = CN() 19 | # Number of data loading threads 20 | _C.DATALOADER.NUM_WORKERS = 4 21 | _C.DATALOADER.IMS_PER_ID = 4 22 | _C.DATALOADER.EN_SAMPLER = True 23 | 24 | 25 | # ----------------------------------------------------------------------------- 26 | # Input 27 | # ----------------------------------------------------------------------------- 28 | _C.INPUT = CN() 29 | _C.INPUT.HEIGHT = 224 30 | _C.INPUT.WIDTH = 224 31 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 32 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 33 | _C.INPUT.PADDING = 10 34 | _C.INPUT.USE_AUG = False 35 | 36 | 37 | # ----------------------------------------------------------------------------- 38 | # Model 39 | # ----------------------------------------------------------------------------- 40 | _C.MODEL = CN() 41 | _C.MODEL.DEVICE = "cuda" 42 | _C.MODEL.VISUAL_MODEL = "resnet50" 43 | _C.MODEL.TEXTUAL_MODEL = "bilstm" 44 | _C.MODEL.NUM_CLASSES = 11003 45 | _C.MODEL.FREEZE = False 46 | _C.MODEL.WEIGHT = "imagenet" 47 | 48 | 49 | # ----------------------------------------------------------------------------- 50 | # MoCo 51 | # ----------------------------------------------------------------------------- 52 | _C.MODEL.MOCO = CN() 53 | _C.MODEL.MOCO.K = 1024 54 | _C.MODEL.MOCO.M = 0.999 55 | _C.MODEL.MOCO.FC = True 56 | 57 | 58 | # ----------------------------------------------------------------------------- 59 | # GRU 60 | # ----------------------------------------------------------------------------- 61 | _C.MODEL.GRU = CN() 62 | _C.MODEL.GRU.ONEHOT = "yes" 63 | _C.MODEL.GRU.EMBEDDING_SIZE = 512 64 | _C.MODEL.GRU.NUM_UNITS = 512 65 | _C.MODEL.GRU.VOCABULARY_SIZE = 12000 66 | _C.MODEL.GRU.DROPOUT_KEEP_PROB = 0.7 67 | _C.MODEL.GRU.MAX_LENGTH = 100 68 | _C.MODEL.GRU.NUM_LAYER = 1 69 | 70 | 71 | # ----------------------------------------------------------------------------- 72 | # Resnet 73 | # ----------------------------------------------------------------------------- 74 | _C.MODEL.RESNET = CN() 75 | _C.MODEL.RESNET.RES5_STRIDE = 2 76 | _C.MODEL.RESNET.RES5_DILATION = 1 77 | _C.MODEL.RESNET.PRETRAINED = None 78 | 79 | 80 | # ----------------------------------------------------------------------------- 81 | # Embedding 82 | # ----------------------------------------------------------------------------- 83 | _C.MODEL.EMBEDDING = CN() 84 | _C.MODEL.EMBEDDING.EMBED_HEAD = "simple" 85 | _C.MODEL.EMBEDDING.FEATURE_SIZE = 512 86 | _C.MODEL.EMBEDDING.DROPOUT_PROB = 0.3 87 | _C.MODEL.EMBEDDING.EPSILON = 0.0 88 | 89 | 90 | # ----------------------------------------------------------------------------- 91 | # Solver 92 | # ----------------------------------------------------------------------------- 93 | _C.SOLVER = CN() 94 | _C.SOLVER.IMS_PER_BATCH = 16 95 | _C.SOLVER.NUM_EPOCHS = 100 96 | _C.SOLVER.CHECKPOINT_PERIOD = 1 97 | _C.SOLVER.EVALUATE_PERIOD = 1 98 | 99 | _C.SOLVER.OPTIMIZER = "Adam" 100 | _C.SOLVER.BASE_LR = 0.0002 101 | _C.SOLVER.BIAS_LR_FACTOR = 2 102 | 103 | _C.SOLVER.WEIGHT_DECAY = 0.00004 104 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0.0 105 | 106 | _C.SOLVER.ADAM_ALPHA = 0.9 107 | _C.SOLVER.ADAM_BETA = 0.999 108 | _C.SOLVER.SGD_MOMENTUM = 0.9 109 | 110 | _C.SOLVER.LRSCHEDULER = "step" 111 | 112 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3 113 | _C.SOLVER.WARMUP_EPOCHS = 10 114 | _C.SOLVER.WARMUP_METHOD = "linear" 115 | 116 | _C.SOLVER.GAMMA = 0.1 117 | _C.SOLVER.STEPS = (500,) 118 | 119 | _C.SOLVER.POWER = 0.9 120 | _C.SOLVER.TARGET_LR = 0.0001 121 | 122 | 123 | # ---------------------------------------------------------------------------- # 124 | # Specific test options 125 | # ---------------------------------------------------------------------------- # 126 | _C.TEST = CN() 127 | # Number of images per batch 128 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 129 | # see 2 images per batch 130 | _C.TEST.IMS_PER_BATCH = 16 131 | 132 | 133 | # ---------------------------------------------------------------------------- # 134 | # Misc options 135 | # ---------------------------------------------------------------------------- # 136 | 137 | 138 | # ---------------------------------------------------------------------------- # 139 | # Precision options 140 | # ---------------------------------------------------------------------------- # 141 | # Precision of input, allowable: (float32, float16) 142 | _C.DTYPE = "float32" 143 | # Enable verbosity in apex.amp 144 | _C.AMP_VERBOSE = False 145 | -------------------------------------------------------------------------------- /lib/config/paths_catalog.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class DatasetCatalog: 5 | DATA_DIR = "datasets" 6 | DATASETS = { 7 | "cuhkpedes_train": { 8 | "img_dir": "cuhkpedes", 9 | "ann_file": "cuhkpedes/annotations/train.json", 10 | }, 11 | "cuhkpedes_val": { 12 | "img_dir": "cuhkpedes", 13 | "ann_file": "cuhkpedes/annotations/val.json", 14 | }, 15 | "cuhkpedes_test": { 16 | "img_dir": "cuhkpedes", 17 | "ann_file": "cuhkpedes/annotations/test.json", 18 | }, 19 | } 20 | 21 | @staticmethod 22 | def get(root, name): 23 | if "cuhkpedes" in name: 24 | data_dir = DatasetCatalog.DATA_DIR 25 | attrs = DatasetCatalog.DATASETS[name] 26 | args = dict( 27 | root=os.path.join(root, data_dir, attrs["img_dir"]), 28 | ann_file=os.path.join(root, data_dir, attrs["ann_file"]), 29 | ) 30 | return dict( 31 | factory="CUHKPEDESDataset", 32 | args=args, 33 | ) 34 | raise RuntimeError("Dataset not available: {}".format(name)) 35 | -------------------------------------------------------------------------------- /lib/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .build import make_data_loader 3 | 4 | __all__ = ["make_data_loader"] 5 | -------------------------------------------------------------------------------- /lib/data/build.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | 3 | from lib.config.paths_catalog import DatasetCatalog 4 | from lib.utils.comm import get_world_size 5 | 6 | from . import datasets as D 7 | from . import samplers 8 | from .collate_batch import collate_fn 9 | from .transforms import build_transforms 10 | 11 | 12 | def build_dataset(cfg, dataset_list, transforms, dataset_catalog, is_train=True): 13 | if not isinstance(dataset_list, (list, tuple)): 14 | raise RuntimeError( 15 | "dataset_list should be a list of strings, got {}".format(dataset_list) 16 | ) 17 | datasets = [] 18 | for dataset_name in dataset_list: 19 | data = dataset_catalog.get(cfg.ROOT, dataset_name) 20 | factory = getattr(D, data["factory"]) 21 | args = data["args"] 22 | args["transforms"] = transforms 23 | 24 | if data["factory"] == "CUHKPEDESDataset": 25 | args["use_onehot"] = cfg.DATASETS.USE_ONEHOT 26 | args["max_length"] = 105 27 | 28 | # make dataset from factory 29 | dataset = factory(**args) 30 | datasets.append(dataset) 31 | 32 | # for testing, return a list of datasets 33 | if not is_train: 34 | return datasets 35 | 36 | # for training, concatenate all datasets into a single one 37 | dataset = datasets[0] 38 | if len(datasets) > 1: 39 | dataset = D.ConcatDataset(datasets) 40 | 41 | return [dataset] 42 | 43 | 44 | def make_data_sampler(dataset, shuffle, distributed): 45 | if distributed: 46 | return torch.utils.data.distributed.DistributedSampler(dataset) 47 | if shuffle: 48 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 49 | else: 50 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 51 | return sampler 52 | 53 | 54 | def make_batch_data_sampler(cfg, dataset, sampler, images_per_batch, is_train=True): 55 | if is_train and cfg.DATALOADER.EN_SAMPLER: 56 | batch_sampler = samplers.TripletSampler( 57 | sampler, 58 | dataset, 59 | images_per_batch, 60 | cfg.DATALOADER.IMS_PER_ID, 61 | drop_last=True, 62 | ) 63 | else: 64 | batch_sampler = torch.utils.data.sampler.BatchSampler( 65 | sampler, images_per_batch, drop_last=is_train 66 | ) 67 | return batch_sampler 68 | 69 | 70 | def make_data_loader(cfg, is_train=True, is_distributed=False): 71 | num_gpus = get_world_size() 72 | if is_train: 73 | images_per_batch = cfg.SOLVER.IMS_PER_BATCH 74 | assert ( 75 | images_per_batch % num_gpus == 0 76 | ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of GPUs ({}) used.".format( 77 | images_per_batch, num_gpus 78 | ) 79 | images_per_gpu = images_per_batch // num_gpus 80 | shuffle = True 81 | else: 82 | images_per_batch = cfg.TEST.IMS_PER_BATCH 83 | assert ( 84 | images_per_batch % num_gpus == 0 85 | ), "TEST.IMS_PER_BATCH ({}) must be divisible by the number of GPUs ({}) used.".format( 86 | images_per_batch, num_gpus 87 | ) 88 | images_per_gpu = images_per_batch // num_gpus 89 | shuffle = is_distributed 90 | 91 | dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST 92 | 93 | transforms = build_transforms(cfg, is_train) 94 | 95 | datasets = build_dataset(cfg, dataset_list, transforms, DatasetCatalog, is_train) 96 | 97 | data_loaders = [] 98 | for dataset in datasets: 99 | sampler = make_data_sampler(dataset, shuffle, is_distributed) 100 | batch_sampler = make_batch_data_sampler( 101 | cfg, dataset, sampler, images_per_gpu, is_train 102 | ) 103 | num_workers = cfg.DATALOADER.NUM_WORKERS 104 | data_loader = torch.utils.data.DataLoader( 105 | dataset, 106 | num_workers=num_workers, 107 | batch_sampler=batch_sampler, 108 | collate_fn=collate_fn, 109 | ) 110 | data_loaders.append(data_loader) 111 | if is_train: 112 | # during training, a single (possibly concatenated) data_loader is returned 113 | assert len(data_loaders) == 1 114 | return data_loaders[0] 115 | return data_loaders 116 | -------------------------------------------------------------------------------- /lib/data/collate_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def collate_fn(batch): 5 | transposed_batch = list(zip(*batch)) 6 | images = torch.stack(transposed_batch[0]) 7 | captions = transposed_batch[1] 8 | img_ids = transposed_batch[2] 9 | return images, captions, img_ids 10 | -------------------------------------------------------------------------------- /lib/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .concat_dataset import ConcatDataset 2 | from .cuhkpedes import CUHKPEDESDataset 3 | 4 | __all__ = ["ConcatDataset", "CUHKPEDESDataset"] 5 | -------------------------------------------------------------------------------- /lib/data/datasets/concat_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import bisect 3 | 4 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 5 | 6 | 7 | class ConcatDataset(_ConcatDataset): 8 | """ 9 | Same as torch.utils.data.dataset.ConcatDataset, but exposes an extra 10 | method for querying the sizes of the image 11 | """ 12 | 13 | def get_idxs(self, idx): 14 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 15 | if dataset_idx == 0: 16 | sample_idx = idx 17 | else: 18 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 19 | return dataset_idx, sample_idx 20 | 21 | def get_id_info(self, idx): 22 | dataset_idx, sample_idx = self.get_idxs(idx) 23 | return self.datasets[dataset_idx].get_id_info(sample_idx) 24 | -------------------------------------------------------------------------------- /lib/data/datasets/cuhkpedes.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | from PIL import Image 6 | 7 | from lib.utils.caption import Caption 8 | 9 | 10 | class CUHKPEDESDataset(torch.utils.data.Dataset): 11 | def __init__( 12 | self, 13 | root, 14 | ann_file, 15 | use_onehot=True, 16 | max_length=100, 17 | transforms=None, 18 | ): 19 | self.root = root 20 | self.use_onehot = use_onehot 21 | self.max_length = max_length 22 | self.transforms = transforms 23 | 24 | self.img_dir = os.path.join(self.root, "imgs") 25 | 26 | print("loading annotations into memory...") 27 | dataset = json.load(open(ann_file, "r")) 28 | self.dataset = dataset["annotations"] 29 | 30 | def __getitem__(self, index): 31 | """ 32 | Args: 33 | index(int): Index 34 | Returns: 35 | tuple: (images, labels, captions) 36 | """ 37 | data = self.dataset[index] 38 | 39 | img_path = data["file_path"] 40 | img = Image.open(os.path.join(self.img_dir, img_path)).convert("RGB") 41 | 42 | if self.use_onehot: 43 | caption = data["onehot"] 44 | caption = torch.tensor(caption) 45 | caption = Caption([caption], max_length=self.max_length, padded=False) 46 | else: 47 | caption = data["sentence"] 48 | caption = Caption(caption) 49 | 50 | caption.add_field("img_path", img_path) 51 | 52 | label = data["id"] 53 | label = torch.tensor(label) 54 | caption.add_field("id", label) 55 | 56 | if self.transforms is not None: 57 | img = self.transforms(img) 58 | 59 | return img, caption, index 60 | 61 | def __len__(self): 62 | return len(self.dataset) 63 | 64 | def get_id_info(self, index): 65 | image_id = self.dataset[index]["image_id"] 66 | pid = self.dataset[index]["id"] 67 | return image_id, pid 68 | -------------------------------------------------------------------------------- /lib/data/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluation import evaluation 2 | 3 | __all__ = ["evaluation"] 4 | -------------------------------------------------------------------------------- /lib/data/metrics/evaluation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from lib.utils.logger import table_log 9 | 10 | 11 | def rank(similarity, q_pids, g_pids, topk=[1, 5, 10], get_mAP=True): 12 | max_rank = max(topk) 13 | if get_mAP: 14 | indices = torch.argsort(similarity, dim=1, descending=True) 15 | else: 16 | # acclerate sort with topk 17 | _, indices = torch.topk( 18 | similarity, k=max_rank, dim=1, largest=True, sorted=True 19 | ) # q * topk 20 | pred_labels = g_pids[indices] # q * k 21 | matches = pred_labels.eq(q_pids.view(-1, 1)) # q * k 22 | 23 | all_cmc = matches[:, :max_rank].cumsum(1) 24 | all_cmc[all_cmc > 1] = 1 25 | all_cmc = all_cmc.float().mean(0) * 100 26 | all_cmc = all_cmc[topk - 1] 27 | 28 | if not get_mAP: 29 | return all_cmc, indices 30 | 31 | num_rel = matches.sum(1) # q 32 | tmp_cmc = matches.cumsum(1) # q * k 33 | tmp_cmc = [tmp_cmc[:, i] / (i + 1.0) for i in range(tmp_cmc.shape[1])] 34 | tmp_cmc = torch.stack(tmp_cmc, 1) * matches 35 | AP = tmp_cmc.sum(1) / num_rel # q 36 | mAP = AP.mean() * 100 37 | return all_cmc, mAP, indices 38 | 39 | 40 | def jaccard(a_list, b_list): 41 | return float(len(set(a_list) & set(b_list))) / float(len(set(a_list) | set(b_list))) 42 | 43 | 44 | def jaccard_mat(row_nn, col_nn): 45 | jaccard_sim = np.zeros((row_nn.shape[0], col_nn.shape[0])) 46 | # FIXME: need optimization 47 | for i in range(row_nn.shape[0]): 48 | for j in range(col_nn.shape[0]): 49 | jaccard_sim[i, j] = jaccard(row_nn[i], col_nn[j]) 50 | return torch.from_numpy(jaccard_sim) 51 | 52 | 53 | def k_reciprocal(q_feats, g_feats, neighbor_num=5, alpha=0.05): 54 | qg_sim = torch.matmul(q_feats, g_feats.t()) # q * g 55 | gg_sim = torch.matmul(g_feats, g_feats.t()) # g * g 56 | 57 | qg_indices = torch.argsort(qg_sim, dim=1, descending=True) 58 | gg_indices = torch.argsort(gg_sim, dim=1, descending=True) 59 | 60 | qg_nn = qg_indices[:, :neighbor_num] # q * n 61 | gg_nn = gg_indices[:, :neighbor_num] # g * n 62 | 63 | jaccard_sim = jaccard_mat(qg_nn.cpu().numpy(), gg_nn.cpu().numpy()) # q * g 64 | jaccard_sim = jaccard_sim.to(qg_sim.device) 65 | return alpha * jaccard_sim # q * g 66 | 67 | 68 | def get_unique(image_ids): 69 | keep_idx = {} 70 | for idx, image_id in enumerate(image_ids): 71 | if image_id not in keep_idx.keys(): 72 | keep_idx[image_id] = idx 73 | return torch.tensor(list(keep_idx.values())) 74 | 75 | 76 | def evaluation( 77 | dataset, 78 | predictions, 79 | output_folder, 80 | topk, 81 | save_data=True, 82 | rerank=True, 83 | ): 84 | logger = logging.getLogger("PersonSearch.inference") 85 | data_dir = os.path.join(output_folder, "inference_data.npz") 86 | 87 | if predictions is None: 88 | inference_data = np.load(data_dir) 89 | logger.info("Load inference data from {}".format(data_dir)) 90 | image_pid = torch.tensor(inference_data["image_pid"]) 91 | text_pid = torch.tensor(inference_data["text_pid"]) 92 | similarity = torch.tensor(inference_data["similarity"]) 93 | if rerank: 94 | rvn_mat = torch.tensor(inference_data["rvn_mat"]) 95 | rtn_mat = torch.tensor(inference_data["rtn_mat"]) 96 | else: 97 | image_ids, pids = [], [] 98 | image_global, text_global = [], [] 99 | 100 | # FIXME: need optimization 101 | for idx, prediction in predictions.items(): 102 | image_id, pid = dataset.get_id_info(idx) 103 | image_ids.append(image_id) 104 | pids.append(pid) 105 | image_global.append(prediction[0]) 106 | text_global.append(prediction[1]) 107 | 108 | image_pid = torch.tensor(pids) 109 | text_pid = torch.tensor(pids) 110 | image_global = torch.stack(image_global, dim=0) 111 | text_global = torch.stack(text_global, dim=0) 112 | 113 | keep_idx = get_unique(image_ids) 114 | image_global = image_global[keep_idx] 115 | image_pid = image_pid[keep_idx] 116 | 117 | image_global = F.normalize(image_global, p=2, dim=1) 118 | text_global = F.normalize(text_global, p=2, dim=1) 119 | 120 | similarity = torch.matmul(text_global, image_global.t()) 121 | 122 | if rerank: 123 | rtn_mat = k_reciprocal(image_global, text_global) 124 | rvn_mat = k_reciprocal(text_global, image_global) 125 | 126 | if save_data: 127 | if not rerank: 128 | np.savez( 129 | data_dir, 130 | image_pid=image_pid.cpu().numpy(), 131 | text_pid=text_pid.cpu().numpy(), 132 | similarity=similarity.cpu().numpy(), 133 | ) 134 | else: 135 | np.savez( 136 | data_dir, 137 | image_pid=image_pid.cpu().numpy(), 138 | text_pid=text_pid.cpu().numpy(), 139 | similarity=similarity.cpu().numpy(), 140 | rvn_mat=rvn_mat.cpu().numpy(), 141 | rtn_mat=rtn_mat.cpu().numpy(), 142 | ) 143 | 144 | topk = torch.tensor(topk) 145 | 146 | if rerank: 147 | i2t_cmc, i2t_mAP, _ = rank( 148 | similarity.t(), image_pid, text_pid, topk, get_mAP=True 149 | ) 150 | t2i_cmc, t2i_mAP, _ = rank(similarity, text_pid, image_pid, topk, get_mAP=True) 151 | re_i2t_cmc, re_i2t_mAP, _ = rank( 152 | rtn_mat + similarity.t(), image_pid, text_pid, topk, get_mAP=True 153 | ) 154 | re_t2i_cmc, re_t2i_mAP, _ = rank( 155 | rvn_mat + similarity, text_pid, image_pid, topk, get_mAP=True 156 | ) 157 | cmc_results = torch.stack([topk, t2i_cmc, re_t2i_cmc, i2t_cmc, re_i2t_cmc]) 158 | mAP_results = torch.stack( 159 | [torch.zeros_like(t2i_mAP), t2i_mAP, re_t2i_mAP, i2t_mAP, re_i2t_mAP] 160 | ).unsqueeze(-1) 161 | results = torch.cat([cmc_results, mAP_results], dim=1) 162 | results = results.t().cpu().numpy().tolist() 163 | results[-1][0] = "mAP" 164 | logger.info( 165 | "\n" 166 | + table_log(results, headers=["topk", "t2i", "re-t2i", "i2t", "re-i2t"]) 167 | ) 168 | else: 169 | t2i_cmc, _ = rank(similarity, text_pid, image_pid, topk, get_mAP=False) 170 | i2t_cmc, _ = rank(similarity.t(), image_pid, text_pid, topk, get_mAP=False) 171 | results = torch.stack((topk, t2i_cmc, i2t_cmc)).t().cpu().numpy() 172 | logger.info("\n" + table_log(results, headers=["topk", "t2i", "i2t"])) 173 | return t2i_cmc[0] 174 | -------------------------------------------------------------------------------- /lib/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .triplet_batch_sampler import TripletSampler 2 | 3 | __all__ = ["TripletSampler"] 4 | -------------------------------------------------------------------------------- /lib/data/samplers/triplet_batch_sampler.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import random 4 | from collections import defaultdict 5 | 6 | import torch 7 | from torch.utils.data.sampler import BatchSampler 8 | 9 | 10 | def _split(tensor, size, dim=0, drop_last=False): 11 | if dim < 0: 12 | dim += tensor.dim() 13 | dim_size = tensor.size(dim) 14 | 15 | if dim_size < size: 16 | times = math.ceil(size / dim_size) 17 | tensor = tensor.repeat_interleave(times) 18 | dim_size = size 19 | 20 | split_size = size 21 | num_splits = (dim_size + split_size - 1) // split_size 22 | last_split_size = split_size - (split_size * num_splits - dim_size) 23 | 24 | def get_split_size(i): 25 | return split_size if i < num_splits - 1 else last_split_size 26 | 27 | if drop_last and last_split_size != split_size: 28 | total_num_splits = num_splits - 1 29 | else: 30 | total_num_splits = num_splits 31 | 32 | return list( 33 | tensor.narrow(int(dim), int(i * split_size), int(get_split_size(i))) 34 | for i in range(0, total_num_splits) 35 | ) 36 | 37 | 38 | def _merge(splits, pids, num_pids_per_batch): 39 | avaible_pids = copy.deepcopy(pids) 40 | merged = [] 41 | 42 | while len(avaible_pids) >= num_pids_per_batch: 43 | batch = [] 44 | selected_pids = random.sample(avaible_pids, num_pids_per_batch) 45 | for pid in selected_pids: 46 | batch_idxs = splits[pid].pop(0) 47 | batch.extend(batch_idxs.tolist()) 48 | if len(splits[pid]) == 0: 49 | avaible_pids.remove(pid) 50 | merged.append(batch) 51 | return merged 52 | 53 | 54 | def _map(dataset): 55 | id_to_img_map = [] 56 | for i in range(len(dataset)): 57 | _, pid = dataset.get_id_info(i) 58 | id_to_img_map.append(pid) 59 | return id_to_img_map 60 | 61 | 62 | class TripletSampler(BatchSampler): 63 | """ 64 | Randomly sample N identities, then for each identity, 65 | randomly sample K instances, therefore batch size is N*K. 66 | Args: 67 | - data_source (list): list of (img_path, pid, camid). 68 | - num_instances (int): number of instances per identity in a batch. 69 | - batch_size (int): number of examples in a batch. 70 | """ 71 | 72 | def __init__(self, sampler, data_source, batch_size, images_per_pid, drop_last): 73 | super(TripletSampler, self).__init__(sampler, batch_size, drop_last) 74 | self.num_instances = images_per_pid 75 | self.num_pids_per_batch = batch_size // images_per_pid 76 | self.id_to_img_map = _map(data_source) 77 | self.index_dict = defaultdict(list) 78 | for index, pid in enumerate(self.id_to_img_map): 79 | self.index_dict[pid].append(index) 80 | self.pids = list(self.index_dict.keys()) 81 | 82 | self.group_ids = torch.as_tensor(self.id_to_img_map) 83 | self.groups = torch.unique(self.group_ids).sort(0)[0] 84 | 85 | self._can_reuse_batches = False 86 | 87 | def _prepare_batches(self): 88 | dataset_size = len(self.group_ids) 89 | sampled_ids = torch.as_tensor(list(self.sampler)) 90 | order = torch.full((dataset_size,), -1, dtype=torch.int64) 91 | order[sampled_ids] = torch.arange(len(sampled_ids)) 92 | 93 | mask = order >= 0 94 | clusters = [(self.group_ids == i) & mask for i in self.groups] 95 | relative_order = [order[cluster] for cluster in clusters] 96 | permutation_ids = [s.sort()[0] for s in relative_order] 97 | permuted_clusters = [sampled_ids[idx] for idx in permutation_ids] 98 | 99 | splits = defaultdict(list) 100 | for idx, c in enumerate(permuted_clusters): 101 | splits[idx] = _split(c, self.num_instances, drop_last=True) 102 | merged = _merge(splits, self.pids, self.num_pids_per_batch) 103 | 104 | first_element_of_batch = [t[0] for t in merged] 105 | inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())} 106 | first_index_of_batch = torch.as_tensor( 107 | [inv_sampled_ids_map[s] for s in first_element_of_batch] 108 | ) 109 | permutation_order = first_index_of_batch.sort(0)[1].tolist() 110 | batches = [merged[i] for i in permutation_order] 111 | 112 | return batches 113 | 114 | def __iter__(self): 115 | if self._can_reuse_batches: 116 | batches = self._batches 117 | self._can_reuse_batches = False 118 | else: 119 | batches = self._prepare_batches() 120 | self._batches = batches 121 | 122 | for batch in iter(batches): 123 | yield batch 124 | 125 | def __len__(self): 126 | if not hasattr(self, "_batches"): 127 | self._batches = self._prepare_batches() 128 | self._can_reuse_batches = True 129 | return len(self._batches) 130 | -------------------------------------------------------------------------------- /lib/data/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | 3 | 4 | def build_transforms(cfg, is_train=True): 5 | height = cfg.INPUT.HEIGHT 6 | width = cfg.INPUT.WIDTH 7 | use_aug = cfg.INPUT.USE_AUG 8 | 9 | normalize_transform = T.Normalize( 10 | mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD 11 | ) 12 | 13 | if is_train: 14 | if use_aug: 15 | transform = T.Compose( 16 | [ 17 | T.Resize((height, width)), 18 | T.RandomHorizontalFlip(0.5), 19 | T.Pad(cfg.INPUT.PADDING), 20 | T.RandomCrop((height, width)), 21 | T.ToTensor(), 22 | normalize_transform, 23 | T.RandomErasing(scale=(0.02, 0.4), value=cfg.INPUT.PIXEL_MEAN), 24 | ] 25 | ) 26 | else: 27 | transform = T.Compose( 28 | [ 29 | T.Resize((height, width)), 30 | T.RandomHorizontalFlip(0.5), 31 | T.ToTensor(), 32 | normalize_transform, 33 | ] 34 | ) 35 | else: 36 | transform = T.Compose( 37 | [ 38 | T.Resize((height, width)), 39 | T.ToTensor(), 40 | normalize_transform, 41 | ] 42 | ) 43 | return transform 44 | -------------------------------------------------------------------------------- /lib/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrandonHanx/TextReID/0d00d8e0844fbd3f322147786affcc19d0e42b68/lib/engine/__init__.py -------------------------------------------------------------------------------- /lib/engine/inference.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import time 5 | from collections import defaultdict 6 | 7 | import torch 8 | from tqdm import tqdm 9 | 10 | from lib.data.metrics import evaluation 11 | from lib.utils.comm import all_gather, is_main_process, synchronize 12 | 13 | 14 | def compute_on_dataset(model, data_loader, device): 15 | model.eval() 16 | results_dict = defaultdict(list) 17 | for batch in tqdm(data_loader): 18 | images, captions, image_ids = batch 19 | images = images.to(device) 20 | captions = [caption.to(device) for caption in captions] 21 | with torch.no_grad(): 22 | output = model(images, captions) 23 | for result in output: 24 | for img_id, pred in zip(image_ids, result): 25 | results_dict[img_id].append(pred) 26 | return results_dict 27 | 28 | 29 | def _accumulate_predictions_from_multiple_gpus(predictions_per_gpu): 30 | all_predictions = all_gather(predictions_per_gpu) 31 | if not is_main_process(): 32 | return 33 | # merge the list of dicts 34 | predictions = {} 35 | for p in all_predictions: 36 | predictions.update(p) 37 | # convert a dict where the key is the index in a list 38 | image_ids = list(sorted(predictions.keys())) 39 | if len(image_ids) != image_ids[-1] + 1: 40 | logger = logging.getLogger("PersonSearch.inference") 41 | logger.warning( 42 | "Number of images that were gathered from multiple processes is not " 43 | "a contiguous set. Some images might be missing from the evaluation" 44 | ) 45 | return predictions 46 | 47 | 48 | def inference( 49 | model, 50 | data_loader, 51 | dataset_name="cuhkpedes-test", 52 | device="cuda", 53 | output_folder="", 54 | save_data=True, 55 | rerank=True, 56 | ): 57 | logger = logging.getLogger("PersonSearch.inference") 58 | dataset = data_loader.dataset 59 | logger.info( 60 | "Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)) 61 | ) 62 | 63 | predictions = None 64 | if not os.path.exists(os.path.join(output_folder, "inference_data.npz")): 65 | # convert to a torch.device for efficiency 66 | device = torch.device(device) 67 | num_devices = ( 68 | torch.distributed.get_world_size() 69 | if torch.distributed.is_initialized() 70 | else 1 71 | ) 72 | start_time = time.time() 73 | 74 | predictions = compute_on_dataset(model, data_loader, device) 75 | # wait for all processes to complete before measuring the time 76 | synchronize() 77 | total_time = time.time() - start_time 78 | total_time_str = str(datetime.timedelta(seconds=total_time)) 79 | logger.info( 80 | "Total inference time: {} ({} s / img per device, on {} devices)".format( 81 | total_time_str, total_time * num_devices / len(dataset), num_devices 82 | ) 83 | ) 84 | predictions = _accumulate_predictions_from_multiple_gpus(predictions) 85 | 86 | if not is_main_process(): 87 | return 88 | 89 | return evaluation( 90 | dataset=dataset, 91 | predictions=predictions, 92 | output_folder=output_folder, 93 | save_data=save_data, 94 | rerank=rerank, 95 | topk=[1, 5, 10], 96 | ) 97 | -------------------------------------------------------------------------------- /lib/engine/trainer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | from lib.utils.comm import get_world_size 9 | 10 | from .inference import inference 11 | 12 | 13 | def reduce_loss_dict(loss_dict): 14 | """ 15 | Reduce the loss dictionary from all processes so that process with rank 16 | 0 has the averaged results. Returns a dict with the same fields as 17 | loss_dict, after reduction. 18 | """ 19 | world_size = get_world_size() 20 | if world_size < 2: 21 | return loss_dict 22 | with torch.no_grad(): 23 | loss_names = [] 24 | all_losses = [] 25 | for k in sorted(loss_dict.keys()): 26 | loss_names.append(k) 27 | all_losses.append(loss_dict[k]) 28 | all_losses = torch.stack(all_losses, dim=0) 29 | dist.reduce(all_losses, dst=0) 30 | if dist.get_rank() == 0: 31 | # only main process gets accumulated, so only divide by 32 | # world_size in this case 33 | all_losses /= world_size 34 | reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} 35 | return reduced_losses 36 | 37 | 38 | def do_train( 39 | model, 40 | data_loader, 41 | data_loader_val, 42 | optimizer, 43 | scheduler, 44 | checkpointer, 45 | meters, 46 | device, 47 | checkpoint_period, 48 | evaluate_period, 49 | arguments, 50 | ): 51 | logger = logging.getLogger("PersonSearch.trainer") 52 | logger.info("Start training") 53 | 54 | max_epoch = arguments["max_epoch"] 55 | epoch = arguments["epoch"] 56 | max_iter = max_epoch * len(data_loader) 57 | iteration = arguments["iteration"] 58 | distributed = arguments["distributed"] 59 | 60 | best_top1 = 0.0 61 | start_training_time = time.time() 62 | end = time.time() 63 | 64 | while epoch < max_epoch: 65 | if distributed: 66 | data_loader.sampler.set_epoch(epoch) 67 | 68 | epoch += 1 69 | model.train() 70 | arguments["epoch"] = epoch 71 | 72 | for step, (images, captions, _) in enumerate(data_loader): 73 | data_time = time.time() - end 74 | inner_iter = step 75 | iteration += 1 76 | arguments["iteration"] = iteration 77 | 78 | images = images.to(device) 79 | captions = [caption.to(device) for caption in captions] 80 | 81 | loss_dict = model(images, captions) 82 | losses = sum(loss for loss in loss_dict.values()) 83 | 84 | # reduce losses over all GPUs for logging purposes 85 | loss_dict_reduced = reduce_loss_dict(loss_dict) 86 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 87 | meters.update(loss=losses_reduced, **loss_dict_reduced) 88 | 89 | optimizer.zero_grad() 90 | losses.backward() 91 | optimizer.step() 92 | 93 | batch_time = time.time() - end 94 | end = time.time() 95 | meters.update(time=batch_time, data=data_time) 96 | 97 | eta_seconds = meters.time.global_avg * (max_iter - iteration) 98 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 99 | 100 | if inner_iter % 1 == 0: 101 | logger.info( 102 | meters.delimiter.join( 103 | [ 104 | "eta: {eta}", 105 | "epoch [{epoch}][{inner_iter}/{num_iter}]", 106 | "{meters}", 107 | "lr: {lr:.6f}", 108 | "max mem: {memory:.0f}", 109 | ] 110 | ).format( 111 | eta=eta_string, 112 | epoch=epoch, 113 | inner_iter=inner_iter, 114 | num_iter=len(data_loader), 115 | meters=str(meters), 116 | lr=optimizer.param_groups[-1]["lr"], 117 | memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, 118 | ) 119 | ) 120 | 121 | scheduler.step() 122 | 123 | if epoch % evaluate_period == 0: 124 | top1 = inference(model, data_loader_val[0], save_data=False, rerank=False) 125 | meters.update(top1=top1) 126 | if top1 > best_top1: 127 | best_top1 = top1 128 | checkpointer.save("best", **arguments) 129 | 130 | if epoch % checkpoint_period == 0: 131 | checkpointer.save("epoch_{:d}".format(epoch), **arguments) 132 | 133 | total_training_time = time.time() - start_training_time 134 | total_time_str = str(datetime.timedelta(seconds=total_training_time)) 135 | logger.info( 136 | "Total training time: {} ({:.4f} s / it)".format( 137 | total_time_str, total_training_time / (max_iter) 138 | ) 139 | ) 140 | -------------------------------------------------------------------------------- /lib/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_textual_model, build_visual_model 2 | 3 | __all__ = ["build_textual_model", "build_visual_model"] 4 | -------------------------------------------------------------------------------- /lib/models/backbones/build.py: -------------------------------------------------------------------------------- 1 | from .gru import build_gru 2 | from .m_resnet import build_m_resnet 3 | from .resnet import build_resnet 4 | 5 | 6 | def build_visual_model(cfg): 7 | if cfg.MODEL.VISUAL_MODEL in ["resnet50", "resnet101"]: 8 | return build_resnet(cfg) 9 | if cfg.MODEL.VISUAL_MODEL in ["m_resnet50", "m_resnet101"]: 10 | return build_m_resnet(cfg) 11 | raise NotImplementedError 12 | 13 | 14 | def build_textual_model(cfg): 15 | if cfg.MODEL.TEXTUAL_MODEL == "bigru": 16 | return build_gru(cfg, bidirectional=True) 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /lib/models/backbones/gru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from lib.utils.directory import load_vocab_dict 5 | 6 | 7 | class GRU(nn.Module): 8 | def __init__( 9 | self, 10 | hidden_dim, 11 | vocab_size, 12 | embed_size, 13 | num_layers, 14 | drop_out, 15 | bidirectional, 16 | use_onehot, 17 | root, 18 | ): 19 | super().__init__() 20 | 21 | self.use_onehot = use_onehot 22 | 23 | # word embedding 24 | if use_onehot == "yes": 25 | self.embed = nn.Embedding(vocab_size, embed_size, padding_idx=0) 26 | else: 27 | if vocab_size == embed_size: 28 | self.embed = None 29 | else: 30 | self.embed = nn.Linear(vocab_size, embed_size) 31 | 32 | vocab_dict = load_vocab_dict(root, use_onehot) 33 | assert vocab_size == vocab_dict.shape[1] 34 | self.vocab_dict = torch.tensor(vocab_dict).cuda().float() 35 | 36 | self.gru = nn.GRU( 37 | embed_size, 38 | hidden_dim, 39 | num_layers=num_layers, 40 | dropout=drop_out, 41 | bidirectional=bidirectional, 42 | bias=False, 43 | ) 44 | self.out_channels = hidden_dim * 2 if bidirectional else hidden_dim 45 | 46 | self._init_weight() 47 | 48 | def forward(self, captions): 49 | text = torch.stack([caption.text for caption in captions], dim=1) 50 | text_length = torch.stack([caption.length for caption in captions], dim=1) 51 | 52 | text_length = text_length.view(-1) 53 | text = text.view(-1, text.size(-1)) # b x l 54 | 55 | if not self.use_onehot == "yes": 56 | bs, length = text.shape[0], text.shape[-1] 57 | text = text.view(-1) # bl 58 | text = self.vocab_dict[text].reshape(bs, length, -1) # b x l x vocab_size 59 | if self.embed is not None: 60 | text = self.embed(text) 61 | 62 | gru_out = self.gru_out(text, text_length) 63 | gru_out, _ = torch.max(gru_out, dim=1) 64 | return gru_out 65 | 66 | def gru_out(self, embed, text_length): 67 | 68 | _, idx_sort = torch.sort(text_length, dim=0, descending=True) 69 | _, idx_unsort = torch.sort(idx_sort, dim=0) 70 | 71 | embed_sort = embed.index_select(0, idx_sort) 72 | length_list = text_length[idx_sort] 73 | pack = nn.utils.rnn.pack_padded_sequence( 74 | embed_sort, length_list.cpu(), batch_first=True 75 | ) 76 | 77 | gru_sort_out, _ = self.gru(pack) 78 | gru_sort_out = nn.utils.rnn.pad_packed_sequence(gru_sort_out, batch_first=True) 79 | gru_sort_out = gru_sort_out[0] 80 | 81 | gru_out = gru_sort_out.index_select(0, idx_unsort) 82 | return gru_out 83 | 84 | def _init_weight(self): 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | nn.init.xavier_uniform_(m.weight.data, 1) 88 | nn.init.constant(m.bias.data, 0) 89 | 90 | 91 | def build_gru(cfg, bidirectional): 92 | use_onehot = cfg.MODEL.GRU.ONEHOT 93 | hidden_dim = cfg.MODEL.GRU.NUM_UNITS 94 | vocab_size = cfg.MODEL.GRU.VOCABULARY_SIZE 95 | embed_size = cfg.MODEL.GRU.EMBEDDING_SIZE 96 | num_layer = cfg.MODEL.GRU.NUM_LAYER 97 | drop_out = 1 - cfg.MODEL.GRU.DROPOUT_KEEP_PROB 98 | root = cfg.ROOT 99 | 100 | model = GRU( 101 | hidden_dim, 102 | vocab_size, 103 | embed_size, 104 | num_layer, 105 | drop_out, 106 | bidirectional, 107 | use_onehot, 108 | root, 109 | ) 110 | 111 | if cfg.MODEL.FREEZE: 112 | for m in [model.embed, model.gru]: 113 | m.eval() 114 | for param in m.parameters(): 115 | param.requires_grad = False 116 | 117 | return model 118 | -------------------------------------------------------------------------------- /lib/models/backbones/m_resnet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | 11 | class Bottleneck(nn.Module): 12 | expansion = 4 13 | 14 | def __init__(self, inplanes, planes, stride=1): 15 | super().__init__() 16 | 17 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 18 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 25 | 26 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 27 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 28 | 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = None 31 | self.stride = stride 32 | 33 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 34 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 35 | self.downsample = nn.Sequential( 36 | OrderedDict( 37 | [ 38 | ("-1", nn.AvgPool2d(stride)), 39 | ( 40 | "0", 41 | nn.Conv2d( 42 | inplanes, 43 | planes * self.expansion, 44 | 1, 45 | stride=1, 46 | bias=False, 47 | ), 48 | ), 49 | ("1", nn.BatchNorm2d(planes * self.expansion)), 50 | ] 51 | ) 52 | ) 53 | 54 | def forward(self, x): 55 | identity = x 56 | 57 | out = self.relu(self.bn1(self.conv1(x))) 58 | out = self.relu(self.bn2(self.conv2(out))) 59 | out = self.avgpool(out) 60 | out = self.bn3(self.conv3(out)) 61 | 62 | if self.downsample is not None: 63 | identity = self.downsample(x) 64 | 65 | out += identity 66 | out = self.relu(out) 67 | return out 68 | 69 | 70 | class AttentionPool2d(nn.Module): 71 | def __init__( 72 | self, 73 | spacial_dim, 74 | embed_dim, 75 | num_heads, 76 | output_dim=None, 77 | patch_size=1, 78 | ): 79 | super().__init__() 80 | self.spacial_dim = spacial_dim 81 | self.proj_conv = None 82 | if patch_size > 1: 83 | self.proj_conv = nn.Conv2d( 84 | embed_dim, 85 | embed_dim, 86 | kernel_size=patch_size, 87 | stride=patch_size, 88 | bias=False, 89 | ) 90 | self.positional_embedding = nn.Parameter( 91 | torch.randn( 92 | (spacial_dim[0] // patch_size) * (spacial_dim[1] // patch_size) + 1, 93 | embed_dim, 94 | ) 95 | / embed_dim ** 0.5 96 | ) 97 | self.k_proj = nn.Linear(embed_dim, embed_dim) 98 | self.q_proj = nn.Linear(embed_dim, embed_dim) 99 | self.v_proj = nn.Linear(embed_dim, embed_dim) 100 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 101 | self.num_heads = num_heads 102 | 103 | def forward(self, x): 104 | if self.proj_conv is not None: 105 | x = self.proj_conv(x) 106 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( 107 | 2, 0, 1 108 | ) # NCHW -> (HW)NC 109 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 110 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 111 | x, _ = F.multi_head_attention_forward( 112 | query=x, 113 | key=x, 114 | value=x, 115 | embed_dim_to_check=x.shape[-1], 116 | num_heads=self.num_heads, 117 | q_proj_weight=self.q_proj.weight, 118 | k_proj_weight=self.k_proj.weight, 119 | v_proj_weight=self.v_proj.weight, 120 | in_proj_weight=None, 121 | in_proj_bias=torch.cat( 122 | [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] 123 | ), 124 | bias_k=None, 125 | bias_v=None, 126 | add_zero_attn=False, 127 | dropout_p=0, 128 | out_proj_weight=self.c_proj.weight, 129 | out_proj_bias=self.c_proj.bias, 130 | use_separate_proj_weight=True, 131 | training=self.training, 132 | need_weights=False, 133 | ) 134 | 135 | return x[0] 136 | 137 | 138 | class ModifiedResNet(nn.Module): 139 | """ 140 | A ResNet class that is similar to torchvision's but contains the following changes: 141 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 142 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 143 | - The final pooling layer is a QKV attention instead of an average pool 144 | """ 145 | 146 | def __init__( 147 | self, 148 | layers, 149 | output_dim, 150 | heads, 151 | last_stride=1, 152 | input_resolution=(224, 224), 153 | width=64, 154 | ): 155 | super().__init__() 156 | self.output_dim = output_dim 157 | self.out_channels = output_dim 158 | self.input_resolution = input_resolution 159 | 160 | # the 3-layer stem 161 | self.conv1 = nn.Conv2d( 162 | 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False 163 | ) 164 | self.bn1 = nn.BatchNorm2d(width // 2) 165 | self.conv2 = nn.Conv2d( 166 | width // 2, width // 2, kernel_size=3, padding=1, bias=False 167 | ) 168 | self.bn2 = nn.BatchNorm2d(width // 2) 169 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 170 | self.bn3 = nn.BatchNorm2d(width) 171 | self.avgpool = nn.AvgPool2d(2) 172 | self.relu = nn.ReLU(inplace=True) 173 | 174 | # residual layers 175 | self._inplanes = width # this is a *mutable* variable used during construction 176 | self.layer1 = self._make_layer(width, layers[0]) 177 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 178 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 179 | self.layer4 = self._make_layer(width * 8, layers[3], stride=last_stride) 180 | 181 | embed_dim = width * 32 # the ResNet feature dimension 182 | down_ratio = 16 if last_stride == 1 else 32 183 | spacial_dim = ( 184 | input_resolution[0] // down_ratio, 185 | input_resolution[1] // down_ratio, 186 | ) 187 | self.attnpool = AttentionPool2d(spacial_dim, embed_dim, heads, output_dim) 188 | 189 | def _make_layer(self, planes, blocks, stride=1): 190 | layers = [Bottleneck(self._inplanes, planes, stride)] 191 | 192 | self._inplanes = planes * Bottleneck.expansion 193 | for _ in range(1, blocks): 194 | layers.append(Bottleneck(self._inplanes, planes)) 195 | 196 | return nn.Sequential(*layers) 197 | 198 | def forward(self, x): 199 | def stem(x): 200 | for conv, bn in [ 201 | (self.conv1, self.bn1), 202 | (self.conv2, self.bn2), 203 | (self.conv3, self.bn3), 204 | ]: 205 | x = self.relu(bn(conv(x))) 206 | x = self.avgpool(x) 207 | return x 208 | 209 | x = x.type(self.conv1.weight.dtype) 210 | x = stem(x) 211 | x = self.layer1(x) 212 | x = self.layer2(x) 213 | x = self.layer3(x) 214 | x = self.layer4(x) 215 | x = self.attnpool(x) 216 | 217 | return x 218 | 219 | 220 | def resize_pos_embed(posemb, gs_new): 221 | # Rescale the grid of position embeddings when loading from state_dict. 222 | logger = logging.getLogger("PersonSearch.train") 223 | posemb_tok, posemb_grid = posemb[:1], posemb[1:] 224 | gs_old = int(math.sqrt(len(posemb_grid))) 225 | logger.info("Resized position embedding: {} to {}".format((gs_old, gs_old), gs_new)) 226 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 227 | posemb_grid = F.interpolate( 228 | posemb_grid, size=gs_new, mode="bilinear", align_corners=False 229 | ) 230 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(gs_new[0] * gs_new[1], -1) 231 | posemb = torch.cat([posemb_tok, posemb_grid], dim=0) 232 | return posemb 233 | 234 | 235 | def state_filter(state_dict, final_stage_resolution): 236 | out_dict = {} 237 | for k, v in state_dict.items(): 238 | if k.startswith("visual."): 239 | k = k[7:] 240 | if k == "attnpool.positional_embedding" and final_stage_resolution != (7, 7): 241 | v = resize_pos_embed(v, final_stage_resolution) 242 | out_dict[k] = v 243 | return out_dict 244 | 245 | 246 | def modified_resnet50( 247 | input_resolution, 248 | last_stride, 249 | pretrained_path=None, 250 | ): 251 | model = ModifiedResNet( 252 | layers=[3, 4, 6, 3], 253 | output_dim=1024, 254 | heads=32, 255 | last_stride=last_stride, 256 | input_resolution=input_resolution, 257 | ) 258 | if pretrained_path: 259 | p = torch.jit.load(pretrained_path).state_dict() 260 | model.load_state_dict( 261 | state_filter( 262 | p, 263 | final_stage_resolution=model.attnpool.spacial_dim, 264 | ), 265 | strict=False, 266 | ) 267 | return model 268 | 269 | 270 | def modified_resnet101( 271 | input_resolution, 272 | last_stride, 273 | pretrained_path=None, 274 | ): 275 | model = ModifiedResNet( 276 | layers=[3, 4, 23, 3], 277 | output_dim=512, 278 | heads=32, 279 | last_stride=last_stride, 280 | input_resolution=input_resolution, 281 | ) 282 | if pretrained_path: 283 | p = torch.jit.load(pretrained_path).state_dict() 284 | model.load_state_dict( 285 | state_filter( 286 | p, 287 | final_stage_resolution=model.attnpool.spacial_dim, 288 | ), 289 | strict=False, 290 | ) 291 | return model 292 | 293 | 294 | def build_m_resnet(cfg): 295 | if cfg.MODEL.VISUAL_MODEL in ["m_resnet50", "m_resnet"]: 296 | model = modified_resnet50( 297 | (cfg.INPUT.HEIGHT, cfg.INPUT.WIDTH), 298 | cfg.MODEL.RESNET.RES5_STRIDE, 299 | pretrained_path=os.path.join(cfg.ROOT, "pretrained/clip/RN50.pt"), 300 | ) 301 | elif cfg.MODEL.VISUAL_MODEL == "m_resnet101": 302 | model = modified_resnet101( 303 | (cfg.INPUT.HEIGHT, cfg.INPUT.WIDTH), 304 | cfg.MODEL.RESNET.RES5_STRIDE, 305 | pretrained_path=os.path.join(cfg.ROOT, "pretrained/clip/RN101.pt"), 306 | ) 307 | return model 308 | -------------------------------------------------------------------------------- /lib/models/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 9 | """3x3 convolution with padding""" 10 | # original padding is 1; original dilation is 1 11 | return nn.Conv2d( 12 | in_planes, 13 | out_planes, 14 | kernel_size=3, 15 | stride=stride, 16 | padding=dilation, 17 | bias=False, 18 | dilation=dilation, 19 | ) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | # original padding is 1; original dilation is 1 62 | self.conv2 = nn.Conv2d( 63 | planes, 64 | planes, 65 | kernel_size=3, 66 | stride=stride, 67 | padding=dilation, 68 | bias=False, 69 | dilation=dilation, 70 | ) 71 | self.bn2 = nn.BatchNorm2d(planes) 72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.bn3 = nn.BatchNorm2d(planes * 4) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | def __init__( 103 | self, 104 | model_arch, 105 | res5_stride=2, 106 | res5_dilation=1, 107 | pretrained=True, 108 | ): 109 | super().__init__() 110 | block = model_arch.block 111 | layers = model_arch.stage 112 | 113 | self.inplanes = 64 114 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 115 | self.bn1 = nn.BatchNorm2d(64) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 118 | self.layer1 = self._make_layer(block, 64, layers[0]) 119 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 120 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 121 | self.layer4 = self._make_layer( 122 | block, 512, layers[3], stride=res5_stride, dilation=res5_dilation 123 | ) 124 | 125 | if pretrained is None: 126 | self.load_state_dict(remove_fc(model_zoo.load_url(model_arch.url))) 127 | else: 128 | self.load_state_dict(torch.load(pretrained)) 129 | 130 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 131 | self.out_channels = 512 * block.expansion 132 | 133 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 134 | downsample = None 135 | if stride != 1 or self.inplanes != planes * block.expansion: 136 | downsample = nn.Sequential( 137 | nn.Conv2d( 138 | self.inplanes, 139 | planes * block.expansion, 140 | kernel_size=1, 141 | stride=stride, 142 | bias=False, 143 | ), 144 | nn.BatchNorm2d(planes * block.expansion), 145 | ) 146 | 147 | layers = [] 148 | layers.append(block(self.inplanes, planes, stride, downsample, dilation)) 149 | self.inplanes = planes * block.expansion 150 | for i in range(1, blocks): 151 | layers.append(block(self.inplanes, planes)) 152 | 153 | return nn.Sequential(*layers) 154 | 155 | def forward(self, x): 156 | x = self.conv1(x) 157 | x = self.bn1(x) 158 | x = self.relu(x) 159 | x = self.maxpool(x) 160 | 161 | x = self.layer1(x) 162 | x = self.layer2(x) 163 | x = self.layer3(x) 164 | x = self.layer4(x) 165 | x = self.avgpool(x) 166 | 167 | return x 168 | 169 | def _init_weight(self): 170 | for m in self.modules(): 171 | if isinstance(m, nn.Conv2d): 172 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 173 | elif isinstance(m, nn.BatchNorm2d): 174 | nn.init.constant_(m.weight, 1) 175 | nn.init.constant_(m.bias, 0) 176 | 177 | 178 | def remove_fc(state_dict): 179 | """Remove the fc layer parameters from state_dict.""" 180 | for key in list(state_dict.keys()): 181 | if key.startswith("fc."): 182 | del state_dict[key] 183 | return state_dict 184 | 185 | 186 | resnet = namedtuple("resnet", ["block", "stage", "url"]) 187 | model_archs = {} 188 | model_archs["resnet18"] = resnet( 189 | BasicBlock, 190 | [2, 2, 2, 2], 191 | "https://download.pytorch.org/models/resnet18-5c106cde.pth", 192 | ) 193 | model_archs["resnet34"] = resnet( 194 | BasicBlock, 195 | [3, 4, 6, 3], 196 | "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 197 | ) 198 | model_archs["resnet50"] = resnet( 199 | Bottleneck, 200 | [3, 4, 6, 3], 201 | "https://download.pytorch.org/models/resnet50-19c8e357.pth", 202 | ) 203 | model_archs["resnet101"] = resnet( 204 | Bottleneck, 205 | [3, 4, 23, 3], 206 | "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 207 | ) 208 | model_archs["resnet152"] = resnet( 209 | Bottleneck, 210 | [3, 8, 36, 3], 211 | "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 212 | ) 213 | 214 | 215 | def build_resnet(cfg): 216 | arch = cfg.MODEL.VISUAL_MODEL 217 | res5_stride = cfg.MODEL.RESNET.RES5_STRIDE 218 | res5_dilation = cfg.MODEL.RESNET.RES5_DILATION 219 | pretrained = cfg.MODEL.RESNET.PRETRAINED 220 | 221 | model_arch = model_archs[arch] 222 | model = ResNet( 223 | model_arch, 224 | res5_stride, 225 | res5_dilation, 226 | pretrained=pretrained, 227 | ) 228 | 229 | if cfg.MODEL.FREEZE: 230 | for m in [model.conv1, model.bn1, model.layer1, model.layer2, model.layer3]: 231 | m.eval() 232 | for param in m.parameters(): 233 | param.requires_grad = False 234 | 235 | return model 236 | -------------------------------------------------------------------------------- /lib/models/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_embed 2 | 3 | __all__ = ["build_embed"] 4 | -------------------------------------------------------------------------------- /lib/models/embeddings/build.py: -------------------------------------------------------------------------------- 1 | from .simple_head.head import build_simple_head 2 | 3 | 4 | def build_embed(cfg, visual_out_channels, textual_out_channels): 5 | 6 | if cfg.MODEL.EMBEDDING.EMBED_HEAD == "simple": 7 | return build_simple_head(cfg, visual_out_channels, textual_out_channels) 8 | else: 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /lib/models/embeddings/moco_head/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrandonHanx/TextReID/0d00d8e0844fbd3f322147786affcc19d0e42b68/lib/models/embeddings/moco_head/__init__.py -------------------------------------------------------------------------------- /lib/models/embeddings/moco_head/head.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .loss import make_loss_evaluator 8 | 9 | 10 | class MoCoHead(nn.Module): 11 | def __init__( 12 | self, 13 | cfg, 14 | visual_model, 15 | textual_model, 16 | ): 17 | super().__init__() 18 | self.embed_size = cfg.MODEL.EMBEDDING.FEATURE_SIZE 19 | self.K = cfg.MODEL.MOCO.K 20 | self.m = cfg.MODEL.MOCO.M 21 | self.fc = cfg.MODEL.MOCO.FC 22 | 23 | self.v_encoder_q = visual_model 24 | self.t_encoder_q = textual_model 25 | self.v_encoder_k = copy.deepcopy(visual_model) 26 | self.t_encoder_k = copy.deepcopy(textual_model) 27 | for param in self.v_encoder_k.parameters(): 28 | param.requires_grad = False 29 | for param in self.t_encoder_k.parameters(): 30 | param.requires_grad = False 31 | 32 | if self.fc: 33 | self.v_fc_q = nn.Sequential( 34 | nn.Linear(visual_model.out_channels, self.embed_size), 35 | nn.ReLU(), 36 | nn.Linear(self.embed_size, self.embed_size), 37 | ) 38 | self.t_fc_q = nn.Sequential( 39 | nn.Linear(textual_model.out_channels, self.embed_size), 40 | nn.ReLU(), 41 | nn.Linear(self.embed_size, self.embed_size), 42 | ) 43 | self.v_fc_k = copy.deepcopy(self.v_fc_q) 44 | self.t_fc_k = copy.deepcopy(self.t_fc_q) 45 | for param in self.v_fc_k.parameters(): 46 | param.requires_grad = False 47 | for param in self.t_fc_k.parameters(): 48 | param.requires_grad = False 49 | 50 | self.v_embed_layer = nn.Linear(visual_model.out_channels, self.embed_size) 51 | self.t_embed_layer = nn.Linear(textual_model.out_channels, self.embed_size) 52 | 53 | self.register_buffer("t_queue", torch.rand(self.embed_size, self.K)) 54 | self.t_queue = F.normalize(self.t_queue, dim=0) 55 | self.register_buffer("v_queue", torch.rand(self.embed_size, self.K)) 56 | self.v_queue = F.normalize(self.v_queue, dim=0) 57 | # initialize id label as -1 58 | self.register_buffer("id_queue", -torch.ones((1, self.K), dtype=torch.long)) 59 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 60 | 61 | self.loss_evaluator = make_loss_evaluator(cfg) 62 | self._init_weight() 63 | 64 | def _init_weight(self): 65 | for m in self.modules(): 66 | if isinstance(m, nn.Linear): 67 | nn.init.kaiming_normal_(m.weight, a=0, mode="fan_out") 68 | nn.init.constant_(m.bias, 0) 69 | elif isinstance(m, nn.BatchNorm1d): 70 | nn.init.constant_(m.weight, 1) 71 | nn.init.constant_(m.bias, 0) 72 | 73 | @torch.no_grad() 74 | def _momentum_update_key_encoder(self): 75 | """ 76 | Momentum update of the key encoder 77 | """ 78 | for param_q, param_k in zip( 79 | self.v_encoder_q.parameters(), self.v_encoder_k.parameters() 80 | ): 81 | param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) 82 | for param_q, param_k in zip( 83 | self.t_encoder_q.parameters(), self.t_encoder_k.parameters() 84 | ): 85 | param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) 86 | if self.fc: 87 | for param_q, param_k in zip( 88 | self.v_fc_q.parameters(), self.v_fc_k.parameters() 89 | ): 90 | param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) 91 | for param_q, param_k in zip( 92 | self.t_fc_q.parameters(), self.t_fc_k.parameters() 93 | ): 94 | param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) 95 | 96 | @torch.no_grad() 97 | def _dequeue_and_enqueue(self, v_keys, t_keys, id_keys): 98 | batch_size = v_keys.shape[0] 99 | 100 | ptr = int(self.queue_ptr) 101 | assert self.K % batch_size == 0 # for simplicity 102 | 103 | # replace the keys at ptr (dequeue and enqueue) 104 | self.v_queue[:, ptr : ptr + batch_size] = v_keys.T 105 | self.t_queue[:, ptr : ptr + batch_size] = t_keys.T 106 | self.id_queue[:, ptr : ptr + batch_size] = id_keys.T 107 | 108 | ptr = (ptr + batch_size) % self.K # move pointer 109 | self.queue_ptr[0] = ptr 110 | 111 | def forward(self, images, captions): 112 | N = images.shape[0] 113 | 114 | v_embed = self.v_encoder_q(images) 115 | t_embed = self.t_encoder_q(captions) 116 | 117 | if self.training: 118 | if self.fc: 119 | v_embed_q = self.v_fc_q(v_embed) 120 | t_embed_q = self.t_fc_q(t_embed) 121 | v_embed = self.v_embed_layer(v_embed) 122 | t_embed = self.t_embed_layer(t_embed) 123 | v_embed_q = F.normalize(v_embed_q, dim=1) 124 | t_embed_q = F.normalize(t_embed_q, dim=1) 125 | else: 126 | v_embed = self.v_embed_layer(v_embed) 127 | t_embed = self.t_embed_layer(t_embed) 128 | v_embed_q = F.normalize(v_embed, dim=1) 129 | t_embed_q = F.normalize(t_embed, dim=1) 130 | id_q = torch.stack([caption.get_field("id") for caption in captions]).long() 131 | 132 | with torch.no_grad(): 133 | self._momentum_update_key_encoder() 134 | v_embed_k = self.v_encoder_k(images) 135 | if self.fc: 136 | v_embed_k = self.v_fc_k(v_embed_k) 137 | else: 138 | v_embed_k = self.v_embed_layer(v_embed_k) 139 | v_embed_k = F.normalize(v_embed_k, dim=1) 140 | t_embed_k = self.t_encoder_k(captions) 141 | if self.fc: 142 | t_embed_k = self.t_fc_k(t_embed_k) 143 | else: 144 | t_embed_k = self.t_embed_layer(t_embed_k) 145 | t_embed_k = F.normalize(t_embed_k, dim=1) 146 | 147 | # regard same instance ids as positive sapmles, we need filter them out 148 | pos_idx = ( 149 | self.id_queue.expand(N, self.K) 150 | .eq(id_q.unsqueeze(-1)) 151 | .nonzero(as_tuple=False)[:, 1] 152 | ) 153 | unique, counts = torch.unique( 154 | torch.cat([torch.arange(self.K).long().cuda(), pos_idx]), 155 | return_counts=True, 156 | ) 157 | neg_idx = unique[counts == 1] 158 | 159 | # v positive logits: Nx1 160 | v_pos = torch.einsum("nc,nc->n", [v_embed_q, t_embed_k]).unsqueeze(-1) 161 | # v negative logits: NxK 162 | t_queue = self.t_queue.clone().detach() 163 | t_queue = t_queue[:, neg_idx] 164 | v_neg = torch.einsum("nc,ck->nk", [v_embed_q, t_queue]) 165 | # t positive logits: Nx1 166 | t_pos = torch.einsum("nc,nc->n", [t_embed_q, v_embed_k]).unsqueeze(-1) 167 | # t negative logits: NxK 168 | v_queue = self.v_queue.clone().detach() 169 | v_queue = v_queue[:, neg_idx] 170 | t_neg = torch.einsum("nc,ck->nk", [t_embed_q, v_queue]) 171 | 172 | losses = self.loss_evaluator( 173 | v_embed, t_embed, v_pos, v_neg, t_pos, t_neg, id_q 174 | ) 175 | self._dequeue_and_enqueue(v_embed_k, t_embed_k, id_q) 176 | return losses 177 | 178 | v_embed = self.v_embed_layer(v_embed) 179 | t_embed = self.t_embed_layer(t_embed) 180 | outputs = list() 181 | outputs.append(v_embed) 182 | outputs.append(t_embed) 183 | return outputs 184 | 185 | 186 | def build_moco_head(cfg, visual_model, textual_model): 187 | return MoCoHead(cfg, visual_model, textual_model) 188 | -------------------------------------------------------------------------------- /lib/models/embeddings/moco_head/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | import lib.models.losses as losses 6 | 7 | 8 | class LossComputation(nn.Module): 9 | def __init__(self, cfg): 10 | super().__init__() 11 | 12 | self.projection = Parameter( 13 | torch.randn(cfg.MODEL.EMBEDDING.FEATURE_SIZE, cfg.MODEL.NUM_CLASSES), 14 | requires_grad=True, 15 | ) 16 | self.epsilon = cfg.MODEL.EMBEDDING.EPSILON 17 | # self.T = Parameter(torch.tensor(0.07), requires_grad=True) 18 | self.T = 0.07 19 | nn.init.xavier_uniform_(self.projection.data, gain=1) 20 | 21 | def forward(self, v_embed, t_embed, v_pos, v_neg, t_pos, t_neg, labels): 22 | loss = { 23 | "instance_loss": losses.instance_loss( 24 | self.projection, 25 | v_embed, 26 | t_embed, 27 | labels, 28 | epsilon=self.epsilon, 29 | ), 30 | "infonce_loss": losses.infonce_loss( 31 | v_pos, 32 | v_neg, 33 | t_pos, 34 | t_neg, 35 | self.T, 36 | ), 37 | "global_align_loss": losses.global_align_loss(v_embed, t_embed, labels), 38 | } 39 | return loss 40 | 41 | 42 | def make_loss_evaluator(cfg): 43 | return LossComputation(cfg) 44 | -------------------------------------------------------------------------------- /lib/models/embeddings/simple_head/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrandonHanx/TextReID/0d00d8e0844fbd3f322147786affcc19d0e42b68/lib/models/embeddings/simple_head/__init__.py -------------------------------------------------------------------------------- /lib/models/embeddings/simple_head/head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .loss import make_loss_evaluator 4 | 5 | 6 | class SimpleHead(nn.Module): 7 | def __init__( 8 | self, 9 | cfg, 10 | visual_size, 11 | textual_size, 12 | ): 13 | super().__init__() 14 | self.embed_size = cfg.MODEL.EMBEDDING.FEATURE_SIZE 15 | 16 | self.visual_embed_layer = nn.Linear(visual_size, self.embed_size) 17 | self.textual_embed_layer = nn.Linear(textual_size, self.embed_size) 18 | 19 | self.loss_evaluator = make_loss_evaluator(cfg) 20 | self._init_weight() 21 | 22 | def _init_weight(self): 23 | for m in self.modules(): 24 | if isinstance(m, nn.Linear): 25 | nn.init.kaiming_normal_(m.weight, a=0, mode="fan_out") 26 | nn.init.constant_(m.bias, 0) 27 | elif isinstance(m, nn.BatchNorm1d): 28 | nn.init.constant_(m.weight, 1) 29 | nn.init.constant_(m.bias, 0) 30 | 31 | def forward(self, visual_feature, textual_feature, captions): 32 | batch_size = visual_feature.size(0) 33 | 34 | visual_embed = visual_feature.view(batch_size, -1) 35 | textual_embed = textual_feature.view(batch_size, -1) 36 | 37 | visual_embed = self.visual_embed_layer(visual_embed) 38 | textual_embed = self.textual_embed_layer(textual_embed) 39 | 40 | if self.training: 41 | losses = self.loss_evaluator(visual_embed, textual_embed, captions) 42 | return None, losses 43 | 44 | outputs = list() 45 | outputs.append(visual_embed) 46 | outputs.append(textual_embed) 47 | return outputs, None 48 | 49 | 50 | def build_simple_head(cfg, visual_size, textual_size): 51 | model = SimpleHead(cfg, visual_size, textual_size) 52 | return model 53 | -------------------------------------------------------------------------------- /lib/models/embeddings/simple_head/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | import lib.models.losses as losses 6 | 7 | 8 | class LossComputation(nn.Module): 9 | def __init__(self, cfg): 10 | super().__init__() 11 | self.epsilon = cfg.MODEL.EMBEDDING.EPSILON 12 | self.scale_pos = 10.0 13 | self.scale_neg = 40.0 14 | 15 | self.projection = Parameter( 16 | torch.randn(cfg.MODEL.EMBEDDING.FEATURE_SIZE, cfg.MODEL.NUM_CLASSES), 17 | requires_grad=True, 18 | ) 19 | nn.init.xavier_uniform_(self.projection.data, gain=1) 20 | 21 | def forward( 22 | self, 23 | visual_embed, 24 | textual_embed, 25 | captions, 26 | ): 27 | labels = torch.stack([caption.get_field("id") for caption in captions]).long() 28 | loss = { 29 | "instance_loss": losses.instance_loss( 30 | self.projection, 31 | visual_embed, 32 | textual_embed, 33 | labels, 34 | epsilon=self.epsilon, 35 | ), 36 | "global_align_loss": losses.global_align_loss( 37 | visual_embed, 38 | textual_embed, 39 | labels, 40 | scale_pos=self.scale_pos, 41 | scale_neg=self.scale_neg, 42 | ), 43 | } 44 | return loss 45 | 46 | 47 | def make_loss_evaluator(cfg): 48 | return LossComputation(cfg) 49 | -------------------------------------------------------------------------------- /lib/models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CrossEntropyLabelSmooth(nn.Module): 7 | """Cross entropy loss with label smoothing regularizer. 8 | 9 | Reference: 10 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 11 | Equation: y = (1 - epsilon) * y + epsilon / K. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | epsilon (float): weight. 16 | """ 17 | 18 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 19 | super().__init__() 20 | self.num_classes = num_classes 21 | self.epsilon = epsilon 22 | self.use_gpu = use_gpu 23 | self.logsoftmax = nn.LogSoftmax(dim=1) 24 | 25 | def forward(self, inputs, targets): 26 | """ 27 | Args: 28 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 29 | targets: ground truth labels with shape (num_classes) 30 | """ 31 | log_probs = self.logsoftmax(inputs) 32 | targets = torch.zeros(log_probs.size()).scatter_( 33 | 1, targets.unsqueeze(1).data.cpu(), 1 34 | ) 35 | if self.use_gpu: 36 | targets = targets.cuda() 37 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 38 | loss = (-targets * log_probs).mean(0).sum() 39 | return loss 40 | 41 | 42 | def instance_loss( 43 | projection, visual_embed, textual_embed, labels, scale=1, norm=False, epsilon=0.0 44 | ): 45 | if norm: 46 | visual_norm = F.normalize(visual_embed, p=2, dim=-1) 47 | textual_norm = F.normalize(textual_embed, p=2, dim=-1) 48 | else: 49 | visual_norm = visual_embed 50 | textual_norm = textual_embed 51 | projection_norm = F.normalize(projection, p=2, dim=0) 52 | 53 | visual_logits = scale * torch.matmul(visual_norm, projection_norm) 54 | textual_logits = scale * torch.matmul(textual_norm, projection_norm) 55 | 56 | if epsilon > 0: 57 | criterion = CrossEntropyLabelSmooth(num_classes=projection_norm.shape[1]) 58 | else: 59 | criterion = nn.CrossEntropyLoss(reduction="mean") 60 | loss = criterion(visual_logits, labels) + criterion(textual_logits, labels) 61 | 62 | return loss 63 | 64 | 65 | def cmpc_loss(projection, visual_embed, textual_embed, labels, verbose=False): 66 | """ 67 | Cross-Modal Projection Classfication loss (CMPC) 68 | :param image_embeddings: Tensor with dtype torch.float32 69 | :param text_embeddings: Tensor with dtype torch.float32 70 | :param labels: Tensor with dtype torch.int32 71 | :return: 72 | """ 73 | visual_norm = F.normalize(visual_embed, p=2, dim=1) 74 | textual_norm = F.normalize(textual_embed, p=2, dim=1) 75 | projection_norm = F.normalize(projection, p=2, dim=0) 76 | 77 | image_proj_text = ( 78 | torch.sum(visual_embed * textual_norm, dim=1, keepdim=True) * textual_norm 79 | ) 80 | text_proj_image = ( 81 | torch.sum(textual_embed * visual_norm, dim=1, keepdim=True) * visual_norm 82 | ) 83 | 84 | image_logits = torch.matmul(image_proj_text, projection_norm) 85 | text_logits = torch.matmul(text_proj_image, projection_norm) 86 | 87 | criterion = nn.CrossEntropyLoss(reduction="mean") 88 | loss = criterion(image_logits, labels) + criterion(text_logits, labels) 89 | 90 | # classification accuracy for observation 91 | if verbose: 92 | image_pred = torch.argmax(image_logits, dim=1) 93 | text_pred = torch.argmax(text_logits, dim=1) 94 | 95 | image_precision = torch.mean((image_pred == labels).float()) 96 | text_precision = torch.mean((text_pred == labels).float()) 97 | 98 | return loss, image_precision, text_precision 99 | return loss 100 | 101 | 102 | def global_align_loss( 103 | visual_embed, 104 | textual_embed, 105 | labels, 106 | alpha=0.6, 107 | beta=0.4, 108 | scale_pos=10, 109 | scale_neg=40, 110 | ): 111 | batch_size = labels.size(0) 112 | visual_norm = F.normalize(visual_embed, p=2, dim=1) 113 | textual_norm = F.normalize(textual_embed, p=2, dim=1) 114 | similarity = torch.matmul(visual_norm, textual_norm.t()) 115 | labels_ = ( 116 | labels.expand(batch_size, batch_size) 117 | .eq(labels.expand(batch_size, batch_size).t()) 118 | .float() 119 | ) 120 | 121 | pos_inds = labels_ == 1 122 | neg_inds = labels_ == 0 123 | loss_pos = torch.log(1 + torch.exp(-scale_pos * (similarity[pos_inds] - alpha))) 124 | loss_neg = torch.log(1 + torch.exp(scale_neg * (similarity[neg_inds] - beta))) 125 | loss = (loss_pos.sum() + loss_neg.sum()) * 2.0 126 | 127 | loss /= batch_size 128 | return loss 129 | 130 | 131 | def global_align_loss_from_sim( 132 | similarity, 133 | labels, 134 | alpha=0.6, 135 | beta=0.4, 136 | scale_pos=10, 137 | scale_neg=40, 138 | ): 139 | batch_size = labels.size(0) 140 | labels_ = ( 141 | labels.expand(batch_size, batch_size) 142 | .eq(labels.expand(batch_size, batch_size).t()) 143 | .float() 144 | ) 145 | 146 | pos_inds = labels_ == 1 147 | neg_inds = labels_ == 0 148 | loss_pos = torch.log(1 + torch.exp(-scale_pos * (similarity[pos_inds] - alpha))) 149 | loss_neg = torch.log(1 + torch.exp(scale_neg * (similarity[neg_inds] - beta))) 150 | loss = (loss_pos.sum() + loss_neg.sum()) * 2.0 151 | 152 | loss /= batch_size 153 | return loss 154 | 155 | 156 | def cmpm_loss(visual_embed, textual_embed, labels, verbose=False, epsilon=1e-8): 157 | """ 158 | Cross-Modal Projection Matching Loss(CMPM) 159 | :param image_embeddings: Tensor with dtype torch.float32 160 | :param text_embeddings: Tensor with dtype torch.float32 161 | :param labels: Tensor with dtype torch.int32 162 | :return: 163 | i2t_loss: cmpm loss for image projected to text 164 | t2i_loss: cmpm loss for text projected to image 165 | pos_avg_sim: average cosine-similarity for positive pairs 166 | neg_avg_sim: averate cosine-similarity for negative pairs 167 | """ 168 | 169 | batch_size = visual_embed.shape[0] 170 | labels_reshape = torch.reshape(labels, (batch_size, 1)) 171 | labels_dist = labels_reshape - labels_reshape.t() 172 | labels_mask = labels_dist == 0 173 | 174 | visual_norm = F.normalize(visual_embed, p=2, dim=1) 175 | textual_norm = F.normalize(textual_embed, p=2, dim=1) 176 | image_proj_text = torch.matmul(visual_embed, textual_norm.t()) 177 | text_proj_image = torch.matmul(textual_embed, visual_norm.t()) 178 | 179 | # normalize the true matching distribution 180 | labels_mask_norm = labels_mask.float() / labels_mask.float().norm(dim=1) 181 | 182 | i2t_pred = F.softmax(image_proj_text, dim=1) 183 | i2t_loss = i2t_pred * ( 184 | F.log_softmax(image_proj_text, dim=1) - torch.log(labels_mask_norm + epsilon) 185 | ) 186 | 187 | t2i_pred = F.softmax(text_proj_image, dim=1) 188 | t2i_loss = t2i_pred * ( 189 | F.log_softmax(text_proj_image, dim=1) - torch.log(labels_mask_norm + epsilon) 190 | ) 191 | 192 | loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean( 193 | torch.sum(t2i_loss, dim=1) 194 | ) 195 | 196 | if verbose: 197 | sim_cos = torch.matmul(visual_norm, textual_norm.t()) 198 | 199 | pos_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask)) 200 | neg_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask == 0)) 201 | 202 | return loss, pos_avg_sim, neg_avg_sim 203 | return loss 204 | 205 | 206 | def infonce_loss( 207 | v_pos, 208 | v_neg, 209 | t_pos, 210 | t_neg, 211 | T=0.07, 212 | ): 213 | v_logits = torch.cat([v_pos, v_neg], dim=1) / T 214 | t_logits = torch.cat([t_pos, t_neg], dim=1) / T 215 | labels = torch.zeros(v_logits.shape[0], dtype=torch.long).cuda() 216 | loss = F.cross_entropy(v_logits, labels) + F.cross_entropy(t_logits, labels) 217 | return loss 218 | -------------------------------------------------------------------------------- /lib/models/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .backbones import build_textual_model, build_visual_model 4 | from .embeddings import build_embed 5 | from .embeddings.moco_head.head import build_moco_head 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self, cfg): 10 | super().__init__() 11 | self.visual_model = build_visual_model(cfg) 12 | self.textual_model = build_textual_model(cfg) 13 | 14 | if cfg.MODEL.EMBEDDING.EMBED_HEAD == "moco": 15 | self.embed_model = build_moco_head( 16 | cfg, self.visual_model, self.textual_model 17 | ) 18 | self.embed_type = "moco" 19 | else: 20 | self.embed_model = build_embed( 21 | cfg, self.visual_model.out_channels, self.textual_model.out_channels 22 | ) 23 | self.embed_type = "normal" 24 | 25 | def forward(self, images, captions): 26 | if self.embed_type == "moco": 27 | return self.embed_model(images, captions) 28 | 29 | visual_feat = self.visual_model(images) 30 | textual_feat = self.textual_model(captions) 31 | 32 | outputs_embed, losses_embed = self.embed_model( 33 | visual_feat, textual_feat, captions 34 | ) 35 | 36 | if self.training: 37 | losses = {} 38 | losses.update(losses_embed) 39 | return losses 40 | 41 | return outputs_embed 42 | 43 | 44 | def build_model(cfg): 45 | return Model(cfg) 46 | -------------------------------------------------------------------------------- /lib/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .build import make_lr_scheduler, make_optimizer 3 | from .lr_scheduler import LRSchedulerWithWarmup 4 | 5 | __all__ = ["make_lr_scheduler", "make_optimizer", "LRSchedulerWithWarmup"] 6 | -------------------------------------------------------------------------------- /lib/solver/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .lr_scheduler import LRSchedulerWithWarmup 4 | 5 | 6 | def make_optimizer(cfg, model): 7 | params = [] 8 | 9 | for key, value in model.named_parameters(): 10 | if not value.requires_grad: 11 | continue 12 | lr = cfg.SOLVER.BASE_LR 13 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 14 | if "bias" in key: 15 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 16 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 17 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 18 | 19 | if cfg.SOLVER.OPTIMIZER == "SGD": 20 | optimizer = torch.optim.SGD( 21 | params, lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.SGD_MOMENTUM 22 | ) 23 | elif cfg.SOLVER.OPTIMIZER == "Adam": 24 | optimizer = torch.optim.Adam( 25 | params, 26 | lr=cfg.SOLVER.BASE_LR, 27 | betas=(cfg.SOLVER.ADAM_ALPHA, cfg.SOLVER.ADAM_BETA), 28 | eps=1e-8, 29 | ) 30 | elif cfg.SOLVER.OPTIMIZER == "AdamW": 31 | optimizer = torch.optim.AdamW( 32 | params, 33 | lr=cfg.SOLVER.BASE_LR, 34 | betas=(cfg.SOLVER.ADAM_ALPHA, cfg.SOLVER.ADAM_BETA), 35 | eps=1e-8, 36 | ) 37 | else: 38 | NotImplementedError 39 | 40 | return optimizer 41 | 42 | 43 | def make_lr_scheduler(cfg, optimizer): 44 | return LRSchedulerWithWarmup( 45 | optimizer, 46 | milestones=cfg.SOLVER.STEPS, 47 | gamma=cfg.SOLVER.GAMMA, 48 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 49 | warmup_epochs=cfg.SOLVER.WARMUP_EPOCHS, 50 | warmup_method=cfg.SOLVER.WARMUP_METHOD, 51 | total_epochs=cfg.SOLVER.NUM_EPOCHS, 52 | mode=cfg.SOLVER.LRSCHEDULER, 53 | target_lr=cfg.SOLVER.TARGET_LR, 54 | power=cfg.SOLVER.POWER, 55 | ) 56 | -------------------------------------------------------------------------------- /lib/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_right 2 | from math import cos, pi 3 | 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | 7 | class LRSchedulerWithWarmup(_LRScheduler): 8 | def __init__( 9 | self, 10 | optimizer, 11 | milestones, 12 | gamma=0.1, 13 | mode="step", 14 | warmup_factor=1.0 / 3, 15 | warmup_epochs=10, 16 | warmup_method="linear", 17 | total_epochs=100, 18 | target_lr=0, 19 | power=0.9, 20 | last_epoch=-1, 21 | ): 22 | if not list(milestones) == sorted(milestones): 23 | raise ValueError( 24 | "Milestones should be a list of" 25 | " increasing integers. Got {}".format(milestones), 26 | ) 27 | if mode not in ("step", "exp", "poly", "cosine", "linear"): 28 | raise ValueError( 29 | "Only 'step', 'exp', 'poly' or 'cosine' learning rate scheduler accepted" 30 | "got {}".format(mode) 31 | ) 32 | if warmup_method not in ("constant", "linear"): 33 | raise ValueError( 34 | "Only 'constant' or 'linear' warmup_method accepted" 35 | "got {}".format(warmup_method) 36 | ) 37 | self.milestones = milestones 38 | self.mode = mode 39 | self.gamma = gamma 40 | self.warmup_factor = warmup_factor 41 | self.warmup_epochs = warmup_epochs 42 | self.warmup_method = warmup_method 43 | self.total_epochs = total_epochs 44 | self.target_lr = target_lr 45 | self.power = power 46 | super().__init__(optimizer, last_epoch) 47 | 48 | def get_lr(self): 49 | 50 | if self.last_epoch < self.warmup_epochs: 51 | if self.warmup_method == "constant": 52 | warmup_factor = self.warmup_factor 53 | elif self.warmup_method == "linear": 54 | alpha = self.last_epoch / self.warmup_epochs 55 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 56 | return [base_lr * warmup_factor for base_lr in self.base_lrs] 57 | 58 | if self.mode == "step": 59 | return [ 60 | base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) 61 | for base_lr in self.base_lrs 62 | ] 63 | 64 | epoch_ratio = (self.last_epoch - self.warmup_epochs) / ( 65 | self.total_epochs - self.warmup_epochs 66 | ) 67 | 68 | if self.mode == "exp": 69 | factor = epoch_ratio 70 | return [base_lr * self.power ** factor for base_lr in self.base_lrs] 71 | if self.mode == "linear": 72 | factor = 1 - epoch_ratio 73 | return [base_lr * factor for base_lr in self.base_lrs] 74 | 75 | if self.mode == "poly": 76 | factor = 1 - epoch_ratio 77 | return [ 78 | self.target_lr + (base_lr - self.target_lr) * self.power ** factor 79 | for base_lr in self.base_lrs 80 | ] 81 | if self.mode == "cosine": 82 | factor = 0.5 * (1 + cos(pi * epoch_ratio)) 83 | return [ 84 | self.target_lr + (base_lr - self.target_lr) * factor 85 | for base_lr in self.base_lrs 86 | ] 87 | raise NotImplementedError 88 | -------------------------------------------------------------------------------- /lib/utils/caption.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Caption(object): 5 | def __init__( 6 | self, text, length=None, max_length=None, padded=False, dtype=torch.int64 7 | ): 8 | device = text.device if isinstance(text, torch.Tensor) else torch.device("cpu") 9 | if isinstance(text, list): 10 | text = [torch.as_tensor(line, dtype=dtype, device=device) for line in text] 11 | if length is None: 12 | length = torch.stack( 13 | [ 14 | torch.tensor(line.size(0), dtype=torch.int64, device=device) 15 | for line in text 16 | ] 17 | ) 18 | if max_length is None: 19 | max_length = max([line.size(-1) for line in text]) 20 | elif isinstance(text, str): 21 | if length is None: 22 | length = len(text.split()) 23 | else: 24 | text = torch.as_tensor(text, dtype=dtype, device=device) 25 | if length is None: 26 | length = torch.tensor(text.size(-1), dtype=torch.int64, device=device) 27 | if max_length is None: 28 | max_length = text.size(-1) 29 | 30 | if not padded and not isinstance(text, str): 31 | text = self.pad(text, max_length, device) 32 | 33 | self.text = text 34 | self.length = length 35 | self.max_length = max_length 36 | self.padded = True 37 | self.dtype = dtype 38 | self.extra_fields = {} 39 | 40 | @staticmethod 41 | def pad(text, max_length, device): 42 | padded = [] 43 | for line in text: 44 | length = line.size(0) 45 | if length < max_length: 46 | pad = torch.zeros( 47 | (max_length - length), dtype=torch.int64, device=device 48 | ) 49 | padded.append(torch.cat((line, pad))) 50 | else: 51 | padded.append(line[:max_length]) 52 | return torch.stack(padded) 53 | 54 | def add_field(self, field, field_data): 55 | self.extra_fields[field] = field_data 56 | 57 | def get_field(self, field): 58 | return self.extra_fields[field] 59 | 60 | def has_field(self, field): 61 | return field in self.extra_fields 62 | 63 | def fields(self): 64 | return list(self.extra_fields.keys()) 65 | 66 | # Tensor-like methods 67 | 68 | def to(self, device): 69 | cap = Caption( 70 | self.text, 71 | self.length, 72 | self.max_length, 73 | self.padded, 74 | self.dtype, 75 | ) 76 | if not isinstance(self.text, str): 77 | cap.text = cap.text.to(device) 78 | cap.length = cap.length.to(device) 79 | for k, v in self.extra_fields.items(): 80 | if hasattr(v, "to"): 81 | v = v.to(device) 82 | cap.add_field(k, v) 83 | return cap 84 | 85 | def __getitem__(self, item): 86 | cap = Caption(self.text[item], self.max_length, self.padded) 87 | for k, v in self.extra_fields.items(): 88 | cap.add_field(k, v[item]) 89 | return cap 90 | 91 | def __len__(self): 92 | return len(self.text) 93 | 94 | def __repr__(self): 95 | s = self.__class__.__name__ + "(" 96 | s += "length={}, ".format(self.length) 97 | s += "max_length={}, ".format(self.max_length) 98 | s += "padded={}, ".format(self.padded) 99 | return s 100 | -------------------------------------------------------------------------------- /lib/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | from collections import OrderedDict 5 | 6 | import torch 7 | 8 | 9 | class Checkpointer: 10 | def __init__( 11 | self, 12 | model, 13 | optimizer=None, 14 | scheduler=None, 15 | save_dir="", 16 | save_to_disk=None, 17 | logger=None, 18 | ): 19 | self.model = model 20 | self.optimizer = optimizer 21 | self.scheduler = scheduler 22 | self.save_dir = save_dir 23 | self.save_to_disk = save_to_disk 24 | if logger is None: 25 | logger = logging.getLogger(__name__) 26 | self.logger = logger 27 | 28 | def save(self, name, **kwargs): 29 | if not self.save_dir: 30 | return 31 | 32 | if not self.save_to_disk: 33 | return 34 | 35 | data = {} 36 | data["model"] = self.model.state_dict() 37 | if self.optimizer is not None: 38 | data["optimizer"] = self.optimizer.state_dict() 39 | if self.scheduler is not None: 40 | data["scheduler"] = self.scheduler.state_dict() 41 | data.update(kwargs) 42 | 43 | save_file = os.path.join(self.save_dir, "{}.pth".format(name)) 44 | self.logger.info("Saving checkpoint to {}".format(save_file)) 45 | torch.save(data, save_file) 46 | 47 | def load(self, f=None): 48 | if not f: 49 | # no checkpoint could be found 50 | self.logger.info("No checkpoint found.") 51 | return {} 52 | self.logger.info("Loading checkpoint from {}".format(f)) 53 | checkpoint = self._load_file(f) 54 | self._load_model(checkpoint) 55 | 56 | def resume(self, f=None): 57 | if not f: 58 | # no checkpoint could be found 59 | self.logger.info("No checkpoint found.") 60 | return {} 61 | self.logger.info("Loading checkpoint from {}".format(f)) 62 | checkpoint = self._load_file(f) 63 | self._load_model(checkpoint) 64 | if "optimizer" in checkpoint and self.optimizer: 65 | self.logger.info("Loading optimizer from {}".format(f)) 66 | self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 67 | if "scheduler" in checkpoint and self.scheduler: 68 | self.logger.info("Loading scheduler from {}".format(f)) 69 | self.scheduler.load_state_dict(checkpoint.pop("scheduler")) 70 | # return any further checkpoint data 71 | return checkpoint 72 | 73 | def _load_file(self, f): 74 | return torch.load(f, map_location=torch.device("cpu")) 75 | 76 | def _load_model(self, checkpoint, except_keys=None): 77 | load_state_dict(self.model, checkpoint.pop("model"), except_keys) 78 | 79 | 80 | def check_key(key, except_keys): 81 | if except_keys is None: 82 | return False 83 | else: 84 | for except_key in except_keys: 85 | if except_key in key: 86 | return True 87 | return False 88 | 89 | 90 | def align_and_update_state_dicts(model_state_dict, loaded_state_dict, except_keys=None): 91 | current_keys = sorted(list(model_state_dict.keys())) 92 | loaded_keys = sorted(list(loaded_state_dict.keys())) 93 | # get a matrix of string matches, where each (i, j) entry correspond to the size of the 94 | # loaded_key string, if it matches 95 | match_matrix = [ 96 | len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys 97 | ] 98 | match_matrix = torch.as_tensor(match_matrix).view( 99 | len(current_keys), len(loaded_keys) 100 | ) 101 | max_match_size, idxs = match_matrix.max(1) 102 | # remove indices that correspond to no-match 103 | idxs[max_match_size == 0] = -1 104 | 105 | # used for logging 106 | max_size = max([len(key) for key in current_keys]) if current_keys else 1 107 | max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 108 | log_str_template = "{: <{}} loaded from {: <{}} of shape {}" 109 | logger = logging.getLogger("PersonSearch.checkpoint") 110 | for idx_new, idx_old in enumerate(idxs.tolist()): 111 | if idx_old == -1: 112 | continue 113 | key = current_keys[idx_new] 114 | key_old = loaded_keys[idx_old] 115 | if check_key(key, except_keys): 116 | continue 117 | model_state_dict[key] = loaded_state_dict[key_old] 118 | logger.info( 119 | log_str_template.format( 120 | key, 121 | max_size, 122 | key_old, 123 | max_size_loaded, 124 | tuple(loaded_state_dict[key_old].shape), 125 | ) 126 | ) 127 | 128 | 129 | def strip_prefix_if_present(state_dict, prefix): 130 | keys = sorted(state_dict.keys()) 131 | if not all(key.startswith(prefix) for key in keys): 132 | return state_dict 133 | stripped_state_dict = OrderedDict() 134 | for key, value in state_dict.items(): 135 | stripped_state_dict[key.replace(prefix, "")] = value 136 | return stripped_state_dict 137 | 138 | 139 | def load_state_dict(model, loaded_state_dict, except_keys=None): 140 | model_state_dict = model.state_dict() 141 | # if the state_dict comes from a model that was wrapped in a 142 | # DataParallel or DistributedDataParallel during serialization, 143 | # remove the "module" prefix before performing the matching 144 | loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") 145 | align_and_update_state_dicts(model_state_dict, loaded_state_dict, except_keys) 146 | 147 | # use strict loading 148 | model.load_state_dict(model_state_dict) 149 | -------------------------------------------------------------------------------- /lib/utils/comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains primitives for multi-gpu communication. 3 | This is useful when doing distributed training. 4 | """ 5 | 6 | import pickle 7 | 8 | import torch 9 | import torch.distributed as dist 10 | 11 | 12 | def get_world_size(): 13 | if not dist.is_available(): 14 | return 1 15 | if not dist.is_initialized(): 16 | return 1 17 | return dist.get_world_size() 18 | 19 | 20 | def get_rank(): 21 | if not dist.is_available(): 22 | return 0 23 | if not dist.is_initialized(): 24 | return 0 25 | return dist.get_rank() 26 | 27 | 28 | def is_main_process(): 29 | return get_rank() == 0 30 | 31 | 32 | def synchronize(): 33 | """ 34 | Helper function to synchronize (barrier) among all processes when 35 | using distributed training 36 | """ 37 | if not dist.is_available(): 38 | return 39 | if not dist.is_initialized(): 40 | return 41 | world_size = dist.get_world_size() 42 | if world_size == 1: 43 | return 44 | dist.barrier() 45 | 46 | 47 | def all_gather(data): 48 | """ 49 | Run all_gather on arbitrary picklable data (not necessarily tensors) 50 | Args: 51 | data: any picklable object 52 | Returns: 53 | list[data]: list of data gathered from each rank 54 | """ 55 | world_size = get_world_size() 56 | if world_size == 1: 57 | return [data] 58 | 59 | # serialized to a Tensor 60 | buffer = pickle.dumps(data) 61 | storage = torch.ByteStorage.from_buffer(buffer) 62 | tensor = torch.ByteTensor(storage).to("cuda") 63 | 64 | # obtain Tensor size of each rank 65 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 66 | size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] 67 | dist.all_gather(size_list, local_size) 68 | size_list = [int(size.item()) for size in size_list] 69 | max_size = max(size_list) 70 | 71 | # receiving Tensor from all ranks 72 | # we pad the tensor because torch all_gather does not support 73 | # gathering tensors of different shapes 74 | tensor_list = [] 75 | for _ in size_list: 76 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 77 | if local_size != max_size: 78 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 79 | tensor = torch.cat((tensor, padding), dim=0) 80 | dist.all_gather(tensor_list, tensor) 81 | 82 | data_list = [] 83 | for size, tensor in zip(size_list, tensor_list): 84 | buffer = tensor.cpu().numpy().tobytes()[:size] 85 | data_list.append(pickle.loads(buffer)) 86 | 87 | return data_list 88 | 89 | 90 | def reduce_dict(input_dict, average=True): 91 | """ 92 | Args: 93 | input_dict (dict): all the values will be reduced 94 | average (bool): whether to do average or sum 95 | Reduce the values in the dictionary from all processes so that process with rank 96 | 0 has the averaged results. Returns a dict with the same fields as 97 | input_dict, after reduction. 98 | """ 99 | world_size = get_world_size() 100 | if world_size < 2: 101 | return input_dict 102 | with torch.no_grad(): 103 | names = [] 104 | values = [] 105 | # sort the keys so that they are consistent across processes 106 | for k in sorted(input_dict.keys()): 107 | names.append(k) 108 | values.append(input_dict[k]) 109 | values = torch.stack(values, dim=0) 110 | dist.reduce(values, dst=0) 111 | if dist.get_rank() == 0 and average: 112 | # only main process gets accumulated, so only divide by 113 | # world_size in this case 114 | values /= world_size 115 | reduced_dict = {k: v for k, v in zip(names, values)} 116 | return reduced_dict 117 | -------------------------------------------------------------------------------- /lib/utils/directory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | 6 | def makedir(root): 7 | if not os.path.exists(root): 8 | os.makedirs(root) 9 | 10 | 11 | def load_vocab_dict(root, use_onehot): 12 | if use_onehot == "bert_c4": 13 | vocab_dict = np.load( 14 | os.path.join(root, "./datasets/cuhkpedes/bert_vocab_c4.npy") 15 | ) 16 | elif use_onehot == "bert_l2": 17 | vocab_dict = np.load( 18 | os.path.join(root, "./datasets/cuhkpedes/bert_vocab_l2.npy") 19 | ) 20 | elif use_onehot == "clip_vit": 21 | vocab_dict = np.load( 22 | os.path.join(root, "./datasets/cuhkpedes/clip_vocab_vit.npy") 23 | ) 24 | elif use_onehot == "clip_rn50x4": 25 | vocab_dict = np.load( 26 | os.path.join(root, "./datasets/cuhkpedes/clip_vocab_rn50x4.npy") 27 | ) 28 | else: 29 | NotImplementedError 30 | return vocab_dict 31 | -------------------------------------------------------------------------------- /lib/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | import sys 5 | 6 | from tabulate import tabulate 7 | 8 | 9 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt"): 10 | logger = logging.getLogger(name) 11 | logger.setLevel(logging.DEBUG) 12 | # don't log results for the non-master process 13 | if distributed_rank > 0: 14 | return logger 15 | ch = logging.StreamHandler(stream=sys.stdout) 16 | ch.setLevel(logging.DEBUG) 17 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 18 | ch.setFormatter(formatter) 19 | logger.addHandler(ch) 20 | 21 | if save_dir: 22 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode="w") 23 | fh.setLevel(logging.DEBUG) 24 | fh.setFormatter(formatter) 25 | logger.addHandler(fh) 26 | 27 | return logger 28 | 29 | 30 | def table_log(cols, headers): 31 | return tabulate(cols, headers=headers, tablefmt="grid") 32 | -------------------------------------------------------------------------------- /lib/utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import time 3 | from collections import defaultdict, deque 4 | from datetime import datetime 5 | 6 | import torch 7 | 8 | from .comm import is_main_process 9 | 10 | 11 | class SmoothedValue(object): 12 | """Track a series of values and provide access to smoothed values over a 13 | window or the global series average. 14 | """ 15 | 16 | def __init__(self, window_size=20): 17 | self.deque = deque(maxlen=window_size) 18 | self.series = [] 19 | self.total = 0.0 20 | self.count = 0 21 | 22 | def update(self, value): 23 | self.deque.append(value) 24 | self.series.append(value) 25 | self.count += 1 26 | self.total += value 27 | 28 | @property 29 | def median(self): 30 | d = torch.tensor(list(self.deque)) 31 | return d.median().item() 32 | 33 | @property 34 | def avg(self): 35 | d = torch.tensor(list(self.deque)) 36 | return d.mean().item() 37 | 38 | @property 39 | def global_avg(self): 40 | return self.total / self.count 41 | 42 | 43 | class MetricLogger(object): 44 | def __init__(self, delimiter="\t"): 45 | self.meters = defaultdict(SmoothedValue) 46 | self.delimiter = delimiter 47 | 48 | def update(self, **kwargs): 49 | for k, v in kwargs.items(): 50 | if isinstance(v, torch.Tensor): 51 | v = v.item() 52 | assert isinstance(v, (float, int)) 53 | self.meters[k].update(v) 54 | 55 | def __getattr__(self, attr): 56 | if attr in self.meters: 57 | return self.meters[attr] 58 | if attr in self.__dict__: 59 | return self.__dict__[attr] 60 | raise AttributeError( 61 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 62 | ) 63 | 64 | def __str__(self): 65 | loss_str = [] 66 | for name, meter in self.meters.items(): 67 | loss_str.append( 68 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) 69 | ) 70 | return self.delimiter.join(loss_str) 71 | 72 | 73 | class TensorboardLogger(MetricLogger): 74 | def __init__(self, log_dir, start_iter=0, delimiter="\t"): 75 | super(TensorboardLogger, self).__init__(delimiter) 76 | self.iteration = start_iter 77 | self.writer = self._get_tensorboard_writer(log_dir) 78 | 79 | @staticmethod 80 | def _get_tensorboard_writer(log_dir): 81 | try: 82 | from tensorboardX import SummaryWriter 83 | except ImportError: 84 | raise ImportError( 85 | "To use tensorboard please install tensorboardX " 86 | "[ pip install tensorflow tensorboardX ]." 87 | ) 88 | 89 | if is_main_process(): 90 | timestamp = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H:%M") 91 | tb_logger = SummaryWriter("{}-{}".format(log_dir, timestamp)) 92 | return tb_logger 93 | else: 94 | return None 95 | 96 | def update(self, **kwargs): 97 | super(TensorboardLogger, self).update(**kwargs) 98 | if self.writer: 99 | for k, v in kwargs.items(): 100 | if isinstance(v, torch.Tensor): 101 | v = v.item() 102 | assert isinstance(v, (float, int)) 103 | self.writer.add_scalar(k, v, self.iteration) 104 | self.iteration += 1 105 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backports.entry-points-selectable==1.1.1 2 | certifi==2021.10.8 3 | cfgv==3.3.1 4 | distlib==0.3.3 5 | filelock==3.4.0 6 | identify==2.4.0 7 | importlib-metadata==4.8.2 8 | nodeenv==1.6.0 9 | numpy==1.21.4 10 | Pillow==8.4.0 11 | platformdirs==2.4.0 12 | pre-commit==2.16.0 13 | PyYAML==6.0 14 | six==1.16.0 15 | tabulate==0.8.9 16 | toml==0.10.2 17 | torch==1.10.0 18 | torchvision==0.11.1 19 | tqdm==4.62.3 20 | typing_extensions==4.0.1 21 | virtualenv==20.10.0 22 | yacs==0.1.8 23 | zipp==3.6.0 24 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PYTHONHOME="/vol/research/xmodal_dl/txtreid-env/bin" 4 | HOME="/vol/research/xmodal_dl/TextReID" 5 | 6 | echo $HOME 7 | echo 'args:' $@ 8 | 9 | $PYTHONHOME/python $HOME/train_net.py --root $HOME $@ 10 | -------------------------------------------------------------------------------- /run.submit_file: -------------------------------------------------------------------------------- 1 | executable = run.sh 2 | 3 | universe = docker 4 | docker_image = nvidia/cuda:11.1-runtime-ubuntu18.04 5 | 6 | log = condor_log/c$(cluster).p$(process).log 7 | output = condor_log/c$(cluster).p$(process).out 8 | error = condor_log/c$(cluster).p$(process).error 9 | 10 | environment = "mount=/vol/research/xmodal_dl/" 11 | 12 | +CanCheckpoint = True 13 | +GPUMem = 11000 14 | +JobRunTime = 12 15 | 16 | should_transfer_files = True 17 | stream_output = True 18 | 19 | request_GPUs = 1 20 | request_CPUs = 1 21 | request_memory = 11G 22 | requirements = (CUDAGlobalMemoryMb > 4500) && \ 23 | (HasDocker) && \ 24 | (CUDACapability > 2.0) && \ 25 | (CUDADeviceName == "GeForce RTX 3090") 26 | 27 | queue arguments from ( 28 | --config-file $ENV(PWD)/configs/cuhkpedes/moco_gru_cliprn50_ls_bs128_2048.yaml 29 | ) 30 | -------------------------------------------------------------------------------- /test_net.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.nn.parallel 6 | import torch.optim 7 | import torch.utils.data 8 | import torch.utils.data.distributed 9 | 10 | from lib.config import cfg 11 | from lib.data import make_data_loader 12 | from lib.engine.inference import inference 13 | from lib.models.model import build_model 14 | from lib.utils.checkpoint import Checkpointer 15 | from lib.utils.comm import get_rank, synchronize 16 | from lib.utils.directory import makedir 17 | from lib.utils.logger import setup_logger 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser( 22 | description="PyTorch Image-Text Matching Inference" 23 | ) 24 | parser.add_argument( 25 | "--root", 26 | default="./", 27 | help="root path", 28 | type=str, 29 | ) 30 | parser.add_argument( 31 | "--config-file", 32 | default="", 33 | metavar="FILE", 34 | help="path to config file", 35 | type=str, 36 | ) 37 | parser.add_argument( 38 | "--checkpoint-file", 39 | default="", 40 | metavar="FILE", 41 | help="path to checkpoint file", 42 | type=str, 43 | ) 44 | parser.add_argument( 45 | "--local_rank", 46 | default=0, 47 | type=int, 48 | ) 49 | parser.add_argument( 50 | "opts", 51 | help="Modify config options using the command-line", 52 | default=None, 53 | nargs=argparse.REMAINDER, 54 | ) 55 | parser.add_argument( 56 | "--load-result", 57 | help="Use saved reslut as prediction", 58 | action="store_true", 59 | ) 60 | 61 | args = parser.parse_args() 62 | 63 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 64 | distributed = num_gpus > 1 65 | 66 | if distributed: 67 | torch.cuda.set_device(args.local_rank) 68 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 69 | synchronize() 70 | 71 | cfg.merge_from_file(args.config_file) 72 | cfg.merge_from_list(args.opts) 73 | cfg.ROOT = args.root 74 | cfg.freeze() 75 | 76 | model = build_model(cfg) 77 | model.to(cfg.MODEL.DEVICE) 78 | 79 | output_dir = os.path.join( 80 | args.root, "./output", "/".join(args.config_file.split("/")[-2:])[:-5] 81 | ) 82 | checkpointer = Checkpointer(model, save_dir=output_dir) 83 | _ = checkpointer.load(args.checkpoint_file) 84 | 85 | output_folders = list() 86 | dataset_names = cfg.DATASETS.TEST 87 | for dataset_name in dataset_names: 88 | output_folder = os.path.join(output_dir, "inference", dataset_name) 89 | makedir(output_folder) 90 | output_folders.append(output_folder) 91 | 92 | data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) 93 | for output_folder, dataset_name, data_loader_val in zip( 94 | output_folders, dataset_names, data_loaders_val 95 | ): 96 | logger = setup_logger("PersonSearch", output_folder, get_rank()) 97 | logger.info("Using {} GPUs".format(num_gpus)) 98 | logger.info(cfg) 99 | 100 | inference( 101 | model, 102 | data_loader_val, 103 | dataset_name=dataset_name, 104 | device=cfg.MODEL.DEVICE, 105 | output_folder=output_folder, 106 | save_data=False, 107 | rerank=True, 108 | ) 109 | synchronize() 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.parallel 8 | import torch.optim 9 | import torch.utils.data 10 | import torch.utils.data.distributed 11 | 12 | from lib.config import cfg 13 | from lib.data import make_data_loader 14 | from lib.engine.trainer import do_train 15 | from lib.models.model import build_model 16 | from lib.solver import make_lr_scheduler, make_optimizer 17 | from lib.utils.checkpoint import Checkpointer 18 | from lib.utils.comm import get_rank, synchronize 19 | from lib.utils.directory import makedir 20 | from lib.utils.logger import setup_logger 21 | from lib.utils.metric_logger import MetricLogger, TensorboardLogger 22 | 23 | 24 | def set_random_seed(seed=0): 25 | random.seed(seed) 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | 30 | 31 | def train(cfg, output_dir, local_rank, distributed, resume_from, use_tensorboard): 32 | data_loader = make_data_loader( 33 | cfg, 34 | is_train=True, 35 | is_distributed=distributed, 36 | ) 37 | data_loader_val = make_data_loader( 38 | cfg, 39 | is_train=False, 40 | is_distributed=distributed, 41 | ) 42 | model = build_model(cfg) 43 | device = torch.device(cfg.MODEL.DEVICE) 44 | model.to(device) 45 | 46 | optimizer = make_optimizer(cfg, model) 47 | scheduler = make_lr_scheduler(cfg, optimizer) 48 | 49 | if distributed: 50 | model = torch.nn.parallel.DistributedDataParallel( 51 | model, 52 | device_ids=[local_rank], 53 | output_device=local_rank, 54 | # this should be removed if we update BatchNorm stats 55 | broadcast_buffers=False, 56 | ) 57 | 58 | arguments = {} 59 | arguments["iteration"] = 0 60 | arguments["epoch"] = 0 61 | 62 | save_to_disk = get_rank() == 0 63 | checkpointer = Checkpointer(model, optimizer, scheduler, output_dir, save_to_disk) 64 | if cfg.MODEL.WEIGHT != "imagenet": 65 | if os.path.isfile(cfg.MODEL.WEIGHT): 66 | checkpointer.load(cfg.MODEL.WEIGHT) 67 | else: 68 | raise IOError("{} is not a checkpoint file".format(cfg.MODEL.WEIGHT)) 69 | if resume_from: 70 | if os.path.isfile(resume_from): 71 | extra_checkpoint_data = checkpointer.resume(resume_from) 72 | arguments.update(extra_checkpoint_data) 73 | else: 74 | raise IOError("{} is not a checkpoint file".format(resume_from)) 75 | 76 | if use_tensorboard: 77 | meters = TensorboardLogger( 78 | log_dir=os.path.join(output_dir, "tensorboard"), 79 | start_iter=arguments["iteration"], 80 | delimiter=" ", 81 | ) 82 | else: 83 | meters = MetricLogger(delimiter=" ") 84 | 85 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 86 | evaluate_period = cfg.SOLVER.EVALUATE_PERIOD 87 | arguments["max_epoch"] = cfg.SOLVER.NUM_EPOCHS 88 | arguments["distributed"] = distributed 89 | 90 | do_train( 91 | model, 92 | data_loader, 93 | data_loader_val, 94 | optimizer, 95 | scheduler, 96 | checkpointer, 97 | meters, 98 | device, 99 | checkpoint_period, 100 | evaluate_period, 101 | arguments, 102 | ) 103 | 104 | 105 | def main(): 106 | set_random_seed() 107 | 108 | parser = argparse.ArgumentParser(description="PyTorch Person Search Training") 109 | parser.add_argument( 110 | "--root", 111 | default="./", 112 | help="root path", 113 | type=str, 114 | ) 115 | parser.add_argument( 116 | "--config-file", 117 | default="", 118 | metavar="FILE", 119 | help="path to config file", 120 | type=str, 121 | ) 122 | parser.add_argument( 123 | "--resume-from", 124 | help="the checkpoint file to resume from", 125 | type=str, 126 | ) 127 | parser.add_argument( 128 | "--local_rank", 129 | default=0, 130 | type=int, 131 | ) 132 | parser.add_argument( 133 | "opts", 134 | help="Modify config options using the command-line", 135 | default=None, 136 | nargs=argparse.REMAINDER, 137 | ) 138 | parser.add_argument( 139 | "--use-tensorboard", 140 | dest="use_tensorboard", 141 | help="Use tensorboardX logger (Requires tensorboardX and tensorflow installed)", 142 | action="store_true", 143 | default=False, 144 | ) 145 | 146 | args = parser.parse_args() 147 | 148 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 149 | args.distributed = num_gpus > 1 150 | 151 | if args.distributed: 152 | torch.cuda.set_device(args.local_rank) 153 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 154 | synchronize() 155 | 156 | cfg.merge_from_file(args.config_file) 157 | cfg.merge_from_list(args.opts) 158 | cfg.ROOT = args.root 159 | cfg.freeze() 160 | 161 | output_dir = os.path.join( 162 | args.root, "./output", "/".join(args.config_file.split("/")[-2:])[:-5] 163 | ) 164 | makedir(output_dir) 165 | 166 | logger = setup_logger("PersonSearch", output_dir, get_rank()) 167 | logger.info("Using {} GPUs".format(num_gpus)) 168 | logger.info(args) 169 | 170 | logger.info("Loaded configuration file {}".format(args.config_file)) 171 | with open(args.config_file, "r") as cf: 172 | config_str = "\n" + cf.read() 173 | logger.info(config_str) 174 | logger.info("Running with config:\n{}".format(cfg)) 175 | 176 | train( 177 | cfg, 178 | output_dir, 179 | args.local_rank, 180 | args.distributed, 181 | args.resume_from, 182 | args.use_tensorboard, 183 | ) 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | --------------------------------------------------------------------------------