├── utils ├── __init__.py └── util.py ├── experiments ├── camelyon17 │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── layers.py │ │ ├── resnet_multispectral.py │ │ └── initializer.py │ ├── networks │ │ ├── __init__.py │ │ ├── models.py │ │ └── densenet.py │ ├── data_augmentation │ │ ├── __init__.py │ │ └── randaugment.py │ ├── configs │ │ ├── data_loader.py │ │ ├── scheduler.py │ │ ├── model.py │ │ ├── supported.py │ │ ├── algorithm.py │ │ └── utils.py │ ├── autoencoder.py │ ├── losses.py │ ├── ttt.py │ ├── optimizer.py │ ├── evaluate │ │ ├── MbPA.py │ │ ├── models.py │ │ ├── train.py │ │ ├── mbtt.py │ │ ├── conf.py │ │ └── camelyon17_dataset.py │ ├── scheduler.py │ ├── train.py │ └── evaluate.py └── prostate │ ├── nets │ ├── __init__.py │ └── unet.py │ ├── train │ ├── __init__.py │ ├── model_trainer.py │ ├── center.py │ ├── configs.py │ ├── test_dataset.py │ ├── api.py │ └── model_trainer_segmentation.py │ ├── utils │ ├── __init__.py │ ├── data_preprocess.py │ └── loss.py │ ├── data │ ├── __init__.py │ ├── prostate │ │ ├── __init__.py │ │ ├── generate_data.py │ │ └── dataset.py │ └── generate_data_loader.py │ └── main.py ├── assets └── Figure.png ├── models ├── __init__.py ├── adapmodel.py ├── segmodel.py └── backends.py ├── scripts ├── test_oct.sh └── train_oct.sh ├── README.md ├── memory.py ├── datasets ├── __init__.py ├── dataset.py └── transform.py ├── train.py └── config.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/camelyon17/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/prostate/nets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/prostate/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/prostate/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/camelyon17/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/camelyon17/networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/prostate/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /experiments/camelyon17/data_augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/prostate/data/prostate/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /assets/Figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/med-air/DLTTA/HEAD/assets/Figure.png -------------------------------------------------------------------------------- /experiments/camelyon17/configs/data_loader.py: -------------------------------------------------------------------------------- 1 | loader_defaults = { 2 | 'loader_kwargs': { 3 | 'num_workers': 4, 4 | 'pin_memory': True, 5 | }, 6 | 'unlabeled_loader_kwargs': { 7 | 'num_workers': 8, 8 | 'pin_memory': True, 9 | }, 10 | 'n_groups_per_batch': 4, 11 | } 12 | -------------------------------------------------------------------------------- /experiments/camelyon17/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Identity(nn.Module): 6 | """An identity layer""" 7 | def __init__(self, d): 8 | super().__init__() 9 | self.in_features = d 10 | self.out_features = d 11 | 12 | def forward(self, x): 13 | return x 14 | -------------------------------------------------------------------------------- /experiments/prostate/data/generate_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataloader import DataLoader 3 | 4 | def generate_data_loader(args, client_num, source_set, ood_set): 5 | 6 | 7 | ood_loader = torch.utils.data.DataLoader(ood_set, batch_size=args.batch, shuffle=False) 8 | source_loader = torch.utils.data.DataLoader(source_set, batch_size=args.batch, shuffle=True) 9 | return (client_num, [source_loader, ood_loader]) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.basemodel import AdaptorNet 2 | from models.segmodel import SegANet 3 | 4 | import pdb 5 | import logging 6 | logger = logging.getLogger('global') 7 | def create_model(args): 8 | """Create a model and load its weights if given 9 | """ 10 | 11 | adaptorNet = SegANet 12 | 13 | model = adaptorNet(args) 14 | if args.resume_T: 15 | model.load_nets(args.resume_T, name='tnet') 16 | if args.resume_AE: 17 | model.load_nets(args.resume_AE, name='aenet') 18 | return model -------------------------------------------------------------------------------- /experiments/camelyon17/configs/scheduler.py: -------------------------------------------------------------------------------- 1 | scheduler_defaults = { 2 | 'linear_schedule_with_warmup': { 3 | 'scheduler_kwargs':{ 4 | 'num_warmup_steps': 0, 5 | }, 6 | }, 7 | 'cosine_schedule_with_warmup': { 8 | 'scheduler_kwargs':{ 9 | 'num_warmup_steps': 0, 10 | }, 11 | }, 12 | 'ReduceLROnPlateau': { 13 | 'scheduler_kwargs':{}, 14 | }, 15 | 'StepLR': { 16 | 'scheduler_kwargs':{ 17 | 'step_size': 1, 18 | } 19 | }, 20 | 'FixMatchLR': { 21 | 'scheduler_kwargs': {}, 22 | }, 23 | 'MultiStepLR': { 24 | 'scheduler_kwargs':{ 25 | 'gamma': 0.1, 26 | } 27 | }, 28 | } 29 | -------------------------------------------------------------------------------- /experiments/prostate/data/prostate/generate_data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import dataset 2 | import torchvision.transforms as transforms 3 | from torch.utils.data.dataloader import DataLoader 4 | import torch 5 | 6 | from .dataset import Prostate 7 | from data.generate_data_loader import generate_data_loader 8 | 9 | def load_prostate(args): 10 | sites = args.source 11 | ood_site = args.target if args.target is not None else 'HK' 12 | client_num = 1 13 | 14 | transform = None 15 | 16 | ood_set = Prostate(site=ood_site) 17 | 18 | source_set = Prostate(site=sites[0]) 19 | 20 | 21 | 22 | dataset = generate_data_loader(args, client_num, source_set, ood_set) 23 | 24 | return dataset 25 | -------------------------------------------------------------------------------- /experiments/camelyon17/autoencoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch import nn 3 | 4 | 5 | class autoencoder(nn.Module): 6 | def __init__(self): 7 | super(autoencoder, self).__init__() 8 | self.encoder = nn.Sequential( 9 | nn.Linear(1024, 512), 10 | nn.ReLU(True), 11 | nn.Linear(512, 256), 12 | nn.ReLU(True), nn.Linear(256, 128), nn.ReLU(True), nn.Linear(128, 64)) 13 | self.decoder = nn.Sequential( 14 | nn.Linear(64, 128), 15 | nn.ReLU(True), 16 | nn.Linear(128, 256), 17 | nn.ReLU(True), 18 | nn.Linear(256, 512), 19 | nn.ReLU(True), nn.Linear(512, 1024)) 20 | 21 | def forward(self, x): 22 | x = self.encoder(x) 23 | x = self.decoder(x) 24 | return x -------------------------------------------------------------------------------- /scripts/test_oct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ROOT=/research/pheng4/qdliu/hzyang/; 3 | DATAROOT="/research/pheng4/qdliu/hzyang/seg/dataset/" 4 | BATCH_SIZE=2;AUG=0; 5 | results_dir="$ROOT/oct/atta/exps/oct_a{$AUG}b{$BATCH_SIZE}" 6 | export CUDA_VISIBLE_DEVICES="$1" 7 | echo "test OCT on GPU $1" 8 | mkdir -p "$results_dir" 9 | cp -rf "$0" "$results_dir" 10 | python3 train.py \ 11 | --tepochs=1 \ 12 | --alr=1e-3 \ 13 | --task=seg_oct \ 14 | --batch-size=1 \ 15 | --td=1,64,64,64,64,11 \ 16 | --img_path="$DATAROOT"/spectralis/hctrain/image/ \ 17 | --label_path="$DATAROOT"/spectralis/hctrain/label/ \ 18 | --vimg_path="$DATAROOT"/cirrus/image/ \ 19 | --vlabel_path="$DATAROOT"/cirrus/label/ \ 20 | --sub_name="$DATAROOT"/cirrus/cirrus_name.txt \ 21 | --img_ext=png \ 22 | --label_ext=txt \ 23 | --results_dir="$results_dir"/ \ 24 | --ss=1 \ 25 | --resume_T=$results_dir/checkpoints/tnet_checkpoint.pth \ 26 | --resume_AE=$results_dir/checkpoints/aenet_checkpoint.pth \ 27 | --wo=1 \ 28 | --wt=1,0,1,1,1,1 \ 29 | --seq=1,2,3 \ 30 | --si \ 31 | -t 32 | -------------------------------------------------------------------------------- /experiments/prostate/train/model_trainer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class ModelTrainer(ABC): 5 | """Abstract base class for federated learning trainer. 6 | - This class can be used in both server and client side 7 | - This class is an operator which does not cache any states inside. 8 | """ 9 | def __init__(self, model, args=None): 10 | self.model = model 11 | self.id = 0 12 | self.args = args 13 | 14 | def set_id(self, trainer_id): 15 | self.id = trainer_id 16 | 17 | @abstractmethod 18 | def get_model_params(self): 19 | pass 20 | 21 | @abstractmethod 22 | def set_model_params(self, model_parameters): 23 | pass 24 | 25 | 26 | # @abstractmethod 27 | # def validate(self, val_data, device, args=None): 28 | # pass 29 | 30 | 31 | 32 | # @abstractmethod 33 | # def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool: 34 | # pass 35 | -------------------------------------------------------------------------------- /experiments/camelyon17/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss 3 | from wilds.common.metrics.all_metrics import MSE 4 | from utils import cross_entropy_with_logits_loss 5 | 6 | def initialize_loss(loss, config): 7 | if loss == 'cross_entropy': 8 | return ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none', ignore_index=-100)) 9 | 10 | elif loss == 'lm_cross_entropy': 11 | return MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none', ignore_index=-100)) 12 | 13 | elif loss == 'mse': 14 | return MSE(name='loss') 15 | 16 | elif loss == 'multitask_bce': 17 | return MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')) 18 | 19 | elif loss == 'fasterrcnn_criterion': 20 | from models.detection.fasterrcnn import FasterRCNNLoss 21 | return ElementwiseLoss(loss_fn=FasterRCNNLoss(config.device)) 22 | 23 | elif loss == 'cross_entropy_logits': 24 | return ElementwiseLoss(loss_fn=cross_entropy_with_logits_loss) 25 | 26 | else: 27 | raise ValueError(f'loss {loss} not recognized') 28 | -------------------------------------------------------------------------------- /experiments/prostate/train/center.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import os 4 | import copy 5 | 6 | class Center: 7 | def __init__(self, client_idx, training_data, args, device, 8 | model_trainer): 9 | self.client_idx = client_idx 10 | self.local_training_data = training_data 11 | 12 | 13 | self.args = args 14 | self.device = device 15 | self.model_trainer = model_trainer 16 | 17 | 18 | 19 | 20 | def train(self, w_global): 21 | self.model_trainer.set_model_params(w_global) 22 | self.model_trainer.train(self.local_training_data, self.device, self.args) 23 | weights = self.model_trainer.get_model_params() 24 | 25 | return weights 26 | 27 | 28 | def ood_test(self, ood_data, w_global): 29 | self.model_trainer.set_model_params(w_global) 30 | metrics = self.model_trainer.test(ood_data, self.device, self.args, True) 31 | return metrics 32 | 33 | 34 | def test_time_adaptation(self, w_global): 35 | if w_global != None: 36 | self.model_trainer.set_model_params(w_global) 37 | metrics = self.model_trainer.test_time(self.local_training_data, self.device, self.args) 38 | return metrics 39 | -------------------------------------------------------------------------------- /experiments/camelyon17/ttt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import numpy as np 4 | 5 | # Assumes that tensor is (nchannels, height, width) 6 | def tensor_rot_90(x): 7 | return x.flip(2).transpose(1, 2) 8 | 9 | def tensor_rot_180(x): 10 | return x.flip(2).flip(1) 11 | 12 | def tensor_rot_270(x): 13 | return x.transpose(1, 2).flip(2) 14 | 15 | def rotate_batch_with_labels(batch, labels): 16 | images = [] 17 | for img, label in zip(batch, labels): 18 | if label == 1: 19 | img = tensor_rot_90(img) 20 | elif label == 2: 21 | img = tensor_rot_180(img) 22 | elif label == 3: 23 | img = tensor_rot_270(img) 24 | #print(img.shape) 25 | images.append(img.unsqueeze(0)) 26 | return torch.cat(images) 27 | 28 | def rotate_batch(batch, label): 29 | if label == 'rand': 30 | labels = torch.randint(4, (len(batch),), dtype=torch.long) 31 | elif label == 'expand': 32 | labels = torch.cat([torch.zeros(len(batch), dtype=torch.long), 33 | torch.zeros(len(batch), dtype=torch.long) + 1, 34 | torch.zeros(len(batch), dtype=torch.long) + 2, 35 | torch.zeros(len(batch), dtype=torch.long) + 3]) 36 | batch = batch.repeat((4,1,1,1)) 37 | else: 38 | assert isinstance(label, int) 39 | labels = torch.zeros((len(batch),), dtype=torch.long) + label 40 | return rotate_batch_with_labels(batch, labels), labels -------------------------------------------------------------------------------- /scripts/train_oct.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ROOT=/research/pheng4/qdliu/hzyang/; 3 | TRAINER=$2; BATCH_SIZE=2;AUG=0;SEGLOSS='ce'; 4 | DATAROOT="/research/pheng4/qdliu/hzyang/seg/dataset/" 5 | results_dir="$ROOT/oct/atta/exps/oct_a{$AUG}b{$BATCH_SIZE}" 6 | export CUDA_VISIBLE_DEVICES="$1" 7 | echo "train OCT $TRAINER on GPU $1" 8 | mkdir -p "$results_dir" 9 | cp -rf "$0" "$results_dir" 10 | if [ $TRAINER == "tnet" ]; then 11 | python3 train.py \ 12 | --epochs=20 \ 13 | --task=seg_oct \ 14 | --segloss=$SEGLOSS \ 15 | --batch-size=$BATCH_SIZE \ 16 | --td=1,64,64,64,64,11 \ 17 | --aangle=0,0 \ 18 | --aprob=$AUG \ 19 | --img_path="$DATAROOT"/spectralis/hctrain/image/ \ 20 | --label_path="$DATAROOT"/spectralis/hctrain/label/ \ 21 | --vimg_path="$DATAROOT"/spectralis/hcval/image/ \ 22 | --vlabel_path="$DATAROOT"/spectralis/hcval/label/ \ 23 | --trainer=$TRAINER \ 24 | --img_ext=png \ 25 | --label_ext=txt \ 26 | --results_dir="$results_dir" 27 | else 28 | python3 train.py \ 29 | --epochs=20 \ 30 | --task=seg_oct \ 31 | --segloss=$SEGLOSS \ 32 | --batch-size=$BATCH_SIZE \ 33 | --td=1,64,64,64,64,11 \ 34 | --aangle=0,0 \ 35 | --aprob=0 \ 36 | --img_path="$DATAROOT"/spectralis/hctrain/image/ \ 37 | --label_path="$DATAROOT"/spectralis/hctrain/label/ \ 38 | --vimg_path="$DATAROOT"/spectralis/hcval/image/ \ 39 | --vlabel_path="$DATAROOT"/spectralis/hcval/label/ \ 40 | --trainer=$TRAINER \ 41 | --img_ext=png \ 42 | --label_ext=txt \ 43 | --segaeloss=mse \ 44 | --results_dir="$results_dir"/ \ 45 | --wt=1,0,1,1,1,1,1 \ 46 | --resume_T=$results_dir/checkpoints/tnet_checkpoint_e20.pth 47 | fi 48 | -------------------------------------------------------------------------------- /experiments/camelyon17/configs/model.py: -------------------------------------------------------------------------------- 1 | model_defaults = { 2 | 'bert-base-uncased': { 3 | 'optimizer': 'AdamW', 4 | 'max_grad_norm': 1.0, 5 | 'scheduler': 'linear_schedule_with_warmup', 6 | }, 7 | 'distilbert-base-uncased': { 8 | 'optimizer': 'AdamW', 9 | 'max_grad_norm': 1.0, 10 | 'scheduler': 'linear_schedule_with_warmup', 11 | }, 12 | 'code-gpt-py': { 13 | 'optimizer': 'AdamW', 14 | 'max_grad_norm': 1.0, 15 | 'scheduler': 'linear_schedule_with_warmup', 16 | }, 17 | 'densenet121': { 18 | 'model_kwargs': { 19 | 'pretrained':True, 20 | }, 21 | 'target_resolution': (224, 224), 22 | }, 23 | 'wideresnet50': { 24 | 'model_kwargs': { 25 | 'pretrained':True, 26 | }, 27 | 'target_resolution': (224, 224), 28 | }, 29 | 'resnet18': { 30 | 'model_kwargs':{ 31 | 'pretrained':True, 32 | }, 33 | 'target_resolution': (224, 224), 34 | }, 35 | 'resnet34': { 36 | 'model_kwargs':{ 37 | 'pretrained':True, 38 | }, 39 | 'target_resolution': (224, 224), 40 | }, 41 | 'resnet50': { 42 | 'model_kwargs': { 43 | 'pretrained': True, 44 | }, 45 | 'target_resolution': (224, 224), 46 | }, 47 | 'resnet101': { 48 | 'model_kwargs': { 49 | 'pretrained': True, 50 | }, 51 | 'target_resolution': (224, 224), 52 | }, 53 | 'gin-virtual': {}, 54 | 'resnet18_ms': { 55 | 'target_resolution': (224, 224), 56 | }, 57 | 'logistic_regression': {}, 58 | 'unet-seq': { 59 | 'optimizer': 'Adam' 60 | }, 61 | 'fasterrcnn': { 62 | 'model_kwargs': { 63 | 'pretrained_model': True, 64 | 'pretrained_backbone': True, 65 | 'min_size' :1024, 66 | 'max_size' :1024 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /experiments/prostate/utils/data_preprocess.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import numpy as np 3 | import os 4 | 5 | def normalize(arr): 6 | return (arr-np.mean(arr)) / np.std(arr) 7 | 8 | 9 | clients = ['BIDMC', 'HK', 'I2CVB', 'ISBI', 'ISBI_1.5', 'UCL'] 10 | base_path = '../dataset/Prostate/data' 11 | tar_path = '../dataset/Prostate/processed' 12 | for client in clients: 13 | folder = os.path.join(base_path, client) 14 | nii_seg_list = [nii for nii in os.listdir(folder) if 'segmentation' in str(nii).lower()] 15 | slice_count = dict() 16 | tar_folder_path = os.path.join(tar_path,client) 17 | if not os.path.exists(tar_folder_path): os.makedirs(tar_folder_path) 18 | 19 | for nii_seg in nii_seg_list: 20 | nii_path = os.path.join(folder, nii_seg[:6]+'.nii.gz') 21 | nii_seg_path = os.path.join(folder, nii_seg) 22 | 23 | image_vol = sitk.ReadImage(nii_path) 24 | label_vol = sitk.ReadImage(nii_seg_path) 25 | image_vol = sitk.GetArrayFromImage(image_vol) 26 | label_vol = sitk.GetArrayFromImage(label_vol) 27 | label_vol[label_vol > 1] = 1 28 | has_label = list(set(np.where(label_vol>0)[0])) 29 | 30 | label_vol = label_vol[has_label] 31 | image_vol = image_vol[has_label] 32 | 33 | image_v3 = [] 34 | for i in range(image_vol.shape[0]): 35 | if i==0: 36 | image = np.concatenate([np.expand_dims(image_vol[0, :, :],0),image_vol[i:i+2, :, :]],axis=0) 37 | elif i==image_vol.shape[0]-1: 38 | image = np.concatenate([image_vol[i-2:i, :, :],np.expand_dims(image_vol[i, :, :],0)]) 39 | else: 40 | image = np.array(image_vol[i-1:i+2, :, :]) 41 | 42 | image = np.transpose(image,(1,2,0)) 43 | assert image.shape == (384, 384,3) 44 | 45 | image_v3.append(image) 46 | image_v3 = np.asarray(image_v3) 47 | slice_count[nii_seg[:6]] = image_vol.shape[0] 48 | 49 | np.save(os.path.join(tar_folder_path,nii_seg[:6]+'.npy'),image_v3) 50 | np.save(os.path.join(tar_folder_path,nii_seg[:6]+'_segmentation.npy'),label_vol) 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /experiments/prostate/train/configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | 4 | prosate = ['BIDMC', 'HK', 'ISBI', 'ISBI_1.5', 'UCL', 'I2CVB', None] 5 | available_datasets = prosate 6 | 7 | def set_configs(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--log', action='store_true', help='whether to log') 10 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 11 | parser.add_argument('--early', action='store_true', help='early stop w/o improvement over 10 epochs') 12 | parser.add_argument('--batch', type = int, default= 1, help ='batch size') 13 | parser.add_argument("--source", choices=available_datasets, help="Source", nargs='+') 14 | parser.add_argument("--target", choices=available_datasets, default=None, help="Target") 15 | parser.add_argument('--rounds', type = int, default=500, help = 'rounds') 16 | parser.add_argument('--wk_iters', type = int, default=1, help = 'optimization iters in local worker between communication') 17 | parser.add_argument('--save_path', type = str, default='../checkpoint/', help='path to save the checkpoint') 18 | parser.add_argument('--resume', action='store_true', help ='resume training from the save path checkpoint') 19 | parser.add_argument('--gpu', type = str, default="0", help = 'gpu device number') 20 | parser.add_argument('--seed', type = int, default=0, help = 'random seed') 21 | parser.add_argument('--client_optimizer', type = str, default='adam', help='local optimizer') 22 | parser.add_argument('--data', type = str, default='prostate', help='datasets') 23 | parser.add_argument('--test_time', type = str, default='tent', help='test time adaptation methods') 24 | parser.add_argument('--debug', action='store_true', help = 'use small data to debug') 25 | parser.add_argument('--test', action='store_true', help='test on local clients') 26 | parser.add_argument('--ood_test', action='store_true', help='test on ood client') 27 | parser.add_argument('--every_save', action='store_true', help='Save ckpt with explicit name every iter') 28 | 29 | args = parser.parse_args() 30 | 31 | 32 | return args 33 | 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DLTTA: Dynamic Learning Rate for Test-time Adaptation on Cross-domain Medical Images 2 | Pytorch implementation for TMI paper DLTTA: Dynamic Learning Rate for Test-time Adaptation on Cross-domain Medical Images, by [Hongzheng Yang](https://github.com/HongZhengYang), [Chen Cheng](https://cchen-cc.github.io/), [Meirui Jiang](https://meiruijiang.github.io/MeiruiJiang/), [Quande Liu](https://liuquande.github.io/), [Jianfeng Cao](), [Pheng-Ann Heng](http://www.cse.cuhk.edu.hk/~pheng/), [Qi Dou](http://www.cse.cuhk.edu.hk/~qdou/). 3 | 4 | ## Abstract 5 | 6 | 7 | 8 | ![](assets/Figure.png) 9 | 10 | ## Files 11 | 12 | In this repository, we provide the implementation of our dynamic learning rate method on OCT dataset. The ATTA and Tent implementation were adopted from their official implementation. ([Tent](https://github.com/DequanWang/tent), [ATTA]()) 13 | 14 | To reproduce results on Camelyon17 and Prostate datasets, please refer to the experiments folder. 15 | 16 | ## Datasets 17 | 18 | The OCT dataset can downloaded from [here](http://iacl.ece.jhu.edu/index.php?title=Resources). 19 | 20 | The Camelyon17 dataset can be downloaded from [here](https://wilds.stanford.edu/). 21 | 22 | The Prostate dataset can be downloaded from [here](https://liuquande.github.io/SAML/). 23 | 24 | ## Usage 25 | 26 | 1. create conda environment 27 | 28 | conda create -n DLTTA python=3.7 29 | conda activate DLTTA 30 | 31 | 2. Install dependencies: 32 | 33 | 1. install pytorch==1.7.0 torchvision==0.9.0 (via conda, recommend) 34 | 35 | 3. download the dataset 36 | 37 | 4. download the pretrained model from [google drive](https://drive.google.com/drive/folders/1-Y63KlYmBsEQp5vz3gm2IjdY8TOvidCA?usp=sharing) 38 | 39 | 5. modify the corresponding data path and model path in test.sh 40 | 41 | 6. run test.sh to adapt the model 42 | 43 | ## Citation 44 | 45 | If this repository is useful for your research, please cite: 46 | 47 | @article{2022DLTTA, 48 | title={DLTTA: Dynamic Learning Rate for Test-time Adaptation on Cross-domain Medical Images}, 49 | author={Hongzheng Yang, Cheng Chen, Meirui Jiang, Quande Liu, Jianfeng Cao, Pheng Ann Heng, Qi Dou}, 50 | year={2022} 51 | } 52 | 53 | ### Questions 54 | 55 | Please feel free to contact 'hzyang05@gmail.com' if you have any questions. 56 | 57 | -------------------------------------------------------------------------------- /experiments/camelyon17/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import SGD, Adam 2 | from transformers import AdamW 3 | 4 | def initialize_optimizer(config, model): 5 | # initialize optimizers 6 | if config.optimizer=='SGD': 7 | params = filter(lambda p: p.requires_grad, model.parameters()) 8 | optimizer = SGD( 9 | params, 10 | lr=config.lr, 11 | weight_decay=config.weight_decay, 12 | **config.optimizer_kwargs) 13 | elif config.optimizer=='AdamW': 14 | if 'bert' in config.model or 'gpt' in config.model: 15 | no_decay = ['bias', 'LayerNorm.weight'] 16 | else: 17 | no_decay = [] 18 | 19 | params = [ 20 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': config.weight_decay}, 21 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 22 | ] 23 | optimizer = AdamW( 24 | params, 25 | lr=config.lr, 26 | **config.optimizer_kwargs) 27 | elif config.optimizer == 'Adam': 28 | params = filter(lambda p: p.requires_grad, model.parameters()) 29 | optimizer = Adam( 30 | params, 31 | lr=config.lr, 32 | weight_decay=config.weight_decay, 33 | **config.optimizer_kwargs) 34 | else: 35 | raise ValueError(f'Optimizer {config.optimizer} not recognized.') 36 | 37 | return optimizer 38 | 39 | def initialize_optimizer_with_model_params(config, params): 40 | if config.optimizer=='SGD': 41 | optimizer = SGD( 42 | params, 43 | lr=config.lr, 44 | weight_decay=config.weight_decay, 45 | **config.optimizer_kwargs 46 | ) 47 | elif config.optimizer=='AdamW': 48 | optimizer = AdamW( 49 | params, 50 | lr=config.lr, 51 | weight_decay=config.weight_decay, 52 | **config.optimizer_kwargs 53 | ) 54 | elif config.optimizer == 'Adam': 55 | optimizer = Adam( 56 | params, 57 | lr=config.lr, 58 | weight_decay=config.weight_decay, 59 | **config.optimizer_kwargs 60 | ) 61 | else: 62 | raise ValueError(f'Optimizer {config.optimizer} not supported.') 63 | 64 | return optimizer 65 | -------------------------------------------------------------------------------- /experiments/camelyon17/evaluate/MbPA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from tqdm import trange 5 | import copy 6 | 7 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 8 | """Entropy of softmax distribution from logits.""" 9 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 10 | 11 | 12 | class ReplayMemory(object): 13 | """ 14 | Create the empty memory buffer 15 | """ 16 | 17 | def __init__(self, size): 18 | self.memory = {} 19 | self.size = size 20 | 21 | def get_size(self): 22 | return len(self.memory) 23 | 24 | def push(self, keys, logits): 25 | 26 | dimension = 1024*3*3 27 | avg = [] 28 | 29 | for i, key in enumerate(keys): 30 | 31 | if len(self.memory.keys())>self.size: 32 | self.memory.pop(list(self.memory)[0]) 33 | self.memory.update( 34 | {key.reshape(dimension).tobytes(): (logits[i])}) 35 | 36 | def _prepare_batch(self, sample): 37 | 38 | 39 | ensemble_prediction = sample[0] 40 | 41 | for logit in sample: 42 | ensemble_prediction = ensemble_prediction + logit 43 | 44 | ensemble_prediction = ensemble_prediction - sample[0] 45 | ensemble_prediction = ensemble_prediction / len(sample) 46 | return torch.FloatTensor(ensemble_prediction) 47 | 48 | 49 | def get_neighbours(self, keys, k): 50 | """ 51 | Returns samples from buffer using nearest neighbour approach 52 | """ 53 | samples = [] 54 | 55 | dimension = 1024*3*3 56 | keys = keys.reshape(len(keys), dimension) 57 | total_keys = len(self.memory.keys()) 58 | self.all_keys = np.frombuffer( 59 | np.asarray(list(self.memory.keys())), dtype=np.float32).reshape(total_keys, dimension) 60 | 61 | for key in keys: 62 | 63 | similarity_scores = np.dot(self.all_keys, key.T) 64 | K_neighbour_keys = self.all_keys[np.argpartition( 65 | similarity_scores, -k)[-k:]] 66 | neighbours = [self.memory[nkey.tobytes()] 67 | for nkey in K_neighbour_keys] 68 | batch = self._prepare_batch(neighbours) 69 | 70 | samples.append(batch) 71 | 72 | return torch.stack(samples) 73 | -------------------------------------------------------------------------------- /experiments/camelyon17/configs/supported.py: -------------------------------------------------------------------------------- 1 | from wilds.common.metrics.all_metrics import ( 2 | Accuracy, 3 | MultiTaskAccuracy, 4 | MSE, 5 | multiclass_logits_to_pred, 6 | binary_logits_to_pred, 7 | pseudolabel_binary_logits, 8 | pseudolabel_multiclass_logits, 9 | pseudolabel_identity, 10 | pseudolabel_detection, 11 | pseudolabel_detection_discard_empty, 12 | MultiTaskAveragePrecision 13 | ) 14 | 15 | algo_log_metrics = { 16 | 'accuracy': Accuracy(prediction_fn=multiclass_logits_to_pred), 17 | 'mse': MSE(), 18 | 'multitask_accuracy': MultiTaskAccuracy(prediction_fn=multiclass_logits_to_pred), 19 | 'multitask_binary_accuracy': MultiTaskAccuracy(prediction_fn=binary_logits_to_pred), 20 | 'multitask_avgprec': MultiTaskAveragePrecision(prediction_fn=None), 21 | None: None, 22 | } 23 | 24 | process_outputs_functions = { 25 | 'binary_logits_to_pred': binary_logits_to_pred, 26 | 'multiclass_logits_to_pred': multiclass_logits_to_pred, 27 | None: None, 28 | } 29 | 30 | process_pseudolabels_functions = { 31 | 'pseudolabel_binary_logits': pseudolabel_binary_logits, 32 | 'pseudolabel_multiclass_logits': pseudolabel_multiclass_logits, 33 | 'pseudolabel_identity': pseudolabel_identity, 34 | 'pseudolabel_detection': pseudolabel_detection, 35 | 'pseudolabel_detection_discard_empty': pseudolabel_detection_discard_empty, 36 | } 37 | 38 | # see initialize_*() functions for correspondence= 39 | # See algorithms/initializer.py 40 | algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM', 'DANN', 'AFN', 'FixMatch', 'PseudoLabel', 'NoisyStudent'] 41 | 42 | # See transforms.py 43 | transforms = ['bert', 'image_base', 'image_resize', 'image_resize_and_center_crop', 'poverty', 'rxrx1'] 44 | additional_transforms = ['randaugment', 'weak'] 45 | 46 | # See models/initializer.py 47 | models = ['resnet18_ms', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'wideresnet50', 48 | 'densenet121', 'bert-base-uncased', 'distilbert-base-uncased', 49 | 'gin-virtual', 'logistic_regression', 'code-gpt-py', 50 | 'fasterrcnn', 'unet-seq'] 51 | 52 | # See optimizer.py 53 | optimizers = ['SGD', 'Adam', 'AdamW'] 54 | 55 | # See scheduler.py 56 | schedulers = ['linear_schedule_with_warmup', 'cosine_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR', 'FixMatchLR', 'MultiStepLR'] 57 | 58 | # See losses.py 59 | losses = ['cross_entropy', 'lm_cross_entropy', 'MSE', 'multitask_bce', 'fasterrcnn_criterion', 'cross_entropy_logits'] 60 | -------------------------------------------------------------------------------- /memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from tqdm import trange 5 | import copy 6 | from numpy.linalg import norm 7 | 8 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 9 | """Entropy of softmax distribution from logits.""" 10 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 11 | 12 | 13 | class Memory(object): 14 | """ 15 | Create the empty memory buffer 16 | """ 17 | 18 | def __init__(self, size): 19 | 20 | self.memory = {} 21 | self.size = size 22 | 23 | def get_size(self): 24 | return len(self.memory) 25 | 26 | def push(self, keys, logits): 27 | dimension = 131072 28 | avg = [] 29 | 30 | for i, key in enumerate(keys): 31 | 32 | if len(self.memory.keys())>self.size: 33 | self.memory.pop(list(self.memory)[0]) 34 | 35 | self.memory.update( 36 | {key.reshape(dimension).tobytes(): (logits[i])}) 37 | 38 | def _prepare_batch(self, sample, attention_weight): 39 | 40 | attention_weight = np.array(attention_weight/0.2) 41 | attention_weight = np.exp(attention_weight) / (np.sum(np.exp(attention_weight))) 42 | print(attention_weight) 43 | ensemble_prediction = sample[0] * attention_weight[0] 44 | for i in range(1, len(sample)): 45 | nsemble_prediction = ensemble_prediction + sample[i] * attention_weight[i] 46 | 47 | return torch.FloatTensor(ensemble_prediction) 48 | 49 | 50 | def get_neighbours(self, keys, k): 51 | """ 52 | Returns samples from buffer using nearest neighbour approach 53 | """ 54 | samples = [] 55 | 56 | dimension = 131072 57 | keys = keys.reshape(len(keys), dimension) 58 | total_keys = len(self.memory.keys()) 59 | self.all_keys = np.frombuffer( 60 | np.asarray(list(self.memory.keys())), dtype=np.float32).reshape(total_keys, dimension) 61 | 62 | for key in keys: 63 | 64 | similarity_scores = np.dot(self.all_keys, key.T)/(norm(self.all_keys, axis=1) * norm(key.T) ) 65 | 66 | K_neighbour_keys = self.all_keys[np.argpartition( 67 | similarity_scores, -k)[-k:]] 68 | neighbours = [self.memory[nkey.tobytes()] 69 | for nkey in K_neighbour_keys] 70 | 71 | attention_weight = np.dot(K_neighbour_keys, key.T) /(norm(K_neighbour_keys, axis=1) * norm(key.T) ) 72 | batch = self._prepare_batch(neighbours, attention_weight) 73 | samples.append(batch) 74 | 75 | return torch.stack(samples) 76 | 77 | 78 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.dataset import OCTSegDataset 2 | from utils.util import split_data 3 | from torch.utils.data import Dataset, DataLoader 4 | import os 5 | import pdb 6 | from copy import deepcopy 7 | import logging 8 | logger = logging.getLogger('global') 9 | def create_dataset(args): 10 | ''' Create train and val dataset 11 | ''' 12 | # Data loading code 13 | 14 | dataset = OCTSegDataset 15 | 16 | train_dataset = dataset(args, train=True, augment=True) 17 | val_dataset_list = [] 18 | if args.vimg_path: 19 | val_dataset = dataset(args, train=False, augment=False) 20 | # return a list of val_dataset, each contains one subject 21 | if args.__dict__.get('sub_name',False): 22 | if os.path.isfile(args.sub_name): 23 | with open(args.sub_name) as f: 24 | dataname = f.read().splitlines() 25 | datalist = val_dataset.datalist 26 | labellist = val_dataset.labellist 27 | for name in dataname: 28 | val_dataset.datalist = sorted([_ for _ in datalist if name in str(_)]) 29 | val_dataset.labellist = sorted([_ for _ in labellist if name in str(_)]) 30 | val_dataset_list.append(deepcopy(val_dataset)) 31 | else: 32 | val_dataset_list = [val_dataset] 33 | elif args.split and not args.__dict__.get('sub_name',False): 34 | train_dataset, val_dataset = split_data(dataset=train_dataset, 35 | split=args.split, 36 | switch=args.test or args.evaluate) 37 | train_dataset = train_dataset 38 | val_dataset = val_dataset 39 | val_dataset_list = [val_dataset] 40 | else: 41 | # no validation used 42 | val_dataset = dataset(args, train=False) 43 | val_dataset_list = [val_dataset] 44 | logger.info('Found {} training samples'.format(len(train_dataset))) 45 | logger.info('Found {} validation subjects with total {} samples'.format(len(val_dataset_list),len(val_dataset))) 46 | 47 | train_loader = DataLoader(dataset=train_dataset, 48 | batch_size=args.batch_size, 49 | shuffle=True, 50 | num_workers=args.workers, 51 | pin_memory=True) 52 | val_loader = [DataLoader(dataset=_, 53 | batch_size=args.batch_size, 54 | shuffle=True, 55 | num_workers=args.workers, 56 | pin_memory=True) for _ in val_dataset_list] 57 | return train_loader, val_loader -------------------------------------------------------------------------------- /experiments/camelyon17/evaluate/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision 3 | 4 | config = { 5 | 'pretrained':True, 6 | } 7 | 8 | def initialize_model(model_name, d_out, is_featurizer=False): 9 | """ 10 | Initializes models according to the config 11 | Args: 12 | - config (dictionary): config dictionary 13 | - d_out (int): the dimensionality of the model output 14 | - is_featurizer (bool): whether to return a model or a (featurizer, classifier) pair that constitutes a model. 15 | Output: 16 | If is_featurizer=True: 17 | - featurizer: a model that outputs feature Tensors of shape (batch_size, ..., feature dimensionality) 18 | - classifier: a model that takes in feature Tensors and outputs predictions. In most cases, this is a linear layer. 19 | 20 | If is_featurizer=False: 21 | - model: a model that is equivalent to nn.Sequential(featurizer, classifier) 22 | """ 23 | if model_name in ('resnet50', 'resnet34', 'wideresnet50', 'densenet121'): 24 | if is_featurizer: 25 | featurizer = initialize_torchvision_model( 26 | name=model_name, 27 | d_out=None, 28 | **config) 29 | classifier = nn.Linear(featurizer.d_out, d_out) 30 | model = (featurizer, classifier) 31 | else: 32 | model = initialize_torchvision_model( 33 | name=model_name, 34 | d_out=d_out, 35 | **config) 36 | return model 37 | 38 | 39 | def initialize_torchvision_model(name, d_out, **kwargs): 40 | # get constructor and last layer names 41 | if name == 'wideresnet50': 42 | constructor_name = 'wide_resnet50_2' 43 | last_layer_name = 'fc' 44 | elif name == 'densenet121': 45 | constructor_name = name 46 | last_layer_name = 'classifier' 47 | elif name in ('resnet50', 'resnet34'): 48 | constructor_name = name 49 | last_layer_name = 'fc' 50 | else: 51 | raise ValueError(f'Torchvision model {name} not recognized') 52 | # construct the default model, which has the default last layer 53 | constructor = getattr(torchvision.models, constructor_name) 54 | model = constructor(**kwargs) 55 | # adjust the last layer 56 | d_features = getattr(model, last_layer_name).in_features 57 | if d_out is None: # want to initialize a featurizer model 58 | last_layer = Identity(d_features) 59 | model.d_out = d_features 60 | else: # want to initialize a classifier for a particular num_classes 61 | last_layer = nn.Linear(d_features, d_out) 62 | model.d_out = d_out 63 | setattr(model, last_layer_name, last_layer) 64 | return model -------------------------------------------------------------------------------- /experiments/prostate/data/prostate/dataset.py: -------------------------------------------------------------------------------- 1 | from email.mime import base 2 | import sys, os 3 | base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(base_path) 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | import os 12 | import h5py 13 | import scipy.io as scio 14 | from glob import glob 15 | import SimpleITK as sitk 16 | import random 17 | import cv2 18 | import torchio as tio 19 | 20 | 21 | 22 | def _label_decomp(label_vol, num_cls): 23 | """ 24 | decompose label for softmax classifier 25 | original labels are batchsize * W * H * 1, with label values 0,1,2,3... 26 | this function decompse it to one hot, e.g.: 0,0,0,1,0,0 in channel dimension 27 | numpy version of tf.one_hot 28 | """ 29 | one_hot = [] 30 | for i in range(num_cls): 31 | _vol = np.zeros(label_vol.shape) 32 | _vol[label_vol == i] = 1 33 | one_hot.append(_vol) 34 | 35 | return np.stack(one_hot, axis=0) 36 | 37 | 38 | 39 | class Normalization(object): 40 | 41 | 42 | def __init__(self): 43 | self.name = 'Normalization' 44 | 45 | def __call__(self, sample): 46 | resacleFilter = sitk.RescaleIntensityImageFilter() 47 | resacleFilter.SetOutputMaximum(255) 48 | resacleFilter.SetOutputMinimum(0) 49 | image, label = sample['image'], sample['label'] 50 | image = resacleFilter.Execute(image) 51 | 52 | return {'image':image, 'label':label} 53 | 54 | class Prostate(Dataset): 55 | ''' 56 | Six prostate dataset (BIDMC, HK, I2CVB, ISBI, ISBI_1.5, UCL) 57 | ''' 58 | def __init__(self, site, base_path=None): 59 | channels = {'BIDMC':3, 'HK':3, 'I2CVB':3, 'ISBI':3, 'ISBI_1.5':3, 'UCL':3} 60 | assert site in list(channels.keys()) 61 | 62 | base_path = '/research/pheng4/qdliu/hzyang/prostate/test/' 63 | self.site = site 64 | self.base_path = base_path 65 | self.f_names = os.listdir(os.path.join(base_path, self.site)) 66 | 67 | 68 | 69 | def __len__(self): 70 | return len(self.f_names) 71 | 72 | def __getitem__(self, idx): 73 | f_name = self.f_names[idx] 74 | #print(idx) 75 | sampledir = os.path.join(self.base_path, self.site, f_name) 76 | 77 | test_patch = np.load(sampledir) 78 | image_np = test_patch[0] 79 | label_np = test_patch[1] 80 | 81 | 82 | image = torch.Tensor(image_np) 83 | label = torch.Tensor(label_np) 84 | image = torch.unsqueeze(image, dim=0) 85 | label = torch.unsqueeze(label, dim=0) 86 | 87 | 88 | label = _label_decomp(label[0], 2) 89 | return image, label 90 | 91 | 92 | if __name__=='__main__': 93 | pass -------------------------------------------------------------------------------- /experiments/prostate/train/test_dataset.py: -------------------------------------------------------------------------------- 1 | from email.mime import base 2 | import sys, os 3 | base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(base_path) 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | import os 12 | import h5py 13 | import scipy.io as scio 14 | from glob import glob 15 | import SimpleITK as sitk 16 | import random 17 | import cv2 18 | import torchio as tio 19 | 20 | 21 | 22 | def _label_decomp(label_vol, num_cls): 23 | """ 24 | decompose label for softmax classifier 25 | original labels are batchsize * W * H * 1, with label values 0,1,2,3... 26 | this function decompse it to one hot, e.g.: 0,0,0,1,0,0 in channel dimension 27 | numpy version of tf.one_hot 28 | """ 29 | one_hot = [] 30 | for i in range(num_cls): 31 | _vol = np.zeros(label_vol.shape) 32 | _vol[label_vol == i] = 1 33 | one_hot.append(_vol) 34 | 35 | return np.stack(one_hot, axis=0) 36 | 37 | 38 | 39 | class Normalization(object): 40 | 41 | 42 | def __init__(self): 43 | self.name = 'Normalization' 44 | 45 | def __call__(self, sample): 46 | resacleFilter = sitk.RescaleIntensityImageFilter() 47 | resacleFilter.SetOutputMaximum(255) 48 | resacleFilter.SetOutputMinimum(0) 49 | image, label = sample['image'], sample['label'] 50 | image = resacleFilter.Execute(image) 51 | 52 | return {'image':image, 'label':label} 53 | 54 | class Prostate(Dataset): 55 | ''' 56 | Six prostate dataset (BIDMC, HK, I2CVB, ISBI, ISBI_1.5, UCL) 57 | ''' 58 | def __init__(self, site, base_path=None): 59 | channels = {'BIDMC':3, 'HK':3, 'I2CVB':3, 'ISBI':3, 'ISBI_1.5':3, 'UCL':3} 60 | assert site in list(channels.keys()) 61 | 62 | base_path = '/research/pheng4/qdliu/hzyang/prostate/test/' 63 | self.site = site 64 | self.base_path = base_path 65 | self.f_names = os.listdir(os.path.join(base_path, self.site)) 66 | 67 | 68 | 69 | def __len__(self): 70 | return len(self.f_names) 71 | 72 | def __getitem__(self, idx): 73 | f_name = self.f_names[idx] 74 | #print(idx) 75 | sampledir = os.path.join(self.base_path, self.site, f_name) 76 | 77 | test_patch = np.load(sampledir) 78 | image_np = test_patch[0] 79 | label_np = test_patch[1] 80 | 81 | 82 | image = torch.Tensor(image_np) 83 | label = torch.Tensor(label_np) 84 | image = torch.unsqueeze(image, dim=0) 85 | label = torch.unsqueeze(label, dim=0) 86 | 87 | 88 | label = _label_decomp(label[0], 2) 89 | return image, label 90 | 91 | 92 | if __name__=='__main__': 93 | pass -------------------------------------------------------------------------------- /experiments/prostate/train/api.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import random 4 | import sys,os 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from .center import Center 10 | 11 | 12 | class API(object): 13 | def __init__(self, dataset, device, args, model_trainer): 14 | """ 15 | dataset: data loaders and data size info 16 | """ 17 | self.device = device 18 | self.args = args 19 | client_num, [source_data, ood_data] = dataset 20 | print(len(source_data)) 21 | 22 | 23 | 24 | self.center_list = [] 25 | self.ood_data = ood_data 26 | 27 | self.model_trainer = model_trainer 28 | self._setup_centers(source_data, model_trainer) 29 | logging.info("############setup ood centers#############") 30 | self.ood_center = Center(-1, ood_data, self.args, self.device, model_trainer) 31 | 32 | 33 | def _setup_centers(self, training_data, model_trainer): 34 | 35 | center_idx = 0 36 | c = Center(center_idx, training_data, self.args, self.device, model_trainer) 37 | self.center_list.append(c) 38 | 39 | def train(self): 40 | w_global = self.model_trainer.get_model_params() 41 | for round_idx in range(self.args.rounds): 42 | 43 | logging.info("============ round : {}".format(round_idx)) 44 | 45 | 46 | center_idx = 0 47 | center = self.center_list[center_idx] 48 | 49 | w = center.train(copy.deepcopy(w_global)) 50 | w_global = copy.deepcopy(w) 51 | torch.save(w_global, os.path.join(self.args.save_path, "global_round{}".format(round_idx))) 52 | self.model_trainer.set_model_params(w_global) 53 | self._ood_test(round_idx, self.ood_center, self.ood_data, w_global) 54 | self._test_time_adaptation() 55 | 56 | 57 | 58 | def _ood_test(self, round_idx, ood_client, ood_data, w_global): 59 | logging.info("============ ood_test : {}".format(round_idx)) 60 | metrics = ood_client.ood_test(ood_data, w_global) 61 | ''' unify key''' 62 | test_acc = metrics["test_acc"] 63 | test_loss = metrics["test_loss"] 64 | stats = {'test_acc': '{:.4f}'.format(test_acc), 'test_loss': '{:.4f}'.format(test_loss)} 65 | 66 | logging.info(stats) 67 | return metrics 68 | 69 | 70 | 71 | def _test_time_adaptation(self, w_global=None): 72 | metrics = self.ood_center.test_time_adaptation(copy.deepcopy(w_global)) 73 | 74 | test_acc = metrics["test_acc"] 75 | test_loss = metrics["test_loss"] 76 | stats = {'test_acc': test_acc, 'test_loss': test_loss} 77 | logging.info("############ performance after test time adaptation #############") 78 | logging.info(stats) 79 | return metrics 80 | 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /experiments/camelyon17/configs/algorithm.py: -------------------------------------------------------------------------------- 1 | algorithm_defaults = { 2 | 'ERM': { 3 | 'train_loader': 'standard', 4 | 'uniform_over_groups': False, 5 | 'eval_loader': 'standard', 6 | 'randaugment_n': 2, # When running ERM + data augmentation 7 | }, 8 | 'groupDRO': { 9 | 'train_loader': 'standard', 10 | 'uniform_over_groups': True, 11 | 'distinct_groups': True, 12 | 'eval_loader': 'standard', 13 | 'group_dro_step_size': 0.01, 14 | }, 15 | 'deepCORAL': { 16 | 'train_loader': 'group', 17 | 'uniform_over_groups': True, 18 | 'distinct_groups': True, 19 | 'eval_loader': 'standard', 20 | 'coral_penalty_weight': 1., 21 | 'randaugment_n': 2, 22 | 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples 23 | }, 24 | 'IRM': { 25 | 'train_loader': 'group', 26 | 'uniform_over_groups': True, 27 | 'distinct_groups': True, 28 | 'eval_loader': 'standard', 29 | 'irm_lambda': 100., 30 | 'irm_penalty_anneal_iters': 500, 31 | }, 32 | 'DANN': { 33 | 'train_loader': 'group', 34 | 'uniform_over_groups': True, 35 | 'distinct_groups': True, 36 | 'eval_loader': 'standard', 37 | 'randaugment_n': 2, 38 | 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples 39 | }, 40 | 'AFN': { 41 | 'train_loader': 'standard', 42 | 'uniform_over_groups': False, 43 | 'eval_loader': 'standard', 44 | 'use_hafn': False, 45 | 'afn_penalty_weight': 0.01, 46 | 'safn_delta_r': 1.0, 47 | 'hafn_r': 1.0, 48 | 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples 49 | 'randaugment_n': 2, 50 | }, 51 | 'FixMatch': { 52 | 'train_loader': 'standard', 53 | 'uniform_over_groups': False, 54 | 'eval_loader': 'standard', 55 | 'self_training_lambda': 1, 56 | 'self_training_threshold': 0.7, 57 | 'scheduler': 'FixMatchLR', 58 | 'randaugment_n': 2, 59 | 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled examples 60 | }, 61 | 'PseudoLabel': { 62 | 'train_loader': 'standard', 63 | 'uniform_over_groups': False, 64 | 'eval_loader': 'standard', 65 | 'self_training_lambda': 1, 66 | 'self_training_threshold': 0.7, 67 | 'pseudolabel_T2': 0.4, 68 | 'scheduler': 'FixMatchLR', 69 | 'randaugment_n': 2, 70 | 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples 71 | }, 72 | 'NoisyStudent': { 73 | 'train_loader': 'standard', 74 | 'uniform_over_groups': False, 75 | 'eval_loader': 'standard', 76 | 'noisystudent_add_dropout': True, 77 | 'noisystudent_dropout_rate': 0.5, 78 | 'scheduler': 'FixMatchLR', 79 | 'randaugment_n': 2, 80 | 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /experiments/camelyon17/evaluate/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.optim as optim 5 | import torch.utils.data as data 6 | from torchvision import transforms 7 | 8 | 9 | 10 | import mbtt 11 | from camelyon17_dataset import CamelyonDataset 12 | 13 | from conf import cfg, load_cfg_fom_args 14 | from models import initialize_model 15 | import random 16 | import numpy as np 17 | 18 | from tqdm import tqdm 19 | 20 | import pandas as pd 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | err_rates = [] 25 | 26 | torch.set_printoptions(precision=5) 27 | 28 | 29 | def evaluate(description): 30 | load_cfg_fom_args(description) 31 | # configure model 32 | base_model = initialize_model('densenet121', 2) 33 | checkpoint = torch.load('best.pth') 34 | base_model.load_state_dict(checkpoint['state_dict']) 35 | base_model = torch.nn.DataParallel(base_model) 36 | base_model = base_model.cuda() 37 | 38 | model = setup_mbtt(base_model) 39 | 40 | try: 41 | model.reset() 42 | logger.info("resetting model") 43 | except: 44 | logger.warning("not resetting model") 45 | normalize = transforms.Normalize([0.485, 0.456, 0.406], 46 | [0.229, 0.224, 0.225]) 47 | transform=transforms.Compose([ transforms.Resize((96, 96)), 48 | transforms.ToTensor(), 49 | normalize, 50 | ]) 51 | dataset = CamelyonDataset(None, transform, 'test') 52 | print(len(dataset)) 53 | test_loader = data.DataLoader(dataset, 54 | batch_size=200, 55 | shuffle=True, 56 | num_workers=8) 57 | print('done') 58 | acc = 0 59 | with torch.no_grad(): 60 | for i, (x, y) in enumerate(tqdm(test_loader)): 61 | model.gt = y 62 | x, y = x.cuda(), y.cuda() 63 | output = model(x) 64 | acc += (output.max(1)[1] == y).float().sum() 65 | print(acc) 66 | 67 | 68 | acc = acc.item() / len(dataset) 69 | err = acc 70 | err_rates.append(err) 71 | avg_err = sum(err_rates)/len(err_rates) 72 | logging.info("\n error: {:6f}".format(avg_err)) 73 | 74 | 75 | def setup_mbtt(model): 76 | model = mbtt.configure_model(model) 77 | 78 | params, param_names = mbtt.collect_params(model) 79 | optimizer = setup_optimizer(params) 80 | mbtt_model = mbtt.Tent(model, optimizer, cfg, 81 | steps=cfg.OPTIM.STEPS, 82 | episodic=cfg.MODEL.EPISODIC) 83 | logger.info(f"optimizer for adaptation: %s", optimizer) 84 | return mbtt_model 85 | 86 | 87 | def setup_optimizer(params): 88 | 89 | if cfg.OPTIM.METHOD == 'Adam': 90 | return optim.Adam(params, 91 | lr=cfg.OPTIM.LR, 92 | betas=(cfg.OPTIM.BETA, 0.999), 93 | weight_decay=cfg.OPTIM.WD) 94 | elif cfg.OPTIM.METHOD == 'SGD': 95 | return optim.SGD(params, 96 | lr=cfg.OPTIM.LR, 97 | momentum=cfg.OPTIM.MOMENTUM, 98 | dampening=cfg.OPTIM.DAMPENING, 99 | weight_decay=cfg.OPTIM.WD, 100 | nesterov=cfg.OPTIM.NESTEROV) 101 | else: 102 | raise NotImplementedError 103 | 104 | def setup_determinism(seed): 105 | torch.backends.cudnn.deterministic = True 106 | torch.backends.cudnn.benchmark = False 107 | torch.manual_seed(seed) 108 | np.random.seed(seed) 109 | random.seed(seed) 110 | 111 | if __name__ == '__main__': 112 | setup_determinism(0) 113 | evaluate('evaluate') 114 | -------------------------------------------------------------------------------- /experiments/prostate/main.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import logging 3 | import pandas as pd 4 | 5 | from torch.utils import data 6 | from torch.utils.data import dataset 7 | base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | sys.path.append(base_path) 9 | import numpy as np 10 | import random 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import torch.backends.cudnn as cudnn 15 | 16 | from train.configs import set_configs 17 | from data.prostate.generate_data import load_prostate 18 | from train.api import API 19 | from train.model_trainer_segmentation import ModelTrainerSegmentation 20 | 21 | 22 | 23 | def deterministic(seed): 24 | cudnn.benchmark = False 25 | cudnn.deterministic = True 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | random.seed(seed) 30 | 31 | 32 | def set_paths(args): 33 | args.save_path = './checkpoint/' 34 | if not os.path.exists(args.save_path): 35 | os.makedirs(args.save_path) 36 | 37 | def custom_model_trainer(args): 38 | 39 | 40 | from nets.unet import UNet3D 41 | model = UNet3D(in_channels=1, out_channels=2, layer_order='cbr') 42 | model_trainer = ModelTrainerSegmentation(model, args) 43 | 44 | return model_trainer 45 | 46 | def custom_dataset(args): 47 | 48 | datasets = load_prostate(args) 49 | return datasets 50 | 51 | def custom_api(args, model_trainer, datasets): 52 | device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu') 53 | api = API(datasets, device, args, model_trainer) 54 | return api 55 | 56 | 57 | if __name__ == "__main__": 58 | args = set_configs() 59 | args.generalize = False 60 | deterministic(args.seed) 61 | set_paths(args) 62 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 63 | log_path = args.save_path.replace('checkpoint', 'log') 64 | if not os.path.exists(log_path): os.makedirs(log_path) 65 | log_path = log_path+'/log.txt' if args.log else './log.txt' 66 | logging.basicConfig(filename=log_path, level=logging.INFO, 67 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 68 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 69 | logging.info(str(args)) 70 | 71 | model_trainer = custom_model_trainer(args) 72 | datasets = custom_dataset(args) 73 | manager = custom_api(args, model_trainer, datasets) 74 | if args.ood_test: 75 | 76 | global_round = 99 77 | ckpt = torch.load(f'/research/dept5/mrjiang/hzyang/3d/checkpoint/prostate/{args.target}/seed0/fedavg_global_round{global_round}') 78 | 79 | model_trainer.set_model_params(ckpt) 80 | print('Finish intialization') 81 | 82 | metrics = manager.ood_client.test_time_adaptation(None) 83 | 84 | elif args.test: 85 | from tqdm import tqdm 86 | rounds = [i for i in range(500)] 87 | ood_performance = {"before":[]} 88 | for epoch in tqdm(rounds): 89 | ckpt = torch.load(f'/research/pheng4/qdliu/hzyang/prostate/test_time/checkpoint/prostate/I2CVB/seed0/fedavg/fedavg_global_round{epoch}') 90 | model_trainer.set_model_params(ckpt) 91 | test_data = datasets[1][-1] 92 | metrics = model_trainer.test(test_data, manager.device, args) 93 | test_acc = metrics["test_acc"] 94 | ood_performance['before'].append(test_acc) 95 | ood_performance_pd = pd.DataFrame.from_dict(ood_performance) 96 | ood_performance_pd.to_csv(f"{args.target}_ood_performance_cbr_prostate_center.csv") 97 | 98 | else: 99 | manager.train() 100 | 101 | 102 | -------------------------------------------------------------------------------- /experiments/camelyon17/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau, StepLR, CosineAnnealingLR, MultiStepLR 2 | 3 | def initialize_scheduler(config, optimizer, n_train_steps): 4 | # construct schedulers 5 | if config.scheduler is None: 6 | return None 7 | elif config.scheduler == 'linear_schedule_with_warmup': 8 | from transformers import get_linear_schedule_with_warmup 9 | scheduler = get_linear_schedule_with_warmup( 10 | optimizer, 11 | num_training_steps=n_train_steps, 12 | **config.scheduler_kwargs) 13 | step_every_batch = True 14 | use_metric = False 15 | elif config.scheduler == 'cosine_schedule_with_warmup': 16 | from transformers import get_cosine_schedule_with_warmup 17 | scheduler = get_cosine_schedule_with_warmup( 18 | optimizer, 19 | num_training_steps=n_train_steps, 20 | **config.scheduler_kwargs) 21 | step_every_batch = True 22 | use_metric = False 23 | elif config.scheduler=='ReduceLROnPlateau': 24 | assert config.scheduler_metric_name, f'scheduler metric must be specified for {config.scheduler}' 25 | scheduler = ReduceLROnPlateau( 26 | optimizer, 27 | **config.scheduler_kwargs) 28 | step_every_batch = False 29 | use_metric = True 30 | elif config.scheduler == 'StepLR': 31 | scheduler = StepLR(optimizer, **config.scheduler_kwargs) 32 | step_every_batch = False 33 | use_metric = False 34 | elif config.scheduler == 'FixMatchLR': 35 | scheduler = LambdaLR( 36 | optimizer, 37 | lambda x: (1.0 + 10 * float(x) / n_train_steps) ** -0.75 38 | ) 39 | step_every_batch = True 40 | use_metric = False 41 | elif config.scheduler == 'MultiStepLR': 42 | scheduler = MultiStepLR(optimizer, **config.scheduler_kwargs) 43 | step_every_batch = False 44 | use_metric = False 45 | else: 46 | raise ValueError(f'Scheduler: {config.scheduler} not supported.') 47 | 48 | # add an step_every_batch field 49 | scheduler.step_every_batch = step_every_batch 50 | scheduler.use_metric = use_metric 51 | return scheduler 52 | 53 | def step_scheduler(scheduler, metric=None): 54 | if isinstance(scheduler, ReduceLROnPlateau): 55 | assert metric is not None 56 | scheduler.step(metric) 57 | else: 58 | scheduler.step() 59 | 60 | class LinearScheduleWithWarmupAndThreshold(): 61 | """ 62 | Linear scheduler with warmup and threshold for non lr parameters. 63 | Parameters is held at 0 until some T1, linearly increased until T2, and then held 64 | at some max value after T2. 65 | Designed to be called by step_scheduler() above and used within Algorithm class. 66 | Args: 67 | - last_warmup_step: aka T1. for steps [0, T1) keep param = 0 68 | - threshold_step: aka T2. step over period [T1, T2) to reach param = max value 69 | - max value: end value of the param 70 | """ 71 | def __init__(self, max_value, last_warmup_step=0, threshold_step=1, step_every_batch=False): 72 | self.max_value = max_value 73 | self.T1 = last_warmup_step 74 | self.T2 = threshold_step 75 | assert (0 <= self.T1) and (self.T1 < self.T2) 76 | 77 | # internal tracker of which step we're on 78 | self.current_step = 0 79 | self.value = 0 80 | 81 | # required fields called in Algorithm when stepping schedulers 82 | self.step_every_batch = step_every_batch 83 | self.use_metric = False 84 | 85 | def step(self): 86 | """This function is first called AFTER step 0, so increment first to set value for next step""" 87 | self.current_step += 1 88 | if self.current_step < self.T1: 89 | self.value = 0 90 | elif self.current_step < self.T2: 91 | self.value = (self.current_step - self.T1) / (self.T2 - self.T1) * self.max_value 92 | else: 93 | self.value = self.max_value 94 | -------------------------------------------------------------------------------- /experiments/camelyon17/data_augmentation/randaugment.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/YBZh/Bridging_UDA_SSL 2 | 3 | import torch 4 | from PIL import Image, ImageOps, ImageEnhance, ImageDraw 5 | 6 | 7 | def AutoContrast(img, _): 8 | return ImageOps.autocontrast(img) 9 | 10 | 11 | def Brightness(img, v): 12 | assert v >= 0.0 13 | return ImageEnhance.Brightness(img).enhance(v) 14 | 15 | 16 | def Color(img, v): 17 | assert v >= 0.0 18 | return ImageEnhance.Color(img).enhance(v) 19 | 20 | 21 | def Contrast(img, v): 22 | assert v >= 0.0 23 | return ImageEnhance.Contrast(img).enhance(v) 24 | 25 | 26 | def Equalize(img, _): 27 | return ImageOps.equalize(img) 28 | 29 | 30 | def Invert(img, _): 31 | return ImageOps.invert(img) 32 | 33 | 34 | def Identity(img, v): 35 | return img 36 | 37 | 38 | def Posterize(img, v): # [4, 8] 39 | v = int(v) 40 | v = max(1, v) 41 | return ImageOps.posterize(img, v) 42 | 43 | 44 | def Rotate(img, v): # [-30, 30] 45 | return img.rotate(v) 46 | 47 | 48 | def Sharpness(img, v): # [0.1,1.9] 49 | assert v >= 0.0 50 | return ImageEnhance.Sharpness(img).enhance(v) 51 | 52 | 53 | def ShearX(img, v): # [-0.3, 0.3] 54 | return img.transform(img.size, Image.AFFINE, (1, v, 0, 0, 1, 0)) 55 | 56 | 57 | def ShearY(img, v): # [-0.3, 0.3] 58 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, v, 1, 0)) 59 | 60 | 61 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 62 | v = v * img.size[0] 63 | return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) 64 | 65 | 66 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 67 | return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) 68 | 69 | 70 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 71 | v = v * img.size[1] 72 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) 73 | 74 | 75 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 76 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) 77 | 78 | 79 | def Solarize(img, v): # [0, 256] 80 | assert 0 <= v <= 256 81 | return ImageOps.solarize(img, v) 82 | 83 | 84 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] => change to [0, 0.5] 85 | assert 0.0 <= v <= 0.5 86 | 87 | v = v * img.size[0] 88 | return CutoutAbs(img, v) 89 | 90 | 91 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 92 | if v < 0: 93 | return img 94 | w, h = img.size 95 | x_center = _sample_uniform(0, w) 96 | y_center = _sample_uniform(0, h) 97 | 98 | x0 = int(max(0, x_center - v / 2.0)) 99 | y0 = int(max(0, y_center - v / 2.0)) 100 | x1 = min(w, x0 + v) 101 | y1 = min(h, y0 + v) 102 | 103 | xy = (x0, y0, x1, y1) 104 | color = (125, 123, 114) 105 | img = img.copy() 106 | ImageDraw.Draw(img).rectangle(xy, color) 107 | return img 108 | 109 | 110 | FIX_MATCH_AUGMENTATION_POOL = [ 111 | (AutoContrast, 0, 1), 112 | (Brightness, 0.05, 0.95), 113 | (Color, 0.05, 0.95), 114 | (Contrast, 0.05, 0.95), 115 | (Equalize, 0, 1), 116 | (Identity, 0, 1), 117 | (Posterize, 4, 8), 118 | (Rotate, -30, 30), 119 | (Sharpness, 0.05, 0.95), 120 | (ShearX, -0.3, 0.3), 121 | (ShearY, -0.3, 0.3), 122 | (Solarize, 0, 256), 123 | (TranslateX, -0.3, 0.3), 124 | (TranslateY, -0.3, 0.3), 125 | ] 126 | 127 | 128 | def _sample_uniform(a, b): 129 | return torch.empty(1).uniform_(a, b).item() 130 | 131 | 132 | class RandAugment: 133 | def __init__(self, n, augmentation_pool): 134 | assert n >= 1, "RandAugment N has to be a value greater than or equal to 1." 135 | self.n = n 136 | self.augmentation_pool = augmentation_pool 137 | 138 | def __call__(self, img): 139 | ops = [ 140 | self.augmentation_pool[torch.randint(len(self.augmentation_pool), (1,))] 141 | for _ in range(self.n) 142 | ] 143 | for op, min_val, max_val in ops: 144 | val = min_val + float(max_val - min_val) * _sample_uniform(0, 1) 145 | img = op(img, val) 146 | cutout_val = _sample_uniform(0, 1) * 0.5 147 | img = Cutout(img, cutout_val) 148 | return img 149 | -------------------------------------------------------------------------------- /experiments/camelyon17/networks/models.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | """ 4 | The main CheXpert models implementation. 5 | Including: 6 | DenseNet-121 7 | """ 8 | 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from . import densenet 14 | 15 | class ViewFlatten(nn.Module): 16 | def __init__(self): 17 | super(ViewFlatten, self).__init__() 18 | 19 | def forward(self, x): 20 | return x.view(x.size(0), -1) 21 | 22 | class DenseNet121(nn.Module): 23 | """Model modified. 24 | The architecture of our model is the same as standard DenseNet121 25 | except the classifier layer which has an additional sigmoid function. 26 | """ 27 | def __init__(self, out_size, drop_rate=0.2): 28 | super(DenseNet121, self).__init__() 29 | 30 | self.densenet121 = densenet.densenet121(pretrained=True, drop_rate=drop_rate) 31 | num_ftrs = self.densenet121.classifier.in_features 32 | 33 | self.densenet121.classifier = nn.Sequential( 34 | nn.Linear(num_ftrs, out_size), 35 | #nn.Sigmoid() 36 | ) 37 | 38 | 39 | # Official init from torch repo. 40 | for m in self.densenet121.modules(): 41 | if isinstance(m, nn.Linear): 42 | nn.init.constant_(m.bias, 0) 43 | 44 | self.drop_rate = drop_rate 45 | self.drop_layer = nn.Dropout(p=drop_rate) 46 | 47 | self.fc = nn.Linear(1024, 4) 48 | self.flat = ViewFlatten() 49 | 50 | def forward(self, x): 51 | features = self.densenet121.features(x) 52 | out = F.relu(features, inplace=True) 53 | 54 | 55 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) 56 | #print(out.shape) 57 | ssh_logits = self.fc(self.flat(out)) 58 | 59 | if self.drop_rate > 0: 60 | out = self.drop_layer(out) 61 | self.activations = out 62 | 63 | out = self.densenet121.classifier(out) 64 | 65 | 66 | return self.activations, out, ssh_logits 67 | 68 | class DenseNet161(nn.Module): 69 | """Model modified. 70 | The architecture of our model is the same as standard DenseNet121 71 | except the classifier layer which has an additional sigmoid function. 72 | """ 73 | def __init__(self, out_size, mode, drop_rate=0): 74 | super(DenseNet161, self).__init__() 75 | assert mode in ('U-Ones', 'U-Zeros', 'U-MultiClass') 76 | self.densenet161 = densenet.densenet161(pretrained=True, drop_rate=drop_rate) 77 | num_ftrs = self.densenet161.classifier.in_features 78 | if mode in ('U-Ones', 'U-Zeros'): 79 | self.densenet161.classifier = nn.Sequential( 80 | nn.Linear(num_ftrs, out_size), 81 | #nn.Sigmoid() 82 | ) 83 | elif mode in ('U-MultiClass', ): 84 | self.densenet161.classifier = None 85 | self.densenet161.Linear_0 = nn.Linear(num_ftrs, out_size) 86 | self.densenet161.Linear_1 = nn.Linear(num_ftrs, out_size) 87 | self.densenet161.Linear_u = nn.Linear(num_ftrs, out_size) 88 | 89 | self.mode = mode 90 | 91 | # Official init from torch repo. 92 | for m in self.densenet161.modules(): 93 | if isinstance(m, nn.Linear): 94 | nn.init.constant_(m.bias, 0) 95 | 96 | self.drop_rate = drop_rate 97 | self.drop_layer = nn.Dropout(p=drop_rate) 98 | 99 | def forward(self, x): 100 | features = self.densenet161.features(x) 101 | out = F.relu(features, inplace=True) 102 | 103 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) 104 | 105 | if self.drop_rate > 0: 106 | out = self.drop_layer(out) 107 | self.activations = out 108 | 109 | if self.mode in ('U-Ones', 'U-Zeros'): 110 | out = self.densenet161.classifier(out) 111 | elif self.mode in ('U-MultiClass', ): 112 | n_batch = x.size(0) 113 | out_0 = self.densenet161.Linear_0(out).view(n_batch, 1, -1) 114 | out_1 = self.densenet161.Linear_1(out).view(n_batch, 1, -1) 115 | out_u = self.densenet161.Linear_u(out).view(n_batch, 1, -1) 116 | out = torch.cat((out_0, out_1, out_u), dim=1) 117 | 118 | return self.activations, out -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import matplotlib 3 | from matplotlib import pyplot as plt 4 | from scipy.io import savemat 5 | import pdb 6 | import argparse 7 | import functools 8 | import warnings 9 | import logging 10 | logging.basicConfig(level=logging.INFO) 11 | import os 12 | import shutil 13 | import time 14 | from scipy import misc 15 | import numpy as np 16 | import json,csv 17 | import time 18 | import torch 19 | import torch.backends.cudnn as cudnn 20 | cudnn.benchmark = True 21 | import torch.optim 22 | import torch.nn as nn 23 | from datasets import create_dataset 24 | from models import create_model 25 | from utils.util import setlogger, deterministic 26 | from config import * 27 | from memory import Memory 28 | 29 | buffer_size = 20 30 | 31 | 32 | 33 | def configure_model(model): 34 | model.train() 35 | for m in model.modules(): 36 | if isinstance(m, nn.BatchNorm2d): 37 | m.requires_grad_(True) 38 | m.track_running_stats = False 39 | return model 40 | 41 | 42 | def main(): 43 | warnings.filterwarnings('ignore') 44 | # set random seed 45 | deterministic() 46 | args = parser.parse_args() 47 | os.makedirs(args.results_dir, exist_ok=True) 48 | logger = logging.getLogger('global') 49 | logger = setlogger(logger,args) 50 | # build dataset 51 | # val_loader is a list of dataloader for a list of test subject 52 | train_loader, val_loader = create_dataset(args) 53 | logger.info('build dataset done') 54 | # build model 55 | model = create_model(args) 56 | #model.TNet = configure_model(model.TNet) 57 | #model.ANet = configure_model(model.ANet) 58 | memory = Memory(buffer_size) 59 | model.memory = memory 60 | model.buffer_size = buffer_size 61 | model.diff = [] 62 | 63 | logger.info('build model done') 64 | # logger.info(model) 65 | # evaluate 66 | if args.evaluate: 67 | logger.info('begin evaluation') 68 | val_loader[0].dataset.augment = False 69 | loss = validate(model, val_loader[0], args, 0, logger) 70 | return loss 71 | if args.test: 72 | # put test image in vimg_path 73 | logger.info('begin testing') 74 | model.retri_size = 8 75 | # train adaptor 76 | metric_adps = [] 77 | for sub in range(len(val_loader)): 78 | logger.info('testing subject:{}/{}'.format(sub+1,len(val_loader))) 79 | val_loader[sub].dataset.augment = args.val_augment 80 | prev_loss = np.inf 81 | sub_metric_adp, sub_metric_nadp = [], [] 82 | start_time = time.time() 83 | model.set_opt() 84 | for epoch in range(args.tepochs): 85 | m_loss = 0 86 | for iters, data in enumerate(val_loader[sub]): 87 | model.set_input(data) 88 | loss = model.opt_ANet(epoch) 89 | logger.info('[{}/{}][{}/{}] Adaptor Loss: {}'.format(\ 90 | epoch+1, args.tepochs, iters, len(val_loader[sub]), loss)) 91 | m_loss += np.sum(loss)/len(val_loader[sub]) 92 | 93 | start_time = time.time() 94 | # turn off augmentation on test inference 95 | val_loader[sub].dataset.augment = False 96 | # allow 3D metric calculation 97 | labels, preds = [], [] 98 | for iters, data in enumerate(val_loader[sub]): 99 | model.set_input(data) 100 | _metric, pred = model.test(return_pred=True) 101 | labels.append(model.label) 102 | preds.append(pred) 103 | 104 | 105 | labels = torch.stack(labels) 106 | preds = torch.stack(preds) 107 | 108 | _metric_adps = model.cal_metric3d(preds.view(-1,preds.shape[-3],preds.shape[-2], preds.shape[-1]),\ 109 | labels.view(-1,labels.shape[-2], labels.shape[-1])) 110 | 111 | metric_adps.append(_metric_adps) 112 | 113 | 114 | metric_adps = np.vstack(metric_adps) 115 | logger.info('Overall 3D mean metric adp:\n{}[{}]'.\ 116 | format(str(np.nanmean(metric_adps,axis=0)).replace('\n',''),\ 117 | np.nanmean(metric_adps))) 118 | 119 | return 120 | 121 | if __name__ == '__main__': 122 | main() 123 | 124 | -------------------------------------------------------------------------------- /experiments/camelyon17/evaluate/mbtt.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.jit 6 | import torch.nn.functional as F 7 | 8 | from MbPA import ReplayMemory 9 | 10 | import numpy as np 11 | 12 | torch.set_printoptions(precision=5) 13 | 14 | buffer_size = 800 15 | 16 | class Tent(nn.Module): 17 | 18 | def __init__(self, model, optimizer, cfg, steps=1, episodic=False): 19 | super().__init__() 20 | self.model = model 21 | self.optimizer = optimizer 22 | self.steps = steps 23 | self.episodic = episodic 24 | self.memory = ReplayMemory(buffer_size) 25 | self.mse = nn.MSELoss() 26 | self.cfg = cfg 27 | self.model_state, self.optimizer_state = \ 28 | copy_model_and_optimizer(self.model, self.optimizer) 29 | self.gt = None 30 | 31 | 32 | def forward(self, x): 33 | if self.episodic: 34 | self.reset() 35 | print('Image-specific') 36 | 37 | #for _ in range(self.steps): 38 | outputs = forward_and_adapt(x, self.model, self.optimizer, self.memory, self.mse, self.gt, self.cfg) 39 | 40 | return outputs 41 | 42 | def reset(self): 43 | if self.model_state is None or self.optimizer_state is None: 44 | raise Exception("cannot reset without saved model/optimizer state") 45 | self.memory = ReplayMemory(buffer_size) 46 | load_model_and_optimizer(self.model, self.optimizer, 47 | self.model_state, self.optimizer_state) 48 | 49 | 50 | @torch.jit.script 51 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 52 | """Entropy of softmax distribution from logits.""" 53 | return -((x).softmax(1) * (x).log_softmax(1)).sum(1) 54 | 55 | 56 | @torch.enable_grad() 57 | def forward_and_adapt(x, model, optimizer, memory, mse, gt, cfg): 58 | 59 | outputs = model(x) 60 | features = model.module.features(x) 61 | memory_size = memory.get_size() 62 | if memory_size > buffer_size: 63 | 64 | with torch.no_grad(): 65 | retrieved_batches = memory.get_neighbours(features.cpu().numpy(), k=4) #4 66 | pseudo_past_logits = retrieved_batches.cuda() 67 | pseudo_current_logits = outputs 68 | pseudo_past_labels = nn.functional.softmax(pseudo_past_logits, dim=1) 69 | pseudo_current_labels = nn.functional.softmax(pseudo_current_logits/2, dim=1) 70 | diff_loss = (F.kl_div(pseudo_past_labels.log(), pseudo_current_labels, None, None, 'none') + F.kl_div(pseudo_current_labels.log(), pseudo_past_labels, None, None, 'none')) / 2 71 | diff_loss = torch.sum(diff_loss, dim=1) 72 | diff_loss = diff_loss.cpu().numpy().tolist() 73 | diff_loss = sum(diff_loss) / len(diff_loss) 74 | 75 | for param_group in optimizer.param_groups: 76 | param_group['lr'] = diff_loss * cfg.OPTIM.LR 77 | 78 | 79 | loss = softmax_entropy(outputs) 80 | loss = loss.mean(0) 81 | loss.backward() 82 | optimizer.step() 83 | optimizer.zero_grad() 84 | outputs = model(x) 85 | 86 | with torch.no_grad(): 87 | pseudo_logits = model(x) 88 | keys = model.module.features(x) 89 | memory.push(keys.cpu().numpy(), pseudo_logits.cpu().numpy()) 90 | 91 | 92 | return outputs 93 | 94 | 95 | 96 | 97 | def collect_params(model): 98 | 99 | params = [] 100 | names = [] 101 | for nm, m in model.named_modules(): 102 | if isinstance(m, nn.BatchNorm2d): 103 | for np, p in m.named_parameters(): 104 | if np in ['weight', 'bias']: 105 | params.append(p) 106 | names.append(f"{nm}.{np}") 107 | return params, names 108 | 109 | 110 | def copy_model_and_optimizer(model, optimizer): 111 | model_state = deepcopy(model.state_dict()) 112 | optimizer_state = deepcopy(optimizer.state_dict()) 113 | return model_state, optimizer_state 114 | 115 | 116 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): 117 | model.load_state_dict(model_state, strict=True) 118 | optimizer.load_state_dict(optimizer_state) 119 | 120 | 121 | def configure_model(model): 122 | 123 | model.train() 124 | model.requires_grad_(False) 125 | for m in model.modules(): 126 | if isinstance(m, nn.BatchNorm2d): 127 | m.requires_grad_(True) 128 | m.track_running_stats = False 129 | m.running_mean = None 130 | m.running_var = None 131 | return model -------------------------------------------------------------------------------- /experiments/camelyon17/evaluate/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Configuration file (powered by YACS).""" 7 | 8 | import argparse 9 | import os 10 | import sys 11 | import logging 12 | import random 13 | import torch 14 | import numpy as np 15 | from datetime import datetime 16 | from iopath.common.file_io import g_pathmgr 17 | from yacs.config import CfgNode as CfgNode 18 | 19 | 20 | # Global config object (example usage: from core.config import cfg) 21 | _C = CfgNode() 22 | cfg = _C 23 | 24 | 25 | # ----------------------------- Model options ------------------------------- # 26 | _C.MODEL = CfgNode() 27 | 28 | _C.MODEL.EPISODIC = False 29 | 30 | # ------------------------------- Batch norm options ------------------------ # 31 | _C.BN = CfgNode() 32 | 33 | # BN epsilon 34 | _C.BN.EPS = 1e-3 35 | 36 | # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2) 37 | _C.BN.MOM = 0.1 38 | 39 | # ------------------------------- Optimizer options ------------------------- # 40 | _C.OPTIM = CfgNode() 41 | 42 | # Number of updates per batch 43 | _C.OPTIM.STEPS = 1 44 | 45 | # Learning rate 46 | _C.OPTIM.LR = 1e-3 47 | 48 | # Choices: Adam, SGD 49 | _C.OPTIM.METHOD = 'Adam' 50 | 51 | # Beta 52 | _C.OPTIM.BETA = 0.9 53 | 54 | # Momentum 55 | _C.OPTIM.MOMENTUM = 0.9 56 | 57 | # Momentum dampening 58 | _C.OPTIM.DAMPENING = 0.0 59 | 60 | # Nesterov momentum 61 | _C.OPTIM.NESTEROV = True 62 | 63 | # L2 regularization 64 | _C.OPTIM.WD = 0 65 | # ------------------------------- Testing options --------------------------- # 66 | _C.TEST = CfgNode() 67 | 68 | # Batch size for evaluation (and updates for norm + tent) 69 | _C.TEST.BATCH_SIZE = 200 70 | 71 | # --------------------------------- CUDNN options --------------------------- # 72 | _C.CUDNN = CfgNode() 73 | 74 | # Benchmark to select fastest CUDNN algorithms (best for fixed input sizes) 75 | _C.CUDNN.BENCHMARK = False 76 | 77 | # ---------------------------------- Misc options --------------------------- # 78 | 79 | # Optional description of a config 80 | _C.DESC = "" 81 | 82 | # Note that non-determinism is still present due to non-deterministic GPU ops 83 | _C.RNG_SEED = 1 84 | 85 | # Output directory 86 | _C.SAVE_DIR = "./output" 87 | 88 | # Data directory 89 | _C.DATA_DIR = "./data" 90 | 91 | # Weight directory 92 | _C.CKPT_DIR = "./ckpt" 93 | 94 | # Log destination (in SAVE_DIR) 95 | _C.LOG_DEST = "log.txt" 96 | 97 | # Log datetime 98 | _C.LOG_TIME = '' 99 | 100 | # # Config destination (in SAVE_DIR) 101 | # _C.CFG_DEST = "cfg.yaml" 102 | 103 | # --------------------------------- Default config -------------------------- # 104 | _CFG_DEFAULT = _C.clone() 105 | _CFG_DEFAULT.freeze() 106 | 107 | 108 | def assert_and_infer_cfg(): 109 | """Checks config values invariants.""" 110 | err_str = "Unknown adaptation method." 111 | assert _C.MODEL.ADAPTATION in ["source", "norm", "tent"] 112 | err_str = "Log destination '{}' not supported" 113 | assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST) 114 | 115 | 116 | def merge_from_file(cfg_file): 117 | with g_pathmgr.open(cfg_file, "r") as f: 118 | cfg = _C.load_cfg(f) 119 | _C.merge_from_other_cfg(cfg) 120 | 121 | 122 | def dump_cfg(): 123 | """Dumps the config to the output directory.""" 124 | cfg_file = os.path.join(_C.SAVE_DIR, _C.CFG_DEST) 125 | with g_pathmgr.open(cfg_file, "w") as f: 126 | _C.dump(stream=f) 127 | 128 | 129 | def load_cfg(out_dir, cfg_dest="config.yaml"): 130 | """Loads config from specified output directory.""" 131 | cfg_file = os.path.join(out_dir, cfg_dest) 132 | merge_from_file(cfg_file) 133 | 134 | 135 | def reset_cfg(): 136 | """Reset config to initial state.""" 137 | cfg.merge_from_other_cfg(_CFG_DEFAULT) 138 | 139 | 140 | def load_cfg_fom_args(description="Config options."): 141 | """Load config from command line args and set any specified options.""" 142 | current_time = datetime.now().strftime("%y%m%d_%H%M%S") 143 | parser = argparse.ArgumentParser(description=description) 144 | 145 | args = parser.parse_args() 146 | 147 | 148 | 149 | 150 | g_pathmgr.mkdirs(cfg.SAVE_DIR) 151 | 152 | cfg.freeze() 153 | 154 | logging.basicConfig( 155 | level=logging.INFO, 156 | format="[%(asctime)s] [%(filename)s: %(lineno)4d]: %(message)s", 157 | datefmt="%y/%m/%d %H:%M:%S", 158 | handlers=[ 159 | logging.FileHandler(os.path.join(cfg.SAVE_DIR, cfg.LOG_DEST)), 160 | logging.StreamHandler() 161 | ]) 162 | 163 | np.random.seed(cfg.RNG_SEED) 164 | torch.manual_seed(cfg.RNG_SEED) 165 | random.seed(cfg.RNG_SEED) 166 | torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK 167 | 168 | logger = logging.getLogger(__name__) 169 | version = [torch.__version__, torch.version.cuda, 170 | torch.backends.cudnn.version()] 171 | -------------------------------------------------------------------------------- /experiments/camelyon17/evaluate/camelyon17_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pandas as pd 4 | from PIL import Image 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | import torch.nn.functional as F 8 | 9 | class CamelyonDataset(Dataset): 10 | def __init__(self, root_dir, transform, split): 11 | """ 12 | Args: 13 | data_dir: path to image directory. 14 | csv_file: path to the file containing images 15 | with corresponding labels. 16 | transform: optional transform to be applied on a sample. 17 | """ 18 | super(CamelyonDataset, self).__init__() 19 | self.data_dir = '/research/pheng4/qdliu/hzyang/test_time_medical/dataset/camelyon17_v1.0/' 20 | self.original_resolution = (96,96) 21 | 22 | # Read in metadata 23 | self.metadata_df = pd.read_csv( 24 | os.path.join(self.data_dir, 'metadata.csv'), 25 | index_col=0, 26 | dtype={'patient': 'str'}) 27 | 28 | # Get the y values 29 | self.y_array = torch.LongTensor(self.metadata_df['tumor'].values) 30 | self.y_size = 1 31 | self.n_classes = 2 32 | 33 | # Get filenames 34 | self.input_array = [ 35 | f'patches/patient_{patient}_node_{node}/patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png' 36 | for patient, node, x, y in 37 | self.metadata_df.loc[:, ['patient', 'node', 'x_coord', 'y_coord']].itertuples(index=False, name=None)] 38 | 39 | # Extract splits 40 | # Note that the hospital numbering here is different from what's in the paper, 41 | # where to avoid confusing readers we used a 1-indexed scheme and just labeled the test hospital as 5. 42 | # Here, the numbers are 0-indexed. 43 | test_center = 2 44 | val_center = 1 45 | 46 | self.split_dict = { 47 | 'train': 0, 48 | 'id_val': 1, 49 | 'test': 2, 50 | 'val': 3 51 | } 52 | self.split_names = { 53 | 'train': 'Train', 54 | 'id_val': 'Validation (ID)', 55 | 'test': 'Test', 56 | 'val': 'Validation (OOD)', 57 | } 58 | centers = self.metadata_df['center'].values.astype('long') 59 | num_centers = int(np.max(centers)) + 1 60 | val_center_mask = (self.metadata_df['center'] == val_center) 61 | test_center_mask = (self.metadata_df['center'] == test_center) 62 | self.metadata_df.loc[val_center_mask, 'split'] = self.split_dict['val'] 63 | self.metadata_df.loc[test_center_mask, 'split'] = self.split_dict['test'] 64 | ''' 65 | self._split_scheme = split_scheme 66 | if self._split_scheme == 'official': 67 | pass 68 | elif self._split_scheme == 'in-dist': 69 | # For the in-distribution oracle, 70 | # we move slide 23 (corresponding to patient 042, node 3 in the original dataset) 71 | # from the test set to the training set 72 | slide_mask = (self._metadata_df['slide'] == 23) 73 | self._metadata_df.loc[slide_mask, 'split'] = self.split_dict['train'] 74 | else: 75 | raise ValueError(f'Split scheme {self._split_scheme} not recognized') 76 | ''' 77 | self.split_array = self.metadata_df['split'].values 78 | split_mask = self.split_array == self.split_dict[split] 79 | split_idx = np.where(split_mask)[0] 80 | self.y_array = self.y_array[split_idx] 81 | print(split_idx) 82 | tmp = [] 83 | for idx in split_idx: 84 | tmp.append(self.input_array[idx]) 85 | 86 | self.input_array = tmp #self.input_array[split_idx] 87 | ''' 88 | self.metadata_array = torch.stack( 89 | (torch.LongTensor(centers), 90 | torch.LongTensor(self.metadata_df['slide'].values), 91 | self.y_array), 92 | dim=1) 93 | self.metadata_fields = ['hospital', 'slide', 'y'] 94 | self._eval_grouper = CombinatorialGrouper( 95 | dataset=self, 96 | groupby_fields=['slide']) 97 | ''' 98 | self.transform = transform 99 | 100 | print('Total # images:{}, labels:{}'.format(len(self.input_array),len(self.y_array))) 101 | 102 | def __getitem__(self, index): 103 | """ 104 | Args: 105 | index: the index of item 106 | Returns: 107 | image and its labels 108 | """ 109 | img_filename = os.path.join( 110 | self.data_dir, 111 | self.input_array[index]) 112 | x = Image.open(img_filename).convert('RGB') 113 | y = self.y_array[index] 114 | #y = F.one_hot(y, num_classes=2) 115 | if self.transform is not None: 116 | x = self.transform(x) 117 | return x, y 118 | 119 | def __len__(self): 120 | return len(self.y_array) 121 | 122 | 123 | -------------------------------------------------------------------------------- /models/adapmodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import copy 7 | import pdb 8 | from models.backends import UNet, ConvBlock 9 | from utils.util import ncc, l2_reg_ortho 10 | from matplotlib import pyplot as plt 11 | import matplotlib.colors as mcolors 12 | import logging 13 | import tifffile as tiff 14 | from functools import partial 15 | logger = logging.getLogger('global') 16 | 17 | def init_weights(x): 18 | torch.manual_seed(0) 19 | if type(x) == nn.Conv2d: 20 | nn.init.kaiming_normal_(x.weight.data) 21 | nn.init.zeros_(x.bias.data) 22 | 23 | def init_weights_eye(x, channel=64): 24 | # indentity init, only works for same input/output channel 25 | if type(x) == nn.Conv2d: 26 | eye = nn.init.eye_(torch.empty(x.weight.shape[0], x.weight.shape[1])).unsqueeze(-1).unsqueeze(-1) 27 | init_bias = nn.init.zeros_(torch.empty(x.weight.shape[0])) 28 | x.weight.data = eye 29 | x.bias.data = init_bias 30 | 31 | class DTTAnorm(nn.Module): 32 | def __init__(self): 33 | super(DTTAnorm,self).__init__() 34 | self.conv1 = nn.Conv2d(1,16,3,padding=1) 35 | self.conv2 = nn.Conv2d(16,16,3,padding=1) 36 | self.conv3 = nn.Conv2d(16,1,3,padding=1) 37 | def forward(self,x): 38 | x_ = self.conv1(x) 39 | scale = (torch.randn([1,16,1,1]) * 0.05 + 0.2).to(x_.device) 40 | x_ = torch.exp(-(x_**2) / (scale**2)) 41 | x_ = self.conv2(x_) 42 | x_ = torch.exp(-(x_**2) / (scale**2)) 43 | x_ = self.conv3(x_) 44 | return x_ + x 45 | 46 | class ANet(nn.Module): 47 | def __init__(self, adpNet=None, tnet_dim=[1,64,64,64,64,1], seq=None): 48 | """Adaptor Net: Default is for a 4-level UNet with fixed 64 channels 49 | Args: 50 | AENet: nn.Module, pre-trained auto-encoder for the source image 51 | channel: int, input feature channel of the affine transform 52 | seq: list->int, index of the affine matrix to be used 53 | """ 54 | super(ANet,self).__init__() 55 | self.conv = nn.ModuleList() 56 | feature_channel = tnet_dim[1:-1] 57 | self.channel = feature_channel + feature_channel[::-1] 58 | nums = len(self.channel) 59 | self.nums = nums 60 | if seq is None: 61 | self.seq = np.arange(nums) 62 | else: 63 | self.seq = seq 64 | # use pre-contrast manipulation 65 | self.adpNet = adpNet 66 | if adpNet is None: 67 | self.adpNet = nn.Sequential( 68 | nn.Conv2d(1,64,1), 69 | nn.LeakyReLU(negative_slope=0.2), 70 | nn.InstanceNorm2d(64), 71 | nn.Conv2d(64,64,1), 72 | nn.LeakyReLU(negative_slope=0.2), 73 | nn.InstanceNorm2d(64), 74 | nn.Conv2d(64,1,1), 75 | nn.LeakyReLU(negative_slope=0.2), 76 | nn.InstanceNorm2d(1) 77 | ) 78 | self.adpNet.apply(init_weights) 79 | # use feature affine transform 80 | for c in self.channel: 81 | convs = nn.Conv2d(c,c,1) 82 | self.conv.append(convs) 83 | self.conv.apply(init_weights_eye) 84 | def reset(self): 85 | # reset the fine-tuned weights for a new test subject 86 | np.random.seed(0) 87 | torch.manual_seed(0) 88 | self.conv.apply(init_weights_eye) 89 | self.adpNet.apply(init_weights) 90 | self.cuda() 91 | def forward(self, x, TNet, AENet, side_out=False): 92 | """ 93 | Forward for a 4-level UNet 94 | Args: 95 | TNet: nn.Module. The pretrained task network 96 | side_out: bool. If true, output every intermediate results 97 | seq: list->int or np array. Position of 1x1 convolution 98 | """ 99 | x = self.adpNet(x) 100 | xh = [x] 101 | x = TNet.inblocks(x) 102 | ct = 0 103 | # apply 1x1 conv on input blocks 104 | seq = self.seq 105 | if ct in seq: 106 | x = self.conv[ct](x) 107 | ct += 1 108 | xh.append(x) 109 | for i in range(TNet.depth): 110 | x = TNet.downblocks[i](x) 111 | # apply 1x1 conv on every downsample output except bottleneck 112 | if ct in seq: 113 | x = self.conv[ct](x) 114 | ct += 1 115 | xh.append(x) 116 | x = TNet.bottleneck(x) 117 | if ct in seq: 118 | x = self.conv[ct](x) 119 | ct += 1 120 | xh.append(x) 121 | for i in range(TNet.depth): 122 | x = TNet.upblocks[TNet.depth-i-1](x,xh[TNet.depth-i]) 123 | if ct in seq: 124 | x = self.conv[ct](x) 125 | ct += 1 126 | xh.append(x) 127 | x = TNet.outblock(x) 128 | xh.append(x) 129 | if side_out: 130 | return xh 131 | else: 132 | return x -------------------------------------------------------------------------------- /models/segmodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import pdb 7 | from models.backends import UNet, ConvBlock 8 | from utils.util import ncc, grams 9 | from matplotlib import pyplot as plt 10 | import matplotlib.colors as mcolors 11 | import logging 12 | logger = logging.getLogger('global') 13 | from models.basemodel import tonp, AdaptorNet 14 | from functools import partial 15 | 16 | 17 | 18 | def configure_model(model): 19 | model.train() 20 | for m in model.modules(): 21 | if isinstance(m, nn.BatchNorm2d): 22 | m.requires_grad_(True) 23 | m.track_running_stats = False 24 | return model 25 | 26 | 27 | 28 | class SegANet(AdaptorNet): 29 | def __init__(self, opt): 30 | super(SegANet,self).__init__(opt) 31 | def def_loss(self): 32 | if self.opt.segloss == 'ce': 33 | self.TLoss = self.ce_loss 34 | elif self.opt.segloss == 'dice': 35 | self.TLoss = partial(self.dice_loss, softmax=True) 36 | self.AELoss = nn.MSELoss() 37 | if self.opt.segaeloss == 'mse': 38 | self.AELoss_out = nn.MSELoss() 39 | elif self.opt.segaeloss == 'dice': 40 | self.AELoss_out = self.dice_loss 41 | def ce_loss(self, pred, label): 42 | if pred.shape != label.shape: 43 | return F.cross_entropy(pred, label) 44 | else: 45 | return F.cross_entropy(pred, torch.argmax(label, 1)) 46 | def dice_loss(self, pred, label, softmax=False): 47 | """Calculate dice loss: 48 | Args: 49 | pred: [batch, channel, img_row,img_col] 50 | label:[batch, img_row, img_col] or the same size with pred 51 | """ 52 | if softmax: 53 | pred = F.softmax(pred,1) 54 | if pred.shape == label.shape: 55 | label_ = label 56 | else: 57 | label_ = pred.clone().detach().zero_() 58 | label_.scatter_(1, label.unsqueeze(1), 1) 59 | dice = (2*torch.sum(torch.sum(pred*label_,-1),-1)+0.001)/ \ 60 | (torch.sum(torch.sum(pred,-1),-1) + torch.sum(torch.sum(label_,-1),-1) + 0.001) 61 | return 1 - torch.mean(dice) 62 | 63 | def cal_metric3d(self, pred, label, ignore=[0,9,10], fg=False): 64 | """Calculate quantitative metric: Dice coefficient 65 | Args: 66 | pred: [ slice, channel, img_row, img_col] 67 | label: [ slice, img_row, img_col] 68 | ignore: ignore labels 69 | fg: return foreground dice 70 | Return: 71 | dice: list of dice coefficients 72 | """ 73 | if torch.is_tensor(pred): 74 | pred = pred.data.cpu().numpy() 75 | label = label.data.cpu().numpy() 76 | C = pred.shape[1] 77 | pred = np.argmax(pred, axis=1) 78 | if fg: 79 | C = 2 80 | pred[pred>0] = 1 81 | dice = np.zeros(C) 82 | for i in range(C): 83 | pl = pred == i 84 | ll = label == i 85 | sump = np.nansum(pl) 86 | sumt = np.nansum(ll) 87 | inter = np.nansum(pl*ll) 88 | if sumt == 0: 89 | dice[i] = np.nan 90 | else: 91 | dice[i] = 2*inter/(sump+sumt) 92 | dice = np.delete(dice,ignore) 93 | return dice 94 | def cal_metric(self, pred, label, ignore=[0,9,10]): 95 | """Calculate quantitative metric: Dice coefficient 96 | Args: 97 | pred: [batch, channel, img_row, img_col] 98 | label: [batch, img_row, img_col] 99 | ignore: ignore labels 100 | Return: 101 | dice: list of dice coefficients, [batch] 102 | """ 103 | if torch.is_tensor(pred): 104 | pred = pred.data.cpu().numpy() 105 | label = label.data.cpu().numpy() 106 | dice = [] 107 | batch_size, C = pred.shape[:2] 108 | pred = np.argmax(pred,1) 109 | for b_id in range(batch_size): 110 | _dice = np.zeros(C) 111 | for i in range(C): 112 | pl = pred[b_id,:,:] == i 113 | ll = label[b_id,:,:] == i 114 | sump = np.nansum(pl) 115 | sumt = np.nansum(ll) 116 | inter = np.nansum(pl*ll) 117 | if sumt == 0: 118 | _dice[i] = np.nan 119 | else: 120 | _dice[i] = 2*inter/(sump+sumt) 121 | _dice = np.delete(_dice,ignore) 122 | dice.append(_dice) 123 | return dice 124 | def test(self, return_pred=False): 125 | """Test using ANet and TNet 126 | """ 127 | 128 | 129 | self.TNet.train() 130 | self.ANet.train() 131 | with torch.no_grad(): 132 | pred = self.ANet(self.image, self.TNet, self.AENet[0], side_out=True) 133 | pred, adapt_img = pred[-1], pred[0] 134 | pred = F.softmax(pred, dim=1) 135 | 136 | 137 | metric = [[],[]] 138 | if self.opt.__dict__.get('cal_metric',True): 139 | metric[0] = self.cal_metric(pred, self.label.squeeze(1)) 140 | 141 | if return_pred: 142 | return metric, pred 143 | else: 144 | return metric 145 | -------------------------------------------------------------------------------- /models/backends.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn.functional as F 4 | import functools 5 | import torch.nn as nn 6 | import math 7 | import pdb 8 | import logging 9 | import numpy as np 10 | 11 | class _ResConvBlock(nn.Module): 12 | def __init__(self, inplane, outplane, kernel=(3,3), **kwargs): 13 | super(_ResConvBlock,self).__init__() 14 | if isinstance(kernel, int): 15 | kernel = (kernel, kernel) 16 | pad = kwargs.get('pad', None) 17 | gpn = kwargs.get('gpn', False) 18 | isn = kwargs.get('isn', False) 19 | reflect = kwargs.get('reflect', False) 20 | stride = kwargs.get('stride',1) 21 | if pad is None: 22 | pad = (kernel[0]//2, kernel[1]//2) 23 | # use reflect padding to solve the boundary problem 24 | if reflect: 25 | self.reflectpad1 = nn.ReflectionPad2d((pad[0], pad[0], pad[1], pad[1])) 26 | self.reflectpad2 = nn.ReflectionPad2d((2*pad[0], 2*pad[0], 2*pad[1], 2*pad[1])) 27 | pad = (0, 0) 28 | self.conv = nn.Conv2d(inplane, outplane, kernel, stride=stride, padding=pad) 29 | # use group norm instead of batch norm 30 | if gpn: 31 | self.norm = functools.partial(nn.GroupNorm,4) 32 | elif isn: 33 | self.norm = nn.InstanceNorm2d 34 | else: 35 | self.norm = nn.BatchNorm2d 36 | self.resconv = nn.Sequential( 37 | nn.Conv2d(outplane, outplane, kernel, padding=pad), 38 | self.norm(outplane), 39 | nn.PReLU(), 40 | #nn.Dropout2d(p=0.2), 41 | nn.Conv2d(outplane, outplane, kernel, padding=pad), 42 | self.norm(outplane), 43 | #nn.Dropout2d(p=0.2) 44 | ) 45 | self.relu = nn.PReLU() 46 | 47 | def forward(self, x): 48 | if hasattr(self,'reflectpad1'): 49 | x = self.reflectpad1(x) 50 | residual = self.conv(x) 51 | if hasattr(self,'reflectpad2'): 52 | temp = self.reflectpad2(residual) 53 | else: 54 | temp = residual.clone() 55 | x = self.resconv(temp) + residual 56 | x = self.relu(x) 57 | return x 58 | 59 | ConvBlock = _ResConvBlock 60 | 61 | class UpBlock(nn.Module): 62 | def __init__(self, inplane, outplane, kernel=(3,3), **kwargs): 63 | super(UpBlock,self).__init__() 64 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 65 | self.skip=kwargs.get('skip', True) 66 | self.uconv = ConvBlock(inplane, outplane, kernel, **kwargs) 67 | 68 | def forward(self, xl, xh=None): 69 | if xh is not None: 70 | x = F.upsample(xl,size=xh.shape[-2:],mode='bilinear',align_corners=True) 71 | if self.skip: 72 | x = torch.cat([x, xh], dim=1) 73 | else: 74 | x = self.up(xl) 75 | x = self.uconv(x) 76 | return x 77 | 78 | class DownBlock(nn.Module): 79 | def __init__(self, inplane, outplane, kernel=(3,3), **kwargs): 80 | super(DownBlock,self).__init__() 81 | stride = kwargs.get('down_stride',None) 82 | if stride: 83 | kwargs['stride'] = stride 84 | self.dconv = ConvBlock(inplane, outplane, kernel, **kwargs) 85 | else: 86 | self.dconv = nn.Sequential( 87 | nn.MaxPool2d(2), 88 | ConvBlock(inplane, outplane, kernel, **kwargs) 89 | ) 90 | def forward(self, x): 91 | x = self.dconv(x) 92 | return x 93 | 94 | class UNet(nn.Module): 95 | def __init__(self, inplane, midplane, outplane, 96 | kernel=(3,3), **kwargs): 97 | super(UNet,self).__init__() 98 | self.skip=kwargs.get('skip', True) 99 | self.inplane = inplane 100 | self.midplane = midplane 101 | self.outplane = outplane 102 | self.depth = len(midplane) -1 103 | self.downblocks = nn.ModuleList() 104 | self.upblocks = nn.ModuleList() 105 | self.inblocks = ConvBlock(self.inplane, self.midplane[0], kernel, **kwargs) 106 | # allow user defined bottleneck 107 | self.bottleneck = kwargs.get('bottleneck', \ 108 | ConvBlock(self.midplane[-1], self.midplane[-1], 1, **kwargs)) 109 | self.outblock = nn.Conv2d(self.midplane[0], self.outplane, 1) 110 | for i in range(self.depth): 111 | self.downblocks.append(DownBlock(self.midplane[i], 112 | self.midplane[i+1], 113 | kernel, **kwargs)) 114 | self.upblocks.append(UpBlock(self.skip*self.midplane[i] + \ 115 | self.midplane[i+1],\ 116 | self.midplane[i], 117 | kernel, **kwargs)) 118 | 119 | def forward(self,x,side_out=False,bot_out=False, adaptive=False): 120 | xh = [x] 121 | x = self.inblocks(x) 122 | xh.append(x) 123 | for i in range(self.depth): 124 | x = self.downblocks[i](x) 125 | xh.append(x) 126 | x = self.bottleneck(x) 127 | features = x 128 | if adaptive : 129 | return features 130 | xh.append(x) 131 | for i in range(self.depth): 132 | x = self.upblocks[self.depth-i-1](x,xh[self.depth-i]) 133 | xh.append(x) 134 | x = self.outblock(x) 135 | xh.append(x) 136 | if side_out: 137 | return xh 138 | else: 139 | return x 140 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import numpy as np 5 | import random 6 | from copy import deepcopy 7 | import torch 8 | from torch.autograd import Variable 9 | from torch.nn.functional import normalize 10 | def deterministic(seed=0): 11 | torch.manual_seed(0) 12 | torch.cuda.manual_seed(0) 13 | torch.cuda.manual_seed_all(0) 14 | torch.backends.cudnn.deterministic=True 15 | torch.backends.cudnn.benchmark=False 16 | np.random.seed(0) 17 | random.seed(0) 18 | 19 | def grams(x): 20 | """Return gram matrix for input feature 21 | Args: 22 | x: feature maps, [batch, channel, rows, cols] 23 | Return: 24 | grams [batch, channel, channel] 25 | """ 26 | a, b, c, d = x.size() 27 | features = x.view(a * b, c * d) 28 | G = torch.mm(features, features.t()) # compute the gram product 29 | return G.div(a * b * c * d) 30 | 31 | def load_config(config_path,args): 32 | assert(os.path.exists(config_path)) 33 | cfg = json.load(open(config_path, 'r')) 34 | logger = logging.getLogger('global') 35 | logger.info(json.dumps(cfg, indent=2)) 36 | save_path = os.path.join(args.results_dir, 37 | os.path.basename(args.config)) 38 | if not args.evaluate: 39 | with open(save_path, 'w') as fp: 40 | fp.write(json.dumps(cfg, indent=2)) 41 | for key in cfg.keys(): 42 | if key != 'shared': 43 | cfg[key].update(cfg['shared']) 44 | return cfg 45 | 46 | def setlogger(logger,args): 47 | if args.test or args.evaluate: 48 | hdlr = logging.FileHandler(os.path.join(args.results_dir,'eval.log'),'w+') 49 | else: 50 | if args.resume_T or args.resume_AE: 51 | hdlr = logging.FileHandler(os.path.join(args.results_dir,args.trainer+'_train.log')) 52 | else: 53 | hdlr = logging.FileHandler(os.path.join(args.results_dir,args.trainer+'_train.log'),'w+') 54 | formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') 55 | hdlr.setFormatter(formatter) 56 | logger.addHandler(hdlr) 57 | logger.setLevel(logging.INFO) 58 | logger.info(json.dumps(vars(args), indent=2)) 59 | return logger 60 | 61 | def ncc(img1, img2): 62 | """Normalized cross correlation loss 63 | """ 64 | return 1 - 1.*torch.sum((img1 - torch.mean(img1,dim=[-1,-2]))*(img2 - torch.mean(img2,dim=[-1,-2])))\ 65 | /torch.sqrt(torch.sum((img1 - torch.mean(img1,dim=[-1,-2]))**2) \ 66 | *torch.sum((img2 - torch.mean(img2,dim=[-1,-2]))**2) + 1e-10) 67 | 68 | def generate_mask(bd_pts=None, img_rows=None, lesion=None): 69 | """ Generate masks from boundary surfaces and lesion masks 70 | 71 | Args: 72 | bd_pts: boundary surfaces, shape = [bds, img_cols] 73 | lesion: lesion masks, shape = [img_rows, img_cols] 74 | 75 | Returns: 76 | label: training labels, shape = (img_rows, img_cols) 77 | 78 | """ 79 | bds,img_cols = bd_pts.shape 80 | if img_rows is None: 81 | assert lesion is not None 82 | img_rows = lesion.shape[0] 83 | label = np.zeros((img_rows,img_cols))*np.nan 84 | for j in range(img_cols): 85 | 86 | cols = np.arange(img_rows) 87 | index = cols - bd_pts[0,j] < 0 88 | label[index,j] = 0 89 | index = cols - bd_pts[bds-1,j] >= 0 90 | label[index,j] = bds 91 | for k in range(bds-1): 92 | index_up = cols - bd_pts[k,j] >= 0 93 | index_down = cols - bd_pts[k+1,j] < 0 94 | label[index_up&index_down,j] = k+1 95 | if lesion is not None: 96 | label[lesion>0] = bds+1 97 | 98 | return label 99 | 100 | def l2_reg_ortho(mdl): 101 | l2_reg = None 102 | for W in mdl.parameters(): 103 | if W.ndimension() < 2: 104 | continue 105 | else: 106 | cols = W[0].numel() 107 | rows = W.shape[0] 108 | w1 = W.view(-1,cols) 109 | wt = torch.transpose(w1,0,1) 110 | m = torch.matmul(wt,w1) 111 | ident = Variable(torch.eye(cols,cols)) 112 | ident = ident.cuda() 113 | 114 | w_tmp = (m - ident) 115 | height = w_tmp.size(0) 116 | u = normalize(w_tmp.new_empty(height).normal_(0,1), dim=0, eps=1e-12) 117 | v = normalize(torch.matmul(w_tmp.t(), u), dim=0, eps=1e-12) 118 | u = normalize(torch.matmul(w_tmp, v), dim=0, eps=1e-12) 119 | sigma = torch.dot(u, torch.matmul(w_tmp, v)) 120 | 121 | if l2_reg is None: 122 | l2_reg = (sigma)**2 123 | else: 124 | l2_reg = l2_reg + (sigma)**2 125 | return l2_reg 126 | 127 | def split_data(dataset,split,switch=False): 128 | ''' split training data into training/validation 129 | Args: 130 | split[0] - split[1] val 131 | split[1] - split[2] train 132 | switch: switch returned dataset 133 | ''' 134 | traindataset = deepcopy(dataset) 135 | valdataset = deepcopy(dataset) 136 | if len(split)>2: 137 | idx = np.arange(split[0],split[-1]) 138 | else: 139 | idx = np.arange(len(dataset)) 140 | validx = np.arange(int(split[0]),int(split[1]),dtype=np.uint8) 141 | traidx = np.array(list(set(idx)-set(validx)),dtype=np.uint8) 142 | traindataset.datalist = [dataset.datalist[i] for i in traidx] 143 | traindataset.labellist = [dataset.labellist[i] for i in traidx] 144 | valdataset.datalist = [dataset.datalist[i] for i in validx] 145 | valdataset.labellist = [dataset.labellist[i] for i in validx] 146 | if switch: # switch train/val to return the correct test set 147 | return valdataset, traindataset 148 | else: 149 | return traindataset, valdataset -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import torch 3 | import numpy as np 4 | import random 5 | import json 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from pathlib import Path 9 | import time 10 | import copy 11 | import warnings 12 | import logging 13 | from scipy.io import loadmat 14 | from scipy import misc 15 | from PIL import Image 16 | import imageio 17 | import torchvision.transforms as transforms 18 | from copy import deepcopy 19 | import pdb 20 | logger = logging.getLogger('global') 21 | warnings.filterwarnings("ignore") 22 | from utils.util import generate_mask 23 | import tifffile as tiff 24 | import nibabel 25 | from datasets.transform import Composer, registered_workers 26 | from matplotlib import pyplot as plt 27 | class PairedDataset(Dataset): 28 | def __init__(self, opt, train=True, augment=True): 29 | super(PairedDataset, self).__init__() 30 | self.train = train 31 | self.augment = augment 32 | self.opt = opt 33 | # find images from folder/txt/name list 34 | try: 35 | if self.train: 36 | self.img_path = opt.img_path 37 | self.label_path = opt.label_path 38 | else: 39 | self.img_path = opt.vimg_path 40 | self.label_path = opt.vlabel_path 41 | self.img_ext = opt.img_ext 42 | self.label_ext = opt.label_ext 43 | if os.path.isfile(self.img_path): 44 | with open(self.img_path) as f: 45 | self.datalist = f.read().splitlines() 46 | elif type(self.img_path) == list: 47 | self.datalist = self.img_path 48 | elif type(self.img_path) == str: 49 | self.datalist = sorted(list(Path(self.img_path).glob('*.'+self.img_ext))) 50 | # find labels from folder/txt/name list 51 | if os.path.isfile(self.label_path): 52 | with open(self.label_path) as f: 53 | self.labellist = f.read().splitlines() 54 | elif type(self.label_path) == list: 55 | self.labellist = self.label_path 56 | elif type(self.label_path) == str: 57 | self.labellist = sorted(list(Path(self.label_path).glob('*.'+self.label_ext))) 58 | # assert label and image are aligned 59 | assert len(self.datalist) == len(self.labellist) 60 | except: 61 | self.datalist, self.labellist = [], [] 62 | # define transformation using opt 63 | self.transform = self._get_transform() 64 | def name(self): 65 | return 'PairedDataset' 66 | def __getitem__(self, index): 67 | # load image 68 | self.image = self._get_image(index) 69 | # load label 70 | self.label = self._get_label(index) 71 | # add additional dimension for augmentation 72 | self.image = self.image[np.newaxis] 73 | self.label = self.label[np.newaxis] 74 | self.image, self.label = self.transform(self.image, self.label) 75 | # remove the additional dimension of label 76 | # image: [1, H, W], label: [H, W] 77 | sample = {'data':self.image, 'label':self.label[0], 'filename':str(self.datalist[index]), 78 | 'transform':self.transform.get_params()} 79 | return sample 80 | def __len__(self): 81 | return len(self.datalist) 82 | def _get_image(self,index): 83 | pass 84 | def _get_label(self,index): 85 | pass 86 | def _get_transform(self): 87 | pass 88 | 89 | 90 | 91 | class OCTSegDataset(PairedDataset): 92 | # specifically for OCT segmentation dataset 93 | # https://github.com/YufanHe/oct_preprocess 94 | def __init__(self, opt, train=True, augment=True): 95 | super(OCTSegDataset, self).__init__(opt, train, augment) 96 | def _get_image(self,index): 97 | try: 98 | image = np.array(imageio.imread(str(self.datalist[index]),pilmode = 'L'))/255 99 | except: 100 | raise(RuntimeError("image type not supported")) 101 | return image 102 | def _get_label(self,index): 103 | with open(str(self.labellist[index]),'r') as f: 104 | dicts = json.loads(f.read()) 105 | if 'lesion' in dicts.keys(): 106 | mask = np.array(dicts['lesion']) 107 | mask[mask>1] = 1 108 | else: 109 | mask = np.zeros(self.image.shape) 110 | # add dtype=float to make sure None is converted to NaN 111 | bds = np.array(dicts['bds'], dtype=np.float) - 1 112 | label = generate_mask(bd_pts=bds,lesion=mask) 113 | return label 114 | def _get_transform(self): 115 | # dataset specific augmentation 116 | transform_list = [] 117 | if self.augment: 118 | transform_list +=[registered_workers['gamma'](\ 119 | {'p':self.opt.aug_prob, 120 | 'gamma':self.opt.gamma})] 121 | transform_list += [registered_workers['affine'](\ 122 | {'p':self.opt.aug_prob, 123 | 'angle':self.opt.affine_angle, 124 | 'translate':self.opt.affine_translate, 125 | 'scale':self.opt.affine_scale})] 126 | transform_list += [registered_workers['hflip']({'p':self.opt.aug_prob})] 127 | transform_list +=[registered_workers['noise'](\ 128 | {'p':self.opt.aug_prob, 129 | 'std':self.opt.noise_std})] 130 | transform_list += [registered_workers['normalize']({'n_label':False})] 131 | return Composer(transform_list) 132 | -------------------------------------------------------------------------------- /experiments/prostate/nets/unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .buildingblocks import DoubleConv, ExtResNetBlock, create_encoders, \ 4 | create_decoders 5 | from .utils import number_of_features_per_level, get_class 6 | 7 | class Abstract3DUNet(nn.Module): 8 | """ 9 | Base class for standard and residual UNet. 10 | Args: 11 | in_channels (int): number of input channels 12 | out_channels (int): number of output segmentation masks; 13 | Note that that the of out_channels might correspond to either 14 | different semantic classes or to different binary segmentation mask. 15 | It's up to the user of the class to interpret the out_channels and 16 | use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) 17 | or BCEWithLogitsLoss (two-class) respectively) 18 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 19 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 20 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the 21 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used 22 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. 23 | basic_module: basic model for the encoder/decoder (DoubleConv, ExtResNetBlock, ....) 24 | layer_order (string): determines the order of layers 25 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 26 | See `SingleConv` for more info 27 | num_groups (int): number of groups for the GroupNorm 28 | num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int) 29 | is_segmentation (bool): if True (semantic segmentation problem) Sigmoid/Softmax normalization is applied 30 | after the final convolution; if False (regression problem) the normalization layer is skipped at the end 31 | conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module 32 | pool_kernel_size (int or tuple): the size of the window 33 | conv_padding (int or tuple): add zero-padding added to all three sides of the input 34 | """ 35 | 36 | def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr', 37 | num_groups=8, num_levels=4, is_segmentation=True, conv_kernel_size=3, pool_kernel_size=2, 38 | conv_padding=1, **kwargs): 39 | super(Abstract3DUNet, self).__init__() 40 | 41 | if isinstance(f_maps, int): 42 | f_maps = number_of_features_per_level(f_maps, num_levels=num_levels) 43 | 44 | assert isinstance(f_maps, list) or isinstance(f_maps, tuple) 45 | assert len(f_maps) > 1, "Required at least 2 levels in the U-Net" 46 | 47 | # create encoder path 48 | self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, 49 | num_groups, pool_kernel_size) 50 | 51 | # create decoder path 52 | self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, num_groups, 53 | upsample=True) 54 | 55 | # in the last layer a 1×1 convolution reduces the number of output 56 | # channels to the number of labels 57 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) 58 | 59 | if is_segmentation: 60 | # semantic segmentation problem 61 | if final_sigmoid: 62 | self.final_activation = nn.Sigmoid() 63 | else: 64 | self.final_activation = nn.Softmax(dim=1) 65 | else: 66 | # regression problem 67 | self.final_activation = None 68 | 69 | def forward(self, x): 70 | # encoder part 71 | encoders_features = [] 72 | for encoder in self.encoders: 73 | x = encoder(x) 74 | # reverse the encoder outputs to be aligned with the decoder 75 | encoders_features.insert(0, x) 76 | 77 | # remove the last encoder's output from the list 78 | # !!remember: it's the 1st in the list 79 | encoders_features = encoders_features[1:] 80 | 81 | # decoder part 82 | for decoder, encoder_features in zip(self.decoders, encoders_features): 83 | # pass the output from the corresponding encoder and the output 84 | # of the previous decoder 85 | x = decoder(encoder_features, x) 86 | 87 | x = self.final_conv(x) 88 | 89 | # apply final_activation (i.e. Sigmoid or Softmax) only during prediction. During training the network outputs logits 90 | #if not self.training and self.final_activation is not None: 91 | # x = self.final_activation(x) 92 | 93 | return x 94 | 95 | 96 | class UNet3D(Abstract3DUNet): 97 | """ 98 | 3DUnet model from 99 | `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" 100 | `. 101 | Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder 102 | """ 103 | 104 | def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr', 105 | num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1, **kwargs): 106 | super(UNet3D, self).__init__(in_channels=in_channels, 107 | out_channels=out_channels, 108 | final_sigmoid=final_sigmoid, 109 | basic_module=DoubleConv, 110 | f_maps=f_maps, 111 | layer_order=layer_order, 112 | num_groups=num_groups, 113 | num_levels=num_levels, 114 | is_segmentation=is_segmentation, 115 | conv_padding=conv_padding, 116 | **kwargs) 117 | -------------------------------------------------------------------------------- /experiments/prostate/utils/loss.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Loss for brain segmentaion (not used) 4 | """ 5 | import sys, os 6 | base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 7 | sys.path.append(base_path) 8 | 9 | import torchvision.transforms as transforms 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from medpy import metric 14 | import numpy as np 15 | import torch 16 | 17 | 18 | 19 | 20 | 21 | def entropy_loss(p, c=3): 22 | # p N*C*W*H*D 23 | p = F.softmax(p, dim=1) 24 | y1 = -1 * torch.sum(p * torch.log(p + 1e-6), dim=0) / torch.tensor(np.log(c)).cuda() 25 | ent = torch.mean(y1) 26 | return ent 27 | 28 | def flatten(tensor): 29 | """Flattens a given tensor such that the channel axis is first. 30 | The shapes are transformed as follows: 31 | (N, C, D, H, W) -> (C, N * D * H * W) 32 | """ 33 | # number of channels 34 | C = tensor.size(1) 35 | # new axis order 36 | axis_order = (1, 0) + tuple(range(2, tensor.dim())) 37 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) 38 | transposed = tensor.permute(axis_order) 39 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) 40 | return transposed.contiguous().view(C, -1) 41 | 42 | def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): 43 | """ 44 | Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. 45 | Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. 46 | Args: 47 | input (torch.Tensor): NxCxSpatial input tensor 48 | target (torch.Tensor): NxCxSpatial target tensor 49 | epsilon (float): prevents division by zero 50 | weight (torch.Tensor): Cx1 tensor of weight per channel/class 51 | """ 52 | 53 | # input and target shapes must match 54 | assert input.size() == target.size(), "'input' and 'target' must have the same shape" 55 | 56 | input = flatten(input) 57 | target = flatten(target) 58 | target = target.float() 59 | 60 | # compute per channel Dice Coefficient 61 | intersect = (input * target).sum(-1) 62 | if weight is not None: 63 | intersect = weight * intersect 64 | 65 | # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1) 66 | denominator = (input * input).sum(-1) + (target * target).sum(-1) 67 | return 2 * (intersect / denominator.clamp(min=epsilon)) 68 | 69 | 70 | class _MaskingLossWrapper(nn.Module): 71 | """ 72 | Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`. 73 | """ 74 | 75 | def __init__(self, loss, ignore_index): 76 | super(_MaskingLossWrapper, self).__init__() 77 | assert ignore_index is not None, 'ignore_index cannot be None' 78 | self.loss = loss 79 | self.ignore_index = ignore_index 80 | 81 | def forward(self, input, target): 82 | mask = target.clone().ne_(self.ignore_index) 83 | mask.requires_grad = False 84 | 85 | # mask out input/target so that the gradient is zero where on the mask 86 | input = input * mask 87 | target = target * mask 88 | 89 | # forward masked input and target to the loss 90 | return self.loss(input, target) 91 | 92 | 93 | class SkipLastTargetChannelWrapper(nn.Module): 94 | """ 95 | Loss wrapper which removes additional target channel 96 | """ 97 | 98 | def __init__(self, loss, squeeze_channel=False): 99 | super(SkipLastTargetChannelWrapper, self).__init__() 100 | self.loss = loss 101 | self.squeeze_channel = squeeze_channel 102 | 103 | def forward(self, input, target): 104 | assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel' 105 | 106 | # skips last target channel if needed 107 | target = target[:, :-1, ...] 108 | 109 | if self.squeeze_channel: 110 | # squeeze channel dimension if singleton 111 | target = torch.squeeze(target, dim=1) 112 | return self.loss(input, target) 113 | 114 | 115 | class _AbstractDiceLoss(nn.Module): 116 | """ 117 | Base class for different implementations of Dice loss. 118 | """ 119 | 120 | def __init__(self, weight=None, normalization='softmax'): 121 | super(_AbstractDiceLoss, self).__init__() 122 | self.register_buffer('weight', weight) 123 | # The output from the network during training is assumed to be un-normalized probabilities and we would 124 | # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data, 125 | # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems. 126 | # However if one would like to apply Softmax in order to get the proper probability distribution from the 127 | # output, just specify `normalization=Softmax` 128 | assert normalization in ['sigmoid', 'softmax', 'none'] 129 | if normalization == 'sigmoid': 130 | self.normalization = nn.Sigmoid() 131 | elif normalization == 'softmax': 132 | self.normalization = nn.Softmax(dim=1) 133 | else: 134 | self.normalization = lambda x: x 135 | 136 | def dice(self, input, target, weight): 137 | # actual Dice score computation; to be implemented by the subclass 138 | raise NotImplementedError 139 | 140 | def forward(self, input, target): 141 | # get probabilities from logits 142 | input = self.normalization(input) 143 | 144 | # compute per channel Dice coefficient 145 | per_channel_dice = self.dice(input, target, weight=self.weight) 146 | 147 | # average Dice score across all channels/classes 148 | return 1. - torch.mean(per_channel_dice) 149 | 150 | 151 | class DiceLoss(_AbstractDiceLoss): 152 | """Computes Dice Loss according to https://arxiv.org/abs/1606.04797. 153 | For multi-class segmentation `weight` parameter can be used to assign different weights per class. 154 | The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function. 155 | """ 156 | 157 | def __init__(self, weight=None, normalization='softmax'): 158 | super().__init__(weight, normalization) 159 | 160 | def dice(self, input, target, weight): 161 | return compute_per_channel_dice(input, target, weight=self.weight) 162 | 163 | 164 | -------------------------------------------------------------------------------- /experiments/prostate/train/model_trainer_segmentation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import torch.backends.cudnn as cudnn 7 | from .model_trainer import ModelTrainer 8 | from utils.loss import DiceLoss, entropy_loss 9 | import copy 10 | import numpy as np 11 | import random 12 | from test_dataset import Prostate 13 | from torch.utils.data.dataloader import DataLoader 14 | def deterministic(seed): 15 | cudnn.benchmark = False 16 | cudnn.deterministic = True 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | random.seed(seed) 21 | 22 | 23 | class ModelTrainerSegmentation(ModelTrainer): 24 | def get_model_params(self): 25 | return self.model.cpu().state_dict() 26 | 27 | def set_model_params(self, model_parameters): 28 | self.model.load_state_dict(model_parameters) 29 | 30 | @torch.enable_grad() 31 | def test_time(self, test_data, device, args): 32 | test_set = Prostate(args.target) 33 | test_data = DataLoader(test_set, batch_size=1, shuffle=True) 34 | deterministic(args.seed) 35 | metrics = { 36 | 'test_acc': 0, 37 | 'test_loss': 0, 38 | } 39 | best_dice = 0. 40 | dice_buffer = [] 41 | model = self.model 42 | model_adapt = copy.deepcopy(model) 43 | model_adapt.to(device) 44 | model_adapt.train() 45 | for m in model_adapt.modules(): 46 | if isinstance(m, nn.BatchNorm3d): 47 | m.requires_grad_(True) 48 | m.track_running_stats = False 49 | m.running_mean = None 50 | m.running_var = None 51 | var_list = model_adapt.named_parameters() 52 | update_var_list = [] 53 | update_name_list = [] 54 | names = [] 55 | 56 | for idx, (name, param) in enumerate(var_list): 57 | param.requires_grad_(False) 58 | names.append(name) 59 | if "batchnorm" in name: 60 | param.requires_grad_(True) 61 | update_var_list.append(param) 62 | update_name_list.append(name) 63 | params = model_adapt.parameters() 64 | optimizer = torch.optim.Adam(update_var_list, lr=1e-3, betas=(0.9, 0.999)) 65 | criterion = DiceLoss().to(device) 66 | loss_all = 0 67 | test_acc = 0. 68 | 69 | for epoch in range(1): 70 | loss_all = 0 71 | test_acc = 0. 72 | 73 | deterministic(args.seed) 74 | for step, (data, target) in enumerate(test_data): 75 | deterministic(args.seed) 76 | data = data.to(device) 77 | target = target.to(device) 78 | output = model_adapt(data) 79 | loss_entropy_before = entropy_loss(output, c=2) 80 | all_loss = loss_entropy_before 81 | weight = 1 82 | all_loss = weight*all_loss 83 | #print(all_loss) 84 | optimizer.zero_grad() 85 | all_loss.backward() 86 | optimizer.step() 87 | output = model_adapt(data) 88 | loss = criterion(output, target) 89 | loss_all += loss.item() 90 | loss = loss_all / len(test_data) 91 | acc = 1 - loss 92 | metrics['test_loss'] = loss 93 | metrics["test_acc"] = acc 94 | return metrics 95 | 96 | def train(self, train_data, device, args): 97 | model = self.model 98 | 99 | model.to(device) 100 | model.train() 101 | 102 | # train and update 103 | criterion = DiceLoss().to(device) 104 | if args.client_optimizer == "sgd": 105 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), lr=args.lr) 106 | else: 107 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=args.lr, amsgrad=True) 108 | 109 | epoch_loss = [] 110 | epoch_acc = [] 111 | for epoch in range(args.wk_iters): 112 | batch_loss = [] 113 | batch_acc = [] 114 | for batch_idx, (x, labels) in enumerate(train_data): 115 | model.zero_grad() 116 | x, labels = x.to(device), labels.to(device) 117 | 118 | #print(x.shape) 119 | log_probs = model(x) 120 | loss = criterion(log_probs, labels) 121 | 122 | loss.backward() 123 | optimizer.step() 124 | batch_loss.append(loss.item()) 125 | batch_acc.append(1-loss.item()) 126 | 127 | epoch_loss.append(sum(batch_loss) / len(batch_loss)) 128 | epoch_acc.append(sum(batch_acc) / len(batch_acc)) 129 | logging.info('Client Index = {}\tEpoch: {}\tAcc:{:.4f}\tLoss: {:.4f}'.format( 130 | self.id, epoch, sum(epoch_acc) / len(epoch_acc),sum(epoch_loss) / len(epoch_loss))) 131 | 132 | def test(self, test_data, device, args, ood=False): 133 | model = copy.deepcopy(self.model) 134 | 135 | model.to(device) 136 | if ood: 137 | model.train() 138 | else: 139 | model.eval() 140 | 141 | metrics = { 142 | 'test_acc': 0, 143 | 'test_loss': 0, 144 | 'test_hd': 0, 145 | } 146 | test_set = Prostate(args.target) 147 | test_data = DataLoader(test_set, batch_size=1, shuffle=True) 148 | criterion = DiceLoss().to(device) 149 | test_epoches = 1 150 | with torch.no_grad(): 151 | for test_epoch in range(test_epoches): 152 | for batch_idx, (x, target) in enumerate(test_data): 153 | x = x.to(device) 154 | target = target.to(device) 155 | pred = model(x) 156 | loss = criterion(pred, target) 157 | 158 | acc = 1 - loss.item() 159 | #print(acc) 160 | 161 | metrics['test_loss'] += loss.item() 162 | metrics['test_acc'] += acc 163 | 164 | metrics["test_loss"] = metrics["test_loss"] / (len(test_data)*test_epoches) 165 | metrics["test_acc"] = metrics["test_acc"] / (len(test_data)*test_epoches) 166 | 167 | return metrics 168 | 169 | 170 | -------------------------------------------------------------------------------- /experiments/camelyon17/configs/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from configs.algorithm import algorithm_defaults 3 | from configs.model import model_defaults 4 | from configs.scheduler import scheduler_defaults 5 | from configs.data_loader import loader_defaults 6 | from configs.datasets import dataset_defaults, split_defaults 7 | 8 | def populate_defaults(config): 9 | """Populates hyperparameters with defaults implied by choices 10 | of other hyperparameters.""" 11 | 12 | orig_config = copy.deepcopy(config) 13 | assert config.dataset is not None, 'dataset must be specified' 14 | assert config.algorithm is not None, 'algorithm must be specified' 15 | 16 | # Run oracle using ERM with unlabeled split 17 | if config.use_unlabeled_y: 18 | assert config.algorithm == 'ERM', 'Only ERM is currently supported for training on the true labels of unlabeled data.' 19 | assert config.unlabeled_split is not None, 'Specify an unlabeled split' 20 | assert config.dataset in ['amazon', 'civilcomments', 'fmow', 'iwildcam'], 'The unlabeled data in this dataset are truly unlabeled, and we do not have true labels for them.' 21 | 22 | # Validations 23 | if config.groupby_fields == ['from_source_domain']: 24 | if config.n_groups_per_batch is None: 25 | config.n_groups_per_batch = 1 26 | elif config.n_groups_per_batch != 1: 27 | raise ValueError( 28 | f"from_source_domain was specified for groupby_fields, but n_groups_per_batch " 29 | f"was {config.n_groups_per_batch}, when it should be 1." 30 | ) 31 | 32 | if config.unlabeled_n_groups_per_batch is None: 33 | config.unlabeled_n_groups_per_batch = 1 34 | elif config.unlabeled_n_groups_per_batch != 1: 35 | raise ValueError( 36 | f"from_source_domain was specified for groupby_fields, but unlabeled_n_groups_per_batch " 37 | f"was {config.unlabeled_n_groups_per_batch}, when it should be 1." 38 | ) 39 | 40 | if config.algorithm == 'DANN' and config.lr is not None: 41 | raise ValueError( 42 | "Cannot pass in a value for lr. For DANN, only dann_classifier_lr, dann_featurizer_lr " 43 | "and dann_discriminator_lr are valid learning rate parameters." 44 | ) 45 | 46 | if config.additional_train_transform is not None: 47 | if config.algorithm == "NoisyStudent": 48 | raise ValueError( 49 | "Cannot pass in a value for additional_train_transform, NoisyStudent " 50 | "already has a default transformation for the training data." 51 | ) 52 | 53 | if config.load_featurizer_only: 54 | if config.pretrained_model_path is None: 55 | raise ValueError( 56 | "load_featurizer_only cannot be set when there is no pretrained_model_path " 57 | "specified." 58 | ) 59 | 60 | if config.dataset == 'globalwheat': 61 | if config.additional_train_transform is not None: 62 | raise ValueError( 63 | f"Augmentations not supported for detection dataset: {config.dataset}." 64 | ) 65 | config.additional_train_transform = '' 66 | 67 | if config.algorithm == "NoisyStudent": 68 | if config.process_pseudolabels_function is None: 69 | config.process_pseudolabels_function = 'pseudolabel_detection' 70 | elif config.process_pseudolabels_function == 'pseudolabel_detection_discard_empty': 71 | raise ValueError( 72 | f"Filtering out empty images when generating pseudo-labels for {config.algorithm} " 73 | f"is not supported for detection." 74 | ) 75 | 76 | # implied defaults from choice of dataset 77 | config = populate_config( 78 | config, 79 | dataset_defaults[config.dataset] 80 | ) 81 | 82 | # implied defaults from choice of split 83 | if config.dataset in split_defaults and config.split_scheme in split_defaults[config.dataset]: 84 | config = populate_config( 85 | config, 86 | split_defaults[config.dataset][config.split_scheme] 87 | ) 88 | 89 | # implied defaults from choice of algorithm 90 | config = populate_config( 91 | config, 92 | algorithm_defaults[config.algorithm] 93 | ) 94 | 95 | # implied defaults from choice of loader 96 | config = populate_config( 97 | config, 98 | loader_defaults 99 | ) 100 | # implied defaults from choice of model 101 | if config.model: config = populate_config( 102 | config, 103 | model_defaults[config.model], 104 | ) 105 | 106 | # implied defaults from choice of scheduler 107 | if config.scheduler: config = populate_config( 108 | config, 109 | scheduler_defaults[config.scheduler] 110 | ) 111 | 112 | # misc implied defaults 113 | if config.groupby_fields is None: 114 | config.no_group_logging = True 115 | config.no_group_logging = bool(config.no_group_logging) 116 | 117 | # basic checks 118 | required_fields = [ 119 | 'split_scheme', 'train_loader', 'uniform_over_groups', 'batch_size', 'eval_loader', 'model', 'loss_function', 120 | 'val_metric', 'val_metric_decreasing', 'n_epochs', 'optimizer', 'lr', 'weight_decay', 121 | ] 122 | for field in required_fields: 123 | assert getattr(config, field) is not None, f"Must manually specify {field} for this setup." 124 | 125 | # data loader validations 126 | # we only raise this error if the train_loader is standard, and 127 | # n_groups_per_batch or distinct_groups are 128 | # specified by the user (instead of populated as a default) 129 | if config.train_loader == 'standard': 130 | if orig_config.n_groups_per_batch is not None: 131 | raise ValueError("n_groups_per_batch cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.") 132 | if orig_config.distinct_groups is not None: 133 | raise ValueError("distinct_groups cannot be specified if the data loader is 'standard'. Consider using a 'group' data loader instead.") 134 | 135 | return config 136 | 137 | 138 | def populate_config(config, template: dict, force_compatibility=False): 139 | """Populates missing (key, val) pairs in config with (key, val) in template. 140 | Example usage: populate config with defaults 141 | Args: 142 | - config: namespace 143 | - template: dict 144 | - force_compatibility: option to raise errors if config.key != template[key] 145 | """ 146 | if template is None: 147 | return config 148 | 149 | d_config = vars(config) 150 | for key, val in template.items(): 151 | if not isinstance(val, dict): # config[key] expected to be a non-index-able 152 | if key not in d_config or d_config[key] is None: 153 | d_config[key] = val 154 | elif d_config[key] != val and force_compatibility: 155 | raise ValueError(f"Argument {key} must be set to {val}") 156 | 157 | else: # config[key] expected to be a kwarg dict 158 | for kwargs_key, kwargs_val in val.items(): 159 | if kwargs_key not in d_config[key] or d_config[key][kwargs_key] is None: 160 | d_config[key][kwargs_key] = kwargs_val 161 | elif d_config[key][kwargs_key] != kwargs_val and force_compatibility: 162 | raise ValueError(f"Argument {key}[{kwargs_key}] must be set to {val}") 163 | return config 164 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser(description='PyTorch parser') 3 | # training 4 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 5 | help='number of data loading workers (default: 2)') 6 | parser.add_argument('--epochs', default=50, type=int, metavar='N', 7 | help='number of total epochs to run') 8 | parser.add_argument('--sepochs', default=0, type=int, metavar='N', 9 | help='number of total epochs to initialize adaptor') 10 | parser.add_argument('--tepochs', default=1, type=int, metavar='N', 11 | help='number of total epochs to during test') 12 | parser.add_argument('-b', '--batch-size', default=2, type=int, 13 | metavar='N', help='mini-batch size (default: 256)') 14 | parser.add_argument('--segloss', dest='segloss', default='ce',type=str, 15 | help='wce, ce, dice. segmentation loss') 16 | parser.add_argument('--segaeloss', dest='segaeloss', default='mse',type=str, 17 | help='segmentation output ae loss') 18 | parser.add_argument('--usegt', dest='use_gt', action='store_true', default=False, 19 | help='use gt for last segmentation map autoencoder training') 20 | parser.add_argument('--usedtta', dest='usedtta', action='store_true', default=False, 21 | help='use dtta in MEDIA paper') 22 | parser.add_argument('--segsoftmax', dest='segsoftmax', action='store_true', default=False, 23 | help='use softmax before last autoencoder for segmentation') 24 | parser.add_argument('--tlr', default=0.001, type=float, 25 | metavar='LR', help='initial learning rate for TNet') 26 | parser.add_argument('--aelr', default=0.001, type=float, 27 | metavar='LR', help='initial learning rate for AENet') 28 | parser.add_argument('--alr', default=0.001, type=float, 29 | metavar='LR', help='initial learning rate for ANet') 30 | parser.add_argument('--or_weight', default=0.05, type=float, 31 | metavar='LR', help='or_weight') 32 | parser.add_argument('--wt',dest='weights', type=lambda x: list(map(float, x.split(','))), 33 | help='weights in training ae net') 34 | parser.add_argument('--wo', default=10.0, type=float, 35 | help='orthogonal weights in training ANet') 36 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 37 | help='evaluate model on validation set') 38 | parser.add_argument('-t', '--test', dest='test', action='store_true', 39 | help='test model on validation set') 40 | parser.add_argument('--resume_T', default='', type=str, metavar='PATH', 41 | help='path to latest checkpoint (default: none)') 42 | parser.add_argument('--resume_AE', default='', type=str, metavar='PATH', 43 | help='path to latest checkpoint (default: none)') 44 | # network 45 | parser.add_argument('--trainer',dest='trainer', default='tnet', 46 | type=str, help='select which net to train') 47 | parser.add_argument('--task',dest='task', default='syn', 48 | type=str, help='select task (syn:synthesis or seg:segmentation)') 49 | parser.add_argument('--seq',dest='seq', type=lambda x: list(map(int, x.split(','))), 50 | help='the 1x1 conv seq to be used in A-Net') 51 | parser.add_argument('--td',dest='tnet_dim', type=lambda x: list(map(int, x.split(','))), 52 | help='task net input, encoder, output channels') 53 | parser.add_argument('--ad', dest='aenet_dim', default=64, type=int, 54 | help = 'starting feature channel in auto-encoder') 55 | parser.add_argument('--na', dest='no_adpt', action='store_true', 56 | help = 'not using pre-adaptation') 57 | parser.add_argument('--config', dest='config', default='config.json', 58 | help='hyperparameter in json format') 59 | # datasets 60 | parser.add_argument('--label_path', dest='label_path', default='../data/label/',type=str, 61 | help='path to the label') 62 | parser.add_argument('--label_ext', dest='label_ext', default='json',type=str, 63 | help='label extension') 64 | parser.add_argument('--img_path', dest='img_path', default='../data/image/',type=str, 65 | help='path to the image') 66 | parser.add_argument('--img_ext', dest='img_ext', default='png',type=str, 67 | help='image extension') 68 | parser.add_argument('--vlabel_path', dest='vlabel_path', default='',type=str, 69 | help='path to the validation label') 70 | parser.add_argument('--vimg_path', dest='vimg_path', default='',type=str, 71 | help='path to the validation image') 72 | parser.add_argument('--sub_name', dest='sub_name', default='',type=str, 73 | help='path to the txt file name containing subject unique ID') 74 | parser.add_argument('--split',dest='split', type=lambda x: list(map(int, x.split(','))), 75 | help='the start and end index for validation dataset') 76 | # preprocessing 77 | parser.add_argument('--ps',dest='pad_size', type=lambda x: list(map(int, x.split(','))), 78 | help='padding all the input image to this size') 79 | parser.add_argument('--scs',dest='scale_size', type=lambda x: list(map(int, x.split(','))), 80 | help='interpolate all the input image to this size') 81 | parser.add_argument('--an', dest='add_noise', action='store_true', 82 | help = 'add gaussian noise in preprocessing') 83 | # augmentation 84 | parser.add_argument('--fnoise', dest='feat_noise', action='store_true', default=False, 85 | help='use feature noise to train denoising auto-encoders') 86 | parser.add_argument('--aprob', dest='aug_prob', type=float, default=0, 87 | help='use augmentation during tnet and aenet training') 88 | parser.add_argument('--aangle',dest='affine_angle', type=lambda x: list(map(float, x.split(','))), 89 | default='-30,30', help='affine transformation angle range') 90 | parser.add_argument('--atrans',dest='affine_translate', type=lambda x: list(map(float, x.split(','))), 91 | default='-10,10,-10,10', help='affine transformation translation range') 92 | parser.add_argument('--ascale',dest='affine_scale', type=lambda x: list(map(float, x.split(','))), 93 | default='0.8,1.2',help='affine transformation scaling range') 94 | parser.add_argument('--agamma',dest='gamma', type=lambda x: list(map(float, x.split(','))), 95 | default='0.7,1.5', help='gamma scaling range') 96 | parser.add_argument('--anoise',dest='noise_std', type=float,default=0.1, 97 | help='gaussian noise std') 98 | parser.add_argument('--width',dest='width', type=int,default=400, 99 | help='centor crop width') 100 | parser.add_argument('--height',dest='height', type=int,default=400, 101 | help='centor crop height') 102 | parser.add_argument('-vaug', dest='val_augment', action='store_true', default=False, 103 | help='use augmentation during test time training') 104 | # logging 105 | parser.add_argument('--results_dir', dest='results_dir', default='results_dir', 106 | help='results dir of output') 107 | parser.add_argument('--ss', dest='save_step', default=5, type=int, 108 | help = 'The step of epochs to save checkpoints and validate') 109 | parser.add_argument('--saveimage','--si', dest='saveimage', action='store_true', 110 | help='save image with surfaces and layers') 111 | parser.add_argument('--dpi', dest='dpi', type=int, default=100, help='dpi of saved image') 112 | -------------------------------------------------------------------------------- /datasets/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import pdb 4 | import logging 5 | logger = logging.getLogger('global') 6 | from matplotlib import pyplot as plt 7 | from scipy import interpolate 8 | import torch 9 | from PIL import Image 10 | import torchvision.transforms.functional as TF 11 | from monai.transforms import Affine, AdjustContrast, ResizeWithPadOrCrop 12 | import random 13 | 14 | class Composer(object): 15 | def __init__(self, workers): 16 | self.workers = workers 17 | def __call__(self, image, label): 18 | for _ in self.workers: 19 | image,label = _(image,label) 20 | return image, label 21 | def get_params(self): 22 | # get the latest transformation parameters 23 | # change after the composer being called 24 | params = {} 25 | for _ in self.workers: 26 | params[_.name()] = _.params 27 | return params 28 | def get_rdparams(self): 29 | # get the transformation random range 30 | rdparams = {} 31 | for _ in self.workers: 32 | rdparams[_.name()] = _.rdparams 33 | return rdparams 34 | 35 | class Worker(object): 36 | def __init__(self, rdparams={}): 37 | pass 38 | def name(self): 39 | return 'Base class for transform' 40 | def _random(self): 41 | # generate final transformation params from rdparams 42 | pass 43 | def __call__(self, image, label, params=None): 44 | pass 45 | 46 | class _normalize_worker(Worker): 47 | def __init__(self,rdparams={}): 48 | super(_normalize_worker,self).__init__(rdparams) 49 | default_rdparams = {'n_label':False,'mean':0.0,'std':1.0} 50 | default_rdparams.update(rdparams) 51 | self.rdparams = default_rdparams 52 | def name(self): 53 | return 'normalization' 54 | def __call__(self, image, label, params=None): 55 | if params is None: 56 | params = self._random() 57 | self.params = params 58 | # convert back to numpy and normalize to [0,1] 59 | image = image.astype(np.float32) 60 | image = (image - np.mean(image)) / np.std(image) * self.params['std'] + self.params['mean'] 61 | if self.params['n_label']: 62 | label = (label - np.mean(label)) / np.std(label) 63 | label = label.astype(np.float32) 64 | else: 65 | label = label.astype(np.uint8) 66 | return image, label 67 | def _random(self): 68 | params = self.rdparams 69 | return params 70 | 71 | class _affine_worker(Worker): 72 | def __init__(self,rdparams={}): 73 | super(_affine_worker,self).__init__(rdparams) 74 | default_rdparams = {'p':0.5, 'angle':[-30,30], 'label_mode':'nearest', 75 | 'translate':[-10,10,-10,10], 76 | 'scale':[0.8,1.2],'shear':[0,0,0,0]} 77 | default_rdparams.update(rdparams) 78 | self.rdparams = default_rdparams 79 | def name(self): 80 | return 'affine' 81 | def __call__(self, image, label, params=None): 82 | ''' Affine transformation on image and label 83 | Args: 84 | image: np array or PIL, [img_rows, img_cols] 85 | label: np array or PIL, [img_rows, img_cols] 86 | ''' 87 | if params is None: 88 | params = self._random() 89 | self.params = params 90 | if self.params['p']: 91 | image = Affine(rotate_params=self.params['angle'], translate_params=self.params['translate'], 92 | scale_params=self.params['scale'], shear_params=self.params['shear'], mode='bilinear', 93 | padding_mode='zeros')(image) 94 | label = Affine(rotate_params=self.params['angle'], translate_params=self.params['translate'], 95 | scale_params=self.params['scale'], shear_params=self.params['shear'], 96 | mode=self.params['label_mode'], padding_mode='zeros')(label) 97 | return image, label 98 | def _random(self): 99 | params = {} 100 | params['p'] = random.random() < self.rdparams['p'] 101 | params['label_mode'] = self.rdparams['label_mode'] 102 | params['angle'] = random.randint(self.rdparams['angle'][0],self.rdparams['angle'][1]) 103 | params['translate'] = [random.randint(self.rdparams['translate'][0],self.rdparams['translate'][1]), 104 | random.randint(self.rdparams['translate'][2],self.rdparams['translate'][3])] 105 | params['scale'] = random.random()*(self.rdparams['scale'][1]-self.rdparams['scale'][0]) \ 106 | + self.rdparams['scale'][0] 107 | params['shear'] = [random.randint(self.rdparams['shear'][0],self.rdparams['shear'][1]), 108 | random.randint(self.rdparams['shear'][2],self.rdparams['shear'][3])] 109 | return params 110 | 111 | class _hflip_worker(Worker): 112 | def __init__(self,rdparams={}): 113 | super(_hflip_worker,self).__init__(rdparams) 114 | default_rdparams = {'p':0.5} 115 | default_rdparams.update(rdparams) 116 | self.rdparams = default_rdparams 117 | def name(self): 118 | return 'hflip' 119 | def __call__(self, image, label, params=None): 120 | if params is None: 121 | params = self._random() 122 | self.params = params 123 | if self.params['p']: 124 | image = image[:,:,::-1] 125 | label = label[:,:,::-1] 126 | return image, label 127 | def _random(self): 128 | params = {} 129 | params['p'] = random.random() < self.rdparams['p'] 130 | return params 131 | 132 | class _gamma_worker(Worker): 133 | def __init__(self,rdparams={}): 134 | super(_gamma_worker,self).__init__(rdparams) 135 | default_rdparams = {'p':0.5, 'gamma':[0.7,1.5], 'gain':1} 136 | default_rdparams.update(rdparams) 137 | self.rdparams = default_rdparams 138 | def name(self): 139 | return 'gamma' 140 | def __call__(self, image, label, params=None): 141 | if params is None: 142 | params = self._random() 143 | self.params = params 144 | if self.params['p']: 145 | image = AdjustContrast(self.params['gamma'])(image) 146 | return image, label 147 | def _random(self): 148 | params = {} 149 | params['p'] = random.random() < self.rdparams['p'] 150 | params['gamma'] = random.random()*(self.rdparams['gamma'][1]-self.rdparams['gamma'][0]) \ 151 | + self.rdparams['gamma'][0] 152 | return params 153 | 154 | class _noise_worker(Worker): 155 | def __init__(self,rdparams={}): 156 | super(_noise_worker,self).__init__(rdparams) 157 | default_rdparams = {'p':0.5, 'type':'gaussian', 'std':0.1} 158 | default_rdparams.update(rdparams) 159 | self.rdparams = default_rdparams 160 | def name(self): 161 | return 'noise' 162 | def __call__(self, image, label, params=None): 163 | if params is None: 164 | params = self._random() 165 | self.params = params 166 | if self.params['p']: 167 | if self.params['type'] == 'gaussian': 168 | image = image + self.params['std']*np.random.randn(image.shape[0], image.shape[1], image.shape[2]).astype(np.float32) 169 | else: 170 | raise NotImplementedError 171 | return image, label 172 | def _random(self): 173 | params = {} 174 | params['p'] = random.random() < self.rdparams['p'] 175 | params['std'] = self.rdparams['std'] 176 | params['type'] = self.rdparams['type'] 177 | return params 178 | 179 | class _padcrop_worker(Worker): 180 | def __init__(self,rdparams={}): 181 | super(_padcrop_worker,self).__init__(rdparams) 182 | default_rdparams = {'width':400, 'height':400} 183 | default_rdparams.update(rdparams) 184 | self.rdparams = default_rdparams 185 | def name(self): 186 | return 'padcrop' 187 | def __call__(self, image, label, params=None): 188 | if params is None: 189 | params = self._random() 190 | self.params = params 191 | if self.params['width'] > 0 and self.params['height'] > 0: 192 | image = ResizeWithPadOrCrop([self.params['width'], self.params['height']])(image) 193 | label = ResizeWithPadOrCrop([self.params['width'], self.params['height']])(label) 194 | return image, label 195 | def _random(self): 196 | params = {} 197 | params = self.rdparams 198 | return params 199 | 200 | 201 | registered_workers = \ 202 | { 203 | 'affine':_affine_worker, 204 | 'hflip':_hflip_worker, 205 | 'normalize':_normalize_worker, 206 | 'gamma':_gamma_worker, 207 | 'noise':_noise_worker, 208 | 'padcrop':_padcrop_worker 209 | } 210 | -------------------------------------------------------------------------------- /experiments/camelyon17/models/resnet_multispectral.py: -------------------------------------------------------------------------------- 1 | ##### 2 | # Adapted from torchvision.models.resnet 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=dilation, groups=groups, bias=False, dilation=dilation) 12 | 13 | 14 | def conv1x1(in_planes, out_planes, stride=1): 15 | """1x1 convolution""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | __constants__ = ['downsample'] 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 24 | base_width=64, dilation=1, norm_layer=None): 25 | super(BasicBlock, self).__init__() 26 | if norm_layer is None: 27 | norm_layer = nn.BatchNorm2d 28 | if groups != 1 or base_width != 64: 29 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 30 | if dilation > 1: 31 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 32 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = norm_layer(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = norm_layer(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | identity = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | identity = self.downsample(x) 53 | 54 | out += identity 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | __constants__ = ['downsample'] 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 65 | base_width=64, dilation=1, norm_layer=None): 66 | super(Bottleneck, self).__init__() 67 | if norm_layer is None: 68 | norm_layer = nn.BatchNorm2d 69 | width = int(planes * (base_width / 64.)) * groups 70 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 71 | self.conv1 = conv1x1(inplanes, width) 72 | self.bn1 = norm_layer(width) 73 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 74 | self.bn2 = norm_layer(width) 75 | self.conv3 = conv1x1(width, planes * self.expansion) 76 | self.bn3 = norm_layer(planes * self.expansion) 77 | self.relu = nn.ReLU(inplace=True) 78 | self.downsample = downsample 79 | self.stride = stride 80 | 81 | def forward(self, x): 82 | identity = x 83 | 84 | out = self.conv1(x) 85 | out = self.bn1(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv2(out) 89 | out = self.bn2(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv3(out) 93 | out = self.bn3(out) 94 | 95 | if self.downsample is not None: 96 | identity = self.downsample(x) 97 | 98 | out += identity 99 | out = self.relu(out) 100 | 101 | return out 102 | 103 | 104 | class ResNet(nn.Module): 105 | 106 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 107 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 108 | norm_layer=None, num_channels=3): 109 | super(ResNet, self).__init__() 110 | if norm_layer is None: 111 | norm_layer = nn.BatchNorm2d 112 | self._norm_layer = norm_layer 113 | 114 | self.inplanes = 64 115 | self.dilation = 1 116 | if replace_stride_with_dilation is None: 117 | # each element in the tuple indicates if we should replace 118 | # the 2x2 stride with a dilated convolution instead 119 | replace_stride_with_dilation = [False, False, False] 120 | if len(replace_stride_with_dilation) != 3: 121 | raise ValueError("replace_stride_with_dilation should be None " 122 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 123 | self.groups = groups 124 | self.base_width = width_per_group 125 | self.conv1 = nn.Conv2d(num_channels, self.inplanes, kernel_size=7, stride=2, padding=3, 126 | bias=False) 127 | self.bn1 = norm_layer(self.inplanes) 128 | self.relu = nn.ReLU(inplace=True) 129 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 130 | self.layer1 = self._make_layer(block, 64, layers[0]) 131 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 132 | dilate=replace_stride_with_dilation[0]) 133 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 134 | dilate=replace_stride_with_dilation[1]) 135 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 136 | dilate=replace_stride_with_dilation[2]) 137 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 138 | if num_classes is not None: 139 | self.fc = nn.Linear(512 * block.expansion, num_classes) 140 | self.d_out = num_classes 141 | else: 142 | self.fc = None 143 | self.d_out = 512 * block.expansion 144 | 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv2d): 147 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 148 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 149 | nn.init.constant_(m.weight, 1) 150 | nn.init.constant_(m.bias, 0) 151 | 152 | # Zero-initialize the last BN in each residual branch, 153 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 154 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 155 | if zero_init_residual: 156 | for m in self.modules(): 157 | if isinstance(m, Bottleneck): 158 | nn.init.constant_(m.bn3.weight, 0) 159 | elif isinstance(m, BasicBlock): 160 | nn.init.constant_(m.bn2.weight, 0) 161 | 162 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 163 | norm_layer = self._norm_layer 164 | downsample = None 165 | previous_dilation = self.dilation 166 | if dilate: 167 | self.dilation *= stride 168 | stride = 1 169 | if stride != 1 or self.inplanes != planes * block.expansion: 170 | downsample = nn.Sequential( 171 | conv1x1(self.inplanes, planes * block.expansion, stride), 172 | norm_layer(planes * block.expansion), 173 | ) 174 | 175 | layers = [] 176 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 177 | self.base_width, previous_dilation, norm_layer)) 178 | self.inplanes = planes * block.expansion 179 | for _ in range(1, blocks): 180 | layers.append(block(self.inplanes, planes, groups=self.groups, 181 | base_width=self.base_width, dilation=self.dilation, 182 | norm_layer=norm_layer)) 183 | 184 | return nn.Sequential(*layers) 185 | 186 | def get_feats(self, x, layer=4): 187 | # See note [TorchScript super()] 188 | x = self.conv1(x) 189 | x = self.bn1(x) 190 | x = self.relu(x) 191 | x = self.maxpool(x) 192 | 193 | x = self.layer1(x) 194 | if layer == 1: 195 | return x 196 | x = self.layer2(x) 197 | if layer == 2: 198 | return x 199 | x = self.layer3(x) 200 | if layer == 3: 201 | return x 202 | x = self.layer4(x) 203 | 204 | x = self.avgpool(x) 205 | x = torch.flatten(x, 1) 206 | return x 207 | 208 | 209 | def _forward_impl(self, x, with_feats=False): 210 | x = feats = self.get_feats(x) 211 | if self.fc is not None: 212 | x = self.fc(feats) 213 | 214 | if with_feats: 215 | return x, feats 216 | else: 217 | return x 218 | 219 | def forward(self, x, with_feats=False): 220 | return self._forward_impl(x, with_feats) 221 | 222 | 223 | class ResNet18(ResNet): 224 | def __init__(self, num_classes=10, num_channels=3): 225 | super().__init__( 226 | BasicBlock, [2, 2, 2, 2], num_classes=num_classes, num_channels=num_channels) 227 | 228 | class ResNet34(ResNet): 229 | def __init__(self, num_classes=10, num_channels=3): 230 | super().__init__( 231 | BasicBlock, [3, 4, 6, 3], num_classes=num_classes, num_channels=num_channels) 232 | 233 | class ResNet50(ResNet): 234 | def __init__(self, num_classes=10, num_channels=3): 235 | super().__init__( 236 | Bottleneck, [3, 4, 23, 3], num_classes=num_classes, num_channels=num_channels) 237 | 238 | class ResNet101(ResNet): 239 | def __init__(self, num_classes=10, num_channels=3): 240 | super().__init__( 241 | Bottleneck, [3, 4, 23, 3], num_classes=num_classes, num_channels=num_channels) 242 | 243 | class ResNet152(ResNet): 244 | def __init__(self, num_classes=10, num_channels=3): 245 | super().__init__( 246 | Bottleneck, [3, 8, 36, 3], num_classes=num_classes, num_channels=num_channels) 247 | 248 | 249 | DEPTH_TO_MODEL = {18: ResNet18, 34: ResNet34, 50: ResNet50, 101: ResNet101, 152: ResNet152} 250 | 251 | -------------------------------------------------------------------------------- /experiments/camelyon17/models/initializer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import traceback 5 | 6 | from models.layers import Identity 7 | from utils import load 8 | 9 | def initialize_model(config, d_out, is_featurizer=False): 10 | """ 11 | Initializes models according to the config 12 | Args: 13 | - config (dictionary): config dictionary 14 | - d_out (int): the dimensionality of the model output 15 | - is_featurizer (bool): whether to return a model or a (featurizer, classifier) pair that constitutes a model. 16 | Output: 17 | If is_featurizer=True: 18 | - featurizer: a model that outputs feature Tensors of shape (batch_size, ..., feature dimensionality) 19 | - classifier: a model that takes in feature Tensors and outputs predictions. In most cases, this is a linear layer. 20 | 21 | If is_featurizer=False: 22 | - model: a model that is equivalent to nn.Sequential(featurizer, classifier) 23 | 24 | Pretrained weights are loaded according to config.pretrained_model_path using either transformers.from_pretrained (for bert-based models) 25 | or our own utils.load function (for torchvision models, resnet18-ms, and gin-virtual). 26 | There is currently no support for loading pretrained weights from disk for other models. 27 | """ 28 | # If load_featurizer_only is True, 29 | # then split into (featurizer, classifier) for the purposes of loading only the featurizer, 30 | # before recombining them at the end 31 | featurize = is_featurizer or config.load_featurizer_only 32 | 33 | if config.model in ('resnet18', 'resnet34', 'resnet50', 'resnet101', 'wideresnet50', 'densenet121'): 34 | if featurize: 35 | featurizer = initialize_torchvision_model( 36 | name=config.model, 37 | d_out=None, 38 | **config.model_kwargs) 39 | classifier = nn.Linear(featurizer.d_out, d_out) 40 | model = (featurizer, classifier) 41 | else: 42 | model = initialize_torchvision_model( 43 | name=config.model, 44 | d_out=d_out, 45 | **config.model_kwargs) 46 | 47 | elif 'bert' in config.model: 48 | if featurize: 49 | featurizer = initialize_bert_based_model(config, d_out, featurize) 50 | classifier = nn.Linear(featurizer.d_out, d_out) 51 | model = (featurizer, classifier) 52 | else: 53 | model = initialize_bert_based_model(config, d_out) 54 | 55 | elif config.model == 'resnet18_ms': # multispectral resnet 18 56 | from models.resnet_multispectral import ResNet18 57 | if featurize: 58 | featurizer = ResNet18(num_classes=None, **config.model_kwargs) 59 | classifier = nn.Linear(featurizer.d_out, d_out) 60 | model = (featurizer, classifier) 61 | else: 62 | model = ResNet18(num_classes=d_out, **config.model_kwargs) 63 | 64 | elif config.model == 'gin-virtual': 65 | from models.gnn import GINVirtual 66 | if featurize: 67 | featurizer = GINVirtual(num_tasks=None, **config.model_kwargs) 68 | classifier = nn.Linear(featurizer.d_out, d_out) 69 | model = (featurizer, classifier) 70 | else: 71 | model = GINVirtual(num_tasks=d_out, **config.model_kwargs) 72 | 73 | elif config.model == 'code-gpt-py': 74 | from models.code_gpt import GPT2LMHeadLogit, GPT2FeaturizerLMHeadLogit 75 | from transformers import GPT2Tokenizer 76 | name = 'microsoft/CodeGPT-small-py' 77 | tokenizer = GPT2Tokenizer.from_pretrained(name) 78 | if featurize: 79 | model = GPT2FeaturizerLMHeadLogit.from_pretrained(name) 80 | model.resize_token_embeddings(len(tokenizer)) 81 | featurizer = model.transformer 82 | classifier = model.lm_head 83 | model = (featurizer, classifier) 84 | else: 85 | model = GPT2LMHeadLogit.from_pretrained(name) 86 | model.resize_token_embeddings(len(tokenizer)) 87 | 88 | elif config.model == 'logistic_regression': 89 | assert not featurize, "Featurizer not supported for logistic regression" 90 | model = nn.Linear(out_features=d_out, **config.model_kwargs) 91 | elif config.model == 'unet-seq': 92 | from models.CNN_genome import UNet 93 | if featurize: 94 | featurizer = UNet(num_tasks=None, **config.model_kwargs) 95 | classifier = nn.Linear(featurizer.d_out, d_out) 96 | model = (featurizer, classifier) 97 | else: 98 | model = UNet(num_tasks=d_out, **config.model_kwargs) 99 | 100 | elif config.model == 'fasterrcnn': 101 | if featurize: 102 | raise NotImplementedError('Featurizer not implemented for detection yet') 103 | else: 104 | model = initialize_fasterrcnn_model(config, d_out) 105 | model.needs_y = True 106 | 107 | else: 108 | raise ValueError(f'Model: {config.model} not recognized.') 109 | 110 | # Load pretrained weights from disk using our utils.load function 111 | if config.pretrained_model_path is not None: 112 | if config.model in ('code-gpt-py', 'logistic_regression', 'unet-seq'): 113 | # This has only been tested on some models (mostly vision), so run this code iff we're sure it works 114 | raise NotImplementedError(f"Model loading not yet tested for {config.model}.") 115 | 116 | if 'bert' not in config.model: # We've already loaded pretrained weights for bert-based models using the transformers library 117 | try: 118 | if featurize: 119 | if config.load_featurizer_only: 120 | model_to_load = model[0] 121 | else: 122 | model_to_load = nn.Sequential(*model) 123 | else: 124 | model_to_load = model 125 | 126 | prev_epoch, best_val_metric = load( 127 | model_to_load, 128 | config.pretrained_model_path, 129 | device=config.device) 130 | 131 | print( 132 | (f'Initialized model with pretrained weights from {config.pretrained_model_path} ') 133 | + (f'previously trained for {prev_epoch} epochs ' if prev_epoch else '') 134 | + (f'with previous val metric {best_val_metric} ' if best_val_metric else '') 135 | ) 136 | except Exception as e: 137 | print('Something went wrong loading the pretrained model:') 138 | traceback.print_exc() 139 | raise 140 | 141 | # Recombine model if we originally split it up just for loading 142 | if featurize and not is_featurizer: 143 | model = nn.Sequential(*model) 144 | 145 | # The `needs_y` attribute specifies whether the model's forward function 146 | # needs to take in both (x, y). 147 | # If False, Algorithm.process_batch will call model(x). 148 | # If True, Algorithm.process_batch() will call model(x, y) during training, 149 | # and model(x, None) during eval. 150 | if not hasattr(model, 'needs_y'): 151 | # Sometimes model is a tuple of (featurizer, classifier) 152 | if is_featurizer: 153 | for submodel in model: 154 | submodel.needs_y = False 155 | else: 156 | model.needs_y = False 157 | 158 | return model 159 | 160 | 161 | def initialize_bert_based_model(config, d_out, featurize=False): 162 | from models.bert.bert import BertClassifier, BertFeaturizer 163 | from models.bert.distilbert import DistilBertClassifier, DistilBertFeaturizer 164 | 165 | if config.pretrained_model_path: 166 | print(f'Initialized model with pretrained weights from {config.pretrained_model_path}') 167 | config.model_kwargs['state_dict'] = torch.load(config.pretrained_model_path, map_location=config.device) 168 | 169 | if config.model == 'bert-base-uncased': 170 | if featurize: 171 | model = BertFeaturizer.from_pretrained(config.model, **config.model_kwargs) 172 | else: 173 | model = BertClassifier.from_pretrained( 174 | config.model, 175 | num_labels=d_out, 176 | **config.model_kwargs) 177 | elif config.model == 'distilbert-base-uncased': 178 | if featurize: 179 | model = DistilBertFeaturizer.from_pretrained(config.model, **config.model_kwargs) 180 | else: 181 | model = DistilBertClassifier.from_pretrained( 182 | config.model, 183 | num_labels=d_out, 184 | **config.model_kwargs) 185 | else: 186 | raise ValueError(f'Model: {config.model} not recognized.') 187 | return model 188 | 189 | def initialize_torchvision_model(name, d_out, **kwargs): 190 | import torchvision 191 | 192 | # get constructor and last layer names 193 | if name == 'wideresnet50': 194 | constructor_name = 'wide_resnet50_2' 195 | last_layer_name = 'fc' 196 | elif name == 'densenet121': 197 | constructor_name = name 198 | last_layer_name = 'classifier' 199 | elif name in ('resnet18', 'resnet34', 'resnet50', 'resnet101'): 200 | constructor_name = name 201 | last_layer_name = 'fc' 202 | else: 203 | raise ValueError(f'Torchvision model {name} not recognized') 204 | # construct the default model, which has the default last layer 205 | constructor = getattr(torchvision.models, constructor_name) 206 | model = constructor(**kwargs) 207 | # adjust the last layer 208 | d_features = getattr(model, last_layer_name).in_features 209 | if d_out is None: # want to initialize a featurizer model 210 | last_layer = Identity(d_features) 211 | model.d_out = d_features 212 | else: # want to initialize a classifier for a particular num_classes 213 | last_layer = nn.Linear(d_features, d_out) 214 | model.d_out = d_out 215 | setattr(model, last_layer_name, last_layer) 216 | 217 | return model 218 | 219 | def initialize_fasterrcnn_model(config, d_out): 220 | from models.detection.fasterrcnn import fasterrcnn_resnet50_fpn 221 | 222 | # load a model pre-trained on COCO 223 | model = fasterrcnn_resnet50_fpn( 224 | pretrained=config.model_kwargs["pretrained_model"], 225 | pretrained_backbone=config.model_kwargs["pretrained_backbone"], 226 | num_classes=d_out, 227 | min_size=config.model_kwargs["min_size"], 228 | max_size=config.model_kwargs["max_size"] 229 | ) 230 | 231 | return model 232 | -------------------------------------------------------------------------------- /experiments/camelyon17/networks/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from collections import OrderedDict 7 | import os 8 | 9 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 10 | 11 | os.environ['TORCH_HOME'] = '/research/pheng4/qdliu/anaconda/envs/pytorch/models' 12 | 13 | model_urls = { 14 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 15 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 16 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 17 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 18 | } 19 | 20 | 21 | class _DenseLayer(nn.Sequential): 22 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 23 | super(_DenseLayer, self).__init__() 24 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 25 | self.add_module('relu1', nn.ReLU(inplace=True)), 26 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 27 | growth_rate, kernel_size=1, stride=1, bias=False)), 28 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 29 | self.add_module('relu2', nn.ReLU(inplace=True)), 30 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 31 | kernel_size=3, stride=1, padding=1, bias=False)), 32 | self.drop_rate = drop_rate 33 | self.drop_layer = nn.Dropout(p=drop_rate) 34 | def forward(self, x): 35 | new_features = super(_DenseLayer, self).forward(x) 36 | # if self.drop_rate > 0: 37 | # print (self.drop_rate) 38 | # new_features = self.drop_layer(new_features) 39 | return torch.cat([x, new_features], 1) 40 | 41 | 42 | class _DenseBlock(nn.Sequential): 43 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 44 | super(_DenseBlock, self).__init__() 45 | for i in range(num_layers): 46 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 47 | self.add_module('denselayer%d' % (i + 1), layer) 48 | 49 | 50 | class _Transition(nn.Sequential): 51 | def __init__(self, num_input_features, num_output_features): 52 | super(_Transition, self).__init__() 53 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 54 | self.add_module('relu', nn.ReLU(inplace=True)) 55 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 56 | kernel_size=1, stride=1, bias=False)) 57 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 58 | 59 | 60 | class DenseNet(nn.Module): 61 | r"""Densenet-BC model class, based on 62 | `"Densely Connected Convolutional Networks" `_ 63 | Args: 64 | growth_rate (int) - how many filters to add each layer (`k` in paper) 65 | block_config (list of 4 ints) - how many layers in each pooling block 66 | num_init_features (int) - the number of filters to learn in the first convolution layer 67 | bn_size (int) - multiplicative factor for number of bottle neck layers 68 | (i.e. bn_size * k features in the bottleneck layer) 69 | drop_rate (float) - dropout rate after each dense layer 70 | num_classes (int) - number of classification classes 71 | """ 72 | 73 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 74 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 75 | 76 | super(DenseNet, self).__init__() 77 | 78 | # First convolution 79 | self.features = nn.Sequential(OrderedDict([ 80 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 81 | ('norm0', nn.BatchNorm2d(num_init_features)), 82 | ('relu0', nn.ReLU(inplace=True)), 83 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 84 | ])) 85 | 86 | # Each denseblock 87 | num_features = num_init_features 88 | for i, num_layers in enumerate(block_config): 89 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 90 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 91 | self.features.add_module('denseblock%d' % (i + 1), block) 92 | num_features = num_features + num_layers * growth_rate 93 | if i != len(block_config) - 1: 94 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 95 | self.features.add_module('transition%d' % (i + 1), trans) 96 | num_features = num_features // 2 97 | 98 | # Final batch norm 99 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 100 | 101 | # Linear layer 102 | self.classifier = nn.Linear(num_features, num_classes) 103 | 104 | # Official init from torch repo. 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | nn.init.kaiming_normal_(m.weight) 108 | elif isinstance(m, nn.BatchNorm2d): 109 | nn.init.constant_(m.weight, 1) 110 | nn.init.constant_(m.bias, 0) 111 | elif isinstance(m, nn.Linear): 112 | nn.init.constant_(m.bias, 0) 113 | 114 | def forward(self, x): 115 | features = self.features(x) 116 | out = F.relu(features, inplace=True) 117 | print(out.size()) 118 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) 119 | print(out.size()) 120 | out = self.classifier(out) 121 | return out 122 | 123 | 124 | def densenet121(pretrained=False, **kwargs): 125 | r"""Densenet-121 model from 126 | `"Densely Connected Convolutional Networks" `_ 127 | Args: 128 | pretrained (bool): If True, returns a model pre-trained on ImageNet 129 | """ 130 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 131 | **kwargs) 132 | if pretrained: 133 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 134 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 135 | # They are also in the checkpoints in model_urls. This pattern is used 136 | # to find such keys. 137 | pattern = re.compile( 138 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 139 | state_dict = model_zoo.load_url(model_urls['densenet121']) 140 | for key in list(state_dict.keys()): 141 | res = pattern.match(key) 142 | if res: 143 | new_key = res.group(1) + res.group(2) 144 | state_dict[new_key] = state_dict[key] 145 | del state_dict[key] 146 | model.load_state_dict(state_dict) 147 | return model 148 | 149 | 150 | def densenet169(pretrained=False, **kwargs): 151 | r"""Densenet-169 model from 152 | `"Densely Connected Convolutional Networks" `_ 153 | Args: 154 | pretrained (bool): If True, returns a model pre-trained on ImageNet 155 | """ 156 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 157 | **kwargs) 158 | if pretrained: 159 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 160 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 161 | # They are also in the checkpoints in model_urls. This pattern is used 162 | # to find such keys. 163 | pattern = re.compile( 164 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 165 | state_dict = model_zoo.load_url(model_urls['densenet169']) 166 | for key in list(state_dict.keys()): 167 | res = pattern.match(key) 168 | if res: 169 | new_key = res.group(1) + res.group(2) 170 | state_dict[new_key] = state_dict[key] 171 | del state_dict[key] 172 | model.load_state_dict(state_dict) 173 | return model 174 | 175 | 176 | def densenet201(pretrained=False, **kwargs): 177 | r"""Densenet-201 model from 178 | `"Densely Connected Convolutional Networks" `_ 179 | Args: 180 | pretrained (bool): If True, returns a model pre-trained on ImageNet 181 | """ 182 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 183 | **kwargs) 184 | if pretrained: 185 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 186 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 187 | # They are also in the checkpoints in model_urls. This pattern is used 188 | # to find such keys. 189 | pattern = re.compile( 190 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 191 | state_dict = model_zoo.load_url(model_urls['densenet201']) 192 | for key in list(state_dict.keys()): 193 | res = pattern.match(key) 194 | if res: 195 | new_key = res.group(1) + res.group(2) 196 | state_dict[new_key] = state_dict[key] 197 | del state_dict[key] 198 | model.load_state_dict(state_dict) 199 | return model 200 | 201 | 202 | def densenet161(pretrained=False, **kwargs): 203 | r"""Densenet-161 model from 204 | `"Densely Connected Convolutional Networks" `_ 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 209 | **kwargs) 210 | if pretrained: 211 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 212 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 213 | # They are also in the checkpoints in model_urls. This pattern is used 214 | # to find such keys. 215 | pattern = re.compile( 216 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 217 | state_dict = model_zoo.load_url(model_urls['densenet161']) 218 | for key in list(state_dict.keys()): 219 | res = pattern.match(key) 220 | if res: 221 | new_key = res.group(1) + res.group(2) 222 | state_dict[new_key] = state_dict[key] 223 | del state_dict[key] 224 | model.load_state_dict(state_dict) 225 | return model -------------------------------------------------------------------------------- /experiments/camelyon17/train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from tqdm import tqdm 4 | import math 5 | 6 | from configs.supported import process_outputs_functions, process_pseudolabels_functions 7 | from utils import save_model, save_pred, get_pred_prefix, get_model_prefix, collate_list, detach_and_clone, InfiniteDataIterator 8 | 9 | def run_epoch(algorithm, dataset, general_logger, epoch, config, train, unlabeled_dataset=None): 10 | if dataset['verbose']: 11 | general_logger.write(f"\n{dataset['name']}:\n") 12 | 13 | if train: 14 | algorithm.train() 15 | torch.set_grad_enabled(True) 16 | else: 17 | algorithm.eval() 18 | torch.set_grad_enabled(False) 19 | 20 | # Not preallocating memory is slower 21 | # but makes it easier to handle different types of data loaders 22 | # (which might not return exactly the same number of examples per epoch) 23 | epoch_y_true = [] 24 | epoch_y_pred = [] 25 | epoch_metadata = [] 26 | 27 | # Assert that data loaders are defined for the datasets 28 | assert 'loader' in dataset, "A data loader must be defined for the dataset." 29 | if unlabeled_dataset: 30 | assert 'loader' in unlabeled_dataset, "A data loader must be defined for the dataset." 31 | 32 | batches = dataset['loader'] 33 | if config.progress_bar: 34 | batches = tqdm(batches) 35 | last_batch_idx = len(batches)-1 36 | 37 | if unlabeled_dataset: 38 | unlabeled_data_iterator = InfiniteDataIterator(unlabeled_dataset['loader']) 39 | 40 | # Using enumerate(iterator) can sometimes leak memory in some environments (!) 41 | # so we manually increment batch_idx 42 | batch_idx = 0 43 | for labeled_batch in batches: 44 | if train: 45 | if unlabeled_dataset: 46 | unlabeled_batch = next(unlabeled_data_iterator) 47 | batch_results = algorithm.update(labeled_batch, unlabeled_batch, is_epoch_end=(batch_idx==last_batch_idx)) 48 | else: 49 | batch_results = algorithm.update(labeled_batch, is_epoch_end=(batch_idx==last_batch_idx)) 50 | else: 51 | batch_results = algorithm.evaluate(labeled_batch) 52 | 53 | # These tensors are already detached, but we need to clone them again 54 | # Otherwise they don't get garbage collected properly in some versions 55 | # The extra detach is just for safety 56 | # (they should already be detached in batch_results) 57 | epoch_y_true.append(detach_and_clone(batch_results['y_true'])) 58 | y_pred = detach_and_clone(batch_results['y_pred']) 59 | if config.process_outputs_function is not None: 60 | y_pred = process_outputs_functions[config.process_outputs_function](y_pred) 61 | epoch_y_pred.append(y_pred) 62 | epoch_metadata.append(detach_and_clone(batch_results['metadata'])) 63 | 64 | if train: 65 | effective_batch_idx = (batch_idx + 1) / config.gradient_accumulation_steps 66 | else: 67 | effective_batch_idx = batch_idx + 1 68 | 69 | if train and effective_batch_idx % config.log_every==0: 70 | log_results(algorithm, dataset, general_logger, epoch, math.ceil(effective_batch_idx)) 71 | 72 | batch_idx += 1 73 | 74 | epoch_y_pred = collate_list(epoch_y_pred) 75 | epoch_y_true = collate_list(epoch_y_true) 76 | epoch_metadata = collate_list(epoch_metadata) 77 | 78 | results, results_str = dataset['dataset'].eval( 79 | epoch_y_pred, 80 | epoch_y_true, 81 | epoch_metadata) 82 | 83 | if config.scheduler_metric_split==dataset['split']: 84 | algorithm.step_schedulers( 85 | is_epoch=True, 86 | metrics=results, 87 | log_access=(not train)) 88 | 89 | # log after updating the scheduler in case it needs to access the internal logs 90 | log_results(algorithm, dataset, general_logger, epoch, math.ceil(effective_batch_idx)) 91 | 92 | results['epoch'] = epoch 93 | dataset['eval_logger'].log(results) 94 | if dataset['verbose']: 95 | general_logger.write('Epoch eval:\n') 96 | general_logger.write(results_str) 97 | 98 | return results, epoch_y_pred 99 | 100 | 101 | def train(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric, unlabeled_dataset=None): 102 | """ 103 | Train loop that, each epoch: 104 | - Steps an algorithm on the datasets['train'] split and the unlabeled split 105 | - Evaluates the algorithm on the datasets['val'] split 106 | - Saves models / preds with frequency according to the configs 107 | - Evaluates on any other specified splits in the configs 108 | Assumes that the datasets dict contains labeled data. 109 | """ 110 | for epoch in range(epoch_offset, config.n_epochs): 111 | general_logger.write('\nEpoch [%d]:\n' % epoch) 112 | 113 | # First run training 114 | run_epoch(algorithm, datasets['train'], general_logger, epoch, config, train=True, unlabeled_dataset=unlabeled_dataset) 115 | 116 | # Then run val 117 | val_results, y_pred = run_epoch(algorithm, datasets['val'], general_logger, epoch, config, train=False) 118 | curr_val_metric = val_results[config.val_metric] 119 | general_logger.write(f'Validation {config.val_metric}: {curr_val_metric:.3f}\n') 120 | 121 | if best_val_metric is None: 122 | is_best = True 123 | else: 124 | if config.val_metric_decreasing: 125 | is_best = curr_val_metric < best_val_metric 126 | else: 127 | is_best = curr_val_metric > best_val_metric 128 | if is_best: 129 | best_val_metric = curr_val_metric 130 | general_logger.write(f'Epoch {epoch} has the best validation performance so far.\n') 131 | 132 | save_model_if_needed(algorithm, datasets['val'], epoch, config, is_best, best_val_metric) 133 | save_pred_if_needed(y_pred, datasets['val'], epoch, config, is_best) 134 | 135 | # Then run everything else 136 | if config.evaluate_all_splits: 137 | additional_splits = [split for split in datasets.keys() if split not in ['train','val']] 138 | else: 139 | additional_splits = config.eval_splits 140 | for split in additional_splits: 141 | _, y_pred = run_epoch(algorithm, datasets[split], general_logger, epoch, config, train=False) 142 | save_pred_if_needed(y_pred, datasets[split], epoch, config, is_best) 143 | 144 | general_logger.write('\n') 145 | 146 | 147 | def evaluate(algorithm, datasets, epoch, general_logger, config, is_best): 148 | algorithm.eval() 149 | torch.set_grad_enabled(False) 150 | for split, dataset in datasets.items(): 151 | if (not config.evaluate_all_splits) and (split not in config.eval_splits): 152 | continue 153 | epoch_y_true = [] 154 | epoch_y_pred = [] 155 | epoch_metadata = [] 156 | iterator = tqdm(dataset['loader']) if config.progress_bar else dataset['loader'] 157 | for batch in iterator: 158 | batch_results = algorithm.evaluate(batch) 159 | epoch_y_true.append(detach_and_clone(batch_results['y_true'])) 160 | y_pred = detach_and_clone(batch_results['y_pred']) 161 | if config.process_outputs_function is not None: 162 | y_pred = process_outputs_functions[config.process_outputs_function](y_pred) 163 | epoch_y_pred.append(y_pred) 164 | epoch_metadata.append(detach_and_clone(batch_results['metadata'])) 165 | 166 | epoch_y_pred = collate_list(epoch_y_pred) 167 | epoch_y_true = collate_list(epoch_y_true) 168 | epoch_metadata = collate_list(epoch_metadata) 169 | results, results_str = dataset['dataset'].eval( 170 | epoch_y_pred, 171 | epoch_y_true, 172 | epoch_metadata) 173 | 174 | results['epoch'] = epoch 175 | dataset['eval_logger'].log(results) 176 | general_logger.write(f'Eval split {split} at epoch {epoch}:\n') 177 | general_logger.write(results_str) 178 | 179 | # Skip saving train preds, since the train loader generally shuffles the data 180 | if split != 'train': 181 | save_pred_if_needed(epoch_y_pred, dataset, epoch, config, is_best, force_save=True) 182 | 183 | def infer_predictions(model, loader, config): 184 | """ 185 | Simple inference loop that performs inference using a model (not algorithm) and returns model outputs. 186 | Compatible with both labeled and unlabeled WILDS datasets. 187 | """ 188 | model.eval() 189 | y_pred = [] 190 | iterator = tqdm(loader) if config.progress_bar else loader 191 | for batch in iterator: 192 | x = batch[0] 193 | x = x.to(config.device) 194 | with torch.no_grad(): 195 | output = model(x) 196 | if not config.soft_pseudolabels and config.process_pseudolabels_function is not None: 197 | _, output, _, _ = process_pseudolabels_functions[config.process_pseudolabels_function]( 198 | output, 199 | confidence_threshold=config.self_training_threshold if config.dataset == 'globalwheat' else 0 200 | ) 201 | elif config.soft_pseudolabels: 202 | output = torch.nn.functional.softmax(output, dim=1) 203 | if isinstance(output, list): 204 | y_pred.extend(detach_and_clone(output)) 205 | else: 206 | y_pred.append(detach_and_clone(output)) 207 | 208 | return torch.cat(y_pred, 0) if torch.is_tensor(y_pred[0]) else y_pred 209 | 210 | def log_results(algorithm, dataset, general_logger, epoch, effective_batch_idx): 211 | if algorithm.has_log: 212 | log = algorithm.get_log() 213 | log['epoch'] = epoch 214 | log['batch'] = effective_batch_idx 215 | dataset['algo_logger'].log(log) 216 | if dataset['verbose']: 217 | general_logger.write(algorithm.get_pretty_log_str()) 218 | algorithm.reset_log() 219 | 220 | 221 | def save_pred_if_needed(y_pred, dataset, epoch, config, is_best, force_save=False): 222 | if config.save_pred: 223 | prefix = get_pred_prefix(dataset, config) 224 | if force_save or (config.save_step is not None and (epoch + 1) % config.save_step == 0): 225 | save_pred(y_pred, prefix + f'epoch:{epoch}_pred') 226 | if (not force_save) and config.save_last: 227 | save_pred(y_pred, prefix + f'epoch:last_pred') 228 | if config.save_best and is_best: 229 | save_pred(y_pred, prefix + f'epoch:best_pred') 230 | 231 | 232 | def save_model_if_needed(algorithm, dataset, epoch, config, is_best, best_val_metric): 233 | prefix = get_model_prefix(dataset, config) 234 | if config.save_step is not None and (epoch + 1) % config.save_step == 0: 235 | save_model(algorithm, epoch, best_val_metric, prefix + f'epoch:{epoch}_model.pth') 236 | if config.save_last: 237 | save_model(algorithm, epoch, best_val_metric, prefix + 'epoch:last_model.pth') 238 | if config.save_best and is_best: 239 | save_model(algorithm, epoch, best_val_metric, prefix + 'epoch:best_model.pth') 240 | -------------------------------------------------------------------------------- /experiments/camelyon17/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import urllib.request 5 | from ast import literal_eval 6 | from typing import Dict, List 7 | from urllib.parse import urlparse 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from wilds import benchmark_datasets 13 | from wilds import get_dataset 14 | from wilds.datasets.wilds_dataset import WILDSDataset, WILDSSubset 15 | 16 | 17 | """ 18 | Evaluate predictions for WILDS datasets. 19 | 20 | Usage: 21 | 22 | python examples/evaluate.py 23 | python examples/evaluate.py --dataset 24 | 25 | """ 26 | 27 | 28 | def evaluate_all_benchmarks(predictions_dir: str, output_dir: str, root_dir: str): 29 | """ 30 | Evaluate predictions for all the WILDS benchmarks. 31 | 32 | Parameters: 33 | predictions_dir (str): Path to the directory with predictions. Can be a URL 34 | output_dir (str): Output directory 35 | root_dir (str): The directory where datasets can be found 36 | """ 37 | all_results: Dict[str, Dict[str, Dict[str, float]]] = dict() 38 | for dataset in benchmark_datasets: 39 | try: 40 | all_results[dataset] = evaluate_benchmark( 41 | dataset, os.path.join(predictions_dir, dataset), output_dir, root_dir 42 | ) 43 | except Exception as e: 44 | print(f"Could not evaluate predictions for {dataset}:\n{str(e)}") 45 | 46 | # Write out aggregated results to output file 47 | print(f"Writing complete results to {output_dir}...") 48 | with open(os.path.join(output_dir, "all_results.json"), "w") as f: 49 | json.dump(all_results, f, indent=4) 50 | 51 | 52 | def evaluate_benchmark( 53 | dataset_name: str, predictions_dir: str, output_dir: str, root_dir: str 54 | ) -> Dict[str, Dict[str, float]]: 55 | """ 56 | Evaluate across multiple replicates for a single benchmark. 57 | 58 | Parameters: 59 | dataset_name (str): Name of the dataset. See datasets.py for the complete list of datasets. 60 | predictions_dir (str): Path to the directory with predictions. Can be a URL. 61 | output_dir (str): Output directory 62 | root_dir (str): The directory where datasets can be found 63 | 64 | Returns: 65 | Metrics as a dictionary with metrics as the keys and metric values as the values 66 | """ 67 | 68 | def get_replicates(dataset_name: str) -> List[str]: 69 | if dataset_name == "poverty": 70 | return [f"fold:{fold}" for fold in ["A", "B", "C", "D", "E"]] 71 | else: 72 | if dataset_name == "camelyon17": 73 | seeds = range(0, 10) 74 | elif dataset_name == "civilcomments": 75 | seeds = range(0, 5) 76 | else: 77 | seeds = range(0, 3) 78 | return [f"seed:{seed}" for seed in seeds] 79 | 80 | def get_prediction_file( 81 | predictions_dir: str, dataset_name: str, split: str, replicate: str 82 | ) -> str: 83 | run_id = f"{dataset_name}_split:{split}_{replicate}" 84 | for file in os.listdir(predictions_dir): 85 | if file.startswith(run_id) and ( 86 | file.endswith(".csv") or file.endswith(".pth") 87 | ): 88 | return file 89 | raise FileNotFoundError( 90 | f"Could not find CSV or pth prediction file that starts with {run_id}." 91 | ) 92 | 93 | def get_metrics(dataset_name: str) -> List[str]: 94 | if "amazon" == dataset_name: 95 | return ["10th_percentile_acc", "acc_avg"] 96 | elif "camelyon17" == dataset_name: 97 | return ["acc_avg"] 98 | elif "civilcomments" == dataset_name: 99 | return ["acc_wg", "acc_avg"] 100 | elif "fmow" == dataset_name: 101 | return ["acc_worst_region", "acc_avg"] 102 | elif "iwildcam" == dataset_name: 103 | return ["F1-macro_all", "acc_avg"] 104 | elif "ogb-molpcba" == dataset_name: 105 | return ["ap"] 106 | elif "poverty" == dataset_name: 107 | return ["r_wg", "r_all"] 108 | elif "py150" == dataset_name: 109 | return ["acc", "Acc (Overall)"] 110 | elif "globalwheat" == dataset_name: 111 | return ["detection_acc_avg_dom"] 112 | elif "rxrx1" == dataset_name: 113 | return ["acc_avg"] 114 | else: 115 | raise ValueError(f"Invalid dataset: {dataset_name}") 116 | 117 | # Dataset will only be downloaded if it does not exist 118 | wilds_dataset: WILDSDataset = get_dataset( 119 | dataset=dataset_name, root_dir=root_dir, download=True 120 | ) 121 | splits: List[str] = list(wilds_dataset.split_dict.keys()) 122 | if "train" in splits: 123 | splits.remove("train") 124 | 125 | replicates_results: Dict[str, Dict[str, List[float]]] = dict() 126 | replicates: List[str] = get_replicates(dataset_name) 127 | metrics: List[str] = get_metrics(dataset_name) 128 | 129 | # Store the results for each replicate 130 | for split in splits: 131 | replicates_results[split] = {} 132 | for metric in metrics: 133 | replicates_results[split][metric] = [] 134 | 135 | for replicate in replicates: 136 | predictions_file = get_prediction_file( 137 | predictions_dir, dataset_name, split, replicate 138 | ) 139 | print( 140 | f"Processing split={split}, replicate={replicate}, predictions_file={predictions_file}..." 141 | ) 142 | full_path = os.path.join(predictions_dir, predictions_file) 143 | 144 | # GlobalWheat's predictions are a list of dictionaries, so it has to be handled separately 145 | if dataset_name == "globalwheat": 146 | metric_results: Dict[str, float] = evaluate_replicate_for_globalwheat( 147 | wilds_dataset, split, full_path 148 | ) 149 | else: 150 | predicted_labels: torch.Tensor = get_predictions(full_path) 151 | if dataset_name == "poverty": 152 | # Poverty is special because we need to pass in the fold when calling `get_dataset` 153 | # e.g., {"fold": "A"} 154 | dataset_kwargs = {"fold": replicate.split(":")[1]} 155 | wilds_dataset: WILDSDataset = get_dataset( 156 | dataset=dataset_name, root_dir=root_dir, download=True, **dataset_kwargs 157 | ) 158 | 159 | metric_results = evaluate_replicate( 160 | wilds_dataset, split, predicted_labels 161 | ) 162 | for metric in metrics: 163 | replicates_results[split][metric].append(metric_results[metric]) 164 | 165 | aggregated_results: Dict[str, Dict[str, float]] = dict() 166 | 167 | # Aggregate results of replicates 168 | for split in splits: 169 | aggregated_results[split] = {} 170 | for metric in metrics: 171 | replicates_metric_values: List[float] = replicates_results[split][metric] 172 | aggregated_results[split][f"{metric}_std"] = np.std( 173 | replicates_metric_values, ddof=1 174 | ) 175 | aggregated_results[split][metric] = np.mean(replicates_metric_values) 176 | 177 | # Write out aggregated results to output file 178 | print(f"Writing aggregated results for {dataset_name} to {output_dir}...") 179 | with open(os.path.join(output_dir, f"{dataset_name}_results.json"), "w") as f: 180 | json.dump(aggregated_results, f, indent=4) 181 | 182 | return aggregated_results 183 | 184 | 185 | def evaluate_replicate( 186 | dataset: WILDSDataset, split: str, predicted_labels: torch.Tensor 187 | ) -> Dict[str, float]: 188 | """ 189 | Evaluate the given predictions and return the appropriate metrics. 190 | 191 | Parameters: 192 | dataset (WILDSDataset): A WILDS Dataset 193 | split (str): split we are evaluating on 194 | predicted_labels (torch.Tensor): Predictions 195 | 196 | Returns: 197 | Metrics as a dictionary with metrics as the keys and metric values as the values 198 | """ 199 | # Dataset will only be downloaded if it does not exist 200 | subset: WILDSSubset = dataset.get_subset(split) 201 | metadata: torch.Tensor = subset.metadata_array 202 | true_labels = subset.y_array 203 | if predicted_labels.shape != true_labels.shape: 204 | predicted_labels.unsqueeze_(-1) 205 | return dataset.eval(predicted_labels, true_labels, metadata)[0] 206 | 207 | 208 | def evaluate_replicate_for_globalwheat( 209 | dataset: WILDSDataset, split: str, path_to_predictions: str 210 | ) -> Dict[str, float]: 211 | predicted_labels = torch.load(path_to_predictions) 212 | subset: WILDSSubset = dataset.get_subset(split) 213 | metadata: torch.Tensor = subset.metadata_array 214 | true_labels = [subset.dataset.y_array[idx] for idx in subset.indices] 215 | return dataset.eval(predicted_labels, true_labels, metadata)[0] 216 | 217 | 218 | def get_predictions(path: str) -> torch.Tensor: 219 | """ 220 | Extract out the predictions from the file at path. 221 | 222 | Parameters: 223 | path (str): Path to the file that has the predicted labels. Can be a URL. 224 | 225 | Return: 226 | Tensor representing predictions 227 | """ 228 | if is_path_url(path): 229 | data = urllib.request.urlopen(path) 230 | else: 231 | file = open(path, mode="r") 232 | data = file.readlines() 233 | file.close() 234 | 235 | predicted_labels = [literal_eval(line.rstrip()) for line in data if line.rstrip()] 236 | return torch.from_numpy(np.array(predicted_labels)) 237 | 238 | 239 | def is_path_url(path: str) -> bool: 240 | """ 241 | Returns True if the path is a URL. 242 | """ 243 | try: 244 | result = urlparse(path) 245 | return all([result.scheme, result.netloc, result.path]) 246 | except: 247 | return False 248 | 249 | 250 | def main(): 251 | if args.dataset: 252 | evaluate_benchmark( 253 | args.dataset, args.predictions_dir, args.output_dir, args.root_dir 254 | ) 255 | else: 256 | print("A dataset was not specified. Evaluating for all WILDS datasets...") 257 | evaluate_all_benchmarks(args.predictions_dir, args.output_dir, args.root_dir) 258 | print("\nDone.") 259 | 260 | 261 | if __name__ == "__main__": 262 | parser = argparse.ArgumentParser( 263 | description="Evaluate predictions for WILDS datasets." 264 | ) 265 | parser.add_argument( 266 | "predictions_dir", 267 | type=str, 268 | help="Path to prediction CSV or pth files.", 269 | ) 270 | parser.add_argument( 271 | "output_dir", 272 | type=str, 273 | help="Path to output directory.", 274 | ) 275 | parser.add_argument( 276 | "--dataset", 277 | type=str, 278 | choices=benchmark_datasets, 279 | help="WILDS dataset to evaluate for.", 280 | ) 281 | parser.add_argument( 282 | "--root-dir", 283 | type=str, 284 | default="data", 285 | help="The directory where the datasets can be found (or should be downloaded to, if they do not exist).", 286 | ) 287 | 288 | # Parse args and run this script 289 | args = parser.parse_args() 290 | main() 291 | --------------------------------------------------------------------------------