├── .gitignore ├── ICFG-SH ├── vit_itc_ICFG.sh ├── vit_itc_ICFG2.sh ├── vit_itc_ICFG3.sh ├── vit_itc_ICFG4.sh ├── vit_itc_ICFG5.sh ├── vit_itc_ICFG6.sh └── vit_itc_ICFG7.sh ├── README.md ├── data ├── CUHK-PEDES ├── ICFG-PEDES └── bpe_simple_vocab_16e6.txt.gz ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── bases.cpython-36.pyc │ ├── bases.cpython-38.pyc │ ├── build.cpython-36.pyc │ ├── build.cpython-38.pyc │ ├── cuhkpedes.cpython-38.pyc │ ├── cuhkpedes_merge.cpython-38.pyc │ ├── f30k.cpython-38.pyc │ ├── icfgpedes.cpython-38.pyc │ ├── rstpreid.cpython-38.pyc │ ├── sampler.cpython-36.pyc │ ├── sampler.cpython-38.pyc │ ├── sampler_ddp.cpython-36.pyc │ └── sampler_ddp.cpython-38.pyc ├── bases.py ├── build.py ├── cuhkpedes.py ├── cuhkpedes_merge.py ├── f30k.py ├── icfgpedes.py ├── preprocessing.py ├── rstpreid.py ├── sampler.py └── sampler_ddp.py ├── merge.py ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── build.cpython-38.pyc │ ├── clip_model.cpython-38.pyc │ └── objectives.cpython-38.pyc ├── build.py ├── clip_model.py └── objectives.py ├── processor ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── processor.cpython-38.pyc └── processor.py ├── solver ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── build.cpython-38.pyc │ └── lr_scheduler.cpython-38.pyc ├── build.py └── lr_scheduler.py ├── test.py ├── train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-38.pyc ├── checkpoint.cpython-38.pyc ├── comm.cpython-36.pyc ├── comm.cpython-38.pyc ├── iotools.cpython-36.pyc ├── iotools.cpython-38.pyc ├── logger.cpython-38.pyc ├── meter.cpython-38.pyc ├── metrics.cpython-38.pyc ├── options.cpython-38.pyc └── simple_tokenizer.cpython-38.pyc ├── checkpoint.py ├── comm.py ├── iotools.py ├── logger.py ├── meter.py ├── metrics.py ├── options.py └── simple_tokenizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # data/* 2 | **.pth 3 | **.pt 4 | 5 | **/data/ 6 | 7 | **/archive/ 8 | 9 | **.out 10 | **/__pycache__/ 11 | 12 | .vscode/ 13 | .pycharm/ 14 | -------------------------------------------------------------------------------- /ICFG-SH/vit_itc_ICFG.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | flag=0 3 | file=${0%%.*} 4 | name=${file##*/} 5 | # time=$(date +%Y%m%d%H%M%S) 6 | while [ $flag -eq 0 ] 7 | do 8 | count=0 9 | for i in $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits) 10 | do 11 | if [ $i -lt 8000 ] 12 | then 13 | echo 'GPU' $count ' is avaiable, start training...' 14 | CUDA_VISIBLE_DEVICES=$count \ 15 | nohup python train.py \ 16 | --dataset_name 'ICFG-PEDES' \ 17 | --root_dir '/data0/data_ccq/ICFG/' \ 18 | --output_dir '/data1/ccq/multimodality-ICFG' \ 19 | --img_aug \ 20 | --name 'sketch2_add-fusion-twofocal-1-35-fusion-itcloss_05kl-text-label' \ 21 | --fusion_way 'add' \ 22 | --batch_size 64 \ 23 | --pa 0.1 \ 24 | --pretrain_choice 'ViT-B/16' \ 25 | --loss_names 'itc' \ 26 | --lrscheduler 'cosine' \ 27 | --target_lr 0 \ 28 | --num_epoch 60 \ 29 | --al 1.0 \ 30 | --ga 3.5 \ 31 | --klp 0.5 \ 32 | --focal_three_fusion_loss3 \ 33 | > scripts/ICFG-PEDES/ViT/nohup.out 34 | _pid=$! 35 | echo "training pid: $_pid" 36 | flag=1 37 | break 38 | fi 39 | count=$(($count+1)) 40 | done 41 | sleep 20 42 | done 43 | 44 | #--name $name \ 45 | # --root_dir '/data0/data_ccq/CUHK-PEDES/' \ -------------------------------------------------------------------------------- /ICFG-SH/vit_itc_ICFG2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | flag=0 3 | file=${0%%.*} 4 | name=${file##*/} 5 | # time=$(date +%Y%m%d%H%M%S) 6 | while [ $flag -eq 0 ] 7 | do 8 | count=0 9 | for i in $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits) 10 | do 11 | if [ $i -lt 8000 ] 12 | then 13 | echo 'GPU' $count ' is avaiable, start training...' 14 | CUDA_VISIBLE_DEVICES=$count \ 15 | nohup python train.py \ 16 | --dataset_name 'ICFG-PEDES' \ 17 | --root_dir '/data0/data_ccq/ICFG/' \ 18 | --output_dir '/data1/ccq/multimodality-ICFG' \ 19 | --img_aug \ 20 | --name 'sketch2_add-fusion-twofocal-1-35-fusion-itcloss_03kl-text-label' \ 21 | --fusion_way 'add' \ 22 | --batch_size 64 \ 23 | --pa 0.1 \ 24 | --pretrain_choice 'ViT-B/16' \ 25 | --loss_names 'itc' \ 26 | --lrscheduler 'cosine' \ 27 | --target_lr 0 \ 28 | --num_epoch 60 \ 29 | --al 1.0 \ 30 | --ga 3.5 \ 31 | --klp 0.3 \ 32 | --focal_three_fusion_loss3 \ 33 | > scripts/ICFG-PEDES/ViT/nohup2.out & 34 | _pid=$! 35 | echo "training pid: $_pid" 36 | flag=1 37 | break 38 | fi 39 | count=$(($count+1)) 40 | done 41 | sleep 20 42 | done 43 | 44 | #--name $name \ 45 | # --root_dir '/data0/data_ccq/CUHK-PEDES/' \ -------------------------------------------------------------------------------- /ICFG-SH/vit_itc_ICFG3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | flag=0 3 | file=${0%%.*} 4 | name=${file##*/} 5 | # time=$(date +%Y%m%d%H%M%S) 6 | while [ $flag -eq 0 ] 7 | do 8 | count=0 9 | for i in $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits) 10 | do 11 | if [ $i -lt 8000 ] 12 | then 13 | echo 'GPU' $count ' is avaiable, start training...' 14 | CUDA_VISIBLE_DEVICES=$count \ 15 | nohup python train.py \ 16 | --dataset_name 'ICFG-PEDES' \ 17 | --root_dir '/data0/data_ccq/ICFG/' \ 18 | --output_dir '/data1/ccq/multimodality-ICFG' \ 19 | --img_aug \ 20 | --name 'sketch2_sketch-text-jointtrain_itcloss' \ 21 | --fusion_way 'no' \ 22 | --batch_size 64 \ 23 | --pa 0.1 \ 24 | --pretrain_choice 'ViT-B/16' \ 25 | --loss_names 'itc' \ 26 | --lrscheduler 'cosine' \ 27 | --target_lr 0 \ 28 | --num_epoch 60 \ 29 | --al 1.0 \ 30 | --ga 2.0 \ 31 | --klp 1.0 \ 32 | > scripts/ICFG-PEDES/ViT/nohup3.out & 33 | _pid=$! 34 | echo "training pid: $_pid" 35 | flag=1 36 | break 37 | fi 38 | count=$(($count+1)) 39 | done 40 | sleep 20 41 | done 42 | 43 | #--name $name \ 44 | # --root_dir '/data0/data_ccq/CUHK-PEDES/' \ -------------------------------------------------------------------------------- /ICFG-SH/vit_itc_ICFG4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | flag=0 3 | file=${0%%.*} 4 | name=${file##*/} 5 | # time=$(date +%Y%m%d%H%M%S) 6 | while [ $flag -eq 0 ] 7 | do 8 | count=0 9 | for i in $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits) 10 | do 11 | if [ $i -lt 8000 ] 12 | then 13 | echo 'GPU' $count ' is avaiable, start training...' 14 | CUDA_VISIBLE_DEVICES=$count \ 15 | nohup python train.py \ 16 | --dataset_name 'ICFG-PEDES' \ 17 | --root_dir '/data0/data_ccq/ICFG/' \ 18 | --output_dir '/data1/ccq/multimodality-ICFG' \ 19 | --img_aug \ 20 | --name 'sketch2_fusion-add-three_itcloss' \ 21 | --fusion_way 'add' \ 22 | --batch_size 64 \ 23 | --pa 0.1 \ 24 | --pretrain_choice 'ViT-B/16' \ 25 | --loss_names 'itc' \ 26 | --lrscheduler 'cosine' \ 27 | --target_lr 0 \ 28 | --num_epoch 60 \ 29 | --al 1.0 \ 30 | --ga 2.0 \ 31 | --klp 1.0 \ 32 | > scripts/ICFG-PEDES/ViT/nohup4.out & 33 | _pid=$! 34 | echo "training pid: $_pid" 35 | flag=1 36 | break 37 | fi 38 | count=$(($count+1)) 39 | done 40 | sleep 20 41 | done 42 | 43 | #--name $name \ 44 | # --root_dir '/data0/data_ccq/CUHK-PEDES/' \ -------------------------------------------------------------------------------- /ICFG-SH/vit_itc_ICFG5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | flag=0 3 | file=${0%%.*} 4 | name=${file##*/} 5 | # time=$(date +%Y%m%d%H%M%S) 6 | while [ $flag -eq 0 ] 7 | do 8 | count=0 9 | for i in $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits) 10 | do 11 | if [ $i -lt 8000 ] 12 | then 13 | echo 'GPU' $count ' is avaiable, start training...' 14 | CUDA_VISIBLE_DEVICES=$count \ 15 | nohup python train.py \ 16 | --dataset_name 'ICFG-PEDES' \ 17 | --root_dir '/data0/data_ccq/ICFG/' \ 18 | --output_dir '/data1/ccq/multimodality-ICFG' \ 19 | --img_aug \ 20 | --name 'sketch2_add-fusion-twofocal-1-35-fusion-itcloss' \ 21 | --fusion_way 'add' \ 22 | --batch_size 64 \ 23 | --pa 0.1 \ 24 | --pretrain_choice 'ViT-B/16' \ 25 | --loss_names 'itc' \ 26 | --lrscheduler 'cosine' \ 27 | --target_lr 0 \ 28 | --num_epoch 60 \ 29 | --al 1.0 \ 30 | --ga 3.5 \ 31 | --klp 1.0 \ 32 | --focal_three_fusion_loss \ 33 | > scripts/ICFG-PEDES/ViT/nohup5.out & 34 | _pid=$! 35 | echo "training pid: $_pid" 36 | flag=1 37 | break 38 | fi 39 | count=$(($count+1)) 40 | done 41 | sleep 20 42 | done 43 | 44 | #--name $name \ 45 | # --root_dir '/data0/data_ccq/CUHK-PEDES/' \ -------------------------------------------------------------------------------- /ICFG-SH/vit_itc_ICFG6.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | flag=0 3 | file=${0%%.*} 4 | name=${file##*/} 5 | # time=$(date +%Y%m%d%H%M%S) 6 | while [ $flag -eq 0 ] 7 | do 8 | count=0 9 | for i in $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits) 10 | do 11 | if [ $i -lt 8000 ] 12 | then 13 | echo 'GPU' $count ' is avaiable, start training...' 14 | CUDA_VISIBLE_DEVICES=$count \ 15 | nohup python train.py \ 16 | --dataset_name 'ICFG-PEDES' \ 17 | --root_dir '/data0/data_ccq/ICFG/' \ 18 | --output_dir '/data1/ccq/multimodality-ICFG' \ 19 | --img_aug \ 20 | --name 'sketch2_add-fusion-twofocal-1-35-fusion-itcloss_07kl-text-label' \ 21 | --fusion_way 'add' \ 22 | --batch_size 64 \ 23 | --pa 0.1 \ 24 | --pretrain_choice 'ViT-B/16' \ 25 | --loss_names 'itc' \ 26 | --lrscheduler 'cosine' \ 27 | --target_lr 0 \ 28 | --num_epoch 60 \ 29 | --al 1.0 \ 30 | --ga 3.5 \ 31 | --klp 0.7 \ 32 | --focal_three_fusion_loss3 \ 33 | > scripts/ICFG-PEDES/ViT/nohup6.out & 34 | _pid=$! 35 | echo "training pid: $_pid" 36 | flag=1 37 | break 38 | fi 39 | count=$(($count+1)) 40 | done 41 | sleep 20 42 | done 43 | 44 | #--name $name \ 45 | # --root_dir '/data0/data_ccq/CUHK-PEDES/' \ -------------------------------------------------------------------------------- /ICFG-SH/vit_itc_ICFG7.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | flag=0 3 | file=${0%%.*} 4 | name=${file##*/} 5 | # time=$(date +%Y%m%d%H%M%S) 6 | while [ $flag -eq 0 ] 7 | do 8 | count=0 9 | for i in $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits) 10 | do 11 | if [ $i -lt 8000 ] 12 | then 13 | echo 'GPU' $count ' is avaiable, start training...' 14 | CUDA_VISIBLE_DEVICES=$count \ 15 | nohup python train.py \ 16 | --dataset_name 'ICFG-PEDES' \ 17 | --root_dir '/data0/data_ccq/ICFG/' \ 18 | --output_dir '/data1/ccq/multimodality-ICFG' \ 19 | --img_aug \ 20 | --name 'sketch2_add-fusion-twofocal-1-35-fusion-itcloss_1kl-text-label' \ 21 | --fusion_way 'add' \ 22 | --batch_size 64 \ 23 | --pa 0.1 \ 24 | --pretrain_choice 'ViT-B/16' \ 25 | --loss_names 'itc' \ 26 | --lrscheduler 'cosine' \ 27 | --target_lr 0 \ 28 | --num_epoch 60 \ 29 | --al 1.0 \ 30 | --ga 3.5 \ 31 | --klp 1.0 \ 32 | --focal_three_fusion_loss3 \ 33 | > scripts/ICFG-PEDES/ViT/nohup7.out & 34 | _pid=$! 35 | echo "training pid: $_pid" 36 | flag=1 37 | break 38 | fi 39 | count=$(($count+1)) 40 | done 41 | sleep 20 42 | done 43 | 44 | #--name $name \ 45 | # --root_dir '/data0/data_ccq/CUHK-PEDES/' \ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Modality-Agnostic Person Re-identification with Descriptive Query CVPR2023 2 | 3 | # Highlight 4 | 1. This paper start the first attempts to investigate the modality-agnostic person re-identification with the descriptive query. 5 | 2. This paper introduces a novel unified person re-identification (UNIReID) architecture based on a dual-encoder to jointly integrate cross-modal and multi-modal task learning. With task-specific modality learning and task-aware dynamic training, UNIReID enhances generalization ability across tasks and domains. 6 | 3. This paper contributes three multi-modal ReID datasets to support unified ReID evaluation. 7 | 8 | # Dataset 9 | Based on existing text-based datasets (CUHK-PEDES, ICFG-PEDES, and RSTPReid), we collect the sketches from photo modality to obtain multi-modality datasets (Tri-CUHK-PEDES, Tri-ICFG-PEDES, and Tri-RSTPReid). The collected sketches can be found in: https://pan.baidu.com/s/1c0h2utqisEx6OzGuoSaQhA (提取码: ndau) Google Drive(https://drive.google.com/file/d/12FIN-93Y4vXqVDVWLvLBwg3q0z0Vtwij/view?usp=sharing). 10 | 11 | # Citation 12 | @inproceedings{chen2023towards, 13 | title={Towards Modality-Agnostic Person Re-identification with Descriptive Query}, 14 | author={Cuiqun Chen, Mang Ye, Ding Jiang}, 15 | booktitle={Conference on Computer Vision and Pattern Recognition 2023}, 16 | year={2023} 17 | } 18 | 19 | # Contact 20 | chencuiqun@whu.edu.cn; yemang@whu.edu.cn. 21 | 22 | -------------------------------------------------------------------------------- /data/CUHK-PEDES: -------------------------------------------------------------------------------- 1 | /data0/data_jd/datasets/t2i_reid/CUHK-PEDES -------------------------------------------------------------------------------- /data/ICFG-PEDES: -------------------------------------------------------------------------------- 1 | /data0/data_jd/datasets/t2i_reid/ICFG-PEDES -------------------------------------------------------------------------------- /data/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/data/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_dataloader -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/bases.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/bases.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/bases.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/bases.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/build.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/build.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/build.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/build.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cuhkpedes.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/cuhkpedes.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cuhkpedes_merge.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/cuhkpedes_merge.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/f30k.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/f30k.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/icfgpedes.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/icfgpedes.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/rstpreid.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/rstpreid.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/sampler.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler_ddp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/sampler_ddp.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler_ddp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/datasets/__pycache__/sampler_ddp.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/build.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torchvision.transforms as T 4 | from torch.utils.data import DataLoader 5 | from datasets.sampler import RandomIdentitySampler 6 | from datasets.sampler_ddp import RandomIdentitySampler_DDP 7 | from torch.utils.data.distributed import DistributedSampler 8 | 9 | from utils.comm import get_world_size 10 | 11 | from .bases import ImageDataset, SketchDataset, ImageTextMSMDataset, ImageTextMSMMLMDataset, TextDataset, ImageTextDataset, ImageTextMCQDataset, ImageTextMaskColorDataset, ImageTextMLMDataset, ImageTextMCQMLMDataset, SketchTextDataset 12 | 13 | from .f30k import F30K 14 | from .cuhkpedes import CUHKPEDES 15 | from .icfgpedes import ICFGPEDES 16 | from .rstpreid import RSTPReid 17 | __factory = {'CUHK-PEDES': CUHKPEDES, 'ICFG-PEDES': ICFGPEDES, 'F30K': F30K, 'RSTPReid': RSTPReid} 18 | 19 | 20 | def build_transforms(img_size=(384, 128), aug=False, is_train=True): 21 | height, width = img_size 22 | 23 | mean = [0.48145466, 0.4578275, 0.40821073] 24 | std = [0.26862954, 0.26130258, 0.27577711] 25 | 26 | if not is_train: 27 | transform = T.Compose([ 28 | T.Resize((height, width)), 29 | T.ToTensor(), 30 | T.Normalize(mean=mean, std=std), 31 | ]) 32 | return transform 33 | 34 | # transform for training 35 | if aug: 36 | transform = T.Compose([ 37 | T.Resize((height, width)), 38 | T.RandomHorizontalFlip(0.5), 39 | T.Pad(10), 40 | T.RandomCrop((height, width)), 41 | T.ToTensor(), 42 | T.Normalize(mean=mean, std=std), 43 | T.RandomErasing(scale=(0.02, 0.4), value=mean), 44 | ]) 45 | else: 46 | transform = T.Compose([ 47 | T.Resize((height, width)), 48 | T.RandomHorizontalFlip(0.5), 49 | T.ToTensor(), 50 | T.Normalize(mean=mean, std=std), 51 | ]) 52 | return transform 53 | 54 | 55 | def collate(batch): 56 | keys = set([key for b in batch for key in b.keys()]) 57 | # turn list of dicts data structure to dict of lists data structure 58 | dict_batch = {k: [dic[k] if k in dic else None for dic in batch] for k in keys} 59 | 60 | batch_tensor_dict = {} 61 | for k, v in dict_batch.items(): 62 | if isinstance(v[0], int): 63 | batch_tensor_dict.update({k: torch.tensor(v)}) 64 | elif torch.is_tensor(v[0]): 65 | batch_tensor_dict.update({k: torch.stack(v)}) 66 | else: 67 | raise TypeError(f"Unexpect data type: {type(v[0])} in a batch.") 68 | 69 | return batch_tensor_dict 70 | 71 | def build_dataloader(args, tranforms=None): 72 | logger = logging.getLogger("CLIP2ReID.dataset") 73 | 74 | num_workers = args.num_workers 75 | dataset = __factory[args.dataset_name](root=args.root_dir, nlp_aug=args.nlp_aug) 76 | 77 | if args.training: 78 | train_transforms = build_transforms(img_size=args.img_size, 79 | aug=args.img_aug, 80 | is_train=True) 81 | val_transforms = build_transforms(img_size=args.img_size, 82 | is_train=False) 83 | 84 | if args.MCQ: 85 | train_set = ImageTextMCQDataset(dataset.train, 86 | train_transforms, 87 | text_length=args.text_length) 88 | elif args.MCM: 89 | train_set = ImageTextMaskColorDataset(dataset.train, 90 | train_transforms, 91 | text_length=args.text_length, 92 | masked_token_rate=args.masked_token_rate, 93 | masked_token_unchanged_rate=args.masked_token_unchanged_rate) 94 | elif args.MLM: 95 | train_set = ImageTextMLMDataset(dataset.train, 96 | train_transforms, 97 | text_length=args.text_length) 98 | elif args.MSM: 99 | train_set = ImageTextMSMDataset(dataset.train, 100 | train_transforms, 101 | text_length=args.text_length) 102 | elif args.MCQMLM: 103 | train_set = ImageTextMCQMLMDataset(dataset.train, 104 | train_transforms, 105 | text_length=args.text_length) 106 | elif args.MSMMLM: 107 | train_set = ImageTextMSMMLMDataset(dataset.train, 108 | train_transforms, 109 | text_length=args.text_length) 110 | else: 111 | train_set = ImageTextDataset(dataset.train, 112 | train_transforms, 113 | text_length=args.text_length) 114 | 115 | num_classes = len(dataset.train_id_container) 116 | 117 | if args.sampler == 'identity': 118 | if args.distributed: 119 | logger.info('using ddp random identity sampler') 120 | logger.info('DISTRIBUTED TRAIN START') 121 | mini_batch_size = args.batch_size // get_world_size() 122 | # TODO wait to fix bugs 123 | data_sampler = RandomIdentitySampler_DDP( 124 | dataset.train, args.batch_size, args.num_instance) 125 | batch_sampler = torch.utils.data.sampler.BatchSampler( 126 | data_sampler, mini_batch_size, True) 127 | # sampler = DistributedSampler(train_set) 128 | # train_loader = DataLoader( 129 | # train_set, 130 | # num_workers=num_workers, 131 | # # sampler=sampler, 132 | # # batch_size=mini_batch_size, 133 | # batch_sampler=batch_sampler, 134 | # collate_fn=collate) 135 | else: 136 | logger.info( 137 | f'using random identity sampler: batch_size: {args.batch_size}, id: {args.batch_size // args.num_instance}, instance: {args.num_instance}' 138 | ) 139 | train_loader = DataLoader(train_set, 140 | batch_size=args.batch_size, 141 | sampler=RandomIdentitySampler( 142 | dataset.train, args.batch_size, 143 | args.num_instance), 144 | num_workers=num_workers, 145 | collate_fn=collate) 146 | elif args.sampler == 'random': 147 | # TODO add distributed condition 148 | logger.info('using random sampler') 149 | train_loader = DataLoader(train_set, 150 | batch_size=args.batch_size, 151 | shuffle=True, 152 | num_workers=num_workers, 153 | collate_fn=collate) 154 | else: 155 | logger.error('unsupported sampler! expected softmax or triplet but got {}'.format(args.sampler)) 156 | 157 | # use test set as validate set 158 | ds = dataset.val if args.val_dataset == 'val' else dataset.test 159 | 160 | val_img_set = ImageDataset(ds['image_pids'], ds['img_paths'], ds['image_ids'], 161 | val_transforms) 162 | val_txt_set = SketchTextDataset(ds['simg_paths'], ds['simage_ids'], ds['caption_pids'], 163 | ds['captions'], val_transforms, 164 | text_length=args.text_length) 165 | 166 | val_sketch_set = SketchDataset(ds['simg_paths'], ds['simage_ids'], ds['simage_pids'], val_transforms) 167 | 168 | 169 | val_img_loader = DataLoader(val_img_set, 170 | batch_size=args.batch_size, 171 | shuffle=False, 172 | num_workers=num_workers) 173 | val_txt_loader = DataLoader(val_txt_set, 174 | batch_size=args.batch_size, 175 | shuffle=False, 176 | num_workers=num_workers) 177 | val_sketch_loader = DataLoader(val_sketch_set, 178 | batch_size=args.batch_size, 179 | shuffle=False, 180 | num_workers=num_workers) 181 | 182 | return train_loader, val_img_loader, val_txt_loader, val_sketch_loader, num_classes 183 | 184 | else: 185 | # build dataloader for testing 186 | if tranforms: 187 | test_transforms = tranforms 188 | else: 189 | test_transforms = build_transforms(img_size=args.img_size, 190 | is_train=False) 191 | 192 | ds = dataset.test 193 | test_img_set = ImageDataset(ds['image_pids'], ds['img_paths'], ds['image_ids'], 194 | test_transforms) 195 | test_txt_set = SketchTextDataset(ds['simg_paths'], ds['simage_ids'], ds['caption_pids'], 196 | ds['captions'], test_transforms, 197 | text_length=args.text_length) 198 | test_sketch_set = SketchDataset(ds['simg_paths'], ds['simage_ids'], ds['simage_pids'], test_transforms) 199 | 200 | test_img_loader = DataLoader(test_img_set, 201 | batch_size=args.test_batch_size, 202 | shuffle=False, 203 | num_workers=num_workers) 204 | test_txt_loader = DataLoader(test_txt_set, 205 | batch_size=args.test_batch_size, 206 | shuffle=False, 207 | num_workers=num_workers) 208 | test_sketch_loader = DataLoader(test_sketch_set, 209 | batch_size=args.batch_size, 210 | shuffle=False, 211 | num_workers=num_workers) 212 | 213 | return test_img_loader, test_txt_loader, test_sketch_loader -------------------------------------------------------------------------------- /datasets/cuhkpedes.py: -------------------------------------------------------------------------------- 1 | from json.encoder import py_encode_basestring 2 | import os.path as op 3 | from typing import List 4 | 5 | from utils.iotools import read_json 6 | from .bases import BaseDataset 7 | import numpy as np 8 | import pdb 9 | 10 | class CUHKPEDES(BaseDataset): 11 | """ 12 | CUHK-PEDES 13 | 14 | Reference: 15 | Person Search With Natural Language Description (CVPR 2017) 16 | 17 | URL: https://openaccess.thecvf.com/content_cvpr_2017/html/Li_Person_Search_With_CVPR_2017_paper.html 18 | 19 | Dataset statistics: 20 | ### identities: 13003 21 | ### images: 40206, (train) (test) (val) 22 | ### captions: 23 | ### 9 images have more than 2 captions 24 | ### 4 identity have only one image 25 | 26 | annotation format: 27 | [{'split', str, 28 | 'captions', list, 29 | 'file_path', str, 30 | 'processed_tokens', list, 31 | 'id', int}...] 32 | """ 33 | dataset_dir = 'CUHK-PEDES' 34 | 35 | def __init__(self, root='', nlp_aug=False, verbose=True): 36 | super(CUHKPEDES, self).__init__() 37 | self.dataset_dir = op.join(root, self.dataset_dir) 38 | self.img_dir = op.join(self.dataset_dir, 'imgs/') 39 | self.simg_dir = op.join(self.dataset_dir, 'imgs-sketch2/') 40 | if nlp_aug: 41 | self.anno_path = op.join(self.dataset_dir, 'nlp_aug.json') 42 | else: 43 | self.anno_path = op.join(self.dataset_dir, 'reid_raw.json') 44 | self._check_before_run() 45 | 46 | self.train_annos, self.test_annos, self.val_annos = self._split_anno(self.anno_path) 47 | 48 | self.train, self.train_id_container = self._process_anno(self.train_annos, training=True) 49 | self.test, self.test_id_container = self._process_anno(self.test_annos) 50 | self.val, self.val_id_container = self._process_anno(self.val_annos) 51 | 52 | if verbose: 53 | self.logger.info("=> CUHK-PEDES Images and Captions are loaded") 54 | self.show_dataset_info() 55 | 56 | 57 | def _split_anno(self, anno_path: str): 58 | train_annos, test_annos, val_annos = [], [], [] 59 | annos = read_json(anno_path) 60 | for anno in annos: 61 | if anno['split'] == 'train': 62 | train_annos.append(anno) 63 | elif anno['split'] == 'test': 64 | test_annos.append(anno) 65 | else: 66 | val_annos.append(anno) 67 | return train_annos, test_annos, val_annos 68 | 69 | 70 | def _process_anno(self, annos: List[dict], training=False): 71 | pid_container = set() 72 | if training: 73 | dataset = [] 74 | image_id = 0 75 | 76 | for anno in annos: 77 | pid = int(anno['id']) - 1 # make pid begin from 0 78 | img_path = op.join(self.img_dir, anno['file_path']) 79 | captions = anno['captions'] # caption list 80 | # if pid not in pid_container: 81 | simg_path = op.join(self.simg_dir, anno['file_path']) 82 | pid_container.add(pid) 83 | 84 | for caption in captions: 85 | dataset.append((pid, image_id, img_path, simg_path, caption)) 86 | image_id += 1 87 | 88 | 89 | for idx, pid in enumerate(pid_container): 90 | # check pid begin from 0 and no break 91 | assert idx == pid, f"idx: {idx} and pid: {pid} are not match" 92 | 93 | # for i in pid_container: 94 | # indexs = [z[0] for z in list(enumerate(images_pid)) if z[1] == i] 95 | # index = int(np.random.choice(indexs,1)) # 每个身份随机选择一张sektch #,选择多张sketch呢,与caption对应张sektch? numpy.random.choice(aaa, 5) 96 | # simg_paths[i] = op.join(self.simg_dir, rgb_name[index]) 97 | # simage_ids[i] = image_ids[index] 98 | 99 | return dataset, pid_container 100 | else: 101 | dataset = {} 102 | img_paths = [] 103 | simg_paths = [] 104 | captions = [] 105 | image_pids = [] 106 | caption_pids = [] 107 | image_id = 0 108 | image_ids = [] 109 | simage_ids = [] 110 | simage_pids = [] 111 | 112 | for anno in annos: 113 | pid = int(anno['id']) 114 | img_path = op.join(self.img_dir, anno['file_path']) 115 | img_paths.append(img_path) 116 | image_pids.append(pid) 117 | caption_list = anno['captions'] # caption list 118 | image_ids.append(image_id) 119 | 120 | # if pid not in pid_container: 121 | simg_path = op.join(self.simg_dir, anno['file_path']) 122 | # simage_id = image_id 123 | # simage_pid = pid 124 | 125 | pid_container.add(pid) 126 | 127 | for caption in caption_list: 128 | captions.append(caption) 129 | caption_pids.append(pid) 130 | 131 | simg_paths.append(simg_path) 132 | simage_ids.append(image_id) 133 | simage_pids.append(pid) 134 | 135 | image_id += 1 136 | 137 | # for i in pid_container: 138 | # indexs = [z[0] for z in list(enumerate(images_pid)) if z[1] == i] 139 | # index = int(np.random.choice(indexs,1)) # 每个身份随机选择一张sektch #,选择多张sketch呢,与caption对应张sektch? numpy.random.choice(aaa, 5) 140 | # simg_paths[i] = op.join(self.simg_dir, rgb_name[index]) 141 | # simage_ids[i] = image_ids[index] 142 | 143 | dataset = { 144 | "image_pids": image_pids, 145 | "img_paths": img_paths, 146 | "image_ids": image_ids, 147 | "simage_pids": simage_pids, 148 | "simg_paths": simg_paths, 149 | "simage_ids": simage_ids, 150 | "caption_pids": caption_pids, 151 | "captions": captions 152 | } 153 | return dataset, pid_container 154 | 155 | 156 | def _check_before_run(self): 157 | """Check if all files are available before going deeper""" 158 | if not op.exists(self.dataset_dir): 159 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 160 | if not op.exists(self.img_dir): 161 | raise RuntimeError("'{}' is not available".format(self.img_dir)) 162 | if not op.exists(self.anno_path): 163 | raise RuntimeError("'{}' is not available".format(self.anno_path)) 164 | -------------------------------------------------------------------------------- /datasets/cuhkpedes_merge.py: -------------------------------------------------------------------------------- 1 | from json.encoder import py_encode_basestring 2 | import os.path as op 3 | from typing import List 4 | import logging, os 5 | import torch 6 | from utils.iotools import read_image 7 | from PIL import Image 8 | from utils.iotools import read_json 9 | from .bases import BaseDataset 10 | import numpy as np 11 | import pdb 12 | 13 | def read_image(img_path): 14 | """Keep reading image until succeed. 15 | This can avoid IOError incurred by heavy IO process.""" 16 | got_img = False 17 | if not op.exists(img_path): 18 | raise IOError("{} does not exist".format(img_path)) 19 | while not got_img: 20 | try: 21 | img = Image.open(img_path).convert('RGB') 22 | got_img = True 23 | except IOError: 24 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 25 | pass 26 | return img 27 | 28 | class CUHKPEDES_M(BaseDataset): 29 | """ 30 | CUHK-PEDES 31 | 32 | Reference: 33 | Person Search With Natural Language Description (CVPR 2017) 34 | 35 | URL: https://openaccess.thecvf.com/content_cvpr_2017/html/Li_Person_Search_With_CVPR_2017_paper.html 36 | 37 | Dataset statistics: 38 | ### identities: 13003 39 | ### images: 40206, (train) (test) (val) 40 | ### captions: 41 | ### 9 images have more than 2 captions 42 | ### 4 identity have only one image 43 | 44 | annotation format: 45 | [{'split', str, 46 | 'captions', list, 47 | 'file_path', str, 48 | 'processed_tokens', list, 49 | 'id', int}...] 50 | """ 51 | dataset_dir = 'CUHK-PEDES' 52 | 53 | def __init__(self, root='', nlp_aug=False, verbose=True): 54 | super(CUHKPEDES_M, self).__init__() 55 | self.dataset_dir = op.join(root, self.dataset_dir) 56 | self.img_dir = op.join(self.dataset_dir, 'imgs/') 57 | self.simg_dir = op.join(self.dataset_dir, 'imgs-sketch2/') 58 | if nlp_aug: 59 | self.anno_path = op.join(self.dataset_dir, 'nlp_aug.json') 60 | else: 61 | self.anno_path = op.join(self.dataset_dir, 'reid_raw.json') 62 | self._check_before_run() 63 | 64 | self.train_annos, self.test_annos, self.val_annos = self._split_anno(self.anno_path) 65 | 66 | self.train, self.train_id_container = self._process_anno(self.train_annos, training=True) 67 | self.test, self.test_id_container = self._process_anno(self.test_annos) 68 | self.val, self.val_id_container = self._process_anno(self.val_annos) 69 | 70 | if verbose: 71 | self.logger.info("=> CUHK-PEDES Images and Captions are loaded") 72 | self.show_dataset_info() 73 | 74 | 75 | def _split_anno(self, anno_path: str): 76 | train_annos, test_annos, val_annos = [], [], [] 77 | annos = read_json(anno_path) 78 | for anno in annos: 79 | if anno['split'] == 'train': 80 | train_annos.append(anno) 81 | elif anno['split'] == 'test': 82 | test_annos.append(anno) 83 | else: 84 | val_annos.append(anno) 85 | return train_annos, test_annos, val_annos 86 | 87 | 88 | def _process_anno(self, annos: List[dict], training=False): 89 | pid_container = set() 90 | if training: 91 | dataset = [] 92 | image_id = 0 93 | images_pid = [] 94 | rgb_names = [] 95 | for anno in annos: 96 | pid = int(anno['id']) - 1 # make pid begin from 0 97 | img_path = op.join(self.img_dir, anno['file_path']) 98 | captions = anno['captions'] # caption list 99 | # if pid not in pid_container: 100 | simg_path = op.join(self.simg_dir, anno['file_path']) 101 | pid_container.add(pid) 102 | images_pid.append(pid) 103 | rgb_names.append(anno['file_path']) 104 | 105 | for caption in captions: 106 | dataset.append((pid, image_id, img_path, simg_path, caption)) 107 | image_id += 1 108 | 109 | 110 | for idx, pid in enumerate(pid_container): 111 | # check pid begin from 0 and no break 112 | assert idx == pid, f"idx: {idx} and pid: {pid} are not match" 113 | 114 | for i in pid_container: 115 | indexs = [z[0] for z in list(enumerate(images_pid)) if z[1] == i] 116 | # index = int(np.random.choice(indexs,1)) # 每个身份随机选择一张sektch #,选择多张sketch呢,与caption对应张sektch? numpy.random.choice(aaa, 5) 117 | j = 0 118 | for ind in indexs: 119 | simg_path = op.join(self.simg_dir, rgb_names[ind]) 120 | if j == 0: 121 | img = read_image(simg_path) 122 | # img.resize((384, 128),Image.ANTIALIAS) 123 | else: 124 | w,h = img.size 125 | img = Image.blend(img, read_image(simg_path).resize((w, h),Image.ANTIALIAS),0.5) 126 | j = j + 1 127 | 128 | img = Image.fromarray(np.uint8(2*np.array(img) / len(indexs))) 129 | path = op.join(self.simg_dir, rgb_names[indexs[0]]).split('/') 130 | out_path = op.join('/data0/data_ccq/CUHK-PEDES/CUHK-PEDES/imgs-sketchmerge/', path[-2]) 131 | if not op.exists(out_path): 132 | os.makedirs(out_path) 133 | img.save(op.join(out_path, path[-1])) 134 | 135 | 136 | return dataset, pid_container 137 | else: 138 | dataset = {} 139 | img_paths = [] 140 | simg_paths = [] 141 | captions = [] 142 | image_pids = [] 143 | caption_pids = [] 144 | image_id = 0 145 | image_ids = [] 146 | simage_ids = [] 147 | simage_pids = [] 148 | rgb_names = [] 149 | 150 | for anno in annos: 151 | pid = int(anno['id']) 152 | img_path = op.join(self.img_dir, anno['file_path']) 153 | img_paths.append(img_path) 154 | image_pids.append(pid) 155 | caption_list = anno['captions'] # caption list 156 | image_ids.append(image_id) 157 | rgb_names.append(anno['file_path']) 158 | # if pid not in pid_container: 159 | simg_path = op.join(self.simg_dir, anno['file_path']) 160 | # simage_id = image_id 161 | # simage_pid = pid 162 | 163 | pid_container.add(pid) 164 | 165 | for caption in caption_list: 166 | captions.append(caption) 167 | caption_pids.append(pid) 168 | 169 | simg_paths.append(simg_path) 170 | simage_ids.append(image_id) 171 | simage_pids.append(pid) 172 | 173 | 174 | image_id += 1 175 | 176 | # for i in pid_container: 177 | # indexs = [z[0] for z in list(enumerate(images_pid)) if z[1] == i] 178 | # index = int(np.random.choice(indexs,1)) # 每个身份随机选择一张sektch #,选择多张sketch呢,与caption对应张sektch? numpy.random.choice(aaa, 5) 179 | # simg_paths[i] = op.join(self.simg_dir, rgb_name[index]) 180 | # simage_ids[i] = image_ids[index] 181 | 182 | for i in pid_container: 183 | indexs = [z[0] for z in list(enumerate(image_pids)) if z[1] == i] 184 | # index = int(np.random.choice(indexs,1)) # 每个身份随机选择一张sektch #,选择多张sketch呢,与caption对应张sektch? numpy.random.choice(aaa, 5) 185 | j = 0 186 | for ind in indexs: 187 | simg_path = op.join(self.simg_dir, rgb_names[ind]) 188 | if j == 0: 189 | img = read_image(simg_path)#.resizeresize((384, 128),Image.ANTIALIAS) 190 | else: 191 | w,h = img.size 192 | img = Image.blend(img, read_image(simg_path).resize((w, h),Image.ANTIALIAS),0.5) 193 | j = j + 1 194 | 195 | img = Image.fromarray(np.uint8(2*np.array(img) / len(indexs))) 196 | path = op.join(self.simg_dir, rgb_names[indexs[0]]).split('/') 197 | out_path = op.join('/data0/data_ccq/CUHK-PEDES/CUHK-PEDES/imgs-sketchmerge/', path[-2]) 198 | if not op.exists(out_path): 199 | os.makedirs(out_path) 200 | img.save(op.join(out_path, path[-1])) 201 | 202 | dataset = { 203 | "image_pids": image_pids, 204 | "img_paths": img_paths, 205 | "image_ids": image_ids, 206 | "simage_pids": simage_pids, 207 | "simg_paths": simg_paths, 208 | "simage_ids": simage_ids, 209 | "caption_pids": caption_pids, 210 | "captions": captions 211 | } 212 | return dataset, pid_container 213 | 214 | 215 | def _check_before_run(self): 216 | """Check if all files are available before going deeper""" 217 | if not op.exists(self.dataset_dir): 218 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 219 | if not op.exists(self.img_dir): 220 | raise RuntimeError("'{}' is not available".format(self.img_dir)) 221 | if not op.exists(self.anno_path): 222 | raise RuntimeError("'{}' is not available".format(self.anno_path)) 223 | -------------------------------------------------------------------------------- /datasets/f30k.py: -------------------------------------------------------------------------------- 1 | import os.path as op 2 | from typing import List 3 | from utils.iotools import read_json 4 | from .bases import BaseDataset 5 | 6 | 7 | class F30K(BaseDataset): 8 | """ 9 | Flickr30K 10 | 11 | Reference: 12 | From image descriptions to visual denotations: 13 | New similarity metrics for semantic inference over event descriptions 14 | 15 | URL: https://aclanthology.org/Q14-1006/ 16 | 17 | Dataset statistics: 18 | The Flickr30k dataset contains 31,000 images collected from Flickr, together with 5 reference sentences provided by human annotators. 19 | 20 | annotation format: 21 | { 22 | 'images': list[{ 23 | 'sentids': [10, 11, 12, 13, 14], 24 | 'imgid': 2, 25 | 'sentences': [{ 26 | 'tokens': ['a', 'child', 'in', 'a', 'pink', 'dress'], 27 | 'raw': 'A child in a pink dress', 28 | 'imgid': 2 , 29 | 'sentid': 10 30 | }, ...], 31 | 'split': 'train', 32 | 'filename': '1000268201.jpg' 33 | }, ...], 34 | 'dataset': 'flickr30k', 35 | } 36 | """ 37 | dataset_dir = 'F30K' 38 | 39 | def __init__(self, root='', nlp_aug=False, verbose=True): 40 | super(F30K, self).__init__() 41 | self.dataset_dir = op.join(root, self.dataset_dir) 42 | self.img_dir = op.join(self.dataset_dir, 'flickr30k-images/') 43 | if nlp_aug: 44 | self.anno_path = op.join(self.dataset_dir, 'nlp_aug.json') 45 | else: 46 | self.anno_path = op.join(self.dataset_dir, 'karpathy/dataset_flickr30k.json') 47 | self._check_before_run() 48 | 49 | self.train_annos, self.test_annos, self.val_annos = self._split_anno(self.anno_path) 50 | 51 | self.train, self.train_id_container = self._process_anno(self.train_annos, training=True) 52 | self.test, self.test_id_container = self._process_anno(self.test_annos) 53 | self.val, self.val_id_container = self._process_anno(self.val_annos) 54 | 55 | if verbose: 56 | self.logger.info("=> F30K Images and Captions are loaded") 57 | self.show_dataset_info() 58 | 59 | 60 | def _split_anno(self, anno_path: str): 61 | train_annos, test_annos, val_annos = [], [], [] 62 | annos = read_json(anno_path)['images'] 63 | for anno in annos: 64 | if anno['split'] == 'train': 65 | train_annos.append(anno) 66 | elif anno['split'] == 'test': 67 | test_annos.append(anno) 68 | else: 69 | val_annos.append(anno) 70 | return train_annos, test_annos, val_annos 71 | 72 | 73 | def _process_anno(self, annos: List[dict], training=False): 74 | pid_container = set() 75 | if training: 76 | dataset = [] 77 | for anno in annos: 78 | img_path = op.join(self.img_dir, anno['filename']) 79 | img_id = anno['imgid'] 80 | sentences = anno['sentences'] # caption list 81 | for sentence in sentences: 82 | assert img_id == sentence['imgid'] 83 | caption = sentence['raw'] 84 | dataset.append((img_id, -1, img_path, caption)) 85 | 86 | return dataset, pid_container 87 | else: 88 | dataset = {} 89 | img_paths = [] 90 | captions = [] 91 | img_ids = [] 92 | caption_pids = [] 93 | for anno in annos: 94 | 95 | img_path = op.join(self.img_dir, anno['filename']) 96 | img_paths.append(img_path) 97 | img_id = anno['imgid'] 98 | img_ids.append(img_id) 99 | 100 | sentences = anno['sentences'] # caption list 101 | for sentence in sentences: 102 | captions.append(sentence['raw']) 103 | caption_pids.append(sentence['imgid']) 104 | dataset = { 105 | "image_pids": img_ids, 106 | "img_paths": img_paths, 107 | "caption_pids": caption_pids, 108 | "captions": captions 109 | } 110 | return dataset, pid_container 111 | 112 | 113 | def _check_before_run(self): 114 | """Check if all files are available before going deeper""" 115 | if not op.exists(self.dataset_dir): 116 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 117 | if not op.exists(self.img_dir): 118 | raise RuntimeError("'{}' is not available".format(self.img_dir)) 119 | if not op.exists(self.anno_path): 120 | raise RuntimeError("'{}' is not available".format(self.anno_path)) 121 | 122 | -------------------------------------------------------------------------------- /datasets/icfgpedes.py: -------------------------------------------------------------------------------- 1 | import os.path as op 2 | from typing import List 3 | 4 | from utils.iotools import read_json 5 | from .bases import BaseDataset 6 | 7 | 8 | class ICFGPEDES(BaseDataset): 9 | """ 10 | ICFG-PEDES 11 | 12 | Reference: 13 | Semantically Self-Aligned Network for Text-to-Image Part-aware Person Re-identification arXiv 2107 14 | 15 | URL: http://arxiv.org/abs/2107.12666 16 | 17 | Dataset statistics: 18 | # identities: 4102 19 | # images: 34674 (train) + 4855 (query) + 14993 (gallery) 20 | # cameras: 15 21 | """ 22 | dataset_dir = 'ICFG-PEDES' 23 | 24 | def __init__(self, root='', nlp_aug=False, verbose=True): 25 | super(ICFGPEDES, self).__init__() 26 | self.dataset_dir = op.join(root, self.dataset_dir) 27 | self.img_dir = op.join(self.dataset_dir, 'imgs/') 28 | self.simg_dir = op.join(self.dataset_dir, 'imgs-sketch/') 29 | if nlp_aug: 30 | self.anno_path = op.join(self.dataset_dir, 'nlp_aug.json') 31 | else: 32 | self.anno_path = op.join(self.dataset_dir, 'ICFG-PEDES.json') 33 | self._check_before_run() 34 | 35 | self.train_annos, self.test_annos, self.val_annos = self._split_anno(self.anno_path) 36 | 37 | self.train, self.train_id_container = self._process_anno(self.train_annos, training=True) 38 | self.test, self.test_id_container = self._process_anno(self.test_annos) 39 | self.val, self.val_id_container = self._process_anno(self.val_annos) 40 | 41 | if verbose: 42 | self.logger.info("=> CUHK-PEDES Images and Captions are loaded") 43 | self.show_dataset_info() 44 | 45 | 46 | def _split_anno(self, anno_path: str): 47 | train_annos, test_annos, val_annos = [], [], [] 48 | annos = read_json(anno_path) 49 | for anno in annos: 50 | if anno['split'] == 'train': 51 | train_annos.append(anno) 52 | elif anno['split'] == 'test': 53 | test_annos.append(anno) 54 | else: 55 | val_annos.append(anno) 56 | return train_annos, test_annos, val_annos 57 | 58 | 59 | def _process_anno(self, annos: List[dict], training=False): 60 | pid_container = set() 61 | if training: 62 | dataset = [] 63 | image_id = 0 64 | for anno in annos: 65 | pid = int(anno['id']) 66 | pid_container.add(pid) 67 | img_path = op.join(self.img_dir, anno['file_path']) 68 | captions = anno['captions'] # caption list 69 | simg_path = op.join(self.simg_dir, anno['file_path']) 70 | 71 | for caption in captions: 72 | dataset.append((pid, image_id, img_path, simg_path, caption)) 73 | image_id += 1 74 | for idx, pid in enumerate(pid_container): 75 | # check pid begin from 0 and no break 76 | assert idx == pid, f"idx: {idx} and pid: {pid} are not match" 77 | return dataset, pid_container 78 | else: 79 | dataset = {} 80 | img_paths = [] 81 | captions = [] 82 | image_pids = [] 83 | caption_pids = [] 84 | image_id = 0 85 | simage_ids = [] 86 | simage_pids = [] 87 | simg_paths = [] 88 | image_ids = [] 89 | 90 | for anno in annos: 91 | pid = int(anno['id']) 92 | pid_container.add(pid) 93 | img_path = op.join(self.img_dir, anno['file_path']) 94 | img_paths.append(img_path) 95 | image_pids.append(pid) 96 | caption_list = anno['captions'] # caption list 97 | image_ids.append(image_id) 98 | simg_path = op.join(self.simg_dir, anno['file_path']) 99 | 100 | for caption in caption_list: 101 | captions.append(caption) 102 | caption_pids.append(pid) 103 | simg_paths.append(simg_path) 104 | simage_ids.append(image_id) 105 | 106 | simage_pids.append(pid) 107 | image_id += 1 108 | 109 | dataset = { 110 | "image_pids": image_pids, 111 | "img_paths": img_paths, 112 | "image_ids": image_ids, 113 | "simage_pids": simage_pids, 114 | "simg_paths": simg_paths, 115 | "simage_ids": simage_ids, 116 | "caption_pids": caption_pids, 117 | "captions": captions 118 | } 119 | return dataset, pid_container 120 | 121 | 122 | def _check_before_run(self): 123 | """Check if all files are available before going deeper""" 124 | if not op.exists(self.dataset_dir): 125 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 126 | if not op.exists(self.img_dir): 127 | raise RuntimeError("'{}' is not available".format(self.img_dir)) 128 | if not op.exists(self.anno_path): 129 | raise RuntimeError("'{}' is not available".format(self.anno_path)) 130 | -------------------------------------------------------------------------------- /datasets/preprocessing.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | 5 | class RandomErasing(object): 6 | """ Randomly selects a rectangle region in an image and erases its pixels. 7 | 'Random Erasing Data Augmentation' by Zhong et al. 8 | See https://arxiv.org/pdf/1708.04896.pdf 9 | Args: 10 | probability: The probability that the Random Erasing operation will be performed. 11 | sl: Minimum proportion of erased area against input image. 12 | sh: Maximum proportion of erased area against input image. 13 | r1: Minimum aspect ratio of erased area. 14 | mean: Erasing value. 15 | """ 16 | 17 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 18 | self.probability = probability 19 | self.mean = mean 20 | self.sl = sl 21 | self.sh = sh 22 | self.r1 = r1 23 | 24 | def __call__(self, img): 25 | 26 | if random.uniform(0, 1) >= self.probability: 27 | return img 28 | 29 | for attempt in range(100): 30 | area = img.size()[1] * img.size()[2] 31 | 32 | target_area = random.uniform(self.sl, self.sh) * area 33 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 34 | 35 | h = int(round(math.sqrt(target_area * aspect_ratio))) 36 | w = int(round(math.sqrt(target_area / aspect_ratio))) 37 | 38 | if w < img.size()[2] and h < img.size()[1]: 39 | x1 = random.randint(0, img.size()[1] - h) 40 | y1 = random.randint(0, img.size()[2] - w) 41 | if img.size()[0] == 3: 42 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 43 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 44 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 45 | else: 46 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 47 | return img 48 | 49 | return img 50 | 51 | -------------------------------------------------------------------------------- /datasets/rstpreid.py: -------------------------------------------------------------------------------- 1 | import os.path as op 2 | from typing import List 3 | 4 | from utils.iotools import read_json 5 | from .bases import BaseDataset 6 | 7 | 8 | class RSTPReid(BaseDataset): 9 | """ 10 | ICFG-PEDES 11 | 12 | Reference: 13 | Semantically Self-Aligned Network for Text-to-Image Part-aware Person Re-identification arXiv 2107 14 | 15 | URL: https://github.com/NjtechCVLab/RSTPReid-Dataset 16 | 17 | Dataset statistics: 18 | # identities: 4101, 3701 + 200 + 200 19 | # Each person has 5 corresponding images taken by different cameras with complex both indoor and outdoor scene, Each image is annotated with 2 textual descriptions. 20 | # images: 18505 (train) + 1000 (val) + 1000 (text) 21 | # cameras: 15 22 | """ 23 | dataset_dir = 'RSTPReid' 24 | 25 | def __init__(self, root='', nlp_aug=False, verbose=True): 26 | super(RSTPReid, self).__init__() 27 | self.dataset_dir = op.join(root, self.dataset_dir) 28 | self.img_dir = op.join(self.dataset_dir, 'imgs/') 29 | self.simg_dir = op.join(self.dataset_dir, 'imgs-sketch/') 30 | if nlp_aug: 31 | ## not implement yet 32 | # self.anno_path = op.join(self.dataset_dir, 'nlp_aug.json') 33 | raise FileNotFoundError 34 | else: 35 | self.anno_path = op.join(self.dataset_dir, 'data_captions.json') 36 | self._check_before_run() 37 | 38 | self.train_annos, self.test_annos, self.val_annos = self._split_anno(self.anno_path) 39 | 40 | self.train, self.train_id_container = self._process_anno(self.train_annos, training=True) 41 | self.test, self.test_id_container = self._process_anno(self.test_annos) 42 | self.val, self.val_id_container = self._process_anno(self.val_annos) 43 | 44 | if verbose: 45 | self.logger.info("=> RSTPReid Images and Captions are loaded") 46 | self.show_dataset_info() 47 | 48 | 49 | def _split_anno(self, anno_path: str): 50 | train_annos, test_annos, val_annos = [], [], [] 51 | annos = read_json(anno_path) 52 | for anno in annos: 53 | if anno['split'] == 'train': 54 | train_annos.append(anno) 55 | elif anno['split'] == 'test': 56 | test_annos.append(anno) 57 | else: 58 | val_annos.append(anno) 59 | return train_annos, test_annos, val_annos 60 | 61 | 62 | def _process_anno(self, annos: List[dict], training=False): 63 | pid_container = set() 64 | if training: 65 | dataset = [] 66 | image_id = 0 67 | for anno in annos: 68 | pid = int(anno['id']) 69 | pid_container.add(pid) 70 | img_path = op.join(self.img_dir, anno['img_path']) 71 | captions = anno['captions'] # caption list 72 | simg_path = op.join(self.simg_dir, anno['img_path']) 73 | 74 | for caption in captions: 75 | dataset.append((pid, image_id, img_path, simg_path, caption)) 76 | image_id += 1 77 | for idx, pid in enumerate(pid_container): 78 | # check pid begin from 0 and no break 79 | assert idx == pid, f"idx: {idx} and pid: {pid} are not match" 80 | return dataset, pid_container 81 | else: 82 | dataset = {} 83 | img_paths = [] 84 | captions = [] 85 | image_pids = [] 86 | caption_pids = [] 87 | image_id = 0 88 | simage_ids = [] 89 | simage_pids = [] 90 | simg_paths = [] 91 | image_ids = [] 92 | 93 | for anno in annos: 94 | pid = int(anno['id']) 95 | pid_container.add(pid) 96 | img_path = op.join(self.img_dir, anno['img_path']) 97 | img_paths.append(img_path) 98 | image_pids.append(pid) 99 | caption_list = anno['captions'] # caption list 100 | image_ids.append(image_id) 101 | simg_path = op.join(self.simg_dir, anno['img_path']) 102 | 103 | for caption in caption_list: 104 | captions.append(caption) 105 | caption_pids.append(pid) 106 | simg_paths.append(simg_path) 107 | simage_ids.append(image_id) 108 | 109 | simage_pids.append(pid) 110 | image_id += 1 111 | 112 | dataset = { 113 | "image_pids": image_pids, 114 | "img_paths": img_paths, 115 | "image_ids": image_ids, 116 | "simage_pids": simage_pids, 117 | "simg_paths": simg_paths, 118 | "simage_ids": simage_ids, 119 | "caption_pids": caption_pids, 120 | "captions": captions 121 | } 122 | return dataset, pid_container 123 | 124 | 125 | def _check_before_run(self): 126 | """Check if all files are available before going deeper""" 127 | if not op.exists(self.dataset_dir): 128 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 129 | if not op.exists(self.img_dir): 130 | raise RuntimeError("'{}' is not available".format(self.img_dir)) 131 | if not op.exists(self.anno_path): 132 | raise RuntimeError("'{}' is not available".format(self.anno_path)) 133 | -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | 7 | class RandomIdentitySampler(Sampler): 8 | """ 9 | Randomly sample N identities, then for each identity, 10 | randomly sample K instances, therefore batch size is N*K. 11 | Args: 12 | - data_source (list): list of (img_path, pid, camid). 13 | - num_instances (int): number of instances per identity in a batch. 14 | - batch_size (int): number of examples in a batch. 15 | """ 16 | 17 | def __init__(self, data_source, batch_size, num_instances): 18 | self.data_source = data_source 19 | self.batch_size = batch_size 20 | self.num_instances = num_instances 21 | self.num_pids_per_batch = self.batch_size // self.num_instances 22 | self.index_dic = defaultdict(list) #dict with list value 23 | #{783: [0, 5, 116, 876, 1554, 2041],...,} 24 | for index, (pid, _, _, _) in enumerate(self.data_source): 25 | self.index_dic[pid].append(index) 26 | self.pids = list(self.index_dic.keys()) 27 | 28 | # estimate number of examples in an epoch 29 | self.length = 0 30 | for pid in self.pids: 31 | idxs = self.index_dic[pid] 32 | num = len(idxs) 33 | if num < self.num_instances: 34 | num = self.num_instances 35 | self.length += num - num % self.num_instances 36 | 37 | def __iter__(self): 38 | batch_idxs_dict = defaultdict(list) 39 | 40 | for pid in self.pids: 41 | idxs = copy.deepcopy(self.index_dic[pid]) 42 | if len(idxs) < self.num_instances: 43 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 44 | random.shuffle(idxs) 45 | batch_idxs = [] 46 | for idx in idxs: 47 | batch_idxs.append(idx) 48 | if len(batch_idxs) == self.num_instances: 49 | batch_idxs_dict[pid].append(batch_idxs) 50 | batch_idxs = [] 51 | 52 | avai_pids = copy.deepcopy(self.pids) 53 | final_idxs = [] 54 | 55 | while len(avai_pids) >= self.num_pids_per_batch: 56 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 57 | for pid in selected_pids: 58 | batch_idxs = batch_idxs_dict[pid].pop(0) 59 | final_idxs.extend(batch_idxs) 60 | if len(batch_idxs_dict[pid]) == 0: 61 | avai_pids.remove(pid) 62 | 63 | return iter(final_idxs) 64 | 65 | def __len__(self): 66 | return self.length 67 | 68 | -------------------------------------------------------------------------------- /datasets/sampler_ddp.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | import math 7 | import torch.distributed as dist 8 | _LOCAL_PROCESS_GROUP = None 9 | import torch 10 | import pickle 11 | 12 | def _get_global_gloo_group(): 13 | """ 14 | Return a process group based on gloo backend, containing all the ranks 15 | The result is cached. 16 | """ 17 | if dist.get_backend() == "nccl": 18 | return dist.new_group(backend="gloo") 19 | else: 20 | return dist.group.WORLD 21 | 22 | def _serialize_to_tensor(data, group): 23 | backend = dist.get_backend(group) 24 | assert backend in ["gloo", "nccl"] 25 | device = torch.device("cpu" if backend == "gloo" else "cuda") 26 | 27 | buffer = pickle.dumps(data) 28 | if len(buffer) > 1024 ** 3: 29 | print( 30 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 31 | dist.get_rank(), len(buffer) / (1024 ** 3), device 32 | ) 33 | ) 34 | storage = torch.ByteStorage.from_buffer(buffer) 35 | tensor = torch.ByteTensor(storage).to(device=device) 36 | return tensor 37 | 38 | def _pad_to_largest_tensor(tensor, group): 39 | """ 40 | Returns: 41 | list[int]: size of the tensor, on each rank 42 | Tensor: padded tensor that has the max size 43 | """ 44 | world_size = dist.get_world_size(group=group) 45 | assert ( 46 | world_size >= 1 47 | ), "comm.gather/all_gather must be called from ranks within the given group!" 48 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 49 | size_list = [ 50 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 51 | ] 52 | dist.all_gather(size_list, local_size, group=group) 53 | size_list = [int(size.item()) for size in size_list] 54 | 55 | max_size = max(size_list) 56 | 57 | # we pad the tensor because torch all_gather does not support 58 | # gathering tensors of different shapes 59 | if local_size != max_size: 60 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 61 | tensor = torch.cat((tensor, padding), dim=0) 62 | return size_list, tensor 63 | 64 | def all_gather(data, group=None): 65 | """ 66 | Run all_gather on arbitrary picklable data (not necessarily tensors). 67 | Args: 68 | data: any picklable object 69 | group: a torch process group. By default, will use a group which 70 | contains all ranks on gloo backend. 71 | Returns: 72 | list[data]: list of data gathered from each rank 73 | """ 74 | if dist.get_world_size() == 1: 75 | return [data] 76 | if group is None: 77 | group = _get_global_gloo_group() 78 | if dist.get_world_size(group) == 1: 79 | return [data] 80 | 81 | tensor = _serialize_to_tensor(data, group) 82 | 83 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 84 | max_size = max(size_list) 85 | 86 | # receiving Tensor from all ranks 87 | tensor_list = [ 88 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 89 | ] 90 | dist.all_gather(tensor_list, tensor, group=group) 91 | 92 | data_list = [] 93 | for size, tensor in zip(size_list, tensor_list): 94 | buffer = tensor.cpu().numpy().tobytes()[:size] 95 | data_list.append(pickle.loads(buffer)) 96 | 97 | return data_list 98 | 99 | def shared_random_seed(): 100 | """ 101 | Returns: 102 | int: a random number that is the same across all workers. 103 | If workers need a shared RNG, they can use this shared seed to 104 | create one. 105 | All workers must call this function, otherwise it will deadlock. 106 | """ 107 | ints = np.random.randint(2 ** 31) 108 | all_ints = all_gather(ints) 109 | return all_ints[0] 110 | 111 | class RandomIdentitySampler_DDP(Sampler): 112 | """ 113 | Randomly sample N identities, then for each identity, 114 | randomly sample K instances, therefore batch size is N*K. 115 | Args: 116 | - data_source (list): list of (img_path, pid, camid). 117 | - num_instances (int): number of instances per identity in a batch. 118 | - batch_size (int): number of examples in a batch. 119 | """ 120 | 121 | def __init__(self, data_source, batch_size, num_instances): 122 | self.data_source = data_source 123 | self.batch_size = batch_size 124 | self.world_size = dist.get_world_size() 125 | self.num_instances = num_instances 126 | self.mini_batch_size = self.batch_size // self.world_size 127 | self.num_pids_per_batch = self.mini_batch_size // self.num_instances 128 | self.index_dic = defaultdict(list) 129 | 130 | for index, (pid, _, _, _) in enumerate(self.data_source): 131 | self.index_dic[pid].append(index) 132 | self.pids = list(self.index_dic.keys()) 133 | 134 | # estimate number of examples in an epoch 135 | self.length = 0 136 | for pid in self.pids: 137 | idxs = self.index_dic[pid] 138 | num = len(idxs) 139 | if num < self.num_instances: 140 | num = self.num_instances 141 | self.length += num - num % self.num_instances 142 | 143 | self.rank = dist.get_rank() 144 | #self.world_size = dist.get_world_size() 145 | self.length //= self.world_size 146 | 147 | def __iter__(self): 148 | seed = shared_random_seed() 149 | np.random.seed(seed) 150 | self._seed = int(seed) 151 | final_idxs = self.sample_list() 152 | length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size)) 153 | #final_idxs = final_idxs[self.rank * length:(self.rank + 1) * length] 154 | final_idxs = self.__fetch_current_node_idxs(final_idxs, length) 155 | self.length = len(final_idxs) 156 | return iter(final_idxs) 157 | 158 | 159 | def __fetch_current_node_idxs(self, final_idxs, length): 160 | total_num = len(final_idxs) 161 | block_num = (length // self.mini_batch_size) 162 | index_target = [] 163 | for i in range(0, block_num * self.world_size, self.world_size): 164 | index = range(self.mini_batch_size * self.rank + self.mini_batch_size * i, min(self.mini_batch_size * self.rank + self.mini_batch_size * (i+1), total_num)) 165 | index_target.extend(index) 166 | index_target_npy = np.array(index_target) 167 | final_idxs = list(np.array(final_idxs)[index_target_npy]) 168 | return final_idxs 169 | 170 | 171 | def sample_list(self): 172 | #np.random.seed(self._seed) 173 | avai_pids = copy.deepcopy(self.pids) 174 | batch_idxs_dict = {} 175 | 176 | batch_indices = [] 177 | while len(avai_pids) >= self.num_pids_per_batch: 178 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() 179 | for pid in selected_pids: 180 | if pid not in batch_idxs_dict: 181 | idxs = copy.deepcopy(self.index_dic[pid]) 182 | if len(idxs) < self.num_instances: 183 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() 184 | np.random.shuffle(idxs) 185 | batch_idxs_dict[pid] = idxs 186 | 187 | avai_idxs = batch_idxs_dict[pid] 188 | for _ in range(self.num_instances): 189 | batch_indices.append(avai_idxs.pop(0)) 190 | 191 | if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) 192 | 193 | return batch_indices 194 | 195 | def __len__(self): 196 | return self.length 197 | 198 | -------------------------------------------------------------------------------- /merge.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as op 3 | import torch 4 | import numpy as np 5 | import random 6 | import time 7 | 8 | from datasets import build_dataloader 9 | from processor.processor import do_train 10 | from utils.checkpoint import Checkpointer 11 | from utils.iotools import save_train_configs 12 | from utils.logger import setup_logger 13 | from solver import build_optimizer, build_lr_scheduler 14 | from model import build_model 15 | from utils.metrics import Evaluator 16 | from utils.options import get_args 17 | from utils.comm import get_rank, synchronize 18 | 19 | 20 | import logging 21 | import torch 22 | import torchvision.transforms as T 23 | from torch.utils.data import DataLoader 24 | from datasets.sampler import RandomIdentitySampler 25 | from datasets.sampler_ddp import RandomIdentitySampler_DDP 26 | from torch.utils.data.distributed import DistributedSampler 27 | 28 | from utils.comm import get_world_size 29 | from datasets.cuhkpedes_merge import CUHKPEDES_M 30 | from datasets.icfgpedes import ICFGPEDES 31 | from datasets.bases import ImageDataset, SketchDataset, ImageTextMSMDataset, ImageTextMSMMLMDataset, TextDataset, ImageTextDataset, ImageTextMCQDataset, ImageTextMaskColorDataset, ImageTextMLMDataset, ImageTextMCQMLMDataset, SketchTextDataset 32 | 33 | root= '/data0/data_ccq/CUHK-PEDES/' 34 | dataset = CUHKPEDES_M(root=root) 35 | 36 | print('finished!') 37 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model 2 | -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/build.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/model/__pycache__/build.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/clip_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/model/__pycache__/clip_model.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/objectives.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/model/__pycache__/objectives.cpython-38.pyc -------------------------------------------------------------------------------- /model/build.py: -------------------------------------------------------------------------------- 1 | from model import objectives 2 | from .clip_model import Transformer, QuickGELU, LayerNorm, build_CLIP_from_openai_pretrained, convert_weights 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from collections import OrderedDict 7 | 8 | class FM_cat(nn.Module): 9 | def __init__(self,in_channels): 10 | super(FM_cat, self).__init__() 11 | 12 | self.W = nn.Sequential( 13 | nn.Conv2d(in_channels * 2, in_channels, 14 | kernel_size=1, stride=1, padding=0, bias=True), 15 | nn.BatchNorm2d(in_channels) 16 | ) 17 | nn.init.normal_(self.W[1].weight.data, 1.0, 0.01) 18 | nn.init.zeros_(self.W[1].bias.data) 19 | 20 | 21 | # self.bottleneck = nn.BatchNorm1d(in_channels) 22 | # self.bottleneck.bias.requires_grad_(False) # no shift 23 | 24 | # nn.init.normal_(self.bottleneck.weight.data, 1.0, 0.01) 25 | # nn.init.zeros_(self.bottleneck.bias.data) 26 | 27 | def forward(self,f): 28 | 29 | f = f.view(f.size(0),f.size(1),1,1) 30 | f = self.W(f) 31 | f = f.view(f.size(0),-1) 32 | # f = self.bottleneck(f+feat) 33 | 34 | return f 35 | 36 | class CLIP2ReID(nn.Module): 37 | def __init__(self, args, num_classes=11003): 38 | super().__init__() 39 | self.args = args 40 | self.num_classes = num_classes 41 | self.test_setting = args.test_setting 42 | 43 | self._set_task() 44 | 45 | self.base_model, base_cfg = build_CLIP_from_openai_pretrained(args.pretrain_choice, args.img_size, args.stride_size) 46 | self.embed_dim = base_cfg['embed_dim'] 47 | 48 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / args.temperature)) # 0.07 49 | # self.logit_scale = torch.ones([]) * np.log(1 / args.temperature) # 0.07 50 | 51 | if args.fusion_way == 'weight add': 52 | self.gate = nn.Parameter(torch.FloatTensor(2)) 53 | nn.init.constant_(self.gate, 0.5) 54 | if args.fusion_way == 'concat': 55 | scale = 512**-0.5 56 | proj_std = scale * ((2 * 4)**-0.5) 57 | self.dim_conv = nn.Linear(512*2,512) 58 | nn.init.normal_(self.dim_conv.weight, std=proj_std) 59 | 60 | if args.fusion_way == 'global concat': 61 | self.dim_conv = nn.Linear(512*2,512) 62 | self.global_attn_s = nn.MultiheadAttention(self.embed_dim, 63 | self.embed_dim // 64, 64 | batch_first=True) 65 | self.global_attn_t = nn.MultiheadAttention(self.embed_dim, 66 | self.embed_dim // 64, 67 | batch_first=True) 68 | # init cross attn 69 | scale = 512**-0.5 70 | proj_std = scale * ((2 * 4)**-0.5) 71 | attn_std = scale 72 | fc_std = (2 * 512)**-0.5 73 | nn.init.normal_(self.global_attn_s.in_proj_weight, std=attn_std) 74 | nn.init.normal_(self.global_attn_s.out_proj.weight, std=proj_std) 75 | 76 | # init cross attn 77 | nn.init.normal_(self.global_attn_t.in_proj_weight, std=attn_std) 78 | nn.init.normal_(self.global_attn_t.out_proj.weight, std=proj_std) 79 | 80 | nn.init.normal_(self.dim_conv.weight, std=proj_std) 81 | 82 | 83 | if 'concat' in args.fusion_way: 84 | self.cross_modal_transformer = Transformer(width=self.embed_dim, 85 | layers=args.cmt_depth, 86 | heads=self.embed_dim // 87 | 64) 88 | 89 | if 'cross' in args.fusion_way: 90 | self.cross_attn = nn.MultiheadAttention(self.embed_dim, 91 | self.embed_dim // 64, 92 | batch_first=True) 93 | self.cross_modal_transformer = Transformer(width=self.embed_dim, 94 | layers=args.cmt_depth, 95 | heads=self.embed_dim // 96 | 64) 97 | scale = self.cross_modal_transformer.width**-0.5 98 | # self.pos_embedding = nn.Parameter(scale * torch.randn(self.embed_dim)) 99 | 100 | self.ln_pre_t = LayerNorm(self.embed_dim) 101 | self.ln_pre_i = LayerNorm(self.embed_dim) 102 | self.ln_post = LayerNorm(self.embed_dim) 103 | 104 | proj_std = scale * ((2 * self.cross_modal_transformer.layers)**-0.5) 105 | attn_std = scale 106 | fc_std = (2 * self.cross_modal_transformer.width)**-0.5 107 | for block in self.cross_modal_transformer.resblocks: 108 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 109 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 110 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 111 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 112 | 113 | # init cross attn 114 | nn.init.normal_(self.cross_attn.in_proj_weight, std=attn_std) 115 | nn.init.normal_(self.cross_attn.out_proj.weight, std=proj_std) 116 | 117 | if 'id' in args.loss_names: 118 | self.classifier = nn.Linear(self.embed_dim, self.num_classes) 119 | nn.init.normal_(self.classifier.weight.data, std=0.001) 120 | nn.init.constant_(self.classifier.bias.data, val=0.0) 121 | 122 | if 'mcm' in args.loss_names or 'mcq' in args.loss_names or 'mlm' in args.loss_names or 'msm' in args.loss_names: 123 | self.cross_attn = nn.MultiheadAttention(self.embed_dim, 124 | self.embed_dim // 64, 125 | batch_first=True) 126 | self.cross_modal_transformer = Transformer(width=self.embed_dim, 127 | layers=args.cmt_depth, 128 | heads=self.embed_dim // 129 | 64) 130 | scale = self.cross_modal_transformer.width**-0.5 131 | # self.pos_embedding = nn.Parameter(scale * torch.randn(self.embed_dim)) 132 | 133 | self.ln_pre_t = LayerNorm(self.embed_dim) 134 | self.ln_pre_i = LayerNorm(self.embed_dim) 135 | self.ln_post = LayerNorm(self.embed_dim) 136 | 137 | proj_std = scale * ((2 * self.cross_modal_transformer.layers)**-0.5) 138 | attn_std = scale 139 | fc_std = (2 * self.cross_modal_transformer.width)**-0.5 140 | for block in self.cross_modal_transformer.resblocks: 141 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 142 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 143 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 144 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 145 | 146 | # init cross attn 147 | nn.init.normal_(self.cross_attn.in_proj_weight, std=attn_std) 148 | nn.init.normal_(self.cross_attn.out_proj.weight, std=proj_std) 149 | 150 | # mcm 151 | if 'mcm' in args.loss_names: 152 | self.mcm_head = nn.Sequential( 153 | OrderedDict([('dense', nn.Linear(self.embed_dim, self.embed_dim)), 154 | ('gelu', QuickGELU()), 155 | ('ln', LayerNorm(self.embed_dim)), 156 | ('fc', nn.Linear(self.embed_dim, args.num_colors))])) 157 | # self.mcm_head = nn.Linear(self.embed_dim, args.num_colors) 158 | # init mcm head 159 | nn.init.normal_(self.mcm_head.dense.weight, std=fc_std) 160 | nn.init.normal_(self.mcm_head.fc.weight, std=proj_std) 161 | 162 | # mcq 163 | # if 'mcq' in args.loss_names: 164 | # self.mcq_proj = nn.Parameter(scale * torch.randn(self.cross_modal_transformer.width, self.embed_dim)) 165 | 166 | # TODO mlm 167 | if 'mlm' in args.loss_names: 168 | self.mlm_head = nn.Sequential( 169 | OrderedDict([('dense', nn.Linear(self.embed_dim, self.embed_dim)), 170 | ('gelu', QuickGELU()), 171 | ('ln', LayerNorm(self.embed_dim)), 172 | ('fc', nn.Linear(self.embed_dim, args.vocab_size))])) 173 | # self.mlm_head = nn.Linear(self.embed_dim, args.num_colors) 174 | # init mlm head 175 | nn.init.normal_(self.mlm_head.dense.weight, std=fc_std) 176 | nn.init.normal_(self.mlm_head.fc.weight, std=proj_std) 177 | 178 | def _set_task(self): 179 | loss_names = self.args.loss_names 180 | self.current_task = [l.strip() for l in loss_names.split('+')] 181 | print(f'Training Model with {self.current_task} tasks') 182 | 183 | def cross_former(self, q, k, v): 184 | x = self.cross_attn( 185 | self.ln_pre_t(q), 186 | self.ln_pre_i(k), 187 | self.ln_pre_i(v), 188 | need_weights=False)[0] 189 | # x = q + x # residual connection (invalid for mcq and mcqmlm, valid for mlm) 190 | # x = x.permute(1, 0, 2) # NLD -> LND 191 | # x = self.cross_modal_transformer(x) 192 | # x = x.permute(1, 0, 2) # LND -> NLD 193 | 194 | x = self.ln_post(x) 195 | return x 196 | 197 | def global_former_s(self, q, k, v): 198 | x = self.global_attn_s( 199 | self.ln_pre_t(q), 200 | self.ln_pre_i(k), 201 | self.ln_pre_i(v), 202 | need_weights=False)[0] 203 | x = self.ln_post(x) 204 | return x 205 | 206 | def global_former_t(self, q, k, v): 207 | x = self.global_attn_t( 208 | self.ln_pre_t(q), 209 | self.ln_pre_i(k), 210 | self.ln_pre_i(v), 211 | need_weights=False)[0] 212 | x = self.ln_post(x) 213 | return x 214 | 215 | def encode_image(self, image): 216 | x = self.base_model.encode_image(image) 217 | return x 218 | # return x.float() # for CLIP ResNet visual model 219 | 220 | def encode_text(self, text): 221 | x = self.base_model.encode_text(text) 222 | return x #[torch.arange(x.shape[0]), text.argmax(dim=-1)].float() 223 | 224 | def fusion_layer(self, text, sketch, caption_ids, pa=0.1, way='add'): 225 | 226 | if way == 'weight add': 227 | f_feats = self.gate[0] * text[torch.arange(text.shape[0]), caption_ids.argmax(dim=-1)] + self.gate[1] * sketch[:, 0, :] 228 | elif way == 'cross attention': 229 | f_feats = (self.cross_former(text[torch.arange(text.shape[0]), caption_ids.argmax(dim=-1)].unsqueeze(1),sketch,sketch) + self.cross_former(sketch[:,0,:].unsqueeze(1),text,text)) 230 | f_feats = f_feats.squeeze(1).contiguous() 231 | elif way == 'cross attention text': 232 | # f_feats = (self.cross_former(text,sketch,sketch)[:, 0, :] + self.cross_former(sketch,text,text)[torch.arange(text.shape[0]), caption_ids.argmax(dim=-1)]) 233 | f_feats = self.cross_former(sketch[:,0,:].unsqueeze(1),text,text).squeeze(1).contiguous() 234 | elif way == 'cross attention sketch': 235 | # f_feats = (self.cross_former(text,sketch,sketch)[:, 0, :] + self.cross_former(sketch,text,text)[torch.arange(text.shape[0]), caption_ids.argmax(dim=-1)]) 236 | f_feats = self.cross_former(text[torch.arange(text.shape[0]), caption_ids.argmax(dim=-1)].unsqueeze(1),sketch,sketch).squeeze(1).contiguous() 237 | elif way == 'parameter add': 238 | f_feats = (1-pa)*text[torch.arange(text.shape[0]), caption_ids.argmax(dim=-1)] + pa*sketch[:, 0, :] 239 | elif way == 'concat': 240 | f_feats = self.dim_conv(torch.cat((text[torch.arange(text.shape[0]), caption_ids.argmax(dim=-1)], sketch[:, 0, :]),dim=1)) 241 | elif way == 'global concat': 242 | s_global = self.global_former_s(sketch[:,0,:,],sketch[:,1:,:],sketch[:,1:,:]) 243 | eos_indices = caption_ids.argmax(dim=-1) 244 | t_globel = text[torch.arange(text.shape[0]), eos_indices] 245 | text[torch.arange(text.shape[0]), eos_indices] = 0 246 | t_local = text 247 | t_global = self.global_former_t(t_globel,t_local, t_local) 248 | # f_feats = self.dim_conv(torch.cat((, caption_ids.argmax(dim=-1)], sketch[:, 0, :]),dim=1)) 249 | elif way == 'concat transformer': 250 | l_t = text.size(1) 251 | f_feats = self.cross_modal_transformer(torch.cat((text,sketch),dim=1)) 252 | f_feats = f_feats[:,l_t:,:][:, 0, :] + f_feats[:,:l_t,:][torch.arange(text.shape[0]), caption_ids.argmax(dim=-1)] 253 | elif way == 'concat transformer-s': 254 | l_t = text.size(1) 255 | f_feats = self.cross_modal_transformer(torch.cat((text,sketch),dim=1)) 256 | f_feats = f_feats[:,l_t:,:][:, 0, :] 257 | elif way == 'concat transformer-t': 258 | l_t = text.size(1) 259 | f_feats = self.cross_modal_transformer(torch.cat((text,sketch),dim=1)) 260 | f_feats = f_feats[:,:l_t,:][torch.arange(text.shape[0]), caption_ids.argmax(dim=-1)] 261 | else: 262 | 263 | f_feats = text[torch.arange(text.shape[0]), caption_ids.argmax(dim=-1)] + sketch[:, 0, :] 264 | 265 | return f_feats.float() 266 | 267 | def forward(self, batch): 268 | ret = dict() 269 | images = batch['images'] 270 | caption_ids = batch['caption_ids'] 271 | simages = batch['simages'] 272 | label = batch['pids'] 273 | 274 | image_feats, text_feats = self.base_model(torch.cat((images,simages),dim=0), caption_ids) 275 | b = image_feats.size(0) 276 | simage_feats = image_feats[int(b/2):,:,:] # [64, 193, 512] text:[64, 77, 512] 277 | image_feats = image_feats[:int(b/2),:,:] 278 | 279 | logit_scale = self.logit_scale.exp() 280 | ret.update({'temperature': 1 / logit_scale}) 281 | 282 | if self.args.only_sketch: 283 | i_feats = image_feats[:, 0, :].float() 284 | si_feats = simage_feats[:, 0, :].float() 285 | ret.update({'itc_loss':(objectives.compute_itc(i_feats, si_feats, logit_scale))*self.args.cmm_loss_weight}) 286 | elif self.args.only_text: 287 | i_feats = image_feats[:, 0, :].float() 288 | t_feats = text_feats[torch.arange(text_feats.shape[0]), caption_ids.argmax(dim=-1)].float() #[64, 512] 289 | ret.update({'itc_loss':(objectives.compute_itc(i_feats, t_feats, logit_scale))*self.args.cmm_loss_weight}) 290 | # elif self.args.only_fusion: 291 | # i_feats = image_feats[:, 0, :].float() 292 | # si_feats = simage_feats[:, 0, :].float() 293 | # t_feats = text_feats[torch.arange(text_feats.shape[0]), caption_ids.argmax(dim=-1)].float() #[64, 512] 294 | # f_feats = t_feats + si_feats 295 | # ret.update({'itc_loss':(objectives.compute_itc(i_feats, f_feats, logit_scale))*self.args.cmm_loss_weight}) 296 | else: 297 | if self.args.fusion_way in ['add', 'weight add', 'cross attention', 'parameter add', 'concat', 'global concat', 'cross attention text', 'cross attention sketch', 'concat transformer']: 298 | f_feats = self.fusion_layer(text_feats, simage_feats, caption_ids, pa=self.args.pa, way=self.args.fusion_way) 299 | i_feats = image_feats[:, 0, :].float() 300 | si_feats = simage_feats[:, 0, :].float() 301 | t_feats = text_feats[torch.arange(text_feats.shape[0]), caption_ids.argmax(dim=-1)].float() #[64, 512] 302 | if self.args.only_fusion_loss: 303 | ret.update({'itc_loss':(objectives.compute_itc(i_feats, f_feats, logit_scale))*self.args.cmm_loss_weight}) 304 | elif self.args.four_fusion_loss: 305 | ret.update({'itc_loss':(objectives.compute_itc(i_feats, t_feats, logit_scale) + objectives.compute_itc(i_feats, si_feats, logit_scale) + objectives.compute_itc(i_feats, f_feats, logit_scale)+objectives.compute_itc(si_feats, t_feats, logit_scale))*self.args.cmm_loss_weight}) 306 | elif self.args.focal_three_fusion_loss3: 307 | ret.update({'itc_loss':(objectives.compute_itc_focal3(i_feats, t_feats, si_feats, f_feats, logit_scale, self.args.al, self.args.ga, self.args.klp))*self.args.cmm_loss_weight}) 308 | else: 309 | ret.update({'itc_loss':(objectives.compute_itc(i_feats, t_feats, logit_scale) + objectives.compute_itc(i_feats, si_feats, logit_scale) + objectives.compute_itc(i_feats, f_feats, logit_scale))*self.args.cmm_loss_weight}) 310 | else: 311 | i_feats = image_feats[:, 0, :].float() 312 | si_feats = simage_feats[:, 0, :].float() 313 | t_feats = text_feats[torch.arange(text_feats.shape[0]), caption_ids.argmax(dim=-1)].float() #[64, 512] 314 | ret.update({'itc_loss':(objectives.compute_itc(i_feats, t_feats, logit_scale)+objectives.compute_itc(i_feats, si_feats, logit_scale))*self.args.cmm_loss_weight}) 315 | 316 | return ret 317 | 318 | # if 'tcmpm' in self.current_task: 319 | # if self.args.use_imageid: 320 | # ret.update({'tcmpm_loss':objectives.compute_tcmpm(i_feats, t_feats, batch['pids'], logit_scale, image_id=batch['image_ids'])*self.args.cmm_loss_weight}) 321 | # else: 322 | # ret.update({'tcmpm_loss':objectives.compute_tcmpm(i_feats, t_feats, batch['pids'], logit_scale)*self.args.cmm_loss_weight}) 323 | 324 | # if 'itc' in self.current_task: 325 | # ret.update({'itc_loss':(objectives.compute_itc(i_feats, t_feats, logit_scale)+objectives.compute_itc(i_feats, si_feats, logit_scale)+objectives.compute_itc(i_feats, f_feats, logit_scale))*self.args.cmm_loss_weight}) 326 | 327 | # if 'sdm' in self.current_task: 328 | # ret.update({'sdm_loss':objectives.compute_sdm(i_feats, t_feats, batch['pids'], logit_scale)*self.args.cmm_loss_weight}) 329 | 330 | # if 'cmpm' in self.current_task: 331 | # ret.update({'cmpm_loss':objectives.compute_cmpm(i_feats, t_feats, batch['pids'])*self.args.cmm_loss_weight}) 332 | 333 | # if 'supcon' in self.current_task: 334 | # bs, d = i_feats.size() 335 | # i_feats = i_feats.view(-1, self.args.num_instance, d) 336 | # si_feats = si_feats.view(-1, self.args.num_instance, d) 337 | # t_feats = t_feats.view(-1, self.args.num_instance, d) 338 | # f_feats = f_feats.view(-1, self.args.num_instance, d) 339 | # label = label.view(-1, self.args.num_instance)[:,0] 340 | 341 | # ret.update({'supcon_loss':(objectives.SupConLoss(torch.cat((i_feats, t_feats),dim=1), label)+objectives.SupConLoss(torch.cat((i_feats, si_feats),dim=1), label)+objectives.SupConLoss(torch.cat((i_feats, f_feats),dim=1), label))*self.args.cmm_loss_weight}) 342 | 343 | # if 'mcm' in self.current_task: 344 | # masked_caption_ids = batch['masked_caption_ids'] 345 | # # with torch.no_grad(): 346 | # masked_caption_feats = self.base_model.encode_text(masked_caption_ids) 347 | 348 | # x = self.cross_former(masked_caption_feats, image_feats, image_feats) 349 | 350 | # x = self.mcm_head(x) # [batch_size, text_len, num_colors] 351 | 352 | # scores = x.float().reshape(-1, self.args.num_colors) 353 | # mcm_labels = batch['mcm_labels'].reshape(-1) 354 | # ret.update({'mcm_loss': objectives.compute_mcm_or_mlm(scores, mcm_labels)*self.args.mcm_loss_weight}) 355 | 356 | # pred = scores.max(1)[1] 357 | # mcm_label_idx = torch.nonzero(mcm_labels) 358 | # acc = (pred[mcm_label_idx] == mcm_labels[mcm_label_idx]).float().mean() 359 | # ret.update({'acc': acc}) 360 | 361 | # if 'mlm' in self.current_task: 362 | # mlm_ids = batch['mlm_ids'] 363 | 364 | # mlm_feats = self.base_model.encode_text(mlm_ids) 365 | 366 | # x = self.cross_former(mlm_feats, image_feats, image_feats) 367 | 368 | # x = self.mlm_head(x) # [batch_size, text_len, num_colors] 369 | 370 | # scores = x.float().reshape(-1, self.args.vocab_size) 371 | # mlm_labels = batch['mlm_labels'].reshape(-1) 372 | # ret.update({'mlm_loss': objectives.compute_mcm_or_mlm(scores, mlm_labels)*self.args.mlm_loss_weight}) 373 | 374 | # pred = scores.max(1)[1] 375 | # mlm_label_idx = torch.nonzero(mlm_labels) 376 | # acc = (pred[mlm_label_idx] == mlm_labels[mlm_label_idx]).float().mean() 377 | # ret.update({'acc': acc}) 378 | 379 | # if 'mcq' in self.current_task or 'msm' in self.current_task: 380 | # question_ids = batch['question_ids'] 381 | # answer_ids = batch['answer_ids'] 382 | 383 | # question_feats = self.base_model.encode_text(question_ids) 384 | # answer_feats = self.encode_text(answer_ids) 385 | 386 | # x = self.cross_former(question_feats, image_feats, image_feats) 387 | 388 | # # x = x @ self.mcq_proj 389 | 390 | # pred_answer_feats = x[torch.arange(x.shape[0]), question_ids.argmax(dim=-1)].float() 391 | # ret.update({'mcq_loss': objectives.compute_mcq(pred_answer_feats, answer_feats)*self.args.mcq_loss_weight}) 392 | 393 | # return ret 394 | 395 | 396 | def build_model(args, num_classes=11003): 397 | model = CLIP2ReID(args, num_classes) 398 | # covert model to fp16 399 | convert_weights(model) 400 | return model -------------------------------------------------------------------------------- /model/clip_model.py: -------------------------------------------------------------------------------- 1 | """ CLIP Model 2 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 3 | """ 4 | from collections import OrderedDict 5 | import logging 6 | import math 7 | import os 8 | from typing import List, Tuple, Union 9 | import hashlib 10 | import urllib 11 | from tqdm import tqdm 12 | import warnings 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | from torch import nn 17 | 18 | 19 | logger = logging.getLogger("CLIP2ReID.model") 20 | 21 | _MODELS = { 22 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 23 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 24 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 25 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 26 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 27 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 28 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 29 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 30 | } 31 | 32 | def available_models() -> List[str]: 33 | """Returns the names of available CLIP models""" 34 | return list(_MODELS.keys()) 35 | 36 | def _download(url: str, root: str): 37 | os.makedirs(root, exist_ok=True) 38 | filename = os.path.basename(url) 39 | 40 | expected_sha256 = url.split("/")[-2] 41 | download_target = os.path.join(root, filename) 42 | 43 | if os.path.exists(download_target) and not os.path.isfile(download_target): 44 | raise RuntimeError(f"{download_target} exists and is not a regular file") 45 | 46 | if os.path.isfile(download_target): 47 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 48 | return download_target 49 | else: 50 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 51 | 52 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 53 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 54 | while True: 55 | buffer = source.read(8192) 56 | if not buffer: 57 | break 58 | 59 | output.write(buffer) 60 | loop.update(len(buffer)) 61 | 62 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 63 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 64 | 65 | return download_target 66 | 67 | 68 | class Bottleneck(nn.Module): 69 | expansion = 4 70 | 71 | def __init__(self, inplanes, planes, stride=1): 72 | super().__init__() 73 | 74 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 75 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 76 | self.bn1 = nn.BatchNorm2d(planes) 77 | 78 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 79 | self.bn2 = nn.BatchNorm2d(planes) 80 | 81 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 82 | 83 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 84 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 85 | 86 | self.relu = nn.ReLU(inplace=True) 87 | self.downsample = None 88 | self.stride = stride 89 | 90 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 91 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 92 | self.downsample = nn.Sequential(OrderedDict([ 93 | ("-1", nn.AvgPool2d(stride)), 94 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 95 | ("1", nn.BatchNorm2d(planes * self.expansion)) 96 | ])) 97 | 98 | def forward(self, x: torch.Tensor): 99 | identity = x 100 | 101 | out = self.relu(self.bn1(self.conv1(x))) 102 | out = self.relu(self.bn2(self.conv2(out))) 103 | out = self.avgpool(out) 104 | out = self.bn3(self.conv3(out)) 105 | 106 | if self.downsample is not None: 107 | identity = self.downsample(x) 108 | 109 | out += identity 110 | out = self.relu(out) 111 | return out 112 | 113 | 114 | class AttentionPool2d(nn.Module): 115 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 116 | super().__init__() 117 | # self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 118 | self.positional_embedding = nn.Parameter(torch.randn((spacial_dim[0] * spacial_dim[1]) + 1, embed_dim)/ embed_dim ** 0.5) 119 | self.k_proj = nn.Linear(embed_dim, embed_dim) 120 | self.q_proj = nn.Linear(embed_dim, embed_dim) 121 | self.v_proj = nn.Linear(embed_dim, embed_dim) 122 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 123 | self.num_heads = num_heads 124 | 125 | def forward(self, x): 126 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 127 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 128 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 129 | x, _ = F.multi_head_attention_forward( 130 | query=x, key=x, value=x, 131 | embed_dim_to_check=x.shape[-1], 132 | num_heads=self.num_heads, 133 | q_proj_weight=self.q_proj.weight, 134 | k_proj_weight=self.k_proj.weight, 135 | v_proj_weight=self.v_proj.weight, 136 | in_proj_weight=None, 137 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 138 | bias_k=None, 139 | bias_v=None, 140 | add_zero_attn=False, 141 | dropout_p=0, 142 | out_proj_weight=self.c_proj.weight, 143 | out_proj_bias=self.c_proj.bias, 144 | use_separate_proj_weight=True, 145 | training=self.training, 146 | need_weights=False 147 | ) 148 | 149 | return x[0] 150 | 151 | 152 | class ModifiedResNet(nn.Module): 153 | """ 154 | A ResNet class that is similar to torchvision's but contains the following changes: 155 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 156 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 157 | - The final pooling layer is a QKV attention instead of an average pool 158 | """ 159 | 160 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 161 | super().__init__() 162 | self.output_dim = output_dim 163 | self.input_resolution = input_resolution 164 | 165 | # the 3-layer stem 166 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 167 | self.bn1 = nn.BatchNorm2d(width // 2) 168 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 169 | self.bn2 = nn.BatchNorm2d(width // 2) 170 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 171 | self.bn3 = nn.BatchNorm2d(width) 172 | self.avgpool = nn.AvgPool2d(2) 173 | self.relu = nn.ReLU(inplace=True) 174 | 175 | # residual layers 176 | self._inplanes = width # this is a *mutable* variable used during construction 177 | self.layer1 = self._make_layer(width, layers[0]) 178 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 179 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 180 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 181 | 182 | embed_dim = width * 32 # the ResNet feature dimension 183 | spacial_dim = ( 184 | input_resolution[0] // 32, 185 | input_resolution[1] // 32, 186 | ) 187 | self.attnpool = AttentionPool2d(spacial_dim, embed_dim, heads, output_dim) 188 | 189 | def _make_layer(self, planes, blocks, stride=1): 190 | layers = [Bottleneck(self._inplanes, planes, stride)] 191 | 192 | self._inplanes = planes * Bottleneck.expansion 193 | for _ in range(1, blocks): 194 | layers.append(Bottleneck(self._inplanes, planes)) 195 | 196 | return nn.Sequential(*layers) 197 | 198 | def forward(self, x): 199 | def stem(x): 200 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 201 | x = self.relu(bn(conv(x))) 202 | x = self.avgpool(x) 203 | return x 204 | 205 | x = x.type(self.conv1.weight.dtype) 206 | x = stem(x) 207 | x = self.layer1(x) 208 | x = self.layer2(x) 209 | x = self.layer3(x) 210 | x = self.layer4(x) 211 | x = self.attnpool(x) 212 | 213 | return x 214 | 215 | 216 | class LayerNorm(nn.LayerNorm): 217 | """Subclass torch's LayerNorm to handle fp16.""" 218 | 219 | def forward(self, x: torch.Tensor): 220 | orig_type = x.dtype 221 | ret = super().forward(x.type(torch.float32)) 222 | return ret.type(orig_type) 223 | 224 | 225 | class QuickGELU(nn.Module): 226 | def forward(self, x: torch.Tensor): 227 | return x * torch.sigmoid(1.702 * x) 228 | 229 | 230 | class ResidualAttentionBlock(nn.Module): 231 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 232 | super().__init__() 233 | 234 | self.attn = nn.MultiheadAttention(d_model, n_head) 235 | self.ln_1 = LayerNorm(d_model) 236 | self.mlp = nn.Sequential(OrderedDict([ 237 | ("c_fc", nn.Linear(d_model, d_model * 4)), 238 | ("gelu", QuickGELU()), 239 | ("c_proj", nn.Linear(d_model * 4, d_model)) 240 | ])) 241 | self.ln_2 = LayerNorm(d_model) 242 | self.attn_mask = attn_mask 243 | 244 | def attention(self, x: torch.Tensor): 245 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 246 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 247 | 248 | def forward(self, x: torch.Tensor): 249 | x = x + self.attention(self.ln_1(x)) 250 | x = x + self.mlp(self.ln_2(x)) 251 | return x 252 | 253 | 254 | class Transformer(nn.Module): 255 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 256 | super().__init__() 257 | self.width = width 258 | self.layers = layers 259 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 260 | 261 | def forward(self, x: torch.Tensor): 262 | return self.resblocks(x) 263 | 264 | 265 | class VisionTransformer(nn.Module): 266 | def __init__(self, input_resolution: Tuple[int, int], patch_size: int, stride_size: int, width: int, layers: int, heads: int, output_dim: int): 267 | super().__init__() 268 | self.input_resolution = input_resolution # (384, 128) 269 | self.num_x = (input_resolution[1] - patch_size) // stride_size + 1 270 | self.num_y = (input_resolution[0] - patch_size) // stride_size + 1 271 | num_patches = self.num_x * self.num_y 272 | 273 | self.output_dim = output_dim 274 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=stride_size, bias=False) 275 | 276 | scale = width ** -0.5 # 1/sqrt(768) 277 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 278 | self.positional_embedding = nn.Parameter(scale * torch.randn(num_patches + 1, width)) 279 | self.ln_pre = LayerNorm(width) 280 | 281 | self.transformer = Transformer(width, layers, heads) 282 | 283 | self.ln_post = LayerNorm(width) 284 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 285 | 286 | 287 | def forward(self, x: torch.Tensor): 288 | x = self.conv1(x) # shape = [*, width, grid, grid] 289 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 290 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 291 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 292 | x = x + self.positional_embedding.to(x.dtype) 293 | x = self.ln_pre(x) 294 | 295 | x = x.permute(1, 0, 2) # NLD -> LND 296 | x = self.transformer(x) 297 | x = x.permute(1, 0, 2) # LND -> NLD 298 | 299 | # x = self.ln_post(x[:, 0, :]) 300 | x = self.ln_post(x) 301 | 302 | if self.proj is not None: 303 | x = x @ self.proj 304 | 305 | return x 306 | 307 | 308 | 309 | class CLIP(nn.Module): 310 | def __init__(self, 311 | embed_dim: int, 312 | # vision 313 | image_resolution: Union[int, Tuple[int, int]], 314 | vision_layers: Union[Tuple[int, int, int, int], int], 315 | vision_width: int, 316 | vision_patch_size: int, 317 | stride_size: int, 318 | # text 319 | context_length: int, 320 | vocab_size: int, 321 | transformer_width: int, 322 | transformer_heads: int, 323 | transformer_layers: int 324 | ): 325 | super().__init__() 326 | 327 | self.context_length = context_length 328 | 329 | if isinstance(vision_layers, (tuple, list)): 330 | vision_heads = vision_width * 32 // 64 331 | self.visual = ModifiedResNet( 332 | layers=vision_layers, 333 | output_dim=embed_dim, 334 | heads=vision_heads, 335 | input_resolution=image_resolution, 336 | width=vision_width 337 | ) 338 | else: 339 | vision_heads = vision_width // 64 340 | self.visual = VisionTransformer( 341 | input_resolution=image_resolution, 342 | patch_size=vision_patch_size, 343 | stride_size=stride_size, 344 | width=vision_width, 345 | layers=vision_layers, 346 | heads=vision_heads, 347 | output_dim=embed_dim 348 | ) 349 | 350 | self.transformer = Transformer( 351 | width=transformer_width, 352 | layers=transformer_layers, 353 | heads=transformer_heads, 354 | attn_mask=self.build_attention_mask() 355 | ) 356 | 357 | self.vocab_size = vocab_size 358 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 359 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 360 | self.ln_final = LayerNorm(transformer_width) 361 | 362 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 363 | # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 364 | 365 | self.initialize_parameters() 366 | 367 | def initialize_parameters(self): 368 | nn.init.normal_(self.token_embedding.weight, std=0.02) 369 | nn.init.normal_(self.positional_embedding, std=0.01) 370 | 371 | if isinstance(self.visual, ModifiedResNet): 372 | if self.visual.attnpool is not None: 373 | std = self.visual.attnpool.c_proj.in_features ** -0.5 374 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 375 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 376 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 377 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 378 | 379 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 380 | for name, param in resnet_block.named_parameters(): 381 | if name.endswith("bn3.weight"): 382 | nn.init.zeros_(param) 383 | 384 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 385 | attn_std = self.transformer.width ** -0.5 386 | fc_std = (2 * self.transformer.width) ** -0.5 387 | for block in self.transformer.resblocks: 388 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 389 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 390 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 391 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 392 | 393 | if self.text_projection is not None: 394 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 395 | 396 | def build_attention_mask(self): 397 | # lazily create causal attention mask, with full attention between the vision tokens 398 | # pytorch uses additive attention mask; fill with -inf 399 | mask = torch.empty(self.context_length, self.context_length) 400 | mask.fill_(float("-inf")) 401 | mask.triu_(1) # zero out the lower diagonal 402 | return mask 403 | 404 | @property 405 | def dtype(self): 406 | return self.visual.conv1.weight.dtype 407 | 408 | def encode_image(self, image): 409 | return self.visual(image.type(self.dtype)) 410 | 411 | def encode_text(self, text): 412 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 413 | 414 | x = x + self.positional_embedding.type(self.dtype) 415 | x = x.permute(1, 0, 2) # NLD -> LND 416 | x = self.transformer(x) 417 | x = x.permute(1, 0, 2) # LND -> NLD 418 | x = self.ln_final(x).type(self.dtype) 419 | 420 | # x.shape = [batch_size, n_ctx, transformer.width] 421 | # take features from the eot embedding (eot_token is the highest number in each sequence) 422 | # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 423 | x = x @ self.text_projection 424 | 425 | return x 426 | 427 | def forward(self, image, text): 428 | image_features = self.encode_image(image) 429 | text_features = self.encode_text(text) 430 | 431 | # # normalized features 432 | # image_features = image_features / image_features.norm(dim=-1, keepdim=True) 433 | # text_features = text_features / text_features.norm(dim=-1, keepdim=True) 434 | 435 | # # cosine similarity as logits 436 | # logit_scale = self.logit_scale.exp() 437 | # logits_per_image = logit_scale * image_features @ text_features.t() 438 | # logits_per_text = logits_per_image.t() 439 | 440 | # # shape = [global_batch_size, global_batch_size] 441 | # return logits_per_image, logits_per_text 442 | 443 | return image_features, text_features 444 | 445 | 446 | def load_param(self, state_dict): 447 | # 将pretrained_dict里不属于model_dict的键剔除掉 448 | param_dict = {k: v for k, v in state_dict.items() if k in self.state_dict()} 449 | 450 | if 'model' in param_dict: 451 | param_dict = param_dict['model'] 452 | if 'state_dict' in param_dict: 453 | param_dict = param_dict['state_dict'] 454 | for k, v in param_dict.items(): 455 | if k == 'visual.positional_embedding' and v.shape != self.visual.positional_embedding.shape: 456 | v = resize_pos_embed(v, self.visual.positional_embedding, self.visual.num_y, self.visual.num_x) 457 | elif k == 'positional_embedding' and v.shape != self.positional_embedding.shape: 458 | v = resize_text_pos_embed(v, self.context_length) 459 | try: 460 | self.state_dict()[k].copy_(v) 461 | except: 462 | print(f'===========================ERROR occur in copy {k}, {v.shape}=========================') 463 | print('shape do not match in k :{}: param_dict{} vs self.state_dict(){}'.format(k, v.shape, self.state_dict()[k].shape)) 464 | 465 | 466 | 467 | def resize_pos_embed(posemb, posemb_new, hight, width): 468 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 469 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 470 | posemb = posemb.unsqueeze(0) 471 | posemb_new = posemb_new.unsqueeze(0) 472 | 473 | posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:] 474 | 475 | gs_old = int(math.sqrt(len(posemb_grid))) 476 | print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width)) 477 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 478 | posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') 479 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) 480 | posemb = torch.cat([posemb_token, posemb_grid], dim=1) 481 | return posemb.squeeze(0) 482 | 483 | 484 | def resize_text_pos_embed(posemb, length): 485 | old_h, old_w = posemb.shape 486 | print(f'Resized position embedding from size:{old_h} * {old_w} to size: {length} * {old_w}') 487 | 488 | posemb = posemb.reshape(1, 1, old_h, old_w) # [1, 1, 77, 512] 489 | posemb = F.interpolate(posemb, length, mode='bilinear') 490 | 491 | return posemb.squeeze(0) 492 | 493 | 494 | def convert_weights(model: nn.Module): 495 | """Convert applicable model parameters to fp16""" 496 | 497 | def _convert_weights_to_fp16(l): 498 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 499 | l.weight.data = l.weight.data.half() 500 | if l.bias is not None: 501 | l.bias.data = l.bias.data.half() 502 | 503 | if isinstance(l, nn.MultiheadAttention): 504 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 505 | tensor = getattr(l, attr) 506 | if tensor is not None: 507 | tensor.data = tensor.data.half() 508 | 509 | for name in ["text_projection", "proj", "mcq_proj"]: 510 | if hasattr(l, name): 511 | attr = getattr(l, name) 512 | if attr is not None: 513 | attr.data = attr.data.half() 514 | 515 | model.apply(_convert_weights_to_fp16) 516 | 517 | 518 | def build_CLIP_from_openai_pretrained(name: str, image_size: Union[int, Tuple[int, int]], stride_size: int, jit: bool = False, download_root: str = None): 519 | """Load a CLIP model 520 | 521 | Parameters 522 | ---------- 523 | name : str 524 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 525 | 526 | image_size: Union[int, Tuple[int, int]] 527 | Input image size, in Re-ID task, image size commonly set to 384x128, instead of 224x224 528 | 529 | jit : bool 530 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 531 | 532 | download_root: str 533 | path to download the model files; by default, it uses "~/.cache/clip" 534 | 535 | Returns 536 | ------- 537 | model : torch.nn.Module 538 | The CLIP model 539 | """ 540 | if name in _MODELS: 541 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 542 | elif os.path.isfile(name): 543 | model_path = name 544 | else: 545 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 546 | 547 | try: 548 | # loading JIT archive 549 | model = torch.jit.load(model_path, map_location="cpu") 550 | state_dict = None 551 | except RuntimeError: 552 | # loading saved state dict 553 | if jit: 554 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 555 | jit = False 556 | state_dict = torch.load(model_path, map_location="cpu") 557 | 558 | state_dict = state_dict or model.state_dict() 559 | 560 | vit = "visual.proj" in state_dict 561 | 562 | if vit: 563 | vision_width = state_dict["visual.conv1.weight"].shape[0] 564 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 565 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 566 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 567 | image_resolution = vision_patch_size * grid_size 568 | else: 569 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 570 | vision_layers = tuple(counts) 571 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 572 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 573 | vision_patch_size = None 574 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 575 | image_resolution = output_width * 32 576 | 577 | embed_dim = state_dict["text_projection"].shape[1] 578 | context_length = state_dict["positional_embedding"].shape[0] 579 | vocab_size = state_dict["token_embedding.weight"].shape[0] 580 | transformer_width = state_dict["ln_final.weight"].shape[0] 581 | transformer_heads = transformer_width // 64 582 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 583 | 584 | model_cfg = { 585 | 'embed_dim': embed_dim, 586 | 'image_resolution': image_resolution, 587 | 'vision_layers': vision_layers, 588 | 'vision_width': vision_width, 589 | 'vision_patch_size': vision_patch_size, 590 | 'context_length': context_length, 591 | 'vocab_size': vocab_size, 592 | 'transformer_width': transformer_width, 593 | 'transformer_heads': transformer_heads, 594 | 'transformer_layers': transformer_layers 595 | } 596 | 597 | 598 | # modify image resolution to adapt Re-ID task 599 | model_cfg['image_resolution'] = image_size 600 | model_cfg['stride_size'] = stride_size 601 | logger.info(f"Load pretrained {name} CLIP model with model config: {model_cfg}") 602 | model = CLIP(**model_cfg) 603 | 604 | # covert model to fp16 605 | # convert_weights(model) 606 | 607 | # resize modified pos embedding 608 | model.load_param(state_dict) 609 | return model, model_cfg 610 | 611 | 612 | -------------------------------------------------------------------------------- /model/objectives.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def compute_tcmpm(image_fetures, text_fetures, pid, logit_scale, image_id=None, factor=0.3, epsilon=1e-8): 6 | """ 7 | Cross Modal Projection Matching 8 | t2i_proj = ||t|| * cos(theta) 9 | i2j_proj = ||v|| * cos(theta) 10 | """ 11 | batch_size = image_fetures.shape[0] 12 | pid = pid.reshape((batch_size, 1)) # make sure pid size is [batch_size, 1] 13 | pid_dist = pid - pid.t() 14 | labels = (pid_dist == 0).float() 15 | 16 | if image_id != None: 17 | # print("Mix PID and ImageID to create soft label.") 18 | image_id = image_id.reshape((-1, 1)) 19 | image_id_dist = image_id - image_id.t() 20 | image_id_mask = (image_id_dist == 0).float() 21 | labels = (labels - image_id_mask) * factor + image_id_mask 22 | # labels = (labels + image_id_mask) / 2 23 | 24 | image_norm = image_fetures / image_fetures.norm(dim=1, keepdim=True) 25 | text_norm = text_fetures / text_fetures.norm(dim=1, keepdim=True) 26 | 27 | image_proj_text = logit_scale * torch.matmul(image_fetures, text_norm.t()) 28 | text_proj_image = logit_scale * torch.matmul(text_fetures, image_norm.t()) 29 | 30 | 31 | # normalize the true matching distribution 32 | labels_distribute = labels / labels.sum(dim=1) # original paper use sum, and use norm will lead minus loss 33 | # labels_distribute = F.softmax((labels * logit_scale), dim=1) 34 | # labels_distribute = F.softmax(labels, dim=1) 35 | 36 | i2t_pred = F.softmax(image_proj_text, dim=1) 37 | i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_distribute + epsilon)) 38 | t2i_pred = F.softmax(text_proj_image, dim=1) 39 | t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_distribute + epsilon)) 40 | 41 | # i2t2t2i_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - F.log_softmax(text_proj_image, dim=1)) 42 | 43 | # loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1)) + torch.mean(torch.sum(i2t2t2i_loss, dim=1)) 44 | loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1)) 45 | 46 | return loss 47 | 48 | 49 | def compute_sdm(image_fetures, text_fetures, pid, logit_scale, image_id=None, factor=0.3, epsilon=1e-8): 50 | """ 51 | Similarity Distribution Matching 52 | """ 53 | batch_size = image_fetures.shape[0] 54 | pid = pid.reshape((batch_size, 1)) # make sure pid size is [batch_size, 1] 55 | pid_dist = pid - pid.t() 56 | labels = (pid_dist == 0).float() 57 | 58 | if image_id != None: 59 | # print("Mix PID and ImageID to create soft label.") 60 | image_id = image_id.reshape((-1, 1)) 61 | image_id_dist = image_id - image_id.t() 62 | image_id_mask = (image_id_dist == 0).float() 63 | labels = (labels - image_id_mask) * factor + image_id_mask 64 | # labels = (labels + image_id_mask) / 2 65 | 66 | image_norm = image_fetures / image_fetures.norm(dim=1, keepdim=True) 67 | text_norm = text_fetures / text_fetures.norm(dim=1, keepdim=True) 68 | 69 | t2i_cosine_theta = text_norm @ image_norm.t() 70 | i2t_cosine_theta = t2i_cosine_theta.t() 71 | 72 | # text_proj_image = logit_scale * (text_norm_value * t2i_cosine_theta) 73 | # image_proj_text = logit_scale * (image_norm_value * i2t_cosine_theta) 74 | 75 | # mean_norm_value = (text_norm_value + image_norm_value) / 2 76 | # text_proj_image = logit_scale * (mean_norm_value * t2i_cosine_theta) 77 | # image_proj_text = logit_scale * (mean_norm_value * i2t_cosine_theta) 78 | 79 | # k_value = 8 80 | # text_proj_image = logit_scale * (k_value * t2i_cosine_theta) 81 | # image_proj_text = logit_scale * (k_value * i2t_cosine_theta) 82 | 83 | text_proj_image = logit_scale * t2i_cosine_theta 84 | image_proj_text = logit_scale * i2t_cosine_theta 85 | 86 | # normalize the true matching distribution 87 | labels_distribute = labels / labels.sum(dim=1) # original paper use sum, and use norm will lead minus loss 88 | # labels_distribute = F.softmax((labels * logit_scale), dim=1) 89 | 90 | i2t_pred = F.softmax(image_proj_text, dim=1) 91 | i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_distribute + epsilon)) 92 | t2i_pred = F.softmax(text_proj_image, dim=1) 93 | t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_distribute + epsilon)) 94 | 95 | loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1)) 96 | 97 | return loss 98 | 99 | def compute_mcm_or_mlm(scores, labels): 100 | ce = nn.CrossEntropyLoss(ignore_index=0) 101 | return ce(scores, labels) 102 | 103 | 104 | def compute_itc(image_features, text_features, logit_scale): 105 | """ 106 | image-text contrastive (ITC) loss, InfoNCE 107 | """ 108 | batch_size = image_features.shape[0] 109 | labels = torch.arange(start=0, end=batch_size, dtype=torch.int64) 110 | labels = labels.to(image_features.device) 111 | 112 | 113 | # normalized features 114 | image_norm = image_features / image_features.norm(dim=-1, keepdim=True) 115 | text_norm = text_features / text_features.norm(dim=-1, keepdim=True) 116 | 117 | # cosine similarity as logits 118 | logits_per_image = logit_scale * image_norm @ text_norm.t() 119 | logits_per_text = logits_per_image.t() 120 | 121 | loss_i = F.cross_entropy(logits_per_image, labels) 122 | loss_t =F.cross_entropy(logits_per_text, labels) 123 | loss = (loss_i + loss_t)/2 124 | 125 | return loss 126 | 127 | class CMFL(nn.Module): 128 | """ 129 | Cross Modal Focal Loss 130 | """ 131 | 132 | def __init__(self, alpha=1, gamma=2, binary=False, multiplier=2, sg=False): 133 | super(CMFL, self).__init__() 134 | self.alpha = alpha 135 | self.gamma = gamma 136 | self.binary = binary 137 | self.multiplier = multiplier 138 | self.sg = sg 139 | 140 | def forward(self, inputs_a, inputs_b, targets): 141 | 142 | # bce_loss_a = F.binary_cross_entropy(inputs_a, targets, reduce=False) 143 | # bce_loss_b = F.binary_cross_entropy(inputs_b, targets, reduce=False) 144 | 145 | bce_loss_a = F.cross_entropy(inputs_a, targets, reduce=False) 146 | bce_loss_b = F.cross_entropy(inputs_b, targets, reduce=False) 147 | 148 | pt_a = torch.exp(-bce_loss_a) 149 | pt_b = torch.exp(-bce_loss_b) 150 | 151 | eps = 0.000000001 152 | 153 | if self.sg: 154 | d_pt_a = pt_a.detach() 155 | d_pt_b = pt_b.detach() 156 | wt_a = ((d_pt_b + eps) * (self.multiplier * pt_a * d_pt_b)) / (pt_a + d_pt_b + eps) 157 | wt_b = ((d_pt_a + eps) * (self.multiplier * d_pt_a * pt_b)) / (d_pt_a + pt_b + eps) 158 | else: 159 | wt_a = ((pt_b + eps) * (self.multiplier * pt_a * pt_b)) / (pt_a + pt_b + eps) 160 | wt_b = ((pt_a + eps) * (self.multiplier * pt_a * pt_b)) / (pt_a + pt_b + eps) 161 | 162 | if self.binary: 163 | wt_a = wt_a * (1 - targets) 164 | wt_b = wt_b * (1 - targets) 165 | 166 | f_loss_a = self.alpha * (1 - wt_a) ** self.gamma * bce_loss_a 167 | f_loss_b = self.alpha * (1 - wt_b) ** self.gamma * bce_loss_b 168 | 169 | loss = 0.5 * torch.mean(f_loss_a) + 0.5 * torch.mean(f_loss_b) 170 | 171 | return loss 172 | 173 | def focal_loss_two(inputs_a, inputs_b, alpha, gamma): 174 | 175 | pt_a = torch.exp(-inputs_a) 176 | pt_b = torch.exp(-inputs_b) 177 | 178 | eps = 0.000000001 179 | 180 | 181 | wt_a = ((pt_b + eps) * (2 * pt_a * pt_b)) / (pt_a + pt_b + eps) 182 | wt_b = ((pt_a + eps) * (2 * pt_a * pt_b)) / (pt_a + pt_b + eps) 183 | 184 | 185 | f_loss_a = alpha * (1 + wt_a) ** gamma * inputs_a 186 | f_loss_b = alpha * (1 + wt_b) ** gamma * inputs_b 187 | 188 | loss = torch.mean(f_loss_a) + torch.mean(f_loss_b) 189 | 190 | return loss 191 | 192 | 193 | def compute_itc_focal3(image_features, text_features, simage_features, fusion_features, logit_scale, alpha, gamma, klp): 194 | """ 195 | image-text contrastive (ITC) loss, InfoNCE 196 | """ 197 | batch_size = image_features.shape[0] 198 | labels = torch.arange(start=0, end=batch_size, dtype=torch.int64) 199 | labels = labels.to(image_features.device) 200 | 201 | 202 | # normalized features 203 | image_norm = image_features / image_features.norm(dim=-1, keepdim=True) 204 | text_norm = text_features / text_features.norm(dim=-1, keepdim=True) 205 | simage_norm = simage_features / simage_features.norm(dim=-1, keepdim=True) 206 | fusion_norm = fusion_features / fusion_features.norm(dim=-1, keepdim=True) 207 | 208 | # cosine similarity as logits 209 | logits_per_image0 = logit_scale * image_norm @ text_norm.t() 210 | logits_per_text0 = logits_per_image0.t() 211 | 212 | loss_i = F.cross_entropy(logits_per_image0, labels, reduce=False) 213 | loss_t =F.cross_entropy(logits_per_text0, labels, reduce=False) 214 | loss_it = (loss_i + loss_t)/2 215 | 216 | # cosine similarity as logits 217 | logits_per_image1 = logit_scale * image_norm @ simage_norm.t() 218 | logits_per_text1 = logits_per_image1.t() 219 | 220 | loss_i = F.cross_entropy(logits_per_image1, labels, reduce=False) 221 | loss_t =F.cross_entropy(logits_per_text1, labels, reduce=False) 222 | loss_is = (loss_i + loss_t)/2 223 | 224 | # cosine similarity as logits 225 | logits_per_image = logit_scale * image_norm @ fusion_norm.t() 226 | logits_per_text = logits_per_image.t() 227 | 228 | loss_i = F.cross_entropy(logits_per_image, labels) 229 | loss_t =F.cross_entropy(logits_per_text, labels) 230 | loss_if = (loss_i + loss_t)/2 231 | 232 | # focal loss 233 | # kl = F.kl_div(logits_per_text1.softmax(dim=-1).log(), logits_per_text0.detach().softmax(dim=-1), reduction='sum') + F.kl_div(logits_per_text0.softmax(dim=-1).log(), logits_per_text1.detach().softmax(dim=-1), reduction='sum') 234 | 235 | loss = focal_loss_two(loss_it, loss_is, alpha, gamma) + loss_if + klp*(CoRefineLoss(logits_per_text1, logits_per_text0.detach())) 236 | 237 | return loss 238 | 239 | 240 | 241 | def CoRefineLoss(output1, output2): 242 | 243 | # Target is ignored at training time. Loss is defined as KL divergence 244 | # between the model output and the refined labels. 245 | if output2.requires_grad: 246 | raise ValueError("Refined labels should not require gradients.") 247 | 248 | output1_log_prob = F.log_softmax(output1, dim=1) 249 | output2_prob = F.softmax(output2, dim=1) 250 | 251 | _, pred_label = output2_prob.max(1) 252 | 253 | # Loss is normal cross entropy loss 254 | # base_loss = F.cross_entropy(output1, pred_label) 255 | 256 | # Loss is -dot(model_output_log_prob, refined_labels). Prepare tensors 257 | # for batch matrix multiplicatio 258 | 259 | model_output1_log_prob = output1_log_prob.unsqueeze(2) 260 | model_output2_prob = output2_prob.unsqueeze(1) 261 | 262 | # Compute the loss, and average/sum for the batch. 263 | kl_loss = -torch.bmm(model_output2_prob, model_output1_log_prob) 264 | 265 | return kl_loss.mean() 266 | 267 | def compute_id(classifier, image_embeddings, text_embeddings, labels, verbose=False): 268 | image_logits = classifier(image_embeddings) 269 | text_logits = classifier(text_embeddings) 270 | 271 | criterion = nn.CrossEntropyLoss(reduction="mean") 272 | loss = criterion(image_logits, labels) + criterion(text_logits, labels) 273 | 274 | # classification accuracy for observation 275 | if verbose: 276 | image_pred = torch.argmax(image_logits, dim=1) 277 | text_pred = torch.argmax(text_logits, dim=1) 278 | 279 | image_precision = torch.mean((image_pred == labels).float()) 280 | text_precision = torch.mean((text_pred == labels).float()) 281 | 282 | return loss, image_precision, text_precision 283 | 284 | return loss 285 | 286 | 287 | def compute_cmpm(image_embeddings, text_embeddings, labels, epsilon=1e-8): 288 | """ 289 | Cross-Modal Projection Matching Loss(CMPM) 290 | :param image_embeddings: Tensor with dtype torch.float32 291 | :param text_embeddings: Tensor with dtype torch.float32 292 | :param labels: Tensor with dtype torch.int32 293 | :return: 294 | i2t_loss: cmpm loss for image projected to text 295 | t2i_loss: cmpm loss for text projected to image 296 | pos_avg_sim: average cosine-similarity for positive pairs 297 | neg_avg_sim: averate cosine-similarity for negative pairs 298 | """ 299 | 300 | batch_size = image_embeddings.shape[0] 301 | labels_reshape = torch.reshape(labels, (batch_size, 1)) 302 | labels_dist = labels_reshape - labels_reshape.t() 303 | labels_mask = (labels_dist == 0).float() 304 | 305 | image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) 306 | text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) 307 | image_proj_text = torch.matmul(image_embeddings, text_norm.t()) 308 | text_proj_image = torch.matmul(text_embeddings, image_norm.t()) 309 | 310 | # normalize the true matching distribution 311 | labels_mask_norm = labels_mask / labels_mask.norm(dim=1) 312 | 313 | i2t_pred = F.softmax(image_proj_text, dim=1) 314 | i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_mask_norm + epsilon)) 315 | t2i_pred = F.softmax(text_proj_image, dim=1) 316 | t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_mask_norm + epsilon)) 317 | 318 | cmpm_loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1)) 319 | 320 | return cmpm_loss 321 | 322 | 323 | def compute_mcq(a, b, temperature=0.05, eps=1e-8): 324 | a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] 325 | a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) 326 | b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) 327 | x = torch.mm(a_norm, b_norm.transpose(0, 1)) 328 | 329 | i_logsm = F.log_softmax(x/temperature, dim=1) 330 | j_logsm = F.log_softmax(x.t()/temperature, dim=1) 331 | 332 | # sum over positives 333 | idiag = torch.diag(i_logsm) 334 | loss_i = idiag.sum() / len(idiag) 335 | 336 | jdiag = torch.diag(j_logsm) 337 | loss_j = jdiag.sum() / len(jdiag) 338 | 339 | return - loss_i - loss_j 340 | 341 | 342 | def CrossModalSupConLoss(image_fetures, text_fetures, labels, temperature): 343 | """ 344 | Args: 345 | features: hidden vector of shape [bsz, n_views, ...]. 346 | labels: ground truth of shape [bsz]. 347 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 348 | has the same class as sample i. Can be asymmetric. 349 | Returns: 350 | A loss scalar. 351 | """ 352 | device = (torch.device('cuda') if image_fetures.is_cuda else torch.device('cpu')) 353 | 354 | 355 | batch_size = image_fetures.shape[0] 356 | 357 | labels = labels.contiguous().view(-1, 1) 358 | if labels.shape[0] != batch_size: 359 | raise ValueError('Num of labels does not match num of features') 360 | mask = torch.eq(labels, labels.T).float().to(device) 361 | 362 | 363 | contrast_count = 2 364 | contrast_feature = torch.cat([image_fetures, text_fetures], dim=0) 365 | 366 | anchor_feature = contrast_feature 367 | anchor_count = contrast_count 368 | 369 | # compute logits 370 | anchor_dot_contrast = torch.matmul(anchor_feature, contrast_feature.T) * temperature 371 | # for numerical stability 372 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 373 | logits = anchor_dot_contrast - logits_max.detach() 374 | 375 | # tile mask 376 | mask = mask.repeat(anchor_count, contrast_count) 377 | # mask-out self-contrast cases 378 | # logits_mask = torch.scatter( 379 | # torch.ones_like(mask), 380 | # 1, 381 | # torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 382 | # 0 383 | # ) 384 | # mask = mask * logits_mask 385 | 386 | # compute log_prob 387 | # exp_logits = torch.exp(logits) * logits_mask 388 | exp_logits = torch.exp(logits) 389 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 390 | 391 | # compute mean of log-likelihood over positive 392 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 393 | 394 | # loss 395 | loss = -mean_log_prob_pos 396 | loss = loss.view(anchor_count, batch_size).mean() 397 | 398 | return loss 399 | 400 | 401 | class CrossEntropyLabelSmooth(nn.Module): 402 | """Cross entropy loss with label smoothing regularizer. 403 | 404 | Reference: 405 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 406 | Equation: y = (1 - epsilon) * y + epsilon / K. 407 | 408 | Args: 409 | num_classes (int): number of classes. 410 | epsilon (float): weight. 411 | """ 412 | 413 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 414 | super().__init__() 415 | self.num_classes = num_classes 416 | self.epsilon = epsilon 417 | self.use_gpu = use_gpu 418 | self.logsoftmax = nn.LogSoftmax(dim=1) 419 | 420 | def forward(self, inputs, targets): 421 | """ 422 | Args: 423 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 424 | targets: ground truth labels with shape (num_classes) 425 | """ 426 | log_probs = self.logsoftmax(inputs) 427 | targets = torch.zeros(log_probs.size()).scatter_( 428 | 1, targets.unsqueeze(1).data.cpu(), 1 429 | ) 430 | if self.use_gpu: 431 | targets = targets.cuda() 432 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 433 | loss = (-targets * log_probs).mean(0).sum() 434 | return loss 435 | 436 | 437 | 438 | # class SupConLoss(nn.Module): 439 | # """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 440 | # It also supports the unsupervised contrastive loss in SimCLR""" 441 | # def __init__(self, temperature=0.07, contrast_mode='all', 442 | # base_temperature=0.07): 443 | # super(SupConLoss, self).__init__() 444 | # self.temperature = temperature 445 | # self.contrast_mode = contrast_mode 446 | # self.base_temperature = base_temperature 447 | 448 | def SupConLoss(features, labels=None, mask=None, temperature=2.0, contrast_mode='all', 449 | base_temperature=0.07): 450 | """Compute loss for model. If both `labels` and `mask` are None, 451 | it degenerates to SimCLR unsupervised loss: 452 | https://arxiv.org/pdf/2002.05709.pdf 453 | Args: 454 | features: hidden vector of shape [bsz, n_views, ...]. 455 | labels: ground truth of shape [bsz]. 456 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 457 | has the same class as sample i. Can be asymmetric. 458 | Returns: 459 | A loss scalar. 460 | """ 461 | device = (torch.device('cuda') 462 | if features.is_cuda 463 | else torch.device('cpu')) 464 | 465 | if len(features.shape) < 3: 466 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 467 | 'at least 3 dimensions are required') 468 | if len(features.shape) > 3: 469 | features = features.view(features.shape[0], features.shape[1], -1) 470 | 471 | batch_size = features.shape[0] 472 | if labels is not None and mask is not None: 473 | raise ValueError('Cannot define both `labels` and `mask`') 474 | elif labels is None and mask is None: 475 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 476 | elif labels is not None: 477 | labels = labels.contiguous().view(-1, 1) 478 | if labels.shape[0] != batch_size: 479 | raise ValueError('Num of labels does not match num of features') 480 | mask = torch.eq(labels, labels.T).float().to(device) 481 | else: 482 | mask = mask.float().to(device) 483 | 484 | contrast_count = features.shape[1] 485 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 486 | if contrast_mode == 'one': 487 | anchor_feature = features[:, 0] 488 | anchor_count = 1 489 | elif contrast_mode == 'all': 490 | anchor_feature = contrast_feature 491 | anchor_count = contrast_count 492 | else: 493 | raise ValueError('Unknown mode: {}'.format(contrast_mode)) 494 | 495 | # compute logits 496 | anchor_dot_contrast = torch.div( 497 | torch.matmul(anchor_feature, contrast_feature.T), 498 | temperature) 499 | # for numerical stability 500 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 501 | logits = anchor_dot_contrast - logits_max.detach() 502 | 503 | # tile mask 504 | mask = mask.repeat(anchor_count, contrast_count) 505 | # mask-out self-contrast cases 506 | logits_mask = torch.scatter( 507 | torch.ones_like(mask), 508 | 1, 509 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 510 | 0 511 | ) 512 | mask = mask * logits_mask 513 | 514 | # compute log_prob 515 | exp_logits = torch.exp(logits) * logits_mask 516 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6) 517 | 518 | # compute mean of log-likelihood over positive 519 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 520 | 521 | # loss 522 | loss = - (temperature / base_temperature) * mean_log_prob_pos 523 | loss = loss.view(anchor_count, batch_size).mean() 524 | 525 | return loss 526 | 527 | 528 | -------------------------------------------------------------------------------- /processor/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import do_train, do_inference -------------------------------------------------------------------------------- /processor/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/processor/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /processor/__pycache__/processor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/processor/__pycache__/processor.cpython-38.pyc -------------------------------------------------------------------------------- /processor/processor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import torch 4 | from utils.meter import AverageMeter 5 | from utils.metrics import Evaluator 6 | from utils.comm import get_rank, synchronize 7 | from torch.utils.tensorboard import SummaryWriter 8 | from prettytable import PrettyTable 9 | 10 | 11 | def do_train(start_epoch, args, model, train_loader, evaluator, optimizer, 12 | scheduler, checkpointer): 13 | 14 | log_period = args.log_period 15 | eval_period = args.eval_period 16 | device = "cuda" 17 | num_epoch = args.num_epoch 18 | arguments = {} 19 | arguments["num_epoch"] = num_epoch 20 | arguments["iteration"] = 0 21 | 22 | arguments2 = {} 23 | arguments2["num_epoch"] = num_epoch 24 | arguments2["iteration"] = 0 25 | 26 | arguments3 = {} 27 | arguments3["num_epoch"] = num_epoch 28 | arguments3["iteration"] = 0 29 | 30 | arguments4 = {} 31 | arguments4["num_epoch"] = num_epoch 32 | arguments4["iteration"] = 0 33 | 34 | logger = logging.getLogger("CLIP2ReID.train") 35 | logger.info('start training') 36 | 37 | loss_meter = AverageMeter() 38 | mcm_loss_meter = AverageMeter() 39 | mlm_loss_meter = AverageMeter() 40 | mcq_loss_meter = AverageMeter() 41 | acc_meter = AverageMeter() 42 | 43 | tb_writer = SummaryWriter(log_dir=args.output_dir) 44 | 45 | best_ttop1 = 0.0 46 | best_stop1 = 0.0 47 | best_itop1 = 0.0 48 | best_ftop1 = 0.0 49 | 50 | # train 51 | for epoch in range(start_epoch, num_epoch + 1): 52 | start_time = time.time() 53 | loss_meter.reset() 54 | acc_meter.reset() 55 | mcm_loss_meter.reset() 56 | mlm_loss_meter.reset() 57 | mcq_loss_meter.reset() 58 | model.train() 59 | 60 | for n_iter, batch in enumerate(train_loader): 61 | batch = {k: v.to(device) for k, v in batch.items()} 62 | 63 | ret = model(batch) 64 | total_loss = sum([v for k, v in ret.items() if "loss" in k]) 65 | 66 | loss_meter.update(total_loss.item(), batch['images'].shape[0]) 67 | acc_meter.update(ret.get('acc', 0), 1) 68 | 69 | mcm_loss_meter.update(ret.get('mcm_loss', 0), batch['images'].shape[0]) 70 | mlm_loss_meter.update(ret.get('mlm_loss', 0), batch['images'].shape[0]) 71 | mcq_loss_meter.update(ret.get('mcq_loss', 0), batch['images'].shape[0]) 72 | 73 | optimizer.zero_grad() 74 | total_loss.backward() 75 | optimizer.step() 76 | synchronize() 77 | 78 | if (n_iter + 1) % log_period == 0: 79 | logger.info( 80 | f"Epoch[{epoch}] Iteration[{n_iter + 1}/{len(train_loader)}] Loss: {loss_meter.avg:.4f}, mcm_loss: {mcm_loss_meter.avg:.4f}, mcq_loss: {mcq_loss_meter.avg:.4f}, mlm_loss: {mlm_loss_meter.avg:.4f}, Acc: {acc_meter.avg:.3f}, Base Lr: {scheduler.get_lr()[0]:.2e}" 81 | ) 82 | 83 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], epoch) 84 | tb_writer.add_scalar('temperature', ret['temperature'], epoch) 85 | tb_writer.add_scalar('loss', loss_meter.avg, epoch) 86 | tb_writer.add_scalar('mcm_loss', mcm_loss_meter.avg, epoch) 87 | tb_writer.add_scalar('mlm_loss', mlm_loss_meter.avg, epoch) 88 | tb_writer.add_scalar('mcq_loss', mcq_loss_meter.avg, epoch) 89 | tb_writer.add_scalar('acc', acc_meter.avg, epoch) 90 | 91 | scheduler.step() 92 | if get_rank() == 0: 93 | end_time = time.time() 94 | time_per_batch = (end_time - start_time) / (n_iter + 1) 95 | logger.info( 96 | "Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]" 97 | .format(epoch, time_per_batch, 98 | train_loader.batch_size / time_per_batch)) 99 | if epoch % eval_period == 0: 100 | if get_rank() == 0: 101 | logger.info("Validation Results - Epoch: {}".format(epoch)) 102 | if args.distributed: 103 | ttop1, stop1, itop1 = evaluator.eval(model.module.eval()) 104 | ttop1, stop1, itop1 = evaluator.eval(model.eval()) 105 | 106 | ftop1 = (ttop1 + stop1 + itop1)/3.0 107 | 108 | torch.cuda.empty_cache() 109 | if best_ttop1 < ttop1: 110 | best_ttop1 = ttop1 111 | arguments["epoch"] = epoch 112 | checkpointer.save("text_best", **arguments) 113 | 114 | if best_stop1 < stop1: 115 | best_stop1 = stop1 116 | arguments2["epoch"] = epoch 117 | checkpointer.save("sketch_best", **arguments) 118 | 119 | if best_itop1 < itop1: 120 | best_itop1 = itop1 121 | arguments3["epoch"] = epoch 122 | checkpointer.save("fusion_best", **arguments) 123 | 124 | if best_ftop1 < ftop1: 125 | best_ftop1 = ftop1 126 | arguments4["epoch"] = epoch 127 | checkpointer.save("average_best", **arguments) 128 | 129 | logger.info(f"text best R1: {best_ttop1} at epoch {arguments['epoch']}") 130 | logger.info(f"sketch best R1: {best_stop1} at epoch {arguments2['epoch']}") 131 | logger.info(f"fusion best R1: {best_itop1} at epoch {arguments3['epoch']}") 132 | logger.info(f"average best R1: {best_ftop1} at epoch {arguments4['epoch']}") 133 | 134 | 135 | def do_inference(args, model, test_img_loader, test_txt_loader, test_sketch_loader): 136 | 137 | logger = logging.getLogger("CLIP2ReID.test") 138 | logger.info("Enter inferencing") 139 | 140 | evaluator = Evaluator(args, test_img_loader, test_txt_loader, test_sketch_loader) 141 | ttop1, stop1, itop1 = evaluator.eval(model.eval()) 142 | # top1 = evaluator.eval_by_proj(model.eval()) 143 | 144 | # table = PrettyTable(["task", "R1", "R5", "R10", "mAP"]) 145 | # table.float_format = '.4' 146 | # table.add_row(['t2i', cmc[0], cmc[4], cmc[9], mAP]) 147 | # logger.info("Validation Results: ") 148 | # logger.info('\n' + str(table)) 149 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_optimizer, build_lr_scheduler 2 | 3 | __all__ = ["build_optimizer", "build_lr_scheduler"] -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/solver/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /solver/__pycache__/build.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/solver/__pycache__/build.cpython-38.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/solver/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /solver/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .lr_scheduler import LRSchedulerWithWarmup 4 | 5 | 6 | def build_optimizer(args, model): 7 | params = [] 8 | 9 | print(f'Using {args.lr_factor} times learning rate for random init module ') 10 | 11 | for key, value in model.named_parameters(): 12 | if not value.requires_grad: 13 | continue 14 | lr = args.lr 15 | weight_decay = args.weight_decay 16 | 17 | if "cross" in key: 18 | # use large learning rate for random initialized cross modal module 19 | lr = args.lr * args.lr_factor # default 5.0 20 | if "bias" in key: 21 | lr = args.lr * args.bias_lr_factor 22 | weight_decay = args.weight_decay_bias 23 | if "classifier" in key or "mcm_head" in key or "mlm_head" in key: 24 | lr = args.lr * args.lr_factor 25 | 26 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 27 | 28 | if args.optimizer == "SGD": 29 | optimizer = torch.optim.SGD( 30 | params, lr=args.lr, momentum=args.momentum 31 | ) 32 | elif args.optimizer == "Adam": 33 | optimizer = torch.optim.Adam( 34 | params, 35 | lr=args.lr, 36 | betas=(args.alpha, args.beta), 37 | eps=1e-3, 38 | ) 39 | elif args.optimizer == "AdamW": 40 | optimizer = torch.optim.AdamW( 41 | params, 42 | lr=args.lr, 43 | betas=(args.alpha, args.beta), 44 | eps=1e-8, 45 | ) 46 | else: 47 | NotImplementedError 48 | 49 | return optimizer 50 | 51 | 52 | def build_lr_scheduler(args, optimizer): 53 | return LRSchedulerWithWarmup( 54 | optimizer, 55 | milestones=args.milestones, 56 | gamma=args.gamma, 57 | warmup_factor=args.warmup_factor, 58 | warmup_epochs=args.warmup_epochs, 59 | warmup_method=args.warmup_method, 60 | total_epochs=args.num_epoch, 61 | mode=args.lrscheduler, 62 | target_lr=args.target_lr, 63 | power=args.power, 64 | ) 65 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_right 2 | from math import cos, pi 3 | 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | 7 | class LRSchedulerWithWarmup(_LRScheduler): 8 | def __init__( 9 | self, 10 | optimizer, 11 | milestones, 12 | gamma=0.1, 13 | mode="step", 14 | warmup_factor=1.0 / 3, 15 | warmup_epochs=10, 16 | warmup_method="linear", 17 | total_epochs=100, 18 | target_lr=0, 19 | power=0.9, 20 | last_epoch=-1, 21 | ): 22 | if not list(milestones) == sorted(milestones): 23 | raise ValueError( 24 | "Milestones should be a list of" 25 | " increasing integers. Got {}".format(milestones), 26 | ) 27 | if mode not in ("step", "exp", "poly", "cosine", "linear"): 28 | raise ValueError( 29 | "Only 'step', 'exp', 'poly' or 'cosine' learning rate scheduler accepted" 30 | "got {}".format(mode) 31 | ) 32 | if warmup_method not in ("constant", "linear"): 33 | raise ValueError( 34 | "Only 'constant' or 'linear' warmup_method accepted" 35 | "got {}".format(warmup_method) 36 | ) 37 | self.milestones = milestones 38 | self.mode = mode 39 | self.gamma = gamma 40 | self.warmup_factor = warmup_factor 41 | self.warmup_epochs = warmup_epochs 42 | self.warmup_method = warmup_method 43 | self.total_epochs = total_epochs 44 | self.target_lr = target_lr 45 | self.power = power 46 | super().__init__(optimizer, last_epoch) 47 | 48 | def get_lr(self): 49 | 50 | if self.last_epoch < self.warmup_epochs: 51 | if self.warmup_method == "constant": 52 | warmup_factor = self.warmup_factor 53 | elif self.warmup_method == "linear": 54 | alpha = self.last_epoch / self.warmup_epochs 55 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 56 | return [base_lr * warmup_factor for base_lr in self.base_lrs] 57 | 58 | if self.mode == "step": 59 | return [ 60 | base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) 61 | for base_lr in self.base_lrs 62 | ] 63 | 64 | epoch_ratio = (self.last_epoch - self.warmup_epochs) / ( 65 | self.total_epochs - self.warmup_epochs 66 | ) 67 | 68 | if self.mode == "exp": 69 | factor = epoch_ratio 70 | return [base_lr * self.power ** factor for base_lr in self.base_lrs] 71 | if self.mode == "linear": 72 | factor = 1 - epoch_ratio 73 | return [base_lr * factor for base_lr in self.base_lrs] 74 | 75 | if self.mode == "poly": 76 | factor = 1 - epoch_ratio 77 | return [ 78 | self.target_lr + (base_lr - self.target_lr) * self.power ** factor 79 | for base_lr in self.base_lrs 80 | ] 81 | if self.mode == "cosine": 82 | factor = 0.5 * (1 + cos(pi * epoch_ratio)) 83 | return [ 84 | self.target_lr + (base_lr - self.target_lr) * factor 85 | for base_lr in self.base_lrs 86 | ] 87 | raise NotImplementedError 88 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from prettytable import PrettyTable 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | import torch 5 | import torch.nn.parallel 6 | import numpy as np 7 | import time 8 | import os.path as op 9 | 10 | # import clip 11 | from datasets import build_dataloader 12 | from processor.processor import do_inference 13 | from utils.checkpoint import Checkpointer 14 | from utils.logger import setup_logger 15 | from model import build_model 16 | from utils.metrics import Evaluator 17 | import argparse 18 | from utils.iotools import load_train_configs 19 | 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser(description="TranTextReID Text") 23 | parser.add_argument("--config_file", default='/data1/ccq/multimodality-ICFG/ICFG-PEDES/20221008_110136_sketch2_text_itcloss/configs.yaml') 24 | args = parser.parse_args() 25 | args = load_train_configs(args.config_file) 26 | 27 | 28 | args.training = False 29 | logger = setup_logger('CLIP2ReID', save_dir=args.output_dir, if_train=args.training) 30 | logger.info(args) 31 | device = "cuda" 32 | 33 | test_img_loader, test_txt_loader, test_sketch_loader = build_dataloader(args) 34 | model = build_model(args) 35 | checkpointer = Checkpointer(model) 36 | checkpointer.load(f=op.join(args.output_dir, 'text_best.pth')) 37 | model.to(device) 38 | 39 | do_inference(args, model, test_img_loader, test_txt_loader, test_sketch_loader) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as op 3 | import torch 4 | import numpy as np 5 | import random 6 | import time 7 | 8 | from datasets import build_dataloader 9 | from processor.processor import do_train 10 | from utils.checkpoint import Checkpointer 11 | from utils.iotools import save_train_configs 12 | from utils.logger import setup_logger 13 | from solver import build_optimizer, build_lr_scheduler 14 | from model import build_model 15 | from utils.metrics import Evaluator 16 | from utils.options import get_args 17 | from utils.comm import get_rank, synchronize 18 | 19 | 20 | def set_seed(seed=0): 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | np.random.seed(seed) 25 | random.seed(seed) 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = True 28 | 29 | 30 | if __name__ == '__main__': 31 | args = get_args() 32 | set_seed(1+get_rank()) 33 | name = args.name 34 | 35 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 36 | args.distributed = num_gpus > 1 37 | 38 | if args.distributed: 39 | torch.cuda.set_device(args.local_rank) 40 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 41 | synchronize() 42 | 43 | device = "cuda" 44 | cur_time = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 45 | args.output_dir = op.join(args.output_dir, args.dataset_name, f'{cur_time}_{name}') 46 | # args.log_dir = op.join(args.output_dir, args.log_dir) 47 | logger = setup_logger('CLIP2ReID', save_dir=args.output_dir, if_train=args.training, distributed_rank=get_rank()) 48 | logger.info("Using {} GPUs".format(num_gpus)) 49 | logger.info(str(args).replace(',', '\n')) 50 | logger.info("Training only sketch {}".format(args.only_sketch)) 51 | logger.info("Using only text {}".format(args.only_text)) 52 | logger.info("Using only fusion {}".format(args.only_fusion_loss)) 53 | logger.info("Using {} fusion method".format(args.fusion_way)) 54 | 55 | save_train_configs(args.output_dir, args) 56 | 57 | # get image-text pair datasets dataloader 58 | train_loader, val_img_loader, val_txt_loader, val_sketch_loader, num_classes = build_dataloader(args) 59 | model = build_model(args, num_classes) 60 | model.to(device) 61 | 62 | if args.distributed: 63 | model = torch.nn.parallel.DistributedDataParallel( 64 | model, 65 | device_ids=[args.local_rank], 66 | output_device=args.local_rank, 67 | # this should be removed if we update BatchNorm stats 68 | broadcast_buffers=False, 69 | ) 70 | optimizer = build_optimizer(args, model) 71 | scheduler = build_lr_scheduler(args, optimizer) 72 | 73 | is_master = get_rank() == 0 74 | checkpointer = Checkpointer(model, optimizer, scheduler, args.output_dir, is_master) 75 | evaluator = Evaluator(args, val_img_loader, val_txt_loader, val_sketch_loader) 76 | 77 | start_epoch = 1 78 | 79 | if args.resume: 80 | checkpoint = checkpointer.resume(args.resume_ckpt_file) 81 | start_epoch = checkpoint['epoch'] 82 | 83 | do_train(start_epoch, args, model, train_loader, evaluator, optimizer, scheduler, checkpointer) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/checkpoint.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/utils/__pycache__/checkpoint.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/utils/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/comm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/utils/__pycache__/comm.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/utils/__pycache__/iotools.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/utils/__pycache__/iotools.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/meter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/utils/__pycache__/meter.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/utils/__pycache__/options.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccq195/UNIReID/756571cc6cccbb6a787af2daeed446836ba87ef6/utils/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | from collections import OrderedDict 5 | 6 | import torch 7 | 8 | 9 | class Checkpointer: 10 | def __init__( 11 | self, 12 | model, 13 | optimizer=None, 14 | scheduler=None, 15 | save_dir="", 16 | save_to_disk=None, 17 | logger=None, 18 | ): 19 | self.model = model 20 | self.optimizer = optimizer 21 | self.scheduler = scheduler 22 | self.save_dir = save_dir 23 | self.save_to_disk = save_to_disk 24 | if logger is None: 25 | logger = logging.getLogger(__name__) 26 | self.logger = logger 27 | 28 | def save(self, name, **kwargs): 29 | if not self.save_dir: 30 | return 31 | 32 | if not self.save_to_disk: 33 | return 34 | 35 | data = {} 36 | data["model"] = self.model.state_dict() 37 | if self.optimizer is not None: 38 | data["optimizer"] = self.optimizer.state_dict() 39 | if self.scheduler is not None: 40 | data["scheduler"] = self.scheduler.state_dict() 41 | data.update(kwargs) 42 | 43 | save_file = os.path.join(self.save_dir, "{}.pth".format(name)) 44 | self.logger.info("Saving checkpoint to {}".format(save_file)) 45 | torch.save(data, save_file) 46 | 47 | def load(self, f=None): 48 | if not f: 49 | # no checkpoint could be found 50 | self.logger.info("No checkpoint found.") 51 | return {} 52 | self.logger.info("Loading checkpoint from {}".format(f)) 53 | checkpoint = self._load_file(f) 54 | self._load_model(checkpoint) 55 | 56 | def resume(self, f=None): 57 | if not f: 58 | # no checkpoint could be found 59 | self.logger.info("No checkpoint found.") 60 | raise IOError(f"No Checkpoint file found on {f}") 61 | self.logger.info("Loading checkpoint from {}".format(f)) 62 | checkpoint = self._load_file(f) 63 | self._load_model(checkpoint) 64 | if "optimizer" in checkpoint and self.optimizer: 65 | self.logger.info("Loading optimizer from {}".format(f)) 66 | self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 67 | if "scheduler" in checkpoint and self.scheduler: 68 | self.logger.info("Loading scheduler from {}".format(f)) 69 | self.scheduler.load_state_dict(checkpoint.pop("scheduler")) 70 | # return any further checkpoint data 71 | return checkpoint 72 | 73 | def _load_file(self, f): 74 | return torch.load(f, map_location=torch.device("cpu")) 75 | 76 | def _load_model(self, checkpoint, except_keys=None): 77 | load_state_dict(self.model, checkpoint.pop("model"), except_keys) 78 | 79 | 80 | def check_key(key, except_keys): 81 | if except_keys is None: 82 | return False 83 | else: 84 | for except_key in except_keys: 85 | if except_key in key: 86 | return True 87 | return False 88 | 89 | 90 | def align_and_update_state_dicts(model_state_dict, loaded_state_dict, except_keys=None): 91 | current_keys = sorted(list(model_state_dict.keys())) 92 | loaded_keys = sorted(list(loaded_state_dict.keys())) 93 | # get a matrix of string matches, where each (i, j) entry correspond to the size of the 94 | # loaded_key string, if it matches 95 | match_matrix = [ 96 | len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys 97 | ] 98 | match_matrix = torch.as_tensor(match_matrix).view( 99 | len(current_keys), len(loaded_keys) 100 | ) 101 | max_match_size, idxs = match_matrix.max(1) 102 | # remove indices that correspond to no-match 103 | idxs[max_match_size == 0] = -1 104 | 105 | # used for logging 106 | max_size = max([len(key) for key in current_keys]) if current_keys else 1 107 | max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 108 | log_str_template = "{: <{}} loaded from {: <{}} of shape {}" 109 | logger = logging.getLogger("PersonSearch.checkpoint") 110 | for idx_new, idx_old in enumerate(idxs.tolist()): 111 | if idx_old == -1: 112 | continue 113 | key = current_keys[idx_new] 114 | key_old = loaded_keys[idx_old] 115 | if check_key(key, except_keys): 116 | continue 117 | model_state_dict[key] = loaded_state_dict[key_old] 118 | logger.info( 119 | log_str_template.format( 120 | key, 121 | max_size, 122 | key_old, 123 | max_size_loaded, 124 | tuple(loaded_state_dict[key_old].shape), 125 | ) 126 | ) 127 | 128 | 129 | def strip_prefix_if_present(state_dict, prefix): 130 | keys = sorted(state_dict.keys()) 131 | if not all(key.startswith(prefix) for key in keys): 132 | return state_dict 133 | stripped_state_dict = OrderedDict() 134 | for key, value in state_dict.items(): 135 | stripped_state_dict[key.replace(prefix, "")] = value 136 | return stripped_state_dict 137 | 138 | 139 | def load_state_dict(model, loaded_state_dict, except_keys=None): 140 | model_state_dict = model.state_dict() 141 | # if the state_dict comes from a model that was wrapped in a 142 | # DataParallel or DistributedDataParallel during serialization, 143 | # remove the "module" prefix before performing the matching 144 | loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") 145 | align_and_update_state_dicts(model_state_dict, loaded_state_dict, except_keys) 146 | 147 | # use strict loading 148 | model.load_state_dict(model_state_dict) 149 | -------------------------------------------------------------------------------- /utils/comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains primitives for multi-gpu communication. 3 | This is useful when doing distributed training. 4 | """ 5 | 6 | import pickle 7 | 8 | import torch 9 | import torch.distributed as dist 10 | 11 | 12 | def get_world_size(): 13 | if not dist.is_available(): 14 | return 1 15 | if not dist.is_initialized(): 16 | return 1 17 | return dist.get_world_size() 18 | 19 | 20 | def get_rank(): 21 | if not dist.is_available(): 22 | return 0 23 | if not dist.is_initialized(): 24 | return 0 25 | return dist.get_rank() 26 | 27 | 28 | def is_main_process(): 29 | return get_rank() == 0 30 | 31 | 32 | def synchronize(): 33 | """ 34 | Helper function to synchronize (barrier) among all processes when 35 | using distributed training 36 | """ 37 | if not dist.is_available(): 38 | return 39 | if not dist.is_initialized(): 40 | return 41 | world_size = dist.get_world_size() 42 | if world_size == 1: 43 | return 44 | dist.barrier() 45 | 46 | 47 | def all_gather(data): 48 | """ 49 | Run all_gather on arbitrary picklable data (not necessarily tensors) 50 | Args: 51 | data: any picklable object 52 | Returns: 53 | list[data]: list of data gathered from each rank 54 | """ 55 | world_size = get_world_size() 56 | if world_size == 1: 57 | return [data] 58 | 59 | # serialized to a Tensor 60 | buffer = pickle.dumps(data) 61 | storage = torch.ByteStorage.from_buffer(buffer) 62 | tensor = torch.ByteTensor(storage).to("cuda") 63 | 64 | # obtain Tensor size of each rank 65 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 66 | size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] 67 | dist.all_gather(size_list, local_size) 68 | size_list = [int(size.item()) for size in size_list] 69 | max_size = max(size_list) 70 | 71 | # receiving Tensor from all ranks 72 | # we pad the tensor because torch all_gather does not support 73 | # gathering tensors of different shapes 74 | tensor_list = [] 75 | for _ in size_list: 76 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 77 | if local_size != max_size: 78 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 79 | tensor = torch.cat((tensor, padding), dim=0) 80 | dist.all_gather(tensor_list, tensor) 81 | 82 | data_list = [] 83 | for size, tensor in zip(size_list, tensor_list): 84 | buffer = tensor.cpu().numpy().tobytes()[:size] 85 | data_list.append(pickle.loads(buffer)) 86 | 87 | return data_list 88 | 89 | 90 | def reduce_dict(input_dict, average=True): 91 | """ 92 | Args: 93 | input_dict (dict): all the values will be reduced 94 | average (bool): whether to do average or sum 95 | Reduce the values in the dictionary from all processes so that process with rank 96 | 0 has the averaged results. Returns a dict with the same fields as 97 | input_dict, after reduction. 98 | """ 99 | world_size = get_world_size() 100 | if world_size < 2: 101 | return input_dict 102 | with torch.no_grad(): 103 | names = [] 104 | values = [] 105 | # sort the keys so that they are consistent across processes 106 | for k in sorted(input_dict.keys()): 107 | names.append(k) 108 | values.append(input_dict[k]) 109 | values = torch.stack(values, dim=0) 110 | dist.reduce(values, dst=0) 111 | if dist.get_rank() == 0 and average: 112 | # only main process gets accumulated, so only divide by 113 | # world_size in this case 114 | values /= world_size 115 | reduced_dict = {k: v for k, v in zip(names, values)} 116 | return reduced_dict 117 | -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from PIL import Image, ImageFile 7 | import errno 8 | import json 9 | import pickle as pkl 10 | import os 11 | import os.path as osp 12 | import yaml 13 | from easydict import EasyDict as edict 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | def read_image(img_path): 19 | """Keep reading image until succeed. 20 | This can avoid IOError incurred by heavy IO process.""" 21 | got_img = False 22 | if not osp.exists(img_path): 23 | raise IOError("{} does not exist".format(img_path)) 24 | while not got_img: 25 | try: 26 | img = Image.open(img_path).convert('RGB') 27 | got_img = True 28 | except IOError: 29 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 30 | pass 31 | return img 32 | 33 | 34 | def mkdir_if_missing(directory): 35 | if not osp.exists(directory): 36 | try: 37 | os.makedirs(directory) 38 | except OSError as e: 39 | if e.errno != errno.EEXIST: 40 | raise 41 | 42 | 43 | def check_isfile(path): 44 | isfile = osp.isfile(path) 45 | if not isfile: 46 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 47 | return isfile 48 | 49 | 50 | def read_json(fpath): 51 | with open(fpath, 'r') as f: 52 | obj = json.load(f) 53 | return obj 54 | 55 | 56 | def write_json(obj, fpath): 57 | mkdir_if_missing(osp.dirname(fpath)) 58 | with open(fpath, 'w') as f: 59 | json.dump(obj, f, indent=4, separators=(',', ': ')) 60 | 61 | 62 | def get_text_embedding(path, length): 63 | with open(path, 'rb') as f: 64 | word_frequency = pkl.load(f) 65 | 66 | 67 | def save_train_configs(path, args): 68 | if not os.path.exists(path): 69 | os.makedirs(path) 70 | with open(f'{path}/configs.yaml', 'w') as f: 71 | yaml.dump(vars(args), f, default_flow_style=False) 72 | 73 | def load_train_configs(path): 74 | with open(path, 'r') as f: 75 | args = yaml.load(f, Loader=yaml.FullLoader) 76 | return edict(args) -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import os.path as op 5 | 6 | 7 | def setup_logger(name, save_dir, if_train, distributed_rank=0): 8 | logger = logging.getLogger(name) 9 | logger.setLevel(logging.DEBUG) 10 | 11 | # don't log results for the non-master process 12 | if distributed_rank > 0: 13 | return logger 14 | 15 | ch = logging.StreamHandler(stream=sys.stdout) 16 | ch.setLevel(logging.DEBUG) 17 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 18 | ch.setFormatter(formatter) 19 | logger.addHandler(ch) 20 | 21 | if not op.exists(save_dir): 22 | print(f"{save_dir} is not exists, create given directory") 23 | os.makedirs(save_dir) 24 | if if_train: 25 | fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='w') 26 | else: 27 | fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='a+') 28 | fh.setLevel(logging.DEBUG) 29 | fh.setFormatter(formatter) 30 | logger.addHandler(fh) 31 | 32 | return logger -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.val = 0 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from prettytable import PrettyTable 2 | import torch 3 | import numpy as np 4 | import os 5 | import torch.nn.functional as F 6 | import logging 7 | 8 | 9 | def rank(similarity, q_pids, g_pids, max_rank=10, get_mAP=True): 10 | if get_mAP: 11 | indices = torch.argsort(similarity, dim=1, descending=True) 12 | else: 13 | # acclerate sort with topk 14 | _, indices = torch.topk( 15 | similarity, k=max_rank, dim=1, largest=True, sorted=True 16 | ) # q * topk 17 | pred_labels = g_pids[indices] # q * k 18 | matches = pred_labels.eq(q_pids.view(-1, 1)) # q * k 19 | 20 | all_cmc = matches[:, :max_rank].cumsum(1) # cumulative sum 21 | all_cmc[all_cmc > 1] = 1 22 | all_cmc = all_cmc.float().mean(0) * 100 23 | # all_cmc = all_cmc[topk - 1] 24 | 25 | if not get_mAP: 26 | return all_cmc, indices 27 | 28 | num_rel = matches.sum(1) # q 29 | tmp_cmc = matches.cumsum(1) # q * k 30 | tmp_cmc = [tmp_cmc[:, i] / (i + 1.0) for i in range(tmp_cmc.shape[1])] 31 | tmp_cmc = torch.stack(tmp_cmc, 1) * matches 32 | AP = tmp_cmc.sum(1) / num_rel # q 33 | mAP = AP.mean() * 100 34 | return all_cmc, mAP, indices 35 | 36 | 37 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, set=0, max_rank=50): 38 | """Evaluation with market1501 metric 39 | Key: for each query identity, its gallery images from the same camera view are discarded. 40 | """ 41 | num_q, num_g = distmat.shape 42 | if num_g < max_rank: 43 | max_rank = num_g 44 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 45 | indices = np.argsort(distmat, axis=1) 46 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 47 | 48 | # compute cmc curve for each query 49 | all_cmc = [] 50 | all_AP = [] 51 | all_INP = [] 52 | num_valid_q = 0. # number of valid query 53 | for q_idx in range(num_q): 54 | # get query pid and camid 55 | q_pid = q_pids[q_idx] 56 | q_camid = q_camids[q_idx] 57 | 58 | # remove gallery samples that have the same pid and camid with query 59 | if set == 2: 60 | order = indices[q_idx] 61 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 62 | keep = np.invert(remove) 63 | 64 | # compute cmc curve 65 | # binary vector, positions with value 1 are correct matches 66 | orig_cmc = matches[q_idx][keep] 67 | else: 68 | orig_cmc = matches[q_idx] 69 | 70 | if not np.any(orig_cmc): 71 | # this condition is true when query identity does not appear in gallery 72 | continue 73 | 74 | cmc = orig_cmc.cumsum() 75 | 76 | pos_idx = np.where(orig_cmc == 1) 77 | max_pos_idx = np.max(pos_idx) 78 | inp = cmc[max_pos_idx]/ (max_pos_idx + 1.0) 79 | all_INP.append(inp) 80 | 81 | cmc[cmc > 1] = 1 82 | 83 | all_cmc.append(cmc[:max_rank]) 84 | num_valid_q += 1. 85 | 86 | # compute average precision 87 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 88 | num_rel = orig_cmc.sum() 89 | tmp_cmc = orig_cmc.cumsum() 90 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 91 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 92 | AP = tmp_cmc.sum() / num_rel 93 | all_AP.append(AP) 94 | 95 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 96 | 97 | all_cmc = np.asarray(all_cmc).astype(np.float32) 98 | all_cmc = all_cmc.sum(0) / num_valid_q 99 | mAP = np.mean(all_AP) 100 | mINP = np.mean(all_INP) 101 | 102 | return all_cmc* 100, mAP* 100, mINP* 100 103 | 104 | class Evaluator(): 105 | def __init__(self, args, img_loader, txt_loader, sketch_loader, test_setting=0): 106 | self.img_loader = img_loader # gallery 107 | self.txt_loader = txt_loader # query 108 | self.sketch_loader = sketch_loader # query2 109 | self.args = args 110 | 111 | self.test_setting = test_setting 112 | self.logger = logging.getLogger("CLIP2ReID.eval") 113 | 114 | def _compute_embedding(self, model): 115 | model = model.eval() 116 | device = next(model.parameters()).device 117 | 118 | qids, qids_sketch, gids, qfeats_text, qfeats_sketch, qfeats_text_sketch, gfeats, qimage_ids, qimage_ids_sketch, gimage_ids = [], [], [], [], [], [], [], [], [], [] 119 | 120 | # text+sketch 121 | for simg, simage_id, pid, caption in self.txt_loader: 122 | caption = caption.to(device) 123 | simg = simg.to(device) 124 | with torch.no_grad(): 125 | text_feat = model.encode_text(caption) 126 | sketch_feat = model.encode_image(simg) 127 | if self.args.fusion_way in ['add', 'weight add', 'cross attention', 'parameter add', 'concat', 'global concat', 'cross attention text', 'cross attention sketch', 'concat transformer']: 128 | text_sketch_fu = model.fusion_layer(text_feat, sketch_feat, caption, way=self.args.fusion_way) 129 | # text_sketch_fu = model.fusion_layer(text_feat, sketch_feat, caption, pa=self.args.pa, way=self.args.fusion_way) 130 | else: 131 | text_sketch_fu = text_feat[torch.arange(text_feat.shape[0]), caption.argmax(dim=-1)].float() 132 | text_feat = text_feat[torch.arange(text_feat.shape[0]), caption.argmax(dim=-1)].float() 133 | 134 | qids.append(pid.view(-1)) # flatten 135 | qfeats_text.append(text_feat) 136 | qfeats_text_sketch.append(text_sketch_fu) 137 | qimage_ids.append(simage_id) 138 | 139 | 140 | qids = torch.cat(qids, 0) 141 | qfeats_text = torch.cat(qfeats_text, 0) 142 | qfeats_text_sketch = torch.cat(qfeats_text_sketch, 0) 143 | qimage_ids = torch.cat(qimage_ids, 0) 144 | 145 | # image 146 | for pid, img, image_id in self.img_loader: 147 | img = img.to(device) 148 | with torch.no_grad(): 149 | img_feat = model.encode_image(img)[:, 0, :].float() 150 | gids.append(pid.view(-1)) # flatten 151 | gfeats.append(img_feat) 152 | gimage_ids.append(image_id) 153 | 154 | gids = torch.cat(gids, 0) 155 | gfeats = torch.cat(gfeats, 0) 156 | gimage_ids = torch.cat(gimage_ids, 0) 157 | 158 | # sketch 159 | for pid, simg, simage_id in self.sketch_loader: 160 | simg = simg.to(device) 161 | with torch.no_grad(): 162 | simg_feat = model.encode_image(simg)[:, 0, :].float() 163 | qids_sketch.append(pid.view(-1)) # flatten 164 | qfeats_sketch.append(simg_feat) 165 | qimage_ids_sketch.append(simage_id) 166 | 167 | qids_sketch = torch.cat(qids_sketch, 0) 168 | qfeats_sketch = torch.cat(qfeats_sketch, 0) 169 | qimage_ids_sketch = torch.cat(qimage_ids_sketch, 0) 170 | 171 | return qfeats_text, qfeats_sketch, qfeats_text_sketch, gfeats, qids, qids_sketch, gids, qimage_ids, qimage_ids_sketch, gimage_ids 172 | 173 | def eval(self, model, i2t_metric=False): 174 | 175 | qfeats_text, qfeats_sketch, qfeats_text_sketch, gfeats, qids, qids_sketch, gids, qimage_ids, qimage_ids_sketch, gimage_ids = self._compute_embedding(model) 176 | 177 | qfeats_text = F.normalize(qfeats_text, p=2, dim=1) # text features 178 | qfeats_sketch = F.normalize(qfeats_sketch, p=2, dim=1) # sketch features 179 | qfeats_text_sketch = F.normalize(qfeats_text_sketch, p=2, dim=1) # sketch+text features 180 | 181 | gfeats = F.normalize(gfeats, p=2, dim=1) # image features 182 | 183 | similarity_text_rgb = qfeats_text @ gfeats.t() 184 | similarity_sketch_rgb = qfeats_sketch @ gfeats.t() 185 | similarity_textsketch_rgb = qfeats_text_sketch @ gfeats.t() 186 | 187 | #original gallery set for text-to-rgb retrieval 188 | t2i_cmc, t2i_mAP, t2i_mINP = eval_func(-similarity_text_rgb.detach().cpu().numpy() , qids.numpy(), gids.numpy(), qimage_ids.numpy(), gimage_ids.numpy(), set=0, max_rank=10) 189 | 190 | # remove the rgb images that used for generated sketches from gallery set 191 | t2i_cmc0, t2i_mAP0, t2i_mINP0 = eval_func(-similarity_text_rgb.detach().cpu().numpy() , qids.numpy(), gids.numpy(), qimage_ids.numpy(), gimage_ids.numpy(), set=2, max_rank=10) 192 | 193 | t2i_cmc1, t2i_mAP1, t2i_mINP1 = eval_func(-similarity_sketch_rgb.detach().cpu().numpy() , qids_sketch.numpy(), gids.numpy(), qimage_ids_sketch.numpy(), gimage_ids.numpy(), set=2, max_rank=10) 194 | t2i_cmc2, t2i_mAP2, t2i_mINP2 = eval_func(-similarity_textsketch_rgb.detach().cpu().numpy() , qids.numpy(), gids.numpy(), qimage_ids.numpy(), gimage_ids.numpy(), set=2, max_rank=10) 195 | 196 | table = PrettyTable(["task", "R1", "R5", "R10", "mAP", "mINP"]) 197 | table.add_row(['t2i-text_RGB_original', t2i_cmc[0], t2i_cmc[4], t2i_cmc[9], t2i_mAP, t2i_mINP]) 198 | table.add_row(['t2i-text_RGB', t2i_cmc0[0], t2i_cmc0[4], t2i_cmc0[9], t2i_mAP0, t2i_mINP0]) 199 | table.add_row(['t2i-sketch_RGB', t2i_cmc1[0], t2i_cmc1[4], t2i_cmc1[9], t2i_mAP1, t2i_mINP1]) 200 | table.add_row(['t2i-textsketch_RGB', t2i_cmc2[0], t2i_cmc2[4], t2i_cmc2[9], t2i_mAP2, t2i_mINP2]) 201 | # table.add_row(['t2i-text_RGB', t2i_cmc[0], t2i_cmc[4], t2i_cmc[9], t2i_mAP, t2i_mAP]) 202 | 203 | if i2t_metric: 204 | i2t_cmc, i2t_mAP, _ = rank(similarity=similarity_text_rgb.t(), q_pids=gids, g_pids=qids, max_rank=10, get_mAP=True) 205 | i2t_cmc, i2t_mAP = i2t_cmc.cpu().numpy(), i2t_mAP.cpu().numpy() 206 | table.add_row(['i2t', i2t_cmc[0], i2t_cmc[4], i2t_cmc[9], i2t_mAP]) 207 | 208 | table.float_format = '.4' 209 | self.logger.info('\n' + str(table)) 210 | 211 | return t2i_cmc[0], t2i_cmc1[0], t2i_cmc2[0] 212 | 213 | 214 | # def eval_by_proj(self, model, i2t_metric=False): 215 | 216 | # qfeats, gfeats, qids, gids = self._compute_embedding(model) 217 | 218 | # # qfeats_norm = F.normalize(qfeats, p=2, dim=1) # text features 219 | # gfeats_norm = F.normalize(gfeats, p=2, dim=1) # image features 220 | 221 | # similarity = qfeats @ gfeats_norm.t() 222 | 223 | # t2i_cmc, t2i_mAP, _ = rank(similarity=similarity, q_pids=qids, g_pids=gids, max_rank=10, get_mAP=True) 224 | # t2i_cmc, t2i_mAP = t2i_cmc.cpu().numpy(), t2i_mAP.cpu().numpy() 225 | # table = PrettyTable(["task", "R1", "R5", "R10", "mAP"]) 226 | # table.add_row(['t2i', t2i_cmc[0], t2i_cmc[4], t2i_cmc[9], t2i_mAP]) 227 | 228 | # if i2t_metric: 229 | # i2t_cmc, i2t_mAP, _ = rank(similarity=similarity.t(), q_pids=gids, g_pids=qids, max_rank=10, get_mAP=True) 230 | # i2t_cmc, i2t_mAP = i2t_cmc.cpu().numpy(), i2t_mAP.cpu().numpy() 231 | # table.add_row(['i2t', i2t_cmc[0], i2t_cmc[4], i2t_cmc[9], i2t_mAP]) 232 | # table.float_format = '.4' 233 | # self.logger.info('\n' + str(table)) 234 | 235 | # return t2i_cmc[0] 236 | 237 | -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser(description="TransTextReID") 6 | ######################## general settings ######################## 7 | parser.add_argument("--local_rank", default=0, type=int) 8 | parser.add_argument("--name", default="baseline", help="experiment name to save") 9 | # parser.add_argument("--log_dir", default="logs") 10 | parser.add_argument("--output_dir", default="/data1/ccq/multimodality") 11 | # parser.add_argument("--gpu_id", default="0", help="select gpu to run") 12 | parser.add_argument("--log_period", default=100) 13 | parser.add_argument("--eval_period", default=1) 14 | parser.add_argument("--val_dataset", default="test") # use val set when evaluate, if test use test set 15 | parser.add_argument("--resume", default=False, action='store_true') 16 | parser.add_argument("--resume_ckpt_file", default="", help='resume from ...') 17 | 18 | ######################## model general settings ######################## 19 | parser.add_argument("--pretrain_choice", default='ViT-B/16') # whether use pretrained model 20 | parser.add_argument("--temperature", type=float, default=0.07, help="initial temperature value, if 0, don't use temperature") 21 | parser.add_argument("--img_aug", default=False, action='store_true') 22 | parser.add_argument("--nlp_aug", default=False, action='store_true') 23 | # parser.add_argument("--embed_dim", type=int, default=512, help="the final visual and textual feature dim") 24 | 25 | ## cross transfomer setting 26 | parser.add_argument("--num_colors", type=int, default=60, help="num colors of Mask Color Modeling labels") 27 | parser.add_argument("--cmt_depth", type=int, default=4, help="cross modal transformer self attn layers") 28 | parser.add_argument("--masked_token_rate", type=float, default=0.8, help="masked token rate for mcm task, 1.0 indicates mask every color in a caption") 29 | parser.add_argument("--masked_token_unchanged_rate", type=float, default=0.1, help="masked token unchanged rate") 30 | parser.add_argument("--lr_factor", type=float, default=5.0, help="lr factor for random init self implement module") 31 | parser.add_argument("--use_imageid", default=False, action='store_true', help="whether to use image_id info to build soft label.") 32 | parser.add_argument("--MCQ", default=False, action='store_true', help="whether to use Multiple Choice Questions dataset") 33 | parser.add_argument("--MCM", default=False, action='store_true', help="whether to use Mask Color Modeling dataset") 34 | parser.add_argument("--MLM", default=False, action='store_true', help="whether to use Mask Language Modeling dataset") 35 | parser.add_argument("--MSM", default=False, action='store_true', help="whether to use Mask Subsequence Matching dataset") 36 | parser.add_argument("--MCQMLM", default=False, action='store_true', help="whether to use MCQMLM dataset") 37 | parser.add_argument("--MSMMLM", default=False, action='store_true', help="whether to use MSMMLM dataset") 38 | 39 | ######################## loss settings ######################## 40 | parser.add_argument("--loss_names", default='itc', help="which loss to use ['tcmpm','mcm', 'mcq', 'mlm', 'msm', 'id', 'itc', 'sdm']") 41 | parser.add_argument("--cmm_loss_weight", type=float, default=1.0, help="cross modal matching loss (tcmpm, cmpm, infonce...) weight") 42 | parser.add_argument("--mcm_loss_weight", type=float, default=1.0, help="mcm loss weight") 43 | parser.add_argument("--mlm_loss_weight", type=float, default=1.0, help="mlm loss weight") 44 | parser.add_argument("--mcq_loss_weight", type=float, default=1.0, help="mcq loss weight") 45 | parser.add_argument("--id_loss_weight", type=float, default=1.0, help="id loss weight") 46 | 47 | ######################## vison trainsformer settings ######################## 48 | parser.add_argument("--img_size", type=tuple, default=(384, 128)) 49 | parser.add_argument("--stride_size", type=int, default=16) 50 | 51 | ######################## text transformer settings ######################## 52 | parser.add_argument("--text_length", type=int, default=77) 53 | parser.add_argument("--vocab_size", type=int, default=49408) 54 | 55 | ######################## solver ######################## 56 | parser.add_argument("--learnable_loss_weight", default=False) 57 | parser.add_argument("--label_mix", default=False, action='store_true', help="whether mix pid and imagid label") 58 | parser.add_argument("--optimizer", type=str, default="Adam", help="[SGD, Adam, Adamw]") 59 | parser.add_argument("--lr", type=float, default=1e-5) 60 | parser.add_argument("--bias_lr_factor", type=float, default=2.) 61 | parser.add_argument("--momentum", type=float, default=0.9) 62 | parser.add_argument("--weight_decay", type=float, default=4e-5) 63 | parser.add_argument("--weight_decay_bias", type=float, default=0.) 64 | parser.add_argument("--alpha", type=float, default=0.9) 65 | parser.add_argument("--beta", type=float, default=0.999) 66 | 67 | ######################## scheduler ######################## 68 | parser.add_argument("--num_epoch", type=int, default=60) 69 | parser.add_argument("--milestones", type=int, nargs='+', default=(20, 50)) 70 | parser.add_argument("--gamma", type=float, default=0.1) 71 | parser.add_argument("--warmup_factor", type=float, default=0.1) 72 | parser.add_argument("--warmup_epochs", type=int, default=5) 73 | parser.add_argument("--warmup_method", type=str, default="linear") 74 | parser.add_argument("--lrscheduler", type=str, default="step") 75 | parser.add_argument("--target_lr", type=float, default=1e-8) 76 | parser.add_argument("--power", type=float, default=0.9) 77 | 78 | ######################## dataset ######################## 79 | parser.add_argument("--dataset_name", default="CUHK-PEDES", help="[CUHK-PEDES, ICFG-PEDES, F30K, RSTPReid]") 80 | parser.add_argument("--sampler", default="random", help="choose sampler from type idtentity and random") 81 | parser.add_argument("--num_instance", type=int, default=4) 82 | parser.add_argument("--root_dir", type=str, default="/data0/data_ccq/CUHK-PEDES/") 83 | parser.add_argument("--batch_size", type=int, default=64) 84 | parser.add_argument("--test_batch_size", type=int, default=512) 85 | parser.add_argument("--num_workers", type=int, default=8) 86 | parser.add_argument("--test", dest='training', default=True, action='store_false') # whether in training mode 87 | parser.add_argument("--test_setting", type=int, default=0) 88 | 89 | ######################## multi-modality model settings ######################## 90 | parser.add_argument("--fusion_way", default='add', help="[add, weight add, cross attention]") # whether use text and sketch fusion method 91 | parser.add_argument("--only_sketch", default=False, action='store_true', help="whether training with only sketch") 92 | parser.add_argument("--only_text", default=False, action='store_true', help="whether training with only text") 93 | parser.add_argument("--pa", type=float, default=0.1, help="parameter add for fusion") 94 | parser.add_argument("--only_fusion_loss", default=False, action='store_true', help="whether training with only text") 95 | parser.add_argument("--four_fusion_loss", default=False, action='store_true', help="whether training with only text") 96 | parser.add_argument("--focal_three_fusion_loss", default=False, action='store_true', help="whether training with only text") 97 | parser.add_argument("--focal_three_fusion_loss2", default=False, action='store_true', help="whether training with only text") 98 | parser.add_argument("--focal_three_fusion_loss3", default=False, action='store_true', help="sketch label kl") 99 | parser.add_argument("--focal_three_fusion_loss4", default=False, action='store_true', help=" text label kl") 100 | parser.add_argument("--focal_three_fusion_loss5", default=False, action='store_true', help=" text label two kl") 101 | parser.add_argument("--focal_three_fusion_loss6", default=False, action='store_true', help=" text label two kl") 102 | parser.add_argument("--focalthree_fusion_loss", default=False, action='store_true', help="whether training with only text") 103 | parser.add_argument("--focalthree_four_fusion_loss", default=False, action='store_true', help="whether training with only text") 104 | parser.add_argument("--al", type=float, default=1.0, help="parameter add for fusion") 105 | parser.add_argument("--ga", type=float, default=2.0, help="parameter add for fusion") 106 | parser.add_argument("--klp", type=float, default=1.0, help="parameter add for fusion") 107 | 108 | args = parser.parse_args() 109 | 110 | return args -------------------------------------------------------------------------------- /utils/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "../data/bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | 74 | vocab.pop(-1) # remove last one in vocab(jekyll) to keep vocab_size unchanged 75 | vocab.extend(['<|mask|>', '<|startoftext|>', '<|endoftext|>']) # vocab_size 49408 76 | # vocab.extend(['<|startoftext|>', '<|endoftext|>']) # vocab_size 49408 77 | self.encoder = dict(zip(vocab, range(len(vocab)))) 78 | self.decoder = {v: k for k, v in self.encoder.items()} 79 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 80 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|mask|>': '<|mask|>', '<|endoftext|>': '<|endoftext|>'} 81 | self.pat = re.compile(r"""<\|startoftext\|>|<\|mask\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 82 | 83 | def bpe(self, token): 84 | if token in self.cache: 85 | return self.cache[token] 86 | word = tuple(token[:-1]) + ( token[-1] + '',) 87 | pairs = get_pairs(word) 88 | 89 | if not pairs: 90 | return token+'' 91 | 92 | while True: 93 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 94 | if bigram not in self.bpe_ranks: 95 | break 96 | first, second = bigram 97 | new_word = [] 98 | i = 0 99 | while i < len(word): 100 | try: 101 | j = word.index(first, i) 102 | new_word.extend(word[i:j]) 103 | i = j 104 | except: 105 | new_word.extend(word[i:]) 106 | break 107 | 108 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 109 | new_word.append(first+second) 110 | i += 2 111 | else: 112 | new_word.append(word[i]) 113 | i += 1 114 | new_word = tuple(new_word) 115 | word = new_word 116 | if len(word) == 1: 117 | break 118 | else: 119 | pairs = get_pairs(word) 120 | word = ' '.join(word) 121 | self.cache[token] = word 122 | return word 123 | 124 | def encode(self, text): 125 | bpe_tokens = [] 126 | text = whitespace_clean(basic_clean(text)).lower() 127 | for token in re.findall(self.pat, text): 128 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 129 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 130 | return bpe_tokens 131 | 132 | def decode(self, tokens): 133 | text = ''.join([self.decoder[token] for token in tokens]) 134 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 135 | return text 136 | --------------------------------------------------------------------------------