├── .gitignore ├── LICENSE ├── MIL ├── .gitignore ├── LICENSE.md ├── __init__.py ├── build_preset.py ├── create_heatmaps.py ├── create_patches.py ├── create_patches_fp.py ├── create_splits_seq.py ├── datasets │ ├── __init__.py │ ├── dataset_generic.py │ ├── dataset_h5.py │ └── wsi_dataset.py ├── eval.py ├── extract_features.py ├── extract_features_fp.py ├── get_experiments.ipynb ├── main.py ├── main_cam.py ├── models │ ├── model_clam.py │ ├── model_mil.py │ └── resnet_custom.py ├── splits │ ├── test │ │ ├── splits_0.csv │ │ └── splits_0_bool.csv │ └── train │ │ ├── splits_0.csv │ │ └── splits_0_bool.csv ├── utils │ ├── core_utils.py │ ├── eval_utils.py │ ├── file_utils.py │ └── utils.py ├── vis_utils │ └── heatmap_utils.py └── wsi_core │ ├── WholeSlideImage.py │ ├── batch_process_utils.py │ ├── util_classes.py │ └── wsi_utils.py ├── MIL_data_creation.ipynb ├── README.md ├── __init__.py ├── camelyon16_extraction.ipynb ├── check_preprocess.ipynb ├── dino ├── LICENSE ├── README.md ├── __init__.py ├── causal-conv1d │ ├── =1.1.0 │ ├── AUTHORS │ ├── LICENSE │ ├── README.md │ ├── build │ │ ├── lib.linux-x86_64-cpython-310 │ │ │ ├── causal_conv1d │ │ │ │ ├── __init__.py │ │ │ │ └── causal_conv1d_interface.py │ │ │ └── causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so │ │ └── temp.linux-x86_64-cpython-310 │ │ │ ├── .ninja_deps │ │ │ ├── .ninja_log │ │ │ ├── build.ninja │ │ │ └── csrc │ │ │ ├── causal_conv1d.o │ │ │ ├── causal_conv1d_bwd.o │ │ │ ├── causal_conv1d_fwd.o │ │ │ └── causal_conv1d_update.o │ ├── causal_conv1d-1.0.0+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl │ ├── causal_conv1d.egg-info │ │ └── PKG-INFO │ ├── causal_conv1d │ │ ├── __init__.py │ │ └── causal_conv1d_interface.py │ ├── csrc │ │ ├── causal_conv1d.cpp │ │ ├── causal_conv1d.h │ │ ├── causal_conv1d_bwd.cu │ │ ├── causal_conv1d_common.h │ │ ├── causal_conv1d_fwd.cu │ │ ├── causal_conv1d_update.cu │ │ └── static_switch.h │ ├── setup.py │ └── tests │ │ └── test_causal_conv1d.py ├── config.py ├── eval_finetune.py ├── eval_linear.py ├── main.py ├── mamba-1p1p1 │ ├── .gitignore │ ├── .gitmodules │ ├── AUTHORS │ ├── LICENSE │ ├── README.md │ ├── benchmarks │ │ └── benchmark_generation_mamba_simple.py │ ├── csrc │ │ └── selective_scan │ │ │ ├── reverse_scan.cuh │ │ │ ├── selective_scan.cpp │ │ │ ├── selective_scan.h │ │ │ ├── selective_scan_bwd_bf16_complex.cu │ │ │ ├── selective_scan_bwd_bf16_real.cu │ │ │ ├── selective_scan_bwd_fp16_complex.cu │ │ │ ├── selective_scan_bwd_fp16_real.cu │ │ │ ├── selective_scan_bwd_fp32_complex.cu │ │ │ ├── selective_scan_bwd_fp32_real.cu │ │ │ ├── selective_scan_bwd_kernel.cuh │ │ │ ├── selective_scan_common.h │ │ │ ├── selective_scan_fwd_bf16.cu │ │ │ ├── selective_scan_fwd_fp16.cu │ │ │ ├── selective_scan_fwd_fp32.cu │ │ │ ├── selective_scan_fwd_kernel.cuh │ │ │ ├── static_switch.h │ │ │ └── uninitialized_copy.cuh │ ├── evals │ │ └── lm_harness_eval.py │ ├── mamba_ssm │ │ ├── __init__.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── config_mamba.py │ │ │ └── mixer_seq_simple.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── mamba_simple.py │ │ ├── ops │ │ │ ├── __init__.py │ │ │ ├── selective_scan_interface.py │ │ │ └── triton │ │ │ │ ├── __init__.py │ │ │ │ ├── layernorm.py │ │ │ │ └── selective_state_update.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── generation.py │ │ │ └── hf.py │ ├── setup.py │ └── tests │ │ └── ops │ │ ├── test_selective_scan.py │ │ └── triton │ │ └── test_selective_state_update.py ├── run_with_submitit.py ├── utils.py ├── utils_dino.py ├── vim │ ├── .gitignore │ ├── LICENSE │ ├── __init__.py │ ├── augment.py │ ├── datasets.py │ ├── engine.py │ ├── hubconf.py │ ├── losses.py │ ├── main.py │ ├── models_mamba.py │ ├── rope.py │ ├── run_with_submitit.py │ ├── samplers.py │ ├── scripts │ │ ├── ft-vim-s.sh │ │ ├── ft-vim-t.sh │ │ ├── pt-vim-s.sh │ │ └── pt-vim-t.sh │ └── utils.py ├── vision_transformer.py └── visualize_attention.py ├── media └── Vim4Path.webp ├── mil_data_creation.py ├── patch_heatmaps.ipynb ├── patch_heatmaps.py ├── preprocess ├── check_images.py ├── create_patches.py ├── create_patches_fp.py ├── datasets │ ├── dataset_generic.py │ ├── dataset_h5.py │ └── wsi_dataset.py ├── extract_patches.py ├── extract_patches.sh ├── extract_patches_camelyon.py ├── extract_patches_tar.py ├── sample.sh ├── utils │ ├── core_utils.py │ ├── eval_utils.py │ ├── file_utils.py │ └── utils.py └── wsi_core │ ├── WholeSlideImage.py │ ├── batch_process_utils.py │ ├── util_classes.py │ └── wsi_utils.py ├── test_patches.py └── test_speed.py /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | data/ 3 | *ndpi 4 | *svs 5 | *csv 6 | *jpg 7 | *h5 8 | *png 9 | *tar 10 | *pth 11 | *txt 12 | *out 13 | *pyc 14 | *out 15 | *so 16 | *zip 17 | *sh 18 | *pt 19 | 20 | dino/causal-conv1d/build/ 21 | checkpoints/ 22 | .ipynb_checkpoints/ 23 | __pycache__/ 24 | .idea/ 25 | HIPT/ 26 | #datasets/ 27 | dataset/ 28 | Vim 29 | .github 30 | wandb/ 31 | MIL/results/ 32 | MIL/results_old/ 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AtlasAnalyticsLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MIL/.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | __pycache__/ 3 | splits/ 4 | results/ 5 | presets/ 6 | heatmaps/ 7 | eval_results/ 8 | dataset_csv/ 9 | .ipynb_checkpoints/ 10 | docs/ -------------------------------------------------------------------------------- /MIL/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/MIL/__init__.py -------------------------------------------------------------------------------- /MIL/build_preset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import argparse 4 | 5 | 6 | parser = argparse.ArgumentParser(description='preset_builder') 7 | parser.add_argument('--preset_name', type=str, 8 | help='name of preset') 9 | parser.add_argument('--seg_level', type=int, default=-1, 10 | help='downsample level at which to segment') 11 | parser.add_argument('--sthresh', type=int, default=8, 12 | help='segmentation threshold') 13 | parser.add_argument('--mthresh', type=int, default=7, 14 | help='median filter threshold') 15 | parser.add_argument('--use_otsu', action='store_true', default=False) 16 | parser.add_argument('--close', type=int, default=4, 17 | help='additional morphological closing') 18 | parser.add_argument('--a_t', type=int, default=100, 19 | help='area filter for tissue') 20 | parser.add_argument('--a_h', type=int, default=16, 21 | help='area filter for holes') 22 | parser.add_argument('--max_n_holes', type=int, default=8, 23 | help='maximum number of holes to consider for each tissue contour') 24 | parser.add_argument('--vis_level', type=int, default=-1, 25 | help='downsample level at which to visualize') 26 | parser.add_argument('--line_thickness', type=int, default=250, 27 | help='line_thickness to visualize segmentation') 28 | parser.add_argument('--white_thresh', type=int, default=5, 29 | help='saturation threshold for whether to consider a patch as blank for exclusion') 30 | parser.add_argument('--black_thresh', type=int, default=50, 31 | help='mean rgb threshold for whether to consider a patch as black for exclusion') 32 | parser.add_argument('--no_padding', action='store_false', default=True) 33 | parser.add_argument('--contour_fn', type=str, choices=['four_pt', 'center', 'basic', 'four_pt_hard'], default='four_pt', 34 | help='contour checking function') 35 | 36 | 37 | if __name__ == '__main__': 38 | args = parser.parse_args() 39 | seg_params = {'seg_level': args.seg_level, 'sthresh': args.sthresh, 'mthresh': args.mthresh, 40 | 'close': args.close, 'use_otsu': args.use_otsu, 'keep_ids': 'none', 'exclude_ids': 'none'} 41 | filter_params = {'a_t':args.a_t, 'a_h': args.a_h, 'max_n_holes': args.max_n_holes} 42 | vis_params = {'vis_level': args.vis_level, 'line_thickness': args.line_thickness} 43 | patch_params = {'white_thresh': args.white_thresh, 'black_thresh': args.black_thresh, 44 | 'use_padding': args.no_padding, 'contour_fn': args.contour_fn} 45 | 46 | all_params = {} 47 | all_params.update(seg_params) 48 | all_params.update(filter_params) 49 | all_params.update(vis_params) 50 | all_params.update(patch_params) 51 | params_df = pd.DataFrame(all_params, index=[0]) 52 | params_df.to_csv('presets/{}'.format(args.preset_name), index=False) 53 | -------------------------------------------------------------------------------- /MIL/create_splits_seq.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import os 3 | import pandas as pd 4 | from datasets.dataset_generic import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset, save_splits 5 | import argparse 6 | import numpy as np 7 | 8 | parser = argparse.ArgumentParser(description='Creating splits for whole slide classification') 9 | parser.add_argument('--label_frac', type=float, default= 1.0, 10 | help='fraction of labels (default: 1)') 11 | parser.add_argument('--seed', type=int, default=1, 12 | help='random seed (default: 1)') 13 | parser.add_argument('--k', type=int, default=10, 14 | help='number of splits (default: 10)') 15 | parser.add_argument('--task', type=str, choices=['task_1_tumor_vs_normal', 'task_2_tumor_subtyping']) 16 | parser.add_argument('--val_frac', type=float, default= 0.1, 17 | help='fraction of labels for validation (default: 0.1)') 18 | parser.add_argument('--test_frac', type=float, default= 0.1, 19 | help='fraction of labels for test (default: 0.1)') 20 | 21 | args = parser.parse_args() 22 | 23 | if args.task == 'task_1_tumor_vs_normal': 24 | args.n_classes=2 25 | dataset = Generic_WSI_Classification_Dataset(csv_path = '../clam_data/224_10x/vim-s/testing/tumor_vs_normal.csv', 26 | shuffle = False, 27 | seed = args.seed, 28 | print_info = True, 29 | label_dict = {'normal':0, 'tumor':1}, 30 | patient_strat=True, 31 | ignore=[]) 32 | 33 | elif args.task == 'task_2_tumor_subtyping': 34 | args.n_classes=3 35 | dataset = Generic_WSI_Classification_Dataset(csv_path = 'dataset_csv/tumor_subtyping_dummy_clean.csv', 36 | shuffle = False, 37 | seed = args.seed, 38 | print_info = True, 39 | label_dict = {'subtype_1':0, 'subtype_2':1, 'subtype_3':2}, 40 | patient_strat= True, 41 | patient_voting='maj', 42 | ignore=[]) 43 | 44 | else: 45 | raise NotImplementedError 46 | 47 | num_slides_cls = np.array([len(cls_ids) for cls_ids in dataset.patient_cls_ids]) 48 | val_num = np.round(num_slides_cls * args.val_frac).astype(int) 49 | test_num = np.round(num_slides_cls * args.test_frac).astype(int) 50 | 51 | if __name__ == '__main__': 52 | if args.label_frac > 0: 53 | label_fracs = [args.label_frac] 54 | else: 55 | label_fracs = [0.1, 0.25, 0.5, 0.75, 1.0] 56 | 57 | for lf in label_fracs: 58 | split_dir = 'splits/test' 59 | os.makedirs(split_dir, exist_ok=True) 60 | dataset.create_splits(k = args.k, val_num = val_num, test_num = test_num, label_frac=lf) 61 | for i in range(args.k): 62 | dataset.set_splits() 63 | descriptor_df = dataset.test_split_gen(return_descriptor=True) 64 | splits = [dataset.return_splits(from_id=True)] 65 | save_splits(splits, ['train'], os.path.join(split_dir, 'splits_{}.csv'.format(i))) 66 | save_splits(splits, ['train'], os.path.join(split_dir, 'splits_{}_bool.csv'.format(i)), boolean_style=True) 67 | # descriptor_df.to_csv(os.path.join(split_dir, 'splits_{}_descriptor.csv'.format(i))) 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /MIL/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/MIL/datasets/__init__.py -------------------------------------------------------------------------------- /MIL/datasets/dataset_h5.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | import numpy as np 5 | import pandas as pd 6 | import math 7 | import re 8 | import pdb 9 | import pickle 10 | 11 | from torch.utils.data import Dataset, DataLoader, sampler 12 | from torchvision import transforms, utils, models 13 | import torch.nn.functional as F 14 | 15 | from PIL import Image 16 | import h5py 17 | 18 | from random import randrange 19 | 20 | def eval_transforms(pretrained=False): 21 | if pretrained: 22 | mean = (0.485, 0.456, 0.406) 23 | std = (0.229, 0.224, 0.225) 24 | 25 | else: 26 | mean = (0.5,0.5,0.5) 27 | std = (0.5,0.5,0.5) 28 | 29 | trnsfrms_val = transforms.Compose( 30 | [ 31 | transforms.ToTensor(), 32 | transforms.Normalize(mean = mean, std = std) 33 | ] 34 | ) 35 | 36 | return trnsfrms_val 37 | 38 | class Whole_Slide_Bag(Dataset): 39 | def __init__(self, 40 | file_path, 41 | pretrained=False, 42 | custom_transforms=None, 43 | target_patch_size=-1, 44 | ): 45 | """ 46 | Args: 47 | file_path (string): Path to the .h5 file containing patched data. 48 | pretrained (bool): Use ImageNet transforms 49 | custom_transforms (callable, optional): Optional transform to be applied on a sample 50 | """ 51 | self.pretrained=pretrained 52 | if target_patch_size > 0: 53 | self.target_patch_size = (target_patch_size, target_patch_size) 54 | else: 55 | self.target_patch_size = None 56 | 57 | if not custom_transforms: 58 | self.roi_transforms = eval_transforms(pretrained=pretrained) 59 | else: 60 | self.roi_transforms = custom_transforms 61 | 62 | self.file_path = file_path 63 | 64 | with h5py.File(self.file_path, "r") as f: 65 | dset = f['imgs'] 66 | self.length = len(dset) 67 | 68 | self.summary() 69 | 70 | def __len__(self): 71 | return self.length 72 | 73 | def summary(self): 74 | hdf5_file = h5py.File(self.file_path, "r") 75 | dset = hdf5_file['imgs'] 76 | for name, value in dset.attrs.items(): 77 | print(name, value) 78 | 79 | print('pretrained:', self.pretrained) 80 | print('transformations:', self.roi_transforms) 81 | if self.target_patch_size is not None: 82 | print('target_size: ', self.target_patch_size) 83 | 84 | def __getitem__(self, idx): 85 | with h5py.File(self.file_path,'r') as hdf5_file: 86 | img = hdf5_file['imgs'][idx] 87 | coord = hdf5_file['coords'][idx] 88 | 89 | img = Image.fromarray(img) 90 | if self.target_patch_size is not None: 91 | img = img.resize(self.target_patch_size) 92 | img = self.roi_transforms(img).unsqueeze(0) 93 | return img, coord 94 | 95 | class Whole_Slide_Bag_FP(Dataset): 96 | def __init__(self, 97 | file_path, 98 | wsi, 99 | pretrained=False, 100 | custom_transforms=None, 101 | custom_downsample=1, 102 | target_patch_size=-1 103 | ): 104 | """ 105 | Args: 106 | file_path (string): Path to the .h5 file containing patched data. 107 | pretrained (bool): Use ImageNet transforms 108 | custom_transforms (callable, optional): Optional transform to be applied on a sample 109 | custom_downsample (int): Custom defined downscale factor (overruled by target_patch_size) 110 | target_patch_size (int): Custom defined image size before embedding 111 | """ 112 | self.pretrained=pretrained 113 | self.wsi = wsi 114 | if not custom_transforms: 115 | self.roi_transforms = eval_transforms(pretrained=pretrained) 116 | else: 117 | self.roi_transforms = custom_transforms 118 | 119 | self.file_path = file_path 120 | 121 | with h5py.File(self.file_path, "r") as f: 122 | dset = f['coords'] 123 | self.patch_level = f['coords'].attrs['patch_level'] 124 | self.patch_size = f['coords'].attrs['patch_size'] 125 | self.length = len(dset) 126 | if target_patch_size > 0: 127 | self.target_patch_size = (target_patch_size, ) * 2 128 | elif custom_downsample > 1: 129 | self.target_patch_size = (self.patch_size // custom_downsample, ) * 2 130 | else: 131 | self.target_patch_size = None 132 | # self.summary() 133 | 134 | def __len__(self): 135 | return self.length 136 | 137 | def summary(self): 138 | hdf5_file = h5py.File(self.file_path, "r") 139 | dset = hdf5_file['coords'] 140 | for name, value in dset.attrs.items(): 141 | print(name, value) 142 | 143 | print('\nfeature extraction settings') 144 | print('target patch size: ', self.target_patch_size) 145 | print('pretrained: ', self.pretrained) 146 | print('transformations: ', self.roi_transforms) 147 | 148 | def __getitem__(self, idx): 149 | with h5py.File(self.file_path,'r') as hdf5_file: 150 | coord = hdf5_file['coords'][idx] 151 | img = self.wsi.read_region(coord, self.patch_level, (self.patch_size, self.patch_size)).convert('RGB') 152 | 153 | if self.target_patch_size is not None: 154 | img = img.resize(self.target_patch_size) 155 | img = self.roi_transforms(img).unsqueeze(0) 156 | return img, coord 157 | 158 | class Dataset_All_Bags(Dataset): 159 | 160 | def __init__(self, csv_path): 161 | self.df = pd.read_csv(csv_path) 162 | 163 | def __len__(self): 164 | return len(self.df) 165 | 166 | def __getitem__(self, idx): 167 | return self.df['slide_id'][idx] 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /MIL/datasets/wsi_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import pandas as pd 3 | import numpy as np 4 | import time 5 | import pdb 6 | import PIL.Image as Image 7 | import h5py 8 | from torch.utils.data import Dataset 9 | import torch 10 | from wsi_core.util_classes import Contour_Checking_fn, isInContourV1, isInContourV2, isInContourV3_Easy, isInContourV3_Hard 11 | 12 | def default_transforms(mean = (0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 13 | t = transforms.Compose( 14 | [transforms.ToTensor(), 15 | transforms.Normalize(mean = mean, std = std)]) 16 | return t 17 | 18 | def get_contour_check_fn(contour_fn='four_pt_hard', cont=None, ref_patch_size=None, center_shift=None): 19 | if contour_fn == 'four_pt_hard': 20 | cont_check_fn = isInContourV3_Hard(contour=cont, patch_size=ref_patch_size, center_shift=center_shift) 21 | elif contour_fn == 'four_pt_easy': 22 | cont_check_fn = isInContourV3_Easy(contour=cont, patch_size=ref_patch_size, center_shift=0.5) 23 | elif contour_fn == 'center': 24 | cont_check_fn = isInContourV2(contour=cont, patch_size=ref_patch_size) 25 | elif contour_fn == 'basic': 26 | cont_check_fn = isInContourV1(contour=cont) 27 | else: 28 | raise NotImplementedError 29 | return cont_check_fn 30 | 31 | 32 | 33 | class Wsi_Region(Dataset): 34 | ''' 35 | args: 36 | wsi_object: instance of WholeSlideImage wrapper over a WSI 37 | top_left: tuple of coordinates representing the top left corner of WSI region (Default: None) 38 | bot_right tuple of coordinates representing the bot right corner of WSI region (Default: None) 39 | level: downsample level at which to prcess the WSI region 40 | patch_size: tuple of width, height representing the patch size 41 | step_size: tuple of w_step, h_step representing the step size 42 | contour_fn (str): 43 | contour checking fn to use 44 | choice of ['four_pt_hard', 'four_pt_easy', 'center', 'basic'] (Default: 'four_pt_hard') 45 | t: custom torchvision transformation to apply 46 | custom_downsample (int): additional downscale factor to apply 47 | use_center_shift: for 'four_pt_hard' contour check, how far out to shift the 4 points 48 | ''' 49 | def __init__(self, wsi_object, top_left=None, bot_right=None, level=0, 50 | patch_size = (256, 256), step_size=(256, 256), 51 | contour_fn='four_pt_hard', 52 | t=None, custom_downsample=1, use_center_shift=False): 53 | 54 | self.custom_downsample = custom_downsample 55 | 56 | # downscale factor in reference to level 0 57 | self.ref_downsample = wsi_object.level_downsamples[level] 58 | # patch size in reference to level 0 59 | self.ref_size = tuple((np.array(patch_size) * np.array(self.ref_downsample)).astype(int)) 60 | 61 | if self.custom_downsample > 1: 62 | self.target_patch_size = patch_size 63 | patch_size = tuple((np.array(patch_size) * np.array(self.ref_downsample) * custom_downsample).astype(int)) 64 | step_size = tuple((np.array(step_size) * custom_downsample).astype(int)) 65 | self.ref_size = patch_size 66 | else: 67 | step_size = tuple((np.array(step_size)).astype(int)) 68 | self.ref_size = tuple((np.array(patch_size) * np.array(self.ref_downsample)).astype(int)) 69 | 70 | self.wsi = wsi_object.wsi 71 | self.level = level 72 | self.patch_size = patch_size 73 | 74 | if not use_center_shift: 75 | center_shift = 0. 76 | else: 77 | overlap = 1 - float(step_size[0] / patch_size[0]) 78 | if overlap < 0.25: 79 | center_shift = 0.375 80 | elif overlap >= 0.25 and overlap < 0.75: 81 | center_shift = 0.5 82 | elif overlap >=0.75 and overlap < 0.95: 83 | center_shift = 0.5 84 | else: 85 | center_shift = 0.625 86 | #center_shift = 0.375 # 25% overlap 87 | #center_shift = 0.625 #50%, 75% overlap 88 | #center_shift = 1.0 #95% overlap 89 | 90 | filtered_coords = [] 91 | #iterate through tissue contours for valid patch coordinates 92 | for cont_idx, contour in enumerate(wsi_object.contours_tissue): 93 | print('processing {}/{} contours'.format(cont_idx, len(wsi_object.contours_tissue))) 94 | cont_check_fn = get_contour_check_fn(contour_fn, contour, self.ref_size[0], center_shift) 95 | coord_results, _ = wsi_object.process_contour(contour, wsi_object.holes_tissue[cont_idx], level, '', 96 | patch_size = patch_size[0], step_size = step_size[0], contour_fn=cont_check_fn, 97 | use_padding=True, top_left = top_left, bot_right = bot_right) 98 | if len(coord_results) > 0: 99 | filtered_coords.append(coord_results['coords']) 100 | 101 | coords=np.vstack(filtered_coords) 102 | 103 | self.coords = coords 104 | print('filtered a total of {} coordinates'.format(len(self.coords))) 105 | 106 | # apply transformation 107 | if t is None: 108 | self.transforms = default_transforms() 109 | else: 110 | self.transforms = t 111 | 112 | def __len__(self): 113 | return len(self.coords) 114 | 115 | def __getitem__(self, idx): 116 | coord = self.coords[idx] 117 | patch = self.wsi.read_region(tuple(coord), self.level, self.patch_size).convert('RGB') 118 | if self.custom_downsample > 1: 119 | patch = patch.resize(self.target_patch_size) 120 | patch = self.transforms(patch).unsqueeze(0) 121 | return patch, coord 122 | -------------------------------------------------------------------------------- /MIL/extract_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from math import floor 4 | import os 5 | import random 6 | import numpy as np 7 | import pdb 8 | import time 9 | from datasets.dataset_h5 import Dataset_All_Bags, Whole_Slide_Bag 10 | from torch.utils.data import DataLoader 11 | from models.resnet_custom import resnet50_baseline 12 | import argparse 13 | from utils.utils import print_network, collate_features 14 | from utils.file_utils import save_hdf5 15 | from PIL import Image 16 | import h5py 17 | 18 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 19 | 20 | 21 | def compute_w_loader(file_path, output_path, model, batch_size = 8, verbose = 0, 22 | print_every=20, pretrained=True, target_patch_size=-1): 23 | """ 24 | args: 25 | file_path: directory of bag (.h5 file) 26 | output_path: directory to save computed features (.h5 file) 27 | model: pytorch model 28 | batch_size: batch_size for computing features in batches 29 | verbose: level of feedback 30 | pretrained: use weights pretrained on imagenet 31 | """ 32 | dataset = Whole_Slide_Bag(file_path=file_path, pretrained=pretrained, 33 | target_patch_size=target_patch_size) 34 | x, y = dataset[0] 35 | kwargs = {'num_workers': 4, 'pin_memory': True} if device.type == "cuda" else {} 36 | loader = DataLoader(dataset=dataset, batch_size=batch_size, **kwargs, collate_fn=collate_features) 37 | 38 | if verbose > 0: 39 | print('processing {}: total of {} batches'.format(file_path,len(loader))) 40 | 41 | mode = 'w' 42 | for count, (batch, coords) in enumerate(loader): 43 | with torch.no_grad(): 44 | if count % print_every == 0: 45 | print('batch {}/{}, {} files processed'.format(count, len(loader), count * batch_size)) 46 | batch = batch.to(device, non_blocking=True) 47 | mini_bs = coords.shape[0] 48 | 49 | features = model(batch) 50 | 51 | features = features.cpu().numpy() 52 | 53 | asset_dict = {'features': features, 'coords': coords} 54 | save_hdf5(output_path, asset_dict, attr_dict= None, mode=mode) 55 | mode = 'a' 56 | 57 | return output_path 58 | 59 | 60 | parser = argparse.ArgumentParser(description='Feature Extraction') 61 | parser.add_argument('--data_dir', type=str) 62 | parser.add_argument('--csv_path', type=str) 63 | parser.add_argument('--feat_dir', type=str) 64 | parser.add_argument('--batch_size', type=int, default=256) 65 | parser.add_argument('--slide_ext', type=str, default= '.svs') 66 | parser.add_argument('--no_auto_skip', default=False, action='store_true') 67 | parser.add_argument('--target_patch_size', type=int, default=-1, 68 | help='the desired size of patches for optional scaling before feature embedding') 69 | args = parser.parse_args() 70 | 71 | 72 | if __name__ == '__main__': 73 | 74 | print('initializing dataset') 75 | csv_path = args.csv_path 76 | bags_dataset = Dataset_All_Bags(csv_path) 77 | 78 | os.makedirs(args.feat_dir, exist_ok=True) 79 | dest_files = os.listdir(args.feat_dir) 80 | 81 | print('loading model checkpoint') 82 | model = resnet50_baseline(pretrained=True) 83 | model = model.to(device) 84 | 85 | # print_network(model) 86 | if torch.cuda.device_count() > 1: 87 | model = nn.DataParallel(model) 88 | 89 | model.eval() 90 | total = len(bags_dataset) 91 | 92 | for bag_candidate_idx in range(total): 93 | slide_id = bags_dataset[bag_candidate_idx].split(args.slide_ext)[0] 94 | bag_name = slide_id + '.h5' 95 | bag_candidate = os.path.join(args.data_dir, 'patches', bag_name) 96 | 97 | print('\nprogress: {}/{}'.format(bag_candidate_idx, total)) 98 | print(bag_name) 99 | if not args.no_auto_skip and slide_id+'.pt' in dest_files: 100 | print('skipped {}'.format(slide_id)) 101 | continue 102 | 103 | output_path = os.path.join(args.feat_dir, 'h5_files', bag_name) 104 | file_path = bag_candidate 105 | time_start = time.time() 106 | output_file_path = compute_w_loader(file_path, output_path, 107 | model = model, batch_size = args.batch_size, 108 | verbose = 1, print_every = 20, 109 | target_patch_size=args.target_patch_size) 110 | time_elapsed = time.time() - time_start 111 | print('\ncomputing features for {} took {} s'.format(output_file_path, time_elapsed)) 112 | file = h5py.File(output_file_path, "r") 113 | 114 | features = file['features'][:] 115 | print('features size: ', features.shape) 116 | print('coordinates size: ', file['coords'].shape) 117 | features = torch.from_numpy(features) 118 | bag_base, _ = os.path.splitext(bag_name) 119 | torch.save(features, os.path.join(args.feat_dir, 'pt_files', bag_base+'.pt')) 120 | -------------------------------------------------------------------------------- /MIL/extract_features_fp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from math import floor 4 | import os 5 | import random 6 | import numpy as np 7 | import pdb 8 | import time 9 | from datasets.dataset_h5 import Dataset_All_Bags, Whole_Slide_Bag_FP 10 | from torch.utils.data import DataLoader 11 | from models.resnet_custom import resnet50_baseline 12 | import argparse 13 | from utils.utils import print_network, collate_features 14 | from utils.file_utils import save_hdf5 15 | from PIL import Image 16 | import h5py 17 | import openslide 18 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 19 | 20 | def compute_w_loader(file_path, output_path, wsi, model, 21 | batch_size = 8, verbose = 0, print_every=20, pretrained=True, 22 | custom_downsample=1, target_patch_size=-1): 23 | """ 24 | args: 25 | file_path: directory of bag (.h5 file) 26 | output_path: directory to save computed features (.h5 file) 27 | model: pytorch model 28 | batch_size: batch_size for computing features in batches 29 | verbose: level of feedback 30 | pretrained: use weights pretrained on imagenet 31 | custom_downsample: custom defined downscale factor of image patches 32 | target_patch_size: custom defined, rescaled image size before embedding 33 | """ 34 | dataset = Whole_Slide_Bag_FP(file_path=file_path, wsi=wsi, pretrained=pretrained, 35 | custom_downsample=custom_downsample, target_patch_size=target_patch_size) 36 | x, y = dataset[0] 37 | kwargs = {'num_workers': 4, 'pin_memory': True} if device.type == "cuda" else {} 38 | loader = DataLoader(dataset=dataset, batch_size=batch_size, **kwargs, collate_fn=collate_features) 39 | 40 | if verbose > 0: 41 | print('processing {}: total of {} batches'.format(file_path,len(loader))) 42 | 43 | mode = 'w' 44 | for count, (batch, coords) in enumerate(loader): 45 | with torch.no_grad(): 46 | if count % print_every == 0: 47 | print('batch {}/{}, {} files processed'.format(count, len(loader), count * batch_size)) 48 | batch = batch.to(device, non_blocking=True) 49 | 50 | features = model(batch) 51 | features = features.cpu().numpy() 52 | 53 | asset_dict = {'features': features, 'coords': coords} 54 | save_hdf5(output_path, asset_dict, attr_dict= None, mode=mode) 55 | mode = 'a' 56 | 57 | return output_path 58 | 59 | 60 | parser = argparse.ArgumentParser(description='Feature Extraction') 61 | parser.add_argument('--data_h5_dir', type=str, default=None) 62 | parser.add_argument('--data_slide_dir', type=str, default=None) 63 | parser.add_argument('--slide_ext', type=str, default= '.svs') 64 | parser.add_argument('--csv_path', type=str, default=None) 65 | parser.add_argument('--feat_dir', type=str, default=None) 66 | parser.add_argument('--batch_size', type=int, default=256) 67 | parser.add_argument('--no_auto_skip', default=False, action='store_true') 68 | parser.add_argument('--custom_downsample', type=int, default=1) 69 | parser.add_argument('--target_patch_size', type=int, default=-1) 70 | args = parser.parse_args() 71 | 72 | 73 | if __name__ == '__main__': 74 | 75 | print('initializing dataset') 76 | csv_path = args.csv_path 77 | if csv_path is None: 78 | raise NotImplementedError 79 | 80 | bags_dataset = Dataset_All_Bags(csv_path) 81 | 82 | os.makedirs(args.feat_dir, exist_ok=True) 83 | os.makedirs(os.path.join(args.feat_dir, 'pt_files'), exist_ok=True) 84 | os.makedirs(os.path.join(args.feat_dir, 'h5_files'), exist_ok=True) 85 | dest_files = os.listdir(os.path.join(args.feat_dir, 'pt_files')) 86 | 87 | print('loading model checkpoint') 88 | model = resnet50_baseline(pretrained=True) 89 | model = model.to(device) 90 | 91 | # print_network(model) 92 | if torch.cuda.device_count() > 1: 93 | model = nn.DataParallel(model) 94 | 95 | model.eval() 96 | total = len(bags_dataset) 97 | 98 | for bag_candidate_idx in range(total): 99 | slide_id = bags_dataset[bag_candidate_idx].split(args.slide_ext)[0] 100 | bag_name = slide_id+'.h5' 101 | h5_file_path = os.path.join(args.data_h5_dir, 'patches', bag_name) 102 | slide_file_path = os.path.join(args.data_slide_dir, slide_id+args.slide_ext) 103 | print('\nprogress: {}/{}'.format(bag_candidate_idx, total)) 104 | print(slide_id) 105 | 106 | if not args.no_auto_skip and slide_id+'.pt' in dest_files: 107 | print('skipped {}'.format(slide_id)) 108 | continue 109 | 110 | output_path = os.path.join(args.feat_dir, 'h5_files', bag_name) 111 | time_start = time.time() 112 | wsi = openslide.open_slide(slide_file_path) 113 | output_file_path = compute_w_loader(h5_file_path, output_path, wsi, 114 | model = model, batch_size = args.batch_size, verbose = 1, print_every = 20, 115 | custom_downsample=args.custom_downsample, target_patch_size=args.target_patch_size) 116 | time_elapsed = time.time() - time_start 117 | print('\ncomputing features for {} took {} s'.format(output_file_path, time_elapsed)) 118 | file = h5py.File(output_file_path, "r") 119 | 120 | features = file['features'][:] 121 | print('features size: ', features.shape) 122 | print('coordinates size: ', file['coords'].shape) 123 | features = torch.from_numpy(features) 124 | bag_base, _ = os.path.splitext(bag_name) 125 | torch.save(features, os.path.join(args.feat_dir, 'pt_files', bag_base+'.pt')) 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /MIL/models/model_mil.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.utils import initialize_weights 5 | import numpy as np 6 | 7 | class MIL_fc(nn.Module): 8 | def __init__(self, gate = True, size_arg = "small", dropout = False, n_classes = 2, top_k=1, embedding_dim=1536): 9 | super(MIL_fc, self).__init__() 10 | assert n_classes == 2 11 | self.size_dict = {"small": [embedding_dim, 512]} 12 | size = self.size_dict[size_arg] 13 | fc = [nn.Linear(size[0], size[1]), nn.ReLU()] 14 | if dropout: 15 | fc.append(nn.Dropout(0.25)) 16 | 17 | fc.append(nn.Linear(size[1], n_classes)) 18 | self.classifier= nn.Sequential(*fc) 19 | initialize_weights(self) 20 | self.top_k=top_k 21 | 22 | def relocate(self): 23 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | self.classifier.to(device) 25 | 26 | def forward(self, h, return_features=False): 27 | if return_features: 28 | h = self.classifier.module[:3](h) 29 | logits = self.classifier.module[3](h) 30 | else: 31 | logits = self.classifier(h) # K x 1 32 | 33 | y_probs = F.softmax(logits, dim = 1) 34 | top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1,) 35 | top_instance = torch.index_select(logits, dim=0, index=top_instance_idx) 36 | Y_hat = torch.topk(top_instance, 1, dim = 1)[1] 37 | Y_prob = F.softmax(top_instance, dim = 1) 38 | results_dict = {} 39 | 40 | if return_features: 41 | top_features = torch.index_select(h, dim=0, index=top_instance_idx) 42 | results_dict.update({'features': top_features}) 43 | return top_instance, Y_prob, Y_hat, y_probs, results_dict 44 | 45 | 46 | class MIL_fc_mc(nn.Module): 47 | def __init__(self, gate = True, size_arg = "small", dropout = False, n_classes = 2, top_k=1): 48 | super(MIL_fc_mc, self).__init__() 49 | assert n_classes > 2 50 | self.size_dict = {"small": [1024, 512]} 51 | size = self.size_dict[size_arg] 52 | fc = [nn.Linear(size[0], size[1]), nn.ReLU()] 53 | if dropout: 54 | fc.append(nn.Dropout(0.25)) 55 | self.fc = nn.Sequential(*fc) 56 | 57 | self.classifiers = nn.ModuleList([nn.Linear(size[1], 1) for i in range(n_classes)]) 58 | initialize_weights(self) 59 | self.top_k=top_k 60 | self.n_classes = n_classes 61 | assert self.top_k == 1 62 | 63 | def relocate(self): 64 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 65 | self.fc = self.fc.to(device) 66 | self.classifiers = self.classifiers.to(device) 67 | 68 | def forward(self, h, return_features=False): 69 | device = h.device 70 | 71 | h = self.fc(h) 72 | logits = torch.empty(h.size(0), self.n_classes).float().to(device) 73 | 74 | for c in range(self.n_classes): 75 | if isinstance(self.classifiers, nn.DataParallel): 76 | logits[:, c] = self.classifiers.module[c](h).squeeze(1) 77 | else: 78 | logits[:, c] = self.classifiers[c](h).squeeze(1) 79 | 80 | y_probs = F.softmax(logits, dim = 1) 81 | m = y_probs.view(1, -1).argmax(1) 82 | top_indices = torch.cat(((m // self.n_classes).view(-1, 1), (m % self.n_classes).view(-1, 1)), dim=1).view(-1, 1) 83 | top_instance = logits[top_indices[0]] 84 | 85 | Y_hat = top_indices[1] 86 | Y_prob = y_probs[top_indices[0]] 87 | 88 | results_dict = {} 89 | 90 | if return_features: 91 | top_features = torch.index_select(h, dim=0, index=top_indices[0]) 92 | results_dict.update({'features': top_features}) 93 | return top_instance, Y_prob, Y_hat, y_probs, results_dict 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /MIL/models/resnet_custom.py: -------------------------------------------------------------------------------- 1 | # modified from Pytorch official resnet.py 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | class Bottleneck_Baseline(nn.Module): 19 | expansion = 4 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(Bottleneck_Baseline, self).__init__() 23 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 29 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv3(out) 46 | out = self.bn3(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | class ResNet_Baseline(nn.Module): 57 | 58 | def __init__(self, block, layers): 59 | self.inplanes = 64 60 | super(ResNet_Baseline, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 66 | self.layer1 = self._make_layer(block, 64, layers[0]) 67 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 68 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 69 | self.avgpool = nn.AdaptiveAvgPool2d(1) 70 | 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 74 | elif isinstance(m, nn.BatchNorm2d): 75 | nn.init.constant_(m.weight, 1) 76 | nn.init.constant_(m.bias, 0) 77 | 78 | def _make_layer(self, block, planes, blocks, stride=1): 79 | downsample = None 80 | if stride != 1 or self.inplanes != planes * block.expansion: 81 | downsample = nn.Sequential( 82 | nn.Conv2d(self.inplanes, planes * block.expansion, 83 | kernel_size=1, stride=stride, bias=False), 84 | nn.BatchNorm2d(planes * block.expansion), 85 | ) 86 | 87 | layers = [] 88 | layers.append(block(self.inplanes, planes, stride, downsample)) 89 | self.inplanes = planes * block.expansion 90 | for i in range(1, blocks): 91 | layers.append(block(self.inplanes, planes)) 92 | 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | x = self.conv1(x) 97 | x = self.bn1(x) 98 | x = self.relu(x) 99 | x = self.maxpool(x) 100 | 101 | x = self.layer1(x) 102 | x = self.layer2(x) 103 | x = self.layer3(x) 104 | 105 | x = self.avgpool(x) 106 | x = x.view(x.size(0), -1) 107 | 108 | return x 109 | 110 | def resnet50_baseline(pretrained=False): 111 | """Constructs a Modified ResNet-50 model. 112 | Args: 113 | pretrained (bool): If True, returns a model pre-trained on ImageNet 114 | """ 115 | model = ResNet_Baseline(Bottleneck_Baseline, [3, 4, 6, 3]) 116 | if pretrained: 117 | model = load_pretrained_weights(model, 'resnet50') 118 | return model 119 | 120 | def load_pretrained_weights(model, name): 121 | pretrained_dict = model_zoo.load_url(model_urls[name]) 122 | model.load_state_dict(pretrained_dict, strict=False) 123 | return model 124 | 125 | 126 | -------------------------------------------------------------------------------- /MIL/splits/test/splits_0.csv: -------------------------------------------------------------------------------- 1 | ,train 2 | 0,test_003 3 | 1,test_005 4 | 2,test_006 5 | 3,test_007 6 | 4,test_009 7 | 5,test_012 8 | 6,test_014 9 | 7,test_015 10 | 8,test_017 11 | 9,test_018 12 | 10,test_019 13 | 11,test_020 14 | 12,test_022 15 | 13,test_023 16 | 14,test_024 17 | 15,test_025 18 | 16,test_028 19 | 17,test_031 20 | 18,test_032 21 | 19,test_034 22 | 20,test_035 23 | 21,test_036 24 | 22,test_037 25 | 23,test_039 26 | 24,test_041 27 | 25,test_042 28 | 26,test_043 29 | 27,test_044 30 | 28,test_045 31 | 29,test_047 32 | 30,test_050 33 | 31,test_053 34 | 32,test_054 35 | 33,test_055 36 | 34,test_056 37 | 35,test_057 38 | 36,test_058 39 | 37,test_059 40 | 38,test_060 41 | 39,test_062 42 | 40,test_063 43 | 41,test_067 44 | 42,test_070 45 | 43,test_072 46 | 44,test_076 47 | 45,test_077 48 | 46,test_078 49 | 47,test_080 50 | 48,test_081 51 | 49,test_083 52 | 50,test_085 53 | 51,test_086 54 | 52,test_087 55 | 53,test_088 56 | 54,test_089 57 | 55,test_091 58 | 56,test_093 59 | 57,test_095 60 | 58,test_096 61 | 59,test_098 62 | 60,test_100 63 | 61,test_101 64 | 62,test_103 65 | 63,test_106 66 | 64,test_107 67 | 65,test_109 68 | 66,test_111 69 | 67,test_112 70 | 68,test_115 71 | 69,test_118 72 | 70,test_119 73 | 71,test_120 74 | 72,test_123 75 | 73,test_124 76 | 74,test_125 77 | 75,test_126 78 | 76,test_127 79 | 77,test_128 80 | 78,test_129 81 | 79,test_130 82 | 80,test_001 83 | 81,test_002 84 | 82,test_004 85 | 83,test_008 86 | 84,test_010 87 | 85,test_011 88 | 86,test_013 89 | 87,test_016 90 | 88,test_021 91 | 89,test_026 92 | 90,test_027 93 | 91,test_029 94 | 92,test_030 95 | 93,test_033 96 | 94,test_038 97 | 95,test_040 98 | 96,test_046 99 | 97,test_048 100 | 98,test_051 101 | 99,test_052 102 | 100,test_061 103 | 101,test_064 104 | 102,test_065 105 | 103,test_066 106 | 104,test_068 107 | 105,test_069 108 | 106,test_071 109 | 107,test_073 110 | 108,test_074 111 | 109,test_075 112 | 110,test_079 113 | 111,test_082 114 | 112,test_084 115 | 113,test_090 116 | 114,test_092 117 | 115,test_094 118 | 116,test_097 119 | 117,test_099 120 | 118,test_102 121 | 119,test_104 122 | 120,test_105 123 | 121,test_108 124 | 122,test_110 125 | 123,test_113 126 | 124,test_114 127 | 125,test_116 128 | 126,test_117 129 | 127,test_121 130 | 128,test_122 131 | -------------------------------------------------------------------------------- /MIL/splits/test/splits_0_bool.csv: -------------------------------------------------------------------------------- 1 | ,train 2 | test_003,True 3 | test_005,True 4 | test_006,True 5 | test_007,True 6 | test_009,True 7 | test_012,True 8 | test_014,True 9 | test_015,True 10 | test_017,True 11 | test_018,True 12 | test_019,True 13 | test_020,True 14 | test_022,True 15 | test_023,True 16 | test_024,True 17 | test_025,True 18 | test_028,True 19 | test_031,True 20 | test_032,True 21 | test_034,True 22 | test_035,True 23 | test_036,True 24 | test_037,True 25 | test_039,True 26 | test_041,True 27 | test_042,True 28 | test_043,True 29 | test_044,True 30 | test_045,True 31 | test_047,True 32 | test_050,True 33 | test_053,True 34 | test_054,True 35 | test_055,True 36 | test_056,True 37 | test_057,True 38 | test_058,True 39 | test_059,True 40 | test_060,True 41 | test_062,True 42 | test_063,True 43 | test_067,True 44 | test_070,True 45 | test_072,True 46 | test_076,True 47 | test_077,True 48 | test_078,True 49 | test_080,True 50 | test_081,True 51 | test_083,True 52 | test_085,True 53 | test_086,True 54 | test_087,True 55 | test_088,True 56 | test_089,True 57 | test_091,True 58 | test_093,True 59 | test_095,True 60 | test_096,True 61 | test_098,True 62 | test_100,True 63 | test_101,True 64 | test_103,True 65 | test_106,True 66 | test_107,True 67 | test_109,True 68 | test_111,True 69 | test_112,True 70 | test_115,True 71 | test_118,True 72 | test_119,True 73 | test_120,True 74 | test_123,True 75 | test_124,True 76 | test_125,True 77 | test_126,True 78 | test_127,True 79 | test_128,True 80 | test_129,True 81 | test_130,True 82 | test_001,True 83 | test_002,True 84 | test_004,True 85 | test_008,True 86 | test_010,True 87 | test_011,True 88 | test_013,True 89 | test_016,True 90 | test_021,True 91 | test_026,True 92 | test_027,True 93 | test_029,True 94 | test_030,True 95 | test_033,True 96 | test_038,True 97 | test_040,True 98 | test_046,True 99 | test_048,True 100 | test_051,True 101 | test_052,True 102 | test_061,True 103 | test_064,True 104 | test_065,True 105 | test_066,True 106 | test_068,True 107 | test_069,True 108 | test_071,True 109 | test_073,True 110 | test_074,True 111 | test_075,True 112 | test_079,True 113 | test_082,True 114 | test_084,True 115 | test_090,True 116 | test_092,True 117 | test_094,True 118 | test_097,True 119 | test_099,True 120 | test_102,True 121 | test_104,True 122 | test_105,True 123 | test_108,True 124 | test_110,True 125 | test_113,True 126 | test_114,True 127 | test_116,True 128 | test_117,True 129 | test_121,True 130 | test_122,True 131 | -------------------------------------------------------------------------------- /MIL/splits/train/splits_0.csv: -------------------------------------------------------------------------------- 1 | ,train 2 | 0,normal_001 3 | 1,normal_002 4 | 2,normal_003 5 | 3,normal_004 6 | 4,normal_005 7 | 5,normal_006 8 | 6,normal_007 9 | 7,normal_008 10 | 8,normal_009 11 | 9,normal_010 12 | 10,normal_011 13 | 11,normal_012 14 | 12,normal_013 15 | 13,normal_014 16 | 14,normal_015 17 | 15,normal_016 18 | 16,normal_017 19 | 17,normal_018 20 | 18,normal_019 21 | 19,normal_020 22 | 20,normal_021 23 | 21,normal_022 24 | 22,normal_023 25 | 23,normal_024 26 | 24,normal_025 27 | 25,normal_026 28 | 26,normal_027 29 | 27,normal_028 30 | 28,normal_029 31 | 29,normal_030 32 | 30,normal_031 33 | 31,normal_032 34 | 32,normal_033 35 | 33,normal_034 36 | 34,normal_035 37 | 35,normal_036 38 | 36,normal_037 39 | 37,normal_038 40 | 38,normal_039 41 | 39,normal_040 42 | 40,normal_041 43 | 41,normal_042 44 | 42,normal_043 45 | 43,normal_044 46 | 44,normal_045 47 | 45,normal_046 48 | 46,normal_047 49 | 47,normal_048 50 | 48,normal_049 51 | 49,normal_050 52 | 50,normal_051 53 | 51,normal_052 54 | 52,normal_053 55 | 53,normal_054 56 | 54,normal_055 57 | 55,normal_056 58 | 56,normal_057 59 | 57,normal_058 60 | 58,normal_059 61 | 59,normal_060 62 | 60,normal_061 63 | 61,normal_062 64 | 62,normal_063 65 | 63,normal_064 66 | 64,normal_065 67 | 65,normal_066 68 | 66,normal_067 69 | 67,normal_068 70 | 68,normal_069 71 | 69,normal_070 72 | 70,normal_071 73 | 71,normal_072 74 | 72,normal_073 75 | 73,normal_074 76 | 74,normal_075 77 | 75,normal_076 78 | 76,normal_077 79 | 77,normal_078 80 | 78,normal_079 81 | 79,normal_080 82 | 80,normal_081 83 | 81,normal_082 84 | 82,normal_083 85 | 83,normal_084 86 | 84,normal_085 87 | 85,normal_087 88 | 86,normal_088 89 | 87,normal_089 90 | 88,normal_090 91 | 89,normal_091 92 | 90,normal_092 93 | 91,normal_093 94 | 92,normal_094 95 | 93,normal_095 96 | 94,normal_096 97 | 95,normal_097 98 | 96,normal_098 99 | 97,normal_099 100 | 98,normal_100 101 | 99,normal_101 102 | 100,normal_102 103 | 101,normal_103 104 | 102,normal_104 105 | 103,normal_105 106 | 104,normal_106 107 | 105,normal_107 108 | 106,normal_108 109 | 107,normal_109 110 | 108,normal_110 111 | 109,normal_111 112 | 110,normal_112 113 | 111,normal_113 114 | 112,normal_114 115 | 113,normal_115 116 | 114,normal_116 117 | 115,normal_117 118 | 116,normal_118 119 | 117,normal_119 120 | 118,normal_120 121 | 119,normal_121 122 | 120,normal_122 123 | 121,normal_123 124 | 122,normal_124 125 | 123,normal_125 126 | 124,normal_126 127 | 125,normal_127 128 | 126,normal_128 129 | 127,normal_129 130 | 128,normal_130 131 | 129,normal_131 132 | 130,normal_132 133 | 131,normal_133 134 | 132,normal_134 135 | 133,normal_135 136 | 134,normal_136 137 | 135,normal_137 138 | 136,normal_138 139 | 137,normal_139 140 | 138,normal_140 141 | 139,normal_141 142 | 140,normal_142 143 | 141,normal_143 144 | 142,normal_144 145 | 143,normal_145 146 | 144,normal_146 147 | 145,normal_147 148 | 146,normal_148 149 | 147,normal_149 150 | 148,normal_150 151 | 149,normal_151 152 | 150,normal_152 153 | 151,normal_153 154 | 152,normal_154 155 | 153,normal_155 156 | 154,normal_156 157 | 155,normal_157 158 | 156,normal_158 159 | 157,normal_159 160 | 158,tumor_001 161 | 159,tumor_002 162 | 160,tumor_003 163 | 161,tumor_004 164 | 162,tumor_005 165 | 163,tumor_006 166 | 164,tumor_007 167 | 165,tumor_008 168 | 166,tumor_009 169 | 167,tumor_010 170 | 168,tumor_011 171 | 169,tumor_012 172 | 170,tumor_013 173 | 171,tumor_014 174 | 172,tumor_015 175 | 173,tumor_016 176 | 174,tumor_017 177 | 175,tumor_018 178 | 176,tumor_019 179 | 177,tumor_020 180 | 178,tumor_021 181 | 179,tumor_022 182 | 180,tumor_023 183 | 181,tumor_024 184 | 182,tumor_025 185 | 183,tumor_026 186 | 184,tumor_027 187 | 185,tumor_028 188 | 186,tumor_029 189 | 187,tumor_030 190 | 188,tumor_031 191 | 189,tumor_032 192 | 190,tumor_033 193 | 191,tumor_034 194 | 192,tumor_035 195 | 193,tumor_036 196 | 194,tumor_037 197 | 195,tumor_038 198 | 196,tumor_039 199 | 197,tumor_040 200 | 198,tumor_041 201 | 199,tumor_042 202 | 200,tumor_043 203 | 201,tumor_044 204 | 202,tumor_045 205 | 203,tumor_046 206 | 204,tumor_047 207 | 205,tumor_048 208 | 206,tumor_049 209 | 207,tumor_050 210 | 208,tumor_051 211 | 209,tumor_052 212 | 210,tumor_053 213 | 211,tumor_054 214 | 212,tumor_055 215 | 213,tumor_056 216 | 214,tumor_057 217 | 215,tumor_058 218 | 216,tumor_059 219 | 217,tumor_060 220 | 218,tumor_061 221 | 219,tumor_062 222 | 220,tumor_063 223 | 221,tumor_064 224 | 222,tumor_065 225 | 223,tumor_066 226 | 224,tumor_067 227 | 225,tumor_068 228 | 226,tumor_069 229 | 227,tumor_070 230 | 228,tumor_071 231 | 229,tumor_072 232 | 230,tumor_073 233 | 231,tumor_074 234 | 232,tumor_075 235 | 233,tumor_076 236 | 234,tumor_077 237 | 235,tumor_078 238 | 236,tumor_079 239 | 237,tumor_080 240 | 238,tumor_081 241 | 239,tumor_082 242 | 240,tumor_083 243 | 241,tumor_084 244 | 242,tumor_085 245 | 243,tumor_086 246 | 244,tumor_087 247 | 245,tumor_088 248 | 246,tumor_089 249 | 247,tumor_090 250 | 248,tumor_091 251 | 249,tumor_092 252 | 250,tumor_093 253 | 251,tumor_094 254 | 252,tumor_095 255 | 253,tumor_096 256 | 254,tumor_097 257 | 255,tumor_098 258 | 256,tumor_099 259 | 257,tumor_100 260 | 258,tumor_101 261 | 259,tumor_102 262 | 260,tumor_103 263 | 261,tumor_104 264 | 262,tumor_105 265 | 263,tumor_106 266 | 264,tumor_107 267 | 265,tumor_108 268 | 266,tumor_109 269 | 267,tumor_110 270 | -------------------------------------------------------------------------------- /MIL/splits/train/splits_0_bool.csv: -------------------------------------------------------------------------------- 1 | ,train 2 | normal_001,True 3 | normal_002,True 4 | normal_003,True 5 | normal_004,True 6 | normal_005,True 7 | normal_006,True 8 | normal_007,True 9 | normal_008,True 10 | normal_009,True 11 | normal_010,True 12 | normal_011,True 13 | normal_012,True 14 | normal_013,True 15 | normal_014,True 16 | normal_015,True 17 | normal_016,True 18 | normal_017,True 19 | normal_018,True 20 | normal_019,True 21 | normal_020,True 22 | normal_021,True 23 | normal_022,True 24 | normal_023,True 25 | normal_024,True 26 | normal_025,True 27 | normal_026,True 28 | normal_027,True 29 | normal_028,True 30 | normal_029,True 31 | normal_030,True 32 | normal_031,True 33 | normal_032,True 34 | normal_033,True 35 | normal_034,True 36 | normal_035,True 37 | normal_036,True 38 | normal_037,True 39 | normal_038,True 40 | normal_039,True 41 | normal_040,True 42 | normal_041,True 43 | normal_042,True 44 | normal_043,True 45 | normal_044,True 46 | normal_045,True 47 | normal_046,True 48 | normal_047,True 49 | normal_048,True 50 | normal_049,True 51 | normal_050,True 52 | normal_051,True 53 | normal_052,True 54 | normal_053,True 55 | normal_054,True 56 | normal_055,True 57 | normal_056,True 58 | normal_057,True 59 | normal_058,True 60 | normal_059,True 61 | normal_060,True 62 | normal_061,True 63 | normal_062,True 64 | normal_063,True 65 | normal_064,True 66 | normal_065,True 67 | normal_066,True 68 | normal_067,True 69 | normal_068,True 70 | normal_069,True 71 | normal_070,True 72 | normal_071,True 73 | normal_072,True 74 | normal_073,True 75 | normal_074,True 76 | normal_075,True 77 | normal_076,True 78 | normal_077,True 79 | normal_078,True 80 | normal_079,True 81 | normal_080,True 82 | normal_081,True 83 | normal_082,True 84 | normal_083,True 85 | normal_084,True 86 | normal_085,True 87 | normal_087,True 88 | normal_088,True 89 | normal_089,True 90 | normal_090,True 91 | normal_091,True 92 | normal_092,True 93 | normal_093,True 94 | normal_094,True 95 | normal_095,True 96 | normal_096,True 97 | normal_097,True 98 | normal_098,True 99 | normal_099,True 100 | normal_100,True 101 | normal_101,True 102 | normal_102,True 103 | normal_103,True 104 | normal_104,True 105 | normal_105,True 106 | normal_106,True 107 | normal_107,True 108 | normal_108,True 109 | normal_109,True 110 | normal_110,True 111 | normal_111,True 112 | normal_112,True 113 | normal_113,True 114 | normal_114,True 115 | normal_115,True 116 | normal_116,True 117 | normal_117,True 118 | normal_118,True 119 | normal_119,True 120 | normal_120,True 121 | normal_121,True 122 | normal_122,True 123 | normal_123,True 124 | normal_124,True 125 | normal_125,True 126 | normal_126,True 127 | normal_127,True 128 | normal_128,True 129 | normal_129,True 130 | normal_130,True 131 | normal_131,True 132 | normal_132,True 133 | normal_133,True 134 | normal_134,True 135 | normal_135,True 136 | normal_136,True 137 | normal_137,True 138 | normal_138,True 139 | normal_139,True 140 | normal_140,True 141 | normal_141,True 142 | normal_142,True 143 | normal_143,True 144 | normal_144,True 145 | normal_145,True 146 | normal_146,True 147 | normal_147,True 148 | normal_148,True 149 | normal_149,True 150 | normal_150,True 151 | normal_151,True 152 | normal_152,True 153 | normal_153,True 154 | normal_154,True 155 | normal_155,True 156 | normal_156,True 157 | normal_157,True 158 | normal_158,True 159 | normal_159,True 160 | tumor_001,True 161 | tumor_002,True 162 | tumor_003,True 163 | tumor_004,True 164 | tumor_005,True 165 | tumor_006,True 166 | tumor_007,True 167 | tumor_008,True 168 | tumor_009,True 169 | tumor_010,True 170 | tumor_011,True 171 | tumor_012,True 172 | tumor_013,True 173 | tumor_014,True 174 | tumor_015,True 175 | tumor_016,True 176 | tumor_017,True 177 | tumor_018,True 178 | tumor_019,True 179 | tumor_020,True 180 | tumor_021,True 181 | tumor_022,True 182 | tumor_023,True 183 | tumor_024,True 184 | tumor_025,True 185 | tumor_026,True 186 | tumor_027,True 187 | tumor_028,True 188 | tumor_029,True 189 | tumor_030,True 190 | tumor_031,True 191 | tumor_032,True 192 | tumor_033,True 193 | tumor_034,True 194 | tumor_035,True 195 | tumor_036,True 196 | tumor_037,True 197 | tumor_038,True 198 | tumor_039,True 199 | tumor_040,True 200 | tumor_041,True 201 | tumor_042,True 202 | tumor_043,True 203 | tumor_044,True 204 | tumor_045,True 205 | tumor_046,True 206 | tumor_047,True 207 | tumor_048,True 208 | tumor_049,True 209 | tumor_050,True 210 | tumor_051,True 211 | tumor_052,True 212 | tumor_053,True 213 | tumor_054,True 214 | tumor_055,True 215 | tumor_056,True 216 | tumor_057,True 217 | tumor_058,True 218 | tumor_059,True 219 | tumor_060,True 220 | tumor_061,True 221 | tumor_062,True 222 | tumor_063,True 223 | tumor_064,True 224 | tumor_065,True 225 | tumor_066,True 226 | tumor_067,True 227 | tumor_068,True 228 | tumor_069,True 229 | tumor_070,True 230 | tumor_071,True 231 | tumor_072,True 232 | tumor_073,True 233 | tumor_074,True 234 | tumor_075,True 235 | tumor_076,True 236 | tumor_077,True 237 | tumor_078,True 238 | tumor_079,True 239 | tumor_080,True 240 | tumor_081,True 241 | tumor_082,True 242 | tumor_083,True 243 | tumor_084,True 244 | tumor_085,True 245 | tumor_086,True 246 | tumor_087,True 247 | tumor_088,True 248 | tumor_089,True 249 | tumor_090,True 250 | tumor_091,True 251 | tumor_092,True 252 | tumor_093,True 253 | tumor_094,True 254 | tumor_095,True 255 | tumor_096,True 256 | tumor_097,True 257 | tumor_098,True 258 | tumor_099,True 259 | tumor_100,True 260 | tumor_101,True 261 | tumor_102,True 262 | tumor_103,True 263 | tumor_104,True 264 | tumor_105,True 265 | tumor_106,True 266 | tumor_107,True 267 | tumor_108,True 268 | tumor_109,True 269 | tumor_110,True 270 | -------------------------------------------------------------------------------- /MIL/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from models.model_mil import MIL_fc, MIL_fc_mc 7 | from models.model_clam import CLAM_SB, CLAM_MB 8 | import pdb 9 | import os 10 | import pandas as pd 11 | from utils.utils import * 12 | from utils.core_utils import Accuracy_Logger 13 | from sklearn.metrics import roc_auc_score, roc_curve, auc 14 | from sklearn.preprocessing import label_binarize 15 | import matplotlib.pyplot as plt 16 | 17 | def initiate_model(args, ckpt_path): 18 | print('Init Model') 19 | model_dict = {"dropout": args.drop_out, 'n_classes': args.n_classes} 20 | 21 | if args.model_size is not None and args.model_type in ['clam_sb', 'clam_mb']: 22 | model_dict.update({"size_arg": args.model_size}) 23 | 24 | if args.model_type =='clam_sb': 25 | model = CLAM_SB(**model_dict) 26 | elif args.model_type =='clam_mb': 27 | model = CLAM_MB(**model_dict) 28 | else: # args.model_type == 'mil' 29 | if args.n_classes > 2: 30 | model = MIL_fc_mc(**model_dict) 31 | else: 32 | model = MIL_fc(**model_dict) 33 | 34 | print_network(model) 35 | 36 | ckpt = torch.load(ckpt_path) 37 | ckpt_clean = {} 38 | for key in ckpt.keys(): 39 | if 'instance_loss_fn' in key: 40 | continue 41 | ckpt_clean.update({key.replace('.module', ''):ckpt[key]}) 42 | model.load_state_dict(ckpt_clean, strict=True) 43 | 44 | model.relocate() 45 | model.eval() 46 | return model 47 | 48 | def eval(dataset, args, ckpt_path): 49 | model = initiate_model(args, ckpt_path) 50 | 51 | print('Init Loaders') 52 | loader = get_simple_loader(dataset) 53 | patient_results, test_error, auc, df, _ = summary(model, loader, args) 54 | print('test_error: ', test_error) 55 | print('auc: ', auc) 56 | return model, patient_results, test_error, auc, df 57 | 58 | def summary(model, loader, args): 59 | acc_logger = Accuracy_Logger(n_classes=args.n_classes) 60 | model.eval() 61 | test_loss = 0. 62 | test_error = 0. 63 | 64 | all_probs = np.zeros((len(loader), args.n_classes)) 65 | all_labels = np.zeros(len(loader)) 66 | all_preds = np.zeros(len(loader)) 67 | 68 | slide_ids = loader.dataset.slide_data['slide_id'] 69 | patient_results = {} 70 | for batch_idx, (data, label) in enumerate(loader): 71 | data, label = data.to(device), label.to(device) 72 | slide_id = slide_ids.iloc[batch_idx] 73 | with torch.no_grad(): 74 | logits, Y_prob, Y_hat, _, results_dict = model(data) 75 | 76 | acc_logger.log(Y_hat, label) 77 | 78 | probs = Y_prob.cpu().numpy() 79 | 80 | all_probs[batch_idx] = probs 81 | all_labels[batch_idx] = label.item() 82 | all_preds[batch_idx] = Y_hat.item() 83 | 84 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}}) 85 | 86 | error = calculate_error(Y_hat, label) 87 | test_error += error 88 | 89 | del data 90 | test_error /= len(loader) 91 | 92 | aucs = [] 93 | if len(np.unique(all_labels)) == 1: 94 | auc_score = -1 95 | 96 | else: 97 | if args.n_classes == 2: 98 | auc_score = roc_auc_score(all_labels, all_probs[:, 1]) 99 | else: 100 | binary_labels = label_binarize(all_labels, classes=[i for i in range(args.n_classes)]) 101 | for class_idx in range(args.n_classes): 102 | if class_idx in all_labels: 103 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx]) 104 | aucs.append(auc(fpr, tpr)) 105 | else: 106 | aucs.append(float('nan')) 107 | if args.micro_average: 108 | binary_labels = label_binarize(all_labels, classes=[i for i in range(args.n_classes)]) 109 | fpr, tpr, _ = roc_curve(binary_labels.ravel(), all_probs.ravel()) 110 | auc_score = auc(fpr, tpr) 111 | else: 112 | auc_score = np.nanmean(np.array(aucs)) 113 | 114 | results_dict = {'slide_id': slide_ids, 'Y': all_labels, 'Y_hat': all_preds} 115 | for c in range(args.n_classes): 116 | results_dict.update({'p_{}'.format(c): all_probs[:,c]}) 117 | df = pd.DataFrame(results_dict) 118 | return patient_results, test_error, auc_score, df, acc_logger 119 | -------------------------------------------------------------------------------- /MIL/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import h5py 3 | 4 | def save_pkl(filename, save_object): 5 | writer = open(filename,'wb') 6 | pickle.dump(save_object, writer) 7 | writer.close() 8 | 9 | def load_pkl(filename): 10 | loader = open(filename,'rb') 11 | file = pickle.load(loader) 12 | loader.close() 13 | return file 14 | 15 | 16 | def save_hdf5(output_path, asset_dict, attr_dict= None, mode='a'): 17 | file = h5py.File(output_path, mode) 18 | for key, val in asset_dict.items(): 19 | data_shape = val.shape 20 | if key not in file: 21 | data_type = val.dtype 22 | chunk_shape = (1, ) + data_shape[1:] 23 | maxshape = (None, ) + data_shape[1:] 24 | dset = file.create_dataset(key, shape=data_shape, maxshape=maxshape, chunks=chunk_shape, dtype=data_type) 25 | dset[:] = val 26 | if attr_dict is not None: 27 | if key in attr_dict.keys(): 28 | for attr_key, attr_val in attr_dict[key].items(): 29 | dset.attrs[attr_key] = attr_val 30 | else: 31 | dset = file[key] 32 | dset.resize(len(dset) + data_shape[0], axis=0) 33 | dset[-data_shape[0]:] = val 34 | file.close() 35 | return output_path -------------------------------------------------------------------------------- /MIL/utils/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import pdb 6 | 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | from torchvision import transforms 11 | from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler 12 | import torch.optim as optim 13 | import pdb 14 | import torch.nn.functional as F 15 | import math 16 | from itertools import islice 17 | import collections 18 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | class SubsetSequentialSampler(Sampler): 21 | """Samples elements sequentially from a given list of indices, without replacement. 22 | 23 | Arguments: 24 | indices (sequence): a sequence of indices 25 | """ 26 | def __init__(self, indices): 27 | self.indices = indices 28 | 29 | def __iter__(self): 30 | return iter(self.indices) 31 | 32 | def __len__(self): 33 | return len(self.indices) 34 | 35 | def collate_MIL(batch): 36 | img = torch.cat([item[0] for item in batch], dim = 0) 37 | label = torch.LongTensor([item[1] for item in batch]) 38 | return [img, label] 39 | 40 | def collate_features(batch): 41 | img = torch.Tensor(np.concatenate([item[0] for item in batch], axis= 0)) 42 | coords = np.vstack([item[1] for item in batch]) 43 | return [img, coords] 44 | 45 | 46 | def get_simple_loader(dataset, batch_size=1, num_workers=1): 47 | kwargs = {'num_workers': 4, 'pin_memory': False, 'num_workers': num_workers} if device.type == "cuda" else {} 48 | loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL, **kwargs) 49 | return loader 50 | 51 | def get_split_loader(split_dataset, training = False, testing = False, weighted = False): 52 | """ 53 | return either the validation loader or training loader 54 | """ 55 | kwargs = {'num_workers': 4} if device.type == "cuda" else {} 56 | if not testing: 57 | if training: 58 | if weighted: 59 | weights = make_weights_for_balanced_classes_split(split_dataset) 60 | loader = DataLoader(split_dataset, batch_size=1, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate_MIL, **kwargs) 61 | else: 62 | loader = DataLoader(split_dataset, batch_size=1, sampler = RandomSampler(split_dataset), collate_fn = collate_MIL, **kwargs) 63 | else: 64 | loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate_MIL, **kwargs) 65 | 66 | else: 67 | ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False) 68 | loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate_MIL, **kwargs ) 69 | 70 | return loader 71 | 72 | def get_optim(model, args): 73 | if args.opt == "adam": 74 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg) 75 | elif args.opt == 'sgd': 76 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg) 77 | else: 78 | raise NotImplementedError 79 | return optimizer 80 | 81 | def print_network(net): 82 | num_params = 0 83 | num_params_train = 0 84 | print(net) 85 | 86 | for param in net.parameters(): 87 | n = param.numel() 88 | num_params += n 89 | if param.requires_grad: 90 | num_params_train += n 91 | 92 | print('Total number of parameters: %d' % num_params) 93 | print('Total number of trainable parameters: %d' % num_params_train) 94 | 95 | 96 | def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5, 97 | seed = 7, label_frac = 1.0, custom_test_ids = None): 98 | indices = np.arange(samples).astype(int) 99 | 100 | if custom_test_ids is not None: 101 | indices = np.setdiff1d(indices, custom_test_ids) 102 | 103 | np.random.seed(seed) 104 | for i in range(n_splits): 105 | all_val_ids = [] 106 | all_test_ids = [] 107 | sampled_train_ids = [] 108 | 109 | if custom_test_ids is not None: # pre-built test split, do not need to sample 110 | all_test_ids.extend(custom_test_ids) 111 | 112 | for c in range(len(val_num)): 113 | possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class 114 | val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids 115 | 116 | remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation 117 | all_val_ids.extend(val_ids) 118 | 119 | if custom_test_ids is None: # sample test split 120 | 121 | test_ids = np.random.choice(remaining_ids, test_num[c], replace = False) 122 | remaining_ids = np.setdiff1d(remaining_ids, test_ids) 123 | all_test_ids.extend(test_ids) 124 | 125 | if label_frac == 1: 126 | sampled_train_ids.extend(remaining_ids) 127 | 128 | else: 129 | sample_num = math.ceil(len(remaining_ids) * label_frac) 130 | slice_ids = np.arange(sample_num) 131 | sampled_train_ids.extend(remaining_ids[slice_ids]) 132 | 133 | yield sampled_train_ids, all_val_ids, all_test_ids 134 | 135 | 136 | def nth(iterator, n, default=None): 137 | if n is None: 138 | return collections.deque(iterator, maxlen=0) 139 | else: 140 | return next(islice(iterator,n, None), default) 141 | 142 | def calculate_error(Y_hat, Y): 143 | error = 1. - Y_hat.float().eq(Y.float()).float().mean().item() 144 | 145 | return error 146 | 147 | def make_weights_for_balanced_classes_split(dataset): 148 | N = float(len(dataset)) 149 | weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))] 150 | weight = [0] * int(N) 151 | for idx in range(len(dataset)): 152 | y = dataset.getlabel(idx) 153 | weight[idx] = weight_per_class[y] 154 | 155 | return torch.DoubleTensor(weight) 156 | 157 | def initialize_weights(module): 158 | for m in module.modules(): 159 | if isinstance(m, nn.Linear): 160 | nn.init.xavier_normal_(m.weight) 161 | m.bias.data.zero_() 162 | 163 | elif isinstance(m, nn.BatchNorm1d): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | 167 | -------------------------------------------------------------------------------- /MIL/vis_utils/heatmap_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import pdb 6 | import os 7 | import pandas as pd 8 | from utils.utils import * 9 | from PIL import Image 10 | from math import floor 11 | import matplotlib.pyplot as plt 12 | from datasets.wsi_dataset import Wsi_Region 13 | import h5py 14 | from wsi_core.WholeSlideImage import WholeSlideImage 15 | from scipy.stats import percentileofscore 16 | import math 17 | from utils.file_utils import save_hdf5 18 | from scipy.stats import percentileofscore 19 | 20 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | def score2percentile(score, ref): 23 | percentile = percentileofscore(ref, score) 24 | return percentile 25 | 26 | def drawHeatmap(scores, coords, slide_path=None, wsi_object=None, vis_level = -1, **kwargs): 27 | if wsi_object is None: 28 | wsi_object = WholeSlideImage(slide_path) 29 | print(wsi_object.name) 30 | 31 | wsi = wsi_object.getOpenSlide() 32 | if vis_level < 0: 33 | vis_level = wsi.get_best_level_for_downsample(32) 34 | 35 | heatmap = wsi_object.visHeatmap(scores=scores, coords=coords, vis_level=vis_level, **kwargs) 36 | return heatmap 37 | 38 | def initialize_wsi(wsi_path, seg_mask_path=None, seg_params=None, filter_params=None): 39 | wsi_object = WholeSlideImage(wsi_path) 40 | if seg_params['seg_level'] < 0: 41 | best_level = wsi_object.wsi.get_best_level_for_downsample(32) 42 | seg_params['seg_level'] = best_level 43 | 44 | wsi_object.segmentTissue(**seg_params, filter_params=filter_params) 45 | wsi_object.saveSegmentation(seg_mask_path) 46 | return wsi_object 47 | 48 | def compute_from_patches(args, wsi_object, clam_pred=None, model=None, feature_extractor=None, batch_size=512, 49 | attn_save_path=None, ref_scores=None, feat_save_path=None, **wsi_kwargs): 50 | top_left = wsi_kwargs['top_left'] 51 | bot_right = wsi_kwargs['bot_right'] 52 | patch_size = wsi_kwargs['patch_size'] 53 | 54 | roi_dataset = Wsi_Region(wsi_object, **wsi_kwargs) 55 | roi_loader = get_simple_loader(roi_dataset, batch_size=batch_size, num_workers=8) 56 | print('total number of patches to process: ', len(roi_dataset)) 57 | num_batches = len(roi_loader) 58 | print('number of batches: ', len(roi_loader)) 59 | mode = "w" 60 | for idx, (roi, coords) in enumerate(roi_loader): 61 | roi = roi.to(device) 62 | coords = coords.numpy() 63 | 64 | with torch.no_grad(): 65 | if "vit" in args['arch']: 66 | intermediate_output = feature_extractor.get_intermediate_layers(roi, args['n_last_blocks']) 67 | features = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) 68 | if args['avgpool_patchtokens']: 69 | features = torch.cat((features.unsqueeze(-1), 70 | torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1) 71 | features = features.reshape(features.shape[0], -1) 72 | else: 73 | features = feature_extractor(roi) 74 | 75 | 76 | 77 | if attn_save_path is not None and False: 78 | A = model(features, attention_only=True) 79 | 80 | if A.size(0) > 1: #CLAM multi-branch attention 81 | A = A[clam_pred] 82 | 83 | A = A.view(-1, 1).cpu().numpy() 84 | 85 | if ref_scores is not None: 86 | for score_idx in range(len(A)): 87 | A[score_idx] = score2percentile(A[score_idx], ref_scores) 88 | 89 | asset_dict = {'attention_scores': A, 'coords': coords} 90 | save_path = save_hdf5(attn_save_path, asset_dict, mode=mode) 91 | 92 | if idx % math.ceil(num_batches * 0.05) == 0: 93 | print('processed {} / {}'.format(idx, num_batches)) 94 | 95 | if feat_save_path is not None: 96 | asset_dict = {'features': features.cpu().numpy(), 'coords': coords} 97 | save_hdf5(feat_save_path, asset_dict, mode=mode) 98 | 99 | mode = "a" 100 | return attn_save_path, feat_save_path, wsi_object 101 | -------------------------------------------------------------------------------- /MIL/wsi_core/batch_process_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import pdb 4 | 5 | ''' 6 | initiate a pandas df describing a list of slides to process 7 | args: 8 | slides (df or array-like): 9 | array-like structure containing list of slide ids, if df, these ids assumed to be 10 | stored under the 'slide_id' column 11 | seg_params (dict): segmentation paramters 12 | filter_params (dict): filter parameters 13 | vis_params (dict): visualization paramters 14 | patch_params (dict): patching paramters 15 | use_heatmap_args (bool): whether to include heatmap arguments such as ROI coordinates 16 | ''' 17 | def initialize_df(slides, seg_params, filter_params, vis_params, patch_params, 18 | use_heatmap_args=False, save_patches=False): 19 | 20 | total = len(slides) 21 | if isinstance(slides, pd.DataFrame): 22 | slide_ids = slides.slide_id.values 23 | else: 24 | slide_ids = slides 25 | default_df_dict = {'slide_id': slide_ids, 'process': np.full((total), 1, dtype=np.uint8)} 26 | 27 | # initiate empty labels in case not provided 28 | if use_heatmap_args: 29 | default_df_dict.update({'label': np.full((total), -1)}) 30 | 31 | default_df_dict.update({ 32 | 'status': np.full((total), 'tbp'), 33 | # seg params 34 | 'seg_level': np.full((total), int(seg_params['seg_level']), dtype=np.int8), 35 | 'sthresh': np.full((total), int(seg_params['sthresh']), dtype=np.uint8), 36 | 'mthresh': np.full((total), int(seg_params['mthresh']), dtype=np.uint8), 37 | 'close': np.full((total), int(seg_params['close']), dtype=np.uint32), 38 | 'use_otsu': np.full((total), bool(seg_params['use_otsu']), dtype=bool), 39 | 'keep_ids': np.full((total), seg_params['keep_ids']), 40 | 'exclude_ids': np.full((total), seg_params['exclude_ids']), 41 | 42 | # filter params 43 | 'a_t': np.full((total), int(filter_params['a_t']), dtype=np.float32), 44 | 'a_h': np.full((total), int(filter_params['a_h']), dtype=np.float32), 45 | 'max_n_holes': np.full((total), int(filter_params['max_n_holes']), dtype=np.uint32), 46 | 47 | # vis params 48 | 'vis_level': np.full((total), int(vis_params['vis_level']), dtype=np.int8), 49 | 'line_thickness': np.full((total), int(vis_params['line_thickness']), dtype=np.uint32), 50 | 51 | # patching params 52 | 'use_padding': np.full((total), bool(patch_params['use_padding']), dtype=bool), 53 | 'contour_fn': np.full((total), patch_params['contour_fn']) 54 | }) 55 | 56 | if save_patches: 57 | default_df_dict.update({ 58 | 'white_thresh': np.full((total), int(patch_params['white_thresh']), dtype=np.uint8), 59 | 'black_thresh': np.full((total), int(patch_params['black_thresh']), dtype=np.uint8)}) 60 | 61 | if use_heatmap_args: 62 | # initiate empty x,y coordinates in case not provided 63 | default_df_dict.update({'x1': np.empty((total)).fill(np.NaN), 64 | 'x2': np.empty((total)).fill(np.NaN), 65 | 'y1': np.empty((total)).fill(np.NaN), 66 | 'y2': np.empty((total)).fill(np.NaN)}) 67 | 68 | 69 | if isinstance(slides, pd.DataFrame): 70 | temp_copy = pd.DataFrame(default_df_dict) # temporary dataframe w/ default params 71 | # find key in provided df 72 | # if exist, fill empty fields w/ default values, else, insert the default values as a new column 73 | for key in default_df_dict.keys(): 74 | if key in slides.columns: 75 | mask = slides[key].isna() 76 | slides.loc[mask, key] = temp_copy.loc[mask, key] 77 | else: 78 | slides.insert(len(slides.columns), key, default_df_dict[key]) 79 | else: 80 | slides = pd.DataFrame(default_df_dict) 81 | 82 | return slides -------------------------------------------------------------------------------- /MIL/wsi_core/util_classes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import pdb 5 | import cv2 6 | class Mosaic_Canvas(object): 7 | def __init__(self,patch_size=256, n=100, downscale=4, n_per_row=10, bg_color=(0,0,0), alpha=-1): 8 | self.patch_size = patch_size 9 | self.downscaled_patch_size = int(np.ceil(patch_size/downscale)) 10 | self.n_rows = int(np.ceil(n / n_per_row)) 11 | self.n_cols = n_per_row 12 | w = self.n_cols * self.downscaled_patch_size 13 | h = self.n_rows * self.downscaled_patch_size 14 | if alpha < 0: 15 | canvas = Image.new(size=(w,h), mode="RGB", color=bg_color) 16 | else: 17 | canvas = Image.new(size=(w,h), mode="RGBA", color=bg_color + (int(255 * alpha),)) 18 | 19 | self.canvas = canvas 20 | self.dimensions = np.array([w, h]) 21 | self.reset_coord() 22 | 23 | def reset_coord(self): 24 | self.coord = np.array([0, 0]) 25 | 26 | def increment_coord(self): 27 | #print('current coord: {} x {} / {} x {}'.format(self.coord[0], self.coord[1], self.dimensions[0], self.dimensions[1])) 28 | assert np.all(self.coord<=self.dimensions) 29 | if self.coord[0] + self.downscaled_patch_size <=self.dimensions[0] - self.downscaled_patch_size: 30 | self.coord[0]+=self.downscaled_patch_size 31 | else: 32 | self.coord[0] = 0 33 | self.coord[1]+=self.downscaled_patch_size 34 | 35 | 36 | def save(self, save_path, **kwargs): 37 | self.canvas.save(save_path, **kwargs) 38 | 39 | def paste_patch(self, patch): 40 | assert patch.size[0] == self.patch_size 41 | assert patch.size[1] == self.patch_size 42 | self.canvas.paste(patch.resize(tuple([self.downscaled_patch_size, self.downscaled_patch_size])), tuple(self.coord)) 43 | self.increment_coord() 44 | 45 | def get_painting(self): 46 | return self.canvas 47 | 48 | class Contour_Checking_fn(object): 49 | # Defining __call__ method 50 | def __call__(self, pt): 51 | raise NotImplementedError 52 | 53 | class isInContourV1(Contour_Checking_fn): 54 | def __init__(self, contour): 55 | self.cont = contour 56 | 57 | def __call__(self, pt): 58 | return 1 if cv2.pointPolygonTest(self.cont, tuple(np.array(pt).astype(float)), False) >= 0 else 0 59 | 60 | class isInContourV2(Contour_Checking_fn): 61 | def __init__(self, contour, patch_size): 62 | self.cont = contour 63 | self.patch_size = patch_size 64 | 65 | def __call__(self, pt): 66 | pt = np.array((pt[0]+self.patch_size//2, pt[1]+self.patch_size//2)).astype(float) 67 | return 1 if cv2.pointPolygonTest(self.cont, tuple(np.array(pt).astype(float)), False) >= 0 else 0 68 | 69 | # Easy version of 4pt contour checking function - 1 of 4 points need to be in the contour for test to pass 70 | class isInContourV3_Easy(Contour_Checking_fn): 71 | def __init__(self, contour, patch_size, center_shift=0.5): 72 | self.cont = contour 73 | self.patch_size = patch_size 74 | self.shift = int(patch_size//2*center_shift) 75 | def __call__(self, pt): 76 | center = (pt[0]+self.patch_size//2, pt[1]+self.patch_size//2) 77 | if self.shift > 0: 78 | all_points = [(center[0]-self.shift, center[1]-self.shift), 79 | (center[0]+self.shift, center[1]+self.shift), 80 | (center[0]+self.shift, center[1]-self.shift), 81 | (center[0]-self.shift, center[1]+self.shift) 82 | ] 83 | else: 84 | all_points = [center] 85 | 86 | for points in all_points: 87 | if cv2.pointPolygonTest(self.cont, tuple(np.array(points).astype(float)), False) >= 0: 88 | return 1 89 | return 0 90 | 91 | # Hard version of 4pt contour checking function - all 4 points need to be in the contour for test to pass 92 | class isInContourV3_Hard(Contour_Checking_fn): 93 | def __init__(self, contour, patch_size, center_shift=0.5): 94 | self.cont = contour 95 | self.patch_size = patch_size 96 | self.shift = int(patch_size//2*center_shift) 97 | def __call__(self, pt): 98 | center = (pt[0]+self.patch_size//2, pt[1]+self.patch_size//2) 99 | if self.shift > 0: 100 | all_points = [(center[0]-self.shift, center[1]-self.shift), 101 | (center[0]+self.shift, center[1]+self.shift), 102 | (center[0]+self.shift, center[1]-self.shift), 103 | (center[0]-self.shift, center[1]+self.shift) 104 | ] 105 | else: 106 | all_points = [center] 107 | 108 | for points in all_points: 109 | if cv2.pointPolygonTest(self.cont, tuple(np.array(points).astype(float)), False) < 0: 110 | return 0 111 | return 1 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/__init__.py -------------------------------------------------------------------------------- /dino/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/__init__.py -------------------------------------------------------------------------------- /dino/causal-conv1d/=1.1.0: -------------------------------------------------------------------------------- 1 | Defaulting to user installation because normal site-packages is not writeable 2 | Obtaining file:///home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d/causal_conv1d 3 | -------------------------------------------------------------------------------- /dino/causal-conv1d/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | -------------------------------------------------------------------------------- /dino/causal-conv1d/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /dino/causal-conv1d/README.md: -------------------------------------------------------------------------------- 1 | # Causal depthwise conv1d in CUDA with a PyTorch interface 2 | -------------------------------------------------------------------------------- /dino/causal-conv1d/build/lib.linux-x86_64-cpython-310/causal_conv1d/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update 4 | -------------------------------------------------------------------------------- /dino/causal-conv1d/build/lib.linux-x86_64-cpython-310/causal_conv1d/causal_conv1d_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | import causal_conv1d_cuda 8 | 9 | 10 | class CausalConv1dFn(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, x, weight, bias=None, activation=None): 13 | if activation not in [None, "silu", "swish"]: 14 | raise NotImplementedError("activation must be None, silu, or swish") 15 | if x.stride(2) != 1 and x.stride(1) != 1: 16 | x = x.contiguous() 17 | bias = bias.contiguous() if bias is not None else None 18 | ctx.save_for_backward(x, weight, bias) 19 | ctx.activation = activation in ["silu", "swish"] 20 | out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation) 21 | return out 22 | 23 | @staticmethod 24 | def backward(ctx, dout): 25 | x, weight, bias = ctx.saved_tensors 26 | if dout.stride(2) != 1 and dout.stride(1) != 1: 27 | dout = dout.contiguous() 28 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the 29 | # backward of conv1d with the backward of chunk). 30 | # Here we just pass in None and dx will be allocated in the C++ code. 31 | dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd( 32 | x, weight, bias, dout, None, ctx.activation 33 | ) 34 | return dx, dweight, dbias if bias is not None else None, None 35 | 36 | 37 | def causal_conv1d_fn(x, weight, bias=None, activation=None): 38 | """ 39 | x: (batch, dim, seqlen) 40 | weight: (dim, width) 41 | bias: (dim,) 42 | activation: either None or "silu" or "swish" 43 | 44 | out: (batch, dim, seqlen) 45 | """ 46 | return CausalConv1dFn.apply(x, weight, bias, activation) 47 | 48 | 49 | def causal_conv1d_ref(x, weight, bias=None, activation=None): 50 | """ 51 | x: (batch, dim, seqlen) 52 | weight: (dim, width) 53 | bias: (dim,) 54 | 55 | out: (batch, dim, seqlen) 56 | """ 57 | if activation not in [None, "silu", "swish"]: 58 | raise NotImplementedError("activation must be None, silu, or swish") 59 | dtype_in = x.dtype 60 | x = x.to(weight.dtype) 61 | seqlen = x.shape[-1] 62 | dim, width = weight.shape 63 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) 64 | out = out[..., :seqlen] 65 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 66 | 67 | 68 | def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): 69 | """ 70 | x: (batch, dim) 71 | conv_state: (batch, dim, width) 72 | weight: (dim, width) 73 | bias: (dim,) 74 | 75 | out: (batch, dim) 76 | """ 77 | if activation not in [None, "silu", "swish"]: 78 | raise NotImplementedError("activation must be None, silu, or swish") 79 | activation = activation in ["silu", "swish"] 80 | return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation) 81 | 82 | 83 | def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None): 84 | """ 85 | x: (batch, dim) 86 | conv_state: (batch, dim, width) 87 | weight: (dim, width) 88 | bias: (dim,) 89 | 90 | out: (batch, dim) 91 | """ 92 | if activation not in [None, "silu", "swish"]: 93 | raise NotImplementedError("activation must be None, silu, or swish") 94 | dtype_in = x.dtype 95 | batch, dim = x.shape 96 | width = weight.shape[1] 97 | assert conv_state.shape == (batch, dim, width) 98 | assert weight.shape == (dim, width) 99 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 100 | conv_state[:, :, -1] = x 101 | out = torch.sum(conv_state * weight, dim=-1) # (B D) 102 | if bias is not None: 103 | out += bias 104 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 105 | -------------------------------------------------------------------------------- /dino/causal-conv1d/build/lib.linux-x86_64-cpython-310/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/causal-conv1d/build/lib.linux-x86_64-cpython-310/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/.ninja_deps -------------------------------------------------------------------------------- /dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 2 3899 1708464862707292896 /home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d_update.o 7773c9b1419bbd99 3 | 2 8863 1708464867647297838 /home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d_fwd.o c0bf9e706b098ad2 4 | 1 23753 1708464882443312635 /home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d_bwd.o 7cd1b8307e16c275 5 | 1 28015 1708464886819317010 /home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d.o 932ee24097499248 6 | -------------------------------------------------------------------------------- /dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /usr/local/pkg/cuda/12.3.2/root/bin/nvcc 4 | 5 | cflags = -pthread -B /usr/local/pkg/anaconda/v3.2023.03/root/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /usr/local/pkg/anaconda/v3.2023.03/root/include -fPIC -O2 -isystem /usr/local/pkg/anaconda/v3.2023.03/root/include -fPIC -I/home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d -I/home/a_n29343/.local/lib/python3.10/site-packages/torch/include -I/home/a_n29343/.local/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/a_n29343/.local/lib/python3.10/site-packages/torch/include/TH -I/home/a_n29343/.local/lib/python3.10/site-packages/torch/include/THC -I/usr/local/pkg/cuda/12.3.2/root/include -I/usr/local/pkg/anaconda/v3.2023.03/root/include/python3.10 -c 6 | post_cflags = -O3 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=causal_conv1d_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17 7 | cuda_cflags = -I/home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d -I/home/a_n29343/.local/lib/python3.10/site-packages/torch/include -I/home/a_n29343/.local/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/a_n29343/.local/lib/python3.10/site-packages/torch/include/TH -I/home/a_n29343/.local/lib/python3.10/site-packages/torch/include/THC -I/usr/local/pkg/cuda/12.3.2/root/include -I/usr/local/pkg/anaconda/v3.2023.03/root/include/python3.10 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_BFLOAT16_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ -U__CUDA_NO_BFLOAT162_OPERATORS__ -U__CUDA_NO_BFLOAT162_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math --ptxas-options=-v -lineinfo -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90 --threads 4 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=causal_conv1d_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17 9 | cuda_dlink_post_cflags = 10 | ldflags = 11 | 12 | rule compile 13 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 14 | depfile = $out.d 15 | deps = gcc 16 | 17 | rule cuda_compile 18 | depfile = $out.d 19 | deps = gcc 20 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 21 | 22 | 23 | 24 | 25 | 26 | build /home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d.o: compile /home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d/csrc/causal_conv1d.cpp 27 | build /home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d_bwd.o: cuda_compile /home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d/csrc/causal_conv1d_bwd.cu 28 | build /home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d_fwd.o: cuda_compile /home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d/csrc/causal_conv1d_fwd.cu 29 | build /home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d_update.o: cuda_compile /home/a_n29343/CHUM/VIM4Path/dino/causal-conv1d/csrc/causal_conv1d_update.cu 30 | 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d.o -------------------------------------------------------------------------------- /dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d_bwd.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d_bwd.o -------------------------------------------------------------------------------- /dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d_fwd.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d_fwd.o -------------------------------------------------------------------------------- /dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d_update.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/causal-conv1d/build/temp.linux-x86_64-cpython-310/csrc/causal_conv1d_update.o -------------------------------------------------------------------------------- /dino/causal-conv1d/causal_conv1d-1.0.0+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/causal-conv1d/causal_conv1d-1.0.0+cu122torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl -------------------------------------------------------------------------------- /dino/causal-conv1d/causal_conv1d.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: causal_conv1d 3 | Version: 1.0.0 4 | Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface 5 | Home-page: https://github.com/Dao-AILab/causal-conv1d 6 | Author: Tri Dao 7 | Author-email: tri@tridao.me 8 | Classifier: Programming Language :: Python :: 3 9 | Classifier: License :: OSI Approved :: BSD License 10 | Classifier: Operating System :: Unix 11 | Requires-Python: >=3.7 12 | Description-Content-Type: text/markdown 13 | License-File: LICENSE 14 | License-File: AUTHORS 15 | Requires-Dist: torch 16 | Requires-Dist: packaging 17 | Requires-Dist: ninja 18 | 19 | # Causal depthwise conv1d in CUDA with a PyTorch interface 20 | -------------------------------------------------------------------------------- /dino/causal-conv1d/causal_conv1d/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update 4 | -------------------------------------------------------------------------------- /dino/causal-conv1d/causal_conv1d/causal_conv1d_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | import causal_conv1d_cuda 8 | 9 | 10 | class CausalConv1dFn(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, x, weight, bias=None, activation=None): 13 | if activation not in [None, "silu", "swish"]: 14 | raise NotImplementedError("activation must be None, silu, or swish") 15 | if x.stride(2) != 1 and x.stride(1) != 1: 16 | x = x.contiguous() 17 | bias = bias.contiguous() if bias is not None else None 18 | ctx.save_for_backward(x, weight, bias) 19 | ctx.activation = activation in ["silu", "swish"] 20 | out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation) 21 | return out 22 | 23 | @staticmethod 24 | def backward(ctx, dout): 25 | x, weight, bias = ctx.saved_tensors 26 | if dout.stride(2) != 1 and dout.stride(1) != 1: 27 | dout = dout.contiguous() 28 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the 29 | # backward of conv1d with the backward of chunk). 30 | # Here we just pass in None and dx will be allocated in the C++ code. 31 | dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd( 32 | x, weight, bias, dout, None, ctx.activation 33 | ) 34 | return dx, dweight, dbias if bias is not None else None, None 35 | 36 | 37 | def causal_conv1d_fn(x, weight, bias=None, activation=None): 38 | """ 39 | x: (batch, dim, seqlen) 40 | weight: (dim, width) 41 | bias: (dim,) 42 | activation: either None or "silu" or "swish" 43 | 44 | out: (batch, dim, seqlen) 45 | """ 46 | return CausalConv1dFn.apply(x, weight, bias, activation) 47 | 48 | 49 | def causal_conv1d_ref(x, weight, bias=None, activation=None): 50 | """ 51 | x: (batch, dim, seqlen) 52 | weight: (dim, width) 53 | bias: (dim,) 54 | 55 | out: (batch, dim, seqlen) 56 | """ 57 | if activation not in [None, "silu", "swish"]: 58 | raise NotImplementedError("activation must be None, silu, or swish") 59 | dtype_in = x.dtype 60 | x = x.to(weight.dtype) 61 | seqlen = x.shape[-1] 62 | dim, width = weight.shape 63 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) 64 | out = out[..., :seqlen] 65 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 66 | 67 | 68 | def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): 69 | """ 70 | x: (batch, dim) 71 | conv_state: (batch, dim, width) 72 | weight: (dim, width) 73 | bias: (dim,) 74 | 75 | out: (batch, dim) 76 | """ 77 | if activation not in [None, "silu", "swish"]: 78 | raise NotImplementedError("activation must be None, silu, or swish") 79 | activation = activation in ["silu", "swish"] 80 | return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation) 81 | 82 | 83 | def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None): 84 | """ 85 | x: (batch, dim) 86 | conv_state: (batch, dim, width) 87 | weight: (dim, width) 88 | bias: (dim,) 89 | 90 | out: (batch, dim) 91 | """ 92 | if activation not in [None, "silu", "swish"]: 93 | raise NotImplementedError("activation must be None, silu, or swish") 94 | dtype_in = x.dtype 95 | batch, dim = x.shape 96 | width = weight.shape[1] 97 | assert conv_state.shape == (batch, dim, width) 98 | assert weight.shape == (dim, width) 99 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 100 | conv_state[:, :, -1] = x 101 | out = torch.sum(conv_state * weight, dim=-1) # (B D) 102 | if bias is not None: 103 | out += bias 104 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 105 | -------------------------------------------------------------------------------- /dino/causal-conv1d/csrc/causal_conv1d.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct ConvParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, dim, seqlen, width; 13 | bool silu_activation; 14 | 15 | index_t x_batch_stride; 16 | index_t x_c_stride; 17 | index_t x_l_stride; 18 | index_t weight_c_stride; 19 | index_t weight_width_stride; 20 | index_t out_batch_stride; 21 | index_t out_c_stride; 22 | index_t out_l_stride; 23 | 24 | index_t conv_state_batch_stride; 25 | index_t conv_state_c_stride; 26 | index_t conv_state_l_stride; 27 | 28 | // Common data pointers. 29 | void *__restrict__ x_ptr; 30 | void *__restrict__ weight_ptr; 31 | void *__restrict__ bias_ptr; 32 | void *__restrict__ out_ptr; 33 | 34 | void *__restrict__ conv_state_ptr; 35 | }; 36 | 37 | struct ConvParamsBwd: public ConvParamsBase { 38 | index_t dx_batch_stride; 39 | index_t dx_c_stride; 40 | index_t dx_l_stride; 41 | index_t dweight_c_stride; 42 | index_t dweight_width_stride; 43 | index_t dout_batch_stride; 44 | index_t dout_c_stride; 45 | index_t dout_l_stride; 46 | 47 | // Common data pointers. 48 | void *__restrict__ dx_ptr; 49 | void *__restrict__ dweight_ptr; 50 | void *__restrict__ dbias_ptr; 51 | void *__restrict__ dout_ptr; 52 | }; 53 | 54 | -------------------------------------------------------------------------------- /dino/causal-conv1d/csrc/causal_conv1d_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | //////////////////////////////////////////////////////////////////////////////////////////////////// 11 | 12 | template struct BytesToType {}; 13 | 14 | template<> struct BytesToType<16> { 15 | using Type = uint4; 16 | static_assert(sizeof(Type) == 16); 17 | }; 18 | 19 | template<> struct BytesToType<8> { 20 | using Type = uint64_t; 21 | static_assert(sizeof(Type) == 8); 22 | }; 23 | 24 | template<> struct BytesToType<4> { 25 | using Type = uint32_t; 26 | static_assert(sizeof(Type) == 4); 27 | }; 28 | 29 | template<> struct BytesToType<2> { 30 | using Type = uint16_t; 31 | static_assert(sizeof(Type) == 2); 32 | }; 33 | 34 | template<> struct BytesToType<1> { 35 | using Type = uint8_t; 36 | static_assert(sizeof(Type) == 1); 37 | }; 38 | 39 | //////////////////////////////////////////////////////////////////////////////////////////////////// 40 | 41 | template 42 | struct SumOp { 43 | __device__ inline T operator()(T const & x, T const & y) { return x + y; } 44 | }; 45 | 46 | template 47 | struct Allreduce { 48 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); 49 | template 50 | static __device__ inline T run(T x, Operator &op) { 51 | constexpr int OFFSET = THREADS / 2; 52 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); 53 | return Allreduce::run(x, op); 54 | } 55 | }; 56 | 57 | template<> 58 | struct Allreduce<2> { 59 | template 60 | static __device__ inline T run(T x, Operator &op) { 61 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); 62 | return x; 63 | } 64 | }; 65 | -------------------------------------------------------------------------------- /dino/causal-conv1d/csrc/causal_conv1d_update.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 8 | 9 | #include 10 | #include 11 | 12 | #include "causal_conv1d.h" 13 | #include "causal_conv1d_common.h" 14 | #include "static_switch.h" 15 | 16 | template 17 | struct Causal_conv1d_update_kernel_traits { 18 | using input_t = input_t_; 19 | using weight_t = weight_t_; 20 | static constexpr int kNThreads = kNThreads_; 21 | static constexpr int kWidth = kWidth_; 22 | static constexpr int kNBytes = sizeof(input_t); 23 | static_assert(kNBytes == 2 || kNBytes == 4); 24 | }; 25 | 26 | template 27 | __global__ __launch_bounds__(Ktraits::kNThreads) 28 | void causal_conv1d_update_kernel(ConvParamsBase params) { 29 | constexpr int kWidth = Ktraits::kWidth; 30 | constexpr int kNThreads = Ktraits::kNThreads; 31 | using input_t = typename Ktraits::input_t; 32 | using weight_t = typename Ktraits::weight_t; 33 | 34 | const int tidx = threadIdx.x; 35 | const int batch_id = blockIdx.x; 36 | const int channel_id = blockIdx.y * kNThreads + tidx; 37 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride 38 | + channel_id * params.x_c_stride; 39 | input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride 40 | + channel_id * params.conv_state_c_stride; 41 | weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; 42 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 43 | + channel_id * params.out_c_stride; 44 | float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); 45 | 46 | float weight_vals[kWidth] = {0}; 47 | if (channel_id < params.dim) { 48 | #pragma unroll 49 | for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } 50 | } 51 | 52 | float x_vals[kWidth] = {0}; 53 | if (channel_id < params.dim) { 54 | #pragma unroll 55 | for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } 56 | x_vals[kWidth - 1] = float(x[0]); 57 | #pragma unroll 58 | for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } 59 | } 60 | 61 | float out_val = bias_val; 62 | #pragma unroll 63 | for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } 64 | if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } 65 | if (channel_id < params.dim) { out[0] = input_t(out_val); } 66 | } 67 | 68 | template 69 | void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { 70 | using Ktraits = Causal_conv1d_update_kernel_traits; 71 | dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); 72 | auto kernel = &causal_conv1d_update_kernel; 73 | kernel<<>>(params); 74 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 75 | } 76 | 77 | template 78 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { 79 | if (params.width == 2) { 80 | causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); 81 | } else if (params.width == 3) { 82 | causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); 83 | } else if (params.width == 4) { 84 | causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); 85 | } 86 | } 87 | 88 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 89 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 90 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 91 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 92 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 93 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 94 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 95 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 96 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /dino/causal-conv1d/csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | static constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | static constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /dino/config.py: -------------------------------------------------------------------------------- 1 | configurations = { 2 | "vim-t": { 3 | "img_size": 224, 4 | "patch_size": 16, 5 | "stride": 16, 6 | "embed_dim": 192, 7 | "depth": 24, 8 | "rms_norm": True, 9 | "attn_drop_rate": 0.0, 10 | "drop_path_rate": 0.1, 11 | "residual_in_fp32": True, 12 | "fused_add_norm": True, 13 | "final_pool_type": "mean", 14 | "if_abs_pos_embed": True, 15 | "if_rope": False, 16 | "if_rope_residual": False, 17 | "if_cls_token": True, 18 | "if_devide_out": True, 19 | "use_middle_cls_token": True, 20 | "bimamba_type": "v2" 21 | }, 22 | "vim-t-plus": { 23 | "img_size": 224, 24 | "patch_size": 16, 25 | "stride": 16, 26 | "embed_dim": 384, 27 | "depth": 12, 28 | "rms_norm": True, 29 | "attn_drop_rate": 0.0, 30 | "drop_path_rate": 0.1, 31 | "residual_in_fp32": True, 32 | "fused_add_norm": True, 33 | "final_pool_type": "mean", 34 | "if_abs_pos_embed": True, 35 | "if_rope": False, 36 | "if_rope_residual": False, 37 | "if_cls_token": True, 38 | "if_devide_out": True, 39 | "use_middle_cls_token": True, 40 | "bimamba_type": "v2" 41 | }, 42 | "vim-s": { 43 | "img_size": 224, 44 | "patch_size": 16, 45 | "stride": 16, 46 | "embed_dim": 384, 47 | "depth": 24, 48 | "rms_norm": True, 49 | "attn_drop_rate": 0.0, 50 | "drop_path_rate": 0.1, 51 | "residual_in_fp32": True, 52 | "fused_add_norm": True, 53 | "final_pool_type": "mean", 54 | "if_abs_pos_embed": True, 55 | "if_rope": False, 56 | "if_rope_residual": False, 57 | "if_cls_token": True, 58 | "if_devide_out": True, 59 | "use_middle_cls_token": True, 60 | "bimamba_type": "v2" 61 | }, 62 | "vit-t": { 63 | "img_size": 512, 64 | "patch_size": 16, 65 | "in_chans": 3, 66 | "num_classes": 2, 67 | "embed_dim": 192, 68 | "depth": 12, 69 | "num_heads": 3, 70 | "mlp_ratio": 4, 71 | "qkv_bias": True, 72 | "qk_scale": None, 73 | "drop_rate": 0.0, 74 | "attn_drop_rate": 0.0, 75 | "drop_path_rate": 0.1, 76 | "norm_layer": "nn.LayerNorm", 77 | "eps": 1e-6 78 | }, 79 | "vit-s": { 80 | "img_size": 512, 81 | "patch_size": 16, 82 | "in_chans": 3, 83 | "num_classes": 2, 84 | "embed_dim": 384, 85 | "depth": 12, 86 | "num_heads": 6, 87 | "mlp_ratio": 4, 88 | "qkv_bias": True, 89 | "qk_scale": None, 90 | "drop_rate": 0.0, 91 | "attn_drop_rate": 0.0, 92 | "drop_path_rate": 0.1, 93 | "norm_layer": "nn.LayerNorm", 94 | "eps": 1e-6 95 | } 96 | } -------------------------------------------------------------------------------- /dino/mamba-1p1p1/.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__/ 2 | *.egg-info/ 3 | build/ 4 | **.so 5 | -------------------------------------------------------------------------------- /dino/mamba-1p1p1/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/lm-evaluation-harness"] 2 | path = 3rdparty/lm-evaluation-harness 3 | url = https://github.com/EleutherAI/lm-evaluation-harness/ 4 | -------------------------------------------------------------------------------- /dino/mamba-1p1p1/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | Albert Gu, agu@andrew.cmu.edu 3 | -------------------------------------------------------------------------------- /dino/mamba-1p1p1/benchmarks/benchmark_generation_mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import argparse 4 | import time 5 | import json 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange 11 | 12 | from transformers import AutoTokenizer, AutoModelForCausalLM 13 | 14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Generation benchmarking") 18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") 19 | parser.add_argument("--prompt", type=str, default=None) 20 | parser.add_argument("--promptlen", type=int, default=100) 21 | parser.add_argument("--genlen", type=int, default=100) 22 | parser.add_argument("--temperature", type=float, default=1.0) 23 | parser.add_argument("--topk", type=int, default=1) 24 | parser.add_argument("--topp", type=float, default=1.0) 25 | parser.add_argument("--repetition-penalty", type=float, default=1.0) 26 | parser.add_argument("--batch", type=int, default=1) 27 | args = parser.parse_args() 28 | 29 | repeats = 3 30 | device = "cuda" 31 | dtype = torch.float16 32 | 33 | print(f"Loading model {args.model_name}") 34 | is_mamba = args.model_name.startswith("state-spaces/mamba-") 35 | if is_mamba: 36 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 37 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) 38 | else: 39 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 40 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype) 41 | model.eval() 42 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 43 | 44 | torch.random.manual_seed(0) 45 | if args.prompt is None: 46 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") 47 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") 48 | else: 49 | tokens = tokenizer(args.prompt, return_tensors="pt") 50 | input_ids = tokens.input_ids.to(device=device) 51 | attn_mask = tokens.attention_mask.to(device=device) 52 | max_length = input_ids.shape[1] + args.genlen 53 | 54 | if is_mamba: 55 | fn = lambda: model.generate( 56 | input_ids=input_ids, 57 | max_length=max_length, 58 | cg=True, 59 | return_dict_in_generate=True, 60 | output_scores=True, 61 | enable_timing=False, 62 | temperature=args.temperature, 63 | top_k=args.topk, 64 | top_p=args.topp, 65 | repetition_penalty=args.repetition_penalty, 66 | ) 67 | else: 68 | fn = lambda: model.generate( 69 | input_ids=input_ids, 70 | attention_mask=attn_mask, 71 | max_length=max_length, 72 | return_dict_in_generate=True, 73 | pad_token_id=tokenizer.eos_token_id, 74 | do_sample=True, 75 | temperature=args.temperature, 76 | top_k=args.topk, 77 | top_p=args.topp, 78 | repetition_penalty=args.repetition_penalty, 79 | ) 80 | out = fn() 81 | if args.prompt is not None: 82 | print(tokenizer.batch_decode(out.sequences.tolist())) 83 | 84 | torch.cuda.synchronize() 85 | start = time.time() 86 | for _ in range(repeats): 87 | fn() 88 | torch.cuda.synchronize() 89 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") 90 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") 91 | -------------------------------------------------------------------------------- /dino/mamba-1p1p1/csrc/selective_scan/selective_scan.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, dstate, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | bool is_variable_B; 32 | bool is_variable_C; 33 | 34 | bool delta_softplus; 35 | 36 | index_t A_d_stride; 37 | index_t A_dstate_stride; 38 | index_t B_batch_stride; 39 | index_t B_d_stride; 40 | index_t B_dstate_stride; 41 | index_t B_group_stride; 42 | index_t C_batch_stride; 43 | index_t C_d_stride; 44 | index_t C_dstate_stride; 45 | index_t C_group_stride; 46 | index_t u_batch_stride; 47 | index_t u_d_stride; 48 | index_t delta_batch_stride; 49 | index_t delta_d_stride; 50 | index_t z_batch_stride; 51 | index_t z_d_stride; 52 | index_t out_batch_stride; 53 | index_t out_d_stride; 54 | index_t out_z_batch_stride; 55 | index_t out_z_d_stride; 56 | 57 | // Common data pointers. 58 | void *__restrict__ A_ptr; 59 | void *__restrict__ B_ptr; 60 | void *__restrict__ C_ptr; 61 | void *__restrict__ D_ptr; 62 | void *__restrict__ u_ptr; 63 | void *__restrict__ delta_ptr; 64 | void *__restrict__ delta_bias_ptr; 65 | void *__restrict__ out_ptr; 66 | void *__restrict__ x_ptr; 67 | void *__restrict__ z_ptr; 68 | void *__restrict__ out_z_ptr; 69 | }; 70 | 71 | struct SSMParamsBwd: public SSMParamsBase { 72 | index_t dout_batch_stride; 73 | index_t dout_d_stride; 74 | index_t dA_d_stride; 75 | index_t dA_dstate_stride; 76 | index_t dB_batch_stride; 77 | index_t dB_group_stride; 78 | index_t dB_d_stride; 79 | index_t dB_dstate_stride; 80 | index_t dC_batch_stride; 81 | index_t dC_group_stride; 82 | index_t dC_d_stride; 83 | index_t dC_dstate_stride; 84 | index_t du_batch_stride; 85 | index_t du_d_stride; 86 | index_t dz_batch_stride; 87 | index_t dz_d_stride; 88 | index_t ddelta_batch_stride; 89 | index_t ddelta_d_stride; 90 | 91 | // Common data pointers. 92 | void *__restrict__ dout_ptr; 93 | void *__restrict__ dA_ptr; 94 | void *__restrict__ dB_ptr; 95 | void *__restrict__ dC_ptr; 96 | void *__restrict__ dD_ptr; 97 | void *__restrict__ du_ptr; 98 | void *__restrict__ dz_ptr; 99 | void *__restrict__ ddelta_ptr; 100 | void *__restrict__ ddelta_bias_ptr; 101 | }; 102 | -------------------------------------------------------------------------------- /dino/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /dino/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_bf16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /dino/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /dino/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /dino/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /dino/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp32_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /dino/mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_bf16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /dino/mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_fp16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /dino/mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_fp32.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /dino/mamba-1p1p1/csrc/selective_scan/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /dino/mamba-1p1p1/csrc/selective_scan/uninitialized_copy.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | #include 31 | 32 | #include 33 | 34 | 35 | namespace detail 36 | { 37 | 38 | #if defined(_NVHPC_CUDA) 39 | template 40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 41 | { 42 | // NVBug 3384810 43 | new (ptr) T(::cuda::std::forward(val)); 44 | } 45 | #else 46 | template ::value, 50 | int 51 | >::type = 0> 52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 53 | { 54 | *ptr = ::cuda::std::forward(val); 55 | } 56 | 57 | template ::value, 61 | int 62 | >::type = 0> 63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 64 | { 65 | new (ptr) T(::cuda::std::forward(val)); 66 | } 67 | #endif 68 | 69 | } // namespace detail 70 | -------------------------------------------------------------------------------- /dino/mamba-1p1p1/evals/lm_harness_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import transformers 4 | from transformers import AutoTokenizer 5 | 6 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 7 | 8 | from lm_eval.api.model import LM 9 | from lm_eval.models.huggingface import HFLM 10 | from lm_eval.api.registry import register_model 11 | from lm_eval.__main__ import cli_evaluate 12 | 13 | 14 | @register_model("mamba") 15 | class MambaEvalWrapper(HFLM): 16 | 17 | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM 18 | 19 | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", 20 | dtype=torch.float16): 21 | LM.__init__(self) 22 | self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) 23 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 24 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 25 | self.vocab_size = self.tokenizer.vocab_size 26 | self._batch_size = int(batch_size) if batch_size is not None else 64 27 | self._max_length = max_length 28 | self._device = torch.device(device) 29 | 30 | @property 31 | def batch_size(self): 32 | return self._batch_size 33 | 34 | def _model_generate(self, context, max_length, stop, **generation_kwargs): 35 | raise NotImplementedError() 36 | 37 | 38 | if __name__ == "__main__": 39 | cli_evaluate() 40 | -------------------------------------------------------------------------------- /dino/mamba-1p1p1/mamba_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.1" 2 | 3 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn 4 | from mamba_ssm.modules.mamba_simple import Mamba 5 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 6 | -------------------------------------------------------------------------------- /dino/mamba-1p1p1/mamba_ssm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/mamba-1p1p1/mamba_ssm/models/__init__.py -------------------------------------------------------------------------------- /dino/mamba-1p1p1/mamba_ssm/models/config_mamba.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class MambaConfig: 6 | 7 | d_model: int = 2560 8 | n_layer: int = 64 9 | vocab_size: int = 50277 10 | ssm_cfg: dict = field(default_factory=dict) 11 | rms_norm: bool = True 12 | residual_in_fp32: bool = True 13 | fused_add_norm: bool = True 14 | pad_vocab_size_multiple: int = 8 15 | -------------------------------------------------------------------------------- /dino/mamba-1p1p1/mamba_ssm/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/mamba-1p1p1/mamba_ssm/modules/__init__.py -------------------------------------------------------------------------------- /dino/mamba-1p1p1/mamba_ssm/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/mamba-1p1p1/mamba_ssm/ops/__init__.py -------------------------------------------------------------------------------- /dino/mamba-1p1p1/mamba_ssm/ops/triton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/mamba-1p1p1/mamba_ssm/ops/triton/__init__.py -------------------------------------------------------------------------------- /dino/mamba-1p1p1/mamba_ssm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/mamba-1p1p1/mamba_ssm/utils/__init__.py -------------------------------------------------------------------------------- /dino/mamba-1p1p1/mamba_ssm/utils/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | 5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 6 | from transformers.utils.hub import cached_file 7 | 8 | 9 | def load_config_hf(model_name): 10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 11 | return json.load(open(resolved_archive_file)) 12 | 13 | 14 | def load_state_dict_hf(model_name, device=None, dtype=None): 15 | # If not fp32, then we don't want to load directly to the GPU 16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device 17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 18 | return torch.load(resolved_archive_file, map_location=mapped_device) 19 | # Convert dtype before moving to GPU to save memory 20 | if dtype is not None: 21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} 22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()} 23 | return state_dict 24 | -------------------------------------------------------------------------------- /dino/mamba-1p1p1/tests/ops/triton/test_selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import pytest 8 | 9 | from einops import rearrange 10 | 11 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref 12 | 13 | 14 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 15 | # @pytest.mark.parametrize('itype', [torch.float16]) 16 | @pytest.mark.parametrize("has_z", [False, True]) 17 | # @pytest.mark.parametrize('has_z', [True]) 18 | @pytest.mark.parametrize("dstate", [16, 32, 64]) 19 | # @pytest.mark.parametrize("dstate", [16]) 20 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 21 | # @pytest.mark.parametrize("dim", [2048]) 22 | def test_causal_conv1d_update(dim, dstate, has_z, itype): 23 | device = "cuda" 24 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) 25 | if itype == torch.bfloat16: 26 | rtol, atol = 1e-2, 5e-2 27 | # set seed 28 | torch.random.manual_seed(0) 29 | batch_size = 2 30 | state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) 31 | x = torch.randn(batch_size, dim, device=device, dtype=itype) 32 | dt = torch.randn(batch_size, dim, device=device, dtype=itype) 33 | dt_bias = torch.rand(dim, device=device) - 4.0 34 | A = -torch.rand(dim, dstate, device=device) - 1.0 35 | B = torch.randn(batch_size, dstate, device=device) 36 | C = torch.randn(batch_size, dstate, device=device) 37 | D = torch.randn(dim, device=device) 38 | if has_z: 39 | z = torch.randn_like(x) 40 | else: 41 | z = None 42 | state_ref = state.detach().clone() 43 | out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 44 | out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 45 | 46 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 47 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 48 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) 49 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 50 | -------------------------------------------------------------------------------- /dino/run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A script to run multinode training with submitit. 16 | Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py 17 | """ 18 | import argparse 19 | import os 20 | import uuid 21 | from pathlib import Path 22 | 23 | import main_dino 24 | import submitit 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser("Submitit for DINO", parents=[main_dino.get_args_parser()]) 29 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 30 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 31 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") 32 | 33 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 34 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") 35 | parser.add_argument('--comment', default="", type=str, 36 | help='Comment to pass to scheduler, e.g. priority message') 37 | return parser.parse_args() 38 | 39 | 40 | def get_shared_folder() -> Path: 41 | user = os.getenv("USER") 42 | if Path("/checkpoint/").is_dir(): 43 | p = Path(f"/checkpoint/{user}/experiments") 44 | p.mkdir(exist_ok=True) 45 | return p 46 | raise RuntimeError("No shared folder available") 47 | 48 | 49 | def get_init_file(): 50 | # Init file must not exist, but it's parent dir must exist. 51 | os.makedirs(str(get_shared_folder()), exist_ok=True) 52 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 53 | if init_file.exists(): 54 | os.remove(str(init_file)) 55 | return init_file 56 | 57 | 58 | class Trainer(object): 59 | def __init__(self, args): 60 | self.args = args 61 | 62 | def __call__(self): 63 | import main_dino 64 | 65 | self._setup_gpu_args() 66 | main_dino.train_dino(self.args) 67 | 68 | def checkpoint(self): 69 | import os 70 | import submitit 71 | 72 | self.args.dist_url = get_init_file().as_uri() 73 | print("Requeuing ", self.args) 74 | empty_trainer = type(self)(self.args) 75 | return submitit.helpers.DelayedSubmission(empty_trainer) 76 | 77 | def _setup_gpu_args(self): 78 | import submitit 79 | from pathlib import Path 80 | 81 | job_env = submitit.JobEnvironment() 82 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 83 | self.args.gpu = job_env.local_rank 84 | self.args.rank = job_env.global_rank 85 | self.args.world_size = job_env.num_tasks 86 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 87 | 88 | 89 | def main(): 90 | args = parse_args() 91 | if args.output_dir == "": 92 | args.output_dir = get_shared_folder() / "%j" 93 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 94 | executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) 95 | 96 | num_gpus_per_node = args.ngpus 97 | nodes = args.nodes 98 | timeout_min = args.timeout 99 | 100 | partition = args.partition 101 | kwargs = {} 102 | if args.use_volta32: 103 | kwargs['slurm_constraint'] = 'volta32gb' 104 | if args.comment: 105 | kwargs['slurm_comment'] = args.comment 106 | 107 | executor.update_parameters( 108 | mem_gb=40 * num_gpus_per_node, 109 | gpus_per_node=num_gpus_per_node, 110 | tasks_per_node=num_gpus_per_node, # one task per GPU 111 | cpus_per_task=10, 112 | nodes=nodes, 113 | timeout_min=timeout_min, # max is 60 * 72 114 | # Below are cluster dependent parameters 115 | slurm_partition=partition, 116 | slurm_signal_delay_s=120, 117 | **kwargs 118 | ) 119 | 120 | executor.update_parameters(name="dino") 121 | 122 | args.dist_url = get_init_file().as_uri() 123 | 124 | trainer = Trainer(args) 125 | job = executor.submit(trainer) 126 | 127 | print(f"Submitted job_id: {job.job_id}") 128 | print(f"Logs and checkpoints will be saved at: {args.output_dir}") 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /dino/vim/.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | **/__pycache__/** 3 | imnet_resnet50_scratch/timm_temp/ 4 | .dumbo.json 5 | checkpoints/ 6 | -------------------------------------------------------------------------------- /dino/vim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/dino/vim/__init__.py -------------------------------------------------------------------------------- /dino/vim/augment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | """ 5 | 3Augment implementation 6 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino) 7 | and timm DA(https://github.com/rwightman/pytorch-image-models) 8 | """ 9 | import torch 10 | from torchvision import transforms 11 | 12 | # error: cannot import name '_pil_interp' from 'timm.data.transforms' 13 | # from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor 14 | 15 | # fix: timm version problem 16 | # from timm.data.transforms import str_pil_interp as _pil_interp 17 | from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor 18 | 19 | import numpy as np 20 | from torchvision import datasets, transforms 21 | import random 22 | 23 | 24 | 25 | from PIL import ImageFilter, ImageOps 26 | import torchvision.transforms.functional as TF 27 | 28 | 29 | class GaussianBlur(object): 30 | """ 31 | Apply Gaussian Blur to the PIL image. 32 | """ 33 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): 34 | self.prob = p 35 | self.radius_min = radius_min 36 | self.radius_max = radius_max 37 | 38 | def __call__(self, img): 39 | do_it = random.random() <= self.prob 40 | if not do_it: 41 | return img 42 | 43 | img = img.filter( 44 | ImageFilter.GaussianBlur( 45 | radius=random.uniform(self.radius_min, self.radius_max) 46 | ) 47 | ) 48 | return img 49 | 50 | class Solarization(object): 51 | """ 52 | Apply Solarization to the PIL image. 53 | """ 54 | def __init__(self, p=0.2): 55 | self.p = p 56 | 57 | def __call__(self, img): 58 | if random.random() < self.p: 59 | return ImageOps.solarize(img) 60 | else: 61 | return img 62 | 63 | class gray_scale(object): 64 | """ 65 | Apply Solarization to the PIL image. 66 | """ 67 | def __init__(self, p=0.2): 68 | self.p = p 69 | self.transf = transforms.Grayscale(3) 70 | 71 | def __call__(self, img): 72 | if random.random() < self.p: 73 | return self.transf(img) 74 | else: 75 | return img 76 | 77 | 78 | 79 | class horizontal_flip(object): 80 | """ 81 | Apply Solarization to the PIL image. 82 | """ 83 | def __init__(self, p=0.2,activate_pred=False): 84 | self.p = p 85 | self.transf = transforms.RandomHorizontalFlip(p=1.0) 86 | 87 | def __call__(self, img): 88 | if random.random() < self.p: 89 | return self.transf(img) 90 | else: 91 | return img 92 | 93 | 94 | 95 | def new_data_aug_generator(args = None): 96 | img_size = args.input_size 97 | remove_random_resized_crop = args.src 98 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 99 | primary_tfl = [] 100 | scale=(0.08, 1.0) 101 | interpolation='bicubic' 102 | if remove_random_resized_crop: 103 | primary_tfl = [ 104 | transforms.Resize(img_size, interpolation=3), 105 | transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'), 106 | transforms.RandomHorizontalFlip() 107 | ] 108 | else: 109 | primary_tfl = [ 110 | RandomResizedCropAndInterpolation( 111 | img_size, scale=scale, interpolation=interpolation), 112 | transforms.RandomHorizontalFlip() 113 | ] 114 | 115 | 116 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0), 117 | Solarization(p=1.0), 118 | GaussianBlur(p=1.0)])] 119 | 120 | if args.color_jitter is not None and not args.color_jitter==0: 121 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) 122 | final_tfl = [ 123 | transforms.ToTensor(), 124 | transforms.Normalize( 125 | mean=torch.tensor(mean), 126 | std=torch.tensor(std)) 127 | ] 128 | return transforms.Compose(primary_tfl+secondary_tfl+final_tfl) 129 | -------------------------------------------------------------------------------- /dino/vim/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import json 5 | 6 | from torchvision import datasets, transforms 7 | from torchvision.datasets.folder import ImageFolder, default_loader 8 | 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from timm.data import create_transform 11 | 12 | 13 | class INatDataset(ImageFolder): 14 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 15 | category='name', loader=default_loader): 16 | self.transform = transform 17 | self.loader = loader 18 | self.target_transform = target_transform 19 | self.year = year 20 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 21 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 22 | with open(path_json) as json_file: 23 | data = json.load(json_file) 24 | 25 | with open(os.path.join(root, 'categories.json')) as json_file: 26 | data_catg = json.load(json_file) 27 | 28 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 29 | 30 | with open(path_json_for_targeter) as json_file: 31 | data_for_targeter = json.load(json_file) 32 | 33 | targeter = {} 34 | indexer = 0 35 | for elem in data_for_targeter['annotations']: 36 | king = [] 37 | king.append(data_catg[int(elem['category_id'])][category]) 38 | if king[0] not in targeter.keys(): 39 | targeter[king[0]] = indexer 40 | indexer += 1 41 | self.nb_classes = len(targeter) 42 | 43 | self.samples = [] 44 | for elem in data['images']: 45 | cut = elem['file_name'].split('/') 46 | target_current = int(cut[2]) 47 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 48 | 49 | categors = data_catg[target_current] 50 | target_current_true = targeter[categors[category]] 51 | self.samples.append((path_current, target_current_true)) 52 | 53 | # __getitem__ and __len__ inherited from ImageFolder 54 | 55 | 56 | def build_dataset(is_train, args): 57 | transform = build_transform(is_train, args) 58 | 59 | if args.data_set == 'CIFAR': 60 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 61 | nb_classes = 100 62 | elif args.data_set == 'IMNET': 63 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 64 | dataset = datasets.ImageFolder(root, transform=transform) 65 | nb_classes = 1000 66 | elif args.data_set == 'INAT': 67 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 68 | category=args.inat_category, transform=transform) 69 | nb_classes = dataset.nb_classes 70 | elif args.data_set == 'INAT19': 71 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 72 | category=args.inat_category, transform=transform) 73 | nb_classes = dataset.nb_classes 74 | 75 | return dataset, nb_classes 76 | 77 | 78 | def build_transform(is_train, args): 79 | resize_im = args.input_size > 32 80 | if is_train: 81 | # this should always dispatch to transforms_imagenet_train 82 | transform = create_transform( 83 | input_size=args.input_size, 84 | is_training=True, 85 | color_jitter=args.color_jitter, 86 | auto_augment=args.aa, 87 | interpolation=args.train_interpolation, 88 | re_prob=args.reprob, 89 | re_mode=args.remode, 90 | re_count=args.recount, 91 | ) 92 | if not resize_im: 93 | # replace RandomResizedCropAndInterpolation with 94 | # RandomCrop 95 | transform.transforms[0] = transforms.RandomCrop( 96 | args.input_size, padding=4) 97 | return transform 98 | 99 | t = [] 100 | if resize_im: 101 | size = int(args.input_size / args.eval_crop_ratio) 102 | t.append( 103 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 104 | ) 105 | t.append(transforms.CenterCrop(args.input_size)) 106 | 107 | t.append(transforms.ToTensor()) 108 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 109 | return transforms.Compose(t) 110 | -------------------------------------------------------------------------------- /dino/vim/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | 10 | import torch 11 | 12 | import timm 13 | from timm.data import Mixup 14 | from timm.utils import accuracy, ModelEma 15 | 16 | from vim.losses import DistillationLoss 17 | from vim import utils 18 | 19 | 20 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 21 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 22 | device: torch.device, epoch: int, loss_scaler, amp_autocast, max_norm: float = 0, 23 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 24 | set_training_mode=True, args = None): 25 | model.train(set_training_mode) 26 | metric_logger = utils.MetricLogger(delimiter=" ") 27 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 28 | header = 'Epoch: [{}]'.format(epoch) 29 | print_freq = 10 30 | 31 | if args.cosub: 32 | criterion = torch.nn.BCEWithLogitsLoss() 33 | 34 | # debug 35 | # count = 0 36 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 37 | # count += 1 38 | # if count > 20: 39 | # break 40 | 41 | samples = samples.to(device, non_blocking=True) 42 | targets = targets.to(device, non_blocking=True) 43 | 44 | if mixup_fn is not None: 45 | samples, targets = mixup_fn(samples, targets) 46 | 47 | if args.cosub: 48 | samples = torch.cat((samples,samples),dim=0) 49 | 50 | if args.bce_loss: 51 | targets = targets.gt(0.0).type(targets.dtype) 52 | 53 | with amp_autocast(): 54 | outputs = model(samples, if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank) 55 | # outputs = model(samples) 56 | if not args.cosub: 57 | loss = criterion(samples, outputs, targets) 58 | else: 59 | outputs = torch.split(outputs, outputs.shape[0]//2, dim=0) 60 | loss = 0.25 * criterion(outputs[0], targets) 61 | loss = loss + 0.25 * criterion(outputs[1], targets) 62 | loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid()) 63 | loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid()) 64 | 65 | if args.if_nan2num: 66 | with amp_autocast(): 67 | loss = torch.nan_to_num(loss) 68 | 69 | loss_value = loss.item() 70 | 71 | if not math.isfinite(loss_value): 72 | print("Loss is {}, stopping training".format(loss_value)) 73 | if args.if_continue_inf: 74 | optimizer.zero_grad() 75 | continue 76 | else: 77 | sys.exit(1) 78 | 79 | optimizer.zero_grad() 80 | 81 | # this attribute is added by timm on one optimizer (adahessian) 82 | if isinstance(loss_scaler, timm.utils.NativeScaler): 83 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 84 | loss_scaler(loss, optimizer, clip_grad=max_norm, 85 | parameters=model.parameters(), create_graph=is_second_order) 86 | else: 87 | loss.backward() 88 | if max_norm != None: 89 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 90 | optimizer.step() 91 | 92 | torch.cuda.synchronize() 93 | if model_ema is not None: 94 | model_ema.update(model) 95 | 96 | metric_logger.update(loss=loss_value) 97 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 98 | # gather the stats from all processes 99 | metric_logger.synchronize_between_processes() 100 | print("Averaged stats:", metric_logger) 101 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 102 | 103 | 104 | @torch.no_grad() 105 | def evaluate(data_loader, model, device, amp_autocast): 106 | criterion = torch.nn.CrossEntropyLoss() 107 | 108 | metric_logger = utils.MetricLogger(delimiter=" ") 109 | header = 'Test:' 110 | 111 | # switch to evaluation mode 112 | model.eval() 113 | 114 | for images, target in metric_logger.log_every(data_loader, 10, header): 115 | images = images.to(device, non_blocking=True) 116 | target = target.to(device, non_blocking=True) 117 | 118 | # compute output 119 | with amp_autocast(): 120 | output = model(images) 121 | loss = criterion(output, target) 122 | 123 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 124 | 125 | batch_size = images.shape[0] 126 | metric_logger.update(loss=loss.item()) 127 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 128 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 129 | # gather the stats from all processes 130 | metric_logger.synchronize_between_processes() 131 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 132 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 133 | 134 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 135 | -------------------------------------------------------------------------------- /dino/vim/hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | from models import * 4 | from cait_models import * 5 | from resmlp_models import * 6 | #from patchconvnet_models import * 7 | 8 | dependencies = ["torch", "torchvision", "timm"] 9 | -------------------------------------------------------------------------------- /dino/vim/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | class DistillationLoss(torch.nn.Module): 11 | """ 12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 13 | taking a teacher model prediction and using it as additional supervision. 14 | """ 15 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 16 | distillation_type: str, alpha: float, tau: float): 17 | super().__init__() 18 | self.base_criterion = base_criterion 19 | self.teacher_model = teacher_model 20 | assert distillation_type in ['none', 'soft', 'hard'] 21 | self.distillation_type = distillation_type 22 | self.alpha = alpha 23 | self.tau = tau 24 | 25 | def forward(self, inputs, outputs, labels): 26 | """ 27 | Args: 28 | inputs: The original inputs that are feed to the teacher model 29 | outputs: the outputs of the model to be trained. It is expected to be 30 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 31 | in the first position and the distillation predictions as the second output 32 | labels: the labels for the base criterion 33 | """ 34 | outputs_kd = None 35 | if not isinstance(outputs, torch.Tensor): 36 | # assume that the model outputs a tuple of [outputs, outputs_kd] 37 | outputs, outputs_kd = outputs 38 | base_loss = self.base_criterion(outputs, labels) 39 | if self.distillation_type == 'none': 40 | return base_loss 41 | 42 | if outputs_kd is None: 43 | raise ValueError("When knowledge distillation is enabled, the model is " 44 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 45 | "class_token and the dist_token") 46 | # don't backprop throught the teacher 47 | with torch.no_grad(): 48 | teacher_outputs = self.teacher_model(inputs) 49 | 50 | if self.distillation_type == 'soft': 51 | T = self.tau 52 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 53 | # with slight modifications 54 | distillation_loss = F.kl_div( 55 | F.log_softmax(outputs_kd / T, dim=1), 56 | #We provide the teacher's targets in log probability because we use log_target=True 57 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719) 58 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both. 59 | F.log_softmax(teacher_outputs / T, dim=1), 60 | reduction='sum', 61 | log_target=True 62 | ) * (T * T) / outputs_kd.numel() 63 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 64 | #But we also experiments output_kd.size(0) 65 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details 66 | elif self.distillation_type == 'hard': 67 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 68 | 69 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 70 | return loss 71 | -------------------------------------------------------------------------------- /dino/vim/rope.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # EVA-02: A Visual Representation for Neon Genesis 3 | # Github source: https://github.com/baaivision/EVA/EVA02 4 | # Copyright (c) 2023 Beijing Academy of Artificial Intelligence (BAAI) 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Yuxin Fang 7 | # 8 | # Based on https://github.com/lucidrains/rotary-embedding-torch 9 | # --------------------------------------------------------' 10 | 11 | from math import pi 12 | 13 | import torch 14 | from torch import nn 15 | 16 | from einops import rearrange, repeat 17 | 18 | 19 | 20 | def broadcat(tensors, dim = -1): 21 | num_tensors = len(tensors) 22 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 23 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' 24 | shape_len = list(shape_lens)[0] 25 | dim = (dim + shape_len) if dim < 0 else dim 26 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 27 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 28 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' 29 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 30 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 31 | expanded_dims.insert(dim, (dim, dims[dim])) 32 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 33 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 34 | return torch.cat(tensors, dim = dim) 35 | 36 | 37 | 38 | def rotate_half(x): 39 | x = rearrange(x, '... (d r) -> ... d r', r = 2) 40 | x1, x2 = x.unbind(dim = -1) 41 | x = torch.stack((-x2, x1), dim = -1) 42 | return rearrange(x, '... d r -> ... (d r)') 43 | 44 | 45 | 46 | class VisionRotaryEmbedding(nn.Module): 47 | def __init__( 48 | self, 49 | dim, 50 | pt_seq_len, 51 | ft_seq_len=None, 52 | custom_freqs = None, 53 | freqs_for = 'lang', 54 | theta = 10000, 55 | max_freq = 10, 56 | num_freqs = 1, 57 | ): 58 | super().__init__() 59 | if custom_freqs: 60 | freqs = custom_freqs 61 | elif freqs_for == 'lang': 62 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 63 | elif freqs_for == 'pixel': 64 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 65 | elif freqs_for == 'constant': 66 | freqs = torch.ones(num_freqs).float() 67 | else: 68 | raise ValueError(f'unknown modality {freqs_for}') 69 | 70 | if ft_seq_len is None: ft_seq_len = pt_seq_len 71 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 72 | 73 | freqs_h = torch.einsum('..., f -> ... f', t, freqs) 74 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) 75 | 76 | freqs_w = torch.einsum('..., f -> ... f', t, freqs) 77 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) 78 | 79 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) 80 | 81 | self.register_buffer("freqs_cos", freqs.cos()) 82 | self.register_buffer("freqs_sin", freqs.sin()) 83 | 84 | print('======== shape of rope freq', self.freqs_cos.shape, '========') 85 | 86 | def forward(self, t, start_index = 0): 87 | rot_dim = self.freqs_cos.shape[-1] 88 | end_index = start_index + rot_dim 89 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 90 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 91 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 92 | return torch.cat((t_left, t, t_right), dim = -1) 93 | 94 | 95 | 96 | class VisionRotaryEmbeddingFast(nn.Module): 97 | def __init__( 98 | self, 99 | dim, 100 | pt_seq_len=16, 101 | ft_seq_len=None, 102 | custom_freqs = None, 103 | freqs_for = 'lang', 104 | theta = 10000, 105 | max_freq = 10, 106 | num_freqs = 1, 107 | ): 108 | super().__init__() 109 | if custom_freqs: 110 | freqs = custom_freqs 111 | elif freqs_for == 'lang': 112 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 113 | elif freqs_for == 'pixel': 114 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 115 | elif freqs_for == 'constant': 116 | freqs = torch.ones(num_freqs).float() 117 | else: 118 | raise ValueError(f'unknown modality {freqs_for}') 119 | 120 | if ft_seq_len is None: ft_seq_len = pt_seq_len 121 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 122 | 123 | freqs = torch.einsum('..., f -> ... f', t, freqs) 124 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2) 125 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) 126 | 127 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 128 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 129 | 130 | self.register_buffer("freqs_cos", freqs_cos) 131 | self.register_buffer("freqs_sin", freqs_sin) 132 | 133 | print('======== shape of rope freq', self.freqs_cos.shape, '========') 134 | 135 | def forward(self, t): 136 | if t.shape[1] % 2 != 0: 137 | t_spatial = t[:, 1:, :] 138 | t_spatial = t_spatial * self.freqs_cos + rotate_half(t_spatial) * self.freqs_sin 139 | return torch.cat((t[:, :1, :], t_spatial), dim=1) 140 | else: 141 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin -------------------------------------------------------------------------------- /dino/vim/run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | A script to run multinode training with submitit. 5 | """ 6 | import argparse 7 | import os 8 | import uuid 9 | from pathlib import Path 10 | 11 | import main as classification 12 | import submitit 13 | 14 | 15 | def parse_args(): 16 | classification_parser = classification.get_args_parser() 17 | parser = argparse.ArgumentParser("Submitit for DeiT", parents=[classification_parser]) 18 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 19 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 20 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") 21 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 22 | 23 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 24 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") 25 | parser.add_argument('--comment', default="", type=str, 26 | help='Comment to pass to scheduler, e.g. priority message') 27 | return parser.parse_args() 28 | 29 | 30 | def get_shared_folder() -> Path: 31 | user = os.getenv("USER") 32 | if Path("/checkpoint/").is_dir(): 33 | p = Path(f"/checkpoint/{user}/experiments") 34 | p.mkdir(exist_ok=True) 35 | return p 36 | raise RuntimeError("No shared folder available") 37 | 38 | 39 | def get_init_file(): 40 | # Init file must not exist, but it's parent dir must exist. 41 | os.makedirs(str(get_shared_folder()), exist_ok=True) 42 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 43 | if init_file.exists(): 44 | os.remove(str(init_file)) 45 | return init_file 46 | 47 | 48 | class Trainer(object): 49 | def __init__(self, args): 50 | self.args = args 51 | 52 | def __call__(self): 53 | import main as classification 54 | 55 | self._setup_gpu_args() 56 | classification.main(self.args) 57 | 58 | def checkpoint(self): 59 | import os 60 | import submitit 61 | 62 | self.args.dist_url = get_init_file().as_uri() 63 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 64 | if os.path.exists(checkpoint_file): 65 | self.args.resume = checkpoint_file 66 | print("Requeuing ", self.args) 67 | empty_trainer = type(self)(self.args) 68 | return submitit.helpers.DelayedSubmission(empty_trainer) 69 | 70 | def _setup_gpu_args(self): 71 | import submitit 72 | from pathlib import Path 73 | 74 | job_env = submitit.JobEnvironment() 75 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 76 | self.args.gpu = job_env.local_rank 77 | self.args.rank = job_env.global_rank 78 | self.args.world_size = job_env.num_tasks 79 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 80 | 81 | 82 | def main(): 83 | args = parse_args() 84 | if args.job_dir == "": 85 | args.job_dir = get_shared_folder() / "%j" 86 | 87 | # Note that the folder will depend on the job_id, to easily track experiments 88 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 89 | 90 | num_gpus_per_node = args.ngpus 91 | nodes = args.nodes 92 | timeout_min = args.timeout 93 | 94 | partition = args.partition 95 | kwargs = {} 96 | if args.use_volta32: 97 | kwargs['slurm_constraint'] = 'volta32gb' 98 | if args.comment: 99 | kwargs['slurm_comment'] = args.comment 100 | 101 | executor.update_parameters( 102 | mem_gb=40 * num_gpus_per_node, 103 | gpus_per_node=num_gpus_per_node, 104 | tasks_per_node=num_gpus_per_node, # one task per GPU 105 | cpus_per_task=10, 106 | nodes=nodes, 107 | timeout_min=timeout_min, # max is 60 * 72 108 | # Below are cluster dependent parameters 109 | slurm_partition=partition, 110 | slurm_signal_delay_s=120, 111 | **kwargs 112 | ) 113 | 114 | executor.update_parameters(name="deit") 115 | 116 | args.dist_url = get_init_file().as_uri() 117 | args.output_dir = args.job_dir 118 | 119 | trainer = Trainer(args) 120 | job = executor.submit(trainer) 121 | 122 | print("Submitted job_id:", job.job_id) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /dino/vim/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | if num_repeats < 1: 26 | raise ValueError("num_repeats should be greater than 0") 27 | self.dataset = dataset 28 | self.num_replicas = num_replicas 29 | self.rank = rank 30 | self.num_repeats = num_repeats 31 | self.epoch = 0 32 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas)) 33 | self.total_size = self.num_samples * self.num_replicas 34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 36 | self.shuffle = shuffle 37 | 38 | def __iter__(self): 39 | if self.shuffle: 40 | # deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.epoch) 43 | indices = torch.randperm(len(self.dataset), generator=g) 44 | else: 45 | indices = torch.arange(start=0, end=len(self.dataset)) 46 | 47 | # add extra samples to make it evenly divisible 48 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() 49 | padding_size: int = self.total_size - len(indices) 50 | if padding_size > 0: 51 | indices += indices[:padding_size] 52 | assert len(indices) == self.total_size 53 | 54 | # subsample 55 | indices = indices[self.rank:self.total_size:self.num_replicas] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices[:self.num_selected_samples]) 59 | 60 | def __len__(self): 61 | return self.num_selected_samples 62 | 63 | def set_epoch(self, epoch): 64 | self.epoch = epoch 65 | -------------------------------------------------------------------------------- /dino/vim/scripts/ft-vim-s.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate 3 | cd /vim; 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --lr 5e-6 --min-lr 1e-5 --warmup-lr 1e-5 --drop-path 0.0 --weight-decay 1e-8 --num_workers 25 --data-path --output_dir ./output/vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --epochs 30 --finetune --no_amp 6 | -------------------------------------------------------------------------------- /dino/vim/scripts/ft-vim-t.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate 3 | cd /vim; 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model vim_tiny_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --lr 5e-6 --min-lr 1e-5 --warmup-lr 1e-5 --drop-path 0.0 --weight-decay 1e-8 --num_workers 25 --data-path --output_dir ./output/vim_tiny_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --epochs 30 --finetune --no_amp 6 | -------------------------------------------------------------------------------- /dino/vim/scripts/pt-vim-s.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate 3 | cd /vim; 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 64 --drop-path 0.05 --weight-decay 0.05 --lr 1e-3 --num_workers 25 --data-path --output_dir ./output/vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp 6 | -------------------------------------------------------------------------------- /dino/vim/scripts/pt-vim-t.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate 3 | cd /vim; 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-path --output_dir ./output/vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp 6 | -------------------------------------------------------------------------------- /media/Vim4Path.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AtlasAnalyticsLab/Vim4Path/5a85840a439fe9fa3f419e52b846b049b97df9d0/media/Vim4Path.webp -------------------------------------------------------------------------------- /patch_heatmaps.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | from natsort import os_sorted 4 | from dino.vision_transformer import DINOHead, VisionTransformer 5 | from dino.vim.models_mamba import VisionMamba 6 | from dino.config import configurations 7 | from dino.main import get_args_parser 8 | from functools import partial 9 | from dino.utils import load_pretrained_weights 10 | from torchvision import transforms 11 | from torch import nn 12 | import torch 13 | from PIL import Image 14 | import torchvision 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | from pytorch_grad_cam import GradCAM 18 | from pytorch_grad_cam.utils.image import show_cam_on_image 19 | 20 | 21 | from tqdm import tqdm 22 | import random 23 | import matplotlib.gridspec as gridspec 24 | import cv2 25 | 26 | 27 | def get_model(args): 28 | 29 | config = configurations[args.arch] 30 | config['img_size'] = args.image_size 31 | config['patch_size'] = args.patch_size 32 | config['num_classes'] = args.num_classes 33 | if args.arch in configurations: 34 | config = configurations[args.arch] 35 | config['img_size'] = args.image_size 36 | config['patch_size'] = args.patch_size 37 | config['num_classes'] = args.num_classes 38 | 39 | if 'norm_layer' in config and config['norm_layer'] == "nn.LayerNorm": 40 | config['norm_layer'] = partial(nn.LayerNorm, eps=config['eps']) 41 | config['drop_path_rate'] = 0 42 | if args.arch.startswith('vim'): 43 | model = VisionMamba(return_features=True, **config) 44 | embed_dim = model.embed_dim 45 | elif args.arch.startswith('vit'): 46 | model = VisionTransformer(**config) 47 | embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)) 48 | print('EMBEDDED DIM:', embed_dim) 49 | else: 50 | print(f"Unknown architecture: {args.arch}") 51 | return model 52 | 53 | 54 | dataset_dir = 'path_to_test_candidate_images' 55 | parser = get_args_parser() 56 | args = parser.parse_known_args()[0] 57 | 58 | val_transform = transforms.Compose([ 59 | transforms.Resize(args.image_size, interpolation=3), 60 | transforms.CenterCrop(args.image_size), 61 | transforms.ToTensor(), 62 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 63 | ]) 64 | 65 | def reshape_transform_vit(tensor, height=14, width=14): 66 | result = tensor[:, 1 : , :].reshape(tensor.size(0), 67 | height, width, tensor.size(2)) 68 | 69 | result = result.transpose(2, 3).transpose(1, 2) 70 | return result 71 | 72 | 73 | def reshape_transform_vim(tensor, height=14, width=14, token_position=98): 74 | hidden_state = tensor 75 | hidden_state = torch.cat((hidden_state[:, 1:token_position, :], hidden_state[:, token_position+1:, :]), dim=1) 76 | result = hidden_state.reshape(hidden_state.size(0), height, width, hidden_state.size(2)) 77 | result = result.transpose(2, 3).transpose(1, 2) 78 | return result 79 | 80 | args.image_size = 224 81 | args.patch_size = 16 82 | args.num_classes = 2 83 | args.n_last_blocks = 4 84 | args.avgpool_patchtokens = False 85 | 86 | args.checkpoint_key = 'teacher' 87 | 88 | args.arch = 'vim-s' 89 | args.pretrained_weights = '/home/ubuntu/checkpoints/camelyon16_224_10x/vim-s_224-96/checkpoint.pth' 90 | model_vim_s = get_model(args) 91 | model_vim_s.cuda() 92 | model_vim_s.eval() 93 | load_pretrained_weights(model_vim_s, args.pretrained_weights, 94 | args.checkpoint_key, args.arch, args.patch_size) 95 | 96 | args.arch = 'vit-s' 97 | args.pretrained_weights = '/home/ubuntu/checkpoints/camelyon16_224_10x/vit-s_224-96/checkpoint.pth' 98 | model_vit_s = get_model(args) 99 | model_vit_s.cuda() 100 | model_vit_s.eval() 101 | load_pretrained_weights(model_vit_s, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) 102 | 103 | 104 | models = { 105 | 'Vim-s':(model_vim_s, model_vim_s.layers[-1].drop_path), 106 | 'ViT-s':(model_vit_s, model_vit_s.blocks[-1].norm1) 107 | } 108 | 109 | for class_name in ['tumor', 'normal']: 110 | 111 | img_paths = glob(os.path.join(dataset_dir, class_name, "*jpg")) 112 | img_paths = os_sorted(img_paths) 113 | target_image_idx = list(np.random.randint(0, len(img_paths), 60)) 114 | 115 | os.makedirs(f'heatmaps/heatmaps_diverse/{class_name}', exist_ok=True) 116 | os.makedirs(f'heatmaps/heatmaps_diverse/{class_name}/raw', exist_ok=True) 117 | 118 | for i in tqdm(target_image_idx): 119 | img = Image.open(img_paths[i]) 120 | img_transformed = val_transform(img).unsqueeze(0) 121 | img_show = img_transformed.cpu().squeeze().permute(1, 2, 0).numpy() 122 | img_show = (img_show - img_show.min()) / (img_show.max() - img_show.min()) # Normalize to [0,1] 123 | plt.figure(figsize=(50, 23)) 124 | gs = gridspec.GridSpec(2, 5) 125 | 126 | # Original image 127 | final_img = np.array(img)/255 128 | 129 | ax0 = plt.subplot(gs[0:2, 0:2]) 130 | ax0.imshow(final_img) 131 | ax0.set_title('Original Image', fontsize=40) 132 | ax0.axis('off') 133 | 134 | 135 | cams = [] 136 | for idx, (model_name, (model, target_layer)) in enumerate(models.items()): 137 | cam = GradCAM(model=model, target_layers=[target_layer], 138 | reshape_transform=reshape_transform_vim if 'mamba' in model.__class__.__name__.lower() else reshape_transform_vit) 139 | grayscale_cam = cam(input_tensor=img_transformed)[0, :] 140 | grayscale_cam = cv2.resize(grayscale_cam, (final_img.shape[:2])) 141 | cam_image = show_cam_on_image(final_img, grayscale_cam, use_rgb=True) 142 | 143 | ax = plt.subplot(gs[idx // 3, idx % 3 + 2]) 144 | ax.imshow(cam_image) 145 | Image.fromarray(cam_image).save(f'heatmaps/heatmaps_diverse/{class_name}/raw/{img_name}_{model_name}.png') 146 | ax.set_title(f'{model_name} Heatmap', fontsize=40) 147 | ax.axis('off') 148 | 149 | plt.tight_layout() 150 | plt.savefig(f'heatmaps/heatmaps_diverse/{class_name}/{img_name}.jpg', bbox_inches='tight', dpi=200) 151 | plt.close() 152 | 153 | -------------------------------------------------------------------------------- /preprocess/check_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image, ImageFile 3 | from tqdm import tqdm 4 | import argparse 5 | 6 | def check_and_remove_images(directory): 7 | corrupted_count = 0 8 | ImageFile.LOAD_TRUNCATED_IMAGES = False # Do not load truncated images 9 | 10 | for subdir, dirs, files in os.walk(directory): 11 | for file in tqdm(files, position=0, leave=False): 12 | file_path = os.path.join(subdir, file) 13 | try: 14 | with Image.open(file_path) as img: 15 | img.load() # Attempt to load the image to catch truncation errors 16 | except (IOError, SyntaxError, ValueError) as e: 17 | print(f'Removing corrupted image: {file_path} with error {e}') 18 | os.remove(file_path) # Remove the corrupted image 19 | corrupted_count += 1 20 | return corrupted_count 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser(description="Check and remove corrupted images in a directory") 25 | parser.add_argument('--dir', type=str, help='The path to the directory containing the images') 26 | 27 | args = parser.parse_args() 28 | 29 | corrupted_images_count = check_and_remove_images(args.dir) 30 | print(f'Number of corrupted images removed: {corrupted_images_count}') 31 | 32 | if __name__ == '__main__': 33 | main() 34 | -------------------------------------------------------------------------------- /preprocess/datasets/dataset_h5.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | import numpy as np 5 | import pandas as pd 6 | import math 7 | import re 8 | import pdb 9 | import pickle 10 | 11 | from torch.utils.data import Dataset, DataLoader, sampler 12 | from torchvision import transforms, utils, models 13 | import torch.nn.functional as F 14 | 15 | from PIL import Image 16 | import h5py 17 | 18 | from random import randrange 19 | 20 | def eval_transforms(pretrained=False): 21 | if pretrained: 22 | mean = (0.485, 0.456, 0.406) 23 | std = (0.229, 0.224, 0.225) 24 | 25 | else: 26 | mean = (0.5,0.5,0.5) 27 | std = (0.5,0.5,0.5) 28 | 29 | trnsfrms_val = transforms.Compose( 30 | [ 31 | transforms.ToTensor(), 32 | transforms.Normalize(mean = mean, std = std) 33 | ] 34 | ) 35 | 36 | return trnsfrms_val 37 | 38 | class Whole_Slide_Bag(Dataset): 39 | def __init__(self, 40 | file_path, 41 | pretrained=False, 42 | custom_transforms=None, 43 | target_patch_size=-1, 44 | ): 45 | """ 46 | Args: 47 | file_path (string): Path to the .h5 file containing patched data. 48 | pretrained (bool): Use ImageNet transforms 49 | custom_transforms (callable, optional): Optional transform to be applied on a sample 50 | """ 51 | self.pretrained=pretrained 52 | if target_patch_size > 0: 53 | self.target_patch_size = (target_patch_size, target_patch_size) 54 | else: 55 | self.target_patch_size = None 56 | 57 | if not custom_transforms: 58 | self.roi_transforms = eval_transforms(pretrained=pretrained) 59 | else: 60 | self.roi_transforms = custom_transforms 61 | 62 | self.file_path = file_path 63 | 64 | with h5py.File(self.file_path, "r") as f: 65 | dset = f['imgs'] 66 | self.length = len(dset) 67 | 68 | self.summary() 69 | 70 | def __len__(self): 71 | return self.length 72 | 73 | def summary(self): 74 | hdf5_file = h5py.File(self.file_path, "r") 75 | dset = hdf5_file['imgs'] 76 | for name, value in dset.attrs.items(): 77 | print(name, value) 78 | 79 | print('pretrained:', self.pretrained) 80 | print('transformations:', self.roi_transforms) 81 | if self.target_patch_size is not None: 82 | print('target_size: ', self.target_patch_size) 83 | 84 | def __getitem__(self, idx): 85 | with h5py.File(self.file_path,'r') as hdf5_file: 86 | img = hdf5_file['imgs'][idx] 87 | coord = hdf5_file['coords'][idx] 88 | 89 | img = Image.fromarray(img) 90 | if self.target_patch_size is not None: 91 | img = img.resize(self.target_patch_size) 92 | img = self.roi_transforms(img).unsqueeze(0) 93 | return img, coord 94 | 95 | class Whole_Slide_Bag_FP(Dataset): 96 | def __init__(self, 97 | file_path, 98 | wsi, 99 | pretrained=False, 100 | custom_transforms=None, 101 | custom_downsample=1, 102 | target_patch_size=-1 103 | ): 104 | """ 105 | Args: 106 | file_path (string): Path to the .h5 file containing patched data. 107 | pretrained (bool): Use ImageNet transforms 108 | custom_transforms (callable, optional): Optional transform to be applied on a sample 109 | custom_downsample (int): Custom defined downscale factor (overruled by target_patch_size) 110 | target_patch_size (int): Custom defined image size before embedding 111 | """ 112 | self.pretrained=pretrained 113 | self.wsi = wsi 114 | if not custom_transforms: 115 | self.roi_transforms = eval_transforms(pretrained=pretrained) 116 | else: 117 | self.roi_transforms = custom_transforms 118 | 119 | self.file_path = file_path 120 | 121 | with h5py.File(self.file_path, "r") as f: 122 | dset = f['coords'] 123 | self.patch_level = f['coords'].attrs['patch_level'] 124 | self.patch_size = f['coords'].attrs['patch_size'] 125 | self.length = len(dset) 126 | if target_patch_size > 0: 127 | self.target_patch_size = (target_patch_size, ) * 2 128 | elif custom_downsample > 1: 129 | self.target_patch_size = (self.patch_size // custom_downsample, ) * 2 130 | else: 131 | self.target_patch_size = None 132 | self.summary() 133 | 134 | def __len__(self): 135 | return self.length 136 | 137 | def summary(self): 138 | hdf5_file = h5py.File(self.file_path, "r") 139 | dset = hdf5_file['coords'] 140 | for name, value in dset.attrs.items(): 141 | print(name, value) 142 | 143 | print('\nfeature extraction settings') 144 | print('target patch size: ', self.target_patch_size) 145 | print('pretrained: ', self.pretrained) 146 | print('transformations: ', self.roi_transforms) 147 | 148 | def __getitem__(self, idx): 149 | with h5py.File(self.file_path,'r') as hdf5_file: 150 | coord = hdf5_file['coords'][idx] 151 | img = self.wsi.read_region(coord, self.patch_level, (self.patch_size, self.patch_size)).convert('RGB') 152 | 153 | if self.target_patch_size is not None: 154 | img = img.resize(self.target_patch_size) 155 | img = self.roi_transforms(img).unsqueeze(0) 156 | return img, coord 157 | 158 | class Dataset_All_Bags(Dataset): 159 | 160 | def __init__(self, csv_path): 161 | self.df = pd.read_csv(csv_path) 162 | 163 | def __len__(self): 164 | return len(self.df) 165 | 166 | def __getitem__(self, idx): 167 | return self.df['slide_id'][idx] 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /preprocess/datasets/wsi_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import pandas as pd 3 | import numpy as np 4 | import time 5 | import pdb 6 | import PIL.Image as Image 7 | import h5py 8 | from torch.utils.data import Dataset 9 | import torch 10 | from wsi_core.util_classes import Contour_Checking_fn, isInContourV1, isInContourV2, isInContourV3_Easy, isInContourV3_Hard 11 | 12 | def default_transforms(mean = (0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 13 | t = transforms.Compose( 14 | [transforms.ToTensor(), 15 | transforms.Normalize(mean = mean, std = std)]) 16 | return t 17 | 18 | def get_contour_check_fn(contour_fn='four_pt_hard', cont=None, ref_patch_size=None, center_shift=None): 19 | if contour_fn == 'four_pt_hard': 20 | cont_check_fn = isInContourV3_Hard(contour=cont, patch_size=ref_patch_size, center_shift=center_shift) 21 | elif contour_fn == 'four_pt_easy': 22 | cont_check_fn = isInContourV3_Easy(contour=cont, patch_size=ref_patch_size, center_shift=0.5) 23 | elif contour_fn == 'center': 24 | cont_check_fn = isInContourV2(contour=cont, patch_size=ref_patch_size) 25 | elif contour_fn == 'basic': 26 | cont_check_fn = isInContourV1(contour=cont) 27 | else: 28 | raise NotImplementedError 29 | return cont_check_fn 30 | 31 | 32 | 33 | class Wsi_Region(Dataset): 34 | ''' 35 | args: 36 | wsi_object: instance of WholeSlideImage wrapper over a WSI 37 | top_left: tuple of coordinates representing the top left corner of WSI region (Default: None) 38 | bot_right tuple of coordinates representing the bot right corner of WSI region (Default: None) 39 | level: downsample level at which to prcess the WSI region 40 | patch_size: tuple of width, height representing the patch size 41 | step_size: tuple of w_step, h_step representing the step size 42 | contour_fn (str): 43 | contour checking fn to use 44 | choice of ['four_pt_hard', 'four_pt_easy', 'center', 'basic'] (Default: 'four_pt_hard') 45 | t: custom torchvision transformation to apply 46 | custom_downsample (int): additional downscale factor to apply 47 | use_center_shift: for 'four_pt_hard' contour check, how far out to shift the 4 points 48 | ''' 49 | def __init__(self, wsi_object, top_left=None, bot_right=None, level=0, 50 | patch_size = (256, 256), step_size=(256, 256), 51 | contour_fn='four_pt_hard', 52 | t=None, custom_downsample=1, use_center_shift=False): 53 | 54 | self.custom_downsample = custom_downsample 55 | 56 | # downscale factor in reference to level 0 57 | self.ref_downsample = wsi_object.level_downsamples[level] 58 | # patch size in reference to level 0 59 | self.ref_size = tuple((np.array(patch_size) * np.array(self.ref_downsample)).astype(int)) 60 | 61 | if self.custom_downsample > 1: 62 | self.target_patch_size = patch_size 63 | patch_size = tuple((np.array(patch_size) * np.array(self.ref_downsample) * custom_downsample).astype(int)) 64 | step_size = tuple((np.array(step_size) * custom_downsample).astype(int)) 65 | self.ref_size = patch_size 66 | else: 67 | step_size = tuple((np.array(step_size)).astype(int)) 68 | self.ref_size = tuple((np.array(patch_size) * np.array(self.ref_downsample)).astype(int)) 69 | 70 | self.wsi = wsi_object.wsi 71 | self.level = level 72 | self.patch_size = patch_size 73 | 74 | if not use_center_shift: 75 | center_shift = 0. 76 | else: 77 | overlap = 1 - float(step_size[0] / patch_size[0]) 78 | if overlap < 0.25: 79 | center_shift = 0.375 80 | elif overlap >= 0.25 and overlap < 0.75: 81 | center_shift = 0.5 82 | elif overlap >=0.75 and overlap < 0.95: 83 | center_shift = 0.5 84 | else: 85 | center_shift = 0.625 86 | #center_shift = 0.375 # 25% overlap 87 | #center_shift = 0.625 #50%, 75% overlap 88 | #center_shift = 1.0 #95% overlap 89 | 90 | filtered_coords = [] 91 | #iterate through tissue contours for valid patch coordinates 92 | for cont_idx, contour in enumerate(wsi_object.contours_tissue): 93 | print('processing {}/{} contours'.format(cont_idx, len(wsi_object.contours_tissue))) 94 | cont_check_fn = get_contour_check_fn(contour_fn, contour, self.ref_size[0], center_shift) 95 | coord_results, _ = wsi_object.process_contour(contour, wsi_object.holes_tissue[cont_idx], level, '', 96 | patch_size = patch_size[0], step_size = step_size[0], contour_fn=cont_check_fn, 97 | use_padding=True, top_left = top_left, bot_right = bot_right) 98 | if len(coord_results) > 0: 99 | filtered_coords.append(coord_results['coords']) 100 | 101 | coords=np.vstack(filtered_coords) 102 | 103 | self.coords = coords 104 | print('filtered a total of {} coordinates'.format(len(self.coords))) 105 | 106 | # apply transformation 107 | if t is None: 108 | self.transforms = default_transforms() 109 | else: 110 | self.transforms = t 111 | 112 | def __len__(self): 113 | return len(self.coords) 114 | 115 | def __getitem__(self, idx): 116 | coord = self.coords[idx] 117 | patch = self.wsi.read_region(tuple(coord), self.level, self.patch_size).convert('RGB') 118 | if self.custom_downsample > 1: 119 | patch = patch.resize(self.target_patch_size) 120 | patch = self.transforms(patch).unsqueeze(0) 121 | return patch, coord 122 | -------------------------------------------------------------------------------- /preprocess/extract_patches.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import glob 4 | import openslide 5 | from tqdm import tqdm 6 | from multiprocessing import Pool, cpu_count 7 | import cv2 8 | import numpy as np 9 | import argparse 10 | 11 | 12 | def process_patch(patch_path): 13 | class_name = patch_path.split(os.sep)[-3] # Adjust index based on the structure of patch_path 14 | file_name = os.path.basename(patch_path) 15 | patch_name = os.path.splitext(file_name)[0] 16 | 17 | raw_patch_path = os.path.join(args.raw_data_folder, class_name, f"{patch_name}.{args.wsi_extension}") 18 | wsi = openslide.open_slide(raw_patch_path) 19 | with h5py.File(patch_path, 'r') as hdf5_file: 20 | patch_level = hdf5_file['coords'].attrs['patch_level'] 21 | patch_size = hdf5_file['coords'].attrs['patch_size'] 22 | for idx in range(len(hdf5_file['coords'])): 23 | coord = hdf5_file['coords'][idx] 24 | if os.path.isfile(os.path.join(args.output_folder, class_name, f"{patch_name}_{idx}.jpg")): 25 | continue 26 | img = wsi.read_region(coord, patch_level, (patch_size, patch_size)).convert('RGB') 27 | img = np.array(img) 28 | cv2.imwrite(os.path.join(args.output_folder, class_name, f"{patch_name}_{idx}.jpg"), cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 29 | 30 | def main(args): 31 | class_names = os.listdir(args.input_folder) 32 | for class_name in class_names: 33 | os.makedirs(os.path.join(args.output_folder, class_name), exist_ok=True) 34 | print(f"Processing for Class {class_name}") 35 | if args.sample_count>0: 36 | patch_paths = sorted(glob.glob(os.path.join(args.input_folder, class_name, "patches", "*h5")))[ 37 | :args.sample_count] 38 | else: 39 | patch_paths = sorted(glob.glob(os.path.join(args.input_folder, class_name, "patches", "*h5"))) 40 | total = len(patch_paths) 41 | progress_bar = tqdm(total=total) 42 | with Pool(min(cpu_count(), 8)) as p: 43 | for _ in p.imap_unordered(process_patch, patch_paths): 44 | progress_bar.update() 45 | progress_bar.close() 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--raw_data_folder', type=str, help='Path to the folder containing raw WSIs.') 50 | parser.add_argument('--wsi_extension', type=str, choices=['ndpi', 'tif', 'svs'], help='Extension of WSI file type. Valid choices are [ndpi, tif, svs]') 51 | parser.add_argument('--input_folder', type=str, help='Folder that contains h5 files extracted from WSI using ' 52 | 'create_patches_fp.py') 53 | parser.add_argument('--output_folder', type=str, help='Folder to save extracted patches.') 54 | parser.add_argument('--sample_count', type=int, default=-1, help='Maximum number of WSIs to extract patches. If -1, it will extract all the patches.') 55 | args = parser.parse_args() 56 | main(args) -------------------------------------------------------------------------------- /preprocess/extract_patches.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -J CHUM # job name 3 | #SBATCH -w virya1 4 | #SBATCH -n8 # of CPU cores 5 | #SBATCH --mem=30GB # memory reserved (mandatory) 6 | 7 | source /etc/profile.d/modules.sh # adding module binaries 8 | 9 | module load anaconda/3.2023.03 10 | python extract_patches.py -------------------------------------------------------------------------------- /preprocess/extract_patches_tar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import glob 4 | import openslide 5 | from tqdm import tqdm 6 | import webdataset as wds 7 | from PIL import Image 8 | import io 9 | import json 10 | raw_data_paths = { 11 | 'digestive_benign': "/home/atlas-gp/Transfer_CC/10930_chx_digestive_benigne/", 12 | 'digestive_malign': "/home/atlas-gp/Transfer_CC/10931_chx_digestive_maligne" 13 | } 14 | 15 | input_folder = 'output' 16 | class_names = raw_data_paths.keys() 17 | hipt_patch_folder = 'extracted_mag10x_patch256_fp' 18 | output_pattern = 'output_tar/dataset-%02d.tar' # Pattern for sharded output files 19 | 20 | # Configuration for sharding 21 | max_shard_size = 3e10 # Maximum shard size in bytes (e.g., 30GB) 22 | 23 | # Create a ShardWriter object to write the dataset to sharded tar archives 24 | 25 | total_count =0 26 | # , maxcount=3e5 27 | with wds.ShardWriter(output_pattern, maxsize=max_shard_size) as sink: 28 | for class_name in class_names: 29 | print(f"Processing for Class {class_name}") 30 | patch_h5_paths = glob.glob(os.path.join(input_folder, class_name, hipt_patch_folder, "patches", "*h5"))[:20] 31 | for patch_h5_path in tqdm(patch_h5_paths): 32 | file_name = os.path.basename(patch_h5_path) 33 | wsi_id = os.path.splitext(file_name)[0] 34 | raw_wsi_path = os.path.join(raw_data_paths[class_name], f"{wsi_id}.ndpi") 35 | wsi = openslide.open_slide(raw_wsi_path) 36 | 37 | with h5py.File(patch_h5_path, 'r') as hdf5_file: 38 | patch_level = hdf5_file['coords'].attrs['patch_level'] 39 | patch_size = hdf5_file['coords'].attrs['patch_size'] 40 | for idx, coord in enumerate(hdf5_file['coords']): 41 | img = wsi.read_region((coord[0], coord[1]), patch_level, (patch_size, patch_size)).convert('RGB') 42 | # Convert PIL image to bytes 43 | img_byte_arr = io.BytesIO() 44 | img.save(img_byte_arr, format='JPEG') 45 | img_bytes = img_byte_arr.getvalue() 46 | img_byte_arr.close() 47 | 48 | # Key, class label, and coordinates for each image 49 | key = f"{wsi_id}_{idx}" 50 | class_label = class_name 51 | coord_str = f"{coord[0]}_{coord[1]}" 52 | 53 | # Write the image, metadata, and coordinates to the ShardWriter 54 | sink.write({ 55 | "__key__": key, 56 | "class": class_label, 57 | "coords": coord_str, 58 | "jpg": img_bytes 59 | }) 60 | total_count += 1 61 | 62 | print(total_count ) 63 | -------------------------------------------------------------------------------- /preprocess/sample.sh: -------------------------------------------------------------------------------- 1 | python create_patches_fp.py --source /home/atlas-gp/Transfer_CC/10930_chx_digestive_benigne/ --save_dir /home/a_n29343/CHUM/VIM4Path/datasets/CHUM/output/digestive_benign/extracted_mag5x_patch640_fp --patch_size 640 --step_size 640 --patch_level 3 --seg --patch --stitch 2 | python create_patches_fp.py --source /home/atlas-gp/Transfer_CC/10931_chx_digestive_maligne/ --save_dir /home/a_n29343/CHUM/VIM4Path/datasets/CHUM/output/digestive_malign/extracted_mag5x_patch640_fp --patch_size 640 --step_size 640 --patch_level 3 --seg --patch --stitch -------------------------------------------------------------------------------- /preprocess/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from models.model_mil import MIL_fc, MIL_fc_mc 7 | from models.model_clam import CLAM_SB, CLAM_MB 8 | import pdb 9 | import os 10 | import pandas as pd 11 | from utils.utils import * 12 | from utils.core_utils import Accuracy_Logger 13 | from sklearn.metrics import roc_auc_score, roc_curve, auc 14 | from sklearn.preprocessing import label_binarize 15 | import matplotlib.pyplot as plt 16 | 17 | def initiate_model(args, ckpt_path): 18 | print('Init Model') 19 | model_dict = {"dropout": args.drop_out, 'n_classes': args.n_classes} 20 | 21 | if args.model_size is not None and args.model_type in ['clam_sb', 'clam_mb']: 22 | model_dict.update({"size_arg": args.model_size}) 23 | 24 | if args.model_type =='clam_sb': 25 | model = CLAM_SB(**model_dict) 26 | elif args.model_type =='clam_mb': 27 | model = CLAM_MB(**model_dict) 28 | else: # args.model_type == 'mil' 29 | if args.n_classes > 2: 30 | model = MIL_fc_mc(**model_dict) 31 | else: 32 | model = MIL_fc(**model_dict) 33 | 34 | print_network(model) 35 | 36 | ckpt = torch.load(ckpt_path) 37 | ckpt_clean = {} 38 | for key in ckpt.keys(): 39 | if 'instance_loss_fn' in key: 40 | continue 41 | ckpt_clean.update({key.replace('.module', ''):ckpt[key]}) 42 | model.load_state_dict(ckpt_clean, strict=True) 43 | 44 | model.relocate() 45 | model.eval() 46 | return model 47 | 48 | def eval(dataset, args, ckpt_path): 49 | model = initiate_model(args, ckpt_path) 50 | 51 | print('Init Loaders') 52 | loader = get_simple_loader(dataset) 53 | patient_results, test_error, auc, df, _ = summary(model, loader, args) 54 | print('test_error: ', test_error) 55 | print('auc: ', auc) 56 | return model, patient_results, test_error, auc, df 57 | 58 | def summary(model, loader, args): 59 | acc_logger = Accuracy_Logger(n_classes=args.n_classes) 60 | model.eval() 61 | test_loss = 0. 62 | test_error = 0. 63 | 64 | all_probs = np.zeros((len(loader), args.n_classes)) 65 | all_labels = np.zeros(len(loader)) 66 | all_preds = np.zeros(len(loader)) 67 | 68 | slide_ids = loader.dataset.slide_data['slide_id'] 69 | patient_results = {} 70 | for batch_idx, (data, label) in enumerate(loader): 71 | data, label = data.to(device), label.to(device) 72 | slide_id = slide_ids.iloc[batch_idx] 73 | with torch.no_grad(): 74 | logits, Y_prob, Y_hat, _, results_dict = model(data) 75 | 76 | acc_logger.log(Y_hat, label) 77 | 78 | probs = Y_prob.cpu().numpy() 79 | 80 | all_probs[batch_idx] = probs 81 | all_labels[batch_idx] = label.item() 82 | all_preds[batch_idx] = Y_hat.item() 83 | 84 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}}) 85 | 86 | error = calculate_error(Y_hat, label) 87 | test_error += error 88 | 89 | del data 90 | test_error /= len(loader) 91 | 92 | aucs = [] 93 | if len(np.unique(all_labels)) == 1: 94 | auc_score = -1 95 | 96 | else: 97 | if args.n_classes == 2: 98 | auc_score = roc_auc_score(all_labels, all_probs[:, 1]) 99 | else: 100 | binary_labels = label_binarize(all_labels, classes=[i for i in range(args.n_classes)]) 101 | for class_idx in range(args.n_classes): 102 | if class_idx in all_labels: 103 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx]) 104 | aucs.append(auc(fpr, tpr)) 105 | else: 106 | aucs.append(float('nan')) 107 | if args.micro_average: 108 | binary_labels = label_binarize(all_labels, classes=[i for i in range(args.n_classes)]) 109 | fpr, tpr, _ = roc_curve(binary_labels.ravel(), all_probs.ravel()) 110 | auc_score = auc(fpr, tpr) 111 | else: 112 | auc_score = np.nanmean(np.array(aucs)) 113 | 114 | results_dict = {'slide_id': slide_ids, 'Y': all_labels, 'Y_hat': all_preds} 115 | for c in range(args.n_classes): 116 | results_dict.update({'p_{}'.format(c): all_probs[:,c]}) 117 | df = pd.DataFrame(results_dict) 118 | return patient_results, test_error, auc_score, df, acc_logger 119 | -------------------------------------------------------------------------------- /preprocess/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import h5py 3 | 4 | def save_pkl(filename, save_object): 5 | writer = open(filename,'wb') 6 | pickle.dump(save_object, writer) 7 | writer.close() 8 | 9 | def load_pkl(filename): 10 | loader = open(filename,'rb') 11 | file = pickle.load(loader) 12 | loader.close() 13 | return file 14 | 15 | 16 | def save_hdf5(output_path, asset_dict, attr_dict= None, mode='a'): 17 | file = h5py.File(output_path, mode) 18 | for key, val in asset_dict.items(): 19 | data_shape = val.shape 20 | if key not in file: 21 | data_type = val.dtype 22 | chunk_shape = (1, ) + data_shape[1:] 23 | maxshape = (None, ) + data_shape[1:] 24 | dset = file.create_dataset(key, shape=data_shape, maxshape=maxshape, chunks=chunk_shape, dtype=data_type) 25 | dset[:] = val 26 | if attr_dict is not None: 27 | if key in attr_dict.keys(): 28 | for attr_key, attr_val in attr_dict[key].items(): 29 | dset.attrs[attr_key] = attr_val 30 | else: 31 | dset = file[key] 32 | dset.resize(len(dset) + data_shape[0], axis=0) 33 | dset[-data_shape[0]:] = val 34 | file.close() 35 | return output_path -------------------------------------------------------------------------------- /preprocess/utils/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import pdb 6 | 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | from torchvision import transforms 11 | from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler 12 | import torch.optim as optim 13 | import pdb 14 | import torch.nn.functional as F 15 | import math 16 | from itertools import islice 17 | import collections 18 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | class SubsetSequentialSampler(Sampler): 21 | """Samples elements sequentially from a given list of indices, without replacement. 22 | 23 | Arguments: 24 | indices (sequence): a sequence of indices 25 | """ 26 | def __init__(self, indices): 27 | self.indices = indices 28 | 29 | def __iter__(self): 30 | return iter(self.indices) 31 | 32 | def __len__(self): 33 | return len(self.indices) 34 | 35 | def collate_MIL(batch): 36 | img = torch.cat([item[0] for item in batch], dim = 0) 37 | label = torch.LongTensor([item[1] for item in batch]) 38 | return [img, label] 39 | 40 | def collate_features(batch): 41 | img = torch.cat([item[0] for item in batch], dim = 0) 42 | coords = np.vstack([item[1] for item in batch]) 43 | return [img, coords] 44 | 45 | 46 | def get_simple_loader(dataset, batch_size=1, num_workers=1): 47 | kwargs = {'num_workers': 4, 'pin_memory': False, 'num_workers': num_workers} if device.type == "cuda" else {} 48 | loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL, **kwargs) 49 | return loader 50 | 51 | def get_split_loader(split_dataset, training = False, testing = False, weighted = False): 52 | """ 53 | return either the validation loader or training loader 54 | """ 55 | kwargs = {'num_workers': 4} if device.type == "cuda" else {} 56 | if not testing: 57 | if training: 58 | if weighted: 59 | weights = make_weights_for_balanced_classes_split(split_dataset) 60 | loader = DataLoader(split_dataset, batch_size=1, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate_MIL, **kwargs) 61 | else: 62 | loader = DataLoader(split_dataset, batch_size=1, sampler = RandomSampler(split_dataset), collate_fn = collate_MIL, **kwargs) 63 | else: 64 | loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate_MIL, **kwargs) 65 | 66 | else: 67 | ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False) 68 | loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate_MIL, **kwargs ) 69 | 70 | return loader 71 | 72 | def get_optim(model, args): 73 | if args.opt == "adam": 74 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg) 75 | elif args.opt == 'sgd': 76 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg) 77 | else: 78 | raise NotImplementedError 79 | return optimizer 80 | 81 | def print_network(net): 82 | num_params = 0 83 | num_params_train = 0 84 | print(net) 85 | 86 | for param in net.parameters(): 87 | n = param.numel() 88 | num_params += n 89 | if param.requires_grad: 90 | num_params_train += n 91 | 92 | print('Total number of parameters: %d' % num_params) 93 | print('Total number of trainable parameters: %d' % num_params_train) 94 | 95 | 96 | def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5, 97 | seed = 7, label_frac = 1.0, custom_test_ids = None): 98 | indices = np.arange(samples).astype(int) 99 | 100 | if custom_test_ids is not None: 101 | indices = np.setdiff1d(indices, custom_test_ids) 102 | 103 | np.random.seed(seed) 104 | for i in range(n_splits): 105 | all_val_ids = [] 106 | all_test_ids = [] 107 | sampled_train_ids = [] 108 | 109 | if custom_test_ids is not None: # pre-built test split, do not need to sample 110 | all_test_ids.extend(custom_test_ids) 111 | 112 | for c in range(len(val_num)): 113 | possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class 114 | val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids 115 | 116 | remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation 117 | all_val_ids.extend(val_ids) 118 | 119 | if custom_test_ids is None: # sample test split 120 | 121 | test_ids = np.random.choice(remaining_ids, test_num[c], replace = False) 122 | remaining_ids = np.setdiff1d(remaining_ids, test_ids) 123 | all_test_ids.extend(test_ids) 124 | 125 | if label_frac == 1: 126 | sampled_train_ids.extend(remaining_ids) 127 | 128 | else: 129 | sample_num = math.ceil(len(remaining_ids) * label_frac) 130 | slice_ids = np.arange(sample_num) 131 | sampled_train_ids.extend(remaining_ids[slice_ids]) 132 | 133 | yield sampled_train_ids, all_val_ids, all_test_ids 134 | 135 | 136 | def nth(iterator, n, default=None): 137 | if n is None: 138 | return collections.deque(iterator, maxlen=0) 139 | else: 140 | return next(islice(iterator,n, None), default) 141 | 142 | def calculate_error(Y_hat, Y): 143 | error = 1. - Y_hat.float().eq(Y.float()).float().mean().item() 144 | 145 | return error 146 | 147 | def make_weights_for_balanced_classes_split(dataset): 148 | N = float(len(dataset)) 149 | weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))] 150 | weight = [0] * int(N) 151 | for idx in range(len(dataset)): 152 | y = dataset.getlabel(idx) 153 | weight[idx] = weight_per_class[y] 154 | 155 | return torch.DoubleTensor(weight) 156 | 157 | def initialize_weights(module): 158 | for m in module.modules(): 159 | if isinstance(m, nn.Linear): 160 | nn.init.xavier_normal_(m.weight) 161 | m.bias.data.zero_() 162 | 163 | elif isinstance(m, nn.BatchNorm1d): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | 167 | -------------------------------------------------------------------------------- /preprocess/wsi_core/batch_process_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import pdb 4 | 5 | ''' 6 | initiate a pandas df describing a list of slides to process 7 | args: 8 | slides (df or array-like): 9 | array-like structure containing list of slide ids, if df, these ids assumed to be 10 | stored under the 'slide_id' column 11 | seg_params (dict): segmentation paramters 12 | filter_params (dict): filter parameters 13 | vis_params (dict): visualization paramters 14 | patch_params (dict): patching paramters 15 | use_heatmap_args (bool): whether to include heatmap arguments such as ROI coordinates 16 | ''' 17 | def initialize_df(slides, seg_params, filter_params, vis_params, patch_params, 18 | use_heatmap_args=False, save_patches=False): 19 | 20 | total = len(slides) 21 | if isinstance(slides, pd.DataFrame): 22 | slide_ids = slides.slide_id.values 23 | else: 24 | slide_ids = slides 25 | default_df_dict = {'slide_id': slide_ids, 'process': np.full((total), 1, dtype=np.uint8)} 26 | 27 | # initiate empty labels in case not provided 28 | if use_heatmap_args: 29 | default_df_dict.update({'label': np.full((total), -1)}) 30 | 31 | default_df_dict.update({ 32 | 'status': np.full((total), 'tbp'), 33 | # seg params 34 | 'seg_level': np.full((total), int(seg_params['seg_level']), dtype=np.int8), 35 | 'sthresh': np.full((total), int(seg_params['sthresh']), dtype=np.uint8), 36 | 'mthresh': np.full((total), int(seg_params['mthresh']), dtype=np.uint8), 37 | 'close': np.full((total), int(seg_params['close']), dtype=np.uint32), 38 | 'use_otsu': np.full((total), bool(seg_params['use_otsu']), dtype=bool), 39 | 'keep_ids': np.full((total), seg_params['keep_ids']), 40 | 'exclude_ids': np.full((total), seg_params['exclude_ids']), 41 | 42 | # filter params 43 | 'a_t': np.full((total), int(filter_params['a_t']), dtype=np.float32), 44 | 'a_h': np.full((total), int(filter_params['a_h']), dtype=np.float32), 45 | 'max_n_holes': np.full((total), int(filter_params['max_n_holes']), dtype=np.uint32), 46 | 47 | # vis params 48 | 'vis_level': np.full((total), int(vis_params['vis_level']), dtype=np.int8), 49 | 'line_thickness': np.full((total), int(vis_params['line_thickness']), dtype=np.uint32), 50 | 51 | # patching params 52 | 'use_padding': np.full((total), bool(patch_params['use_padding']), dtype=bool), 53 | 'contour_fn': np.full((total), patch_params['contour_fn']) 54 | }) 55 | 56 | if save_patches: 57 | default_df_dict.update({ 58 | 'white_thresh': np.full((total), int(patch_params['white_thresh']), dtype=np.uint8), 59 | 'black_thresh': np.full((total), int(patch_params['black_thresh']), dtype=np.uint8)}) 60 | 61 | if use_heatmap_args: 62 | # initiate empty x,y coordinates in case not provided 63 | default_df_dict.update({'x1': np.empty((total)).fill(np.NaN), 64 | 'x2': np.empty((total)).fill(np.NaN), 65 | 'y1': np.empty((total)).fill(np.NaN), 66 | 'y2': np.empty((total)).fill(np.NaN)}) 67 | 68 | 69 | if isinstance(slides, pd.DataFrame): 70 | temp_copy = pd.DataFrame(default_df_dict) # temporary dataframe w/ default params 71 | # find key in provided df 72 | # if exist, fill empty fields w/ default values, else, insert the default values as a new column 73 | for key in default_df_dict.keys(): 74 | if key in slides.columns: 75 | mask = slides[key].isna() 76 | slides.loc[mask, key] = temp_copy.loc[mask, key] 77 | else: 78 | slides.insert(len(slides.columns), key, default_df_dict[key]) 79 | else: 80 | slides = pd.DataFrame(default_df_dict) 81 | 82 | return slides -------------------------------------------------------------------------------- /preprocess/wsi_core/util_classes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import pdb 5 | import cv2 6 | class Mosaic_Canvas(object): 7 | def __init__(self,patch_size=256, n=100, downscale=4, n_per_row=10, bg_color=(0,0,0), alpha=-1): 8 | self.patch_size = patch_size 9 | self.downscaled_patch_size = int(np.ceil(patch_size/downscale)) 10 | self.n_rows = int(np.ceil(n / n_per_row)) 11 | self.n_cols = n_per_row 12 | w = self.n_cols * self.downscaled_patch_size 13 | h = self.n_rows * self.downscaled_patch_size 14 | if alpha < 0: 15 | canvas = Image.new(size=(w,h), mode="RGB", color=bg_color) 16 | else: 17 | canvas = Image.new(size=(w,h), mode="RGBA", color=bg_color + (int(255 * alpha),)) 18 | 19 | self.canvas = canvas 20 | self.dimensions = np.array([w, h]) 21 | self.reset_coord() 22 | 23 | def reset_coord(self): 24 | self.coord = np.array([0, 0]) 25 | 26 | def increment_coord(self): 27 | #print('current coord: {} x {} / {} x {}'.format(self.coord[0], self.coord[1], self.dimensions[0], self.dimensions[1])) 28 | assert np.all(self.coord<=self.dimensions) 29 | if self.coord[0] + self.downscaled_patch_size <=self.dimensions[0] - self.downscaled_patch_size: 30 | self.coord[0]+=self.downscaled_patch_size 31 | else: 32 | self.coord[0] = 0 33 | self.coord[1]+=self.downscaled_patch_size 34 | 35 | 36 | def save(self, save_path, **kwargs): 37 | self.canvas.save(save_path, **kwargs) 38 | 39 | def paste_patch(self, patch): 40 | assert patch.size[0] == self.patch_size 41 | assert patch.size[1] == self.patch_size 42 | self.canvas.paste(patch.resize(tuple([self.downscaled_patch_size, self.downscaled_patch_size])), tuple(self.coord)) 43 | self.increment_coord() 44 | 45 | def get_painting(self): 46 | return self.canvas 47 | 48 | class Contour_Checking_fn(object): 49 | # Defining __call__ method 50 | def __call__(self, pt): 51 | raise NotImplementedError 52 | 53 | class isInContourV1(Contour_Checking_fn): 54 | def __init__(self, contour): 55 | self.cont = contour 56 | 57 | def __call__(self, pt): 58 | return 1 if cv2.pointPolygonTest(self.cont, tuple(np.array(pt).astype(float)), False) >= 0 else 0 59 | 60 | class isInContourV2(Contour_Checking_fn): 61 | def __init__(self, contour, patch_size): 62 | self.cont = contour 63 | self.patch_size = patch_size 64 | 65 | def __call__(self, pt): 66 | pt = np.array((pt[0]+self.patch_size//2, pt[1]+self.patch_size//2)).astype(float) 67 | return 1 if cv2.pointPolygonTest(self.cont, tuple(np.array(pt).astype(float)), False) >= 0 else 0 68 | 69 | # Easy version of 4pt contour checking function - 1 of 4 points need to be in the contour for test to pass 70 | class isInContourV3_Easy(Contour_Checking_fn): 71 | def __init__(self, contour, patch_size, center_shift=0.5): 72 | self.cont = contour 73 | self.patch_size = patch_size 74 | self.shift = int(patch_size//2*center_shift) 75 | def __call__(self, pt): 76 | center = (pt[0]+self.patch_size//2, pt[1]+self.patch_size//2) 77 | if self.shift > 0: 78 | all_points = [(center[0]-self.shift, center[1]-self.shift), 79 | (center[0]+self.shift, center[1]+self.shift), 80 | (center[0]+self.shift, center[1]-self.shift), 81 | (center[0]-self.shift, center[1]+self.shift) 82 | ] 83 | else: 84 | all_points = [center] 85 | 86 | for points in all_points: 87 | if cv2.pointPolygonTest(self.cont, tuple(np.array(points).astype(float)), False) >= 0: 88 | return 1 89 | return 0 90 | 91 | # Hard version of 4pt contour checking function - all 4 points need to be in the contour for test to pass 92 | class isInContourV3_Hard(Contour_Checking_fn): 93 | def __init__(self, contour, patch_size, center_shift=0.5): 94 | self.cont = contour 95 | self.patch_size = patch_size 96 | self.shift = int(patch_size//2*center_shift) 97 | def __call__(self, pt): 98 | center = (pt[0]+self.patch_size//2, pt[1]+self.patch_size//2) 99 | if self.shift > 0: 100 | all_points = [(center[0]-self.shift, center[1]-self.shift), 101 | (center[0]+self.shift, center[1]+self.shift), 102 | (center[0]+self.shift, center[1]-self.shift), 103 | (center[0]-self.shift, center[1]+self.shift) 104 | ] 105 | else: 106 | all_points = [center] 107 | 108 | for points in all_points: 109 | if cv2.pointPolygonTest(self.cont, tuple(np.array(points).astype(float)), False) < 0: 110 | return 0 111 | return 1 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /test_patches.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import glob 4 | import openslide 5 | from tqdm import tqdm 6 | raw_data_paths = {'digestive_benign': "/home/atlas-gp/Transfer_CC/10930_chx_digestive_benigne/", 7 | 'digestive_malign': "/home/atlas-gp/Transfer_CC/10931_chx_digestive_maligne"} 8 | 9 | input_folder = 'preprocess/output' 10 | class_names = os.listdir(input_folder) 11 | hipt_patch_folder = 'extracted_mag10x_patch256_fp' 12 | 13 | for class_name in class_names: 14 | total_count = 0 15 | print(f"Processing for Class {class_name}") 16 | patch_paths = sorted(glob.glob(os.path.join(input_folder, class_name, hipt_patch_folder, "patches", "*h5")))#[:300] 17 | print(patch_paths[0]) 18 | for patch_path in tqdm(patch_paths): 19 | file_name = os.path.basename(patch_path) 20 | 21 | with h5py.File(patch_path,'r') as hdf5_file: 22 | patch_level = hdf5_file['coords'].attrs['patch_level'] 23 | patch_size = hdf5_file['coords'].attrs['patch_size'] 24 | patch_count = len(hdf5_file['coords']) 25 | total_count += patch_count 26 | 27 | print (f"Total patch count for {class_name} is {total_count}") --------------------------------------------------------------------------------