├── .gitignore ├── imgs └── SSiT.png ├── requirements.txt ├── utils ├── folder2pkl.py ├── crop.py ├── saliency_detect.py └── attn_visualize.py ├── funcs.py ├── README.md ├── main.py ├── ssit.py ├── train.py ├── data.py ├── knn.py ├── eval.py ├── eval_seg.py └── vits.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .vscode 3 | data_index 4 | checkpoints -------------------------------------------------------------------------------- /imgs/SSiT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YijinHuang/SSiT/HEAD/imgs/SSiT.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.5.1 2 | numpy==1.21.2 3 | opencv_python==4.5.5.64 4 | opencv-contrib-python==4.6.0.66 5 | scikit_image==0.19.2 6 | timm==0.5.4 7 | torch==1.11.0 8 | torchvision==0.12.0 9 | tqdm==4.64.0 10 | tensorboard==2.9.1 11 | albumentations==1.3.1 12 | -------------------------------------------------------------------------------- /utils/folder2pkl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import pickle 4 | import argparse 5 | from pathlib import Path 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--output-file', type=str, default='../data_index/pretraining_dataset.pkl', help='path to save dataset index') 10 | parser.add_argument('--image-folder', type=str, help='path to image folder') 11 | parser.add_argument('--saliency-folder', type=str, help='path to saliency folder') 12 | 13 | 14 | def folder2pkl(): 15 | args = parser.parse_args() 16 | image_folder = Path(args.image_folder) 17 | saliency_folder = Path(args.saliency_folder) 18 | output_file = Path(args.output_file) 19 | 20 | # sort images by name 21 | image_paths = sorted(image_folder.glob('**/*.*'), key=lambda x: x.stem) 22 | saliency_paths = sorted(saliency_folder.glob('**/*.*'), key=lambda x: x.stem) 23 | assert len(image_paths) == len(saliency_paths), 'Number of images and number of saliency maps are not equal' 24 | 25 | # absolute path 26 | image_paths = [path.absolute() for path in image_paths] 27 | saliency_paths = [path.absolute() for path in saliency_paths] 28 | 29 | dataset = list(zip(image_paths, saliency_paths)) 30 | for image_path, saliency_path in dataset: 31 | image_id = image_path.stem 32 | saliency_id = saliency_path.stem 33 | assert image_id == saliency_id, 'Image id {} and saliency id {} do not match'.format(image_id, saliency_id) 34 | 35 | output_file.parent.mkdir(parents=True, exist_ok=True) 36 | with open(output_file, 'wb') as f: 37 | pickle.dump(dataset, f) 38 | 39 | 40 | if __name__ == '__main__': 41 | folder2pkl() 42 | -------------------------------------------------------------------------------- /utils/crop.py: -------------------------------------------------------------------------------- 1 | # ========================================================================== 2 | # Base on https://github.com/sveitser/kaggle_diabetic/blob/master/convert.py 3 | # ========================================================================== 4 | import os 5 | import random 6 | import argparse 7 | import numpy as np 8 | 9 | from pathlib import Path 10 | from tqdm import tqdm 11 | from PIL import Image, ImageFilter 12 | from multiprocessing import Process 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--image-folder', type=str, help='path to image folder') 17 | parser.add_argument('--output-folder', type=str, help='path to output folder') 18 | parser.add_argument('--crop-size', type=int, default=512, help='crop size of image') 19 | parser.add_argument('-n', '--num-processes', type=int, default=8, help='number of processes to use') 20 | 21 | 22 | def main(): 23 | args = parser.parse_args() 24 | image_folder = Path(args.image_folder) 25 | output_folder = Path(args.output_folder) 26 | 27 | jobs = [] 28 | for root, _, imgs in os.walk(args.image_folder): 29 | root = Path(root) 30 | subfolders = root.relative_to(image_folder) 31 | output_root = output_folder.joinpath(subfolders) 32 | output_root.mkdir(parents=True, exist_ok=True) 33 | 34 | for img in tqdm(imgs): 35 | src_path = root.joinpath(img) 36 | tgt_path = output_root.joinpath(img) 37 | jobs.append((src_path, tgt_path, args.crop_size)) 38 | random.shuffle(jobs) 39 | 40 | procs = [] 41 | job_size = len(jobs) // args.num_processes 42 | for i in range(args.num_processes): 43 | if i < args.num_processes - 1: 44 | procs.append(Process(target=convert_list, args=(i, jobs[i * job_size:(i + 1) * job_size]))) 45 | else: 46 | procs.append(Process(target=convert_list, args=(i, jobs[i * job_size:]))) 47 | 48 | for p in procs: 49 | p.start() 50 | 51 | for p in procs: 52 | p.join() 53 | 54 | 55 | def convert_list(i, jobs): 56 | for j, job in enumerate(jobs): 57 | if j % 100 == 0: 58 | print('worker{} has finished {}.'.format(i, j)) 59 | convert(*job) 60 | 61 | 62 | def convert(fname, tgt_path, crop_size): 63 | img = Image.open(fname) 64 | 65 | blurred = img.filter(ImageFilter.BLUR) 66 | ba = np.array(blurred) 67 | h, w, _ = ba.shape 68 | 69 | if w > 1.2 * h: 70 | left_max = ba[:, : w // 32, :].max(axis=(0, 1)).astype(int) 71 | right_max = ba[:, - w // 32:, :].max(axis=(0, 1)).astype(int) 72 | max_bg = np.maximum(left_max, right_max) 73 | 74 | foreground = (ba > max_bg + 10).astype(np.uint8) 75 | bbox = Image.fromarray(foreground).getbbox() 76 | 77 | if bbox is None: 78 | print('bbox none for {} (???)'.format(fname)) 79 | else: 80 | left, upper, right, lower = bbox 81 | # if we selected less than 80% of the original 82 | # height, just crop the square 83 | if right - left < 0.8 * h or lower - upper < 0.8 * h: 84 | print('bbox too small for {}'.format(fname)) 85 | bbox = None 86 | else: 87 | bbox = None 88 | 89 | if bbox is None: 90 | bbox = square_bbox(img) 91 | 92 | cropped = img.crop(bbox) 93 | cropped = cropped.resize([crop_size, crop_size], Image.ANTIALIAS) 94 | save(cropped, tgt_path) 95 | 96 | 97 | def square_bbox(img): 98 | w, h = img.size 99 | left = max((w - h) // 2, 0) 100 | upper = 0 101 | right = min(w - (w - h) // 2, w) 102 | lower = h 103 | return (left, upper, right, lower) 104 | 105 | 106 | def save(img, fname): 107 | img.save(fname, quality=100, subsampling=0) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() -------------------------------------------------------------------------------- /utils/saliency_detect.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import cv2 as cv 4 | import numpy as np 5 | 6 | from pathlib import Path 7 | from multiprocessing import Pool 8 | 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('-n', '--num_process', type=int, default=8, help='number of processes') 12 | parser.add_argument('--saliency-model', type=str, default='fine_grained', help='saliency model (fine_grained / spectral_residual)') 13 | parser.add_argument('--image-folder', type=str, help='path to image folder') 14 | parser.add_argument('--output-folder', type=str, help='path to save saliency map') 15 | parser.add_argument('--visualize-folder', type=str, default='', help='path to save saliency map visualizeualization') 16 | 17 | 18 | circle = np.zeros((512, 512)) 19 | circle = cv.circle(circle, (256, 256), 240, 1, -1) 20 | 21 | 22 | def saliency_detect(i, saliency_model, src_path, output_path, visualize_path): 23 | image = cv.imread(str(src_path)) 24 | image = preprocess(image) 25 | 26 | if saliency_model == 'fine_grained': 27 | saliency = cv.saliency.StaticSaliencyFineGrained_create() 28 | elif saliency_model == 'spectral_residual': 29 | saliency = cv.saliency.StaticSaliencySpectralResidual_create() 30 | else: 31 | raise ValueError('Unknown saliency model: {}'.format(saliency_model)) 32 | 33 | (_, raw_saliencyMap) = saliency.computeSaliency(image) 34 | raw_saliencyMap *= circle 35 | 36 | np.save(output_path, raw_saliencyMap) 37 | 38 | if visualize_path: 39 | int_saliencyMap = (raw_saliencyMap * 255).astype("uint8") 40 | cv.imwrite(str(visualize_path), int_saliencyMap) 41 | 42 | if i % 500 == 0: 43 | print('Processed {} images'.format(i)) 44 | 45 | 46 | def main(): 47 | args = parser.parse_args() 48 | image_folder = Path(args.image_folder) 49 | output_folder = Path(args.output_folder) 50 | visualize_folder = Path(args.visualize_folder) 51 | 52 | i = 0 53 | res = [] 54 | pool = Pool(processes=args.num_process) 55 | print('Loading tasks...') 56 | for folder, _, imgs in os.walk(args.image_folder): 57 | folder = Path(folder) 58 | subfolders = folder.relative_to(image_folder) 59 | output_subfolder = output_folder.joinpath(subfolders) 60 | output_subfolder.mkdir(parents=True, exist_ok=True) 61 | 62 | if args.visualize_folder: 63 | visualize_subfolder = visualize_folder.joinpath(subfolders) 64 | visualize_subfolder.mkdir(parents=True, exist_ok=True) 65 | 66 | for img in imgs: 67 | i += 1 68 | src_path = folder.joinpath(img) 69 | output_path = output_subfolder.joinpath(img).with_suffix('.npy') 70 | visualize_path = visualize_subfolder.joinpath(img) if args.visualize_folder else '' 71 | res.append(pool.apply_async(saliency_detect, args=(i, args.saliency_model, src_path, output_path, visualize_path))) 72 | 73 | print('Waiting for all subprocesses done...') 74 | for re in res: 75 | re.get() 76 | pool.close() 77 | pool.join() 78 | print('All subprocesses done.') 79 | 80 | 81 | def preprocess(img): 82 | scale = 512 83 | mask = np.zeros(img.shape) 84 | cv.circle(mask, (int(img.shape[1]/2), int(img.shape[0]/2)), 85 | int(scale/2*0.98), (1, 1, 1), -1, 8, 0) 86 | weighted_img = cv.addWeighted(img, 4, cv.GaussianBlur(img, (0, 0), scale/30), -4, 128) 87 | processed_img = weighted_img * mask + 128 * (1 - mask) 88 | 89 | # To reproduce the saliency map used in the paper, 90 | # we simulated the processing of saving the processed image in jpeg format and then reading it. 91 | # These codes can be removed if error or performance degradation is observed. 92 | processed_img = processed_img.astype(np.uint8) 93 | _, jpeg = cv.imencode('.jpeg', processed_img) 94 | processed_img = cv.imdecode(jpeg, cv.IMREAD_COLOR) 95 | 96 | return processed_img 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from vits import resize_pos_embed 4 | 5 | 6 | def print_msg(msg, appendixs=[]): 7 | max_len = len(max([msg, *appendixs], key=len)) 8 | print('=' * max_len) 9 | print(msg) 10 | for appendix in appendixs: 11 | print(appendix) 12 | print('=' * max_len) 13 | 14 | 15 | def print_config(args): 16 | print('=================================') 17 | for key, value in args.__dict__.items(): 18 | print('{}: {}'.format(key, value)) 19 | print('=================================') 20 | 21 | 22 | def print_dataset_info(datasets): 23 | train_dataset, test_dataset, val_dataset = datasets 24 | print('=========================') 25 | print('Dataset Loaded.') 26 | print('Categories:\t{}'.format(len(train_dataset.classes))) 27 | print('Training:\t{}'.format(len(train_dataset))) 28 | print('Validation:\t{}'.format(len(val_dataset))) 29 | print('Test:\t\t{}'.format(len(test_dataset))) 30 | print('=========================') 31 | 32 | 33 | def inverse_normalize(tensor, mean, std): 34 | for t, m, s in zip(tensor, mean, std): 35 | t.mul_(s).add_(m) 36 | return tensor 37 | 38 | 39 | def is_main(args): 40 | return (not args.distributed) or args.rank == 0 41 | 42 | 43 | def to_devices(args, *tensors): 44 | if args.distributed: 45 | return [tensor.cuda(args.gpu) for tensor in tensors] 46 | else: 47 | return [tensor.to(args.device) for tensor in tensors] 48 | 49 | 50 | def quadratic_weighted_kappa(conf_mat): 51 | assert conf_mat.shape[0] == conf_mat.shape[1] 52 | cate_num = conf_mat.shape[0] 53 | 54 | # Quadratic weighted matrix 55 | weighted_matrix = np.zeros((cate_num, cate_num)) 56 | for i in range(cate_num): 57 | for j in range(cate_num): 58 | weighted_matrix[i][j] = 1 - float(((i - j)**2) / ((cate_num - 1)**2)) 59 | 60 | # Expected matrix 61 | ground_truth_count = np.sum(conf_mat, axis=1) 62 | pred_count = np.sum(conf_mat, axis=0) 63 | expected_matrix = np.outer(ground_truth_count, pred_count) 64 | 65 | # Normalization 66 | conf_mat = conf_mat / conf_mat.sum() 67 | expected_matrix = expected_matrix / expected_matrix.sum() 68 | 69 | observed = (conf_mat * weighted_matrix).sum() 70 | expected = (expected_matrix * weighted_matrix).sum() 71 | return (observed - expected) / (1 - expected) 72 | 73 | 74 | def load_checkpoint(model, checkpoint_path, checkpoint_key, linear_key): 75 | checkpoint = torch.load(checkpoint_path) 76 | state_dict = checkpoint.state_dict() 77 | for k in list(state_dict.keys()): 78 | # retain only base_encoder up to before the embedding layer 79 | if k.startswith(checkpoint_key) and not k.startswith('%s.%s' % (checkpoint_key, linear_key)): 80 | # remove prefix 81 | state_dict[k[len("%s." % checkpoint_key):]] = state_dict[k] 82 | # delete renamed or unused k 83 | del state_dict[k] 84 | 85 | # position embedding 86 | pos_embed_w = state_dict['pos_embed'] 87 | pos_embed_w = resize_pos_embed(pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 88 | state_dict['pos_embed'] = pos_embed_w 89 | 90 | msg = model.load_state_dict(state_dict, strict=False) 91 | assert set(msg.missing_keys) == {"%s.weight" % linear_key, "%s.bias" % linear_key} 92 | print_msg('Load weights form {}'.format(checkpoint_path)) 93 | 94 | 95 | def get_dataset_stats(dataset): 96 | # mean and std from train set 97 | dataset_stats = { 98 | 'ddr': ( 99 | [0.423737496137619, 0.2609460651874542, 0.128403902053833], 100 | [0.29482534527778625, 0.20167365670204163, 0.13668020069599152] 101 | ), 102 | 'aptos2019': ( 103 | [0.46100369095802307, 0.246780663728714, 0.07989078760147095], 104 | [0.24873991310596466, 0.13842609524726868, 0.08025242388248444] 105 | ), 106 | 'messidor2': ( 107 | [0.48436370491981506, 0.2238118201494217, 0.07583174854516983], 108 | [0.2939208149909973, 0.14721707999706268, 0.06350880116224289] 109 | ) 110 | } 111 | if dataset in dataset_stats.keys(): 112 | mean, std = dataset_stats[dataset] 113 | else: 114 | raise NotImplementedError( 115 | 'Not implemented dataset: {}. ' 116 | 'Please specify the dataset name [--dataset ddr / aptos2019 / messidor2]. ' 117 | 'If you are training on the customized dataset, ' 118 | 'please add the mean and std of your dataset in dataset_stats in funcs.py.'.format(dataset) 119 | ) 120 | return mean, std 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSiT: Saliency-guided Self-supervised Image Transformer 2 | 3 | This is the pytorch implementation of the paper: 4 | 5 | > Y. Huang, J. Lyu, P. Cheng, R. Tam and X. Tang, "SSiT: Saliency-guided Self-supervised Image Transformer for Diabetic Retinopathy Grading", IEEE Journal of Biomedical and Health Informatics (JBHI), 2024. \[[arxiv](https://arxiv.org/abs/2210.10969)\] \[[JBHI](https://ieeexplore.ieee.org/abstract/document/10423096)\] 6 | 7 | ![](./imgs/SSiT.png) 8 | 9 | 10 | 11 | ## Dataset 12 | The datasets used in this work are listed as follows: 13 | 14 | Pretraining: 15 | - EyePACS [[homepage](https://www.kaggle.com/c/diabetic-retinopathy-detection/overview)]. 16 | 17 | Evalutation for classification: 18 | - DDR [[homepage](https://github.com/nkicsl/DDR-dataset)]. 19 | - APTOS 2019 [[homepage](https://www.kaggle.com/c/aptos2019-blindness-detection/overview)]. 20 | - Messidor-2 [[images](https://www.adcis.net/en/third-party/messidor2/)] [[labels](https://www.kaggle.com/datasets/google-brain/messidor2-dr-grades)]. 21 | - IChallenge-AMD [[homepage](https://refuge.grand-challenge.org/iChallenge-AMD/)] 22 | - IChallenge-PM [[homepage](https://palm.grand-challenge.org/)] 23 | 24 | Evalutation for segmentation: 25 | - DRIVE [[homepage](https://drive.grand-challenge.org/)] 26 | - IDRiD [[homepage](https://idrid.grand-challenge.org/)] 27 | 28 | 29 | ## Installation 30 | To install the dependencies, run: 31 | ```shell 32 | git clone https://github.com/YijinHuang/SSiT.git 33 | cd SSiT 34 | conda create -n ssit python=3.8.0 35 | conda activate ssit 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | 40 | 41 | ## Usage 42 | ### Dataset preparation 43 | 44 | 1\. Organize each dataset as follows: 45 | ``` 46 | ├── dataset 47 | ├── train 48 | ├── class1 49 | ├── image1.jpg 50 | ├── image2.jpg 51 | ├── ... 52 | ├── class2 53 | ├── image3.jpg 54 | ├── image4.jpg 55 | ├── ... 56 | ├── class3 57 | ├── ... 58 | ├── val 59 | ├── test 60 | ``` 61 | Here, `val` and `test` have the same structure of `train`. Note that we do not use image labels in the pretraining stage, so this folder structure is not required for the pretraining dataset (EyePACS in this work). 62 | 63 | 2\. Data preprocessing for all datasets (crop and resize): 64 | ```shell 65 | cd utils 66 | python crop.py -n 8 --crop-size 512 --image-folder --output-folder 67 | cd .. 68 | ``` 69 | Here, `-n` is the number of workers. The processed dataset will be saved in the `--output-folder`. 70 | 71 | 3\. Data preprocessing for the **pretraining dataset (EyePACS)** only: 72 | 73 | ```shell 74 | cd utils 75 | python saliency_detect.py -n 8 --image-folder --output-folder 76 | python folder2pkl.py --image-folder --saliency-folder --output-file ../data_index/pretraining_dataset.pkl 77 | cd .. 78 | ``` 79 | 80 | 81 | 82 | ### Pretraining 83 | Pretraining with ViT-S on a single multi-GPUs node: 84 | ```shell 85 | python main.py \ 86 | --distributed --port 28888 --num-workers 32 \ 87 | --arch ViT-S-p16 --batch-size 512 \ 88 | --epochs 300 --warmup-epochs 40 \ 89 | --data-index ./data_index/pretraining_dataset.pkl \ 90 | --save-path 91 | ``` 92 | Specify `CUDA_VISIBLE_DEVICES` to control the number of GPUs. To reproduce the results in the paper, at least 64GB GPU memory is required. (4 NVIDIA RTX 3090 GPUs with 24GB memory are used in our experiments.) 93 | 94 | 95 | 96 | ### Evaluation 97 | #### Classification 98 | 1\. Fine-tuning evaluation on DDR dataset on one GPU: 99 | ```shell 100 | python eval.py \ 101 | --dataset ddr --arch ViT-S-p16 --kappa-prior \ 102 | --data-path \ 103 | --checkpoint \ 104 | --save-path 105 | ``` 106 | 107 | 2\. Linear evaluation on DDR dataset on one GPU: 108 | ```shell 109 | python eval.py \ 110 | --dataset ddr --arch ViT-S-p16 --kappa-prior \ 111 | --linear --learning-rate 0.002 --weight-decay 0 \ 112 | --data-path \ 113 | --checkpoint \ 114 | --save-path 115 | ``` 116 | 117 | 3\. Perform kNN-classification on DDR dataset: 118 | ```shell 119 | python knn.py \ 120 | --dataset ddr --arch ViT-S-p16 \ 121 | --data-path \ 122 | --checkpoint \ 123 | ``` 124 | 125 | Note that the `--checkpoint` should be `epoch_xxx.pt` instead of `checkpoint.pt` in the pretraining save path. To evaluate on other datasets, update `--dataset` to messidor2 or aptos2019 and the `--data-path` to corresponding dataset folder. 126 | 127 | #### Segmentation 128 | 1\. Save the segmentation dataset as a pickle file: 129 | ``` 130 | data_index = { 131 | 'train': [ 132 | ('path/to/image1', 'path/to/mask_of_image1'), 133 | ('path/to/image2', 'path/to/mask_of_image2'), 134 | ... 135 | ], 136 | 'test': [ 137 | ('path/to/image3', 'path/to/mask_of_image3'), 138 | ... 139 | ], 140 | 'val': [ 141 | ('path/to/image4', 'path/to/mask_of_image4'), 142 | ... 143 | ] 144 | } 145 | 146 | import pickle 147 | pickle.dump(your_data_dict, open('path/to/data_index', 'wb')) 148 | ``` 149 | The mask of the image should be a 255-based image, where the value 255 indicates the area of interest. 150 | 151 | 2\. Perform segmentation on DRIVE dataset: 152 | ```shell 153 | python eval_seg.py \ 154 | --dataset drive --arch ViT-S-p16 \ 155 | --data-index \ 156 | --checkpoint \ 157 | ``` 158 | 159 | 160 | ### Visualization 161 | To visualize self-attention maps from folder of images: 162 | ```shell 163 | cd utils 164 | python attn_visualize.py \ 165 | --arch ViT-S-p16 --image-size 1024 \ 166 | --image-folder \ 167 | --checkpoint \ 168 | --output-dir 169 | cd .. 170 | ``` 171 | 172 | 173 | 174 | ## Citation 175 | 176 | If you find this repository useful, please cite the paper: 177 | 178 | ``` 179 | @article{huang2024ssit, 180 | title={Ssit: Saliency-guided self-supervised image transformer for diabetic retinopathy grading}, 181 | author={Huang, Yijin and Lyu, Junyan and Cheng, Pujin and Tam, Roger and Tang, Xiaoying}, 182 | journal={IEEE Journal of Biomedical and Health Informatics}, 183 | year={2024}, 184 | publisher={IEEE} 185 | } 186 | ``` 187 | 188 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import argparse 5 | import builtins 6 | 7 | import torch 8 | import numpy as np 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | from train import train 14 | from ssit import build_model 15 | from data import build_dataset 16 | from funcs import print_config, print_msg, is_main 17 | 18 | 19 | parser = argparse.ArgumentParser() 20 | # base setting 21 | parser.add_argument('--arch', '--architecture', type=str, default='ViT-S-p16', help='network architecture, shoud be in archs in vits.py') 22 | parser.add_argument('--data-index', type=str, default='./data_index/pretraining_dataset.pkl', help='pickle file with lesion predicted results') 23 | parser.add_argument('--save-path', type=str, default='./checkpoints', help='path to save checkpoints') 24 | parser.add_argument('--record-path', type=str, default=None, help='path to save log') 25 | parser.add_argument('--pretrained', action='store_true', help='load pretrained parameters in ImageNet') 26 | parser.add_argument('--device', type=str, default='cuda', help='only support cuda') 27 | parser.add_argument('--seed', type=int, default=-1, help='random seed for reproducibilty. Set to -1 to disable.') 28 | parser.add_argument('--resume', action='store_true', help='resume training from the latest checkpoint') 29 | 30 | # DDP setting 31 | parser.add_argument('--distributed', action='store_true', help='distributed training') 32 | parser.add_argument('--backend', type=str, default='nccl', help='distributed backend') 33 | parser.add_argument('--nodes', type=int, default=1, help='number of nodes for distributed training') 34 | parser.add_argument('--n-gpus', type=int, default=None, help='number of gpus per node') 35 | parser.add_argument('--addr', type=str, default='127.0.0.1', help='master address') 36 | parser.add_argument('--port', type=str, default='28888', help='master port') 37 | parser.add_argument('--rank', type=int, default=0, help='rank of current process') 38 | 39 | # training setting 40 | parser.add_argument('--input-size', type=int, default=224, help='input size') 41 | parser.add_argument('--start-epoch', type=int, default=0, help='start epoch for training') 42 | parser.add_argument('--epochs', type=int, default=300, help='total training epochs') 43 | parser.add_argument('--warmup-epochs', type=int, default=40, help='number of warmup epochs') 44 | parser.add_argument('--mask-ratio', type=float, default=0.25, help='ratio of masked pixels') 45 | parser.add_argument('--disable-progress', action='store_true', help='do not show progress bar') 46 | parser.add_argument('--ss', type=float, default=10, help='weight of saliency segmentation loss') 47 | parser.add_argument('--ss-decay', action='store_true', help='cosine decay weight of saliency segmentation loss') 48 | parser.add_argument('--cl', type=float, default=1, help='weight of contrastive learning loss') 49 | parser.add_argument('--saliency-threshold', type=float, default=0.5, help='threshold for saliency map') 50 | parser.add_argument('--batch-size', type=int, default=512, help='total training batch size') 51 | parser.add_argument('--optimizer', type=str, default='ADAMW', help='SGD / ADAM / ADAMW') 52 | parser.add_argument('--moco-m', type=float, default=0.99, help='momentum for moco') 53 | parser.add_argument('--temperature', type=float, default=0.2, help='temperature for moco') 54 | parser.add_argument('--learning-rate', type=float, default=0.001, help='initial learning rate') 55 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum for SGD optimizer') 56 | parser.add_argument('--weight-decay', type=float, default=0.1, help='weight decay for SGD and ADAM') 57 | parser.add_argument('--num-workers', type=int, default=32, help='total number of workers') 58 | parser.add_argument('--save-interval', type=int, default=20, help='number of interval to store model and checkpoint') 59 | parser.add_argument('--pool-mode', type=str, default='max', help="'max' / 'avg', pooling mode for saliency map patch") 60 | parser.add_argument('--dataset-ratio', type=float, default=1.0, help='ratio of dataset for pre-training') 61 | 62 | 63 | def main(): 64 | # print configuration 65 | args = parser.parse_args() 66 | print_config(args) 67 | 68 | # create folder 69 | save_path = args.save_path 70 | os.makedirs(save_path, exist_ok=True) 71 | 72 | # create logger 73 | if args.record_path is None: 74 | args.record_path = os.path.join(save_path, 'log') 75 | 76 | n_gpus = args.n_gpus if args.n_gpus else torch.cuda.device_count() 77 | if not n_gpus: 78 | raise NotImplementedError('No GPU found. Only GPU training is supported.') 79 | 80 | if args.distributed: 81 | print_msg('Distributed mode with {} GPUs'.format(n_gpus)) 82 | args.world_size = n_gpus * args.nodes 83 | os.environ['MASTER_ADDR'] = args.addr 84 | os.environ['MASTER_PORT'] = args.port 85 | mp.spawn(worker, nprocs=n_gpus, args=(n_gpus, args)) 86 | else: 87 | print_msg('Single GPU mode') 88 | worker(0, n_gpus, args) 89 | 90 | 91 | def worker(gpu, n_gpus, args): 92 | if args.distributed: 93 | torch.cuda.set_device(gpu) 94 | args.gpu = gpu 95 | args.rank = args.rank * n_gpus + gpu 96 | dist.init_process_group( 97 | backend=args.backend, 98 | init_method='env://', 99 | world_size=args.world_size, 100 | rank=args.rank 101 | ) 102 | torch.distributed.barrier() 103 | 104 | args.batch_size = int(args.batch_size / args.world_size) 105 | args.num_workers = int((args.num_workers + n_gpus - 1) / n_gpus) 106 | 107 | # suppress printing 108 | if args.gpu != 0 or args.rank != 0: 109 | def print_pass(*args): 110 | pass 111 | builtins.print = print_pass 112 | 113 | if args.seed >= 0: 114 | set_random_seed(args.seed + args.rank) 115 | 116 | model = build_model(args) 117 | train_dataset = build_dataset(args) 118 | logger = SummaryWriter(args.record_path) if is_main(args) else None 119 | scaler = torch.cuda.amp.GradScaler() 120 | 121 | train( 122 | args=args, 123 | model=model, 124 | train_dataset=train_dataset, 125 | logger=logger, 126 | scaler=scaler 127 | ) 128 | torch.distributed.barrier() 129 | 130 | 131 | def set_random_seed(seed): 132 | random.seed(seed) 133 | np.random.seed(seed) 134 | torch.manual_seed(seed) 135 | torch.cuda.manual_seed(seed) 136 | # torch.backends.cudnn.deterministic = True 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /ssit.py: -------------------------------------------------------------------------------- 1 | # ===================================================================== 2 | # Based on moco-v3/moco/builder.py 3 | # https://github.com/facebookresearch/moco-v3/blob/main/moco/builder.py 4 | # ===================================================================== 5 | from functools import partial 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from vits import archs 12 | 13 | 14 | def build_model(args): 15 | assert args.arch in archs.keys(), 'Not implemented architecture.' 16 | encoder = partial( 17 | archs[args.arch], 18 | pretrained=args.pretrained, 19 | img_size=args.input_size, 20 | mask_ratio=args.mask_ratio, 21 | ) 22 | 23 | model = SSiT( 24 | encoder, 25 | dim=256, 26 | mlp_dim=4096, 27 | T=args.temperature, 28 | pool_mode=args.pool_mode, 29 | saliency_threshold=args.saliency_threshold, 30 | ) 31 | 32 | if args.distributed: 33 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 34 | model = model.cuda(args.gpu) 35 | model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], gradient_as_bucket_view=True) 36 | else: 37 | model = model.to(args.device) 38 | 39 | return model 40 | 41 | 42 | class SSiT(nn.Module): 43 | def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0, pool_mode='max', saliency_threshold=0.25): 44 | super(SSiT, self).__init__() 45 | 46 | self.T = T 47 | self.saliency_threshold = saliency_threshold 48 | 49 | # build encoders 50 | self.base_encoder = base_encoder(num_classes=mlp_dim) 51 | self.momentum_encoder = base_encoder(num_classes=mlp_dim) 52 | 53 | patch_size = self.base_encoder.patch_size 54 | if pool_mode == 'avg': 55 | self.pool = nn.AvgPool2d(kernel_size=patch_size, stride=patch_size) 56 | elif pool_mode == 'max': 57 | self.pool = nn.MaxPool2d(kernel_size=patch_size, stride=patch_size) 58 | else: 59 | self.pool = None 60 | 61 | self.build_saliency_segmentor_mlps(mlp_dim, patch_size) 62 | self.build_projector_and_predictor_mlps(dim, mlp_dim) 63 | 64 | for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): 65 | param_m.data.copy_(param_b.data) # initialize 66 | param_m.requires_grad = False # not update by gradient 67 | 68 | def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True): 69 | mlp = [] 70 | for l in range(num_layers): 71 | dim1 = input_dim if l == 0 else mlp_dim 72 | dim2 = output_dim if l == num_layers - 1 else mlp_dim 73 | 74 | mlp.append(nn.Linear(dim1, dim2, bias=False)) 75 | 76 | if l < num_layers - 1: 77 | mlp.append(nn.BatchNorm1d(dim2)) 78 | mlp.append(nn.ReLU(inplace=True)) 79 | elif last_bn: 80 | # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157 81 | # for simplicity, we further removed gamma in BN 82 | mlp.append(nn.BatchNorm1d(dim2, affine=False)) 83 | 84 | return nn.Sequential(*mlp) 85 | 86 | def build_projector_and_predictor_mlps(self, dim, mlp_dim): 87 | hidden_dim = self.base_encoder.head.weight.shape[1] 88 | del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer 89 | 90 | # projectors 91 | self.base_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim) 92 | self.momentum_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim) 93 | 94 | # predictor 95 | self.predictor = self._build_mlp(2, dim, mlp_dim, dim) 96 | 97 | def build_saliency_segmentor_mlps(self, mlp_dim, patch_size): 98 | hidden_dim = self.base_encoder.head.weight.shape[1] 99 | self.saliency_segmentor = nn.Sequential( 100 | nn.Conv2d( 101 | in_channels=hidden_dim, 102 | out_channels=patch_size ** 2, 103 | kernel_size=1, 104 | ), 105 | nn.PixelShuffle(upscale_factor=patch_size), 106 | ) 107 | 108 | @torch.no_grad() 109 | def _update_momentum_encoder(self, m): 110 | """Momentum update of the momentum encoder""" 111 | for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()): 112 | param_m.data = param_m.data * m + param_b.data * (1. - m) 113 | 114 | def contrastive_loss(self, q, k): 115 | # normalize 116 | q = nn.functional.normalize(q, dim=1) 117 | k = nn.functional.normalize(k, dim=1) 118 | k = concat_all_gather(k) 119 | 120 | # Einstein sum is more intuitive 121 | logits = torch.einsum('nc,mc->nm', [q, k]) / self.T 122 | N = logits.shape[0] # batch size per GPU 123 | rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 124 | labels = (torch.arange(N, dtype=torch.long) + N * rank).cuda() 125 | return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T) 126 | 127 | def saliency_segmentation_loss(self, f, m): 128 | f = f[:, 1:] 129 | m = (m > self.saliency_threshold).float() 130 | 131 | B, L, C = f.shape 132 | H = W = int(L ** 0.5) 133 | f = f.permute(0, 2, 1).reshape(B, C, H, W) 134 | ss = self.saliency_segmentor(f) 135 | 136 | bce = F.binary_cross_entropy_with_logits(ss, m) 137 | return bce 138 | 139 | def forward(self, x1, x2, m1, m2, m): 140 | mp1 = None if self.pool is None else self.pool(m1) 141 | mp2 = None if self.pool is None else self.pool(m2) 142 | 143 | # compute features 144 | t1, f1 = self.base_encoder(x1) 145 | t2, f2 = self.base_encoder(x2) 146 | q1 = self.predictor(t1) 147 | q2 = self.predictor(t2) 148 | 149 | with torch.no_grad(): # no gradient 150 | self._update_momentum_encoder(m) # update the momentum encoder 151 | 152 | # compute momentum features as targets 153 | k1, _ = self.momentum_encoder(x1, mp1) 154 | k2, _ = self.momentum_encoder(x2, mp2) 155 | 156 | cl_loss = self.contrastive_loss(q1, k2) + self.contrastive_loss(q2, k1) 157 | sp_loss = self.saliency_segmentation_loss(f1, m1) + self.saliency_segmentation_loss(f2, m2) 158 | return cl_loss, sp_loss 159 | 160 | 161 | @torch.no_grad() 162 | def concat_all_gather(tensor): 163 | if not torch.distributed.is_initialized(): 164 | return tensor 165 | 166 | tensors_gather = [torch.ones_like(tensor) 167 | for _ in range(torch.distributed.get_world_size())] 168 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 169 | 170 | output = torch.cat(tensors_gather, dim=0) 171 | return output 172 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | from torch.utils.data import DataLoader 9 | from torch.distributed import all_reduce, ReduceOp 10 | 11 | from funcs import is_main, to_devices, print_msg 12 | 13 | 14 | def train(args, model, train_dataset, logger=None, scaler=None): 15 | optimizer = initialize_optimizer(args, model) 16 | train_sampler = initialize_sampler(args, train_dataset) if args.distributed == True else None 17 | train_loader = initialize_dataloader(args, train_dataset, train_sampler) 18 | 19 | if args.resume: 20 | resume(args, model, optimizer, scaler) 21 | 22 | # start training 23 | model.train() 24 | avg_cl_loss = 0 25 | avg_ss_loss = 0 26 | for epoch in range(args.start_epoch, args.epochs): 27 | if args.distributed: 28 | train_sampler.set_epoch(epoch) 29 | 30 | epoch_cl_loss = 0 31 | epoch_ss_loss = 0 32 | current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) 33 | progress = enumerate(train_loader) 34 | if is_main(args) and not args.disable_progress: 35 | progress = tqdm(progress) 36 | for step, train_data in progress: 37 | scheduler_step = epoch + step / len(train_loader) 38 | lr = adjust_learning_rate(args, optimizer, scheduler_step) 39 | moco_m = adjust_moco_momentum(args, scheduler_step) 40 | ss = adjust_lambda_ss(args, scheduler_step) if args.ss_decay else args.ss 41 | 42 | X1, X2, M1, M2 = train_data 43 | X1, X2, M1, M2 = to_devices(args, X1, X2, M1, M2) 44 | 45 | # forward 46 | with torch.cuda.amp.autocast(True): 47 | cl_loss, ss_loss = model(X1, X2, M1, M2, moco_m) 48 | loss = args.cl * cl_loss + ss * ss_loss 49 | 50 | # backward 51 | optimizer.zero_grad() 52 | scaler.scale(loss).backward() 53 | 54 | scaler.unscale_(optimizer) 55 | nn.utils.clip_grad_norm_(model.parameters(), 1) 56 | 57 | scaler.step(optimizer) 58 | scaler.update() 59 | 60 | if args.distributed: 61 | all_reduce(cl_loss, ReduceOp.AVG) 62 | all_reduce(ss_loss, ReduceOp.AVG) 63 | 64 | # metrics 65 | if is_main(args): 66 | epoch_cl_loss += cl_loss 67 | epoch_ss_loss += ss_loss 68 | avg_cl_loss = epoch_cl_loss / (step + 1) 69 | avg_ss_loss = epoch_ss_loss / (step + 1) 70 | 71 | message = '[{}] epoch: {}/{}, cl loss: {:.6f}, ss loss: {:.6f}, lr: {:.6f}, moco_m: {:.6f}'.format( 72 | current_time, epoch + 1, args.epochs, avg_cl_loss, avg_ss_loss, lr, moco_m) 73 | if not args.disable_progress: 74 | progress.set_description(message) 75 | 76 | if is_main(args) and args.disable_progress: 77 | print(message) 78 | 79 | if is_main(args) and (epoch + 1) % args.save_interval == 0 and (epoch + 1) < args.epochs: 80 | save_checkpoint(args, epoch, model, optimizer, scaler) 81 | 82 | # record 83 | if is_main(args) and logger: 84 | logger.add_scalar('contrastive loss', avg_cl_loss, epoch) 85 | logger.add_scalar('saliency segmentation loss', avg_ss_loss, epoch) 86 | logger.add_scalar('learning rate', lr, epoch) 87 | logger.add_scalar('moco momentum', moco_m, epoch) 88 | 89 | # save final model 90 | if is_main(args): 91 | save_checkpoint(args, epoch, model, optimizer, scaler) 92 | if logger: 93 | logger.close() 94 | 95 | 96 | # define data loader 97 | def initialize_dataloader(args, train_dataset, train_sampler): 98 | batch_size = args.batch_size 99 | num_workers = args.num_workers 100 | train_loader = DataLoader( 101 | train_dataset, 102 | batch_size=batch_size, 103 | shuffle=(train_sampler is None), 104 | sampler=train_sampler, 105 | num_workers=num_workers, 106 | drop_last=True, 107 | pin_memory=True 108 | ) 109 | return train_loader 110 | 111 | 112 | # define optmizer 113 | def initialize_optimizer(args, model): 114 | optimizer_strategy = args.optimizer 115 | learning_rate = args.learning_rate 116 | weight_decay = args.weight_decay 117 | momentum = args.momentum 118 | 119 | if optimizer_strategy == 'SGD': 120 | optimizer = torch.optim.SGD( 121 | model.parameters(), 122 | lr=learning_rate, 123 | momentum=momentum, 124 | weight_decay=weight_decay 125 | ) 126 | elif optimizer_strategy == 'ADAM': 127 | optimizer = torch.optim.Adam( 128 | model.parameters(), 129 | lr=learning_rate, 130 | weight_decay=weight_decay 131 | ) 132 | elif optimizer_strategy == 'ADAMW': 133 | optimizer = torch.optim.AdamW( 134 | model.parameters(), 135 | lr=learning_rate, 136 | weight_decay=weight_decay 137 | ) 138 | else: 139 | raise NotImplementedError('Not implemented optimizer.') 140 | 141 | return optimizer 142 | 143 | 144 | def initialize_sampler(args, train_dataset): 145 | train_sampler = torch.utils.data.distributed.DistributedSampler( 146 | train_dataset, 147 | num_replicas=args.world_size, 148 | rank=args.rank 149 | ) 150 | return train_sampler 151 | 152 | 153 | def adjust_learning_rate(args, optimizer, epoch): 154 | """Decays the learning rate with half-cycle cosine after warmup""" 155 | if epoch < args.warmup_epochs: 156 | lr = args.learning_rate * epoch / args.warmup_epochs 157 | else: 158 | lr = args.learning_rate * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 159 | for param_group in optimizer.param_groups: 160 | param_group['lr'] = lr 161 | return lr 162 | 163 | 164 | def adjust_moco_momentum(args, epoch): 165 | """Adjust moco momentum based on current epoch""" 166 | m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.moco_m) 167 | return m 168 | 169 | 170 | def adjust_lambda_ss(args, epoch): 171 | """Adjust moco momentum based on current epoch""" 172 | ss = args.ss * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) 173 | return ss 174 | 175 | 176 | def save_checkpoint(args, epoch, model, optimizer, scaler): 177 | checkpoint = { 178 | 'epoch': epoch, 179 | 'state_dict': model.state_dict(), 180 | 'optimizer' : optimizer.state_dict(), 181 | 'scaler': scaler.state_dict(), 182 | } 183 | model = model.module if args.distributed else model 184 | 185 | torch.save(checkpoint, os.path.join(args.save_path, 'checkpoint.pt')) 186 | torch.save(model, os.path.join(args.save_path, 'epoch_{}.pt'.format(epoch + 1))) 187 | print_msg('Saved checkpoint to {}'.format(args.save_path)) 188 | 189 | 190 | def resume(args, model, optimizer, scaler): 191 | checkpoint_path = os.path.join(args.save_path, 'checkpoint.pt') 192 | if os.path.exists(checkpoint_path): 193 | print_msg('Loading checkpoint {}'.format(checkpoint_path)) 194 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 195 | args.start_epoch = checkpoint['epoch'] + 1 196 | model.load_state_dict(checkpoint['state_dict']) 197 | optimizer.load_state_dict(checkpoint['optimizer']) 198 | scaler.load_state_dict(checkpoint['scaler']) 199 | print_msg('Loaded checkpoint {} from epoch {}'.format(checkpoint_path, checkpoint['epoch'])) 200 | else: 201 | print_msg('No checkpoint found at {}'.format(checkpoint_path)) -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | 4 | import numpy as np 5 | from torchvision import transforms 6 | from torch.utils.data import Dataset 7 | from PIL import Image, ImageFilter, ImageOps 8 | from torchvision.transforms import functional as F 9 | 10 | from funcs import print_msg 11 | 12 | 13 | def build_dataset(args): 14 | transform = data_transforms(args.input_size) 15 | datasets = generate_dataset_from_pickle(args.data_index, transform, args.dataset_ratio) 16 | return datasets 17 | 18 | 19 | def generate_dataset_from_pickle(pkl, transform, ratio=1.0): 20 | data = pickle.load(open(pkl, 'rb')) 21 | if ratio < 1.0: 22 | random.shuffle(data) 23 | data = data[:int(len(data)*ratio)] 24 | print_msg('Number of training samples: {}'.format(len(data))) 25 | 26 | train_dataset = PairGenerator(data, transform) 27 | return train_dataset 28 | 29 | 30 | def data_transforms(input_size): 31 | mean = [0.425753653049469, 0.29737451672554016, 0.21293757855892181] # eyepacs mean 32 | std = [0.27670302987098694, 0.20240527391433716, 0.1686241775751114] # eyepacs std 33 | data_aug = { 34 | 'brightness': 0.4, 35 | 'contrast': 0.4, 36 | 'saturation': 0.2, 37 | 'hue': 0.1, 38 | 'scale_stu': (0.08, 0.8), 39 | 'scale_tea': (0.8, 1.0), 40 | 'degrees': (-180, 180), 41 | } 42 | 43 | transform = TransformWithMask(input_size, mean, std, data_aug) 44 | return transform 45 | 46 | 47 | class PairGenerator(Dataset): 48 | def __init__(self, imgs, transform=None): 49 | super(PairGenerator, self).__init__() 50 | self.imgs = imgs 51 | self.transform = transform 52 | 53 | def __len__(self): 54 | return len(self.imgs) 55 | 56 | def __getitem__(self, index): 57 | img_path, mask_path = self.imgs[index] 58 | img = self.pil_loader(img_path) 59 | mask = self.npy_loader(mask_path) 60 | mask = Image.fromarray(np.uint8(mask*255)) 61 | if self.transform is not None: 62 | img_stu, img_tea, mask_stu, mask_tea = self.transform(img, mask) 63 | 64 | return img_stu, img_tea, mask_stu, mask_tea 65 | 66 | def pil_loader(self, path): 67 | with open(path, 'rb') as f: 68 | img = Image.open(f) 69 | return img.convert('RGB') 70 | 71 | def npy_loader(self, path): 72 | with open(path, 'rb') as f: 73 | img = np.load(f) 74 | return img 75 | 76 | 77 | class TwoCropTransform(): 78 | def __init__(self, transform): 79 | self.transform = transform 80 | 81 | def __call__(self, x): 82 | return [self.transform(x), self.transform(x)] 83 | 84 | 85 | class BYOLTransform(): 86 | def __init__(self, transform_stu, transform_tea): 87 | self.transform_stu = transform_stu 88 | self.transform_tea = transform_tea 89 | 90 | def __call__(self, x1, x2): 91 | return [self.transform_stu(x1), self.transform_tea(x2)] 92 | 93 | 94 | class GaussianBlur(object): 95 | """Gaussian blur augmentation from SimCLR: https://arxiv.org/abs/2002.05709""" 96 | 97 | def __init__(self, sigma=[.1, 2.]): 98 | self.sigma = sigma 99 | 100 | def __call__(self, x): 101 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 102 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 103 | return x 104 | 105 | 106 | class Solarize(object): 107 | """Solarize augmentation from BYOL: https://arxiv.org/abs/2006.07733""" 108 | 109 | def __call__(self, x): 110 | return ImageOps.solarize(x) 111 | 112 | 113 | class TransformWithMask(object): 114 | def __init__(self, input_size, mean, std, data_aug): 115 | scale_stu = data_aug['scale_stu'] 116 | scale_tea = data_aug['scale_tea'] 117 | jitter_param = (data_aug['brightness'], data_aug['contrast'], data_aug['saturation'], data_aug['hue']) 118 | degree = data_aug['degrees'] 119 | 120 | self.resized_crop_stu = transforms.RandomResizedCrop(input_size, scale=scale_stu) 121 | self.color_jitter_stu = transforms.RandomApply([transforms.ColorJitter(*jitter_param)], p=0.8) 122 | self.grayscale_stu = transforms.RandomGrayscale(p=0.2) 123 | self.gaussian_blur_stu = transforms.RandomApply([GaussianBlur([.1, 2.])], p=1.0) 124 | self.rotation_stu = transforms.RandomRotation(degree) 125 | self.p_rotation_stu = 0.8 126 | self.p_hflip_stu = 0.5 127 | self.p_vflip_stu = 0.5 128 | 129 | self.resized_crop_tea = transforms.RandomResizedCrop(input_size, scale=scale_tea) 130 | self.color_jitter_tea = transforms.RandomApply([transforms.ColorJitter(*jitter_param)], p=0.8) 131 | self.grayscale_tea = transforms.RandomGrayscale(p=0.2) 132 | self.gaussian_blur_tea = transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.1) 133 | self.rotation_tea = transforms.RandomRotation(degree) 134 | self.solarize_tea = transforms.RandomApply([Solarize()], p=0.2) 135 | self.p_rotation_tea = 0.8 136 | self.p_hflip_tea = 0.5 137 | self.p_vflip_tea = 0.5 138 | 139 | self.to_tensor = transforms.ToTensor() 140 | self.normalize = transforms.Normalize(mean, std) 141 | 142 | def __call__(self, img, mask): 143 | img_stu, mask_stu = self.resized_crop_with_mask(self.resized_crop_stu, img, mask) 144 | img_stu = self.color_jitter_stu(img_stu) 145 | img_stu = self.grayscale_stu(img_stu) 146 | img_stu = self.gaussian_blur_stu(img_stu) 147 | img_stu, mask_stu = self.rotation_with_mask(self.rotation_stu, img_stu, mask_stu, self.p_rotation_stu) 148 | img_stu, mask_stu = self.horizontal_flip_with_mask(img_stu, mask_stu, self.p_hflip_stu) 149 | img_stu, mask_stu = self.vertical_flip_with_mask(img_stu, mask_stu, self.p_vflip_stu) 150 | img_stu, mask_stu = self.to_tensor(img_stu), self.to_tensor(mask_stu) 151 | img_stu = self.normalize(img_stu) 152 | 153 | img_tea, mask_tea = self.resized_crop_with_mask(self.resized_crop_tea, img, mask) 154 | img_tea = self.color_jitter_tea(img_tea) 155 | img_tea = self.grayscale_tea(img_tea) 156 | img_tea = self.gaussian_blur_tea(img_tea) 157 | img_tea = self.solarize_tea(img_tea) 158 | img_tea, mask_tea = self.rotation_with_mask(self.rotation_tea, img_tea, mask_tea, self.p_rotation_tea) 159 | img_tea, mask_tea = self.horizontal_flip_with_mask(img_tea, mask_tea, self.p_hflip_tea) 160 | img_tea, mask_tea = self.vertical_flip_with_mask(img_tea, mask_tea, self.p_vflip_tea) 161 | img_tea, mask_tea = self.to_tensor(img_tea), self.to_tensor(mask_tea) 162 | img_tea = self.normalize(img_tea) 163 | 164 | return img_stu, img_tea, mask_stu, mask_tea 165 | 166 | def resized_crop_with_mask(self, tf, img, mask): 167 | assert isinstance(tf, transforms.RandomResizedCrop) 168 | i, j, h, w = tf.get_params(img, tf.scale, tf.ratio) 169 | img = F.resized_crop(img, i, j, h, w, tf.size, tf.interpolation) 170 | mask = F.resized_crop(mask, i, j, h, w, tf.size, tf.interpolation) 171 | return img, mask 172 | 173 | def rotation_with_mask(self, tf, img, mask, p): 174 | assert isinstance(tf, transforms.RandomRotation) 175 | if random.random() < p: 176 | angle = tf.get_params(tf.degrees) 177 | img = F.rotate(img, angle, tf.resample, tf.expand, tf.center, tf.fill) 178 | mask = F.rotate(mask, angle, tf.resample, tf.expand, tf.center, tf.fill) 179 | return img, mask 180 | 181 | def horizontal_flip_with_mask(self, img, mask, p): 182 | if random.random() < p: 183 | img = F.hflip(img) 184 | mask = F.hflip(mask) 185 | return img, mask 186 | 187 | def vertical_flip_with_mask(self, img, mask, p): 188 | if random.random() < p: 189 | img = F.vflip(img) 190 | mask = F.vflip(mask) 191 | return img, mask 192 | -------------------------------------------------------------------------------- /utils/attn_visualize.py: -------------------------------------------------------------------------------- 1 | # ================================================================================== 2 | # Based on https://github.com/facebookresearch/dino/blob/main/visualize_attention.py 3 | # ================================================================================== 4 | import os 5 | import sys 6 | import cv2 7 | import random 8 | import argparse 9 | import colorsys 10 | 11 | import torch 12 | import skimage.io 13 | import torchvision 14 | import numpy as np 15 | import torch.nn as nn 16 | import matplotlib.pyplot as plt 17 | from PIL import Image, ImageFilter 18 | from matplotlib.patches import Polygon 19 | from skimage.measure import find_contours 20 | from torchvision import transforms as pth_transforms 21 | 22 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) 23 | from vits import archs, resize_pos_embed 24 | from funcs import load_checkpoint 25 | 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--arch', type=str, default='ViT-S-p16', help='Architecture (support only ViT).') 29 | parser.add_argument('--patch-size', default=16, type=int, help='Patch resolution of the model.') 30 | parser.add_argument('--checkpoint', default='', type=str, 31 | help="Path to pretrained weights to load.") 32 | parser.add_argument("--checkpoint-key", default="base_encoder", type=str, 33 | help='Key to use in the checkpoint (example: "teacher")') 34 | parser.add_argument("--image-folder", default=None, type=str, help="Path of the image to load.") 35 | parser.add_argument("--image-size", default=1024, type=int, nargs="+", help="Resize image.") 36 | parser.add_argument('--output-dir', default='.', help='Path where to save visualizations.') 37 | parser.add_argument("--threshold", type=float, default=None, help="""We visualize masks 38 | obtained by thresholding the self-attention maps to keep xx% of the mass.""") 39 | 40 | 41 | def apply_mask(image, mask, color, alpha=0.5): 42 | for c in range(3): 43 | image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255 44 | return image 45 | 46 | 47 | def random_colors(N, bright=True): 48 | """ 49 | Generate random colors. 50 | """ 51 | brightness = 1.0 if bright else 0.7 52 | hsv = [(i / N, 1, brightness) for i in range(N)] 53 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 54 | random.shuffle(colors) 55 | return colors 56 | 57 | 58 | def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5): 59 | fig = plt.figure(figsize=figsize, frameon=False) 60 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 61 | ax.set_axis_off() 62 | fig.add_axes(ax) 63 | ax = plt.gca() 64 | 65 | N = 1 66 | mask = mask[None, :, :] 67 | # Generate random colors 68 | colors = random_colors(N) 69 | 70 | # Show area outside image boundaries. 71 | height, width = image.shape[:2] 72 | margin = 0 73 | ax.set_ylim(height + margin, -margin) 74 | ax.set_xlim(-margin, width + margin) 75 | ax.axis('off') 76 | masked_image = image.astype(np.uint32).copy() 77 | for i in range(N): 78 | color = colors[i] 79 | _mask = mask[i] 80 | if blur: 81 | _mask = cv2.blur(_mask,(10,10)) 82 | # Mask 83 | masked_image = apply_mask(masked_image, _mask, color, alpha) 84 | # Mask Polygon 85 | # Pad to ensure proper polygons for masks that touch image edges. 86 | if contour: 87 | padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2)) 88 | padded_mask[1:-1, 1:-1] = _mask 89 | contours = find_contours(padded_mask, 0.5) 90 | for verts in contours: 91 | # Subtract the padding and flip (y, x) to (x, y) 92 | verts = np.fliplr(verts) - 1 93 | p = Polygon(verts, facecolor="none", edgecolor=color) 94 | ax.add_patch(p) 95 | ax.imshow(masked_image.astype(np.uint8), aspect='auto') 96 | fig.savefig(fname) 97 | print(f"{fname} saved.") 98 | return 99 | 100 | 101 | def main(): 102 | args = parser.parse_args() 103 | 104 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 105 | # build model 106 | 107 | model = archs[args.arch]( 108 | pretrained=False, 109 | num_classes=1, 110 | img_size=args.image_size, 111 | ) 112 | 113 | linear_key = 'head' 114 | checkpoint_key = args.checkpoint_key 115 | load_checkpoint(model, args.checkpoint, checkpoint_key, linear_key) 116 | 117 | for p in model.parameters(): 118 | p.requires_grad = False 119 | model.eval() 120 | model.to(device) 121 | 122 | if isinstance(args.image_size, list): 123 | if len(args.image_size) == 1: 124 | args.image_size = (args.image_size, args.image_size) 125 | elif len(args.image_size) == 2: 126 | args.image_size = tuple(args.image_size) 127 | else: 128 | raise ValueError("image_size list must have one or two elements") 129 | 130 | # Ensure correct types 131 | assert isinstance(args.image_size, (int, tuple)), "image_size must be an integer or a tuple of integers" 132 | assert all(isinstance(m, float) for m in mean), "mean values must be floats" 133 | assert all(isinstance(s, float) for s in std), "std values must be floats" 134 | 135 | # open image 136 | mean = [0.425753653049469, 0.29737451672554016, 0.21293757855892181] 137 | std = [0.27670302987098694, 0.20240527391433716, 0.1686241775751114] 138 | transform = pth_transforms.Compose([ 139 | pth_transforms.Resize(args.image_size), 140 | pth_transforms.ToTensor(), 141 | pth_transforms.Normalize(mean, std), 142 | ]) 143 | for root, _, imgs in os.walk(args.image_folder): 144 | for img in imgs: 145 | name = img.split('.')[0] 146 | out_dir = os.path.join(args.output_dir, name) 147 | img_path = os.path.join(root, img) 148 | 149 | with open(img_path, 'rb') as f: 150 | raw_img = Image.open(f) 151 | raw_img = raw_img.convert('RGB') 152 | 153 | img = transform(raw_img) 154 | H, W = img.shape[1:] 155 | 156 | # make the image divisible by the patch size 157 | w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size 158 | img = img[:, :w, :h].unsqueeze(0) 159 | 160 | w_featmap = img.shape[-2] // args.patch_size 161 | h_featmap = img.shape[-1] // args.patch_size 162 | 163 | _, f = model(img.to(device)) 164 | f = f[:, 1:] 165 | B, L, C = f.shape 166 | H = W = int(L ** 0.5) 167 | f = f.permute(0, 2, 1).reshape(B, C, H, W) 168 | 169 | attentions = model.get_last_selfattention(img.to(device)) 170 | nh = attentions.shape[1] # number of head 171 | 172 | # we keep only the output patch attention 173 | attentions = attentions[0, :, 0, 1:].reshape(nh, -1) 174 | attentions = torch.cat([attentions, attentions.mean(dim=0, keepdim=True)], dim=0) 175 | nh += 1 176 | 177 | if args.threshold is not None: 178 | # we keep only a certain percentage of the mass 179 | val, idx = torch.sort(attentions) 180 | val /= torch.sum(val, dim=1, keepdim=True) 181 | cumval = torch.cumsum(val, dim=1) 182 | th_attn = cumval > (1 - args.threshold) 183 | idx2 = torch.argsort(idx) 184 | for head in range(nh): 185 | th_attn[head] = th_attn[head][idx2[head]] 186 | th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() 187 | # interpolate 188 | th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() 189 | 190 | attentions = attentions.reshape(nh, w_featmap, h_featmap) 191 | attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() 192 | 193 | # save attentions heatmaps 194 | os.makedirs(out_dir, exist_ok=True) 195 | torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(out_dir, "img.png")) 196 | for j in range(nh): 197 | fname = os.path.join(out_dir, "attn-head" + str(j) + ".png") 198 | plt.imsave(fname=fname, arr=attentions[j], format='png') 199 | print(f"{fname} saved.") 200 | 201 | normalized_attn = [] 202 | for j in range(nh-1): 203 | normalized_attn.append((attentions[j] - attentions[j].min()) / (attentions[j].max() - attentions[j].min())) 204 | normalized_attn = np.stack(normalized_attn, axis=0) 205 | 206 | mean_attn = np.mean(normalized_attn, axis=0) 207 | fname = os.path.join(out_dir, "attn-head" + "-mean.png") 208 | plt.imsave(fname=fname, arr=mean_attn, format='png') 209 | print(f"{fname} saved.") 210 | 211 | max_attn = np.max(normalized_attn, axis=0) 212 | fname = os.path.join(out_dir, "attn-head" + "-max.png") 213 | plt.imsave(fname=fname, arr=max_attn, format='png') 214 | print(f"{fname} saved.") 215 | 216 | if args.threshold is not None: 217 | image = skimage.io.imread(os.path.join(out_dir, "img.png")) 218 | for j in range(nh): 219 | display_instances(image, th_attn[j], fname=os.path.join(out_dir, "mask_th" + str(args.threshold) + "_head" + str(j) +".png"), blur=False) 220 | 221 | 222 | class GaussianBlur(object): 223 | """Gaussian blur augmentation from SimCLR: https://arxiv.org/abs/2002.05709""" 224 | 225 | def __init__(self, sigma=[.1, 2.]): 226 | self.sigma = sigma 227 | 228 | def __call__(self, x): 229 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 230 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 231 | return x 232 | 233 | 234 | if __name__ == '__main__': 235 | main() 236 | -------------------------------------------------------------------------------- /knn.py: -------------------------------------------------------------------------------- 1 | # ======================================================================= 2 | # Based on https://github.com/facebookresearch/dino/blob/main/eval_knn.py 3 | # ======================================================================= 4 | 5 | import os 6 | import argparse 7 | 8 | import torch 9 | from torch import nn 10 | from torchvision import datasets 11 | from torchvision import transforms as pth_transforms 12 | 13 | from vits import archs 14 | from funcs import * 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataset', type=str, help='Dataset to evaluate (ddr / messidor2 / aptos2019).') 19 | parser.add_argument('--arch', default='ViT-S-p16', type=str, help='Architecture (support only ViT).') 20 | parser.add_argument('--data-path', type=str, help='Path to the fundus dataset.') 21 | parser.add_argument('--checkpoint', default='', type=str, help="Path to pretrained weights to evaluate.") 22 | parser.add_argument('--input-size', default=384, type=int, help='Input size of the model.') 23 | parser.add_argument('--batch-size', default=16, type=int, help='Per-GPU batch-size') 24 | parser.add_argument('--nb-knn', default=[5, 10, 20], nargs='+', type=int, 25 | help='Number of NN to use. 20 is usually working the best.') 26 | parser.add_argument('--temperature', default=0.07, type=float, 27 | help='Temperature used in the voting coefficient') 28 | parser.add_argument('--network', default='', type=str, help="network") 29 | parser.add_argument('--use-cuda', default=True, 30 | help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM") 31 | parser.add_argument('--patch-size', default=16, type=int, help='Patch resolution of the model.') 32 | parser.add_argument("--checkpoint-key", default="base_encoder", type=str, 33 | help='Key to use in the checkpoint') 34 | parser.add_argument('--dump-features', default=None, 35 | help='Path where to save computed features, empty for no saving') 36 | parser.add_argument('--load-features', default=None, help="""If the features have 37 | already been computed, where to find them.""") 38 | parser.add_argument('--num-workers', default=12, type=int, help='Number of data loading workers per GPU.') 39 | parser.add_argument('--device', default='cuda', type=str, help='Device to use.') 40 | 41 | 42 | def extract_feature_pipeline(args): 43 | # ============ preparing data ... ============ 44 | mean, std = get_dataset_stats(args.dataset) 45 | transform = pth_transforms.Compose([ 46 | pth_transforms.Resize((args.input_size, args.input_size)), 47 | pth_transforms.ToTensor(), 48 | pth_transforms.Normalize(mean, std), 49 | ]) 50 | 51 | dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform) 52 | dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "test"), transform=transform) 53 | data_loader_train = torch.utils.data.DataLoader( 54 | dataset_train, 55 | batch_size=args.batch_size, 56 | num_workers=args.num_workers, 57 | pin_memory=True, 58 | drop_last=False, 59 | ) 60 | data_loader_val = torch.utils.data.DataLoader( 61 | dataset_val, 62 | batch_size=args.batch_size, 63 | num_workers=args.num_workers, 64 | pin_memory=True, 65 | drop_last=False, 66 | ) 67 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} test imgs.") 68 | 69 | # ============ building network ... ============ 70 | model = archs[args.arch]( 71 | pretrained=False, 72 | num_classes=1, 73 | img_size=args.input_size, 74 | ) 75 | 76 | linear_key = 'head' 77 | checkpoint_key = args.checkpoint_key 78 | if args.checkpoint: 79 | load_checkpoint(model, args.checkpoint, checkpoint_key, linear_key) 80 | else: 81 | print_msg('No checkpoint provided. Training from scratch.') 82 | 83 | model = model.to(args.device) 84 | model.eval() 85 | 86 | for _, param in model.named_parameters(): 87 | param.requires_grad = False 88 | 89 | # ============ extract features ... ============ 90 | print("Extracting features for train set...") 91 | train_features = extract_features(model, data_loader_train, args.use_cuda) 92 | print("Extracting features for test set...") 93 | test_features = extract_features(model, data_loader_val, args.use_cuda) 94 | 95 | train_features = nn.functional.normalize(train_features, dim=1, p=2) 96 | test_features = nn.functional.normalize(test_features, dim=1, p=2) 97 | 98 | train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long() 99 | test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long() 100 | # save features and labels 101 | if args.dump_features: 102 | torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth")) 103 | torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth")) 104 | torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth")) 105 | torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth")) 106 | return train_features, test_features, train_labels, test_labels 107 | 108 | 109 | @torch.no_grad() 110 | def extract_features(model, data_loader, use_cuda=True, multiscale=False): 111 | features = None 112 | for samples, index in data_loader: 113 | samples = samples.cuda(non_blocking=True) 114 | index = index.cuda(non_blocking=True) 115 | if multiscale: 116 | feats = multi_scale(samples, model) 117 | else: 118 | feats = model.forward_features(samples).clone()[:,0] 119 | 120 | # init storage feature matrix 121 | if features is None: 122 | features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) 123 | if use_cuda: 124 | features = features.cuda(non_blocking=True) 125 | print(f"Storing features into tensor of shape {features.shape}") 126 | 127 | # update storage feature matrix 128 | if use_cuda: 129 | features.index_copy_(0, index, feats) 130 | else: 131 | features.index_copy_(0, index.cpu(), feats.cpu()) 132 | return features 133 | 134 | 135 | @torch.no_grad() 136 | def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=5): 137 | top1, total = 0.0, 0 138 | train_features = train_features.t() 139 | num_test_images, num_chunks = test_labels.shape[0], 50 140 | imgs_per_chunk = num_test_images // num_chunks 141 | retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device) 142 | 143 | conf_mat = torch.zeros(num_classes, num_classes).to(train_features.device) 144 | for idx in range(0, num_test_images, imgs_per_chunk): 145 | # get the features for test images 146 | features = test_features[ 147 | idx : min((idx + imgs_per_chunk), num_test_images), : 148 | ] 149 | targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)] 150 | batch_size = targets.shape[0] 151 | 152 | # calculate the dot product and compute top-k neighbors 153 | similarity = torch.mm(features, train_features) 154 | distances, indices = similarity.topk(k, largest=True, sorted=True) 155 | candidates = train_labels.view(1, -1).expand(batch_size, -1) 156 | retrieved_neighbors = torch.gather(candidates, 1, indices) 157 | 158 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() 159 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) 160 | distances_transform = distances.clone().div_(T).exp_() 161 | probs = torch.sum( 162 | torch.mul( 163 | retrieval_one_hot.view(batch_size, -1, num_classes), 164 | distances_transform.view(batch_size, -1, 1), 165 | ), 166 | 1, 167 | ) 168 | _, predictions = probs.sort(1, True) 169 | 170 | # find the predictions that match the target 171 | correct = predictions.eq(targets.data.view(-1, 1)) 172 | top1 = top1 + correct.narrow(1, 0, 1).sum().item() 173 | total += targets.size(0) 174 | 175 | # kappa 176 | tgt = targets.data.view(-1, 1) 177 | for i, p in enumerate(predictions.narrow(1, 0, 1)): 178 | conf_mat[int(tgt[i])][int(p.item())] += 1 179 | 180 | conf_mat = conf_mat.cpu().numpy() 181 | top1 = top1 * 100.0 / total 182 | kappa = quadratic_weighted_kappa(conf_mat) 183 | f1 = conf_mat[1][1] / (conf_mat[1][1] + 0.5 * (conf_mat[0][1] + conf_mat[1][0])) 184 | return top1, kappa, f1 185 | 186 | 187 | def multi_scale(samples, model): 188 | v = None 189 | for s in [1, 1/2**(1/2), 1/2]: # we use 3 different scales 190 | if s == 1: 191 | inp = samples.clone() 192 | else: 193 | inp = nn.functional.interpolate(samples, scale_factor=s, mode='bilinear', align_corners=False) 194 | feats = model.forward_features(inp)[:,0].clone() 195 | if v is None: 196 | v = feats 197 | else: 198 | v += feats 199 | v /= 3 200 | v /= v.norm() 201 | return 202 | 203 | 204 | class ReturnIndexDataset(datasets.ImageFolder): 205 | def __getitem__(self, idx): 206 | img, lab = super(ReturnIndexDataset, self).__getitem__(idx) 207 | return img, idx 208 | 209 | 210 | if __name__ == '__main__': 211 | args = parser.parse_args() 212 | 213 | if args.load_features: 214 | train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth")) 215 | test_features = torch.load(os.path.join(args.load_features, "testfeat.pth")) 216 | train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth")) 217 | test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth")) 218 | else: 219 | # need to extract features ! 220 | train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args) 221 | 222 | train_features = train_features.cuda() 223 | test_features = test_features.cuda() 224 | train_labels = train_labels.cuda() 225 | test_labels = test_labels.cuda() 226 | 227 | print("Features are ready!\nStart the k-NN classification.") 228 | print("Evaluating on {}:".format(args.dataset)) 229 | for k in args.nb_knn: 230 | acc, kappa, _ = knn_classifier(train_features, train_labels, 231 | test_features, test_labels, k, args.temperature) 232 | print(f"{k}-NN classifier result: Acc: {acc}, Kappa: {kappa}") 233 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import argparse 5 | 6 | import torch 7 | import numpy as np 8 | import torch.nn as nn 9 | from PIL import Image 10 | from tqdm import tqdm 11 | from torchvision import datasets 12 | from torchvision import transforms 13 | from torch.utils.data import DataLoader 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | from funcs import * 17 | from vits import archs 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--dataset', type=str, default=None, help='ddr / aptos2019 / messidor2') 22 | parser.add_argument('--arch', type=str, default='ViT-S-p16', help='model architecture') 23 | parser.add_argument('--data-path', type=str, help='dataset folder') 24 | parser.add_argument('--save-path', type=str, default='./eval_checkpoints', help='save path') 25 | parser.add_argument('--log-path', type=str, default='./log', help='log path') 26 | parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint path') 27 | parser.add_argument('--checkpoint-key', type=str, default='base_encoder', help='base_encoder / momentum_encoder') 28 | parser.add_argument('--linear', action='store_true', help='use linear eval') 29 | parser.add_argument('--num-classes', type=int, default=5, help='number of classes') 30 | parser.add_argument('--seed', type=int, default=0, help='random seed') 31 | parser.add_argument('--device', type=str, default='cuda', help='device') 32 | 33 | parser.add_argument('--epochs', type=int, default=25, help='number of epochs') 34 | parser.add_argument('--input-size', type=int, default=384, help='input size') 35 | parser.add_argument('--learning-rate', type=float, default=0.00002, help='learning rate') 36 | parser.add_argument('--criterion', type=str, default='mse', help='mse / ce') 37 | parser.add_argument('--optimizer', type=str, default='ADAM', help='optimizer') 38 | parser.add_argument('--weight-decay', type=float, default=0.00001, help='weight decay') 39 | parser.add_argument('--batch-size', type=int, default=16, help='batch size') 40 | parser.add_argument('--num-workers', type=int, default=8, help='number of workers') 41 | parser.add_argument('--kappa-prior', action='store_true', help='use kappa as best model indicator') 42 | parser.add_argument('--eval-interval', type=int, default=1, help='the epoch interval of evaluating model on val dataset') 43 | parser.add_argument('--save-interval', type=int, default=5, help='the epoch interval of saving model') 44 | parser.add_argument('--disable-progress', action='store_true', help='disable progress bar') 45 | 46 | 47 | def main(): 48 | args = parser.parse_args() 49 | 50 | save_path = args.save_path 51 | log_path = args.log_path 52 | if log_path is None: 53 | log_path = os.path.join(save_path, 'log') 54 | os.makedirs(save_path, exist_ok=True) 55 | logger = SummaryWriter(log_path) 56 | 57 | set_random_seed(args.seed) 58 | model = generate_model(args) 59 | train_dataset, test_dataset, val_dataset = generate_dataset(args) 60 | estimator = Estimator(args.criterion, args.num_classes) 61 | scaler = torch.cuda.amp.GradScaler() 62 | train( 63 | args=args, 64 | model=model, 65 | train_dataset=train_dataset, 66 | val_dataset=val_dataset, 67 | estimator=estimator, 68 | logger=logger, 69 | scaler=scaler 70 | ) 71 | 72 | # test 73 | print('This is the performance of the best validation model:') 74 | checkpoint = os.path.join(save_path, 'best_validation_weights.pt') 75 | evaluate(args, model, checkpoint, test_dataset, estimator) 76 | print('This is the performance of the final model:') 77 | checkpoint = os.path.join(save_path, 'final_weights.pt') 78 | evaluate(args, model, checkpoint, test_dataset, estimator) 79 | 80 | 81 | def train(args, model, train_dataset, val_dataset, estimator, logger=None, scaler=None): 82 | device = args.device 83 | optimizer = torch.optim.Adam( 84 | model.parameters(), 85 | lr=args.learning_rate, 86 | weight_decay=args.weight_decay 87 | ) 88 | loss_function = nn.MSELoss() if args.criterion == 'mse' else nn.CrossEntropyLoss() 89 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) 90 | train_loader = DataLoader( 91 | train_dataset, 92 | batch_size=args.batch_size, 93 | shuffle=True, 94 | num_workers=args.num_workers, 95 | drop_last=True, 96 | pin_memory=True 97 | ) 98 | val_loader = DataLoader( 99 | val_dataset, 100 | batch_size=args.batch_size, 101 | num_workers=args.num_workers, 102 | pin_memory=True 103 | ) 104 | 105 | # start training 106 | model.train() 107 | max_indicator = 0 108 | avg_loss, avg_acc, avg_kappa = 0, 0, 0 109 | for epoch in range(args.epochs): 110 | epoch_loss = 0 111 | estimator.reset() 112 | progress = tqdm(enumerate(train_loader)) if not args.disable_progress else enumerate(train_loader) 113 | for step, train_data in progress: 114 | X, y = train_data 115 | X, y = X.to(device), y.to(device).float() 116 | 117 | if scaler is not None: 118 | with torch.cuda.amp.autocast(True): 119 | y_pred, _ = model(X) 120 | y_pred = y_pred.squeeze() if args.criterion == 'mse' else y_pred 121 | loss = loss_function(y_pred, y) 122 | 123 | optimizer.zero_grad() 124 | scaler.scale(loss).backward() 125 | 126 | scaler.unscale_(optimizer) 127 | nn.utils.clip_grad_norm_(model.parameters(), 1) 128 | 129 | scaler.step(optimizer) 130 | scaler.update() 131 | else: 132 | y_pred = model(X) 133 | y_pred = y_pred.squeeze() if args.criterion == 'mse' else y_pred 134 | loss = loss_function(y_pred, y) 135 | 136 | optimizer.zero_grad() 137 | loss.backward() 138 | nn.utils.clip_grad_norm_(model.parameters(), 1) 139 | optimizer.step() 140 | 141 | # metrics 142 | epoch_loss += loss.item() 143 | avg_loss = epoch_loss / (step + 1) 144 | estimator.update(y_pred, y) 145 | avg_acc = estimator.get_accuracy(6) 146 | avg_kappa = estimator.get_kappa(6) 147 | 148 | current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) 149 | message = '[{}] epoch: [{} / {}], loss: {:.6f}, acc: {:.4f}, kappa: {:.4f}'.format(current_time, epoch + 1, args.epochs, avg_loss, avg_acc, avg_kappa) 150 | if not args.disable_progress: 151 | progress.set_description(message) 152 | 153 | if args.disable_progress: 154 | print(message) 155 | 156 | # validation performance 157 | if epoch % args.eval_interval == 0: 158 | eval(model, val_loader, estimator, device) 159 | acc = estimator.get_accuracy(6) 160 | kappa = estimator.get_kappa(6) 161 | print('validation accuracy: {}, kappa: {}'.format(acc, kappa)) 162 | if logger: 163 | logger.add_scalar('validation accuracy', acc, epoch) 164 | logger.add_scalar('validation kappa', kappa, epoch) 165 | 166 | # save model 167 | indicator = kappa if args.kappa_prior else acc 168 | if indicator > max_indicator: 169 | torch.save( 170 | model.state_dict(), 171 | os.path.join(args.save_path, 'best_validation_weights.pt') 172 | ) 173 | max_indicator = indicator 174 | print_msg('Best in validation set. Model save at {}'.format(args.save_path)) 175 | 176 | if epoch % args.save_interval == 0: 177 | torch.save( 178 | model.state_dict(), 179 | os.path.join(args.save_path, 'epoch_{}.pt'.format(epoch)) 180 | ) 181 | 182 | # update learning rate 183 | curr_lr = optimizer.param_groups[0]['lr'] 184 | if lr_scheduler: 185 | lr_scheduler.step() 186 | 187 | # record 188 | if logger: 189 | logger.add_scalar('training loss', avg_loss, epoch) 190 | logger.add_scalar('training accuracy', avg_acc, epoch) 191 | logger.add_scalar('training kappa', avg_kappa, epoch) 192 | logger.add_scalar('learning rate', curr_lr, epoch) 193 | 194 | # save final model 195 | torch.save( 196 | model.state_dict(), 197 | os.path.join(args.save_path, 'final_weights.pt') 198 | ) 199 | 200 | if logger: 201 | logger.close() 202 | 203 | 204 | def evaluate(args, model, checkpoint, test_dataset, estimator): 205 | weights = torch.load(checkpoint) 206 | model.load_state_dict(weights, strict=True) 207 | test_loader = DataLoader( 208 | test_dataset, 209 | batch_size=args.batch_size, 210 | num_workers=args.num_workers, 211 | shuffle=False, 212 | pin_memory=True 213 | ) 214 | 215 | print('Running on Test set...') 216 | eval(model, test_loader, estimator, args.device) 217 | 218 | print('========================================') 219 | print('Finished! test acc: {}'.format(estimator.get_accuracy(6))) 220 | print('Confusion Matrix:') 221 | print(estimator.conf_mat) 222 | print('quadratic kappa: {}'.format(estimator.get_kappa(6))) 223 | print('========================================') 224 | 225 | 226 | def eval(model, dataloader, estimator, device): 227 | model.eval() 228 | torch.set_grad_enabled(False) 229 | 230 | estimator.reset() 231 | for test_data in dataloader: 232 | X, y = test_data 233 | X, y = X.to(device), y.to(device).float() 234 | 235 | y_pred, _ = model(X) 236 | estimator.update(y_pred, y) 237 | 238 | model.train() 239 | torch.set_grad_enabled(True) 240 | 241 | 242 | def generate_model(args): 243 | assert args.arch in archs.keys(), 'Not implemented architecture.' 244 | out_features = 1 if args.criterion == 'mse' else args.num_classes 245 | model = archs[args.arch]( 246 | num_classes=out_features, 247 | img_size=args.input_size, 248 | feat_concat=True 249 | ) 250 | 251 | linear_key = 'head' 252 | checkpoint_key = args.checkpoint_key 253 | if args.checkpoint: 254 | load_checkpoint(model, args.checkpoint, checkpoint_key, linear_key) 255 | else: 256 | print_msg('No checkpoint provided. Training from scratch.') 257 | 258 | if args.linear: 259 | # freeze all layers but the last fc 260 | for name, param in model.named_parameters(): 261 | if name not in ['%s.weight' % linear_key, '%s.bias' % linear_key]: 262 | param.requires_grad = False 263 | # init the fc layer 264 | nn.init.normal_(getattr(model, linear_key).weight, mean=0.0, std=0.01 ) 265 | nn.init.constant_(getattr(model, linear_key).bias, 0) 266 | 267 | model = model.to(args.device) 268 | return model 269 | 270 | 271 | def generate_dataset(args): 272 | train_transform, test_transform = data_transforms(args) 273 | train_path = os.path.join(args.data_path, 'train') 274 | test_path = os.path.join(args.data_path, 'test') 275 | val_path = os.path.join(args.data_path, 'val') 276 | 277 | train_dataset = datasets.ImageFolder(train_path, train_transform, loader=pil_loader) 278 | test_dataset = datasets.ImageFolder(test_path, test_transform, loader=pil_loader) 279 | val_dataset = datasets.ImageFolder(val_path, test_transform, loader=pil_loader) 280 | 281 | dataset = train_dataset, test_dataset, val_dataset 282 | 283 | print_dataset_info(dataset) 284 | return dataset 285 | 286 | 287 | def data_transforms(args): 288 | mean, std = get_dataset_stats(args.dataset) 289 | augmentations = [ 290 | transforms.RandomHorizontalFlip(p=0.5), 291 | transforms.RandomVerticalFlip(p=0.5), 292 | transforms.RandomResizedCrop( 293 | size=(args.input_size, args.input_size), 294 | scale=(0.87, 1.15), 295 | ratio=(0.7, 1.3) 296 | ), 297 | transforms.ColorJitter( 298 | brightness=0.2, 299 | contrast=0.2, 300 | saturation=0.1, 301 | hue=0.1 302 | ), 303 | transforms.RandomRotation(degrees=(-180, 180)), 304 | transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)) 305 | ] 306 | 307 | normalization = [ 308 | transforms.Resize((args.input_size, args.input_size)), 309 | transforms.ToTensor(), 310 | transforms.Normalize(mean, std) 311 | ] 312 | 313 | train_preprocess = transforms.Compose([ 314 | *augmentations, 315 | *normalization 316 | ]) 317 | 318 | test_preprocess = transforms.Compose(normalization) 319 | return train_preprocess, test_preprocess 320 | 321 | 322 | class Estimator(): 323 | def __init__(self, criterion, num_classes, thresholds=None): 324 | self.criterion = criterion 325 | self.num_classes = num_classes 326 | self.thresholds = [-0.5 + i for i in range(num_classes)] if not thresholds else thresholds 327 | 328 | self.reset() # intitialization 329 | 330 | def update(self, predictions, targets): 331 | targets = targets.cpu() 332 | predictions = predictions.cpu() 333 | predictions = self.to_prediction(predictions) 334 | 335 | # update metrics 336 | self.num_samples += len(predictions) 337 | self.correct += (predictions == targets).sum().item() 338 | for i, p in enumerate(predictions): 339 | self.conf_mat[int(targets[i])][int(p.item())] += 1 340 | 341 | def get_accuracy(self, digits=-1): 342 | acc = self.correct / self.num_samples 343 | acc = acc if digits == -1 else round(acc, digits) 344 | return acc 345 | 346 | def get_kappa(self, digits=-1): 347 | kappa = quadratic_weighted_kappa(self.conf_mat) 348 | kappa = kappa if digits == -1 else round(kappa, digits) 349 | return kappa 350 | 351 | def reset(self): 352 | self.correct = 0 353 | self.num_samples = 0 354 | self.conf_mat = np.zeros((self.num_classes, self.num_classes), dtype=int) 355 | 356 | def to_prediction(self, predictions): 357 | if self.criterion == 'ce': 358 | predictions = torch.tensor( 359 | [torch.argmax(p) for p in predictions] 360 | ).long() 361 | elif self.criterion == 'mse': 362 | predictions = torch.tensor( 363 | [self.classify(p.item()) for p in predictions] 364 | ).float() 365 | else: 366 | raise NotImplementedError('Not implemented criterion.') 367 | 368 | return predictions 369 | 370 | def classify(self, predict): 371 | thresholds = self.thresholds 372 | predict = max(predict, thresholds[0]) 373 | for i in reversed(range(len(thresholds))): 374 | if predict >= thresholds[i]: 375 | return i 376 | 377 | 378 | def pil_loader(path): 379 | with open(path, 'rb') as f: 380 | img = Image.open(f) 381 | return img.convert('RGB') 382 | 383 | 384 | def set_random_seed(seed): 385 | random.seed(seed) 386 | np.random.seed(seed) 387 | torch.manual_seed(seed) 388 | torch.cuda.manual_seed(seed) 389 | # torch.backends.cudnn.deterministic = True 390 | 391 | 392 | if __name__ == '__main__': 393 | main() 394 | -------------------------------------------------------------------------------- /eval_seg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import random 5 | import pickle 6 | import argparse 7 | 8 | import torch 9 | import cv2 as cv 10 | import numpy as np 11 | import torch.nn as nn 12 | import albumentations as A 13 | import torch.nn.functional as F 14 | import segmentation_models_pytorch as smp 15 | 16 | from torch.utils.data import DataLoader 17 | from albumentations.pytorch import ToTensorV2 18 | from torch.utils.tensorboard import SummaryWriter 19 | from albumentations.augmentations.crops.transforms import CropNonEmptyMaskIfExists 20 | 21 | from funcs import * 22 | from vits import archs 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--dataset', type=str, default=None, help='drive / idrid') 27 | parser.add_argument('--arch', type=str, default='ViT-S-p16', help='model architecture') 28 | parser.add_argument('--data-index', type=str, help='dataset index') 29 | parser.add_argument('--save-path', type=str, default='./eval_checkpoints', help='save path') 30 | parser.add_argument('--log-path', type=str, default=None, help='log path') 31 | parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint path') 32 | parser.add_argument('--checkpoint-key', type=str, default='base_encoder', help='base_encoder / momentum_encoder') 33 | parser.add_argument('--linear', action='store_true', help='use linear eval') 34 | parser.add_argument('--seed', type=int, default=0, help='random seed') 35 | parser.add_argument('--device', type=str, default='cuda', help='device') 36 | 37 | parser.add_argument('--iterations', type=int, default=2000, help='number of iterations') 38 | parser.add_argument('--warmup-iterations', type=int, default=200, help='number of warmup iterations') 39 | parser.add_argument('--input-size', type=int, default=512, help='input size') 40 | parser.add_argument('--patch-size', type=int, default=256, help='input size') 41 | parser.add_argument('--learning-rate', type=float, default=0.001, help='learning rate') 42 | parser.add_argument('--optimizer', type=str, default='ADAM', help='optimizer') 43 | parser.add_argument('--weight-decay', type=float, default=0.000001, help='weight decay') 44 | parser.add_argument('--batch-size', type=int, default=16, help='batch size') 45 | parser.add_argument('--num-workers', type=int, default=8, help='number of workers') 46 | parser.add_argument('--kappa-prior', action='store_true', help='use kappa as best model indicator') 47 | parser.add_argument('--eval-interval', type=int, default=50, help='the epoch interval of evaluating model on val dataset') 48 | parser.add_argument('--save-interval', type=int, default=500, help='the epoch interval of saving model') 49 | parser.add_argument('--disable-progress', action='store_true', help='disable progress bar') 50 | parser.add_argument('--ce-weight', type=float, default=10, help='weight of cross entropy loss') 51 | 52 | 53 | def main(): 54 | args = parser.parse_args() 55 | 56 | log_path = args.log_path 57 | if log_path is None: 58 | log_path = os.path.join(args.save_path, 'log') 59 | os.makedirs(args.save_path, exist_ok=True) 60 | logger = SummaryWriter(log_path) 61 | 62 | set_random_seed(args.seed) 63 | model = generate_model(args) 64 | train_dataset, test_dataset, val_dataset = generate_dataset(args) 65 | estimator = Estimator() 66 | scaler = torch.cuda.amp.GradScaler() 67 | train( 68 | args=args, 69 | model=model, 70 | train_dataset=train_dataset, 71 | val_dataset=val_dataset, 72 | estimator=estimator, 73 | logger=logger, 74 | scaler=scaler 75 | ) 76 | 77 | # test 78 | print('This is the performance of the best validation model:') 79 | checkpoint = os.path.join(args.save_path, 'best_validation_weights.pt') 80 | evaluate(args, model, checkpoint, test_dataset, estimator) 81 | print('This is the performance of the final model:') 82 | checkpoint = os.path.join(args.save_path, 'final_weights.pt') 83 | evaluate(args, model, checkpoint, test_dataset, estimator) 84 | 85 | 86 | def train(args, model, train_dataset, val_dataset, estimator, logger=None, scaler=None): 87 | device = args.device 88 | optimizer = torch.optim.Adam( 89 | model.parameters(), 90 | lr=args.learning_rate, 91 | weight_decay=args.weight_decay 92 | ) 93 | dice_loss = smp.losses.DiceLoss('binary') 94 | bce_loss = nn.BCEWithLogitsLoss() 95 | train_loader = DataLoader( 96 | train_dataset, 97 | batch_size=args.batch_size, 98 | shuffle=True, 99 | num_workers=args.num_workers, 100 | drop_last=True, 101 | pin_memory=True 102 | ) 103 | val_loader = DataLoader( 104 | val_dataset, 105 | batch_size=1, 106 | num_workers=args.num_workers, 107 | pin_memory=True 108 | ) 109 | 110 | # start training 111 | model.train() 112 | max_indicator = 0 113 | avg_loss, avg_dice = 0, 0 114 | cum_loss = 0 115 | data_iter = iter(train_loader) 116 | for step in range(args.iterations): 117 | lr = adjust_learning_rate(args, optimizer, step) 118 | X, y, data_iter = load_sample(train_loader, data_iter) 119 | X, y = X.to(device), y.to(device).long() 120 | 121 | if scaler is not None: 122 | with torch.cuda.amp.autocast(True): 123 | y_pred = model(X).squeeze() 124 | loss = dice_loss(y_pred, y) + args.ce_weight * bce_loss(y_pred, y.float()) 125 | 126 | optimizer.zero_grad() 127 | scaler.scale(loss).backward() 128 | 129 | scaler.unscale_(optimizer) 130 | nn.utils.clip_grad_norm_(model.parameters(), 1) 131 | 132 | scaler.step(optimizer) 133 | scaler.update() 134 | else: 135 | y_pred = model(X).squeeze() 136 | loss = dice_loss(y_pred, y) + args.ce_weight * bce_loss(y_pred, y.float()) 137 | 138 | optimizer.zero_grad() 139 | loss.backward() 140 | nn.utils.clip_grad_norm_(model.parameters(), 1) 141 | optimizer.step() 142 | 143 | # metrics 144 | cum_loss += loss.item() 145 | avg_loss = cum_loss / (step + 1) 146 | estimator.update(y_pred, y) 147 | 148 | if (step+1) % 10 == 0: 149 | avg_dice = estimator.get_dice(6) 150 | current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) 151 | message = '[{}] step: [{} / {}], loss: {:.6f}, dice: {:.4f}, lr: {:.6f}'.format(current_time, step + 1, args.iterations, avg_loss, avg_dice, lr) 152 | print(message) 153 | estimator.reset() 154 | cum_loss = 0 155 | 156 | # validation performance 157 | if (step+1) % args.eval_interval == 0: 158 | eval(args, model, val_loader, estimator, device) 159 | dice = estimator.get_dice(6) 160 | print('validation dice: {}'.format(dice)) 161 | if logger: 162 | logger.add_scalar('validation dice', dice, step) 163 | 164 | # save model 165 | indicator = dice 166 | if indicator > max_indicator: 167 | torch.save( 168 | model.state_dict(), 169 | os.path.join(args.save_path, 'best_validation_weights.pt') 170 | ) 171 | max_indicator = indicator 172 | print_msg('Best in validation set. Model save at {}'.format(args.save_path)) 173 | 174 | if (step+1) % args.save_interval == 0: 175 | torch.save( 176 | model.state_dict(), 177 | os.path.join(args.save_path, 'step_{}.pt'.format(step)) 178 | ) 179 | 180 | # record 181 | if logger: 182 | logger.add_scalar('training loss', avg_loss, step) 183 | logger.add_scalar('training dice', avg_dice, step) 184 | logger.add_scalar('learning rate', lr, step) 185 | 186 | # save final model 187 | torch.save( 188 | model.state_dict(), 189 | os.path.join(args.save_path, 'final_weights.pt') 190 | ) 191 | 192 | if logger: 193 | logger.close() 194 | 195 | 196 | def evaluate(args, model, checkpoint, test_dataset, estimator): 197 | weights = torch.load(checkpoint) 198 | model.load_state_dict(weights, strict=True) 199 | test_loader = DataLoader( 200 | test_dataset, 201 | batch_size=1, 202 | num_workers=args.num_workers, 203 | shuffle=False, 204 | pin_memory=True 205 | ) 206 | 207 | print('Running on Test set...') 208 | eval(args, model, test_loader, estimator, args.device) 209 | 210 | print('========================================') 211 | print('Finished! Dice: {}'.format(estimator.get_dice(6))) 212 | print('========================================') 213 | 214 | 215 | def eval(args, model, dataloader, estimator, device): 216 | model.eval() 217 | torch.set_grad_enabled(False) 218 | 219 | estimator.reset() 220 | for test_data in dataloader: 221 | X, y = test_data 222 | X, y = X.to(device), y.to(device).float() 223 | 224 | X = patchify(X, kernel_size=args.patch_size, stride=args.patch_size // 2) 225 | y_pred = model(X) 226 | y_pred = unpatchify(y_pred, kernel_size=args.patch_size, stride=args.patch_size // 2, target_shape=y.shape) 227 | 228 | estimator.update(y_pred, y) 229 | 230 | model.train() 231 | torch.set_grad_enabled(True) 232 | 233 | 234 | def generate_model(args): 235 | assert args.arch in archs.keys(), 'Not implemented architecture.' 236 | encoder = archs[args.arch]( 237 | img_size=args.patch_size 238 | ) 239 | 240 | linear_key = 'head' 241 | checkpoint_key = args.checkpoint_key 242 | if args.checkpoint: 243 | load_checkpoint(encoder, args.checkpoint, checkpoint_key, linear_key) 244 | else: 245 | print_msg('No checkpoint provided. Training from scratch.') 246 | 247 | if args.linear: 248 | for name, param in encoder.named_parameters(): 249 | param.requires_grad = False 250 | encoder.eval() 251 | 252 | model = Segmentor(encoder) 253 | model = model.to(args.device) 254 | return model 255 | 256 | 257 | def generate_dataset(args): 258 | train_transform, test_transform = data_transforms(args) 259 | datasets = pickle.load(open(args.data_index, 'rb')) 260 | print(datasets) 261 | 262 | train_dataset = SegmentationDataset(datasets['train'], train_transform, loader=cv_loader) 263 | test_dataset = SegmentationDataset(datasets['test'], test_transform, loader=cv_loader) 264 | val_dataset = SegmentationDataset(datasets['val'], test_transform, loader=cv_loader) 265 | 266 | dataset = train_dataset, test_dataset, val_dataset 267 | 268 | print('train dataset: {}'.format(len(train_dataset))) 269 | print('test dataset: {}'.format(len(test_dataset))) 270 | print('val dataset: {}'.format(len(val_dataset))) 271 | return dataset 272 | 273 | 274 | def data_transforms(args): 275 | mean, std = get_dataset_stats(args.dataset) 276 | train_preprocess = A.Compose([ 277 | A.Resize(args.input_size, args.input_size), 278 | CropNonEmptyMaskIfExists(args.patch_size, args.patch_size), 279 | A.RandomBrightnessContrast(p=0.2), 280 | A.ShiftScaleRotate(p=0.8), 281 | A.HorizontalFlip(p=0.5), 282 | A.VerticalFlip(p=0.5), 283 | A.Normalize(mean=mean, std=std), 284 | ToTensorV2() 285 | ]) 286 | test_preprocess = A.Compose([ 287 | A.Resize(args.input_size, args.input_size), 288 | A.Normalize(mean=mean, std=std), 289 | ToTensorV2() 290 | ]) 291 | 292 | return train_preprocess, test_preprocess 293 | 294 | 295 | class Segmentor(nn.Module): 296 | def __init__(self, encoder): 297 | super(Segmentor, self).__init__() 298 | 299 | self.encoder = encoder 300 | patch_size = encoder.patch_size 301 | hidden_dim = encoder.head.weight.shape[1] 302 | self.segmentor = nn.Sequential( 303 | nn.Conv2d( 304 | in_channels=hidden_dim, 305 | out_channels=(patch_size ** 2), 306 | kernel_size=1, 307 | ), 308 | nn.PixelShuffle(upscale_factor=patch_size), 309 | ) 310 | 311 | def forward(self, x): 312 | _, f = self.encoder(x) 313 | f = f[:, 1:] 314 | 315 | B, L, C = f.shape 316 | H = W = int(L ** 0.5) 317 | f = f.permute(0, 2, 1).reshape(B, C, H, W) 318 | y = self.segmentor(f) 319 | return y.squeeze() 320 | 321 | 322 | class Estimator(): 323 | def __init__(self): 324 | self.reset() # intitialization 325 | 326 | def update(self, predictions, targets): 327 | targets = targets.cpu() 328 | predictions = self.predict(predictions) 329 | 330 | self.num_samples += targets.shape[0] 331 | self.dice += sum(self.compute_dice(targets, predictions)) 332 | 333 | def compute_dice(self, targets, predictions): 334 | dices = [] 335 | for i in range(targets.shape[0]): 336 | target = targets[i] 337 | prediction = predictions[i] 338 | tp = torch.sum(target * prediction) 339 | fp = torch.sum(prediction) - tp 340 | fn = torch.sum(target) - tp 341 | dice = (2 * tp) / (2 * tp + fp + fn) 342 | dices.append(dice.item()) 343 | return dices 344 | 345 | def predict(self, predictions): 346 | predictions = torch.sigmoid(predictions) 347 | predictions = predictions > 0.5 348 | return predictions.cpu() 349 | 350 | def get_dice(self, digits=-1): 351 | score = self.dice / self.num_samples 352 | score = score if digits == -1 else round(score, digits) 353 | return score 354 | 355 | def reset(self): 356 | self.dice = 0 357 | self.num_samples = 0 358 | 359 | 360 | class SegmentationDataset(torch.utils.data.Dataset): 361 | def __init__(self, dataset, transform=None, loader=None): 362 | self.dataset = dataset 363 | self.transform = transform 364 | self.loader = loader 365 | 366 | def __getitem__(self, index): 367 | image, mask = self.dataset[index] 368 | image = self.loader(image) 369 | mask = cv.imread(mask).astype(np.uint8) 370 | mask = mask[:, :, 2] if len(mask.shape) == 3 else mask 371 | mask = mask / 255 372 | 373 | augmented = self.transform(image=image, mask=mask) 374 | image = augmented['image'] 375 | mask = augmented['mask'] 376 | return image, mask 377 | 378 | def __len__(self): 379 | return len(self.dataset) 380 | 381 | def cv_loader(path): 382 | img = cv.imread(path) 383 | img = cv.cvtColor(img, cv.COLOR_BGR2RGB) 384 | return img 385 | 386 | 387 | def set_random_seed(seed): 388 | random.seed(seed) 389 | np.random.seed(seed) 390 | torch.manual_seed(seed) 391 | torch.cuda.manual_seed(seed) 392 | 393 | 394 | def load_sample(data_loader, data_iter): 395 | try: 396 | X, y = next(data_iter) 397 | except StopIteration: 398 | data_iter = iter(data_loader) 399 | X, y = next(data_iter) 400 | return X, y, data_iter 401 | 402 | 403 | def patchify(x, kernel_size, stride): 404 | pad_size = kernel_size - stride 405 | x = F.pad(x, (pad_size, pad_size, pad_size, pad_size), mode='reflect') 406 | x = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride) 407 | x = x.permute(0, 2, 3, 1, 4, 5) 408 | x = x.contiguous().view(-1, *x.shape[3:]) 409 | return x 410 | 411 | 412 | def unpatchify(x, kernel_size, stride, target_shape, num_channel=1): 413 | B, H, W = target_shape 414 | C = num_channel 415 | pad_size = kernel_size - stride 416 | 417 | x = x.contiguous().view(B, -1, C, kernel_size, kernel_size) 418 | x = x.contiguous().view(B, -1, C * kernel_size * kernel_size) 419 | x = x.permute(0, 2, 1) 420 | x = F.fold(x, output_size=(H, W), kernel_size=kernel_size, padding=pad_size, stride=stride) 421 | x = x / (kernel_size / stride) ** 2 422 | return x 423 | 424 | 425 | def adjust_learning_rate(args, optimizer, step): 426 | """Decays the learning rate with half-cycle cosine after warmup""" 427 | if step < args.warmup_iterations: 428 | lr = args.learning_rate * step / args.warmup_iterations 429 | else: 430 | lr = args.learning_rate * 0.5 * (1. + math.cos(math.pi * (step - args.warmup_iterations) / (args.iterations - args.warmup_iterations))) 431 | for param_group in optimizer.param_groups: 432 | param_group['lr'] = lr 433 | return lr 434 | 435 | 436 | 437 | if __name__ == '__main__': 438 | main() 439 | 440 | -------------------------------------------------------------------------------- /vits.py: -------------------------------------------------------------------------------- 1 | # ============================================================================================= 2 | # Based on timm/models/vision_transformer.py 3 | # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/vision_transformer.py 4 | # ============================================================================================= 5 | 6 | import math 7 | import logging 8 | from copy import deepcopy 9 | from functools import partial 10 | from collections import OrderedDict 11 | from itertools import chain 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | from torch.utils.checkpoint import checkpoint 18 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 19 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 20 | from timm.models.layers import Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple 21 | from timm.models.registry import register_model 22 | _logger = logging.getLogger(__name__) 23 | 24 | 25 | def _cfg(url='', **kwargs): 26 | return { 27 | 'url': url, 28 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 29 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 30 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 31 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 32 | **kwargs 33 | } 34 | 35 | 36 | default_cfgs = { 37 | 'vit_tiny_patch16_384': _cfg( 38 | url='https://storage.googleapis.com/vit_models/augreg/' 39 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 40 | input_size=(3, 384, 384), crop_pct=1.0), 41 | 'vit_tiny_patch32_384': _cfg( 42 | url='https://storage.googleapis.com/vit_models/augreg/' 43 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 44 | input_size=(3, 384, 384), crop_pct=1.0), 45 | 'vit_small_patch32_384': _cfg( 46 | url='https://storage.googleapis.com/vit_models/augreg/' 47 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 48 | input_size=(3, 384, 384), crop_pct=1.0), 49 | 'vit_small_patch16_384': _cfg( 50 | url='https://storage.googleapis.com/vit_models/augreg/' 51 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 52 | input_size=(3, 384, 384), crop_pct=1.0), 53 | 'vit_base_patch32_384': _cfg( 54 | url='https://storage.googleapis.com/vit_models/augreg/' 55 | 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 56 | input_size=(3, 384, 384), crop_pct=1.0), 57 | 'vit_base_patch16_384': _cfg( 58 | url='https://storage.googleapis.com/vit_models/augreg/' 59 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', 60 | input_size=(3, 384, 384), crop_pct=1.0), 61 | 'vit_large_patch32_384': _cfg( 62 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 63 | input_size=(3, 384, 384), crop_pct=1.0), 64 | 'vit_large_patch16_384': _cfg( 65 | url='https://storage.googleapis.com/vit_models/augreg/' 66 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', 67 | input_size=(3, 384, 384), crop_pct=1.0), 68 | } 69 | 70 | 71 | class Attention(nn.Module): 72 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 73 | super().__init__() 74 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 75 | self.num_heads = num_heads 76 | head_dim = dim // num_heads 77 | self.scale = head_dim ** -0.5 78 | 79 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 80 | self.attn_drop = nn.Dropout(attn_drop) 81 | self.proj = nn.Linear(dim, dim) 82 | self.proj_drop = nn.Dropout(proj_drop) 83 | 84 | def forward(self, x): 85 | B, N, C = x.shape 86 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 87 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 88 | 89 | attn = (q @ k.transpose(-2, -1)) * self.scale 90 | attn = attn.softmax(dim=-1) 91 | attn = self.attn_drop(attn) 92 | 93 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 94 | x = self.proj(x) 95 | x = self.proj_drop(x) 96 | return x, attn 97 | 98 | 99 | class LayerScale(nn.Module): 100 | def __init__(self, dim, init_values=1e-5, inplace=False): 101 | super().__init__() 102 | self.inplace = inplace 103 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 104 | 105 | def forward(self, x): 106 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 107 | 108 | 109 | class Block(nn.Module): 110 | 111 | def __init__( 112 | self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, 113 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 114 | super().__init__() 115 | self.norm1 = norm_layer(dim) 116 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 117 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 118 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 119 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 120 | 121 | self.norm2 = norm_layer(dim) 122 | mlp_hidden_dim = int(dim * mlp_ratio) 123 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 124 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 125 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 126 | 127 | def forward(self, x, return_attention=False): 128 | y, attn = self.attn(self.norm1(x)) 129 | if return_attention: 130 | return attn 131 | x = x + self.drop_path1(self.ls1(y)) 132 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 133 | return x 134 | 135 | 136 | class ParallelBlock(nn.Module): 137 | def __init__( 138 | self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, 139 | drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 140 | super().__init__() 141 | self.num_parallel = num_parallel 142 | self.attns = nn.ModuleList() 143 | self.ffns = nn.ModuleList() 144 | for _ in range(num_parallel): 145 | self.attns.append(nn.Sequential(OrderedDict([ 146 | ('norm', norm_layer(dim)), 147 | ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), 148 | ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), 149 | ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) 150 | ]))) 151 | self.ffns.append(nn.Sequential(OrderedDict([ 152 | ('norm', norm_layer(dim)), 153 | ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), 154 | ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), 155 | ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) 156 | ]))) 157 | 158 | def _forward_jit(self, x): 159 | x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) 160 | x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) 161 | return x 162 | 163 | @torch.jit.ignore 164 | def _forward(self, x): 165 | x = x + sum(attn(x) for attn in self.attns) 166 | x = x + sum(ffn(x) for ffn in self.ffns) 167 | return x 168 | 169 | def forward(self, x): 170 | if torch.jit.is_scripting() or torch.jit.is_tracing(): 171 | return self._forward_jit(x) 172 | else: 173 | return self._forward(x) 174 | 175 | 176 | class PatchEmbed(nn.Module): 177 | """ 2D Image to Patch Embedding 178 | """ 179 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 180 | super().__init__() 181 | img_size = to_2tuple(img_size) 182 | patch_size = to_2tuple(patch_size) 183 | self.img_size = img_size 184 | self.patch_size = patch_size 185 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 186 | self.num_patches = self.grid_size[0] * self.grid_size[1] 187 | self.flatten = flatten 188 | 189 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 190 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 191 | 192 | def forward(self, x): 193 | B, C, H, W = x.shape 194 | x = self.proj(x) 195 | if self.flatten: 196 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 197 | x = self.norm(x) 198 | return x 199 | 200 | 201 | class PatchSampler(object): 202 | def __init__(self, mask_size, patch_size, mask_ratio=0.25): 203 | self.mask_size = mask_size 204 | self.patch_size = patch_size 205 | self.mask_ratio = mask_ratio 206 | 207 | def __call__(self, pmap): 208 | B, C, H, W = pmap.shape 209 | num_sample = int((1 - self.mask_ratio) * H * W) 210 | 211 | feat_idx = pmap.flatten(1).argsort(descending=True)[:,:num_sample] 212 | feat_idx += 1 # class embedding concat before the image embedding 213 | cls_idx = torch.zeros((B, 1), dtype=torch.int64, device=pmap.device) 214 | active_idx = torch.cat([cls_idx, feat_idx], dim=1) 215 | return active_idx 216 | 217 | 218 | class VisionTransformer(nn.Module): 219 | """ Vision Transformer 220 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 221 | - https://arxiv.org/abs/2010.11929 222 | """ 223 | 224 | def __init__( 225 | self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', 226 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, 227 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, 228 | embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block, mask_ratio=0.25, feat_concat=False): 229 | """ 230 | Args: 231 | img_size (int, tuple): input image size 232 | patch_size (int, tuple): patch size 233 | in_chans (int): number of input channels 234 | num_classes (int): number of classes for classification head 235 | global_pool (str): type of global pooling for final sequence (default: 'token') 236 | embed_dim (int): embedding dimension 237 | depth (int): depth of transformer 238 | num_heads (int): number of attention heads 239 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 240 | qkv_bias (bool): enable bias for qkv if True 241 | init_values: (float): layer-scale init values 242 | class_token (bool): use class token 243 | fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) 244 | drop_rate (float): dropout rate 245 | attn_drop_rate (float): attention dropout rate 246 | drop_path_rate (float): stochastic depth rate 247 | weight_init (str): weight init scheme 248 | embed_layer (nn.Module): patch embedding layer 249 | norm_layer: (nn.Module): normalization layer 250 | act_layer: (nn.Module): MLP activation layer 251 | """ 252 | super().__init__() 253 | assert global_pool in ('', 'avg', 'token') 254 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 255 | act_layer = act_layer or nn.GELU 256 | 257 | # store grid_size 258 | self.grid_sizes = {} 259 | self.patch_size = patch_size 260 | 261 | self.num_classes = num_classes 262 | self.global_pool = global_pool 263 | self.feat_concat = feat_concat 264 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 265 | self.num_tokens = 1 266 | self.grad_checkpointing = False 267 | 268 | self.patch_embed = embed_layer( 269 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 270 | num_patches = self.patch_embed.num_patches 271 | 272 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 273 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 274 | self.pos_drop = nn.Dropout(p=drop_rate) 275 | 276 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 277 | self.blocks = nn.Sequential(*[ 278 | block_fn( 279 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, 280 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) 281 | for i in range(depth)]) 282 | use_fc_norm = self.global_pool == 'avg' 283 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() 284 | 285 | # Representation layer. Used for original ViT models w/ in21k pretraining. 286 | self.representation_size = representation_size 287 | self.pre_logits = nn.Identity() 288 | if representation_size: 289 | self._reset_representation(representation_size) 290 | 291 | # Classifier Head 292 | final_chs = self.representation_size if self.representation_size else self.embed_dim 293 | final_chs = final_chs * 2 if self.feat_concat else final_chs 294 | self.fc_norm = norm_layer(final_chs) if use_fc_norm else nn.Identity() 295 | self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() 296 | 297 | self.mask_ratio = mask_ratio 298 | self.patch_sampler = PatchSampler(mask_size=img_size, patch_size=patch_size, mask_ratio=mask_ratio) 299 | 300 | if weight_init != 'skip': 301 | self.init_weights(weight_init) 302 | 303 | def init_weights(self, mode=''): 304 | assert mode in ('jax', 'jax_nlhb', 'moco', '') 305 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 306 | trunc_normal_(self.pos_embed, std=.02) 307 | nn.init.normal_(self.cls_token, std=1e-6) 308 | named_apply(get_init_weights_vit(mode, head_bias), self) 309 | 310 | def _init_weights(self, m): 311 | # this fn left here for compat with downstream users 312 | init_weights_vit_timm(m) 313 | 314 | @torch.jit.ignore() 315 | def load_pretrained(self, checkpoint_path, prefix=''): 316 | _load_weights(self, checkpoint_path, prefix) 317 | 318 | @torch.jit.ignore 319 | def no_weight_decay(self): 320 | return {'pos_embed', 'cls_token', 'dist_token'} 321 | 322 | @torch.jit.ignore 323 | def group_matcher(self, coarse=False): 324 | return dict( 325 | stem=r'^cls_token|pos_embed|patch_embed', # stem and embed 326 | blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] 327 | ) 328 | 329 | @torch.jit.ignore 330 | def set_grad_checkpointing(self, enable=True): 331 | self.grad_checkpointing = enable 332 | 333 | @torch.jit.ignore 334 | def get_classifier(self): 335 | return self.head 336 | 337 | def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None): 338 | self.num_classes = num_classes 339 | if global_pool is not None: 340 | assert global_pool in ('', 'avg', 'token') 341 | self.global_pool = global_pool 342 | if representation_size is not None: 343 | self._reset_representation(representation_size) 344 | final_chs = self.representation_size if self.representation_size else self.embed_dim 345 | self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() 346 | 347 | def forward_features(self, x, pmap=None): 348 | x = self.patch_embed(x) 349 | if self.cls_token is not None: 350 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 351 | x = x + self.pos_embed 352 | 353 | if pmap is not None and self.mask_ratio < 1: 354 | active_idx = self.patch_sampler(pmap) 355 | active_idx = active_idx.unsqueeze(-1).repeat(1, 1, self.embed_dim) 356 | x = torch.gather(x, dim=1, index=active_idx) 357 | 358 | x = self.pos_drop(x) 359 | if self.grad_checkpointing and not torch.jit.is_scripting(): 360 | x = checkpoint_seq(self.blocks, x) 361 | else: 362 | x = self.blocks(x) 363 | x = self.norm(x) 364 | return x 365 | 366 | def forward_head(self, x, pre_logits: bool = False): 367 | if self.feat_concat: 368 | feats = x[:, 1:].mean(dim=1) 369 | x = torch.cat((x[:, 0], feats), dim=1) 370 | elif self.global_pool: 371 | x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] 372 | x = self.fc_norm(x) 373 | x = self.pre_logits(x) 374 | return x if pre_logits else self.head(x) 375 | 376 | def forward(self, x, pmap=None): 377 | f = self.forward_features(x, pmap) 378 | x = self.forward_head(f) 379 | return x, f 380 | 381 | def get_last_selfattention(self, x): 382 | x = self.patch_embed(x) 383 | if self.cls_token is not None: 384 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 385 | x = x + self.pos_embed 386 | x = self.pos_drop(x) 387 | 388 | for i, blk in enumerate(self.blocks): 389 | if i < len(self.blocks) - 1: 390 | x = blk(x) 391 | else: 392 | # return attention of the last block 393 | return blk(x, return_attention=True) 394 | 395 | 396 | def init_weights_vit_timm(module: nn.Module, name: str = ''): 397 | """ ViT weight initialization, original timm impl (for reproducibility) """ 398 | if isinstance(module, nn.Linear): 399 | trunc_normal_(module.weight, std=.02) 400 | if module.bias is not None: 401 | nn.init.zeros_(module.bias) 402 | 403 | 404 | def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): 405 | """ ViT weight initialization, matching JAX (Flax) impl """ 406 | if isinstance(module, nn.Linear): 407 | if name.startswith('head'): 408 | nn.init.zeros_(module.weight) 409 | nn.init.constant_(module.bias, head_bias) 410 | elif name.startswith('pre_logits'): 411 | lecun_normal_(module.weight) 412 | nn.init.zeros_(module.bias) 413 | else: 414 | nn.init.xavier_uniform_(module.weight) 415 | if module.bias is not None: 416 | nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) 417 | elif isinstance(module, nn.Conv2d): 418 | lecun_normal_(module.weight) 419 | if module.bias is not None: 420 | nn.init.zeros_(module.bias) 421 | 422 | 423 | def init_weights_vit_moco(module: nn.Module, name: str = ''): 424 | """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ 425 | if isinstance(module, nn.Linear): 426 | if 'qkv' in name: 427 | # treat the weights of Q, K, V separately 428 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) 429 | nn.init.uniform_(module.weight, -val, val) 430 | else: 431 | nn.init.xavier_uniform_(module.weight) 432 | if module.bias is not None: 433 | nn.init.zeros_(module.bias) 434 | 435 | 436 | def get_init_weights_vit(mode='jax', head_bias: float = 0.): 437 | if 'jax' in mode: 438 | return partial(init_weights_vit_jax, head_bias=head_bias) 439 | elif 'moco' in mode: 440 | return init_weights_vit_moco 441 | else: 442 | return init_weights_vit_timm 443 | 444 | 445 | @torch.no_grad() 446 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 447 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 448 | """ 449 | import numpy as np 450 | 451 | def _n2p(w, t=True): 452 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 453 | w = w.flatten() 454 | if t: 455 | if w.ndim == 4: 456 | w = w.transpose([3, 2, 0, 1]) 457 | elif w.ndim == 3: 458 | w = w.transpose([2, 0, 1]) 459 | elif w.ndim == 2: 460 | w = w.transpose([1, 0]) 461 | return torch.from_numpy(w) 462 | 463 | w = np.load(checkpoint_path) 464 | if not prefix and 'opt/target/embedding/kernel' in w: 465 | prefix = 'opt/target/' 466 | 467 | if hasattr(model.patch_embed, 'backbone'): 468 | # hybrid 469 | backbone = model.patch_embed.backbone 470 | stem_only = not hasattr(backbone, 'stem') 471 | stem = backbone if stem_only else backbone.stem 472 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 473 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 474 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 475 | if not stem_only: 476 | for i, stage in enumerate(backbone.stages): 477 | for j, block in enumerate(stage.blocks): 478 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 479 | for r in range(3): 480 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 481 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 482 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 483 | if block.downsample is not None: 484 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 485 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 486 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 487 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 488 | else: 489 | embed_conv_w = adapt_input_conv( 490 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 491 | model.patch_embed.proj.weight.copy_(embed_conv_w) 492 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 493 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 494 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 495 | if pos_embed_w.shape != model.pos_embed.shape: 496 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 497 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 498 | model.pos_embed.copy_(pos_embed_w) 499 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 500 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 501 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 502 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 503 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 504 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 505 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 506 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 507 | for i, block in enumerate(model.blocks.children()): 508 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 509 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 510 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 511 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 512 | block.attn.qkv.weight.copy_(torch.cat([ 513 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 514 | block.attn.qkv.bias.copy_(torch.cat([ 515 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 516 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 517 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 518 | for r in range(2): 519 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 520 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 521 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 522 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 523 | 524 | 525 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): 526 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 527 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 528 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 529 | ntok_new = posemb_new.shape[1] 530 | if num_tokens: 531 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] 532 | ntok_new -= num_tokens 533 | else: 534 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 535 | gs_old = int(math.sqrt(len(posemb_grid))) 536 | if not len(gs_new): # backwards compatibility 537 | gs_new = [int(math.sqrt(ntok_new))] * 2 538 | assert len(gs_new) >= 2 539 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) 540 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 541 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) 542 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 543 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 544 | return posemb 545 | 546 | 547 | def checkpoint_filter_fn(state_dict, model): 548 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 549 | out_dict = {} 550 | if 'model' in state_dict: 551 | # For deit models 552 | state_dict = state_dict['model'] 553 | for k, v in state_dict.items(): 554 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 555 | # For old models that I trained prior to conv based patchification 556 | O, I, H, W = model.patch_embed.proj.weight.shape 557 | v = v.reshape(O, -1, H, W) 558 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 559 | # To resize pos embedding when using model at different size from pretrained weights 560 | v = resize_pos_embed( 561 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 562 | out_dict[k] = v 563 | return out_dict 564 | 565 | 566 | def _create_vision_transformer(variant, pretrained=False, **kwargs): 567 | if kwargs.get('features_only', None): 568 | raise RuntimeError('features_only not implemented for Vision Transformer models.') 569 | 570 | # NOTE this extra code to support handling of repr size for in21k pretrained models 571 | pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) 572 | default_num_classes = pretrained_cfg['num_classes'] 573 | num_classes = kwargs.get('num_classes', default_num_classes) 574 | repr_size = kwargs.pop('representation_size', None) 575 | if repr_size is not None and num_classes != default_num_classes: 576 | # Remove representation layer if fine-tuning. This may not always be the desired action, 577 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? 578 | _logger.warning("Removing representation layer for fine-tuning.") 579 | repr_size = None 580 | 581 | model = build_model_with_cfg( 582 | VisionTransformer, variant, pretrained, 583 | default_cfg=pretrained_cfg, 584 | representation_size=repr_size, 585 | pretrained_filter_fn=checkpoint_filter_fn, 586 | pretrained_custom_load='npz' in pretrained_cfg['url'], 587 | **kwargs) 588 | return model 589 | 590 | 591 | @register_model 592 | def vit_tiny_patch16_384(pretrained=False, **kwargs): 593 | """ ViT-Tiny (Vit-Ti/16) @ 384x384. 594 | """ 595 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 596 | model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) 597 | return model 598 | 599 | 600 | @register_model 601 | def vit_tiny_patch32_384(pretrained=False, **kwargs): 602 | """ ViT-Tiny (Vit-Ti/16) @ 384x384. 603 | """ 604 | model_kwargs = dict(patch_size=32, embed_dim=192, depth=12, num_heads=3, **kwargs) 605 | model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) 606 | return model 607 | 608 | 609 | @register_model 610 | def vit_small_patch32_384(pretrained=False, **kwargs): 611 | """ ViT-Small (ViT-S/32) at 384x384. 612 | """ 613 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 614 | model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) 615 | return model 616 | 617 | 618 | @register_model 619 | def vit_small_patch16_384(pretrained=False, **kwargs): 620 | """ ViT-Small (ViT-S/16) 621 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 622 | """ 623 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 624 | model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) 625 | return model 626 | 627 | 628 | @register_model 629 | def vit_base_patch32_384(pretrained=False, **kwargs): 630 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 631 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 632 | """ 633 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 634 | model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) 635 | return model 636 | 637 | 638 | @register_model 639 | def vit_base_patch16_384(pretrained=False, **kwargs): 640 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 641 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 642 | """ 643 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 644 | model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) 645 | return model 646 | 647 | 648 | @register_model 649 | def vit_large_patch32_384(pretrained=False, **kwargs): 650 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 651 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 652 | """ 653 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 654 | model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) 655 | return model 656 | 657 | 658 | @register_model 659 | def vit_large_patch16_384(pretrained=False, **kwargs): 660 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 661 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 662 | """ 663 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 664 | model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) 665 | return model 666 | 667 | 668 | def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None): 669 | if pretrained_cfg and isinstance(pretrained_cfg, dict): 670 | # highest priority, pretrained_cfg available and passed explicitly 671 | return deepcopy(pretrained_cfg) 672 | if kwargs and 'pretrained_cfg' in kwargs: 673 | # next highest, pretrained_cfg in a kwargs dict, pop and return 674 | pretrained_cfg = kwargs.pop('pretrained_cfg', {}) 675 | if pretrained_cfg: 676 | return deepcopy(pretrained_cfg) 677 | # lookup pretrained cfg in model registry by variant 678 | pretrained_cfg = get_pretrained_cfg(variant) 679 | assert pretrained_cfg 680 | return pretrained_cfg 681 | 682 | 683 | def get_pretrained_cfg(model_name): 684 | if model_name in default_cfgs: 685 | return deepcopy(default_cfgs[model_name]) 686 | return {} 687 | 688 | 689 | def checkpoint_seq( 690 | functions, 691 | x, 692 | every=1, 693 | flatten=False, 694 | skip_last=False, 695 | preserve_rng_state=True 696 | ): 697 | r"""A helper function for checkpointing sequential models. 698 | Sequential models execute a list of modules/functions in order 699 | (sequentially). Therefore, we can divide such a sequence into segments 700 | and checkpoint each segment. All segments except run in :func:`torch.no_grad` 701 | manner, i.e., not storing the intermediate activations. The inputs of each 702 | checkpointed segment will be saved for re-running the segment in the backward pass. 703 | See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. 704 | .. warning:: 705 | Checkpointing currently only supports :func:`torch.autograd.backward` 706 | and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` 707 | is not supported. 708 | .. warning: 709 | At least one of the inputs needs to have :code:`requires_grad=True` if 710 | grads are needed for model inputs, otherwise the checkpointed part of the 711 | model won't have gradients. 712 | Args: 713 | functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. 714 | x: A Tensor that is input to :attr:`functions` 715 | every: checkpoint every-n functions (default: 1) 716 | flatten (bool): flatten nn.Sequential of nn.Sequentials 717 | skip_last (bool): skip checkpointing the last function in the sequence if True 718 | preserve_rng_state (bool, optional, default=True): Omit stashing and restoring 719 | the RNG state during each checkpoint. 720 | Returns: 721 | Output of running :attr:`functions` sequentially on :attr:`*inputs` 722 | Example: 723 | >>> model = nn.Sequential(...) 724 | >>> input_var = checkpoint_seq(model, input_var, every=2) 725 | """ 726 | def run_function(start, end, functions): 727 | def forward(_x): 728 | for j in range(start, end + 1): 729 | _x = functions[j](_x) 730 | return _x 731 | return forward 732 | 733 | if isinstance(functions, torch.nn.Sequential): 734 | functions = functions.children() 735 | if flatten: 736 | functions = chain.from_iterable(functions) 737 | if not isinstance(functions, (tuple, list)): 738 | functions = tuple(functions) 739 | 740 | num_checkpointed = len(functions) 741 | if skip_last: 742 | num_checkpointed -= 1 743 | end = -1 744 | for start in range(0, num_checkpointed, every): 745 | end = min(start + every - 1, num_checkpointed - 1) 746 | x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) 747 | if skip_last: 748 | return run_function(end + 1, len(functions) - 1, functions)(x) 749 | return x 750 | 751 | 752 | archs = { 753 | 'ViT-T-p16': vit_tiny_patch16_384, 754 | 'ViT-T-p32': vit_tiny_patch32_384, 755 | 'ViT-S-p16': vit_small_patch16_384, 756 | 'ViT-S-p32': vit_small_patch32_384, 757 | 'ViT-B-p16': vit_base_patch16_384, 758 | 'ViT-B-p32': vit_base_patch32_384, 759 | 'ViT-L-p16': vit_large_patch16_384, 760 | 'ViT-L-p32': vit_large_patch32_384, 761 | } 762 | --------------------------------------------------------------------------------