├── AttentionMap.png ├── CS-MIL_Docker ├── CS-MIL_docker_commandline.txt ├── Dockerfile ├── requirements.txt └── src │ ├── DeepAttnMISL_CS_MIL.py │ ├── GCA512_0411_train4w.yaml │ ├── MIL_dataloader_csv_MAg_clustering_pair_stack.py │ ├── arguments.py │ ├── augmentations │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── byol_aug.cpython-36.pyc │ │ ├── byol_aug.cpython-38.pyc │ │ ├── eval_aug.cpython-36.pyc │ │ ├── eval_aug.cpython-38.pyc │ │ ├── simclr_aug.cpython-36.pyc │ │ ├── simclr_aug.cpython-38.pyc │ │ ├── simsiam_aug.cpython-36.pyc │ │ └── simsiam_aug.cpython-38.pyc │ ├── byol_aug.py │ ├── eval_aug.py │ ├── gaussian_blur.py │ ├── simclr_aug.py │ ├── simsiam_aug.py │ └── swav_aug.py │ ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── random_dataset.cpython-36.pyc │ │ └── random_dataset.cpython-38.pyc │ └── random_dataset.py │ ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── byol.cpython-36.pyc │ │ ├── byol.cpython-38.pyc │ │ ├── simclr.cpython-36.pyc │ │ ├── simclr.cpython-38.pyc │ │ ├── simsiam.cpython-36.pyc │ │ └── simsiam.cpython-38.pyc │ ├── backbones │ │ ├── TCGA_eval_onehot.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── cifar_resnet_1.cpython-36.pyc │ │ │ ├── cifar_resnet_1.cpython-38.pyc │ │ │ ├── cifar_resnet_2.cpython-36.pyc │ │ │ └── cifar_resnet_2.cpython-38.pyc │ │ ├── cifar_resnet_1.py │ │ └── cifar_resnet_2.py │ ├── byol.py │ ├── simclr.py │ ├── simsiam.py │ └── swav.py │ ├── optimizers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── larc.cpython-36.pyc │ │ ├── larc.cpython-38.pyc │ │ ├── lars.cpython-36.pyc │ │ ├── lars.cpython-38.pyc │ │ ├── lars_simclr.cpython-36.pyc │ │ ├── lars_simclr.cpython-38.pyc │ │ ├── lr_scheduler.cpython-36.pyc │ │ └── lr_scheduler.cpython-38.pyc │ ├── larc.py │ ├── lars.py │ ├── lars_simclr.py │ └── lr_scheduler.py │ ├── run_inference.py │ └── tools │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── accuracy.cpython-36.pyc │ ├── accuracy.cpython-38.pyc │ ├── average_meter.cpython-36.pyc │ ├── average_meter.cpython-38.pyc │ ├── file_exist_fn.cpython-36.pyc │ ├── file_exist_fn.cpython-38.pyc │ ├── knn_monitor.cpython-36.pyc │ ├── knn_monitor.cpython-38.pyc │ ├── logger.cpython-36.pyc │ ├── logger.cpython-38.pyc │ ├── plotter.cpython-36.pyc │ └── plotter.cpython-38.pyc │ ├── accuracy.py │ ├── average_meter.py │ ├── file_exist_fn.py │ ├── knn_monitor.py │ ├── logger.py │ └── plotter.py ├── Cross-modality.png ├── Cross-scale.png ├── Cross_modality.png ├── Emb_Clustering_Code ├── LICENSE ├── MicrosoftVision.ResNet50.tar ├── arguments.py ├── configs │ └── GCA512_0411_train4w.yaml ├── create_kmeans_features_local_multi-resolution.py ├── create_kmeans_features_local_singleresolution.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── random_dataset.cpython-36.pyc │ │ └── random_dataset.cpython-38.pyc │ └── random_dataset.py ├── get_features_simsiam_1024.py ├── get_features_simsiam_256.py ├── get_features_simsiam_512.py ├── get_features_simsiam_multi-resolution.py ├── main_mixprecision.py └── optimizers │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── larc.cpython-36.pyc │ ├── larc.cpython-38.pyc │ ├── lars.cpython-36.pyc │ ├── lars.cpython-38.pyc │ ├── lars_simclr.cpython-36.pyc │ ├── lars_simclr.cpython-38.pyc │ ├── lr_scheduler.cpython-36.pyc │ └── lr_scheduler.cpython-38.pyc │ ├── larc.py │ ├── lars.py │ ├── lars_simclr.py │ └── lr_scheduler.py ├── LICENSE ├── Pipeline.png ├── README.md ├── Toydataset.png ├── ToydatasetResults.png ├── Toydataset_Code ├── cs-mil-toydataset │ ├── AttentionMapping_dataset1.py │ ├── AttentionMapping_dataset2.py │ ├── Background_filter.py │ ├── Boxplot_dataset1.py │ ├── Boxplot_dataset1_classmark.py │ ├── Boxplot_dataset2_classmark.py │ ├── Classifier_model_MAg.py │ ├── DeepAttnMISL_model.py │ ├── DeepAttnMISL_model_no21.py │ ├── DeepAttnMISL_model_no21_attentionscore.py │ ├── DeepAttnMISL_model_no21_withResnet.py │ ├── GetDataList.py │ ├── LICENSE │ ├── MILDataset.py │ ├── MIL_dataloader_csv_data1.py │ ├── MIL_dataloader_csv_data2.py │ ├── MIL_dataloader_csv_image.py │ ├── MIL_dataloader_image.py │ ├── MIL_global_Stage1_Testing_all.py │ ├── MIL_main.py │ ├── MIL_main_DeepSurv.py │ ├── MIL_main_DeepSurv_batch_dataset1_getattention.py │ ├── MIL_main_DeepSurv_batch_dataset2_getattention.py │ ├── MIL_main_DeepSurv_dataset1.py │ ├── MIL_main_DeepSurv_dataset2.py │ ├── Patch_select_MIL_dataloader_csv_MAg_clustering_check_stack.py │ ├── Patch_select_MIL_global_Stage1_attention_cross.py │ ├── ROCAUC_dataset1_classmark.py │ ├── ROCAUC_dataset2_classmark.py │ ├── __init__.py │ ├── create_kmeans.py │ ├── data_split.py │ ├── data_split_tranvalbalanced.py │ ├── dataloader.py │ └── mnist_bags_loader.py └── data_processing │ ├── MIL_bag_generation.py │ ├── MIL_bag_generation_forattention_data1.py │ ├── MIL_bag_generation_forattention_data2.py │ ├── MIL_main_DeepSurv_batch_dataset1_getattention.py │ └── MIL_main_DeepSurv_batch_dataset2_getattention.py └── Train_Test_Code ├── Background_filter.py ├── Classifier_model_MAg.py ├── DeepAttnMISL_model.py ├── DeepAttnMISL_model_no21.py ├── GetDataList.py ├── MILDataset.py ├── MIL_dataloader.py ├── MIL_dataloader_csv.py ├── MIL_dataloader_csv_MAg.py ├── MIL_dataloader_csv_MAg_clustering.py ├── MIL_dataloader_csv_MAg_clustering_pair_stack.py ├── MIL_dataloader_csv_clustering.py ├── MIL_dataloader_csv_clustering_pair_stack.py ├── MIL_global_Stage1_Testing.py ├── MIL_global_Stage1_Training.py ├── Patch_select_MIL_dataloader_csv_MAg_clustering_check_stack.py ├── Patch_select_MIL_global_Stage1_attention_cross.py ├── Regions_to_multiscale_patches.py ├── __init__.py ├── create_kmeans.py ├── data_split.py ├── data_split_tranvalbalanced.py ├── dataloader.py ├── mnist_bags_loader.py ├── model.py └── model3.py /AttentionMap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/AttentionMap.png -------------------------------------------------------------------------------- /CS-MIL_Docker/CS-MIL_docker_commandline.txt: -------------------------------------------------------------------------------- 1 | # create the docker 2 | docker build -f Dockerfile -t cs-mil . 3 | docker login 4 | docker tag cs-mil ddrrnn123/cs-mil:2.0 5 | docker push ddrrnn123/cs-mil:2.0 6 | 7 | # run the code online with gpu 8 | docker run --rm -v /Data2/CS-MIL_data/input:/input/:ro -v /Data2/CS-MIL_data/output:/output --gpus all -it ddrrnn123/cs-mil:2.0 9 | 10 | 11 | # run the docker locally with gpu 12 | docker run --rm -v /Data2/CS-MIL_data/input:/input/:ro -v /Data2/CS-MIL_data/output:/output --gpus all -it cs-mil:2.0 13 | 14 | -------------------------------------------------------------------------------- /CS-MIL_Docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ## Pull from existing image 2 | FROM nvcr.io/nvidia/pytorch:21.05-py3 3 | 4 | ## Copy requirements 5 | COPY ./requirements.txt . 6 | 7 | ## Install Python packages in Docker image 8 | 9 | RUN pip3 install --upgrade pip 10 | RUN pip3 install -r requirements.txt 11 | RUN pip3 install "opencv-python-headless<4.3" 12 | RUN pip3 install openslide-python 13 | 14 | RUN apt-get update \ 15 | && DEBIAN_FRONTEND="noninteractive" apt-get install -y libopenslide0 \ 16 | && rm -rf /var/lib/apt/lists/* 17 | 18 | #RUN export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libffi.so.7 19 | RUN export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libopenslide.so.0 20 | 21 | 22 | ## Copy all files (here "./src/run_inference.py") 23 | COPY ./ ./ 24 | 25 | 26 | RUN mkdir /myhome/ 27 | COPY ./src /myhome 28 | RUN chmod -R 777 /myhome 29 | 30 | 31 | ## Execute the inference command 32 | CMD ["./src/run_inference.py"] 33 | ENTRYPOINT ["python3"] 34 | -------------------------------------------------------------------------------- /CS-MIL_Docker/requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2023.5.7 2 | charset-normalizer==3.2.0 3 | cmake==3.26.4 4 | contourpy==1.1.0 5 | cycler==0.11.0 6 | filelock==3.12.2 7 | fonttools==4.41.0 8 | idna==3.4 9 | imageio==2.31.1 10 | imgaug==0.4.0 11 | importlib-resources==6.0.0 12 | Jinja2==3.1.2 13 | joblib==1.3.1 14 | kiwisolver==1.4.4 15 | lazy_loader==0.3 16 | lit==16.0.6 17 | MarkupSafe==2.1.3 18 | matplotlib==3.7.2 19 | mpmath==1.3.0 20 | networkx==3.1 21 | numpy==1.24.4 22 | nvidia-cublas-cu11==11.10.3.66 23 | nvidia-cuda-cupti-cu11==11.7.101 24 | nvidia-cuda-nvrtc-cu11==11.7.99 25 | nvidia-cuda-runtime-cu11==11.7.99 26 | nvidia-cudnn-cu11==8.5.0.96 27 | nvidia-cufft-cu11==10.9.0.58 28 | nvidia-curand-cu11==10.2.10.91 29 | nvidia-cusolver-cu11==11.4.0.1 30 | nvidia-cusparse-cu11==11.7.4.91 31 | nvidia-nccl-cu11==2.14.3 32 | nvidia-nvtx-cu11==11.7.91 33 | opencv-python 34 | openslide-python==1.2.0 35 | packaging==23.1 36 | pandas==2.0.3 37 | Pillow==10.0.0 38 | protobuf==4.23.4 39 | pyparsing==3.0.9 40 | python-dateutil==2.8.2 41 | pytz==2023.3 42 | PyWavelets==1.4.1 43 | PyYAML==6.0 44 | requests==2.31.0 45 | scikit-image==0.21.0 46 | scikit-learn==1.3.0 47 | scipy==1.10.1 48 | shapely==2.0.1 49 | six==1.16.0 50 | sympy==1.12 51 | tensorboardX==2.6.1 52 | threadpoolctl==3.2.0 53 | tifffile==2023.7.10 54 | torch==2.0.1 55 | torchaudio==2.0.2 56 | torchvision==0.15.2 57 | tqdm==4.65.0 58 | triton==2.0.0 59 | typing_extensions==4.7.1 60 | tzdata==2023.3 61 | urllib3==2.0.3 62 | zipp==3.16.2 63 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/DeepAttnMISL_CS_MIL.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model definition of DeepAttnMISL 3 | 4 | If this work is useful for your research, please consider to cite our papers: 5 | 6 | [1] "Whole Slide Images based Cancer Survival Prediction using Attention Guided Deep Multiple Instance Learning Networks" 7 | Jiawen Yao, XinliangZhu, Jitendra Jonnagaddala, NicholasHawkins, Junzhou Huang, 8 | Medical Image Analysis, Available online 19 July 2020, 101789 9 | 10 | [2] "Deep Multi-instance Learning for Survival Prediction from Whole Slide Images", In MICCAI 2019 11 | 12 | """ 13 | 14 | import torch.nn as nn 15 | import torch 16 | 17 | class DeepAttnMIL_Surv(nn.Module): 18 | """ 19 | Deep AttnMISL Model definition 20 | """ 21 | 22 | def __init__(self, cluster_num): 23 | super(DeepAttnMIL_Surv, self).__init__() 24 | self.embedding_net = nn.Sequential(nn.Conv2d(2048, 64, 1), 25 | nn.ReLU(), 26 | nn.AdaptiveAvgPool2d((1,1)) 27 | ) 28 | 29 | self.res_attention = nn.Sequential( 30 | nn.Conv2d(64, 32, 1), # V 31 | nn.ReLU(), 32 | nn.Conv2d(32, 1, 1), 33 | ) 34 | 35 | self.attention = nn.Sequential( 36 | nn.Linear(64, 32), # V 37 | nn.Tanh(), 38 | nn.Linear(32, 1) # W 39 | ) 40 | 41 | self.fc6 = nn.Sequential( 42 | nn.Linear(64, 32), 43 | nn.ReLU(), 44 | nn.Dropout(p=0.5), 45 | nn.Linear(32, 1), 46 | nn.Sigmoid() 47 | ) 48 | self.cluster_num = cluster_num 49 | 50 | 51 | def masked_softmax(self, x, mask=None): 52 | """ 53 | Performs masked softmax, as simply masking post-softmax can be 54 | inaccurate 55 | :param x: [batch_size, num_items] 56 | :param mask: [batch_size, num_items] 57 | :return: 58 | """ 59 | if mask is not None: 60 | mask = mask.float() 61 | if mask is not None: 62 | x_masked = x * mask + (1 - 1 / (mask+1e-5)) 63 | else: 64 | x_masked = x 65 | x_max = x_masked.max(1)[0] 66 | x_exp = (x - x_max.unsqueeze(-1)).exp() 67 | if mask is not None: 68 | x_exp = x_exp * mask.float() 69 | return x_exp / x_exp.sum(1).unsqueeze(-1) 70 | 71 | 72 | def forward(self, x, mask): 73 | 74 | " x is a tensor list" 75 | res = [] 76 | for i in range(self.cluster_num): 77 | hh = x[i].type(torch.FloatTensor).to("cuda") 78 | output1 = self.embedding_net(hh[:,:,0:1,:]) 79 | output2 = self.embedding_net(hh[:,:,1:2,:]) 80 | output3 = self.embedding_net(hh[:,:,2:3,:]) 81 | output = torch.cat([output1, output2, output3],2) 82 | res_attention = self.res_attention(output).squeeze(-1) 83 | 84 | final_output = torch.matmul(output.squeeze(-1), torch.transpose(res_attention,2,1)).squeeze(-1) 85 | res.append(final_output) 86 | 87 | h = torch.cat(res) 88 | 89 | b = h.size(0) 90 | c = h.size(1) 91 | 92 | h = h.view(b, c) 93 | 94 | A = self.attention(h) 95 | A = torch.transpose(A, 1, 0) # KxN 96 | 97 | A = self.masked_softmax(A, mask) 98 | 99 | M = torch.mm(A, h) # KxL 100 | 101 | Y_pred = self.fc6(M) 102 | 103 | return Y_pred 104 | 105 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/GCA512_0411_train4w.yaml: -------------------------------------------------------------------------------- 1 | name: simsiam-GCA-b6128s256_ori1024patch 2 | dataset: 3 | name: random 4 | image_size: 256 5 | num_workers: 8 6 | 7 | model: 8 | name: simsiam 9 | backbone: resnet50 10 | proj_layers: 2 11 | 12 | train: 13 | optimizer: 14 | name: sgd 15 | weight_decay: 0.0001 16 | momentum: 0.9 17 | warmup_epochs: 10 18 | warmup_lr: 0 19 | base_lr: 0.05 20 | final_lr: 0 21 | num_epochs: 800 # this parameter influence the lr decay 22 | stop_at_epoch: 200 # has to be smaller than num_epochs 23 | batch_size: 128 24 | save_interval: 1 25 | knn_monitor: False # knn monitor will take more time 26 | knn_interval: 1 27 | knn_k: 200 28 | eval: # linear evaluation, False will turn off automatic evaluation after training 29 | optimizer: 30 | name: sgd 31 | weight_decay: 0 32 | momentum: 0.9 33 | warmup_lr: 0 34 | warmup_epochs: 0 35 | base_lr: 30 36 | final_lr: 0 37 | batch_size: 128 38 | num_epochs: 200 39 | 40 | logger: 41 | tensorboard: True 42 | matplotlib: True 43 | 44 | seed: null # None type for yaml file 45 | # two things might lead to stochastic behavior other than seed: 46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 47 | # (keep this in mind if you want to achieve 100% deterministic) 48 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | import numpy as np 6 | import torch 7 | import random 8 | 9 | import re 10 | import yaml 11 | 12 | import shutil 13 | import warnings 14 | 15 | from datetime import datetime 16 | 17 | 18 | class Namespace(object): 19 | def __init__(self, somedict): 20 | for key, value in somedict.items(): 21 | assert isinstance(key, str) and re.match("[A-Za-z_-]", key) 22 | if isinstance(value, dict): 23 | self.__dict__[key] = Namespace(value) 24 | else: 25 | self.__dict__[key] = value 26 | 27 | def __getattr__(self, attribute): 28 | 29 | raise AttributeError(f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!") 30 | 31 | 32 | def set_deterministic(seed): 33 | # seed by default is None 34 | if seed is not None: 35 | print(f"Deterministic with seed = {seed}") 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed(seed) 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | 43 | def get_args(): 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('-c', '--config-file', type=str, help="xxx.yaml", default = '/myhome/GCA512_0411_train4w.yaml') 46 | parser.add_argument('--debug', action='store_true') 47 | parser.add_argument('--debug_subset_size', type=int, default=8) 48 | parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web") 49 | parser.add_argument('--data_dir', type=str, default=os.getenv('DATA')) 50 | parser.add_argument('--log_dir', type=str, default=os.getenv('LOG')) 51 | parser.add_argument('--ckpt_dir', type=str, default=os.getenv('CHECKPOINT')) 52 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 53 | parser.add_argument('--eval_from', type=str, default=None) 54 | parser.add_argument('--hide_progress', action='store_true') 55 | args = parser.parse_args() 56 | 57 | 58 | with open(args.config_file, 'r') as f: 59 | for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items(): 60 | vars(args)[key] = value 61 | 62 | if args.debug: 63 | if args.train: 64 | args.train.batch_size 65 | args.train.num_epochs = 1 66 | args.train.stop_at_epoch = 1 67 | if args.eval: 68 | args.eval.batch_size = 2 69 | args.eval.num_epochs = 1 # train only one epoch 70 | args.dataset.num_workers = 0 71 | 72 | args.log_dir = '' 73 | args.data_dir = '' 74 | args.ckpt_dir = '' 75 | args.name = '' 76 | 77 | assert not None in [args.log_dir, args.data_dir, args.ckpt_dir, args.name] 78 | 79 | args.log_dir = os.path.join(args.log_dir, 'in-progress_'+datetime.now().strftime('%m%d%H%M%S_')+args.name) 80 | 81 | # os.makedirs(args.log_dir, exist_ok=False) 82 | # print(f'creating file {args.log_dir}') 83 | # os.makedirs(args.ckpt_dir, exist_ok=True) 84 | # 85 | # shutil.copy2(args.config_file, args.log_dir) 86 | set_deterministic(args.seed) 87 | 88 | 89 | vars(args)['aug_kwargs'] = { 90 | 'name':args.model.name, 91 | 'image_size': args.dataset.image_size 92 | } 93 | vars(args)['dataset_kwargs'] = { 94 | 'dataset':args.dataset.name, 95 | 'data_dir': args.data_dir, 96 | 'download':args.download, 97 | 'debug_subset_size': args.debug_subset_size if args.debug else None, 98 | } 99 | vars(args)['dataloader_kwargs'] = { 100 | 'drop_last': True, 101 | 'pin_memory': True, 102 | 'num_workers': args.dataset.num_workers, 103 | } 104 | 105 | return args 106 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .simsiam_aug import SimSiamTransform 2 | from .eval_aug import Transform_single 3 | from .byol_aug import BYOL_transform 4 | from .simclr_aug import SimCLRTransform 5 | def get_aug(name='byol', image_size=256, train=True, train_classifier=None): 6 | 7 | if train==True: 8 | if name == 'simsiam': 9 | augmentation = SimSiamTransform(image_size) 10 | elif name == 'byol': 11 | # augmentation = BYOL_transform(image_size) 12 | augmentation = SimSiamTransform(image_size) 13 | elif name == 'simclr': 14 | augmentation = SimCLRTransform(image_size) 15 | else: 16 | raise NotImplementedError 17 | elif train==False: 18 | if train_classifier is None: 19 | raise Exception 20 | augmentation = Transform_single(image_size, train=train_classifier) 21 | else: 22 | raise Exception 23 | 24 | return augmentation 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/augmentations/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/augmentations/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/__pycache__/byol_aug.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/augmentations/__pycache__/byol_aug.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/__pycache__/byol_aug.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/augmentations/__pycache__/byol_aug.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/__pycache__/eval_aug.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/augmentations/__pycache__/eval_aug.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/__pycache__/eval_aug.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/augmentations/__pycache__/eval_aug.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/__pycache__/simclr_aug.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/augmentations/__pycache__/simclr_aug.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/__pycache__/simclr_aug.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/augmentations/__pycache__/simclr_aug.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/__pycache__/simsiam_aug.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/augmentations/__pycache__/simsiam_aug.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/__pycache__/simsiam_aug.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/augmentations/__pycache__/simsiam_aug.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/byol_aug.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image, ImageOps 3 | try: 4 | from torchvision.transforms import GaussianBlur 5 | except ImportError: 6 | from .gaussian_blur import GaussianBlur 7 | transforms.GaussianBlur = GaussianBlur 8 | 9 | imagenet_norm = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]] 10 | 11 | class BYOL_transform: # Table 6 12 | def __init__(self, image_size, normalize=imagenet_norm): 13 | 14 | self.transform1 = transforms.Compose([ 15 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC), 16 | transforms.RandomHorizontalFlip(p=0.5), 17 | transforms.RandomApply([transforms.ColorJitter(0.4,0.4,0.2,0.1)], p=0.8), 18 | transforms.RandomGrayscale(p=0.2), 19 | transforms.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0)), # simclr paper gives the kernel size. 20 | # Kernel size has to be odd positive number with torchvision 21 | transforms.ToTensor(), 22 | transforms.Normalize(*normalize) 23 | ]) 24 | self.transform2 = transforms.Compose([ 25 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC), 26 | transforms.RandomHorizontalFlip(p=0.5), 27 | transforms.RandomApply([transforms.ColorJitter(0.4,0.4,0.2,0.1)], p=0.8), 28 | transforms.RandomGrayscale(p=0.2), 29 | # transforms.RandomApply([GaussianBlur(kernel_size=int(0.1 * image_size))], p=0.1), 30 | transforms.RandomApply([transforms.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0))], p=0.1), 31 | transforms.RandomApply([Solarization()], p=0.2), 32 | 33 | transforms.ToTensor(), 34 | transforms.Normalize(*normalize) 35 | ]) 36 | 37 | 38 | def __call__(self, x): 39 | x1 = self.transform1(x) 40 | x2 = self.transform2(x) 41 | return x1, x2 42 | 43 | 44 | class Transform_single: 45 | def __init__(self, image_size, train, normalize=imagenet_norm): 46 | self.denormalize = Denormalize(*imagenet_norm) 47 | if train == True: 48 | self.transform = transforms.Compose([ 49 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC), 50 | transforms.RandomHorizontalFlip(), 51 | transforms.ToTensor(), 52 | transforms.Normalize(*normalize) 53 | ]) 54 | else: 55 | self.transform = transforms.Compose([ 56 | transforms.Resize(int(image_size*(8/7)), interpolation=Image.BICUBIC), # 224 -> 256 57 | transforms.CenterCrop(image_size), 58 | transforms.ToTensor(), 59 | transforms.Normalize(*normalize) 60 | ]) 61 | 62 | def __call__(self, x): 63 | return self.transform(x) 64 | 65 | 66 | 67 | class Solarization(): 68 | # ImageFilter 69 | def __init__(self, threshold=128): 70 | self.threshold = threshold 71 | def __call__(self, image): 72 | return ImageOps.solarize(image, self.threshold) 73 | 74 | 75 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/eval_aug.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | 4 | imagenet_norm = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]] 5 | 6 | class Transform_single(): 7 | def __init__(self, image_size, train, normalize=imagenet_norm): 8 | if train == True: 9 | self.transform = transforms.Compose([ 10 | # transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC), 11 | # transforms.RandomHorizontalFlip(), 12 | transforms.Resize(image_size), 13 | transforms.ToTensor(), 14 | transforms.Normalize(*normalize) 15 | ]) 16 | else: 17 | self.transform = transforms.Compose([ 18 | # transforms.Resize(int(image_size*(8/7)), interpolation=Image.BICUBIC), # 224 -> 256 19 | # transforms.CenterCrop(image_size), 20 | transforms.Resize(image_size), 21 | transforms.ToTensor(), 22 | transforms.Normalize(*normalize) 23 | ]) 24 | 25 | def __call__(self, x): 26 | return self.transform(x) 27 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/simclr_aug.py: -------------------------------------------------------------------------------- 1 | 2 | import torchvision.transforms as T 3 | try: 4 | from torchvision.transforms import GaussianBlur 5 | except ImportError: 6 | from .gaussian_blur import GaussianBlur 7 | T.GaussianBlur = GaussianBlur 8 | 9 | imagenet_mean_std = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]] 10 | 11 | class SimCLRTransform(): 12 | def __init__(self, image_size, mean_std=imagenet_mean_std, s=1.0): 13 | image_size = 224 if image_size is None else image_size 14 | self.transform = T.Compose([ 15 | T.RandomResizedCrop(image_size, scale=(0.2, 1.0)), 16 | T.RandomHorizontalFlip(), 17 | T.RandomApply([T.ColorJitter(0.8*s,0.8*s,0.8*s,0.2*s)], p=0.8), 18 | T.RandomGrayscale(p=0.2), 19 | T.RandomApply([T.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0))], p=0.5), 20 | # We blur the image 50% of the time using a Gaussian kernel. We randomly sample σ ∈ [0.1, 2.0], and the kernel size is set to be 10% of the image height/width. 21 | T.ToTensor(), 22 | T.Normalize(*mean_std) 23 | ]) 24 | def __call__(self, x): 25 | x1 = self.transform(x) 26 | x2 = self.transform(x) 27 | return x1, x2 28 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/simsiam_aug.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | import numpy as np 3 | import torch 4 | try: 5 | from torchvision.transforms import GaussianBlur 6 | except ImportError: 7 | from .gaussian_blur import GaussianBlur 8 | T.GaussianBlur = GaussianBlur 9 | 10 | imagenet_mean_std = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]] 11 | 12 | class SimSiamTransform(): 13 | def __init__(self, image_size, mean_std=imagenet_mean_std): 14 | image_size = 128 if image_size is None else image_size # by default simsiam use image size 224 15 | p_blur = 0.5 if image_size > 32 else 0 # exclude cifar 16 | # the paper didn't specify this, feel free to change this value 17 | # I use the setting from simclr which is 50% chance applying the gaussian blur 18 | # the 32 is prepared for cifar training where they disabled gaussian blur 19 | self.transform = T.Compose([ 20 | T.RandomResizedCrop(image_size, scale=(0.2, 1.0)), 21 | T.RandomHorizontalFlip(), 22 | T.RandomApply([T.ColorJitter(0.4,0.4,0.4,0.1)], p=0.8), 23 | T.RandomGrayscale(p=0.2), 24 | T.RandomApply([T.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0))], p=p_blur), 25 | T.ToTensor(), 26 | T.Normalize(*mean_std) 27 | ]) 28 | 29 | # self.transform = T.Compose([ 30 | # T.Resize(image_size), 31 | # # T.Resize(224), 32 | # T.ToTensor(), 33 | # T.Normalize(*mean_std) 34 | # ]) 35 | def __call__(self, x): 36 | x1 = self.transform(x) 37 | x2 = self.transform(x) 38 | return x1, x2 39 | 40 | 41 | def to_pil_image(pic, mode=None): 42 | """Convert a tensor or an ndarray to PIL Image. 43 | 44 | See :class:`~torchvision.transforms.ToPILImage` for more details. 45 | 46 | Args: 47 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. 48 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 49 | 50 | .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes 51 | 52 | Returns: 53 | PIL Image: Image converted to PIL Image. 54 | """ 55 | if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): 56 | raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) 57 | 58 | elif isinstance(pic, torch.Tensor): 59 | if pic.ndimension() not in {2, 3}: 60 | raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension())) 61 | 62 | elif pic.ndimension() == 2: 63 | # if 2D image, add channel dimension (CHW) 64 | pic = pic.unsqueeze(0) 65 | 66 | elif isinstance(pic, np.ndarray): 67 | if pic.ndim not in {2, 3}: 68 | raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) 69 | 70 | elif pic.ndim == 2: 71 | # if 2D image, add channel dimension (HWC) 72 | pic = np.expand_dims(pic, 2) 73 | 74 | npimg = pic 75 | if isinstance(pic, torch.Tensor): 76 | if pic.is_floating_point() and mode != 'F': 77 | pic = pic.mul(255).byte() 78 | npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0)) 79 | 80 | if not isinstance(npimg, np.ndarray): 81 | raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 82 | 'not {}'.format(type(npimg))) 83 | 84 | if npimg.shape[2] == 1: 85 | expected_mode = None 86 | npimg = npimg[:, :, 0] 87 | if npimg.dtype == np.uint8: 88 | expected_mode = 'L' 89 | elif npimg.dtype == np.int16: 90 | expected_mode = 'I;16' 91 | elif npimg.dtype == np.int32: 92 | expected_mode = 'I' 93 | elif npimg.dtype == np.float32: 94 | expected_mode = 'F' 95 | if mode is not None and mode != expected_mode: 96 | raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" 97 | .format(mode, np.dtype, expected_mode)) 98 | mode = expected_mode 99 | 100 | elif npimg.shape[2] == 2: 101 | permitted_2_channel_modes = ['LA'] 102 | if mode is not None and mode not in permitted_2_channel_modes: 103 | raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes)) 104 | 105 | if mode is None and npimg.dtype == np.uint8: 106 | mode = 'LA' 107 | 108 | elif npimg.shape[2] == 4: 109 | permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX'] 110 | if mode is not None and mode not in permitted_4_channel_modes: 111 | raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) 112 | 113 | if mode is None and npimg.dtype == np.uint8: 114 | mode = 'RGBA' 115 | else: 116 | permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] 117 | if mode is not None and mode not in permitted_3_channel_modes: 118 | raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) 119 | if mode is None and npimg.dtype == np.uint8: 120 | mode = 'RGB' 121 | 122 | if mode is None: 123 | raise TypeError('Input type {} is not supported'.format(npimg.dtype)) 124 | 125 | return Image.fromarray(npimg, mode=mode) 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/augmentations/swav_aug.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/augmentations/swav_aug.py -------------------------------------------------------------------------------- /CS-MIL_Docker/src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from .random_dataset import RandomDataset 4 | 5 | 6 | def get_dataset(dataset, data_dir, transform, train=True, download=False, debug_subset_size=None): 7 | if dataset == 'mnist': 8 | dataset = torchvision.datasets.MNIST(data_dir, train=train, transform=transform, download=download) 9 | elif dataset == 'stl10': 10 | dataset = torchvision.datasets.STL10(data_dir, split='train+unlabeled' if train else 'test', transform=transform, download=download) 11 | elif dataset == 'cifar10': 12 | dataset = torchvision.datasets.CIFAR10(data_dir, train=train, transform=transform, download=download) 13 | elif dataset == 'cifar100': 14 | dataset = torchvision.datasets.CIFAR100(data_dir, train=train, transform=transform, download=download) 15 | elif dataset == 'imagenet': 16 | dataset = torchvision.datasets.ImageNet(data_dir, split='train' if train == True else 'val', transform=transform, download=download) 17 | elif dataset == 'random': 18 | dataset = RandomDataset() 19 | else: 20 | raise NotImplementedError 21 | 22 | if debug_subset_size is not None: 23 | dataset = torch.utils.data.Subset(dataset, range(0, debug_subset_size)) # take only one batch 24 | dataset.classes = dataset.dataset.classes 25 | dataset.targets = dataset.dataset.targets 26 | return dataset -------------------------------------------------------------------------------- /CS-MIL_Docker/src/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/datasets/__pycache__/random_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/datasets/__pycache__/random_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/datasets/__pycache__/random_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/datasets/__pycache__/random_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/datasets/random_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class RandomDataset(torch.utils.data.Dataset): 4 | def __init__(self, root=None, train=True, transform=None, target_transform=None): 5 | self.transform = transform 6 | self.target_transform = target_transform 7 | 8 | self.size = 1000 9 | 10 | def __getitem__(self, idx): 11 | if idx < self.size: 12 | # return [torch.randn((3, 224, 224)), torch.randn((3, 224, 224))], [0,0,0] 13 | return [torch.randn((3, 128, 128)), torch.randn((3, 128, 128))], [0,0,0] 14 | else: 15 | raise Exception 16 | 17 | def __len__(self): 18 | return self.size 19 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .simsiam import SimSiam 2 | from .byol import BYOL 3 | from .simclr import SimCLR 4 | from torchvision.models import resnet50, resnet18 5 | import torch 6 | from .backbones import resnet18_cifar_variant1, resnet18_cifar_variant2, resnet50_TCGA 7 | 8 | def get_backbone(backbone, castrate=True): #lq debug 9 | backbone = eval(f"{backbone}()") 10 | 11 | if castrate: 12 | backbone.output_dim = backbone.fc.in_features 13 | backbone.fc = torch.nn.Identity() 14 | 15 | return backbone 16 | 17 | def get_model(model_cfg): 18 | 19 | if model_cfg.name == 'simsiam': 20 | model = SimSiam(get_backbone(model_cfg.backbone)) 21 | # model = SimSiam() 22 | if model_cfg.proj_layers is not None: 23 | model.projector.set_layers(model_cfg.proj_layers) 24 | 25 | elif model_cfg.name == 'byol': 26 | model = BYOL(get_backbone(model_cfg.backbone)) 27 | elif model_cfg.name == 'simclr': 28 | model = SimCLR(get_backbone(model_cfg.backbone)) 29 | elif model_cfg.name == 'swav': 30 | raise NotImplementedError 31 | else: 32 | raise NotImplementedError 33 | return model 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/__pycache__/byol.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/__pycache__/byol.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/__pycache__/byol.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/__pycache__/byol.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/__pycache__/simclr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/__pycache__/simclr.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/__pycache__/simclr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/__pycache__/simclr.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/__pycache__/simsiam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/__pycache__/simsiam.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/__pycache__/simsiam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/__pycache__/simsiam.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/backbones/TCGA_eval_onehot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from tqdm import tqdm 7 | from arguments import get_args 8 | from augmentations import get_aug 9 | from models import get_model, get_backbone 10 | from tools import AverageMeter 11 | from datasets import get_dataset 12 | from optimizers import get_optimizer, LR_Scheduler 13 | from torchvision import datasets 14 | 15 | 16 | def main(args): 17 | train_directory = '/share/contrastive_learning/data/sup_data/data_0124_10000/train_patch' 18 | train_loader = torch.utils.data.DataLoader( 19 | # dataset=get_dataset( 20 | # transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs), 21 | # train=True, 22 | # **args.dataset_kwargs 23 | # ), 24 | dataset=datasets.ImageFolder(root=train_directory, transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs)), 25 | batch_size=args.eval.batch_size, 26 | shuffle=True, 27 | **args.dataloader_kwargs 28 | ) 29 | test_dictionary = '/share/contrastive_learning/data/sup_data/data_0124_10000/val_patch' 30 | test_loader = torch.utils.data.DataLoader( 31 | # dataset=get_dataset( 32 | # transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs), 33 | # train=False, 34 | # **args.dataset_kwargs 35 | # ), 36 | dataset=datasets.ImageFolder(root=test_dictionary, transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs)), 37 | batch_size=args.eval.batch_size, 38 | shuffle=False, 39 | **args.dataloader_kwargs 40 | ) 41 | 42 | model = get_backbone(args.model.backbone) 43 | classifier = nn.Linear(in_features=model.output_dim, out_features=16, bias=True).to(args.device) 44 | 45 | assert args.eval_from is not None 46 | save_dict = torch.load(args.eval_from, map_location='cpu') 47 | msg = model.load_state_dict({k[9:]: v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, 48 | strict=True) 49 | 50 | # print(msg) 51 | model = model.to(args.device) 52 | model = torch.nn.DataParallel(model) 53 | 54 | # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier) 55 | classifier = torch.nn.DataParallel(classifier) 56 | # define optimizer 57 | optimizer = get_optimizer( 58 | args.eval.optimizer.name, classifier, 59 | lr=args.eval.base_lr * args.eval.batch_size / 256, 60 | momentum=args.eval.optimizer.momentum, 61 | weight_decay=args.eval.optimizer.weight_decay) 62 | 63 | # define lr scheduler 64 | lr_scheduler = LR_Scheduler( 65 | optimizer, 66 | args.eval.warmup_epochs, args.eval.warmup_lr * args.eval.batch_size / 256, 67 | args.eval.num_epochs, args.eval.base_lr * args.eval.batch_size / 256, 68 | args.eval.final_lr * args.eval.batch_size / 256, 69 | len(train_loader), 70 | ) 71 | 72 | loss_meter = AverageMeter(name='Loss') 73 | acc_meter = AverageMeter(name='Accuracy') 74 | 75 | # Start training 76 | global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating') 77 | for epoch in global_progress: 78 | loss_meter.reset() 79 | model.eval() 80 | classifier.train() 81 | local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.eval.num_epochs}', disable=False) 82 | 83 | for idx, (images, labels) in enumerate(local_progress): 84 | classifier.zero_grad() 85 | with torch.no_grad(): 86 | feature = model(images.to(args.device)) 87 | 88 | preds = classifier(feature) 89 | 90 | loss = F.cross_entropy(preds, labels.to(args.device)) 91 | 92 | loss.backward() 93 | optimizer.step() 94 | loss_meter.update(loss.item()) 95 | lr = lr_scheduler.step() 96 | local_progress.set_postfix({'lr': lr, "loss": loss_meter.val, 'loss_avg': loss_meter.avg}) 97 | 98 | classifier.eval() 99 | correct, total = 0, 0 100 | acc_meter.reset() 101 | for idx, (images, labels) in enumerate(test_loader): 102 | with torch.no_grad(): 103 | feature = model(images.to(args.device)) 104 | preds = classifier(feature).argmax(dim=1) 105 | correct = (preds == labels.to(args.device)).sum().item() 106 | acc_meter.update(correct / preds.shape[0]) 107 | print(f'Accuracy = {acc_meter.avg * 100:.2f}') 108 | 109 | 110 | if __name__ == "__main__": 111 | main(args=get_args()) 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar_resnet_1 import resnet18 as resnet18_cifar_variant1 2 | from .cifar_resnet_2 import ResNet18 as resnet18_cifar_variant2 3 | from .cifar_resnet_1 import resnet50 as resnet50_TCGA 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/backbones/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/backbones/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/backbones/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/backbones/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/backbones/__pycache__/cifar_resnet_1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/backbones/__pycache__/cifar_resnet_1.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/backbones/__pycache__/cifar_resnet_1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/backbones/__pycache__/cifar_resnet_1.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/backbones/__pycache__/cifar_resnet_2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/backbones/__pycache__/cifar_resnet_2.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/backbones/__pycache__/cifar_resnet_2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/models/backbones/__pycache__/cifar_resnet_2.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/backbones/cifar_resnet_2.py: -------------------------------------------------------------------------------- 1 | # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py 2 | '''ResNet in PyTorch. 3 | 4 | For Pre-activation ResNet, see 'preact_resnet.py'. 5 | 6 | Reference: 7 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 8 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 9 | ''' 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, in_planes, planes, stride=1): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = nn.Conv2d( 21 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 24 | stride=1, padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride != 1 or in_planes != self.expansion*planes: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d(in_planes, self.expansion*planes, 31 | kernel_size=1, stride=stride, bias=False), 32 | nn.BatchNorm2d(self.expansion*planes) 33 | ) 34 | 35 | def forward(self, x): 36 | out = F.relu(self.bn1(self.conv1(x))) 37 | out = self.bn2(self.conv2(out)) 38 | out += self.shortcut(x) 39 | out = F.relu(out) 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | expansion = 4 45 | 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(Bottleneck, self).__init__() 48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 51 | stride=stride, padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, self.expansion * 54 | planes, kernel_size=1, bias=False) 55 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 56 | 57 | self.shortcut = nn.Sequential() 58 | if stride != 1 or in_planes != self.expansion*planes: 59 | self.shortcut = nn.Sequential( 60 | nn.Conv2d(in_planes, self.expansion*planes, 61 | kernel_size=1, stride=stride, bias=False), 62 | nn.BatchNorm2d(self.expansion*planes) 63 | ) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = F.relu(self.bn2(self.conv2(out))) 68 | out = self.bn3(self.conv3(out)) 69 | out += self.shortcut(x) 70 | out = F.relu(out) 71 | return out 72 | 73 | 74 | class ResNet(nn.Module): 75 | def __init__(self, block, num_blocks, num_classes=16): 76 | super(ResNet, self).__init__() 77 | self.in_planes = 64 78 | 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 80 | stride=1, padding=1, bias=False) 81 | self.bn1 = nn.BatchNorm2d(64) 82 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 83 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 84 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 85 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 86 | self.fc = nn.Linear(512*block.expansion, num_classes) 87 | 88 | def _make_layer(self, block, planes, num_blocks, stride): 89 | strides = [stride] + [1]*(num_blocks-1) 90 | layers = [] 91 | for stride in strides: 92 | layers.append(block(self.in_planes, planes, stride)) 93 | self.in_planes = planes * block.expansion 94 | return nn.Sequential(*layers) 95 | 96 | def forward(self, x): 97 | out = F.relu(self.bn1(self.conv1(x))) 98 | out = self.layer1(out) 99 | out = self.layer2(out) 100 | out = self.layer3(out) 101 | out = self.layer4(out) 102 | out = F.avg_pool2d(out, 4) 103 | out = out.view(out.size(0), -1) 104 | out = self.fc(out) 105 | return out 106 | 107 | 108 | def ResNet18(): 109 | return ResNet(BasicBlock, [2, 2, 2, 2]) 110 | 111 | 112 | def ResNet34(): 113 | return ResNet(BasicBlock, [3, 4, 6, 3]) 114 | 115 | 116 | def ResNet50(): 117 | return ResNet(Bottleneck, [3, 4, 6, 3]) 118 | 119 | 120 | def ResNet101(): 121 | return ResNet(Bottleneck, [3, 4, 23, 3]) 122 | 123 | 124 | def ResNet152(): 125 | return ResNet(Bottleneck, [3, 8, 36, 3]) 126 | 127 | 128 | def test(): 129 | net = ResNet18() 130 | print(sum(p.numel() for p in net.parameters() if p.requires_grad)) 131 | import torchvision 132 | net2 = torchvision.models.resnet18() 133 | print(sum(p.numel() for p in net2.parameters() if p.requires_grad)) 134 | # y = net(torch.randn(1, 3, 32, 32)) 135 | # print(y.size()) 136 | # 11173962 137 | # 11689512 138 | # test() -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/byol.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torchvision import transforms 7 | from math import pi, cos 8 | from collections import OrderedDict 9 | HPS = dict( 10 | max_steps=int(1000. * 1281167 / 4096), # 1000 epochs * 1281167 samples / batch size = 100 epochs * N of step/epoch 11 | # = total_epochs * len(dataloader) 12 | mlp_hidden_size=4096, 13 | projection_size=256, 14 | base_target_ema=4e-3, 15 | optimizer_config=dict( 16 | optimizer_name='lars', 17 | beta=0.9, 18 | trust_coef=1e-3, 19 | weight_decay=1.5e-6, 20 | exclude_bias_from_adaption=True), 21 | learning_rate_schedule=dict( 22 | base_learning_rate=0.2, 23 | warmup_steps=int(10.0 * 1281167 / 4096), # 10 epochs * N of steps/epoch = 10 epochs * len(dataloader) 24 | anneal_schedule='cosine'), 25 | batchnorm_kwargs=dict( 26 | decay_rate=0.9, 27 | eps=1e-5), 28 | seed=1337, 29 | ) 30 | 31 | # def loss_fn(x, y, version='simplified'): 32 | 33 | # if version == 'original': 34 | # y = y.detach() 35 | # x = F.normalize(x, dim=-1, p=2) 36 | # y = F.normalize(y, dim=-1, p=2) 37 | # return (2 - 2 * (x * y).sum(dim=-1)).mean() 38 | # elif version == 'simplified': 39 | # return (2 - 2 * F.cosine_similarity(x,y.detach(), dim=-1)).mean() 40 | # else: 41 | # raise NotImplementedError 42 | 43 | from .simsiam import D # a bit different but it's essentially the same thing: neg cosine sim & stop gradient 44 | 45 | 46 | class MLP(nn.Module): 47 | def __init__(self, in_dim): 48 | super().__init__() 49 | 50 | self.layer1 = nn.Sequential( 51 | nn.Linear(in_dim, HPS['mlp_hidden_size']), 52 | nn.BatchNorm1d(HPS['mlp_hidden_size'], eps=HPS['batchnorm_kwargs']['eps'], momentum=1-HPS['batchnorm_kwargs']['decay_rate']), 53 | nn.ReLU(inplace=True) 54 | ) 55 | self.layer2 = nn.Linear(HPS['mlp_hidden_size'], HPS['projection_size']) 56 | 57 | def forward(self, x): 58 | x = self.layer1(x) 59 | x = self.layer2(x) 60 | return x 61 | 62 | class BYOL(nn.Module): 63 | def __init__(self, backbone): 64 | super().__init__() 65 | 66 | self.backbone = backbone 67 | self.projector = MLP(backbone.output_dim) 68 | self.online_encoder = nn.Sequential( 69 | self.backbone, 70 | self.projector 71 | ) 72 | 73 | self.target_encoder = copy.deepcopy(self.online_encoder) 74 | self.online_predictor = MLP(HPS['projection_size']) 75 | raise NotImplementedError('Please put update_moving_average to training') 76 | 77 | def target_ema(self, k, K, base_ema=HPS['base_target_ema']): 78 | # tau_base = 0.996 79 | # base_ema = 1 - tau_base = 0.996 80 | return 1 - base_ema * (cos(pi*k/K)+1)/2 81 | # return 1 - (1-self.tau_base) * (cos(pi*k/K)+1)/2 82 | 83 | @torch.no_grad() 84 | def update_moving_average(self, global_step, max_steps): 85 | tau = self.target_ema(global_step, max_steps) 86 | for online, target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()): 87 | target.data = tau * target.data + (1 - tau) * online.data 88 | 89 | def forward(self, x1, x2): 90 | f_o, h_o = self.online_encoder, self.online_predictor 91 | f_t = self.target_encoder 92 | 93 | z1_o = f_o(x1) 94 | z2_o = f_o(x2) 95 | 96 | p1_o = h_o(z1_o) 97 | p2_o = h_o(z2_o) 98 | 99 | with torch.no_grad(): 100 | z1_t = f_t(x1) 101 | z2_t = f_t(x2) 102 | 103 | L = D(p1_o, z2_t) / 2 + D(p2_o, z1_t) / 2 104 | return {'loss': L} 105 | 106 | 107 | 108 | if __name__ == "__main__": 109 | pass -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/simclr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import resnet50 5 | 6 | def NT_XentLoss(z1, z2, temperature=0.5): 7 | z1 = F.normalize(z1, dim=1) 8 | z2 = F.normalize(z2, dim=1) 9 | N, Z = z1.shape 10 | device = z1.device 11 | representations = torch.cat([z1, z2], dim=0) 12 | similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) 13 | l_pos = torch.diag(similarity_matrix, N) 14 | r_pos = torch.diag(similarity_matrix, -N) 15 | positives = torch.cat([l_pos, r_pos]).view(2 * N, 1) 16 | diag = torch.eye(2*N, dtype=torch.bool, device=device) 17 | diag[N:,:N] = diag[:N,N:] = diag[:N,:N] 18 | 19 | negatives = similarity_matrix[~diag].view(2*N, -1) 20 | 21 | logits = torch.cat([positives, negatives], dim=1) 22 | logits /= temperature 23 | 24 | labels = torch.zeros(2*N, device=device, dtype=torch.int64) 25 | 26 | loss = F.cross_entropy(logits, labels, reduction='sum') 27 | return loss / (2 * N) 28 | 29 | 30 | class projection_MLP(nn.Module): 31 | def __init__(self, in_dim, out_dim=256): 32 | super().__init__() 33 | hidden_dim = in_dim 34 | self.layer1 = nn.Sequential( 35 | nn.Linear(in_dim, hidden_dim), 36 | nn.ReLU(inplace=True) 37 | ) 38 | self.layer2 = nn.Linear(hidden_dim, out_dim) 39 | def forward(self, x): 40 | x = self.layer1(x) 41 | x = self.layer2(x) 42 | return x 43 | 44 | class SimCLR(nn.Module): 45 | 46 | def __init__(self, backbone=resnet50()): 47 | super().__init__() 48 | 49 | self.backbone = backbone 50 | self.projector = projection_MLP(backbone.output_dim) 51 | self.encoder = nn.Sequential( 52 | self.backbone, 53 | self.projector 54 | ) 55 | 56 | 57 | 58 | def forward(self, x1, x2): 59 | z1 = self.encoder(x1) 60 | z2 = self.encoder(x2) 61 | 62 | loss = NT_XentLoss(z1, z2) 63 | return {'loss':loss} 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/simsiam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import resnet50 5 | 6 | 7 | def D(p, z, version='simplified'): # negative cosine similarity 8 | if version == 'original': 9 | z = z.detach() # stop gradient 10 | p = F.normalize(p, dim=1) # l2-normalize 11 | z = F.normalize(z, dim=1) # l2-normalize 12 | return -(p*z).sum(dim=1).mean() 13 | 14 | elif version == 'simplified':# same thing, much faster. Scroll down, speed test in __main__ 15 | return - F.cosine_similarity(p, z.detach(), dim=-1).mean() 16 | else: 17 | raise Exception 18 | 19 | 20 | 21 | class projection_MLP(nn.Module): 22 | def __init__(self, in_dim, hidden_dim=2048, out_dim=2048): 23 | super().__init__() 24 | ''' page 3 baseline setting 25 | Projection MLP. The projection MLP (in f) has BN ap- 26 | plied to each fully-connected (fc) layer, including its out- 27 | put fc. Its output fc has no ReLU. The hidden fc is 2048-d. 28 | This MLP has 3 layers. 29 | ''' 30 | self.layer1 = nn.Sequential( 31 | nn.Linear(in_dim, hidden_dim), 32 | nn.BatchNorm1d(hidden_dim), 33 | nn.ReLU(inplace=True) 34 | ) 35 | self.layer2 = nn.Sequential( 36 | nn.Linear(hidden_dim, hidden_dim), 37 | nn.BatchNorm1d(hidden_dim), 38 | nn.ReLU(inplace=True) 39 | ) 40 | self.layer3 = nn.Sequential( 41 | nn.Linear(hidden_dim, out_dim), 42 | nn.BatchNorm1d(hidden_dim) 43 | ) 44 | self.num_layers = 3 45 | def set_layers(self, num_layers): 46 | self.num_layers = num_layers 47 | 48 | def forward(self, x): 49 | if self.num_layers == 3: 50 | x = self.layer1(x) 51 | x = self.layer2(x) 52 | x = self.layer3(x) 53 | elif self.num_layers == 2: 54 | x = self.layer1(x) 55 | x = self.layer3(x) 56 | else: 57 | raise Exception 58 | return x 59 | 60 | 61 | class prediction_MLP(nn.Module): 62 | def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048): # bottleneck structure 63 | super().__init__() 64 | ''' page 3 baseline setting 65 | Prediction MLP. The prediction MLP (h) has BN applied 66 | to its hidden fc layers. Its output fc does not have BN 67 | (ablation in Sec. 4.4) or ReLU. This MLP has 2 layers. 68 | The dimension of h’s input and output (z and p) is d = 2048, 69 | and h’s hidden layer’s dimension is 512, making h a 70 | bottleneck structure (ablation in supplement). 71 | ''' 72 | self.layer1 = nn.Sequential( 73 | nn.Linear(in_dim, hidden_dim), 74 | nn.BatchNorm1d(hidden_dim), 75 | nn.ReLU(inplace=True) 76 | ) 77 | self.layer2 = nn.Linear(hidden_dim, out_dim) 78 | """ 79 | Adding BN to the output of the prediction MLP h does not work 80 | well (Table 3d). We find that this is not about collapsing. 81 | The training is unstable and the loss oscillates. 82 | """ 83 | 84 | def forward(self, x): 85 | x = self.layer1(x) 86 | x = self.layer2(x) 87 | return x 88 | 89 | class SimSiam(nn.Module): 90 | def __init__(self, backbone=resnet50()): 91 | super().__init__() 92 | 93 | self.backbone = backbone 94 | self.projector = projection_MLP(backbone.output_dim) 95 | 96 | self.encoder = nn.Sequential( # f encoder 97 | self.backbone, 98 | self.projector 99 | ) 100 | self.predictor = prediction_MLP() 101 | 102 | # self.backbone = backbone 103 | # # self.projector = projection_MLP(backbone.output_dim) 104 | # self.encoder = nn.Sequential( # f encoder 105 | # self.backbone, 106 | # # self.projector 107 | # ) 108 | # # self.predictor = prediction_MLP() 109 | 110 | def forward(self, x1, x2): 111 | 112 | f, h = self.encoder, self.predictor 113 | # z1, z2 = f(x1), f(x2) 114 | z1 = f(x1) 115 | z2 = f(x2) 116 | p1, p2 = h(z1), h(z2) 117 | L = D(p1, z2) / 2 + D(p2, z1) / 2 118 | return {'loss': L} 119 | 120 | # f = self.encoder 121 | # # z1, z2 = f(x1), f(x2) 122 | # z1 = f(x1) 123 | # # z2 = f(x2) 124 | # # p1, p2 = h(z1), h(z2) 125 | # L = D(z1, z1) / 2 126 | # return {'loss': L} 127 | 128 | 129 | if __name__ == "__main__": 130 | model = SimSiam() 131 | x1 = torch.randn((2, 3, 224, 224)) 132 | x2 = torch.randn_like(x1) 133 | 134 | model.forward(x1, x2).backward() 135 | print("forward backwork check") 136 | 137 | z1 = torch.randn((200, 2560)) 138 | z2 = torch.randn_like(z1) 139 | import time 140 | tic = time.time() 141 | print(D(z1, z2, version='original')) 142 | toc = time.time() 143 | print(toc - tic) 144 | tic = time.time() 145 | print(D(z1, z2, version='simplified')) 146 | toc = time.time() 147 | print(toc - tic) 148 | 149 | # Output: 150 | # tensor(-0.0010) 151 | # 0.005159854888916016 152 | # tensor(-0.0010) 153 | # 0.0014872550964355469 154 | 155 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/models/swav.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import resnet50 5 | 6 | class SwAV(nn.Module): 7 | def __init__(self, backbone=resnet50()): 8 | super().__init__() 9 | 10 | backbone.fc = nn.Identity() 11 | self.backbone = backbone 12 | 13 | def forward(self, x1, x2): 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .lars import LARS 2 | from .lars_simclr import LARS_simclr 3 | from .larc import LARC 4 | import torch 5 | from .lr_scheduler import LR_Scheduler 6 | 7 | 8 | def get_optimizer(name, model, lr, momentum, weight_decay): 9 | 10 | predictor_prefix = ('module.predictor', 'predictor') 11 | parameters = [{ 12 | 'name': 'base', 13 | 'params': [param for name, param in model.named_parameters() if not name.startswith(predictor_prefix)], 14 | 'lr': lr 15 | },{ 16 | 'name': 'predictor', 17 | 'params': [param for name, param in model.named_parameters() if name.startswith(predictor_prefix)], 18 | 'lr': lr 19 | }] 20 | if name == 'lars': 21 | optimizer = LARS(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) 22 | elif name == 'sgd': 23 | optimizer = torch.optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) 24 | elif name == 'lars_simclr': # Careful 25 | optimizer = LARS_simclr(model.named_modules(), lr=lr, momentum=momentum, weight_decay=weight_decay) 26 | elif name == 'larc': 27 | optimizer = LARC( 28 | torch.optim.SGD( 29 | parameters, 30 | lr=lr, 31 | momentum=momentum, 32 | weight_decay=weight_decay 33 | ), 34 | trust_coefficient=0.001, 35 | clip=False 36 | ) 37 | else: 38 | raise NotImplementedError 39 | return optimizer 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/optimizers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/optimizers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/__pycache__/larc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/optimizers/__pycache__/larc.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/__pycache__/larc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/optimizers/__pycache__/larc.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/__pycache__/lars.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/optimizers/__pycache__/lars.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/__pycache__/lars.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/optimizers/__pycache__/lars.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/__pycache__/lars_simclr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/optimizers/__pycache__/lars_simclr.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/__pycache__/lars_simclr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/optimizers/__pycache__/lars_simclr.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/optimizers/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/optimizers/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/larc.py: -------------------------------------------------------------------------------- 1 | """SwAV use larc instead of lars optimizer""" 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.parameter import Parameter 6 | from torch.optim.optimizer import Optimizer 7 | 8 | def main(): # Example 9 | import torchvision 10 | model = torchvision.models.resnet18(pretrained=False) 11 | # optim = torch.optim.Adam(model.parameters(), lr=0.0001) 12 | optim = torch.optim.SGD(model.parameters(),lr=0.2, momentum=0.9, weight_decay=1.5e-6) 13 | optim = LARC(optim) 14 | 15 | class LARC(Optimizer): 16 | """ 17 | :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, 18 | in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive 19 | local learning rate for each individual parameter. The algorithm is designed to improve 20 | convergence of large batch training. 21 | 22 | See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. 23 | In practice it modifies the gradients of parameters as a proxy for modifying the learning rate 24 | of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer. 25 | ``` 26 | model = ... 27 | optim = torch.optim.Adam(model.parameters(), lr=...) 28 | optim = LARC(optim) 29 | ``` 30 | It can even be used in conjunction with apex.fp16_utils.FP16_optimizer. 31 | ``` 32 | model = ... 33 | optim = torch.optim.Adam(model.parameters(), lr=...) 34 | optim = LARC(optim) 35 | optim = apex.fp16_utils.FP16_Optimizer(optim) 36 | ``` 37 | Args: 38 | optimizer: Pytorch optimizer to wrap and modify learning rate for. 39 | trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888 40 | clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. 41 | eps: epsilon kludge to help with numerical stability while calculating adaptive_lr 42 | """ 43 | 44 | def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8): 45 | self.optim = optimizer 46 | self.trust_coefficient = trust_coefficient 47 | self.eps = eps 48 | self.clip = clip 49 | 50 | def __getstate__(self): 51 | return self.optim.__getstate__() 52 | 53 | def __setstate__(self, state): 54 | self.optim.__setstate__(state) 55 | 56 | @property 57 | def state(self): 58 | return self.optim.state 59 | 60 | def __repr__(self): 61 | return self.optim.__repr__() 62 | 63 | @property 64 | def param_groups(self): 65 | return self.optim.param_groups 66 | 67 | @param_groups.setter 68 | def param_groups(self, value): 69 | self.optim.param_groups = value 70 | 71 | def state_dict(self): 72 | return self.optim.state_dict() 73 | 74 | def load_state_dict(self, state_dict): 75 | self.optim.load_state_dict(state_dict) 76 | 77 | def zero_grad(self): 78 | self.optim.zero_grad() 79 | 80 | def add_param_group(self, param_group): 81 | self.optim.add_param_group( param_group) 82 | 83 | def step(self): 84 | with torch.no_grad(): 85 | weight_decays = [] 86 | for group in self.optim.param_groups: 87 | # absorb weight decay control from optimizer 88 | weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 89 | weight_decays.append(weight_decay) 90 | group['weight_decay'] = 0 91 | for p in group['params']: 92 | if p.grad is None: 93 | continue 94 | param_norm = torch.norm(p.data) 95 | grad_norm = torch.norm(p.grad.data) 96 | 97 | if param_norm != 0 and grad_norm != 0: 98 | # calculate adaptive lr + weight decay 99 | adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps) 100 | 101 | # clip learning rate for LARC 102 | if self.clip: 103 | # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)` 104 | adaptive_lr = min(adaptive_lr/group['lr'], 1) 105 | 106 | p.grad.data += weight_decay * p.data 107 | p.grad.data *= adaptive_lr 108 | 109 | self.optim.step() 110 | # return weight decay control to optimizer 111 | for i, group in enumerate(self.optim.param_groups): 112 | group['weight_decay'] = weight_decays[i] 113 | 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/lars.py: -------------------------------------------------------------------------------- 1 | """ Layer-wise adaptive rate scaling for SGD in PyTorch! """ 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class LARS(Optimizer): 6 | r"""Implements layer-wise adaptive rate scaling for SGD. 7 | 8 | Args: 9 | params (iterable): iterable of parameters to optimize or dicts defining 10 | parameter groups 11 | lr (float): base learning rate (\gamma_0) 12 | momentum (float, optional): momentum factor (default: 0) ("m") 13 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 14 | ("\beta") 15 | eta (float, optional): LARS coefficient 16 | max_epoch: maximum training epoch to determine polynomial LR decay. 17 | 18 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. 19 | Large Batch Training of Convolutional Networks: 20 | https://arxiv.org/abs/1708.03888 21 | 22 | Example: 23 | >>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3) 24 | >>> optimizer.zero_grad() 25 | >>> loss_fn(model(input), target).backward() 26 | >>> optimizer.step() 27 | """ 28 | def __init__(self, params, lr=required, momentum=.9, 29 | weight_decay=.0005, eta=0.001, max_epoch=200): 30 | if lr is not required and lr < 0.0: 31 | raise ValueError("Invalid learning rate: {}".format(lr)) 32 | if momentum < 0.0: 33 | raise ValueError("Invalid momentum value: {}".format(momentum)) 34 | if weight_decay < 0.0: 35 | raise ValueError("Invalid weight_decay value: {}" 36 | .format(weight_decay)) 37 | if eta < 0.0: 38 | raise ValueError("Invalid LARS coefficient value: {}".format(eta)) 39 | 40 | self.epoch = 0 41 | defaults = dict(lr=lr, momentum=momentum, 42 | weight_decay=weight_decay, 43 | eta=eta, max_epoch=max_epoch) 44 | super(LARS, self).__init__(params, defaults) 45 | 46 | def step(self, epoch=None, closure=None): 47 | """Performs a single optimization step. 48 | 49 | Arguments: 50 | closure (callable, optional): A closure that reevaluates the model 51 | and returns the loss. 52 | epoch: current epoch to calculate polynomial LR decay schedule. 53 | if None, uses self.epoch and increments it. 54 | """ 55 | loss = None 56 | if closure is not None: 57 | loss = closure() 58 | 59 | if epoch is None: 60 | epoch = self.epoch 61 | self.epoch += 1 62 | 63 | for group in self.param_groups: 64 | weight_decay = group['weight_decay'] 65 | momentum = group['momentum'] 66 | eta = group['eta'] 67 | lr = group['lr'] 68 | max_epoch = group['max_epoch'] 69 | 70 | for p in group['params']: 71 | if p.grad is None: 72 | continue 73 | 74 | param_state = self.state[p] 75 | d_p = p.grad.data 76 | 77 | weight_norm = torch.norm(p.data) 78 | grad_norm = torch.norm(d_p) 79 | 80 | # Global LR computed on polynomial decay schedule 81 | decay = (1 - float(epoch) / max_epoch) ** 2 82 | global_lr = lr * decay 83 | 84 | # Compute local learning rate for this layer 85 | local_lr = eta * weight_norm / (grad_norm + weight_decay * weight_norm) 86 | 87 | # Update the momentum term 88 | actual_lr = local_lr * global_lr 89 | 90 | if 'momentum_buffer' not in param_state: 91 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 92 | else: 93 | buf = param_state['momentum_buffer'] 94 | buf.mul_(momentum).add_(d_p + weight_decay * p.data, alpha=actual_lr) 95 | p.data.add_(-buf) 96 | 97 | return loss -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/lars_simclr.py: -------------------------------------------------------------------------------- 1 | """The lars optimizer used in simclr is a bit different from the paper where they exclude certain parameters""" 2 | """I asked the author of byol, they also stick to the simclr lars implementation""" 3 | 4 | 5 | 6 | import torch 7 | import torchvision 8 | from torch.optim.optimizer import Optimizer 9 | import torch.nn as nn 10 | # comments from the lead author of byol 11 | # 2. + 3. We follow the same implementation as the one used in SimCLR for LARS. This is indeed a bit 12 | # different from the one described in the LARS paper and the implementation you attached to your email. 13 | # In particular as in SimCLR we first modify the gradient to include the weight decay (with beta corresponding 14 | # to self.weight_decay in the SimCLR code) and then adapt the learning rate by dividing by the norm of this 15 | # sum, this is different from the LARS pseudo code where they divide by the sum of the norm (instead of the 16 | # norm of the sum as SimCLR and us are doing). This is done in the SimCLR code by first adding the weight 17 | # decay term to the gradient and then using this sum to perform the adaptation. We also use a term (usually 18 | # referred to as trust_coefficient but referred as eeta in SimCLR code) set to 1e-3 to multiply the updates 19 | # of linear layers. 20 | # Note that the logic "if w_norm > 0 and g_norm > 0 else 1.0" is there to tackle numerical instabilities. 21 | # In general we closely followed SimCLR implementation of LARS. 22 | class LARS_simclr(Optimizer): 23 | def __init__(self, 24 | named_modules, 25 | lr, 26 | momentum=0.9, # beta? YES 27 | trust_coef=1e-3, 28 | weight_decay=1.5e-6, 29 | exclude_bias_from_adaption=True): 30 | '''byol: As in SimCLR and official implementation of LARS, we exclude bias # and batchnorm weight from the Lars adaptation and weightdecay''' 31 | defaults = dict(momentum=momentum, 32 | lr=lr, 33 | weight_decay=weight_decay, 34 | trust_coef=trust_coef) 35 | parameters = self.exclude_from_model(named_modules, exclude_bias_from_adaption) 36 | super(LARS_simclr, self).__init__(parameters, defaults) 37 | 38 | @torch.no_grad() 39 | def step(self): 40 | for group in self.param_groups: # only 1 group in most cases 41 | weight_decay = group['weight_decay'] 42 | momentum = group['momentum'] 43 | lr = group['lr'] 44 | 45 | trust_coef = group['trust_coef'] 46 | # print(group['name']) 47 | # eps = group['eps'] 48 | for p in group['params']: 49 | # breakpoint() 50 | if p.grad is None: 51 | continue 52 | global_lr = lr 53 | velocity = self.state[p].get('velocity', 0) 54 | # if name in self.exclude_from_layer_adaptation: 55 | if self._use_weight_decay(group): 56 | p.grad.data += weight_decay * p.data 57 | 58 | trust_ratio = 1.0 59 | if self._do_layer_adaptation(group): 60 | w_norm = torch.norm(p.data, p=2) 61 | g_norm = torch.norm(p.grad.data, p=2) 62 | trust_ratio = trust_coef * w_norm / g_norm if w_norm > 0 and g_norm > 0 else 1.0 63 | scaled_lr = global_lr * trust_ratio # trust_ratio is the local_lr 64 | next_v = momentum * velocity + scaled_lr * p.grad.data 65 | update = next_v 66 | p.data = p.data - update 67 | 68 | 69 | def _use_weight_decay(self, group): 70 | return False if group['name'] == 'exclude' else True 71 | def _do_layer_adaptation(self, group): 72 | return False if group['name'] == 'exclude' else True 73 | 74 | def exclude_from_model(self, named_modules, exclude_bias_from_adaption=True): 75 | base = [] 76 | exclude = [] 77 | for name, module in named_modules: 78 | if type(module) in [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]: 79 | # if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) 80 | for name2, param in module.named_parameters(): 81 | exclude.append(param) 82 | else: 83 | for name2, param in module.named_parameters(): 84 | if name2 == 'bias': 85 | exclude.append(param) 86 | elif name2 == 'weight': 87 | base.append(param) 88 | else: 89 | pass # non leaf modules 90 | return [{ 91 | 'name': 'base', 92 | 'params': base 93 | },{ 94 | 'name': 'exclude', 95 | 'params': exclude 96 | }] if exclude_bias_from_adaption == True else [{ 97 | 'name': 'base', 98 | 'params': base+exclude 99 | }] 100 | 101 | if __name__ == "__main__": 102 | 103 | resnet = torchvision.models.resnet18(pretrained=False) 104 | model = resnet 105 | 106 | optimizer = LARS_simclr(model.named_modules(), lr=0.1) 107 | # print() 108 | # out = optimizer.exclude_from_model(model.named_modules(),exclude_bias_from_adaption=False) 109 | # print(len(out[0]['params'])) 110 | # exit() 111 | 112 | criterion = torch.nn.CrossEntropyLoss() 113 | for i in range(100): 114 | model.zero_grad() 115 | pred = model(torch.randn((2,3,32,32))) 116 | loss = pred.mean() 117 | loss.backward() 118 | optimizer.step() 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/optimizers/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | # from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class LR_Scheduler(object): 7 | def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False): 8 | self.base_lr = base_lr 9 | self.constant_predictor_lr = constant_predictor_lr 10 | warmup_iter = iter_per_epoch * warmup_epochs 11 | warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter) 12 | decay_iter = iter_per_epoch * (num_epochs - warmup_epochs) 13 | cosine_lr_schedule = final_lr+0.5*(base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter)) 14 | 15 | self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) 16 | self.optimizer = optimizer 17 | self.iter = 0 18 | self.current_lr = 0 19 | def step(self): 20 | for param_group in self.optimizer.param_groups: 21 | 22 | if self.constant_predictor_lr and param_group['name'] == 'predictor': 23 | param_group['lr'] = self.base_lr 24 | else: 25 | lr = param_group['lr'] = self.lr_schedule[self.iter] 26 | 27 | self.iter += 1 28 | self.current_lr = lr 29 | return lr 30 | def get_lr(self): 31 | return self.current_lr 32 | 33 | if __name__ == "__main__": 34 | import torchvision 35 | model = torchvision.models.resnet50() 36 | optimizer = torch.optim.SGD(model.parameters(), lr=999) 37 | epochs = 100 38 | n_iter = 1000 39 | scheduler = LR_Scheduler(optimizer, 10, 1, epochs, 3, 0, n_iter) 40 | import matplotlib.pyplot as plt 41 | lrs = [] 42 | for epoch in range(epochs): 43 | for it in range(n_iter): 44 | lr = scheduler.step() 45 | lrs.append(lr) 46 | plt.plot(lrs) 47 | plt.show() -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .average_meter import AverageMeter 2 | from .accuracy import accuracy 3 | from .knn_monitor import knn_monitor 4 | from .logger import Logger 5 | from .file_exist_fn import file_exist_check -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/accuracy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/accuracy.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/accuracy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/accuracy.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/average_meter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/average_meter.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/average_meter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/average_meter.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/file_exist_fn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/file_exist_fn.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/file_exist_fn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/file_exist_fn.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/knn_monitor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/knn_monitor.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/knn_monitor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/knn_monitor.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/plotter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/plotter.cpython-36.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/__pycache__/plotter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/CS-MIL_Docker/src/tools/__pycache__/plotter.cpython-38.pyc -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/accuracy.py: -------------------------------------------------------------------------------- 1 | def accuracy(output, target, topk=(1,)): 2 | """Computes the accuracy over the k top predictions for the specified values of k""" 3 | with torch.no_grad(): 4 | maxk = max(topk) 5 | batch_size = target.size(0) 6 | 7 | _, pred = output.topk(maxk, 1, True, True) 8 | pred = pred.t() 9 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 10 | 11 | res = [] 12 | for k in topk: 13 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 14 | res.append(correct_k.mul_(100.0 / batch_size)) 15 | return res 16 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/average_meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(): 2 | """Computes and stores the average and current value""" 3 | def __init__(self, name, fmt=':f'): 4 | self.name = name 5 | self.fmt = fmt 6 | self.log = [] 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | 12 | def reset(self): 13 | self.log.append(self.avg) 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | def __str__(self): 26 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 27 | return fmtstr.format(**self.__dict__) 28 | 29 | if __name__ == "__main__": 30 | meter = AverageMeter('sldk') 31 | print(meter.log) 32 | 33 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/file_exist_fn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | 5 | def file_exist_check(file_dir): 6 | 7 | if os.path.isdir(file_dir): 8 | for i in range(2, 1000): 9 | if not os.path.isdir(file_dir + f'({i})'): 10 | file_dir += f'({i})' 11 | break 12 | return file_dir 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/knn_monitor.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch.nn.functional as F 3 | import torch 4 | # code copied from https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb#scrollTo=RI1Y8bSImD7N 5 | # test using a knn monitor 6 | def knn_monitor(net, memory_data_loader, test_data_loader, epoch, k=200, t=0.1, hide_progress=False): 7 | net.eval() 8 | classes = len(memory_data_loader.dataset.classes) 9 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] 10 | with torch.no_grad(): 11 | # generate feature bank 12 | for data, target in tqdm(memory_data_loader, desc='Feature extracting', leave=False, disable=hide_progress): 13 | feature = net(data.cuda(non_blocking=True)) 14 | feature = F.normalize(feature, dim=1) 15 | feature_bank.append(feature) 16 | # [D, N] 17 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() 18 | # [N] 19 | feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device) 20 | # loop test data to predict the label by weighted knn search 21 | test_bar = tqdm(test_data_loader, desc='kNN', disable=hide_progress) 22 | for data, target in test_bar: 23 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 24 | feature = net(data) 25 | feature = F.normalize(feature, dim=1) 26 | 27 | pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, k, t) 28 | 29 | total_num += data.size(0) 30 | total_top1 += (pred_labels[:, 0] == target).float().sum().item() 31 | test_bar.set_postfix({'Accuracy':total_top1 / total_num * 100}) 32 | return total_top1 / total_num * 100 33 | 34 | # knn monitor as in InstDisc https://arxiv.org/abs/1805.01978 35 | # implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR 36 | def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): 37 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 38 | sim_matrix = torch.mm(feature, feature_bank) 39 | # [B, K] 40 | sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) 41 | # [B, K] 42 | sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices) 43 | sim_weight = (sim_weight / knn_t).exp() 44 | 45 | # counts for each class 46 | one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device) 47 | # [B*K, C] 48 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0) 49 | # weighted score ---> [B, C] 50 | pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1) 51 | 52 | pred_labels = pred_scores.argsort(dim=-1, descending=True) 53 | return pred_labels 54 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/logger.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | # try: 3 | # from torch.utils.tensorboard import SummaryWriter 4 | # except ImportError: 5 | from tensorboardX import SummaryWriter 6 | 7 | from torch import Tensor 8 | from collections import OrderedDict 9 | import os 10 | from .plotter import Plotter 11 | 12 | 13 | class Logger(object): 14 | def __init__(self, log_dir, tensorboard=True, matplotlib=True): 15 | 16 | self.reset(log_dir, tensorboard, matplotlib) 17 | 18 | def reset(self, log_dir=None, tensorboard=True, matplotlib=True): 19 | 20 | if log_dir is not None: self.log_dir=log_dir 21 | self.writer = SummaryWriter(log_dir=self.log_dir) if tensorboard else None 22 | self.plotter = Plotter() if matplotlib else None 23 | self.counter = OrderedDict() 24 | 25 | def update_scalers(self, ordered_dict): 26 | 27 | for key, value in ordered_dict.items(): 28 | if isinstance(value, Tensor): 29 | ordered_dict[key] = value.item() 30 | if self.counter.get(key) is None: 31 | self.counter[key] = 1 32 | else: 33 | self.counter[key] += 1 34 | 35 | if self.writer: 36 | self.writer.add_scalar(key, value, self.counter[key]) 37 | 38 | 39 | if self.plotter: 40 | self.plotter.update(ordered_dict) 41 | self.plotter.save(os.path.join(self.log_dir, 'plotter.svg')) 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /CS-MIL_Docker/src/tools/plotter.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') #https://stackoverflow.com/questions/49921721/runtimeerror-main-thread-is-not-in-main-loop-with-matplotlib-and-flask 3 | import matplotlib.pyplot as plt 4 | from collections import OrderedDict 5 | from torch import Tensor 6 | 7 | class Plotter(object): 8 | def __init__(self): 9 | self.logger = OrderedDict() 10 | def update(self, ordered_dict): 11 | for key, value in ordered_dict.items(): 12 | if isinstance(value, Tensor): 13 | ordered_dict[key] = value.item() 14 | if self.logger.get(key) is None: 15 | self.logger[key] = [value] 16 | else: 17 | self.logger[key].append(value) 18 | 19 | def save(self, file, **kwargs): 20 | fig, axes = plt.subplots(nrows=len(self.logger), ncols=1, figsize=(8,2*len(self.logger))) 21 | fig.tight_layout() 22 | for ax, (key, value) in zip(axes, self.logger.items()): 23 | ax.plot(value) 24 | ax.set_title(key) 25 | 26 | plt.savefig(file, **kwargs) 27 | plt.close() 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /Cross-modality.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Cross-modality.png -------------------------------------------------------------------------------- /Cross-scale.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Cross-scale.png -------------------------------------------------------------------------------- /Cross_modality.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Cross_modality.png -------------------------------------------------------------------------------- /Emb_Clustering_Code/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 PatrickHua 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Emb_Clustering_Code/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | import numpy as np 6 | import torch 7 | import random 8 | 9 | import re 10 | import yaml 11 | 12 | import shutil 13 | import warnings 14 | 15 | from datetime import datetime 16 | 17 | 18 | class Namespace(object): 19 | def __init__(self, somedict): 20 | for key, value in somedict.items(): 21 | assert isinstance(key, str) and re.match("[A-Za-z_-]", key) 22 | if isinstance(value, dict): 23 | self.__dict__[key] = Namespace(value) 24 | else: 25 | self.__dict__[key] = value 26 | 27 | def __getattr__(self, attribute): 28 | 29 | raise AttributeError(f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!") 30 | 31 | 32 | def set_deterministic(seed): 33 | # seed by default is None 34 | if seed is not None: 35 | print(f"Deterministic with seed = {seed}") 36 | random.seed(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed(seed) 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | 43 | def get_args(): 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('-c', '--config-file', required=True, type=str, help="xxx.yaml") 46 | parser.add_argument('--debug', action='store_true') 47 | parser.add_argument('--debug_subset_size', type=int, default=8) 48 | parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web") 49 | parser.add_argument('--data_dir', type=str, default=os.getenv('DATA')) 50 | parser.add_argument('--log_dir', type=str, default=os.getenv('LOG')) 51 | parser.add_argument('--ckpt_dir', type=str, default=os.getenv('CHECKPOINT')) 52 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 53 | parser.add_argument('--eval_from', type=str, default=None) 54 | parser.add_argument('--hide_progress', action='store_true') 55 | args = parser.parse_args() 56 | 57 | 58 | with open(args.config_file, 'r') as f: 59 | for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items(): 60 | vars(args)[key] = value 61 | 62 | if args.debug: 63 | if args.train: 64 | args.train.batch_size 65 | args.train.num_epochs = 1 66 | args.train.stop_at_epoch = 1 67 | if args.eval: 68 | args.eval.batch_size = 2 69 | args.eval.num_epochs = 1 # train only one epoch 70 | args.dataset.num_workers = 0 71 | 72 | 73 | assert not None in [args.log_dir, args.data_dir, args.ckpt_dir, args.name] 74 | 75 | args.log_dir = os.path.join(args.log_dir, 'in-progress_'+datetime.now().strftime('%m%d%H%M%S_')+args.name) 76 | 77 | os.makedirs(args.log_dir, exist_ok=False) 78 | print(f'creating file {args.log_dir}') 79 | os.makedirs(args.ckpt_dir, exist_ok=True) 80 | 81 | shutil.copy2(args.config_file, args.log_dir) 82 | set_deterministic(args.seed) 83 | 84 | 85 | vars(args)['aug_kwargs'] = { 86 | 'name':args.model.name, 87 | 'image_size': args.dataset.image_size 88 | } 89 | vars(args)['dataset_kwargs'] = { 90 | 'dataset':args.dataset.name, 91 | 'data_dir': args.data_dir, 92 | 'download':args.download, 93 | 'debug_subset_size': args.debug_subset_size if args.debug else None, 94 | } 95 | vars(args)['dataloader_kwargs'] = { 96 | 'drop_last': True, 97 | 'pin_memory': True, 98 | 'num_workers': args.dataset.num_workers, 99 | } 100 | 101 | return args 102 | -------------------------------------------------------------------------------- /Emb_Clustering_Code/configs/GCA512_0411_train4w.yaml: -------------------------------------------------------------------------------- 1 | name: simsiam-GCA-b6128s256_ori1024patch 2 | dataset: 3 | name: random 4 | image_size: 256 5 | num_workers: 8 6 | 7 | model: 8 | name: simsiam 9 | backbone: resnet50 10 | proj_layers: 2 11 | 12 | train: 13 | optimizer: 14 | name: sgd 15 | weight_decay: 0.0001 16 | momentum: 0.9 17 | warmup_epochs: 10 18 | warmup_lr: 0 19 | base_lr: 0.05 20 | final_lr: 0 21 | num_epochs: 800 # this parameter influence the lr decay 22 | stop_at_epoch: 200 # has to be smaller than num_epochs 23 | batch_size: 128 24 | save_interval: 1 25 | knn_monitor: False # knn monitor will take more time 26 | knn_interval: 1 27 | knn_k: 200 28 | eval: # linear evaluation, False will turn off automatic evaluation after training 29 | optimizer: 30 | name: sgd 31 | weight_decay: 0 32 | momentum: 0.9 33 | warmup_lr: 0 34 | warmup_epochs: 0 35 | base_lr: 30 36 | final_lr: 0 37 | batch_size: 128 38 | num_epochs: 200 39 | 40 | logger: 41 | tensorboard: True 42 | matplotlib: True 43 | 44 | seed: null # None type for yaml file 45 | # two things might lead to stochastic behavior other than seed: 46 | # worker_init_fn from dataloader and torch.nn.functional.interpolate 47 | # (keep this in mind if you want to achieve 100% deterministic) 48 | -------------------------------------------------------------------------------- /Emb_Clustering_Code/create_kmeans_features_local_multi-resolution.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import glob 6 | import matplotlib.pyplot as plt 7 | 8 | severe = 3 9 | 10 | if severe == 0: 11 | feature_dir = '/Data2/GCA/simsiam/feature_data_multi-resolution/' 12 | csv_dir = '/Data2/GCA/simsiam/feature_cluster_multi-resolution_csv_local' 13 | label_file = pd.read_csv('/Data2/GCA/simsiam/data_list.csv') 14 | elif severe == 1: 15 | feature_dir = '/Data2/GCA/simsiam/feature_data_multi-resolution_severe/' 16 | csv_dir = '/Data2/GCA/simsiam/feature_cluster_multi-resolution_csv_local_severe' 17 | label_file = pd.read_csv('/Data2/GCA/data_list_severe.csv') 18 | elif severe == 2: 19 | feature_dir = '/Data2/GCA/simsiam/feature_data_multi-resolution_severe/' 20 | csv_dir = '/Data2/GCA/simsiam/feature_cluster_multi-resolution_csv_local_CD' 21 | label_file = pd.read_csv('/Data2/GCA/data_list_CD.csv') 22 | else: 23 | feature_dir = '/Data2/GCA/simsiam/feature_data_multi-resolution_severe_extend/' 24 | csv_dir = '/Data2/GCA/simsiam/feature_cluster_multi-resolution_csv_local_severe_extend' 25 | label_file = pd.read_csv('/Data2/GCA/data_list_severe_extend.csv') 26 | 27 | cases = glob.glob(os.path.join(feature_dir, "*")) 28 | 29 | mapping = 0 30 | 31 | for now_case in cases: 32 | # for now_case in cases[:5]: 33 | print(now_case) 34 | now_wsi = os.path.basename(now_case) 35 | csv_folder = csv_dir 36 | 37 | if len(label_file[label_file['filename'] == now_wsi + '.svs']['class'].tolist()) == 0: 38 | continue 39 | 40 | now_label = label_file[label_file['filename'] == now_wsi + '.svs']['class'].tolist()[0] 41 | 42 | if not os.path.exists(csv_folder): 43 | os.makedirs(csv_folder) 44 | 45 | feature_files = glob.glob(os.path.join(now_case,"*")) 46 | 47 | features = np.zeros((len(feature_files), 2048)) 48 | 49 | for ii in range(len(feature_files)): 50 | features[ii] = np.load(feature_files[ii]) 51 | 52 | local_features = features 53 | local_features_file = feature_files 54 | # if len(local_features) == 0: 55 | # local_features = features 56 | # local_features_file = feature_files 57 | # else: 58 | # local_features = np.concatenate((local_features, features), axis = 0) 59 | # local_features_file.extend((feature_files)) 60 | 61 | 62 | kmeans = KMeans(n_clusters=8, random_state=0).fit(local_features) 63 | df = pd.DataFrame(columns=['wsi_name', 'root', 'label', 'cluster']) 64 | for ii in range(len(local_features)): 65 | now_root = local_features_file[ii] 66 | now_wsi = os.path.basename(now_root).split('_')[0] 67 | now_label = label_file[label_file['filename'] == now_wsi + '.svs']['class'].tolist()[0] 68 | 69 | now_cluster = kmeans.labels_[ii] 70 | row = len(df) 71 | df.loc[row] = [now_wsi, now_root, now_label, now_cluster] 72 | 73 | 74 | now_wsi = os.path.basename(now_case) 75 | csv_folder = csv_dir 76 | if len(label_file[label_file['filename'] == now_wsi + '.svs']['class'].tolist()) == 0: 77 | continue 78 | 79 | now_label = label_file[label_file['filename'] == now_wsi + '.svs']['class'].tolist()[0] 80 | 81 | save_root = os.path.join(csv_folder, '%s_cluster.csv' % (now_wsi)) 82 | 83 | now_df = df[df['wsi_name'] == now_wsi] 84 | now_df.to_csv(save_root, index = False) 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /Emb_Clustering_Code/create_kmeans_features_local_singleresolution.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import glob 6 | import matplotlib.pyplot as plt 7 | 8 | severe_list = [3] 9 | resolution_list = [256,512,1024] 10 | 11 | for sl in range(len(severe_list)): 12 | severe = severe_list[sl] 13 | for rl in range(len(resolution_list)): 14 | resolution = resolution_list[rl] 15 | 16 | print('now is : severe %d, resolution %d' % (severe, resolution)) 17 | 18 | if severe == 0: 19 | feature_dir = '/Data2/GCA/simsiam/feature_data_multi-resolution/' 20 | csv_dir = '/Data2/GCA/simsiam/feature_cluster_multi-resolution_csv_local_oneclustering_only%d' % (resolution) 21 | label_file = pd.read_csv('/Data2/GCA/simsiam/data_list.csv') 22 | elif severe == 1: 23 | feature_dir = '/Data2/GCA/simsiam/feature_data_multi-resolution_severe/' 24 | csv_dir = '/Data2/GCA/simsiam/feature_cluster_multi-resolution_csv_local_severe_oneclustering_only%d' % (resolution) 25 | label_file = pd.read_csv('/Data2/GCA/data_list_severe.csv') 26 | elif severe == 2: 27 | feature_dir = '/Data2/GCA/simsiam/feature_data_multi-resolution_severe/' 28 | csv_dir = '/Data2/GCA/simsiam/feature_cluster_multi-resolution_csv_local_CD_oneclustering_only%d' % (resolution) 29 | label_file = pd.read_csv('/Data2/GCA/data_list_CD.csv') 30 | else: 31 | feature_dir = '/Data2/GCA/simsiam/feature_data_multi-resolution_severe_extend/' 32 | csv_dir = '/Data2/GCA/simsiam/feature_cluster_multi-resolution_csv_local_severe_extend_oneclustering_only%d' % (resolution) 33 | label_file = pd.read_csv('/Data2/GCA/data_list_severe_extend.csv') 34 | 35 | cases = glob.glob(os.path.join(feature_dir, "*")) 36 | 37 | mapping = 0 38 | 39 | 40 | for now_case in cases: 41 | # for now_case in cases[:5]: 42 | print(now_case) 43 | 44 | now_wsi = os.path.basename(now_case) 45 | csv_folder = csv_dir 46 | 47 | if len(label_file[label_file['filename'] == now_wsi + '.svs']['class'].tolist()) == 0: 48 | continue 49 | 50 | now_label = label_file[label_file['filename'] == now_wsi + '.svs']['class'].tolist()[0] 51 | 52 | if not os.path.exists(csv_folder): 53 | os.makedirs(csv_folder) 54 | 55 | feature_files = glob.glob(os.path.join(now_case,"*size%d.npy" % (resolution))) 56 | 57 | features = np.zeros((len(feature_files), 2048)) 58 | 59 | for ii in range(len(feature_files)): 60 | features[ii] = np.load(feature_files[ii]) 61 | 62 | local_features = features 63 | local_features_file = feature_files 64 | 65 | kmeans = KMeans(n_clusters=8, random_state=0).fit(local_features) 66 | df = pd.DataFrame(columns=['wsi_name', 'root', 'label', 'cluster']) 67 | for ii in range(len(local_features)): 68 | now_root = local_features_file[ii] 69 | now_wsi = os.path.basename(now_root).split('_')[0] 70 | now_label = label_file[label_file['filename'] == now_wsi + '.svs']['class'].tolist()[0] 71 | 72 | now_cluster = kmeans.labels_[ii] 73 | row = len(df) 74 | df.loc[row] = [now_wsi, now_root, now_label, now_cluster] 75 | 76 | 77 | now_wsi = os.path.basename(now_case) 78 | csv_folder = csv_dir 79 | if len(label_file[label_file['filename'] == now_wsi + '.svs']['class'].tolist()) == 0: 80 | continue 81 | 82 | now_label = label_file[label_file['filename'] == now_wsi + '.svs']['class'].tolist()[0] 83 | 84 | save_root = os.path.join(csv_folder, '%s_cluster.csv' % (now_wsi)) 85 | 86 | now_df = df[df['wsi_name'] == now_wsi] 87 | now_df.to_csv(save_root, index = False) 88 | 89 | 90 | -------------------------------------------------------------------------------- /Emb_Clustering_Code/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from .random_dataset import RandomDataset 4 | 5 | 6 | def get_dataset(dataset, data_dir, transform, train=True, download=False, debug_subset_size=None): 7 | if dataset == 'mnist': 8 | dataset = torchvision.datasets.MNIST(data_dir, train=train, transform=transform, download=download) 9 | elif dataset == 'stl10': 10 | dataset = torchvision.datasets.STL10(data_dir, split='train+unlabeled' if train else 'test', transform=transform, download=download) 11 | elif dataset == 'cifar10': 12 | dataset = torchvision.datasets.CIFAR10(data_dir, train=train, transform=transform, download=download) 13 | elif dataset == 'cifar100': 14 | dataset = torchvision.datasets.CIFAR100(data_dir, train=train, transform=transform, download=download) 15 | elif dataset == 'imagenet': 16 | dataset = torchvision.datasets.ImageNet(data_dir, split='train' if train == True else 'val', transform=transform, download=download) 17 | elif dataset == 'random': 18 | dataset = RandomDataset() 19 | else: 20 | raise NotImplementedError 21 | 22 | if debug_subset_size is not None: 23 | dataset = torch.utils.data.Subset(dataset, range(0, debug_subset_size)) # take only one batch 24 | dataset.classes = dataset.dataset.classes 25 | dataset.targets = dataset.dataset.targets 26 | return dataset -------------------------------------------------------------------------------- /Emb_Clustering_Code/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/datasets/__pycache__/random_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/datasets/__pycache__/random_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/datasets/__pycache__/random_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/datasets/__pycache__/random_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/datasets/random_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class RandomDataset(torch.utils.data.Dataset): 4 | def __init__(self, root=None, train=True, transform=None, target_transform=None): 5 | self.transform = transform 6 | self.target_transform = target_transform 7 | 8 | self.size = 1000 9 | 10 | def __getitem__(self, idx): 11 | if idx < self.size: 12 | # return [torch.randn((3, 224, 224)), torch.randn((3, 224, 224))], [0,0,0] 13 | return [torch.randn((3, 128, 128)), torch.randn((3, 128, 128))], [0,0,0] 14 | else: 15 | raise Exception 16 | 17 | def __len__(self): 18 | return self.size 19 | -------------------------------------------------------------------------------- /Emb_Clustering_Code/get_features_simsiam_1024.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from tqdm import tqdm 7 | from arguments import get_args 8 | from augmentations import get_aug 9 | from models import get_model, get_backbone 10 | from tools import AverageMeter 11 | from datasets import get_dataset 12 | from optimizers import get_optimizer, LR_Scheduler 13 | import glob 14 | import matplotlib.pyplot as plt 15 | import torchvision.transforms as T 16 | import pandas as pd 17 | import numpy as np 18 | import random 19 | 20 | def main(args): 21 | tensor_transform = T.Compose([ 22 | T.Resize([256,]), 23 | # T.ToTensor(), 24 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 25 | ]) 26 | 27 | model = get_backbone(args.model.backbone) 28 | 29 | assert args.eval_from is not None 30 | save_dict = torch.load(args.eval_from, map_location='cpu') 31 | msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True) 32 | 33 | # print(msg) 34 | model = model.to(args.device) 35 | model = torch.nn.DataParallel(model) 36 | 37 | model.eval() 38 | 39 | label_file = pd.read_csv('/Data2/GCA/data_list_severe.csv') 40 | features_output = '/Data2/GCA/simsiam/feature_data_severe_1024/' 41 | data_dir = '/Data2/GCA/GCA_Original_Series_patch_1024_0407' 42 | 43 | images = [] 44 | 45 | for row in range(len(label_file)): 46 | now_case = label_file.iloc[row]['filename'].replace('.svs','') 47 | now_images = glob.glob(os.path.join(data_dir, now_case, '*')) 48 | images.extend((now_images)) 49 | 50 | random.shuffle(images) 51 | patch_size = 1024 52 | batch = 16 #int(args.eval.batch_size) 53 | bag_num = int(len(images) / batch) + 1 54 | 55 | if not os.path.exists(features_output): 56 | os.makedirs(features_output) 57 | 58 | for ri in range(bag_num): 59 | if ri != bag_num - 1: 60 | now_images = images[ri * batch : (ri + 1) * batch] 61 | else: 62 | now_images = images[ri * batch:] 63 | 64 | tensor = np.zeros((len(now_images), patch_size , patch_size, 3)) 65 | for ni in range(len(now_images)): 66 | image_folder = now_images[ni] 67 | tensor[ni] = plt.imread(image_folder)[:,:,:3] 68 | 69 | # tensor = tensor.transpose([0,3,1,2]) 70 | tensor = torch.from_numpy(tensor).permute([0,3,1,2]) 71 | inputs = tensor_transform(tensor) 72 | features = model(inputs.to(args.device).float()) 73 | 74 | for fi in range(len(features)): 75 | now_name = os.path.basename(now_images[fi]) 76 | wsi_name = now_name.split('_')[0] 77 | now_feature = features[fi].detach().cpu().numpy() 78 | save_root = os.path.join(features_output, wsi_name) 79 | 80 | if not os.path.exists(save_root): 81 | os.makedirs(save_root) 82 | 83 | save_dir = os.path.join(save_root, now_name.replace('.png', '.npy')) 84 | np.save(save_dir, now_feature) 85 | 86 | 87 | if __name__ == "__main__": 88 | main(args=get_args()) 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /Emb_Clustering_Code/get_features_simsiam_256.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from tqdm import tqdm 7 | from arguments import get_args 8 | from augmentations import get_aug 9 | from models import get_model, get_backbone 10 | from tools import AverageMeter 11 | from datasets import get_dataset 12 | from optimizers import get_optimizer, LR_Scheduler 13 | import glob 14 | import matplotlib.pyplot as plt 15 | import torchvision.transforms as T 16 | import pandas as pd 17 | import numpy as np 18 | import random 19 | 20 | def main(args): 21 | tensor_transform = T.Compose([ 22 | T.Resize([256,]), 23 | # T.ToTensor(), 24 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 25 | ]) 26 | 27 | model = get_backbone(args.model.backbone) 28 | 29 | assert args.eval_from is not None 30 | save_dict = torch.load(args.eval_from, map_location='cpu') 31 | msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True) 32 | 33 | # print(msg) 34 | model = model.to(args.device) 35 | model = torch.nn.DataParallel(model) 36 | 37 | model.eval() 38 | 39 | label_file = pd.read_csv('/Data2/GCA/simsiam/data_list.csv') 40 | features_output = '/Data2/GCA/simsiam/feature_data_1024/' 41 | data_dir = '/Data2/GCA/GCA_Original_Series_patch_1024_0407' 42 | 43 | images = [] 44 | 45 | for row in range(len(label_file)): 46 | now_case = label_file.iloc[row]['filename'].replace('.svs','') 47 | now_images = glob.glob(os.path.join(data_dir, now_case, '*')) 48 | images.extend((now_images)) 49 | 50 | random.shuffle(images) 51 | patch_size = 1024 52 | batch = 32 #int(args.eval.batch_size) 53 | bag_num = int(len(images) / batch) + 1 54 | 55 | if not os.path.exists(features_output): 56 | os.makedirs(features_output) 57 | 58 | for ri in range(bag_num): 59 | if ri != bag_num - 1: 60 | now_images = images[ri * batch : (ri + 1) * batch] 61 | else: 62 | now_images = images[ri * batch:] 63 | 64 | tensor = np.zeros((len(now_images), patch_size , patch_size, 3)) 65 | for ni in range(len(now_images)): 66 | image_folder = now_images[ni] 67 | tensor[ni] = plt.imread(image_folder)[:,:,:3] 68 | 69 | # tensor = tensor.transpose([0,3,1,2]) 70 | tensor = torch.from_numpy(tensor).permute([0,3,1,2]) 71 | inputs = tensor_transform(tensor) 72 | features = model(inputs.to(args.device).float()) 73 | 74 | for fi in range(len(features)): 75 | now_name = os.path.basename(now_images[fi]) 76 | wsi_name = now_name.split('_')[0] 77 | now_feature = features[fi].detach().cpu().numpy() 78 | save_root = os.path.join(features_output, wsi_name) 79 | 80 | if not os.path.exists(save_root): 81 | os.makedirs(save_root) 82 | 83 | save_dir = os.path.join(save_root, now_name.replace('.png', '.npy')) 84 | np.save(save_dir, now_feature) 85 | 86 | 87 | if __name__ == "__main__": 88 | main(args=get_args()) 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /Emb_Clustering_Code/get_features_simsiam_512.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from tqdm import tqdm 7 | from arguments import get_args 8 | from augmentations import get_aug 9 | from models import get_model, get_backbone 10 | from tools import AverageMeter 11 | from datasets import get_dataset 12 | from optimizers import get_optimizer, LR_Scheduler 13 | import glob 14 | import matplotlib.pyplot as plt 15 | import torchvision.transforms as T 16 | import pandas as pd 17 | import numpy as np 18 | import random 19 | 20 | def main(args): 21 | tensor_transform = T.Compose([ 22 | T.Resize([256,]), 23 | # T.ToTensor(), 24 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 25 | ]) 26 | 27 | model = get_backbone(args.model.backbone) 28 | 29 | assert args.eval_from is not None 30 | save_dict = torch.load(args.eval_from, map_location='cpu') 31 | msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True) 32 | 33 | # print(msg) 34 | model = model.to(args.device) 35 | model = torch.nn.DataParallel(model) 36 | 37 | model.eval() 38 | 39 | label_file = pd.read_csv('/Data2/GCA/simsiam/data_list.csv') 40 | features_output = '/Data2/GCA/simsiam/feature_data_512/' 41 | data_dir = '/Data2/GCA/GCA_Original_Series_patch_512_0407' 42 | 43 | images = [] 44 | 45 | for row in range(len(label_file)): 46 | now_case = label_file.iloc[row]['filename'].replace('.svs','') 47 | now_images = glob.glob(os.path.join(data_dir, now_case, '*')) 48 | images.extend((now_images)) 49 | 50 | random.shuffle(images) 51 | patch_size = 512 52 | batch = 16 #int(args.eval.batch_size) 53 | bag_num = int(len(images) / batch) + 1 54 | 55 | if not os.path.exists(features_output): 56 | os.makedirs(features_output) 57 | 58 | for ri in range(bag_num): 59 | if ri != bag_num - 1: 60 | now_images = images[ri * batch : (ri + 1) * batch] 61 | else: 62 | now_images = images[ri * batch:] 63 | 64 | tensor = np.zeros((len(now_images), patch_size , patch_size, 3)) 65 | for ni in range(len(now_images)): 66 | image_folder = now_images[ni] 67 | tensor[ni] = plt.imread(image_folder)[:,:,:3] 68 | 69 | # tensor = tensor.transpose([0,3,1,2]) 70 | tensor = torch.from_numpy(tensor).permute([0,3,1,2]) 71 | inputs = tensor_transform(tensor) 72 | features = model(inputs.to(args.device).float()) 73 | 74 | for fi in range(len(features)): 75 | now_name = os.path.basename(now_images[fi]) 76 | wsi_name = now_name.split('_')[0] 77 | now_feature = features[fi].detach().cpu().numpy() 78 | save_root = os.path.join(features_output, wsi_name) 79 | 80 | if not os.path.exists(save_root): 81 | os.makedirs(save_root) 82 | 83 | save_dir = os.path.join(save_root, now_name.replace('.png', '.npy')) 84 | np.save(save_dir, now_feature) 85 | 86 | 87 | if __name__ == "__main__": 88 | main(args=get_args()) 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /Emb_Clustering_Code/get_features_simsiam_multi-resolution.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from tqdm import tqdm 7 | from arguments import get_args 8 | from augmentations import get_aug 9 | from models import get_model, get_backbone 10 | from tools import AverageMeter 11 | from datasets import get_dataset 12 | from optimizers import get_optimizer, LR_Scheduler 13 | import glob 14 | import matplotlib.pyplot as plt 15 | import torchvision.transforms as T 16 | import pandas as pd 17 | import numpy as np 18 | import random 19 | 20 | def main(args): 21 | tensor_transform = T.Compose([ 22 | T.Resize([256,]), 23 | # T.ToTensor(), 24 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 25 | ]) 26 | 27 | resolution_list = [256,512,1024] 28 | for rrri in range(len(resolution_list)): 29 | 30 | resolution = resolution_list[rrri] 31 | 32 | if resolution == 1024: 33 | args.eval_from = 'checkpoint/simsiam-GCA-b6128s256_ori1024patch_0414131202.pth' 34 | elif resolution == 512: 35 | args.eval_from = 'checkpoint/simsiam-GCA-b6128s256_ori512patch_0418160825.pth' 36 | else: 37 | args.eval_from = 'checkpoint/simsiam-GCA-b6128s256_ori256patch_0422134301.pth' 38 | 39 | severe = 2 40 | 41 | if severe == 0: 42 | label_file = pd.read_csv('/Data2/GCA/simsiam/data_list.csv') 43 | features_output = '/Data2/GCA/simsiam/feature_data_multi-resolution/' 44 | data_dir = '/media/dengr/Seagate Backup Plus Drive/GCA_multi-resolution/GCA_Original_Series_patch_multi-resolution_0421' 45 | 46 | elif severe == 1: 47 | label_file = pd.read_csv('/Data2/GCA/data_list_severe.csv') 48 | features_output = '/Data2/GCA/simsiam/feature_data_multi-resolution_severe/' 49 | data_dir = '/media/dengr/Seagate Backup Plus Drive/GCA_multi-resolution/GCA_Original_Series_patch_multi-resolution_0421' 50 | else: 51 | label_file = pd.read_csv('/Data2/GCA/data_list_severe_extend.csv') 52 | features_output = '/Data2/GCA/simsiam/feature_data_multi-resolution_severe_extend/' 53 | data_dir = '/media/dengr/Seagate Backup Plus Drive/GCA_multi-resolution/GCA_Original_Series_patch_multi-resolution_extend_0621' 54 | 55 | images = [] 56 | 57 | model = get_backbone(args.model.backbone) 58 | 59 | assert args.eval_from is not None 60 | save_dict = torch.load(args.eval_from, map_location='cpu') 61 | msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True) 62 | 63 | # print(msg) 64 | model = model.to(args.device) 65 | model = torch.nn.DataParallel(model) 66 | 67 | model.eval() 68 | 69 | for row in range(len(label_file)): 70 | now_case = label_file.iloc[row]['filename'].replace('.svs','') 71 | now_images = glob.glob(os.path.join(data_dir, now_case, '*size%d.png' % (resolution))) 72 | images.extend((now_images)) 73 | 74 | random.shuffle(images) 75 | patch_size = resolution 76 | batch = 32 #int(args.eval.batch_size) 77 | bag_num = int(len(images) / batch) + 1 78 | 79 | if not os.path.exists(features_output): 80 | os.makedirs(features_output) 81 | 82 | for ri in range(bag_num): 83 | if ri != bag_num - 1: 84 | now_images = images[ri * batch : (ri + 1) * batch] 85 | else: 86 | now_images = images[ri * batch:] 87 | 88 | tensor = np.zeros((len(now_images), patch_size , patch_size, 3)) 89 | for ni in range(len(now_images)): 90 | image_folder = now_images[ni] 91 | tensor[ni] = plt.imread(image_folder)[:,:,:3] 92 | 93 | # tensor = tensor.transpose([0,3,1,2]) 94 | tensor = torch.from_numpy(tensor).permute([0,3,1,2]) 95 | inputs = tensor_transform(tensor) 96 | features = model(inputs.to(args.device).float()) 97 | 98 | for fi in range(len(features)): 99 | now_name = os.path.basename(now_images[fi]) 100 | wsi_name = now_name.split('_')[0] 101 | now_feature = features[fi].detach().cpu().numpy() 102 | save_root = os.path.join(features_output, wsi_name) 103 | 104 | if not os.path.exists(save_root): 105 | os.makedirs(save_root) 106 | 107 | save_dir = os.path.join(save_root, now_name.replace('.png', '.npy')) 108 | np.save(save_dir, now_feature) 109 | 110 | 111 | if __name__ == "__main__": 112 | main(args=get_args()) 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .lars import LARS 2 | from .lars_simclr import LARS_simclr 3 | from .larc import LARC 4 | import torch 5 | from .lr_scheduler import LR_Scheduler 6 | 7 | 8 | def get_optimizer(name, model, lr, momentum, weight_decay): 9 | 10 | predictor_prefix = ('module.predictor', 'predictor') 11 | parameters = [{ 12 | 'name': 'base', 13 | 'params': [param for name, param in model.named_parameters() if not name.startswith(predictor_prefix)], 14 | 'lr': lr 15 | },{ 16 | 'name': 'predictor', 17 | 'params': [param for name, param in model.named_parameters() if name.startswith(predictor_prefix)], 18 | 'lr': lr 19 | }] 20 | if name == 'lars': 21 | optimizer = LARS(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) 22 | elif name == 'sgd': 23 | optimizer = torch.optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) 24 | elif name == 'lars_simclr': # Careful 25 | optimizer = LARS_simclr(model.named_modules(), lr=lr, momentum=momentum, weight_decay=weight_decay) 26 | elif name == 'larc': 27 | optimizer = LARC( 28 | torch.optim.SGD( 29 | parameters, 30 | lr=lr, 31 | momentum=momentum, 32 | weight_decay=weight_decay 33 | ), 34 | trust_coefficient=0.001, 35 | clip=False 36 | ) 37 | else: 38 | raise NotImplementedError 39 | return optimizer 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/optimizers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/optimizers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/__pycache__/larc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/optimizers/__pycache__/larc.cpython-36.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/__pycache__/larc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/optimizers/__pycache__/larc.cpython-38.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/__pycache__/lars.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/optimizers/__pycache__/lars.cpython-36.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/__pycache__/lars.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/optimizers/__pycache__/lars.cpython-38.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/__pycache__/lars_simclr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/optimizers/__pycache__/lars_simclr.cpython-36.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/__pycache__/lars_simclr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/optimizers/__pycache__/lars_simclr.cpython-38.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/optimizers/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Emb_Clustering_Code/optimizers/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/larc.py: -------------------------------------------------------------------------------- 1 | """SwAV use larc instead of lars optimizer""" 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.parameter import Parameter 6 | from torch.optim.optimizer import Optimizer 7 | 8 | def main(): # Example 9 | import torchvision 10 | model = torchvision.models.resnet18(pretrained=False) 11 | # optim = torch.optim.Adam(model.parameters(), lr=0.0001) 12 | optim = torch.optim.SGD(model.parameters(),lr=0.2, momentum=0.9, weight_decay=1.5e-6) 13 | optim = LARC(optim) 14 | 15 | class LARC(Optimizer): 16 | """ 17 | :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, 18 | in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive 19 | local learning rate for each individual parameter. The algorithm is designed to improve 20 | convergence of large batch training. 21 | 22 | See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. 23 | In practice it modifies the gradients of parameters as a proxy for modifying the learning rate 24 | of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer. 25 | ``` 26 | model = ... 27 | optim = torch.optim.Adam(model.parameters(), lr=...) 28 | optim = LARC(optim) 29 | ``` 30 | It can even be used in conjunction with apex.fp16_utils.FP16_optimizer. 31 | ``` 32 | model = ... 33 | optim = torch.optim.Adam(model.parameters(), lr=...) 34 | optim = LARC(optim) 35 | optim = apex.fp16_utils.FP16_Optimizer(optim) 36 | ``` 37 | Args: 38 | optimizer: Pytorch optimizer to wrap and modify learning rate for. 39 | trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888 40 | clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. 41 | eps: epsilon kludge to help with numerical stability while calculating adaptive_lr 42 | """ 43 | 44 | def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8): 45 | self.optim = optimizer 46 | self.trust_coefficient = trust_coefficient 47 | self.eps = eps 48 | self.clip = clip 49 | 50 | def __getstate__(self): 51 | return self.optim.__getstate__() 52 | 53 | def __setstate__(self, state): 54 | self.optim.__setstate__(state) 55 | 56 | @property 57 | def state(self): 58 | return self.optim.state 59 | 60 | def __repr__(self): 61 | return self.optim.__repr__() 62 | 63 | @property 64 | def param_groups(self): 65 | return self.optim.param_groups 66 | 67 | @param_groups.setter 68 | def param_groups(self, value): 69 | self.optim.param_groups = value 70 | 71 | def state_dict(self): 72 | return self.optim.state_dict() 73 | 74 | def load_state_dict(self, state_dict): 75 | self.optim.load_state_dict(state_dict) 76 | 77 | def zero_grad(self): 78 | self.optim.zero_grad() 79 | 80 | def add_param_group(self, param_group): 81 | self.optim.add_param_group( param_group) 82 | 83 | def step(self): 84 | with torch.no_grad(): 85 | weight_decays = [] 86 | for group in self.optim.param_groups: 87 | # absorb weight decay control from optimizer 88 | weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 89 | weight_decays.append(weight_decay) 90 | group['weight_decay'] = 0 91 | for p in group['params']: 92 | if p.grad is None: 93 | continue 94 | param_norm = torch.norm(p.data) 95 | grad_norm = torch.norm(p.grad.data) 96 | 97 | if param_norm != 0 and grad_norm != 0: 98 | # calculate adaptive lr + weight decay 99 | adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps) 100 | 101 | # clip learning rate for LARC 102 | if self.clip: 103 | # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)` 104 | adaptive_lr = min(adaptive_lr/group['lr'], 1) 105 | 106 | p.grad.data += weight_decay * p.data 107 | p.grad.data *= adaptive_lr 108 | 109 | self.optim.step() 110 | # return weight decay control to optimizer 111 | for i, group in enumerate(self.optim.param_groups): 112 | group['weight_decay'] = weight_decays[i] 113 | 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/lars.py: -------------------------------------------------------------------------------- 1 | """ Layer-wise adaptive rate scaling for SGD in PyTorch! """ 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class LARS(Optimizer): 6 | r"""Implements layer-wise adaptive rate scaling for SGD. 7 | 8 | Args: 9 | params (iterable): iterable of parameters to optimize or dicts defining 10 | parameter groups 11 | lr (float): base learning rate (\gamma_0) 12 | momentum (float, optional): momentum factor (default: 0) ("m") 13 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 14 | ("\beta") 15 | eta (float, optional): LARS coefficient 16 | max_epoch: maximum training epoch to determine polynomial LR decay. 17 | 18 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. 19 | Large Batch Training of Convolutional Networks: 20 | https://arxiv.org/abs/1708.03888 21 | 22 | Example: 23 | >>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3) 24 | >>> optimizer.zero_grad() 25 | >>> loss_fn(model(input), target).backward() 26 | >>> optimizer.step() 27 | """ 28 | def __init__(self, params, lr=required, momentum=.9, 29 | weight_decay=.0005, eta=0.001, max_epoch=200): 30 | if lr is not required and lr < 0.0: 31 | raise ValueError("Invalid learning rate: {}".format(lr)) 32 | if momentum < 0.0: 33 | raise ValueError("Invalid momentum value: {}".format(momentum)) 34 | if weight_decay < 0.0: 35 | raise ValueError("Invalid weight_decay value: {}" 36 | .format(weight_decay)) 37 | if eta < 0.0: 38 | raise ValueError("Invalid LARS coefficient value: {}".format(eta)) 39 | 40 | self.epoch = 0 41 | defaults = dict(lr=lr, momentum=momentum, 42 | weight_decay=weight_decay, 43 | eta=eta, max_epoch=max_epoch) 44 | super(LARS, self).__init__(params, defaults) 45 | 46 | def step(self, epoch=None, closure=None): 47 | """Performs a single optimization step. 48 | 49 | Arguments: 50 | closure (callable, optional): A closure that reevaluates the model 51 | and returns the loss. 52 | epoch: current epoch to calculate polynomial LR decay schedule. 53 | if None, uses self.epoch and increments it. 54 | """ 55 | loss = None 56 | if closure is not None: 57 | loss = closure() 58 | 59 | if epoch is None: 60 | epoch = self.epoch 61 | self.epoch += 1 62 | 63 | for group in self.param_groups: 64 | weight_decay = group['weight_decay'] 65 | momentum = group['momentum'] 66 | eta = group['eta'] 67 | lr = group['lr'] 68 | max_epoch = group['max_epoch'] 69 | 70 | for p in group['params']: 71 | if p.grad is None: 72 | continue 73 | 74 | param_state = self.state[p] 75 | d_p = p.grad.data 76 | 77 | weight_norm = torch.norm(p.data) 78 | grad_norm = torch.norm(d_p) 79 | 80 | # Global LR computed on polynomial decay schedule 81 | decay = (1 - float(epoch) / max_epoch) ** 2 82 | global_lr = lr * decay 83 | 84 | # Compute local learning rate for this layer 85 | local_lr = eta * weight_norm / (grad_norm + weight_decay * weight_norm) 86 | 87 | # Update the momentum term 88 | actual_lr = local_lr * global_lr 89 | 90 | if 'momentum_buffer' not in param_state: 91 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 92 | else: 93 | buf = param_state['momentum_buffer'] 94 | buf.mul_(momentum).add_(d_p + weight_decay * p.data, alpha=actual_lr) 95 | p.data.add_(-buf) 96 | 97 | return loss -------------------------------------------------------------------------------- /Emb_Clustering_Code/optimizers/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | # from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class LR_Scheduler(object): 7 | def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False): 8 | self.base_lr = base_lr 9 | self.constant_predictor_lr = constant_predictor_lr 10 | warmup_iter = iter_per_epoch * warmup_epochs 11 | warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter) 12 | decay_iter = iter_per_epoch * (num_epochs - warmup_epochs) 13 | cosine_lr_schedule = final_lr+0.5*(base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter)) 14 | 15 | self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) 16 | self.optimizer = optimizer 17 | self.iter = 0 18 | self.current_lr = 0 19 | def step(self): 20 | for param_group in self.optimizer.param_groups: 21 | 22 | if self.constant_predictor_lr and param_group['name'] == 'predictor': 23 | param_group['lr'] = self.base_lr 24 | else: 25 | lr = param_group['lr'] = self.lr_schedule[self.iter] 26 | 27 | self.iter += 1 28 | self.current_lr = lr 29 | return lr 30 | def get_lr(self): 31 | return self.current_lr 32 | 33 | if __name__ == "__main__": 34 | import torchvision 35 | model = torchvision.models.resnet50() 36 | optimizer = torch.optim.SGD(model.parameters(), lr=999) 37 | epochs = 100 38 | n_iter = 1000 39 | scheduler = LR_Scheduler(optimizer, 10, 1, epochs, 3, 0, n_iter) 40 | import matplotlib.pyplot as plt 41 | lrs = [] 42 | for epoch in range(epochs): 43 | for it in range(n_iter): 44 | lr = scheduler.step() 45 | lrs.append(lr) 46 | plt.plot(lrs) 47 | plt.show() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Maximilian Ilse and Jakub Tomczak 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Pipeline.png -------------------------------------------------------------------------------- /Toydataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Toydataset.png -------------------------------------------------------------------------------- /ToydatasetResults.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/ToydatasetResults.png -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/Background_filter.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import glob 6 | import matplotlib.pyplot as plt 7 | 8 | data_dir = '/Data2/GCA/AttentionDeepMIL-master/data/train' 9 | output_dir = '/Data2/GCA/AttentionDeepMIL-master/contrastive_learning_data' 10 | 11 | if not os.path.exists(output_dir): 12 | os.makedirs(output_dir) 13 | 14 | label = glob.glob(os.path.join(data_dir, "*")) 15 | 16 | for ki in range(len(label)): 17 | cases = glob.glob(os.path.join(label[ki], "*")) 18 | 19 | # df = pd.DataFrame(columns = ['root', 'label', 'cluster']) 20 | 21 | for now_case in cases: 22 | images = glob.glob(os.path.join(now_case,"*")) 23 | images.sort() 24 | 25 | for now_image in images: 26 | patch = plt.imread(now_image)[:,:,:3] 27 | image_name = os.path.basename(now_image) 28 | if (patch.mean(2) > 230 / 255).sum() < 512 * 512 / 2: # for dodnet 29 | plt.imsave(os.path.join(output_dir, image_name), patch) 30 | 31 | -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/Classifier_model_MAg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model definition of DeepAttnMISL 3 | 4 | If this work is useful for your research, please consider to cite our papers: 5 | 6 | [1] "Whole Slide Images based Cancer Survival Prediction using Attention Guided Deep Multiple Instance Learning Networks" 7 | Jiawen Yao, XinliangZhu, Jitendra Jonnagaddala, NicholasHawkins, Junzhou Huang, 8 | Medical Image Analysis, Available online 19 July 2020, 101789 9 | 10 | [2] "Deep Multi-instance Learning for Survival Prediction from Whole Slide Images", In MICCAI 2019 11 | 12 | """ 13 | 14 | import torch.nn as nn 15 | import torch 16 | 17 | class Classifier(nn.Module): 18 | """ 19 | Deep AttnMISL Model definition 20 | """ 21 | 22 | def __init__(self): 23 | super(Classifier, self).__init__() 24 | 25 | self.classifier = nn.Sequential( 26 | nn.Linear(10, 5), 27 | nn.ReLU(), 28 | # nn.Dropout(p=0.5), 29 | nn.Linear(5, 2), 30 | nn.Softmax(dim = 1) 31 | ) 32 | 33 | def forward(self, x): 34 | 35 | out = self.classifier(x) 36 | 37 | return out 38 | 39 | -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/DeepAttnMISL_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model definition of DeepAttnMISL 3 | 4 | If this work is useful for your research, please consider to cite our papers: 5 | 6 | [1] "Whole Slide Images based Cancer Survival Prediction using Attention Guided Deep Multiple Instance Learning Networks" 7 | Jiawen Yao, XinliangZhu, Jitendra Jonnagaddala, NicholasHawkins, Junzhou Huang, 8 | Medical Image Analysis, Available online 19 July 2020, 101789 9 | 10 | [2] "Deep Multi-instance Learning for Survival Prediction from Whole Slide Images", In MICCAI 2019 11 | 12 | """ 13 | 14 | import torch.nn as nn 15 | import torch 16 | 17 | class DeepAttnMIL_Surv(nn.Module): 18 | """ 19 | Deep AttnMISL Model definition 20 | """ 21 | 22 | def __init__(self, cluster_num): 23 | super(DeepAttnMIL_Surv, self).__init__() 24 | self.embedding_net = nn.Sequential(nn.Conv2d(2048, 64, 1), 25 | nn.ReLU(), 26 | nn.AdaptiveAvgPool2d((1,1)) 27 | ) 28 | 29 | 30 | self.attention = nn.Sequential( 31 | nn.Linear(64, 32), # V 32 | nn.Tanh(), 33 | nn.Linear(32, 1) # W 34 | ) 35 | 36 | self.fc6 = nn.Sequential( 37 | nn.Linear(64, 32), 38 | nn.ReLU(), 39 | nn.Dropout(p=0.5), 40 | nn.Linear(32, 1), 41 | nn.Sigmoid() 42 | ) 43 | self.cluster_num = cluster_num 44 | 45 | 46 | def masked_softmax(self, x, mask=None): 47 | """ 48 | Performs masked softmax, as simply masking post-softmax can be 49 | inaccurate 50 | :param x: [batch_size, num_items] 51 | :param mask: [batch_size, num_items] 52 | :return: 53 | """ 54 | if mask is not None: 55 | mask = mask.float() 56 | if mask is not None: 57 | x_masked = x * mask + (1 - 1 / (mask+1e-5)) 58 | else: 59 | x_masked = x 60 | x_max = x_masked.max(1)[0] 61 | x_exp = (x - x_max.unsqueeze(-1)).exp() 62 | if mask is not None: 63 | x_exp = x_exp * mask.float() 64 | return x_exp / x_exp.sum(1).unsqueeze(-1) 65 | 66 | 67 | def forward(self, x, mask): 68 | 69 | " x is a tensor list" 70 | res = [] 71 | for i in range(self.cluster_num): 72 | hh = x[i].type(torch.FloatTensor).to("cuda") 73 | output = self.embedding_net(hh) 74 | output = output.view(output.size()[0], -1) 75 | res.append(output) 76 | 77 | 78 | h = torch.cat(res) 79 | 80 | b = h.size(0) 81 | c = h.size(1) 82 | 83 | h = h.view(b, c) 84 | 85 | A = self.attention(h) 86 | A = torch.transpose(A, 1, 0) # KxN 87 | 88 | A = self.masked_softmax(A, mask) 89 | 90 | 91 | M = torch.mm(A, h) # KxL 92 | 93 | Y_pred = self.fc6(M) 94 | 95 | return Y_pred 96 | 97 | -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/DeepAttnMISL_model_no21.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model definition of DeepAttnMISL 3 | 4 | If this work is useful for your research, please consider to cite our papers: 5 | 6 | [1] "Whole Slide Images based Cancer Survival Prediction using Attention Guided Deep Multiple Instance Learning Networks" 7 | Jiawen Yao, XinliangZhu, Jitendra Jonnagaddala, NicholasHawkins, Junzhou Huang, 8 | Medical Image Analysis, Available online 19 July 2020, 101789 9 | 10 | [2] "Deep Multi-instance Learning for Survival Prediction from Whole Slide Images", In MICCAI 2019 11 | 12 | """ 13 | 14 | import torch.nn as nn 15 | import torch 16 | 17 | class DeepAttnMIL_Surv(nn.Module): 18 | """ 19 | Deep AttnMISL Model definition 20 | """ 21 | 22 | def __init__(self, cluster_num): 23 | super(DeepAttnMIL_Surv, self).__init__() 24 | self.embedding_net = nn.Sequential(nn.Conv2d(2048, 64, 1), 25 | nn.ReLU(), 26 | nn.AdaptiveAvgPool2d((1,1)) 27 | ) 28 | 29 | self.res_attention = nn.Sequential( 30 | nn.Conv2d(64, 32, 1), # V 31 | nn.ReLU(), 32 | nn.Conv2d(32, 1, 1), 33 | ) 34 | 35 | self.attention = nn.Sequential( 36 | nn.Linear(64, 32), # V 37 | nn.Tanh(), 38 | nn.Linear(32, 1) # W 39 | ) 40 | 41 | self.fc6 = nn.Sequential( 42 | nn.Linear(64, 32), 43 | nn.ReLU(), 44 | nn.Dropout(p=0.5), 45 | nn.Linear(32, 1), 46 | nn.Sigmoid() 47 | ) 48 | self.cluster_num = cluster_num 49 | 50 | 51 | def masked_softmax(self, x, mask=None): 52 | """ 53 | Performs masked softmax, as simply masking post-softmax can be 54 | inaccurate 55 | :param x: [batch_size, num_items] 56 | :param mask: [batch_size, num_items] 57 | :return: 58 | """ 59 | if mask is not None: 60 | mask = mask.float() 61 | if mask is not None: 62 | x_masked = x * mask + (1 - 1 / (mask+1e-5)) 63 | else: 64 | x_masked = x 65 | x_max = x_masked.max(1)[0] 66 | x_exp = (x - x_max.unsqueeze(-1)).exp() 67 | if mask is not None: 68 | x_exp = x_exp * mask.float() 69 | return x_exp / x_exp.sum(1).unsqueeze(-1) 70 | 71 | 72 | def forward(self, x, mask): 73 | 74 | " x is a tensor list" 75 | res = [] 76 | for i in range(self.cluster_num): 77 | hh = x[i].type(torch.FloatTensor).to("cuda") 78 | output1 = self.embedding_net(hh[:,:,0:1,:]) 79 | output2 = self.embedding_net(hh[:,:,1:2,:]) 80 | output3 = self.embedding_net(hh[:,:,2:3,:]) 81 | output = torch.cat([output1, output2, output3],2) 82 | res_attention = self.res_attention(output).squeeze(-1) 83 | 84 | final_output = torch.matmul(output.squeeze(-1), torch.transpose(res_attention,2,1)).squeeze(-1) 85 | res.append(final_output) 86 | 87 | h = torch.cat(res) 88 | 89 | b = h.size(0) 90 | c = h.size(1) 91 | 92 | h = h.view(b, c) 93 | 94 | A = self.attention(h) 95 | A = torch.transpose(A, 1, 0) # KxN 96 | 97 | A = self.masked_softmax(A, mask) 98 | 99 | M = torch.mm(A, h) # KxL 100 | 101 | Y_pred = self.fc6(M) 102 | 103 | return Y_pred 104 | 105 | -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/DeepAttnMISL_model_no21_attentionscore.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model definition of DeepAttnMISL 3 | 4 | If this work is useful for your research, please consider to cite our papers: 5 | 6 | [1] "Whole Slide Images based Cancer Survival Prediction using Attention Guided Deep Multiple Instance Learning Networks" 7 | Jiawen Yao, XinliangZhu, Jitendra Jonnagaddala, NicholasHawkins, Junzhou Huang, 8 | Medical Image Analysis, Available online 19 July 2020, 101789 9 | 10 | [2] "Deep Multi-instance Learning for Survival Prediction from Whole Slide Images", In MICCAI 2019 11 | 12 | """ 13 | 14 | import torch.nn as nn 15 | import torch 16 | 17 | class DeepAttnMIL_Surv(nn.Module): 18 | """ 19 | Deep AttnMISL Model definition 20 | """ 21 | 22 | def __init__(self, cluster_num): 23 | super(DeepAttnMIL_Surv, self).__init__() 24 | self.embedding_net = nn.Sequential(nn.Conv2d(2048, 64, 1), 25 | nn.ReLU(), 26 | nn.AdaptiveAvgPool2d((1,1)) 27 | ) 28 | 29 | self.res_attention = nn.Sequential( 30 | nn.Conv2d(64, 32, 1), # V 31 | nn.ReLU(), 32 | nn.Conv2d(32, 1, 1), 33 | ) 34 | 35 | self.attention = nn.Sequential( 36 | nn.Linear(64, 32), # V 37 | nn.Tanh(), 38 | nn.Linear(32, 1) # W 39 | ) 40 | 41 | self.fc6 = nn.Sequential( 42 | nn.Linear(64, 32), 43 | nn.ReLU(), 44 | nn.Dropout(p=0.5), 45 | nn.Linear(32, 1), 46 | nn.Sigmoid() 47 | ) 48 | self.cluster_num = cluster_num 49 | 50 | 51 | def masked_softmax(self, x, mask=None): 52 | """ 53 | Performs masked softmax, as simply masking post-softmax can be 54 | inaccurate 55 | :param x: [batch_size, num_items] 56 | :param mask: [batch_size, num_items] 57 | :return: 58 | """ 59 | if mask is not None: 60 | mask = mask.float() 61 | if mask is not None: 62 | x_masked = x * mask + (1 - 1 / (mask+1e-5)) 63 | else: 64 | x_masked = x 65 | x_max = x_masked.max(1)[0] 66 | x_exp = (x - x_max.unsqueeze(-1)).exp() 67 | if mask is not None: 68 | x_exp = x_exp * mask.float() 69 | return x_exp / x_exp.sum(1).unsqueeze(-1) 70 | 71 | 72 | def forward(self, x, mask): 73 | 74 | " x is a tensor list" 75 | res = [] 76 | resolution_attention_256 = [] 77 | resolution_attention_512 = [] 78 | resolution_attention_1024 = [] 79 | for i in range(self.cluster_num): 80 | hh = x[i].type(torch.FloatTensor).to("cuda") 81 | output1 = self.embedding_net(hh[:,:,0:1,:]) 82 | output2 = self.embedding_net(hh[:,:,1:2,:]) 83 | output3 = self.embedding_net(hh[:,:,2:3,:]) 84 | output = torch.cat([output1, output2, output3],2) 85 | res_attention = self.res_attention(output).squeeze(-1) 86 | res_attention_1 = res_attention[:,:,0] 87 | res_attention_2 = res_attention[:,:,1] 88 | res_attention_3 = res_attention[:,:,2] 89 | 90 | resolution_attention_256.append(res_attention_1) 91 | resolution_attention_512.append(res_attention_2) 92 | resolution_attention_1024.append(res_attention_3) 93 | 94 | final_output = torch.matmul(output.squeeze(-1), torch.transpose(res_attention,2,1)).squeeze(-1) 95 | res.append(final_output) 96 | 97 | h = torch.cat(res) 98 | 99 | b = h.size(0) 100 | c = h.size(1) 101 | 102 | h = h.view(b, c) 103 | 104 | A = self.attention(h) 105 | A = torch.transpose(A, 1, 0) # KxN 106 | 107 | A = self.masked_softmax(A, mask) 108 | 109 | M = torch.mm(A, h) # KxL 110 | 111 | Y_pred = self.fc6(M) 112 | 113 | return Y_pred, resolution_attention_256, resolution_attention_512, resolution_attention_1024, A 114 | 115 | -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/DeepAttnMISL_model_no21_withResnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model definition of DeepAttnMISL 3 | 4 | If this work is useful for your research, please consider to cite our papers: 5 | 6 | [1] "Whole Slide Images based Cancer Survival Prediction using Attention Guided Deep Multiple Instance Learning Networks" 7 | Jiawen Yao, XinliangZhu, Jitendra Jonnagaddala, NicholasHawkins, Junzhou Huang, 8 | Medical Image Analysis, Available online 19 July 2020, 101789 9 | 10 | [2] "Deep Multi-instance Learning for Survival Prediction from Whole Slide Images", In MICCAI 2019 11 | 12 | """ 13 | 14 | import torch.nn as nn 15 | import torch 16 | 17 | from torchvision.models import resnet18, densenet121 18 | 19 | 20 | 21 | class DeepAttnMIL_Surv(nn.Module): 22 | """ 23 | Deep AttnMISL Model definition 24 | """ 25 | 26 | def __init__(self, cluster_num): 27 | super(DeepAttnMIL_Surv, self).__init__() 28 | 29 | self.resnet = resnet18(pretrained=True) 30 | self.feature_extractor = torch.nn.Sequential(*list(self.resnet.children())[:-1]).cuda() 31 | 32 | for param in self.feature_extractor.parameters(): 33 | param.requires_grad = True 34 | 35 | self.embedding_net = nn.Sequential(nn.Conv2d(512, 64, 1), 36 | nn.ReLU(), 37 | nn.AdaptiveAvgPool2d((1,1)) 38 | ) 39 | 40 | self.res_attention = nn.Sequential( 41 | nn.Conv2d(64, 32, 1), # V 42 | nn.ReLU(), 43 | nn.Conv2d(32, 1, 1), 44 | ) 45 | 46 | self.attention = nn.Sequential( 47 | nn.Linear(64, 32), # V 48 | nn.Tanh(), 49 | nn.Linear(32, 1) # W 50 | ) 51 | 52 | self.fc6 = nn.Sequential( 53 | nn.Linear(64, 32), 54 | nn.ReLU(), 55 | nn.Dropout(p=0.5), 56 | nn.Linear(32, 1), 57 | nn.Sigmoid() 58 | ) 59 | self.cluster_num = cluster_num 60 | 61 | self.softmax = nn.Softmax(2) 62 | 63 | 64 | def masked_softmax(self, x, mask=None): 65 | """ 66 | Performs masked softmax, as simply masking post-softmax can be 67 | inaccurate 68 | :param x: [batch_size, num_items] 69 | :param mask: [batch_size, num_items] 70 | :return: 71 | """ 72 | if mask is not None: 73 | mask = mask.float() 74 | if mask is not None: 75 | x_masked = x * mask + (1 - 1 / (mask+1e-5)) 76 | else: 77 | x_masked = x 78 | x_max = x_masked.max(1)[0] 79 | x_exp = (x - x_max.unsqueeze(-1)).exp() 80 | if mask is not None: 81 | x_exp = x_exp * mask.float() 82 | return x_exp / x_exp.sum(1).unsqueeze(-1) 83 | 84 | 85 | def forward(self, x, mask): 86 | 87 | " x is a tensor list" 88 | res = [] 89 | x = x.float() 90 | hh1 = self.feature_extractor(x[0,:, 0,...].permute([0,3,1,2])) 91 | hh2 = self.feature_extractor(x[0,:, 1,...].permute([0,3,1,2])) 92 | hh3 = self.feature_extractor(x[0,:, 2,...].permute([0,3,1,2])) 93 | 94 | 95 | output1 = self.embedding_net(hh1) 96 | output2 = self.embedding_net(hh2) 97 | output3 = self.embedding_net(hh3) 98 | output = torch.cat([output1, output2, output3],2) 99 | res_attention = self.res_attention(output).squeeze(-1) 100 | res_attention = self.softmax(res_attention) 101 | final_output = torch.matmul(output.squeeze(-1), torch.transpose(res_attention,2,1)).squeeze(-1) 102 | res.append(final_output) 103 | 104 | h = torch.cat(res) 105 | 106 | b = h.size(0) 107 | c = h.size(1) 108 | 109 | h = h.view(b, c) 110 | 111 | A = self.attention(h) 112 | A = torch.transpose(A, 1, 0) # KxN 113 | 114 | A = self.masked_softmax(A, mask) 115 | 116 | M = torch.mm(A, h) # KxL 117 | 118 | Y_pred = self.fc6(M) 119 | 120 | return Y_pred, A, res_attention 121 | 122 | -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/GetDataList.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import glob 6 | import matplotlib.pyplot as plt 7 | 8 | data_dir = '/Data2/GCA/AttentionDeepMIL/Normal_control_vs_patient.csv' 9 | csv_dir = '/Data2/GCA/AttentionDeepMIL/data_list.csv' 10 | 11 | label_df = pd.read_csv(data_dir) 12 | list_df = pd.DataFrame(columns = ['filename', 'region', 'patient_type', 'score', 'class', 'train']) 13 | 14 | for ki in range(len(label_df)): 15 | #if label_df.iloc[ki]['patient_type'].replace(" ", "") == 'CD' or label_df.iloc[ki]['patient_type'].replace(" ", "") == 'Control': 16 | if not pd.isna(label_df.iloc[ki]['patient_type']): 17 | now_file = label_df.iloc[ki]['filename'].replace(" ", "") 18 | now_region = label_df.iloc[ki]['region'].replace(" ", "") 19 | now_patient_type = label_df.iloc[ki]['patient_type'].replace(" ", "") 20 | if now_patient_type == 'CD': 21 | now_class = 1 22 | else: 23 | now_class = 0 24 | 25 | now_score = int(label_df.iloc[ki]['score']) 26 | now_train = 1 27 | 28 | row = len(list_df) 29 | list_df.loc[row] = [now_file, now_region, now_patient_type, now_score, now_class, now_train] 30 | 31 | list_df.to_csv(csv_dir, index = False) -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Maximilian Ilse and Jakub Tomczak 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/MIL_dataloader_csv_data1.py: -------------------------------------------------------------------------------- 1 | """Pytorch dataset object that loads MNIST dataset as bags.""" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data_utils 6 | from torchvision import datasets, transforms 7 | import glob 8 | import os 9 | #import imgaug.augmenters as iaa 10 | import pandas as pd 11 | import random 12 | 13 | import matplotlib.pyplot as plt 14 | # import imgaug.augmenters as iaa 15 | 16 | 17 | class MILBags(data_utils.Dataset): 18 | def __init__(self, bag_root, MIL): 19 | self.bag_root = bag_root 20 | self.bag_list = glob.glob(os.path.join(self.bag_root,'*')) 21 | self.MIL = MIL 22 | 23 | def _create_bags(self, index): 24 | 25 | df = pd.read_csv(self.bag_list[index]) 26 | now_images = df['img_root'].tolist() 27 | 28 | images_bags = np.zeros((len(df), 3, 256, 256, 3)) 29 | for ii in range(len(df)): 30 | if self.MIL: 31 | images_bags[ii, 0, ...] = plt.imread(now_images[ii].replace('size1024', 'size256'))[:, :,:3] 32 | images_bags[ii, 1, ...] = plt.imread(now_images[ii].replace('size1024', 'size512'))[:, :,:3] 33 | images_bags[ii, 2, ...] = plt.imread(now_images[ii].replace('size1024', 'size1024'))[:, :,:3] 34 | 35 | else: 36 | images_bags[ii, 0, ...] = plt.imread(now_images[ii].replace('size1024', 'size256'))[:, :, :3] 37 | images_bags[ii, 1, ...] = plt.imread(now_images[ii].replace('size1024', 'size256'))[:, :, :3] 38 | images_bags[ii, 2, ...] = plt.imread(now_images[ii].replace('size1024', 'size256'))[:, :, :3] 39 | 40 | 41 | label_bag = int(df['class'].tolist()[0]) 42 | mask = 1 43 | #print(now_images) 44 | 45 | # if label_bag == 0: 46 | # images_bags = images_bags * 0 47 | return images_bags, label_bag, mask, now_images 48 | 49 | def __len__(self): 50 | return len(self.bag_list) 51 | 52 | def __getitem__(self, index): 53 | bag, label, mask, root = self._create_bags(index) 54 | return bag, label, mask, root 55 | 56 | if __name__ == "__main__": 57 | 58 | train_loader = data_utils.DataLoader(MILBags(bag_root = '/Data2/GCA_Demo/Bag_1000_1/Train', 59 | MIL = False), 60 | batch_size=1, 61 | shuffle=True) 62 | 63 | validation_loader = data_utils.DataLoader(MILBags(bag_root = '/Data2/GCA_Demo/Bag_1000_1/Train', 64 | MIL = False), 65 | batch_size=1, 66 | shuffle=False) 67 | 68 | 69 | len_bag_list_train = [] 70 | mnist_bags_train = 0 71 | for batch_idx, (bag, label, mask, case_name) in enumerate(train_loader): 72 | print('aaa') 73 | len_bag_list_train.append(int(bag.squeeze(0).size()[0])) 74 | 75 | -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/MIL_dataloader_csv_data2.py: -------------------------------------------------------------------------------- 1 | """Pytorch dataset object that loads MNIST dataset as bags.""" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data_utils 6 | from torchvision import datasets, transforms 7 | import glob 8 | import os 9 | # import imgaug.augmenters as iaa 10 | import pandas as pd 11 | import random 12 | 13 | import matplotlib.pyplot as plt 14 | # import imgaug.augmenters as iaa 15 | 16 | 17 | class MILBags(data_utils.Dataset): 18 | def __init__(self, bag_root, MIL): 19 | self.bag_root = bag_root 20 | self.bag_list = glob.glob(os.path.join(self.bag_root,'*')) 21 | self.MIL = MIL 22 | 23 | def _create_bags(self, index): 24 | 25 | df = pd.read_csv(self.bag_list[index]) 26 | now_images = df['img_root'].tolist() 27 | 28 | images_bags = np.zeros((len(df), 3, 256, 256, 3)) 29 | for ii in range(len(df)): 30 | if self.MIL: 31 | images_bags[ii, 0, ...] = plt.imread(now_images[ii].replace('size1024', 'size256'))[:, :,:3] 32 | images_bags[ii, 1, ...] = plt.imread(now_images[ii].replace('size1024', 'size512'))[:, :,:3] 33 | images_bags[ii, 2, ...] = plt.imread(now_images[ii].replace('size1024', 'size1024'))[:, :,:3] 34 | 35 | else: 36 | images_bags[ii, 0, ...] = plt.imread(now_images[ii].replace('size1024', 'size1024'))[:, :, :3] 37 | images_bags[ii, 1, ...] = plt.imread(now_images[ii].replace('size1024', 'size1024'))[:, :, :3] 38 | images_bags[ii, 2, ...] = plt.imread(now_images[ii].replace('size1024', 'size1024'))[:, :, :3] 39 | 40 | 41 | label_bag = int(df['class'].tolist()[0]) 42 | mask = 1 43 | #print(now_images) 44 | 45 | # if label_bag == 0: 46 | # images_bags = images_bags * 0 47 | return images_bags, label_bag, mask, now_images 48 | 49 | def __len__(self): 50 | return len(self.bag_list) 51 | 52 | def __getitem__(self, index): 53 | bag, label, mask, root = self._create_bags(index) 54 | return bag, label, mask, root 55 | 56 | if __name__ == "__main__": 57 | 58 | train_loader = data_utils.DataLoader(MILBags(bag_root = '/Data2/GCA_Demo/Bag_1000_1/Train', 59 | MIL = False), 60 | batch_size=1, 61 | shuffle=True) 62 | 63 | validation_loader = data_utils.DataLoader(MILBags(bag_root = '/Data2/GCA_Demo/Bag_1000_1/Val', 64 | MIL = False), 65 | batch_size=1, 66 | shuffle=False) 67 | 68 | 69 | len_bag_list_train = [] 70 | mnist_bags_train = 0 71 | for batch_idx, (bag, label, mask, case_name) in enumerate(train_loader): 72 | print('aaa') 73 | len_bag_list_train.append(int(bag.squeeze(0).size()[0])) 74 | 75 | -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/MIL_dataloader_csv_image.py: -------------------------------------------------------------------------------- 1 | """Pytorch dataset object that loads MNIST dataset as bags.""" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data_utils 6 | from torchvision import datasets, transforms 7 | import glob 8 | import os 9 | import imgaug.augmenters as iaa 10 | import pandas as pd 11 | import random 12 | 13 | import matplotlib.pyplot as plt 14 | # import imgaug.augmenters as iaa 15 | 16 | 17 | class MILBags(data_utils.Dataset): 18 | def __init__(self, data_csv_root, feature_csv_root, size = (256, 256), mean_bag_length = 64, var_bag_length = 4, seed=1, train=True): 19 | self.feature_csv_root = feature_csv_root 20 | self.data_csv_root = data_csv_root 21 | self.size = size 22 | self.mean_bag_length = mean_bag_length 23 | self.var_bag_length = var_bag_length 24 | self.train = train 25 | 26 | self.r = np.random.RandomState(seed) 27 | self.data_csv = pd.read_csv(data_csv_root) 28 | 29 | self.image_list = [] 30 | self.label_list = [] 31 | 32 | # self.data_aug = 33 | 34 | 35 | for ki in range(len(self.data_csv)): 36 | now_case, now_label, train = self.data_csv.iloc[ki]['filename'], self.data_csv.iloc[ki]['class'], self.data_csv.iloc[ki]['train'] 37 | if (self.train == 1) and (train == 1): 38 | self.image_list.append((pd.read_csv(os.path.join(self.feature_csv_root,now_case.replace('.svs', '.csv'))))) 39 | self.label_list.append((now_label)) 40 | 41 | elif (self.train == 2) and (train == 2): 42 | self.image_list.append((pd.read_csv(os.path.join(self.feature_csv_root,now_case.replace('.svs', '.csv'))))) 43 | self.label_list.append((now_label)) 44 | 45 | self.case_num = len(self.image_list) 46 | 47 | def _create_bags(self, index): 48 | now_images = self.image_list[index]['root'].tolist() 49 | now_label = self.label_list[index] 50 | 51 | bag_length = np.int(self.r.normal(self.mean_bag_length, self.var_bag_length, 1)) 52 | if bag_length < 1: 53 | bag_length = 1 54 | 55 | indices = torch.LongTensor(self.r.randint(0, len(now_images), bag_length)) 56 | images_bags = np.zeros((bag_length, 3, 256, 256, 3)) 57 | 58 | for ii in range(len(indices)): 59 | images_bags[ii,2,...] = plt.imread(now_images[indices[ii]])[:,:,:3] 60 | images_bags[ii,1,...] = plt.imread(now_images[indices[ii]].replace('size1024', 'size512'))[:,:,:3] 61 | images_bags[ii,0,...] = plt.imread(now_images[indices[ii]].replace('size1024', 'size256'))[:,:,:3] 62 | 63 | # if now_label == 0: 64 | # images_bags = images_bags * 0 65 | 66 | label_bag = now_label 67 | mask = 1 68 | case_name = os.path.basename(now_images[0]).split('_')[0] 69 | 70 | return images_bags, label_bag, mask, case_name 71 | 72 | def __len__(self): 73 | return len(self.image_list) 74 | 75 | def __getitem__(self, index): 76 | bag, label, mask, case_name = self._create_bags(index) 77 | return bag, label, mask, case_name 78 | 79 | 80 | if __name__ == "__main__": 81 | 82 | train_loader = data_utils.DataLoader(MILBags(data_csv_root = '/Data2/GCA_Demo/data_list_severe.csv', 83 | feature_csv_root = '/Data2/GCA_Demo/Datase1_csv', 84 | size = (256, 256), 85 | mean_bag_length=64, 86 | var_bag_length=4, 87 | seed=1, 88 | train=1), 89 | batch_size=1, 90 | shuffle=True) 91 | 92 | validation_loader = data_utils.DataLoader(MILBags(data_csv_root = '/Data2/GCA_Demo/data_list_severe.csv', 93 | feature_csv_root = '/Data2/GCA_Demo/Datase1_csv', 94 | size=(256, 256), 95 | mean_bag_length=64, 96 | var_bag_length=4, 97 | seed=1, 98 | train=2), 99 | batch_size=1, 100 | shuffle=False) 101 | 102 | 103 | len_bag_list_train = [] 104 | mnist_bags_train = 0 105 | for batch_idx, (bag, label, mask, case_name) in enumerate(train_loader): 106 | print('aaa') 107 | len_bag_list_train.append(int(bag.squeeze(0).size()[0])) 108 | # mnist_bags_train += label[0].numpy()[0] 109 | # print('Number positive train bags: {}/{}\n' 110 | # 'Number of instances per bag, mean: {}, max: {}, min {}\n'.format( 111 | # mnist_bags_train, len(train_loader), 112 | # np.mean(len_bag_list_train), np.max(len_bag_list_train), np.min(len_bag_list_train))) 113 | 114 | -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/ROCAUC_dataset1_classmark.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import glob 6 | import matplotlib.pyplot as plt 7 | #from scipy.interpolate import int 8 | from skimage.transform import resize 9 | from scipy.ndimage import gaussian_filter 10 | import cv2 11 | 12 | attention_df = pd.read_csv('results_stage1_clustering_dataset1/attention_score_dataset1_region.csv') 13 | 14 | region_image_folder = '/Data3/GCA_Demo/4864patch_dataset1' 15 | region_folder = '/Data3/GCA_Demo/4864patch_dataset1_patch' 16 | attention_folder = '/Data3/GCA_Demo/4864patch_dataset1_attention' 17 | 18 | if not os.path.exists(attention_folder): 19 | os.makedirs(attention_folder) 20 | 21 | 22 | patch_size = 256 23 | 24 | get_bagscore_flag = 0 25 | 26 | if get_bagscore_flag: 27 | bag_score_df = pd.DataFrame(columns=['bag_num', 'label', 'pred']) 28 | 29 | for bi in range(int(len(attention_df)/8)): 30 | print('%d/%d' % (bi, int(len(attention_df)/8))) 31 | 32 | label = 0 33 | pred = attention_df['score'][bi * 8] 34 | now_bag = attention_df['bag_num'][bi * 8] 35 | 36 | now_root_list = attention_df['root'][bi * 8].split('/') 37 | region_name = now_root_list[4] 38 | now_classmark = glob.glob(os.path.join(attention_folder, "%s*classmark.png" % (region_name)))[0] 39 | img = plt.imread(now_classmark)[:,:,:3] 40 | 41 | 42 | for ki in range(8): 43 | idx = bi * 8 + ki 44 | now_root_list = attention_df['root'][idx].split('/') 45 | region_name = now_root_list[4] 46 | now_x = int(now_root_list[5].split('_')[1]) 47 | now_y = int(now_root_list[5].split('_')[2]) 48 | 49 | patch_red = img[now_x - 384:now_x + patch_size - 384, now_y - 384:now_y + patch_size - 384, 0].mean() 50 | patch_green = img[now_x - 384:now_x + patch_size - 384, now_y - 384:now_y + patch_size - 384, 1].mean() 51 | 52 | if patch_red > 0.9 and patch_green < 0.8: 53 | label = 1 54 | break 55 | 56 | bag_score_df.loc[bi] = [now_bag, label, pred] 57 | 58 | bag_score_df.to_csv('results_stage1_clustering_dataset1/bag_score.csv', index = False) 59 | 60 | else: 61 | bag_score_df = pd.read_csv('results_stage1_clustering_dataset1/bag_score.csv') 62 | y_test = np.array(bag_score_df['label'].tolist()) 63 | preds = np.array(bag_score_df['pred'].tolist()) 64 | 65 | 66 | # cnt = 0 67 | # for ki in range(len(y_test)): 68 | # if (preds[ki] > 0.5 and y_test[ki] == 1) or (preds[ki] < 0.5 and y_test[ki] == 0): 69 | # cnt += 1 70 | # 71 | # print(cnt/len(y_test)) 72 | 73 | acc = ((preds > 0.5) * (y_test == 1) + (preds < 0.5) * (y_test == 0)).mean() 74 | 75 | 76 | import sklearn.metrics as metrics 77 | 78 | import matplotlib.pyplot as plt 79 | 80 | plt.title('Receiver Operating Characteristic') 81 | # plt.plot(fpr, tpr, 'b', label='AUC = %0.4f' % (roc_auc)) 82 | # plt.legend(loc='lower right') 83 | plt.plot([0, 1], [0, 1], 'r--') 84 | plt.xlim([0, 1]) 85 | plt.ylim([0, 1]) 86 | plt.ylabel('True Positive Rate') 87 | plt.xlabel('False Positive Rate') 88 | 89 | 90 | fpr, tpr, threshold = metrics.roc_curve(y_test, preds, pos_label=1) 91 | roc_auc = metrics.auc(fpr, tpr) 92 | print(roc_auc) 93 | 94 | 95 | plt.plot(fpr, tpr, linewidth=2, label= '%s, AUC = %0.4f' % ('ours', roc_auc)) 96 | plt.legend(loc='lower right') 97 | 98 | plt.savefig('ROC-AUC_dataset1.png') 99 | plt.clf() 100 | 101 | 102 | import matplotlib.pyplot as plt 103 | 104 | plt.title('Precision Recall Curve') 105 | # plt.legend(loc='lower right') 106 | plt.plot([0, 1], [0.5, 0.5], 'r--') 107 | plt.xlim([0, 1]) 108 | plt.ylim([0, 1]) 109 | plt.ylabel('Precision') 110 | plt.xlabel('Recall') 111 | 112 | precision, recall, threshold = metrics.precision_recall_curve(y_test, preds, pos_label=1) 113 | ap = metrics.average_precision_score(y_test, preds) 114 | 115 | plt.plot(recall, precision, linewidth=2, 116 | label='%s, AP = %0.4f' % ('ours', ap)) 117 | plt.legend(loc='lower right') 118 | plt.savefig('PR-AP_dataset1.png') 119 | plt.clf() 120 | 121 | print(ap) 122 | print(metrics.accuracy_score(y_test, preds > 0.5)) 123 | print(metrics.f1_score(y_test, preds > 0.5)) -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/ROCAUC_dataset2_classmark.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import glob 6 | import matplotlib.pyplot as plt 7 | #from scipy.interpolate import int 8 | from skimage.transform import resize 9 | from scipy.ndimage import gaussian_filter 10 | import cv2 11 | 12 | attention_df = pd.read_csv('results_stage1_clustering_dataset2/attention_score_dataset2_region.csv') 13 | 14 | region_image_folder = '/Data3/GCA_Demo/4864patch_dataset2' 15 | region_folder = '/Data3/GCA_Demo/4864patch_dataset2_patch' 16 | attention_folder = '/Data3/GCA_Demo/4864patch_dataset2_attention' 17 | 18 | if not os.path.exists(attention_folder): 19 | os.makedirs(attention_folder) 20 | 21 | 22 | patch_size = 1024 23 | 24 | get_bagscore_flag = 0 25 | 26 | if get_bagscore_flag: 27 | bag_score_df = pd.DataFrame(columns=['bag_num', 'label', 'pred']) 28 | 29 | for bi in range(int(len(attention_df)/8)): 30 | print('%d/%d' % (bi, int(len(attention_df)/8))) 31 | 32 | label = 0 33 | pred = attention_df['score'][bi * 8] 34 | now_bag = attention_df['bag_num'][bi * 8] 35 | 36 | now_root_list = attention_df['root'][bi * 8].split('/') 37 | region_name = now_root_list[4] 38 | now_classmark = glob.glob(os.path.join(attention_folder, "%s*classmark.png" % (region_name)))[0] 39 | img = plt.imread(now_classmark)[:,:,:3] 40 | 41 | 42 | for ki in range(8): 43 | idx = bi * 8 + ki 44 | now_root_list = attention_df['root'][idx].split('/') 45 | region_name = now_root_list[4] 46 | now_x = int(now_root_list[5].split('_')[1]) 47 | now_y = int(now_root_list[5].split('_')[2]) 48 | 49 | patch_red = img[now_x - 384* 2:now_x + patch_size - 384* 2, now_y - 384* 2:now_y + patch_size - 384* 2, 0].mean() 50 | patch_green = img[now_x - 384* 2:now_x + patch_size - 384* 2, now_y - 384* 2:now_y + patch_size - 384* 2, 1].mean() 51 | 52 | if patch_red > 0.9 and patch_green < 0.8: 53 | label = 1 54 | break 55 | 56 | bag_score_df.loc[bi] = [now_bag, label, pred] 57 | 58 | bag_score_df.to_csv('results_stage1_clustering_dataset2/bag_score.csv', index = False) 59 | 60 | else: 61 | bag_score_df = pd.read_csv('results_stage1_clustering_dataset2/bag_score.csv') 62 | 63 | y_test = np.array(bag_score_df['label'].tolist()) 64 | preds = np.array(bag_score_df['pred'].tolist()) 65 | 66 | 67 | # cnt = 0 68 | # for ki in range(len(y_test)): 69 | # if (preds[ki] > 0.5 and y_test[ki] == 1) or (preds[ki] < 0.5 and y_test[ki] == 0): 70 | # cnt += 1 71 | # 72 | # print(cnt/len(y_test)) 73 | 74 | acc = ((preds > 0.5) * (y_test == 1) + (preds < 0.5) * (y_test == 0)).mean() 75 | 76 | 77 | import sklearn.metrics as metrics 78 | 79 | import matplotlib.pyplot as plt 80 | 81 | plt.title('Receiver Operating Characteristic') 82 | # plt.plot(fpr, tpr, 'b', label='AUC = %0.4f' % (roc_auc)) 83 | # plt.legend(loc='lower right') 84 | plt.plot([0, 1], [0, 1], 'r--') 85 | plt.xlim([0, 1]) 86 | plt.ylim([0, 1]) 87 | plt.ylabel('True Positive Rate') 88 | plt.xlabel('False Positive Rate') 89 | 90 | 91 | fpr, tpr, threshold = metrics.roc_curve(y_test, preds, pos_label=1) 92 | roc_auc = metrics.auc(fpr, tpr) 93 | print(roc_auc) 94 | 95 | 96 | plt.plot(fpr, tpr, linewidth=2, label= '%s, AUC = %0.4f' % ('ours', roc_auc)) 97 | plt.legend(loc='lower right') 98 | 99 | plt.savefig('ROC-AUC_dataset2.png') 100 | plt.clf() 101 | 102 | 103 | 104 | 105 | import matplotlib.pyplot as plt 106 | 107 | plt.title('Precision Recall Curve') 108 | # plt.legend(loc='lower right') 109 | plt.plot([0, 1], [0.5, 0.5], 'r--') 110 | plt.xlim([0, 1]) 111 | plt.ylim([0, 1]) 112 | plt.ylabel('Precision') 113 | plt.xlabel('Recall') 114 | 115 | precision, recall, threshold = metrics.precision_recall_curve(y_test, preds, pos_label=1) 116 | ap = metrics.average_precision_score(y_test, preds) 117 | 118 | plt.plot(recall, precision, linewidth=2, 119 | label='%s, AP = %0.4f' % ('ours', ap)) 120 | plt.legend(loc='lower right') 121 | plt.savefig('PR-AP_dataset2.png') 122 | plt.clf() 123 | 124 | print(ap) 125 | print(metrics.accuracy_score(y_test, preds > 0.5)) 126 | print(metrics.f1_score(y_test, preds > 0.5)) -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Toydataset_Code/cs-mil-toydataset/__init__.py -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/create_kmeans.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import glob 6 | import matplotlib.pyplot as plt 7 | 8 | data_dir = '/Data2/GCA/AttentionDeepMIL-master/data/train/' 9 | csv_dir = '/Data2/GCA/AttentionDeepMIL-master/data/csv/' 10 | 11 | label = glob.glob(os.path.join(data_dir, "*")) 12 | 13 | for ki in range(len(label)): 14 | cases = glob.glob(os.path.join(data_dir, "*")) 15 | 16 | df = pd.DataFrame(columns = ['root', 'label', 'cluster']) 17 | 18 | for now_case in cases: 19 | images = glob.glob(os.path.join(now_case,"*")) 20 | images.sort() 21 | 22 | features = np.zeros((len(images), 1024)) 23 | 24 | #for ii in range(len(images)): 25 | for ii in range(10): 26 | array[ii] = plt.imread(images[ii])[:,:,:3] 27 | 28 | kmeans = KMeans(n_clusters=8, random_state=0).fit(array) 29 | -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/data_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import random 4 | 5 | pd.options.mode.chained_assignment = None 6 | split_num = 5 7 | original_file = pd.read_csv('/Data2/GCA/AttentionDeepMIL/data_split/data_list.csv') 8 | class1 = original_file[original_file['class'] == 1]['filename'].tolist() 9 | class0 = original_file[original_file['class'] == 0]['filename'].tolist() 10 | 11 | test_size = 4 12 | 13 | for si in range(split_num): 14 | random.shuffle(class0) 15 | random.shuffle(class1) 16 | # ind_0 = np.random.randint(0, len(class0), test_size) 17 | # ind_1 = np.random.randint(0, len(class1), test_size) 18 | 19 | testing_0 = class0[:11] 20 | testing_1 = class1[:11] 21 | 22 | validation_0 = class0[11:14] 23 | validation_1 = class1[11:14] 24 | 25 | training_0 = class0[14:] 26 | training_1 = class1[14:] 27 | 28 | for oi in range(len(original_file)): 29 | if (original_file.iloc[oi]['filename'] in testing_0) or (original_file.iloc[oi]['filename'] in testing_1): 30 | original_file.iloc[oi, original_file.columns.get_loc('train')] = 2 31 | elif (original_file.iloc[oi]['filename'] in validation_0) or (original_file.iloc[oi]['filename'] in validation_1): 32 | original_file.iloc[oi, original_file.columns.get_loc('train')] = 1 33 | else: 34 | original_file.iloc[oi, original_file.columns.get_loc('train')] = 0 35 | original_file.to_csv('/Data2/GCA/AttentionDeepMIL/data_split/data_list_%d.csv' % (si), index = False) 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /Toydataset_Code/cs-mil-toydataset/data_split_tranvalbalanced.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import random 4 | 5 | pd.options.mode.chained_assignment = None 6 | split_num = 5 7 | original_file = pd.read_csv('/Data2/GCA/AttentionDeepMIL/data_split/data_list.csv') 8 | class1 = original_file[original_file['class'] == 1]['filename'].tolist() 9 | class0 = original_file[original_file['class'] == 0]['filename'].tolist() 10 | 11 | test_size = 4 12 | 13 | for si in range(split_num): 14 | random.shuffle(class0) 15 | random.shuffle(class1) 16 | # ind_0 = np.random.randint(0, len(class0), test_size) 17 | # ind_1 = np.random.randint(0, len(class1), test_size) 18 | 19 | testing_0 = class0[:4] 20 | testing_1 = class1[:18] 21 | 22 | validation_0 = class0[4:7] 23 | validation_1 = class1[18:21] 24 | 25 | training_0 = class0[7:] 26 | training_1 = class1[21:] 27 | 28 | for oi in range(len(original_file)): 29 | if (original_file.iloc[oi]['filename'] in testing_0) or (original_file.iloc[oi]['filename'] in testing_1): 30 | original_file.iloc[oi, original_file.columns.get_loc('train')] = 2 31 | elif (original_file.iloc[oi]['filename'] in validation_0) or (original_file.iloc[oi]['filename'] in validation_1): 32 | original_file.iloc[oi, original_file.columns.get_loc('train')] = 1 33 | else: 34 | original_file.iloc[oi, original_file.columns.get_loc('train')] = 0 35 | original_file.to_csv('/Data2/GCA/AttentionDeepMIL/data_split/data_list_%d.csv' % (si), index = False) 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /Toydataset_Code/data_processing/MIL_bag_generation.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv2 2 | import numpy as np 3 | import pandas as pd 4 | from PIL import Image 5 | import os 6 | import SimpleITK as sitk 7 | import matplotlib 8 | matplotlib.use('TkAgg') 9 | import matplotlib.pyplot as plt 10 | from pathlib import Path 11 | import argparse 12 | import random 13 | import numpy as np 14 | import matplotlib.cm as cm 15 | import glob 16 | 17 | import re 18 | 19 | import matplotlib.pyplot as plt 20 | from matplotlib.cbook import get_sample_data 21 | from skimage.transform import resize 22 | import torch 23 | 24 | def atoi(text): 25 | return int(text) if text.isdigit() else text 26 | 27 | 28 | def natural_keys(text): 29 | ''' 30 | alist.sort(key=natural_keys) sorts in human order 31 | http://nedbatchelder.com/blog/200712/human_sorting.html 32 | (See Toothy's implementation in the comments) 33 | ''' 34 | return [ atoi(c) for c in re.split(r'(\d+)', text) ] 35 | 36 | 37 | 38 | if __name__ == "__main__": 39 | # dataset = 2 40 | # bag_num = [1000, 200] 41 | dataset = 2 42 | bag_num = [1000, 200] 43 | bag_size = 8 44 | MIL = 1 45 | pos_ratio = 0.5 46 | 47 | bag_folder = '/Data3/GCA_Demo/Dataset%d_Bag_%d_%d_%d' % (dataset, bag_num[0],bag_size, MIL) 48 | 49 | data_folder = ['/Data3/GCA_Demo/Dataset%d_demo' % (dataset),'/Data3/GCA_Demo/Dataset%d_demo_val' % (dataset)] 50 | 51 | mode = ['Train','Val'] 52 | r = np.random.RandomState(1) 53 | 54 | for mi in range(len(mode)): 55 | save_root = os.path.join(bag_folder, mode[mi]) 56 | if not os.path.exists(save_root): 57 | os.makedirs(save_root) 58 | 59 | img_root = data_folder[mi] 60 | now_bag_num = int(bag_num[mi]) 61 | 62 | for bi in range(now_bag_num): 63 | now_label = np.random.randint(0, 2, 1)[0] 64 | pos_list = glob.glob(os.path.join(img_root,'1','*size1024.png')) 65 | neg_list = glob.glob(os.path.join(img_root,'0','*size1024.png')) 66 | 67 | if now_label == 0: 68 | indices = torch.LongTensor(r.randint(0, len(neg_list), bag_size)) 69 | 70 | if bag_size == 1: 71 | now_images = [neg_list[indices]] 72 | else: 73 | now_images = list(np.array(neg_list)[indices]) 74 | 75 | else: 76 | if MIL: 77 | pos_length = r.randint(1, int(pos_ratio * bag_size), 1) 78 | else: 79 | pos_length = bag_size 80 | neg_length = bag_size - pos_length 81 | 82 | indices_pos = torch.LongTensor(r.randint(0, len(pos_list), pos_length)) 83 | indices_neg = torch.LongTensor(r.randint(0, len(neg_list), neg_length)) 84 | 85 | if len(indices_pos) == 1: 86 | now_images_pos = [pos_list[indices_pos]] 87 | now_images_neg = list(np.array(neg_list)[indices_neg]) 88 | else: 89 | now_images_pos = list(np.array(pos_list)[ 90 | indices_pos]) # + list(np.array(self.image_list_class0)[indices_neg]) # self.image_list_class0[indices_neg] 91 | now_images_neg = list(np.array(neg_list)[indices_neg]) 92 | 93 | now_images = now_images_pos + now_images_neg 94 | random.shuffle(now_images) 95 | 96 | 97 | df = pd.DataFrame(columns = ['img_root', 'class']) 98 | for ki in range(len(now_images)): 99 | df.loc[ki] = [now_images[ki], now_label] 100 | 101 | df.to_csv(os.path.join(save_root, 'bag_%d.csv' % (bi)), index = False) 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /Train_Test_Code/Background_filter.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import glob 6 | import matplotlib.pyplot as plt 7 | 8 | data_dir = '/Data2/GCA/AttentionDeepMIL-master/data/train' 9 | output_dir = '/Data2/GCA/AttentionDeepMIL-master/contrastive_learning_data' 10 | 11 | if not os.path.exists(output_dir): 12 | os.makedirs(output_dir) 13 | 14 | label = glob.glob(os.path.join(data_dir, "*")) 15 | 16 | for ki in range(len(label)): 17 | cases = glob.glob(os.path.join(label[ki], "*")) 18 | 19 | # df = pd.DataFrame(columns = ['root', 'label', 'cluster']) 20 | 21 | for now_case in cases: 22 | images = glob.glob(os.path.join(now_case,"*")) 23 | images.sort() 24 | 25 | for now_image in images: 26 | patch = plt.imread(now_image)[:,:,:3] 27 | image_name = os.path.basename(now_image) 28 | if (patch.mean(2) > 230 / 255).sum() < 512 * 512 / 2: # for dodnet 29 | plt.imsave(os.path.join(output_dir, image_name), patch) 30 | 31 | -------------------------------------------------------------------------------- /Train_Test_Code/Classifier_model_MAg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model definition of DeepAttnMISL 3 | 4 | If this work is useful for your research, please consider to cite our papers: 5 | 6 | [1] "Whole Slide Images based Cancer Survival Prediction using Attention Guided Deep Multiple Instance Learning Networks" 7 | Jiawen Yao, XinliangZhu, Jitendra Jonnagaddala, NicholasHawkins, Junzhou Huang, 8 | Medical Image Analysis, Available online 19 July 2020, 101789 9 | 10 | [2] "Deep Multi-instance Learning for Survival Prediction from Whole Slide Images", In MICCAI 2019 11 | 12 | """ 13 | 14 | import torch.nn as nn 15 | import torch 16 | 17 | class Classifier(nn.Module): 18 | """ 19 | Deep AttnMISL Model definition 20 | """ 21 | 22 | def __init__(self): 23 | super(Classifier, self).__init__() 24 | 25 | self.classifier = nn.Sequential( 26 | nn.Linear(10, 5), 27 | nn.ReLU(), 28 | # nn.Dropout(p=0.5), 29 | nn.Linear(5, 2), 30 | nn.Softmax(dim = 1) 31 | ) 32 | 33 | def forward(self, x): 34 | 35 | out = self.classifier(x) 36 | 37 | return out 38 | 39 | -------------------------------------------------------------------------------- /Train_Test_Code/DeepAttnMISL_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model definition of DeepAttnMISL 3 | 4 | If this work is useful for your research, please consider to cite our papers: 5 | 6 | [1] "Whole Slide Images based Cancer Survival Prediction using Attention Guided Deep Multiple Instance Learning Networks" 7 | Jiawen Yao, XinliangZhu, Jitendra Jonnagaddala, NicholasHawkins, Junzhou Huang, 8 | Medical Image Analysis, Available online 19 July 2020, 101789 9 | 10 | [2] "Deep Multi-instance Learning for Survival Prediction from Whole Slide Images", In MICCAI 2019 11 | 12 | """ 13 | 14 | import torch.nn as nn 15 | import torch 16 | 17 | class DeepAttnMIL_Surv(nn.Module): 18 | """ 19 | Deep AttnMISL Model definition 20 | """ 21 | 22 | def __init__(self, cluster_num): 23 | super(DeepAttnMIL_Surv, self).__init__() 24 | self.embedding_net = nn.Sequential(nn.Conv2d(2048, 64, 1), 25 | nn.ReLU(), 26 | nn.AdaptiveAvgPool2d((1,1)) 27 | ) 28 | 29 | 30 | self.attention = nn.Sequential( 31 | nn.Linear(64, 32), # V 32 | nn.Tanh(), 33 | nn.Linear(32, 1) # W 34 | ) 35 | 36 | self.fc6 = nn.Sequential( 37 | nn.Linear(64, 32), 38 | nn.ReLU(), 39 | nn.Dropout(p=0.5), 40 | nn.Linear(32, 1), 41 | nn.Sigmoid() 42 | ) 43 | self.cluster_num = cluster_num 44 | 45 | 46 | def masked_softmax(self, x, mask=None): 47 | """ 48 | Performs masked softmax, as simply masking post-softmax can be 49 | inaccurate 50 | :param x: [batch_size, num_items] 51 | :param mask: [batch_size, num_items] 52 | :return: 53 | """ 54 | if mask is not None: 55 | mask = mask.float() 56 | if mask is not None: 57 | x_masked = x * mask + (1 - 1 / (mask+1e-5)) 58 | else: 59 | x_masked = x 60 | x_max = x_masked.max(1)[0] 61 | x_exp = (x - x_max.unsqueeze(-1)).exp() 62 | if mask is not None: 63 | x_exp = x_exp * mask.float() 64 | return x_exp / x_exp.sum(1).unsqueeze(-1) 65 | 66 | 67 | def forward(self, x, mask): 68 | 69 | " x is a tensor list" 70 | res = [] 71 | for i in range(self.cluster_num): 72 | hh = x[i].type(torch.FloatTensor).to("cuda") 73 | output = self.embedding_net(hh) 74 | output = output.view(output.size()[0], -1) 75 | res.append(output) 76 | 77 | 78 | h = torch.cat(res) 79 | 80 | b = h.size(0) 81 | c = h.size(1) 82 | 83 | h = h.view(b, c) 84 | 85 | A = self.attention(h) 86 | A = torch.transpose(A, 1, 0) # KxN 87 | 88 | A = self.masked_softmax(A, mask) 89 | 90 | 91 | M = torch.mm(A, h) # KxL 92 | 93 | Y_pred = self.fc6(M) 94 | 95 | return Y_pred 96 | 97 | -------------------------------------------------------------------------------- /Train_Test_Code/DeepAttnMISL_model_no21.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model definition of DeepAttnMISL 3 | 4 | If this work is useful for your research, please consider to cite our papers: 5 | 6 | [1] "Whole Slide Images based Cancer Survival Prediction using Attention Guided Deep Multiple Instance Learning Networks" 7 | Jiawen Yao, XinliangZhu, Jitendra Jonnagaddala, NicholasHawkins, Junzhou Huang, 8 | Medical Image Analysis, Available online 19 July 2020, 101789 9 | 10 | [2] "Deep Multi-instance Learning for Survival Prediction from Whole Slide Images", In MICCAI 2019 11 | 12 | """ 13 | 14 | import torch.nn as nn 15 | import torch 16 | 17 | class DeepAttnMIL_Surv(nn.Module): 18 | """ 19 | Deep AttnMISL Model definition 20 | """ 21 | 22 | def __init__(self, cluster_num): 23 | super(DeepAttnMIL_Surv, self).__init__() 24 | self.embedding_net = nn.Sequential(nn.Conv2d(2048, 64, 1), 25 | nn.ReLU(), 26 | nn.AdaptiveAvgPool2d((1,1)) 27 | ) 28 | 29 | self.res_attention = nn.Sequential( 30 | nn.Conv2d(64, 32, 1), # V 31 | nn.ReLU(), 32 | nn.Conv2d(32, 1, 1), 33 | ) 34 | 35 | self.attention = nn.Sequential( 36 | nn.Linear(64, 32), # V 37 | nn.Tanh(), 38 | nn.Linear(32, 1) # W 39 | ) 40 | 41 | self.fc6 = nn.Sequential( 42 | nn.Linear(64, 32), 43 | nn.ReLU(), 44 | nn.Dropout(p=0.5), 45 | nn.Linear(32, 1), 46 | nn.Sigmoid() 47 | ) 48 | self.cluster_num = cluster_num 49 | 50 | 51 | def masked_softmax(self, x, mask=None): 52 | """ 53 | Performs masked softmax, as simply masking post-softmax can be 54 | inaccurate 55 | :param x: [batch_size, num_items] 56 | :param mask: [batch_size, num_items] 57 | :return: 58 | """ 59 | if mask is not None: 60 | mask = mask.float() 61 | if mask is not None: 62 | x_masked = x * mask + (1 - 1 / (mask+1e-5)) 63 | else: 64 | x_masked = x 65 | x_max = x_masked.max(1)[0] 66 | x_exp = (x - x_max.unsqueeze(-1)).exp() 67 | if mask is not None: 68 | x_exp = x_exp * mask.float() 69 | return x_exp / x_exp.sum(1).unsqueeze(-1) 70 | 71 | 72 | def forward(self, x, mask): 73 | 74 | " x is a tensor list" 75 | res = [] 76 | for i in range(self.cluster_num): 77 | hh = x[i].type(torch.FloatTensor).to("cuda") 78 | output1 = self.embedding_net(hh[:,:,0:1,:]) 79 | output2 = self.embedding_net(hh[:,:,1:2,:]) 80 | output3 = self.embedding_net(hh[:,:,2:3,:]) 81 | output = torch.cat([output1, output2, output3],2) 82 | res_attention = self.res_attention(output).squeeze(-1) 83 | 84 | final_output = torch.matmul(output.squeeze(-1), torch.transpose(res_attention,2,1)).squeeze(-1) 85 | res.append(final_output) 86 | 87 | h = torch.cat(res) 88 | 89 | b = h.size(0) 90 | c = h.size(1) 91 | 92 | h = h.view(b, c) 93 | 94 | A = self.attention(h) 95 | A = torch.transpose(A, 1, 0) # KxN 96 | 97 | A = self.masked_softmax(A, mask) 98 | 99 | M = torch.mm(A, h) # KxL 100 | 101 | Y_pred = self.fc6(M) 102 | 103 | return Y_pred 104 | 105 | -------------------------------------------------------------------------------- /Train_Test_Code/GetDataList.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import glob 6 | import matplotlib.pyplot as plt 7 | 8 | data_dir = '/Data2/GCA/AttentionDeepMIL/Normal_control_vs_patient.csv' 9 | csv_dir = '/Data2/GCA/AttentionDeepMIL/data_list.csv' 10 | 11 | label_df = pd.read_csv(data_dir) 12 | list_df = pd.DataFrame(columns = ['filename', 'region', 'patient_type', 'score', 'class', 'train']) 13 | 14 | for ki in range(len(label_df)): 15 | #if label_df.iloc[ki]['patient_type'].replace(" ", "") == 'CD' or label_df.iloc[ki]['patient_type'].replace(" ", "") == 'Control': 16 | if not pd.isna(label_df.iloc[ki]['patient_type']): 17 | now_file = label_df.iloc[ki]['filename'].replace(" ", "") 18 | now_region = label_df.iloc[ki]['region'].replace(" ", "") 19 | now_patient_type = label_df.iloc[ki]['patient_type'].replace(" ", "") 20 | if now_patient_type == 'CD': 21 | now_class = 1 22 | else: 23 | now_class = 0 24 | 25 | now_score = int(label_df.iloc[ki]['score']) 26 | now_train = 1 27 | 28 | row = len(list_df) 29 | list_df.loc[row] = [now_file, now_region, now_patient_type, now_score, now_class, now_train] 30 | 31 | list_df.to_csv(csv_dir, index = False) -------------------------------------------------------------------------------- /Train_Test_Code/Regions_to_multiscale_patches.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import glob 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | import numpy as np 7 | import random 8 | from skimage.transform import resize 9 | 10 | if __name__ == "__main__": 11 | folder_name = '/Data2/CSMIL_TCGA/TCGA_regions' 12 | output_folder = '/Data2/CSMIL_TCGA/TCGA_multi-scale_patches' 13 | 14 | cases = glob.glob(os.path.join(folder_name,'*')) 15 | 16 | for case in cases: 17 | regions = glob.glob(os.path.join(case, '*.png')) 18 | 19 | now_folder = os.path.join(output_folder, os.path.basename(case)) 20 | 21 | if not os.path.exists(now_folder): 22 | os.makedirs(now_folder) 23 | 24 | for ri in range(len(regions)): 25 | now_img = plt.imread(regions[ri])[:,:,:3] 26 | for xi in range(0,16): 27 | for yi in range(0,16): 28 | now_x = xi * 256 + 384 29 | now_y = yi * 256 + 384 30 | 31 | patch_256 = now_img[now_x:now_x + 256, now_y:now_y + 256,:] 32 | patch_512 = resize(now_img[now_x - 128:now_x + 256 + 128, now_y - 128:now_y + 256 + 128,:], (256, 256,3), anti_aliasing= False) 33 | patch_1024 = resize(now_img[now_x - 128 -256:now_x + 256 + 128 +256, now_y - 128 - 256:now_y + 256 + 128 + 256,:], (256, 256,3), anti_aliasing= False) 34 | 35 | plt.imsave(os.path.join(now_folder, os.path.basename(regions[ri]).replace('size4864.png', '%d_%d_size256.png' % (now_x, now_y))),patch_256) 36 | plt.imsave(os.path.join(now_folder, os.path.basename(regions[ri]).replace('size4864.png', '%d_%d_size512.png' % (now_x, now_y))),patch_512) 37 | plt.imsave(os.path.join(now_folder, os.path.basename(regions[ri]).replace('size4864.png', '%d_%d_size1024.png' % (now_x, now_y))),patch_1024) 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /Train_Test_Code/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hrlblab/CS-MIL/d51b97f411caeff092e397427003021d1b23f00f/Train_Test_Code/__init__.py -------------------------------------------------------------------------------- /Train_Test_Code/create_kmeans.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import glob 6 | import matplotlib.pyplot as plt 7 | 8 | data_dir = '/Data2/GCA/AttentionDeepMIL-master/data/train/' 9 | csv_dir = '/Data2/GCA/AttentionDeepMIL-master/data/csv/' 10 | 11 | label = glob.glob(os.path.join(data_dir, "*")) 12 | 13 | for ki in range(len(label)): 14 | cases = glob.glob(os.path.join(data_dir, "*")) 15 | 16 | df = pd.DataFrame(columns = ['root', 'label', 'cluster']) 17 | 18 | for now_case in cases: 19 | images = glob.glob(os.path.join(now_case,"*")) 20 | images.sort() 21 | 22 | features = np.zeros((len(images), 1024)) 23 | 24 | #for ii in range(len(images)): 25 | for ii in range(10): 26 | array[ii] = plt.imread(images[ii])[:,:,:3] 27 | 28 | kmeans = KMeans(n_clusters=8, random_state=0).fit(array) 29 | -------------------------------------------------------------------------------- /Train_Test_Code/data_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import random 4 | 5 | pd.options.mode.chained_assignment = None 6 | split_num = 5 7 | original_file = pd.read_csv('/Data2/GCA/AttentionDeepMIL/data_split/data_list.csv') 8 | class1 = original_file[original_file['class'] == 1]['filename'].tolist() 9 | class0 = original_file[original_file['class'] == 0]['filename'].tolist() 10 | 11 | test_size = 4 12 | 13 | for si in range(split_num): 14 | random.shuffle(class0) 15 | random.shuffle(class1) 16 | # ind_0 = np.random.randint(0, len(class0), test_size) 17 | # ind_1 = np.random.randint(0, len(class1), test_size) 18 | 19 | testing_0 = class0[:11] 20 | testing_1 = class1[:11] 21 | 22 | validation_0 = class0[11:14] 23 | validation_1 = class1[11:14] 24 | 25 | training_0 = class0[14:] 26 | training_1 = class1[14:] 27 | 28 | for oi in range(len(original_file)): 29 | if (original_file.iloc[oi]['filename'] in testing_0) or (original_file.iloc[oi]['filename'] in testing_1): 30 | original_file.iloc[oi, original_file.columns.get_loc('train')] = 2 31 | elif (original_file.iloc[oi]['filename'] in validation_0) or (original_file.iloc[oi]['filename'] in validation_1): 32 | original_file.iloc[oi, original_file.columns.get_loc('train')] = 1 33 | else: 34 | original_file.iloc[oi, original_file.columns.get_loc('train')] = 0 35 | original_file.to_csv('/Data2/GCA/AttentionDeepMIL/data_split/data_list_%d.csv' % (si), index = False) 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /Train_Test_Code/data_split_tranvalbalanced.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import random 4 | 5 | pd.options.mode.chained_assignment = None 6 | split_num = 5 7 | original_file = pd.read_csv('/Data2/GCA/AttentionDeepMIL/data_split/data_list.csv') 8 | class1 = original_file[original_file['class'] == 1]['filename'].tolist() 9 | class0 = original_file[original_file['class'] == 0]['filename'].tolist() 10 | 11 | test_size = 4 12 | 13 | for si in range(split_num): 14 | random.shuffle(class0) 15 | random.shuffle(class1) 16 | # ind_0 = np.random.randint(0, len(class0), test_size) 17 | # ind_1 = np.random.randint(0, len(class1), test_size) 18 | 19 | testing_0 = class0[:4] 20 | testing_1 = class1[:18] 21 | 22 | validation_0 = class0[4:7] 23 | validation_1 = class1[18:21] 24 | 25 | training_0 = class0[7:] 26 | training_1 = class1[21:] 27 | 28 | for oi in range(len(original_file)): 29 | if (original_file.iloc[oi]['filename'] in testing_0) or (original_file.iloc[oi]['filename'] in testing_1): 30 | original_file.iloc[oi, original_file.columns.get_loc('train')] = 2 31 | elif (original_file.iloc[oi]['filename'] in validation_0) or (original_file.iloc[oi]['filename'] in validation_1): 32 | original_file.iloc[oi, original_file.columns.get_loc('train')] = 1 33 | else: 34 | original_file.iloc[oi, original_file.columns.get_loc('train')] = 0 35 | original_file.to_csv('/Data2/GCA/AttentionDeepMIL/data_split/data_list_%d.csv' % (si), index = False) 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /Train_Test_Code/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Attention(nn.Module): 7 | def __init__(self): 8 | super(Attention, self).__init__() 9 | self.L = 500 10 | self.D = 128 11 | self.K = 1 12 | 13 | self.feature_extractor_part1 = nn.Sequential( 14 | nn.Conv2d(1, 20, kernel_size=5), 15 | nn.ReLU(), 16 | nn.MaxPool2d(2, stride=2), 17 | nn.Conv2d(20, 50, kernel_size=5), 18 | nn.ReLU(), 19 | nn.MaxPool2d(2, stride=2) 20 | ) 21 | 22 | self.feature_extractor_part2 = nn.Sequential( 23 | nn.Linear(50 * 4 * 4, self.L), 24 | nn.ReLU(), 25 | ) 26 | 27 | self.attention = nn.Sequential( 28 | nn.Linear(self.L, self.D), 29 | nn.Tanh(), 30 | nn.Linear(self.D, self.K) 31 | ) 32 | 33 | self.classifier = nn.Sequential( 34 | nn.Linear(self.L*self.K, 1), 35 | nn.Sigmoid() 36 | ) 37 | 38 | def forward(self, x): 39 | x = x.squeeze(0) 40 | 41 | H = self.feature_extractor_part1(x) 42 | H = H.view(-1, 50 * 4 * 4) 43 | H = self.feature_extractor_part2(H) # NxL 44 | 45 | A = self.attention(H) # NxK 46 | A = torch.transpose(A, 1, 0) # KxN 47 | A = F.softmax(A, dim=1) # softmax over N 48 | 49 | M = torch.mm(A, H) # KxL 50 | 51 | Y_prob = self.classifier(M) 52 | Y_hat = torch.ge(Y_prob, 0.5).float() 53 | 54 | return Y_prob, Y_hat, A 55 | 56 | # AUXILIARY METHODS 57 | def calculate_classification_error(self, X, Y): 58 | Y = Y.float() 59 | _, Y_hat, _ = self.forward(X) 60 | #error = 1. - Y_hat.eq(Y).cpu().float().mean().data[0] 61 | error = 1. - Y_hat.eq(Y).cpu().float().mean().data 62 | 63 | return error, Y_hat 64 | 65 | def calculate_objective(self, X, Y): 66 | Y = Y.float() 67 | Y_prob, _, A = self.forward(X) 68 | Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5) 69 | neg_log_likelihood = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob)) # negative log bernoulli 70 | 71 | return neg_log_likelihood, A 72 | 73 | class GatedAttention(nn.Module): 74 | def __init__(self): 75 | super(GatedAttention, self).__init__() 76 | self.L = 500 77 | self.D = 128 78 | self.K = 1 79 | 80 | self.feature_extractor_part1 = nn.Sequential( 81 | nn.Conv2d(1, 20, kernel_size=5), 82 | nn.ReLU(), 83 | nn.MaxPool2d(2, stride=2), 84 | nn.Conv2d(20, 50, kernel_size=5), 85 | nn.ReLU(), 86 | nn.MaxPool2d(2, stride=2) 87 | ) 88 | 89 | self.feature_extractor_part2 = nn.Sequential( 90 | nn.Linear(50 * 4 * 4, self.L), 91 | nn.ReLU(), 92 | ) 93 | 94 | self.attention_V = nn.Sequential( 95 | nn.Linear(self.L, self.D), 96 | nn.Tanh() 97 | ) 98 | 99 | self.attention_U = nn.Sequential( 100 | nn.Linear(self.L, self.D), 101 | nn.Sigmoid() 102 | ) 103 | 104 | self.attention_weights = nn.Linear(self.D, self.K) 105 | 106 | self.classifier = nn.Sequential( 107 | nn.Linear(self.L*self.K, 1), 108 | nn.Sigmoid() 109 | ) 110 | 111 | def forward(self, x): 112 | x = x.squeeze(0) 113 | 114 | H = self.feature_extractor_part1(x) 115 | H = H.view(-1, 50 * 4 * 4) 116 | H = self.feature_extractor_part2(H) # NxL 117 | 118 | A_V = self.attention_V(H) # NxD 119 | A_U = self.attention_U(H) # NxD 120 | A = self.attention_weights(A_V * A_U) # element wise multiplication # NxK 121 | A = torch.transpose(A, 1, 0) # KxN 122 | A = F.softmax(A, dim=1) # softmax over N 123 | 124 | M = torch.mm(A, H) # KxL 125 | 126 | Y_prob = self.classifier(M) 127 | Y_hat = torch.ge(Y_prob, 0.5).float() 128 | 129 | return Y_prob, Y_hat, A 130 | 131 | # AUXILIARY METHODS 132 | def calculate_classification_error(self, X, Y): 133 | Y = Y.float() 134 | _, Y_hat, _ = self.forward(X) 135 | error = 1. - Y_hat.eq(Y).cpu().float().mean().item() 136 | 137 | return error, Y_hat 138 | 139 | def calculate_objective(self, X, Y): 140 | Y = Y.float() 141 | Y_prob, _, A = self.forward(X) 142 | Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5) 143 | neg_log_likelihood = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob)) # negative log bernoulli 144 | 145 | return neg_log_likelihood, A 146 | -------------------------------------------------------------------------------- /Train_Test_Code/model3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Attention(nn.Module): 7 | def __init__(self): 8 | super(Attention, self).__init__() 9 | self.L = 500 10 | self.D = 128 11 | self.K = 1 12 | 13 | self.feature_extractor_part1 = nn.Sequential( 14 | nn.Conv2d(3, 4, kernel_size=5), 15 | nn.ReLU(), 16 | nn.MaxPool2d(2, stride=2), 17 | nn.Conv2d(4, 8, kernel_size=5), 18 | nn.ReLU(), 19 | nn.MaxPool2d(2, stride=2), 20 | nn.Conv2d(8, 16, kernel_size=5), 21 | nn.ReLU(), 22 | nn.MaxPool2d(2, stride=2), 23 | nn.Conv2d(16, 20, kernel_size=5), 24 | nn.ReLU(), 25 | nn.MaxPool2d(2, stride=2), 26 | nn.Conv2d(20, 32, kernel_size=5), 27 | nn.ReLU(), 28 | nn.MaxPool2d(2, stride=2), 29 | nn.Conv2d(32, 50, kernel_size=5), 30 | nn.ReLU(), 31 | nn.MaxPool2d(2, stride=2) 32 | ) 33 | 34 | self.feature_extractor_part2 = nn.Sequential( 35 | nn.Linear(50 * 4 * 4, self.L), 36 | nn.ReLU(), 37 | ) 38 | 39 | self.attention = nn.Sequential( 40 | nn.Linear(self.L, self.D), 41 | nn.Tanh(), 42 | nn.Linear(self.D, self.K) 43 | ) 44 | 45 | self.classifier = nn.Sequential( 46 | nn.Linear(self.L*self.K, 1), 47 | nn.Sigmoid() 48 | ) 49 | 50 | def forward(self, x): 51 | x = x.squeeze(0) 52 | 53 | H = self.feature_extractor_part1(x) 54 | H = H.view(-1, 50 * 4 * 4) 55 | H = self.feature_extractor_part2(H) # NxL 56 | 57 | A = self.attention(H) # NxK 58 | A = torch.transpose(A, 1, 0) # KxN 59 | A = F.softmax(A, dim=1) # softmax over N 60 | 61 | M = torch.mm(A, H) # KxL 62 | 63 | Y_prob = self.classifier(M) 64 | Y_hat = torch.ge(Y_prob, 0.5).float() 65 | 66 | return Y_prob, Y_hat, A 67 | 68 | # AUXILIARY METHODS 69 | def calculate_classification_error(self, X, Y): 70 | Y = Y.float() 71 | _, Y_hat, _ = self.forward(X) 72 | # error = 1. - Y_hat.eq(Y).cpu().float().mean().data[0] 73 | error = 1. - Y_hat.eq(Y).cpu().float().mean().data 74 | 75 | return error, Y_hat 76 | 77 | def calculate_objective(self, X, Y): 78 | Y = Y.float() 79 | Y_prob, _, A = self.forward(X) 80 | Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5) 81 | neg_log_likelihood = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob)) # negative log bernoulli 82 | 83 | return neg_log_likelihood, A 84 | 85 | class GatedAttention(nn.Module): 86 | def __init__(self): 87 | super(GatedAttention, self).__init__() 88 | self.L = 500 89 | self.D = 128 90 | self.K = 1 91 | 92 | self.feature_extractor_part1 = nn.Sequential( 93 | nn.Conv2d(3, 20, kernel_size=5), 94 | nn.ReLU(), 95 | nn.MaxPool2d(2, stride=2), 96 | nn.Conv2d(20, 50, kernel_size=5), 97 | nn.ReLU(), 98 | nn.MaxPool2d(2, stride=2) 99 | ) 100 | 101 | self.feature_extractor_part2 = nn.Sequential( 102 | nn.Linear(50 * 4 * 4, self.L), 103 | nn.ReLU(), 104 | ) 105 | 106 | self.attention_V = nn.Sequential( 107 | nn.Linear(self.L, self.D), 108 | nn.Tanh() 109 | ) 110 | 111 | self.attention_U = nn.Sequential( 112 | nn.Linear(self.L, self.D), 113 | nn.Sigmoid() 114 | ) 115 | 116 | self.attention_weights = nn.Linear(self.D, self.K) 117 | 118 | self.classifier = nn.Sequential( 119 | nn.Linear(self.L*self.K, 1), 120 | nn.Sigmoid() 121 | ) 122 | 123 | def forward(self, x): 124 | x = x.squeeze(0) 125 | 126 | H = self.feature_extractor_part1(x) 127 | H = H.view(-1, 50 * 4 * 4) 128 | H = self.feature_extractor_part2(H) # NxL 129 | 130 | A_V = self.attention_V(H) # NxD 131 | A_U = self.attention_U(H) # NxD 132 | A = self.attention_weights(A_V * A_U) # element wise multiplication # NxK 133 | A = torch.transpose(A, 1, 0) # KxN 134 | A = F.softmax(A, dim=1) # softmax over N 135 | 136 | M = torch.mm(A, H) # KxL 137 | 138 | Y_prob = self.classifier(M) 139 | Y_hat = torch.ge(Y_prob, 0.5).float() 140 | 141 | return Y_prob, Y_hat, A 142 | 143 | # AUXILIARY METHODS 144 | def calculate_classification_error(self, X, Y): 145 | Y = Y.float() 146 | _, Y_hat, _ = self.forward(X) 147 | error = 1. - Y_hat.eq(Y).cpu().float().mean().item() 148 | 149 | return error, Y_hat 150 | 151 | def calculate_objective(self, X, Y): 152 | Y = Y.float() 153 | Y_prob, _, A = self.forward(X) 154 | Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5) 155 | neg_log_likelihood = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob)) # negative log bernoulli 156 | 157 | return neg_log_likelihood, A 158 | --------------------------------------------------------------------------------