├── hpc_pathing.py ├── .gitignore ├── createSaveBinaryMask.asv ├── createSaveBinaryMask.m ├── sample.py ├── camvidOnly_train.slurm ├── camvidOnly_baseline.slurm ├── he_to_binary_mask_final.m ├── README.md ├── utilities ├── preprocessing.py ├── dataset.py └── main.py └── mat2npy.ipynb /hpc_pathing.py: -------------------------------------------------------------------------------- 1 | import os, time 2 | 3 | # Check what the main folder is on rivanna 4 | main_folder = os.path.dirname(os.path.abspath(__file__)) 5 | time.sleep(5) 6 | print(main_folder) 7 | # os.makedirs("/scratch/nm4wu/modelRuns") 8 | 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.iml 2 | *.pyc 3 | *.DS_Store 4 | *.vbb 5 | *.csv 6 | *.mat 7 | *.npy 8 | *.tif 9 | *.pdf 10 | *.png 11 | *.jpeg 12 | *.pkls 13 | *.pt 14 | *.tif 15 | *.pdf 16 | __pycache__/ 17 | __pycache__ 18 | ./Data/ 19 | .idea 20 | .idea/ 21 | /ralis 22 | /ralis-master 23 | /ralis-master/ 24 | /UNet-Instance-Cell-Segmentation/ 25 | 26 | 27 | -------------------------------------------------------------------------------- /createSaveBinaryMask.asv: -------------------------------------------------------------------------------- 1 | %basedir = '/home/pm2kb/RL_proj/MoNuSegTestData'; 2 | basedir = '/home/pm2kb/RL_proj/MoNuSegTrainingData/Annotations'; 3 | files = dir([basedir '/*.xml']); 4 | idx=0 5 | % note, you might run out of memory and have to load some files manually 6 | for file = files' 7 | 8 | [pathstr, name, ext] = fileparts(files(11).name); 9 | [binary_mask,color_mask] = he_to_binary_mask_final(name); 10 | save(fullfile(basedir,'/masks_true/',strcat(name,'.mat')),'binary_mask') 11 | 12 | end 13 | 14 | -------------------------------------------------------------------------------- /createSaveBinaryMask.m: -------------------------------------------------------------------------------- 1 | %basedir = '/home/pm2kb/RL_proj/MoNuSegTestData'; 2 | basedir = '/home/pm2kb/RL_proj/MoNuSegTrainingData/Annotations'; 3 | files = dir([basedir '/*.xml']); 4 | showFigures = 0; 5 | % note, you might run out of memory and have to load some files manually 6 | for file = files' 7 | [pathstr, name, ext] = fileparts(file.name); 8 | [binary_mask,color_mask] = he_to_binary_mask_final(name,showFigures); 9 | save(fullfile(basedir,'/masks_true/',strcat(name,'.mat')),'binary_mask') 10 | clear binary_mask color_mask 11 | end 12 | 13 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import shutil 3 | import numpy as np 4 | 5 | base_dir = '' 6 | 7 | ds_dir = '' 8 | dt_dir = '' 9 | dv_dir = '' 10 | set_list = [ds_dir, dt_dir, dv_dir] 11 | move_files = False 12 | 13 | #3 ds, 27 dt, 70 dv 14 | def split_three(lst, ratio=[0.03, 0.27, 0.70]): 15 | train_r, val_r, test_r = ratio 16 | assert(np.sum(ratio) == 1.0) # makes sure the splits make sense 17 | # only need to give 2 indices to split, the last one returns rest of the list/empty list 18 | indicies_for_splitting = [int(len(lst) * train_r), int(len(lst) * (train_r+val_r))] 19 | ds, dt, dv = np.split(lst, indicies_for_splitting) 20 | return ds, dt, dv 21 | 22 | ds, dt, dv = split_three(list) 23 | 24 | if move_files: 25 | for target_dir in set_list: 26 | for file in os.listdir(base_dir): 27 | if file in target_dir: 28 | shutil.copy(os.path.join(base_dir, file), os.path.join(target_dir, file)) 29 | else: 30 | print(f"File {file} not found in target directory.") 31 | 32 | -------------------------------------------------------------------------------- /camvidOnly_train.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --ntasks=1 3 | #SBATCH -t 72:00:00 4 | #SBATCH --partition=gpu 5 | #SBATCH --gres=gpu:p100:4 6 | #SBATCH -A gutintelligencelab 7 | #SBATCH --job-name=Train_Ralis_CAMVID_Only 8 | #SBATCH --output=Train_Ralis_CAMVID_Only_%A_%a.out 9 | #SBATCH --error=Train_Ralis_CAMVID_Only_%A_%a.err 10 | #SBATCH --mail-type=end 11 | #SBATCH --mail-user=pm2kb@virginia.edu 12 | 13 | module purge 14 | module --ignore-cache load anaconda/2019.10-py3.7 15 | module --ignore-cache load singularity/3.5.2 16 | 17 | ckpt_path='/scratch/pm2kb/ckpt_seg' 18 | data_path='/home/pm2kb/RL_proj/SegNet' 19 | 20 | ### Camvid ### 21 | for seed in 20 50 82 12 4560 22 | do 23 | singularity run --nv ~/pytorch-1.4.0-py37.sif /home/pm2kb/RL_proj/ralis-master/run.py --exp-name 'RALIS_camvid_train_seed'$seed --full-res --region-size 80 90 \ 24 | --snapshot 'best_jaccard_val.pth' --al-algorithm 'ralis' \ 25 | --ckpt-path $ckpt_path --data-path $data_path \ 26 | --rl-episodes 100 --rl-buffer 600 --lr-dqn 0.001\ 27 | --load-weights --exp-name-toload 'gta_pretraining_camvid' \ 28 | --dataset 'camvid' --lr 0.001 --train-batch-size 32 --val-batch-size 4 --patience 10 \ 29 | --input-size 224 224 --only-last-labeled --budget-labels 480 --num-each-iter 24 --rl-pool 20 --seed $seed 30 | done 31 | 32 | 33 | -------------------------------------------------------------------------------- /camvidOnly_baseline.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --ntasks=1 3 | #SBATCH -t 72:00:00 4 | #SBATCH --partition=gpu 5 | #SBATCH --gres=gpu:p100:4 6 | #SBATCH -A gutintelligencelab 7 | #SBATCH --job-name=Baseline_Ralis_CAMVID_Only 8 | #SBATCH --output=Baseline_Ralis_CAMVID_Only_%A_%a.out 9 | #SBATCH --error=Baseline_Ralis_CAMVID_Only_%A_%a.err 10 | #SBATCH --mail-type=end 11 | #SBATCH --mail-user=pm2kb@virginia.edu 12 | 13 | module purge 14 | module --ignore-cache load anaconda/2019.10-py3.7 15 | module --ignore-cache load singularity/3.5.2 16 | 17 | ckpt_path='/scratch/pm2kb/ckpt_seg' 18 | data_path='/home/pm2kb/RL_proj/SegNet' 19 | 20 | #### Camvid #### 21 | 22 | for al_algorithm in 'random' 23 | do 24 | for budget in 480 720 960 1200 1440 1920 25 | do 26 | for seed in 20 50 234 4560 12 27 | do 28 | singularity run --nv ~/pytorch-1.4.0-py37.sif /home/pm2kb/RL_proj/ralis-master/run.py --exp-name 'baseline_camvid_'$al_algorithm'_budget_'$budget'_seed'$seed --seed $seed --checkpointer \ 29 | --ckpt-path $ckpt_path --data-path $data_path \ 30 | --load-weights --exp-name-toload 'camvid_pretrained_dt' \ 31 | --input-size 224 224 --only-last-labeled --dataset 'camvid' --lr 0.001 --train-batch-size 32 --val-batch-size 4 \ 32 | --patience 150 --region-size 80 90 \ 33 | --budget-labels $budget --num-each-iter 24 --al-algorithm $al_algorithm --rl-pool 50 --train --test --final-test 34 | done 35 | done 36 | done 37 | 38 | for al_algorithm in 'entropy' 'bald' 39 | do 40 | for budget in 480 720 960 1200 1440 1920 41 | do 42 | for seed in 20 50 234 4560 12 43 | do 44 | singularity run --nv ~/pytorch-1.4.0-py37.sif /home/pm2kb/RL_proj/ralis-master/run.py --exp-name 'baseline_camvid_'$al_algorithm'_budget_'$budget'_seed'$seed --seed $seed --checkpointer \ 45 | --ckpt-path $ckpt_path --data-path $data_path \ 46 | --load-weights --exp-name-toload 'camvid_pretrained_dt' \ 47 | --input-size 224 224 --only-last-labeled --dataset 'camvid' --lr 0.001 --train-batch-size 32 --val-batch-size 4 \ 48 | --patience 150 --region-size 80 90 \ 49 | --budget-labels $budget --num-each-iter 24 --al-algorithm $al_algorithm --rl-pool 10 --train --test --final-test 50 | done 51 | done 52 | done 53 | 54 | -------------------------------------------------------------------------------- /he_to_binary_mask_final.m: -------------------------------------------------------------------------------- 1 | % A function to read in H&E image and xml file containing annotations 2 | % Gives the binary and colored maps based on annotated objects 3 | % Created by Neeraj Kumar, please cite the following paper if you use this code- 4 | % N. Kumar, R. Verma, S. Sharma, S. Bhargava, A. Vahadane and A. Sethi, 5 | % "A Dataset and a Technique for Generalized Nuclear Segmentation for 6 | % Computational Pathology," in IEEE Transactions on Medical Imaging, 7 | % vol. 36, no. 7, pp. 1550-1560, July 2017 8 | 9 | function [binary_mask,color_mask]=he_to_binary_mask_final(filename,showFigures) 10 | im_file=strcat(filename,'.tif'); 11 | 12 | xml_file=strcat(filename,'.xml'); 13 | 14 | xDoc = xmlread(xml_file); 15 | Regions=xDoc.getElementsByTagName('Region'); % get a list of all the region tags 16 | for regioni = 0:Regions.getLength-1 17 | Region=Regions.item(regioni); % for each region tag 18 | 19 | %get a list of all the vertexes (which are in order) 20 | verticies=Region.getElementsByTagName('Vertex'); 21 | xy{regioni+1}=zeros(verticies.getLength-1,2); %allocate space for them 22 | for vertexi = 0:verticies.getLength-1 %iterate through all verticies 23 | %get the x value of that vertex 24 | x=str2double(verticies.item(vertexi).getAttribute('X')); 25 | 26 | %get the y value of that vertex 27 | y=str2double(verticies.item(vertexi).getAttribute('Y')); 28 | xy{regioni+1}(vertexi+1,:)=[x,y]; % finally save them into the array 29 | end 30 | 31 | end 32 | im_info=imfinfo(im_file); 33 | 34 | 35 | nrow=im_info.Height; 36 | ncol=im_info.Width; 37 | binary_mask=zeros(nrow,ncol); %pre-allocate a mask 38 | color_mask = zeros(nrow,ncol,3); 39 | %mask_final = []; 40 | for zz=1:length(xy) %for each region 41 | fprintf('Processing object # %d \n',zz); 42 | smaller_x=xy{zz}(:,1); 43 | smaller_y=xy{zz}(:,2); 44 | 45 | %make a mask and add it to the current mask 46 | %this addition makes it obvious when more than 1 layer overlap each 47 | %other, can be changed to simply an OR depending on application. 48 | polygon = poly2mask(smaller_x,smaller_y,nrow,ncol); 49 | binary_mask=binary_mask+zz*(1-min(1,binary_mask)).*polygon;% 50 | color_mask = color_mask + cat(3, rand*polygon, rand*polygon,rand*polygon); 51 | %binary mask for all objects 52 | %imshow(ditance_transform) 53 | end 54 | if showFigures == 1 55 | figure;imshow(binary_mask) 56 | figure;imshow(color_mask) 57 | end 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Active Learning for Semantic Segmentation for Nuclei in Medical Imagery 2 | Semantic segmentation, the process of acquiring pixel-by-pixel data labels, requires large amounts of training data and is time consuming. Particularly in segmenting medical imagery, labeling cells requires an expert medical professional to label every pixel of interest. Many conventional deep learning methods achieve their results on the basis of large amounts of semantically segmented data. The goals of this project are to prioritize data for labeling which provides the maximum model performance given the training data. In active learning we select “the most informative samples to label so that a learning algorithm will perform better with less data than a non-selective approach” (Casanova et al., 2020). We plan to extend a reinforcement learning technique to select the best regions of whole slide images for labeling in order to train a segmentation model to maximize intersection over union (IoU) for segmenting nuclei cells in hematoxylin and eosin (H&E) stained biopsies. Hematoxylin stains the nucleus of a cell and eosin stains the cytoplasmic components, such as red blood cells and various types of fibers (collagen, elastic, and muscle) (Sampias). Successfully identifying morphological features and types of nuclei present in a pathology sample enables diagnosis and prognosis of numerous conditions, including cancer and muscular dystrophy (Zwerger et al., 2011). 3 | 4 | Specifically, the segmentation of nuclei in cells via active reinforcement learning will be explored via pathology datasets sourced from either existing cardiovascular fluorescent microscopy images utilized by the Owen’s Lab or publicly available histopathology datasets such as the Multi-organ nuclei segmentation (MoNuSeg) image set explored by Kumar et al (2020). 5 | 6 | References 7 | 8 | Casanova, A., Pinheiro, P. O., Rostamzadeh, N., & Pal, C. J. (2020). Reinforced active 9 | learning for image segmentation. arXiv preprint arXiv:2002.06583. 10 | 11 | Zwerger M, Ho CY, Lammerding J. Nuclear mechanics in disease. Annu Rev Biomed Eng. 12 | 2011;13:397-428. doi:10.1146/annurev-bioeng-071910-124736 13 | 14 | Sampias, Rolls, H&E Staining Overview: A Guide to Best Practices, 15 | https://www.leicabiosystems.com/knowledge-pathway/he-staining-overview-a-guide-to-best- 16 | practics/ 17 | 18 | N. Kumar et al., "A Multi-Organ Nucleus Segmentation Challenge," in IEEE Transactions on 19 | Medical Imaging, vol. 39, no. 5, pp. 1380-1391, May 2020, doi: 10.1109/TMI.2019.2947628. 20 | 21 | 22 | -------------------------------------------------------------------------------- /utilities/preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | 5 | # Training 6 | current_dir=os.getcwd() 7 | img_dir=os.path.join(current_dir,'MoNuSegTrainingData/Tissue Images/') 8 | save_dir=os.path.join(current_dir,'train_500_size/Tissue_Images') 9 | 10 | size_img=500 11 | 12 | 13 | for files in os.listdir(img_dir): 14 | if not files.startswith('.'): 15 | img=Image.open(os.path.join(img_dir,files)) 16 | img=np.array(img) 17 | n=0 18 | for i in range(0,np.array(img).shape[0],size_img): 19 | for j in range(0,np.array(img).shape[0],size_img): 20 | n+=1 21 | new_img=img[i:i+size_img,j:j+size_img,:] 22 | new_img=Image.fromarray(new_img) 23 | new_img.save(os.path.join(save_dir,files[:-4])+'_'+str(n)+'.jpg','JPEG') 24 | 25 | 26 | 27 | img_dir=os.path.join(current_dir,'MoNuSegTrainingData/Annotations/masks_true/masks_true_jpeg') 28 | save_dir=os.path.join(current_dir,'train_500_size/Annotations') 29 | 30 | 31 | size_img=500 32 | for files in os.listdir(img_dir): 33 | if not files.startswith('.'): 34 | img=Image.open(os.path.join(img_dir,files)) 35 | img=np.array(img) 36 | n=0 37 | for i in range(0,np.array(img).shape[0],size_img): 38 | for j in range(0,np.array(img).shape[0],size_img): 39 | n+=1 40 | new_img=img[i:i+size_img,j:j+size_img] 41 | new_img=Image.fromarray(new_img) 42 | new_img.save(os.path.join(save_dir,files[:-5])+'_'+str(n)+'.jpg','JPEG') 43 | 44 | img_dir=os.path.join(current_dir,'MoNuSegTestData/Tissue Images/') 45 | save_dir=os.path.join(current_dir,'test_500_size/Tissue_Images') 46 | 47 | #Test data 48 | size_img=500 49 | for files in os.listdir(img_dir): 50 | if not files.startswith('.'): 51 | img=Image.open(os.path.join(img_dir,files)) 52 | img=np.array(img) 53 | n=0 54 | for i in range(0,np.array(img).shape[0],size_img): 55 | for j in range(0,np.array(img).shape[0],size_img): 56 | n+=1 57 | new_img=img[i:i+size_img,j:j+size_img,:] 58 | new_img=Image.fromarray(new_img) 59 | new_img.save(os.path.join(save_dir,files[:-4])+'_'+str(n)+'.jpg','JPEG') 60 | 61 | 62 | img_dir=os.path.join(current_dir,'MoNuSegTestData/masks_true/masks_true_jpeg') 63 | save_dir=os.path.join(current_dir,'test_500_size/Annotations') 64 | 65 | size_img=500 66 | for files in os.listdir(img_dir): 67 | if not files.startswith('.'): 68 | img=Image.open(os.path.join(img_dir,files)) 69 | img=np.array(img) 70 | n=0 71 | for i in range(0,np.array(img).shape[0],size_img): 72 | for j in range(0,np.array(img).shape[0],size_img): 73 | n+=1 74 | new_img=img[i:i+size_img,j:j+size_img] 75 | new_img=Image.fromarray(new_img) 76 | new_img.save(os.path.join(save_dir,files[:-5])+'_'+str(n)+'.jpg','JPEG') 77 | 78 | -------------------------------------------------------------------------------- /utilities/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | 5 | from torch.utils.data.dataset import Dataset 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms 8 | from torch.utils.data.sampler import SubsetRandomSampler 9 | 10 | ''' 11 | Class that defines the reading and processing of the images. 12 | Specialized on BBBC006 dataset. 13 | ''' 14 | class BBBCDataset(Dataset): 15 | def __init__(self, ids, dir_data, dir_gt, extension='.png', isWCE=False, dir_weights = ''): 16 | 17 | self.dir_data = dir_data 18 | self.dir_gt = dir_gt 19 | self.extension = extension 20 | self.isWCE = isWCE 21 | self.dir_weights = dir_weights 22 | 23 | # Transforms 24 | self.data_transforms = { 25 | 'imgs': transforms.Compose([ 26 | # transforms.RandomResizedCrop(256), 27 | # transforms.RandomHorizontalFlip(), 28 | transforms.ToTensor(), 29 | transforms.Normalize([0.0054],[0.0037]) 30 | ]), 31 | 'masks': transforms.Compose([ 32 | transforms.ToTensor() 33 | ]), 34 | } 35 | 36 | # Images IDS 37 | self.ids = ids 38 | 39 | # Calculate len of data 40 | self.data_len = len(self.ids) 41 | 42 | ''' 43 | Ask for an image. 44 | ''' 45 | def __getitem__(self, index): 46 | # Get an ID of a specific image 47 | id_img = self.dir_data + self.ids[index] + self.extension 48 | id_gt = self.dir_gt + self.ids[index] + self.extension 49 | # Open Image and GroundTruth 50 | img = Image.open(id_img).convert('L') 51 | gt = Image.open(id_gt) 52 | # Applies transformations 53 | img = self.data_transforms['imgs'](img) 54 | gt = self.data_transforms['masks'](gt) 55 | if self.isWCE: 56 | id_weight = self.dir_weights + self.ids[index] + self.extension 57 | weight = Image.open(id_weight).convert('L') 58 | weight = self.data_transforms['masks'](weight) 59 | return (img, gt, weight) 60 | 61 | return (img, gt, gt) 62 | 63 | ''' 64 | Length of the dataset. 65 | It's needed for the upper class. 66 | ''' 67 | def __len__(self): 68 | return self.data_len 69 | 70 | 71 | ''' 72 | Returns the dataset separated in batches. 73 | Used inside every epoch for retrieving the images. 74 | ''' 75 | def get_dataloaders(dir_img, dir_gt, test_percent=0.2, batch_size=10, isWCE = False, dir_weights=''): 76 | # Validate a correct percentage 77 | test_percent = test_percent/100 if test_percent > 1 else test_percent 78 | # Read the names of the images 79 | ids = [f[:-4] for f in os.listdir(dir_img)] 80 | # Creates the dataset 81 | if not isWCE: 82 | dset = BBBCDataset(ids, dir_img, dir_gt) 83 | else: 84 | dset = BBBCDataset(ids, dir_img, dir_gt, isWCE = isWCE, dir_weights = dir_weights) 85 | 86 | # Calculate separation index for training and validation 87 | num_train = len(dset) 88 | indices = list(range(num_train)) 89 | split = int(np.floor(test_percent * num_train)) 90 | np.random.shuffle(indices) 91 | train_idx, valid_idx = indices[split:], indices[:split] 92 | 93 | # Create the dataloaders 94 | dataloaders = {} 95 | dataloaders['train'] = DataLoader(dset, batch_size=batch_size, 96 | sampler=SubsetRandomSampler(train_idx)) 97 | dataloaders['val'] = DataLoader(dset, batch_size=batch_size, 98 | sampler=SubsetRandomSampler(valid_idx)) 99 | 100 | return dataloaders['train'], dataloaders['val'] 101 | 102 | 103 | ''' 104 | Returns few images for showing the results. 105 | ''' 106 | def get_dataloader_show(dir_img, dir_gt): 107 | # Read the names of the images 108 | ids = [f[:-4] for f in os.listdir(dir_img)] 109 | # Creates the dataset 110 | dset = BBBCDataset(ids, dir_img, dir_gt) 111 | 112 | # Create the dataloader 113 | dataloader = DataLoader(dset, batch_size=len(ids)) 114 | 115 | return dataloader 116 | 117 | ''' 118 | Class that defines the reading and processing of the images. 119 | Specialized on BBBC006 dataset. 120 | ''' 121 | class BBBCDataset_Transform(Dataset): 122 | def __init__(self, ids, dir_data, extension='.png'): 123 | 124 | self.dir_data = dir_data 125 | self.extension = extension 126 | 127 | # Transforms 128 | self.data_transforms = { 129 | 'imgs': transforms.Compose([ 130 | # transforms.RandomResizedCrop(256), 131 | # transforms.RandomHorizontalFlip(), 132 | transforms.ToTensor(), 133 | transforms.Normalize([0.0054],[0.0037]) 134 | ]), 135 | 'masks': transforms.Compose([ 136 | transforms.ToTensor() 137 | ]), 138 | } 139 | 140 | # Images IDS 141 | self.ids = ids 142 | 143 | # Calculate len of data 144 | self.data_len = len(self.ids) 145 | 146 | ''' 147 | Ask for an image. 148 | ''' 149 | def __getitem__(self, index): 150 | # Get an ID of a specific image 151 | id_img = self.dir_data + self.ids[index] + self.extension 152 | # Open Image and GroundTruth 153 | img = Image.open(id_img).convert('L') 154 | # Applies transformations 155 | img = self.data_transforms['imgs'](img) 156 | return (img, self.ids[index]+self.extension) 157 | 158 | ''' 159 | Length of the dataset. 160 | It's needed for the upper class. 161 | ''' 162 | def __len__(self): 163 | return self.data_len 164 | 165 | ''' 166 | Returns whole dataset to transform. 167 | ''' 168 | def get_dataloader_transform(dir_img, batch_size = 1): 169 | # Read the names of the images 170 | ids = [f[:-4] for f in os.listdir(dir_img)] 171 | # Creates the dataset 172 | dset = BBBCDataset_Transform(ids, dir_img) 173 | 174 | # Create the dataloader 175 | dataloader = DataLoader(dset, batch_size=batch_size) 176 | 177 | return dataloader 178 | -------------------------------------------------------------------------------- /mat2npy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "TCGA-FG-A4MU-01B-01-TS1.mat\n", 13 | "TCGA-CU-A0YN-01A-02-BSB.mat\n", 14 | "TCGA-A6-6782-01A-01-BS1.mat\n", 15 | "TCGA-EJ-A46H-01A-03-TSC.mat\n", 16 | "TCGA-IZ-8196-01A-01-BS1.mat\n", 17 | "TCGA-GL-6846-01A-01-BS1.mat\n", 18 | "TCGA-AC-A2FO-01A-01-TS1.mat\n", 19 | "TCGA-69-7764-01A-01-TS1.mat\n", 20 | "TCGA-HT-8564-01Z-00-DX1.mat\n", 21 | "TCGA-HC-7209-01A-01-TS1.mat\n", 22 | "TCGA-2Z-A9J9-01A-01-TS1.mat\n", 23 | "TCGA-44-2665-01B-06-BS6.mat\n", 24 | "TCGA-ZF-A9R5-01A-01-TS1.mat\n", 25 | "TCGA-AO-A0J2-01A-01-BSA.mat\n" 26 | ] 27 | } 28 | ], 29 | "source": [ 30 | "#Purpose of this notebook: convert .mat to npy and jpeg files\n", 31 | "# The npy files are binary masks to be used for training and testing of MoNuSeg data\n", 32 | "#you need to create the npy and jpeg folders in the directory of he mask_path\n", 33 | "#I only use jpeg for easy viewing, not nn training\n", 34 | "\n", 35 | "\n", 36 | "import os\n", 37 | "## load names of images\n", 38 | "mask_path = '/home/pm2kb/RL_proj/MoNuSegTestData/masks_true'\n", 39 | "(_, _, filenames) = next(os.walk(mask_path))\n", 40 | "for file in filenames:\n", 41 | " print(file)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 12, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "from scipy import io \n", 51 | "from PIL import Image\n", 52 | "import numpy as np\n", 53 | "\n", 54 | "for file in filenames:\n", 55 | " dat = io.loadmat(os.path.join(mask_path,file)) \n", 56 | " mask = dat[\"binary_mask\"]\n", 57 | " bin_mask = mask.copy()\n", 58 | " #this is for semantic segmentation bc values are by count \n", 59 | " bin_mask[bin_mask>0] = 1\n", 60 | " filePath_npy = os.path.join(mask_path,'masks_true_npy',file[0:-4])\n", 61 | " np.save(filePath_npy,mask)\n", 62 | " filePath_jpeg= os.path.join(mask_path,'masks_true_jpeg',f'{file[0:-4]}.jpeg')\n", 63 | " mask_PIL =Image.fromarray(mask>0)\n", 64 | " mask_PIL = mask_PIL.convert(\"L\")\n", 65 | " mask_PIL.save(filePath_jpeg,\"JPEG\")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 14, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "356" 77 | ] 78 | }, 79 | "execution_count": 14, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "#This shows the number of unique annotations. Useful for instance segmantation. For semantic segmentation we turn it all to binary \n", 86 | "np.max(mask)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 4, 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "TCGA-49-4488-01Z-00-DX1.mat\n", 99 | "TCGA-KB-A93J-01A-01-TS1.mat\n", 100 | "TCGA-G9-6336-01Z-00-DX1.mat\n", 101 | "TCGA-DK-A2I6-01A-01-TS1.mat\n", 102 | "TCGA-G9-6362-01Z-00-DX1.mat\n", 103 | "TCGA-G2-A2EK-01A-02-TSB.mat\n", 104 | "TCGA-HE-7129-01Z-00-DX1.mat\n", 105 | "TCGA-HE-7128-01Z-00-DX1.mat\n", 106 | "TCGA-38-6178-01Z-00-DX1.mat\n", 107 | "TCGA-HE-7130-01Z-00-DX1.mat\n", 108 | "TCGA-E2-A14V-01Z-00-DX1.mat\n", 109 | "TCGA-E2-A1B5-01Z-00-DX1.mat\n", 110 | "TCGA-B0-5698-01Z-00-DX1.mat\n", 111 | "TCGA-G9-6363-01Z-00-DX1.mat\n", 112 | "TCGA-B0-5710-01Z-00-DX1.mat\n", 113 | "TCGA-A7-A13E-01Z-00-DX1.mat\n", 114 | "TCGA-AY-A8YK-01A-01-TS1.mat\n", 115 | "TCGA-CH-5767-01Z-00-DX1.mat\n", 116 | "TCGA-AR-A1AK-01Z-00-DX1.mat\n", 117 | "TCGA-A7-A13F-01Z-00-DX1.mat\n", 118 | "TCGA-21-5786-01Z-00-DX1.mat\n", 119 | "TCGA-21-5784-01Z-00-DX1.mat\n", 120 | "TCGA-AR-A1AS-01Z-00-DX1.mat\n", 121 | "TCGA-G9-6348-01Z-00-DX1.mat\n", 122 | "TCGA-NH-A8F7-01A-01-TS1.mat\n", 123 | "TCGA-18-5592-01Z-00-DX1.mat\n", 124 | "TCGA-G9-6356-01Z-00-DX1.mat\n", 125 | "TCGA-RD-A8N9-01A-01-TS1.mat\n", 126 | "TCGA-B0-5711-01Z-00-DX1.mat\n", 127 | "TCGA-50-5931-01Z-00-DX1.mat\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "## load names of images\n", 133 | "mask_path = '/home/pm2kb/RL_proj/MoNuSegTrainingData/Annotations/masks_true'\n", 134 | "(_, _, filenames) = next(os.walk(mask_path))\n", 135 | "for file in filenames:\n", 136 | " print(file)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 6, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "from scipy import io \n", 146 | "from PIL import Image\n", 147 | "import numpy as np\n", 148 | "\n", 149 | "for file in filenames:\n", 150 | " dat = io.loadmat(os.path.join(mask_path,file)) \n", 151 | " mask = dat[\"binary_mask\"]\n", 152 | " bin_mask = mask.copy()\n", 153 | " #this is for semantic segmentation bc values are by count \n", 154 | " bin_mask[bin_mask>0] = 1\n", 155 | " filePath_npy = os.path.join(mask_path,'masks_true_npy',file[0:-4])\n", 156 | " np.save(filePath_npy,mask)\n", 157 | " filePath_jpeg= os.path.join(mask_path,'masks_true_jpeg',f'{file[0:-4]}.jpeg')\n", 158 | " mask_PIL =Image.fromarray(mask>0)\n", 159 | " mask_PIL = mask_PIL.convert(\"L\")\n", 160 | " mask_PIL.save(filePath_jpeg,\"JPEG\")" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [] 169 | } 170 | ], 171 | "metadata": { 172 | "kernelspec": { 173 | "display_name": "Tensorflow 2.4.1/Keras Py3.7", 174 | "language": "python", 175 | "name": "tensorflow-2.4.1" 176 | }, 177 | "language_info": { 178 | "codemirror_mode": { 179 | "name": "ipython", 180 | "version": 3 181 | }, 182 | "file_extension": ".py", 183 | "mimetype": "text/x-python", 184 | "name": "python", 185 | "nbconvert_exporter": "python", 186 | "pygments_lexer": "ipython3", 187 | "version": "3.7.3" 188 | } 189 | }, 190 | "nbformat": 4, 191 | "nbformat_minor": 4 192 | } 193 | -------------------------------------------------------------------------------- /utilities/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import time 4 | from optparse import OptionParser 5 | 6 | from model import UNet 7 | from dataset import get_dataloaders 8 | from train_val import DiceLoss, RMSELoss, CELoss, WCELoss, train_net #, val_net 9 | from misc import export_history, save_checkpoint 10 | 11 | ''' 12 | Configure every aspect of the run. 13 | Runs the training and validation. 14 | ''' 15 | def setup_and_run_train(n_channels, n_classes, dir_img, dir_gt, dir_results, load, 16 | val_perc, batch_size, epochs, lr, run, optimizer, loss, evaluation, dir_weights): 17 | 18 | # Use GPU or not 19 | use_cuda = torch.cuda.is_available() 20 | device = torch.device("cuda" if use_cuda else "cpu") 21 | 22 | # Create the model 23 | net = UNet(n_channels, n_classes).to(device) 24 | net = torch.nn.DataParallel(net, device_ids=list( 25 | range(torch.cuda.device_count()))).to(device) 26 | 27 | # Load old weights 28 | if load: 29 | net.load_state_dict(torch.load(load)) 30 | print('Model loaded from {}'.format(load)) 31 | 32 | # Load the dataset 33 | if loss != "WCE": 34 | train_loader, val_loader = get_dataloaders(dir_img, dir_gt, val_perc, batch_size) 35 | else: 36 | train_loader, val_loader = get_dataloaders(dir_img, dir_gt, val_perc, batch_size, isWCE = True, dir_weights = dir_weights) 37 | 38 | # Pretty print of the run 39 | print('''\n 40 | Starting training: 41 | Dataset: {} 42 | Num Channels: {} 43 | Groundtruth: {} 44 | Num Classes: {} 45 | Folder to save: {} 46 | Load previous: {} 47 | Training size: {} 48 | Validation size: {} 49 | Validation Percentage: {} 50 | Batch size: {} 51 | Epochs: {} 52 | Learning rate: {} 53 | Optimizer: {} 54 | Loss Function: {} 55 | Evaluation Function: {} 56 | CUDA: {} 57 | '''.format(dir_img, n_channels, dir_gt, n_classes, dir_results, load, 58 | len(train_loader)*batch_size, len(val_loader)*batch_size, 59 | val_perc, batch_size, epochs, lr, optimizer, loss, evaluation, use_cuda)) 60 | 61 | # Definition of the optimizer ADD MORE IF YOU WANT 62 | if optimizer == "Adam": 63 | optimizer = torch.optim.Adam(net.parameters(), 64 | lr=lr) 65 | elif optimizer == "SGD": 66 | optimizer = torch.optim.SGD(net.parameters(), 67 | lr=lr, 68 | momentum=0.9, 69 | weight_decay=0.0005) 70 | 71 | # Definition of the loss function ADD MORE IF YOU WANT 72 | if loss == "Dice": 73 | criterion = DiceLoss() 74 | elif loss == "RMSE": 75 | criterion = RMSELoss() 76 | elif loss == "MSE": 77 | criterion = nn.MSELoss() 78 | elif loss == "MAE": 79 | criterion = nn.L1Loss() 80 | elif loss == "CE": 81 | criterion = CELoss() 82 | elif loss == "WCE": 83 | criterion = WCELoss() 84 | 85 | # Saving History to csv 86 | header = ['epoch', 'train loss'] 87 | 88 | best_loss = 10000 89 | time_start = time.time() 90 | # Run the training and validation 91 | for epoch in range(epochs): 92 | print('\nStarting epoch {}/{}.'.format(epoch + 1, epochs)) 93 | 94 | train_loss = train_net(net, device, train_loader, optimizer, criterion, batch_size, isWCE = (loss == "WCE")) 95 | #val_loss = val_net(net, device, val_loader, criterion_val, batch_size) 96 | 97 | values = [epoch+1, train_loss] 98 | export_history(header, values, dir_results, "result"+run+".csv") 99 | 100 | # save model 101 | if train_loss < best_loss: 102 | best_loss = train_loss 103 | save_checkpoint({ 104 | 'epoch': epoch + 1, 105 | 'state_dict': net.state_dict(), 106 | 'loss': train_loss, 107 | 'optimizer' : optimizer.state_dict(), 108 | }, path=dir_results, filename="weights"+run+".pth") 109 | 110 | time_dif = time.time() - time_start 111 | print("It tooks %.4f seconds to finish the run." % (time_dif)) 112 | 113 | 114 | ''' 115 | Definition of the optional and needed parameters. 116 | ''' 117 | def get_args(): 118 | parser = OptionParser() 119 | parser.add_option('-e', '--epochs', dest='epochs', default=30, type='int', 120 | help='number of epochs') 121 | parser.add_option('-b', '--batch-size', dest='batchsize', default=25, 122 | type='int', help='batch size') 123 | parser.add_option('-l', '--learning-rate', dest='lr', default=0.0001, 124 | type='float', help='learning rate') 125 | parser.add_option('-a', '--load', dest='load', 126 | default=False, help='load file model') 127 | parser.add_option('-r', '--runs', dest='runs', type='int', 128 | default=1, help='How many runs') 129 | parser.add_option('-d', '--dataset', dest='dataset', 130 | default='Data', help='Which dataset should use.') 131 | parser.add_option('-g', '--groundtruth', dest='gt', 132 | default='GT_One_Class', help='Which gt should use.') 133 | parser.add_option('-s', '--savedir', dest='savedir', 134 | default='checkpoints/', help='Which folder should use for checkpoints.') 135 | parser.add_option('-t', '--val-percentage', dest='val_perc', default=0.3,type='float', 136 | help='Validation Percentage') 137 | parser.add_option('-i', '--n-channels', dest='n_channels', default=1, type='int', 138 | help='Number of channels of the inputs.') 139 | parser.add_option('-c', '--n-classes', dest='n_classes', default=1, type='int', 140 | help='Number of classes of the output.') 141 | parser.add_option('-o', '--optimizer', dest='optimizer', default="Adam", choices=["Adam", "SGD"], 142 | help='Optimizer to use.') 143 | parser.add_option('-f', '--loss', dest='loss', default="Dice", choices=["Dice", "RMSE", "MSE", "MAE", "CE", "WCE"], 144 | help='Loss functios to use.') 145 | parser.add_option('-v', '--evaluation', dest='evaluation', default="Dice", choices=["Dice", "RMSE", "MSE", "MAE", "CE", "WCE"], 146 | help='Evaluation function to use.') 147 | parser.add_option('-w', '--weights', dest='weights', 148 | default='', help='Which weights should use.') 149 | 150 | (options, args) = parser.parse_args() 151 | return options 152 | 153 | 154 | ''' 155 | Runs the application. 156 | ''' 157 | if __name__ == "__main__": 158 | args = get_args() 159 | for r in range(args.runs): 160 | setup_and_run_train( 161 | n_channels = args.n_channels, 162 | n_classes = args.n_classes, 163 | dir_img = args.dataset, 164 | dir_gt = args.gt, 165 | dir_results = args.savedir, 166 | load = args.load, 167 | val_perc = args.val_perc, 168 | batch_size = args.batchsize, 169 | epochs = args.epochs, 170 | lr = args.lr, 171 | run=str(r), 172 | optimizer = args.optimizer, 173 | loss = args.loss, 174 | evaluation = args.evaluation, 175 | dir_weights = args.weights) 176 | 177 | 178 | 179 | --------------------------------------------------------------------------------