├── images ├── pipeline.png ├── tutor_1.png ├── tutor_2.png ├── tutor_3.png ├── tutor_4.png ├── tutor_5.png ├── tutor_6.png ├── tutor_7.png ├── tutor_8.png ├── generation.png ├── onlinedemo.png └── RETFound-DE.png ├── .gitignore ├── __pycache__ ├── utility.cpython-37.pyc ├── utility.cpython-38.pyc ├── models_mae.cpython-37.pyc ├── models_mae.cpython-38.pyc └── models_vit.cpython-38.pyc ├── util ├── __pycache__ │ ├── data.cpython-37.pyc │ ├── misc.cpython-37.pyc │ ├── data_AMD.cpython-37.pyc │ ├── datasets.cpython-37.pyc │ ├── lr_decay.cpython-37.pyc │ ├── lr_sched.cpython-37.pyc │ ├── pos_embed.cpython-37.pyc │ ├── pos_embed.cpython-38.pyc │ ├── data_Cataract.cpython-37.pyc │ ├── data_DRGrading.cpython-37.pyc │ ├── data_Glaucoma.cpython-37.pyc │ └── data_MultiDisease.cpython-37.pyc ├── lr_sched.py ├── datasets.py ├── lr_decay.py ├── data.py ├── pos_embed.py ├── data_Radiology.py ├── misc.py ├── data_AMD.py ├── data_Cataract.py ├── data_Glaucoma.py └── data_MultiDisease.py ├── requirement.txt ├── mae_pretrain.sh ├── main_finetune.sh ├── main_evaluation.sh ├── models_vit.py ├── README_tutorial.md ├── README_SD.md ├── engine_pretrain.py ├── submitit_pretrain.py ├── Example.ipynb ├── README.md ├── main_pretrain.py ├── models_mae.py ├── visualize.py ├── engine_finetune.py ├── utility.py └── LICENSE /images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/images/pipeline.png -------------------------------------------------------------------------------- /images/tutor_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/images/tutor_1.png -------------------------------------------------------------------------------- /images/tutor_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/images/tutor_2.png -------------------------------------------------------------------------------- /images/tutor_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/images/tutor_3.png -------------------------------------------------------------------------------- /images/tutor_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/images/tutor_4.png -------------------------------------------------------------------------------- /images/tutor_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/images/tutor_5.png -------------------------------------------------------------------------------- /images/tutor_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/images/tutor_6.png -------------------------------------------------------------------------------- /images/tutor_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/images/tutor_7.png -------------------------------------------------------------------------------- /images/tutor_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/images/tutor_8.png -------------------------------------------------------------------------------- /images/generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/images/generation.png -------------------------------------------------------------------------------- /images/onlinedemo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/images/onlinedemo.png -------------------------------------------------------------------------------- /images/RETFound-DE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/images/RETFound-DE.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoint/ 2 | exampledata/ 3 | data/ 4 | RetinaDiffusion/ 5 | sd_app.py 6 | images/dml.png 7 | *.zip 8 | *.pth -------------------------------------------------------------------------------- /__pycache__/utility.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/__pycache__/utility.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utility.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/__pycache__/utility.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/util/__pycache__/data.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/util/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/models_mae.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/__pycache__/models_mae.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/models_mae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/__pycache__/models_mae.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/models_vit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/__pycache__/models_vit.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/data_AMD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/util/__pycache__/data_AMD.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/util/__pycache__/datasets.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/lr_decay.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/util/__pycache__/lr_decay.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/lr_sched.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/util/__pycache__/lr_sched.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/pos_embed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/util/__pycache__/pos_embed.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/pos_embed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/util/__pycache__/pos_embed.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/data_Cataract.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/util/__pycache__/data_Cataract.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/data_DRGrading.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/util/__pycache__/data_DRGrading.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/data_Glaucoma.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/util/__pycache__/data_Glaucoma.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/data_MultiDisease.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jonlysun/DERETFound/HEAD/util/__pycache__/data_MultiDisease.cpython-37.pyc -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | imageio==2.33.0 2 | opencv-python==4.5.3.56 3 | pandas==2.0.3 4 | Pillow==8.3.2 5 | protobuf==3.19.6 6 | pycm==3.2 7 | pydicom==2.3.0 8 | scikit-image==0.17.2 9 | scikit-learn==0.24.2 10 | scipy==1.5.4 11 | tensorboard==2.14.0 12 | tensorboard-data-server==0.7.0 13 | tensorboard-plugin-wit==1.8.0 14 | timm==0.3.2 15 | tqdm==4.62.1 16 | gradio==3.40.0 17 | grad-cam==1.4.8 18 | -------------------------------------------------------------------------------- /mae_pretrain.sh: -------------------------------------------------------------------------------- 1 | IMAGENET_DIR=YOUR_OWN_PATH 2 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=48797 main_pretrain.py \ 3 | --batch_size 224 \ 4 | --model mae_vit_large_patch16 \ 5 | --norm_pix_loss \ 6 | --mask_ratio 0.75 \ 7 | --epochs 200 \ 8 | --warmup_epochs 20 \ 9 | --blr 1.5e-4 --weight_decay 0.05 \ 10 | --data_path $IMAGENET_DIR \ 11 | --task './DERETFound/' \ 12 | --output_dir './DERETFound_log/' \ 13 | --resume ./mae_pretrain_vit_large.pth \ 14 | 15 | -------------------------------------------------------------------------------- /main_finetune.sh: -------------------------------------------------------------------------------- 1 | # DATASETS : ['DR_APTOS2019','DR_IDRID','DR_MESSIDOR2','Glaucoma_PAPILA','Glaucoma_Glaucoma_Fundus','Glaucoma_ORIGA','AMD_AREDS','Multi_Retina', 'Multi_JSIEC'] 2 | DATASET='DR_APTOS2019' 3 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=48797 main_finetune.py \ 4 | --batch_size 32 \ 5 | --world_size 1 \ 6 | --model vit_large_patch16 \ 7 | --epochs 50 \ 8 | --blr 5e-3 --layer_decay 0.65 \ 9 | --weight_decay 0.05 --drop_path 0.2 \ 10 | --root YOUR_OWN_PATH \ 11 | --task ./checkpoint/$DATASET/ \ 12 | --dataset_name $DATASET \ 13 | --finetune ./checkpoint/PreTraining/checkpoint-best.pth 14 | -------------------------------------------------------------------------------- /main_evaluation.sh: -------------------------------------------------------------------------------- 1 | # DATASETS : ['DR_APTOS2019','DR_IDRID','DR_MESSIDOR2','Glaucoma_PAPILA','Glaucoma_Glaucoma_Fundus','Glaucoma_ORIGA','AMD_AREDS','Multi_Retina', 'Multi_JSIEC'] 2 | DATASET='DR_APTOS2019' 3 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=48797 main_finetune.py \ 4 | --eval --batch_size 16 \ 5 | --world_size 1 \ 6 | --model vit_large_patch16 \ 7 | --epochs 50 \ 8 | --blr 5e-3 --layer_decay 0.65 \ 9 | --weight_decay 0.05 --drop_path 0.2 \ 10 | --nb_classes 5 \ 11 | --root YOUR_OWN_PATH \ 12 | --task ./Results/internal_$DATASET/ \ 13 | --resume ./checkpoint/$DATASET/checkpoint-best.pth \ 14 | --dataset_name $DATASET 15 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # Partly revised by YZ @UCL&Moorfields 4 | # -------------------------------------------------------- 5 | 6 | import math 7 | 8 | def adjust_learning_rate(optimizer, epoch, args): 9 | """Decay the learning rate with half-cycle cosine after warmup""" 10 | if epoch < args.warmup_epochs: 11 | lr = args.lr * epoch / args.warmup_epochs 12 | else: 13 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 14 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 15 | for param_group in optimizer.param_groups: 16 | if "lr_scale" in param_group: 17 | param_group["lr"] = lr * param_group["lr_scale"] 18 | else: 19 | param_group["lr"] = lr 20 | return lr 21 | -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # Partly revised by YZ @UCL&Moorfields 4 | # -------------------------------------------------------- 5 | 6 | import os 7 | from torchvision import datasets, transforms 8 | from timm.data import create_transform 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | 11 | 12 | def build_dataset(is_train, args): 13 | 14 | transform = build_transform(is_train, args) 15 | root = os.path.join(args.data_path, is_train) 16 | dataset = datasets.ImageFolder(root, transform=transform) 17 | 18 | return dataset 19 | 20 | 21 | def build_transform(is_train, args): 22 | mean = IMAGENET_DEFAULT_MEAN 23 | std = IMAGENET_DEFAULT_STD 24 | # train transform 25 | if is_train=='train': 26 | # this should always dispatch to transforms_imagenet_train 27 | transform = create_transform( 28 | input_size=args.input_size, 29 | is_training=True, 30 | color_jitter=args.color_jitter, 31 | auto_augment=args.aa, 32 | interpolation='bicubic', 33 | re_prob=args.reprob, 34 | re_mode=args.remode, 35 | re_count=args.recount, 36 | mean=mean, 37 | std=std, 38 | ) 39 | return transform 40 | 41 | # eval transform 42 | t = [] 43 | if args.input_size <= 224: 44 | crop_pct = 224 / 256 45 | else: 46 | crop_pct = 1.0 47 | size = int(args.input_size / crop_pct) 48 | t.append( 49 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 50 | ) 51 | t.append(transforms.CenterCrop(args.input_size)) 52 | t.append(transforms.ToTensor()) 53 | t.append(transforms.Normalize(mean, std)) 54 | return transforms.Compose(t) 55 | -------------------------------------------------------------------------------- /models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # Partly revised by YZ @UCL&Moorfields 4 | # -------------------------------------------------------- 5 | 6 | from functools import partial 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | import timm.models.vision_transformer 12 | 13 | 14 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 15 | """ Vision Transformer with support for global average pooling 16 | """ 17 | def __init__(self, global_pool=False, **kwargs): 18 | super(VisionTransformer, self).__init__(**kwargs) 19 | 20 | self.global_pool = global_pool 21 | if self.global_pool: 22 | norm_layer = kwargs['norm_layer'] 23 | embed_dim = kwargs['embed_dim'] 24 | self.fc_norm = norm_layer(embed_dim) 25 | 26 | del self.norm # remove the original norm 27 | 28 | def forward_features(self, x): 29 | B = x.shape[0] 30 | x = self.patch_embed(x) 31 | 32 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 33 | x = torch.cat((cls_tokens, x), dim=1) 34 | x = x + self.pos_embed 35 | x = self.pos_drop(x) 36 | 37 | for blk in self.blocks: 38 | x = blk(x) 39 | 40 | if self.global_pool: 41 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 42 | outcome = self.fc_norm(x) 43 | else: 44 | x = self.norm(x) 45 | outcome = x[:, 0] 46 | 47 | return outcome 48 | 49 | 50 | def vit_large_patch16(**kwargs): 51 | model = VisionTransformer( 52 | img_size=224,patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 53 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 54 | return model 55 | 56 | -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # Partly revised by YZ @UCL&Moorfields 4 | # -------------------------------------------------------- 5 | 6 | import json 7 | 8 | 9 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 10 | """ 11 | Parameter groups for layer-wise lr decay 12 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 13 | """ 14 | param_group_names = {} 15 | param_groups = {} 16 | 17 | num_layers = len(model.blocks) + 1 18 | 19 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 20 | 21 | for n, p in model.named_parameters(): 22 | if not p.requires_grad: 23 | continue 24 | 25 | # no decay: all 1D parameters and model specific ones 26 | if p.ndim == 1 or n in no_weight_decay_list: 27 | g_decay = "no_decay" 28 | this_decay = 0. 29 | else: 30 | g_decay = "decay" 31 | this_decay = weight_decay 32 | 33 | layer_id = get_layer_id_for_vit(n, num_layers) 34 | group_name = "layer_%d_%s" % (layer_id, g_decay) 35 | 36 | if group_name not in param_group_names: 37 | this_scale = layer_scales[layer_id] 38 | 39 | param_group_names[group_name] = { 40 | "lr_scale": this_scale, 41 | "weight_decay": this_decay, 42 | "params": [], 43 | } 44 | param_groups[group_name] = { 45 | "lr_scale": this_scale, 46 | "weight_decay": this_decay, 47 | "params": [], 48 | } 49 | 50 | param_group_names[group_name]["params"].append(n) 51 | param_groups[group_name]["params"].append(p) 52 | 53 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 54 | 55 | return list(param_groups.values()) 56 | 57 | 58 | def get_layer_id_for_vit(name, num_layers): 59 | """ 60 | Assign a parameter with its layer id 61 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 62 | """ 63 | if name in ['cls_token', 'pos_embed']: 64 | return 0 65 | elif name.startswith('patch_embed'): 66 | return 0 67 | elif name.startswith('blocks'): 68 | return int(name.split('.')[1]) + 1 69 | else: 70 | return num_layers -------------------------------------------------------------------------------- /util/data.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | from torch.utils import data 7 | from torchvision import transforms 8 | import glob 9 | import random 10 | 11 | class DS(data.Dataset): 12 | def __init__(self, root, transform=None): 13 | 14 | self.samples = [] 15 | self.labels = [] 16 | 17 | imgs_path_list = [ 18 | "EyePACS/train-FULL/train/", 19 | "DDR/DDR-dataset/DR_grading/train/", 20 | "DDR/DDR-dataset/lesion_segmentation/train/image", 21 | "AIROGS", 22 | "ODIR-5K/odir5k/ODIR-5K/ODIR-5K/Training_Images", 23 | 24 | ] 25 | 26 | for img_path in imgs_path_list: 27 | sample = self.get_image_labels(root + img_path) 28 | print(img_path, len(sample)) 29 | self.samples.extend(sample) 30 | 31 | image_number = len(self.samples) 32 | print(f'Real image number: {image_number}') 33 | 34 | if len(self.samples) == 0: 35 | raise RuntimeError("Found 0 files in subfolders of: " + root) 36 | else: 37 | # print("Val_dataset:", val_data_name) 38 | print("Samples:", len(self.samples)) 39 | 40 | self.transform = transform 41 | 42 | def __len__(self): 43 | return len(self.samples) 44 | 45 | def get_image_labels(self, imgs_path): 46 | samples = [] 47 | for imgs_path, _, fnames in sorted(os.walk(imgs_path)): 48 | for fname in sorted(fnames): 49 | if '.jpg' in fname or '.png' in fname or '.jpeg' in fname or '.JPG' in fname: 50 | path = os.path.join(imgs_path, fname) 51 | samples.append(path) 52 | return samples 53 | 54 | def pad(self, im, padding=64): 55 | h, w = im.shape[-2:] 56 | mh = h % padding 57 | ph = 0 if mh == 0 else padding - mh 58 | mw = w % padding 59 | pw = 0 if mw == 0 else padding - mw 60 | shape = [s for s in im.shape] 61 | shape[-2] += ph 62 | shape[-1] += pw 63 | im_p = np.zeros(shape, dtype=im.dtype) 64 | im_p[..., :h, :w] = im 65 | im = im_p 66 | return im 67 | 68 | def __getitem__(self, index): 69 | 70 | sample_path = self.samples[index] 71 | sample = Image.open(sample_path).convert('RGB') 72 | sample_name = sample_path.split('/')[-1] 73 | 74 | if self.transform is not None: 75 | sample = self.transform(sample) 76 | 77 | # return sample 78 | return sample 79 | -------------------------------------------------------------------------------- /README_tutorial.md: -------------------------------------------------------------------------------- 1 | # Tutorial for online demo 2 | 3 | Here we provide a detailed introduction to the use of the [online demo](http://fdudml.cn:12001/) for RETFound-DE. 4 | 5 | In the demo, we offer three features, including: 6 | 7 | - MAE reconstructed images: We show the input and output of MAE self-supervision. 8 | - Diagnostic probability: We present the diagnostic probability for diseases by RETFound-DE in the form of a bar chart. 9 | - Interpretable heatmaps: We display the heatmap of model inference on the input retinal image. 10 | 11 | Below, we provide a detailed introduction to the specific usage process. 12 | 13 | ## Upload Image 14 | 15 | Online demo starts with uploading images. 16 | 17 | Users can upload their own fundus image by clicking here 18 | ![](images/tutor_1.png) 19 | 20 | Also, users can use the images we have provided. Click the image name to select the image like below: 21 | ![](images/tutor_2.png) 22 | 23 | Finally, click the **Check input** to visualize the input image. 24 | ![](images/tutor_3.png) 25 | 26 | ## MAE reconstructed images (Optional) 27 | After uploading an image, you can click **Run** in **Load and Run Pre-Training model** if you want to the reconstructed images of MAE. 28 | 29 | - "Masked" is the masked image of original image and the input seet into RETFound-DE. 30 | - "Reconstruction" is the output of our self-supervised pre-training Model RETFound-DE. 31 | - "Reconstruction + Visible" is the result of the fusion of "Reconstruction" and "Masked". 32 | 33 | ![](images/tutor_4.png) 34 | 35 | ## Diagnostic probability and Interpretable heatmaps: 36 | 37 | Before you can get the disease diagnosis probability and interpretable heatmaps, you first need to select the disease diagnosis task and model type. 38 | 39 | We provide models for different tasks and datasets, including: 40 | - diabetic retinopathy grading (Kaggle APTOS-2019, IDRiD, MESSIDOR2) 41 | - glaucoma diagnosis (PAPILA, Glaucoma Fundus, ORIGA) 42 | - Age-related macular degeneration grading (AREDS) 43 | - Multi-diseases classification (Retina, JSIEC) 44 | 45 | Each dataset above represents a fine-tuned downstream model. Please note that different downstream task datasets may represent different classification tasks. For example, For example, glaucoma diagnosis is a three-category task on PAPILA but a two-category task on ORIGA. 46 | 47 | ### Step 1: We can select the 'Model Type' in 'Load and Run Fine-tuning Model' like this: 48 | ![](images/tutor_5.png) 49 | 50 | ### Step 2: Click 'Load Model' to load the model. This may take a while. 51 | After successfully loading the model, we can see the correct model information. 52 | ![](images/tutor_6.png) 53 | 54 | ### Step 3: Click 'Run' to run the model on the input image. 55 | Once you have successfully run the model, the online demo will show a disease diagnostic probability bar chart like: 56 | ![](images/tutor_7.png) 57 | and an heatmaps to indicate how the model make the decision: 58 | ![](images/tutor_8.png) 59 | -------------------------------------------------------------------------------- /README_SD.md: -------------------------------------------------------------------------------- 1 | # Controllable Generative Model Enables Ultra-High Data Efficiency for Building Generalist Medical Foundation Model 2 | 3 | We present the information and deployment of the retinal image stable diffusion model v1.4 (ReSDv1.4) used in RETFound-DE here. 4 | 5 | ## Bassic Information: 6 | - Retinal image stable diffusion model v1.4 (ReSDv1.4) is based on stable diffusion v1.4. 7 | - We fine-tuned it for 60000 iteration on 150k retinal image text-image pairs. 8 | - It follows text-to-image stragety to generate retinal image and takes about 7 seconds to generate an image each time on an NVIDIA GTX 3090. 9 | - The resolution of generated retinal image is 512x512. 10 | 11 | ## Prepare the environment 12 | 13 | 1. Download the stable diffusion v1.4 and ReSDv1.4 model 14 | 15 | You can download the stable diffusion v1.4 from [HuggingFace](https://huggingface.co/CompVis/stable-diffusion-v1-4) and ReSDv1.4 from [Zenodo:sd-retina-model](https://zenodo.org/records/10947092) or [baiduDisk code:7n7v](https://pan.baidu.com/s/1TBVNlaR9xW_rqA8ZdrRuOg). 16 | 17 | 18 | 2. Install Diffusers 19 | 20 | Install [Diffusers](https://github.com/huggingface/diffusers) in a virtual environment from PyPI or Conda. 21 | 22 | Please note that the version of Diffusers may influence the deployment. In our experiments, we use Diffusers v0.21.4. 23 | 24 | ## Inference 25 | 26 | ``` 27 | import os 28 | from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel 29 | from transformers import CLIPTextModel 30 | 31 | my_model_path = "ReSDv1.4 path, e.g., /home/user/sd-retinal-model/checkpoint-60000/ " 32 | pre_trained_model = "stable diffusion v1.4 path, e.g., /home/user/stable-diffusion-v1-4 " 33 | text_encoder = CLIPTextModel.from_pretrained(pre_trained_model, subfolder="text_encoder") 34 | vae = AutoencoderKL.from_pretrained(pre_trained_model, subfolder="vae") 35 | 36 | unet = UNet2DConditionModel.from_pretrained(my_model_path, subfolder="unet") 37 | 38 | pipe = StableDiffusionPipeline.from_pretrained( 39 | pre_trained_model, 40 | text_encoder=text_encoder, 41 | vae=vae, 42 | unet=unet, 43 | ) 44 | 45 | pipe.to("cuda") 46 | 47 | # define your prompt 48 | prompt = "No Diabetic Retinopathy" 49 | 50 | image = pipe(prompt=prompt).images[0] 51 | image.save("test.png") 52 | ``` 53 | 54 | We present some disease prompt here, for more information please refer to our paper: 55 | ``` 56 | # disease prompts for retinal image generation 57 | 58 | Normal fundus 59 | No referable glaucoma 60 | Referable glaucoma 61 | No Diabetic Retinopathy 62 | Diabetic Retinopathy 63 | Mild Non-Proliferative Diabetic Retinopathy 64 | Moderate Non-Proliferative Diabetic Retinopathy 65 | Severe Non-Proliferative Diabetic Retinopathy 66 | Proliferative Diabetic Retinopathy 67 | ``` 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | Please contact **sunyuqi387@gmail.com** if you have questions. 79 | -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import math 12 | import sys 13 | from typing import Iterable 14 | 15 | import torch 16 | 17 | import util.misc as misc 18 | import util.lr_sched as lr_sched 19 | 20 | 21 | def train_one_epoch(model: torch.nn.Module, 22 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 23 | device: torch.device, epoch: int, loss_scaler, 24 | log_writer=None, 25 | args=None): 26 | model.train(True) 27 | metric_logger = misc.MetricLogger(delimiter=" ") 28 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 29 | header = 'Epoch: [{}]'.format(epoch) 30 | print_freq = 20 31 | 32 | accum_iter = args.accum_iter 33 | 34 | optimizer.zero_grad() 35 | 36 | if log_writer is not None: 37 | print('log_dir: {}'.format(log_writer.log_dir)) 38 | 39 | for data_iter_step, samples in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 40 | 41 | # we use a per iteration (instead of per epoch) lr scheduler 42 | if data_iter_step % accum_iter == 0: 43 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 44 | 45 | samples = samples.to(device, non_blocking=True) 46 | 47 | with torch.cuda.amp.autocast(): 48 | loss, _, _ = model(samples, mask_ratio=args.mask_ratio) 49 | 50 | loss_value = loss.item() 51 | 52 | if not math.isfinite(loss_value): 53 | print("Loss is {}, stopping training".format(loss_value)) 54 | # sys.exit(1) 55 | continue 56 | 57 | loss /= accum_iter 58 | loss_scaler(loss, optimizer, parameters=model.parameters(), 59 | update_grad=(data_iter_step + 1) % accum_iter == 0) 60 | if (data_iter_step + 1) % accum_iter == 0: 61 | optimizer.zero_grad() 62 | 63 | torch.cuda.synchronize() 64 | 65 | metric_logger.update(loss=loss_value) 66 | 67 | lr = optimizer.param_groups[0]["lr"] 68 | metric_logger.update(lr=lr) 69 | 70 | loss_value_reduce = misc.all_reduce_mean(loss_value) 71 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 72 | """ We use epoch_1000x as the x-axis in tensorboard. 73 | This calibrates different curves when batch size changes. 74 | """ 75 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 76 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 77 | log_writer.add_scalar('lr', lr, epoch_1000x) 78 | 79 | 80 | # gather the stats from all processes 81 | metric_logger.synchronize_between_processes() 82 | print("Averaged stats:", metric_logger) 83 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # Partly revised by YZ @UCL&Moorfields 4 | # -------------------------------------------------------- 5 | 6 | import numpy as np 7 | 8 | import torch 9 | 10 | # -------------------------------------------------------- 11 | # 2D sine-cosine position embedding 12 | # References: 13 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 14 | # MoCo v3: https://github.com/facebookresearch/moco-v3 15 | # -------------------------------------------------------- 16 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 17 | """ 18 | grid_size: int of the grid height and width 19 | return: 20 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 21 | """ 22 | grid_h = np.arange(grid_size, dtype=np.float32) 23 | grid_w = np.arange(grid_size, dtype=np.float32) 24 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 25 | grid = np.stack(grid, axis=0) 26 | 27 | grid = grid.reshape([2, 1, grid_size, grid_size]) 28 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 29 | if cls_token: 30 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 31 | return pos_embed 32 | 33 | 34 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 35 | assert embed_dim % 2 == 0 36 | 37 | # use half of dimensions to encode grid_h 38 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 39 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 40 | 41 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 42 | return emb 43 | 44 | 45 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 46 | """ 47 | embed_dim: output dimension for each position 48 | pos: a list of positions to be encoded: size (M,) 49 | out: (M, D) 50 | """ 51 | assert embed_dim % 2 == 0 52 | omega = np.arange(embed_dim // 2, dtype=np.float32) 53 | omega /= embed_dim / 2. 54 | omega = 1. / 10000**omega # (D/2,) 55 | 56 | pos = pos.reshape(-1) # (M,) 57 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 58 | 59 | emb_sin = np.sin(out) # (M, D/2) 60 | emb_cos = np.cos(out) # (M, D/2) 61 | 62 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 63 | return emb 64 | 65 | 66 | # -------------------------------------------------------- 67 | # Interpolate position embeddings for high-resolution 68 | # References: 69 | # DeiT: https://github.com/facebookresearch/deit 70 | # -------------------------------------------------------- 71 | def interpolate_pos_embed(model, checkpoint_model): 72 | if 'pos_embed' in checkpoint_model: 73 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 74 | embedding_size = pos_embed_checkpoint.shape[-1] 75 | num_patches = model.patch_embed.num_patches 76 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 77 | # height (== width) for the checkpoint position embedding 78 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 79 | # height (== width) for the new position embedding 80 | new_size = int(num_patches ** 0.5) 81 | # class_token and dist_token are kept unchanged 82 | if orig_size != new_size: 83 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 84 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 85 | # only the position tokens are interpolated 86 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 87 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 88 | pos_tokens = torch.nn.functional.interpolate( 89 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 90 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 91 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 92 | checkpoint_model['pos_embed'] = new_pos_embed 93 | -------------------------------------------------------------------------------- /submitit_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_pretrain as trainer 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | trainer_parser = trainer.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE pretrain", parents=[trainer_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("checkpoint/").is_dir(): 36 | p = Path(f"checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | path = Path('/cpfs01/projects-HDD/neikiuyiliaodamoxing_HDD/sunyuqi/code/RETFound_MAE/') / get_shared_folder() 46 | init_file = path / f"{uuid.uuid4().hex}_init" 47 | if init_file.exists(): 48 | os.remove(str(init_file)) 49 | 50 | return init_file 51 | 52 | class Trainer(object): 53 | def __init__(self, args): 54 | self.args = args 55 | 56 | def __call__(self): 57 | import main_pretrain as trainer 58 | 59 | self._setup_gpu_args() 60 | trainer.main(self.args) 61 | 62 | def checkpoint(self): 63 | import os 64 | import submitit 65 | 66 | self.args.dist_url = get_init_file().as_uri() 67 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 68 | if os.path.exists(checkpoint_file): 69 | self.args.resume = checkpoint_file 70 | print("Requeuing ", self.args) 71 | empty_trainer = type(self)(self.args) 72 | return submitit.helpers.DelayedSubmission(empty_trainer) 73 | 74 | def _setup_gpu_args(self): 75 | import submitit 76 | from pathlib import Path 77 | 78 | job_env = submitit.JobEnvironment() 79 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 80 | self.args.log_dir = self.args.output_dir 81 | self.args.gpu = job_env.local_rank 82 | self.args.rank = job_env.global_rank 83 | self.args.world_size = job_env.num_tasks 84 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 85 | 86 | 87 | def main(): 88 | args = parse_args() 89 | print(args.job_dir) 90 | if args.job_dir == "": 91 | args.job_dir = get_shared_folder() / "%j" 92 | 93 | 94 | # Note that the folder will depend on the job_id, to easily track experiments 95 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 96 | 97 | num_gpus_per_node = args.ngpus 98 | nodes = args.nodes 99 | timeout_min = args.timeout 100 | 101 | partition = args.partition 102 | kwargs = {} 103 | if args.use_volta32: 104 | kwargs['slurm_constraint'] = 'volta32gb' 105 | if args.comment: 106 | kwargs['slurm_comment'] = args.comment 107 | 108 | executor.update_parameters( 109 | mem_gb=40 * num_gpus_per_node, 110 | gpus_per_node=num_gpus_per_node, 111 | tasks_per_node=num_gpus_per_node, # one task per GPU 112 | cpus_per_task=10, 113 | nodes=nodes, 114 | timeout_min=timeout_min, # max is 60 * 72 115 | # Below are cluster dependent parameters 116 | slurm_partition=partition, 117 | slurm_signal_delay_s=120, 118 | **kwargs 119 | ) 120 | 121 | executor.update_parameters(name="mae") 122 | 123 | args.dist_url = get_init_file().as_uri() 124 | args.output_dir = args.job_dir 125 | 126 | trainer = Trainer(args) 127 | job = executor.submit(trainer) 128 | 129 | # print("Submitted job_id:", job.job_id) 130 | print(job.job_id) 131 | 132 | output = job.result() 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /Example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b2e049c7-d5db-45e6-b651-2601c02f4b7d", 6 | "metadata": {}, 7 | "source": [ 8 | "## Data organisation example - IDRiD" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 34, 14 | "id": "16b65740-249b-4eef-9298-1db01f72d050", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import os\n", 19 | "import shutil\n", 20 | "import pickle\n", 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "from sklearn.model_selection import train_test_split\n", 24 | "\n", 25 | "# replace with your own data path to save IDRiD datasets\n", 26 | "DATAPATH = './datasets/IDRiD/B_Disease_Grading/1__Original_Images/'" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "b12bad44", 32 | "metadata": {}, 33 | "source": [] 34 | }, 35 | { 36 | "attachments": {}, 37 | "cell_type": "markdown", 38 | "id": "ff0bf26e-c657-49de-8761-89d5a94c390d", 39 | "metadata": {}, 40 | "source": [ 41 | "### Split val set from train data\n", 42 | "- Download dataset from [official website](https://ieee-dataport.org/open-access/indian-diabetic-retinopathy-image-dataset-idrid) " 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 35, 48 | "id": "4bc1cb67-0adf-4640-8640-d0740a39366b", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "list_ = pd.read_csv('a__IDRiD_Disease_Grading_Training_Labels_csv')" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 36, 58 | "id": "b85fc0d1-2049-4550-bdec-76240b1bc759", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "noDR = list_.loc[list_['Retinopathy grade']==0, 'Image name']\n", 63 | "mildDR = list_.loc[list_['Retinopathy grade']==1, 'Image name']\n", 64 | "moderateDR = list_.loc[list_['Retinopathy grade']==2, 'Image name']\n", 65 | "severeDR = list_.loc[list_['Retinopathy grade']==3, 'Image name']\n", 66 | "proDR = list_.loc[list_['Retinopathy grade']==4, 'Image name']" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 37, 72 | "id": "d0617e35-8b91-45d3-90d5-d5e5bf2d7762", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "noDR_train, noDR_val = train_test_split(noDR, test_size=0.2,random_state=1)\n", 77 | "mildDR_train, mildDR_val = train_test_split(mildDR, test_size=0.2,random_state=1)\n", 78 | "moderateDR_train, moderateDR_val = train_test_split(moderateDR, test_size=0.2,random_state=1)\n", 79 | "severeDR_train, severeDR_val = train_test_split(severeDR, test_size=0.2,random_state=1)\n", 80 | "proDR_train, proDR_val = train_test_split(proDR, test_size=0.2,random_state=1)\n" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 38, 86 | "id": "f30ce03f-5730-4e68-b6c5-8e1b6b9167f8", 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "{'img_root': './datasets/IDRiD/B_Disease_Grading/1__Original_Images/a__Training_Set\\\\IDRiD_178.jpg', 'label': 4}\n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "train_list = [noDR_train, mildDR_train, moderateDR_train, severeDR_train, proDR_train]\n", 99 | "for idx, disease in enumerate(train_list):\n", 100 | " data = [{'img_root': os.path.join(DATAPATH, 'a__Training_Set', value+'.jpg'), 'label': idx} for value in disease]\n", 101 | "print(data[0])\n", 102 | "save_path = 'data/IDRiD'\n", 103 | "os.makedirs(save_path, exist_ok=True)\n", 104 | "with open(os.path.join(save_path, 'train.pkl') , 'wb') as file:\n", 105 | " pickle.dump(np.array(data), file)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 39, 111 | "id": "196d1845-3e5e-4d38-82e5-66057a693962", 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "name": "stdout", 116 | "output_type": "stream", 117 | "text": [ 118 | "{'img_root': './datasets/IDRiD/B_Disease_Grading/1__Original_Images/a__Training_Set\\\\IDRiD_100.jpg', 'label': 4}\n" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "val_list = [noDR_val, mildDR_val, moderateDR_val, severeDR_val, proDR_val]\n", 124 | "for idx, disease in enumerate(val_list):\n", 125 | " data = [{'img_root': os.path.join(DATAPATH, 'a__Training_Set', value+'.jpg'), 'label': idx} for value in disease]\n", 126 | "print(data[0])\n", 127 | "with open(os.path.join(save_path, 'val.pkl') , 'wb') as file:\n", 128 | " pickle.dump(np.array(data), file)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "id": "faf285f4-9079-49ca-9d99-8f3f5718afbf", 134 | "metadata": {}, 135 | "source": [ 136 | "### Organise test set" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 40, 142 | "id": "118d15d0-9e94-4f6e-855d-dfa3796b24d2", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "list_test = pd.read_csv('b__IDRiD_Disease_Grading_Testing_Labels_csv')" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 41, 152 | "id": "89a098fe-0aad-41d4-ab09-476ff0354c77", 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "noDR_test = list_test.loc[list_test['Retinopathy grade']==0, 'Image name']\n", 157 | "mildDR_test = list_test.loc[list_test['Retinopathy grade']==1, 'Image name']\n", 158 | "moderateDR_test = list_test.loc[list_test['Retinopathy grade']==2, 'Image name']\n", 159 | "severeDR_test = list_test.loc[list_test['Retinopathy grade']==3, 'Image name']\n", 160 | "proDR_test = list_test.loc[list_test['Retinopathy grade']==4, 'Image name']" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 42, 166 | "id": "33a207c1-1fef-4e79-8ff2-84329062495b", 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "name": "stdout", 171 | "output_type": "stream", 172 | "text": [ 173 | "{'img_root': './datasets/IDRiD/B_Disease_Grading/1__Original_Images/b__Testing_Set\\\\IDRiD_001.jpg', 'label': 4}\n" 174 | ] 175 | } 176 | ], 177 | "source": [ 178 | "test_list = [noDR_test, mildDR_test, moderateDR_test, severeDR_test, proDR_test]\n", 179 | "for idx, disease in enumerate(test_list):\n", 180 | " data = [{'img_root': os.path.join(DATAPATH, 'b__Testing_Set', value+'.jpg'), 'label': idx} for value in disease]\n", 181 | "print(data[0])\n", 182 | "with open(os.path.join(save_path, 'test.pkl') , 'wb') as file:\n", 183 | " pickle.dump(np.array(data), file)" 184 | ] 185 | } 186 | ], 187 | "metadata": { 188 | "environment": { 189 | "kernel": "python3", 190 | "name": "common-cu110.m91", 191 | "type": "gcloud", 192 | "uri": "gcr.io/deeplearning-platform-release/base-cu110:m91" 193 | }, 194 | "kernelspec": { 195 | "display_name": "Python 3", 196 | "language": "python", 197 | "name": "python3" 198 | }, 199 | "language_info": { 200 | "codemirror_mode": { 201 | "name": "ipython", 202 | "version": 3 203 | }, 204 | "file_extension": ".py", 205 | "mimetype": "text/x-python", 206 | "name": "python", 207 | "nbconvert_exporter": "python", 208 | "pygments_lexer": "ipython3", 209 | "version": "3.8.16" 210 | } 211 | }, 212 | "nbformat": 4, 213 | "nbformat_minor": 5 214 | } 215 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Controllable Generative Model Enables High Data Efficiency for Building Medical Foundation Model 2 | 3 | RETFound-DE is a medical foundation model from retinal images that enables high data efficiency. 4 | 5 | This is the official repo for RETFound-DE, which is based on [MAE](https://github.com/facebookresearch/mae) and [RETFound](https://github.com/rmaphoh/RETFound_MAE/tree/main) (Y. Zhou et al, Nature 2023): 6 | 7 | ## News 8 | - [x] Release the code of RETFound-DE 9 | - [x] Release the pre-training model and fine-tuning models of RETFound-DE 10 | - [x] Release the stable diffusion model of RETFound-DE 11 | 12 | ## Key features 13 | 14 | - **Ultra-High Data Efficiency:** RETFound-DE enables ultra-high data efficiency and only uses 16.7% of the colour fundus photography retinal image required in RETFound. 15 | - **Excellent performance:** Extensive experiments on nine datasets across four ocular disease detection tasks demonstrate the excellent performance of RETFound-DE in improving the detection of eye diseases, label and fine-tuning time efficiency. 16 | - **Transferable:** RETFound-DE provides an effective solution for other diseases that were once discouraged from building foundation models due to limited data, which has profound significance for generalist medical AI. 17 | 18 | 23 | 24 | ## Prepare the environment 25 | 26 | 1. Download the pre-training and fine-tuning model 27 | 28 | You can download the pre-training model and fine-tuning models from [Zenodo](https://zenodo.org/records/13340936) or [baiduDisk code:7n7v ](https://pan.baidu.com/s/1TBVNlaR9xW_rqA8ZdrRuOg) and the example images named exampledata.zip from [here](https://drive.google.com/file/d/1f1Lmdtf1LELYpKpthEawastSJWXEWIWb/view?usp=drive_link). Then, you can unzip the file and put the folder `exampledata` and `checkpoint` in the root directory of RETFound-DE. 29 | 30 | ``` 31 | exampledata/ 32 | AMD/ 33 | DR/ 34 | Glaucoma/ 35 | Multi-disease/ 36 | 37 | checkpoint/ 38 | AMD_AREDS/ 39 | DR_APTOS2019/ 40 | DR_IDRID/ 41 | DR_MESSIDOR2/ 42 | Glaucoma_Glaucoma_Fundus/ 43 | Glaucoma_ORIGA/ 44 | Glaucoma_PAPILA/ 45 | Multi_JSIEC/ 46 | Multi_Retina/ 47 | PreTraining/ 48 | ``` 49 | 50 | 1. Install enviroment 51 | 52 | Create enviroment with conda: 53 | 54 | ``` 55 | conda create -n RETFound-DE python=3.8 -y 56 | conda activate RETFound-DE 57 | ``` 58 | Install Pytorch 1.13 (cuda 11.7) 59 | ``` 60 | pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 -f https://download.pytorch.org/whl/torch_stable.html 61 | ``` 62 | 63 | Install others 64 | ``` 65 | git clone https://github.com/Jonlysun/RETFound-DE/ 66 | cd RETFound-DE 67 | pip install -r requirement.txt 68 | ``` 69 | If you have the following error: 70 | ``` 71 | ImportError: cannot import name 'container_abcs' from 'torch._six' 72 | ``` 73 | please refer to the solution in [here](https://github.com/huggingface/pytorch-image-models/issues/420). 74 | 75 | ## Offline Demo 76 | ### User Interface for RETFound-DE 77 | 78 | You can run the web interface locally by the following command: 79 | ``` 80 | python app.py 81 | ``` 82 | 83 | Then, you can visit the web interface at [http://127.0.0.1:7891/](http://127.0.0.1:7860/). You can upload your own image or use our examples to run RETFound-DE. 84 | 85 | ### Visualize with code 86 | We also provide a `visualize.py` to generate the **MAE reconstructed images**, **diagnostic probability** and **interpretable heatmaps**. Your can run the following command: 87 | ``` 88 | # MAE reconstructed images. Result is the 'mae.png' 89 | python visualize.py --mode mae --img_path XXXX 90 | 91 | # Diagnostic probability. Result is the 'classification.png' 92 | python visualize.py --mode classification --img_path XXXX --ft_model XXXX (e.g., DR_APTOS2019) 93 | 94 | # Interpretable heatmaps. Result is the 'cam.png' 95 | python visualize.py --mode cam --img_path XXXX --ft_model XXXX (e.g., DR_APTOS2019) 96 | ``` 97 | 98 | ## Evaluate or fine-tune RETFound-DE 99 | ### 1. Prepare the datasets 100 | - Firstly, you can download the public dataset following the url in `Data availability` in our paper. 101 | - Then, you can split the dataset into train, val, test datasets following the Supplementary Table 1 in our paper 102 | - Finally, generate three 'train.pkl', 'val.pkl', 'test.pkl' files containing the information about 'img_root' and 'label' for each dataset. (using IDRiD as an example) 103 | 104 | We use IDRiD as an [example](Example.ipynb). 105 | ``` 106 | data/ 107 | IDRiD/ 108 | train.pkl 109 | val.pkl 110 | test.pkl 111 | ``` 112 | If you want to follow the same split in our paper, you can download '.pkl' files from [here](https://drive.google.com/file/d/1lMYGntHw9H9XsPxelfNHTrZG4z5pqKo3/view?usp=drive_link) and put `data` in root directory. Also, you may need to post-process these files with your own path and replace the `train_data_dir` in main_finetune.py with your own path. 113 | 114 | ### 2. Evaluation 115 | You can use the following command or run the 'bash main_evaluation.sh'. Please remember replace the root path with your own dataset path 116 | ``` 117 | # chose the dataset 118 | DATASET='DR_APTOS2019' 119 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=48797 main_finetune.py \ 120 | --eval --batch_size 16 \ 121 | --world_size 1 \ 122 | --model vit_large_patch16 \ 123 | --epochs 50 \ 124 | --blr 5e-3 --layer_decay 0.65 \ 125 | --weight_decay 0.05 --drop_path 0.2 \ 126 | --nb_classes 5 \ 127 | --root YOUR_OWN_DATASET_PATH \ 128 | --task ./Results/internal_$DATASET/ \ 129 | --resume ./checkpoint/$DATASET/checkpoint-best.pth \ 130 | --dataset_name $DATASET 131 | ``` 132 | ### 3. Fine-tuning 133 | You can use the following command or run the 'bash main_finetune.sh'. Please remember replace the root path with your own dataset path 134 | ``` 135 | # chose the dataset 136 | DATASET='DR_APTOS2019' 137 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=40003 main_finetune.py \ 138 | --batch_size 64 \ 139 | --world_size 1 \ 140 | --model vit_large_patch16 \ 141 | --epochs 50 \ 142 | --blr 5e-3 --layer_decay 0.65 \ 143 | --weight_decay 0.05 --drop_path 0.2 \ 144 | --root YOUR_OWN_DATASET_PATH \ 145 | --task ./Results/$DATASET/ \ 146 | --dataset_name $DATASET \ 147 | --finetune ./checkpoint/PreTraining/checkpoint-best.pth 148 | 149 | ``` 150 | 151 | ## Pre-Training 152 | You can use the following command or run the 'bash main_pretrain.sh'. Please remember replace the root path with your own dataset path. You can download the `mae_pretrain_vit_large.pth` from the official repo of [MAE](https://github.com/facebookresearch/mae). 153 | ``` 154 | IMAGE_DIR='YOUR_IMAGE_DIR' 155 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=48797 main_pretrain.py \ 156 | --batch_size 224 \ 157 | --model mae_vit_large_patch16 \ 158 | --norm_pix_loss \ 159 | --mask_ratio 0.75 \ 160 | --epochs 200 \ 161 | --warmup_epochs 20 \ 162 | --blr 1.5e-4 --weight_decay 0.05 \ 163 | --data_path ${IMAGE_DIR} \ 164 | --task './RETFound-DE/' \ 165 | --output_dir './RETFound-DE_log/' \ 166 | --resume ./mae_pretrain_vit_large.pth \ 167 | ``` 168 | 169 | ## Retinal Image Stable Diffusion Model 170 | For detailed information about the retinal diffusion model, please refer to README_SD.md 171 | 172 | 173 | ## Additional results on Chest X-ray images 174 | Following our data-efficient framework, we conducted additional experiments on Chest X-ray (CXR) images to further demonstrate the potential of our framework in extending to other medical fileds. We present the pretrained CXR foundation model [here](https://zenodo.org/records/13340936) (ChestX_Pretraining.zip), which was pretrained on 20k real and 80k synthetic CXR images. 175 | 176 | For downstream tasks, we provided two fine-tuned model [here](https://zenodo.org/records/13340936) (ChestX_Shenzhen.zip, ChestX_TBChest.zip) to show the performance of foundation model on Tuberculosis. 177 | 178 | Please follow the pipeline before and our paper to evaluate the performance on Chest X-ray images. 179 | 180 | 181 | Please contact **sunyuqi387@gmail.com** if you have questions. 182 | -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import argparse 12 | import datetime 13 | import json 14 | import numpy as np 15 | import os 16 | import time 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.backends.cudnn as cudnn 21 | from torch.utils.tensorboard import SummaryWriter 22 | import torchvision.transforms as transforms 23 | import torchvision.datasets as datasets 24 | 25 | import timm 26 | 27 | assert timm.__version__ == "0.3.2" # version check 28 | import timm.optim.optim_factory as optim_factory 29 | 30 | import util.misc as misc 31 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 32 | from util.data import DS 33 | 34 | import models_mae 35 | 36 | from engine_pretrain import train_one_epoch 37 | 38 | 39 | def get_args_parser(): 40 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 41 | parser.add_argument('--batch_size', default=64, type=int, 42 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 43 | parser.add_argument('--epochs', default=400, type=int) 44 | parser.add_argument('--accum_iter', default=1, type=int, 45 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 46 | 47 | # Model parameters 48 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL', 49 | help='Name of model to train') 50 | 51 | parser.add_argument('--input_size', default=224, type=int, 52 | help='images input size') 53 | 54 | parser.add_argument('--mask_ratio', default=0.75, type=float, 55 | help='Masking ratio (percentage of removed patches).') 56 | 57 | parser.add_argument('--norm_pix_loss', action='store_true', 58 | help='Use (per-patch) normalized pixels as targets for computing loss') 59 | parser.set_defaults(norm_pix_loss=False) 60 | 61 | # Optimizer parameters 62 | parser.add_argument('--weight_decay', type=float, default=0.05, 63 | help='weight decay (default: 0.05)') 64 | 65 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 66 | help='learning rate (absolute lr)') 67 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 68 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 69 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 70 | help='lower lr bound for cyclic schedulers that hit 0') 71 | 72 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 73 | help='epochs to warmup LR') 74 | 75 | # Dataset parameters 76 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 77 | help='dataset path') 78 | 79 | parser.add_argument('--output_dir', default='./output_dir', 80 | help='path where to save, empty for no saving') 81 | parser.add_argument('--log_dir', default='./output_dir', 82 | help='path where to tensorboard log') 83 | parser.add_argument('--device', default='cuda', 84 | help='device to use for training / testing') 85 | parser.add_argument('--seed', default=0, type=int) 86 | parser.add_argument('--resume', default='', 87 | help='resume from checkpoint') 88 | parser.add_argument('--task', default='',type=str, 89 | help='finetune from checkpoint') 90 | 91 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 92 | help='start epoch') 93 | parser.add_argument('--num_workers', default=10, type=int) 94 | parser.add_argument('--pin_mem', action='store_true', 95 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 96 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 97 | parser.set_defaults(pin_mem=True) 98 | 99 | # distributed training parameters 100 | parser.add_argument('--world_size', default=1, type=int, 101 | help='number of distributed processes') 102 | parser.add_argument('--local_rank', default=-1, type=int) 103 | parser.add_argument('--dist_on_itp', action='store_true') 104 | parser.add_argument('--dist_url', default='env://', 105 | help='url used to set up distributed training') 106 | 107 | return parser 108 | 109 | 110 | def main(args): 111 | misc.init_distributed_mode(args) 112 | 113 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 114 | print("{}".format(args).replace(', ', ',\n')) 115 | 116 | device = torch.device(args.device) 117 | 118 | # fix the seed for reproducibility 119 | seed = args.seed + misc.get_rank() 120 | torch.manual_seed(seed) 121 | np.random.seed(seed) 122 | 123 | cudnn.benchmark = True 124 | 125 | # simple augmentation 126 | transform_train = transforms.Compose([ 127 | transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic 128 | transforms.RandomHorizontalFlip(), 129 | transforms.ToTensor(), 130 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 131 | 132 | dataset_train = DS(args.data_path, transform_train) 133 | 134 | if True: # args.distributed: 135 | num_tasks = misc.get_world_size() 136 | global_rank = misc.get_rank() 137 | sampler_train = torch.utils.data.DistributedSampler( 138 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 139 | ) 140 | print("Sampler_train = %s" % str(sampler_train)) 141 | else: 142 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 143 | 144 | if global_rank == 0 and args.log_dir is not None: 145 | os.makedirs(args.log_dir, exist_ok=True) 146 | log_writer = SummaryWriter(log_dir=args.log_dir) 147 | else: 148 | log_writer = None 149 | 150 | data_loader_train = torch.utils.data.DataLoader( 151 | dataset_train, sampler=sampler_train, 152 | batch_size=args.batch_size, 153 | num_workers=args.num_workers, 154 | pin_memory=args.pin_mem, 155 | drop_last=True, 156 | ) 157 | 158 | # define the model 159 | model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss) 160 | 161 | model.to(device) 162 | 163 | model_without_ddp = model 164 | print("Model = %s" % str(model_without_ddp)) 165 | 166 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 167 | 168 | if args.lr is None: # only base_lr is specified 169 | args.lr = args.blr * eff_batch_size / 256 170 | 171 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 172 | print("actual lr: %.2e" % args.lr) 173 | 174 | print("accumulate grad iterations: %d" % args.accum_iter) 175 | print("effective batch size: %d" % eff_batch_size) 176 | 177 | if args.distributed: 178 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 179 | model_without_ddp = model.module 180 | 181 | # following timm: set wd as 0 for bias and norm layers 182 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 183 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 184 | print(optimizer) 185 | loss_scaler = NativeScaler() 186 | 187 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 188 | 189 | print(f"Start training for {args.epochs} epochs") 190 | start_time = time.time() 191 | for epoch in range(args.start_epoch, args.epochs): 192 | if args.distributed: 193 | data_loader_train.sampler.set_epoch(epoch) 194 | train_stats = train_one_epoch( 195 | model, data_loader_train, 196 | optimizer, device, epoch, loss_scaler, 197 | log_writer=log_writer, 198 | args=args 199 | ) 200 | if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs): 201 | misc.save_model( 202 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 203 | loss_scaler=loss_scaler, epoch=epoch) 204 | 205 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 206 | 'epoch': epoch,} 207 | 208 | if args.output_dir and misc.is_main_process(): 209 | if log_writer is not None: 210 | log_writer.flush() 211 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 212 | f.write(json.dumps(log_stats) + "\n") 213 | 214 | total_time = time.time() - start_time 215 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 216 | print('Training time {}'.format(total_time_str)) 217 | 218 | 219 | if __name__ == '__main__': 220 | args = get_args_parser() 221 | args = args.parse_args() 222 | if args.output_dir: 223 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 224 | main(args) -------------------------------------------------------------------------------- /models_mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from timm.models.vision_transformer import PatchEmbed, Block 10 | 11 | from util.pos_embed import get_2d_sincos_pos_embed 12 | 13 | 14 | class MaskedAutoencoderViT(nn.Module): 15 | """ Masked Autoencoder with VisionTransformer backbone 16 | """ 17 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 18 | embed_dim=1024, depth=24, num_heads=16, 19 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 20 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 21 | super().__init__() 22 | 23 | # -------------------------------------------------------------------------- 24 | # MAE encoder specifics 25 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 26 | num_patches = self.patch_embed.num_patches 27 | 28 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 29 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 30 | 31 | self.blocks = nn.ModuleList([ 32 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 33 | for i in range(depth)]) 34 | self.norm = norm_layer(embed_dim) 35 | # -------------------------------------------------------------------------- 36 | 37 | # -------------------------------------------------------------------------- 38 | # MAE decoder specifics 39 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 40 | 41 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 42 | 43 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 44 | 45 | self.decoder_blocks = nn.ModuleList([ 46 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 47 | for i in range(decoder_depth)]) 48 | 49 | self.decoder_norm = norm_layer(decoder_embed_dim) 50 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch 51 | # -------------------------------------------------------------------------- 52 | 53 | self.norm_pix_loss = norm_pix_loss 54 | 55 | self.initialize_weights() 56 | 57 | def initialize_weights(self): 58 | # initialization 59 | # initialize (and freeze) pos_embed by sin-cos embedding 60 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 61 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 62 | 63 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 64 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 65 | 66 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 67 | w = self.patch_embed.proj.weight.data 68 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 69 | 70 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 71 | torch.nn.init.normal_(self.cls_token, std=.02) 72 | torch.nn.init.normal_(self.mask_token, std=.02) 73 | 74 | # initialize nn.Linear and nn.LayerNorm 75 | self.apply(self._init_weights) 76 | 77 | def _init_weights(self, m): 78 | if isinstance(m, nn.Linear): 79 | # we use xavier_uniform following official JAX ViT: 80 | torch.nn.init.xavier_uniform_(m.weight) 81 | if isinstance(m, nn.Linear) and m.bias is not None: 82 | nn.init.constant_(m.bias, 0) 83 | elif isinstance(m, nn.LayerNorm): 84 | nn.init.constant_(m.bias, 0) 85 | nn.init.constant_(m.weight, 1.0) 86 | 87 | def patchify(self, imgs): 88 | """ 89 | imgs: (N, 3, H, W) 90 | x: (N, L, patch_size**2 *3) 91 | """ 92 | p = self.patch_embed.patch_size[0] 93 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 94 | 95 | h = w = imgs.shape[2] // p 96 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 97 | x = torch.einsum('nchpwq->nhwpqc', x) 98 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 99 | return x 100 | 101 | def unpatchify(self, x): 102 | """ 103 | x: (N, L, patch_size**2 *3) 104 | imgs: (N, 3, H, W) 105 | """ 106 | p = self.patch_embed.patch_size[0] 107 | h = w = int(x.shape[1]**.5) 108 | assert h * w == x.shape[1] 109 | 110 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 111 | x = torch.einsum('nhwpqc->nchpwq', x) 112 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 113 | return imgs 114 | 115 | def random_masking(self, x, mask_ratio): 116 | """ 117 | Perform per-sample random masking by per-sample shuffling. 118 | Per-sample shuffling is done by argsort random noise. 119 | x: [N, L, D], sequence 120 | """ 121 | N, L, D = x.shape # batch, length, dim 122 | len_keep = int(L * (1 - mask_ratio)) 123 | 124 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 125 | 126 | # sort noise for each sample 127 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 128 | ids_restore = torch.argsort(ids_shuffle, dim=1) 129 | 130 | # keep the first subset 131 | ids_keep = ids_shuffle[:, :len_keep] 132 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 133 | 134 | # generate the binary mask: 0 is keep, 1 is remove 135 | mask = torch.ones([N, L], device=x.device) 136 | mask[:, :len_keep] = 0 137 | # unshuffle to get the binary mask 138 | mask = torch.gather(mask, dim=1, index=ids_restore) 139 | 140 | return x_masked, mask, ids_restore 141 | 142 | def forward_encoder(self, x, mask_ratio): 143 | # embed patches 144 | x = self.patch_embed(x) 145 | 146 | # add pos embed w/o cls token 147 | x = x + self.pos_embed[:, 1:, :] 148 | 149 | # masking: length -> length * mask_ratio 150 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 151 | 152 | # append cls token 153 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 154 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 155 | x = torch.cat((cls_tokens, x), dim=1) 156 | 157 | # apply Transformer blocks 158 | for blk in self.blocks: 159 | x = blk(x) 160 | x = self.norm(x) 161 | 162 | return x, mask, ids_restore 163 | 164 | def forward_decoder(self, x, ids_restore): 165 | # embed tokens 166 | x = self.decoder_embed(x) 167 | 168 | # append mask tokens to sequence 169 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 170 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 171 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 172 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 173 | 174 | # add pos embed 175 | x = x + self.decoder_pos_embed 176 | 177 | # apply Transformer blocks 178 | for blk in self.decoder_blocks: 179 | x = blk(x) 180 | x = self.decoder_norm(x) 181 | 182 | # predictor projection 183 | x = self.decoder_pred(x) 184 | 185 | # remove cls token 186 | x = x[:, 1:, :] 187 | 188 | return x 189 | 190 | def forward_loss(self, imgs, pred, mask): 191 | """ 192 | imgs: [N, 3, H, W] 193 | pred: [N, L, p*p*3] 194 | mask: [N, L], 0 is keep, 1 is remove, 195 | """ 196 | target = self.patchify(imgs) 197 | if self.norm_pix_loss: 198 | mean = target.mean(dim=-1, keepdim=True) 199 | var = target.var(dim=-1, keepdim=True) 200 | target = (target - mean) / (var + 1.e-6)**.5 201 | 202 | loss = (pred - target) ** 2 203 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 204 | 205 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 206 | return loss 207 | 208 | # def forward(self, imgs, mask_ratio=0.75): 209 | # latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 210 | # pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 211 | # loss = self.forward_loss(imgs, pred, mask) 212 | # return loss, pred, mask, latent 213 | 214 | def forward(self, imgs, mask_ratio=0.75): 215 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 216 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 217 | loss = self.forward_loss(imgs, pred, mask) 218 | return loss, pred, mask 219 | 220 | def forward_latent(self, imgs, mask_ratio=1): 221 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 222 | return latent 223 | 224 | def mae_vit_large_patch16_dec512d8b(**kwargs): 225 | model = MaskedAutoencoderViT( 226 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 227 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 228 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 229 | return model 230 | 231 | 232 | 233 | # set recommended archs 234 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 235 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import requests 4 | import torch 5 | import numpy as np 6 | import argparse 7 | import matplotlib.pyplot as plt 8 | from PIL import Image 9 | import models_mae 10 | import models_vit 11 | import torch.nn as nn 12 | from pytorch_grad_cam import GradCAM 13 | from pytorch_grad_cam.utils.image import show_cam_on_image 14 | from utility import remove_black_borders_fast 15 | import utility 16 | 17 | imagenet_mean = np.array([0.485, 0.456, 0.406]) 18 | imagenet_std = np.array([0.229, 0.224, 0.225]) 19 | 20 | def show_image(image, title=''): 21 | # image is [H, W, 3] 22 | assert image.shape[2] == 3 23 | plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()) 24 | plt.title(title, fontsize=32) 25 | plt.axis('off') 26 | return 27 | 28 | def save_image(image, save_path): 29 | # image is [H, W, 3] 30 | assert image.shape[2] == 3 31 | image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).detach().numpy().astype(np.uint8) 32 | Image.fromarray(image).save(save_path) 33 | 34 | def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'): 35 | # build model 36 | model = getattr(models_mae, arch)() 37 | # load model 38 | checkpoint = torch.load(chkpt_dir, map_location='cpu') 39 | msg = model.load_state_dict(checkpoint['model'], strict=False) 40 | print(msg) 41 | return model 42 | 43 | def prepare_ft_model(chkpt_dir, arch='vit_large_patch16', model_type='DR'): 44 | if 'DR' in model_type: 45 | nb_classes = 5 46 | elif 'Glaucoma_PAPILA' in model_type: 47 | nb_classes = 3 48 | elif 'Glaucoma_Glaucoma_Fundus' in model_type: 49 | nb_classes = 3 50 | elif 'Glaucoma_ORIGA' in model_type: 51 | nb_classes = 2 52 | elif 'AMD_AREDS' in model_type: 53 | nb_classes = 4 54 | elif 'Multi_Retina' in model_type: 55 | nb_classes = 4 56 | elif 'Multi_JSIEC' in model_type: 57 | nb_classes = 39 58 | 59 | model = models_vit.__dict__[arch]( 60 | num_classes=nb_classes, 61 | drop_path_rate=0.1, 62 | global_pool=True, 63 | ) 64 | checkpoint = torch.load(chkpt_dir, map_location='cpu') 65 | model.load_state_dict(checkpoint['model'], strict=False) 66 | print("Resume checkpoint %s" % chkpt_dir) 67 | 68 | model.eval() 69 | return model 70 | 71 | def run_one_image(img, model): 72 | 73 | x = torch.tensor(img) 74 | # make it a batch-like 75 | x = x.unsqueeze(dim=0) 76 | x = torch.einsum('nhwc->nchw', x) 77 | 78 | # run MAE 79 | loss, y, mask = model(x.float(), mask_ratio=0.25) 80 | y = model.unpatchify(y) 81 | y = torch.einsum('nchw->nhwc', y).detach().cpu() 82 | 83 | # visualize the mask 84 | mask = mask.detach() 85 | mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3) 86 | mask = model.unpatchify(mask) # 1 is removing, 0 is keeping 87 | mask = torch.einsum('nchw->nhwc', mask).detach().cpu() 88 | 89 | x = torch.einsum('nchw->nhwc', x) 90 | 91 | # masked image 92 | im_masked = x * (1 - mask) 93 | 94 | # MAE reconstruction pasted with visible patches 95 | im_paste = x * (1 - mask) + y * mask 96 | 97 | # make the plt figure larger 98 | plt.rcParams['figure.figsize'] = [24, 24] 99 | 100 | # img = torch.cat([x[0], im_masked[0], y[0], im_paste[0]]) 101 | 102 | plt.subplot(2, 2, 1) 103 | show_image(x[0], "original") 104 | # save_image(x[0], save_path='original.png') 105 | 106 | plt.subplot(2, 2, 2) 107 | show_image(im_masked[0], "masked") 108 | # save_image(im_masked[0], save_path='masked.png') 109 | 110 | plt.subplot(2, 2, 3) 111 | show_image(y[0], "reconstruction") 112 | # save_image(y[0], save_path='reconstruction.png') 113 | 114 | plt.subplot(2, 2, 4) 115 | show_image(im_paste[0], "reconstruction + visible") 116 | # save_image(im_paste[0], save_path='reconstruction_visible.png') 117 | 118 | # plt.show() 119 | plt.savefig('mae.png') 120 | 121 | def reshape_transform(tensor, height=14, width=14): 122 | # 去掉cls token 123 | result = tensor[:, 1:, :].reshape(tensor.size(0), 124 | height, width, tensor.size(2)) 125 | 126 | # 将通道维度放到第一个位置 127 | result = result.transpose(2, 3).transpose(1, 2) 128 | return result 129 | 130 | def run_cam(img, cam, label=None): 131 | input_img = img - imagenet_mean 132 | input_img = img / imagenet_std 133 | x = torch.tensor(input_img) 134 | # make it a batch-like 135 | x = x.unsqueeze(dim=0) 136 | x = torch.einsum('nhwc->nchw', x).float() 137 | target_category = None # 可以指定一个类别,或者使用 None 表示最高概率的类别 138 | input_tensor = x 139 | grayscale_cam = cam(input_tensor=input_tensor, targets=target_category, aug_smooth=True) 140 | grayscale_cam = grayscale_cam[0, :] 141 | 142 | # 将 grad-cam 的输出叠加到原始图像上 143 | visualization = show_cam_on_image(img, grayscale_cam, image_weight=0.5) 144 | Image.fromarray(visualization).save('cam.png') 145 | 146 | def run_classification(img, model, type): 147 | x = utility.prepare_data(img) 148 | # model inference 149 | with torch.no_grad(): 150 | output = model(x) 151 | output = nn.Softmax(dim=1)(output) 152 | output = output.squeeze(0).cpu().detach().numpy() 153 | 154 | if 'DR' in type: 155 | # visualization 156 | categories = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR'] 157 | colors = ['blue', 'green', 'red', 'purple', 'orange'] 158 | prob_result = utility.draw_result(output, categories, colors) 159 | 160 | elif 'Glaucoma_PAPILA' in type: 161 | # visualization 162 | categories = ['No glaucoma', 'Suspected glaucoma', 'Glaucoma'] 163 | colors = ['blue', 'green', 'red'] 164 | prob_result = utility.draw_result(output, categories, colors) 165 | 166 | elif 'Glaucoma_Glaucoma_Fundus' in type: 167 | # visualization 168 | categories = ['No glaucoma', 'Early glaucoma', 'Advanced glaucoma'] 169 | colors = ['blue', 'green', 'red'] 170 | prob_result = utility.draw_result(output, categories, colors) 171 | 172 | elif 'Glaucoma_ORIGA' in type: 173 | # visualization 174 | categories = ['No glaucoma', 'Glaucoma'] 175 | colors = ['blue', 'red'] 176 | prob_result = utility.draw_result(output, categories, colors) 177 | 178 | elif 'AMD_AREDS' in type: 179 | # visualization 180 | categories = ['Non AMD', 'Mild AMD', 'Moderate AMD', 'Advanced AMD'] 181 | colors = ['blue', 'green', 'red', 'orange'] 182 | prob_result = utility.draw_result(output, categories, colors) 183 | 184 | elif 'Multi_Retina' in type: 185 | # visualization 186 | categories = ['Normal', 'Cataract', 'Glaucoma', 'Others'] 187 | colors = ['blue', 'green', 'red', 'orange'] 188 | prob_result = utility.draw_result(output, categories, colors) 189 | 190 | elif 'Multi_JSIEC' in type: 191 | # visualization 192 | categories = ['Normal', 'Tessellated fundus', 'Large optic cup', 'DR1', 'DR2', 'DR3', \ 193 | 'BRVO', 'CRVO', 'RAO', 'Rhegmatogenous RD', 'CSCR', 'VKH disease', 'Maculopathy', \ 194 | 'ERM', 'MH', 'Pathological myopia', 'Possible glaucoma', 'Optic atrophy', \ 195 | 'Severe hypertensive retinopathy', 'Disc swelling and elevation', 'Dragged Disc', \ 196 | 'Congenital disc abnormality', 'Retinitis pigmentosa', 'Bietti crystalline dystrophy', \ 197 | 'Peripheral retinal degeneration and break', 'Myelinated nerve fiber', 'Vitreous particles', \ 198 | 'Fundus neoplasm', 'Massive hard exudates', 'Yellow-white spots-flecks', 'Cotton-wool spots', \ 199 | 'Vessel tortuosity', 'Chorioretinal atrophy-coloboma', 'Preretinal hemorrhage', 'Fibrosis', \ 200 | 'Laser Spots', 'Silicon oil in eye', 'Blur fundus without PDR', 'Blur fundus with suspected PDR'] 201 | 202 | colors = ['aliceblue', 'antiquewhite', 'aqua', 'aquamarine', 'azure', 'beige', 'bisque', 'black', \ 203 | 'blanchedalmond', 'blue', 'blueviolet', 'brown', 'burlywood', 'cadetblue', 'chartreuse', \ 204 | 'chocolate', 'coral', 'cornflowerblue', 'cornsilk', 'crimson', 'cyan', 'darkblue', 'darkcyan', \ 205 | 'darkgoldenrod', 'darkgray', 'darkgreen', 'darkgrey', 'darkkhaki', 'darkmagenta', 'darkolivegreen', \ 206 | 'darkorange', 'darkorchid', 'darkred', 'darksalmon', 'darkseagreen', 'darkslateblue', 'darkslategray', \ 207 | 'darkslategrey', 'darkturquoise'] 208 | 209 | prob_result = utility.draw_result(output, categories, colors) 210 | 211 | Image.fromarray(prob_result).save('classification.png') 212 | 213 | if __name__ == '__main__': 214 | parser = argparse.ArgumentParser() 215 | parser.add_argument('--mode', type=str, default='mae', choices=['mae', 'cam', 'classification']) 216 | parser.add_argument('--img_path', type=str, default='./exampledata/DR/APTOS2019.png') 217 | parser.add_argument('--ft_model', type=str, default='DR_APTOS2019', choices=['DR_APTOS2019','DR_IDRID', \ 218 | 'DR_MESSIDOR2','Glaucoma_PAPILA', \ 219 | 'Glaucoma_Glaucoma_Fundus','Glaucoma_ORIGA',\ 220 | 'AMD_AREDS','Multi_Retina', 'Multi_JSIEC']) 221 | args = parser.parse_args() 222 | 223 | # load an image 224 | img_path = args.img_path 225 | img = remove_black_borders_fast(img_path) 226 | img = img.resize((224, 224)) 227 | img = np.array(img) / 255. 228 | 229 | assert img.shape == (224, 224, 3) 230 | if args.mode == 'mae': 231 | chkpt_dir = './checkpoint/PreTraining/checkpoint-best.pth' 232 | model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16') 233 | 234 | # make random mask reproducible (comment out to make it change) 235 | torch.manual_seed(4) 236 | # normalize by ImageNet mean and std 237 | img = img - imagenet_mean 238 | img = img / imagenet_std 239 | run_one_image(img, model_mae) 240 | 241 | elif args.mode == 'cam': 242 | # chose fine-tuned model from 'checkpoint' 243 | chkpt_dir = os.path.join('./checkpoint', args.ft_model, 'checkpoint-best.pth') 244 | model = prepare_ft_model(chkpt_dir) 245 | cam = GradCAM(model=model, target_layers=[model.blocks[-1].norm1], use_cuda=True, reshape_transform=reshape_transform) 246 | run_cam(img, cam) 247 | 248 | elif args.mode == 'classification': 249 | # chose fine-tuned model from 'checkpoint' 250 | chkpt_dir = os.path.join('./checkpoint', args.ft_model, 'checkpoint-best.pth') 251 | model = prepare_ft_model(chkpt_dir, model_type = args.ft_model) 252 | run_classification(img, model, args.ft_model) 253 | 254 | 255 | -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # Partly revised by YZ @UCL&Moorfields 4 | # -------------------------------------------------------- 5 | 6 | import math 7 | import sys 8 | import csv 9 | import os 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from timm.data import Mixup 14 | from timm.utils import accuracy 15 | from typing import Iterable, Optional 16 | import util.misc as misc 17 | import util.lr_sched as lr_sched 18 | from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, average_precision_score,multilabel_confusion_matrix, precision_recall_curve, roc_curve, auc 19 | from pycm import * 20 | import matplotlib.pyplot as plt 21 | import numpy as np 22 | from scipy import interp 23 | 24 | 25 | 26 | def misc_measures(confusion_matrix): 27 | 28 | acc = [] 29 | sensitivity = [] 30 | specificity = [] 31 | precision = [] 32 | G = [] 33 | F1_score_2 = [] 34 | mcc_ = [] 35 | 36 | for i in range(1, confusion_matrix.shape[0]): 37 | cm1=confusion_matrix[i] 38 | acc.append(1.*(cm1[0,0]+cm1[1,1])/np.sum(cm1)) 39 | sensitivity_ = 1.*cm1[1,1]/(cm1[1,0]+cm1[1,1]) 40 | sensitivity.append(sensitivity_) 41 | specificity_ = 1.*cm1[0,0]/(cm1[0,1]+cm1[0,0]) 42 | specificity.append(specificity_) 43 | precision_ = 1.*cm1[1,1]/(cm1[1,1]+cm1[0,1]) 44 | precision.append(precision_) 45 | G.append(np.sqrt(sensitivity_*specificity_)) 46 | F1_score_2.append(2*precision_*sensitivity_/(precision_+sensitivity_)) 47 | mcc = (cm1[0,0]*cm1[1,1]-cm1[0,1]*cm1[1,0])/np.sqrt((cm1[0,0]+cm1[0,1])*(cm1[0,0]+cm1[1,0])*(cm1[1,1]+cm1[1,0])*(cm1[1,1]+cm1[0,1])) 48 | mcc_.append(mcc) 49 | 50 | acc = np.array(acc).mean() 51 | sensitivity = np.array(sensitivity).mean() 52 | specificity = np.array(specificity).mean() 53 | precision = np.array(precision).mean() 54 | G = np.array(G).mean() 55 | F1_score_2 = np.array(F1_score_2).mean() 56 | mcc_ = np.array(mcc_).mean() 57 | 58 | return acc, sensitivity, specificity, precision, G, F1_score_2, mcc_ 59 | 60 | 61 | 62 | 63 | 64 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 65 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 66 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 67 | mixup_fn: Optional[Mixup] = None, log_writer=None, 68 | args=None): 69 | model.train(True) 70 | metric_logger = misc.MetricLogger(delimiter=" ") 71 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 72 | header = 'Epoch: [{}]'.format(epoch) 73 | print_freq = 20 74 | 75 | accum_iter = args.accum_iter 76 | 77 | optimizer.zero_grad() 78 | 79 | if log_writer is not None: 80 | print('log_dir: {}'.format(log_writer.log_dir)) 81 | 82 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 83 | 84 | # we use a per iteration (instead of per epoch) lr scheduler 85 | if data_iter_step % accum_iter == 0: 86 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 87 | 88 | samples = samples.to(device, non_blocking=True) 89 | targets = targets.to(device, non_blocking=True) 90 | 91 | if mixup_fn is not None: 92 | samples, targets = mixup_fn(samples, targets) 93 | 94 | with torch.cuda.amp.autocast(): 95 | outputs = model(samples) 96 | loss = criterion(outputs, targets) 97 | 98 | loss_value = loss.item() 99 | 100 | if not math.isfinite(loss_value): 101 | print("Loss is {}, stopping training".format(loss_value)) 102 | sys.exit(1) 103 | 104 | loss /= accum_iter 105 | loss_scaler(loss, optimizer, clip_grad=max_norm, 106 | parameters=model.parameters(), create_graph=False, 107 | update_grad=(data_iter_step + 1) % accum_iter == 0) 108 | if (data_iter_step + 1) % accum_iter == 0: 109 | optimizer.zero_grad() 110 | 111 | torch.cuda.synchronize() 112 | 113 | metric_logger.update(loss=loss_value) 114 | min_lr = 10. 115 | max_lr = 0. 116 | for group in optimizer.param_groups: 117 | min_lr = min(min_lr, group["lr"]) 118 | max_lr = max(max_lr, group["lr"]) 119 | 120 | metric_logger.update(lr=max_lr) 121 | 122 | loss_value_reduce = misc.all_reduce_mean(loss_value) 123 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 124 | """ We use epoch_1000x as the x-axis in tensorboard. 125 | This calibrates different curves when batch size changes. 126 | """ 127 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 128 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 129 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 130 | 131 | # gather the stats from all processes 132 | metric_logger.synchronize_between_processes() 133 | print("Averaged stats:", metric_logger) 134 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 135 | 136 | def draw_pr(lr_label, lr_precision, save_name, n_classes): 137 | precision = dict() 138 | recall = dict() 139 | average_precision = dict() 140 | lr_label = np.array(lr_label) 141 | lr_precision = np.array(lr_precision) 142 | 143 | for i in range(n_classes): 144 | precision[i], recall[i], _ = precision_recall_curve(lr_label[:, i], lr_precision[:, i]) 145 | 146 | average_precision[i] = average_precision_score(lr_label[:, i], 147 | lr_precision[:, i]) 148 | # (2) A "macro-average": quantifying score on all classes jointly 149 | precision["macro"], recall["macro"], _ = precision_recall_curve(lr_label.ravel(), 150 | lr_precision.ravel()) 151 | 152 | average_precision["macro"] = average_precision_score(lr_label, lr_precision, 153 | average="macro") 154 | 155 | plt.figure() 156 | plt.step(recall['macro'], precision['macro'], where='post') 157 | plt.xlabel('Recall') 158 | plt.ylabel('Precision') 159 | plt.ylim([0.0, 1.05]) 160 | plt.xlim([0.0, 1.0]) 161 | plt.title('Average precision score, macro-averaged over all classes: AP={0:0.3f}'.format(average_precision["macro"])) 162 | plt.savefig(save_name) 163 | 164 | def draw_roc(lr_label, lr_precision, save_name, n_classes): 165 | fpr = dict() 166 | tpr = dict() 167 | roc_auc = dict() 168 | lr_label = np.array(lr_label) 169 | lr_precision = np.array(lr_precision) 170 | 171 | for i in range(n_classes): 172 | 173 | fpr[i], tpr[i], _ = roc_curve(lr_label[:, i], lr_precision[:, i]) 174 | roc_auc[i] = auc(fpr[i], tpr[i]) 175 | 176 | fpr["micro"], tpr["micro"], _ = roc_curve(lr_label.ravel(), lr_precision.ravel()) 177 | roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) 178 | all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) 179 | 180 | mean_tpr = np.zeros_like(all_fpr) 181 | for i in range(n_classes): 182 | mean_tpr += interp(all_fpr, fpr[i], tpr[i]) 183 | 184 | mean_tpr /= n_classes 185 | fpr["macro"] = all_fpr 186 | tpr["macro"] = mean_tpr 187 | roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) 188 | 189 | lw=2 190 | plt.figure() 191 | plt.plot(fpr["micro"], tpr["micro"], 192 | label='micro-average ROC curve (area = {0:0.2f})' 193 | ''.format(roc_auc["micro"]), 194 | color='deeppink', linestyle=':', linewidth=4) 195 | 196 | plt.plot(fpr["macro"], tpr["macro"], 197 | label='macro-average ROC curve (area = {0:0.2f})' 198 | ''.format(roc_auc["macro"]), 199 | color='navy', linestyle=':', linewidth=4) 200 | 201 | plt.plot([0, 1], [0, 1], 'k--', lw=lw) 202 | plt.xlim([0.0, 1.0]) 203 | plt.ylim([0.0, 1.05]) 204 | plt.xlabel('False Positive Rate') 205 | plt.ylabel('True Positive Rate') 206 | plt.title('Some extension of Receiver operating characteristic to multi-class') 207 | plt.legend(loc="lower right") 208 | plt.savefig(save_name) 209 | 210 | @torch.no_grad() 211 | def evaluate(data_loader, model, device, task, epoch, mode, num_class): 212 | criterion = torch.nn.CrossEntropyLoss() 213 | 214 | metric_logger = misc.MetricLogger(delimiter=" ") 215 | header = 'Test:' 216 | 217 | if not os.path.exists(task): 218 | os.makedirs(task) 219 | 220 | prediction_decode_list = [] 221 | prediction_list = [] 222 | true_label_decode_list = [] 223 | true_label_onehot_list = [] 224 | 225 | # switch to evaluation mode 226 | model.eval() 227 | 228 | for batch in metric_logger.log_every(data_loader, 10, header): 229 | images = batch[0] 230 | target = batch[-1] 231 | images = images.to(device, non_blocking=True) 232 | target = target.to(device, non_blocking=True) 233 | true_label=F.one_hot(target.to(torch.int64), num_classes=num_class) 234 | 235 | # compute output 236 | with torch.cuda.amp.autocast(): 237 | output = model(images) 238 | loss = criterion(output, target) 239 | prediction_softmax = nn.Softmax(dim=1)(output) 240 | _,prediction_decode = torch.max(prediction_softmax, 1) 241 | _,true_label_decode = torch.max(true_label, 1) 242 | 243 | prediction_decode_list.extend(prediction_decode.cpu().detach().numpy()) 244 | true_label_decode_list.extend(true_label_decode.cpu().detach().numpy()) 245 | true_label_onehot_list.extend(true_label.cpu().detach().numpy()) 246 | prediction_list.extend(prediction_softmax.cpu().detach().numpy()) 247 | 248 | acc1,_ = accuracy(output, target, topk=(1,2)) 249 | 250 | batch_size = images.shape[0] 251 | metric_logger.update(loss=loss.item()) 252 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 253 | 254 | # gather the stats from all processes 255 | true_label_decode_list = np.array(true_label_decode_list) 256 | prediction_decode_list = np.array(prediction_decode_list) 257 | 258 | confusion_matrix = multilabel_confusion_matrix(true_label_decode_list, prediction_decode_list,labels=[i for i in range(num_class)]) 259 | acc, sensitivity, specificity, precision, G, F1, mcc = misc_measures(confusion_matrix) 260 | 261 | auc_roc = roc_auc_score(true_label_onehot_list, prediction_list,multi_class='ovr',average='macro') 262 | auc_pr = average_precision_score(true_label_onehot_list, prediction_list,average='macro') 263 | 264 | metric_logger.synchronize_between_processes() 265 | 266 | print('Sklearn Metrics - Acc: {:.4f} AUC-roc: {:.4f} AUC-pr: {:.4f} F1-score: {:.4f} MCC: {:.4f}'.format(acc, auc_roc, auc_pr, F1, mcc)) 267 | results_path = task+'_metrics_{}.csv'.format(mode) 268 | with open(results_path,mode='a',newline='',encoding='utf8') as cfa: 269 | wf = csv.writer(cfa) 270 | data2=[[acc,sensitivity,specificity,precision,auc_roc,auc_pr,F1,mcc,metric_logger.loss]] 271 | for i in data2: 272 | wf.writerow(i) 273 | 274 | # if mode=='test': 275 | # cm = ConfusionMatrix(actual_vector=true_label_decode_list, predict_vector=prediction_decode_list) 276 | # cm.plot(cmap=plt.cm.Blues,number_label=True,normalized=True,plot_lib="matplotlib") 277 | # plt.savefig(task+'confusion_matrix_test.jpg',dpi=600,bbox_inches ='tight') 278 | 279 | # draw_pr(true_label_onehot_list, prediction_list, save_name=task+'pr_curve.jpg', n_classes=num_class) 280 | # draw_roc(true_label_onehot_list, prediction_list, save_name=task+'roc_curve.jpg', n_classes=num_class) 281 | # true_label_onehot_list = np.array(true_label_onehot_list) 282 | # prediction_list = np.array(prediction_list) 283 | # save_result = np.stack([true_label_onehot_list, prediction_list], axis=0) 284 | # np.save(task+'save_result', save_result) 285 | 286 | 287 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()},auc_roc 288 | 289 | -------------------------------------------------------------------------------- /util/data_Radiology.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from PIL import Image 4 | import pickle 5 | from torch.utils.data import Dataset, ConcatDataset 6 | from torch.utils.data import DataLoader 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | from PIL import ImageFile 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | from glob import glob 12 | import pickle 13 | import random 14 | # from dataset.Mytransforms import * 15 | 16 | class ShenzhenDataset(Dataset): 17 | def __init__(self, data_dir, label_path, data_type, data_ratio=100, opt=None, use_syn=False, syn_data_dir=None, syn_label_path=None): 18 | self.trainsize = (224,224) 19 | self.data_dir = data_dir 20 | self.data_type = data_type 21 | self.use_syn = use_syn 22 | self.train = True if data_type == 'train' else False 23 | 24 | self.data_list = [] 25 | 26 | if data_ratio == 100: 27 | with open(label_path, "rb") as f: 28 | tr_dl = pickle.load(f) 29 | self.data_list = tr_dl 30 | print('Total Real Samples:', len(self.data_list)) 31 | else: 32 | dataset_name = label_path.split('/')[1] 33 | sample_data_path = os.path.join('SampleData', dataset_name, f'ratio_{data_ratio}', 'train.pkl') 34 | with open(sample_data_path, "rb") as f: 35 | tr_dl = pickle.load(f) 36 | self.data_list = tr_dl 37 | print(f'Total Ratio {data_ratio} Samples: {len(self.data_list)}') 38 | 39 | self.size = len(self.data_list) 40 | print('Total Samples:', self.size) 41 | 42 | self.size = len(self.data_list) 43 | if self.train: 44 | self.transform_center = transforms.Compose([ 45 | transforms.Resize(self.trainsize), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.RandomVerticalFlip(), 48 | transforms.RandomGrayscale(p=0.2), 49 | transforms.ColorJitter(), 50 | transforms.RandomRotation(degrees=(-180, 180)), 51 | transforms.ToTensor(), 52 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 53 | ]) 54 | else: 55 | self.transform_center = transforms.Compose([ 56 | # CropCenterSquare(), 57 | transforms.Resize(self.trainsize), 58 | transforms.ToTensor(), 59 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 60 | ]) 61 | 62 | def __getitem__(self, index): 63 | data_pac = self.data_list[index] 64 | basename, ext = data_pac['img_root'].split('.') 65 | imgname = data_pac['img_root'] 66 | # if ext == 'jpg': 67 | # imgname = basename + '.JPG' 68 | img_path = os.path.join(self.data_dir, imgname) 69 | img = Image.open(img_path).convert('RGB') 70 | 71 | img_torch = self.transform_center(img) 72 | 73 | label = int(data_pac['label']) 74 | 75 | return img_torch, label 76 | 77 | def __len__(self): 78 | return self.size 79 | 80 | class TBChestDataset(Dataset): 81 | def __init__(self, data_dir, label_path, data_type, data_ratio=100, opt=None, use_syn=False, syn_data_dir=None, syn_label_path=None): 82 | self.trainsize = (224,224) 83 | self.data_dir = data_dir 84 | self.data_type = data_type 85 | self.use_syn = use_syn 86 | self.train = True if data_type == 'train' else False 87 | 88 | self.data_list = [] 89 | 90 | if data_ratio == 100: 91 | with open(label_path, "rb") as f: 92 | tr_dl = pickle.load(f) 93 | self.data_list = tr_dl 94 | print('Total Real Samples:', len(self.data_list)) 95 | else: 96 | dataset_name = label_path.split('/')[1] 97 | sample_data_path = os.path.join('SampleData', dataset_name, f'ratio_{data_ratio}', 'train.pkl') 98 | with open(sample_data_path, "rb") as f: 99 | tr_dl = pickle.load(f) 100 | self.data_list = tr_dl 101 | print(f'Total Ratio {data_ratio} Samples: {len(self.data_list)}') 102 | 103 | self.size = len(self.data_list) 104 | print('Total Samples:', self.size) 105 | 106 | self.size = len(self.data_list) 107 | if self.train: 108 | self.transform_center = transforms.Compose([ 109 | transforms.Resize(self.trainsize), 110 | transforms.RandomHorizontalFlip(), 111 | transforms.RandomVerticalFlip(), 112 | transforms.RandomGrayscale(p=0.2), 113 | transforms.ColorJitter(), 114 | transforms.RandomRotation(degrees=(-180, 180)), 115 | transforms.ToTensor(), 116 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 117 | ]) 118 | else: 119 | self.transform_center = transforms.Compose([ 120 | # CropCenterSquare(), 121 | transforms.Resize(self.trainsize), 122 | transforms.ToTensor(), 123 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 124 | ]) 125 | 126 | def __getitem__(self, index): 127 | data_pac = self.data_list[index] 128 | basename, ext = data_pac['img_root'].split('.') 129 | imgname = data_pac['img_root'] 130 | # if ext == 'jpg': 131 | # imgname = basename + '.JPG' 132 | img_path = os.path.join(self.data_dir, imgname) 133 | img = Image.open(img_path).convert('RGB') 134 | 135 | img_torch = self.transform_center(img) 136 | 137 | label = int(data_pac['label']) 138 | 139 | return img_torch, label 140 | 141 | def __len__(self): 142 | return self.size 143 | 144 | class MyDataset(Dataset): 145 | def __init__(self, data_dir, label_path, transforms, data_type="train", debug=True, opt = None): 146 | self.data_dir = data_dir 147 | self.label_path = label_path 148 | # df = pd.read_csv(self.label_path) 149 | # self.img_list = df[df.columns[0]].values 150 | # self.label_list = df[df.columns[1]].values 151 | self.transforms = transforms 152 | self.imgs = sorted(glob.glob(os.path.join(self.data_dir, "*.*"))) 153 | 154 | if debug: 155 | self.imgs = self.imgs[:10] 156 | 157 | self.length = len(self.imgs) 158 | 159 | def __getitem__(self, idx): 160 | img_path = self.imgs[idx] 161 | img_name = os.path.split(img_path)[-1] 162 | pil_img = Image.open(img_path).convert("RGB") 163 | img = self.transforms(pil_img) 164 | return img, img_name 165 | 166 | def __len__(self): 167 | return self.length 168 | 169 | def get_loaders(opt): 170 | DatasetClass = eval(opt.datasetM) 171 | num_train_sets = len(opt.train_sets) 172 | train_dataset_list = [] 173 | for i in range(num_train_sets): 174 | train_dataset = DatasetClass(data_dir = opt.TRAIN_DATA_DIR[opt.train_sets[i]], 175 | label_path = opt.PATH_TO_TRAIN_LABEL[opt.train_sets[i]], 176 | data_type = 'train', 177 | debug = opt.debug, 178 | opt = opt) 179 | train_dataset_list.append(train_dataset) 180 | train_dataset = ConcatDataset(train_dataset_list) 181 | test_dataset = DatasetClass(data_dir = opt.TEST_DATA_DIR[opt.test_sets[0]], 182 | label_path = opt.PATH_TO_VAL_LABEL[opt.test_sets[0]], 183 | data_type = 'test', 184 | debug = opt.debug, 185 | opt = opt) 186 | train_loader = DataLoader( 187 | train_dataset, 188 | batch_size=opt.batch_size, 189 | shuffle=True, 190 | num_workers=opt.num_workers, 191 | ) 192 | test_loader = DataLoader( 193 | test_dataset, 194 | batch_size=opt.batch_size, 195 | shuffle=False, 196 | num_workers=opt.num_workers, 197 | ) 198 | return train_loader, test_loader 199 | 200 | ## for five-fold cross-validation on Train&Val, return Train&Val loaders 201 | # def get_loaders(opt): 202 | # DatasetClass = eval(opt.datasetM) 203 | # train_dataset = DatasetClass(data_dir = opt.DATA_DIR, 204 | # label_path = opt.PATH_TO_LABEL[opt.train_dataset], 205 | # data_type = 'train', 206 | # debug = opt.debug, 207 | # opt = opt) 208 | 209 | # # gain indices for cross-validation 210 | # whole_folder = [] 211 | # whole_num = len(train_dataset) 212 | # indices = np.arange(whole_num) 213 | # random.seed(opt.ds_seed) 214 | # random.shuffle(indices) 215 | 216 | # # split indices into five-fold 217 | # num_folder = opt.num_folder 218 | # each_folder_num = int(whole_num / num_folder) 219 | # for ii in range(num_folder-1): 220 | # each_folder = indices[each_folder_num*ii: each_folder_num*(ii+1)] 221 | # whole_folder.append(each_folder) 222 | # each_folder = indices[each_folder_num*(num_folder-1):] 223 | # whole_folder.append(each_folder) 224 | # assert len(whole_folder) == num_folder 225 | # assert sum([len(each) for each in whole_folder if 1==1]) == whole_num 226 | 227 | # ## split into train/eval 228 | # train_eval_idxs = [] 229 | # for ii in range(num_folder): 230 | # eval_idxs = whole_folder[ii] 231 | # train_idxs = [] 232 | # for jj in range(num_folder): 233 | # if jj != ii: train_idxs.extend(whole_folder[jj]) 234 | # train_eval_idxs.append([train_idxs, eval_idxs]) 235 | 236 | # ## gain train and eval loaders 237 | # train_loaders = [] 238 | # eval_loaders = [] 239 | # for ii in range(len(train_eval_idxs)): 240 | # train_idxs = train_eval_idxs[ii][0] 241 | # eval_idxs = train_eval_idxs[ii][1] 242 | # train_loader = DataLoader(train_dataset, 243 | # batch_size=opt.batch_size, 244 | # sampler=SubsetRandomSampler(train_idxs), 245 | # num_workers=opt.num_workers, 246 | # pin_memory=True) 247 | # eval_loader = DataLoader(train_dataset, 248 | # batch_size=opt.batch_size, 249 | # sampler=SubsetRandomSampler(eval_idxs), 250 | # num_workers=opt.num_workers, 251 | # pin_memory=True) 252 | # train_loaders.append(train_loader) 253 | # eval_loaders.append(eval_loader) 254 | 255 | # return train_loaders, eval_loaders 256 | 257 | def get_test_loaders(opt): 258 | test_loaders = [] 259 | if opt.havetest_sets: 260 | for test_set in opt.test_sets: 261 | DatasetClass = eval(test_set) 262 | test_dataset = DatasetClass(data_dir = opt.TEST_DATA_DIR[test_set], 263 | label_path = opt.PATH_TO_TEST_LABEL[test_set], 264 | data_type = test_set, 265 | debug = opt.debug) 266 | 267 | test_loader = DataLoader(test_dataset, 268 | batch_size=opt.batch_size, 269 | num_workers=opt.num_workers, 270 | shuffle=False, 271 | pin_memory=False) 272 | test_loaders.append(test_loader) 273 | return test_loaders 274 | 275 | if __name__ == '__main__': 276 | # train_dataset = APTOSDataset(data_dir = "/mnt/gzy/DiffMed/APTOS/train_images", 277 | # label_path = "/mnt/gzy/DiffMed/APTOS/aptos_test.pkl", 278 | # data_type = 'train', 279 | # debug = False) 280 | train_dataset = ISICDataset(data_dir = "/mnt/gzy/DiffMed/ISIC/rec_subset", 281 | label_path = "/mnt/gzy/DiffMed/ISIC/isic2018_test.pkl", 282 | data_type = 'test', 283 | debug = False) 284 | train_loader = DataLoader( 285 | train_dataset, 286 | batch_size=4, 287 | shuffle=True, 288 | num_workers=8, 289 | ) 290 | print(len(train_dataset)) 291 | # for batch in train_loader: 292 | # img, label = batch 293 | # print(img.shape) 294 | # print(label.shape) -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # Partly revised by YZ @UCL&Moorfields 4 | # -------------------------------------------------------- 5 | 6 | import builtins 7 | import datetime 8 | import os 9 | import time 10 | from collections import defaultdict, deque 11 | from pathlib import Path 12 | 13 | import torch 14 | import torch.distributed as dist 15 | from torch._six import inf 16 | 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if v is None: 88 | continue 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError("'{}' object has no attribute '{}'".format( 100 | type(self).__name__, attr)) 101 | 102 | def __str__(self): 103 | loss_str = [] 104 | for name, meter in self.meters.items(): 105 | loss_str.append( 106 | "{}: {}".format(name, str(meter)) 107 | ) 108 | return self.delimiter.join(loss_str) 109 | 110 | def synchronize_between_processes(self): 111 | for meter in self.meters.values(): 112 | meter.synchronize_between_processes() 113 | 114 | def add_meter(self, name, meter): 115 | self.meters[name] = meter 116 | 117 | def log_every(self, iterable, print_freq, header=None): 118 | i = 0 119 | if not header: 120 | header = '' 121 | start_time = time.time() 122 | end = time.time() 123 | iter_time = SmoothedValue(fmt='{avg:.4f}') 124 | data_time = SmoothedValue(fmt='{avg:.4f}') 125 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 126 | log_msg = [ 127 | header, 128 | '[{0' + space_fmt + '}/{1}]', 129 | 'eta: {eta}', 130 | '{meters}', 131 | 'time: {time}', 132 | 'data: {data}' 133 | ] 134 | if torch.cuda.is_available(): 135 | log_msg.append('max mem: {memory:.0f}') 136 | log_msg = self.delimiter.join(log_msg) 137 | MB = 1024.0 * 1024.0 138 | for obj in iterable: 139 | data_time.update(time.time() - end) 140 | yield obj 141 | iter_time.update(time.time() - end) 142 | if i % print_freq == 0 or i == len(iterable) - 1: 143 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 144 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 145 | if torch.cuda.is_available(): 146 | print(log_msg.format( 147 | i, len(iterable), eta=eta_string, 148 | meters=str(self), 149 | time=str(iter_time), data=str(data_time), 150 | memory=torch.cuda.max_memory_allocated() / MB)) 151 | else: 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time))) 156 | i += 1 157 | end = time.time() 158 | total_time = time.time() - start_time 159 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 160 | print('{} Total time: {} ({:.4f} s / it)'.format( 161 | header, total_time_str, total_time / len(iterable))) 162 | 163 | 164 | def setup_for_distributed(is_master): 165 | """ 166 | This function disables printing when not in master process 167 | """ 168 | builtin_print = builtins.print 169 | 170 | def print(*args, **kwargs): 171 | force = kwargs.pop('force', False) 172 | force = force or (get_world_size() > 8) 173 | if is_master or force: 174 | now = datetime.datetime.now().time() 175 | builtin_print('[{}] '.format(now), end='') # print with time stamp 176 | builtin_print(*args, **kwargs) 177 | 178 | builtins.print = print 179 | 180 | 181 | def is_dist_avail_and_initialized(): 182 | if not dist.is_available(): 183 | return False 184 | if not dist.is_initialized(): 185 | return False 186 | return True 187 | 188 | 189 | def get_world_size(): 190 | if not is_dist_avail_and_initialized(): 191 | return 1 192 | return dist.get_world_size() 193 | 194 | 195 | def get_rank(): 196 | if not is_dist_avail_and_initialized(): 197 | return 0 198 | return dist.get_rank() 199 | 200 | 201 | def is_main_process(): 202 | return get_rank() == 0 203 | 204 | 205 | def save_on_master(*args, **kwargs): 206 | if is_main_process(): 207 | torch.save(*args, **kwargs) 208 | 209 | 210 | def init_distributed_mode(args): 211 | if args.dist_on_itp: 212 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 213 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 214 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 215 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 216 | os.environ['LOCAL_RANK'] = str(args.gpu) 217 | os.environ['RANK'] = str(args.rank) 218 | os.environ['WORLD_SIZE'] = str(args.world_size) 219 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 220 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 221 | args.rank = int(os.environ["RANK"]) 222 | args.world_size = int(os.environ['WORLD_SIZE']) 223 | args.gpu = int(os.environ['LOCAL_RANK']) 224 | elif 'SLURM_PROCID' in os.environ: 225 | args.rank = int(os.environ['SLURM_PROCID']) 226 | args.gpu = args.rank % torch.cuda.device_count() 227 | else: 228 | print('Not using distributed mode') 229 | setup_for_distributed(is_master=True) # hack 230 | args.distributed = False 231 | return 232 | 233 | args.distributed = True 234 | 235 | torch.cuda.set_device(args.gpu) 236 | args.dist_backend = 'nccl' 237 | print('| distributed init (rank {}): {}, gpu {}'.format( 238 | args.rank, args.dist_url, args.gpu), flush=True) 239 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 240 | world_size=args.world_size, rank=args.rank) 241 | torch.distributed.barrier() 242 | setup_for_distributed(args.rank == 0) 243 | 244 | 245 | class NativeScalerWithGradNormCount: 246 | state_dict_key = "amp_scaler" 247 | 248 | def __init__(self): 249 | self._scaler = torch.cuda.amp.GradScaler() 250 | 251 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 252 | self._scaler.scale(loss).backward(create_graph=create_graph) 253 | if update_grad: 254 | if clip_grad is not None: 255 | assert parameters is not None 256 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 257 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 258 | else: 259 | self._scaler.unscale_(optimizer) 260 | norm = get_grad_norm_(parameters) 261 | self._scaler.step(optimizer) 262 | self._scaler.update() 263 | else: 264 | norm = None 265 | return norm 266 | 267 | def state_dict(self): 268 | return self._scaler.state_dict() 269 | 270 | def load_state_dict(self, state_dict): 271 | self._scaler.load_state_dict(state_dict) 272 | 273 | 274 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 275 | if isinstance(parameters, torch.Tensor): 276 | parameters = [parameters] 277 | parameters = [p for p in parameters if p.grad is not None] 278 | norm_type = float(norm_type) 279 | if len(parameters) == 0: 280 | return torch.tensor(0.) 281 | device = parameters[0].grad.device 282 | if norm_type == inf: 283 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 284 | else: 285 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 286 | return total_norm 287 | 288 | 289 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 290 | output_dir = Path(args.output_dir) 291 | epoch_name = str(epoch) 292 | if loss_scaler is not None: 293 | checkpoint_paths = [args.task + f'checkpoint-{str(epoch)}.pth'] 294 | for checkpoint_path in checkpoint_paths: 295 | to_save = { 296 | 'model': model_without_ddp.state_dict(), 297 | 'optimizer': optimizer.state_dict(), 298 | 'epoch': epoch, 299 | 'scaler': loss_scaler.state_dict(), 300 | 'args': args, 301 | } 302 | 303 | save_on_master(to_save, checkpoint_path) 304 | else: 305 | client_state = {'epoch': epoch} 306 | model.save_checkpoint(save_dir=args.task, tag=f"checkpoint-{str(epoch)}", client_state=client_state) 307 | 308 | def save_best_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 309 | output_dir = Path(args.output_dir) 310 | epoch_name = str(epoch) 311 | if loss_scaler is not None: 312 | checkpoint_paths = [args.task + f'checkpoint-best.pth'] 313 | for checkpoint_path in checkpoint_paths: 314 | to_save = { 315 | 'model': model_without_ddp.state_dict(), 316 | 'optimizer': optimizer.state_dict(), 317 | 'epoch': epoch, 318 | 'scaler': loss_scaler.state_dict(), 319 | 'args': args, 320 | } 321 | 322 | save_on_master(to_save, checkpoint_path) 323 | else: 324 | client_state = {'epoch': epoch} 325 | model.save_checkpoint(save_dir=args.task, tag=f"checkpoint-best", client_state=client_state) 326 | 327 | 328 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 329 | if args.resume: 330 | if args.resume.startswith('https'): 331 | checkpoint = torch.hub.load_state_dict_from_url( 332 | args.resume, map_location='cpu', check_hash=True) 333 | else: 334 | checkpoint = torch.load(args.resume, map_location='cpu') 335 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 336 | print("Resume checkpoint %s" % args.resume) 337 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 338 | optimizer.load_state_dict(checkpoint['optimizer']) 339 | args.start_epoch = checkpoint['epoch'] + 1 340 | if 'scaler' in checkpoint: 341 | loss_scaler.load_state_dict(checkpoint['scaler']) 342 | print("With optim & sched!") 343 | 344 | 345 | def all_reduce_mean(x): 346 | world_size = get_world_size() 347 | if world_size > 1: 348 | x_reduce = torch.tensor(x).cuda() 349 | dist.all_reduce(x_reduce) 350 | x_reduce /= world_size 351 | return x_reduce.item() 352 | else: 353 | return x 354 | -------------------------------------------------------------------------------- /util/data_AMD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from PIL import Image 4 | import pickle 5 | from torch.utils.data import Dataset, ConcatDataset 6 | from torch.utils.data import DataLoader 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | from PIL import ImageFile 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | from glob import glob 12 | import pickle 13 | import random 14 | # from dataset.Mytransforms import * 15 | 16 | # 定制数据集 17 | class AMDDataset(Dataset): 18 | def __init__(self, data_dir, label_path, data_type, data_ratio=100, opt=None, use_syn=False, syn_data_dir=None, syn_label_path=None): 19 | self.trainsize = (224,224) 20 | self.data_dir = data_dir 21 | self.data_type = data_type 22 | self.use_syn = use_syn 23 | self.train = True if data_type == 'train' else False 24 | 25 | self.data_list = [] 26 | 27 | if data_ratio == 100: 28 | with open(label_path, "rb") as f: 29 | tr_dl = pickle.load(f) 30 | self.data_list = tr_dl 31 | print('Total Real Samples:', len(self.data_list)) 32 | else: 33 | dataset_name = label_path.split('/')[1] 34 | sample_data_path = os.path.join('SampleData', dataset_name, f'ratio_{data_ratio}', 'train.pkl') 35 | with open(sample_data_path, "rb") as f: 36 | tr_dl = pickle.load(f) 37 | self.data_list = tr_dl 38 | print(f'Total Ratio {data_ratio} Samples: {len(self.data_list)}') 39 | 40 | self.size = len(self.data_list) 41 | print('Total Samples:', self.size) 42 | 43 | self.size = len(self.data_list) 44 | if self.train: 45 | self.transform_center = transforms.Compose([ 46 | transforms.Resize(self.trainsize), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.RandomVerticalFlip(), 49 | transforms.RandomGrayscale(p=0.2), 50 | transforms.ColorJitter(), 51 | transforms.RandomRotation(degrees=(-180, 180)), 52 | transforms.ToTensor(), 53 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 54 | ]) 55 | else: 56 | self.transform_center = transforms.Compose([ 57 | # CropCenterSquare(), 58 | transforms.Resize(self.trainsize), 59 | transforms.ToTensor(), 60 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 61 | ]) 62 | 63 | def __getitem__(self, index): 64 | data_pac = self.data_list[index] 65 | basename, ext = data_pac['img_root'].split('.') 66 | imgname = data_pac['img_root'] 67 | # if ext == 'jpg': 68 | # imgname = basename + '.JPG' 69 | img_path = os.path.join(self.data_dir, imgname) 70 | img = Image.open(img_path).convert('RGB') 71 | 72 | img_torch = self.transform_center(img) 73 | 74 | label = int(data_pac['label']) 75 | 76 | return img_torch, label 77 | 78 | def __len__(self): 79 | return self.size 80 | 81 | class ISICDataset(Dataset): 82 | def __init__(self, data_dir, label_path, data_type, opt=None): 83 | self.trainsize = (224,224) 84 | self.data_dir = data_dir 85 | self.data_type = data_type 86 | self.train = True if data_type == 'train' else False 87 | with open(label_path, "rb") as f: 88 | tr_dl = pickle.load(f) 89 | self.data_list = tr_dl 90 | 91 | # test subset of dataset 92 | if data_type != 'train': 93 | test_image_files = os.listdir(self.data_dir) 94 | matching_items = [] 95 | for item in self.data_list: 96 | if item['img_root'] in test_image_files: 97 | matching_items.append(item) 98 | self.data_list = matching_items 99 | self.size = len(self.data_list) 100 | 101 | if self.train: 102 | self.transform_center = transforms.Compose([ 103 | # CropCenterSquare(), 104 | transforms.Resize(self.trainsize), 105 | # RandomHorizontalFlip(), 106 | # RandomRotation(30), 107 | transforms.ToTensor(), 108 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 109 | ]) 110 | else: 111 | self.transform_center = transforms.Compose([ 112 | # CropCenterSquare(), 113 | transforms.Resize(self.trainsize), 114 | transforms.ToTensor(), 115 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 116 | ]) 117 | #self.depths_transform = transforms.Compose([transforms.Resize((self.trainsize, self.trainsize)),transforms.ToTensor()]) 118 | 119 | def __getitem__(self, index): 120 | data_pac = self.data_list[index] 121 | img_path = os.path.join(self.data_dir, data_pac['img_root']) 122 | img = Image.open(img_path).convert('RGB') 123 | 124 | img_torch = self.transform_center(img) 125 | 126 | label = int(data_pac['label']) 127 | 128 | return img_torch, label 129 | 130 | def __len__(self): 131 | return self.size 132 | 133 | class EyePACSDataset(Dataset): 134 | def __init__(self, data_dir, label_path, data_type, opt=None): 135 | self.trainsize = (224,224) 136 | self.data_dir = data_dir 137 | self.data_type = data_type 138 | self.train = True if data_type == 'train' else False 139 | with open(label_path, "rb") as f: 140 | tr_dl = pickle.load(f) 141 | self.data_list = tr_dl 142 | self.size = len(self.data_list) 143 | 144 | if self.train: 145 | self.transform_center = transforms.Compose([ 146 | transforms.Resize(self.trainsize), 147 | transforms.RandomHorizontalFlip(), 148 | transforms.RandomVerticalFlip(), 149 | transforms.RandomGrayscale(p=0.2), 150 | transforms.ColorJitter(), 151 | transforms.RandomRotation(degrees=(-180, 180)), 152 | transforms.ToTensor(), 153 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 154 | transforms.Pad(16, fill=0, padding_mode='constant'), 155 | ]) 156 | else: 157 | self.transform_center = transforms.Compose([ 158 | transforms.Resize(self.trainsize), 159 | transforms.ToTensor(), 160 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 161 | transforms.Pad(16, fill=0, padding_mode='constant'), 162 | ]) 163 | 164 | def __getitem__(self, index): 165 | data_pac = self.data_list[index] 166 | img_path = os.path.join(self.data_dir, data_pac['img_root']) 167 | img = Image.open(img_path).convert('RGB') 168 | 169 | img_torch = self.transform_center(img) 170 | 171 | label = int(data_pac['label']) 172 | 173 | return img_torch, label 174 | 175 | def __len__(self): 176 | return self.size 177 | 178 | class MyDataset(Dataset): 179 | def __init__(self, data_dir, label_path, transforms, data_type="train", debug=True, opt = None): 180 | self.data_dir = data_dir 181 | self.label_path = label_path 182 | # df = pd.read_csv(self.label_path) 183 | # self.img_list = df[df.columns[0]].values 184 | # self.label_list = df[df.columns[1]].values 185 | self.transforms = transforms 186 | self.imgs = sorted(glob.glob(os.path.join(self.data_dir, "*.*"))) 187 | 188 | if debug: 189 | self.imgs = self.imgs[:10] 190 | 191 | self.length = len(self.imgs) 192 | 193 | def __getitem__(self, idx): 194 | img_path = self.imgs[idx] 195 | img_name = os.path.split(img_path)[-1] 196 | pil_img = Image.open(img_path).convert("RGB") 197 | img = self.transforms(pil_img) 198 | return img, img_name 199 | 200 | def __len__(self): 201 | return self.length 202 | 203 | def get_loaders(opt): 204 | DatasetClass = eval(opt.datasetM) 205 | num_train_sets = len(opt.train_sets) 206 | train_dataset_list = [] 207 | for i in range(num_train_sets): 208 | train_dataset = DatasetClass(data_dir = opt.TRAIN_DATA_DIR[opt.train_sets[i]], 209 | label_path = opt.PATH_TO_TRAIN_LABEL[opt.train_sets[i]], 210 | data_type = 'train', 211 | debug = opt.debug, 212 | opt = opt) 213 | train_dataset_list.append(train_dataset) 214 | train_dataset = ConcatDataset(train_dataset_list) 215 | test_dataset = DatasetClass(data_dir = opt.TEST_DATA_DIR[opt.test_sets[0]], 216 | label_path = opt.PATH_TO_VAL_LABEL[opt.test_sets[0]], 217 | data_type = 'test', 218 | debug = opt.debug, 219 | opt = opt) 220 | train_loader = DataLoader( 221 | train_dataset, 222 | batch_size=opt.batch_size, 223 | shuffle=True, 224 | num_workers=opt.num_workers, 225 | ) 226 | test_loader = DataLoader( 227 | test_dataset, 228 | batch_size=opt.batch_size, 229 | shuffle=False, 230 | num_workers=opt.num_workers, 231 | ) 232 | return train_loader, test_loader 233 | 234 | ## for five-fold cross-validation on Train&Val, return Train&Val loaders 235 | # def get_loaders(opt): 236 | # DatasetClass = eval(opt.datasetM) 237 | # train_dataset = DatasetClass(data_dir = opt.DATA_DIR, 238 | # label_path = opt.PATH_TO_LABEL[opt.train_dataset], 239 | # data_type = 'train', 240 | # debug = opt.debug, 241 | # opt = opt) 242 | 243 | # # gain indices for cross-validation 244 | # whole_folder = [] 245 | # whole_num = len(train_dataset) 246 | # indices = np.arange(whole_num) 247 | # random.seed(opt.ds_seed) 248 | # random.shuffle(indices) 249 | 250 | # # split indices into five-fold 251 | # num_folder = opt.num_folder 252 | # each_folder_num = int(whole_num / num_folder) 253 | # for ii in range(num_folder-1): 254 | # each_folder = indices[each_folder_num*ii: each_folder_num*(ii+1)] 255 | # whole_folder.append(each_folder) 256 | # each_folder = indices[each_folder_num*(num_folder-1):] 257 | # whole_folder.append(each_folder) 258 | # assert len(whole_folder) == num_folder 259 | # assert sum([len(each) for each in whole_folder if 1==1]) == whole_num 260 | 261 | # ## split into train/eval 262 | # train_eval_idxs = [] 263 | # for ii in range(num_folder): 264 | # eval_idxs = whole_folder[ii] 265 | # train_idxs = [] 266 | # for jj in range(num_folder): 267 | # if jj != ii: train_idxs.extend(whole_folder[jj]) 268 | # train_eval_idxs.append([train_idxs, eval_idxs]) 269 | 270 | # ## gain train and eval loaders 271 | # train_loaders = [] 272 | # eval_loaders = [] 273 | # for ii in range(len(train_eval_idxs)): 274 | # train_idxs = train_eval_idxs[ii][0] 275 | # eval_idxs = train_eval_idxs[ii][1] 276 | # train_loader = DataLoader(train_dataset, 277 | # batch_size=opt.batch_size, 278 | # sampler=SubsetRandomSampler(train_idxs), 279 | # num_workers=opt.num_workers, 280 | # pin_memory=True) 281 | # eval_loader = DataLoader(train_dataset, 282 | # batch_size=opt.batch_size, 283 | # sampler=SubsetRandomSampler(eval_idxs), 284 | # num_workers=opt.num_workers, 285 | # pin_memory=True) 286 | # train_loaders.append(train_loader) 287 | # eval_loaders.append(eval_loader) 288 | 289 | # return train_loaders, eval_loaders 290 | 291 | def get_test_loaders(opt): 292 | test_loaders = [] 293 | if opt.havetest_sets: 294 | for test_set in opt.test_sets: 295 | DatasetClass = eval(test_set) 296 | test_dataset = DatasetClass(data_dir = opt.TEST_DATA_DIR[test_set], 297 | label_path = opt.PATH_TO_TEST_LABEL[test_set], 298 | data_type = test_set, 299 | debug = opt.debug) 300 | 301 | test_loader = DataLoader(test_dataset, 302 | batch_size=opt.batch_size, 303 | num_workers=opt.num_workers, 304 | shuffle=False, 305 | pin_memory=False) 306 | test_loaders.append(test_loader) 307 | return test_loaders 308 | 309 | if __name__ == '__main__': 310 | # train_dataset = APTOSDataset(data_dir = "/mnt/gzy/DiffMed/APTOS/train_images", 311 | # label_path = "/mnt/gzy/DiffMed/APTOS/aptos_test.pkl", 312 | # data_type = 'train', 313 | # debug = False) 314 | train_dataset = ISICDataset(data_dir = "/mnt/gzy/DiffMed/ISIC/rec_subset", 315 | label_path = "/mnt/gzy/DiffMed/ISIC/isic2018_test.pkl", 316 | data_type = 'test', 317 | debug = False) 318 | train_loader = DataLoader( 319 | train_dataset, 320 | batch_size=4, 321 | shuffle=True, 322 | num_workers=8, 323 | ) 324 | print(len(train_dataset)) 325 | # for batch in train_loader: 326 | # img, label = batch 327 | # print(img.shape) 328 | # print(label.shape) -------------------------------------------------------------------------------- /util/data_Cataract.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from PIL import Image 4 | import pickle 5 | from torch.utils.data import Dataset, ConcatDataset 6 | from torch.utils.data import DataLoader 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | from PIL import ImageFile 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | from glob import glob 12 | import pickle 13 | import random 14 | # from dataset.Mytransforms import * 15 | 16 | # 定制数据集 17 | class CataractDataset(Dataset): 18 | def __init__(self, data_dir, label_path, data_type, data_ratio=100, opt=None, use_syn=False, syn_data_dir=None, syn_label_path=None): 19 | self.trainsize = (224,224) 20 | self.data_dir = data_dir 21 | self.data_type = data_type 22 | self.use_syn = use_syn 23 | self.train = True if data_type == 'train' else False 24 | 25 | self.data_list = [] 26 | 27 | if data_ratio == 100: 28 | with open(label_path, "rb") as f: 29 | tr_dl = pickle.load(f) 30 | self.data_list = tr_dl 31 | print('Total Real Samples:', len(self.data_list)) 32 | else: 33 | dataset_name = label_path.split('/')[1] 34 | sample_data_path = os.path.join('SampleData', dataset_name, f'ratio_{data_ratio}', 'train.pkl') 35 | with open(sample_data_path, "rb") as f: 36 | tr_dl = pickle.load(f) 37 | self.data_list = tr_dl 38 | print(f'Total Ratio {data_ratio} Samples: {len(self.data_list)}') 39 | 40 | self.size = len(self.data_list) 41 | print('Total Samples:', self.size) 42 | 43 | self.size = len(self.data_list) 44 | if self.train: 45 | self.transform_center = transforms.Compose([ 46 | transforms.Resize(self.trainsize), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.RandomVerticalFlip(), 49 | transforms.RandomGrayscale(p=0.2), 50 | transforms.ColorJitter(), 51 | transforms.RandomRotation(degrees=(-180, 180)), 52 | transforms.ToTensor(), 53 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 54 | ]) 55 | else: 56 | self.transform_center = transforms.Compose([ 57 | # CropCenterSquare(), 58 | transforms.Resize(self.trainsize), 59 | transforms.ToTensor(), 60 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 61 | ]) 62 | 63 | def __getitem__(self, index): 64 | data_pac = self.data_list[index] 65 | basename, ext = data_pac['img_root'].split('.') 66 | imgname = data_pac['img_root'] 67 | # if ext == 'jpg': 68 | # imgname = basename + '.JPG' 69 | img_path = os.path.join(self.data_dir, imgname) 70 | img = Image.open(img_path).convert('RGB') 71 | 72 | img_torch = self.transform_center(img) 73 | 74 | label = int(data_pac['label']) 75 | 76 | return img_torch, label 77 | 78 | def __len__(self): 79 | return self.size 80 | 81 | class ISICDataset(Dataset): 82 | def __init__(self, data_dir, label_path, data_type, opt=None): 83 | self.trainsize = (224,224) 84 | self.data_dir = data_dir 85 | self.data_type = data_type 86 | self.train = True if data_type == 'train' else False 87 | with open(label_path, "rb") as f: 88 | tr_dl = pickle.load(f) 89 | self.data_list = tr_dl 90 | 91 | # test subset of dataset 92 | if data_type != 'train': 93 | test_image_files = os.listdir(self.data_dir) 94 | matching_items = [] 95 | for item in self.data_list: 96 | if item['img_root'] in test_image_files: 97 | matching_items.append(item) 98 | self.data_list = matching_items 99 | self.size = len(self.data_list) 100 | 101 | if self.train: 102 | self.transform_center = transforms.Compose([ 103 | # CropCenterSquare(), 104 | transforms.Resize(self.trainsize), 105 | # RandomHorizontalFlip(), 106 | # RandomRotation(30), 107 | transforms.ToTensor(), 108 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 109 | ]) 110 | else: 111 | self.transform_center = transforms.Compose([ 112 | # CropCenterSquare(), 113 | transforms.Resize(self.trainsize), 114 | transforms.ToTensor(), 115 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 116 | ]) 117 | #self.depths_transform = transforms.Compose([transforms.Resize((self.trainsize, self.trainsize)),transforms.ToTensor()]) 118 | 119 | def __getitem__(self, index): 120 | data_pac = self.data_list[index] 121 | img_path = os.path.join(self.data_dir, data_pac['img_root']) 122 | img = Image.open(img_path).convert('RGB') 123 | 124 | img_torch = self.transform_center(img) 125 | 126 | label = int(data_pac['label']) 127 | 128 | return img_torch, label 129 | 130 | def __len__(self): 131 | return self.size 132 | 133 | class EyePACSDataset(Dataset): 134 | def __init__(self, data_dir, label_path, data_type, opt=None): 135 | self.trainsize = (224,224) 136 | self.data_dir = data_dir 137 | self.data_type = data_type 138 | self.train = True if data_type == 'train' else False 139 | with open(label_path, "rb") as f: 140 | tr_dl = pickle.load(f) 141 | self.data_list = tr_dl 142 | self.size = len(self.data_list) 143 | 144 | if self.train: 145 | self.transform_center = transforms.Compose([ 146 | transforms.Resize(self.trainsize), 147 | transforms.RandomHorizontalFlip(), 148 | transforms.RandomVerticalFlip(), 149 | transforms.RandomGrayscale(p=0.2), 150 | transforms.ColorJitter(), 151 | transforms.RandomRotation(degrees=(-180, 180)), 152 | transforms.ToTensor(), 153 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 154 | transforms.Pad(16, fill=0, padding_mode='constant'), 155 | ]) 156 | else: 157 | self.transform_center = transforms.Compose([ 158 | transforms.Resize(self.trainsize), 159 | transforms.ToTensor(), 160 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 161 | transforms.Pad(16, fill=0, padding_mode='constant'), 162 | ]) 163 | 164 | def __getitem__(self, index): 165 | data_pac = self.data_list[index] 166 | img_path = os.path.join(self.data_dir, data_pac['img_root']) 167 | img = Image.open(img_path).convert('RGB') 168 | 169 | img_torch = self.transform_center(img) 170 | 171 | label = int(data_pac['label']) 172 | 173 | return img_torch, label 174 | 175 | def __len__(self): 176 | return self.size 177 | 178 | class MyDataset(Dataset): 179 | def __init__(self, data_dir, label_path, transforms, data_type="train", debug=True, opt = None): 180 | self.data_dir = data_dir 181 | self.label_path = label_path 182 | # df = pd.read_csv(self.label_path) 183 | # self.img_list = df[df.columns[0]].values 184 | # self.label_list = df[df.columns[1]].values 185 | self.transforms = transforms 186 | self.imgs = sorted(glob.glob(os.path.join(self.data_dir, "*.*"))) 187 | 188 | if debug: 189 | self.imgs = self.imgs[:10] 190 | 191 | self.length = len(self.imgs) 192 | 193 | def __getitem__(self, idx): 194 | img_path = self.imgs[idx] 195 | img_name = os.path.split(img_path)[-1] 196 | pil_img = Image.open(img_path).convert("RGB") 197 | img = self.transforms(pil_img) 198 | return img, img_name 199 | 200 | def __len__(self): 201 | return self.length 202 | 203 | def get_loaders(opt): 204 | DatasetClass = eval(opt.datasetM) 205 | num_train_sets = len(opt.train_sets) 206 | train_dataset_list = [] 207 | for i in range(num_train_sets): 208 | train_dataset = DatasetClass(data_dir = opt.TRAIN_DATA_DIR[opt.train_sets[i]], 209 | label_path = opt.PATH_TO_TRAIN_LABEL[opt.train_sets[i]], 210 | data_type = 'train', 211 | debug = opt.debug, 212 | opt = opt) 213 | train_dataset_list.append(train_dataset) 214 | train_dataset = ConcatDataset(train_dataset_list) 215 | test_dataset = DatasetClass(data_dir = opt.TEST_DATA_DIR[opt.test_sets[0]], 216 | label_path = opt.PATH_TO_VAL_LABEL[opt.test_sets[0]], 217 | data_type = 'test', 218 | debug = opt.debug, 219 | opt = opt) 220 | train_loader = DataLoader( 221 | train_dataset, 222 | batch_size=opt.batch_size, 223 | shuffle=True, 224 | num_workers=opt.num_workers, 225 | ) 226 | test_loader = DataLoader( 227 | test_dataset, 228 | batch_size=opt.batch_size, 229 | shuffle=False, 230 | num_workers=opt.num_workers, 231 | ) 232 | return train_loader, test_loader 233 | 234 | ## for five-fold cross-validation on Train&Val, return Train&Val loaders 235 | # def get_loaders(opt): 236 | # DatasetClass = eval(opt.datasetM) 237 | # train_dataset = DatasetClass(data_dir = opt.DATA_DIR, 238 | # label_path = opt.PATH_TO_LABEL[opt.train_dataset], 239 | # data_type = 'train', 240 | # debug = opt.debug, 241 | # opt = opt) 242 | 243 | # # gain indices for cross-validation 244 | # whole_folder = [] 245 | # whole_num = len(train_dataset) 246 | # indices = np.arange(whole_num) 247 | # random.seed(opt.ds_seed) 248 | # random.shuffle(indices) 249 | 250 | # # split indices into five-fold 251 | # num_folder = opt.num_folder 252 | # each_folder_num = int(whole_num / num_folder) 253 | # for ii in range(num_folder-1): 254 | # each_folder = indices[each_folder_num*ii: each_folder_num*(ii+1)] 255 | # whole_folder.append(each_folder) 256 | # each_folder = indices[each_folder_num*(num_folder-1):] 257 | # whole_folder.append(each_folder) 258 | # assert len(whole_folder) == num_folder 259 | # assert sum([len(each) for each in whole_folder if 1==1]) == whole_num 260 | 261 | # ## split into train/eval 262 | # train_eval_idxs = [] 263 | # for ii in range(num_folder): 264 | # eval_idxs = whole_folder[ii] 265 | # train_idxs = [] 266 | # for jj in range(num_folder): 267 | # if jj != ii: train_idxs.extend(whole_folder[jj]) 268 | # train_eval_idxs.append([train_idxs, eval_idxs]) 269 | 270 | # ## gain train and eval loaders 271 | # train_loaders = [] 272 | # eval_loaders = [] 273 | # for ii in range(len(train_eval_idxs)): 274 | # train_idxs = train_eval_idxs[ii][0] 275 | # eval_idxs = train_eval_idxs[ii][1] 276 | # train_loader = DataLoader(train_dataset, 277 | # batch_size=opt.batch_size, 278 | # sampler=SubsetRandomSampler(train_idxs), 279 | # num_workers=opt.num_workers, 280 | # pin_memory=True) 281 | # eval_loader = DataLoader(train_dataset, 282 | # batch_size=opt.batch_size, 283 | # sampler=SubsetRandomSampler(eval_idxs), 284 | # num_workers=opt.num_workers, 285 | # pin_memory=True) 286 | # train_loaders.append(train_loader) 287 | # eval_loaders.append(eval_loader) 288 | 289 | # return train_loaders, eval_loaders 290 | 291 | def get_test_loaders(opt): 292 | test_loaders = [] 293 | if opt.havetest_sets: 294 | for test_set in opt.test_sets: 295 | DatasetClass = eval(test_set) 296 | test_dataset = DatasetClass(data_dir = opt.TEST_DATA_DIR[test_set], 297 | label_path = opt.PATH_TO_TEST_LABEL[test_set], 298 | data_type = test_set, 299 | debug = opt.debug) 300 | 301 | test_loader = DataLoader(test_dataset, 302 | batch_size=opt.batch_size, 303 | num_workers=opt.num_workers, 304 | shuffle=False, 305 | pin_memory=False) 306 | test_loaders.append(test_loader) 307 | return test_loaders 308 | 309 | if __name__ == '__main__': 310 | # train_dataset = APTOSDataset(data_dir = "/mnt/gzy/DiffMed/APTOS/train_images", 311 | # label_path = "/mnt/gzy/DiffMed/APTOS/aptos_test.pkl", 312 | # data_type = 'train', 313 | # debug = False) 314 | train_dataset = ISICDataset(data_dir = "/mnt/gzy/DiffMed/ISIC/rec_subset", 315 | label_path = "/mnt/gzy/DiffMed/ISIC/isic2018_test.pkl", 316 | data_type = 'test', 317 | debug = False) 318 | train_loader = DataLoader( 319 | train_dataset, 320 | batch_size=4, 321 | shuffle=True, 322 | num_workers=8, 323 | ) 324 | print(len(train_dataset)) 325 | # for batch in train_loader: 326 | # img, label = batch 327 | # print(img.shape) 328 | # print(label.shape) -------------------------------------------------------------------------------- /util/data_Glaucoma.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from PIL import Image 4 | import pickle 5 | from torch.utils.data import Dataset, ConcatDataset 6 | from torch.utils.data import DataLoader 7 | import os 8 | import torchvision.transforms as transforms 9 | from PIL import Image 10 | from PIL import ImageFile 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | import numpy as np 13 | from glob import glob 14 | import pickle 15 | # from dataset.Mytransforms import * 16 | 17 | # 定制数据集 18 | class ORIGADataset(Dataset): 19 | def __init__(self, data_dir, label_path, data_type, data_ratio=100, opt=None, use_syn=False, syn_data_dir=None, syn_label_path=None): 20 | self.trainsize = (224,224) 21 | self.data_dir = data_dir 22 | self.data_type = data_type 23 | self.use_syn = use_syn 24 | self.train = True if data_type == 'train' else False 25 | self.data_list = [] 26 | 27 | 28 | if data_ratio == 100: 29 | with open(label_path, "rb") as f: 30 | tr_dl = pickle.load(f) 31 | new_tr_dl = [] 32 | for data in tr_dl: 33 | data['img_root'] = os.path.join(data_dir, data['img_root']) 34 | new_tr_dl.append(data) 35 | self.data_list = new_tr_dl 36 | print('Total Real Samples:', len(self.data_list)) 37 | else: 38 | dataset_name = label_path.split('/')[1] 39 | sample_data_path = os.path.join('SampleData', dataset_name, f'ratio_{data_ratio}', 'train.pkl') 40 | with open(sample_data_path, "rb") as f: 41 | tr_dl = pickle.load(f) 42 | new_tr_dl = [] 43 | for data in tr_dl: 44 | data['img_root'] = os.path.join(data_dir, data['img_root']) 45 | new_tr_dl.append(data) 46 | self.data_list = new_tr_dl 47 | print(f'Total Ratio {data_ratio} Samples: {len(self.data_list)}') 48 | 49 | if self.use_syn: 50 | with open(syn_label_path, "rb") as f: 51 | tr_dl = pickle.load(f) 52 | new_tr_dl = [] 53 | for data in tr_dl: 54 | data['img_root'] = os.path.join(syn_data_dir, data['img_root']) 55 | new_tr_dl.append(data) 56 | self.data_list.extend(new_tr_dl) 57 | print('Synthesised Samples:', len(new_tr_dl)) 58 | 59 | self.size = len(self.data_list) 60 | print('Total Samples:', self.size) 61 | 62 | if self.train: 63 | self.transform_center = transforms.Compose([ 64 | transforms.Resize(self.trainsize), 65 | transforms.RandomHorizontalFlip(), 66 | transforms.RandomVerticalFlip(), 67 | transforms.RandomGrayscale(p=0.2), 68 | transforms.ColorJitter(), 69 | transforms.RandomRotation(degrees=(-180, 180)), 70 | transforms.ToTensor(), 71 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 72 | ]) 73 | else: 74 | self.transform_center = transforms.Compose([ 75 | # CropCenterSquare(), 76 | transforms.Resize(self.trainsize), 77 | transforms.ToTensor(), 78 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 79 | ]) 80 | 81 | def __getitem__(self, index): 82 | data_pac = self.data_list[index] 83 | img_path = data_pac['img_root'] 84 | img = Image.open(img_path).convert('RGB') 85 | 86 | img_torch = self.transform_center(img) 87 | 88 | label = int(data_pac['label']) 89 | 90 | return img_torch, label 91 | 92 | def __len__(self): 93 | return self.size 94 | 95 | 96 | class PAPILADataset(Dataset): 97 | def __init__(self, data_dir, label_path, data_type, data_ratio=100, opt=None, use_syn=False, syn_data_dir=None, syn_label_path=None): 98 | self.trainsize = (224,224) 99 | self.data_dir = data_dir 100 | self.data_type = data_type 101 | self.use_syn = use_syn 102 | self.train = True if data_type == 'train' else False 103 | self.data_list = [] 104 | 105 | 106 | if data_ratio == 100: 107 | with open(label_path, "rb") as f: 108 | tr_dl = pickle.load(f) 109 | new_tr_dl = [] 110 | for data in tr_dl: 111 | data['img_root'] = os.path.join(data_dir, data['img_root']) 112 | new_tr_dl.append(data) 113 | self.data_list = new_tr_dl 114 | print('Total Real Samples:', len(self.data_list)) 115 | else: 116 | dataset_name = label_path.split('/')[1] 117 | sample_data_path = os.path.join('SampleData', dataset_name, f'ratio_{data_ratio}', 'train.pkl') 118 | with open(sample_data_path, "rb") as f: 119 | tr_dl = pickle.load(f) 120 | new_tr_dl = [] 121 | for data in tr_dl: 122 | data['img_root'] = os.path.join(data_dir, data['img_root']) 123 | new_tr_dl.append(data) 124 | self.data_list = new_tr_dl 125 | print(f'Total Ratio {data_ratio} Samples: {len(self.data_list)}') 126 | 127 | if self.use_syn: 128 | with open(syn_label_path, "rb") as f: 129 | tr_dl = pickle.load(f) 130 | new_tr_dl = [] 131 | for data in tr_dl: 132 | data['img_root'] = os.path.join(syn_data_dir, data['img_root']) 133 | new_tr_dl.append(data) 134 | self.data_list.extend(new_tr_dl) 135 | print('Synthesised Samples:', len(new_tr_dl)) 136 | 137 | self.size = len(self.data_list) 138 | print('Total Samples:', self.size) 139 | 140 | if self.train: 141 | self.transform_center = transforms.Compose([ 142 | transforms.Resize(self.trainsize), 143 | transforms.RandomHorizontalFlip(), 144 | transforms.RandomVerticalFlip(), 145 | transforms.RandomGrayscale(p=0.2), 146 | transforms.ColorJitter(), 147 | transforms.RandomRotation(degrees=(-180, 180)), 148 | transforms.ToTensor(), 149 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 150 | ]) 151 | else: 152 | self.transform_center = transforms.Compose([ 153 | # CropCenterSquare(), 154 | transforms.Resize(self.trainsize), 155 | transforms.ToTensor(), 156 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 157 | ]) 158 | 159 | def __getitem__(self, index): 160 | data_pac = self.data_list[index] 161 | img_path = data_pac['img_root'] 162 | img = Image.open(img_path).convert('RGB') 163 | 164 | img_torch = self.transform_center(img) 165 | 166 | label = int(data_pac['label']) 167 | 168 | return img_torch, label 169 | 170 | def __len__(self): 171 | return self.size 172 | 173 | import random 174 | class GFDataset(Dataset): 175 | def __init__(self, data_dir, label_path, data_type, data_ratio=100, opt=None, use_syn=False, syn_data_dir=None, syn_label_path=None): 176 | self.trainsize = (224,224) 177 | self.data_dir = data_dir 178 | self.data_type = data_type 179 | self.use_syn = use_syn 180 | self.train = True if data_type == 'train' else False 181 | self.data_list = [] 182 | 183 | if data_ratio == 100: 184 | with open(label_path, "rb") as f: 185 | tr_dl = pickle.load(f) 186 | new_tr_dl = [] 187 | for data in tr_dl: 188 | data['img_root'] = os.path.join(data_dir, data['img_root']) 189 | new_tr_dl.append(data) 190 | self.data_list = new_tr_dl 191 | print('Total Real Samples:', len(self.data_list)) 192 | else: 193 | dataset_name = label_path.split('/')[1] 194 | sample_data_path = os.path.join('SampleData', dataset_name, f'ratio_{data_ratio}', 'train.pkl') 195 | with open(sample_data_path, "rb") as f: 196 | tr_dl = pickle.load(f) 197 | new_tr_dl = [] 198 | for data in tr_dl: 199 | data['img_root'] = os.path.join(data_dir, data['img_root']) 200 | new_tr_dl.append(data) 201 | self.data_list = new_tr_dl 202 | print(f'Total Ratio {data_ratio} Samples: {len(self.data_list)}') 203 | 204 | if self.use_syn: 205 | with open(syn_label_path, "rb") as f: 206 | tr_dl = pickle.load(f) 207 | new_tr_dl = [] 208 | for data in tr_dl: 209 | data['img_root'] = os.path.join(syn_data_dir, data['img_root']) 210 | new_tr_dl.append(data) 211 | self.data_list.extend(new_tr_dl) 212 | print('Synthesised Samples:', len(new_tr_dl)) 213 | 214 | self.size = len(self.data_list) 215 | print('Total Samples:', self.size) 216 | 217 | if self.train: 218 | self.transform_center = transforms.Compose([ 219 | transforms.Resize(self.trainsize), 220 | transforms.RandomHorizontalFlip(), 221 | transforms.RandomVerticalFlip(), 222 | transforms.RandomGrayscale(p=0.2), 223 | transforms.ColorJitter(), 224 | transforms.RandomRotation(degrees=(-180, 180)), 225 | transforms.ToTensor(), 226 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 227 | ]) 228 | else: 229 | self.transform_center = transforms.Compose([ 230 | # CropCenterSquare(), 231 | transforms.Resize(self.trainsize), 232 | transforms.ToTensor(), 233 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 234 | ]) 235 | 236 | def __getitem__(self, index): 237 | data_pac = self.data_list[index] 238 | img_path = data_pac['img_root'] 239 | img = Image.open(img_path).convert('RGB') 240 | 241 | img_torch = self.transform_center(img) 242 | 243 | label = int(data_pac['label']) 244 | 245 | return img_torch, label 246 | 247 | def __len__(self): 248 | return self.size 249 | 250 | class MyDataset(Dataset): 251 | def __init__(self, data_dir, label_path, transforms, data_type="train", debug=True, opt = None): 252 | self.data_dir = data_dir 253 | self.label_path = label_path 254 | # df = pd.read_csv(self.label_path) 255 | # self.img_list = df[df.columns[0]].values 256 | # self.label_list = df[df.columns[1]].values 257 | self.transforms = transforms 258 | self.imgs = sorted(glob.glob(os.path.join(self.data_dir, "*.*"))) 259 | 260 | if debug: 261 | self.imgs = self.imgs[:10] 262 | 263 | self.length = len(self.imgs) 264 | 265 | def __getitem__(self, idx): 266 | img_path = self.imgs[idx] 267 | img_name = os.path.split(img_path)[-1] 268 | pil_img = Image.open(img_path).convert("RGB") 269 | img = self.transforms(pil_img) 270 | return img, img_name 271 | 272 | def __len__(self): 273 | return self.length 274 | 275 | def get_loaders(opt): 276 | DatasetClass = eval(opt.datasetM) 277 | num_train_sets = len(opt.train_sets) 278 | train_dataset_list = [] 279 | for i in range(num_train_sets): 280 | train_dataset = DatasetClass(data_dir = opt.TRAIN_DATA_DIR[opt.train_sets[i]], 281 | label_path = opt.PATH_TO_TRAIN_LABEL[opt.train_sets[i]], 282 | data_type = 'train', 283 | debug = opt.debug, 284 | opt = opt) 285 | train_dataset_list.append(train_dataset) 286 | train_dataset = ConcatDataset(train_dataset_list) 287 | test_dataset = DatasetClass(data_dir = opt.TEST_DATA_DIR[opt.test_sets[0]], 288 | label_path = opt.PATH_TO_VAL_LABEL[opt.test_sets[0]], 289 | data_type = 'test', 290 | debug = opt.debug, 291 | opt = opt) 292 | train_loader = DataLoader( 293 | train_dataset, 294 | batch_size=opt.batch_size, 295 | shuffle=True, 296 | num_workers=opt.num_workers, 297 | ) 298 | test_loader = DataLoader( 299 | test_dataset, 300 | batch_size=opt.batch_size, 301 | shuffle=False, 302 | num_workers=opt.num_workers, 303 | ) 304 | return train_loader, test_loader 305 | 306 | ## for five-fold cross-validation on Train&Val, return Train&Val loaders 307 | # def get_loaders(opt): 308 | # DatasetClass = eval(opt.datasetM) 309 | # train_dataset = DatasetClass(data_dir = opt.DATA_DIR, 310 | # label_path = opt.PATH_TO_LABEL[opt.train_dataset], 311 | # data_type = 'train', 312 | # debug = opt.debug, 313 | # opt = opt) 314 | 315 | # # gain indices for cross-validation 316 | # whole_folder = [] 317 | # whole_num = len(train_dataset) 318 | # indices = np.arange(whole_num) 319 | # random.seed(opt.ds_seed) 320 | # random.shuffle(indices) 321 | 322 | # # split indices into five-fold 323 | # num_folder = opt.num_folder 324 | # each_folder_num = int(whole_num / num_folder) 325 | # for ii in range(num_folder-1): 326 | # each_folder = indices[each_folder_num*ii: each_folder_num*(ii+1)] 327 | # whole_folder.append(each_folder) 328 | # each_folder = indices[each_folder_num*(num_folder-1):] 329 | # whole_folder.append(each_folder) 330 | # assert len(whole_folder) == num_folder 331 | # assert sum([len(each) for each in whole_folder if 1==1]) == whole_num 332 | 333 | # ## split into train/eval 334 | # train_eval_idxs = [] 335 | # for ii in range(num_folder): 336 | # eval_idxs = whole_folder[ii] 337 | # train_idxs = [] 338 | # for jj in range(num_folder): 339 | # if jj != ii: train_idxs.extend(whole_folder[jj]) 340 | # train_eval_idxs.append([train_idxs, eval_idxs]) 341 | 342 | # ## gain train and eval loaders 343 | # train_loaders = [] 344 | # eval_loaders = [] 345 | # for ii in range(len(train_eval_idxs)): 346 | # train_idxs = train_eval_idxs[ii][0] 347 | # eval_idxs = train_eval_idxs[ii][1] 348 | # train_loader = DataLoader(train_dataset, 349 | # batch_size=opt.batch_size, 350 | # sampler=SubsetRandomSampler(train_idxs), 351 | # num_workers=opt.num_workers, 352 | # pin_memory=True) 353 | # eval_loader = DataLoader(train_dataset, 354 | # batch_size=opt.batch_size, 355 | # sampler=SubsetRandomSampler(eval_idxs), 356 | # num_workers=opt.num_workers, 357 | # pin_memory=True) 358 | # train_loaders.append(train_loader) 359 | # eval_loaders.append(eval_loader) 360 | 361 | # return train_loaders, eval_loaders 362 | 363 | def get_test_loaders(opt): 364 | test_loaders = [] 365 | if opt.havetest_sets: 366 | for test_set in opt.test_sets: 367 | DatasetClass = eval(test_set) 368 | test_dataset = DatasetClass(data_dir = opt.TEST_DATA_DIR[test_set], 369 | label_path = opt.PATH_TO_TEST_LABEL[test_set], 370 | data_type = test_set, 371 | debug = opt.debug) 372 | 373 | test_loader = DataLoader(test_dataset, 374 | batch_size=opt.batch_size, 375 | num_workers=opt.num_workers, 376 | shuffle=False, 377 | pin_memory=False) 378 | test_loaders.append(test_loader) 379 | return test_loaders 380 | 381 | if __name__ == '__main__': 382 | # train_dataset = APTOSDataset(data_dir = "/mnt/gzy/DiffMed/APTOS/train_images", 383 | # label_path = "/mnt/gzy/DiffMed/APTOS/aptos_test.pkl", 384 | # data_type = 'train', 385 | # debug = False) 386 | train_dataset = ISICDataset(data_dir = "/mnt/gzy/DiffMed/ISIC/rec_subset", 387 | label_path = "/mnt/gzy/DiffMed/ISIC/isic2018_test.pkl", 388 | data_type = 'test', 389 | debug = False) 390 | train_loader = DataLoader( 391 | train_dataset, 392 | batch_size=4, 393 | shuffle=True, 394 | num_workers=8, 395 | ) 396 | print(len(train_dataset)) 397 | # for batch in train_loader: 398 | # img, label = batch 399 | # print(img.shape) 400 | # print(label.shape) -------------------------------------------------------------------------------- /util/data_MultiDisease.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from PIL import Image 4 | import pickle 5 | from torch.utils.data import Dataset, ConcatDataset 6 | from torch.utils.data import DataLoader 7 | import os, torch, random 8 | import numpy as np 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | from PIL import ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | import numpy as np 14 | from glob import glob 15 | import pickle 16 | # from dataset.Mytransforms import * 17 | 18 | # 定制数据集 19 | 20 | class ODIRDataset(Dataset): 21 | def __init__(self, data_dir, label_path, data_type, data_ratio=100, opt=None, use_syn=False, syn_data_dir=None, syn_label_path=None): 22 | self.trainsize = (224,224) 23 | self.data_dir = data_dir 24 | self.data_type = data_type 25 | self.use_syn = use_syn 26 | self.train = True if data_type == 'train' else False 27 | 28 | if data_ratio == 100: 29 | with open(label_path, "rb") as f: 30 | tr_dl = pickle.load(f) 31 | new_tr_dl = [] 32 | for data in tr_dl: 33 | data['img_root'] = os.path.join(data_dir, data['img_root']) 34 | new_tr_dl.append(data) 35 | self.data_list = new_tr_dl 36 | print('Total Real Samples:', len(self.data_list)) 37 | else: 38 | dataset_name = label_path.split('/')[1] 39 | sample_data_path = os.path.join('SampleData', dataset_name, f'ratio_{data_ratio}', 'train.pkl') 40 | with open(sample_data_path, "rb") as f: 41 | tr_dl = pickle.load(f) 42 | new_tr_dl = [] 43 | for data in tr_dl: 44 | data['img_root'] = os.path.join(data_dir, data['img_root']) 45 | new_tr_dl.append(data) 46 | self.data_list = new_tr_dl 47 | print(f'Total Ratio {data_ratio} Samples: {len(self.data_list)}') 48 | 49 | if self.use_syn: 50 | with open(syn_label_path, "rb") as f: 51 | tr_dl = pickle.load(f) 52 | new_tr_dl = [] 53 | for data in tr_dl: 54 | data['img_root'] = os.path.join(syn_data_dir, data['img_root']) 55 | new_tr_dl.append(data) 56 | self.data_list.extend(new_tr_dl) 57 | print('Synthesised Samples:', len(new_tr_dl)) 58 | 59 | self.size = len(self.data_list) 60 | print('Total Samples:', self.size) 61 | 62 | if self.train: 63 | self.transform_center = transforms.Compose([ 64 | transforms.Resize(self.trainsize), 65 | transforms.RandomHorizontalFlip(), 66 | transforms.RandomVerticalFlip(), 67 | transforms.RandomGrayscale(p=0.2), 68 | transforms.ColorJitter(), 69 | transforms.RandomRotation(degrees=(-180, 180)), 70 | transforms.ToTensor(), 71 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 72 | ]) 73 | else: 74 | self.transform_center = transforms.Compose([ 75 | # CropCenterSquare(), 76 | transforms.Resize(self.trainsize), 77 | transforms.ToTensor(), 78 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 79 | ]) 80 | 81 | def __getitem__(self, index): 82 | data_pac = self.data_list[index] 83 | img_path = data_pac['img_root'] 84 | img = Image.open(img_path).convert('RGB') 85 | 86 | img_torch = self.transform_center(img) 87 | 88 | label = int(data_pac['label']) 89 | 90 | return img_torch, label 91 | 92 | def __len__(self): 93 | return self.size 94 | 95 | class JSIECDataset(Dataset): 96 | def __init__(self, data_dir, label_path, data_type, data_ratio=100, opt=None, use_syn=False, syn_data_dir=None, syn_label_path=None): 97 | self.trainsize = (224,224) 98 | self.data_dir = data_dir 99 | self.data_type = data_type 100 | self.use_syn = use_syn 101 | self.train = True if data_type == 'train' else False 102 | self.data_list = [] 103 | 104 | if data_ratio == 100: 105 | with open(label_path, "rb") as f: 106 | tr_dl = pickle.load(f) 107 | new_tr_dl = [] 108 | for data in tr_dl: 109 | data['img_root'] = os.path.join(data_dir, data['img_root']) 110 | new_tr_dl.append(data) 111 | self.data_list = new_tr_dl 112 | print('Total Real Samples:', len(self.data_list)) 113 | else: 114 | dataset_name = label_path.split('/')[1] 115 | sample_data_path = os.path.join('SampleData', dataset_name, f'ratio_{data_ratio}', 'train.pkl') 116 | with open(sample_data_path, "rb") as f: 117 | tr_dl = pickle.load(f) 118 | new_tr_dl = [] 119 | for data in tr_dl: 120 | data['img_root'] = os.path.join(data_dir, data['img_root']) 121 | new_tr_dl.append(data) 122 | self.data_list = new_tr_dl 123 | print(f'Total Ratio {data_ratio} Samples: {len(self.data_list)}') 124 | 125 | if self.use_syn: 126 | with open(syn_label_path, "rb") as f: 127 | tr_dl = pickle.load(f) 128 | new_tr_dl = [] 129 | for data in tr_dl: 130 | data['img_root'] = os.path.join(syn_data_dir, data['img_root']) 131 | new_tr_dl.append(data) 132 | self.data_list.extend(new_tr_dl) 133 | print('Synthesised Samples:', len(new_tr_dl)) 134 | 135 | self.size = len(self.data_list) 136 | print('Total Samples:', self.size) 137 | 138 | if self.train: 139 | self.transform_center = transforms.Compose([ 140 | transforms.Resize(self.trainsize), 141 | transforms.RandomHorizontalFlip(), 142 | transforms.RandomVerticalFlip(), 143 | transforms.RandomGrayscale(p=0.2), 144 | transforms.ColorJitter(), 145 | transforms.RandomRotation(degrees=(-180, 180)), 146 | transforms.ToTensor(), 147 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 148 | ]) 149 | else: 150 | self.transform_center = transforms.Compose([ 151 | # CropCenterSquare(), 152 | transforms.Resize(self.trainsize), 153 | transforms.ToTensor(), 154 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 155 | ]) 156 | 157 | def __getitem__(self, index): 158 | data_pac = self.data_list[index] 159 | img_path = data_pac['img_root'] 160 | img = Image.open(img_path).convert('RGB') 161 | 162 | img_torch = self.transform_center(img) 163 | 164 | label = int(data_pac['label']) 165 | 166 | return img_torch, label 167 | 168 | def __len__(self): 169 | return self.size 170 | 171 | class RetinaDataset(Dataset): 172 | def __init__(self, data_dir, label_path, data_type, data_ratio=100, opt=None, use_syn=False, syn_data_dir=None, syn_label_path=None): 173 | self.trainsize = (224,224) 174 | self.data_dir = data_dir 175 | self.data_type = data_type 176 | self.use_syn = use_syn 177 | self.train = True if data_type == 'train' else False 178 | self.data_list = [] 179 | 180 | 181 | if data_ratio == 100: 182 | with open(label_path, "rb") as f: 183 | tr_dl = pickle.load(f) 184 | new_tr_dl = [] 185 | for data in tr_dl: 186 | data['img_root'] = os.path.join(data_dir, data['img_root']) 187 | new_tr_dl.append(data) 188 | self.data_list = new_tr_dl 189 | print('Total Real Samples:', len(self.data_list)) 190 | else: 191 | dataset_name = label_path.split('/')[1] 192 | sample_data_path = os.path.join('SampleData', dataset_name, f'ratio_{data_ratio}', 'train.pkl') 193 | with open(sample_data_path, "rb") as f: 194 | tr_dl = pickle.load(f) 195 | new_tr_dl = [] 196 | for data in tr_dl: 197 | data['img_root'] = os.path.join(data_dir, data['img_root']) 198 | new_tr_dl.append(data) 199 | self.data_list = new_tr_dl 200 | print(f'Total Ratio {data_ratio} Samples: {len(self.data_list)}') 201 | 202 | 203 | if self.use_syn: 204 | with open(syn_label_path, "rb") as f: 205 | tr_dl = pickle.load(f) 206 | new_tr_dl = [] 207 | for data in tr_dl: 208 | data['img_root'] = os.path.join(syn_data_dir, data['img_root']) 209 | new_tr_dl.append(data) 210 | self.data_list.extend(new_tr_dl) 211 | print('Synthesised Samples:', len(new_tr_dl)) 212 | 213 | self.size = len(self.data_list) 214 | print('Total Samples:', self.size) 215 | 216 | if self.train: 217 | self.transform_center = transforms.Compose([ 218 | transforms.Resize(self.trainsize), 219 | transforms.RandomHorizontalFlip(), 220 | transforms.RandomVerticalFlip(), 221 | transforms.RandomGrayscale(p=0.2), 222 | transforms.ColorJitter(), 223 | transforms.RandomRotation(degrees=(-180, 180)), 224 | transforms.ToTensor(), 225 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 226 | ]) 227 | else: 228 | self.transform_center = transforms.Compose([ 229 | # CropCenterSquare(), 230 | transforms.Resize(self.trainsize), 231 | transforms.ToTensor(), 232 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 233 | ]) 234 | 235 | def __getitem__(self, index): 236 | data_pac = self.data_list[index] 237 | img_path = data_pac['img_root'] 238 | img = Image.open(img_path).convert('RGB') 239 | 240 | img_torch = self.transform_center(img) 241 | 242 | label = int(data_pac['label']) 243 | 244 | return img_torch, label 245 | 246 | def __len__(self): 247 | return self.size 248 | 249 | 250 | class MyDataset(Dataset): 251 | def __init__(self, data_dir, label_path, transforms, data_type="train", debug=True, opt = None): 252 | self.data_dir = data_dir 253 | self.label_path = label_path 254 | # df = pd.read_csv(self.label_path) 255 | # self.img_list = df[df.columns[0]].values 256 | # self.label_list = df[df.columns[1]].values 257 | self.transforms = transforms 258 | self.imgs = sorted(glob.glob(os.path.join(self.data_dir, "*.*"))) 259 | 260 | if debug: 261 | self.imgs = self.imgs[:10] 262 | 263 | self.length = len(self.imgs) 264 | 265 | def __getitem__(self, idx): 266 | img_path = self.imgs[idx] 267 | img_name = os.path.split(img_path)[-1] 268 | pil_img = Image.open(img_path).convert("RGB") 269 | img = self.transforms(pil_img) 270 | return img, img_name 271 | 272 | def __len__(self): 273 | return self.length 274 | 275 | def get_loaders(opt): 276 | DatasetClass = eval(opt.datasetM) 277 | num_train_sets = len(opt.train_sets) 278 | train_dataset_list = [] 279 | for i in range(num_train_sets): 280 | train_dataset = DatasetClass(data_dir = opt.TRAIN_DATA_DIR[opt.train_sets[i]], 281 | label_path = opt.PATH_TO_TRAIN_LABEL[opt.train_sets[i]], 282 | data_type = 'train', 283 | debug = opt.debug, 284 | opt = opt) 285 | train_dataset_list.append(train_dataset) 286 | train_dataset = ConcatDataset(train_dataset_list) 287 | test_dataset = DatasetClass(data_dir = opt.TEST_DATA_DIR[opt.test_sets[0]], 288 | label_path = opt.PATH_TO_VAL_LABEL[opt.test_sets[0]], 289 | data_type = 'test', 290 | debug = opt.debug, 291 | opt = opt) 292 | train_loader = DataLoader( 293 | train_dataset, 294 | batch_size=opt.batch_size, 295 | shuffle=True, 296 | num_workers=opt.num_workers, 297 | ) 298 | test_loader = DataLoader( 299 | test_dataset, 300 | batch_size=opt.batch_size, 301 | shuffle=False, 302 | num_workers=opt.num_workers, 303 | ) 304 | return train_loader, test_loader 305 | 306 | ## for five-fold cross-validation on Train&Val, return Train&Val loaders 307 | # def get_loaders(opt): 308 | # DatasetClass = eval(opt.datasetM) 309 | # train_dataset = DatasetClass(data_dir = opt.DATA_DIR, 310 | # label_path = opt.PATH_TO_LABEL[opt.train_dataset], 311 | # data_type = 'train', 312 | # debug = opt.debug, 313 | # opt = opt) 314 | 315 | # # gain indices for cross-validation 316 | # whole_folder = [] 317 | # whole_num = len(train_dataset) 318 | # indices = np.arange(whole_num) 319 | # random.seed(opt.ds_seed) 320 | # random.shuffle(indices) 321 | 322 | # # split indices into five-fold 323 | # num_folder = opt.num_folder 324 | # each_folder_num = int(whole_num / num_folder) 325 | # for ii in range(num_folder-1): 326 | # each_folder = indices[each_folder_num*ii: each_folder_num*(ii+1)] 327 | # whole_folder.append(each_folder) 328 | # each_folder = indices[each_folder_num*(num_folder-1):] 329 | # whole_folder.append(each_folder) 330 | # assert len(whole_folder) == num_folder 331 | # assert sum([len(each) for each in whole_folder if 1==1]) == whole_num 332 | 333 | # ## split into train/eval 334 | # train_eval_idxs = [] 335 | # for ii in range(num_folder): 336 | # eval_idxs = whole_folder[ii] 337 | # train_idxs = [] 338 | # for jj in range(num_folder): 339 | # if jj != ii: train_idxs.extend(whole_folder[jj]) 340 | # train_eval_idxs.append([train_idxs, eval_idxs]) 341 | 342 | # ## gain train and eval loaders 343 | # train_loaders = [] 344 | # eval_loaders = [] 345 | # for ii in range(len(train_eval_idxs)): 346 | # train_idxs = train_eval_idxs[ii][0] 347 | # eval_idxs = train_eval_idxs[ii][1] 348 | # train_loader = DataLoader(train_dataset, 349 | # batch_size=opt.batch_size, 350 | # sampler=SubsetRandomSampler(train_idxs), 351 | # num_workers=opt.num_workers, 352 | # pin_memory=True) 353 | # eval_loader = DataLoader(train_dataset, 354 | # batch_size=opt.batch_size, 355 | # sampler=SubsetRandomSampler(eval_idxs), 356 | # num_workers=opt.num_workers, 357 | # pin_memory=True) 358 | # train_loaders.append(train_loader) 359 | # eval_loaders.append(eval_loader) 360 | 361 | # return train_loaders, eval_loaders 362 | 363 | def get_test_loaders(opt): 364 | test_loaders = [] 365 | if opt.havetest_sets: 366 | for test_set in opt.test_sets: 367 | DatasetClass = eval(test_set) 368 | test_dataset = DatasetClass(data_dir = opt.TEST_DATA_DIR[test_set], 369 | label_path = opt.PATH_TO_TEST_LABEL[test_set], 370 | data_type = test_set, 371 | debug = opt.debug) 372 | 373 | test_loader = DataLoader(test_dataset, 374 | batch_size=opt.batch_size, 375 | num_workers=opt.num_workers, 376 | shuffle=False, 377 | pin_memory=False) 378 | test_loaders.append(test_loader) 379 | return test_loaders 380 | 381 | if __name__ == '__main__': 382 | # train_dataset = APTOSDataset(data_dir = "/mnt/gzy/DiffMed/APTOS/train_images", 383 | # label_path = "/mnt/gzy/DiffMed/APTOS/aptos_test.pkl", 384 | # data_type = 'train', 385 | # debug = False) 386 | train_dataset = ISICDataset(data_dir = "/mnt/gzy/DiffMed/ISIC/rec_subset", 387 | label_path = "/mnt/gzy/DiffMed/ISIC/isic2018_test.pkl", 388 | data_type = 'test', 389 | debug = False) 390 | train_loader = DataLoader( 391 | train_dataset, 392 | batch_size=4, 393 | shuffle=True, 394 | num_workers=8, 395 | ) 396 | print(len(train_dataset)) 397 | # for batch in train_loader: 398 | # img, label = batch 399 | # print(img.shape) 400 | # print(label.shape) -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | import sys 2 | try: 3 | from skimage.metrics import structural_similarity as compare_ssim 4 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 5 | except: 6 | from skimage.measure import compare_psnr, compare_ssim 7 | 8 | import os 9 | import math 10 | import time 11 | import datetime 12 | import matplotlib 13 | matplotlib.use('Agg') 14 | from PIL import Image 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | import torch 18 | import torch.optim as optim 19 | import torch.optim.lr_scheduler as lrs 20 | from pytorch_grad_cam import GradCAM 21 | from pytorch_grad_cam.utils.image import show_cam_on_image 22 | 23 | imagenet_mean = np.array([0.485, 0.456, 0.406]) 24 | imagenet_std = np.array([0.229, 0.224, 0.225]) 25 | 26 | def remove_black_borders_fast(image_path, tolerance=20): 27 | with Image.open(image_path) as img: 28 | # Convert the image to a numpy array 29 | img_array = np.array(img) 30 | brightness_sum = img_array.sum(axis=2) 31 | 32 | # Find all rows and columns where the brightness is above a certain threshold 33 | # Assuming black pixels have very low brightness values 34 | threshold = brightness_sum.max() * 0.1 # 10% of the max value 35 | non_black_rows = np.where(brightness_sum.max(axis=1) > threshold)[0] 36 | non_black_cols = np.where(brightness_sum.max(axis=0) > threshold)[0] 37 | 38 | # Find the bounding box of the non-black areas 39 | top_row = non_black_rows[0] 40 | bottom_row = non_black_rows[-1] 41 | left_col = non_black_cols[0] 42 | right_col = non_black_cols[-1] 43 | 44 | cropped_img = img.crop((left_col, top_row, right_col, bottom_row)) 45 | return cropped_img 46 | 47 | def show_image(image, title=''): 48 | # image is [H, W, 3] 49 | assert image.shape[2] == 3 50 | # plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()) 51 | plt.imshow(image) 52 | plt.title(title, fontsize=64) 53 | plt.axis('off') 54 | return 55 | 56 | def prepare_data(image): 57 | input_img = image - imagenet_mean 58 | input_img = image / imagenet_std 59 | x = torch.tensor(input_img) 60 | # make it a batch-like 61 | x = x.unsqueeze(dim=0) 62 | x = torch.einsum('nhwc->nchw', x).float() 63 | return x 64 | 65 | def draw_mae(image, mask, y): 66 | # make the plt figure larger 67 | fig = plt.figure(figsize=(24, 24)) 68 | 69 | image = torch.tensor(image).unsqueeze(0) 70 | im_masked = image * (1 - mask) 71 | y = torch.clip((y * imagenet_std + imagenet_mean) * 255, 0, 255) 72 | x = image * 255 73 | im_masked = im_masked * 255 74 | im_paste = x * (1 - mask) + y * mask 75 | 76 | plt.subplot(2, 2, 1) 77 | show_image(x[0].numpy().astype(np.uint8), "original") 78 | 79 | plt.subplot(2, 2, 2) 80 | show_image(im_masked[0].numpy().astype(np.uint8), "masked") 81 | 82 | plt.subplot(2, 2, 3) 83 | show_image(y[0].numpy().astype(np.uint8), "reconstruction") 84 | 85 | plt.subplot(2, 2, 4) 86 | show_image(im_paste[0].numpy().astype(np.uint8), "reconstruction + visible") 87 | 88 | fig.canvas.draw() 89 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 90 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 91 | 92 | return data 93 | 94 | 95 | def draw_heatmap(MODEL, img, x, gpu=True): 96 | cam = GradCAM(model=MODEL, target_layers=[MODEL.blocks[-1].norm1], use_cuda=gpu, reshape_transform=reshape_transform) 97 | target_category = None # 可以指定一个类别,或者使用 None 表示最高概率的类别 98 | input_tensor = x 99 | grayscale_cam = cam(input_tensor=input_tensor, targets=target_category) 100 | grayscale_cam = grayscale_cam[0, :] 101 | 102 | # 将 grad-cam 的输出叠加到原始图像上 103 | visualization = show_cam_on_image(img, grayscale_cam, image_weight=0.5) 104 | return visualization 105 | 106 | def reshape_transform(tensor, height=14, width=14): 107 | # 去掉cls token 108 | result = tensor[:, 1:, :].reshape(tensor.size(0), 109 | height, width, tensor.size(2)) 110 | 111 | # 将通道维度放到第一个位置 112 | result = result.transpose(2, 3).transpose(1, 2) 113 | return result 114 | 115 | def draw_result(probabilities, categories, colors): 116 | 117 | # Creating the bar plot 118 | fig = plt.figure(figsize=(12, 10)) 119 | plt.barh(categories, probabilities, color=colors) 120 | if len(categories) == 39: 121 | fontsize = 8 122 | else: 123 | fontsize = 12 124 | plt.xticks(fontsize=fontsize) 125 | plt.yticks(fontsize=fontsize) 126 | plt.xlabel('Probability', fontsize=fontsize) 127 | plt.ylabel('DR Category', fontsize=fontsize) 128 | plt.title('Probability Distribution for Different DR Categories', fontsize=fontsize) 129 | plt.xlim(0, 1) # Ensuring the x-axis ranges from 0 to 1 130 | 131 | # plt.show() 132 | fig.canvas.draw() 133 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 134 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 135 | return data 136 | 137 | def to_color(arr, pmin=1, pmax=99.8, gamma=1., colors=((0, 1, 0), (1, 0, 1), (0, 1, 1))): 138 | """Converts a 2D or 3D stack to a colored image (maximal 3 channels). 139 | 140 | Parameters 141 | ---------- 142 | arr : numpy.ndarray 143 | 2D or 3D input data 144 | pmin : float 145 | lower percentile, pass -1 if no lower normalization is required 146 | pmax : float 147 | upper percentile, pass -1 if no upper normalization is required 148 | gamma : float 149 | gamma correction 150 | colors : list 151 | list of colors (r,g,b) for each channel of the input 152 | 153 | Returns 154 | ------- 155 | numpy.ndarray 156 | colored image 157 | """ 158 | if not arr.ndim in (2, 3): 159 | raise ValueError("only 2d or 3d arrays supported") 160 | 161 | if arr.ndim == 2: 162 | arr = arr[np.newaxis] 163 | 164 | ind_min = np.argmin(arr.shape) 165 | arr = np.moveaxis(arr, ind_min, 0).astype(np.float32) 166 | 167 | out = np.zeros(arr.shape[1:] + (3,)) 168 | 169 | eps = 1.e-20 170 | if pmin >= 0: 171 | mi = np.percentile(arr, pmin, axis=(1, 2), keepdims=True) 172 | else: 173 | mi = 0 174 | 175 | if pmax >= 0: 176 | ma = np.percentile(arr, pmax, axis=(1, 2), keepdims=True) 177 | else: 178 | ma = 1. + eps 179 | 180 | arr_norm = (1. * arr - mi) / (ma - mi + eps) 181 | 182 | for i_stack, col_stack in enumerate(colors): 183 | if i_stack >= len(arr): 184 | break 185 | for j, c in enumerate(col_stack): 186 | out[..., j] += c * arr_norm[i_stack] 187 | 188 | return np.clip(out, 0, 1) 189 | 190 | def savecolorim(save, im, norm=True, **imshow_kwargs): 191 | # im: Uint8 192 | imshow_kwargs['cmap'] = 'magma' 193 | if not norm: # 不对当前图片归一化处理,直接保存 194 | imshow_kwargs['vmin'] = 0 195 | imshow_kwargs['vmax'] = 255 196 | 197 | im = np.asarray(im) 198 | 199 | if save is not None: 200 | plt.imsave(save, im, **imshow_kwargs) 201 | else: 202 | # Make a random plot... 203 | fig = plt.figure() 204 | fig.add_subplot(111) 205 | 206 | # If we haven't already shown or saved the plot, then we need to 207 | # draw the figure first... 208 | plt.imshow(im, **imshow_kwargs) 209 | plt.axis('off') 210 | fig.subplots_adjust(bottom = 0) 211 | fig.subplots_adjust(top = 1) 212 | fig.subplots_adjust(right = 1) 213 | fig.subplots_adjust(left = 0) 214 | 215 | # Now we can save it to a numpy array. 216 | fig.canvas.draw() 217 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 218 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 219 | return data 220 | 221 | 222 | class timer(): 223 | def __init__(self): 224 | self.acc = 0 225 | self.tic() 226 | 227 | def tic(self): 228 | self.t0 = time.time() 229 | 230 | def toc(self, restart=False): 231 | diff = time.time() - self.t0 232 | if restart: self.t0 = time.time() 233 | return diff 234 | 235 | def hold(self): 236 | self.acc += self.toc() 237 | 238 | def release(self): 239 | ret = self.acc 240 | self.acc = 0 241 | 242 | return ret 243 | 244 | def reset(self): 245 | self.acc = 0 246 | 247 | 248 | class checkpoint(): 249 | def __init__(self, args): 250 | self.args = args 251 | self.ok = True 252 | self.log = torch.Tensor() 253 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 254 | rp = os.path.dirname(__file__) 255 | 256 | if not args.load: 257 | # if not args.save: 258 | # args.save = now 259 | self.dir = os.path.join(rp, 'experiment', args.save) 260 | else: 261 | self.dir = os.path.join(rp, 'experiment', args.load) 262 | if os.path.exists(self.dir): 263 | self.log = torch.load(self.get_path('psnr_log.pt')) 264 | print('Continue from epoch {}...'.format(len(self.log))) 265 | else: 266 | args.load = '' 267 | 268 | os.makedirs(self.dir, exist_ok=True) 269 | os.makedirs(self.get_path('model'), exist_ok=True) 270 | # os.makedirs(self.get_path('results-{}'.format(args.data_test)), exist_ok=True) 271 | 272 | open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w' 273 | self.log_file = open(self.get_path('log.txt'), open_type) 274 | with open(self.get_path('config.txt'), open_type) as f: 275 | f.write(now + '\n\n') 276 | for arg in vars(args): 277 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 278 | f.write('\n') 279 | 280 | self.n_processes = 0 # 8 281 | 282 | def get_path(self, *subdir): 283 | return os.path.join(self.dir, *subdir) 284 | 285 | def save(self, trainer, epoch, is_best=False): 286 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best) 287 | trainer.loss.save(self.dir) 288 | # trainer.loss.plot_loss(self.dir, epoch) 289 | 290 | # self.plot_psnr(epoch) 291 | trainer.optimizer.save(self.dir) 292 | # torch.save(self.log, self.get_path('psnr_log.pt')) 293 | 294 | def add_log(self, log): 295 | self.log = torch.cat([self.log, log]) 296 | 297 | def write_log(self, log, refresh=False): 298 | print(log) 299 | self.log_file.write(log + '\n') 300 | if refresh: 301 | self.log_file.close() 302 | self.log_file = open(self.get_path('log.txt'), 'a') 303 | 304 | def done(self): 305 | self.log_file.close() 306 | 307 | def plot_psnr(self, epoch): 308 | axis = np.linspace(1, epoch, epoch) 309 | for idx_data, d in enumerate(self.args.data_test): 310 | label = 'SR on {}'.format(d) 311 | fig = plt.figure() 312 | plt.title(label) 313 | for idx_scale, scale in enumerate(self.args.scale): 314 | plt.plot( 315 | axis, 316 | self.log[:, idx_data, idx_scale].numpy(), 317 | label='Scale {}'.format(scale) 318 | ) 319 | plt.legend() 320 | plt.xlabel('Epochs') 321 | plt.ylabel('PSNR') 322 | plt.grid(True) 323 | plt.savefig(self.get_path('test_{}.pdf'.format(d))) 324 | plt.close(fig) 325 | 326 | 327 | def quantize(img, rgb_range): 328 | pixel_range = 255 / rgb_range 329 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 330 | 331 | 332 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None): 333 | if hr.nelement() == 1: return 0 334 | 335 | diff = (sr - hr) / rgb_range 336 | if dataset and dataset.dataset.benchmark: 337 | shave = scale 338 | if diff.size(1) > 1: 339 | gray_coeffs = [65.738, 129.057, 25.064] 340 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 341 | diff = diff.mul(convert).sum(dim=1) 342 | else: 343 | shave = scale + 6 344 | 345 | valid = diff[..., shave:-shave, shave:-shave] 346 | mse = valid.pow(2).mean() 347 | 348 | return -10 * math.log10(mse) 349 | 350 | 351 | def compute_psnr_and_ssim(image1, image2, border_size=0): 352 | """ 353 | Computes PSNR and SSIM index from 2 images. 354 | We round it and clip to 0 - 255. Then shave 'scale' pixels from each border. 355 | """ 356 | if len(image1.shape) == 2: 357 | image1 = image1.reshape(image1.shape[0], image1.shape[1], 1) 358 | if len(image2.shape) == 2: 359 | image2 = image2.reshape(image2.shape[0], image2.shape[1], 1) 360 | 361 | if image1.shape[0] != image2.shape[0] or image1.shape[1] != image2.shape[1] or image1.shape[2] != image2.shape[2]: 362 | return None 363 | 364 | if border_size > 0: 365 | image1 = image1[border_size:-border_size, border_size:-border_size, :] 366 | image2 = image2[border_size:-border_size, border_size:-border_size, :] 367 | 368 | psnr = compare_psnr(image1, image2, data_range=255) 369 | ssim = compare_ssim(image1, image2, win_size=11, gaussian_weights=True, multichannel=True, K1=0.01, K2=0.03, 370 | sigma=1.5, data_range=255) 371 | 372 | return psnr, ssim 373 | 374 | 375 | def make_optimizer(args, target): 376 | ''' 377 | make optimizer and scheduler together 378 | ''' 379 | # optimizer 380 | trainable = filter(lambda x: x.requires_grad, target.parameters()) 381 | kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} 382 | 383 | if args.optimizer == 'SGD': 384 | optimizer_class = optim.SGD 385 | kwargs_optimizer['momentum'] = args.momentum 386 | elif args.optimizer == 'ADAM': 387 | optimizer_class = optim.Adam 388 | kwargs_optimizer['betas'] = args.betas 389 | kwargs_optimizer['eps'] = args.epsilon 390 | elif args.optimizer == 'RMSprop': 391 | optimizer_class = optim.RMSprop 392 | kwargs_optimizer['eps'] = args.epsilon 393 | 394 | # scheduler 395 | milestones = list(map(lambda x: int(x), args.decay.split('-'))) 396 | kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} 397 | scheduler_class = lrs.MultiStepLR 398 | 399 | class CustomOptimizer(optimizer_class): 400 | def __init__(self, *args, **kwargs): 401 | super(CustomOptimizer, self).__init__(*args, **kwargs) 402 | 403 | def _register_scheduler(self, scheduler_class, **kwargs): 404 | self.scheduler = scheduler_class(self, **kwargs) 405 | 406 | def save(self, save_dir): 407 | torch.save(self.state_dict(), self.get_dir(save_dir)) 408 | 409 | def load(self, load_dir, epoch=1): 410 | self.load_state_dict(torch.load(self.get_dir(load_dir))) 411 | if epoch > 1: 412 | for _ in range(epoch): self.scheduler.step() 413 | 414 | def get_dir(self, dir_path): 415 | return os.path.join(dir_path, 'optimizer.pt') 416 | 417 | def schedule(self): 418 | self.scheduler.step() 419 | 420 | def get_lr(self): 421 | return self.scheduler.get_lr()[0] 422 | 423 | def get_last_epoch(self): 424 | return self.scheduler.last_epoch 425 | 426 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer) 427 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) 428 | return optimizer 429 | 430 | from tifffile import imsave 431 | import warnings 432 | def save_tiff_imagej_compatible(file, img, axes, **imsave_kwargs): 433 | """Save image in ImageJ-compatible TIFF format. 434 | 435 | Parameters 436 | ---------- 437 | file : str 438 | File name 439 | img : numpy.ndarray 440 | Image 441 | axes: str 442 | Axes of ``img`` 443 | imsave_kwargs : dict, optional 444 | Keyword arguments for :func:`tifffile.imsave` 445 | 446 | """ 447 | axes = axes_check_and_normalize(axes, img.ndim, disallowed='S') 448 | 449 | # convert to imagej-compatible data type 450 | t = img.dtype 451 | if 'float' in t.name: 452 | t_new = np.float32 453 | elif 'uint' in t.name: 454 | t_new = np.uint16 if t.itemsize >= 2 else np.uint8 455 | elif 'int' in t.name: 456 | t_new = np.int16 457 | else: 458 | t_new = t 459 | img = img.astype(t_new, copy=False) 460 | if t != t_new: 461 | warnings.warn("Converting data type from '%s' to ImageJ-compatible '%s'." % (t, np.dtype(t_new))) 462 | 463 | # move axes to correct positions for imagej 464 | img = move_image_axes(img, axes, 'TZCYX', True) 465 | 466 | imsave_kwargs['imagej'] = True 467 | imsave(file, img, **imsave_kwargs) 468 | import collections 469 | 470 | # https://docs.python.org/3/library/itertools.html#itertools-recipes 471 | def consume(iterator): 472 | collections.deque(iterator, maxlen=0) 473 | 474 | 475 | def _raise(e): 476 | raise e 477 | 478 | 479 | def axes_check_and_normalize(axes, length=None, disallowed=None, return_allowed=False): 480 | """ 481 | S(ample), T(ime), C(hannel), Z, Y, X 482 | """ 483 | allowed = 'STCZYX' 484 | assert axes is not None 485 | axes = str(axes).upper() 486 | consume( 487 | a in allowed or _raise(ValueError("invalid axis '%s', must be one of %s." % (a, list(allowed)))) for a in axes) 488 | disallowed is None or consume(a not in disallowed or _raise(ValueError("disallowed axis '%s'." % a)) for a in axes) 489 | consume(axes.count(a) == 1 or _raise(ValueError("axis '%s' occurs more than once." % a)) for a in axes) 490 | length is None or len(axes) == length or _raise(ValueError('axes (%s) must be of length %d.' % (axes, length))) 491 | return (axes, allowed) if return_allowed else axes 492 | 493 | 494 | def move_image_axes(x, fr, to, adjust_singletons=False): 495 | """ 496 | x: ndarray 497 | fr,to: axes string (see `axes_dict`) 498 | """ 499 | fr = axes_check_and_normalize(fr, length=x.ndim) 500 | to = axes_check_and_normalize(to) 501 | 502 | fr_initial = fr 503 | x_shape_initial = x.shape 504 | adjust_singletons = bool(adjust_singletons) 505 | if adjust_singletons: 506 | # remove axes not present in 'to' 507 | slices = [slice(None) for _ in x.shape] 508 | for i, a in enumerate(fr): 509 | if (a not in to) and (x.shape[i] == 1): 510 | # remove singleton axis 511 | slices[i] = 0 512 | fr = fr.replace(a, '') 513 | x = x[tuple(slices)] 514 | # add dummy axes present in 'to' 515 | for i, a in enumerate(to): 516 | if (a not in fr): 517 | # add singleton axis 518 | x = np.expand_dims(x, -1) 519 | fr += a 520 | 521 | if set(fr) != set(to): 522 | _adjusted = '(adjusted to %s and %s) ' % (x.shape, fr) if adjust_singletons else '' 523 | raise ValueError( 524 | 'image with shape %s and axes %s %snot compatible with target axes %s.' 525 | % (x_shape_initial, fr_initial, _adjusted, to) 526 | ) 527 | 528 | ax_from, ax_to = axes_dict(fr), axes_dict(to) 529 | if fr == to: 530 | return x 531 | return np.moveaxis(x, [ax_from[a] for a in fr], [ax_to[a] for a in fr]) 532 | 533 | 534 | def axes_dict(axes): 535 | """ 536 | from axes string to dict 537 | """ 538 | axes, allowed = axes_check_and_normalize(axes, return_allowed=True) 539 | return {a: None if axes.find(a) == -1 else axes.find(a) for a in allowed} 540 | # return collections.namedtuple('Axes',list(allowed))(*[None if axes.find(a) == -1 else axes.find(a) for a in allowed ]) 541 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | --------------------------------------------------------------------------------