├── figs ├── overview.png └── comparison.png ├── .gitignore ├── configs ├── few-shot │ ├── dist_fewshot_sim_base.sh │ └── slurm_fewshot_sim_base.sh ├── semisup_rebuttal.sh ├── semisup_sim_base_400ep.sh ├── linprobe │ ├── dist_linprobe_sim_base.sh │ └── slurm_linprobe_sim_base.sh ├── finetune │ ├── dist_finetune_sim_base.sh │ ├── dist_finetune_sim_base_eval.sh │ ├── slurm_finetune_sim_base.sh │ └── slurm_finetune_sim_base_eval.sh ├── semisup_sim_large_1600ep.sh └── pretrain │ ├── dist_sim_base_1600ep.sh │ └── slurm_sim_base_1600ep.sh ├── util ├── lr_sched.py ├── crop.py ├── lars.py ├── lr_decay.py ├── datasets.py ├── augmentation.py ├── masking_generator.py ├── pos_embed.py ├── tcs_datasets.py └── misc.py ├── docs ├── few_shot.md ├── linear_eval.md ├── checkpoints.md ├── pretrain.md ├── finetune.md └── prepare.md ├── engine_finetune.py ├── engine_pretrain.py ├── README.md ├── models_vit.py ├── main_linprobe.py ├── main_pretrain.py ├── main_logistic.py ├── main_finetune.py └── LICENSE /figs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/Siamese-Image-Modeling/HEAD/figs/overview.png -------------------------------------------------------------------------------- /figs/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/Siamese-Image-Modeling/HEAD/figs/comparison.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | exp/ 2 | **/__pycache__/ 3 | *.pth 4 | batchscript-* 5 | phoenix-slurm-* 6 | .ipynb_checkpoints/ 7 | .idea/ 8 | .vscode/ 9 | -------------------------------------------------------------------------------- /configs/few-shot/dist_fewshot_sim_base.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | IP=${1} 4 | RANK=${2} 5 | NNODES=${3} 6 | CKPT_PATH=${4} 7 | DATA_PATH=${5} 8 | PORT=${PORT:-28500} 9 | PY_ARGS=${PY_ARGS:-""} 10 | 11 | BASENAME=$(basename ${CKPT_PATH}) 12 | EXP_NAME=$(basename $(dirname ${CKPT_PATH})) 13 | DIR=./exp/fewshot/${EXP_NAME} 14 | 15 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=${NNODES} --node_rank=${RANK} --master_addr=${IP} --master_port=${PORT} \ 16 | main_logistic.py \ 17 | --subset-path imagenet_subset1/1percent.txt \ 18 | --root-path ${DATA_PATH} \ 19 | --image-folder imagenet_full_size/061417/ \ 20 | --device cuda:0 \ 21 | --pretrained ${CKPT_PATH} \ 22 | --fname 'fewshot_1percent.pth' \ 23 | --model-name 'vit_base_patch16' \ 24 | --penalty l2 \ 25 | --lambd 0.1 \ 26 | --preload -------------------------------------------------------------------------------- /configs/semisup_rebuttal.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | GPUS=${1} 4 | GPUS_PER_NODE=${2} 5 | JOB_NAME=${3} 6 | QUOTATYPE=${4} 7 | PARTITION=${5} 8 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 9 | 10 | DIR=./exp/semisup_ibot_400ep 11 | CKPT=./ckpt/ibot.pth 12 | 13 | srun --partition=vc_research_${PARTITION} \ 14 | --mpi=pmi2 \ 15 | --quotatype=${QUOTATYPE} \ 16 | --job-name=${JOB_NAME} \ 17 | -n$GPUS \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks-per-node=${GPUS_PER_NODE} \ 20 | --cpus-per-task=$CPUS_PER_TASK \ 21 | --kill-on-bad-exit=1 \ 22 | --dependency=singleton \ 23 | python -W ignore -u main_logistic.py \ 24 | --subset-path imagenet_subset1/1percent.txt \ 25 | --root-path /mnt/cache/share/images \ 26 | --image-folder imagenet_full_size/061417/ \ 27 | --device cuda:0 \ 28 | --pretrained ${CKPT} \ 29 | --fname 'semisup.pth' \ 30 | --model-name 'vit_base_patch16' \ 31 | --penalty l2 \ 32 | --lambd 0.1 -------------------------------------------------------------------------------- /configs/semisup_sim_base_400ep.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | GPUS=${1} 4 | GPUS_PER_NODE=${2} 5 | JOB_NAME=${3} 6 | QUOTATYPE=${4} 7 | PARTITION=${5} 8 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 9 | 10 | DIR=./exp/semisup_sim_base_1600ep 11 | CKPT=./exp/pretrain_sim_base_400ep/checkpoint-399.pth 12 | 13 | srun --partition=vc_research_${PARTITION} \ 14 | --mpi=pmi2 \ 15 | --quotatype=${QUOTATYPE} \ 16 | --job-name=${JOB_NAME} \ 17 | -n$GPUS \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks-per-node=${GPUS_PER_NODE} \ 20 | --cpus-per-task=$CPUS_PER_TASK \ 21 | --kill-on-bad-exit=1 \ 22 | python -W ignore -u main_logistic.py \ 23 | --subset-path imagenet_subset1/1percent.txt \ 24 | --root-path /mnt/cache/share/images \ 25 | --image-folder imagenet_full_size/061417/ \ 26 | --device cuda:0 \ 27 | --pretrained ${CKPT} \ 28 | --fname 'semisup.pth' \ 29 | --model-name 'vit_base_patch16' \ 30 | --penalty l2 \ 31 | --lambd 0.1 -------------------------------------------------------------------------------- /configs/linprobe/dist_linprobe_sim_base.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | IP=${1} 4 | RANK=${2} 5 | NNODES=${3} 6 | CKPT_PATH=${4} 7 | DATA_PATH=${5} 8 | PORT=${PORT:-28500} 9 | PY_ARGS=${PY_ARGS:-""} 10 | 11 | TOTAL_BATCH_SIZE=16384 12 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${NNODES}/8 13 | 14 | BASENAME=$(basename ${CKPT_PATH}) 15 | EXP_NAME=$(basename $(dirname ${CKPT_PATH})) 16 | DIR=./exp/linear/${EXP_NAME} 17 | 18 | mkdir -p ${DIR} 19 | 20 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=${NNODES} --node_rank=${RANK} --master_addr=${IP} --master_port=${PORT} \ 21 | main_linprobe.py \ 22 | --batch_size ${BATCH_SIZE} \ 23 | --model vit_base_patch16 \ 24 | --finetune ${CKPT_PATH} \ 25 | --epochs 90 \ 26 | --blr 0.1 \ 27 | --weight_decay 0.0 \ 28 | --dist_eval \ 29 | --output_dir ${DIR} \ 30 | --log_dir ${DIR} \ 31 | --global_pool \ 32 | --data_path ${DATA_PATH} \ 33 | --use_tcs_dataset \ 34 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt -------------------------------------------------------------------------------- /configs/finetune/dist_finetune_sim_base.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | IP=${1} 4 | RANK=${2} 5 | NNODES=${3} 6 | CKPT_PATH=${4} 7 | DATA_PATH=${5} 8 | PORT=${PORT:-28500} 9 | PY_ARGS=${PY_ARGS:-""} 10 | 11 | TOTAL_BATCH_SIZE=1024 12 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${NNODES}/8 13 | 14 | BASENAME=$(basename ${CKPT_PATH}) 15 | EXP_NAME=$(basename $(dirname ${CKPT_PATH})) 16 | DIR=./exp/finetune/${EXP_NAME} 17 | 18 | mkdir -p ${DIR} 19 | 20 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=${NNODES} --node_rank=${RANK} --master_addr=${IP} --master_port=${PORT} \ 21 | main_finetune.py \ 22 | --output_dir ${DIR} \ 23 | --log_dir ${DIR} \ 24 | --batch_size ${BATCH_SIZE} \ 25 | --model vit_base_patch16 \ 26 | --finetune ${CKPT_PATH} \ 27 | --epochs 100 \ 28 | --blr 2.5e-4 --layer_decay 0.65 \ 29 | --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 30 | --dist_eval --data_path ${DATA_PATH} \ 31 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt -------------------------------------------------------------------------------- /configs/few-shot/slurm_fewshot_sim_base.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | GPUS=${1} 4 | GPUS_PER_NODE=${2} 5 | QUOTATYPE=${3} 6 | PARTITION=${4} 7 | CKPT_PATH=${5} 8 | DATA_PATH=${6} 9 | CPUS_PER_TASK=${CPUS_PER_TASK:-12} 10 | 11 | BASENAME=$(basename ${CKPT_PATH}) 12 | EXP_NAME=$(basename $(dirname ${CKPT_PATH})) 13 | DIR=./exp/fewshot/${EXP_NAME} 14 | JOB_NAME=fewshot-${EXP} 15 | 16 | srun --partition=${PARTITION} \ 17 | --mpi=pmi2 \ 18 | --quotatype=${QUOTATYPE} \ 19 | --job-name=${JOB_NAME} \ 20 | -n$GPUS \ 21 | --gres=gpu:${GPUS_PER_NODE} \ 22 | --ntasks-per-node=${GPUS_PER_NODE} \ 23 | --cpus-per-task=$CPUS_PER_TASK \ 24 | --kill-on-bad-exit=1 \ 25 | python -W ignore -u main_logistic.py \ 26 | --subset-path imagenet_subset1/1percent.txt \ 27 | --root-path ${DATA_PATH} \ 28 | --image-folder imagenet_full_size/061417/ \ 29 | --device cuda:0 \ 30 | --pretrained ${CKPT_PATH} \ 31 | --fname 'fewshot_1percent.pth' \ 32 | --model-name 'vit_base_patch16' \ 33 | --penalty l2 \ 34 | --lambd 0.1 -------------------------------------------------------------------------------- /configs/finetune/dist_finetune_sim_base_eval.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | IP=${1} 4 | RANK=${2} 5 | NNODES=${3} 6 | CKPT_PATH=${4} 7 | DATA_PATH=${5} 8 | PORT=${PORT:-28500} 9 | PY_ARGS=${PY_ARGS:-""} 10 | 11 | TOTAL_BATCH_SIZE=1024 12 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${NNODES}/8 13 | 14 | BASENAME=$(basename ${CKPT_PATH}) 15 | EXP_NAME=$(basename $(dirname ${CKPT_PATH})) 16 | DIR=./exp/finetune/${EXP_NAME} 17 | 18 | mkdir -p ${DIR} 19 | 20 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=${NNODES} --node_rank=${RANK} --master_addr=${IP} --master_port=${PORT} \ 21 | main_finetune.py \ 22 | --output_dir ${DIR} \ 23 | --log_dir ${DIR} \ 24 | --batch_size ${BATCH_SIZE} \ 25 | --model vit_base_patch16 \ 26 | --resume ${CKPT_PATH} \ 27 | --epochs 100 \ 28 | --blr 2.5e-4 --layer_decay 0.65 \ 29 | --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 30 | --dist_eval --data_path ${DATA_PATH} \ 31 | --eval \ 32 | --use_tcs_dataset \ 33 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt -------------------------------------------------------------------------------- /configs/semisup_sim_large_1600ep.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | GPUS=${1} 4 | GPUS_PER_NODE=${2} 5 | JOB_NAME=${3} 6 | QUOTATYPE=${4} 7 | PARTITION=${5} 8 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 9 | 10 | DIR=./exp/semisup_sim_large_1600ep 11 | CKPT=./exp/pretrain_sim_large_1600ep/checkpoint-latest.pth 12 | 13 | srun --partition=vc_research_${PARTITION} \ 14 | --mpi=pmi2 \ 15 | --quotatype=${QUOTATYPE} \ 16 | --job-name=${JOB_NAME} \ 17 | -n$GPUS \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks-per-node=${GPUS_PER_NODE} \ 20 | --cpus-per-task=$CPUS_PER_TASK \ 21 | --kill-on-bad-exit=1 \ 22 | --dependency=singleton \ 23 | -x SH-IDC1-10-142-5-[45,13,70,198],SH-IDC1-10-142-4-[187,93,188,46,165,83,151,146,26] \ 24 | python -W ignore -u main_logistic.py \ 25 | --subset-path imagenet_subset1/1percent.txt \ 26 | --root-path /mnt/cache/share/images \ 27 | --image-folder imagenet_full_size/061417/ \ 28 | --device cuda:0 \ 29 | --pretrained ${CKPT} \ 30 | --fname 'semisup.pth' \ 31 | --model-name 'vit_large_patch16' \ 32 | --penalty l2 \ 33 | --lambd 0.01 -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MAE (https://github.com/facebookresearch/mae) 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | 9 | 10 | import math 11 | 12 | def adjust_learning_rate(optimizer, epoch, args): 13 | """Decay the learning rate with half-cycle cosine after warmup""" 14 | if epoch < args.warmup_epochs: 15 | lr = args.lr * epoch / args.warmup_epochs 16 | else: 17 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 18 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 19 | for param_group in optimizer.param_groups: 20 | if "lr_scale" in param_group: 21 | param_group["lr"] = lr * param_group["lr_scale"] 22 | else: 23 | param_group["lr"] = lr 24 | return lr 25 | -------------------------------------------------------------------------------- /configs/pretrain/dist_sim_base_1600ep.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | IP=${1} 4 | RANK=${2} 5 | NNODES=${3} 6 | DATA_PATH=${4} 7 | PORT=${PORT:-28500} 8 | PY_ARGS=${PY_ARGS:-""} 9 | 10 | BASENAME=`basename ${0} .sh` 11 | DIR=./exp/pretrain/${BASENAME} 12 | mkdir -p ${DIR} 13 | 14 | TOTAL_BATCH_SIZE=4096 15 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${NNODES}/8 16 | 17 | EPOCHS=1600 18 | 19 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=${NNODES} --node_rank=${RANK} --master_addr=${IP} --master_port=${PORT} \ 20 | main_pretrain.py \ 21 | --model sim_vit_base_patch16 \ 22 | --decoder_embed_dim 768 \ 23 | --batch_size ${BATCH_SIZE} \ 24 | --epochs ${EPOCHS} \ 25 | --warmup_epochs 40 \ 26 | --crop_min 0.08 \ 27 | --with_blockwise_mask \ 28 | --blockwise_num_masking_patches 118 \ 29 | --blr 6.25e-5 --weight_decay 0.05 \ 30 | --mm 0.995 \ 31 | --mmschedule 'cosine' \ 32 | --clip_grad 1.0 \ 33 | --loss_type 'sim' \ 34 | --neg_weight 0.02 \ 35 | --save_latest_freq 5 \ 36 | --output_dir ${DIR} \ 37 | --log_dir ${DIR} \ 38 | --data_path ${DATA_PATH} \ 39 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt 40 | -------------------------------------------------------------------------------- /docs/few_shot.md: -------------------------------------------------------------------------------- 1 | # Few-shot Evaluation 2 | 3 | We provide the few-shot evaluation scripts here. We only use 1% ImageNet labelled data to train the model. We follow [MSN](https://github.com/facebookresearch/msn/blob/main/logistic_eval.py) to train a linear classifier on the representation, without tuning model's parameters. 4 | 5 | ## Train with torch.distributed.launch 6 | Few-shot evaluation does not require high computational resources, so it is enough to run the scripts on a single node, shown as follows. 7 | 8 | ``` 9 | sh ./configs/few-shot/dist_fewshot_sim_base.sh ${MASTER_ADDR} 0 1 ${CKPT_PATH} ${DATA_PATH} 10 | ``` 11 | 12 | Note: 13 | The `${MASTER_ADDR}` is the ip address of rank 0 node. The second and third arguments specify the node rank and node number respectively. You need to adjust them if different node numbders are used. 14 | 15 | ## Train on a slurm cluster 16 | If you need to run the few-shot evaluation on a slurm cluster, use the command below to run on `${GPUS}/${GPUS_PER_NODE}` nodes with `${GPUS_PER_NODE}` gpus on each node: 17 | ``` 18 | sh ./configs/few-shot/slurm_fewshot_sim_base.sh ${GPUS} ${GPUS_PER_NODE} ${QUOTATYPE} ${PARTITION} ${CKPT_PATH} ${DATA_PATH} 19 | ``` 20 | -------------------------------------------------------------------------------- /docs/linear_eval.md: -------------------------------------------------------------------------------- 1 | # Linear Evaluation 2 | 3 | We provide the linear evaluation scripts here. The evaluation setting mainly follows MAE, which uses 16384 batch size and LARS optimizer. 4 | 5 | ## Train with torch.distributed.launch 6 | This method supports training on multi-nodes with torch.distributed.launch. For example, to conduct linear evaluation on 2 nodes, run the command below. 7 | 8 | On node 1: 9 | ``` 10 | sh ./configs/linprobe/dist_linprobe_sim_base.sh ${MASTER_ADDR} 0 2 ${CKPT_PATH} ${DATA_PATH} 11 | ``` 12 | 13 | On node 2: 14 | ``` 15 | sh ./configs/linprobe/dist_linprobe_sim_base.sh ${MASTER_ADDR} 1 2 ${CKPT_PATH} ${DATA_PATH} 16 | ``` 17 | 18 | Note: 19 | The `${MASTER_ADDR}` is the ip address of rank 0 node. The second and third arguments specify the node rank and node number respectively. You need to adjust them if different node numbders are used. 20 | 21 | ## Train on a slurm cluster 22 | If you need to run the linear evaluation on a slurm cluster, use the command below to run on `${GPUS}/${GPUS_PER_NODE}` nodes with `${GPUS_PER_NODE}` gpus on each node: 23 | ``` 24 | sh ./configs/linprobe/slurm_linprobe_sim_base.sh ${GPUS} ${GPUS_PER_NODE} ${QUOTATYPE} ${PARTITION} ${CKPT_PATH} ${DATA_PATH} 25 | ``` 26 | -------------------------------------------------------------------------------- /docs/checkpoints.md: -------------------------------------------------------------------------------- 1 | # Checkpoints 2 | We provide links for you to download the checkpoints of SiameseIM models here. 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 |
ModelBackbonePretrained EpochFinetuned on ImageNetLink
SiameseIMViT-Base1600w/oDownload
SiameseIMftViT-Base1600w/Download
17 | 18 | * The SiameseIM model is only pretrained on ImageNet datasets for 1600 epochs. For pretraining details, see [pretrain.md](./pretrain.md). 19 | * The SiameseIM$`_{\mathrm{ft}}`$ model is first pretrained for 1600 epochs, and the finetuned with ImageNet classification task for 100 epochs. For finetuning details, see [finetune.md](./finetune.md). 20 | * More pre-trained weights will be released. 21 | -------------------------------------------------------------------------------- /configs/linprobe/slurm_linprobe_sim_base.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | GPUS=${1} 4 | GPUS_PER_NODE=${2} 5 | QUOTATYPE=${3} 6 | PARTITION=${4} 7 | CKPT_PATH=${5} 8 | DATA_PATH=${6} 9 | CPUS_PER_TASK=${CPUS_PER_TASK:-12} 10 | SRUN_ARGS=${SRUN_ARGS:-""} 11 | PY_ARGS=${PY_ARGS:-""} 12 | 13 | 14 | TOTAL_BATCH_SIZE=16384 15 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${GPUS} 16 | 17 | BASENAME=$(basename ${CKPT_PATH}) 18 | EXP_NAME=$(basename $(dirname ${CKPT_PATH})) 19 | DIR=./exp/linear/${EXP_NAME} 20 | JOB_NAME=lin-${EXP} 21 | 22 | mkdir -p ${DIR} 23 | 24 | srun --partition=${PARTITION} \ 25 | --mpi=pmi2 \ 26 | --open-mode=append \ 27 | --quotatype=${QUOTATYPE} \ 28 | --job-name=${JOB_NAME} \ 29 | -n$GPUS \ 30 | --gres=gpu:${GPUS_PER_NODE} \ 31 | --ntasks-per-node=${GPUS_PER_NODE} \ 32 | --cpus-per-task=$CPUS_PER_TASK \ 33 | --kill-on-bad-exit=1 \ 34 | ${SRUN_ARGS} \ 35 | python -u main_linprobe.py \ 36 | --batch_size ${BATCH_SIZE} \ 37 | --model vit_base_patch16 \ 38 | --finetune ${CKPT_PATH} \ 39 | --epochs 90 \ 40 | --blr 0.1 \ 41 | --weight_decay 0.0 \ 42 | --dist_eval \ 43 | --output_dir ${DIR} \ 44 | --log_dir ${DIR} \ 45 | --global_pool \ 46 | --data_path ${DATA_PATH} \ 47 | --use_tcs_dataset \ 48 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt -------------------------------------------------------------------------------- /configs/finetune/slurm_finetune_sim_base.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | GPUS=${1} 4 | GPUS_PER_NODE=${2} 5 | QUOTATYPE=${3} 6 | PARTITION=${4} 7 | CPUS_PER_TASK=${CPUS_PER_TASK:-12} 8 | CKPT_PATH=${5} 9 | DATA_PATH=${6} 10 | SRUN_ARGS=${SRUN_ARGS:-""} 11 | PY_ARGS=${PY_ARGS:-""} 12 | 13 | 14 | TOTAL_BATCH_SIZE=1024 15 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${GPUS} 16 | 17 | BASENAME=$(basename ${CKPT_PATH}) 18 | EXP_NAME=$(basename $(dirname ${CKPT_PATH})) 19 | DIR=./exp/finetune/${EXP_NAME} 20 | JOB_NAME=ft-${EXP} 21 | 22 | mkdir -p ${DIR} 23 | 24 | srun --partition=${PARTITION} \ 25 | --mpi=pmi2 \ 26 | --quotatype=${QUOTATYPE} \ 27 | --job-name=${JOB_NAME} \ 28 | -n$GPUS \ 29 | --gres=gpu:${GPUS_PER_NODE} \ 30 | --ntasks-per-node=${GPUS_PER_NODE} \ 31 | --cpus-per-task=$CPUS_PER_TASK \ 32 | --kill-on-bad-exit=1 \ 33 | ${SRUN_ARGS} \ 34 | python -u main_finetune.py \ 35 | --output_dir ${DIR} \ 36 | --log_dir ${DIR} \ 37 | --batch_size ${BATCH_SIZE} \ 38 | --model vit_base_patch16 \ 39 | --finetune ${CKPT_PATH} \ 40 | --epochs 100 \ 41 | --blr 2.5e-4 --layer_decay 0.65 \ 42 | --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 43 | --dist_eval --data_path ${DATA_PATH} \ 44 | --use_tcs_dataset \ 45 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt 46 | -------------------------------------------------------------------------------- /configs/finetune/slurm_finetune_sim_base_eval.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | GPUS=${1} 4 | GPUS_PER_NODE=${2} 5 | QUOTATYPE=${3} 6 | PARTITION=${4} 7 | CPUS_PER_TASK=${CPUS_PER_TASK:-12} 8 | CKPT_PATH=${5} 9 | DATA_PATH=${6} 10 | SRUN_ARGS=${SRUN_ARGS:-""} 11 | PY_ARGS=${PY_ARGS:-""} 12 | 13 | 14 | TOTAL_BATCH_SIZE=1024 15 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${GPUS} 16 | 17 | BASENAME=$(basename ${CKPT_PATH}) 18 | EXP_NAME=$(basename $(dirname ${CKPT_PATH})) 19 | DIR=./exp/finetune/${EXP_NAME} 20 | JOB_NAME=ft-${EXP} 21 | 22 | mkdir -p ${DIR} 23 | 24 | srun --partition=${PARTITION} \ 25 | --mpi=pmi2 \ 26 | --quotatype=${QUOTATYPE} \ 27 | --job-name=${JOB_NAME} \ 28 | -n$GPUS \ 29 | --gres=gpu:${GPUS_PER_NODE} \ 30 | --ntasks-per-node=${GPUS_PER_NODE} \ 31 | --cpus-per-task=$CPUS_PER_TASK \ 32 | --kill-on-bad-exit=1 \ 33 | ${SRUN_ARGS} \ 34 | python -u main_finetune.py \ 35 | --output_dir ${DIR} \ 36 | --log_dir ${DIR} \ 37 | --batch_size ${BATCH_SIZE} \ 38 | --model vit_base_patch16 \ 39 | --resume ${CKPT_PATH} \ 40 | --epochs 100 \ 41 | --blr 2.5e-4 --layer_decay 0.65 \ 42 | --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 43 | --dist_eval --data_path ${DATA_PATH} \ 44 | --eval \ 45 | --use_tcs_dataset \ 46 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt 47 | -------------------------------------------------------------------------------- /configs/pretrain/slurm_sim_base_1600ep.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | GPUS=${1} 4 | GPUS_PER_NODE=${2} 5 | JOB_NAME=${3} 6 | QUOTATYPE=${4} 7 | PARTITION=${5} 8 | DATA_PATH=${6} 9 | CPUS_PER_TASK=${CPUS_PER_TASK:-8} 10 | SRUN_ARGS=${SRUN_ARGS:-""} 11 | PY_ARGS=${PY_ARGS:-""} 12 | 13 | BASENAME=`basename ${0} .sh` 14 | DIR=./exp/pretrain/${BASENAME} 15 | mkdir -p ${DIR} 16 | 17 | TOTAL_BATCH_SIZE=4096 18 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${GPUS} 19 | 20 | EPOCHS=1600 21 | 22 | srun --partition=${PARTITION} \ 23 | --mpi=pmi2 \ 24 | --quotatype=${QUOTATYPE} \ 25 | --job-name=${JOB_NAME} \ 26 | -n$GPUS \ 27 | --gres=gpu:${GPUS_PER_NODE} \ 28 | --ntasks-per-node=${GPUS_PER_NODE} \ 29 | --cpus-per-task=$CPUS_PER_TASK \ 30 | --kill-on-bad-exit=1 \ 31 | ${SRUN_ARGS} \ 32 | python -u main_pretrain.py \ 33 | --model sim_vit_base_patch16 \ 34 | --decoder_embed_dim 768 \ 35 | --batch_size ${BATCH_SIZE} \ 36 | --epochs ${EPOCHS} \ 37 | --warmup_epochs 40 \ 38 | --crop_min 0.08 \ 39 | --with_blockwise_mask \ 40 | --blockwise_num_masking_patches 118 \ 41 | --blr 6.25e-5 --weight_decay 0.05 \ 42 | --mm 0.995 \ 43 | --mmschedule 'cosine' \ 44 | --clip_grad 1.0 \ 45 | --loss_type 'sim' \ 46 | --neg_weight 0.02 \ 47 | --save_latest_freq 5 \ 48 | --output_dir ${DIR} \ 49 | --log_dir ${DIR} \ 50 | --data_path ${DATA_PATH} \ 51 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt 52 | -------------------------------------------------------------------------------- /docs/pretrain.md: -------------------------------------------------------------------------------- 1 | # Pretrain 2 | 3 | We provide the pretraining scripts here. To pretrain a SiameseIM model, it is recommended that 4 | * use 4096 batch size, which should fit into 32 V100 gpus with 32G memory; 5 | * pretrain for 1600 epochs for better performance. We also note that pretraining SiameseIM for 400 epochs can already match the performances of 1600 epoch MAE on some tasks; 6 | * We provide the 1600 epoch pretrained checkpoint in [checkpoints.md](./checkpoints.md). 7 | 8 | ## Train with torch.distributed.launch 9 | This method supports training on multi-nodes with torch.distributed.launch. For example, to pretrain a SiameseIM model on 2 nodes, run the command below. 10 | 11 | On node 1: 12 | ``` 13 | sh ./configs/pretrain/dist_sim_base_1600ep.sh ${MASTER_ADDR} 0 2 ${DATA_PATH} 14 | ``` 15 | 16 | On node 2: 17 | ``` 18 | sh ./configs/pretrain/dist_sim_base_1600ep.sh ${MASTER_ADDR} 1 2 ${DATA_PATH} 19 | ``` 20 | 21 | Note: 22 | The `${MASTER_ADDR}` is the ip address of rank 0 node. The second and third arguments specify the node rank and node number respectively. You need to adjust them if different node numbders are used. 23 | 24 | ## Train on a slurm cluster 25 | If you need to run the pretraining on a slurm cluster, use the command below to run on `${GPUS}/${GPUS_PER_NODE}` nodes with `${GPUS_PER_NODE}` gpus on each node: 26 | ``` 27 | sh ./configs/pretrain/slurm_sim_base_1600ep.sh ${GPUS} ${GPUS_PER_NODE} ${JOB_NAME} ${QUOTATYPE} ${PARTITION} ${DATA_PATH} 28 | ``` 29 | -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MAE (https://github.com/facebookresearch/mae) 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | 9 | import math 10 | 11 | import torch 12 | 13 | from torchvision import transforms 14 | from torchvision.transforms import functional as F 15 | 16 | 17 | class RandomResizedCrop(transforms.RandomResizedCrop): 18 | """ 19 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 20 | This may lead to results different with torchvision's version. 21 | Following BYOL's TF code: 22 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 23 | """ 24 | @staticmethod 25 | def get_params(img, scale, ratio): 26 | width, height = F.get_image_size(img) 27 | area = height * width 28 | 29 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 30 | log_ratio = torch.log(torch.tensor(ratio)) 31 | aspect_ratio = torch.exp( 32 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 33 | ).item() 34 | 35 | w = int(round(math.sqrt(target_area * aspect_ratio))) 36 | h = int(round(math.sqrt(target_area / aspect_ratio))) 37 | 38 | w = min(w, width) 39 | h = min(h, height) 40 | 41 | i = torch.randint(0, height - h + 1, size=(1,)).item() 42 | j = torch.randint(0, width - w + 1, size=(1,)).item() 43 | 44 | return i, j, h, w 45 | -------------------------------------------------------------------------------- /docs/finetune.md: -------------------------------------------------------------------------------- 1 | # Finetune 2 | 3 | We provide the finetuning scripts here. To finetune a SiameseIM model, it is recommended that 4 | * use 1024 batch size, which should fit into 8 V100 gpus with 32G memory; 5 | * We provide the finetuned checkpoint in [checkpoints.md](./checkpoints.md). 6 | 7 | ## Train with torch.distributed.launch 8 | This method supports training on multi-nodes with torch.distributed.launch. For example, to finetune a SiameseIM model on 2 nodes, run the command below. 9 | 10 | On node 1: 11 | ``` 12 | sh ./configs/finetune/dist_finetune_sim_base.sh ${MASTER_ADDR} 0 2 ${CKPT_PATH} ${DATA_PATH} 13 | ``` 14 | 15 | On node 2: 16 | ``` 17 | sh ./configs/finetune/dist_finetune_sim_base.sh ${MASTER_ADDR} 1 2 ${CKPT_PATH} ${DATA_PATH} 18 | ``` 19 | 20 | Note: 21 | The `${MASTER_ADDR}` is the ip address of rank 0 node. The second and third arguments specify the node rank and node number respectively. You need to adjust them if different node numbders are used. 22 | 23 | ## Train on a slurm cluster 24 | If you need to run the finetuning on a slurm cluster, use the command below to run on `${GPUS}/${GPUS_PER_NODE}` nodes with `${GPUS_PER_NODE}` gpus on each node: 25 | ``` 26 | sh ./configs/finetune/slurm_finetune_sim_base.sh ${GPUS} ${GPUS_PER_NODE} ${QUOTATYPE} ${PARTITION} ${CKPT_PATH} ${DATA_PATH} 27 | ``` 28 | 29 | ## Evaluation 30 | We also provide the evaluation scripts as follows. 31 | 32 | For torch.distributed.launch, use 33 | ``` 34 | sh ./configs/finetune/dist_finetune_sim_base_eval.sh ${MASTER_ADDR} 0 1 ${CKPT_PATH} ${DATA_PATH} 35 | ``` 36 | 37 | For slurm launch, use 38 | ``` 39 | sh ./configs/finetune/slurm_finetune_sim_base_eval.sh ${GPUS} ${GPUS_PER_NODE} ${QUOTATYPE} ${PARTITION} ${CKPT_PATH} ${DATA_PATH} 40 | ``` 41 | You should get 42 | ``` 43 | * Acc@1 84.118 Acc@5 96.766 loss 0.728 44 | ``` 45 | for the provided checkpoint. 46 | -------------------------------------------------------------------------------- /docs/prepare.md: -------------------------------------------------------------------------------- 1 | # Preparation 2 | 3 | * The only dataset required in this repo is ImageNet, which is enough for pretraining, finetuning, linear evaluation and few-shot evaluation. If you want to evaluate on COCO, LVIS, ADE20k and robustness datasets, please follow the corresponding repos to prepare the data. 4 | 5 | ## Installation 6 | 7 | * Python >=3.7 8 | * We recommend to use Pytorch1.11 for a faster training speed. 9 | * timm == 0.6.12 10 | * numpy == 1.21.5 11 | * tensorboard 12 | 13 | To run few-shot evaluation, [cyanure](https://github.com/inria-thoth/cyanure) package is further required. You can install it with 14 | ``` 15 | pip install cyanure-openblas 16 | # or pip install cyanure-mkl 17 | ``` 18 | 19 | ## Data preparation 20 | 21 | Download and extract ImageNet train and val images from http://image-net.org/. 22 | The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder), and the training and validation data is expected to be in the `train/` folder and `val` folder respectively: 23 | 24 | ``` 25 | /path/to/imagenet/ 26 | ├── train/ 27 | │   ├── class1/ 28 | │ │   ├── img1.JPEG 29 | | │   ├── img2.JPEG 30 | | │   ├── img3.JPEG 31 | | │   └── ... 32 | │   ├── class2/ 33 | | │   └── ... 34 | │   ├── class3/ 35 | | │   └── ... 36 | | └── ... 37 | └─── val 38 | │   ├── class1/ 39 | │ │   ├── img4.JPEG 40 | | │   ├── img5.JPEG 41 | | │   ├── img6.JPEG 42 | | │   └── ... 43 | │   ├── class2/ 44 | | │   └── ... 45 | │   ├── class3/ 46 | | │   └── ... 47 | ``` 48 | 49 | Note that raw val images are not put into class folders, use [this script](https://github.com/pytorch/examples/blob/main/imagenet/extract_ILSVRC.sh) to get correct layout. 50 | -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MAE (https://github.com/facebookresearch/mae) 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # LARS optimizer, implementation from MoCo v3: 9 | # https://github.com/facebookresearch/moco-v3 10 | # ------------------------------------------------------------------------ 11 | 12 | 13 | import torch 14 | 15 | 16 | class LARS(torch.optim.Optimizer): 17 | """ 18 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 19 | """ 20 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 21 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 22 | super().__init__(params, defaults) 23 | 24 | @torch.no_grad() 25 | def step(self): 26 | for g in self.param_groups: 27 | for p in g['params']: 28 | dp = p.grad 29 | 30 | if dp is None: 31 | continue 32 | 33 | if p.ndim > 1: # if not normalization gamma/beta or bias 34 | dp = dp.add(p, alpha=g['weight_decay']) 35 | param_norm = torch.norm(p) 36 | update_norm = torch.norm(dp) 37 | one = torch.ones_like(param_norm) 38 | q = torch.where(param_norm > 0., 39 | torch.where(update_norm > 0, 40 | (g['trust_coefficient'] * param_norm / update_norm), one), 41 | one) 42 | dp = dp.mul(q) 43 | 44 | param_state = self.state[p] 45 | if 'mu' not in param_state: 46 | param_state['mu'] = torch.zeros_like(p) 47 | mu = param_state['mu'] 48 | mu.mul_(g['momentum']).add_(dp) 49 | p.add_(mu, alpha=-g['lr']) 50 | -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MAE (https://github.com/facebookresearch/mae) 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # References: 9 | # ELECTRA https://github.com/google-research/electra 10 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 11 | # ------------------------------------------------------------------------ 12 | 13 | import json 14 | 15 | 16 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75, 17 | small_lr_keywords=('offset',), small_lr_ratio=0.1): 18 | """ 19 | Parameter groups for layer-wise lr decay 20 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 21 | """ 22 | param_group_names = {} 23 | param_groups = {} 24 | 25 | if hasattr(model, 'blocks'): 26 | num_layers = len(model.blocks) + 1 27 | else: 28 | raise NotImplementedError 29 | 30 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 31 | 32 | for n, p in model.named_parameters(): 33 | if not p.requires_grad: 34 | continue 35 | 36 | small_lr = False 37 | for small_lr_keyword in small_lr_keywords: 38 | if small_lr_keyword in n: 39 | g_decay = 'decay_small_lr' 40 | this_decay = weight_decay 41 | layer_id = get_layer_id_for_vit(n, num_layers) 42 | group_name = "layer_%d_%s" % (layer_id, g_decay) 43 | small_lr = True 44 | this_scale = layer_scales[layer_id] * 0.1 45 | 46 | if not small_lr: 47 | # no decay: all 1D parameters and model specific ones 48 | if p.ndim == 1 or n in no_weight_decay_list: 49 | g_decay = "no_decay" 50 | this_decay = 0. 51 | else: 52 | g_decay = "decay" 53 | this_decay = weight_decay 54 | 55 | layer_id = get_layer_id_for_vit(n, num_layers) 56 | group_name = "layer_%d_%s" % (layer_id, g_decay) 57 | this_scale = layer_scales[layer_id] 58 | 59 | if group_name not in param_group_names: 60 | param_group_names[group_name] = { 61 | "lr_scale": this_scale, 62 | "weight_decay": this_decay, 63 | "params": [], 64 | } 65 | param_groups[group_name] = { 66 | "lr_scale": this_scale, 67 | "weight_decay": this_decay, 68 | "params": [], 69 | } 70 | 71 | param_group_names[group_name]["params"].append(n) 72 | param_groups[group_name]["params"].append(p) 73 | 74 | print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 75 | 76 | return list(param_groups.values()) 77 | 78 | 79 | def get_layer_id_for_vit(name, num_layers): 80 | """ 81 | Assign a parameter with its layer id 82 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 83 | """ 84 | if name in ['cls_token', 'pos_embed']: 85 | return 0 86 | elif name.startswith('patch_embed'): 87 | return 0 88 | elif name.startswith('blocks'): 89 | return int(name.split('.')[1]) + 1 90 | else: 91 | return num_layers 92 | -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MAE (https://github.com/facebookresearch/mae) 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # References: 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # ------------------------------------------------------------------------ 11 | 12 | import os 13 | import PIL 14 | 15 | from torchvision import datasets, transforms 16 | 17 | from timm.data import create_transform 18 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 19 | 20 | 21 | def build_transform(is_train, args): 22 | mean = IMAGENET_DEFAULT_MEAN 23 | std = IMAGENET_DEFAULT_STD 24 | # train transform 25 | if is_train: 26 | # this should always dispatch to transforms_imagenet_train 27 | transform = create_transform( 28 | input_size=args.input_size, 29 | is_training=True, 30 | color_jitter=args.color_jitter, 31 | auto_augment=args.aa, 32 | interpolation='bicubic', 33 | re_prob=args.reprob, 34 | re_mode=args.remode, 35 | re_count=args.recount, 36 | mean=mean, 37 | std=std, 38 | ) 39 | return transform 40 | 41 | # eval transform 42 | t = [] 43 | if args.input_size <= 224: 44 | crop_pct = 224 / 256 45 | else: 46 | crop_pct = 1.0 47 | size = int(args.input_size / crop_pct) 48 | t.append( 49 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 50 | ) 51 | t.append(transforms.CenterCrop(args.input_size)) 52 | 53 | t.append(transforms.ToTensor()) 54 | t.append(transforms.Normalize(mean, std)) 55 | return transforms.Compose(t) 56 | 57 | 58 | class ImagenetWithMask(datasets.ImageFolder): 59 | def __init__(self, root, 60 | transform = None, 61 | with_blockwise_mask=False, ### !!! set to True, enable blockwise masking 62 | blockwise_num_masking_patches=75, ### !!! 75 / 196 = 0.38 -> Modify this to increase mask ratio 63 | input_size=224, patch_size=16, # no need to change now 64 | max_mask_patches_per_block=None, # BEiT default setting, no need to change 65 | min_mask_patches_per_block=16, # BEiT default setting, no need to change 66 | fixed_num_masking_patches=True, ### set to true, fixed number of masking patch to blockwise_num_masking_patches for sim training 67 | ): 68 | super().__init__(root, transform) 69 | self.with_blockwise_mask = with_blockwise_mask 70 | if with_blockwise_mask: 71 | from .masking_generator import MaskingGenerator 72 | window_size = input_size // patch_size 73 | self.masked_position_generator = MaskingGenerator( 74 | (window_size, window_size), 75 | num_masking_patches=blockwise_num_masking_patches, 76 | max_num_patches=max_mask_patches_per_block, 77 | min_num_patches=min_mask_patches_per_block, 78 | fixed_num_masking_patches=fixed_num_masking_patches 79 | ) 80 | 81 | def __getitem__(self, index): 82 | sample, target = super().__getitem__(index) 83 | if self.with_blockwise_mask: 84 | return sample, target, self.masked_position_generator() 85 | return sample, target 86 | -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MAE (https://github.com/facebookresearch/mae) 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # References: 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # DeiT: https://github.com/facebookresearch/deit 11 | # ------------------------------------------------------------------------ 12 | 13 | import math 14 | import sys 15 | from typing import Iterable, Optional 16 | 17 | import torch 18 | 19 | from timm.data import Mixup 20 | from timm.utils import accuracy 21 | 22 | import util.misc as misc 23 | import util.lr_sched as lr_sched 24 | 25 | 26 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 27 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 28 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 29 | mixup_fn: Optional[Mixup] = None, log_writer=None, 30 | args=None): 31 | model.train(True) 32 | metric_logger = misc.MetricLogger(delimiter=" ") 33 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 34 | header = 'Epoch: [{}]'.format(epoch) 35 | print_freq = 20 36 | 37 | accum_iter = args.accum_iter 38 | 39 | optimizer.zero_grad() 40 | 41 | if log_writer is not None: 42 | print('log_dir: {}'.format(log_writer.log_dir)) 43 | 44 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 45 | 46 | # we use a per iteration (instead of per epoch) lr scheduler 47 | if data_iter_step % accum_iter == 0: 48 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 49 | 50 | samples = samples.to(device, non_blocking=True) 51 | targets = targets.to(device, non_blocking=True) 52 | 53 | if mixup_fn is not None: 54 | samples, targets = mixup_fn(samples, targets) 55 | 56 | with torch.cuda.amp.autocast(): 57 | outputs = model(samples) 58 | loss = criterion(outputs, targets) 59 | 60 | loss_value = loss.item() 61 | 62 | if not math.isfinite(loss_value): 63 | print("Loss is {}, stopping training".format(loss_value)) 64 | sys.exit(1) 65 | 66 | loss /= accum_iter 67 | loss_scaler(loss, optimizer, clip_grad=max_norm, 68 | parameters=model.parameters(), create_graph=False, 69 | update_grad=(data_iter_step + 1) % accum_iter == 0) 70 | if (data_iter_step + 1) % accum_iter == 0: 71 | optimizer.zero_grad() 72 | 73 | torch.cuda.synchronize() 74 | 75 | metric_logger.update(loss=loss_value) 76 | min_lr = 10. 77 | max_lr = 0. 78 | for group in optimizer.param_groups: 79 | min_lr = min(min_lr, group["lr"]) 80 | max_lr = max(max_lr, group["lr"]) 81 | 82 | metric_logger.update(lr=max_lr) 83 | 84 | loss_value_reduce = misc.all_reduce_mean(loss_value) 85 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 86 | """ We use epoch_1000x as the x-axis in tensorboard. 87 | This calibrates different curves when batch size changes. 88 | """ 89 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 90 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 91 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 92 | 93 | # gather the stats from all processes 94 | metric_logger.synchronize_between_processes() 95 | print("Averaged stats:", metric_logger) 96 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 97 | 98 | 99 | @torch.no_grad() 100 | def evaluate(data_loader, model, device): 101 | criterion = torch.nn.CrossEntropyLoss() 102 | 103 | metric_logger = misc.MetricLogger(delimiter=" ") 104 | header = 'Test:' 105 | 106 | # switch to evaluation mode 107 | model.eval() 108 | 109 | for batch in metric_logger.log_every(data_loader, 10, header): 110 | images = batch[0] 111 | target = batch[-1] 112 | images = images.to(device, non_blocking=True) 113 | target = target.to(device, non_blocking=True) 114 | 115 | # compute output 116 | with torch.cuda.amp.autocast(): 117 | output = model(images) 118 | loss = criterion(output, target) 119 | 120 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 121 | 122 | batch_size = images.shape[0] 123 | metric_logger.update(loss=loss.item()) 124 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 125 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 126 | # gather the stats from all processes 127 | metric_logger.synchronize_between_processes() 128 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 129 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 130 | 131 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 132 | -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MAE (https://github.com/facebookresearch/mae) 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # References: 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # DeiT: https://github.com/facebookresearch/deit 11 | # ------------------------------------------------------------------------ 12 | 13 | 14 | import math 15 | import os 16 | import sys 17 | from turtle import update 18 | from typing import Iterable 19 | from pathlib import Path 20 | 21 | import torch 22 | 23 | import util.misc as misc 24 | import util.lr_sched as lr_sched 25 | 26 | 27 | def train_one_epoch(model: torch.nn.Module, 28 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 29 | device: torch.device, epoch: int, loss_scaler, 30 | log_writer=None, 31 | args=None): 32 | model.train(True) 33 | metric_logger = misc.MetricLogger(delimiter=" ") 34 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 35 | header = 'Epoch: [{}]'.format(epoch) 36 | print_freq = 50 37 | 38 | accum_iter = args.accum_iter 39 | 40 | optimizer.zero_grad() 41 | 42 | if log_writer is not None: 43 | print('log_dir: {}'.format(log_writer.log_dir)) 44 | 45 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 46 | if args.with_blockwise_mask: 47 | samples, labels, mask = data 48 | else: 49 | samples, labels = data 50 | mask = None 51 | 52 | # we use a per iteration (instead of per epoch) lr scheduler 53 | if data_iter_step % accum_iter == 0: 54 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 55 | 56 | if args.mmschedule == 'const': 57 | mm = args.mm 58 | elif args.mmschedule == 'cosine': 59 | mm = 1. - 0.5 * (1. + math.cos(math.pi * (data_iter_step / len(data_loader) + epoch) / args.epochs)) * (1. - args.mm) 60 | metric_logger.update(mm=mm) 61 | update_mm = (data_iter_step % accum_iter == 0) 62 | 63 | if args.loss_type in ['sim',]: 64 | x1, x2, delta_i, delta_j, delta_h, delta_w, relative_flip, flip_delta_j = samples 65 | x1 = x1.to(device, non_blocking=True) 66 | x2 = x2.to(device, non_blocking=True) 67 | delta_i = delta_i.to(x1) 68 | delta_j = delta_j.to(x1) 69 | delta_h = delta_h.to(x1) 70 | delta_w = delta_w.to(x1) 71 | flip_delta_j = flip_delta_j.to(x1) 72 | 73 | rel_pos_21 = (delta_i, delta_j, delta_h, delta_w, relative_flip, flip_delta_j) 74 | 75 | with torch.cuda.amp.autocast(enabled=(not args.fp32)): 76 | loss, outputs = model(x1, x2, rel_pos_21, mm, update_mm, mask=mask) 77 | metric_logger.update(**outputs) 78 | else: 79 | samples = samples.to(device, non_blocking=True) 80 | 81 | with torch.cuda.amp.autocast(enabled=(not args.fp32)): 82 | loss, _, _ = model(samples, mask_ratio=args.mask_ratio) 83 | 84 | loss_value = loss.item() 85 | 86 | if not math.isfinite(loss_value): 87 | print("Loss is {}, stopping training".format(loss_value)) 88 | sys.exit(1) 89 | 90 | loss /= accum_iter 91 | grad_norm = loss_scaler(loss, optimizer, parameters=model.parameters(), 92 | update_grad=(data_iter_step + 1) % accum_iter == 0, clip_grad=args.clip_grad) 93 | if args.fp32: 94 | loss_scale = None 95 | else: 96 | loss_scale = loss_scaler.state_dict()['scale'] 97 | 98 | metric_logger.update(grad_norm=grad_norm) 99 | metric_logger.update(loss_scale=loss_scale) 100 | 101 | if (data_iter_step + 1) % accum_iter == 0: 102 | optimizer.zero_grad() 103 | 104 | torch.cuda.synchronize() 105 | 106 | metric_logger.update(loss=loss_value) 107 | 108 | lr = optimizer.param_groups[0]["lr"] 109 | metric_logger.update(lr=lr) 110 | 111 | loss_value_reduce = misc.all_reduce_mean(loss_value) 112 | outputs_reduced = {k_: misc.all_reduce_mean(v_) for k_, v_ in outputs.items()} 113 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 114 | """ We use epoch_1000x as the x-axis in tensorboard. 115 | This calibrates different curves when batch size changes. 116 | """ 117 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 118 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 119 | log_writer.add_scalar('lr', lr, epoch_1000x) 120 | log_writer.add_scalar('grad_norm', grad_norm, epoch_1000x) 121 | if loss_scale is not None: 122 | log_writer.add_scalar('loss_scale', loss_scale, epoch_1000x) 123 | log_writer.add_scalar('mm', mm, epoch_1000x) 124 | for k_, v_ in outputs_reduced.items(): 125 | log_writer.add_scalar(f'train/{k_}', v_, epoch_1000x) 126 | 127 | # gather the stats from all processes 128 | metric_logger.synchronize_between_processes() 129 | print("Averaged stats:", metric_logger) 130 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 131 | -------------------------------------------------------------------------------- /util/augmentation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | from PIL import ImageFilter, ImageOps 5 | import torch 6 | import torchvision.transforms as transforms 7 | import torchvision.transforms.functional as F 8 | 9 | 10 | class RandomResizedCrop(transforms.RandomResizedCrop): 11 | def __init__(self, cfg, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.args = cfg 14 | 15 | @staticmethod 16 | def get_params(img, scale, ratio): 17 | """Get parameters for ``crop`` for a random sized crop. 18 | 19 | Args: 20 | img (PIL Image or Tensor): Input image. 21 | scale (list): range of scale of the origin size cropped 22 | ratio (list): range of aspect ratio of the origin aspect ratio cropped 23 | 24 | Returns: 25 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 26 | sized crop. 27 | """ 28 | width, height = F.get_image_size(img) 29 | area = height * width 30 | 31 | log_ratio = torch.log(torch.tensor(ratio)) 32 | for _ in range(10): 33 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 34 | aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() 35 | 36 | w = int(round(math.sqrt(target_area * aspect_ratio))) 37 | h = int(round(math.sqrt(target_area / aspect_ratio))) 38 | 39 | if 0 < w <= width and 0 < h <= height: 40 | i1 = torch.randint(0, height - h + 1, size=(1,)).item() 41 | i2 = torch.randint(0, height - h + 1, size=(1,)).item() 42 | j1 = torch.randint(0, width - w + 1, size=(1,)).item() 43 | j2 = torch.randint(0, width - w + 1, size=(1,)).item() 44 | 45 | return i1, j1, i2, j2, h, w 46 | 47 | # Fallback to central crop 48 | in_ratio = float(width) / float(height) 49 | if in_ratio < min(ratio): 50 | w = width 51 | h = int(round(w / min(ratio))) 52 | elif in_ratio > max(ratio): 53 | h = height 54 | w = int(round(h * max(ratio))) 55 | else: # whole image 56 | w = width 57 | h = height 58 | i = (height - h) // 2 59 | j = (width - w) // 2 60 | return i, j, i, j, h, w 61 | 62 | def forward(self, img): 63 | """ 64 | Args: 65 | img (PIL Image or Tensor): Image to be cropped and resized. 66 | 67 | Returns: 68 | PIL Image or Tensor: Randomly cropped and resized image. 69 | """ 70 | i1, j1, i2, j2, h, w = self.get_params(img, self.scale, self.ratio) 71 | return F.resized_crop(img, i1, j1, h, w, self.size, self.interpolation), \ 72 | F.resized_crop(img, i2, j2, h, w, self.size, self.interpolation), (i2-i1)/h, (j2-j1)/w, h/h, w/w 73 | 74 | 75 | class GaussianBlur(object): 76 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 77 | 78 | def __init__(self, sigma=[.1, 2.]): 79 | self.sigma = sigma 80 | 81 | def __call__(self, x): 82 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 83 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 84 | return x 85 | 86 | 87 | class Solarize(object): 88 | """Solarize augmentation from BYOL: https://arxiv.org/abs/2006.07733""" 89 | 90 | def __call__(self, x): 91 | return ImageOps.solarize(x) 92 | 93 | 94 | class SingleRandomResizedCrop(transforms.RandomResizedCrop): 95 | @staticmethod 96 | def get_params(img, scale, ratio): 97 | """Get parameters for ``crop`` for a random sized crop. 98 | 99 | Args: 100 | img (PIL Image or Tensor): Input image. 101 | scale (list): range of scale of the origin size cropped 102 | ratio (list): range of aspect ratio of the origin aspect ratio cropped 103 | 104 | Returns: 105 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 106 | sized crop. 107 | """ 108 | width, height = F.get_image_size(img) 109 | area = height * width 110 | 111 | log_ratio = torch.log(torch.tensor(ratio)) 112 | for _ in range(10): 113 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 114 | aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() 115 | 116 | w = int(round(math.sqrt(target_area * aspect_ratio))) 117 | h = int(round(math.sqrt(target_area / aspect_ratio))) 118 | 119 | if 0 < w <= width and 0 < h <= height: 120 | i = torch.randint(0, height - h + 1, size=(1,)).item() 121 | j = torch.randint(0, width - w + 1, size=(1,)).item() 122 | return i, j, h, w, width 123 | 124 | # Fallback to central crop 125 | in_ratio = float(width) / float(height) 126 | if in_ratio < min(ratio): 127 | w = width 128 | h = int(round(w / min(ratio))) 129 | elif in_ratio > max(ratio): 130 | h = height 131 | w = int(round(h * max(ratio))) 132 | else: # whole image 133 | w = width 134 | h = height 135 | i = (height - h) // 2 136 | j = (width - w) // 2 137 | return i, j, h, w, width 138 | 139 | def forward(self, img): 140 | """ 141 | Args: 142 | img (PIL Image or Tensor): Image to be cropped and resized. 143 | 144 | Returns: 145 | PIL Image or Tensor: Randomly cropped and resized image. 146 | """ 147 | i, j, h, w, width = self.get_params(img, self.scale, self.ratio) 148 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), i, j, h, w, width 149 | 150 | 151 | class RandomHorizontalFlip(transforms.RandomHorizontalFlip): 152 | def forward(self, img): 153 | if torch.rand(1) < self.p: 154 | return F.hflip(img), True 155 | return img, False 156 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Siamese-Image-Modeling 2 | 3 | By [Chenxin Tao](https://scholar.google.com/citations?user=sXHFIBkAAAAJ&hl=zh-CN), 4 | [Xizhou Zhu](https://scholar.google.com/citations?user=02RXI00AAAAJ), 5 | [Weijie Su](https://www.weijiesu.com/), 6 | [Gao Huang](http://www.gaohuang.net/), 7 | [Bin Li](http://staff.ustc.edu.cn/~binli/), 8 | [Jie Zhou](https://scholar.google.com/citations?user=6a79aPwAAAAJ&hl=en), 9 | [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ&hl=en), 10 | [Xiaogang Wang](http://www.ee.cuhk.edu.hk/~xgwang/), 11 | [Jifeng Dai](https://jifengdai.org/) 12 | 13 | This is the official implementation of the CVPR 2023 paper [Siamese Image Modeling for Self-Supervised Vision Representation Learning](https://arxiv.org/pdf/2206.01204.pdf). 14 | 15 | ![SiameseIM-overview](./figs/overview.png) 16 | 17 | ## 🏠 Introduction 18 | 19 | SiameseIM is a new form of self-supervised learning that can learn semantic alignment and spatial sensitivity with a single dense loss. We note the following key observations from SiameseIM: 20 | 21 | - Compared with MIM methods, SiameseIM shows that reconstructing another view helps to obtain good semantic alignment. 22 | 23 | - Compared with ID methods, SiameseIM shows that dense supervision can be applied by matching the dense correspondence between two views strictly through their relative positions. 24 | 25 | - SiameseIM is able to surpass both MIM and ID methods over a wide range of tasks. SiameseIM obtains more improvements in few-shot, long-tail and robustness-concerned scenarios. 26 | 27 | 28 | ![SiameseIM-comparison](./figs/comparison.png) 29 | 30 | 31 | ## 📈 Main Results 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 |
ImageNetCOCOADE20kLVISRobustness
FTLIN1% FTAP boxAP maskmIoUAP boxAP box rareAP maskAP mask rareIN-A top-1IN-R top-1IN-Sketch top-1IN-C 1-mCE
MoCo-v3 (ID method)83.076.763.447.942.747.337.325.535.325.832.449.835.955.4
MAE (MIM method)83.668.051.151.645.948.140.129.338.129.135.948.334.548.3
SiameseIM84.178.065.152.146.251.140.530.938.130.143.852.538.357.1
Improve w.r.t. MoCo-v3+1.1+1.3+1.7+4.2+3.5+3.8+3.2+5.4+2.8+4.3+11.4+2.7+2.4+1.7
Improve w.r.t. MAE+0.5+10.0+14.0+0.5+0.3+3.0+0.4+1.6+0.0+1.0+7.9+4.2+3.8+8.8
61 | 62 | 63 | Note: 64 | 65 | (1) Compared with MoCo-v3, SiameseIM improves dense prediction tasks (COCO detection, ADE20k segmentation, LVIS detection) significantly; 66 | 67 | (2) Compared with MAE, SiameseIM improves long-tail, few-shot, robustness tasks (ImageNet linear evaluation & few-shot classification, ADE20k segmentation, LVIS detection) significantly; 68 | 69 | (3) Notably, ADE20k segmentation and LVIS detection both contain long-tail classes, which put forward high requirement for semantic alignment, and detection tasks, which demand good spatial alignment. Thus, SiameseIM can surpass both MoCo-v3 and MAE by a large margin on these tasks. 70 | 71 | 72 | ## 🛠️ Usage 73 | ### Preparation 74 | 75 | See [prepare.md](docs/prepare.md) 76 | 77 | ### Model Checkpoint 78 | 79 | See [checkpoints.md](docs/checkpoints.md) 80 | 81 | ### Pretrain 82 | 83 | See [pretrain.md](docs/pretrain.md) 84 | 85 | ### Finetune 86 | 87 | See [finetune.md](docs/finetune.md) 88 | 89 | ### Linear Evaluation 90 | 91 | See [linear_eval.md](docs/linear_eval.md) 92 | 93 | ### Few-shot Evaluation 94 | 95 | See [few_shot.md](docs/few_shot.md) 96 | 97 | ### COCO & LVIS Detection 98 | 99 | We use ViTDet for detection tasks, please refer to [detectron2](https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet). 100 | 101 | ### ADE20k Segmentation 102 | 103 | We follow MAE to use UPerNet for segmentation task, please refer to [mmsegmentation](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/mae). 104 | 105 | ### Robustness Evaluation 106 | 107 | We evaluate the ImageNet finetuned model on [ImageNet-A](https://github.com/hendrycks/natural-adv-examples), [ImageNet-R](https://github.com/hendrycks/imagenet-r), [ImageNet-Sketch](https://github.com/HaohanWang/ImageNet-Sketch) and [ImageNet-C](https://github.com/hendrycks/robustness) datasets. 108 | 109 | 110 | ## 📃 License 111 | 112 | This project is released under the [CC-BY-NC 4.0 license](./LICENSE). 113 | 114 | ## 🖊️ Citing SiameseIM 115 | If you find SiameseIM useful in your research, please consider citing: 116 | ```bibtex 117 | @inproceedings{tao2023siamese, 118 | title={Siamese image modeling for self-supervised vision representation learning}, 119 | author={Tao, Chenxin and Zhu, Xizhou and Su, Weijie and Huang, Gao and Li, Bin and Zhou, Jie and Qiao, Yu and Wang, Xiaogang and Dai, Jifeng}, 120 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 121 | pages={2132--2141}, 122 | year={2023} 123 | ``` 124 | -------------------------------------------------------------------------------- /util/masking_generator.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | """ 6 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 7 | Copyright Zhun Zhong & Liang Zheng 8 | 9 | Hacked together by / Copyright 2020 Ross Wightman 10 | 11 | Modified by Hangbo Bao, for generating the masked position for visual image transformer 12 | """ 13 | # -------------------------------------------------------- 14 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 15 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 16 | # Copyright (c) 2021 Microsoft 17 | # Licensed under The MIT License [see LICENSE for details] 18 | # By Hangbo Bao 19 | # Based on timm, DINO and DeiT code bases 20 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 21 | # Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 22 | # Copyright Zhun Zhong & Liang Zheng 23 | # 24 | # Hacked together by / Copyright 2020 Ross Wightman 25 | # 26 | # Modified by Hangbo Bao, for generating the masked position for visual image transformer 27 | # --------------------------------------------------------' 28 | import random 29 | import math 30 | import numpy as np 31 | 32 | 33 | class MaskingGenerator: 34 | def __init__( 35 | self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None, 36 | min_aspect=0.3, max_aspect=None, fixed_num_masking_patches=False): 37 | if not isinstance(input_size, tuple): 38 | input_size = (input_size, ) * 2 39 | self.height, self.width = input_size 40 | 41 | self.num_patches = self.height * self.width 42 | self.num_masking_patches = num_masking_patches 43 | 44 | self.min_num_patches = min_num_patches 45 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches 46 | 47 | max_aspect = max_aspect or 1 / min_aspect 48 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 49 | self.fixed_num_masking_patches = fixed_num_masking_patches 50 | 51 | def __repr__(self): 52 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( 53 | self.height, self.width, self.min_num_patches, self.max_num_patches, 54 | self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) 55 | return repr_str 56 | 57 | def get_shape(self): 58 | return self.height, self.width 59 | 60 | def _mask(self, mask, max_mask_patches): 61 | delta = 0 62 | for attempt in range(10): 63 | target_area = random.uniform(self.min_num_patches, max_mask_patches) 64 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 65 | h = int(round(math.sqrt(target_area * aspect_ratio))) 66 | w = int(round(math.sqrt(target_area / aspect_ratio))) 67 | if w < self.width and h < self.height: 68 | top = random.randint(0, self.height - h) 69 | left = random.randint(0, self.width - w) 70 | 71 | num_masked = mask[top: top + h, left: left + w].sum() 72 | # Overlap 73 | if 0 < h * w - num_masked <= max_mask_patches: 74 | for i in range(top, top + h): 75 | for j in range(left, left + w): 76 | if mask[i, j] == 0: 77 | mask[i, j] = 1 78 | delta += 1 79 | 80 | if delta > 0: 81 | break 82 | return delta 83 | 84 | def __call__(self): 85 | mask = np.zeros(shape=self.get_shape(), dtype=np.int) 86 | mask_count = 0 87 | while mask_count < self.num_masking_patches: 88 | max_mask_patches = self.num_masking_patches - mask_count 89 | max_mask_patches = min(max_mask_patches, self.max_num_patches) 90 | 91 | delta = self._mask(mask, max_mask_patches) 92 | if delta == 0: 93 | break 94 | else: 95 | mask_count += delta 96 | 97 | if self.fixed_num_masking_patches and (mask_count < self.num_masking_patches): 98 | non_masked_inds_i, non_masked_inds_j = (mask == 0).nonzero() 99 | shuffle_inds = list(range(non_masked_inds_i.shape[0])) 100 | random.shuffle(shuffle_inds) 101 | num_to_mask = self.num_masking_patches - mask_count 102 | to_mask_inds_i = non_masked_inds_i[shuffle_inds[:num_to_mask]] 103 | to_mask_inds_j = non_masked_inds_j[shuffle_inds[:num_to_mask]] 104 | mask[to_mask_inds_i, to_mask_inds_j] = 1 105 | mask_count += num_to_mask 106 | 107 | return mask 108 | 109 | 110 | if __name__ == '__main__': 111 | blockwise_num_masking_patches=75 ### TODO: 75 / 196 = 0.38 -> Modify this to increase mask ratio 112 | input_size=224 113 | patch_size=16 # BEiT default setting, no need to change 114 | max_mask_patches_per_block=None # BEiT default setting, no need to change 115 | min_mask_patches_per_block=16 # BEiT default setting, no need to change 116 | fixed_num_masking_patches=True ### TODO: fixed number of masking patch to blockwise_num_masking_patches for sim training 117 | window_size = input_size // patch_size 118 | masked_position_generator = MaskingGenerator( 119 | (window_size, window_size), 120 | num_masking_patches=blockwise_num_masking_patches, 121 | max_num_patches=max_mask_patches_per_block, 122 | min_num_patches=min_mask_patches_per_block, 123 | fixed_num_masking_patches=fixed_num_masking_patches 124 | ) 125 | mask_num = [] 126 | for _ in range(10000): 127 | mask = masked_position_generator() 128 | if _ < 10: 129 | print(mask) 130 | mask_num.append(mask.sum()) 131 | print(f"Max Patches: {max(mask_num)} Min Patches: {min(mask_num)} Mean Patches: {sum(mask_num) / len(mask_num)}") 132 | print(f"Max Ratio: {max(mask_num)/196.0} Min Ratio: {min(mask_num)/196.0} Mean Ratio: {sum(mask_num) / len(mask_num) / 196.0}") 133 | -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MAE (https://github.com/facebookresearch/mae) 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | def get_2d_sincos_pos_embed_relative(delta_i, delta_j, delta_h, delta_w, relative_flip, flip_delta_j, embed_dim, grid_size, cls_token=False): 71 | """ 72 | grid_size: int of the grid height and width 73 | return: 74 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 75 | """ 76 | delta_i = delta_i * grid_size 77 | delta_j = delta_j * grid_size 78 | flip_delta_j = flip_delta_j * grid_size 79 | grid_h = torch.arange(grid_size, dtype=torch.float32) 80 | grid_w = torch.arange(grid_size, dtype=torch.float32) 81 | raw_grid_h, raw_grid_w = torch.meshgrid(grid_h, grid_w) 82 | 83 | raw_grid_h = raw_grid_h + 0.5 84 | raw_grid_w = raw_grid_w + 0.5 85 | grid_h = torch.einsum('b,n->bn', delta_h, raw_grid_h.flatten().to(delta_h)) + delta_i.unsqueeze(-1) 86 | grid_w = torch.einsum('b,n->bn', delta_w, raw_grid_w.flatten().to(delta_w)) + delta_j.unsqueeze(-1) 87 | 88 | flip_grid_w = -torch.einsum('b,n->bn', [delta_w, raw_grid_w.flatten().to(delta_h)]) + flip_delta_j[:, None] 89 | relative_flip = relative_flip.float().unsqueeze(-1) 90 | grid_w = relative_flip * flip_grid_w + (1-relative_flip) * grid_w 91 | grid_w = grid_w - 0.5 92 | grid_h = grid_h - 0.5 93 | 94 | omega = torch.arange(embed_dim//4, dtype=torch.float32) / (embed_dim/4) 95 | omega = 1. / (10000**omega) 96 | out_h = torch.einsum('bn,c->bnc', [grid_h, omega.to(grid_h)]) 97 | out_w = torch.einsum('bn,c->bnc', [grid_w, omega.to(grid_w)]) 98 | out_scale_h = torch.einsum('b,c->bc', [10*torch.log(delta_h), omega.to(grid_h)]).unsqueeze(1).expand(-1, out_h.shape[1], -1) 99 | out_scale_w = torch.einsum('b,c->bc', [10*torch.log(delta_w), omega.to(grid_w)]).unsqueeze(1).expand(-1, out_h.shape[1], -1) 100 | pos_embed = torch.cat([torch.sin(out_h), torch.cos(out_h), torch.sin(out_w), torch.cos(out_w), 101 | torch.sin(out_scale_h), torch.cos(out_scale_h), 102 | torch.sin(out_scale_w), torch.cos(out_scale_w),], dim=2).detach() 103 | 104 | return pos_embed 105 | 106 | 107 | # -------------------------------------------------------- 108 | # Interpolate position embeddings for high-resolution 109 | # References: 110 | # DeiT: https://github.com/facebookresearch/deit 111 | # -------------------------------------------------------- 112 | def interpolate_pos_embed(model, checkpoint_model): 113 | if 'pos_embed' in checkpoint_model: 114 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 115 | embedding_size = pos_embed_checkpoint.shape[-1] 116 | num_patches = model.patch_embed.num_patches 117 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 118 | # height (== width) for the checkpoint position embedding 119 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 120 | # height (== width) for the new position embedding 121 | new_size = int(num_patches ** 0.5) 122 | # class_token and dist_token are kept unchanged 123 | if orig_size != new_size: 124 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 125 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 126 | # only the position tokens are interpolated 127 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 128 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 129 | pos_tokens = torch.nn.functional.interpolate( 130 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 131 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 132 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 133 | checkpoint_model['pos_embed'] = new_pos_embed 134 | -------------------------------------------------------------------------------- /util/tcs_datasets.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | 7 | import os 8 | import io 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | import pyarrow as pa 12 | import numpy as np 13 | from io import BytesIO 14 | import tqdm 15 | from tqdm import trange 16 | try: 17 | from petrel_client.client import Client 18 | except ImportError as E: 19 | "petrel_client.client cannot be imported" 20 | pass 21 | 22 | 23 | def tcs_pil_loader(img_str): 24 | buff = io.BytesIO(img_str) 25 | return Image.open(buff) 26 | 27 | 28 | class TCSLoader(object): 29 | 30 | def __init__(self, conf_path): 31 | self.client = Client(conf_path) 32 | 33 | def __call__(self, fn): 34 | try: 35 | img_value_str = self.client.get(fn) 36 | img = tcs_pil_loader(img_value_str) 37 | except: 38 | print('Read image failed ({})'.format(fn)) 39 | return None 40 | else: 41 | return img 42 | 43 | 44 | def _get_images(annotations): 45 | images = [] 46 | classes = [] 47 | for line in annotations: 48 | if isinstance(line, bytes): 49 | line = line.decode() 50 | image_name, cls = line.strip('\n').split() 51 | images.append(image_name) 52 | classes.append(cls) 53 | return images, classes 54 | 55 | 56 | class ImageNetTCSDatasetQK(Dataset): 57 | def __init__(self, image_set, data_path, transform=None, use_tcs=False, 58 | tcs_conf_path='/mnt/lustre/share_data/taochenxin/tcs/petreloss.conf', 59 | test_mode=False, 60 | on_memory=False, local_rank=None, local_size=None, 61 | **kwargs): 62 | ann_file = os.path.join(data_path, f'meta/{image_set}.txt') 63 | data_path = os.path.join(data_path, image_set) 64 | self.image_set = image_set 65 | self.transform = transform 66 | self.data_path = data_path 67 | self.test_mode = test_mode 68 | if use_tcs: 69 | self.tcs_loader = TCSLoader(tcs_conf_path) 70 | self.use_tcs = use_tcs 71 | self.images, self.classes, self.class_to_idx = self._load_database(ann_file) 72 | self.on_memory = on_memory 73 | if on_memory: 74 | if local_rank is None: 75 | local_rank = int(os.environ.get('LOCAL_RANK', 0)) 76 | if local_size is None: 77 | local_size = int(os.environ.get('LOCAL_SIZE', 1)) 78 | self.local_rank = local_rank 79 | self.local_size = local_size 80 | self.holder = {} 81 | self.load_onto_memory() 82 | 83 | def load_onto_memory(self): 84 | print("Loading images onto memory...") 85 | for index in trange(len(self.images)): 86 | if index % self.local_size != self.local_rank: 87 | continue 88 | path = self.images[index].as_py() 89 | full_path = os.path.join(self.data_path, path) 90 | if self.use_tcs: 91 | sample = self.tcs_loader.client.get(full_path) 92 | else: 93 | with open(full_path, 'rb') as f: 94 | sample = f.read() 95 | self.holder[path] = sample 96 | # print('Loading: path {}, full_path {}, data length {}'.format(path, full_path, 97 | # len(self.tcs_loader.client.get(full_path)))) 98 | print("Loading complete!") 99 | 100 | def _load_database(self, annotation_file): 101 | if not self.use_tcs: 102 | annotation_file = os.path.abspath(annotation_file) 103 | print(f'loading annotations from {annotation_file} ...') 104 | if self.use_tcs: 105 | with BytesIO(self.tcs_loader.client.get(annotation_file)) as annotations: 106 | images, classes = _get_images(annotations) 107 | else: 108 | with open(annotation_file, 'rt') as annotations: 109 | images, classes = _get_images(annotations) 110 | 111 | # convert possible classes to indices 112 | class_names = sorted(set(classes)) 113 | # class_to_idx = {class_name: idx for idx, class_name in enumerate(class_names)} 114 | class_to_idx = {class_name: int(class_name) for class_name in class_names} 115 | return pa.array(images), pa.array([class_to_idx[class_name] for class_name in classes]), class_to_idx 116 | 117 | def __len__(self): 118 | return len(self.images) 119 | 120 | def __getitem__(self, index): 121 | path = self.images[index].as_py() 122 | target = self.classes[index].as_py() 123 | sample = self._load_image(path) 124 | if self.transform is not None: 125 | sample_q = self.transform(sample) 126 | sample_k = self.transform(sample) 127 | return sample_q, sample_k 128 | else: 129 | return sample, sample 130 | 131 | def _load_image(self, path): 132 | full_path = os.path.join(self.data_path, path) 133 | if self.on_memory: 134 | try: 135 | return Image.open(BytesIO(self.holder[path])).convert('RGB') 136 | except: 137 | print('error acquiring data from {}'.format(path)) 138 | return self.tcs_loader(full_path).convert('RGB') 139 | elif self.use_tcs: 140 | return self.tcs_loader(full_path).convert('RGB') 141 | else: 142 | with open(full_path, 'rb') as f: 143 | return Image.open(f).convert('RGB') 144 | 145 | 146 | class ImagenetTCSDataset(ImageNetTCSDatasetQK): 147 | def __init__(self, image_set, data_path, transform=None, use_tcs=False, 148 | tcs_conf_path='/mnt/lustre/share_data/taochenxin/tcs/petreloss.conf', 149 | test_mode=False, on_memory=False, local_rank=None, local_size=None, 150 | with_blockwise_mask=False, ### !!! set to True, enable blockwise masking 151 | blockwise_num_masking_patches=75, ### !!! 75 / 196 = 0.38 -> Modify this to increase mask ratio 152 | input_size=224, patch_size=16, # no need to change now 153 | max_mask_patches_per_block=None, # BEiT default setting, no need to change 154 | min_mask_patches_per_block=16, # BEiT default setting, no need to change 155 | fixed_num_masking_patches=True, ### set to true, fixed number of masking patch to blockwise_num_masking_patches for sim training 156 | **kwargs): 157 | super().__init__(image_set, data_path, transform=transform, use_tcs=use_tcs, 158 | tcs_conf_path=tcs_conf_path, test_mode=test_mode, on_memory=on_memory, 159 | local_rank=local_rank, local_size=local_size, **kwargs) 160 | self.with_blockwise_mask = with_blockwise_mask 161 | if with_blockwise_mask: 162 | from .masking_generator import MaskingGenerator 163 | window_size = input_size // patch_size 164 | self.masked_position_generator = MaskingGenerator( 165 | (window_size, window_size), 166 | num_masking_patches=blockwise_num_masking_patches, 167 | max_num_patches=max_mask_patches_per_block, 168 | min_num_patches=min_mask_patches_per_block, 169 | fixed_num_masking_patches=fixed_num_masking_patches 170 | ) 171 | 172 | def __getitem__(self, index): 173 | path = self.images[index].as_py() 174 | target = self.classes[index].as_py() 175 | sample = self._load_image(path) 176 | if self.transform is not None: 177 | sample = self.transform(sample) 178 | if self.with_blockwise_mask: 179 | return sample, target, self.masked_position_generator() 180 | return sample, target 181 | 182 | 183 | if __name__ == '__main__': 184 | transform = transforms.Compose([ 185 | transforms.RandomResizedCrop(224), 186 | transforms.RandomHorizontalFlip(0.5), 187 | transforms.ToTensor(), 188 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), 189 | ]) 190 | dataset = ImagenetTCSDataset( 191 | 'val', 192 | 's3://imagenet', 193 | tcs_conf_path='./petreloss.conf', 194 | transform=transform, 195 | with_blockwise_mask=True, 196 | blockwise_num_masking_patches=75) 197 | for i, (sample, target, mask) in enumerate(dataset): 198 | if i < 10: 199 | print(mask.sum()) 200 | print(mask) 201 | else: 202 | break 203 | -------------------------------------------------------------------------------- /models_vit.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MAE (https://github.com/facebookresearch/mae) 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # References: 9 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 10 | # DeiT: https://github.com/facebookresearch/deit 11 | # ------------------------------------------------------------------------ 12 | 13 | from functools import partial 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | import timm.models.vision_transformer 20 | from timm.models.layers import Mlp, DropPath 21 | from timm.models.layers.helpers import to_2tuple 22 | 23 | from util.misc import LayerNorm 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 2D Image to Patch Embedding 28 | """ 29 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 30 | super().__init__() 31 | img_size = to_2tuple(img_size) 32 | patch_size = to_2tuple(patch_size) 33 | self.img_size = img_size 34 | self.patch_size = patch_size 35 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 36 | self.num_patches = self.grid_size[0] * self.grid_size[1] 37 | self.flatten = flatten 38 | 39 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 40 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 41 | 42 | def forward(self, x): 43 | B, C, H, W = x.shape 44 | x = self.proj(x) 45 | if self.flatten: 46 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 47 | x = self.norm(x) 48 | return x 49 | 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 53 | super().__init__() 54 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 55 | self.num_heads = num_heads 56 | head_dim = dim // num_heads 57 | self.scale = head_dim ** -0.5 58 | 59 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 60 | self.attn_drop = nn.Dropout(attn_drop) 61 | self.proj = nn.Linear(dim, dim) 62 | self.proj_drop = nn.Dropout(proj_drop) 63 | 64 | def forward(self, x): 65 | B, N, C = x.shape 66 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 67 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 68 | 69 | attn = ((q * self.scale) @ k.transpose(-2, -1)) 70 | attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16 71 | attn = attn.softmax(dim=-1) 72 | attn = self.attn_drop(attn) 73 | 74 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 75 | x = self.proj(x) 76 | x = self.proj_drop(x) 77 | return x 78 | 79 | 80 | class CrossAttention(nn.Module): 81 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 82 | super().__init__() 83 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 84 | self.num_heads = num_heads 85 | head_dim = dim // num_heads 86 | self.scale = head_dim ** -0.5 87 | 88 | # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 89 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 90 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 91 | self.attn_drop = nn.Dropout(attn_drop) 92 | self.proj = nn.Linear(dim, dim) 93 | self.proj_drop = nn.Dropout(proj_drop) 94 | 95 | def forward(self, query, key): 96 | B, Nq, C = query.shape 97 | _, Nk, _ = key.shape 98 | q = self.q(query).reshape(B, Nq, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 99 | kv = self.kv(key).reshape(B, Nk, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 100 | k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 101 | 102 | attn = ((q * self.scale) @ k.transpose(-2, -1)) 103 | attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16 104 | attn = attn.softmax(dim=-1) 105 | attn = self.attn_drop(attn) 106 | 107 | x = (attn @ v).transpose(1, 2).reshape(B, Nq, C) 108 | x = self.proj(x) 109 | x = self.proj_drop(x) 110 | return x 111 | 112 | 113 | class LayerScale(nn.Module): 114 | def __init__(self, dim, init_values=1e-5, inplace=False): 115 | super().__init__() 116 | self.inplace = inplace 117 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 118 | 119 | @torch.cuda.amp.autocast(enabled=False) 120 | def forward(self, x): 121 | return x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float() 122 | 123 | 124 | class Block(nn.Module): 125 | 126 | def __init__( 127 | self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, 128 | drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm): 129 | super().__init__() 130 | self.norm1 = norm_layer(dim) 131 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 132 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 133 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 134 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 135 | 136 | self.norm2 = norm_layer(dim) 137 | mlp_hidden_dim = int(dim * mlp_ratio) 138 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 139 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 140 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 141 | 142 | def forward(self, x): 143 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 144 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 145 | 146 | return x 147 | 148 | 149 | class CrossBlock(nn.Module): 150 | 151 | def __init__( 152 | self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, 153 | drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm): 154 | super().__init__() 155 | self.norm1 = norm_layer(dim) 156 | self.norm2 = norm_layer(dim) 157 | self.cross_attn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 158 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 159 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 160 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 161 | 162 | self.norm3 = norm_layer(dim) 163 | mlp_hidden_dim = int(dim * mlp_ratio) 164 | self.mlp1 = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 165 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 166 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 167 | 168 | self.norm4 = norm_layer(dim) 169 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 170 | self.ls3 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 171 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 172 | self.drop_path3 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 173 | 174 | self.norm5 = norm_layer(dim) 175 | mlp_hidden_dim = int(dim * mlp_ratio) 176 | self.mlp2 = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 177 | self.ls4 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 178 | self.drop_path4 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 179 | 180 | 181 | def forward(self, query, key): 182 | query = query + self.drop_path1(self.ls1(self.cross_attn(self.norm1(query), self.norm2(key)))) 183 | query = query + self.drop_path2(self.ls2(self.mlp1(self.norm3(query)))) 184 | query = query + self.drop_path3(self.ls3(self.attn(self.norm4(query)))) 185 | query = query + self.drop_path4(self.ls4(self.mlp2(self.norm5(query)))) 186 | return query 187 | 188 | 189 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 190 | """ Vision Transformer with support for global average pooling 191 | """ 192 | def __init__(self, global_pool=False, **kwargs): 193 | init_values = kwargs.pop('init_values') 194 | super(VisionTransformer, self).__init__(**kwargs) 195 | 196 | drop_path_rate = kwargs['drop_path_rate'] 197 | depth = kwargs['depth'] 198 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 199 | self.blocks = nn.Sequential(*[ 200 | Block( 201 | dim=kwargs['embed_dim'], num_heads=kwargs['num_heads'], mlp_ratio=kwargs['mlp_ratio'], qkv_bias=kwargs['qkv_bias'], 202 | init_values=init_values, norm_layer=kwargs['norm_layer'], drop_path=dpr[i]) 203 | for i in range(kwargs['depth'])]) 204 | 205 | self.global_pool = global_pool 206 | norm_layer = kwargs['norm_layer'] 207 | embed_dim = kwargs['embed_dim'] 208 | if self.global_pool: 209 | self.fc_norm = norm_layer(embed_dim) 210 | 211 | del self.norm # remove the original norm 212 | 213 | # remove cls token embedding 214 | # delattr(self, 'cls_token') 215 | 216 | num_patches = self.patch_embed.num_patches 217 | if self.global_pool: 218 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim), requires_grad=False) 219 | else: 220 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False) 221 | self.cls_pos_embed = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=False) 222 | 223 | def forward_features(self, x): 224 | B = x.shape[0] 225 | x = self.patch_embed(x) 226 | 227 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 228 | x = torch.cat((cls_tokens, x), dim=1) 229 | x = x + self.pos_embed 230 | x = self.pos_drop(x) 231 | 232 | for blk in self.blocks: 233 | x = blk(x) 234 | 235 | outcome = x 236 | 237 | return outcome 238 | 239 | def forward_head(self, x, pre_logits: bool = False): 240 | if self.global_pool: 241 | x = x[:, 1:, :].mean(dim=1) 242 | else: 243 | x[:, 0] 244 | x = self.fc_norm(x) 245 | return x if pre_logits else self.head(x) 246 | 247 | 248 | def vit_base_patch16(**kwargs): 249 | model = VisionTransformer( 250 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 251 | norm_layer=partial(LayerNorm, eps=1e-6), **kwargs) 252 | return model 253 | 254 | 255 | def vit_large_patch16(**kwargs): 256 | model = VisionTransformer( 257 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 258 | norm_layer=partial(LayerNorm, eps=1e-6), **kwargs) 259 | return model 260 | 261 | 262 | def vit_huge_patch14(**kwargs): 263 | model = VisionTransformer( 264 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 265 | norm_layer=partial(LayerNorm, eps=1e-6), **kwargs) 266 | return model 267 | -------------------------------------------------------------------------------- /main_linprobe.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MAE (https://github.com/facebookresearch/mae) 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # References: 9 | # MoCo v3: https://github.com/facebookresearch/moco-v3 10 | # DeiT: https://github.com/facebookresearch/deit 11 | # ------------------------------------------------------------------------ 12 | 13 | 14 | import argparse 15 | import datetime 16 | import json 17 | import numpy as np 18 | import os 19 | import time 20 | from pathlib import Path 21 | 22 | import torch 23 | import torch.backends.cudnn as cudnn 24 | from torch.utils.tensorboard import SummaryWriter 25 | import torchvision.transforms as transforms 26 | import torchvision.datasets as datasets 27 | 28 | import timm 29 | 30 | assert timm.__version__ == "0.6.12" # version check 31 | from timm.models.layers import trunc_normal_ 32 | 33 | import util.misc as misc 34 | from util.pos_embed import interpolate_pos_embed 35 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 36 | from util.lars import LARS 37 | from util.crop import RandomResizedCrop 38 | import models_vit 39 | from engine_finetune import train_one_epoch, evaluate 40 | 41 | 42 | def get_args_parser(): 43 | parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False) 44 | parser.add_argument('--batch_size', default=512, type=int, 45 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 46 | parser.add_argument('--epochs', default=90, type=int) 47 | parser.add_argument('--accum_iter', default=1, type=int, 48 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 49 | 50 | # Model parameters 51 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 52 | help='Name of model to train') 53 | 54 | # Optimizer parameters 55 | parser.add_argument('--weight_decay', type=float, default=0, 56 | help='weight decay (default: 0 for linear probe following MoCo v1)') 57 | 58 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 59 | help='learning rate (absolute lr)') 60 | parser.add_argument('--blr', type=float, default=0.1, metavar='LR', 61 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 62 | 63 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 64 | help='lower lr bound for cyclic schedulers that hit 0') 65 | 66 | parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', 67 | help='epochs to warmup LR') 68 | 69 | # * Finetuning params 70 | parser.add_argument('--finetune', default='', 71 | help='finetune from checkpoint') 72 | parser.add_argument('--global_pool', action='store_true') 73 | parser.set_defaults(global_pool=False) 74 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 75 | help='Use class token instead of global pool for classification') 76 | 77 | # Dataset parameters 78 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 79 | help='dataset path') 80 | parser.add_argument('--nb_classes', default=1000, type=int, 81 | help='number of the classification types') 82 | parser.add_argument('--use_tcs_dataset', default=False, action='store_true') 83 | 84 | parser.add_argument('--output_dir', default='./output_dir', 85 | help='path where to save, empty for no saving') 86 | parser.add_argument('--log_dir', default='./output_dir', 87 | help='path where to tensorboard log') 88 | parser.add_argument('--device', default='cuda', 89 | help='device to use for training / testing') 90 | parser.add_argument('--seed', default=0, type=int) 91 | parser.add_argument('--resume', default='', 92 | help='resume from checkpoint') 93 | 94 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 95 | help='start epoch') 96 | parser.add_argument('--eval', action='store_true', 97 | help='Perform evaluation only') 98 | parser.add_argument('--dist_eval', action='store_true', default=False, 99 | help='Enabling distributed evaluation (recommended during training for faster monitor') 100 | parser.add_argument('--num_workers', default=10, type=int) 101 | parser.add_argument('--pin_mem', action='store_true', 102 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 103 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 104 | # parser.set_defaults(pin_mem=True) 105 | 106 | # distributed training parameters 107 | parser.add_argument('--world_size', default=1, type=int, 108 | help='number of distributed processes') 109 | parser.add_argument('--local_rank', default=-1, type=int) 110 | parser.add_argument('--dist_on_itp', action='store_true') 111 | parser.add_argument('--dist_url', default='env://', 112 | help='url used to set up distributed training') 113 | 114 | parser.add_argument('--auto_resume', action='store_true', default=True) 115 | parser.add_argument('--init_values', default=1.0, type=float) 116 | 117 | return parser 118 | 119 | 120 | def main(args): 121 | misc.init_distributed_mode(args) 122 | 123 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 124 | print("{}".format(args).replace(', ', ',\n')) 125 | 126 | device = torch.device(args.device) 127 | 128 | # fix the seed for reproducibility 129 | seed = args.seed + misc.get_rank() 130 | torch.manual_seed(seed) 131 | np.random.seed(seed) 132 | 133 | cudnn.benchmark = True 134 | torch.backends.cuda.matmul.allow_tf32 = False 135 | torch.backends.cudnn.allow_tf32 = False 136 | 137 | # linear probe: weak augmentation 138 | transform_train = transforms.Compose([ 139 | RandomResizedCrop(224, interpolation=3), 140 | transforms.RandomHorizontalFlip(), 141 | transforms.ToTensor(), 142 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 143 | transform_val = transforms.Compose([ 144 | transforms.Resize(256, interpolation=3), 145 | transforms.CenterCrop(224), 146 | transforms.ToTensor(), 147 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 148 | if not args.use_tcs_dataset: 149 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 150 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val) 151 | else: # for internal use only 152 | from util.tcs_datasets import ImagenetTCSDataset 153 | dataset_train = ImagenetTCSDataset( 154 | 'train', 155 | 's3://imagenet', 156 | use_tcs=True, 157 | transform=transform_train) 158 | dataset_val = ImagenetTCSDataset( 159 | 'val', 160 | 's3://imagenet', 161 | use_tcs=True, 162 | transform=transform_val) 163 | print(dataset_train) 164 | print(dataset_val) 165 | 166 | # build dataloader 167 | if True: # args.distributed: 168 | num_tasks = misc.get_world_size() 169 | global_rank = misc.get_rank() 170 | sampler_train = torch.utils.data.DistributedSampler( 171 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 172 | ) 173 | print("Sampler_train = %s" % str(sampler_train)) 174 | if args.dist_eval: 175 | if len(dataset_val) % num_tasks != 0: 176 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 177 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 178 | 'equal num of samples per-process.') 179 | sampler_val = torch.utils.data.DistributedSampler( 180 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 181 | else: 182 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 183 | else: 184 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 185 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 186 | 187 | data_loader_train = torch.utils.data.DataLoader( 188 | dataset_train, sampler=sampler_train, 189 | batch_size=args.batch_size, 190 | num_workers=args.num_workers, 191 | pin_memory=args.pin_mem, 192 | drop_last=True, 193 | ) 194 | 195 | data_loader_val = torch.utils.data.DataLoader( 196 | dataset_val, sampler=sampler_val, 197 | batch_size=args.batch_size, 198 | num_workers=args.num_workers, 199 | pin_memory=args.pin_mem, 200 | drop_last=False 201 | ) 202 | 203 | if global_rank == 0 and args.log_dir is not None and not args.eval: 204 | os.makedirs(args.log_dir, exist_ok=True) 205 | log_writer = SummaryWriter(log_dir=args.log_dir) 206 | else: 207 | log_writer = None 208 | 209 | # build model 210 | model = models_vit.__dict__[args.model]( 211 | num_classes=args.nb_classes, 212 | global_pool=args.global_pool, 213 | init_values=args.init_values if args.init_values != 1.0 else None, 214 | drop_path_rate=0.0 215 | ) 216 | 217 | # load ckpt 218 | if args.finetune and not args.eval: 219 | checkpoint = torch.load(args.finetune, map_location='cpu') 220 | 221 | print("Load pre-trained checkpoint from: %s" % args.finetune) 222 | checkpoint_model = checkpoint['model'] 223 | 224 | state_dict = model.state_dict() 225 | for k in ['head.weight', 'head.bias']: 226 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 227 | print(f"Removing key {k} from pretrained checkpoint") 228 | del checkpoint_model[k] 229 | 230 | # interpolate position embedding 231 | interpolate_pos_embed(model, checkpoint_model) 232 | 233 | # load pre-trained model 234 | msg = model.load_state_dict(checkpoint_model, strict=False) 235 | print(msg) 236 | 237 | if args.global_pool: 238 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 239 | else: 240 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 241 | 242 | # manually initialize fc layer: following MoCo v3 243 | trunc_normal_(model.head.weight, std=0.01) 244 | 245 | # for linear prob only 246 | # hack: revise model's head with BN 247 | # model.bn = torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6) 248 | model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head) 249 | # freeze all but the head 250 | for _, p in model.named_parameters(): 251 | p.requires_grad = False 252 | for _, p in model.head.named_parameters(): 253 | p.requires_grad = True 254 | 255 | model.to(device) 256 | 257 | model_without_ddp = model 258 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 259 | 260 | print("Model = %s" % str(model_without_ddp)) 261 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 262 | 263 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 264 | 265 | if args.lr is None: # only base_lr is specified 266 | args.lr = args.blr * eff_batch_size / 256 267 | 268 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 269 | print("actual lr: %.2e" % args.lr) 270 | 271 | print("accumulate grad iterations: %d" % args.accum_iter) 272 | print("effective batch size: %d" % eff_batch_size) 273 | 274 | if args.distributed: 275 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 276 | model_without_ddp = model.module 277 | 278 | # build optimizer 279 | optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay) 280 | print(optimizer) 281 | loss_scaler = NativeScaler() 282 | 283 | criterion = torch.nn.CrossEntropyLoss() 284 | 285 | print("criterion = %s" % str(criterion)) 286 | 287 | # misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 288 | misc.auto_load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 289 | 290 | if args.eval: 291 | test_stats = evaluate(data_loader_val, model, device) 292 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 293 | exit(0) 294 | 295 | # start training 296 | print(f"Start training for {args.epochs} epochs") 297 | start_time = time.time() 298 | max_accuracy = 0.0 299 | for epoch in range(args.start_epoch, args.epochs): 300 | if args.distributed: 301 | data_loader_train.sampler.set_epoch(epoch) 302 | train_stats = train_one_epoch( 303 | model, criterion, data_loader_train, 304 | optimizer, device, epoch, loss_scaler, 305 | max_norm=None, 306 | log_writer=log_writer, 307 | args=args 308 | ) 309 | 310 | # save ckpt 311 | if args.output_dir: 312 | misc.save_model( 313 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 314 | loss_scaler=loss_scaler, epoch=epoch, latest=True) 315 | 316 | if (epoch+1)%1 == 0: 317 | test_stats = evaluate(data_loader_val, model, device) 318 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 319 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 320 | print(f'Max accuracy: {max_accuracy:.2f}%') 321 | 322 | if log_writer is not None: 323 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 324 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 325 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 326 | 327 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 328 | **{f'test_{k}': v for k, v in test_stats.items()}, 329 | 'epoch': epoch, 330 | 'n_parameters': n_parameters} 331 | 332 | if args.output_dir and misc.is_main_process(): 333 | if log_writer is not None: 334 | log_writer.flush() 335 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 336 | f.write(json.dumps(log_stats) + "\n") 337 | 338 | total_time = time.time() - start_time 339 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 340 | print('Training time {}'.format(total_time_str)) 341 | 342 | 343 | if __name__ == '__main__': 344 | args = get_args_parser() 345 | args = args.parse_args() 346 | if args.output_dir: 347 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 348 | main(args) 349 | -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MAE (https://github.com/facebookresearch/mae) 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # References: 9 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 10 | # DeiT: https://github.com/facebookresearch/deit 11 | # ------------------------------------------------------------------------ 12 | 13 | import argparse 14 | import datetime 15 | import json 16 | import numpy as np 17 | import os 18 | import time 19 | from pathlib import Path 20 | 21 | import torch 22 | import torch.distributed as dist 23 | import torch.backends.cudnn as cudnn 24 | from torch.utils.tensorboard import SummaryWriter 25 | import torchvision.transforms as transforms 26 | import torchvision.datasets as datasets 27 | 28 | import timm 29 | assert timm.__version__ == "0.6.12" # version check 30 | from timm.optim.optim_factory import param_groups_weight_decay 31 | from timm.optim import create_optimizer 32 | 33 | import util.misc as misc 34 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 35 | from util.augmentation import RandomResizedCrop, GaussianBlur, SingleRandomResizedCrop, RandomHorizontalFlip, Solarize 36 | from util.datasets import ImagenetWithMask 37 | import models_sim 38 | from engine_pretrain import train_one_epoch 39 | 40 | 41 | def get_args_parser(): 42 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 43 | parser.add_argument('--batch_size', default=64, type=int, 44 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 45 | parser.add_argument('--epochs', default=400, type=int) 46 | parser.add_argument('--accum_iter', default=1, type=int, 47 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 48 | 49 | # Model parameters 50 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL', 51 | help='Name of model to train') 52 | 53 | parser.add_argument('--input_size', default=224, type=int, 54 | help='images input size') 55 | 56 | parser.add_argument('--mask_ratio', default=0.75, type=float, 57 | help='Masking ratio (percentage of removed patches).') 58 | 59 | parser.add_argument('--norm_pix_loss', action='store_true', 60 | help='Use (per-patch) normalized pixels as targets for computing loss') 61 | parser.set_defaults(norm_pix_loss=False) 62 | 63 | parser.add_argument('--use_abs_pos_emb', default=True, action='store_true') 64 | parser.add_argument('--disable_abs_pos_emb', dest='use_abs_pos_emb', action='store_false') 65 | parser.add_argument('--use_shared_rel_pos_bias', default=False, action='store_true') 66 | 67 | # Optimizer parameters 68 | parser.add_argument('--weight_decay', type=float, default=0.05, 69 | help='weight decay (default: 0.05)') 70 | 71 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 72 | help='learning rate (absolute lr)') 73 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 74 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 75 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 76 | help='lower lr bound for cyclic schedulers that hit 0') 77 | 78 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 79 | help='epochs to warmup LR') 80 | 81 | # Dataset parameters 82 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 83 | help='dataset path') 84 | 85 | parser.add_argument('--output_dir', default='./output_dir', 86 | help='path where to save, empty for no saving') 87 | parser.add_argument('--log_dir', default='./output_dir', 88 | help='path where to tensorboard log') 89 | parser.add_argument('--device', default='cuda', 90 | help='device to use for training / testing') 91 | parser.add_argument('--seed', default=0, type=int) 92 | parser.add_argument('--resume', default='', 93 | help='resume from checkpoint') 94 | 95 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 96 | help='start epoch') 97 | parser.add_argument('--num_workers', default=10, type=int) 98 | parser.add_argument('--pin_mem', action='store_true', 99 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 100 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 101 | # parser.set_defaults(pin_mem=True) 102 | 103 | # distributed training parameters 104 | parser.add_argument('--world_size', default=1, type=int, 105 | help='number of distributed processes') 106 | parser.add_argument('--local_rank', default=-1, type=int) 107 | parser.add_argument('--dist_on_itp', action='store_true') 108 | parser.add_argument('--dist_url', default='env://', 109 | help='url used to set up distributed training') 110 | 111 | # SiameseIM parameters 112 | # data 113 | parser.add_argument('--crop_min', default=0.2, type=float) 114 | parser.add_argument('--use_tcs_dataset', default=False, action='store_true') 115 | 116 | # model 117 | parser.add_argument('--decoder_embed_dim', default=512, type=int) 118 | parser.add_argument('--drop_path_rate', default=0.0, type=float) 119 | parser.add_argument('--init_values', default=None, type=float) 120 | parser.add_argument('--projector_depth', default=2, type=int) 121 | parser.add_argument('--predictor_depth', default=4, type=int) 122 | parser.add_argument('--use_proj_ln', default=False, action='store_true') 123 | parser.add_argument('--use_pred_ln', default=False, action='store_true') 124 | parser.add_argument('--train_patch_embed', default=False, action='store_true') 125 | parser.add_argument('--online_ln', default=False, action='store_true', help='also use frozen LN in online branch') 126 | 127 | parser.add_argument('--loss_type', default='mae') 128 | parser.add_argument('--neg_weight', default=0.02, type=float) 129 | 130 | parser.add_argument('--with_blockwise_mask', default=False, action='store_true') 131 | parser.add_argument('--blockwise_num_masking_patches', default=75, type=int) 132 | 133 | # hyper-parameter 134 | parser.add_argument('--mm', default=0.996, type=float) 135 | parser.add_argument('--mmschedule', default='const') 136 | parser.add_argument('--lambda_F', default=50, type=float) # may no need 137 | parser.add_argument('--T', default=0.2, type=float) # check 138 | parser.add_argument('--clip_grad', default=None, type=float) 139 | parser.add_argument('--beta2', default=0.95, type=float) 140 | 141 | # misc 142 | parser.add_argument('--auto_resume', default=True) 143 | parser.add_argument('--save_freq', default=50, type=int) 144 | parser.add_argument('--save_latest_freq', default=1, type=int) 145 | parser.add_argument('--fp32', default=False, action='store_true') 146 | parser.add_argument('--amp_growth_interval', default=2000, type=int) 147 | 148 | return parser 149 | 150 | 151 | class DataAugmentationForSIM(object): 152 | def __init__(self, args): 153 | self.args = args 154 | 155 | self.random_resized_crop = SingleRandomResizedCrop(args.input_size, scale=(args.crop_min, 1.0), interpolation=3) 156 | self.random_flip = RandomHorizontalFlip() 157 | 158 | self.color_transform1 = transforms.Compose([ 159 | transforms.RandomApply([ 160 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened 161 | ], p=0.8), 162 | transforms.RandomGrayscale(p=0.2), 163 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=1.0), 164 | ]) 165 | 166 | self.color_transform2 = transforms.Compose([ 167 | transforms.RandomApply([ 168 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened 169 | ], p=0.8), 170 | transforms.RandomGrayscale(p=0.2), 171 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1), 172 | transforms.RandomApply([Solarize()], p=0.2), 173 | ]) 174 | 175 | self.format_transform = transforms.Compose([ 176 | transforms.ToTensor(), 177 | transforms.Normalize( 178 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 179 | ]) 180 | 181 | def __call__(self, image): 182 | spatial_image1, flip1 = self.random_flip(image) 183 | spatial_image2, flip2 = self.random_flip(image) 184 | spatial_image1, i1, j1, h1, w1, W = self.random_resized_crop(spatial_image1) 185 | spatial_image2, i2, j2, h2, w2, W = self.random_resized_crop(spatial_image2) 186 | color_image1 = self.color_transform1(spatial_image1) 187 | color_image2 = self.color_transform2(spatial_image2) 188 | 189 | relative_flip = (flip1 and not flip2) or (flip2 and not flip1) 190 | return self.format_transform(color_image1), self.format_transform(color_image2), \ 191 | (i2-i1)/h1, (j2-j1)/w1, h2/h1, w2/w1, relative_flip, (W-j1-j2)/w1 192 | 193 | def __repr__(self): 194 | repr = "(DataAugmentation,\n" 195 | repr += " transform = %s,\n" % str(self.random_resized_crop) + str(self.random_flip) + str(self.color_transform1) + str(self.format_transform) 196 | repr += ")" 197 | return repr 198 | 199 | 200 | def main(args): 201 | misc.init_distributed_mode(args) # need change to torch.engine 202 | 203 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 204 | print("{}".format(args).replace(', ', ',\n')) 205 | 206 | device = torch.device(args.device) 207 | 208 | # fix the seed for reproducibility 209 | seed = args.seed + misc.get_rank() 210 | torch.manual_seed(seed) 211 | np.random.seed(seed) 212 | 213 | cudnn.benchmark = True 214 | 215 | # disable tf32 216 | torch.backends.cuda.matmul.allow_tf32 = False 217 | torch.backends.cudnn.allow_tf32 = False 218 | 219 | # build augmentation and dataset 220 | if args.loss_type in ['sim']: 221 | transform_train = DataAugmentationForSIM(args) 222 | else: 223 | transform_train = transforms.Compose([ 224 | transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic 225 | transforms.RandomHorizontalFlip(), 226 | transforms.ToTensor(), 227 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 228 | if not args.use_tcs_dataset: 229 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 230 | dataset_train = ImagenetWithMask(os.path.join(args.data_path, 'train'), 231 | transform=transform_train, 232 | with_blockwise_mask=args.with_blockwise_mask, 233 | blockwise_num_masking_patches=args.blockwise_num_masking_patches) 234 | else: # for internal use only 235 | from util.tcs_datasets import ImagenetTCSDataset 236 | dataset_train = ImagenetTCSDataset('train', 237 | 's3://imagenet', 238 | use_tcs=True, 239 | transform=transform_train, 240 | with_blockwise_mask=args.with_blockwise_mask, 241 | blockwise_num_masking_patches=args.blockwise_num_masking_patches, 242 | local_rank=int(os.environ['LOCAL_RANK']), 243 | local_size=int(os.environ['LOCAL_SIZE']), 244 | tcs_conf_path='./petreloss.conf') 245 | print(dataset_train) 246 | 247 | # build dataloader 248 | if True: # args.distributed: 249 | num_tasks = misc.get_world_size() 250 | global_rank = misc.get_rank() 251 | sampler_train = torch.utils.data.DistributedSampler( 252 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 253 | ) 254 | print("Sampler_train = %s" % str(sampler_train)) 255 | else: 256 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 257 | 258 | if global_rank == 0 and args.log_dir is not None: 259 | os.makedirs(args.log_dir, exist_ok=True) 260 | log_writer = SummaryWriter(log_dir=args.log_dir) 261 | else: 262 | log_writer = None 263 | 264 | data_loader_train = torch.utils.data.DataLoader( 265 | dataset_train, sampler=sampler_train, 266 | batch_size=args.batch_size, 267 | num_workers=args.num_workers, 268 | pin_memory=args.pin_mem, 269 | drop_last=True, 270 | ) 271 | 272 | 273 | # build model 274 | model = models_sim.__dict__[args.model](norm_pix_loss=args.norm_pix_loss, args=args) 275 | model.to(device) 276 | model_without_ddp = model 277 | 278 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 279 | if args.lr is None: # only base_lr is specified 280 | args.lr = args.blr * eff_batch_size / 256 281 | 282 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 283 | print("actual lr: %.2e" % args.lr) 284 | 285 | print("accumulate grad iterations: %d" % args.accum_iter) 286 | print("effective batch size: %d" % eff_batch_size) 287 | 288 | if args.distributed: 289 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 290 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 291 | model_without_ddp = model.module 292 | print("Model = %s" % str(model_without_ddp)) 293 | 294 | # build optimizer 295 | # following timm: set wd as 0 for bias and norm layers 296 | param_groups = param_groups_weight_decay(model_without_ddp, args.weight_decay) 297 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, args.beta2)) 298 | print(optimizer) 299 | loss_scaler = NativeScaler(enabled=(not args.fp32), growth_interval=args.amp_growth_interval) 300 | 301 | misc.auto_load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, 302 | loss_scaler=loss_scaler) 303 | 304 | # start training 305 | print(f"Start training for {args.epochs} epochs") 306 | start_time = time.time() 307 | for epoch in range(args.start_epoch, args.epochs): 308 | epoch_start_time = time.time() 309 | if args.distributed: 310 | data_loader_train.sampler.set_epoch(epoch) 311 | train_stats = train_one_epoch( 312 | model, data_loader_train, 313 | optimizer, device, epoch, loss_scaler, 314 | log_writer=log_writer, 315 | args=args 316 | ) 317 | dist.barrier() 318 | 319 | # save ckpt 320 | if args.output_dir and ((epoch+1) % args.save_freq == 0 or epoch + 1 == args.epochs): 321 | misc.save_model( 322 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 323 | loss_scaler=loss_scaler, epoch=epoch) 324 | if (epoch+1) % args.save_latest_freq == 0: 325 | misc.save_model( 326 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 327 | loss_scaler=loss_scaler, epoch=epoch, latest=True) 328 | 329 | # log information 330 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 331 | 'epoch': epoch,} 332 | 333 | if args.output_dir and misc.is_main_process(): 334 | if log_writer is not None: 335 | log_writer.flush() 336 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 337 | f.write(json.dumps(log_stats) + "\n") 338 | if misc.is_main_process(): 339 | epoch_total_time = time.time() - epoch_start_time 340 | now = datetime.datetime.today() 341 | eta = now + datetime.timedelta(seconds=(args.epochs-epoch-1)*int(epoch_total_time)) 342 | next_50_ep = ((epoch + 1) // 50 + 1) * 50 343 | eta_to_next_50 =now + datetime.timedelta(seconds=(next_50_ep - epoch - 1) * int(epoch_total_time)) 344 | print(f"ETA to {args.epochs:4d}ep:\t{str(eta)}") 345 | print(f"ETA to {next_50_ep:4d}ep:\t{str(eta_to_next_50)}") 346 | dist.barrier() 347 | 348 | total_time = time.time() - start_time 349 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 350 | print('Training time {}'.format(total_time_str)) 351 | 352 | 353 | if __name__ == '__main__': 354 | args = get_args_parser() 355 | args = args.parse_args() 356 | if args.output_dir: 357 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 358 | main(args) 359 | -------------------------------------------------------------------------------- /main_logistic.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MSN (https://github.com/facebookresearch/msn) 6 | # Copyright (c) Facebook, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | 9 | 10 | import os 11 | import argparse 12 | import logging 13 | import pprint 14 | 15 | import numpy as np 16 | import torch 17 | import torchvision.transforms as transforms 18 | import cyanure as cyan 19 | 20 | 21 | logging.basicConfig() 22 | logger = logging.getLogger() 23 | logger.setLevel(logging.INFO) 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | '--lambd', type=float, 28 | default=0.00025, 29 | help='regularization') 30 | parser.add_argument( 31 | '--penalty', type=str, 32 | help='regularization for logistic classifier', 33 | default='l2', 34 | choices=[ 35 | 'l2', 36 | 'elastic-net' 37 | ]) 38 | parser.add_argument( 39 | '--mask', type=float, 40 | default=0.0, 41 | help='regularization') 42 | parser.add_argument( 43 | '--preload', action='store_true', 44 | help='whether to preload embs if possible') 45 | parser.add_argument( 46 | '--fname', type=str, 47 | help='model architecture') 48 | parser.add_argument( 49 | '--model-name', type=str, 50 | help='model architecture') 51 | parser.add_argument( 52 | '--pretrained', type=str, 53 | help='path to pretrained model', 54 | default='') 55 | parser.add_argument( 56 | '--device', type=str, 57 | default='cuda:0', 58 | help='device to run script on') 59 | parser.add_argument( 60 | '--normalize', type=bool, 61 | default=True, 62 | help='whether to standardize images before feeding to nework') 63 | parser.add_argument( 64 | '--root-path', type=str, 65 | default='/datasets/', 66 | help='root directory to data') 67 | parser.add_argument( 68 | '--image-folder', type=str, 69 | default='imagenet_full_size/061417/', 70 | help='image directory inside root_path') 71 | parser.add_argument( 72 | '--subset-path', type=str, 73 | default=None, 74 | help='name of dataset to evaluate on') 75 | parser.add_argument('--local_rank', default=-1, type=int) 76 | 77 | logging.basicConfig() 78 | logger = logging.getLogger() 79 | logger.setLevel(logging.INFO) 80 | 81 | _GLOBAL_SEED = 0 82 | np.random.seed(_GLOBAL_SEED) 83 | torch.manual_seed(_GLOBAL_SEED) 84 | torch.backends.cudnn.benchmark = True 85 | 86 | pp = pprint.PrettyPrinter(indent=4) 87 | 88 | 89 | def main( 90 | blocks, 91 | lambd, 92 | mask_frac, 93 | preload, 94 | pretrained, 95 | fname, 96 | subset_path, 97 | root_path, 98 | image_folder, 99 | penalty='l2', 100 | model_name=None, 101 | normalize=True, 102 | device_str='cuda:0', 103 | args=None 104 | ): 105 | init_distributed_mode(args) 106 | # torch.cuda.set_device(args.rank) 107 | # device = torch.device('cuda') 108 | # device = torch.device(device_str) 109 | # if 'cuda' in device_str: 110 | # torch.cuda.set_device(device) 111 | 112 | # -- Define file names used to save computed embeddings (for efficient 113 | # -- reuse if running the script more than once) 114 | subset_tag = '-'.join(subset_path.split('/')).split('.txt')[0] if subset_path is not None else 'imagenet_subses1-100percent' 115 | train_embs_path = f'train-features-{subset_tag}-{fname}' 116 | test_embs_path = f'val-features-{fname}' 117 | logger.info(train_embs_path) 118 | logger.info(test_embs_path) 119 | 120 | # pretrained = os.path.join(pretrained, fname) 121 | 122 | # -- Function to make train/test dataloader 123 | def init_pipe(training): 124 | # -- make data transforms 125 | transform = transforms.Compose([ 126 | transforms.Resize(size=256), 127 | transforms.CenterCrop(size=224), 128 | transforms.ToTensor(), 129 | transforms.Normalize( 130 | (0.485, 0.456, 0.406), 131 | (0.229, 0.224, 0.225))]) 132 | # -- init data-loaders/samplers 133 | subset_file = subset_path if training else None 134 | data_loader, _ = init_data( 135 | transform=transform, 136 | batch_size=64, 137 | num_workers=0, 138 | world_size=args.world_size, 139 | rank=args.rank, 140 | root_path=root_path, 141 | image_folder=image_folder, 142 | training=training, 143 | copy_data=False, 144 | drop_last=False, 145 | subset_file=subset_file) 146 | return data_loader 147 | 148 | # -- Initialize the model 149 | encoder = init_model( 150 | # device=device, 151 | pretrained=pretrained, 152 | model_name=model_name) 153 | encoder.eval() 154 | 155 | # -- If train embeddings already computed, load file, otherwise, compute 156 | # -- embeddings and save 157 | if preload and os.path.exists(train_embs_path): 158 | checkpoint = torch.load(train_embs_path, map_location='cpu') 159 | embs, labs = checkpoint['embs'], checkpoint['labs'] 160 | logger.info(f'loaded embs of shape {embs.shape}') 161 | else: 162 | data_loader = init_pipe(True) 163 | embs, labs = make_embeddings( 164 | blocks=blocks, 165 | # device=device, 166 | mask_frac=mask_frac, 167 | data_loader=data_loader, 168 | encoder=encoder) 169 | torch.save({ 170 | 'embs': embs, 171 | 'labs': labs 172 | }, train_embs_path) 173 | logger.info(f'saved train embs of shape {embs.shape}') 174 | # # -- Normalize embeddings 175 | cyan.preprocess(embs, normalize=normalize, columns=False, centering=True) 176 | 177 | # import pdb; pdb.set_trace() 178 | 179 | # -- Fit Logistic Regression Classifier 180 | classifier = cyan.MultiClassifier(loss='multiclass-logistic', penalty=penalty, fit_intercept=False) 181 | lambd /= len(embs) 182 | classifier.fit( 183 | embs.numpy(), 184 | labs.numpy(), 185 | it0=10, 186 | lambd=lambd, 187 | lambd2=lambd, 188 | nthreads=-1, 189 | tol=1e-3, 190 | solver='auto', 191 | seed=0, 192 | max_epochs=300) 193 | 194 | # -- Evaluate and log 195 | train_score = classifier.score(embs.numpy(), labs.numpy()) 196 | # -- (save train score) 197 | logger.info(f'train score: {train_score}') 198 | 199 | # -- If test embeddings already computed, load file, otherwise, compute 200 | # -- embeddings and save 201 | if preload and os.path.exists(test_embs_path): 202 | checkpoint = torch.load(test_embs_path, map_location='cpu') 203 | test_embs, test_labs = checkpoint['embs'], checkpoint['labs'] 204 | logger.info(f'loaded test embs of shape {test_embs.shape}') 205 | else: 206 | data_loader = init_pipe(False) 207 | test_embs, test_labs = make_embeddings( 208 | blocks=blocks, 209 | # device=device, 210 | mask_frac=0.0, 211 | data_loader=data_loader, 212 | encoder=encoder) 213 | torch.save({ 214 | 'embs': test_embs, 215 | 'labs': test_labs 216 | }, test_embs_path) 217 | logger.info(f'saved test embs of shape {test_embs.shape}') 218 | # -- Normalize embeddings 219 | cyan.preprocess(test_embs, normalize=normalize, columns=False, centering=True) 220 | 221 | # -- Evaluate and log 222 | test_score = classifier.score(test_embs.numpy(), test_labs.numpy()) 223 | # -- (save test score) 224 | logger.info(f'test score: {test_score}\n\n') 225 | 226 | return test_score 227 | 228 | 229 | def init_distributed_mode(args): 230 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 231 | args.rank = int(os.environ["RANK"]) 232 | args.world_size = int(os.environ['WORLD_SIZE']) 233 | args.gpu = int(os.environ['LOCAL_RANK']) 234 | elif 'SLURM_PROCID' in os.environ: 235 | args.rank = int(os.environ['SLURM_PROCID']) 236 | args.world_size = int(os.environ['SLURM_NTASKS']) 237 | node_list = os.environ['SLURM_NODELIST'] 238 | num_gpus = torch.cuda.device_count() 239 | args.gpu = args.rank % torch.cuda.device_count() 240 | torch.cuda.set_device(args.rank % num_gpus) 241 | import subprocess 242 | addr = subprocess.getoutput( 243 | f'scontrol show hostname {node_list} | head -n1') 244 | # specify master port 245 | if hasattr(args, 'port'): 246 | os.environ['MASTER_PORT'] = str(args.port) 247 | elif 'MASTER_PORT' in os.environ: 248 | pass # use MASTER_PORT in the environment variable 249 | else: 250 | # 29500 is torch.distributed default port 251 | os.environ['MASTER_PORT'] = '29502' 252 | # use MASTER_ADDR in the environment variable if it already exists 253 | if 'MASTER_ADDR' not in os.environ: 254 | os.environ['MASTER_ADDR'] = addr 255 | os.environ['WORLD_SIZE'] = str(args.world_size) 256 | os.environ['LOCAL_RANK'] = str(args.rank % num_gpus) 257 | os.environ['RANK'] = str(args.rank) 258 | # dist.init_process_group(backend='nccl') 259 | else: 260 | print('Not using distributed mode') 261 | setup_for_distributed(is_master=True) # hack 262 | args.distributed = False 263 | return 264 | 265 | args.distributed = True 266 | 267 | torch.cuda.set_device(args.gpu) 268 | args.dist_backend = 'nccl' 269 | args.dist_url = 'env://' 270 | print('| distributed init (rank {}): {}, gpu {}'.format( 271 | args.rank, args.dist_url, args.gpu), flush=True) 272 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 273 | world_size=args.world_size, rank=args.rank) 274 | torch.distributed.barrier() 275 | # setup_for_distributed(args.rank == 0) 276 | 277 | 278 | def init_data( 279 | transform, 280 | batch_size, 281 | pin_mem=True, 282 | num_workers=8, 283 | world_size=1, 284 | rank=0, 285 | root_path=None, 286 | image_folder=None, 287 | training=True, 288 | copy_data=False, 289 | drop_last=True, 290 | subset_file=None 291 | ): 292 | 293 | # dataset = ImageNet( 294 | # root=root_path, 295 | # image_folder=image_folder, 296 | # transform=transform, 297 | # train=training, 298 | # copy_data=copy_data) 299 | # if subset_file is not None: 300 | # dataset = ImageNetSubset(dataset, subset_file) 301 | import torchvision 302 | if training: 303 | dataset = torchvision.datasets.ImageFolder(os.path.join(root_path, 'train'), transform=transform) 304 | with open(subset_file) as subset_file: 305 | list_imgs = [li.split('\n')[0] for li in subset_file.readlines()] 306 | dataset.samples = [( 307 | os.path.join(os.path.join(root_path, 'train'), li.split('_')[0], li), 308 | dataset.class_to_idx[li.split('_')[0]] 309 | ) for li in list_imgs] 310 | else: 311 | dataset = torchvision.datasets.ImageFolder(os.path.join(root_path, 'val'), transform=transform) 312 | 313 | logger.info('ImageNet dataset created') 314 | dist_sampler = torch.utils.data.distributed.DistributedSampler( 315 | dataset=dataset, 316 | num_replicas=world_size, 317 | rank=rank) 318 | data_loader = torch.utils.data.DataLoader( 319 | dataset, 320 | sampler=dist_sampler, 321 | batch_size=batch_size, 322 | drop_last=drop_last, 323 | pin_memory=pin_mem, 324 | num_workers=num_workers) 325 | logger.info('ImageNet unsupervised data loader created') 326 | 327 | return (data_loader, dist_sampler) 328 | 329 | 330 | def make_embeddings( 331 | blocks, 332 | # device, 333 | mask_frac, 334 | data_loader, 335 | encoder, 336 | epochs=1 337 | ): 338 | ipe = len(data_loader) 339 | 340 | z_mem, l_mem = [], [] 341 | 342 | for _ in range(epochs): 343 | for itr, (imgs, labels) in enumerate(data_loader): 344 | imgs = imgs.cuda() 345 | with torch.no_grad(): 346 | z = encoder.forward_features(imgs)[:, 0].cpu() 347 | labels = labels.cpu() 348 | z_mem.append(z) 349 | l_mem.append(labels) 350 | if itr % 50 == 0: 351 | logger.info(f'[{itr}/{ipe}]') 352 | 353 | z_mem = torch.cat(z_mem, 0) 354 | l_mem = torch.cat(l_mem, 0) 355 | z_mem = all_gather(z_mem) 356 | z_mem = torch.cat(z_mem, 0) 357 | l_mem = all_gather(l_mem) 358 | l_mem = torch.cat(l_mem, 0) 359 | logger.info(z_mem.shape) 360 | logger.info(l_mem.shape) 361 | 362 | return z_mem, l_mem 363 | 364 | 365 | def all_gather(data): 366 | """ 367 | Run all_gather on arbitrary picklable data (not necessarily tensors) 368 | Args: 369 | data: any picklable object 370 | Returns: 371 | list[data]: list of data gathered from each rank 372 | """ 373 | world_size = torch.distributed.get_world_size() 374 | if world_size == 1: 375 | return [data] 376 | 377 | # serialized to a Tensor 378 | import pickle 379 | buffer = pickle.dumps(data) 380 | storage = torch.ByteStorage.from_buffer(buffer) 381 | tensor = torch.ByteTensor(storage).to("cuda") 382 | 383 | # obtain Tensor size of each rank 384 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 385 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 386 | torch.distributed.all_gather(size_list, local_size) 387 | size_list = [int(size.item()) for size in size_list] 388 | max_size = max(size_list) 389 | 390 | # receiving Tensor from all ranks 391 | # we pad the tensor because torch all_gather does not support 392 | # gathering tensors of different shapes 393 | tensor_list = [] 394 | for _ in size_list: 395 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 396 | if local_size != max_size: 397 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 398 | tensor = torch.cat((tensor, padding), dim=0) 399 | torch.distributed.all_gather(tensor_list, tensor) 400 | 401 | data_list = [] 402 | for size, tensor in zip(size_list, tensor_list): 403 | buffer = tensor.cpu().numpy().tobytes()[:size] 404 | data_list.append(pickle.loads(buffer)) 405 | 406 | return data_list 407 | 408 | 409 | def load_pretrained( 410 | encoder, 411 | pretrained 412 | ): 413 | checkpoint = torch.load(pretrained, map_location='cpu') 414 | pretrained_dict = {k.replace('module.', ''): v for k, v in checkpoint['target_encoder'].items()} 415 | for k, v in encoder.state_dict().items(): 416 | if k not in pretrained_dict: 417 | logger.info(f'key "{k}" could not be found in loaded state dict') 418 | elif pretrained_dict[k].shape != v.shape: 419 | logger.info(f'key "{k}" is of different shape in model and loaded state dict') 420 | pretrained_dict[k] = v 421 | msg = encoder.load_state_dict(pretrained_dict, strict=False) 422 | print(encoder) 423 | logger.info(f'loaded pretrained model with msg: {msg}') 424 | try: 425 | logger.info(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]} ' 426 | f'path: {pretrained}') 427 | except Exception: 428 | pass 429 | del checkpoint 430 | return encoder 431 | 432 | 433 | def init_model( 434 | # device, 435 | pretrained, 436 | model_name, 437 | ): 438 | # encoder = deit.__dict__[model_name]() 439 | # encoder.fc = None 440 | # encoder.to(device) 441 | # encoder = load_pretrained(encoder=encoder, pretrained=pretrained) 442 | 443 | import models_vit 444 | model = models_vit.__dict__[model_name]( 445 | num_classes=1000, 446 | global_pool=True, 447 | init_values=None, 448 | drop_path_rate=0.0 449 | ) 450 | 451 | checkpoint = torch.load(pretrained, map_location='cpu') 452 | 453 | print("Load pre-trained checkpoint from: %s" % pretrained) 454 | checkpoint_model = checkpoint['model'] 455 | state_dict = model.state_dict() 456 | for k in ['head.weight', 'head.bias']: 457 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 458 | print(f"Removing key {k} from pretrained checkpoint") 459 | del checkpoint_model[k] 460 | 461 | # load pre-trained model 462 | msg = model.load_state_dict(checkpoint_model, strict=False) 463 | print(msg) 464 | 465 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 466 | 467 | model.head = None 468 | 469 | model.cuda() 470 | 471 | return model 472 | 473 | 474 | if __name__ == '__main__': 475 | """'main' for launching script using params read from command line""" 476 | global args 477 | args = parser.parse_args() 478 | pp.pprint(args) 479 | main( 480 | blocks=1, 481 | lambd=args.lambd, 482 | penalty=args.penalty, 483 | mask_frac=args.mask, 484 | preload=args.preload, 485 | pretrained=args.pretrained, 486 | fname=args.fname, 487 | subset_path=args.subset_path, 488 | root_path=args.root_path, 489 | image_folder=args.image_folder, 490 | model_name=args.model_name, 491 | normalize=args.normalize, 492 | device_str=args.device, 493 | args=args 494 | ) 495 | -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MAE (https://github.com/facebookresearch/mae) 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # References: 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # DeiT: https://github.com/facebookresearch/deit 11 | # ------------------------------------------------------------------------ 12 | 13 | 14 | import argparse 15 | import datetime 16 | import json 17 | import numpy as np 18 | import os 19 | import time 20 | from pathlib import Path 21 | 22 | import torch 23 | import torch.backends.cudnn as cudnn 24 | from torch.utils.tensorboard import SummaryWriter 25 | import torchvision.datasets as datasets 26 | 27 | import timm 28 | assert timm.__version__ == "0.6.12" # version check 29 | from timm.models.layers import trunc_normal_ 30 | from timm.data.mixup import Mixup 31 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 32 | 33 | import util.lr_decay as lrd 34 | import util.misc as misc 35 | from util.datasets import build_transform 36 | from util.pos_embed import interpolate_pos_embed 37 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 38 | import models_vit 39 | from engine_finetune import train_one_epoch, evaluate 40 | 41 | 42 | def get_args_parser(): 43 | parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) 44 | parser.add_argument('--batch_size', default=64, type=int, 45 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 46 | parser.add_argument('--epochs', default=50, type=int) 47 | parser.add_argument('--accum_iter', default=1, type=int, 48 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 49 | 50 | # Model parameters 51 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 52 | help='Name of model to train') 53 | 54 | parser.add_argument('--input_size', default=224, type=int, 55 | help='images input size') 56 | 57 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 58 | help='Drop path rate (default: 0.1)') 59 | 60 | # Optimizer parameters 61 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 62 | help='Clip gradient norm (default: None, no clipping)') 63 | parser.add_argument('--weight_decay', type=float, default=0.05, 64 | help='weight decay (default: 0.05)') 65 | 66 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 67 | help='learning rate (absolute lr)') 68 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 69 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 70 | parser.add_argument('--layer_decay', type=float, default=0.75, 71 | help='layer-wise lr decay from ELECTRA/BEiT') 72 | 73 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 74 | help='lower lr bound for cyclic schedulers that hit 0') 75 | 76 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 77 | help='epochs to warmup LR') 78 | 79 | # Augmentation parameters 80 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', 81 | help='Color jitter factor (enabled only when not using Auto/RandAug)') 82 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 83 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 84 | parser.add_argument('--smoothing', type=float, default=0.1, 85 | help='Label smoothing (default: 0.1)') 86 | 87 | # * Random Erase params 88 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 89 | help='Random erase prob (default: 0.25)') 90 | parser.add_argument('--remode', type=str, default='pixel', 91 | help='Random erase mode (default: "pixel")') 92 | parser.add_argument('--recount', type=int, default=1, 93 | help='Random erase count (default: 1)') 94 | parser.add_argument('--resplit', action='store_true', default=False, 95 | help='Do not random erase first (clean) augmentation split') 96 | 97 | # * Mixup params 98 | parser.add_argument('--mixup', type=float, default=0, 99 | help='mixup alpha, mixup enabled if > 0.') 100 | parser.add_argument('--cutmix', type=float, default=0, 101 | help='cutmix alpha, cutmix enabled if > 0.') 102 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 103 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 104 | parser.add_argument('--mixup_prob', type=float, default=1.0, 105 | help='Probability of performing mixup or cutmix when either/both is enabled') 106 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 107 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 108 | parser.add_argument('--mixup_mode', type=str, default='batch', 109 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 110 | 111 | # * Finetuning params 112 | parser.add_argument('--finetune', default='', 113 | help='finetune from checkpoint') 114 | parser.add_argument('--global_pool', action='store_true') 115 | parser.set_defaults(global_pool=True) 116 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 117 | help='Use class token instead of global pool for classification') 118 | 119 | # Dataset parameters 120 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 121 | help='dataset path') 122 | parser.add_argument('--nb_classes', default=1000, type=int, 123 | help='number of the classification types') 124 | parser.add_argument('--use_tcs_dataset', default=False, action='store_true') 125 | 126 | parser.add_argument('--output_dir', default='./output_dir', 127 | help='path where to save, empty for no saving') 128 | parser.add_argument('--log_dir', default='./output_dir', 129 | help='path where to tensorboard log') 130 | parser.add_argument('--device', default='cuda', 131 | help='device to use for training / testing') 132 | parser.add_argument('--seed', default=0, type=int) 133 | parser.add_argument('--resume', default='', 134 | help='resume from checkpoint') 135 | 136 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 137 | help='start epoch') 138 | parser.add_argument('--eval', action='store_true', 139 | help='Perform evaluation only') 140 | parser.add_argument('--dist_eval', action='store_true', default=False, 141 | help='Enabling distributed evaluation (recommended during training for faster monitor') 142 | parser.add_argument('--num_workers', default=10, type=int) 143 | parser.add_argument('--pin_mem', action='store_true', 144 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 145 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 146 | parser.set_defaults(pin_mem=True) 147 | 148 | # distributed training parameters 149 | parser.add_argument('--world_size', default=1, type=int, 150 | help='number of distributed processes') 151 | parser.add_argument('--local_rank', default=-1, type=int) 152 | parser.add_argument('--dist_on_itp', action='store_true') 153 | parser.add_argument('--dist_url', default='env://', 154 | help='url used to set up distributed training') 155 | 156 | parser.add_argument('--auto_resume', action='store_true', default=True) 157 | parser.add_argument('--init_values', default=1.0, type=float) 158 | 159 | return parser 160 | 161 | 162 | def main(args): 163 | misc.init_distributed_mode(args) 164 | 165 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 166 | print("{}".format(args).replace(', ', ',\n')) 167 | 168 | device = torch.device(args.device) 169 | 170 | # fix the seed for reproducibility 171 | seed = args.seed + misc.get_rank() 172 | torch.manual_seed(seed) 173 | np.random.seed(seed) 174 | 175 | cudnn.benchmark = True 176 | torch.backends.cuda.matmul.allow_tf32 = False 177 | torch.backends.cudnn.allow_tf32 = False 178 | 179 | # build dataset 180 | transform_train = build_transform(is_train=True, args=args) 181 | transform_val = build_transform(is_train=False, args=args) 182 | if not args.use_tcs_dataset: 183 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 184 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val) 185 | else: 186 | from util.tcs_datasets import ImagenetTCSDataset 187 | dataset_train = ImagenetTCSDataset('train', 188 | 's3://imagenet', 189 | transform=transform_train, 190 | use_tcs=True) 191 | dataset_val = ImagenetTCSDataset('val', 192 | 's3://imagenet', 193 | transform=transform_val, 194 | use_tcs=True) 195 | 196 | print(dataset_train) 197 | print(dataset_val) 198 | 199 | if True: # args.distributed: 200 | num_tasks = misc.get_world_size() 201 | global_rank = misc.get_rank() 202 | sampler_train = torch.utils.data.DistributedSampler( 203 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 204 | ) 205 | print("Sampler_train = %s" % str(sampler_train)) 206 | if args.dist_eval: 207 | if len(dataset_val) % num_tasks != 0: 208 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 209 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 210 | 'equal num of samples per-process.') 211 | sampler_val = torch.utils.data.DistributedSampler( 212 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 213 | else: 214 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 215 | else: 216 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 217 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 218 | 219 | if global_rank == 0 and args.log_dir is not None and not args.eval: 220 | os.makedirs(args.log_dir, exist_ok=True) 221 | log_writer = SummaryWriter(log_dir=args.log_dir) 222 | else: 223 | log_writer = None 224 | 225 | data_loader_train = torch.utils.data.DataLoader( 226 | dataset_train, sampler=sampler_train, 227 | batch_size=args.batch_size, 228 | num_workers=args.num_workers, 229 | pin_memory=args.pin_mem, 230 | drop_last=True, 231 | ) 232 | 233 | data_loader_val = torch.utils.data.DataLoader( 234 | dataset_val, sampler=sampler_val, 235 | batch_size=args.batch_size, 236 | num_workers=args.num_workers, 237 | pin_memory=args.pin_mem, 238 | drop_last=False 239 | ) 240 | 241 | # build mixup 242 | mixup_fn = None 243 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 244 | if mixup_active: 245 | print("Mixup is activated!") 246 | mixup_fn = Mixup( 247 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 248 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 249 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 250 | 251 | # build model 252 | model = models_vit.__dict__[args.model]( 253 | num_classes=args.nb_classes, 254 | drop_path_rate=args.drop_path, 255 | global_pool=args.global_pool, 256 | init_values=args.init_values if args.init_values != 1.0 else None, 257 | ) 258 | 259 | # load ckpt 260 | if args.finetune and not args.eval: 261 | checkpoint = torch.load(args.finetune, map_location='cpu') 262 | print("Load pre-trained checkpoint from: %s" % args.finetune) 263 | checkpoint_model = checkpoint['model'] 264 | state_dict = model.state_dict() 265 | for k in ['head.weight', 'head.bias']: 266 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 267 | print(f"Removing key {k} from pretrained checkpoint") 268 | del checkpoint_model[k] 269 | 270 | # interpolate position embedding 271 | interpolate_pos_embed(model, checkpoint_model) 272 | 273 | # load pre-trained model 274 | msg = model.load_state_dict(checkpoint_model, strict=False) 275 | print(msg) 276 | 277 | if args.global_pool: 278 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 279 | else: 280 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 281 | 282 | # manually initialize fc layer 283 | if hasattr(model, 'head'): 284 | trunc_normal_(model.head.weight, std=2e-5) 285 | 286 | model.to(device) 287 | 288 | model_without_ddp = model 289 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 290 | 291 | print("Model = %s" % str(model_without_ddp)) 292 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 293 | 294 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 295 | 296 | if args.lr is None: # only base_lr is specified 297 | args.lr = args.blr * eff_batch_size / 256 298 | 299 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 300 | print("actual lr: %.2e" % args.lr) 301 | 302 | print("accumulate grad iterations: %d" % args.accum_iter) 303 | print("effective batch size: %d" % eff_batch_size) 304 | 305 | if args.distributed: 306 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 307 | model_without_ddp = model.module 308 | 309 | # build optimizer with layer-wise lr decay (lrd) 310 | param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, 311 | no_weight_decay_list=model_without_ddp.no_weight_decay(), 312 | layer_decay=args.layer_decay 313 | ) 314 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr) 315 | loss_scaler = NativeScaler() 316 | 317 | if mixup_fn is not None: 318 | # smoothing is handled with mixup label transform 319 | criterion = SoftTargetCrossEntropy() 320 | elif args.smoothing > 0.: 321 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 322 | else: 323 | criterion = torch.nn.CrossEntropyLoss() 324 | 325 | print("criterion = %s" % str(criterion)) 326 | 327 | # misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 328 | misc.auto_load_model( 329 | args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 330 | 331 | if args.eval: 332 | test_stats = evaluate(data_loader_val, model, device) 333 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 334 | exit(0) 335 | 336 | # start training 337 | print(f"Start training for {args.epochs} epochs") 338 | start_time = time.time() 339 | max_accuracy = 0.0 340 | for epoch in range(args.start_epoch, args.epochs): 341 | if args.distributed: 342 | data_loader_train.sampler.set_epoch(epoch) 343 | train_stats = train_one_epoch( 344 | model, criterion, data_loader_train, 345 | optimizer, device, epoch, loss_scaler, 346 | args.clip_grad, mixup_fn, 347 | log_writer=log_writer, 348 | args=args 349 | ) 350 | 351 | # save model 352 | if args.output_dir: 353 | misc.save_model( 354 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 355 | loss_scaler=loss_scaler, epoch=epoch, latest=True) 356 | 357 | if (epoch+1)%1 == 0: 358 | test_stats = evaluate(data_loader_val, model, device) 359 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 360 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 361 | print(f'Max accuracy: {max_accuracy:.2f}%') 362 | 363 | if log_writer is not None: 364 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 365 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 366 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 367 | 368 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 369 | **{f'test_{k}': v for k, v in test_stats.items()}, 370 | 'epoch': epoch, 371 | 'n_parameters': n_parameters} 372 | 373 | if args.output_dir and misc.is_main_process(): 374 | if log_writer is not None: 375 | log_writer.flush() 376 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 377 | f.write(json.dumps(log_stats) + "\n") 378 | 379 | total_time = time.time() - start_time 380 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 381 | print('Training time {}'.format(total_time_str)) 382 | 383 | 384 | if __name__ == '__main__': 385 | args = get_args_parser() 386 | args = args.parse_args() 387 | if args.output_dir: 388 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 389 | main(args) 390 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # SiameseIM 3 | # Copyright (c) SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from MAE (https://github.com/facebookresearch/mae) 6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # References: 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 11 | # ------------------------------------------------------------------------ 12 | 13 | 14 | import builtins 15 | import datetime 16 | import os 17 | import io 18 | import time 19 | from collections import defaultdict, deque 20 | from pathlib import Path 21 | 22 | import torch 23 | import torch.distributed as dist 24 | from torch._six import inf 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | 29 | class SmoothedValue(object): 30 | """Track a series of values and provide access to smoothed values over a 31 | window or the global series average. 32 | """ 33 | 34 | def __init__(self, window_size=20, fmt=None): 35 | if fmt is None: 36 | fmt = "{median:.4f} ({global_avg:.4f})" 37 | self.deque = deque(maxlen=window_size) 38 | self.total = 0.0 39 | self.count = 0 40 | self.fmt = fmt 41 | 42 | def update(self, value, n=1): 43 | self.deque.append(value) 44 | self.count += n 45 | self.total += value * n 46 | 47 | def synchronize_between_processes(self): 48 | """ 49 | Warning: does not synchronize the deque! 50 | """ 51 | if not is_dist_avail_and_initialized(): 52 | return 53 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 54 | dist.barrier() 55 | dist.all_reduce(t) 56 | t = t.tolist() 57 | self.count = int(t[0]) 58 | self.total = t[1] 59 | 60 | @property 61 | def median(self): 62 | d = torch.tensor(list(self.deque)) 63 | return d.median().item() 64 | 65 | @property 66 | def avg(self): 67 | d = torch.tensor(list(self.deque), dtype=torch.float32) 68 | return d.mean().item() 69 | 70 | @property 71 | def global_avg(self): 72 | return self.total / self.count 73 | 74 | @property 75 | def max(self): 76 | return max(self.deque) 77 | 78 | @property 79 | def value(self): 80 | return self.deque[-1] 81 | 82 | def __str__(self): 83 | return self.fmt.format( 84 | median=self.median, 85 | avg=self.avg, 86 | global_avg=self.global_avg, 87 | max=self.max, 88 | value=self.value) 89 | 90 | 91 | class MetricLogger(object): 92 | def __init__(self, delimiter="\t"): 93 | self.meters = defaultdict(SmoothedValue) 94 | self.delimiter = delimiter 95 | 96 | def update(self, **kwargs): 97 | for k, v in kwargs.items(): 98 | if v is None: 99 | continue 100 | if isinstance(v, torch.Tensor): 101 | v = v.item() 102 | assert isinstance(v, (float, int)) 103 | self.meters[k].update(v) 104 | 105 | def __getattr__(self, attr): 106 | if attr in self.meters: 107 | return self.meters[attr] 108 | if attr in self.__dict__: 109 | return self.__dict__[attr] 110 | raise AttributeError("'{}' object has no attribute '{}'".format( 111 | type(self).__name__, attr)) 112 | 113 | def __str__(self): 114 | loss_str = [] 115 | for name, meter in self.meters.items(): 116 | loss_str.append( 117 | "{}: {}".format(name, str(meter)) 118 | ) 119 | return self.delimiter.join(loss_str) 120 | 121 | def synchronize_between_processes(self): 122 | for meter in self.meters.values(): 123 | meter.synchronize_between_processes() 124 | 125 | def add_meter(self, name, meter): 126 | self.meters[name] = meter 127 | 128 | def log_every(self, iterable, print_freq, header=None): 129 | i = 0 130 | if not header: 131 | header = '' 132 | start_time = time.time() 133 | end = time.time() 134 | iter_time = SmoothedValue(fmt='{avg:.4f}') 135 | data_time = SmoothedValue(fmt='{avg:.4f}') 136 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 137 | log_msg = [ 138 | header, 139 | '[{0' + space_fmt + '}/{1}]', 140 | 'eta: {eta}', 141 | '{meters}', 142 | 'time: {time}', 143 | 'data: {data}' 144 | ] 145 | if torch.cuda.is_available(): 146 | log_msg.append('max mem: {memory:.0f}') 147 | log_msg = self.delimiter.join(log_msg) 148 | MB = 1024.0 * 1024.0 149 | for obj in iterable: 150 | data_time.update(time.time() - end) 151 | yield obj 152 | iter_time.update(time.time() - end) 153 | if i % print_freq == 0 or i == len(iterable) - 1: 154 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 155 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 156 | if torch.cuda.is_available(): 157 | print(log_msg.format( 158 | i, len(iterable), eta=eta_string, 159 | meters=str(self), 160 | time=str(iter_time), data=str(data_time), 161 | memory=torch.cuda.max_memory_allocated() / MB)) 162 | else: 163 | print(log_msg.format( 164 | i, len(iterable), eta=eta_string, 165 | meters=str(self), 166 | time=str(iter_time), data=str(data_time))) 167 | i += 1 168 | end = time.time() 169 | total_time = time.time() - start_time 170 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 171 | print('{} Total time: {} ({:.4f} s / it)'.format( 172 | header, total_time_str, total_time / len(iterable))) 173 | 174 | 175 | def setup_for_distributed(is_master): 176 | """ 177 | This function disables printing when not in master process 178 | """ 179 | builtin_print = builtins.print 180 | 181 | def print(*args, **kwargs): 182 | force = kwargs.pop('force', False) 183 | # force = force or (get_world_size() > 8) 184 | if is_master or force: 185 | now = datetime.datetime.now().time() 186 | builtin_print('[{}] '.format(now), end='') # print with time stamp 187 | builtin_print(*args, **kwargs) 188 | 189 | builtins.print = print 190 | 191 | 192 | def is_dist_avail_and_initialized(): 193 | if not dist.is_available(): 194 | return False 195 | if not dist.is_initialized(): 196 | return False 197 | return True 198 | 199 | 200 | def get_world_size(): 201 | if not is_dist_avail_and_initialized(): 202 | return 1 203 | return dist.get_world_size() 204 | 205 | 206 | def get_rank(): 207 | if not is_dist_avail_and_initialized(): 208 | return 0 209 | return dist.get_rank() 210 | 211 | 212 | def is_main_process(): 213 | return get_rank() == 0 214 | 215 | 216 | def save_on_master(*args, **kwargs): 217 | if is_main_process(): 218 | torch.save(*args, **kwargs) 219 | 220 | 221 | def init_distributed_mode(args): 222 | if args.dist_on_itp: 223 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 224 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 225 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 226 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 227 | os.environ['LOCAL_RANK'] = str(args.gpu) 228 | os.environ['RANK'] = str(args.rank) 229 | os.environ['WORLD_SIZE'] = str(args.world_size) 230 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 231 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 232 | args.rank = int(os.environ["RANK"]) 233 | args.world_size = int(os.environ['WORLD_SIZE']) 234 | args.gpu = int(os.environ['LOCAL_RANK']) 235 | elif 'SLURM_PROCID' in os.environ: 236 | args.rank = int(os.environ['SLURM_PROCID']) 237 | args.world_size = int(os.environ['SLURM_NTASKS']) 238 | node_list = os.environ['SLURM_STEP_NODELIST'] 239 | num_gpus = torch.cuda.device_count() 240 | args.gpu = args.rank % torch.cuda.device_count() 241 | torch.cuda.set_device(args.rank % num_gpus) 242 | import subprocess 243 | addr = subprocess.getoutput( 244 | f'scontrol show hostname {node_list} | head -n1') 245 | # specify master port 246 | if hasattr(args, 'port'): 247 | os.environ['MASTER_PORT'] = str(args.port) 248 | elif 'MASTER_PORT' in os.environ: 249 | pass # use MASTER_PORT in the environment variable 250 | else: 251 | # 29500 is torch.distributed default port 252 | os.environ['MASTER_PORT'] = '28506' 253 | # use MASTER_ADDR in the environment variable if it already exists 254 | if 'MASTER_ADDR' not in os.environ: 255 | os.environ['MASTER_ADDR'] = addr 256 | os.environ['WORLD_SIZE'] = str(args.world_size) 257 | os.environ['LOCAL_RANK'] = str(args.rank % num_gpus) 258 | os.environ['LOCAL_SIZE'] = str(num_gpus) 259 | os.environ['RANK'] = str(args.rank) 260 | # dist.init_process_group(backend='nccl') 261 | else: 262 | print('Not using distributed mode') 263 | setup_for_distributed(is_master=True) # hack 264 | args.distributed = False 265 | return 266 | 267 | args.distributed = True 268 | 269 | torch.cuda.set_device(args.gpu) 270 | args.dist_backend = 'nccl' 271 | print('| distributed init (rank {}): {}, gpu {}'.format( 272 | args.rank, args.dist_url, args.gpu), flush=True) 273 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 274 | world_size=args.world_size, rank=args.rank) 275 | torch.distributed.barrier() 276 | setup_for_distributed(args.rank == 0) 277 | 278 | 279 | class NativeScalerWithGradNormCount: 280 | state_dict_key = "amp_scaler" 281 | 282 | def __init__(self, enabled=True, growth_interval=2000): 283 | self.enabled = enabled 284 | self._scaler = torch.cuda.amp.GradScaler( 285 | enabled=enabled, growth_interval=growth_interval) 286 | 287 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 288 | self._scaler.scale(loss).backward(create_graph=create_graph) 289 | if update_grad: 290 | if clip_grad is not None: 291 | assert parameters is not None 292 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 293 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 294 | else: 295 | self._scaler.unscale_(optimizer) 296 | norm = get_grad_norm_(parameters) 297 | self._scaler.step(optimizer) 298 | self._scaler.update() 299 | else: 300 | norm = None 301 | return norm 302 | 303 | def state_dict(self): 304 | return self._scaler.state_dict() 305 | 306 | def load_state_dict(self, state_dict): 307 | if self.enabled: 308 | self._scaler.load_state_dict(state_dict) 309 | 310 | 311 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 312 | if isinstance(parameters, torch.Tensor): 313 | parameters = [parameters] 314 | parameters = [p for p in parameters if p.grad is not None] 315 | norm_type = float(norm_type) 316 | if len(parameters) == 0: 317 | return torch.tensor(0.) 318 | device = parameters[0].grad.device 319 | if norm_type == inf: 320 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 321 | else: 322 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 323 | return total_norm 324 | 325 | 326 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, 327 | latest=False, 328 | latest_postfix='latest'): 329 | output_dir = Path(args.output_dir) 330 | epoch_name = str(epoch) 331 | if loss_scaler is not None: 332 | checkpoint_paths = [] 333 | if latest: 334 | checkpoint_paths = [output_dir / (f'checkpoint-{latest_postfix}.pth')] 335 | else: 336 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 337 | to_save = { 338 | 'model': model_without_ddp.state_dict(), 339 | 'optimizer': optimizer.state_dict(), 340 | 'epoch': epoch, 341 | 'scaler': loss_scaler.state_dict(), 342 | 'args': args, 343 | } 344 | for checkpoint_path in checkpoint_paths: 345 | save_on_master(to_save, checkpoint_path) 346 | else: 347 | client_state = {'epoch': epoch} 348 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 349 | 350 | 351 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 352 | if args.resume: 353 | if args.resume.startswith('https'): 354 | checkpoint = torch.hub.load_state_dict_from_url( 355 | args.resume, map_location='cpu', check_hash=True) 356 | else: 357 | checkpoint = torch.load(args.resume, map_location='cpu') 358 | model_without_ddp.load_state_dict(checkpoint['model']) 359 | print("Resume checkpoint %s" % args.resume) 360 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 361 | optimizer.load_state_dict(checkpoint['optimizer']) 362 | args.start_epoch = checkpoint['epoch'] + 1 363 | if 'scaler' in checkpoint: 364 | loss_scaler.load_state_dict(checkpoint['scaler']) 365 | print("With optim & sched!") 366 | 367 | 368 | def auto_load_model(args, model_without_ddp, optimizer, loss_scaler): 369 | # torch.amp 370 | output_dir = Path(args.output_dir) 371 | 372 | if args.auto_resume and len(args.resume) == 0: 373 | import glob 374 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 375 | latest_ckpt = -1 376 | for ckpt in all_checkpoints: 377 | t = ckpt.split('-')[-1].split('.')[0] 378 | if t.isdigit(): 379 | latest_ckpt = max(int(t), latest_ckpt) 380 | if latest_ckpt >= 0: 381 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 382 | if os.path.exists(os.path.join(output_dir, 'checkpoint-latest.pth')): 383 | args.resume = os.path.join(output_dir, 'checkpoint-latest.pth') 384 | print("Auto resume checkpoint: %s" % args.resume) 385 | 386 | if args.resume: 387 | if args.resume.startswith('https'): 388 | checkpoint = torch.hub.load_state_dict_from_url( 389 | args.resume, map_location='cpu', check_hash=True) 390 | else: 391 | checkpoint = torch.load(args.resume, map_location='cpu') 392 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 393 | print("Resume checkpoint %s" % args.resume) 394 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 395 | optimizer.load_state_dict(checkpoint['optimizer']) 396 | args.start_epoch = checkpoint['epoch'] + 1 397 | if 'scaler' in checkpoint: 398 | loss_scaler.load_state_dict(checkpoint['scaler']) 399 | print("With optim & sched!") 400 | 401 | 402 | def all_reduce_mean(x): 403 | world_size = get_world_size() 404 | if world_size > 1: 405 | x_reduce = torch.tensor(x).cuda() 406 | dist.all_reduce(x_reduce) 407 | x_reduce /= world_size 408 | return x_reduce.item() 409 | else: 410 | return x 411 | 412 | 413 | class LayerNorm(nn.LayerNorm): 414 | 415 | @torch.cuda.amp.autocast(enabled=False) 416 | def forward(self, input): 417 | return super(LayerNorm, self).forward(input.float()) 418 | 419 | 420 | def add_lr_weight_decay(model, weight_decay=1e-5, lr=1e-4, skip_list=()): 421 | decay = [] 422 | no_decay = [] 423 | no_decay_names = [] 424 | decay_small_lr = [] 425 | decay_small_lr_names = [] 426 | for name, param in model.named_parameters(): 427 | if not param.requires_grad: 428 | continue # frozen weights 429 | if 'offset' in name: 430 | decay_small_lr.append(param) 431 | decay_small_lr_names.append(name) 432 | 433 | elif len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 434 | no_decay.append(param) 435 | no_decay_names.append(name) 436 | else: 437 | decay.append(param) 438 | print(f'decay_small_lr_names: {decay_small_lr_names}') 439 | print(f'no_decay_names: {no_decay_names}') 440 | return [ 441 | {'params': no_decay, 'weight_decay': 0., 'lr': lr}, 442 | {'params': decay, 'weight_decay': weight_decay, 'lr': lr}, 443 | {'params': decay_small_lr, 'weight_decay': weight_decay, 'lr': lr*0.1}, 444 | ] 445 | 446 | 447 | 448 | import math 449 | from torch.utils.data.sampler import Sampler 450 | 451 | 452 | class NodeDistributedSampler(Sampler): 453 | """Sampler that restricts data loading to a subset of the dataset. 454 | It is especially useful in conjunction with 455 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 456 | process can pass a DistributedSampler instance as a DataLoader sampler, 457 | and load a subset of the original dataset that is exclusive to it. 458 | .. note:: 459 | Dataset is assumed to be of constant size. 460 | Arguments: 461 | dataset: Dataset used for sampling. 462 | num_replicas (optional): Number of processes participating in 463 | distributed training. 464 | rank (optional): Rank of the current process within num_replicas. 465 | """ 466 | 467 | def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): 468 | if num_replicas is None: 469 | if not dist.is_available(): 470 | raise RuntimeError("Requires distributed package to be available") 471 | num_replicas = dist.get_world_size() 472 | if rank is None: 473 | if not dist.is_available(): 474 | raise RuntimeError("Requires distributed package to be available") 475 | rank = dist.get_rank() 476 | if local_rank is None: 477 | local_rank = int(os.environ.get('LOCAL_RANK', 0)) 478 | if local_size is None: 479 | local_size = int(os.environ.get('LOCAL_SIZE', 1)) 480 | self.dataset = dataset 481 | self.shuffle = shuffle 482 | self.num_replicas = num_replicas 483 | self.num_parts = local_size 484 | self.rank = rank 485 | self.local_rank = local_rank 486 | self.epoch = 0 487 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 488 | self.total_size = self.num_samples * self.num_replicas 489 | 490 | self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts 491 | 492 | def __iter__(self): 493 | if self.shuffle: 494 | # deterministically shuffle based on epoch 495 | g = torch.Generator() 496 | g.manual_seed(self.epoch) 497 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 498 | else: 499 | indices = torch.arange(len(self.dataset)).tolist() 500 | indices = [i for i in indices if i % self.num_parts == self.local_rank] 501 | 502 | # add extra samples to make it evenly divisible 503 | indices += indices[:(self.total_size_parts - len(indices))] 504 | assert len(indices) == self.total_size_parts 505 | 506 | # subsample 507 | indices = indices[self.rank // self.num_parts:self.total_size_parts:self.num_replicas // self.num_parts] 508 | assert len(indices) == self.num_samples 509 | 510 | return iter(indices) 511 | 512 | def __len__(self): 513 | return self.num_samples 514 | 515 | def set_epoch(self, epoch): 516 | self.epoch = epoch 517 | 518 | 519 | class GatherLayer(torch.autograd.Function): 520 | """Gather tensors from all process, supporting backward propagation. 521 | """ 522 | 523 | @staticmethod 524 | def forward(ctx, input): 525 | ctx.save_for_backward(input) 526 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] 527 | dist.all_gather(output, input) 528 | return torch.stack(output, 0) 529 | 530 | @staticmethod 531 | def backward(ctx, grads): 532 | input, = ctx.saved_tensors 533 | dist.all_reduce(grads) 534 | grad_out = torch.zeros_like(input) 535 | grad_out[:] = grads[dist.get_rank()] 536 | return grad_out 537 | 538 | 539 | class LabelSmoothingCrossEntropy(nn.Module): 540 | """ 541 | NLL loss with label smoothing. 542 | """ 543 | def __init__(self, smoothing=0.1): 544 | """ 545 | Constructor for the LabelSmoothing module. 546 | :param smoothing: label smoothing factor 547 | """ 548 | super(LabelSmoothingCrossEntropy, self).__init__() 549 | assert smoothing < 1.0 550 | self.smoothing = smoothing 551 | self.confidence = 1. - smoothing 552 | 553 | def forward(self, x, target, reduction='mean'): 554 | logprobs = F.log_softmax(x, dim=-1) 555 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 556 | nll_loss = nll_loss.squeeze(1) 557 | smooth_loss = -logprobs.mean(dim=-1) 558 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 559 | if reduction == 'mean': 560 | return loss.mean() 561 | elif reduction == 'none': 562 | return loss 563 | else: 564 | raise NotImplementedError 565 | 566 | 567 | class LabelSmoothingCrossEntropyWithSoftTarget(nn.Module): 568 | """ 569 | NLL loss with label smoothing. 570 | """ 571 | def __init__(self, smoothing=0.1): 572 | """ 573 | Constructor for the LabelSmoothing module. 574 | :param smoothing: label smoothing factor 575 | """ 576 | super(LabelSmoothingCrossEntropyWithSoftTarget, self).__init__() 577 | assert smoothing < 1.0 578 | self.smoothing = smoothing 579 | self.confidence = 1. - smoothing 580 | 581 | def forward(self, x, target, reduction='mean'): 582 | logprobs = F.log_softmax(x, dim=-1) 583 | # nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 584 | # nll_loss = nll_loss.squeeze(1) 585 | nll_loss = - (logprobs * target).sum(dim=-1) 586 | smooth_loss = -logprobs.mean(dim=-1) 587 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 588 | if reduction == 'mean': 589 | return loss.mean() 590 | elif reduction == 'none': 591 | return loss 592 | else: 593 | raise NotImplementedError 594 | --------------------------------------------------------------------------------