├── figs
├── overview.png
└── comparison.png
├── .gitignore
├── configs
├── few-shot
│ ├── dist_fewshot_sim_base.sh
│ └── slurm_fewshot_sim_base.sh
├── semisup_rebuttal.sh
├── semisup_sim_base_400ep.sh
├── linprobe
│ ├── dist_linprobe_sim_base.sh
│ └── slurm_linprobe_sim_base.sh
├── finetune
│ ├── dist_finetune_sim_base.sh
│ ├── dist_finetune_sim_base_eval.sh
│ ├── slurm_finetune_sim_base.sh
│ └── slurm_finetune_sim_base_eval.sh
├── semisup_sim_large_1600ep.sh
└── pretrain
│ ├── dist_sim_base_1600ep.sh
│ └── slurm_sim_base_1600ep.sh
├── util
├── lr_sched.py
├── crop.py
├── lars.py
├── lr_decay.py
├── datasets.py
├── augmentation.py
├── masking_generator.py
├── pos_embed.py
├── tcs_datasets.py
└── misc.py
├── docs
├── few_shot.md
├── linear_eval.md
├── checkpoints.md
├── pretrain.md
├── finetune.md
└── prepare.md
├── engine_finetune.py
├── engine_pretrain.py
├── README.md
├── models_vit.py
├── main_linprobe.py
├── main_pretrain.py
├── main_logistic.py
├── main_finetune.py
└── LICENSE
/figs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/Siamese-Image-Modeling/HEAD/figs/overview.png
--------------------------------------------------------------------------------
/figs/comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/Siamese-Image-Modeling/HEAD/figs/comparison.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | exp/
2 | **/__pycache__/
3 | *.pth
4 | batchscript-*
5 | phoenix-slurm-*
6 | .ipynb_checkpoints/
7 | .idea/
8 | .vscode/
9 |
--------------------------------------------------------------------------------
/configs/few-shot/dist_fewshot_sim_base.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | IP=${1}
4 | RANK=${2}
5 | NNODES=${3}
6 | CKPT_PATH=${4}
7 | DATA_PATH=${5}
8 | PORT=${PORT:-28500}
9 | PY_ARGS=${PY_ARGS:-""}
10 |
11 | BASENAME=$(basename ${CKPT_PATH})
12 | EXP_NAME=$(basename $(dirname ${CKPT_PATH}))
13 | DIR=./exp/fewshot/${EXP_NAME}
14 |
15 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=${NNODES} --node_rank=${RANK} --master_addr=${IP} --master_port=${PORT} \
16 | main_logistic.py \
17 | --subset-path imagenet_subset1/1percent.txt \
18 | --root-path ${DATA_PATH} \
19 | --image-folder imagenet_full_size/061417/ \
20 | --device cuda:0 \
21 | --pretrained ${CKPT_PATH} \
22 | --fname 'fewshot_1percent.pth' \
23 | --model-name 'vit_base_patch16' \
24 | --penalty l2 \
25 | --lambd 0.1 \
26 | --preload
--------------------------------------------------------------------------------
/configs/semisup_rebuttal.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | GPUS=${1}
4 | GPUS_PER_NODE=${2}
5 | JOB_NAME=${3}
6 | QUOTATYPE=${4}
7 | PARTITION=${5}
8 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
9 |
10 | DIR=./exp/semisup_ibot_400ep
11 | CKPT=./ckpt/ibot.pth
12 |
13 | srun --partition=vc_research_${PARTITION} \
14 | --mpi=pmi2 \
15 | --quotatype=${QUOTATYPE} \
16 | --job-name=${JOB_NAME} \
17 | -n$GPUS \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks-per-node=${GPUS_PER_NODE} \
20 | --cpus-per-task=$CPUS_PER_TASK \
21 | --kill-on-bad-exit=1 \
22 | --dependency=singleton \
23 | python -W ignore -u main_logistic.py \
24 | --subset-path imagenet_subset1/1percent.txt \
25 | --root-path /mnt/cache/share/images \
26 | --image-folder imagenet_full_size/061417/ \
27 | --device cuda:0 \
28 | --pretrained ${CKPT} \
29 | --fname 'semisup.pth' \
30 | --model-name 'vit_base_patch16' \
31 | --penalty l2 \
32 | --lambd 0.1
--------------------------------------------------------------------------------
/configs/semisup_sim_base_400ep.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | GPUS=${1}
4 | GPUS_PER_NODE=${2}
5 | JOB_NAME=${3}
6 | QUOTATYPE=${4}
7 | PARTITION=${5}
8 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
9 |
10 | DIR=./exp/semisup_sim_base_1600ep
11 | CKPT=./exp/pretrain_sim_base_400ep/checkpoint-399.pth
12 |
13 | srun --partition=vc_research_${PARTITION} \
14 | --mpi=pmi2 \
15 | --quotatype=${QUOTATYPE} \
16 | --job-name=${JOB_NAME} \
17 | -n$GPUS \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks-per-node=${GPUS_PER_NODE} \
20 | --cpus-per-task=$CPUS_PER_TASK \
21 | --kill-on-bad-exit=1 \
22 | python -W ignore -u main_logistic.py \
23 | --subset-path imagenet_subset1/1percent.txt \
24 | --root-path /mnt/cache/share/images \
25 | --image-folder imagenet_full_size/061417/ \
26 | --device cuda:0 \
27 | --pretrained ${CKPT} \
28 | --fname 'semisup.pth' \
29 | --model-name 'vit_base_patch16' \
30 | --penalty l2 \
31 | --lambd 0.1
--------------------------------------------------------------------------------
/configs/linprobe/dist_linprobe_sim_base.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | IP=${1}
4 | RANK=${2}
5 | NNODES=${3}
6 | CKPT_PATH=${4}
7 | DATA_PATH=${5}
8 | PORT=${PORT:-28500}
9 | PY_ARGS=${PY_ARGS:-""}
10 |
11 | TOTAL_BATCH_SIZE=16384
12 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${NNODES}/8
13 |
14 | BASENAME=$(basename ${CKPT_PATH})
15 | EXP_NAME=$(basename $(dirname ${CKPT_PATH}))
16 | DIR=./exp/linear/${EXP_NAME}
17 |
18 | mkdir -p ${DIR}
19 |
20 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=${NNODES} --node_rank=${RANK} --master_addr=${IP} --master_port=${PORT} \
21 | main_linprobe.py \
22 | --batch_size ${BATCH_SIZE} \
23 | --model vit_base_patch16 \
24 | --finetune ${CKPT_PATH} \
25 | --epochs 90 \
26 | --blr 0.1 \
27 | --weight_decay 0.0 \
28 | --dist_eval \
29 | --output_dir ${DIR} \
30 | --log_dir ${DIR} \
31 | --global_pool \
32 | --data_path ${DATA_PATH} \
33 | --use_tcs_dataset \
34 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt
--------------------------------------------------------------------------------
/configs/finetune/dist_finetune_sim_base.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | IP=${1}
4 | RANK=${2}
5 | NNODES=${3}
6 | CKPT_PATH=${4}
7 | DATA_PATH=${5}
8 | PORT=${PORT:-28500}
9 | PY_ARGS=${PY_ARGS:-""}
10 |
11 | TOTAL_BATCH_SIZE=1024
12 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${NNODES}/8
13 |
14 | BASENAME=$(basename ${CKPT_PATH})
15 | EXP_NAME=$(basename $(dirname ${CKPT_PATH}))
16 | DIR=./exp/finetune/${EXP_NAME}
17 |
18 | mkdir -p ${DIR}
19 |
20 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=${NNODES} --node_rank=${RANK} --master_addr=${IP} --master_port=${PORT} \
21 | main_finetune.py \
22 | --output_dir ${DIR} \
23 | --log_dir ${DIR} \
24 | --batch_size ${BATCH_SIZE} \
25 | --model vit_base_patch16 \
26 | --finetune ${CKPT_PATH} \
27 | --epochs 100 \
28 | --blr 2.5e-4 --layer_decay 0.65 \
29 | --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
30 | --dist_eval --data_path ${DATA_PATH} \
31 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt
--------------------------------------------------------------------------------
/configs/few-shot/slurm_fewshot_sim_base.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | GPUS=${1}
4 | GPUS_PER_NODE=${2}
5 | QUOTATYPE=${3}
6 | PARTITION=${4}
7 | CKPT_PATH=${5}
8 | DATA_PATH=${6}
9 | CPUS_PER_TASK=${CPUS_PER_TASK:-12}
10 |
11 | BASENAME=$(basename ${CKPT_PATH})
12 | EXP_NAME=$(basename $(dirname ${CKPT_PATH}))
13 | DIR=./exp/fewshot/${EXP_NAME}
14 | JOB_NAME=fewshot-${EXP}
15 |
16 | srun --partition=${PARTITION} \
17 | --mpi=pmi2 \
18 | --quotatype=${QUOTATYPE} \
19 | --job-name=${JOB_NAME} \
20 | -n$GPUS \
21 | --gres=gpu:${GPUS_PER_NODE} \
22 | --ntasks-per-node=${GPUS_PER_NODE} \
23 | --cpus-per-task=$CPUS_PER_TASK \
24 | --kill-on-bad-exit=1 \
25 | python -W ignore -u main_logistic.py \
26 | --subset-path imagenet_subset1/1percent.txt \
27 | --root-path ${DATA_PATH} \
28 | --image-folder imagenet_full_size/061417/ \
29 | --device cuda:0 \
30 | --pretrained ${CKPT_PATH} \
31 | --fname 'fewshot_1percent.pth' \
32 | --model-name 'vit_base_patch16' \
33 | --penalty l2 \
34 | --lambd 0.1
--------------------------------------------------------------------------------
/configs/finetune/dist_finetune_sim_base_eval.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | IP=${1}
4 | RANK=${2}
5 | NNODES=${3}
6 | CKPT_PATH=${4}
7 | DATA_PATH=${5}
8 | PORT=${PORT:-28500}
9 | PY_ARGS=${PY_ARGS:-""}
10 |
11 | TOTAL_BATCH_SIZE=1024
12 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${NNODES}/8
13 |
14 | BASENAME=$(basename ${CKPT_PATH})
15 | EXP_NAME=$(basename $(dirname ${CKPT_PATH}))
16 | DIR=./exp/finetune/${EXP_NAME}
17 |
18 | mkdir -p ${DIR}
19 |
20 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=${NNODES} --node_rank=${RANK} --master_addr=${IP} --master_port=${PORT} \
21 | main_finetune.py \
22 | --output_dir ${DIR} \
23 | --log_dir ${DIR} \
24 | --batch_size ${BATCH_SIZE} \
25 | --model vit_base_patch16 \
26 | --resume ${CKPT_PATH} \
27 | --epochs 100 \
28 | --blr 2.5e-4 --layer_decay 0.65 \
29 | --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
30 | --dist_eval --data_path ${DATA_PATH} \
31 | --eval \
32 | --use_tcs_dataset \
33 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt
--------------------------------------------------------------------------------
/configs/semisup_sim_large_1600ep.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | GPUS=${1}
4 | GPUS_PER_NODE=${2}
5 | JOB_NAME=${3}
6 | QUOTATYPE=${4}
7 | PARTITION=${5}
8 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
9 |
10 | DIR=./exp/semisup_sim_large_1600ep
11 | CKPT=./exp/pretrain_sim_large_1600ep/checkpoint-latest.pth
12 |
13 | srun --partition=vc_research_${PARTITION} \
14 | --mpi=pmi2 \
15 | --quotatype=${QUOTATYPE} \
16 | --job-name=${JOB_NAME} \
17 | -n$GPUS \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks-per-node=${GPUS_PER_NODE} \
20 | --cpus-per-task=$CPUS_PER_TASK \
21 | --kill-on-bad-exit=1 \
22 | --dependency=singleton \
23 | -x SH-IDC1-10-142-5-[45,13,70,198],SH-IDC1-10-142-4-[187,93,188,46,165,83,151,146,26] \
24 | python -W ignore -u main_logistic.py \
25 | --subset-path imagenet_subset1/1percent.txt \
26 | --root-path /mnt/cache/share/images \
27 | --image-folder imagenet_full_size/061417/ \
28 | --device cuda:0 \
29 | --pretrained ${CKPT} \
30 | --fname 'semisup.pth' \
31 | --model-name 'vit_large_patch16' \
32 | --penalty l2 \
33 | --lambd 0.01
--------------------------------------------------------------------------------
/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MAE (https://github.com/facebookresearch/mae)
6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 |
9 |
10 | import math
11 |
12 | def adjust_learning_rate(optimizer, epoch, args):
13 | """Decay the learning rate with half-cycle cosine after warmup"""
14 | if epoch < args.warmup_epochs:
15 | lr = args.lr * epoch / args.warmup_epochs
16 | else:
17 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
18 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
19 | for param_group in optimizer.param_groups:
20 | if "lr_scale" in param_group:
21 | param_group["lr"] = lr * param_group["lr_scale"]
22 | else:
23 | param_group["lr"] = lr
24 | return lr
25 |
--------------------------------------------------------------------------------
/configs/pretrain/dist_sim_base_1600ep.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | IP=${1}
4 | RANK=${2}
5 | NNODES=${3}
6 | DATA_PATH=${4}
7 | PORT=${PORT:-28500}
8 | PY_ARGS=${PY_ARGS:-""}
9 |
10 | BASENAME=`basename ${0} .sh`
11 | DIR=./exp/pretrain/${BASENAME}
12 | mkdir -p ${DIR}
13 |
14 | TOTAL_BATCH_SIZE=4096
15 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${NNODES}/8
16 |
17 | EPOCHS=1600
18 |
19 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=${NNODES} --node_rank=${RANK} --master_addr=${IP} --master_port=${PORT} \
20 | main_pretrain.py \
21 | --model sim_vit_base_patch16 \
22 | --decoder_embed_dim 768 \
23 | --batch_size ${BATCH_SIZE} \
24 | --epochs ${EPOCHS} \
25 | --warmup_epochs 40 \
26 | --crop_min 0.08 \
27 | --with_blockwise_mask \
28 | --blockwise_num_masking_patches 118 \
29 | --blr 6.25e-5 --weight_decay 0.05 \
30 | --mm 0.995 \
31 | --mmschedule 'cosine' \
32 | --clip_grad 1.0 \
33 | --loss_type 'sim' \
34 | --neg_weight 0.02 \
35 | --save_latest_freq 5 \
36 | --output_dir ${DIR} \
37 | --log_dir ${DIR} \
38 | --data_path ${DATA_PATH} \
39 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt
40 |
--------------------------------------------------------------------------------
/docs/few_shot.md:
--------------------------------------------------------------------------------
1 | # Few-shot Evaluation
2 |
3 | We provide the few-shot evaluation scripts here. We only use 1% ImageNet labelled data to train the model. We follow [MSN](https://github.com/facebookresearch/msn/blob/main/logistic_eval.py) to train a linear classifier on the representation, without tuning model's parameters.
4 |
5 | ## Train with torch.distributed.launch
6 | Few-shot evaluation does not require high computational resources, so it is enough to run the scripts on a single node, shown as follows.
7 |
8 | ```
9 | sh ./configs/few-shot/dist_fewshot_sim_base.sh ${MASTER_ADDR} 0 1 ${CKPT_PATH} ${DATA_PATH}
10 | ```
11 |
12 | Note:
13 | The `${MASTER_ADDR}` is the ip address of rank 0 node. The second and third arguments specify the node rank and node number respectively. You need to adjust them if different node numbders are used.
14 |
15 | ## Train on a slurm cluster
16 | If you need to run the few-shot evaluation on a slurm cluster, use the command below to run on `${GPUS}/${GPUS_PER_NODE}` nodes with `${GPUS_PER_NODE}` gpus on each node:
17 | ```
18 | sh ./configs/few-shot/slurm_fewshot_sim_base.sh ${GPUS} ${GPUS_PER_NODE} ${QUOTATYPE} ${PARTITION} ${CKPT_PATH} ${DATA_PATH}
19 | ```
20 |
--------------------------------------------------------------------------------
/docs/linear_eval.md:
--------------------------------------------------------------------------------
1 | # Linear Evaluation
2 |
3 | We provide the linear evaluation scripts here. The evaluation setting mainly follows MAE, which uses 16384 batch size and LARS optimizer.
4 |
5 | ## Train with torch.distributed.launch
6 | This method supports training on multi-nodes with torch.distributed.launch. For example, to conduct linear evaluation on 2 nodes, run the command below.
7 |
8 | On node 1:
9 | ```
10 | sh ./configs/linprobe/dist_linprobe_sim_base.sh ${MASTER_ADDR} 0 2 ${CKPT_PATH} ${DATA_PATH}
11 | ```
12 |
13 | On node 2:
14 | ```
15 | sh ./configs/linprobe/dist_linprobe_sim_base.sh ${MASTER_ADDR} 1 2 ${CKPT_PATH} ${DATA_PATH}
16 | ```
17 |
18 | Note:
19 | The `${MASTER_ADDR}` is the ip address of rank 0 node. The second and third arguments specify the node rank and node number respectively. You need to adjust them if different node numbders are used.
20 |
21 | ## Train on a slurm cluster
22 | If you need to run the linear evaluation on a slurm cluster, use the command below to run on `${GPUS}/${GPUS_PER_NODE}` nodes with `${GPUS_PER_NODE}` gpus on each node:
23 | ```
24 | sh ./configs/linprobe/slurm_linprobe_sim_base.sh ${GPUS} ${GPUS_PER_NODE} ${QUOTATYPE} ${PARTITION} ${CKPT_PATH} ${DATA_PATH}
25 | ```
26 |
--------------------------------------------------------------------------------
/docs/checkpoints.md:
--------------------------------------------------------------------------------
1 | # Checkpoints
2 | We provide links for you to download the checkpoints of SiameseIM models here.
3 |
4 |
5 |
6 |
7 | | Model | Backbone | Pretrained Epoch | Finetuned on ImageNet | Link |
8 |
9 |
10 | | SiameseIM | ViT-Base | 1600 | w/o | Download |
11 |
12 |
13 | | SiameseIMft | ViT-Base | 1600 | w/ | Download |
14 |
15 |
16 |
17 |
18 | * The SiameseIM model is only pretrained on ImageNet datasets for 1600 epochs. For pretraining details, see [pretrain.md](./pretrain.md).
19 | * The SiameseIM$`_{\mathrm{ft}}`$ model is first pretrained for 1600 epochs, and the finetuned with ImageNet classification task for 100 epochs. For finetuning details, see [finetune.md](./finetune.md).
20 | * More pre-trained weights will be released.
21 |
--------------------------------------------------------------------------------
/configs/linprobe/slurm_linprobe_sim_base.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | GPUS=${1}
4 | GPUS_PER_NODE=${2}
5 | QUOTATYPE=${3}
6 | PARTITION=${4}
7 | CKPT_PATH=${5}
8 | DATA_PATH=${6}
9 | CPUS_PER_TASK=${CPUS_PER_TASK:-12}
10 | SRUN_ARGS=${SRUN_ARGS:-""}
11 | PY_ARGS=${PY_ARGS:-""}
12 |
13 |
14 | TOTAL_BATCH_SIZE=16384
15 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${GPUS}
16 |
17 | BASENAME=$(basename ${CKPT_PATH})
18 | EXP_NAME=$(basename $(dirname ${CKPT_PATH}))
19 | DIR=./exp/linear/${EXP_NAME}
20 | JOB_NAME=lin-${EXP}
21 |
22 | mkdir -p ${DIR}
23 |
24 | srun --partition=${PARTITION} \
25 | --mpi=pmi2 \
26 | --open-mode=append \
27 | --quotatype=${QUOTATYPE} \
28 | --job-name=${JOB_NAME} \
29 | -n$GPUS \
30 | --gres=gpu:${GPUS_PER_NODE} \
31 | --ntasks-per-node=${GPUS_PER_NODE} \
32 | --cpus-per-task=$CPUS_PER_TASK \
33 | --kill-on-bad-exit=1 \
34 | ${SRUN_ARGS} \
35 | python -u main_linprobe.py \
36 | --batch_size ${BATCH_SIZE} \
37 | --model vit_base_patch16 \
38 | --finetune ${CKPT_PATH} \
39 | --epochs 90 \
40 | --blr 0.1 \
41 | --weight_decay 0.0 \
42 | --dist_eval \
43 | --output_dir ${DIR} \
44 | --log_dir ${DIR} \
45 | --global_pool \
46 | --data_path ${DATA_PATH} \
47 | --use_tcs_dataset \
48 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt
--------------------------------------------------------------------------------
/configs/finetune/slurm_finetune_sim_base.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | GPUS=${1}
4 | GPUS_PER_NODE=${2}
5 | QUOTATYPE=${3}
6 | PARTITION=${4}
7 | CPUS_PER_TASK=${CPUS_PER_TASK:-12}
8 | CKPT_PATH=${5}
9 | DATA_PATH=${6}
10 | SRUN_ARGS=${SRUN_ARGS:-""}
11 | PY_ARGS=${PY_ARGS:-""}
12 |
13 |
14 | TOTAL_BATCH_SIZE=1024
15 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${GPUS}
16 |
17 | BASENAME=$(basename ${CKPT_PATH})
18 | EXP_NAME=$(basename $(dirname ${CKPT_PATH}))
19 | DIR=./exp/finetune/${EXP_NAME}
20 | JOB_NAME=ft-${EXP}
21 |
22 | mkdir -p ${DIR}
23 |
24 | srun --partition=${PARTITION} \
25 | --mpi=pmi2 \
26 | --quotatype=${QUOTATYPE} \
27 | --job-name=${JOB_NAME} \
28 | -n$GPUS \
29 | --gres=gpu:${GPUS_PER_NODE} \
30 | --ntasks-per-node=${GPUS_PER_NODE} \
31 | --cpus-per-task=$CPUS_PER_TASK \
32 | --kill-on-bad-exit=1 \
33 | ${SRUN_ARGS} \
34 | python -u main_finetune.py \
35 | --output_dir ${DIR} \
36 | --log_dir ${DIR} \
37 | --batch_size ${BATCH_SIZE} \
38 | --model vit_base_patch16 \
39 | --finetune ${CKPT_PATH} \
40 | --epochs 100 \
41 | --blr 2.5e-4 --layer_decay 0.65 \
42 | --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
43 | --dist_eval --data_path ${DATA_PATH} \
44 | --use_tcs_dataset \
45 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt
46 |
--------------------------------------------------------------------------------
/configs/finetune/slurm_finetune_sim_base_eval.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | GPUS=${1}
4 | GPUS_PER_NODE=${2}
5 | QUOTATYPE=${3}
6 | PARTITION=${4}
7 | CPUS_PER_TASK=${CPUS_PER_TASK:-12}
8 | CKPT_PATH=${5}
9 | DATA_PATH=${6}
10 | SRUN_ARGS=${SRUN_ARGS:-""}
11 | PY_ARGS=${PY_ARGS:-""}
12 |
13 |
14 | TOTAL_BATCH_SIZE=1024
15 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${GPUS}
16 |
17 | BASENAME=$(basename ${CKPT_PATH})
18 | EXP_NAME=$(basename $(dirname ${CKPT_PATH}))
19 | DIR=./exp/finetune/${EXP_NAME}
20 | JOB_NAME=ft-${EXP}
21 |
22 | mkdir -p ${DIR}
23 |
24 | srun --partition=${PARTITION} \
25 | --mpi=pmi2 \
26 | --quotatype=${QUOTATYPE} \
27 | --job-name=${JOB_NAME} \
28 | -n$GPUS \
29 | --gres=gpu:${GPUS_PER_NODE} \
30 | --ntasks-per-node=${GPUS_PER_NODE} \
31 | --cpus-per-task=$CPUS_PER_TASK \
32 | --kill-on-bad-exit=1 \
33 | ${SRUN_ARGS} \
34 | python -u main_finetune.py \
35 | --output_dir ${DIR} \
36 | --log_dir ${DIR} \
37 | --batch_size ${BATCH_SIZE} \
38 | --model vit_base_patch16 \
39 | --resume ${CKPT_PATH} \
40 | --epochs 100 \
41 | --blr 2.5e-4 --layer_decay 0.65 \
42 | --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
43 | --dist_eval --data_path ${DATA_PATH} \
44 | --eval \
45 | --use_tcs_dataset \
46 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt
47 |
--------------------------------------------------------------------------------
/configs/pretrain/slurm_sim_base_1600ep.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | GPUS=${1}
4 | GPUS_PER_NODE=${2}
5 | JOB_NAME=${3}
6 | QUOTATYPE=${4}
7 | PARTITION=${5}
8 | DATA_PATH=${6}
9 | CPUS_PER_TASK=${CPUS_PER_TASK:-8}
10 | SRUN_ARGS=${SRUN_ARGS:-""}
11 | PY_ARGS=${PY_ARGS:-""}
12 |
13 | BASENAME=`basename ${0} .sh`
14 | DIR=./exp/pretrain/${BASENAME}
15 | mkdir -p ${DIR}
16 |
17 | TOTAL_BATCH_SIZE=4096
18 | let BATCH_SIZE=${TOTAL_BATCH_SIZE}/${GPUS}
19 |
20 | EPOCHS=1600
21 |
22 | srun --partition=${PARTITION} \
23 | --mpi=pmi2 \
24 | --quotatype=${QUOTATYPE} \
25 | --job-name=${JOB_NAME} \
26 | -n$GPUS \
27 | --gres=gpu:${GPUS_PER_NODE} \
28 | --ntasks-per-node=${GPUS_PER_NODE} \
29 | --cpus-per-task=$CPUS_PER_TASK \
30 | --kill-on-bad-exit=1 \
31 | ${SRUN_ARGS} \
32 | python -u main_pretrain.py \
33 | --model sim_vit_base_patch16 \
34 | --decoder_embed_dim 768 \
35 | --batch_size ${BATCH_SIZE} \
36 | --epochs ${EPOCHS} \
37 | --warmup_epochs 40 \
38 | --crop_min 0.08 \
39 | --with_blockwise_mask \
40 | --blockwise_num_masking_patches 118 \
41 | --blr 6.25e-5 --weight_decay 0.05 \
42 | --mm 0.995 \
43 | --mmschedule 'cosine' \
44 | --clip_grad 1.0 \
45 | --loss_type 'sim' \
46 | --neg_weight 0.02 \
47 | --save_latest_freq 5 \
48 | --output_dir ${DIR} \
49 | --log_dir ${DIR} \
50 | --data_path ${DATA_PATH} \
51 | ${PY_ARGS} 2>&1 | tee -a ${DIR}/stdout.txt
52 |
--------------------------------------------------------------------------------
/docs/pretrain.md:
--------------------------------------------------------------------------------
1 | # Pretrain
2 |
3 | We provide the pretraining scripts here. To pretrain a SiameseIM model, it is recommended that
4 | * use 4096 batch size, which should fit into 32 V100 gpus with 32G memory;
5 | * pretrain for 1600 epochs for better performance. We also note that pretraining SiameseIM for 400 epochs can already match the performances of 1600 epoch MAE on some tasks;
6 | * We provide the 1600 epoch pretrained checkpoint in [checkpoints.md](./checkpoints.md).
7 |
8 | ## Train with torch.distributed.launch
9 | This method supports training on multi-nodes with torch.distributed.launch. For example, to pretrain a SiameseIM model on 2 nodes, run the command below.
10 |
11 | On node 1:
12 | ```
13 | sh ./configs/pretrain/dist_sim_base_1600ep.sh ${MASTER_ADDR} 0 2 ${DATA_PATH}
14 | ```
15 |
16 | On node 2:
17 | ```
18 | sh ./configs/pretrain/dist_sim_base_1600ep.sh ${MASTER_ADDR} 1 2 ${DATA_PATH}
19 | ```
20 |
21 | Note:
22 | The `${MASTER_ADDR}` is the ip address of rank 0 node. The second and third arguments specify the node rank and node number respectively. You need to adjust them if different node numbders are used.
23 |
24 | ## Train on a slurm cluster
25 | If you need to run the pretraining on a slurm cluster, use the command below to run on `${GPUS}/${GPUS_PER_NODE}` nodes with `${GPUS_PER_NODE}` gpus on each node:
26 | ```
27 | sh ./configs/pretrain/slurm_sim_base_1600ep.sh ${GPUS} ${GPUS_PER_NODE} ${JOB_NAME} ${QUOTATYPE} ${PARTITION} ${DATA_PATH}
28 | ```
29 |
--------------------------------------------------------------------------------
/util/crop.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MAE (https://github.com/facebookresearch/mae)
6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 |
9 | import math
10 |
11 | import torch
12 |
13 | from torchvision import transforms
14 | from torchvision.transforms import functional as F
15 |
16 |
17 | class RandomResizedCrop(transforms.RandomResizedCrop):
18 | """
19 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
20 | This may lead to results different with torchvision's version.
21 | Following BYOL's TF code:
22 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
23 | """
24 | @staticmethod
25 | def get_params(img, scale, ratio):
26 | width, height = F.get_image_size(img)
27 | area = height * width
28 |
29 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
30 | log_ratio = torch.log(torch.tensor(ratio))
31 | aspect_ratio = torch.exp(
32 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
33 | ).item()
34 |
35 | w = int(round(math.sqrt(target_area * aspect_ratio)))
36 | h = int(round(math.sqrt(target_area / aspect_ratio)))
37 |
38 | w = min(w, width)
39 | h = min(h, height)
40 |
41 | i = torch.randint(0, height - h + 1, size=(1,)).item()
42 | j = torch.randint(0, width - w + 1, size=(1,)).item()
43 |
44 | return i, j, h, w
45 |
--------------------------------------------------------------------------------
/docs/finetune.md:
--------------------------------------------------------------------------------
1 | # Finetune
2 |
3 | We provide the finetuning scripts here. To finetune a SiameseIM model, it is recommended that
4 | * use 1024 batch size, which should fit into 8 V100 gpus with 32G memory;
5 | * We provide the finetuned checkpoint in [checkpoints.md](./checkpoints.md).
6 |
7 | ## Train with torch.distributed.launch
8 | This method supports training on multi-nodes with torch.distributed.launch. For example, to finetune a SiameseIM model on 2 nodes, run the command below.
9 |
10 | On node 1:
11 | ```
12 | sh ./configs/finetune/dist_finetune_sim_base.sh ${MASTER_ADDR} 0 2 ${CKPT_PATH} ${DATA_PATH}
13 | ```
14 |
15 | On node 2:
16 | ```
17 | sh ./configs/finetune/dist_finetune_sim_base.sh ${MASTER_ADDR} 1 2 ${CKPT_PATH} ${DATA_PATH}
18 | ```
19 |
20 | Note:
21 | The `${MASTER_ADDR}` is the ip address of rank 0 node. The second and third arguments specify the node rank and node number respectively. You need to adjust them if different node numbders are used.
22 |
23 | ## Train on a slurm cluster
24 | If you need to run the finetuning on a slurm cluster, use the command below to run on `${GPUS}/${GPUS_PER_NODE}` nodes with `${GPUS_PER_NODE}` gpus on each node:
25 | ```
26 | sh ./configs/finetune/slurm_finetune_sim_base.sh ${GPUS} ${GPUS_PER_NODE} ${QUOTATYPE} ${PARTITION} ${CKPT_PATH} ${DATA_PATH}
27 | ```
28 |
29 | ## Evaluation
30 | We also provide the evaluation scripts as follows.
31 |
32 | For torch.distributed.launch, use
33 | ```
34 | sh ./configs/finetune/dist_finetune_sim_base_eval.sh ${MASTER_ADDR} 0 1 ${CKPT_PATH} ${DATA_PATH}
35 | ```
36 |
37 | For slurm launch, use
38 | ```
39 | sh ./configs/finetune/slurm_finetune_sim_base_eval.sh ${GPUS} ${GPUS_PER_NODE} ${QUOTATYPE} ${PARTITION} ${CKPT_PATH} ${DATA_PATH}
40 | ```
41 | You should get
42 | ```
43 | * Acc@1 84.118 Acc@5 96.766 loss 0.728
44 | ```
45 | for the provided checkpoint.
46 |
--------------------------------------------------------------------------------
/docs/prepare.md:
--------------------------------------------------------------------------------
1 | # Preparation
2 |
3 | * The only dataset required in this repo is ImageNet, which is enough for pretraining, finetuning, linear evaluation and few-shot evaluation. If you want to evaluate on COCO, LVIS, ADE20k and robustness datasets, please follow the corresponding repos to prepare the data.
4 |
5 | ## Installation
6 |
7 | * Python >=3.7
8 | * We recommend to use Pytorch1.11 for a faster training speed.
9 | * timm == 0.6.12
10 | * numpy == 1.21.5
11 | * tensorboard
12 |
13 | To run few-shot evaluation, [cyanure](https://github.com/inria-thoth/cyanure) package is further required. You can install it with
14 | ```
15 | pip install cyanure-openblas
16 | # or pip install cyanure-mkl
17 | ```
18 |
19 | ## Data preparation
20 |
21 | Download and extract ImageNet train and val images from http://image-net.org/.
22 | The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder), and the training and validation data is expected to be in the `train/` folder and `val` folder respectively:
23 |
24 | ```
25 | /path/to/imagenet/
26 | ├── train/
27 | │ ├── class1/
28 | │ │ ├── img1.JPEG
29 | | │ ├── img2.JPEG
30 | | │ ├── img3.JPEG
31 | | │ └── ...
32 | │ ├── class2/
33 | | │ └── ...
34 | │ ├── class3/
35 | | │ └── ...
36 | | └── ...
37 | └─── val
38 | │ ├── class1/
39 | │ │ ├── img4.JPEG
40 | | │ ├── img5.JPEG
41 | | │ ├── img6.JPEG
42 | | │ └── ...
43 | │ ├── class2/
44 | | │ └── ...
45 | │ ├── class3/
46 | | │ └── ...
47 | ```
48 |
49 | Note that raw val images are not put into class folders, use [this script](https://github.com/pytorch/examples/blob/main/imagenet/extract_ILSVRC.sh) to get correct layout.
50 |
--------------------------------------------------------------------------------
/util/lars.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MAE (https://github.com/facebookresearch/mae)
6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 | # LARS optimizer, implementation from MoCo v3:
9 | # https://github.com/facebookresearch/moco-v3
10 | # ------------------------------------------------------------------------
11 |
12 |
13 | import torch
14 |
15 |
16 | class LARS(torch.optim.Optimizer):
17 | """
18 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
19 | """
20 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
21 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
22 | super().__init__(params, defaults)
23 |
24 | @torch.no_grad()
25 | def step(self):
26 | for g in self.param_groups:
27 | for p in g['params']:
28 | dp = p.grad
29 |
30 | if dp is None:
31 | continue
32 |
33 | if p.ndim > 1: # if not normalization gamma/beta or bias
34 | dp = dp.add(p, alpha=g['weight_decay'])
35 | param_norm = torch.norm(p)
36 | update_norm = torch.norm(dp)
37 | one = torch.ones_like(param_norm)
38 | q = torch.where(param_norm > 0.,
39 | torch.where(update_norm > 0,
40 | (g['trust_coefficient'] * param_norm / update_norm), one),
41 | one)
42 | dp = dp.mul(q)
43 |
44 | param_state = self.state[p]
45 | if 'mu' not in param_state:
46 | param_state['mu'] = torch.zeros_like(p)
47 | mu = param_state['mu']
48 | mu.mul_(g['momentum']).add_(dp)
49 | p.add_(mu, alpha=-g['lr'])
50 |
--------------------------------------------------------------------------------
/util/lr_decay.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MAE (https://github.com/facebookresearch/mae)
6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 | # References:
9 | # ELECTRA https://github.com/google-research/electra
10 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
11 | # ------------------------------------------------------------------------
12 |
13 | import json
14 |
15 |
16 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75,
17 | small_lr_keywords=('offset',), small_lr_ratio=0.1):
18 | """
19 | Parameter groups for layer-wise lr decay
20 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
21 | """
22 | param_group_names = {}
23 | param_groups = {}
24 |
25 | if hasattr(model, 'blocks'):
26 | num_layers = len(model.blocks) + 1
27 | else:
28 | raise NotImplementedError
29 |
30 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
31 |
32 | for n, p in model.named_parameters():
33 | if not p.requires_grad:
34 | continue
35 |
36 | small_lr = False
37 | for small_lr_keyword in small_lr_keywords:
38 | if small_lr_keyword in n:
39 | g_decay = 'decay_small_lr'
40 | this_decay = weight_decay
41 | layer_id = get_layer_id_for_vit(n, num_layers)
42 | group_name = "layer_%d_%s" % (layer_id, g_decay)
43 | small_lr = True
44 | this_scale = layer_scales[layer_id] * 0.1
45 |
46 | if not small_lr:
47 | # no decay: all 1D parameters and model specific ones
48 | if p.ndim == 1 or n in no_weight_decay_list:
49 | g_decay = "no_decay"
50 | this_decay = 0.
51 | else:
52 | g_decay = "decay"
53 | this_decay = weight_decay
54 |
55 | layer_id = get_layer_id_for_vit(n, num_layers)
56 | group_name = "layer_%d_%s" % (layer_id, g_decay)
57 | this_scale = layer_scales[layer_id]
58 |
59 | if group_name not in param_group_names:
60 | param_group_names[group_name] = {
61 | "lr_scale": this_scale,
62 | "weight_decay": this_decay,
63 | "params": [],
64 | }
65 | param_groups[group_name] = {
66 | "lr_scale": this_scale,
67 | "weight_decay": this_decay,
68 | "params": [],
69 | }
70 |
71 | param_group_names[group_name]["params"].append(n)
72 | param_groups[group_name]["params"].append(p)
73 |
74 | print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
75 |
76 | return list(param_groups.values())
77 |
78 |
79 | def get_layer_id_for_vit(name, num_layers):
80 | """
81 | Assign a parameter with its layer id
82 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
83 | """
84 | if name in ['cls_token', 'pos_embed']:
85 | return 0
86 | elif name.startswith('patch_embed'):
87 | return 0
88 | elif name.startswith('blocks'):
89 | return int(name.split('.')[1]) + 1
90 | else:
91 | return num_layers
92 |
--------------------------------------------------------------------------------
/util/datasets.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MAE (https://github.com/facebookresearch/mae)
6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 | # References:
9 | # DeiT: https://github.com/facebookresearch/deit
10 | # ------------------------------------------------------------------------
11 |
12 | import os
13 | import PIL
14 |
15 | from torchvision import datasets, transforms
16 |
17 | from timm.data import create_transform
18 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
19 |
20 |
21 | def build_transform(is_train, args):
22 | mean = IMAGENET_DEFAULT_MEAN
23 | std = IMAGENET_DEFAULT_STD
24 | # train transform
25 | if is_train:
26 | # this should always dispatch to transforms_imagenet_train
27 | transform = create_transform(
28 | input_size=args.input_size,
29 | is_training=True,
30 | color_jitter=args.color_jitter,
31 | auto_augment=args.aa,
32 | interpolation='bicubic',
33 | re_prob=args.reprob,
34 | re_mode=args.remode,
35 | re_count=args.recount,
36 | mean=mean,
37 | std=std,
38 | )
39 | return transform
40 |
41 | # eval transform
42 | t = []
43 | if args.input_size <= 224:
44 | crop_pct = 224 / 256
45 | else:
46 | crop_pct = 1.0
47 | size = int(args.input_size / crop_pct)
48 | t.append(
49 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
50 | )
51 | t.append(transforms.CenterCrop(args.input_size))
52 |
53 | t.append(transforms.ToTensor())
54 | t.append(transforms.Normalize(mean, std))
55 | return transforms.Compose(t)
56 |
57 |
58 | class ImagenetWithMask(datasets.ImageFolder):
59 | def __init__(self, root,
60 | transform = None,
61 | with_blockwise_mask=False, ### !!! set to True, enable blockwise masking
62 | blockwise_num_masking_patches=75, ### !!! 75 / 196 = 0.38 -> Modify this to increase mask ratio
63 | input_size=224, patch_size=16, # no need to change now
64 | max_mask_patches_per_block=None, # BEiT default setting, no need to change
65 | min_mask_patches_per_block=16, # BEiT default setting, no need to change
66 | fixed_num_masking_patches=True, ### set to true, fixed number of masking patch to blockwise_num_masking_patches for sim training
67 | ):
68 | super().__init__(root, transform)
69 | self.with_blockwise_mask = with_blockwise_mask
70 | if with_blockwise_mask:
71 | from .masking_generator import MaskingGenerator
72 | window_size = input_size // patch_size
73 | self.masked_position_generator = MaskingGenerator(
74 | (window_size, window_size),
75 | num_masking_patches=blockwise_num_masking_patches,
76 | max_num_patches=max_mask_patches_per_block,
77 | min_num_patches=min_mask_patches_per_block,
78 | fixed_num_masking_patches=fixed_num_masking_patches
79 | )
80 |
81 | def __getitem__(self, index):
82 | sample, target = super().__getitem__(index)
83 | if self.with_blockwise_mask:
84 | return sample, target, self.masked_position_generator()
85 | return sample, target
86 |
--------------------------------------------------------------------------------
/engine_finetune.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MAE (https://github.com/facebookresearch/mae)
6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 | # References:
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # DeiT: https://github.com/facebookresearch/deit
11 | # ------------------------------------------------------------------------
12 |
13 | import math
14 | import sys
15 | from typing import Iterable, Optional
16 |
17 | import torch
18 |
19 | from timm.data import Mixup
20 | from timm.utils import accuracy
21 |
22 | import util.misc as misc
23 | import util.lr_sched as lr_sched
24 |
25 |
26 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
27 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
28 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
29 | mixup_fn: Optional[Mixup] = None, log_writer=None,
30 | args=None):
31 | model.train(True)
32 | metric_logger = misc.MetricLogger(delimiter=" ")
33 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
34 | header = 'Epoch: [{}]'.format(epoch)
35 | print_freq = 20
36 |
37 | accum_iter = args.accum_iter
38 |
39 | optimizer.zero_grad()
40 |
41 | if log_writer is not None:
42 | print('log_dir: {}'.format(log_writer.log_dir))
43 |
44 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
45 |
46 | # we use a per iteration (instead of per epoch) lr scheduler
47 | if data_iter_step % accum_iter == 0:
48 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
49 |
50 | samples = samples.to(device, non_blocking=True)
51 | targets = targets.to(device, non_blocking=True)
52 |
53 | if mixup_fn is not None:
54 | samples, targets = mixup_fn(samples, targets)
55 |
56 | with torch.cuda.amp.autocast():
57 | outputs = model(samples)
58 | loss = criterion(outputs, targets)
59 |
60 | loss_value = loss.item()
61 |
62 | if not math.isfinite(loss_value):
63 | print("Loss is {}, stopping training".format(loss_value))
64 | sys.exit(1)
65 |
66 | loss /= accum_iter
67 | loss_scaler(loss, optimizer, clip_grad=max_norm,
68 | parameters=model.parameters(), create_graph=False,
69 | update_grad=(data_iter_step + 1) % accum_iter == 0)
70 | if (data_iter_step + 1) % accum_iter == 0:
71 | optimizer.zero_grad()
72 |
73 | torch.cuda.synchronize()
74 |
75 | metric_logger.update(loss=loss_value)
76 | min_lr = 10.
77 | max_lr = 0.
78 | for group in optimizer.param_groups:
79 | min_lr = min(min_lr, group["lr"])
80 | max_lr = max(max_lr, group["lr"])
81 |
82 | metric_logger.update(lr=max_lr)
83 |
84 | loss_value_reduce = misc.all_reduce_mean(loss_value)
85 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
86 | """ We use epoch_1000x as the x-axis in tensorboard.
87 | This calibrates different curves when batch size changes.
88 | """
89 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
90 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
91 | log_writer.add_scalar('lr', max_lr, epoch_1000x)
92 |
93 | # gather the stats from all processes
94 | metric_logger.synchronize_between_processes()
95 | print("Averaged stats:", metric_logger)
96 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
97 |
98 |
99 | @torch.no_grad()
100 | def evaluate(data_loader, model, device):
101 | criterion = torch.nn.CrossEntropyLoss()
102 |
103 | metric_logger = misc.MetricLogger(delimiter=" ")
104 | header = 'Test:'
105 |
106 | # switch to evaluation mode
107 | model.eval()
108 |
109 | for batch in metric_logger.log_every(data_loader, 10, header):
110 | images = batch[0]
111 | target = batch[-1]
112 | images = images.to(device, non_blocking=True)
113 | target = target.to(device, non_blocking=True)
114 |
115 | # compute output
116 | with torch.cuda.amp.autocast():
117 | output = model(images)
118 | loss = criterion(output, target)
119 |
120 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
121 |
122 | batch_size = images.shape[0]
123 | metric_logger.update(loss=loss.item())
124 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
125 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
126 | # gather the stats from all processes
127 | metric_logger.synchronize_between_processes()
128 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
129 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
130 |
131 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
132 |
--------------------------------------------------------------------------------
/engine_pretrain.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MAE (https://github.com/facebookresearch/mae)
6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 | # References:
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # DeiT: https://github.com/facebookresearch/deit
11 | # ------------------------------------------------------------------------
12 |
13 |
14 | import math
15 | import os
16 | import sys
17 | from turtle import update
18 | from typing import Iterable
19 | from pathlib import Path
20 |
21 | import torch
22 |
23 | import util.misc as misc
24 | import util.lr_sched as lr_sched
25 |
26 |
27 | def train_one_epoch(model: torch.nn.Module,
28 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
29 | device: torch.device, epoch: int, loss_scaler,
30 | log_writer=None,
31 | args=None):
32 | model.train(True)
33 | metric_logger = misc.MetricLogger(delimiter=" ")
34 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
35 | header = 'Epoch: [{}]'.format(epoch)
36 | print_freq = 50
37 |
38 | accum_iter = args.accum_iter
39 |
40 | optimizer.zero_grad()
41 |
42 | if log_writer is not None:
43 | print('log_dir: {}'.format(log_writer.log_dir))
44 |
45 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
46 | if args.with_blockwise_mask:
47 | samples, labels, mask = data
48 | else:
49 | samples, labels = data
50 | mask = None
51 |
52 | # we use a per iteration (instead of per epoch) lr scheduler
53 | if data_iter_step % accum_iter == 0:
54 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
55 |
56 | if args.mmschedule == 'const':
57 | mm = args.mm
58 | elif args.mmschedule == 'cosine':
59 | mm = 1. - 0.5 * (1. + math.cos(math.pi * (data_iter_step / len(data_loader) + epoch) / args.epochs)) * (1. - args.mm)
60 | metric_logger.update(mm=mm)
61 | update_mm = (data_iter_step % accum_iter == 0)
62 |
63 | if args.loss_type in ['sim',]:
64 | x1, x2, delta_i, delta_j, delta_h, delta_w, relative_flip, flip_delta_j = samples
65 | x1 = x1.to(device, non_blocking=True)
66 | x2 = x2.to(device, non_blocking=True)
67 | delta_i = delta_i.to(x1)
68 | delta_j = delta_j.to(x1)
69 | delta_h = delta_h.to(x1)
70 | delta_w = delta_w.to(x1)
71 | flip_delta_j = flip_delta_j.to(x1)
72 |
73 | rel_pos_21 = (delta_i, delta_j, delta_h, delta_w, relative_flip, flip_delta_j)
74 |
75 | with torch.cuda.amp.autocast(enabled=(not args.fp32)):
76 | loss, outputs = model(x1, x2, rel_pos_21, mm, update_mm, mask=mask)
77 | metric_logger.update(**outputs)
78 | else:
79 | samples = samples.to(device, non_blocking=True)
80 |
81 | with torch.cuda.amp.autocast(enabled=(not args.fp32)):
82 | loss, _, _ = model(samples, mask_ratio=args.mask_ratio)
83 |
84 | loss_value = loss.item()
85 |
86 | if not math.isfinite(loss_value):
87 | print("Loss is {}, stopping training".format(loss_value))
88 | sys.exit(1)
89 |
90 | loss /= accum_iter
91 | grad_norm = loss_scaler(loss, optimizer, parameters=model.parameters(),
92 | update_grad=(data_iter_step + 1) % accum_iter == 0, clip_grad=args.clip_grad)
93 | if args.fp32:
94 | loss_scale = None
95 | else:
96 | loss_scale = loss_scaler.state_dict()['scale']
97 |
98 | metric_logger.update(grad_norm=grad_norm)
99 | metric_logger.update(loss_scale=loss_scale)
100 |
101 | if (data_iter_step + 1) % accum_iter == 0:
102 | optimizer.zero_grad()
103 |
104 | torch.cuda.synchronize()
105 |
106 | metric_logger.update(loss=loss_value)
107 |
108 | lr = optimizer.param_groups[0]["lr"]
109 | metric_logger.update(lr=lr)
110 |
111 | loss_value_reduce = misc.all_reduce_mean(loss_value)
112 | outputs_reduced = {k_: misc.all_reduce_mean(v_) for k_, v_ in outputs.items()}
113 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
114 | """ We use epoch_1000x as the x-axis in tensorboard.
115 | This calibrates different curves when batch size changes.
116 | """
117 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
118 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
119 | log_writer.add_scalar('lr', lr, epoch_1000x)
120 | log_writer.add_scalar('grad_norm', grad_norm, epoch_1000x)
121 | if loss_scale is not None:
122 | log_writer.add_scalar('loss_scale', loss_scale, epoch_1000x)
123 | log_writer.add_scalar('mm', mm, epoch_1000x)
124 | for k_, v_ in outputs_reduced.items():
125 | log_writer.add_scalar(f'train/{k_}', v_, epoch_1000x)
126 |
127 | # gather the stats from all processes
128 | metric_logger.synchronize_between_processes()
129 | print("Averaged stats:", metric_logger)
130 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
131 |
--------------------------------------------------------------------------------
/util/augmentation.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | from PIL import ImageFilter, ImageOps
5 | import torch
6 | import torchvision.transforms as transforms
7 | import torchvision.transforms.functional as F
8 |
9 |
10 | class RandomResizedCrop(transforms.RandomResizedCrop):
11 | def __init__(self, cfg, *args, **kwargs):
12 | super().__init__(*args, **kwargs)
13 | self.args = cfg
14 |
15 | @staticmethod
16 | def get_params(img, scale, ratio):
17 | """Get parameters for ``crop`` for a random sized crop.
18 |
19 | Args:
20 | img (PIL Image or Tensor): Input image.
21 | scale (list): range of scale of the origin size cropped
22 | ratio (list): range of aspect ratio of the origin aspect ratio cropped
23 |
24 | Returns:
25 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random
26 | sized crop.
27 | """
28 | width, height = F.get_image_size(img)
29 | area = height * width
30 |
31 | log_ratio = torch.log(torch.tensor(ratio))
32 | for _ in range(10):
33 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
34 | aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
35 |
36 | w = int(round(math.sqrt(target_area * aspect_ratio)))
37 | h = int(round(math.sqrt(target_area / aspect_ratio)))
38 |
39 | if 0 < w <= width and 0 < h <= height:
40 | i1 = torch.randint(0, height - h + 1, size=(1,)).item()
41 | i2 = torch.randint(0, height - h + 1, size=(1,)).item()
42 | j1 = torch.randint(0, width - w + 1, size=(1,)).item()
43 | j2 = torch.randint(0, width - w + 1, size=(1,)).item()
44 |
45 | return i1, j1, i2, j2, h, w
46 |
47 | # Fallback to central crop
48 | in_ratio = float(width) / float(height)
49 | if in_ratio < min(ratio):
50 | w = width
51 | h = int(round(w / min(ratio)))
52 | elif in_ratio > max(ratio):
53 | h = height
54 | w = int(round(h * max(ratio)))
55 | else: # whole image
56 | w = width
57 | h = height
58 | i = (height - h) // 2
59 | j = (width - w) // 2
60 | return i, j, i, j, h, w
61 |
62 | def forward(self, img):
63 | """
64 | Args:
65 | img (PIL Image or Tensor): Image to be cropped and resized.
66 |
67 | Returns:
68 | PIL Image or Tensor: Randomly cropped and resized image.
69 | """
70 | i1, j1, i2, j2, h, w = self.get_params(img, self.scale, self.ratio)
71 | return F.resized_crop(img, i1, j1, h, w, self.size, self.interpolation), \
72 | F.resized_crop(img, i2, j2, h, w, self.size, self.interpolation), (i2-i1)/h, (j2-j1)/w, h/h, w/w
73 |
74 |
75 | class GaussianBlur(object):
76 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
77 |
78 | def __init__(self, sigma=[.1, 2.]):
79 | self.sigma = sigma
80 |
81 | def __call__(self, x):
82 | sigma = random.uniform(self.sigma[0], self.sigma[1])
83 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
84 | return x
85 |
86 |
87 | class Solarize(object):
88 | """Solarize augmentation from BYOL: https://arxiv.org/abs/2006.07733"""
89 |
90 | def __call__(self, x):
91 | return ImageOps.solarize(x)
92 |
93 |
94 | class SingleRandomResizedCrop(transforms.RandomResizedCrop):
95 | @staticmethod
96 | def get_params(img, scale, ratio):
97 | """Get parameters for ``crop`` for a random sized crop.
98 |
99 | Args:
100 | img (PIL Image or Tensor): Input image.
101 | scale (list): range of scale of the origin size cropped
102 | ratio (list): range of aspect ratio of the origin aspect ratio cropped
103 |
104 | Returns:
105 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random
106 | sized crop.
107 | """
108 | width, height = F.get_image_size(img)
109 | area = height * width
110 |
111 | log_ratio = torch.log(torch.tensor(ratio))
112 | for _ in range(10):
113 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
114 | aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
115 |
116 | w = int(round(math.sqrt(target_area * aspect_ratio)))
117 | h = int(round(math.sqrt(target_area / aspect_ratio)))
118 |
119 | if 0 < w <= width and 0 < h <= height:
120 | i = torch.randint(0, height - h + 1, size=(1,)).item()
121 | j = torch.randint(0, width - w + 1, size=(1,)).item()
122 | return i, j, h, w, width
123 |
124 | # Fallback to central crop
125 | in_ratio = float(width) / float(height)
126 | if in_ratio < min(ratio):
127 | w = width
128 | h = int(round(w / min(ratio)))
129 | elif in_ratio > max(ratio):
130 | h = height
131 | w = int(round(h * max(ratio)))
132 | else: # whole image
133 | w = width
134 | h = height
135 | i = (height - h) // 2
136 | j = (width - w) // 2
137 | return i, j, h, w, width
138 |
139 | def forward(self, img):
140 | """
141 | Args:
142 | img (PIL Image or Tensor): Image to be cropped and resized.
143 |
144 | Returns:
145 | PIL Image or Tensor: Randomly cropped and resized image.
146 | """
147 | i, j, h, w, width = self.get_params(img, self.scale, self.ratio)
148 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), i, j, h, w, width
149 |
150 |
151 | class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
152 | def forward(self, img):
153 | if torch.rand(1) < self.p:
154 | return F.hflip(img), True
155 | return img, False
156 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Siamese-Image-Modeling
2 |
3 | By [Chenxin Tao](https://scholar.google.com/citations?user=sXHFIBkAAAAJ&hl=zh-CN),
4 | [Xizhou Zhu](https://scholar.google.com/citations?user=02RXI00AAAAJ),
5 | [Weijie Su](https://www.weijiesu.com/),
6 | [Gao Huang](http://www.gaohuang.net/),
7 | [Bin Li](http://staff.ustc.edu.cn/~binli/),
8 | [Jie Zhou](https://scholar.google.com/citations?user=6a79aPwAAAAJ&hl=en),
9 | [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ&hl=en),
10 | [Xiaogang Wang](http://www.ee.cuhk.edu.hk/~xgwang/),
11 | [Jifeng Dai](https://jifengdai.org/)
12 |
13 | This is the official implementation of the CVPR 2023 paper [Siamese Image Modeling for Self-Supervised Vision Representation Learning](https://arxiv.org/pdf/2206.01204.pdf).
14 |
15 | 
16 |
17 | ## 🏠 Introduction
18 |
19 | SiameseIM is a new form of self-supervised learning that can learn semantic alignment and spatial sensitivity with a single dense loss. We note the following key observations from SiameseIM:
20 |
21 | - Compared with MIM methods, SiameseIM shows that reconstructing another view helps to obtain good semantic alignment.
22 |
23 | - Compared with ID methods, SiameseIM shows that dense supervision can be applied by matching the dense correspondence between two views strictly through their relative positions.
24 |
25 | - SiameseIM is able to surpass both MIM and ID methods over a wide range of tasks. SiameseIM obtains more improvements in few-shot, long-tail and robustness-concerned scenarios.
26 |
27 |
28 | 
29 |
30 |
31 | ## 📈 Main Results
32 |
33 |
34 |
35 | |
36 | ImageNet |
37 | COCO |
38 | ADE20k |
39 | LVIS |
40 | Robustness |
41 |
42 |
43 | | FT | LIN | 1% FT | AP box | AP mask | mIoU | AP box | AP box rare | AP mask | AP mask rare | IN-A top-1 | IN-R top-1 | IN-Sketch top-1 | IN-C 1-mCE |
44 |
45 |
46 | | MoCo-v3 (ID method) | 83.0 | 76.7 | 63.4 | 47.9 | 42.7 | 47.3 | 37.3 | 25.5 | 35.3 | 25.8 | 32.4 | 49.8 | 35.9 | 55.4 |
47 |
48 |
49 | | MAE (MIM method) | 83.6 | 68.0 | 51.1 | 51.6 | 45.9 | 48.1 | 40.1 | 29.3 | 38.1 | 29.1 | 35.9 | 48.3 | 34.5 | 48.3 |
50 |
51 |
52 | | SiameseIM | 84.1 | 78.0 | 65.1 | 52.1 | 46.2 | 51.1 | 40.5 | 30.9 | 38.1 | 30.1 | 43.8 | 52.5 | 38.3 | 57.1 |
53 |
54 |
55 | | Improve w.r.t. MoCo-v3 | +1.1 | +1.3 | +1.7 | +4.2 | +3.5 | +3.8 | +3.2 | +5.4 | +2.8 | +4.3 | +11.4 | +2.7 | +2.4 | +1.7 |
56 |
57 |
58 | | Improve w.r.t. MAE | +0.5 | +10.0 | +14.0 | +0.5 | +0.3 | +3.0 | +0.4 | +1.6 | +0.0 | +1.0 | +7.9 | +4.2 | +3.8 | +8.8 |
59 |
60 |
61 |
62 |
63 | Note:
64 |
65 | (1) Compared with MoCo-v3, SiameseIM improves dense prediction tasks (COCO detection, ADE20k segmentation, LVIS detection) significantly;
66 |
67 | (2) Compared with MAE, SiameseIM improves long-tail, few-shot, robustness tasks (ImageNet linear evaluation & few-shot classification, ADE20k segmentation, LVIS detection) significantly;
68 |
69 | (3) Notably, ADE20k segmentation and LVIS detection both contain long-tail classes, which put forward high requirement for semantic alignment, and detection tasks, which demand good spatial alignment. Thus, SiameseIM can surpass both MoCo-v3 and MAE by a large margin on these tasks.
70 |
71 |
72 | ## 🛠️ Usage
73 | ### Preparation
74 |
75 | See [prepare.md](docs/prepare.md)
76 |
77 | ### Model Checkpoint
78 |
79 | See [checkpoints.md](docs/checkpoints.md)
80 |
81 | ### Pretrain
82 |
83 | See [pretrain.md](docs/pretrain.md)
84 |
85 | ### Finetune
86 |
87 | See [finetune.md](docs/finetune.md)
88 |
89 | ### Linear Evaluation
90 |
91 | See [linear_eval.md](docs/linear_eval.md)
92 |
93 | ### Few-shot Evaluation
94 |
95 | See [few_shot.md](docs/few_shot.md)
96 |
97 | ### COCO & LVIS Detection
98 |
99 | We use ViTDet for detection tasks, please refer to [detectron2](https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet).
100 |
101 | ### ADE20k Segmentation
102 |
103 | We follow MAE to use UPerNet for segmentation task, please refer to [mmsegmentation](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/mae).
104 |
105 | ### Robustness Evaluation
106 |
107 | We evaluate the ImageNet finetuned model on [ImageNet-A](https://github.com/hendrycks/natural-adv-examples), [ImageNet-R](https://github.com/hendrycks/imagenet-r), [ImageNet-Sketch](https://github.com/HaohanWang/ImageNet-Sketch) and [ImageNet-C](https://github.com/hendrycks/robustness) datasets.
108 |
109 |
110 | ## 📃 License
111 |
112 | This project is released under the [CC-BY-NC 4.0 license](./LICENSE).
113 |
114 | ## 🖊️ Citing SiameseIM
115 | If you find SiameseIM useful in your research, please consider citing:
116 | ```bibtex
117 | @inproceedings{tao2023siamese,
118 | title={Siamese image modeling for self-supervised vision representation learning},
119 | author={Tao, Chenxin and Zhu, Xizhou and Su, Weijie and Huang, Gao and Li, Bin and Zhou, Jie and Qiao, Yu and Wang, Xiaogang and Dai, Jifeng},
120 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
121 | pages={2132--2141},
122 | year={2023}
123 | ```
124 |
--------------------------------------------------------------------------------
/util/masking_generator.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | """
6 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
7 | Copyright Zhun Zhong & Liang Zheng
8 |
9 | Hacked together by / Copyright 2020 Ross Wightman
10 |
11 | Modified by Hangbo Bao, for generating the masked position for visual image transformer
12 | """
13 | # --------------------------------------------------------
14 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
15 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
16 | # Copyright (c) 2021 Microsoft
17 | # Licensed under The MIT License [see LICENSE for details]
18 | # By Hangbo Bao
19 | # Based on timm, DINO and DeiT code bases
20 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
21 | # Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
22 | # Copyright Zhun Zhong & Liang Zheng
23 | #
24 | # Hacked together by / Copyright 2020 Ross Wightman
25 | #
26 | # Modified by Hangbo Bao, for generating the masked position for visual image transformer
27 | # --------------------------------------------------------'
28 | import random
29 | import math
30 | import numpy as np
31 |
32 |
33 | class MaskingGenerator:
34 | def __init__(
35 | self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None,
36 | min_aspect=0.3, max_aspect=None, fixed_num_masking_patches=False):
37 | if not isinstance(input_size, tuple):
38 | input_size = (input_size, ) * 2
39 | self.height, self.width = input_size
40 |
41 | self.num_patches = self.height * self.width
42 | self.num_masking_patches = num_masking_patches
43 |
44 | self.min_num_patches = min_num_patches
45 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
46 |
47 | max_aspect = max_aspect or 1 / min_aspect
48 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
49 | self.fixed_num_masking_patches = fixed_num_masking_patches
50 |
51 | def __repr__(self):
52 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
53 | self.height, self.width, self.min_num_patches, self.max_num_patches,
54 | self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])
55 | return repr_str
56 |
57 | def get_shape(self):
58 | return self.height, self.width
59 |
60 | def _mask(self, mask, max_mask_patches):
61 | delta = 0
62 | for attempt in range(10):
63 | target_area = random.uniform(self.min_num_patches, max_mask_patches)
64 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
65 | h = int(round(math.sqrt(target_area * aspect_ratio)))
66 | w = int(round(math.sqrt(target_area / aspect_ratio)))
67 | if w < self.width and h < self.height:
68 | top = random.randint(0, self.height - h)
69 | left = random.randint(0, self.width - w)
70 |
71 | num_masked = mask[top: top + h, left: left + w].sum()
72 | # Overlap
73 | if 0 < h * w - num_masked <= max_mask_patches:
74 | for i in range(top, top + h):
75 | for j in range(left, left + w):
76 | if mask[i, j] == 0:
77 | mask[i, j] = 1
78 | delta += 1
79 |
80 | if delta > 0:
81 | break
82 | return delta
83 |
84 | def __call__(self):
85 | mask = np.zeros(shape=self.get_shape(), dtype=np.int)
86 | mask_count = 0
87 | while mask_count < self.num_masking_patches:
88 | max_mask_patches = self.num_masking_patches - mask_count
89 | max_mask_patches = min(max_mask_patches, self.max_num_patches)
90 |
91 | delta = self._mask(mask, max_mask_patches)
92 | if delta == 0:
93 | break
94 | else:
95 | mask_count += delta
96 |
97 | if self.fixed_num_masking_patches and (mask_count < self.num_masking_patches):
98 | non_masked_inds_i, non_masked_inds_j = (mask == 0).nonzero()
99 | shuffle_inds = list(range(non_masked_inds_i.shape[0]))
100 | random.shuffle(shuffle_inds)
101 | num_to_mask = self.num_masking_patches - mask_count
102 | to_mask_inds_i = non_masked_inds_i[shuffle_inds[:num_to_mask]]
103 | to_mask_inds_j = non_masked_inds_j[shuffle_inds[:num_to_mask]]
104 | mask[to_mask_inds_i, to_mask_inds_j] = 1
105 | mask_count += num_to_mask
106 |
107 | return mask
108 |
109 |
110 | if __name__ == '__main__':
111 | blockwise_num_masking_patches=75 ### TODO: 75 / 196 = 0.38 -> Modify this to increase mask ratio
112 | input_size=224
113 | patch_size=16 # BEiT default setting, no need to change
114 | max_mask_patches_per_block=None # BEiT default setting, no need to change
115 | min_mask_patches_per_block=16 # BEiT default setting, no need to change
116 | fixed_num_masking_patches=True ### TODO: fixed number of masking patch to blockwise_num_masking_patches for sim training
117 | window_size = input_size // patch_size
118 | masked_position_generator = MaskingGenerator(
119 | (window_size, window_size),
120 | num_masking_patches=blockwise_num_masking_patches,
121 | max_num_patches=max_mask_patches_per_block,
122 | min_num_patches=min_mask_patches_per_block,
123 | fixed_num_masking_patches=fixed_num_masking_patches
124 | )
125 | mask_num = []
126 | for _ in range(10000):
127 | mask = masked_position_generator()
128 | if _ < 10:
129 | print(mask)
130 | mask_num.append(mask.sum())
131 | print(f"Max Patches: {max(mask_num)} Min Patches: {min(mask_num)} Mean Patches: {sum(mask_num) / len(mask_num)}")
132 | print(f"Max Ratio: {max(mask_num)/196.0} Min Ratio: {min(mask_num)/196.0} Mean Ratio: {sum(mask_num) / len(mask_num) / 196.0}")
133 |
--------------------------------------------------------------------------------
/util/pos_embed.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MAE (https://github.com/facebookresearch/mae)
6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 |
9 |
10 | import numpy as np
11 |
12 | import torch
13 |
14 | # --------------------------------------------------------
15 | # 2D sine-cosine position embedding
16 | # References:
17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18 | # MoCo v3: https://github.com/facebookresearch/moco-v3
19 | # --------------------------------------------------------
20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21 | """
22 | grid_size: int of the grid height and width
23 | return:
24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25 | """
26 | grid_h = np.arange(grid_size, dtype=np.float32)
27 | grid_w = np.arange(grid_size, dtype=np.float32)
28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
29 | grid = np.stack(grid, axis=0)
30 |
31 | grid = grid.reshape([2, 1, grid_size, grid_size])
32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33 | if cls_token:
34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35 | return pos_embed
36 |
37 |
38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39 | assert embed_dim % 2 == 0
40 |
41 | # use half of dimensions to encode grid_h
42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
44 |
45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46 | return emb
47 |
48 |
49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50 | """
51 | embed_dim: output dimension for each position
52 | pos: a list of positions to be encoded: size (M,)
53 | out: (M, D)
54 | """
55 | assert embed_dim % 2 == 0
56 | omega = np.arange(embed_dim // 2, dtype=np.float)
57 | omega /= embed_dim / 2.
58 | omega = 1. / 10000**omega # (D/2,)
59 |
60 | pos = pos.reshape(-1) # (M,)
61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62 |
63 | emb_sin = np.sin(out) # (M, D/2)
64 | emb_cos = np.cos(out) # (M, D/2)
65 |
66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67 | return emb
68 |
69 |
70 | def get_2d_sincos_pos_embed_relative(delta_i, delta_j, delta_h, delta_w, relative_flip, flip_delta_j, embed_dim, grid_size, cls_token=False):
71 | """
72 | grid_size: int of the grid height and width
73 | return:
74 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
75 | """
76 | delta_i = delta_i * grid_size
77 | delta_j = delta_j * grid_size
78 | flip_delta_j = flip_delta_j * grid_size
79 | grid_h = torch.arange(grid_size, dtype=torch.float32)
80 | grid_w = torch.arange(grid_size, dtype=torch.float32)
81 | raw_grid_h, raw_grid_w = torch.meshgrid(grid_h, grid_w)
82 |
83 | raw_grid_h = raw_grid_h + 0.5
84 | raw_grid_w = raw_grid_w + 0.5
85 | grid_h = torch.einsum('b,n->bn', delta_h, raw_grid_h.flatten().to(delta_h)) + delta_i.unsqueeze(-1)
86 | grid_w = torch.einsum('b,n->bn', delta_w, raw_grid_w.flatten().to(delta_w)) + delta_j.unsqueeze(-1)
87 |
88 | flip_grid_w = -torch.einsum('b,n->bn', [delta_w, raw_grid_w.flatten().to(delta_h)]) + flip_delta_j[:, None]
89 | relative_flip = relative_flip.float().unsqueeze(-1)
90 | grid_w = relative_flip * flip_grid_w + (1-relative_flip) * grid_w
91 | grid_w = grid_w - 0.5
92 | grid_h = grid_h - 0.5
93 |
94 | omega = torch.arange(embed_dim//4, dtype=torch.float32) / (embed_dim/4)
95 | omega = 1. / (10000**omega)
96 | out_h = torch.einsum('bn,c->bnc', [grid_h, omega.to(grid_h)])
97 | out_w = torch.einsum('bn,c->bnc', [grid_w, omega.to(grid_w)])
98 | out_scale_h = torch.einsum('b,c->bc', [10*torch.log(delta_h), omega.to(grid_h)]).unsqueeze(1).expand(-1, out_h.shape[1], -1)
99 | out_scale_w = torch.einsum('b,c->bc', [10*torch.log(delta_w), omega.to(grid_w)]).unsqueeze(1).expand(-1, out_h.shape[1], -1)
100 | pos_embed = torch.cat([torch.sin(out_h), torch.cos(out_h), torch.sin(out_w), torch.cos(out_w),
101 | torch.sin(out_scale_h), torch.cos(out_scale_h),
102 | torch.sin(out_scale_w), torch.cos(out_scale_w),], dim=2).detach()
103 |
104 | return pos_embed
105 |
106 |
107 | # --------------------------------------------------------
108 | # Interpolate position embeddings for high-resolution
109 | # References:
110 | # DeiT: https://github.com/facebookresearch/deit
111 | # --------------------------------------------------------
112 | def interpolate_pos_embed(model, checkpoint_model):
113 | if 'pos_embed' in checkpoint_model:
114 | pos_embed_checkpoint = checkpoint_model['pos_embed']
115 | embedding_size = pos_embed_checkpoint.shape[-1]
116 | num_patches = model.patch_embed.num_patches
117 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
118 | # height (== width) for the checkpoint position embedding
119 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
120 | # height (== width) for the new position embedding
121 | new_size = int(num_patches ** 0.5)
122 | # class_token and dist_token are kept unchanged
123 | if orig_size != new_size:
124 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
125 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
126 | # only the position tokens are interpolated
127 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
128 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
129 | pos_tokens = torch.nn.functional.interpolate(
130 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
131 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
132 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
133 | checkpoint_model['pos_embed'] = new_pos_embed
134 |
--------------------------------------------------------------------------------
/util/tcs_datasets.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 |
6 |
7 | import os
8 | import io
9 | from PIL import Image
10 | from torch.utils.data import Dataset
11 | import pyarrow as pa
12 | import numpy as np
13 | from io import BytesIO
14 | import tqdm
15 | from tqdm import trange
16 | try:
17 | from petrel_client.client import Client
18 | except ImportError as E:
19 | "petrel_client.client cannot be imported"
20 | pass
21 |
22 |
23 | def tcs_pil_loader(img_str):
24 | buff = io.BytesIO(img_str)
25 | return Image.open(buff)
26 |
27 |
28 | class TCSLoader(object):
29 |
30 | def __init__(self, conf_path):
31 | self.client = Client(conf_path)
32 |
33 | def __call__(self, fn):
34 | try:
35 | img_value_str = self.client.get(fn)
36 | img = tcs_pil_loader(img_value_str)
37 | except:
38 | print('Read image failed ({})'.format(fn))
39 | return None
40 | else:
41 | return img
42 |
43 |
44 | def _get_images(annotations):
45 | images = []
46 | classes = []
47 | for line in annotations:
48 | if isinstance(line, bytes):
49 | line = line.decode()
50 | image_name, cls = line.strip('\n').split()
51 | images.append(image_name)
52 | classes.append(cls)
53 | return images, classes
54 |
55 |
56 | class ImageNetTCSDatasetQK(Dataset):
57 | def __init__(self, image_set, data_path, transform=None, use_tcs=False,
58 | tcs_conf_path='/mnt/lustre/share_data/taochenxin/tcs/petreloss.conf',
59 | test_mode=False,
60 | on_memory=False, local_rank=None, local_size=None,
61 | **kwargs):
62 | ann_file = os.path.join(data_path, f'meta/{image_set}.txt')
63 | data_path = os.path.join(data_path, image_set)
64 | self.image_set = image_set
65 | self.transform = transform
66 | self.data_path = data_path
67 | self.test_mode = test_mode
68 | if use_tcs:
69 | self.tcs_loader = TCSLoader(tcs_conf_path)
70 | self.use_tcs = use_tcs
71 | self.images, self.classes, self.class_to_idx = self._load_database(ann_file)
72 | self.on_memory = on_memory
73 | if on_memory:
74 | if local_rank is None:
75 | local_rank = int(os.environ.get('LOCAL_RANK', 0))
76 | if local_size is None:
77 | local_size = int(os.environ.get('LOCAL_SIZE', 1))
78 | self.local_rank = local_rank
79 | self.local_size = local_size
80 | self.holder = {}
81 | self.load_onto_memory()
82 |
83 | def load_onto_memory(self):
84 | print("Loading images onto memory...")
85 | for index in trange(len(self.images)):
86 | if index % self.local_size != self.local_rank:
87 | continue
88 | path = self.images[index].as_py()
89 | full_path = os.path.join(self.data_path, path)
90 | if self.use_tcs:
91 | sample = self.tcs_loader.client.get(full_path)
92 | else:
93 | with open(full_path, 'rb') as f:
94 | sample = f.read()
95 | self.holder[path] = sample
96 | # print('Loading: path {}, full_path {}, data length {}'.format(path, full_path,
97 | # len(self.tcs_loader.client.get(full_path))))
98 | print("Loading complete!")
99 |
100 | def _load_database(self, annotation_file):
101 | if not self.use_tcs:
102 | annotation_file = os.path.abspath(annotation_file)
103 | print(f'loading annotations from {annotation_file} ...')
104 | if self.use_tcs:
105 | with BytesIO(self.tcs_loader.client.get(annotation_file)) as annotations:
106 | images, classes = _get_images(annotations)
107 | else:
108 | with open(annotation_file, 'rt') as annotations:
109 | images, classes = _get_images(annotations)
110 |
111 | # convert possible classes to indices
112 | class_names = sorted(set(classes))
113 | # class_to_idx = {class_name: idx for idx, class_name in enumerate(class_names)}
114 | class_to_idx = {class_name: int(class_name) for class_name in class_names}
115 | return pa.array(images), pa.array([class_to_idx[class_name] for class_name in classes]), class_to_idx
116 |
117 | def __len__(self):
118 | return len(self.images)
119 |
120 | def __getitem__(self, index):
121 | path = self.images[index].as_py()
122 | target = self.classes[index].as_py()
123 | sample = self._load_image(path)
124 | if self.transform is not None:
125 | sample_q = self.transform(sample)
126 | sample_k = self.transform(sample)
127 | return sample_q, sample_k
128 | else:
129 | return sample, sample
130 |
131 | def _load_image(self, path):
132 | full_path = os.path.join(self.data_path, path)
133 | if self.on_memory:
134 | try:
135 | return Image.open(BytesIO(self.holder[path])).convert('RGB')
136 | except:
137 | print('error acquiring data from {}'.format(path))
138 | return self.tcs_loader(full_path).convert('RGB')
139 | elif self.use_tcs:
140 | return self.tcs_loader(full_path).convert('RGB')
141 | else:
142 | with open(full_path, 'rb') as f:
143 | return Image.open(f).convert('RGB')
144 |
145 |
146 | class ImagenetTCSDataset(ImageNetTCSDatasetQK):
147 | def __init__(self, image_set, data_path, transform=None, use_tcs=False,
148 | tcs_conf_path='/mnt/lustre/share_data/taochenxin/tcs/petreloss.conf',
149 | test_mode=False, on_memory=False, local_rank=None, local_size=None,
150 | with_blockwise_mask=False, ### !!! set to True, enable blockwise masking
151 | blockwise_num_masking_patches=75, ### !!! 75 / 196 = 0.38 -> Modify this to increase mask ratio
152 | input_size=224, patch_size=16, # no need to change now
153 | max_mask_patches_per_block=None, # BEiT default setting, no need to change
154 | min_mask_patches_per_block=16, # BEiT default setting, no need to change
155 | fixed_num_masking_patches=True, ### set to true, fixed number of masking patch to blockwise_num_masking_patches for sim training
156 | **kwargs):
157 | super().__init__(image_set, data_path, transform=transform, use_tcs=use_tcs,
158 | tcs_conf_path=tcs_conf_path, test_mode=test_mode, on_memory=on_memory,
159 | local_rank=local_rank, local_size=local_size, **kwargs)
160 | self.with_blockwise_mask = with_blockwise_mask
161 | if with_blockwise_mask:
162 | from .masking_generator import MaskingGenerator
163 | window_size = input_size // patch_size
164 | self.masked_position_generator = MaskingGenerator(
165 | (window_size, window_size),
166 | num_masking_patches=blockwise_num_masking_patches,
167 | max_num_patches=max_mask_patches_per_block,
168 | min_num_patches=min_mask_patches_per_block,
169 | fixed_num_masking_patches=fixed_num_masking_patches
170 | )
171 |
172 | def __getitem__(self, index):
173 | path = self.images[index].as_py()
174 | target = self.classes[index].as_py()
175 | sample = self._load_image(path)
176 | if self.transform is not None:
177 | sample = self.transform(sample)
178 | if self.with_blockwise_mask:
179 | return sample, target, self.masked_position_generator()
180 | return sample, target
181 |
182 |
183 | if __name__ == '__main__':
184 | transform = transforms.Compose([
185 | transforms.RandomResizedCrop(224),
186 | transforms.RandomHorizontalFlip(0.5),
187 | transforms.ToTensor(),
188 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
189 | ])
190 | dataset = ImagenetTCSDataset(
191 | 'val',
192 | 's3://imagenet',
193 | tcs_conf_path='./petreloss.conf',
194 | transform=transform,
195 | with_blockwise_mask=True,
196 | blockwise_num_masking_patches=75)
197 | for i, (sample, target, mask) in enumerate(dataset):
198 | if i < 10:
199 | print(mask.sum())
200 | print(mask)
201 | else:
202 | break
203 |
--------------------------------------------------------------------------------
/models_vit.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MAE (https://github.com/facebookresearch/mae)
6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 | # References:
9 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
10 | # DeiT: https://github.com/facebookresearch/deit
11 | # ------------------------------------------------------------------------
12 |
13 | from functools import partial
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 |
19 | import timm.models.vision_transformer
20 | from timm.models.layers import Mlp, DropPath
21 | from timm.models.layers.helpers import to_2tuple
22 |
23 | from util.misc import LayerNorm
24 |
25 |
26 | class PatchEmbed(nn.Module):
27 | """ 2D Image to Patch Embedding
28 | """
29 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
30 | super().__init__()
31 | img_size = to_2tuple(img_size)
32 | patch_size = to_2tuple(patch_size)
33 | self.img_size = img_size
34 | self.patch_size = patch_size
35 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
36 | self.num_patches = self.grid_size[0] * self.grid_size[1]
37 | self.flatten = flatten
38 |
39 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
40 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
41 |
42 | def forward(self, x):
43 | B, C, H, W = x.shape
44 | x = self.proj(x)
45 | if self.flatten:
46 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
47 | x = self.norm(x)
48 | return x
49 |
50 |
51 | class Attention(nn.Module):
52 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
53 | super().__init__()
54 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
55 | self.num_heads = num_heads
56 | head_dim = dim // num_heads
57 | self.scale = head_dim ** -0.5
58 |
59 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
60 | self.attn_drop = nn.Dropout(attn_drop)
61 | self.proj = nn.Linear(dim, dim)
62 | self.proj_drop = nn.Dropout(proj_drop)
63 |
64 | def forward(self, x):
65 | B, N, C = x.shape
66 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
67 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
68 |
69 | attn = ((q * self.scale) @ k.transpose(-2, -1))
70 | attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
71 | attn = attn.softmax(dim=-1)
72 | attn = self.attn_drop(attn)
73 |
74 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
75 | x = self.proj(x)
76 | x = self.proj_drop(x)
77 | return x
78 |
79 |
80 | class CrossAttention(nn.Module):
81 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
82 | super().__init__()
83 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
84 | self.num_heads = num_heads
85 | head_dim = dim // num_heads
86 | self.scale = head_dim ** -0.5
87 |
88 | # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
89 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
90 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
91 | self.attn_drop = nn.Dropout(attn_drop)
92 | self.proj = nn.Linear(dim, dim)
93 | self.proj_drop = nn.Dropout(proj_drop)
94 |
95 | def forward(self, query, key):
96 | B, Nq, C = query.shape
97 | _, Nk, _ = key.shape
98 | q = self.q(query).reshape(B, Nq, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
99 | kv = self.kv(key).reshape(B, Nk, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
100 | k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
101 |
102 | attn = ((q * self.scale) @ k.transpose(-2, -1))
103 | attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16
104 | attn = attn.softmax(dim=-1)
105 | attn = self.attn_drop(attn)
106 |
107 | x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
108 | x = self.proj(x)
109 | x = self.proj_drop(x)
110 | return x
111 |
112 |
113 | class LayerScale(nn.Module):
114 | def __init__(self, dim, init_values=1e-5, inplace=False):
115 | super().__init__()
116 | self.inplace = inplace
117 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
118 |
119 | @torch.cuda.amp.autocast(enabled=False)
120 | def forward(self, x):
121 | return x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
122 |
123 |
124 | class Block(nn.Module):
125 |
126 | def __init__(
127 | self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
128 | drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm):
129 | super().__init__()
130 | self.norm1 = norm_layer(dim)
131 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
132 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
133 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
134 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
135 |
136 | self.norm2 = norm_layer(dim)
137 | mlp_hidden_dim = int(dim * mlp_ratio)
138 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
139 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
140 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
141 |
142 | def forward(self, x):
143 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
144 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
145 |
146 | return x
147 |
148 |
149 | class CrossBlock(nn.Module):
150 |
151 | def __init__(
152 | self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
153 | drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm):
154 | super().__init__()
155 | self.norm1 = norm_layer(dim)
156 | self.norm2 = norm_layer(dim)
157 | self.cross_attn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
158 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
159 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
160 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
161 |
162 | self.norm3 = norm_layer(dim)
163 | mlp_hidden_dim = int(dim * mlp_ratio)
164 | self.mlp1 = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
165 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
166 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
167 |
168 | self.norm4 = norm_layer(dim)
169 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
170 | self.ls3 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
171 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
172 | self.drop_path3 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
173 |
174 | self.norm5 = norm_layer(dim)
175 | mlp_hidden_dim = int(dim * mlp_ratio)
176 | self.mlp2 = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
177 | self.ls4 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
178 | self.drop_path4 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
179 |
180 |
181 | def forward(self, query, key):
182 | query = query + self.drop_path1(self.ls1(self.cross_attn(self.norm1(query), self.norm2(key))))
183 | query = query + self.drop_path2(self.ls2(self.mlp1(self.norm3(query))))
184 | query = query + self.drop_path3(self.ls3(self.attn(self.norm4(query))))
185 | query = query + self.drop_path4(self.ls4(self.mlp2(self.norm5(query))))
186 | return query
187 |
188 |
189 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
190 | """ Vision Transformer with support for global average pooling
191 | """
192 | def __init__(self, global_pool=False, **kwargs):
193 | init_values = kwargs.pop('init_values')
194 | super(VisionTransformer, self).__init__(**kwargs)
195 |
196 | drop_path_rate = kwargs['drop_path_rate']
197 | depth = kwargs['depth']
198 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
199 | self.blocks = nn.Sequential(*[
200 | Block(
201 | dim=kwargs['embed_dim'], num_heads=kwargs['num_heads'], mlp_ratio=kwargs['mlp_ratio'], qkv_bias=kwargs['qkv_bias'],
202 | init_values=init_values, norm_layer=kwargs['norm_layer'], drop_path=dpr[i])
203 | for i in range(kwargs['depth'])])
204 |
205 | self.global_pool = global_pool
206 | norm_layer = kwargs['norm_layer']
207 | embed_dim = kwargs['embed_dim']
208 | if self.global_pool:
209 | self.fc_norm = norm_layer(embed_dim)
210 |
211 | del self.norm # remove the original norm
212 |
213 | # remove cls token embedding
214 | # delattr(self, 'cls_token')
215 |
216 | num_patches = self.patch_embed.num_patches
217 | if self.global_pool:
218 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim), requires_grad=False)
219 | else:
220 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
221 | self.cls_pos_embed = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=False)
222 |
223 | def forward_features(self, x):
224 | B = x.shape[0]
225 | x = self.patch_embed(x)
226 |
227 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
228 | x = torch.cat((cls_tokens, x), dim=1)
229 | x = x + self.pos_embed
230 | x = self.pos_drop(x)
231 |
232 | for blk in self.blocks:
233 | x = blk(x)
234 |
235 | outcome = x
236 |
237 | return outcome
238 |
239 | def forward_head(self, x, pre_logits: bool = False):
240 | if self.global_pool:
241 | x = x[:, 1:, :].mean(dim=1)
242 | else:
243 | x[:, 0]
244 | x = self.fc_norm(x)
245 | return x if pre_logits else self.head(x)
246 |
247 |
248 | def vit_base_patch16(**kwargs):
249 | model = VisionTransformer(
250 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
251 | norm_layer=partial(LayerNorm, eps=1e-6), **kwargs)
252 | return model
253 |
254 |
255 | def vit_large_patch16(**kwargs):
256 | model = VisionTransformer(
257 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
258 | norm_layer=partial(LayerNorm, eps=1e-6), **kwargs)
259 | return model
260 |
261 |
262 | def vit_huge_patch14(**kwargs):
263 | model = VisionTransformer(
264 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
265 | norm_layer=partial(LayerNorm, eps=1e-6), **kwargs)
266 | return model
267 |
--------------------------------------------------------------------------------
/main_linprobe.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MAE (https://github.com/facebookresearch/mae)
6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 | # References:
9 | # MoCo v3: https://github.com/facebookresearch/moco-v3
10 | # DeiT: https://github.com/facebookresearch/deit
11 | # ------------------------------------------------------------------------
12 |
13 |
14 | import argparse
15 | import datetime
16 | import json
17 | import numpy as np
18 | import os
19 | import time
20 | from pathlib import Path
21 |
22 | import torch
23 | import torch.backends.cudnn as cudnn
24 | from torch.utils.tensorboard import SummaryWriter
25 | import torchvision.transforms as transforms
26 | import torchvision.datasets as datasets
27 |
28 | import timm
29 |
30 | assert timm.__version__ == "0.6.12" # version check
31 | from timm.models.layers import trunc_normal_
32 |
33 | import util.misc as misc
34 | from util.pos_embed import interpolate_pos_embed
35 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
36 | from util.lars import LARS
37 | from util.crop import RandomResizedCrop
38 | import models_vit
39 | from engine_finetune import train_one_epoch, evaluate
40 |
41 |
42 | def get_args_parser():
43 | parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False)
44 | parser.add_argument('--batch_size', default=512, type=int,
45 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
46 | parser.add_argument('--epochs', default=90, type=int)
47 | parser.add_argument('--accum_iter', default=1, type=int,
48 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
49 |
50 | # Model parameters
51 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
52 | help='Name of model to train')
53 |
54 | # Optimizer parameters
55 | parser.add_argument('--weight_decay', type=float, default=0,
56 | help='weight decay (default: 0 for linear probe following MoCo v1)')
57 |
58 | parser.add_argument('--lr', type=float, default=None, metavar='LR',
59 | help='learning rate (absolute lr)')
60 | parser.add_argument('--blr', type=float, default=0.1, metavar='LR',
61 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
62 |
63 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
64 | help='lower lr bound for cyclic schedulers that hit 0')
65 |
66 | parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
67 | help='epochs to warmup LR')
68 |
69 | # * Finetuning params
70 | parser.add_argument('--finetune', default='',
71 | help='finetune from checkpoint')
72 | parser.add_argument('--global_pool', action='store_true')
73 | parser.set_defaults(global_pool=False)
74 | parser.add_argument('--cls_token', action='store_false', dest='global_pool',
75 | help='Use class token instead of global pool for classification')
76 |
77 | # Dataset parameters
78 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
79 | help='dataset path')
80 | parser.add_argument('--nb_classes', default=1000, type=int,
81 | help='number of the classification types')
82 | parser.add_argument('--use_tcs_dataset', default=False, action='store_true')
83 |
84 | parser.add_argument('--output_dir', default='./output_dir',
85 | help='path where to save, empty for no saving')
86 | parser.add_argument('--log_dir', default='./output_dir',
87 | help='path where to tensorboard log')
88 | parser.add_argument('--device', default='cuda',
89 | help='device to use for training / testing')
90 | parser.add_argument('--seed', default=0, type=int)
91 | parser.add_argument('--resume', default='',
92 | help='resume from checkpoint')
93 |
94 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
95 | help='start epoch')
96 | parser.add_argument('--eval', action='store_true',
97 | help='Perform evaluation only')
98 | parser.add_argument('--dist_eval', action='store_true', default=False,
99 | help='Enabling distributed evaluation (recommended during training for faster monitor')
100 | parser.add_argument('--num_workers', default=10, type=int)
101 | parser.add_argument('--pin_mem', action='store_true',
102 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
103 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
104 | # parser.set_defaults(pin_mem=True)
105 |
106 | # distributed training parameters
107 | parser.add_argument('--world_size', default=1, type=int,
108 | help='number of distributed processes')
109 | parser.add_argument('--local_rank', default=-1, type=int)
110 | parser.add_argument('--dist_on_itp', action='store_true')
111 | parser.add_argument('--dist_url', default='env://',
112 | help='url used to set up distributed training')
113 |
114 | parser.add_argument('--auto_resume', action='store_true', default=True)
115 | parser.add_argument('--init_values', default=1.0, type=float)
116 |
117 | return parser
118 |
119 |
120 | def main(args):
121 | misc.init_distributed_mode(args)
122 |
123 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
124 | print("{}".format(args).replace(', ', ',\n'))
125 |
126 | device = torch.device(args.device)
127 |
128 | # fix the seed for reproducibility
129 | seed = args.seed + misc.get_rank()
130 | torch.manual_seed(seed)
131 | np.random.seed(seed)
132 |
133 | cudnn.benchmark = True
134 | torch.backends.cuda.matmul.allow_tf32 = False
135 | torch.backends.cudnn.allow_tf32 = False
136 |
137 | # linear probe: weak augmentation
138 | transform_train = transforms.Compose([
139 | RandomResizedCrop(224, interpolation=3),
140 | transforms.RandomHorizontalFlip(),
141 | transforms.ToTensor(),
142 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
143 | transform_val = transforms.Compose([
144 | transforms.Resize(256, interpolation=3),
145 | transforms.CenterCrop(224),
146 | transforms.ToTensor(),
147 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
148 | if not args.use_tcs_dataset:
149 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
150 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val)
151 | else: # for internal use only
152 | from util.tcs_datasets import ImagenetTCSDataset
153 | dataset_train = ImagenetTCSDataset(
154 | 'train',
155 | 's3://imagenet',
156 | use_tcs=True,
157 | transform=transform_train)
158 | dataset_val = ImagenetTCSDataset(
159 | 'val',
160 | 's3://imagenet',
161 | use_tcs=True,
162 | transform=transform_val)
163 | print(dataset_train)
164 | print(dataset_val)
165 |
166 | # build dataloader
167 | if True: # args.distributed:
168 | num_tasks = misc.get_world_size()
169 | global_rank = misc.get_rank()
170 | sampler_train = torch.utils.data.DistributedSampler(
171 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
172 | )
173 | print("Sampler_train = %s" % str(sampler_train))
174 | if args.dist_eval:
175 | if len(dataset_val) % num_tasks != 0:
176 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
177 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
178 | 'equal num of samples per-process.')
179 | sampler_val = torch.utils.data.DistributedSampler(
180 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
181 | else:
182 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
183 | else:
184 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
185 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
186 |
187 | data_loader_train = torch.utils.data.DataLoader(
188 | dataset_train, sampler=sampler_train,
189 | batch_size=args.batch_size,
190 | num_workers=args.num_workers,
191 | pin_memory=args.pin_mem,
192 | drop_last=True,
193 | )
194 |
195 | data_loader_val = torch.utils.data.DataLoader(
196 | dataset_val, sampler=sampler_val,
197 | batch_size=args.batch_size,
198 | num_workers=args.num_workers,
199 | pin_memory=args.pin_mem,
200 | drop_last=False
201 | )
202 |
203 | if global_rank == 0 and args.log_dir is not None and not args.eval:
204 | os.makedirs(args.log_dir, exist_ok=True)
205 | log_writer = SummaryWriter(log_dir=args.log_dir)
206 | else:
207 | log_writer = None
208 |
209 | # build model
210 | model = models_vit.__dict__[args.model](
211 | num_classes=args.nb_classes,
212 | global_pool=args.global_pool,
213 | init_values=args.init_values if args.init_values != 1.0 else None,
214 | drop_path_rate=0.0
215 | )
216 |
217 | # load ckpt
218 | if args.finetune and not args.eval:
219 | checkpoint = torch.load(args.finetune, map_location='cpu')
220 |
221 | print("Load pre-trained checkpoint from: %s" % args.finetune)
222 | checkpoint_model = checkpoint['model']
223 |
224 | state_dict = model.state_dict()
225 | for k in ['head.weight', 'head.bias']:
226 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
227 | print(f"Removing key {k} from pretrained checkpoint")
228 | del checkpoint_model[k]
229 |
230 | # interpolate position embedding
231 | interpolate_pos_embed(model, checkpoint_model)
232 |
233 | # load pre-trained model
234 | msg = model.load_state_dict(checkpoint_model, strict=False)
235 | print(msg)
236 |
237 | if args.global_pool:
238 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
239 | else:
240 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'}
241 |
242 | # manually initialize fc layer: following MoCo v3
243 | trunc_normal_(model.head.weight, std=0.01)
244 |
245 | # for linear prob only
246 | # hack: revise model's head with BN
247 | # model.bn = torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6)
248 | model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head)
249 | # freeze all but the head
250 | for _, p in model.named_parameters():
251 | p.requires_grad = False
252 | for _, p in model.head.named_parameters():
253 | p.requires_grad = True
254 |
255 | model.to(device)
256 |
257 | model_without_ddp = model
258 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
259 |
260 | print("Model = %s" % str(model_without_ddp))
261 | print('number of params (M): %.2f' % (n_parameters / 1.e6))
262 |
263 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
264 |
265 | if args.lr is None: # only base_lr is specified
266 | args.lr = args.blr * eff_batch_size / 256
267 |
268 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
269 | print("actual lr: %.2e" % args.lr)
270 |
271 | print("accumulate grad iterations: %d" % args.accum_iter)
272 | print("effective batch size: %d" % eff_batch_size)
273 |
274 | if args.distributed:
275 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
276 | model_without_ddp = model.module
277 |
278 | # build optimizer
279 | optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay)
280 | print(optimizer)
281 | loss_scaler = NativeScaler()
282 |
283 | criterion = torch.nn.CrossEntropyLoss()
284 |
285 | print("criterion = %s" % str(criterion))
286 |
287 | # misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
288 | misc.auto_load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
289 |
290 | if args.eval:
291 | test_stats = evaluate(data_loader_val, model, device)
292 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
293 | exit(0)
294 |
295 | # start training
296 | print(f"Start training for {args.epochs} epochs")
297 | start_time = time.time()
298 | max_accuracy = 0.0
299 | for epoch in range(args.start_epoch, args.epochs):
300 | if args.distributed:
301 | data_loader_train.sampler.set_epoch(epoch)
302 | train_stats = train_one_epoch(
303 | model, criterion, data_loader_train,
304 | optimizer, device, epoch, loss_scaler,
305 | max_norm=None,
306 | log_writer=log_writer,
307 | args=args
308 | )
309 |
310 | # save ckpt
311 | if args.output_dir:
312 | misc.save_model(
313 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
314 | loss_scaler=loss_scaler, epoch=epoch, latest=True)
315 |
316 | if (epoch+1)%1 == 0:
317 | test_stats = evaluate(data_loader_val, model, device)
318 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
319 | max_accuracy = max(max_accuracy, test_stats["acc1"])
320 | print(f'Max accuracy: {max_accuracy:.2f}%')
321 |
322 | if log_writer is not None:
323 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
324 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)
325 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch)
326 |
327 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
328 | **{f'test_{k}': v for k, v in test_stats.items()},
329 | 'epoch': epoch,
330 | 'n_parameters': n_parameters}
331 |
332 | if args.output_dir and misc.is_main_process():
333 | if log_writer is not None:
334 | log_writer.flush()
335 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
336 | f.write(json.dumps(log_stats) + "\n")
337 |
338 | total_time = time.time() - start_time
339 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
340 | print('Training time {}'.format(total_time_str))
341 |
342 |
343 | if __name__ == '__main__':
344 | args = get_args_parser()
345 | args = args.parse_args()
346 | if args.output_dir:
347 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
348 | main(args)
349 |
--------------------------------------------------------------------------------
/main_pretrain.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MAE (https://github.com/facebookresearch/mae)
6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 | # References:
9 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
10 | # DeiT: https://github.com/facebookresearch/deit
11 | # ------------------------------------------------------------------------
12 |
13 | import argparse
14 | import datetime
15 | import json
16 | import numpy as np
17 | import os
18 | import time
19 | from pathlib import Path
20 |
21 | import torch
22 | import torch.distributed as dist
23 | import torch.backends.cudnn as cudnn
24 | from torch.utils.tensorboard import SummaryWriter
25 | import torchvision.transforms as transforms
26 | import torchvision.datasets as datasets
27 |
28 | import timm
29 | assert timm.__version__ == "0.6.12" # version check
30 | from timm.optim.optim_factory import param_groups_weight_decay
31 | from timm.optim import create_optimizer
32 |
33 | import util.misc as misc
34 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
35 | from util.augmentation import RandomResizedCrop, GaussianBlur, SingleRandomResizedCrop, RandomHorizontalFlip, Solarize
36 | from util.datasets import ImagenetWithMask
37 | import models_sim
38 | from engine_pretrain import train_one_epoch
39 |
40 |
41 | def get_args_parser():
42 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
43 | parser.add_argument('--batch_size', default=64, type=int,
44 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
45 | parser.add_argument('--epochs', default=400, type=int)
46 | parser.add_argument('--accum_iter', default=1, type=int,
47 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
48 |
49 | # Model parameters
50 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL',
51 | help='Name of model to train')
52 |
53 | parser.add_argument('--input_size', default=224, type=int,
54 | help='images input size')
55 |
56 | parser.add_argument('--mask_ratio', default=0.75, type=float,
57 | help='Masking ratio (percentage of removed patches).')
58 |
59 | parser.add_argument('--norm_pix_loss', action='store_true',
60 | help='Use (per-patch) normalized pixels as targets for computing loss')
61 | parser.set_defaults(norm_pix_loss=False)
62 |
63 | parser.add_argument('--use_abs_pos_emb', default=True, action='store_true')
64 | parser.add_argument('--disable_abs_pos_emb', dest='use_abs_pos_emb', action='store_false')
65 | parser.add_argument('--use_shared_rel_pos_bias', default=False, action='store_true')
66 |
67 | # Optimizer parameters
68 | parser.add_argument('--weight_decay', type=float, default=0.05,
69 | help='weight decay (default: 0.05)')
70 |
71 | parser.add_argument('--lr', type=float, default=None, metavar='LR',
72 | help='learning rate (absolute lr)')
73 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
74 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
75 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
76 | help='lower lr bound for cyclic schedulers that hit 0')
77 |
78 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
79 | help='epochs to warmup LR')
80 |
81 | # Dataset parameters
82 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
83 | help='dataset path')
84 |
85 | parser.add_argument('--output_dir', default='./output_dir',
86 | help='path where to save, empty for no saving')
87 | parser.add_argument('--log_dir', default='./output_dir',
88 | help='path where to tensorboard log')
89 | parser.add_argument('--device', default='cuda',
90 | help='device to use for training / testing')
91 | parser.add_argument('--seed', default=0, type=int)
92 | parser.add_argument('--resume', default='',
93 | help='resume from checkpoint')
94 |
95 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
96 | help='start epoch')
97 | parser.add_argument('--num_workers', default=10, type=int)
98 | parser.add_argument('--pin_mem', action='store_true',
99 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
100 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
101 | # parser.set_defaults(pin_mem=True)
102 |
103 | # distributed training parameters
104 | parser.add_argument('--world_size', default=1, type=int,
105 | help='number of distributed processes')
106 | parser.add_argument('--local_rank', default=-1, type=int)
107 | parser.add_argument('--dist_on_itp', action='store_true')
108 | parser.add_argument('--dist_url', default='env://',
109 | help='url used to set up distributed training')
110 |
111 | # SiameseIM parameters
112 | # data
113 | parser.add_argument('--crop_min', default=0.2, type=float)
114 | parser.add_argument('--use_tcs_dataset', default=False, action='store_true')
115 |
116 | # model
117 | parser.add_argument('--decoder_embed_dim', default=512, type=int)
118 | parser.add_argument('--drop_path_rate', default=0.0, type=float)
119 | parser.add_argument('--init_values', default=None, type=float)
120 | parser.add_argument('--projector_depth', default=2, type=int)
121 | parser.add_argument('--predictor_depth', default=4, type=int)
122 | parser.add_argument('--use_proj_ln', default=False, action='store_true')
123 | parser.add_argument('--use_pred_ln', default=False, action='store_true')
124 | parser.add_argument('--train_patch_embed', default=False, action='store_true')
125 | parser.add_argument('--online_ln', default=False, action='store_true', help='also use frozen LN in online branch')
126 |
127 | parser.add_argument('--loss_type', default='mae')
128 | parser.add_argument('--neg_weight', default=0.02, type=float)
129 |
130 | parser.add_argument('--with_blockwise_mask', default=False, action='store_true')
131 | parser.add_argument('--blockwise_num_masking_patches', default=75, type=int)
132 |
133 | # hyper-parameter
134 | parser.add_argument('--mm', default=0.996, type=float)
135 | parser.add_argument('--mmschedule', default='const')
136 | parser.add_argument('--lambda_F', default=50, type=float) # may no need
137 | parser.add_argument('--T', default=0.2, type=float) # check
138 | parser.add_argument('--clip_grad', default=None, type=float)
139 | parser.add_argument('--beta2', default=0.95, type=float)
140 |
141 | # misc
142 | parser.add_argument('--auto_resume', default=True)
143 | parser.add_argument('--save_freq', default=50, type=int)
144 | parser.add_argument('--save_latest_freq', default=1, type=int)
145 | parser.add_argument('--fp32', default=False, action='store_true')
146 | parser.add_argument('--amp_growth_interval', default=2000, type=int)
147 |
148 | return parser
149 |
150 |
151 | class DataAugmentationForSIM(object):
152 | def __init__(self, args):
153 | self.args = args
154 |
155 | self.random_resized_crop = SingleRandomResizedCrop(args.input_size, scale=(args.crop_min, 1.0), interpolation=3)
156 | self.random_flip = RandomHorizontalFlip()
157 |
158 | self.color_transform1 = transforms.Compose([
159 | transforms.RandomApply([
160 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
161 | ], p=0.8),
162 | transforms.RandomGrayscale(p=0.2),
163 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=1.0),
164 | ])
165 |
166 | self.color_transform2 = transforms.Compose([
167 | transforms.RandomApply([
168 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
169 | ], p=0.8),
170 | transforms.RandomGrayscale(p=0.2),
171 | transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1),
172 | transforms.RandomApply([Solarize()], p=0.2),
173 | ])
174 |
175 | self.format_transform = transforms.Compose([
176 | transforms.ToTensor(),
177 | transforms.Normalize(
178 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
179 | ])
180 |
181 | def __call__(self, image):
182 | spatial_image1, flip1 = self.random_flip(image)
183 | spatial_image2, flip2 = self.random_flip(image)
184 | spatial_image1, i1, j1, h1, w1, W = self.random_resized_crop(spatial_image1)
185 | spatial_image2, i2, j2, h2, w2, W = self.random_resized_crop(spatial_image2)
186 | color_image1 = self.color_transform1(spatial_image1)
187 | color_image2 = self.color_transform2(spatial_image2)
188 |
189 | relative_flip = (flip1 and not flip2) or (flip2 and not flip1)
190 | return self.format_transform(color_image1), self.format_transform(color_image2), \
191 | (i2-i1)/h1, (j2-j1)/w1, h2/h1, w2/w1, relative_flip, (W-j1-j2)/w1
192 |
193 | def __repr__(self):
194 | repr = "(DataAugmentation,\n"
195 | repr += " transform = %s,\n" % str(self.random_resized_crop) + str(self.random_flip) + str(self.color_transform1) + str(self.format_transform)
196 | repr += ")"
197 | return repr
198 |
199 |
200 | def main(args):
201 | misc.init_distributed_mode(args) # need change to torch.engine
202 |
203 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
204 | print("{}".format(args).replace(', ', ',\n'))
205 |
206 | device = torch.device(args.device)
207 |
208 | # fix the seed for reproducibility
209 | seed = args.seed + misc.get_rank()
210 | torch.manual_seed(seed)
211 | np.random.seed(seed)
212 |
213 | cudnn.benchmark = True
214 |
215 | # disable tf32
216 | torch.backends.cuda.matmul.allow_tf32 = False
217 | torch.backends.cudnn.allow_tf32 = False
218 |
219 | # build augmentation and dataset
220 | if args.loss_type in ['sim']:
221 | transform_train = DataAugmentationForSIM(args)
222 | else:
223 | transform_train = transforms.Compose([
224 | transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic
225 | transforms.RandomHorizontalFlip(),
226 | transforms.ToTensor(),
227 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
228 | if not args.use_tcs_dataset:
229 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
230 | dataset_train = ImagenetWithMask(os.path.join(args.data_path, 'train'),
231 | transform=transform_train,
232 | with_blockwise_mask=args.with_blockwise_mask,
233 | blockwise_num_masking_patches=args.blockwise_num_masking_patches)
234 | else: # for internal use only
235 | from util.tcs_datasets import ImagenetTCSDataset
236 | dataset_train = ImagenetTCSDataset('train',
237 | 's3://imagenet',
238 | use_tcs=True,
239 | transform=transform_train,
240 | with_blockwise_mask=args.with_blockwise_mask,
241 | blockwise_num_masking_patches=args.blockwise_num_masking_patches,
242 | local_rank=int(os.environ['LOCAL_RANK']),
243 | local_size=int(os.environ['LOCAL_SIZE']),
244 | tcs_conf_path='./petreloss.conf')
245 | print(dataset_train)
246 |
247 | # build dataloader
248 | if True: # args.distributed:
249 | num_tasks = misc.get_world_size()
250 | global_rank = misc.get_rank()
251 | sampler_train = torch.utils.data.DistributedSampler(
252 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
253 | )
254 | print("Sampler_train = %s" % str(sampler_train))
255 | else:
256 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
257 |
258 | if global_rank == 0 and args.log_dir is not None:
259 | os.makedirs(args.log_dir, exist_ok=True)
260 | log_writer = SummaryWriter(log_dir=args.log_dir)
261 | else:
262 | log_writer = None
263 |
264 | data_loader_train = torch.utils.data.DataLoader(
265 | dataset_train, sampler=sampler_train,
266 | batch_size=args.batch_size,
267 | num_workers=args.num_workers,
268 | pin_memory=args.pin_mem,
269 | drop_last=True,
270 | )
271 |
272 |
273 | # build model
274 | model = models_sim.__dict__[args.model](norm_pix_loss=args.norm_pix_loss, args=args)
275 | model.to(device)
276 | model_without_ddp = model
277 |
278 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
279 | if args.lr is None: # only base_lr is specified
280 | args.lr = args.blr * eff_batch_size / 256
281 |
282 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
283 | print("actual lr: %.2e" % args.lr)
284 |
285 | print("accumulate grad iterations: %d" % args.accum_iter)
286 | print("effective batch size: %d" % eff_batch_size)
287 |
288 | if args.distributed:
289 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
290 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
291 | model_without_ddp = model.module
292 | print("Model = %s" % str(model_without_ddp))
293 |
294 | # build optimizer
295 | # following timm: set wd as 0 for bias and norm layers
296 | param_groups = param_groups_weight_decay(model_without_ddp, args.weight_decay)
297 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, args.beta2))
298 | print(optimizer)
299 | loss_scaler = NativeScaler(enabled=(not args.fp32), growth_interval=args.amp_growth_interval)
300 |
301 | misc.auto_load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer,
302 | loss_scaler=loss_scaler)
303 |
304 | # start training
305 | print(f"Start training for {args.epochs} epochs")
306 | start_time = time.time()
307 | for epoch in range(args.start_epoch, args.epochs):
308 | epoch_start_time = time.time()
309 | if args.distributed:
310 | data_loader_train.sampler.set_epoch(epoch)
311 | train_stats = train_one_epoch(
312 | model, data_loader_train,
313 | optimizer, device, epoch, loss_scaler,
314 | log_writer=log_writer,
315 | args=args
316 | )
317 | dist.barrier()
318 |
319 | # save ckpt
320 | if args.output_dir and ((epoch+1) % args.save_freq == 0 or epoch + 1 == args.epochs):
321 | misc.save_model(
322 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
323 | loss_scaler=loss_scaler, epoch=epoch)
324 | if (epoch+1) % args.save_latest_freq == 0:
325 | misc.save_model(
326 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
327 | loss_scaler=loss_scaler, epoch=epoch, latest=True)
328 |
329 | # log information
330 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
331 | 'epoch': epoch,}
332 |
333 | if args.output_dir and misc.is_main_process():
334 | if log_writer is not None:
335 | log_writer.flush()
336 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
337 | f.write(json.dumps(log_stats) + "\n")
338 | if misc.is_main_process():
339 | epoch_total_time = time.time() - epoch_start_time
340 | now = datetime.datetime.today()
341 | eta = now + datetime.timedelta(seconds=(args.epochs-epoch-1)*int(epoch_total_time))
342 | next_50_ep = ((epoch + 1) // 50 + 1) * 50
343 | eta_to_next_50 =now + datetime.timedelta(seconds=(next_50_ep - epoch - 1) * int(epoch_total_time))
344 | print(f"ETA to {args.epochs:4d}ep:\t{str(eta)}")
345 | print(f"ETA to {next_50_ep:4d}ep:\t{str(eta_to_next_50)}")
346 | dist.barrier()
347 |
348 | total_time = time.time() - start_time
349 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
350 | print('Training time {}'.format(total_time_str))
351 |
352 |
353 | if __name__ == '__main__':
354 | args = get_args_parser()
355 | args = args.parse_args()
356 | if args.output_dir:
357 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
358 | main(args)
359 |
--------------------------------------------------------------------------------
/main_logistic.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MSN (https://github.com/facebookresearch/msn)
6 | # Copyright (c) Facebook, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 |
9 |
10 | import os
11 | import argparse
12 | import logging
13 | import pprint
14 |
15 | import numpy as np
16 | import torch
17 | import torchvision.transforms as transforms
18 | import cyanure as cyan
19 |
20 |
21 | logging.basicConfig()
22 | logger = logging.getLogger()
23 | logger.setLevel(logging.INFO)
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument(
27 | '--lambd', type=float,
28 | default=0.00025,
29 | help='regularization')
30 | parser.add_argument(
31 | '--penalty', type=str,
32 | help='regularization for logistic classifier',
33 | default='l2',
34 | choices=[
35 | 'l2',
36 | 'elastic-net'
37 | ])
38 | parser.add_argument(
39 | '--mask', type=float,
40 | default=0.0,
41 | help='regularization')
42 | parser.add_argument(
43 | '--preload', action='store_true',
44 | help='whether to preload embs if possible')
45 | parser.add_argument(
46 | '--fname', type=str,
47 | help='model architecture')
48 | parser.add_argument(
49 | '--model-name', type=str,
50 | help='model architecture')
51 | parser.add_argument(
52 | '--pretrained', type=str,
53 | help='path to pretrained model',
54 | default='')
55 | parser.add_argument(
56 | '--device', type=str,
57 | default='cuda:0',
58 | help='device to run script on')
59 | parser.add_argument(
60 | '--normalize', type=bool,
61 | default=True,
62 | help='whether to standardize images before feeding to nework')
63 | parser.add_argument(
64 | '--root-path', type=str,
65 | default='/datasets/',
66 | help='root directory to data')
67 | parser.add_argument(
68 | '--image-folder', type=str,
69 | default='imagenet_full_size/061417/',
70 | help='image directory inside root_path')
71 | parser.add_argument(
72 | '--subset-path', type=str,
73 | default=None,
74 | help='name of dataset to evaluate on')
75 | parser.add_argument('--local_rank', default=-1, type=int)
76 |
77 | logging.basicConfig()
78 | logger = logging.getLogger()
79 | logger.setLevel(logging.INFO)
80 |
81 | _GLOBAL_SEED = 0
82 | np.random.seed(_GLOBAL_SEED)
83 | torch.manual_seed(_GLOBAL_SEED)
84 | torch.backends.cudnn.benchmark = True
85 |
86 | pp = pprint.PrettyPrinter(indent=4)
87 |
88 |
89 | def main(
90 | blocks,
91 | lambd,
92 | mask_frac,
93 | preload,
94 | pretrained,
95 | fname,
96 | subset_path,
97 | root_path,
98 | image_folder,
99 | penalty='l2',
100 | model_name=None,
101 | normalize=True,
102 | device_str='cuda:0',
103 | args=None
104 | ):
105 | init_distributed_mode(args)
106 | # torch.cuda.set_device(args.rank)
107 | # device = torch.device('cuda')
108 | # device = torch.device(device_str)
109 | # if 'cuda' in device_str:
110 | # torch.cuda.set_device(device)
111 |
112 | # -- Define file names used to save computed embeddings (for efficient
113 | # -- reuse if running the script more than once)
114 | subset_tag = '-'.join(subset_path.split('/')).split('.txt')[0] if subset_path is not None else 'imagenet_subses1-100percent'
115 | train_embs_path = f'train-features-{subset_tag}-{fname}'
116 | test_embs_path = f'val-features-{fname}'
117 | logger.info(train_embs_path)
118 | logger.info(test_embs_path)
119 |
120 | # pretrained = os.path.join(pretrained, fname)
121 |
122 | # -- Function to make train/test dataloader
123 | def init_pipe(training):
124 | # -- make data transforms
125 | transform = transforms.Compose([
126 | transforms.Resize(size=256),
127 | transforms.CenterCrop(size=224),
128 | transforms.ToTensor(),
129 | transforms.Normalize(
130 | (0.485, 0.456, 0.406),
131 | (0.229, 0.224, 0.225))])
132 | # -- init data-loaders/samplers
133 | subset_file = subset_path if training else None
134 | data_loader, _ = init_data(
135 | transform=transform,
136 | batch_size=64,
137 | num_workers=0,
138 | world_size=args.world_size,
139 | rank=args.rank,
140 | root_path=root_path,
141 | image_folder=image_folder,
142 | training=training,
143 | copy_data=False,
144 | drop_last=False,
145 | subset_file=subset_file)
146 | return data_loader
147 |
148 | # -- Initialize the model
149 | encoder = init_model(
150 | # device=device,
151 | pretrained=pretrained,
152 | model_name=model_name)
153 | encoder.eval()
154 |
155 | # -- If train embeddings already computed, load file, otherwise, compute
156 | # -- embeddings and save
157 | if preload and os.path.exists(train_embs_path):
158 | checkpoint = torch.load(train_embs_path, map_location='cpu')
159 | embs, labs = checkpoint['embs'], checkpoint['labs']
160 | logger.info(f'loaded embs of shape {embs.shape}')
161 | else:
162 | data_loader = init_pipe(True)
163 | embs, labs = make_embeddings(
164 | blocks=blocks,
165 | # device=device,
166 | mask_frac=mask_frac,
167 | data_loader=data_loader,
168 | encoder=encoder)
169 | torch.save({
170 | 'embs': embs,
171 | 'labs': labs
172 | }, train_embs_path)
173 | logger.info(f'saved train embs of shape {embs.shape}')
174 | # # -- Normalize embeddings
175 | cyan.preprocess(embs, normalize=normalize, columns=False, centering=True)
176 |
177 | # import pdb; pdb.set_trace()
178 |
179 | # -- Fit Logistic Regression Classifier
180 | classifier = cyan.MultiClassifier(loss='multiclass-logistic', penalty=penalty, fit_intercept=False)
181 | lambd /= len(embs)
182 | classifier.fit(
183 | embs.numpy(),
184 | labs.numpy(),
185 | it0=10,
186 | lambd=lambd,
187 | lambd2=lambd,
188 | nthreads=-1,
189 | tol=1e-3,
190 | solver='auto',
191 | seed=0,
192 | max_epochs=300)
193 |
194 | # -- Evaluate and log
195 | train_score = classifier.score(embs.numpy(), labs.numpy())
196 | # -- (save train score)
197 | logger.info(f'train score: {train_score}')
198 |
199 | # -- If test embeddings already computed, load file, otherwise, compute
200 | # -- embeddings and save
201 | if preload and os.path.exists(test_embs_path):
202 | checkpoint = torch.load(test_embs_path, map_location='cpu')
203 | test_embs, test_labs = checkpoint['embs'], checkpoint['labs']
204 | logger.info(f'loaded test embs of shape {test_embs.shape}')
205 | else:
206 | data_loader = init_pipe(False)
207 | test_embs, test_labs = make_embeddings(
208 | blocks=blocks,
209 | # device=device,
210 | mask_frac=0.0,
211 | data_loader=data_loader,
212 | encoder=encoder)
213 | torch.save({
214 | 'embs': test_embs,
215 | 'labs': test_labs
216 | }, test_embs_path)
217 | logger.info(f'saved test embs of shape {test_embs.shape}')
218 | # -- Normalize embeddings
219 | cyan.preprocess(test_embs, normalize=normalize, columns=False, centering=True)
220 |
221 | # -- Evaluate and log
222 | test_score = classifier.score(test_embs.numpy(), test_labs.numpy())
223 | # -- (save test score)
224 | logger.info(f'test score: {test_score}\n\n')
225 |
226 | return test_score
227 |
228 |
229 | def init_distributed_mode(args):
230 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
231 | args.rank = int(os.environ["RANK"])
232 | args.world_size = int(os.environ['WORLD_SIZE'])
233 | args.gpu = int(os.environ['LOCAL_RANK'])
234 | elif 'SLURM_PROCID' in os.environ:
235 | args.rank = int(os.environ['SLURM_PROCID'])
236 | args.world_size = int(os.environ['SLURM_NTASKS'])
237 | node_list = os.environ['SLURM_NODELIST']
238 | num_gpus = torch.cuda.device_count()
239 | args.gpu = args.rank % torch.cuda.device_count()
240 | torch.cuda.set_device(args.rank % num_gpus)
241 | import subprocess
242 | addr = subprocess.getoutput(
243 | f'scontrol show hostname {node_list} | head -n1')
244 | # specify master port
245 | if hasattr(args, 'port'):
246 | os.environ['MASTER_PORT'] = str(args.port)
247 | elif 'MASTER_PORT' in os.environ:
248 | pass # use MASTER_PORT in the environment variable
249 | else:
250 | # 29500 is torch.distributed default port
251 | os.environ['MASTER_PORT'] = '29502'
252 | # use MASTER_ADDR in the environment variable if it already exists
253 | if 'MASTER_ADDR' not in os.environ:
254 | os.environ['MASTER_ADDR'] = addr
255 | os.environ['WORLD_SIZE'] = str(args.world_size)
256 | os.environ['LOCAL_RANK'] = str(args.rank % num_gpus)
257 | os.environ['RANK'] = str(args.rank)
258 | # dist.init_process_group(backend='nccl')
259 | else:
260 | print('Not using distributed mode')
261 | setup_for_distributed(is_master=True) # hack
262 | args.distributed = False
263 | return
264 |
265 | args.distributed = True
266 |
267 | torch.cuda.set_device(args.gpu)
268 | args.dist_backend = 'nccl'
269 | args.dist_url = 'env://'
270 | print('| distributed init (rank {}): {}, gpu {}'.format(
271 | args.rank, args.dist_url, args.gpu), flush=True)
272 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
273 | world_size=args.world_size, rank=args.rank)
274 | torch.distributed.barrier()
275 | # setup_for_distributed(args.rank == 0)
276 |
277 |
278 | def init_data(
279 | transform,
280 | batch_size,
281 | pin_mem=True,
282 | num_workers=8,
283 | world_size=1,
284 | rank=0,
285 | root_path=None,
286 | image_folder=None,
287 | training=True,
288 | copy_data=False,
289 | drop_last=True,
290 | subset_file=None
291 | ):
292 |
293 | # dataset = ImageNet(
294 | # root=root_path,
295 | # image_folder=image_folder,
296 | # transform=transform,
297 | # train=training,
298 | # copy_data=copy_data)
299 | # if subset_file is not None:
300 | # dataset = ImageNetSubset(dataset, subset_file)
301 | import torchvision
302 | if training:
303 | dataset = torchvision.datasets.ImageFolder(os.path.join(root_path, 'train'), transform=transform)
304 | with open(subset_file) as subset_file:
305 | list_imgs = [li.split('\n')[0] for li in subset_file.readlines()]
306 | dataset.samples = [(
307 | os.path.join(os.path.join(root_path, 'train'), li.split('_')[0], li),
308 | dataset.class_to_idx[li.split('_')[0]]
309 | ) for li in list_imgs]
310 | else:
311 | dataset = torchvision.datasets.ImageFolder(os.path.join(root_path, 'val'), transform=transform)
312 |
313 | logger.info('ImageNet dataset created')
314 | dist_sampler = torch.utils.data.distributed.DistributedSampler(
315 | dataset=dataset,
316 | num_replicas=world_size,
317 | rank=rank)
318 | data_loader = torch.utils.data.DataLoader(
319 | dataset,
320 | sampler=dist_sampler,
321 | batch_size=batch_size,
322 | drop_last=drop_last,
323 | pin_memory=pin_mem,
324 | num_workers=num_workers)
325 | logger.info('ImageNet unsupervised data loader created')
326 |
327 | return (data_loader, dist_sampler)
328 |
329 |
330 | def make_embeddings(
331 | blocks,
332 | # device,
333 | mask_frac,
334 | data_loader,
335 | encoder,
336 | epochs=1
337 | ):
338 | ipe = len(data_loader)
339 |
340 | z_mem, l_mem = [], []
341 |
342 | for _ in range(epochs):
343 | for itr, (imgs, labels) in enumerate(data_loader):
344 | imgs = imgs.cuda()
345 | with torch.no_grad():
346 | z = encoder.forward_features(imgs)[:, 0].cpu()
347 | labels = labels.cpu()
348 | z_mem.append(z)
349 | l_mem.append(labels)
350 | if itr % 50 == 0:
351 | logger.info(f'[{itr}/{ipe}]')
352 |
353 | z_mem = torch.cat(z_mem, 0)
354 | l_mem = torch.cat(l_mem, 0)
355 | z_mem = all_gather(z_mem)
356 | z_mem = torch.cat(z_mem, 0)
357 | l_mem = all_gather(l_mem)
358 | l_mem = torch.cat(l_mem, 0)
359 | logger.info(z_mem.shape)
360 | logger.info(l_mem.shape)
361 |
362 | return z_mem, l_mem
363 |
364 |
365 | def all_gather(data):
366 | """
367 | Run all_gather on arbitrary picklable data (not necessarily tensors)
368 | Args:
369 | data: any picklable object
370 | Returns:
371 | list[data]: list of data gathered from each rank
372 | """
373 | world_size = torch.distributed.get_world_size()
374 | if world_size == 1:
375 | return [data]
376 |
377 | # serialized to a Tensor
378 | import pickle
379 | buffer = pickle.dumps(data)
380 | storage = torch.ByteStorage.from_buffer(buffer)
381 | tensor = torch.ByteTensor(storage).to("cuda")
382 |
383 | # obtain Tensor size of each rank
384 | local_size = torch.LongTensor([tensor.numel()]).to("cuda")
385 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
386 | torch.distributed.all_gather(size_list, local_size)
387 | size_list = [int(size.item()) for size in size_list]
388 | max_size = max(size_list)
389 |
390 | # receiving Tensor from all ranks
391 | # we pad the tensor because torch all_gather does not support
392 | # gathering tensors of different shapes
393 | tensor_list = []
394 | for _ in size_list:
395 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
396 | if local_size != max_size:
397 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
398 | tensor = torch.cat((tensor, padding), dim=0)
399 | torch.distributed.all_gather(tensor_list, tensor)
400 |
401 | data_list = []
402 | for size, tensor in zip(size_list, tensor_list):
403 | buffer = tensor.cpu().numpy().tobytes()[:size]
404 | data_list.append(pickle.loads(buffer))
405 |
406 | return data_list
407 |
408 |
409 | def load_pretrained(
410 | encoder,
411 | pretrained
412 | ):
413 | checkpoint = torch.load(pretrained, map_location='cpu')
414 | pretrained_dict = {k.replace('module.', ''): v for k, v in checkpoint['target_encoder'].items()}
415 | for k, v in encoder.state_dict().items():
416 | if k not in pretrained_dict:
417 | logger.info(f'key "{k}" could not be found in loaded state dict')
418 | elif pretrained_dict[k].shape != v.shape:
419 | logger.info(f'key "{k}" is of different shape in model and loaded state dict')
420 | pretrained_dict[k] = v
421 | msg = encoder.load_state_dict(pretrained_dict, strict=False)
422 | print(encoder)
423 | logger.info(f'loaded pretrained model with msg: {msg}')
424 | try:
425 | logger.info(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]} '
426 | f'path: {pretrained}')
427 | except Exception:
428 | pass
429 | del checkpoint
430 | return encoder
431 |
432 |
433 | def init_model(
434 | # device,
435 | pretrained,
436 | model_name,
437 | ):
438 | # encoder = deit.__dict__[model_name]()
439 | # encoder.fc = None
440 | # encoder.to(device)
441 | # encoder = load_pretrained(encoder=encoder, pretrained=pretrained)
442 |
443 | import models_vit
444 | model = models_vit.__dict__[model_name](
445 | num_classes=1000,
446 | global_pool=True,
447 | init_values=None,
448 | drop_path_rate=0.0
449 | )
450 |
451 | checkpoint = torch.load(pretrained, map_location='cpu')
452 |
453 | print("Load pre-trained checkpoint from: %s" % pretrained)
454 | checkpoint_model = checkpoint['model']
455 | state_dict = model.state_dict()
456 | for k in ['head.weight', 'head.bias']:
457 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
458 | print(f"Removing key {k} from pretrained checkpoint")
459 | del checkpoint_model[k]
460 |
461 | # load pre-trained model
462 | msg = model.load_state_dict(checkpoint_model, strict=False)
463 | print(msg)
464 |
465 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
466 |
467 | model.head = None
468 |
469 | model.cuda()
470 |
471 | return model
472 |
473 |
474 | if __name__ == '__main__':
475 | """'main' for launching script using params read from command line"""
476 | global args
477 | args = parser.parse_args()
478 | pp.pprint(args)
479 | main(
480 | blocks=1,
481 | lambd=args.lambd,
482 | penalty=args.penalty,
483 | mask_frac=args.mask,
484 | preload=args.preload,
485 | pretrained=args.pretrained,
486 | fname=args.fname,
487 | subset_path=args.subset_path,
488 | root_path=args.root_path,
489 | image_folder=args.image_folder,
490 | model_name=args.model_name,
491 | normalize=args.normalize,
492 | device_str=args.device,
493 | args=args
494 | )
495 |
--------------------------------------------------------------------------------
/main_finetune.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MAE (https://github.com/facebookresearch/mae)
6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 | # References:
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # DeiT: https://github.com/facebookresearch/deit
11 | # ------------------------------------------------------------------------
12 |
13 |
14 | import argparse
15 | import datetime
16 | import json
17 | import numpy as np
18 | import os
19 | import time
20 | from pathlib import Path
21 |
22 | import torch
23 | import torch.backends.cudnn as cudnn
24 | from torch.utils.tensorboard import SummaryWriter
25 | import torchvision.datasets as datasets
26 |
27 | import timm
28 | assert timm.__version__ == "0.6.12" # version check
29 | from timm.models.layers import trunc_normal_
30 | from timm.data.mixup import Mixup
31 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
32 |
33 | import util.lr_decay as lrd
34 | import util.misc as misc
35 | from util.datasets import build_transform
36 | from util.pos_embed import interpolate_pos_embed
37 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
38 | import models_vit
39 | from engine_finetune import train_one_epoch, evaluate
40 |
41 |
42 | def get_args_parser():
43 | parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
44 | parser.add_argument('--batch_size', default=64, type=int,
45 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
46 | parser.add_argument('--epochs', default=50, type=int)
47 | parser.add_argument('--accum_iter', default=1, type=int,
48 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
49 |
50 | # Model parameters
51 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
52 | help='Name of model to train')
53 |
54 | parser.add_argument('--input_size', default=224, type=int,
55 | help='images input size')
56 |
57 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
58 | help='Drop path rate (default: 0.1)')
59 |
60 | # Optimizer parameters
61 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
62 | help='Clip gradient norm (default: None, no clipping)')
63 | parser.add_argument('--weight_decay', type=float, default=0.05,
64 | help='weight decay (default: 0.05)')
65 |
66 | parser.add_argument('--lr', type=float, default=None, metavar='LR',
67 | help='learning rate (absolute lr)')
68 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
69 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
70 | parser.add_argument('--layer_decay', type=float, default=0.75,
71 | help='layer-wise lr decay from ELECTRA/BEiT')
72 |
73 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
74 | help='lower lr bound for cyclic schedulers that hit 0')
75 |
76 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
77 | help='epochs to warmup LR')
78 |
79 | # Augmentation parameters
80 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
81 | help='Color jitter factor (enabled only when not using Auto/RandAug)')
82 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
83 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
84 | parser.add_argument('--smoothing', type=float, default=0.1,
85 | help='Label smoothing (default: 0.1)')
86 |
87 | # * Random Erase params
88 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
89 | help='Random erase prob (default: 0.25)')
90 | parser.add_argument('--remode', type=str, default='pixel',
91 | help='Random erase mode (default: "pixel")')
92 | parser.add_argument('--recount', type=int, default=1,
93 | help='Random erase count (default: 1)')
94 | parser.add_argument('--resplit', action='store_true', default=False,
95 | help='Do not random erase first (clean) augmentation split')
96 |
97 | # * Mixup params
98 | parser.add_argument('--mixup', type=float, default=0,
99 | help='mixup alpha, mixup enabled if > 0.')
100 | parser.add_argument('--cutmix', type=float, default=0,
101 | help='cutmix alpha, cutmix enabled if > 0.')
102 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
103 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
104 | parser.add_argument('--mixup_prob', type=float, default=1.0,
105 | help='Probability of performing mixup or cutmix when either/both is enabled')
106 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
107 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
108 | parser.add_argument('--mixup_mode', type=str, default='batch',
109 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
110 |
111 | # * Finetuning params
112 | parser.add_argument('--finetune', default='',
113 | help='finetune from checkpoint')
114 | parser.add_argument('--global_pool', action='store_true')
115 | parser.set_defaults(global_pool=True)
116 | parser.add_argument('--cls_token', action='store_false', dest='global_pool',
117 | help='Use class token instead of global pool for classification')
118 |
119 | # Dataset parameters
120 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
121 | help='dataset path')
122 | parser.add_argument('--nb_classes', default=1000, type=int,
123 | help='number of the classification types')
124 | parser.add_argument('--use_tcs_dataset', default=False, action='store_true')
125 |
126 | parser.add_argument('--output_dir', default='./output_dir',
127 | help='path where to save, empty for no saving')
128 | parser.add_argument('--log_dir', default='./output_dir',
129 | help='path where to tensorboard log')
130 | parser.add_argument('--device', default='cuda',
131 | help='device to use for training / testing')
132 | parser.add_argument('--seed', default=0, type=int)
133 | parser.add_argument('--resume', default='',
134 | help='resume from checkpoint')
135 |
136 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
137 | help='start epoch')
138 | parser.add_argument('--eval', action='store_true',
139 | help='Perform evaluation only')
140 | parser.add_argument('--dist_eval', action='store_true', default=False,
141 | help='Enabling distributed evaluation (recommended during training for faster monitor')
142 | parser.add_argument('--num_workers', default=10, type=int)
143 | parser.add_argument('--pin_mem', action='store_true',
144 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
145 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
146 | parser.set_defaults(pin_mem=True)
147 |
148 | # distributed training parameters
149 | parser.add_argument('--world_size', default=1, type=int,
150 | help='number of distributed processes')
151 | parser.add_argument('--local_rank', default=-1, type=int)
152 | parser.add_argument('--dist_on_itp', action='store_true')
153 | parser.add_argument('--dist_url', default='env://',
154 | help='url used to set up distributed training')
155 |
156 | parser.add_argument('--auto_resume', action='store_true', default=True)
157 | parser.add_argument('--init_values', default=1.0, type=float)
158 |
159 | return parser
160 |
161 |
162 | def main(args):
163 | misc.init_distributed_mode(args)
164 |
165 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
166 | print("{}".format(args).replace(', ', ',\n'))
167 |
168 | device = torch.device(args.device)
169 |
170 | # fix the seed for reproducibility
171 | seed = args.seed + misc.get_rank()
172 | torch.manual_seed(seed)
173 | np.random.seed(seed)
174 |
175 | cudnn.benchmark = True
176 | torch.backends.cuda.matmul.allow_tf32 = False
177 | torch.backends.cudnn.allow_tf32 = False
178 |
179 | # build dataset
180 | transform_train = build_transform(is_train=True, args=args)
181 | transform_val = build_transform(is_train=False, args=args)
182 | if not args.use_tcs_dataset:
183 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
184 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val)
185 | else:
186 | from util.tcs_datasets import ImagenetTCSDataset
187 | dataset_train = ImagenetTCSDataset('train',
188 | 's3://imagenet',
189 | transform=transform_train,
190 | use_tcs=True)
191 | dataset_val = ImagenetTCSDataset('val',
192 | 's3://imagenet',
193 | transform=transform_val,
194 | use_tcs=True)
195 |
196 | print(dataset_train)
197 | print(dataset_val)
198 |
199 | if True: # args.distributed:
200 | num_tasks = misc.get_world_size()
201 | global_rank = misc.get_rank()
202 | sampler_train = torch.utils.data.DistributedSampler(
203 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
204 | )
205 | print("Sampler_train = %s" % str(sampler_train))
206 | if args.dist_eval:
207 | if len(dataset_val) % num_tasks != 0:
208 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
209 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
210 | 'equal num of samples per-process.')
211 | sampler_val = torch.utils.data.DistributedSampler(
212 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
213 | else:
214 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
215 | else:
216 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
217 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
218 |
219 | if global_rank == 0 and args.log_dir is not None and not args.eval:
220 | os.makedirs(args.log_dir, exist_ok=True)
221 | log_writer = SummaryWriter(log_dir=args.log_dir)
222 | else:
223 | log_writer = None
224 |
225 | data_loader_train = torch.utils.data.DataLoader(
226 | dataset_train, sampler=sampler_train,
227 | batch_size=args.batch_size,
228 | num_workers=args.num_workers,
229 | pin_memory=args.pin_mem,
230 | drop_last=True,
231 | )
232 |
233 | data_loader_val = torch.utils.data.DataLoader(
234 | dataset_val, sampler=sampler_val,
235 | batch_size=args.batch_size,
236 | num_workers=args.num_workers,
237 | pin_memory=args.pin_mem,
238 | drop_last=False
239 | )
240 |
241 | # build mixup
242 | mixup_fn = None
243 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
244 | if mixup_active:
245 | print("Mixup is activated!")
246 | mixup_fn = Mixup(
247 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
248 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
249 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
250 |
251 | # build model
252 | model = models_vit.__dict__[args.model](
253 | num_classes=args.nb_classes,
254 | drop_path_rate=args.drop_path,
255 | global_pool=args.global_pool,
256 | init_values=args.init_values if args.init_values != 1.0 else None,
257 | )
258 |
259 | # load ckpt
260 | if args.finetune and not args.eval:
261 | checkpoint = torch.load(args.finetune, map_location='cpu')
262 | print("Load pre-trained checkpoint from: %s" % args.finetune)
263 | checkpoint_model = checkpoint['model']
264 | state_dict = model.state_dict()
265 | for k in ['head.weight', 'head.bias']:
266 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
267 | print(f"Removing key {k} from pretrained checkpoint")
268 | del checkpoint_model[k]
269 |
270 | # interpolate position embedding
271 | interpolate_pos_embed(model, checkpoint_model)
272 |
273 | # load pre-trained model
274 | msg = model.load_state_dict(checkpoint_model, strict=False)
275 | print(msg)
276 |
277 | if args.global_pool:
278 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
279 | else:
280 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'}
281 |
282 | # manually initialize fc layer
283 | if hasattr(model, 'head'):
284 | trunc_normal_(model.head.weight, std=2e-5)
285 |
286 | model.to(device)
287 |
288 | model_without_ddp = model
289 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
290 |
291 | print("Model = %s" % str(model_without_ddp))
292 | print('number of params (M): %.2f' % (n_parameters / 1.e6))
293 |
294 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
295 |
296 | if args.lr is None: # only base_lr is specified
297 | args.lr = args.blr * eff_batch_size / 256
298 |
299 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
300 | print("actual lr: %.2e" % args.lr)
301 |
302 | print("accumulate grad iterations: %d" % args.accum_iter)
303 | print("effective batch size: %d" % eff_batch_size)
304 |
305 | if args.distributed:
306 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
307 | model_without_ddp = model.module
308 |
309 | # build optimizer with layer-wise lr decay (lrd)
310 | param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
311 | no_weight_decay_list=model_without_ddp.no_weight_decay(),
312 | layer_decay=args.layer_decay
313 | )
314 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
315 | loss_scaler = NativeScaler()
316 |
317 | if mixup_fn is not None:
318 | # smoothing is handled with mixup label transform
319 | criterion = SoftTargetCrossEntropy()
320 | elif args.smoothing > 0.:
321 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
322 | else:
323 | criterion = torch.nn.CrossEntropyLoss()
324 |
325 | print("criterion = %s" % str(criterion))
326 |
327 | # misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
328 | misc.auto_load_model(
329 | args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
330 |
331 | if args.eval:
332 | test_stats = evaluate(data_loader_val, model, device)
333 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
334 | exit(0)
335 |
336 | # start training
337 | print(f"Start training for {args.epochs} epochs")
338 | start_time = time.time()
339 | max_accuracy = 0.0
340 | for epoch in range(args.start_epoch, args.epochs):
341 | if args.distributed:
342 | data_loader_train.sampler.set_epoch(epoch)
343 | train_stats = train_one_epoch(
344 | model, criterion, data_loader_train,
345 | optimizer, device, epoch, loss_scaler,
346 | args.clip_grad, mixup_fn,
347 | log_writer=log_writer,
348 | args=args
349 | )
350 |
351 | # save model
352 | if args.output_dir:
353 | misc.save_model(
354 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
355 | loss_scaler=loss_scaler, epoch=epoch, latest=True)
356 |
357 | if (epoch+1)%1 == 0:
358 | test_stats = evaluate(data_loader_val, model, device)
359 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
360 | max_accuracy = max(max_accuracy, test_stats["acc1"])
361 | print(f'Max accuracy: {max_accuracy:.2f}%')
362 |
363 | if log_writer is not None:
364 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
365 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)
366 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch)
367 |
368 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
369 | **{f'test_{k}': v for k, v in test_stats.items()},
370 | 'epoch': epoch,
371 | 'n_parameters': n_parameters}
372 |
373 | if args.output_dir and misc.is_main_process():
374 | if log_writer is not None:
375 | log_writer.flush()
376 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
377 | f.write(json.dumps(log_stats) + "\n")
378 |
379 | total_time = time.time() - start_time
380 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
381 | print('Training time {}'.format(total_time_str))
382 |
383 |
384 | if __name__ == '__main__':
385 | args = get_args_parser()
386 | args = args.parse_args()
387 | if args.output_dir:
388 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
389 | main(args)
390 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Attribution-NonCommercial 4.0 International
3 |
4 | =======================================================================
5 |
6 | Creative Commons Corporation ("Creative Commons") is not a law firm and
7 | does not provide legal services or legal advice. Distribution of
8 | Creative Commons public licenses does not create a lawyer-client or
9 | other relationship. Creative Commons makes its licenses and related
10 | information available on an "as-is" basis. Creative Commons gives no
11 | warranties regarding its licenses, any material licensed under their
12 | terms and conditions, or any related information. Creative Commons
13 | disclaims all liability for damages resulting from their use to the
14 | fullest extent possible.
15 |
16 | Using Creative Commons Public Licenses
17 |
18 | Creative Commons public licenses provide a standard set of terms and
19 | conditions that creators and other rights holders may use to share
20 | original works of authorship and other material subject to copyright
21 | and certain other rights specified in the public license below. The
22 | following considerations are for informational purposes only, are not
23 | exhaustive, and do not form part of our licenses.
24 |
25 | Considerations for licensors: Our public licenses are
26 | intended for use by those authorized to give the public
27 | permission to use material in ways otherwise restricted by
28 | copyright and certain other rights. Our licenses are
29 | irrevocable. Licensors should read and understand the terms
30 | and conditions of the license they choose before applying it.
31 | Licensors should also secure all rights necessary before
32 | applying our licenses so that the public can reuse the
33 | material as expected. Licensors should clearly mark any
34 | material not subject to the license. This includes other CC-
35 | licensed material, or material used under an exception or
36 | limitation to copyright. More considerations for licensors:
37 | wiki.creativecommons.org/Considerations_for_licensors
38 |
39 | Considerations for the public: By using one of our public
40 | licenses, a licensor grants the public permission to use the
41 | licensed material under specified terms and conditions. If
42 | the licensor's permission is not necessary for any reason--for
43 | example, because of any applicable exception or limitation to
44 | copyright--then that use is not regulated by the license. Our
45 | licenses grant only permissions under copyright and certain
46 | other rights that a licensor has authority to grant. Use of
47 | the licensed material may still be restricted for other
48 | reasons, including because others have copyright or other
49 | rights in the material. A licensor may make special requests,
50 | such as asking that all changes be marked or described.
51 | Although not required by our licenses, you are encouraged to
52 | respect those requests where reasonable. More_considerations
53 | for the public:
54 | wiki.creativecommons.org/Considerations_for_licensees
55 |
56 | =======================================================================
57 |
58 | Creative Commons Attribution-NonCommercial 4.0 International Public
59 | License
60 |
61 | By exercising the Licensed Rights (defined below), You accept and agree
62 | to be bound by the terms and conditions of this Creative Commons
63 | Attribution-NonCommercial 4.0 International Public License ("Public
64 | License"). To the extent this Public License may be interpreted as a
65 | contract, You are granted the Licensed Rights in consideration of Your
66 | acceptance of these terms and conditions, and the Licensor grants You
67 | such rights in consideration of benefits the Licensor receives from
68 | making the Licensed Material available under these terms and
69 | conditions.
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 | Section 2 -- Scope.
142 |
143 | a. License grant.
144 |
145 | 1. Subject to the terms and conditions of this Public License,
146 | the Licensor hereby grants You a worldwide, royalty-free,
147 | non-sublicensable, non-exclusive, irrevocable license to
148 | exercise the Licensed Rights in the Licensed Material to:
149 |
150 | a. reproduce and Share the Licensed Material, in whole or
151 | in part, for NonCommercial purposes only; and
152 |
153 | b. produce, reproduce, and Share Adapted Material for
154 | NonCommercial purposes only.
155 |
156 | 2. Exceptions and Limitations. For the avoidance of doubt, where
157 | Exceptions and Limitations apply to Your use, this Public
158 | License does not apply, and You do not need to comply with
159 | its terms and conditions.
160 |
161 | 3. Term. The term of this Public License is specified in Section
162 | 6(a).
163 |
164 | 4. Media and formats; technical modifications allowed. The
165 | Licensor authorizes You to exercise the Licensed Rights in
166 | all media and formats whether now known or hereafter created,
167 | and to make technical modifications necessary to do so. The
168 | Licensor waives and/or agrees not to assert any right or
169 | authority to forbid You from making technical modifications
170 | necessary to exercise the Licensed Rights, including
171 | technical modifications necessary to circumvent Effective
172 | Technological Measures. For purposes of this Public License,
173 | simply making modifications authorized by this Section 2(a)
174 | (4) never produces Adapted Material.
175 |
176 | 5. Downstream recipients.
177 |
178 | a. Offer from the Licensor -- Licensed Material. Every
179 | recipient of the Licensed Material automatically
180 | receives an offer from the Licensor to exercise the
181 | Licensed Rights under the terms and conditions of this
182 | Public License.
183 |
184 | b. No downstream restrictions. You may not offer or impose
185 | any additional or different terms or conditions on, or
186 | apply any Effective Technological Measures to, the
187 | Licensed Material if doing so restricts exercise of the
188 | Licensed Rights by any recipient of the Licensed
189 | Material.
190 |
191 | 6. No endorsement. Nothing in this Public License constitutes or
192 | may be construed as permission to assert or imply that You
193 | are, or that Your use of the Licensed Material is, connected
194 | with, or sponsored, endorsed, or granted official status by,
195 | the Licensor or others designated to receive attribution as
196 | provided in Section 3(a)(1)(A)(i).
197 |
198 | b. Other rights.
199 |
200 | 1. Moral rights, such as the right of integrity, are not
201 | licensed under this Public License, nor are publicity,
202 | privacy, and/or other similar personality rights; however, to
203 | the extent possible, the Licensor waives and/or agrees not to
204 | assert any such rights held by the Licensor to the limited
205 | extent necessary to allow You to exercise the Licensed
206 | Rights, but not otherwise.
207 |
208 | 2. Patent and trademark rights are not licensed under this
209 | Public License.
210 |
211 | 3. To the extent possible, the Licensor waives any right to
212 | collect royalties from You for the exercise of the Licensed
213 | Rights, whether directly or through a collecting society
214 | under any voluntary or waivable statutory or compulsory
215 | licensing scheme. In all other cases the Licensor expressly
216 | reserves any right to collect such royalties, including when
217 | the Licensed Material is used other than for NonCommercial
218 | purposes.
219 |
220 | Section 3 -- License Conditions.
221 |
222 | Your exercise of the Licensed Rights is expressly made subject to the
223 | following conditions.
224 |
225 | a. Attribution.
226 |
227 | 1. If You Share the Licensed Material (including in modified
228 | form), You must:
229 |
230 | a. retain the following if it is supplied by the Licensor
231 | with the Licensed Material:
232 |
233 | i. identification of the creator(s) of the Licensed
234 | Material and any others designated to receive
235 | attribution, in any reasonable manner requested by
236 | the Licensor (including by pseudonym if
237 | designated);
238 |
239 | ii. a copyright notice;
240 |
241 | iii. a notice that refers to this Public License;
242 |
243 | iv. a notice that refers to the disclaimer of
244 | warranties;
245 |
246 | v. a URI or hyperlink to the Licensed Material to the
247 | extent reasonably practicable;
248 |
249 | b. indicate if You modified the Licensed Material and
250 | retain an indication of any previous modifications; and
251 |
252 | c. indicate the Licensed Material is licensed under this
253 | Public License, and include the text of, or the URI or
254 | hyperlink to, this Public License.
255 |
256 | 2. You may satisfy the conditions in Section 3(a)(1) in any
257 | reasonable manner based on the medium, means, and context in
258 | which You Share the Licensed Material. For example, it may be
259 | reasonable to satisfy the conditions by providing a URI or
260 | hyperlink to a resource that includes the required
261 | information.
262 |
263 | 3. If requested by the Licensor, You must remove any of the
264 | information required by Section 3(a)(1)(A) to the extent
265 | reasonably practicable.
266 |
267 | 4. If You Share Adapted Material You produce, the Adapter's
268 | License You apply must not prevent recipients of the Adapted
269 | Material from complying with this Public License.
270 |
271 | Section 4 -- Sui Generis Database Rights.
272 |
273 | Where the Licensed Rights include Sui Generis Database Rights that
274 | apply to Your use of the Licensed Material:
275 |
276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
277 | to extract, reuse, reproduce, and Share all or a substantial
278 | portion of the contents of the database for NonCommercial purposes
279 | only;
280 |
281 | b. if You include all or a substantial portion of the database
282 | contents in a database in which You have Sui Generis Database
283 | Rights, then the database in which You have Sui Generis Database
284 | Rights (but not its individual contents) is Adapted Material; and
285 |
286 | c. You must comply with the conditions in Section 3(a) if You Share
287 | all or a substantial portion of the contents of the database.
288 |
289 | For the avoidance of doubt, this Section 4 supplements and does not
290 | replace Your obligations under this Public License where the Licensed
291 | Rights include other Copyright and Similar Rights.
292 |
293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
294 |
295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
305 |
306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
315 |
316 | c. The disclaimer of warranties and limitation of liability provided
317 | above shall be interpreted in a manner that, to the extent
318 | possible, most closely approximates an absolute disclaimer and
319 | waiver of all liability.
320 |
321 | Section 6 -- Term and Termination.
322 |
323 | a. This Public License applies for the term of the Copyright and
324 | Similar Rights licensed here. However, if You fail to comply with
325 | this Public License, then Your rights under this Public License
326 | terminate automatically.
327 |
328 | b. Where Your right to use the Licensed Material has terminated under
329 | Section 6(a), it reinstates:
330 |
331 | 1. automatically as of the date the violation is cured, provided
332 | it is cured within 30 days of Your discovery of the
333 | violation; or
334 |
335 | 2. upon express reinstatement by the Licensor.
336 |
337 | For the avoidance of doubt, this Section 6(b) does not affect any
338 | right the Licensor may have to seek remedies for Your violations
339 | of this Public License.
340 |
341 | c. For the avoidance of doubt, the Licensor may also offer the
342 | Licensed Material under separate terms or conditions or stop
343 | distributing the Licensed Material at any time; however, doing so
344 | will not terminate this Public License.
345 |
346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
347 | License.
348 |
349 | Section 7 -- Other Terms and Conditions.
350 |
351 | a. The Licensor shall not be bound by any additional or different
352 | terms or conditions communicated by You unless expressly agreed.
353 |
354 | b. Any arrangements, understandings, or agreements regarding the
355 | Licensed Material not stated herein are separate from and
356 | independent of the terms and conditions of this Public License.
357 |
358 | Section 8 -- Interpretation.
359 |
360 | a. For the avoidance of doubt, this Public License does not, and
361 | shall not be interpreted to, reduce, limit, restrict, or impose
362 | conditions on any use of the Licensed Material that could lawfully
363 | be made without permission under this Public License.
364 |
365 | b. To the extent possible, if any provision of this Public License is
366 | deemed unenforceable, it shall be automatically reformed to the
367 | minimum extent necessary to make it enforceable. If the provision
368 | cannot be reformed, it shall be severed from this Public License
369 | without affecting the enforceability of the remaining terms and
370 | conditions.
371 |
372 | c. No term or condition of this Public License will be waived and no
373 | failure to comply consented to unless expressly agreed to by the
374 | Licensor.
375 |
376 | d. Nothing in this Public License constitutes or may be interpreted
377 | as a limitation upon, or waiver of, any privileges and immunities
378 | that apply to the Licensor or You, including from the legal
379 | processes of any jurisdiction or authority.
380 |
381 | =======================================================================
382 |
383 | Creative Commons is not a party to its public
384 | licenses. Notwithstanding, Creative Commons may elect to apply one of
385 | its public licenses to material it publishes and in those instances
386 | will be considered the “Licensor.” The text of the Creative Commons
387 | public licenses is dedicated to the public domain under the CC0 Public
388 | Domain Dedication. Except for the limited purpose of indicating that
389 | material is shared under a Creative Commons public license or as
390 | otherwise permitted by the Creative Commons policies published at
391 | creativecommons.org/policies, Creative Commons does not authorize the
392 | use of the trademark "Creative Commons" or any other trademark or logo
393 | of Creative Commons without its prior written consent including,
394 | without limitation, in connection with any unauthorized modifications
395 | to any of its public licenses or any other arrangements,
396 | understandings, or agreements concerning use of licensed material. For
397 | the avoidance of doubt, this paragraph does not form part of the
398 | public licenses.
399 |
400 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # SiameseIM
3 | # Copyright (c) SenseTime. All Rights Reserved.
4 | # ------------------------------------------------------------------------
5 | # Modified from MAE (https://github.com/facebookresearch/mae)
6 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.
7 | # ------------------------------------------------------------------------
8 | # References:
9 | # DeiT: https://github.com/facebookresearch/deit
10 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
11 | # ------------------------------------------------------------------------
12 |
13 |
14 | import builtins
15 | import datetime
16 | import os
17 | import io
18 | import time
19 | from collections import defaultdict, deque
20 | from pathlib import Path
21 |
22 | import torch
23 | import torch.distributed as dist
24 | from torch._six import inf
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 |
28 |
29 | class SmoothedValue(object):
30 | """Track a series of values and provide access to smoothed values over a
31 | window or the global series average.
32 | """
33 |
34 | def __init__(self, window_size=20, fmt=None):
35 | if fmt is None:
36 | fmt = "{median:.4f} ({global_avg:.4f})"
37 | self.deque = deque(maxlen=window_size)
38 | self.total = 0.0
39 | self.count = 0
40 | self.fmt = fmt
41 |
42 | def update(self, value, n=1):
43 | self.deque.append(value)
44 | self.count += n
45 | self.total += value * n
46 |
47 | def synchronize_between_processes(self):
48 | """
49 | Warning: does not synchronize the deque!
50 | """
51 | if not is_dist_avail_and_initialized():
52 | return
53 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
54 | dist.barrier()
55 | dist.all_reduce(t)
56 | t = t.tolist()
57 | self.count = int(t[0])
58 | self.total = t[1]
59 |
60 | @property
61 | def median(self):
62 | d = torch.tensor(list(self.deque))
63 | return d.median().item()
64 |
65 | @property
66 | def avg(self):
67 | d = torch.tensor(list(self.deque), dtype=torch.float32)
68 | return d.mean().item()
69 |
70 | @property
71 | def global_avg(self):
72 | return self.total / self.count
73 |
74 | @property
75 | def max(self):
76 | return max(self.deque)
77 |
78 | @property
79 | def value(self):
80 | return self.deque[-1]
81 |
82 | def __str__(self):
83 | return self.fmt.format(
84 | median=self.median,
85 | avg=self.avg,
86 | global_avg=self.global_avg,
87 | max=self.max,
88 | value=self.value)
89 |
90 |
91 | class MetricLogger(object):
92 | def __init__(self, delimiter="\t"):
93 | self.meters = defaultdict(SmoothedValue)
94 | self.delimiter = delimiter
95 |
96 | def update(self, **kwargs):
97 | for k, v in kwargs.items():
98 | if v is None:
99 | continue
100 | if isinstance(v, torch.Tensor):
101 | v = v.item()
102 | assert isinstance(v, (float, int))
103 | self.meters[k].update(v)
104 |
105 | def __getattr__(self, attr):
106 | if attr in self.meters:
107 | return self.meters[attr]
108 | if attr in self.__dict__:
109 | return self.__dict__[attr]
110 | raise AttributeError("'{}' object has no attribute '{}'".format(
111 | type(self).__name__, attr))
112 |
113 | def __str__(self):
114 | loss_str = []
115 | for name, meter in self.meters.items():
116 | loss_str.append(
117 | "{}: {}".format(name, str(meter))
118 | )
119 | return self.delimiter.join(loss_str)
120 |
121 | def synchronize_between_processes(self):
122 | for meter in self.meters.values():
123 | meter.synchronize_between_processes()
124 |
125 | def add_meter(self, name, meter):
126 | self.meters[name] = meter
127 |
128 | def log_every(self, iterable, print_freq, header=None):
129 | i = 0
130 | if not header:
131 | header = ''
132 | start_time = time.time()
133 | end = time.time()
134 | iter_time = SmoothedValue(fmt='{avg:.4f}')
135 | data_time = SmoothedValue(fmt='{avg:.4f}')
136 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
137 | log_msg = [
138 | header,
139 | '[{0' + space_fmt + '}/{1}]',
140 | 'eta: {eta}',
141 | '{meters}',
142 | 'time: {time}',
143 | 'data: {data}'
144 | ]
145 | if torch.cuda.is_available():
146 | log_msg.append('max mem: {memory:.0f}')
147 | log_msg = self.delimiter.join(log_msg)
148 | MB = 1024.0 * 1024.0
149 | for obj in iterable:
150 | data_time.update(time.time() - end)
151 | yield obj
152 | iter_time.update(time.time() - end)
153 | if i % print_freq == 0 or i == len(iterable) - 1:
154 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
155 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
156 | if torch.cuda.is_available():
157 | print(log_msg.format(
158 | i, len(iterable), eta=eta_string,
159 | meters=str(self),
160 | time=str(iter_time), data=str(data_time),
161 | memory=torch.cuda.max_memory_allocated() / MB))
162 | else:
163 | print(log_msg.format(
164 | i, len(iterable), eta=eta_string,
165 | meters=str(self),
166 | time=str(iter_time), data=str(data_time)))
167 | i += 1
168 | end = time.time()
169 | total_time = time.time() - start_time
170 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
171 | print('{} Total time: {} ({:.4f} s / it)'.format(
172 | header, total_time_str, total_time / len(iterable)))
173 |
174 |
175 | def setup_for_distributed(is_master):
176 | """
177 | This function disables printing when not in master process
178 | """
179 | builtin_print = builtins.print
180 |
181 | def print(*args, **kwargs):
182 | force = kwargs.pop('force', False)
183 | # force = force or (get_world_size() > 8)
184 | if is_master or force:
185 | now = datetime.datetime.now().time()
186 | builtin_print('[{}] '.format(now), end='') # print with time stamp
187 | builtin_print(*args, **kwargs)
188 |
189 | builtins.print = print
190 |
191 |
192 | def is_dist_avail_and_initialized():
193 | if not dist.is_available():
194 | return False
195 | if not dist.is_initialized():
196 | return False
197 | return True
198 |
199 |
200 | def get_world_size():
201 | if not is_dist_avail_and_initialized():
202 | return 1
203 | return dist.get_world_size()
204 |
205 |
206 | def get_rank():
207 | if not is_dist_avail_and_initialized():
208 | return 0
209 | return dist.get_rank()
210 |
211 |
212 | def is_main_process():
213 | return get_rank() == 0
214 |
215 |
216 | def save_on_master(*args, **kwargs):
217 | if is_main_process():
218 | torch.save(*args, **kwargs)
219 |
220 |
221 | def init_distributed_mode(args):
222 | if args.dist_on_itp:
223 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
224 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
225 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
226 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
227 | os.environ['LOCAL_RANK'] = str(args.gpu)
228 | os.environ['RANK'] = str(args.rank)
229 | os.environ['WORLD_SIZE'] = str(args.world_size)
230 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
231 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
232 | args.rank = int(os.environ["RANK"])
233 | args.world_size = int(os.environ['WORLD_SIZE'])
234 | args.gpu = int(os.environ['LOCAL_RANK'])
235 | elif 'SLURM_PROCID' in os.environ:
236 | args.rank = int(os.environ['SLURM_PROCID'])
237 | args.world_size = int(os.environ['SLURM_NTASKS'])
238 | node_list = os.environ['SLURM_STEP_NODELIST']
239 | num_gpus = torch.cuda.device_count()
240 | args.gpu = args.rank % torch.cuda.device_count()
241 | torch.cuda.set_device(args.rank % num_gpus)
242 | import subprocess
243 | addr = subprocess.getoutput(
244 | f'scontrol show hostname {node_list} | head -n1')
245 | # specify master port
246 | if hasattr(args, 'port'):
247 | os.environ['MASTER_PORT'] = str(args.port)
248 | elif 'MASTER_PORT' in os.environ:
249 | pass # use MASTER_PORT in the environment variable
250 | else:
251 | # 29500 is torch.distributed default port
252 | os.environ['MASTER_PORT'] = '28506'
253 | # use MASTER_ADDR in the environment variable if it already exists
254 | if 'MASTER_ADDR' not in os.environ:
255 | os.environ['MASTER_ADDR'] = addr
256 | os.environ['WORLD_SIZE'] = str(args.world_size)
257 | os.environ['LOCAL_RANK'] = str(args.rank % num_gpus)
258 | os.environ['LOCAL_SIZE'] = str(num_gpus)
259 | os.environ['RANK'] = str(args.rank)
260 | # dist.init_process_group(backend='nccl')
261 | else:
262 | print('Not using distributed mode')
263 | setup_for_distributed(is_master=True) # hack
264 | args.distributed = False
265 | return
266 |
267 | args.distributed = True
268 |
269 | torch.cuda.set_device(args.gpu)
270 | args.dist_backend = 'nccl'
271 | print('| distributed init (rank {}): {}, gpu {}'.format(
272 | args.rank, args.dist_url, args.gpu), flush=True)
273 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
274 | world_size=args.world_size, rank=args.rank)
275 | torch.distributed.barrier()
276 | setup_for_distributed(args.rank == 0)
277 |
278 |
279 | class NativeScalerWithGradNormCount:
280 | state_dict_key = "amp_scaler"
281 |
282 | def __init__(self, enabled=True, growth_interval=2000):
283 | self.enabled = enabled
284 | self._scaler = torch.cuda.amp.GradScaler(
285 | enabled=enabled, growth_interval=growth_interval)
286 |
287 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
288 | self._scaler.scale(loss).backward(create_graph=create_graph)
289 | if update_grad:
290 | if clip_grad is not None:
291 | assert parameters is not None
292 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
293 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
294 | else:
295 | self._scaler.unscale_(optimizer)
296 | norm = get_grad_norm_(parameters)
297 | self._scaler.step(optimizer)
298 | self._scaler.update()
299 | else:
300 | norm = None
301 | return norm
302 |
303 | def state_dict(self):
304 | return self._scaler.state_dict()
305 |
306 | def load_state_dict(self, state_dict):
307 | if self.enabled:
308 | self._scaler.load_state_dict(state_dict)
309 |
310 |
311 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
312 | if isinstance(parameters, torch.Tensor):
313 | parameters = [parameters]
314 | parameters = [p for p in parameters if p.grad is not None]
315 | norm_type = float(norm_type)
316 | if len(parameters) == 0:
317 | return torch.tensor(0.)
318 | device = parameters[0].grad.device
319 | if norm_type == inf:
320 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
321 | else:
322 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
323 | return total_norm
324 |
325 |
326 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler,
327 | latest=False,
328 | latest_postfix='latest'):
329 | output_dir = Path(args.output_dir)
330 | epoch_name = str(epoch)
331 | if loss_scaler is not None:
332 | checkpoint_paths = []
333 | if latest:
334 | checkpoint_paths = [output_dir / (f'checkpoint-{latest_postfix}.pth')]
335 | else:
336 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
337 | to_save = {
338 | 'model': model_without_ddp.state_dict(),
339 | 'optimizer': optimizer.state_dict(),
340 | 'epoch': epoch,
341 | 'scaler': loss_scaler.state_dict(),
342 | 'args': args,
343 | }
344 | for checkpoint_path in checkpoint_paths:
345 | save_on_master(to_save, checkpoint_path)
346 | else:
347 | client_state = {'epoch': epoch}
348 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
349 |
350 |
351 | def load_model(args, model_without_ddp, optimizer, loss_scaler):
352 | if args.resume:
353 | if args.resume.startswith('https'):
354 | checkpoint = torch.hub.load_state_dict_from_url(
355 | args.resume, map_location='cpu', check_hash=True)
356 | else:
357 | checkpoint = torch.load(args.resume, map_location='cpu')
358 | model_without_ddp.load_state_dict(checkpoint['model'])
359 | print("Resume checkpoint %s" % args.resume)
360 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
361 | optimizer.load_state_dict(checkpoint['optimizer'])
362 | args.start_epoch = checkpoint['epoch'] + 1
363 | if 'scaler' in checkpoint:
364 | loss_scaler.load_state_dict(checkpoint['scaler'])
365 | print("With optim & sched!")
366 |
367 |
368 | def auto_load_model(args, model_without_ddp, optimizer, loss_scaler):
369 | # torch.amp
370 | output_dir = Path(args.output_dir)
371 |
372 | if args.auto_resume and len(args.resume) == 0:
373 | import glob
374 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
375 | latest_ckpt = -1
376 | for ckpt in all_checkpoints:
377 | t = ckpt.split('-')[-1].split('.')[0]
378 | if t.isdigit():
379 | latest_ckpt = max(int(t), latest_ckpt)
380 | if latest_ckpt >= 0:
381 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
382 | if os.path.exists(os.path.join(output_dir, 'checkpoint-latest.pth')):
383 | args.resume = os.path.join(output_dir, 'checkpoint-latest.pth')
384 | print("Auto resume checkpoint: %s" % args.resume)
385 |
386 | if args.resume:
387 | if args.resume.startswith('https'):
388 | checkpoint = torch.hub.load_state_dict_from_url(
389 | args.resume, map_location='cpu', check_hash=True)
390 | else:
391 | checkpoint = torch.load(args.resume, map_location='cpu')
392 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
393 | print("Resume checkpoint %s" % args.resume)
394 | if 'optimizer' in checkpoint and 'epoch' in checkpoint:
395 | optimizer.load_state_dict(checkpoint['optimizer'])
396 | args.start_epoch = checkpoint['epoch'] + 1
397 | if 'scaler' in checkpoint:
398 | loss_scaler.load_state_dict(checkpoint['scaler'])
399 | print("With optim & sched!")
400 |
401 |
402 | def all_reduce_mean(x):
403 | world_size = get_world_size()
404 | if world_size > 1:
405 | x_reduce = torch.tensor(x).cuda()
406 | dist.all_reduce(x_reduce)
407 | x_reduce /= world_size
408 | return x_reduce.item()
409 | else:
410 | return x
411 |
412 |
413 | class LayerNorm(nn.LayerNorm):
414 |
415 | @torch.cuda.amp.autocast(enabled=False)
416 | def forward(self, input):
417 | return super(LayerNorm, self).forward(input.float())
418 |
419 |
420 | def add_lr_weight_decay(model, weight_decay=1e-5, lr=1e-4, skip_list=()):
421 | decay = []
422 | no_decay = []
423 | no_decay_names = []
424 | decay_small_lr = []
425 | decay_small_lr_names = []
426 | for name, param in model.named_parameters():
427 | if not param.requires_grad:
428 | continue # frozen weights
429 | if 'offset' in name:
430 | decay_small_lr.append(param)
431 | decay_small_lr_names.append(name)
432 |
433 | elif len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
434 | no_decay.append(param)
435 | no_decay_names.append(name)
436 | else:
437 | decay.append(param)
438 | print(f'decay_small_lr_names: {decay_small_lr_names}')
439 | print(f'no_decay_names: {no_decay_names}')
440 | return [
441 | {'params': no_decay, 'weight_decay': 0., 'lr': lr},
442 | {'params': decay, 'weight_decay': weight_decay, 'lr': lr},
443 | {'params': decay_small_lr, 'weight_decay': weight_decay, 'lr': lr*0.1},
444 | ]
445 |
446 |
447 |
448 | import math
449 | from torch.utils.data.sampler import Sampler
450 |
451 |
452 | class NodeDistributedSampler(Sampler):
453 | """Sampler that restricts data loading to a subset of the dataset.
454 | It is especially useful in conjunction with
455 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
456 | process can pass a DistributedSampler instance as a DataLoader sampler,
457 | and load a subset of the original dataset that is exclusive to it.
458 | .. note::
459 | Dataset is assumed to be of constant size.
460 | Arguments:
461 | dataset: Dataset used for sampling.
462 | num_replicas (optional): Number of processes participating in
463 | distributed training.
464 | rank (optional): Rank of the current process within num_replicas.
465 | """
466 |
467 | def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True):
468 | if num_replicas is None:
469 | if not dist.is_available():
470 | raise RuntimeError("Requires distributed package to be available")
471 | num_replicas = dist.get_world_size()
472 | if rank is None:
473 | if not dist.is_available():
474 | raise RuntimeError("Requires distributed package to be available")
475 | rank = dist.get_rank()
476 | if local_rank is None:
477 | local_rank = int(os.environ.get('LOCAL_RANK', 0))
478 | if local_size is None:
479 | local_size = int(os.environ.get('LOCAL_SIZE', 1))
480 | self.dataset = dataset
481 | self.shuffle = shuffle
482 | self.num_replicas = num_replicas
483 | self.num_parts = local_size
484 | self.rank = rank
485 | self.local_rank = local_rank
486 | self.epoch = 0
487 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
488 | self.total_size = self.num_samples * self.num_replicas
489 |
490 | self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts
491 |
492 | def __iter__(self):
493 | if self.shuffle:
494 | # deterministically shuffle based on epoch
495 | g = torch.Generator()
496 | g.manual_seed(self.epoch)
497 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
498 | else:
499 | indices = torch.arange(len(self.dataset)).tolist()
500 | indices = [i for i in indices if i % self.num_parts == self.local_rank]
501 |
502 | # add extra samples to make it evenly divisible
503 | indices += indices[:(self.total_size_parts - len(indices))]
504 | assert len(indices) == self.total_size_parts
505 |
506 | # subsample
507 | indices = indices[self.rank // self.num_parts:self.total_size_parts:self.num_replicas // self.num_parts]
508 | assert len(indices) == self.num_samples
509 |
510 | return iter(indices)
511 |
512 | def __len__(self):
513 | return self.num_samples
514 |
515 | def set_epoch(self, epoch):
516 | self.epoch = epoch
517 |
518 |
519 | class GatherLayer(torch.autograd.Function):
520 | """Gather tensors from all process, supporting backward propagation.
521 | """
522 |
523 | @staticmethod
524 | def forward(ctx, input):
525 | ctx.save_for_backward(input)
526 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
527 | dist.all_gather(output, input)
528 | return torch.stack(output, 0)
529 |
530 | @staticmethod
531 | def backward(ctx, grads):
532 | input, = ctx.saved_tensors
533 | dist.all_reduce(grads)
534 | grad_out = torch.zeros_like(input)
535 | grad_out[:] = grads[dist.get_rank()]
536 | return grad_out
537 |
538 |
539 | class LabelSmoothingCrossEntropy(nn.Module):
540 | """
541 | NLL loss with label smoothing.
542 | """
543 | def __init__(self, smoothing=0.1):
544 | """
545 | Constructor for the LabelSmoothing module.
546 | :param smoothing: label smoothing factor
547 | """
548 | super(LabelSmoothingCrossEntropy, self).__init__()
549 | assert smoothing < 1.0
550 | self.smoothing = smoothing
551 | self.confidence = 1. - smoothing
552 |
553 | def forward(self, x, target, reduction='mean'):
554 | logprobs = F.log_softmax(x, dim=-1)
555 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
556 | nll_loss = nll_loss.squeeze(1)
557 | smooth_loss = -logprobs.mean(dim=-1)
558 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
559 | if reduction == 'mean':
560 | return loss.mean()
561 | elif reduction == 'none':
562 | return loss
563 | else:
564 | raise NotImplementedError
565 |
566 |
567 | class LabelSmoothingCrossEntropyWithSoftTarget(nn.Module):
568 | """
569 | NLL loss with label smoothing.
570 | """
571 | def __init__(self, smoothing=0.1):
572 | """
573 | Constructor for the LabelSmoothing module.
574 | :param smoothing: label smoothing factor
575 | """
576 | super(LabelSmoothingCrossEntropyWithSoftTarget, self).__init__()
577 | assert smoothing < 1.0
578 | self.smoothing = smoothing
579 | self.confidence = 1. - smoothing
580 |
581 | def forward(self, x, target, reduction='mean'):
582 | logprobs = F.log_softmax(x, dim=-1)
583 | # nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
584 | # nll_loss = nll_loss.squeeze(1)
585 | nll_loss = - (logprobs * target).sum(dim=-1)
586 | smooth_loss = -logprobs.mean(dim=-1)
587 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
588 | if reduction == 'mean':
589 | return loss.mean()
590 | elif reduction == 'none':
591 | return loss
592 | else:
593 | raise NotImplementedError
594 |
--------------------------------------------------------------------------------