├── .DS_Store ├── classification ├── requirements.txt ├── .DS_Store ├── configs │ ├── vvt │ │ ├── vvt_test.py │ │ ├── vvt_tiny.py │ │ ├── vvt_large.py │ │ ├── vvt_medium.py │ │ └── vvt_small.py │ ├── .DS_Store │ ├── pvt │ │ ├── pvt_large.py │ │ ├── pvt_tiny.py │ │ ├── pvt_medium.py │ │ └── pvt_small.py │ ├── pvt_v2 │ │ ├── pvt_v2_b0.py │ │ ├── pvt_v2_b1.py │ │ ├── pvt_v2_b2.py │ │ ├── pvt_v2_b3.py │ │ ├── pvt_v2_b4.py │ │ ├── pvt_v2_b5.py │ │ └── pvt_v2_b2_li.py │ └── vit │ │ ├── vit.py │ │ └── vit_tiny_patch16_224.py ├── mcloader │ ├── __init__.py │ ├── imagenet.py │ ├── data_prefetcher.py │ ├── mcloader.py │ ├── classification.py │ └── image_list.py ├── hubconf.py ├── submit.sh ├── train_imagenet_multi_node.sh ├── train_vit.sh ├── train_imagenet.sh ├── train.sh ├── test.sh ├── cifar_node_submit.sh ├── modelsize_estimate.py ├── node_submit.sh ├── train_cifar10.sh ├── samplers.py ├── losses.py ├── get_flops.py ├── datasets.py ├── utils.py ├── engine.py ├── pvt.py ├── main.py ├── pvt_v2.py └── vvt.py ├── proc_model.py ├── .gitignore ├── README.md └── LICENSE /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenNLPLab/Vicinity-Vision-Transformer/HEAD/.DS_Store -------------------------------------------------------------------------------- /classification/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1+cu101 2 | torchvision==0.8.1 3 | timm==0.4.12 4 | mmcv==1.3.1 5 | -------------------------------------------------------------------------------- /classification/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenNLPLab/Vicinity-Vision-Transformer/HEAD/classification/.DS_Store -------------------------------------------------------------------------------- /classification/configs/vvt/vvt_test.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='vvt_test', 3 | drop_path=0.1, 4 | clip_grad=None, 5 | ) -------------------------------------------------------------------------------- /classification/configs/vvt/vvt_tiny.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='vvt_tiny', 3 | drop_path=0.1, 4 | clip_grad=None, 5 | ) -------------------------------------------------------------------------------- /classification/configs/vvt/vvt_large.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='vvt_large', 3 | drop_path=0.3, 4 | clip_grad=1.0, 5 | ) -------------------------------------------------------------------------------- /classification/mcloader/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification import ClassificationDataset 2 | from .data_prefetcher import DataPrefetcher -------------------------------------------------------------------------------- /classification/configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenNLPLab/Vicinity-Vision-Transformer/HEAD/classification/configs/.DS_Store -------------------------------------------------------------------------------- /classification/configs/vvt/vvt_medium.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='vvt_medium', 3 | drop_path=0.3, 4 | clip_grad=1.0, 5 | ) 6 | -------------------------------------------------------------------------------- /classification/configs/vvt/vvt_small.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='vvt_small', 3 | drop_path=0.3, 4 | clip_grad=1.0, 5 | ) 6 | -------------------------------------------------------------------------------- /classification/configs/pvt/pvt_large.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='pvt_large', 3 | drop_path=0.3, 4 | clip_grad=1.0, 5 | output_dir='checkpoints/pvt_large', 6 | ) -------------------------------------------------------------------------------- /classification/configs/pvt/pvt_tiny.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='pvt_tiny', 3 | drop_path=0.1, 4 | clip_grad=None, 5 | output_dir='checkpoints/pvt_tiny', 6 | ) -------------------------------------------------------------------------------- /classification/configs/pvt/pvt_medium.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='pvt_medium', 3 | drop_path=0.3, 4 | clip_grad=1.0, 5 | output_dir='checkpoints/pvt_medium', 6 | ) -------------------------------------------------------------------------------- /classification/configs/pvt/pvt_small.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='pvt_small', 3 | drop_path=0.1, 4 | clip_grad=None, 5 | output_dir='checkpoints/pvt_small', 6 | ) -------------------------------------------------------------------------------- /classification/configs/pvt_v2/pvt_v2_b0.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='pvt_v2_b0', 3 | drop_path=0.1, 4 | clip_grad=None, 5 | output_dir='checkpoints/pvt_v2_b0', 6 | ) -------------------------------------------------------------------------------- /classification/configs/pvt_v2/pvt_v2_b1.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='pvt_v2_b1', 3 | drop_path=0.1, 4 | clip_grad=None, 5 | output_dir='checkpoints/pvt_v2_b1', 6 | ) -------------------------------------------------------------------------------- /classification/configs/pvt_v2/pvt_v2_b2.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='pvt_v2_b2', 3 | drop_path=0.1, 4 | clip_grad=None, 5 | output_dir='checkpoints/pvt_v2_b2', 6 | ) -------------------------------------------------------------------------------- /classification/configs/pvt_v2/pvt_v2_b3.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='pvt_v2_b3', 3 | drop_path=0.3, 4 | clip_grad=1.0, 5 | output_dir='checkpoints/pvt_v2_b3', 6 | ) 7 | -------------------------------------------------------------------------------- /classification/configs/pvt_v2/pvt_v2_b4.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='pvt_v2_b4', 3 | drop_path=0.3, 4 | clip_grad=1.0, 5 | output_dir='checkpoints/pvt_v2_b4', 6 | ) 7 | -------------------------------------------------------------------------------- /classification/configs/pvt_v2/pvt_v2_b5.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='pvt_v2_b5', 3 | drop_path=0.3, 4 | clip_grad=1.0, 5 | output_dir='checkpoints/pvt_v2_b5', 6 | ) 7 | -------------------------------------------------------------------------------- /classification/hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | from models import * 4 | 5 | dependencies = ["torch", "torchvision", "timm"] 6 | -------------------------------------------------------------------------------- /classification/configs/pvt_v2/pvt_v2_b2_li.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='pvt_v2_b2_li', 3 | drop_path=0.1, 4 | clip_grad=None, 5 | output_dir='checkpoints/pvt_v2_b2_li', 6 | ) -------------------------------------------------------------------------------- /classification/configs/vit/vit.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='vit_tiny_patch16_224', 3 | drop_path=0.3, 4 | clip_grad=1.0, 5 | output_dir='checkpoints/vit_tiny_patch16_224', 6 | ) 7 | -------------------------------------------------------------------------------- /classification/configs/vit/vit_tiny_patch16_224.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='vit_tiny_patch16_224', 3 | drop_path=0.3, 4 | clip_grad=1.0, 5 | output_dir='checkpoints/vit_tiny_patch16_224', 6 | ) 7 | -------------------------------------------------------------------------------- /classification/mcloader/imagenet.py: -------------------------------------------------------------------------------- 1 | from .image_list import ImageList 2 | 3 | 4 | class ImageNet(ImageList): 5 | 6 | def __init__(self, root, list_file, memcached, mclient_path): 7 | super(ImageNet, self).__init__( 8 | root, list_file, memcached, mclient_path) 9 | -------------------------------------------------------------------------------- /classification/submit.sh: -------------------------------------------------------------------------------- 1 | export NCCL_SOCKET_IFNAME=bond0 2 | export NCCL_IB_DISABLE=1 3 | export NCCL_P2P_DISABLE=1 4 | 5 | python run_with_submitit.py --ngpus 8 --nodes 1 --partition MMG --comment spring-submit --model pvt_huge_v2 --batch-size 32 --use-mcloader --output_dir checkpoints/pvt_huge_v2_2n/ --config configs/pvc/pvc_b5.py -------------------------------------------------------------------------------- /classification/train_imagenet_multi_node.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export NCCL_LL_THRESHOLD=0 3 | 4 | CONFIG=$2 5 | GPUS=$1 6 | PORT=${PORT:-8888} 7 | 8 | spring.submit arun --gres=gpu:$GPUS -n2 --ntasks-per-node=$GPUS --gpu \ 9 | " 10 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 11 | --use_env main.py --config $CONFIG --data-set IMNET \ 12 | " -------------------------------------------------------------------------------- /classification/train_vit.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export NCCL_LL_THRESHOLD=0 3 | # export NCCL_IB_DISABLE=1 4 | 5 | CONFIG=$2 6 | GPUS=$1 7 | NODES=2 8 | batch_size=$3 9 | PORT=${PORT:-8889} 10 | RESUME=$4 11 | 12 | spring.submit arun --gpu -n$GPUS --cpus-per-task 5 --ntasks-per-node 8 -p MMG \ 13 | " 14 | python main.py --config $CONFIG --data-set IMNET --batch-size $batch_size --dist-eval --input-size 224 \ 15 | --resume $RESUME \ 16 | " 17 | -------------------------------------------------------------------------------- /classification/train_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export NCCL_LL_THRESHOLD=0 3 | 4 | CONFIG=$2 5 | GPUS=$1 6 | PORT=${PORT:-8889} 7 | batch_size=$3 8 | OUTPUT_DIR=$4 9 | 10 | spring.submit arun --gres=gpu:$GPUS -n1 --ntasks-per-node=$GPUS --gpu -p \ 11 | " 12 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 13 | --use_env main.py --config $CONFIG --data-set IMNET --batch-size $batch_size --output_dir $OUTPUT_DIR \ 14 | " -------------------------------------------------------------------------------- /proc_model.py: -------------------------------------------------------------------------------- 1 | '''turn the model pth file to weight only, no optimizer states''' 2 | import torch 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser(description='Hyperparams') 6 | parser.add_argument('filename', nargs='?', type=str, default=None) 7 | 8 | args = parser.parse_args() 9 | 10 | pth = torch.load(args.filename, map_location=torch.device('cpu')) 11 | checkpoint = pth['model'] 12 | torch.save(checkpoint, args.filename.replace(".pth", "_.pth")) 13 | print("finished") 14 | -------------------------------------------------------------------------------- /classification/train.sh: -------------------------------------------------------------------------------- 1 | export NCCL_LL_THRESHOLD=0 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-99998} 6 | N_TASK_PER_NODE=$(($2<8?$2:8)) 7 | batch_size=20 8 | OUTPUT_DIR=checkpoints/test/ 9 | DATA_PATH=/your/data/path/ 10 | 11 | 12 | spring.submit arun --gpu -n$GPUS --cpus-per-task 4 --ntasks-per-node $N_TASK_PER_NODE -p MMG --quotatype=auto \ 13 | " 14 | python main.py --config $CONFIG --data-set IMNET --batch-size $batch_size --dist-eval --data-path $DATA_PATH \ 15 | --output_dir $OUTPUT_DIR --port $PORT \ 16 | " 17 | 18 | # local multi/single GPU 19 | # python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 20 | # --use_env main.py --config $CONFIG --data-set IMNET --batch-size $batch_size --output_dir $OUTPUT_DIR \ 21 | 22 | 23 | -------------------------------------------------------------------------------- /classification/test.sh: -------------------------------------------------------------------------------- 1 | export NCCL_LL_THRESHOLD=0 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | RESUME=$3 6 | PORT=${PORT:-99998} 7 | N_TASK_PER_NODE=$(($2<8?$2:8)) 8 | batch_size=20 9 | OUTPUT_DIR=checkpoints/test/ 10 | DATA_PATH=/your/data/path/ 11 | 12 | 13 | spring.submit arun --gpu -n$GPUS --cpus-per-task 4 --ntasks-per-node $N_TASK_PER_NODE -p MMG --quotatype=auto \ 14 | " 15 | python main.py --config $CONFIG --data-set IMNET --batch-size $batch_size --dist-eval --data-path $DATA_PATH \ 16 | --output_dir $OUTPUT_DIR --port $PORT --resume $RESUME --eval \ 17 | " 18 | 19 | # local multi/single GPU 20 | # python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 21 | # --use_env main.py --config $CONFIG --data-set IMNET --batch-size $batch_size --output_dir $OUTPUT_DIR --resume $RESUME --eval\ 22 | 23 | -------------------------------------------------------------------------------- /classification/cifar_node_submit.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export NCCL_LL_THRESHOLD=0 3 | # export NCCL_IB_DISABLE=1 4 | 5 | CONFIG=$2 6 | GPUS=$1 7 | NODES=2 8 | batch_size=$3 9 | PORT=${PORT:-8889} 10 | OUTPUT_DIR=$4 11 | RESUME=$5 12 | 13 | 14 | if [ -z $RESUME ] 15 | then 16 | spring.submit arun --gpu -n$GPUS --cpus-per-task 4 --ntasks-per-node 4 -p MMG \ 17 | " 18 | python main.py --config $CONFIG --data-set CIFAR --batch-size $batch_size --dist-eval \ 19 | --output_dir $OUTPUT_DIR \ 20 | " 21 | else 22 | echo $RESUME 23 | spring.submit arun --gpu -n$GPUS --cpus-per-task 4 --ntasks-per-node 4 -p MMG \ 24 | " 25 | python main.py --config $CONFIG --data-set CIFAR --batch-size $batch_size --dist-eval \ 26 | --output_dir $OUTPUT_DIR --resume $RESUME \ 27 | " 28 | fi 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /classification/mcloader/data_prefetcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DataPrefetcher: 5 | def __init__(self, loader): 6 | self.loader = iter(loader) 7 | self.stream = torch.cuda.Stream() 8 | self.preload() 9 | 10 | def preload(self): 11 | try: 12 | self.next_input, self.next_target = next(self.loader) 13 | except StopIteration: 14 | self.next_input = None 15 | self.next_target = None 16 | return 17 | 18 | with torch.cuda.stream(self.stream): 19 | self.next_input = self.next_input.cuda(non_blocking=True) 20 | self.next_target = self.next_target.cuda(non_blocking=True) 21 | 22 | def next(self): 23 | torch.cuda.current_stream().wait_stream(self.stream) 24 | input = self.next_input 25 | target = self.next_target 26 | if input is not None: 27 | self.preload() 28 | return input, target 29 | -------------------------------------------------------------------------------- /classification/mcloader/mcloader.py: -------------------------------------------------------------------------------- 1 | import io 2 | from PIL import Image 3 | try: 4 | import mc 5 | except ImportError as E: 6 | pass 7 | 8 | 9 | def pil_loader(img_str): 10 | buff = io.BytesIO(img_str) 11 | return Image.open(buff) 12 | 13 | 14 | class McLoader(object): 15 | 16 | def __init__(self, mclient_path): 17 | assert mclient_path is not None, \ 18 | "Please specify 'data_mclient_path' in the config." 19 | self.mclient_path = mclient_path 20 | server_list_config_file = "{}/server_list.conf".format( 21 | self.mclient_path) 22 | client_config_file = "{}/client.conf".format(self.mclient_path) 23 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, 24 | client_config_file) 25 | 26 | def __call__(self, fn): 27 | try: 28 | img_value = mc.pyvector() 29 | self.mclient.Get(fn, img_value) 30 | img_value_str = mc.ConvertBuffer(img_value) 31 | img = pil_loader(img_value_str) 32 | except: 33 | print('Read image failed ({})'.format(fn)) 34 | return None 35 | else: 36 | return img -------------------------------------------------------------------------------- /classification/modelsize_estimate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | def modelsize(model, input, type_size=4): 7 | para = sum([np.prod(list(p.size())) for p in model.parameters()]) 8 | print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000)) 9 | 10 | input_ = input.clone() 11 | input_.requires_grad_(requires_grad=False) 12 | 13 | mods = list(model.modules()) 14 | out_sizes = [] 15 | 16 | for i in range(1, len(mods)): 17 | m = mods[i] 18 | if isinstance(m, nn.ReLU): 19 | if m.inplace: 20 | continue 21 | out = m(input_) 22 | out_sizes.append(np.array(out.size())) 23 | input_ = out 24 | 25 | total_nums = 0 26 | for i in range(len(out_sizes)): 27 | s = out_sizes[i] 28 | nums = np.prod(np.array(s)) 29 | total_nums += nums 30 | 31 | print('Model {} : intermedite variables: {:3f} M (without backward)' 32 | .format(model._get_name(), total_nums * type_size / 1000 / 1000)) 33 | print('Model {} : intermedite variables: {:3f} M (with backward)' 34 | .format(model._get_name(), total_nums * type_size*2 / 1000 / 1000)) -------------------------------------------------------------------------------- /classification/mcloader/classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from .imagenet import ImageNet 4 | 5 | 6 | class ClassificationDataset(Dataset): 7 | """Dataset for classification. 8 | """ 9 | 10 | def __init__(self, split='train', pipeline=None): 11 | if split == 'train': 12 | self.data_source = ImageNet(root='/mnt/cache/share/images/train', 13 | list_file='/mnt/cache/share/images/meta/train.txt', 14 | memcached=True, 15 | mclient_path='/mnt/lustre/sunweixuan') 16 | else: 17 | self.data_source = ImageNet(root='/mnt/cache/share/images/val', 18 | list_file='/mnt/cache/share/images/meta/val.txt', 19 | memcached=True, 20 | mclient_path='/mnt/lustre/sunweixuan') 21 | self.pipeline = pipeline 22 | 23 | def __len__(self): 24 | return self.data_source.get_length() 25 | 26 | def __getitem__(self, idx): 27 | img, target = self.data_source.get_sample(idx) 28 | if self.pipeline is not None: 29 | img = self.pipeline(img) 30 | 31 | return img, target 32 | -------------------------------------------------------------------------------- /classification/node_submit.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export NCCL_LL_THRESHOLD=0 3 | # export NCCL_IB_DISABLE=1 4 | 5 | CONFIG=$2 6 | GPUS=$1 7 | NODES=2 8 | batch_size=$3 9 | PORT=${PORT:-8889} 10 | OUTPUT_DIR=$4 11 | RESUME=$5 12 | 13 | 14 | if [ -z $RESUME ] 15 | then 16 | spring.submit arun --gpu -n$GPUS --cpus-per-task 4 --ntasks-per-node 8 -p MMG \ 17 | " 18 | python main.py --config $CONFIG --data-set IMNET --batch-size $batch_size --dist-eval \ 19 | --output_dir $OUTPUT_DIR \ 20 | " 21 | else 22 | echo $RESUME 23 | spring.submit arun --gpu -n$GPUS --cpus-per-task 4 --ntasks-per-node 8 -p MMG \ 24 | " 25 | python main.py --config $CONFIG --data-set IMNET --batch-size $batch_size --dist-eval \ 26 | --output_dir $OUTPUT_DIR --resume $RESUME \ 27 | " 28 | fi 29 | 30 | 31 | 32 | 33 | # spring.submit arun --gpu -n$GPUS --cpus-per-task 4 --ntasks-per-node $GPUS -p MMG \ 34 | # " 35 | # python main.py --config $CONFIG --data-set IMNET --batch-size $batch_size --dist-eval --input-size 224 \ 36 | # " 37 | 38 | # --nproc_per_node=8 39 | 40 | # spring.submit arun --gpu -n$GPUS --cpus-per-task 5 --ntasks-per-node 8 -p MMG \ 41 | # " 42 | # python main.py --config $CONFIG --data-set IMNET --batch-size $batch_size --dist-eval \ 43 | # --resume $RESUME 44 | # " 45 | 46 | -------------------------------------------------------------------------------- /classification/mcloader/image_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | from .mcloader import McLoader 5 | 6 | 7 | class ImageList(object): 8 | 9 | def __init__(self, root, list_file, memcached=False, mclient_path=None): 10 | with open(list_file, 'r') as f: 11 | lines = f.readlines() 12 | self.has_labels = len(lines[0].split()) == 2 13 | if self.has_labels: 14 | self.fns, self.labels = zip(*[l.strip().split() for l in lines]) 15 | self.labels = [int(l) for l in self.labels] 16 | else: 17 | self.fns = [l.strip() for l in lines] 18 | self.fns = [os.path.join(root, fn) for fn in self.fns] 19 | self.memcached = memcached 20 | self.mclient_path = mclient_path 21 | self.initialized = False 22 | 23 | def _init_memcached(self): 24 | if not self.initialized: 25 | assert self.mclient_path is not None 26 | self.mc_loader = McLoader(self.mclient_path) 27 | self.initialized = True 28 | 29 | def get_length(self): 30 | return len(self.fns) 31 | 32 | def get_sample(self, idx): 33 | if self.memcached: 34 | self._init_memcached() 35 | if self.memcached: 36 | img = self.mc_loader(self.fns[idx]) 37 | else: 38 | img = Image.open(self.fns[idx]) 39 | img = img.convert('RGB') 40 | if self.has_labels: 41 | target = self.labels[idx] 42 | return img, target 43 | else: 44 | return img 45 | -------------------------------------------------------------------------------- /classification/train_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export NCCL_LL_THRESHOLD=0 3 | 4 | CONFIG=$2 5 | GPUS=$1 6 | PORT=${PORT:-8889} 7 | batch_size=$3 8 | OUTPUT_DIR=$4 9 | 10 | # spring.submit arun --gres=gpu:$GPUS -n1 --ntasks-per-node=$GPUS --gpu -p MMG \ 11 | # " 12 | # python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 13 | # --use_env main.py --config $CONFIG --data-set CIFAR --batch-size $batch_size --output_dir $OUTPUT_DIR \ 14 | # " 15 | 16 | CONFIG=configs/pvc_v2/pvc_v2_b1.py 17 | # CONFIG=configs/pvc/pvc_b0.py 18 | CONFIG=configs/vvt/vvt_test.py 19 | # CONFIG=configs/vvt/vvt_tiny.py 20 | # CONFIG=configs/vvt/vvt_small.py 21 | # CONFIG=configs/vvt/vvt_medium.py 22 | # CONFIG=configs/vvt/vvt_large.py 23 | # CONFIG=configs/pvc_v2/pvc_v2_b5.py 24 | GPUS=1 25 | batch_size=2 26 | OUTPUT_DIR=./checkpoints/pvc_b0/ 27 | 28 | # python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 29 | # --use_env main.py --config $CONFIG --data-set CIFAR --batch-size $batch_size --output_dir $OUTPUT_DIR \ 30 | 31 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 32 | --use_env main.py --config $CONFIG --data-set CIFAR --batch-size $batch_size --output_dir $OUTPUT_DIR \ 33 | 34 | # CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 35 | # --use_env main.py --config $CONFIG --data-set CIFAR --batch-size $batch_size --output_dir $OUTPUT_DIR \ 36 | 37 | # echo 1 38 | # python main.py --config $CONFIG --data-set CIFAR --batch-size $batch_size --output_dir $OUTPUT_DIR -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /classification/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | self.dataset = dataset 26 | self.num_replicas = num_replicas 27 | self.rank = rank 28 | self.epoch = 0 29 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 30 | self.total_size = self.num_samples * self.num_replicas 31 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 32 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 33 | self.shuffle = shuffle 34 | 35 | def __iter__(self): 36 | # deterministically shuffle based on epoch 37 | g = torch.Generator() 38 | g.manual_seed(self.epoch) 39 | if self.shuffle: 40 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 41 | else: 42 | indices = list(range(len(self.dataset))) 43 | 44 | # add extra samples to make it evenly divisible 45 | indices = [ele for ele in indices for i in range(3)] 46 | indices += indices[:(self.total_size - len(indices))] 47 | assert len(indices) == self.total_size 48 | 49 | # subsample 50 | indices = indices[self.rank:self.total_size:self.num_replicas] 51 | assert len(indices) == self.num_samples 52 | 53 | return iter(indices[:self.num_selected_samples]) 54 | 55 | def __len__(self): 56 | return self.num_selected_samples 57 | 58 | def set_epoch(self, epoch): 59 | self.epoch = epoch 60 | -------------------------------------------------------------------------------- /classification/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | class DistillationLoss(torch.nn.Module): 11 | """ 12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 13 | taking a teacher model prediction and using it as additional supervision. 14 | """ 15 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 16 | distillation_type: str, alpha: float, tau: float): 17 | super().__init__() 18 | self.base_criterion = base_criterion 19 | self.teacher_model = teacher_model 20 | assert distillation_type in ['none', 'soft', 'hard'] 21 | self.distillation_type = distillation_type 22 | self.alpha = alpha 23 | self.tau = tau 24 | 25 | def forward(self, inputs, outputs, labels): 26 | """ 27 | Args: 28 | inputs: The original inputs that are feed to the teacher model 29 | outputs: the outputs of the model to be trained. It is expected to be 30 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 31 | in the first position and the distillation predictions as the second output 32 | labels: the labels for the base criterion 33 | """ 34 | outputs_kd = None 35 | if not isinstance(outputs, torch.Tensor): 36 | # assume that the model outputs a tuple of [outputs, outputs_kd] 37 | outputs, outputs_kd = outputs 38 | base_loss = self.base_criterion(outputs, labels) 39 | if self.distillation_type == 'none': 40 | return base_loss 41 | 42 | if outputs_kd is None: 43 | raise ValueError("When knowledge distillation is enabled, the model is " 44 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 45 | "class_token and the dist_token") 46 | # don't backprop throught the teacher 47 | with torch.no_grad(): 48 | teacher_outputs = self.teacher_model(inputs) 49 | 50 | if self.distillation_type == 'soft': 51 | T = self.tau 52 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 53 | # with slight modifications 54 | distillation_loss = F.kl_div( 55 | F.log_softmax(outputs_kd / T, dim=1), 56 | F.log_softmax(teacher_outputs / T, dim=1), 57 | reduction='sum', 58 | log_target=True 59 | ) * (T * T) / outputs_kd.numel() 60 | elif self.distillation_type == 'hard': 61 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 62 | 63 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 64 | return loss 65 | -------------------------------------------------------------------------------- /classification/get_flops.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from timm.models import create_model 4 | import pvt 5 | import pvt_v2 6 | 7 | try: 8 | from mmcv.cnn import get_model_complexity_info 9 | from mmcv.cnn.utils.flops_counter import get_model_complexity_info, flops_to_string, params_to_string 10 | except ImportError: 11 | raise ImportError('Please upgrade mmcv to >0.6.2') 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description='Get FLOPS of a classification model') 16 | parser.add_argument('model', help='train config file path') 17 | parser.add_argument( 18 | '--shape', 19 | type=int, 20 | nargs='+', 21 | default=[224, 224], 22 | help='input image size') 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def sra_flops(h, w, r, dim): 28 | return 2 * h * w * (h // r) * (w // r) * dim 29 | 30 | def li_sra_flops(h, w, dim): 31 | return 2 * h * w * 7 * 7 * dim 32 | 33 | def fra_flops(h, w, sr, dim, fr, head): 34 | return 2 * 4 * (h // sr) * (w // sr) * ((dim//fr)//head) * ((dim//fr)//head) * head 35 | 36 | def get_flops(model, input_shape): 37 | flops, params = get_model_complexity_info(model, input_shape, as_strings=False) 38 | if 'pvt' in model.name: 39 | _, H, W = input_shape 40 | if 'li' in model.name: # calculate flops of PVTv2_li 41 | stage1 = li_sra_flops(H // 4, W // 4, 42 | model.block1[0].attn.dim) * len(model.block1) 43 | stage2 = li_sra_flops(H // 8, W // 8, 44 | model.block2[0].attn.dim) * len(model.block2) 45 | stage3 = li_sra_flops(H // 16, W // 16, 46 | model.block3[0].attn.dim) * len(model.block3) 47 | stage4 = li_sra_flops(H // 32, W // 32, 48 | model.block4[0].attn.dim) * len(model.block4) 49 | else: # calculate flops of PVT/PVTv2 50 | stage1 = sra_flops(H // 4, W // 4, 51 | model.block1[0].attn.sr_ratio, 52 | model.block1[0].attn.dim 53 | ) * len(model.block1) 54 | stage2 = sra_flops(H // 8, W // 8, 55 | model.block2[0].attn.sr_ratio, 56 | model.block2[0].attn.dim) * len(model.block2) 57 | stage3 = sra_flops(H // 16, W // 16, 58 | model.block3[0].attn.sr_ratio, 59 | model.block3[0].attn.dim) * len(model.block3) 60 | stage4 = sra_flops(H // 32, W // 32, 61 | model.block4[0].attn.sr_ratio, 62 | model.block4[0].attn.dim) * len(model.block4) 63 | print(stage1, stage2, stage3, stage4) 64 | flops += stage1 + stage2 + stage3 + stage4 65 | elif 'pvc' in model.name: 66 | _, H, W = input_shape 67 | stage1 = fra_flops(H //4 , W //4 , 68 | model.block1[0].attn.sr_ratio, 69 | model.block1[0].attn.embed_dim, 70 | model.block1[0].attn.fr, 71 | model.block1[0].attn.num_heads) * len(model.block1) 72 | stage2 = fra_flops(H //8 , W //8 , 73 | model.block2[0].attn.sr_ratio, 74 | model.block2[0].attn.embed_dim, 75 | model.block2[0].attn.fr, 76 | model.block2[0].attn.num_heads) * len(model.block2) 77 | stage3 = fra_flops(H //16, W //16, 78 | model.block3[0].attn.sr_ratio, 79 | model.block3[0].attn.embed_dim, 80 | model.block3[0].attn.fr, 81 | model.block3[0].attn.num_heads) * len(model.block3) 82 | stage4 = fra_flops(H //32, W //32, 83 | model.block4[0].attn.sr_ratio, 84 | model.block4[0].attn.embed_dim, 85 | model.block4[0].attn.fr, 86 | model.block4[0].attn.num_heads) * len(model.block4) 87 | print(stage1, stage2, stage3, stage4) 88 | flops += stage1 + stage2 + stage3 + stage4 89 | # pass 90 | 91 | return flops_to_string(flops), params_to_string(params) 92 | 93 | 94 | def main(): 95 | args = parse_args() 96 | 97 | if len(args.shape) == 1: 98 | input_shape = (3, args.shape[0], args.shape[0]) 99 | elif len(args.shape) == 2: 100 | input_shape = (3,) + tuple(args.shape) 101 | else: 102 | raise ValueError('invalid input shape') 103 | 104 | model = create_model( 105 | args.model, 106 | pretrained=False, 107 | num_classes=1000 108 | ) 109 | model.name = args.model 110 | if torch.cuda.is_available(): 111 | model.cuda() 112 | model.eval() 113 | 114 | flops, params = get_flops(model, input_shape) 115 | 116 | split_line = '=' * 30 117 | print(f'{split_line}\nInput shape: {input_shape}\n' 118 | f'Flops: {flops}\nParams: {params}\n{split_line}') 119 | print('!!!Please be cautious if you use the results in papers. ' 120 | 'You may need to check if all ops are supported and verify that the ' 121 | 'flops computation is correct.') 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vicinity-Vision-Transformer 2 | 3 | [[Project Page]](https://opennlplab.github.io/vvt/) 4 | 5 | This repo is the official implementations of [Vicinity Vision Transformer](https://arxiv.org/abs/2206.10552). 6 | It serves as a general-purpose backbone for image classification, semantic segmentation, object detction tasks. 7 | 8 | if you use this code, please cite: 9 | 10 | ``` 11 | @ARTICLE {10149455, 12 | author = {W. Sun and Z. Qin and H. Deng and J. Wang and Y. Zhang and K. Zhang and N. Barnes and S. Birchfield and L. Kong and Y. Zhong}, 13 | journal = {IEEE Transactions on Pattern Analysis & Machine Intelligence}, 14 | title = {Vicinity Vision Transformer}, 15 | year = {5555}, 16 | volume = {}, 17 | number = {01}, 18 | issn = {1939-3539}, 19 | pages = {1-14}, 20 | abstract = {Vision transformers have shown great success on numerous computer vision tasks. However, their central component, softmax attention, prohibits vision transformers from scaling up to high-resolution images, due to both the computational complexity and memory footprint being quadratic. Linear attention was introduced in natural language processing (NLP) which reorders the self-attention mechanism to mitigate a similar issue, but directly applying existing linear attention to vision may not lead to satisfactory results. We investigate this problem and point out that existing linear attention methods ignore an inductive bias in vision tasks, i.e., 2D locality. In this paper, we propose Vicinity Attention, which is a type of linear attention that integrates 2D locality. Specifically, for each image patch, we adjust its attention weight based on its 2D Manhattan distance from its neighbouring patches. In this case, we achieve 2D locality in a linear complexity where the neighbouring image patches receive stronger attention than far away patches. In addition, we propose a novel Vicinity Attention Block that is comprised of Feature Reduction Attention (FRA) and Feature Preserving Connection (FPC) in order to address the computational bottleneck of linear attention approaches, including our Vicinity Attention, whose complexity grows quadratically with respect to the feature dimension. The Vicinity Attention Block computes attention in a compressed feature space with an extra skip connection to retrieve the original feature distribution. We experimentally validate that the block further reduces computation without degenerating the accuracy. Finally, to validate the proposed methods, we build a linear vision transformer backbone named Vicinity Vision Transformer (VVT). Targeting general vision tasks, we build VVT in a pyramid structure with progressively reduced sequence length. We perform extensive experiments on CIFAR-100, ImageNet-1 k, and ADE20 K datasets to validate the effectiveness of our method. Our method has a slower growth rate in terms of computational overhead than previous transformer-based and convolution-based networks when the input resolution increases. In particular, our approach achieves state-of-the-art image classification accuracy with 50% fewer parameters than previous approaches.}, 21 | keywords = {transformers;task analysis;computer vision;standards;image resolution;sun;image classification}, 22 | doi = {10.1109/TPAMI.2023.3285569}, 23 | publisher = {IEEE Computer Society}, 24 | address = {Los Alamitos, CA, USA}, 25 | month = {jun} 26 | } 27 | 28 | ``` 29 | 30 | ![image](https://user-images.githubusercontent.com/13931546/175231586-5e7c46f3-29a2-4497-9ddf-bc40aeee88b3.png) 31 | 32 | 33 | 34 | ## Usage 35 | ### Data preparation 36 | 37 | Download and extract ImageNet train and val images from http://image-net.org/. 38 | The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder), and the training and validation data is expected to be in the `train/` folder and `val` folder respectively: 39 | 40 | ``` 41 | /path/to/imagenet/ 42 | train/ 43 | class1/ 44 | img1.jpeg 45 | class2/ 46 | img2.jpeg 47 | val/ 48 | class1/ 49 | img3.jpeg 50 | class/2 51 | img4.jpeg 52 | ``` 53 | 54 | ### environment 55 | First, clone the repository locally: 56 | ``` 57 | git clone https://github.com/OpenNLPLab/Vicinity-Vision-Transformer.git 58 | ``` 59 | Then, install PyTorch 1.8.1 and torchvision 0.9.1 and pytorch-image-models 0.4.12 60 | 61 | 62 | ### Evaluation 63 | ``` 64 | sh test.sh configs/vvt/vvt_small.py 1 /path/to/checkpoint 65 | ``` 66 | 67 | ### Training 68 | ``` 69 | sh train.sh configs/vvt/vvt_small.py 8 70 | ``` 71 | 72 | 73 | ## Weight 74 | ### VVT on ImageNet-1K 75 | 76 | | Method | Size | Acc@1 | #Params (M) | Link | 77 | |------------------|:----:|:-----:|:-----------:|:-----:| 78 | | VVT-tiny | 224 | 79.2 | 12.9 | [link](https://1drv.ms/u/s!Ak3sXyXVg781gtRFwwHih3Yu9G3FGg?e=yHKvuc) | 79 | | VVT-tiny | 384 | 80.3 | 12.9 | - | 80 | | VVT-small | 224 | 82.6 | 25.5 | [link](https://1drv.ms/u/s!Ak3sXyXVg781gtREWfCdlLJVy1IgpA?e=l4h3Wi) | 81 | | VVT-small | 384 | 83.4 | 25.5 | - | 82 | | VVT-medium | 224 | 83.8 | 47.9 | [link](https://1drv.ms/u/s!Ak3sXyXVg781gtRG4lD_uEVyj7cPYw?e=ihjjtO) | 83 | | VVT-medium | 384 | 84.5 | 47.9 | - | 84 | | VVT-large | 224 | 84.1 | 61.8 | [link](https://1drv.ms/u/s!Ak3sXyXVg781gtRHmfu0BybCZ8k1FQ?e=fLskgG) | 85 | | VVT-large | 384 | 84.7 | 61.8 | - | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | --- 95 | Our code is developed based on [TIMM](https://github.com/rwightman/pytorch-image-models) and [PVT](https://github.com/whai362/PVT) 96 | 97 | 98 | -------------------------------------------------------------------------------- /classification/datasets.py: -------------------------------------------------------------------------------- 1 | # # Copyright (c) 2015-present, Facebook, Inc. 2 | # # All rights reserved. 3 | # import os 4 | # import json 5 | 6 | # from torchvision import datasets, transforms 7 | # from torchvision.datasets.folder import ImageFolder, default_loader 8 | # import img_folder 9 | 10 | # from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 11 | # from timm.data import create_transform 12 | # from mcloader import ClassificationDataset 13 | 14 | 15 | # class INatDataset(ImageFolder): 16 | # def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 17 | # category='name', loader=default_loader): 18 | # self.transform = transform 19 | # self.loader = loader 20 | # self.target_transform = target_transform 21 | # self.year = year 22 | # # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 23 | # path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 24 | # with open(path_json) as json_file: 25 | # data = json.load(json_file) 26 | 27 | # with open(os.path.join(root, 'categories.json')) as json_file: 28 | # data_catg = json.load(json_file) 29 | 30 | # path_json_for_targeter = os.path.join(root, f"train{year}.json") 31 | 32 | # with open(path_json_for_targeter) as json_file: 33 | # data_for_targeter = json.load(json_file) 34 | 35 | # targeter = {} 36 | # indexer = 0 37 | # for elem in data_for_targeter['annotations']: 38 | # king = [] 39 | # king.append(data_catg[int(elem['category_id'])][category]) 40 | # if king[0] not in targeter.keys(): 41 | # targeter[king[0]] = indexer 42 | # indexer += 1 43 | # self.nb_classes = len(targeter) 44 | 45 | # self.samples = [] 46 | # for elem in data['images']: 47 | # cut = elem['file_name'].split('/') 48 | # target_current = int(cut[2]) 49 | # path_current = os.path.join(root, cut[0], cut[2], cut[3]) 50 | 51 | # categors = data_catg[target_current] 52 | # target_current_true = targeter[categors[category]] 53 | # self.samples.append((path_current, target_current_true)) 54 | 55 | # # __getitem__ and __len__ inherited from ImageFolder 56 | 57 | 58 | # def build_dataset(is_train, args): 59 | # transform = build_transform(is_train, args) 60 | 61 | # if args.data_set == 'CIFAR': 62 | # dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 63 | # nb_classes = 100 64 | # elif args.data_set == 'IMNET': 65 | # if not args.use_mcloader: 66 | # root = os.path.join(args.data_path, 'train' if is_train else 'val') 67 | # dataset = img_folder.ImageFolder2(root, transform=transform) 68 | # else: 69 | # dataset = ClassificationDataset( 70 | # 'train' if is_train else 'val', 71 | # pipeline=transform 72 | # ) 73 | # nb_classes = 1000 74 | # elif args.data_set == 'INAT': 75 | # dataset = INatDataset(args.data_path, train=is_train, year=2018, 76 | # category=args.inat_category, transform=transform) 77 | # nb_classes = dataset.nb_classes 78 | # elif args.data_set == 'INAT19': 79 | # dataset = INatDataset(args.data_path, train=is_train, year=2019, 80 | # category=args.inat_category, transform=transform) 81 | # nb_classes = dataset.nb_classes 82 | 83 | # return dataset, nb_classes 84 | 85 | 86 | # def build_transform(is_train, args): 87 | # resize_im = args.input_size > 32 88 | # if is_train: 89 | # # this should always dispatch to transforms_imagenet_train 90 | # transform = create_transform( 91 | # input_size=args.input_size, 92 | # is_training=True, 93 | # color_jitter=args.color_jitter, 94 | # auto_augment=args.aa, 95 | # interpolation=args.train_interpolation, 96 | # re_prob=args.reprob, 97 | # re_mode=args.remode, 98 | # re_count=args.recount, 99 | # ) 100 | # if not resize_im: 101 | # # replace RandomResizedCropAndInterpolation with 102 | # # RandomCrop 103 | # transform.transforms[0] = transforms.RandomCrop( 104 | # args.input_size, padding=4) 105 | # return transform 106 | 107 | # t = [] 108 | # if resize_im: 109 | # size = int((256 / 224) * args.input_size) 110 | # t.append( 111 | # transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 112 | # ) 113 | # t.append(transforms.CenterCrop(args.input_size)) 114 | 115 | # t.append(transforms.ToTensor()) 116 | # t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 117 | # return transforms.Compose(t) 118 | 119 | # Copyright (c) 2015-present, Facebook, Inc. 120 | # All rights reserved. 121 | import os 122 | import json 123 | 124 | from torchvision import datasets, transforms 125 | from torchvision.datasets.folder import ImageFolder, default_loader 126 | 127 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 128 | from timm.data import create_transform 129 | from mcloader import ClassificationDataset 130 | 131 | 132 | class INatDataset(ImageFolder): 133 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 134 | category='name', loader=default_loader): 135 | self.transform = transform 136 | self.loader = loader 137 | self.target_transform = target_transform 138 | self.year = year 139 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 140 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 141 | with open(path_json) as json_file: 142 | data = json.load(json_file) 143 | 144 | with open(os.path.join(root, 'categories.json')) as json_file: 145 | data_catg = json.load(json_file) 146 | 147 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 148 | 149 | with open(path_json_for_targeter) as json_file: 150 | data_for_targeter = json.load(json_file) 151 | 152 | targeter = {} 153 | indexer = 0 154 | for elem in data_for_targeter['annotations']: 155 | king = [] 156 | king.append(data_catg[int(elem['category_id'])][category]) 157 | if king[0] not in targeter.keys(): 158 | targeter[king[0]] = indexer 159 | indexer += 1 160 | self.nb_classes = len(targeter) 161 | 162 | self.samples = [] 163 | for elem in data['images']: 164 | cut = elem['file_name'].split('/') 165 | target_current = int(cut[2]) 166 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 167 | 168 | categors = data_catg[target_current] 169 | target_current_true = targeter[categors[category]] 170 | self.samples.append((path_current, target_current_true)) 171 | 172 | # __getitem__ and __len__ inherited from ImageFolder 173 | 174 | 175 | def build_dataset(is_train, args): 176 | transform = build_transform(is_train, args) 177 | 178 | if args.data_set == 'CIFAR': 179 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 180 | nb_classes = 100 181 | elif args.data_set == 'IMNET': 182 | if not args.use_mcloader: 183 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 184 | dataset = datasets.ImageFolder(root, transform=transform) 185 | else: 186 | dataset = ClassificationDataset( 187 | 'train' if is_train else 'val', 188 | pipeline=transform 189 | ) 190 | nb_classes = 1000 191 | elif args.data_set == 'INAT': 192 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 193 | category=args.inat_category, transform=transform) 194 | nb_classes = dataset.nb_classes 195 | elif args.data_set == 'INAT19': 196 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 197 | category=args.inat_category, transform=transform) 198 | nb_classes = dataset.nb_classes 199 | 200 | return dataset, nb_classes 201 | 202 | 203 | def build_transform(is_train, args): 204 | resize_im = args.input_size > 32 205 | if is_train: 206 | # this should always dispatch to transforms_imagenet_train 207 | transform = create_transform( 208 | input_size=args.input_size, 209 | is_training=True, 210 | color_jitter=args.color_jitter, 211 | auto_augment=args.aa, 212 | interpolation=args.train_interpolation, 213 | re_prob=args.reprob, 214 | re_mode=args.remode, 215 | re_count=args.recount, 216 | ) 217 | if not resize_im: 218 | # replace RandomResizedCropAndInterpolation with 219 | # RandomCrop 220 | transform.transforms[0] = transforms.RandomCrop( 221 | args.input_size, padding=4) 222 | return transform 223 | 224 | t = [] 225 | if resize_im: 226 | size = int((256 / 224) * args.input_size) 227 | t.append( 228 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 229 | ) 230 | t.append(transforms.CenterCrop(args.input_size)) 231 | 232 | t.append(transforms.ToTensor()) 233 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 234 | return transforms.Compose(t) -------------------------------------------------------------------------------- /classification/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | from xml.sax.handler import property_dom_node 14 | 15 | import torch 16 | import torch.distributed as dist 17 | import mmcv 18 | import numpy as np 19 | import socket 20 | import subprocess 21 | 22 | class SmoothedValue(object): 23 | """Track a series of values and provide access to smoothed values over a 24 | window or the global series average. 25 | """ 26 | 27 | def __init__(self, window_size=20, fmt=None): 28 | if fmt is None: 29 | fmt = "{median:.4f} ({global_avg:.4f})" 30 | self.deque = deque(maxlen=window_size) 31 | self.total = 0.0 32 | self.count = 0 33 | self.fmt = fmt 34 | 35 | def update(self, value, n=1): 36 | self.deque.append(value) 37 | self.count += n 38 | self.total += value * n 39 | 40 | def synchronize_between_processes(self): 41 | """ 42 | Warning: does not synchronize the deque! 43 | """ 44 | if not is_dist_avail_and_initialized(): 45 | return 46 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 47 | dist.barrier() 48 | dist.all_reduce(t) 49 | t = t.tolist() 50 | self.count = int(t[0]) 51 | self.total = t[1] 52 | 53 | @property 54 | def median(self): 55 | d = torch.tensor(list(self.deque)) 56 | return d.median().item() 57 | 58 | @property 59 | def avg(self): 60 | d = torch.tensor(list(self.deque), dtype=torch.float32) 61 | return d.mean().item() 62 | 63 | @property 64 | def global_avg(self): 65 | return self.total / self.count 66 | 67 | @property 68 | def max(self): 69 | return max(self.deque) 70 | 71 | @property 72 | def value(self): 73 | return self.deque[-1] 74 | 75 | def __str__(self): 76 | return self.fmt.format( 77 | median=self.median, 78 | avg=self.avg, 79 | global_avg=self.global_avg, 80 | max=self.max, 81 | value=self.value) 82 | 83 | 84 | class MetricLogger(object): 85 | def __init__(self, delimiter="\t"): 86 | self.meters = defaultdict(SmoothedValue) 87 | self.delimiter = delimiter 88 | 89 | def update(self, **kwargs): 90 | for k, v in kwargs.items(): 91 | if isinstance(v, torch.Tensor): 92 | v = v.item() 93 | assert isinstance(v, (float, int)) 94 | self.meters[k].update(v) 95 | 96 | def __getattr__(self, attr): 97 | if attr in self.meters: 98 | return self.meters[attr] 99 | if attr in self.__dict__: 100 | return self.__dict__[attr] 101 | raise AttributeError("'{}' object has no attribute '{}'".format( 102 | type(self).__name__, attr)) 103 | 104 | def __str__(self): 105 | loss_str = [] 106 | for name, meter in self.meters.items(): 107 | loss_str.append( 108 | "{}: {}".format(name, str(meter)) 109 | ) 110 | return self.delimiter.join(loss_str) 111 | 112 | def synchronize_between_processes(self): 113 | for meter in self.meters.values(): 114 | meter.synchronize_between_processes() 115 | 116 | def add_meter(self, name, meter): 117 | self.meters[name] = meter 118 | 119 | def log_every(self, iterable, print_freq, header=None): 120 | i = 0 121 | if not header: 122 | header = '' 123 | start_time = time.time() 124 | end = time.time() 125 | iter_time = SmoothedValue(fmt='{avg:.4f}') 126 | data_time = SmoothedValue(fmt='{avg:.4f}') 127 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 128 | log_msg = [ 129 | header, 130 | '[{0' + space_fmt + '}/{1}]', 131 | 'eta: {eta}', 132 | '{meters}', 133 | 'time: {time}', 134 | 'data: {data}' 135 | ] 136 | if torch.cuda.is_available(): 137 | log_msg.append('max mem: {memory:.0f}') 138 | log_msg = self.delimiter.join(log_msg) 139 | MB = 1024.0 * 1024.0 140 | for obj in iterable: 141 | data_time.update(time.time() - end) 142 | yield obj 143 | iter_time.update(time.time() - end) 144 | if i % print_freq == 0 or i == len(iterable) - 1: 145 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 146 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 147 | if torch.cuda.is_available(): 148 | print(log_msg.format( 149 | i, len(iterable), eta=eta_string, 150 | meters=str(self), 151 | time=str(iter_time), data=str(data_time), 152 | memory=torch.cuda.max_memory_allocated() / MB)) 153 | else: 154 | print(log_msg.format( 155 | i, len(iterable), eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), data=str(data_time))) 158 | i += 1 159 | end = time.time() 160 | total_time = time.time() - start_time 161 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 162 | print('{} Total time: {} ({:.4f} s / it)'.format( 163 | header, total_time_str, total_time / len(iterable))) 164 | 165 | 166 | def _load_checkpoint_for_ema(model_ema, checkpoint): 167 | """ 168 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 169 | """ 170 | mem_file = io.BytesIO() 171 | torch.save(checkpoint, mem_file) 172 | mem_file.seek(0) 173 | model_ema._load_checkpoint(mem_file) 174 | 175 | 176 | def setup_for_distributed(is_master): 177 | """ 178 | This function disables printing when not in master process 179 | """ 180 | import builtins as __builtin__ 181 | builtin_print = __builtin__.print 182 | 183 | def print(*args, **kwargs): 184 | force = kwargs.pop('force', False) 185 | if is_master or force: 186 | builtin_print(*args, **kwargs) 187 | 188 | __builtin__.print = print 189 | 190 | 191 | def is_dist_avail_and_initialized(): 192 | if not dist.is_available(): 193 | return False 194 | if not dist.is_initialized(): 195 | return False 196 | return True 197 | 198 | 199 | def get_world_size(): 200 | if not is_dist_avail_and_initialized(): 201 | return 1 202 | return dist.get_world_size() 203 | 204 | 205 | def get_rank(): 206 | if not is_dist_avail_and_initialized(): 207 | return 0 208 | return dist.get_rank() 209 | 210 | 211 | def is_main_process(): 212 | return get_rank() == 0 213 | 214 | 215 | def save_on_master(*args, **kwargs): 216 | if is_main_process(): 217 | torch.save(*args, **kwargs) 218 | 219 | def get_ip(): 220 | hostname = socket.gethostname() 221 | ip = socket.gethostbyname(hostname) 222 | 223 | return ip 224 | 225 | def init_distributed_mode(args): 226 | # 本地多卡 227 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 228 | args.rank = int(os.environ["RANK"]) 229 | args.world_size = int(os.environ['WORLD_SIZE']) 230 | # single gpu 231 | if args.world_size == 1: 232 | return 233 | args.gpu = int(os.environ['LOCAL_RANK']) 234 | args.distributed = True 235 | torch.cuda.set_device(args.gpu) 236 | args.dist_backend = 'nccl' 237 | print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True) 238 | print(args.dist_backend) 239 | print(args.dist_url) 240 | print(args.world_size) 241 | print(args.rank) 242 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 243 | world_size=args.world_size, rank=args.rank) 244 | torch.distributed.barrier() 245 | setup_for_distributed(args.rank == 0) 246 | 247 | return 248 | 249 | # slurm 250 | local_rank = int(os.environ['SLURM_LOCALID']) 251 | port = str(args.port) 252 | proc_id = int(os.environ['SLURM_PROCID']) 253 | ntasks = int(os.environ['SLURM_NTASKS']) 254 | node_list = os.environ['SLURM_NODELIST'] 255 | hostnames = subprocess.check_output( 256 | ["scontrol", "show", "hostnames", node_list] 257 | ) 258 | addr = hostnames.split()[0].decode("utf-8") 259 | os.environ['MASTER_PORT'] = port 260 | os.environ['MASTER_ADDR'] = addr 261 | os.environ['WORLD_SIZE'] = str(ntasks) 262 | os.environ['RANK'] = str(proc_id) 263 | os.environ['LOCAL_RANK'] = str(local_rank) 264 | 265 | if 'SLURM_PROCID' in os.environ: 266 | rank = int(os.environ['SLURM_PROCID']) 267 | gpu = rank % torch.cuda.device_count() 268 | print(f"SLURM {gpu}") 269 | args.rank = int(os.environ["RANK"]) 270 | args.world_size = int(os.environ['WORLD_SIZE']) 271 | args.gpu = int(os.environ['LOCAL_RANK']) 272 | args.distributed = True 273 | torch.cuda.set_device(local_rank) 274 | args.dist_backend = 'nccl' 275 | host_addr_full = 'tcp://' + addr + ':' + port 276 | 277 | print(f"addr: {host_addr_full}") 278 | print(f"ip {get_ip()}") 279 | print(f"MASTER_ADDR {addr}") 280 | print(f"port {port}") 281 | print(f"world_size {os.environ['WORLD_SIZE']}") 282 | print(f"rank {os.environ['RANK']}") 283 | print(f"local_rank {os.environ['LOCAL_RANK']}") 284 | print('| distributed init (rank {})'.format(local_rank), flush=True) 285 | 286 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=host_addr_full, 287 | world_size=args.world_size, rank=args.rank) 288 | torch.distributed.barrier() 289 | setup_for_distributed(args.rank == 0) 290 | 291 | def update_from_config(args): 292 | cfg = mmcv.Config.fromfile(args.config) 293 | for _, cfg_item in cfg._cfg_dict.items(): 294 | for k, v in cfg_item.items(): 295 | setattr(args, k, v) 296 | return args 297 | -------------------------------------------------------------------------------- /classification/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | import torch.nn.functional as F 10 | import torch 11 | 12 | from timm.data import Mixup 13 | from timm.utils import accuracy, ModelEma 14 | 15 | from losses import DistillationLoss 16 | import utils 17 | import os 18 | import pdb 19 | import os 20 | 21 | import inspect 22 | import numpy as np 23 | import cv2 24 | from torchvision import transforms 25 | 26 | 27 | class ForkedPdb(pdb.Pdb): 28 | """A Pdb subclass that may be used 29 | from a forked multiprocessing child 30 | 31 | """ 32 | def interaction(self, *args, **kwargs): 33 | _stdin = sys.stdin 34 | try: 35 | sys.stdin = open('/dev/stdin') 36 | pdb.Pdb.interaction(self, *args, **kwargs) 37 | finally: 38 | sys.stdin = _stdin 39 | 40 | 41 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 42 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 43 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 44 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 45 | set_training_mode=True, 46 | fp32=False, 47 | distributed=False): 48 | 49 | model.train(set_training_mode) 50 | metric_logger = utils.MetricLogger(delimiter=" ") 51 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 52 | header = 'Epoch: [{}]'.format(epoch) 53 | print_freq = 10 54 | 55 | # ForkedPdb() 56 | 57 | 58 | # try: 59 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 60 | samples = samples.to(device, non_blocking=True) 61 | targets = targets.to(device, non_blocking=True) 62 | 63 | if mixup_fn is not None: 64 | samples, targets = mixup_fn(samples, targets) 65 | 66 | # with torch.cuda.amp.autocast(): 67 | # outputs = model(samples) 68 | # loss = criterion(samples, outputs, targets) 69 | with torch.cuda.amp.autocast(enabled=not fp32): 70 | outputs = model(samples) 71 | outputs = outputs.float() # back to fp32 72 | loss = criterion(samples, outputs, targets) 73 | 74 | # if torch.isnan(loss): 75 | # print('loss is nan!!!...') 76 | 77 | # torch.distributed.barrier() 78 | # continue 79 | 80 | 81 | 82 | loss_value = loss.item() 83 | 84 | if not math.isfinite(loss_value): 85 | print("Loss is {}, stopping training".format(loss_value)) 86 | sys.exit(1) 87 | 88 | optimizer.zero_grad() 89 | 90 | # this attribute is added by timm on one optimizer (adahessian) 91 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 92 | loss_scaler(loss, optimizer, clip_grad=max_norm, 93 | parameters=model.parameters(), create_graph=is_second_order) 94 | 95 | torch.cuda.synchronize() 96 | if model_ema is not None: 97 | model_ema.update(model) 98 | 99 | metric_logger.update(loss=loss_value) 100 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 101 | 102 | if distributed: 103 | torch.distributed.barrier() 104 | 105 | # gather the stats from all processes 106 | metric_logger.synchronize_between_processes() 107 | print("Averaged stats:", metric_logger) 108 | # except: 109 | # ForkedPdb().set_trace() 110 | 111 | 112 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 113 | 114 | 115 | @torch.no_grad() 116 | def evaluate(data_loader, model, device): 117 | criterion = torch.nn.CrossEntropyLoss() 118 | 119 | metric_logger = utils.MetricLogger(delimiter=" ") 120 | header = 'Test:' 121 | 122 | # switch to evaluation mode 123 | model.eval() 124 | 125 | for images, target in metric_logger.log_every(data_loader, 10, header): 126 | images = images.to(device, non_blocking=True) 127 | target = target.to(device, non_blocking=True) 128 | 129 | # print(images.shape) 130 | 131 | # compute outputs 132 | # with torch.cuda.amp.autocast(enabled = False): 133 | with torch.cuda.amp.autocast(): 134 | output = model(images) 135 | output = output.float() # back to fp32 136 | loss = criterion(output, target) 137 | 138 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 139 | 140 | batch_size = images.shape[0] 141 | metric_logger.update(loss=loss.item()) 142 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 143 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 144 | # gather the stats from all processes 145 | metric_logger.synchronize_between_processes() 146 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 147 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 148 | 149 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 150 | 151 | 152 | def generate_cam(data_loader, model, device): 153 | criterion = torch.nn.CrossEntropyLoss() 154 | 155 | metric_logger = utils.MetricLogger(delimiter=" ") 156 | header = 'Test:' 157 | 158 | # switch to evaluation mode 159 | model.eval() 160 | 161 | count = 0 162 | img_transform = transforms.Compose([transforms.Normalize(mean = [-0.4850/.229, -0.456/0.224, -0.406/0.225], std =[1/0.229, 1/0.224, 1/0.225])]) 163 | 164 | 165 | for images, target in metric_logger.log_every(data_loader, 10, header): 166 | 167 | images = images.to(device, non_blocking=True) 168 | target = target.to(device, non_blocking=True) 169 | 170 | count += 1 171 | 172 | batch_size = images.shape[0] 173 | # compute outputs 174 | # with torch.cuda.amp.autocast(enabled = False): 175 | with torch.cuda.amp.autocast(): 176 | output = model(images) 177 | output = output.float() # back to fp32 178 | 179 | for batch in range(batch_size): 180 | cur_image = images[batch,:,:,:] 181 | 182 | rgb_img = img_transform(cur_image) 183 | rgb_img = rgb_img.detach().cpu().numpy().squeeze() 184 | rgb_img = rgb_img * 255 185 | rgb_img = rgb_img.astype(np.uint8) 186 | 187 | cur_target = target[batch] 188 | # print(cur_target) 189 | one_hot = np.zeros((1, 1000), dtype=np.float32) 190 | one_hot[0, cur_target] = 1 191 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 192 | 193 | cur_output = output[batch] 194 | one_hot = (one_hot.cuda() * cur_output) 195 | # print(one_hot) 196 | one_hot = torch.sum(one_hot) 197 | # print(one_hot) 198 | 199 | model.zero_grad() 200 | one_hot.backward(retain_graph=True) 201 | cam = model.module.grad_cam(batch, start_layer=0) 202 | print(cam.shape) 203 | cam = cam.reshape(7, 7) 204 | cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), (224, 224), mode='bilinear', align_corners=True) 205 | cam = cam.detach().cpu().numpy() 206 | norm_cam = cam / (np.max(cam, (0,1), keepdims=True) + 1e-5) 207 | print(norm_cam.shape) 208 | 209 | cv2.imwrite('/mnt/lustre/sunweixuan/pvc/classification/cam_output/pvt_v2_b1/{}_{}.png'.format(count, batch), norm_cam[0,0,:,:]*255) 210 | 211 | 212 | loss = criterion(output, target) 213 | 214 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 215 | 216 | batch_size = images.shape[0] 217 | metric_logger.update(loss=loss.item()) 218 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 219 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 220 | # gather the stats from all processes 221 | metric_logger.synchronize_between_processes() 222 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 223 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 224 | 225 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 226 | 227 | 228 | 229 | def generate_grad_cam(data_loader, model, device, methods): 230 | criterion = torch.nn.CrossEntropyLoss() 231 | 232 | metric_logger = utils.MetricLogger(delimiter=" ") 233 | header = 'Test:' 234 | 235 | # switch to evaluation mode 236 | model.eval() 237 | 238 | count = 0 239 | img_transform = transforms.Compose([transforms.Normalize(mean = [-0.4850/.229, -0.456/0.224, -0.406/0.225], std =[1/0.229, 1/0.224, 1/0.225])]) 240 | 241 | 242 | for images, target in metric_logger.log_every(data_loader, 10, header): 243 | 244 | images = images.to(device, non_blocking=True) 245 | target = target.to(device, non_blocking=True) 246 | 247 | count += 1 248 | 249 | batch_size = images.shape[0] 250 | # compute outputs 251 | # with torch.cuda.amp.autocast(enabled = False): 252 | with torch.cuda.amp.autocast(): 253 | output = model(images) 254 | output = output.float() # back to fp32 255 | 256 | for batch in range(batch_size): 257 | cur_image = images[batch,:,:,:] 258 | 259 | rgb_img = img_transform(cur_image) 260 | rgb_img = rgb_img.detach().cpu().numpy().squeeze() 261 | rgb_img = rgb_img * 255 262 | rgb_img = rgb_img.astype(np.uint8) 263 | 264 | cur_target = target[batch] 265 | # print(cur_target) 266 | one_hot = np.zeros((1, 1000), dtype=np.float32) 267 | one_hot[0, cur_target] = 1 268 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 269 | 270 | cur_output = output[batch] 271 | one_hot = (one_hot.cuda() * cur_output) 272 | # print(one_hot) 273 | one_hot = torch.sum(one_hot) 274 | # print(one_hot) 275 | 276 | model.zero_grad() 277 | one_hot.backward(retain_graph=True) 278 | cam = model.module.grad_cam(batch, start_layer=0) 279 | print(cam.shape) 280 | cam = cam.reshape(7, 7) 281 | cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), (224, 224), mode='bilinear', align_corners=True) 282 | cam = cam.detach().cpu().numpy() 283 | norm_cam = cam / (np.max(cam, (0,1), keepdims=True) + 1e-5) 284 | print(norm_cam.shape) 285 | 286 | cv2.imwrite('/mnt/lustre/sunweixuan/pvc/classification/cam_output/pvt_v2_b1/{}_{}.png'.format(count, batch), norm_cam[0,0,:,:]*255) 287 | 288 | 289 | loss = criterion(output, target) 290 | 291 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 292 | 293 | batch_size = images.shape[0] 294 | metric_logger.update(loss=loss.item()) 295 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 296 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 297 | # gather the stats from all processes 298 | metric_logger.synchronize_between_processes() 299 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 300 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 301 | 302 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 303 | 304 | 305 | 306 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /classification/pvt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | 10 | __all__ = [ 11 | 'pvt_tiny', 'pvt_small', 'pvt_medium', 'pvt_large' 12 | ] 13 | 14 | 15 | class Mlp(nn.Module): 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 36 | super().__init__() 37 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 38 | 39 | self.dim = dim 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = qk_scale or head_dim ** -0.5 43 | 44 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 45 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 46 | self.attn_drop = nn.Dropout(attn_drop) 47 | self.proj = nn.Linear(dim, dim) 48 | self.proj_drop = nn.Dropout(proj_drop) 49 | 50 | self.sr_ratio = sr_ratio 51 | if sr_ratio > 1: 52 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 53 | self.norm = nn.LayerNorm(dim) 54 | 55 | def forward(self, x, H, W): 56 | B, N, C = x.shape 57 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 58 | 59 | if self.sr_ratio > 1: 60 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 61 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 62 | x_ = self.norm(x_) 63 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 64 | else: 65 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 66 | k, v = kv[0], kv[1] 67 | 68 | attn = (q @ k.transpose(-2, -1)) * self.scale 69 | attn = attn.softmax(dim=-1) 70 | attn = self.attn_drop(attn) 71 | 72 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 73 | x = self.proj(x) 74 | x = self.proj_drop(x) 75 | 76 | return x 77 | 78 | 79 | class Block(nn.Module): 80 | 81 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 82 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 83 | super().__init__() 84 | self.norm1 = norm_layer(dim) 85 | self.attn = Attention( 86 | dim, 87 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 88 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 89 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 90 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 91 | self.norm2 = norm_layer(dim) 92 | mlp_hidden_dim = int(dim * mlp_ratio) 93 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 94 | 95 | def forward(self, x, H, W): 96 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 97 | x = x + self.drop_path(self.mlp(self.norm2(x))) 98 | 99 | return x 100 | 101 | 102 | class PatchEmbed(nn.Module): 103 | """ Image to Patch Embedding 104 | """ 105 | 106 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 107 | super().__init__() 108 | img_size = to_2tuple(img_size) 109 | patch_size = to_2tuple(patch_size) 110 | 111 | self.img_size = img_size 112 | self.patch_size = patch_size 113 | # assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ 114 | # f"img_size {img_size} should be divided by patch_size {patch_size}." 115 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 116 | self.num_patches = self.H * self.W 117 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 118 | self.norm = nn.LayerNorm(embed_dim) 119 | 120 | def forward(self, x): 121 | B, C, H, W = x.shape 122 | 123 | x = self.proj(x).flatten(2).transpose(1, 2) 124 | x = self.norm(x) 125 | H, W = H // self.patch_size[0], W // self.patch_size[1] 126 | 127 | return x, (H, W) 128 | 129 | 130 | class PyramidVisionTransformer(nn.Module): 131 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 132 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 133 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 134 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4): 135 | super().__init__() 136 | self.num_classes = num_classes 137 | self.depths = depths 138 | self.num_stages = num_stages 139 | 140 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 141 | cur = 0 142 | 143 | for i in range(num_stages): 144 | patch_embed = PatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), 145 | patch_size=patch_size if i == 0 else 2, 146 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 147 | embed_dim=embed_dims[i]) 148 | num_patches = patch_embed.num_patches if i != num_stages - 1 else patch_embed.num_patches + 1 149 | pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i])) 150 | pos_drop = nn.Dropout(p=drop_rate) 151 | 152 | block = nn.ModuleList([Block( 153 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, 154 | qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], 155 | norm_layer=norm_layer, sr_ratio=sr_ratios[i]) 156 | for j in range(depths[i])]) 157 | cur += depths[i] 158 | 159 | setattr(self, f"patch_embed{i + 1}", patch_embed) 160 | setattr(self, f"pos_embed{i + 1}", pos_embed) 161 | setattr(self, f"pos_drop{i + 1}", pos_drop) 162 | setattr(self, f"block{i + 1}", block) 163 | 164 | self.norm = norm_layer(embed_dims[3]) 165 | 166 | # cls_token 167 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[3])) 168 | 169 | # classification head 170 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 171 | 172 | # init weights 173 | for i in range(num_stages): 174 | pos_embed = getattr(self, f"pos_embed{i + 1}") 175 | trunc_normal_(pos_embed, std=.02) 176 | trunc_normal_(self.cls_token, std=.02) 177 | self.apply(self._init_weights) 178 | 179 | 180 | def _init_weights(self, m): 181 | if isinstance(m, nn.Linear): 182 | trunc_normal_(m.weight, std=.02) 183 | if isinstance(m, nn.Linear) and m.bias is not None: 184 | nn.init.constant_(m.bias, 0) 185 | elif isinstance(m, nn.LayerNorm): 186 | nn.init.constant_(m.bias, 0) 187 | nn.init.constant_(m.weight, 1.0) 188 | 189 | @torch.jit.ignore 190 | def no_weight_decay(self): 191 | # return {'pos_embed', 'cls_token'} # has pos_embed may be better 192 | return {'cls_token'} 193 | 194 | def get_classifier(self): 195 | return self.head 196 | 197 | def reset_classifier(self, num_classes, global_pool=''): 198 | self.num_classes = num_classes 199 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 200 | 201 | def _get_pos_embed(self, pos_embed, patch_embed, H, W): 202 | if H * W == self.patch_embed1.num_patches: 203 | return pos_embed 204 | else: 205 | return F.interpolate( 206 | pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), 207 | size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) 208 | 209 | def forward_features(self, x): 210 | B = x.shape[0] 211 | 212 | for i in range(self.num_stages): 213 | patch_embed = getattr(self, f"patch_embed{i + 1}") 214 | pos_embed = getattr(self, f"pos_embed{i + 1}") 215 | pos_drop = getattr(self, f"pos_drop{i + 1}") 216 | block = getattr(self, f"block{i + 1}") 217 | x, (H, W) = patch_embed(x) 218 | 219 | if i == self.num_stages - 1: 220 | cls_tokens = self.cls_token.expand(B, -1, -1) 221 | x = torch.cat((cls_tokens, x), dim=1) 222 | pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W) 223 | pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1) 224 | else: 225 | pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W) 226 | 227 | x = pos_drop(x + pos_embed) 228 | for blk in block: 229 | x = blk(x, H, W) 230 | if i != self.num_stages - 1: 231 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 232 | 233 | x = self.norm(x) 234 | 235 | return x[:, 0] 236 | 237 | def forward(self, x): 238 | x = self.forward_features(x) 239 | x = self.head(x) 240 | 241 | return x 242 | 243 | 244 | def _conv_filter(state_dict, patch_size=16): 245 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 246 | out_dict = {} 247 | for k, v in state_dict.items(): 248 | if 'patch_embed.proj.weight' in k: 249 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 250 | out_dict[k] = v 251 | 252 | return out_dict 253 | 254 | 255 | @register_model 256 | def pvt_tiny(pretrained=False, **kwargs): 257 | model = PyramidVisionTransformer( 258 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 259 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 260 | **kwargs) 261 | model.default_cfg = _cfg() 262 | 263 | return model 264 | 265 | 266 | @register_model 267 | def pvt_small(pretrained=False, **kwargs): 268 | model = PyramidVisionTransformer( 269 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 270 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) 271 | model.default_cfg = _cfg() 272 | 273 | return model 274 | 275 | 276 | @register_model 277 | def pvt_medium(pretrained=False, **kwargs): 278 | model = PyramidVisionTransformer( 279 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 280 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 281 | **kwargs) 282 | model.default_cfg = _cfg() 283 | 284 | return model 285 | 286 | 287 | @register_model 288 | def pvt_large(pretrained=False, **kwargs): 289 | model = PyramidVisionTransformer( 290 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 291 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 292 | **kwargs) 293 | model.default_cfg = _cfg() 294 | 295 | return model 296 | 297 | 298 | @register_model 299 | def pvt_huge_v2(pretrained=False, **kwargs): 300 | model = PyramidVisionTransformer( 301 | patch_size=4, embed_dims=[128, 256, 512, 768], num_heads=[2, 4, 8, 12], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 302 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 10, 60, 3], sr_ratios=[8, 4, 2, 1], 303 | # drop_rate=0.0, drop_path_rate=0.02) 304 | **kwargs) 305 | model.default_cfg = _cfg() 306 | 307 | return model 308 | -------------------------------------------------------------------------------- /classification/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | # os.environ['MKL_THREADING_LAYER'] = 'GNU' 5 | # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 6 | # os.environ['NPY_MKL_FORCE_INTEL'] = '1' 7 | 8 | import argparse 9 | import datetime 10 | import time 11 | import torch 12 | import numpy as np 13 | import torch.backends.cudnn as cudnn 14 | import json 15 | from pathlib import Path 16 | import collections 17 | import sys 18 | import pdb 19 | import inspect 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | from timm.data import Mixup 24 | from timm.models import create_model 25 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 26 | from timm.scheduler import create_scheduler 27 | from timm.optim import create_optimizer 28 | from timm.utils import NativeScaler, get_state_dict, ModelEma 29 | 30 | from datasets import build_dataset 31 | from engine import train_one_epoch, evaluate 32 | from losses import DistillationLoss 33 | from samplers import RASampler 34 | import pvt 35 | import pvt_v2 36 | import vvt 37 | import utils 38 | 39 | class ForkedPdb(pdb.Pdb): 40 | """A Pdb subclass that may be used 41 | from a forked multiprocessing child 42 | 43 | """ 44 | def interaction(self, *args, **kwargs): 45 | _stdin = sys.stdin 46 | try: 47 | sys.stdin = open('/dev/stdin') 48 | pdb.Pdb.interaction(self, *args, **kwargs) 49 | finally: 50 | sys.stdin = _stdin 51 | 52 | def get_args_parser(): 53 | parser = argparse.ArgumentParser('PVT training and evaluation script', add_help=False) 54 | parser.add_argument('--fp32-resume', action='store_true', default=False) 55 | parser.add_argument('--batch-size', default=80, type=int) 56 | parser.add_argument('--epochs', default=300, type=int) 57 | parser.add_argument('--config', required=True, type=str, help='config') 58 | 59 | # Model parameters 60 | parser.add_argument('--model', default='pvt_small', type=str, metavar='MODEL', 61 | help='Name of model to train') 62 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 63 | 64 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 65 | help='Dropout rate (default: 0.)') 66 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 67 | help='Drop path rate (default: 0.1)') 68 | 69 | # parser.add_argument('--model-ema', action='store_true') 70 | # parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 71 | # parser.set_defaults(model_ema=True) 72 | # parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 73 | # parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 74 | 75 | # Optimizer parameters 76 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 77 | help='Optimizer (default: "adamw"') 78 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 79 | help='Optimizer Epsilon (default: 1e-8)') 80 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 81 | help='Optimizer Betas (default: None, use opt default)') 82 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 83 | help='Clip gradient norm (default: None, no clipping)') 84 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 85 | help='SGD momentum (default: 0.9)') 86 | parser.add_argument('--weight-decay', type=float, default=0.05, 87 | help='weight decay (default: 0.05)') 88 | # Learning rate schedule parameters 89 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 90 | help='LR scheduler (default: "cosine"') 91 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 92 | help='learning rate (default: 5e-4)') 93 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 94 | help='learning rate noise on/off epoch percentages') 95 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 96 | help='learning rate noise limit percent (default: 0.67)') 97 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 98 | help='learning rate noise std-dev (default: 1.0)') 99 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 100 | help='warmup learning rate (default: 1e-6)') 101 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 102 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 103 | 104 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 105 | help='epoch interval to decay LR') 106 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 107 | help='epochs to warmup LR, if scheduler supports') 108 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 109 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 110 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 111 | help='patience epochs for Plateau LR scheduler (default: 10') 112 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 113 | help='LR decay rate (default: 0.1)') 114 | 115 | # Augmentation parameters 116 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 117 | help='Color jitter factor (default: 0.4)') 118 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 119 | help='Use AutoAugment policy. "v0" or "original". " + \ 120 | "(default: rand-m9-mstd0.5-inc1)'), 121 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 122 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 123 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 124 | 125 | parser.add_argument('--repeated-aug', action='store_true') 126 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 127 | parser.set_defaults(repeated_aug=True) 128 | 129 | # * Random Erase params 130 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 131 | help='Random erase prob (default: 0.25)') 132 | parser.add_argument('--remode', type=str, default='pixel', 133 | help='Random erase mode (default: "pixel")') 134 | parser.add_argument('--recount', type=int, default=1, 135 | help='Random erase count (default: 1)') 136 | parser.add_argument('--resplit', action='store_true', default=False, 137 | help='Do not random erase first (clean) augmentation split') 138 | 139 | # * Mixup params 140 | parser.add_argument('--mixup', type=float, default=0.8, 141 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 142 | parser.add_argument('--cutmix', type=float, default=1.0, 143 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 144 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 145 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 146 | parser.add_argument('--mixup-prob', type=float, default=1.0, 147 | help='Probability of performing mixup or cutmix when either/both is enabled') 148 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 149 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 150 | parser.add_argument('--mixup-mode', type=str, default='batch', 151 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 152 | 153 | # * Finetuning params 154 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 155 | 156 | # Dataset parameters 157 | parser.add_argument('--data-path', default='./images/', type=str, 158 | help='dataset path') 159 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], 160 | type=str, help='Image Net dataset path') 161 | parser.add_argument('--use-mcloader', action='store_true', default=True, help='Use mcloader') 162 | parser.add_argument('--inat-category', default='name', 163 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 164 | type=str, help='semantic granularity') 165 | 166 | parser.add_argument('--output_dir', default='', 167 | help='path where to save, empty for no saving') 168 | parser.add_argument('--device', default='cuda', 169 | help='device to use for training / testing') 170 | parser.add_argument('--seed', default=3407, type=int) 171 | parser.add_argument('--resume', default='', help='resume from checkpoint') 172 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 173 | help='start epoch') 174 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 175 | 176 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 177 | parser.add_argument('--num_workers', default=10, type=int) 178 | parser.add_argument('--pin-mem', action='store_true', 179 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 180 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 181 | help='') 182 | parser.set_defaults(pin_mem=True) 183 | 184 | # distributed training parameters 185 | parser.add_argument('--world_size', default=1, type=int, 186 | help='number of distributed processes') 187 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 188 | parser.add_argument('--port', default=-1, type=int, help='distributed port') 189 | parser.add_argument('--distributed', default=False, type=bool) 190 | 191 | return parser 192 | 193 | def main(args): 194 | utils.init_distributed_mode(args) 195 | 196 | device = torch.device(args.device) 197 | 198 | # fix the seed for reproducibility 199 | seed = args.seed + utils.get_rank() 200 | torch.manual_seed(seed) 201 | np.random.seed(seed) 202 | 203 | cudnn.benchmark = True 204 | 205 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 206 | dataset_val, _ = build_dataset(is_train=False, args=args) 207 | 208 | if True: # args.distributed: 209 | num_tasks = utils.get_world_size() 210 | print('world size:', num_tasks) 211 | global_rank = utils.get_rank() 212 | if args.repeated_aug: 213 | sampler_train = RASampler( 214 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 215 | ) 216 | else: 217 | sampler_train = torch.utils.data.DistributedSampler( 218 | dataset_train, 219 | # num_replicas=num_tasks, 220 | num_replicas=0, 221 | rank=global_rank, shuffle=True 222 | ) 223 | if args.dist_eval: 224 | if len(dataset_val) % num_tasks != 0: 225 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 226 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 227 | 'equal num of samples per-process.') 228 | sampler_val = torch.utils.data.DistributedSampler( 229 | dataset_val, 230 | num_replicas=num_tasks, 231 | # num_replicas=0, 232 | rank=global_rank, shuffle=False) 233 | else: 234 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 235 | else: 236 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 237 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 238 | 239 | data_loader_train = torch.utils.data.DataLoader( 240 | dataset_train, sampler=sampler_train, 241 | batch_size=args.batch_size, 242 | num_workers=args.num_workers, 243 | pin_memory=args.pin_mem, 244 | drop_last=True, 245 | ) 246 | 247 | data_loader_val = torch.utils.data.DataLoader( 248 | dataset_val, sampler=sampler_val, 249 | batch_size=int(1 * args.batch_size), 250 | num_workers=args.num_workers, 251 | pin_memory=args.pin_mem, 252 | drop_last=False 253 | ) 254 | 255 | mixup_fn = None 256 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 257 | if mixup_active: 258 | mixup_fn = Mixup( 259 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 260 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 261 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 262 | 263 | print(f"Creating model: {args.model}") 264 | model = create_model( 265 | args.model, 266 | pretrained=False, 267 | num_classes=args.nb_classes, 268 | drop_rate=args.drop, 269 | drop_path_rate=args.drop_path, 270 | drop_block_rate=None, 271 | ) 272 | 273 | 274 | if args.finetune: 275 | if args.finetune.startswith('https'): 276 | checkpoint = torch.hub.load_state_dict_from_url( 277 | args.finetune, map_location='cpu', check_hash=True) 278 | else: 279 | checkpoint = torch.load(args.finetune, map_location='cpu') 280 | 281 | if 'model' in checkpoint: 282 | checkpoint_model = checkpoint['model'] 283 | else: 284 | checkpoint_model = checkpoint 285 | state_dict = model.state_dict() 286 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 287 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 288 | print(f"Removing key {k} from pretrained checkpoint") 289 | del checkpoint_model[k] 290 | 291 | model.load_state_dict(checkpoint_model, strict=False) 292 | 293 | model.to(device) 294 | 295 | model_ema = None 296 | # if args.model_ema: 297 | # # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 298 | # model_ema = ModelEma( 299 | # model, 300 | # decay=args.model_ema_decay, 301 | # device='cpu' if args.model_ema_force_cpu else '', 302 | # resume='') 303 | 304 | model_without_ddp = model 305 | if args.distributed: 306 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 307 | model_without_ddp = model.module 308 | else: 309 | model = torch.nn.DataParallel(model) 310 | 311 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 312 | print('number of params:', n_parameters) 313 | 314 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 315 | args.lr = linear_scaled_lr 316 | optimizer = create_optimizer(args, model_without_ddp) 317 | loss_scaler = NativeScaler() 318 | lr_scheduler, _ = create_scheduler(args, optimizer) 319 | 320 | criterion = LabelSmoothingCrossEntropy() 321 | 322 | if args.mixup > 0.: 323 | # smoothing is handled with mixup label transform 324 | criterion = SoftTargetCrossEntropy() 325 | elif args.smoothing: 326 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 327 | else: 328 | criterion = torch.nn.CrossEntropyLoss() 329 | 330 | # wrap the criterion in our custom DistillationLoss, which 331 | # just dispatches to the original criterion if args.distillation_type is 'none' 332 | # criterion = DistillationLoss( 333 | # criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau 334 | # ) 335 | criterion = DistillationLoss( 336 | criterion, None, 'none', 0, 0 337 | ) 338 | 339 | output_dir = Path(args.output_dir) 340 | if args.resume: 341 | if args.resume.startswith('https'): 342 | checkpoint = torch.hub.load_state_dict_from_url( 343 | args.resume, map_location='cpu', check_hash=True) 344 | else: 345 | checkpoint = torch.load(args.resume, map_location='cpu') 346 | print('loaded checkpoints!') 347 | if 'model' in checkpoint: 348 | msg = model_without_ddp.load_state_dict(checkpoint['model']) 349 | else: 350 | msg = model_without_ddp.load_state_dict(checkpoint) 351 | print(msg) 352 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 353 | optimizer.load_state_dict(checkpoint['optimizer']) 354 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 355 | args.start_epoch = checkpoint['epoch'] + 1 356 | # if args.model_ema: 357 | # utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 358 | if 'scaler' in checkpoint: 359 | loss_scaler.load_state_dict(checkpoint['scaler']) 360 | 361 | if args.eval: 362 | test_stats = evaluate(data_loader_val, model, device) 363 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 364 | return 365 | 366 | print(f"Start training for {args.epochs} epochs") 367 | start_time = time.time() 368 | max_accuracy = 0.0 369 | 370 | for epoch in range(args.start_epoch, args.epochs): 371 | if args.fp32_resume and epoch > args.start_epoch + 1: 372 | args.fp32_resume = False 373 | 374 | # -------------------fp16 or not--- 375 | # args.fp32_resume = True 376 | #--------------------------- 377 | print('fp 32 or not:', args.fp32_resume) 378 | 379 | loss_scaler._scaler = torch.cuda.amp.GradScaler(enabled=not args.fp32_resume) 380 | 381 | if args.distributed: 382 | data_loader_train.sampler.set_epoch(epoch) 383 | 384 | train_stats = train_one_epoch( 385 | model, criterion, data_loader_train, 386 | optimizer, device, epoch, loss_scaler, 387 | args.clip_grad, model_ema, mixup_fn, 388 | set_training_mode=args.finetune == '', # keep in eval mode during finetuning 389 | fp32=args.fp32_resume, 390 | distributed=args.distributed, 391 | ) 392 | 393 | lr_scheduler.step(epoch) 394 | if args.output_dir and (epoch+1)%20 == 0: 395 | checkpoint_paths = [output_dir / 'checkpoint_{}.pth'.format(epoch)] 396 | for checkpoint_path in checkpoint_paths: 397 | utils.save_on_master({ 398 | 'model': model_without_ddp.state_dict(), 399 | 'optimizer': optimizer.state_dict(), 400 | 'lr_scheduler': lr_scheduler.state_dict(), 401 | 'epoch': epoch, 402 | # 'model_ema': get_state_dict(model_ema), 403 | 'scaler': loss_scaler.state_dict(), 404 | 'args': args, 405 | }, checkpoint_path) 406 | 407 | if args.output_dir: 408 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 409 | for checkpoint_path in checkpoint_paths: 410 | utils.save_on_master({ 411 | 'model': model_without_ddp.state_dict(), 412 | 'optimizer': optimizer.state_dict(), 413 | 'lr_scheduler': lr_scheduler.state_dict(), 414 | 'epoch': epoch, 415 | # 'model_ema': get_state_dict(model_ema), 416 | 'scaler': loss_scaler.state_dict(), 417 | 'args': args, 418 | }, checkpoint_path) 419 | 420 | 421 | test_stats = evaluate(data_loader_val, model, device) 422 | 423 | # save best model 424 | if args.output_dir and test_stats["acc1"] > max_accuracy: 425 | checkpoint_paths = [output_dir / 'checkpoint_best.pth'] 426 | for checkpoint_path in checkpoint_paths: 427 | utils.save_on_master({ 428 | 'model': model_without_ddp.state_dict(), 429 | 'optimizer': optimizer.state_dict(), 430 | 'lr_scheduler': lr_scheduler.state_dict(), 431 | 'epoch': epoch, 432 | # 'model_ema': get_state_dict(model_ema), 433 | 'scaler': loss_scaler.state_dict(), 434 | 'args': args, 435 | }, checkpoint_path) 436 | 437 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 438 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 439 | print(f'Max accuracy: {max_accuracy:.2f}%') 440 | 441 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 442 | **{f'test_{k}': v for k, v in test_stats.items()}, 443 | 'epoch': epoch, 444 | 'n_parameters': n_parameters} 445 | 446 | if args.output_dir and utils.is_main_process(): 447 | with (output_dir / "log.txt").open("a") as f: 448 | f.write(json.dumps(log_stats) + "\n") 449 | torch.distributed.barrier() 450 | 451 | 452 | 453 | total_time = time.time() - start_time 454 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 455 | print('Training time {}'.format(total_time_str)) 456 | 457 | 458 | if __name__ == '__main__': 459 | parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) 460 | args = parser.parse_args() 461 | args = utils.update_from_config(args) 462 | 463 | os.environ['OMP_NUM_THREADS'] = '16' 464 | os.environ['OMP_DYNAMIC'] = 'False' 465 | 466 | if args.output_dir: 467 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 468 | main(args) 469 | -------------------------------------------------------------------------------- /classification/pvt_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | import math 10 | 11 | 12 | def compute_rollout_attention(all_layer_matrices, start_layer=0): 13 | # adding residual consideration 14 | num_tokens = all_layer_matrices[0].shape[1] 15 | batch_size = all_layer_matrices[0].shape[0] 16 | 17 | eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device) 18 | all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))] 19 | # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True) 20 | # for i in range(len(all_layer_matrices))] 21 | joint_attention = all_layer_matrices[start_layer] 22 | for i in range(start_layer+1, len(all_layer_matrices)): 23 | joint_attention = all_layer_matrices[i].bmm(joint_attention) 24 | return joint_attention 25 | 26 | class Mlp(nn.Module): 27 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): 28 | super().__init__() 29 | out_features = out_features or in_features 30 | hidden_features = hidden_features or in_features 31 | self.fc1 = nn.Linear(in_features, hidden_features) 32 | self.dwconv = DWConv(hidden_features) 33 | self.act = act_layer() 34 | self.fc2 = nn.Linear(hidden_features, out_features) 35 | self.drop = nn.Dropout(drop) 36 | self.linear = linear 37 | 38 | if self.linear: 39 | self.relu = nn.ReLU(inplace=True) 40 | self.apply(self._init_weights) 41 | 42 | def _init_weights(self, m): 43 | if isinstance(m, nn.Linear): 44 | trunc_normal_(m.weight, std=.02) 45 | if isinstance(m, nn.Linear) and m.bias is not None: 46 | nn.init.constant_(m.bias, 0) 47 | elif isinstance(m, nn.LayerNorm): 48 | nn.init.constant_(m.bias, 0) 49 | nn.init.constant_(m.weight, 1.0) 50 | elif isinstance(m, nn.Conv2d): 51 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 52 | fan_out //= m.groups 53 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 54 | if m.bias is not None: 55 | m.bias.data.zero_() 56 | 57 | def forward(self, x, H, W): 58 | x = self.fc1(x) 59 | if self.linear: 60 | x = self.relu(x) 61 | x = self.dwconv(x, H, W) 62 | x = self.act(x) 63 | x = self.drop(x) 64 | x = self.fc2(x) 65 | x = self.drop(x) 66 | return x 67 | 68 | 69 | class Attention(nn.Module): 70 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False): 71 | super().__init__() 72 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 73 | 74 | self.dim = dim 75 | self.num_heads = num_heads 76 | head_dim = dim // num_heads 77 | self.scale = qk_scale or head_dim ** -0.5 78 | 79 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 80 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 81 | self.attn_drop = nn.Dropout(attn_drop) 82 | self.proj = nn.Linear(dim, dim) 83 | self.proj_drop = nn.Dropout(proj_drop) 84 | 85 | self.linear = linear 86 | self.sr_ratio = sr_ratio 87 | print('linear:', linear, 'sr_ratio:', sr_ratio) 88 | 89 | if not linear: 90 | if sr_ratio > 1: 91 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 92 | self.norm = nn.LayerNorm(dim) 93 | else: 94 | self.pool = nn.AdaptiveAvgPool2d(7) 95 | self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) 96 | self.norm = nn.LayerNorm(dim) 97 | self.act = nn.GELU() 98 | self.apply(self._init_weights) 99 | 100 | def get_attn(self): 101 | return self.attn_map 102 | 103 | def save_attn(self, attn): 104 | self.attn_map = attn 105 | 106 | def save_attn_gradients(self, attn_gradients): 107 | self.attn_map_gradients = attn_gradients 108 | 109 | def get_attn_gradients(self): 110 | return self.attn_map_gradients 111 | 112 | def _init_weights(self, m): 113 | if isinstance(m, nn.Linear): 114 | trunc_normal_(m.weight, std=.02) 115 | if isinstance(m, nn.Linear) and m.bias is not None: 116 | nn.init.constant_(m.bias, 0) 117 | elif isinstance(m, nn.LayerNorm): 118 | nn.init.constant_(m.bias, 0) 119 | nn.init.constant_(m.weight, 1.0) 120 | elif isinstance(m, nn.Conv2d): 121 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 122 | fan_out //= m.groups 123 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 124 | if m.bias is not None: 125 | m.bias.data.zero_() 126 | 127 | def forward(self, x, H, W): 128 | B, N, C = x.shape 129 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B, h, N, C/h 130 | # print(q.shape) 131 | 132 | if not self.linear: 133 | if self.sr_ratio > 1: 134 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 135 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 136 | x_ = self.norm(x_) 137 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 138 | else: 139 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 140 | else: 141 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 142 | x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) 143 | x_ = self.norm(x_) 144 | x_ = self.act(x_) 145 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 146 | k, v = kv[0], kv[1] 147 | # print('1', q.shape, k.shape, v.shape) 148 | 149 | attn = (q @ k.transpose(-2, -1)) * self.scale 150 | # print('2', attn.shape) 151 | attn = attn.softmax(dim=-1) 152 | # print('3', attn.shape) 153 | attn = self.attn_drop(attn) 154 | # print('4', attn.shape) 155 | print(attn.shape) 156 | 157 | # save the attention map for visualization 158 | if x.requires_grad: 159 | self.save_attn(attn) 160 | attn.register_hook(self.save_attn_gradients) 161 | 162 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 163 | # print('5', x.shape) 164 | x = self.proj(x) 165 | # print('6', x.shape) 166 | x = self.proj_drop(x) 167 | # print('7', x.shape) 168 | 169 | return x 170 | 171 | 172 | class AttentionV2(nn.Module): 173 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False): 174 | super().__init__() 175 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 176 | 177 | self.dim = dim 178 | self.num_heads = num_heads 179 | head_dim = dim // num_heads 180 | self.scale = qk_scale or head_dim ** -0.5 181 | 182 | self.fr_ratio = 1 183 | self.q = nn.Linear(dim, dim//self.fr_ratio, bias=qkv_bias) 184 | # self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 185 | self.k = nn.Linear(dim, dim//self.fr_ratio, bias=qkv_bias) 186 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 187 | self.attn_drop = nn.Dropout(attn_drop) 188 | self.proj = nn.Linear(dim, dim) 189 | self.proj_drop = nn.Dropout(proj_drop) 190 | 191 | self.linear = linear 192 | self.sr_ratio = sr_ratio 193 | print('linear:', linear, 'sr_ratio:', sr_ratio) 194 | 195 | if not linear: 196 | if sr_ratio > 1: 197 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 198 | self.norm = nn.LayerNorm(dim) 199 | else: 200 | self.pool = nn.AdaptiveAvgPool2d(7) 201 | self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) 202 | self.norm = nn.LayerNorm(dim) 203 | self.act = nn.GELU() 204 | self.apply(self._init_weights) 205 | 206 | def _init_weights(self, m): 207 | if isinstance(m, nn.Linear): 208 | trunc_normal_(m.weight, std=.02) 209 | if isinstance(m, nn.Linear) and m.bias is not None: 210 | nn.init.constant_(m.bias, 0) 211 | elif isinstance(m, nn.LayerNorm): 212 | nn.init.constant_(m.bias, 0) 213 | nn.init.constant_(m.weight, 1.0) 214 | elif isinstance(m, nn.Conv2d): 215 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 216 | fan_out //= m.groups 217 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 218 | if m.bias is not None: 219 | m.bias.data.zero_() 220 | 221 | def get_attn(self): 222 | return self.attn_map 223 | 224 | def save_attn(self, attn): 225 | self.attn_map = attn 226 | 227 | def save_attn_gradients(self, attn_gradients): 228 | self.attn_map_gradients = attn_gradients 229 | 230 | def get_attn_gradients(self): 231 | return self.attn_map_gradients 232 | 233 | def forward(self, x, H, W): 234 | B, N, C = x.shape 235 | q = self.q(x).reshape(B, N, self.num_heads, (C//self.fr_ratio) // self.num_heads).permute(0, 2, 1, 3) # B, h, N, C/h 236 | 237 | if not self.linear: 238 | if self.sr_ratio > 1: 239 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 240 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 241 | x_ = self.norm(x_) 242 | k = self.k(x_).reshape(B, -1,self.num_heads, (C//self.fr_ratio) // self.num_heads).permute(0,2,1,3) 243 | v = self.v(x_).reshape(B, -1,self.num_heads, C // self.num_heads).permute(0,2,1,3) 244 | # kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 245 | else: 246 | k = self.k(x).reshape(B, -1,self.num_heads, (C//self.fr_ratio) // self.num_heads).permute(0,2,1,3) 247 | v = self.v(x).reshape(B, -1,self.num_heads, C // self.num_heads).permute(0,2,1,3) 248 | # kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 249 | else: 250 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 251 | x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) 252 | x_ = self.norm(x_) 253 | x_ = self.act(x_) 254 | k = self.k(x_).reshape(B, -1,self.num_heads, (C//self.fr_ratio) // self.num_heads).permute(0,2,1,3) 255 | v = self.v(x_).reshape(B, -1,self.num_heads, C // self.num_heads).permute(0,2,1,3) 256 | # kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 257 | # k, v = kv[0], kv[1] 258 | # print('1', q.shape, k.shape, v.shape) 259 | 260 | attn = (q @ k.transpose(-2, -1)) * self.scale 261 | # print('2', attn.shape) 262 | attn = attn.softmax(dim=-1) 263 | # print('3', attn.shape) 264 | attn = self.attn_drop(attn) 265 | # print('4', attn.shape) 266 | 267 | # save the attention map for visualization 268 | if x.requires_grad: 269 | self.save_attn(attn) 270 | attn.register_hook(self.save_attn_gradients) 271 | 272 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 273 | # print('5', x.shape) 274 | x = self.proj(x) 275 | # print('6', x.shape) 276 | x = self.proj_drop(x) 277 | # print('7', x.shape) 278 | 279 | return x 280 | 281 | 282 | class Block(nn.Module): 283 | 284 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 285 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): 286 | super().__init__() 287 | self.norm1 = norm_layer(dim) 288 | self.attn = Attention( 289 | dim, 290 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 291 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) 292 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 293 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 294 | self.norm2 = norm_layer(dim) 295 | mlp_hidden_dim = int(dim * mlp_ratio) 296 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear) 297 | 298 | self.apply(self._init_weights) 299 | 300 | def _init_weights(self, m): 301 | if isinstance(m, nn.Linear): 302 | trunc_normal_(m.weight, std=.02) 303 | if isinstance(m, nn.Linear) and m.bias is not None: 304 | nn.init.constant_(m.bias, 0) 305 | elif isinstance(m, nn.LayerNorm): 306 | nn.init.constant_(m.bias, 0) 307 | nn.init.constant_(m.weight, 1.0) 308 | elif isinstance(m, nn.Conv2d): 309 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 310 | fan_out //= m.groups 311 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 312 | if m.bias is not None: 313 | m.bias.data.zero_() 314 | 315 | def forward(self, x, H, W): 316 | # print(x.shape) 317 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 318 | # print(x.shape) 319 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 320 | # print(x.shape) 321 | 322 | return x 323 | 324 | 325 | class OverlapPatchEmbed(nn.Module): 326 | """ Image to Patch Embedding 327 | """ 328 | 329 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 330 | super().__init__() 331 | img_size = to_2tuple(img_size) 332 | patch_size = to_2tuple(patch_size) 333 | 334 | self.img_size = img_size 335 | self.patch_size = patch_size 336 | self.H, self.W = img_size[0] // stride, img_size[1] // stride 337 | self.num_patches = self.H * self.W 338 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 339 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 340 | self.norm = nn.LayerNorm(embed_dim) 341 | 342 | self.apply(self._init_weights) 343 | 344 | def _init_weights(self, m): 345 | if isinstance(m, nn.Linear): 346 | trunc_normal_(m.weight, std=.02) 347 | if isinstance(m, nn.Linear) and m.bias is not None: 348 | nn.init.constant_(m.bias, 0) 349 | elif isinstance(m, nn.LayerNorm): 350 | nn.init.constant_(m.bias, 0) 351 | nn.init.constant_(m.weight, 1.0) 352 | elif isinstance(m, nn.Conv2d): 353 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 354 | fan_out //= m.groups 355 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 356 | if m.bias is not None: 357 | m.bias.data.zero_() 358 | 359 | def forward(self, x): 360 | x = self.proj(x) 361 | _, _, H, W = x.shape 362 | x = x.flatten(2).transpose(1, 2) 363 | x = self.norm(x) 364 | 365 | return x, H, W 366 | 367 | 368 | class PyramidVisionTransformerV2(nn.Module): 369 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 370 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 371 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 372 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, linear=False): 373 | super().__init__() 374 | self.num_classes = num_classes 375 | self.depths = depths 376 | self.num_stages = num_stages 377 | 378 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 379 | cur = 0 380 | 381 | 382 | for i in range(num_stages): 383 | patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), 384 | patch_size=7 if i == 0 else 3, 385 | stride=4 if i == 0 else 2, 386 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 387 | embed_dim=embed_dims[i]) 388 | 389 | 390 | # # print(img_size if i == 0 else img_size // (2 ** (i + 1))) 391 | 392 | if i == 0: 393 | seq_len = (img_size//4)**2 394 | else: 395 | seq_len = ((img_size // (2 ** (i + 1))) // 2)**2 396 | 397 | print(seq_len) 398 | 399 | block = nn.ModuleList([Block( 400 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, qk_scale=qk_scale, 401 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer, 402 | sr_ratio=sr_ratios[i], linear=linear) 403 | for j in range(depths[i])]) 404 | norm = norm_layer(embed_dims[i]) 405 | cur += depths[i] 406 | 407 | setattr(self, f"patch_embed{i + 1}", patch_embed) 408 | setattr(self, f"block{i + 1}", block) 409 | setattr(self, f"norm{i + 1}", norm) 410 | 411 | # classification head 412 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 413 | 414 | self.apply(self._init_weights) 415 | 416 | def _init_weights(self, m): 417 | if isinstance(m, nn.Linear): 418 | trunc_normal_(m.weight, std=.02) 419 | if isinstance(m, nn.Linear) and m.bias is not None: 420 | nn.init.constant_(m.bias, 0) 421 | elif isinstance(m, nn.LayerNorm): 422 | nn.init.constant_(m.bias, 0) 423 | nn.init.constant_(m.weight, 1.0) 424 | elif isinstance(m, nn.Conv2d): 425 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 426 | fan_out //= m.groups 427 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 428 | if m.bias is not None: 429 | m.bias.data.zero_() 430 | 431 | def freeze_patch_emb(self): 432 | self.patch_embed1.requires_grad = False 433 | 434 | def generate_cam(self, batch, start_layer=0): 435 | cams = [] 436 | attn_list = [] 437 | grad_list = [] 438 | 439 | for i in range(self.num_stages): 440 | block = getattr(self, f"block{i + 1}") 441 | for blk in block: 442 | grad = blk.attn.get_attn_gradients() 443 | cam = blk.attn.get_attn() 444 | # attn_list.append(torch.mean(cam, dim = 1)) 445 | # grad_list.append(torch.mean(grad, dim = 1)) 446 | cam = cam[batch].reshape(-1, cam.shape[-1], cam.shape[-1]) 447 | grad = grad[batch].reshape(-1, grad.shape[-1], grad.shape[-1]) 448 | cam = grad * cam 449 | cam = cam.clamp(min=0).mean(dim=0) 450 | cams.append(cam.unsqueeze(0)) 451 | 452 | rollout = compute_rollout_attention(cams, start_layer=start_layer) 453 | cam = rollout[:, 0, 1:] 454 | return cam 455 | 456 | def grad_cam(self, batch, start_layer=0): 457 | cams = [] 458 | 459 | for i in range(self.num_stages): 460 | block = getattr(self, f"block{i + 1}") 461 | for blk in block: 462 | grad = blk.attn.get_attn_gradients() 463 | cam = blk.attn.get_attn() 464 | cam = cam[batch].reshape(-1, cam.shape[-1], cam.shape[-1]) 465 | grad = grad[batch].reshape(-1, grad.shape[-1], grad.shape[-1]) 466 | # grad = grad.mean(dim=[1, 2], keepdim=True) 467 | cam = grad * cam 468 | # print(cam.shape) 469 | cam = cam.clamp(min=0).mean(dim=0) 470 | cams.append(cam.unsqueeze(0)) 471 | 472 | # rollout = compute_rollout_attention(cams, start_layer=start_layer) 473 | # cam = rollout[:, 0, 1:] 474 | print(cams[-1].shape) 475 | 476 | return cams[-1][0, 24,:] 477 | 478 | @torch.jit.ignore 479 | def no_weight_decay(self): 480 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 481 | 482 | def get_classifier(self): 483 | return self.head 484 | 485 | def reset_classifier(self, num_classes, global_pool=''): 486 | self.num_classes = num_classes 487 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 488 | 489 | def forward_features(self, x): 490 | B = x.shape[0] 491 | 492 | for i in range(self.num_stages): 493 | patch_embed = getattr(self, f"patch_embed{i + 1}") 494 | block = getattr(self, f"block{i + 1}") 495 | norm = getattr(self, f"norm{i + 1}") 496 | x, H, W = patch_embed(x) 497 | for blk in block: 498 | x = blk(x, H, W) 499 | x = norm(x) 500 | if i != self.num_stages - 1: 501 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 502 | 503 | return x.mean(dim=1) 504 | 505 | def forward(self, x): 506 | x = self.forward_features(x) 507 | x = self.head(x) 508 | 509 | return x 510 | 511 | 512 | class DWConv(nn.Module): 513 | def __init__(self, dim=768): 514 | super(DWConv, self).__init__() 515 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 516 | 517 | def forward(self, x, H, W): 518 | B, N, C = x.shape 519 | x = x.transpose(1, 2).view(B, C, H, W) 520 | x = self.dwconv(x) 521 | x = x.flatten(2).transpose(1, 2) 522 | 523 | return x 524 | 525 | 526 | def _conv_filter(state_dict, patch_size=16): 527 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 528 | out_dict = {} 529 | for k, v in state_dict.items(): 530 | if 'patch_embed.proj.weight' in k: 531 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 532 | out_dict[k] = v 533 | 534 | return out_dict 535 | 536 | 537 | @register_model 538 | def pvt_v2_b0(pretrained=False, **kwargs): 539 | model = PyramidVisionTransformerV2( 540 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 541 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 542 | **kwargs) 543 | model.default_cfg = _cfg() 544 | 545 | return model 546 | 547 | 548 | @register_model 549 | def pvt_v2_b1(pretrained=False, **kwargs): 550 | model = PyramidVisionTransformerV2( 551 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 552 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 553 | **kwargs) 554 | model.default_cfg = _cfg() 555 | 556 | return model 557 | 558 | 559 | 560 | 561 | @register_model 562 | def pvt_v2_b2(pretrained=False, **kwargs): 563 | model = PyramidVisionTransformerV2( 564 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 565 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) 566 | model.default_cfg = _cfg() 567 | 568 | return model 569 | 570 | 571 | @register_model 572 | def pvt_v2_b3(pretrained=False, **kwargs): 573 | model = PyramidVisionTransformerV2( 574 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 575 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 576 | **kwargs) 577 | model.default_cfg = _cfg() 578 | 579 | return model 580 | 581 | 582 | @register_model 583 | def pvt_v2_b4(pretrained=False, **kwargs): 584 | model = PyramidVisionTransformerV2( 585 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 586 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 587 | **kwargs) 588 | model.default_cfg = _cfg() 589 | 590 | return model 591 | 592 | 593 | @register_model 594 | def pvt_v2_b5(pretrained=False, **kwargs): 595 | model = PyramidVisionTransformerV2( 596 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, 597 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3,6,40,3], sr_ratios=[8,4,2,1], 598 | **kwargs) 599 | model.default_cfg = _cfg() 600 | 601 | return model 602 | 603 | 604 | @register_model 605 | def pvt_v2_b2_li(pretrained=False, **kwargs): 606 | model = PyramidVisionTransformerV2( 607 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 608 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], linear=True, **kwargs) 609 | model.default_cfg = _cfg() 610 | 611 | return model 612 | -------------------------------------------------------------------------------- /classification/vvt.py: -------------------------------------------------------------------------------- 1 | from sre_constants import AT_NON_BOUNDARY 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functools import partial 6 | 7 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 8 | from timm.models.registry import register_model 9 | from timm.models.vision_transformer import _cfg 10 | import math 11 | import numpy as np 12 | from torch import Tensor 13 | from typing import Optional 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): 18 | super().__init__() 19 | out_features = out_features or in_features 20 | hidden_features = hidden_features or in_features 21 | self.fc1 = nn.Linear(in_features, hidden_features) 22 | self.dwconv = DWConv(hidden_features) 23 | self.act = act_layer() 24 | self.fc2 = nn.Linear(hidden_features, out_features) 25 | self.drop = nn.Dropout(drop) 26 | self.linear = linear 27 | if self.linear: 28 | self.relu = nn.ReLU(inplace=True) 29 | self.apply(self._init_weights) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | elif isinstance(m, nn.LayerNorm): 37 | nn.init.constant_(m.bias, 0) 38 | nn.init.constant_(m.weight, 1.0) 39 | elif isinstance(m, nn.Conv2d): 40 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | fan_out //= m.groups 42 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 43 | if m.bias is not None: 44 | m.bias.data.zero_() 45 | 46 | def forward(self, x, H, W): 47 | x = self.fc1(x) 48 | if self.linear: 49 | x = self.relu(x) 50 | x = self.dwconv(x, H, W) 51 | x = self.act(x) 52 | x = self.drop(x) 53 | x = self.fc2(x) 54 | x = self.drop(x) 55 | return x 56 | 57 | class Vanilla_Attention(nn.Module): 58 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 59 | super().__init__() 60 | self.num_heads = num_heads 61 | head_dim = dim // num_heads 62 | self.scale = head_dim ** -0.5 63 | 64 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 65 | self.attn_drop = nn.Dropout(attn_drop) 66 | self.proj = nn.Linear(dim, dim) 67 | self.proj_drop = nn.Dropout(proj_drop) 68 | 69 | def forward(self, x, H, W): 70 | B, N, C = x.shape 71 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 72 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 73 | 74 | attn = (q @ k.transpose(-2, -1)) * self.scale 75 | attn = attn.softmax(dim=-1) 76 | attn = self.attn_drop(attn) 77 | 78 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 79 | x = self.proj(x) 80 | x = self.proj_drop(x) 81 | return x 82 | 83 | class pvt_Attention(nn.Module): 84 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False): 85 | super().__init__() 86 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 87 | 88 | self.dim = dim 89 | self.num_heads = num_heads 90 | head_dim = dim // num_heads 91 | self.scale = qk_scale or head_dim ** -0.5 92 | 93 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 94 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 95 | self.attn_drop = nn.Dropout(attn_drop) 96 | self.proj = nn.Linear(dim, dim) 97 | self.proj_drop = nn.Dropout(proj_drop) 98 | 99 | self.linear = linear 100 | self.sr_ratio = sr_ratio 101 | print('linear:', linear, 'sr_ratio:', sr_ratio) 102 | 103 | if not linear: 104 | if sr_ratio > 1: 105 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 106 | self.norm = nn.LayerNorm(dim) 107 | else: 108 | self.pool = nn.AdaptiveAvgPool2d(7) 109 | self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) 110 | self.norm = nn.LayerNorm(dim) 111 | self.act = nn.GELU() 112 | self.apply(self._init_weights) 113 | 114 | def get_attn(self): 115 | return self.attn_map 116 | 117 | def save_attn(self, attn): 118 | self.attn_map = attn 119 | 120 | def save_attn_gradients(self, attn_gradients): 121 | self.attn_map_gradients = attn_gradients 122 | 123 | def get_attn_gradients(self): 124 | return self.attn_map_gradients 125 | 126 | def _init_weights(self, m): 127 | if isinstance(m, nn.Linear): 128 | trunc_normal_(m.weight, std=.02) 129 | if isinstance(m, nn.Linear) and m.bias is not None: 130 | nn.init.constant_(m.bias, 0) 131 | elif isinstance(m, nn.LayerNorm): 132 | nn.init.constant_(m.bias, 0) 133 | nn.init.constant_(m.weight, 1.0) 134 | elif isinstance(m, nn.Conv2d): 135 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 136 | fan_out //= m.groups 137 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 138 | if m.bias is not None: 139 | m.bias.data.zero_() 140 | 141 | def forward(self, x, H, W): 142 | B, N, C = x.shape 143 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B, h, N, C/h 144 | 145 | if not self.linear: 146 | if self.sr_ratio > 1: 147 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 148 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 149 | x_ = self.norm(x_) 150 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 151 | else: 152 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 153 | else: 154 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 155 | x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) 156 | x_ = self.norm(x_) 157 | x_ = self.act(x_) 158 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 159 | k, v = kv[0], kv[1] 160 | 161 | attn = (q @ k.transpose(-2, -1)) * self.scale 162 | attn = attn.softmax(dim=-1) 163 | attn = self.attn_drop(attn) 164 | 165 | # save the attention map for visualization 166 | if x.requires_grad: 167 | self.save_attn(attn) 168 | attn.register_hook(self.save_attn_gradients) 169 | 170 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 171 | x = self.proj(x) 172 | x = self.proj_drop(x) 173 | 174 | return x 175 | 176 | class VicinityVisionAttention(nn.Module): 177 | def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, 178 | dropout_rate=0.0, causal=False, use_sum=True, 179 | sr_ratio=1, fr_ratio=1, linear=False, se_reduction=2): 180 | super().__init__() 181 | self.embed_dim = embed_dim 182 | self.num_heads = num_heads 183 | # q, k, v projection 184 | self.fr = fr_ratio 185 | # feature reduction 186 | self.k_proj = nn.Linear(embed_dim, embed_dim // self.fr) 187 | self.v_proj = nn.Linear(embed_dim, embed_dim // self.fr) 188 | self.q_proj = nn.Linear(embed_dim, embed_dim // self.fr) 189 | 190 | self.sr_ratio = sr_ratio 191 | self.linear = linear 192 | if not linear: 193 | if sr_ratio > 1: 194 | self.sr = nn.Conv2d(embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio) 195 | self.norm = nn.LayerNorm(embed_dim) 196 | else: 197 | self.pool = nn.AdaptiveAvgPool2d(7) 198 | self.sr = nn.Conv2d(embed_dim, embed_dim, kernel_size=1, stride=1) 199 | self.norm = nn.LayerNorm(embed_dim) 200 | self.act = nn.GELU() 201 | 202 | self.apply(self._init_weights) 203 | # outprojection 204 | self.out_proj = nn.Linear(embed_dim//self.fr, embed_dim) 205 | # dropout rate 206 | self.dropout_rate = dropout_rate 207 | # causal 208 | self.causal = causal 209 | 210 | assert (self.embed_dim % self.num_heads == 0), "embed_dim must be divisible by num_heads" 211 | self.use_sum = use_sum 212 | if self.use_sum: 213 | print('use sum') 214 | print('linear:', linear, 'sr_ratio:', sr_ratio, 'fr_ratio:', fr_ratio, 'se_ratio:', se_reduction) 215 | else: 216 | print('use production') 217 | print('linear:', linear, 'sr_ratio:', sr_ratio, 'fr_ratio:', fr_ratio, 'se_ratio:', se_reduction) 218 | 219 | # se block: 220 | reduction = se_reduction 221 | self.se_pool = nn.AdaptiveAvgPool1d(1) 222 | self.se_fc = nn.Sequential( 223 | nn.Linear(embed_dim, embed_dim // reduction, bias=False), 224 | nn.ReLU(inplace=True), 225 | nn.Linear(embed_dim // reduction, embed_dim, bias=False), 226 | nn.Sigmoid() 227 | ) 228 | 229 | self.clip = True 230 | 231 | def _init_weights(self, m): 232 | if isinstance(m, nn.Linear): 233 | trunc_normal_(m.weight, std=.02) 234 | if isinstance(m, nn.Linear) and m.bias is not None: 235 | nn.init.constant_(m.bias, 0) 236 | elif isinstance(m, nn.LayerNorm): 237 | nn.init.constant_(m.bias, 0) 238 | nn.init.constant_(m.weight, 1.0) 239 | elif isinstance(m, nn.Conv2d): 240 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 241 | fan_out //= m.groups 242 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 243 | if m.bias is not None: 244 | m.bias.data.zero_() 245 | 246 | def abs_clamp(self, t): 247 | min_mag = 1e-4 248 | max_mag = 10000 249 | sign = t.sign() 250 | return t.abs_().clamp_(min_mag, max_mag)*sign 251 | 252 | def get_index(self, m, n): 253 | """ 254 | m = width, n = height 255 | """ 256 | c = np.pi / 2 257 | seq_len = m * n 258 | index = torch.arange(seq_len).reshape(1, -1, 1, 1) 259 | a = c * (index // m) / n 260 | b = c * (index % m) / m 261 | 262 | seq_len = (m/self.sr_ratio) * (n/self.sr_ratio) 263 | index = torch.arange(seq_len).reshape(1, -1, 1, 1) 264 | a_sr = c * (index // (m/self.sr_ratio) ) / (n/self.sr_ratio) 265 | b_sr = c * (index % (m/self.sr_ratio)) / (m/self.sr_ratio) 266 | 267 | return nn.Parameter(a, requires_grad=False), nn.Parameter(b, requires_grad=False), \ 268 | nn.Parameter(a_sr, requires_grad=False), nn.Parameter(b_sr, requires_grad=False) 269 | 270 | 271 | def abs_clamp(self, t): 272 | min_mag = 1e-4 273 | max_mag = 10000 274 | sign = t.sign() 275 | return t.abs_().clamp_(min_mag, max_mag)*sign 276 | 277 | def forward(self, query, H, W): 278 | # H: height, W: weight 279 | num_heads = self.num_heads 280 | B, N, C = query.shape 281 | query_se = query.permute(0, 2, 1) 282 | query_se = self.se_pool(query_se).view(B, C) 283 | query_se = self.se_fc(query_se).view(B, C, 1) 284 | 285 | if not self.linear: 286 | if self.sr_ratio > 1: 287 | x_ = query.permute(0, 2, 1).reshape(B, C, H, W) 288 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 289 | x_ = self.norm(x_) 290 | k = self.k_proj(x_) 291 | v = self.v_proj(x_) 292 | else: 293 | k = self.k_proj(query) 294 | v = self.v_proj(query) 295 | else: 296 | x_ = query.permute(0, 2, 1).reshape(B, C, H, W) 297 | x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) 298 | x_ = self.norm(x_) 299 | x_ = self.act(x_) 300 | k = self.k_proj(x_) 301 | v = self.v_proj(x_) 302 | 303 | k = k.permute(1, 0, 2) 304 | v = v.permute(1, 0, 2) 305 | 306 | query = query.permute(1,0,2) 307 | tgt_len, bsz, embed_dim = query.size() 308 | head_dim = embed_dim // num_heads 309 | # (L, N, E) 310 | q = self.q_proj(query) 311 | 312 | # multihead 313 | # (N, L, h, d) 314 | q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim//self.fr).transpose(0, 1) 315 | # (N, S, h, d) 316 | k = k.contiguous().view(-1, bsz, num_heads, head_dim//self.fr).transpose(0, 1) 317 | # (N, S, h, d) 318 | v = v.contiguous().view(-1, bsz, num_heads, head_dim//self.fr).transpose(0, 1) 319 | # relu 320 | q = F.relu(q) 321 | k = F.relu(k) 322 | 323 | a, b, a_sr, b_sr = self.get_index(W, H) 324 | a = a.to(q) 325 | b = b.to(q) 326 | a_sr = a_sr.to(q) 327 | b_sr = b_sr.to(q) 328 | 329 | if self.use_sum: 330 | # sum 331 | q_ = torch.cat([q * torch.cos(a), \ 332 | q * torch.sin(a), \ 333 | q * torch.cos(b), \ 334 | q * torch.sin(b)], \ 335 | dim=-1) 336 | # (N, S, h, 4 * d) 337 | k_ = torch.cat([k * torch.cos(a_sr), \ 338 | k * torch.sin(a_sr), \ 339 | k * torch.cos(b_sr), \ 340 | k * torch.sin(b_sr)], \ 341 | dim=-1) 342 | else: 343 | q_ = torch.cat([q * torch.cos(a) * torch.cos(b), \ 344 | q * torch.cos(a) * torch.sin(b), \ 345 | q * torch.sin(a) * torch.cos(b), \ 346 | q * torch.sin(a) * torch.sin(b)], \ 347 | dim=-1) 348 | # (N, S, h, 4 * d) 349 | k_ = torch.cat([k * torch.cos(a_sr) * torch.cos(b_sr), \ 350 | k * torch.cos(a_sr) * torch.sin(b_sr), \ 351 | k * torch.sin(a_sr) * torch.cos(b_sr), \ 352 | k * torch.sin(a_sr) * torch.sin(b_sr)], \ 353 | dim=-1) 354 | 355 | eps = 1e-4 356 | 357 | #--------------------------------------------------------------------------------- 358 | kv_ = torch.matmul(k_.permute(0, 2, 3, 1), v.permute(0, 2, 1, 3)) # no einsum 359 | if self.clip: 360 | kv_ = self.abs_clamp(kv_) 361 | #--------------------------------------------------------------------------------- 362 | 363 | #-------------------------------------------------------------------------------- 364 | k_sum = torch.sum(k_, axis=1, keepdim=True) # no einsum 365 | z_ = 1 / (torch.sum(torch.mul(q_, k_sum), axis=-1) + eps) # no einsum 366 | if self.clip: 367 | z_ = self.abs_clamp(z_) 368 | #-------------------------------------------------------------------------------- 369 | 370 | # no einsum--------------------------------------------------------------------- 371 | attn_output = torch.matmul(q_.transpose(1, 2), kv_).transpose(1, 2) 372 | if self.clip: 373 | attn_output = self.abs_clamp(attn_output) 374 | # nlhm,nlh -> nlhm 375 | attn_output = torch.mul(attn_output, z_.unsqueeze(-1)) 376 | if self.clip: 377 | attn_output = self.abs_clamp(attn_output) 378 | #-------------------------------------------------------------------------------- 379 | 380 | # (N, L, h, d) -> (L, N, h, d) -> (L, N, E) 381 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim//self.fr) 382 | 383 | attn_output = self.out_proj(attn_output) 384 | if self.clip: 385 | attn_output = self.abs_clamp(attn_output) 386 | 387 | # ------------------------------------- se block 388 | attn_output = attn_output.permute(1,2,0) 389 | attn_output = attn_output + attn_output * query_se.expand_as(attn_output) 390 | if self.clip: 391 | attn_output = self.abs_clamp(attn_output) 392 | attn_output = attn_output.permute(2,0,1) 393 | # ------------------------------------------------- 394 | 395 | attn_output = attn_output.permute(1,0,2) 396 | 397 | return attn_output 398 | 399 | class Block(nn.Module): 400 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., 401 | attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, 402 | sr_ratio=1, fr_ratio = 1, linear=False, seq_len=3136, se_ratio=1, prod_type="right"): 403 | super().__init__() 404 | self.norm1 = norm_layer(dim) 405 | self.seq_len = seq_len 406 | 407 | self.attn = VicinityVisionAttention(embed_dim=dim, num_heads=num_heads, sr_ratio=sr_ratio, fr_ratio = fr_ratio, linear=linear, se_reduction=se_ratio) 408 | 409 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 410 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 411 | self.norm2 = norm_layer(dim) 412 | mlp_hidden_dim = int(dim * mlp_ratio) 413 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear) 414 | 415 | self.apply(self._init_weights) 416 | 417 | def _init_weights(self, m): 418 | if isinstance(m, nn.Linear): 419 | trunc_normal_(m.weight, std=.02) 420 | if isinstance(m, nn.Linear) and m.bias is not None: 421 | nn.init.constant_(m.bias, 0) 422 | elif isinstance(m, nn.LayerNorm): 423 | nn.init.constant_(m.bias, 0) 424 | nn.init.constant_(m.weight, 1.0) 425 | elif isinstance(m, nn.Conv2d): 426 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 427 | fan_out //= m.groups 428 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 429 | if m.bias is not None: 430 | m.bias.data.zero_() 431 | 432 | def forward(self, x, H, W): 433 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 434 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 435 | 436 | return x 437 | 438 | 439 | class OverlapPatchEmbed(nn.Module): 440 | """ Image to Patch Embedding 441 | """ 442 | 443 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 444 | super().__init__() 445 | img_size = to_2tuple(img_size) 446 | patch_size = to_2tuple(patch_size) 447 | 448 | self.img_size = img_size 449 | self.patch_size = patch_size 450 | self.H, self.W = img_size[0] // stride, img_size[1] // stride 451 | self.num_patches = self.H * self.W 452 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 453 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 454 | self.norm = nn.LayerNorm(embed_dim) 455 | 456 | self.apply(self._init_weights) 457 | 458 | def _init_weights(self, m): 459 | if isinstance(m, nn.Linear): 460 | trunc_normal_(m.weight, std=.02) 461 | if isinstance(m, nn.Linear) and m.bias is not None: 462 | nn.init.constant_(m.bias, 0) 463 | elif isinstance(m, nn.LayerNorm): 464 | nn.init.constant_(m.bias, 0) 465 | nn.init.constant_(m.weight, 1.0) 466 | elif isinstance(m, nn.Conv2d): 467 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 468 | fan_out //= m.groups 469 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 470 | if m.bias is not None: 471 | m.bias.data.zero_() 472 | 473 | def forward(self, x): 474 | x = self.proj(x) 475 | _, _, H, W = x.shape 476 | x = x.flatten(2).transpose(1, 2) 477 | x = self.norm(x) 478 | 479 | return x, H, W 480 | 481 | 482 | class VicinityVisionTransformer(nn.Module): 483 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 484 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 485 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 486 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], fr_ratios = [1,2,2,4], num_stages=4, linear=False, se_ratio=1, prod_types=["left", "left", "left", "left"]): 487 | super().__init__() 488 | self.num_classes = num_classes 489 | self.depths = depths 490 | self.num_stages = num_stages 491 | 492 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 493 | cur = 0 494 | 495 | for i in range(num_stages): 496 | patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), 497 | patch_size=7 if i == 0 else 3, 498 | stride=4 if i == 0 else 2, 499 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 500 | embed_dim=embed_dims[i]) 501 | 502 | if i == 0: 503 | seq_len = (img_size//patch_size)**2 504 | else: 505 | seq_len = ((img_size // (2 ** (i + 1))) // 2)**2 506 | 507 | print('seq_len:', seq_len, 'embed dim:', embed_dims[i], 'num heads:', num_heads[i], 'depth:', depths[i]) 508 | 509 | block = nn.ModuleList([Block( 510 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, qk_scale=qk_scale, 511 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer, 512 | sr_ratio=sr_ratios[i],fr_ratio = fr_ratios[i], linear=linear, seq_len=seq_len, se_ratio=se_ratio, prod_type=prod_types[i]) 513 | for j in range(depths[i])]) 514 | norm = norm_layer(embed_dims[i]) 515 | cur += depths[i] 516 | 517 | setattr(self, f"patch_embed{i + 1}", patch_embed) 518 | setattr(self, f"block{i + 1}", block) 519 | setattr(self, f"norm{i + 1}", norm) 520 | 521 | # classification head 522 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 523 | 524 | self.apply(self._init_weights) 525 | 526 | def _init_weights(self, m): 527 | if isinstance(m, nn.Linear): 528 | trunc_normal_(m.weight, std=.02) 529 | if isinstance(m, nn.Linear) and m.bias is not None: 530 | nn.init.constant_(m.bias, 0) 531 | elif isinstance(m, nn.LayerNorm): 532 | nn.init.constant_(m.bias, 0) 533 | nn.init.constant_(m.weight, 1.0) 534 | elif isinstance(m, nn.Conv2d): 535 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 536 | fan_out //= m.groups 537 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 538 | if m.bias is not None: 539 | m.bias.data.zero_() 540 | 541 | def freeze_patch_emb(self): 542 | self.patch_embed1.requires_grad = False 543 | 544 | @torch.jit.ignore 545 | def no_weight_decay(self): 546 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 547 | 548 | def get_classifier(self): 549 | return self.head 550 | 551 | def reset_classifier(self, num_classes, global_pool=''): 552 | self.num_classes = num_classes 553 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 554 | 555 | def forward_features(self, x): 556 | B = x.shape[0] 557 | 558 | for i in range(self.num_stages): 559 | patch_embed = getattr(self, f"patch_embed{i + 1}") 560 | block = getattr(self, f"block{i + 1}") 561 | norm = getattr(self, f"norm{i + 1}") 562 | x, H, W = patch_embed(x) 563 | for blk in block: 564 | x = blk(x, H, W) 565 | x = norm(x) 566 | if i != self.num_stages - 1: 567 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 568 | 569 | return x.mean(dim=1) 570 | 571 | def forward(self, x): 572 | x = self.forward_features(x) 573 | x = self.head(x) 574 | 575 | return x 576 | 577 | class DWConv(nn.Module): 578 | def __init__(self, dim=768): 579 | super(DWConv, self).__init__() 580 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 581 | 582 | def forward(self, x, H, W): 583 | B, N, C = x.shape 584 | x = x.transpose(1, 2).view(B, C, H, W) 585 | x = self.dwconv(x) 586 | x = x.flatten(2).transpose(1, 2) 587 | 588 | return x 589 | 590 | def _conv_filter(state_dict, patch_size=16): 591 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 592 | out_dict = {} 593 | for k, v in state_dict.items(): 594 | if 'patch_embed.proj.weight' in k: 595 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 596 | out_dict[k] = v 597 | 598 | return out_dict 599 | 600 | ### VVT 601 | @register_model 602 | def vvt_test(pretrained=False, **kwargs): 603 | model = VicinityVisionTransformer( 604 | patch_size=4, embed_dims=[32, 32, 32, 32], num_heads=[1, 1, 1, 1], mlp_ratios=[1, 1, 1, 1], qkv_bias=True, 605 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[1, 1, 1, 1], sr_ratios=[1, 1, 1, 1], 606 | **kwargs) 607 | model.default_cfg = _cfg() 608 | 609 | return model 610 | 611 | @register_model 612 | def vvt_tiny(pretrained=False, **kwargs): 613 | model = VicinityVisionTransformer( 614 | patch_size=4, embed_dims=[96, 160, 320, 512], num_heads=[1,2,5,8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 615 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], 616 | sr_ratios=[1,1,1,1], fr_ratios=[2,2,2,2],se_ratio=1, 617 | **kwargs) 618 | model.default_cfg = _cfg() 619 | 620 | return model 621 | 622 | @register_model 623 | def vvt_small(pretrained=False, **kwargs): 624 | model = VicinityVisionTransformer( 625 | patch_size=4, embed_dims=[96, 160, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 626 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 3, 9, 3], 627 | sr_ratios=[1,1,1,1],fr_ratios=[2,2,2,2],se_ratio=1, 628 | **kwargs) 629 | model.default_cfg = _cfg() 630 | 631 | return model 632 | 633 | @register_model 634 | def vvt_medium(pretrained=False, **kwargs): 635 | model = VicinityVisionTransformer( 636 | patch_size=4, embed_dims=[96, 160, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 637 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 3, 27, 3], 638 | sr_ratios=[1,1,1,1],fr_ratios=[2,2,2,2],se_ratio=1, 639 | **kwargs) 640 | model.default_cfg = _cfg() 641 | 642 | return model 643 | 644 | @register_model 645 | def vvt_large(pretrained=False, **kwargs): 646 | model = VicinityVisionTransformer( 647 | patch_size=4, embed_dims=[96, 160, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, 648 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[4, 4, 36, 4], 649 | sr_ratios=[1,1,1,1],fr_ratios=[2,2,2,2],se_ratio=1, 650 | **kwargs) 651 | model.default_cfg = _cfg() 652 | 653 | return model --------------------------------------------------------------------------------