├── LICENSE ├── README.md ├── config.py ├── config_validator.py ├── dataset.py ├── docs └── JCOL.jpg ├── environment.yml ├── infer_produce_predict_map_wsi.py ├── loss ├── __init__.py ├── cancer_loss.py ├── ceo_loss.py ├── dorn_loss.py ├── mtmr_loss.py └── rank_ordinal_loss.py ├── misc ├── infer_wsi_utils.py ├── train_ultils_all_iter.py └── train_ultils_validator.py ├── model_lib ├── __init__.py └── efficientnet_pytorch │ ├── __init__.py │ ├── model.py │ ├── model_dorn.py │ ├── model_mtmr.py │ ├── model_rank_ordinal.py │ └── utils.py ├── requirements.txt ├── scheduler_lr ├── __init__.py └── warmup_cosine_lr.py ├── scripts ├── __init__.py └── run_train.sh ├── train_val.py └── train_val_ceo_for_cancer_only.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 TrinhVg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JCO_Learning: Joint Categorical and Ordinal Learning for Cancer Grading in Pathology Images 2 | ## About 3 | A multi-task deep learning model for pathology image grading conducts categorical classification, 4 | and auxiliary ordinal classification for Cancer Grading in Pathology Images uses a L_CEO loss for the auxiliary ordinal task.
5 | [Link](https://www.sciencedirect.com/science/article/pii/S1361841521002516) to Medical Image Analysis paper.
6 | 7 | ![](docs/JCOL.jpg) 8 | ## Datasets 9 | All the models in this project were evaluated on the following datasets: 10 | 11 | - [Colon_KBSMC](https://github.com/QuIIL/KBSMC_colon_cancer_grading_dataset) (Colon TMA from Kangbuk Samsung Hospital) 12 | - [Colon_KBSMC](https://github.com/QuIIL/KBSMC_colon_cancer_grading_dataset) (Colon WSI from Kangbuk Samsung Hospital) 13 | - [Prostate_UHU](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/OCYCMP) (Prostate TMA from University Hospital Zurich - Harvard dataverse) 14 | - [Prostate_UBC](https://gleason2019.grand-challenge.org/) (Prostate TMA from UBC - MICCAI 2019) 15 | 16 | ## Set Up Environment 17 | 18 | ``` 19 | conda env create -f environment.yml 20 | conda activate jco_learning 21 | pip install torch~=1.8.1+cu111 22 | ``` 23 | 24 | Above, we install PyTorch version 1.8.1 with CUDA 11.1. 25 | The code still work older Pytorch version (PyTorch >=1.1). 26 | ## Repository Structure 27 | 28 | Below are the main directories in the repository: 29 | 30 | - `dataloader/`: the data loader and augmentation pipeline 31 | - `docs/`: figures/GIFs used in the repo 32 | - `misc/`: utils that are 33 | - `model_lib/`: model definition, along with the main run step and hyperparameter settings 34 | - `script/`: defines the training loop 35 | 36 | Below are the main executable scripts in the repository: 37 | 38 | - `config.py`: configuration file 39 | - `config_validator.py`: still configuration file but for validation/test phrase or generate the predicted maps 40 | - `dataset.py`: defines the dataset classes 41 | - `train_val.py`: main training script 42 | - `train_val_ceo_for_cancer_only.py`: still training script but ordinal loss only applied to cancer classes (benign class is excluded) 43 | - `infer_produce_predict_map_wsi.py`: following sliding window fashion to generate a predicted map or probability map for WSI/core image 44 | 45 | # Running the Code 46 | 47 | ## Training and Options 48 | 49 | ``` 50 | python train_val.py [--gpu=] [--run_info=] [--dataset=] 51 | ``` 52 | 53 | Options: 54 | ** Our proposed and 9 common/state-of-the-art categorical and ordinal classification methods, including:** 55 | 56 | | METHOD | run_info | Description | 57 | | -------------|----------------------| ----------------------| 58 | | C_CE | CLASS_ce | Classification: Cross-Entropy loss 59 | | C_FOCAL | CLASS_FocalLoss | Classification: Focal loss, Focal loss for dense object detection [[paper]](https://arxiv.org/abs/1708.02002) 60 | | R_MAE | REGRESS_mae | Regression: MAE loss 61 | | R_MSE | REGRESS_mse | Regression: MSE loss 62 | | R_SL | REGRESS_soft_label | Regression: Soft-Label loss, Deep learning regression for prostate cancer detection and grading in Bi-parametric MRI [[paper]](https://ieeexplore.ieee.org/document/9090311) 63 | | O_DORN | REGRESS_rank_dorn | Ordinal regression: Deep ordinal regression network for monocular depth estimation [[paper]](https://arxiv.org/abs/1806.02446) [[code]](https://github.com/hufu6371/DORN?utm_source=catalyzex.com) 64 | | O_CORAL | REGRESS_rank_coral | Ordinal regression: Rank consistent ordinal regression for neural networks with application to age estimation [[paper]](https://arxiv.org/abs/1901.07884) [[code]](https://github.com/Raschka-research-group/coral-cnn?utm_source=catalyzex.com) 65 | | O_FOCAL | REGRESS_FocalOrdinal | Ordinal regression: Joint prostate cancer detection and Gleason score prediction in mp-MRI via FocalNet [[paper]](https://ieeexplore.ieee.org/document/8653866) 66 | | M_MTMR | MULTI_mtmr | Multitask: Multi-task deep model with margin ranking loss for lung nodule analysis [[paper]](https://ieeexplore.ieee.org/document/8794587) [[code]](https://github.com/lihaoliu-cambridge/mtmr-net) 67 | | M_MAE | MULTI_ce_mae | Multitask: Class_CE + Regression_MAE 68 | | M_MSE | MULTI_ce_mse | Multitask: Class_CE + Regression_MSE 69 | | M_MAE_CEO | MULTI_ce_mae_ceo | Multitask: Class_CE + Regression_MAE_CEO (Ours) 70 | | M_MSE_CEO | MULTI_ce_mae_ceo | Multitask: Class_CE + Regression_MSE_CEO (Ours) 71 | 72 | 73 | 74 | 75 | ## Inference 76 | 77 | ``` 78 | python infer_produce_predict_map_wsi.py [--gpu=] [--run_info=] 79 | ``` 80 | 81 | ### Model Weights 82 | 83 | Model weights obtained from training MULTI_ce_mse_ceo here: 84 | - [Colon checkpoint](https://drive.google.com/drive/folders/1Gf2HjjcjJw4h1VvFUbnF2xvr9SJ6_r48?usp=sharing) 85 | - [Prostate checkpoint](https://drive.google.com/drive/folders/1Gf2HjjcjJw4h1VvFUbnF2xvr9SJ6_r48?usp=sharing) 86 | 87 | Access the entire checkpoints [here](https://drive.google.com/drive/folders/1KQMD0iRibfAP9AxBE4TuU1NtPGvw-h5R?usp=sharing). 88 | 89 | If any of the above checkpoints are used, please ensure to cite the corresponding paper. 90 | 91 | ## Authors 92 | 93 | * [Trinh, TL Vuong](https://github.com/trinhvg), Kim, Kyungeun and Song, Boram [Jin Tae Kwak](https://github.com/JinTaeKwak) 94 | 95 | 96 | ## Citation 97 | 98 | If any part of this code is used, please give appropriate citation to our paper.
99 | 100 | BibTex entry:
101 | ``` 102 | @article{le2021joint, 103 | title={Joint categorical and ordinal learning for cancer grading in pathology images}, 104 | author={Le Vuong, Trinh Thi and Kim, Kyungeun and Song, Boram and Kwak, Jin Tae}, 105 | journal={Medical image analysis}, 106 | pages={102206}, 107 | year={2021}, 108 | publisher={Elsevier} 109 | } 110 | ``` 111 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import imgaug # https://github.com/aleju/imgaug 2 | from imgaug import augmenters as iaa 3 | import imgaug as ia 4 | import os 5 | 6 | #### 7 | class Config(object): 8 | def __init__(self, _args=None): 9 | if _args is not None: 10 | self.__dict__.update(_args.__dict__) 11 | self.seed = self.seed 12 | self.init_lr = 1.0e-3 13 | self.lr_steps = 20 # decrease at every n-th epoch 14 | self.gamma = 0.2 15 | self.train_batch_size = 64 16 | self.infer_batch_size = 256 17 | self.nr_classes = 4 18 | self.nr_epochs = 60 19 | self.epoch_length = 50 20 | 21 | # nr of processes for parallel processing input 22 | self.nr_procs_train = 8 23 | self.nr_procs_valid = 8 24 | 25 | self.nr_fold = 5 26 | self.fold_idx = 0 27 | self.cross_valid = False 28 | 29 | self.load_network = False 30 | self.save_net_path = "" 31 | 32 | # 33 | self.dataset = 'colon_manual' 34 | self.logging = True # True for debug run only 35 | 36 | self.log_path = '/media/data1/trinh_2021/data/workspace_data/join_learning_2021/colon/ordinalforcancer_v0/' 37 | 38 | self.chkpts_prefix = 'model' 39 | if _args is not None: 40 | self.__dict__.update(_args.__dict__) 41 | self.task_type = self.run_info.split('_')[0] 42 | self.loss_type = self.run_info.replace(self.task_type + "_", "") 43 | self.model_name = f'/{self.task_type}_{self.loss_type}_cancer_Effi_seed{self.seed}_BS64' 44 | self.log_dir = self.log_path + self.model_name 45 | print(self.model_name) 46 | 47 | def train_augmentors(self): 48 | if self.dataset == "prostate_hv": 49 | shape_augs = [ 50 | iaa.Resize(0.5, interpolation='nearest'), 51 | iaa.CropToFixedSize(width=350, height=350), 52 | ] 53 | else: 54 | shape_augs = [] 55 | # 56 | sometimes = lambda aug: iaa.Sometimes(0.2, aug) 57 | input_augs = iaa.Sequential( 58 | [ 59 | # apply the following augmenters to most images 60 | iaa.Fliplr(0.5), # horizontally flip 50% of all images 61 | iaa.Flipud(0.5), # vertically flip 50% of all images 62 | sometimes(iaa.Affine( 63 | rotate=(-45, 45), # rotate by -45 to +45 degrees 64 | shear=(-16, 16), # shear by -16 to +16 degrees 65 | order=[0, 1], # use nearest neighbour or bilinear interpolation (fast) 66 | cval=(0, 255), # if mode is constant, use a cval between 0 and 255 67 | mode='symmetric' 68 | # use any of scikit-image's warping modes (see 2nd image from the top for examples) 69 | )), 70 | # execute 0 to 5 of the following (less important) augmenters per image 71 | # don't execute all of them, as that would often be way too strong 72 | iaa.SomeOf((0, 5), 73 | [ 74 | iaa.OneOf([ 75 | iaa.GaussianBlur((0, 3.0)), # blur images with a sigma between 0 and 3.0 76 | iaa.AverageBlur(k=(2, 7)), 77 | # blur image using local means with kernel sizes between 2 and 7 78 | iaa.MedianBlur(k=(3, 11)), 79 | # blur image using local medians with kernel sizes between 2 and 7 80 | ]), 81 | iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5), 82 | # add gaussian noise to images 83 | iaa.Dropout((0.01, 0.1), per_channel=0.5), # randomly remove up to 10% of the pixels 84 | # change brightness of images (by -10 to 10 of original value) 85 | iaa.AddToHueAndSaturation((-20, 20)), # change hue and saturation 86 | iaa.LinearContrast((0.5, 2.0), per_channel=0.5), # improve or worsen the contrast 87 | ], 88 | random_order=True 89 | ) 90 | ], 91 | random_order=True 92 | ) 93 | return shape_augs, input_augs 94 | 95 | #### 96 | def infer_augmentors(self): 97 | if self.dataset == "prostate_hv": 98 | shape_augs = [ 99 | iaa.Resize(0.5, interpolation='nearest'), 100 | iaa.CropToFixedSize(width=350, height=350, position="center"), 101 | ] 102 | else: 103 | shape_augs = [] 104 | return shape_augs, None 105 | 106 | ########################################################################### -------------------------------------------------------------------------------- /config_validator.py: -------------------------------------------------------------------------------- 1 | import imgaug # https://github.com/aleju/imgaug 2 | from imgaug import augmenters as iaa 3 | import imgaug as ia 4 | 5 | 6 | #### 7 | class Config(object): 8 | def __init__(self, _args=None): 9 | if _args is not None: 10 | self.__dict__.update(_args.__dict__) 11 | self.seed = 5 #self.seed 12 | self.infer_batch_size = 128 13 | self.nr_classes = 4 14 | 15 | # nr of processes for parallel processing input 16 | self.nr_procs_valid = 8 17 | 18 | self.load_network = False 19 | self.save_net_path = "" 20 | 21 | self.dataset = 'colon_manual' 22 | self.logging = False # True for debug run only 23 | self.log_path = "" 24 | self.chkpts_prefix = 'model' 25 | self.model_name = 'validator' 26 | self.log_dir = self.log_path + self.model_name 27 | print(self.model_name) 28 | 29 | #### 30 | def infer_augmentors(self): 31 | if self.dataset == "prostate_hv": 32 | shape_augs = [ 33 | iaa.Resize(0.5, interpolation='nearest'), 34 | iaa.CropToFixedSize(width=350, height=350, position="center"), 35 | ] 36 | else: 37 | shape_augs = [] 38 | return shape_augs, None 39 | 40 | ############################################################################ -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import glob 4 | import random 5 | from collections import Counter 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch.utils.data as data 10 | from torchvision import transforms 11 | from imgaug import augmenters as iaa 12 | 13 | #### 14 | 15 | 16 | class DatasetSerial(data.Dataset): 17 | 18 | def __init__(self, pair_list, shape_augs=None, input_augs=None, has_aux=False, test_aux=False): 19 | self.test_aux = test_aux 20 | self.pair_list = pair_list 21 | self.shape_augs = shape_augs 22 | self.input_augs = input_augs 23 | 24 | def __getitem__(self, idx): 25 | pair = self.pair_list[idx] 26 | # print(pair) 27 | input_img = cv2.imread(pair[0]) 28 | input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB) 29 | img_label = pair[1] 30 | # print(input_img.shape) 31 | transform = transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize(mean=[0., 0., 0.], 34 | std=[1., 1., 1.]) 35 | ]) 36 | 37 | if not self.test_aux: 38 | 39 | # shape must be deterministic so it can be reused 40 | if self.shape_augs is not None: 41 | shape_augs = self.shape_augs.to_deterministic() 42 | input_img = shape_augs.augment_image(input_img) 43 | 44 | # additional augmenattion just for the input 45 | if self.input_augs is not None: 46 | input_img = self.input_augs.augment_image(input_img) 47 | 48 | input_img = np.array(input_img).copy() 49 | transform = transforms.Compose([ 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean=[0., 0., 0.], 52 | std=[1., 1., 1.]) 53 | ]) 54 | 55 | out_img = np.array(transform(input_img)).transpose(1, 2, 0) 56 | else: 57 | out_img = [] 58 | for idx in range(5): 59 | input_img_ = input_img.copy() 60 | if self.shape_augs is not None: 61 | shape_augs = self.shape_augs.to_deterministic() 62 | input_img_ = shape_augs.augment_image(input_img_) 63 | input_img_ = iaa.Sequential(self.input_augs[idx]).augment_image(input_img_) 64 | input_img_ = np.array(input_img_).copy() 65 | input_img_ = np.array(transform(input_img_)).transpose(1, 2, 0) 66 | out_img.append(input_img_) 67 | return np.array(out_img), img_label 68 | 69 | def __len__(self): 70 | return len(self.pair_list) 71 | 72 | 73 | class DatasetSerialWSI(data.Dataset): 74 | def __init__(self, path_list): 75 | self.path_list = path_list 76 | 77 | def __getitem__(self, idx): 78 | input_img = cv2.imread(self.path_list[idx]) 79 | input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB) 80 | input_img = np.array(input_img).copy() 81 | transform = transforms.Compose([ 82 | transforms.ToTensor(), 83 | transforms.Normalize(mean=[0., 0., 0.], 84 | std=[1., 1., 1.]) 85 | ]) 86 | input_img = np.array(transform(input_img)).transpose(1, 2, 0) 87 | location = self.path_list[idx].split('/')[-1].split('.')[0].split('_') 88 | return input_img, location 89 | 90 | def __len__(self): 91 | return len(self.path_list) 92 | 93 | def prepare_colon_tma_data(): 94 | def load_data_info(pathname, parse_label=True, label_value=0): 95 | file_list = glob.glob(pathname) 96 | cancer_test = False 97 | if cancer_test: 98 | file_list_bn = glob.glob(pathname.replace('*.jpg', '*0.jpg')) 99 | file_list = [elem for elem in file_list if elem not in file_list_bn] 100 | label_list = [int(file_path.split('_')[-1].split('.')[0])-1 for file_path in file_list] 101 | else: 102 | if parse_label: 103 | label_list = [int(file_path.split('_')[-1].split('.')[0]) for file_path in file_list] 104 | else: 105 | label_list = [label_value for file_path in file_list] 106 | print(Counter(label_list)) 107 | return list(zip(file_list, label_list)) 108 | 109 | data_root_dir = '/media/data1/member1/projects/workspace_data/COLON_MANUAL_512/COLON_MANUAL_512' 110 | 111 | set_1010711 = load_data_info('%s/1010711/*.jpg' % data_root_dir) 112 | set_1010712 = load_data_info('%s/1010712/*.jpg' % data_root_dir) 113 | set_1010713 = load_data_info('%s/1010713/*.jpg' % data_root_dir) 114 | set_1010714 = load_data_info('%s/1010714/*.jpg' % data_root_dir) 115 | set_1010715 = load_data_info('%s/1010715/*.jpg' % data_root_dir) 116 | set_1010716 = load_data_info('%s/1010716/*.jpg' % data_root_dir) 117 | wsi_00016 = load_data_info('%s/wsi_00016/*.jpg' % data_root_dir, parse_label=True, 118 | label_value=0) # benign exclusively 119 | wsi_00017 = load_data_info('%s/wsi_00017/*.jpg' % data_root_dir, parse_label=True, 120 | label_value=0) # benign exclusively 121 | wsi_00018 = load_data_info('%s/wsi_00018/*.jpg' % data_root_dir, parse_label=True, 122 | label_value=0) # benign exclusively 123 | 124 | train_set = set_1010711 + set_1010712 + set_1010713 + set_1010715 + wsi_00016 125 | valid_set = set_1010716 + wsi_00018 126 | test_set = set_1010714 + wsi_00017 127 | return train_set, valid_set, test_set 128 | 129 | 130 | def prepare_colon_wsi_patch(data_visual=False): 131 | def load_data_info_from_list(data_dir, path_list): 132 | file_list = [] 133 | for WSI_name in path_list: 134 | pathname = glob.glob(f'{data_dir}/{WSI_name}/*/*.png') 135 | file_list.extend(pathname) 136 | label_list = [int(file_path.split('_')[-1].split('.')[0]) - 1 for file_path in file_list] 137 | print(Counter(label_list)) 138 | list_out = list(zip(file_list, label_list)) 139 | return list_out 140 | 141 | data_root_dir = '/media/data1/trinh/data/workspace_data/colon_wsi/patches_colon_edit_MD/colon_45WSIs_1144_08_step05_05' 142 | data_visual = '/media/data1/trinh/data/workspace_data/colon_wsi/patches_colon_edit_MD/colon_45WSIs_1144_01_step05_visualize/patch_512/' 143 | 144 | df_test = [] #Note: Will be update later 145 | 146 | if data_visual: 147 | test_set = load_data_info_from_list(data_visual, df_test) 148 | else: 149 | test_set = load_data_info_from_list(data_root_dir, df_test) 150 | return test_set 151 | 152 | 153 | def prepare_prostate_uhu_data(): 154 | def load_data_info(pathname, parse_label=True, label_value=0, cancer_test=False): 155 | file_list = glob.glob(pathname) 156 | 157 | if cancer_test: 158 | file_list_bn = glob.glob(pathname.replace('*.jpg', '*0.jpg')) 159 | file_list = [elem for elem in file_list if elem not in file_list_bn] 160 | label_list = [int(file_path.split('_')[-1].split('.')[0])-1 for file_path in file_list] 161 | else: 162 | if parse_label: 163 | label_list = [int(file_path.split('_')[-1].split('.')[0]) for file_path in file_list] 164 | else: 165 | label_list = [label_value for file_path in file_list] 166 | print(Counter(label_list)) 167 | return list(zip(file_list, label_list)) 168 | 169 | data_root_dir = '/data1/trinh/data/patches_data/prostate_harvard/' 170 | data_root_dir_train = f'{data_root_dir}/train_validation_patches_750/' 171 | data_root_dir_test = f'{data_root_dir}/test_patches_750/' 172 | 173 | train_set_111 = load_data_info('%s/ZT111*/*.jpg' % data_root_dir_train) 174 | train_set_199 = load_data_info('%s/ZT199*/*.jpg' % data_root_dir_train) 175 | train_set_204 = load_data_info('%s/ZT204*/*.jpg' % data_root_dir_train) 176 | valid_set = load_data_info('%s/ZT76*/*.jpg' % data_root_dir_train) 177 | test_set = load_data_info('%s/patho_1/*/*.jpg' % data_root_dir_test) 178 | 179 | train_set = train_set_111 + train_set_199 + train_set_204 180 | return train_set, valid_set, test_set 181 | 182 | 183 | def prepare_prostate_ubc_data(fold_idx=0): 184 | def load_data_info(pathname, parse_label=True, label_value=0): 185 | file_list = glob.glob(pathname) 186 | cancer_test = False 187 | if cancer_test: 188 | file_list_bn = glob.glob(pathname.replace('*.jpg', '*0.jpg')) 189 | file_list = [elem for elem in file_list if elem not in file_list_bn] 190 | label_list = [int(file_path.split('_')[-1].split('.')[0]) for file_path in file_list] 191 | label_dict = {2: 0, 3: 1, 4: 2} 192 | label_list = [label_dict[k] for k in label_list] 193 | else: 194 | if parse_label: 195 | label_list = [int(file_path.split('_')[-1].split('.')[0]) for file_path in file_list] 196 | else: 197 | label_list = [label_value for file_path in file_list] 198 | label_dict = {0: 0, 2: 1, 3: 2, 4: 3} 199 | label_list = [label_dict[k] for k in label_list] 200 | print(Counter(label_list)) 201 | return list(zip(file_list, label_list)) 202 | 203 | assert fold_idx < 3, "Currently only support 5 fold, each fold is 1 TMA" 204 | 205 | data_root_dir = '/data1/trinh/data/patches_data/' 206 | data_root_dir_train_ubc = f'{data_root_dir}/prostate_miccai_2019_patches_690_80_step05_test/' 207 | test_set_ubc = load_data_info('%s/*/*.jpg' % data_root_dir_train_ubc) 208 | return test_set_ubc 209 | 210 | 211 | def visualize(ds, batch_size, nr_steps=100): 212 | data_idx = 0 213 | cmap = plt.get_cmap('jet') 214 | for i in range(0, nr_steps): 215 | if data_idx >= len(ds): 216 | data_idx = 0 217 | for j in range(1, batch_size + 1): 218 | sample = ds[data_idx + j] 219 | if len(sample) == 2: 220 | img = sample[0] 221 | else: 222 | img = sample[0] 223 | # TODO: case with multiple channels 224 | aux = np.squeeze(sample[-1]) 225 | aux = cmap(aux)[..., :3] # gray to RGB heatmap 226 | aux = (aux * 255).astype('unint8') 227 | img = np.concatenate([img, aux], axis=0) 228 | img = cv2.resize(img, (40, 80), interpolation=cv2.INTER_CUBIC) 229 | plt.subplot(1, batch_size, j) 230 | plt.title(str(sample[1])) 231 | plt.imshow(img) 232 | plt.show() 233 | data_idx += batch_size 234 | 235 | 236 | 237 | 238 | -------------------------------------------------------------------------------- /docs/JCOL.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trinhvg/JCO_Learning-pytorch/6a6044801ac0a2f6ea2c479d4417e1128e5a0b72/docs/JCOL.jpg -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: jco_learning 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.7 7 | - pip=20.3.1 8 | - openslide 9 | - pip: 10 | - -r file:requirements.txt 11 | - openslide-python==1.1.2 12 | -------------------------------------------------------------------------------- /infer_produce_predict_map_wsi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import importlib 9 | import glob 10 | 11 | import dataset 12 | from config_validator import Config 13 | from misc.infer_wsi_utils import * 14 | from loss.ceo_loss import count_pred 15 | 16 | 17 | def compute_acc(pred_, ano_): 18 | pred, ano = pred_.copy(), ano_.copy() 19 | pred = pred[ano > 0] 20 | ano = ano[ano > 0] 21 | acc = np.mean(pred == ano) 22 | return np.round(acc, 4) 23 | 24 | 25 | class Inferer(Config): 26 | def __init__(self, _args=None): 27 | super(Inferer, self).__init__(_args=_args) 28 | if _args is not None: 29 | self.__dict__.update(_args.__dict__) 30 | self.run_info = self.run_info 31 | self.net_name = self.run_info 32 | self.net_dir = self.net_dir 33 | self.in_img_path = self.in_img_path 34 | self.in_ano_path = self.in_ano_path 35 | self.in_patch = self.in_patch 36 | self.out_img_path = self.out_img_path 37 | self.net_name = self.net_name 38 | self.infer_batch_size = 256 39 | self.nr_procs_valid = 31 40 | self.patch_size = 1144 41 | self.patch_stride = 1144 // 2 42 | self.nr_classes = 4 43 | 44 | def resize_save(self, svs_code, save_name, img, scale=1.0): 45 | ano = img.copy() 46 | cmap = plt.get_cmap('jet') 47 | path = f'{self.out_img_path}/{svs_code}/' 48 | img = (cmap(img / scale)[..., :3] * 255).astype('uint8') 49 | img[ano == 0] = [10, 10, 10] 50 | img = cv2.resize(img, (0, 0), fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC) 51 | cv2.imwrite(f'{path}/{save_name}.png', cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 52 | return 0 53 | 54 | def infer_step_m(self, net, batch, net_name): 55 | net.eval() # infer mode 56 | 57 | imgs = batch # batch is NHWC 58 | imgs = imgs.permute(0, 3, 1, 2) # to NCHW 59 | 60 | # push data to GPUs and convert to float32 61 | imgs = imgs.to('cuda').float() 62 | 63 | with torch.no_grad(): # dont compute gradient 64 | logit_class, _ = net(imgs) # forward 65 | prob = nn.functional.softmax(logit_class, dim=1) 66 | # prob = prob.permute(0, 2, 3, 1) # to NHWC 67 | return prob.cpu().numpy() 68 | 69 | def infer_step_c(self, net, batch, net_name): 70 | net.eval() # infer mode 71 | 72 | imgs = batch # batch is NHWC 73 | imgs = imgs.permute(0, 3, 1, 2) # to NCHW 74 | 75 | # push data to GPUs and convert to float32 76 | imgs = imgs.to('cuda').float() 77 | 78 | with torch.no_grad(): # dont compute gradient 79 | logit_class = net(imgs) # forward 80 | prob = nn.functional.softmax(logit_class, dim=1) 81 | # prob = prob.permute(0, 2, 3, 1) # to NHWC 82 | return prob.cpu().numpy() 83 | 84 | def infer_step_r(self, net, batch, net_name): 85 | net.eval() # infer mode 86 | 87 | imgs = batch # batch is NHWC 88 | imgs = imgs.permute(0, 3, 1, 2) # to NCHW 89 | 90 | # push data to GPUs and convert to float32 91 | imgs = imgs.to('cuda').float() 92 | 93 | with torch.no_grad(): # dont compute gradient 94 | if "rank_ordinal" in net_name: 95 | logits, probas = net(imgs) 96 | predict_levels = probas > 0.5 97 | pred = torch.sum(predict_levels, dim=1) 98 | return pred.cpu().numpy() 99 | elif "rank_dorn" in net_name: 100 | pred, softmax = net(imgs) 101 | return pred.cpu().numpy() 102 | elif "soft_label" in net_name: 103 | logit_regres = net(imgs) # forward 104 | label = torch.tensor([0., 1. / 3., 2. / 3., 1.]).repeat(len(logit_regres), 1).permute(1, 0).cuda() 105 | idx = torch.argmin(torch.abs(logit_regres - label), 0) 106 | return idx.cpu().numpy() 107 | elif "FocalOrdinal" in net_name: 108 | logit_regress = net(imgs) 109 | pred = count_pred(logit_regress) 110 | return pred.cpu().numpy() 111 | else: 112 | logit_regres = net(imgs) # forward 113 | label = torch.tensor([0., 1., 2., 3.]).repeat(len(logit_regres), 1).permute(1, 0).cuda() 114 | idx = torch.argmin(torch.abs(logit_regres - label), 0) 115 | return idx.cpu().numpy() 116 | 117 | def predict_one_model(self, net, svs_code, net_name="Multi_512_mse"): 118 | infer_step = Inferer.__getattribute__(self, f'infer_step_{net_name[0].lower()}') 119 | ano = np.float32(np.load(f'{self.in_ano_path}/{svs_code}.npy')) # [h, w] 120 | inf_output_dir = f'{self.out_img_path}/{svs_code}/' 121 | if not os.path.isdir(inf_output_dir): 122 | os.makedirs(inf_output_dir) 123 | 124 | path_pairs = glob.glob(f'{self.in_patch}/{svs_code}/*/*.png') 125 | infer_dataset = dataset.DatasetSerialWSI(path_pairs) 126 | dataloader = data.DataLoader(infer_dataset, 127 | num_workers=self.nr_procs_valid, 128 | batch_size=256, 129 | shuffle=False, 130 | drop_last=False) 131 | 132 | out_prob = np.zeros([self.nr_classes, ano.shape[0], ano.shape[1]], dtype=np.float32) # [h, w] 133 | out_prob_count = np.zeros([ano.shape[0], ano.shape[1]], dtype=np.float32) # [h, w] 134 | 135 | for batch_data in dataloader: 136 | imgs_input, imgs_path = batch_data 137 | imgs_path = np.array(imgs_path).transpose(1, 0) 138 | output_prob = infer_step(net, imgs_input, net_name) 139 | for idx, patch_loc in enumerate(imgs_path): 140 | patch_loc = patch_loc.astype(int) // 16 141 | patch_loc = [patch_loc[0], patch_loc[1]] 142 | out_prob_count[patch_loc[0]:patch_loc[0] + self.patch_size // 16, 143 | patch_loc[1]:patch_loc[1] + self.patch_size // 16] += 1 144 | for grade in range(self.nr_classes): 145 | out_prob[grade][patch_loc[0]:patch_loc[0] + self.patch_size // 16, 146 | patch_loc[1]:patch_loc[1] + self.patch_size // 16] += output_prob[idx][grade] 147 | 148 | out_prob_count[out_prob_count == 0.] = 1. 149 | out_prob /= out_prob_count 150 | predict = np.argmax(out_prob, axis=0) + 1 151 | 152 | for c in range(self.nr_classes): 153 | out_prob[c][ano == 0] = 0 154 | predict[ano == 0] = 0 155 | 156 | acc = compute_acc(predict, ano) 157 | print(acc) 158 | 159 | self.resize_save(svs_code, f'predict_{net_name}_{acc}', predict, scale=4.0) 160 | self.resize_save(svs_code, 'ano', ano, scale=4.0) 161 | np.save(f'{self.out_img_path}/{svs_code}/predict_{net_name}', predict) 162 | np.save(f'{self.out_img_path}/{svs_code}/ano', ano) 163 | print('done') 164 | return 0 165 | 166 | def predict_one_model_regress(self, net, svs_code, net_name="Multi_512_mse"): 167 | infer_step = Inferer.__getattribute__(self, f'infer_step_{net_name[0].lower()}') 168 | ano = np.float32(np.load(f'{self.in_ano_path}/{svs_code}.npy')) # [h, w] 169 | inf_output_dir = f'{self.out_img_path}/{svs_code}/' 170 | if not os.path.isdir(inf_output_dir): 171 | os.makedirs(inf_output_dir) 172 | 173 | path_pairs = glob.glob(f'{self.in_patch}/{svs_code}/*/*.png') 174 | infer_dataset = dataset.DatasetSerialWSI(path_pairs) 175 | dataloader = data.DataLoader(infer_dataset, 176 | num_workers=self.nr_procs_valid, 177 | batch_size=128, 178 | shuffle=False, 179 | drop_last=False) 180 | out_prob = np.zeros([self.nr_classes, ano.shape[0], ano.shape[1]], dtype=np.float32) # [h, w] 181 | 182 | for batch_data in dataloader: 183 | imgs_input, imgs_path = batch_data 184 | imgs_path = np.array(imgs_path).transpose(1, 0) 185 | output_prob = infer_step(net, imgs_input, net_name) 186 | for idx, patch_loc in enumerate(imgs_path): 187 | patch_loc = patch_loc.astype(int) // 16 188 | patch_loc = [patch_loc[0], patch_loc[1]] 189 | for grade in range(self.nr_classes): 190 | if grade == output_prob[idx]: 191 | out_prob[grade][patch_loc[0]:patch_loc[0] + self.patch_size // 16, 192 | patch_loc[1]:patch_loc[1] + self.patch_size // 16] += 1 193 | predict = np.argmax(out_prob, axis=0) + 1 194 | 195 | for c in range(self.nr_classes): 196 | out_prob[c][ano == 0] = 0 197 | predict[ano == 0] = 0 198 | 199 | acc = compute_acc(predict, ano) 200 | plt.imshow(predict) 201 | plt.show() 202 | print(acc) 203 | self.resize_save(svs_code, f'predict_{net_name}_{acc}', predict, scale=4.0) 204 | self.resize_save(svs_code, 'ano', ano, scale=4.0) 205 | np.save(f'{self.out_img_path}/{svs_code}/predict_{net_name}', predict) 206 | np.save(f'{self.out_img_path}/{svs_code}/ano', ano) 207 | print('done') 208 | return 0 209 | 210 | def run_wsi(self): 211 | device = 'cuda' 212 | 213 | self.task_type = self.net_name.split('_')[0] 214 | 215 | if "rank_dorn" in self.net_name: 216 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_rank_ordinal') # dynamic import 217 | net = net_def.jl_efficientnet(task_mode='regress_rank_dorn', pretrained=True) 218 | 219 | elif "FocalOrdinalLoss" in self.net_name: 220 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model') # dynamic import 221 | net = net_def.jl_efficientnet(task_mode='class', pretrained=True, num_classes=3) 222 | else: 223 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model') # dynamic import 224 | net = net_def.jl_efficientnet(task_mode=self.task_type.lower(), pretrained=True) 225 | 226 | net = torch.nn.DataParallel(net).to(device) 227 | inf_model_path = os.path.join(self.net_dir, self.net_name, f'trained_net.pth') 228 | saved_state = torch.load(inf_model_path) 229 | net.load_state_dict(saved_state) 230 | 231 | name_wsi_list = findExtension(self.in_ano_path, '.npy') 232 | 233 | for name in name_wsi_list: 234 | svs_code = name[:-4] 235 | print(svs_code) 236 | acc_wsi = [] 237 | if 'REGRESS' in self.net_name: 238 | acc_one_model = self.predict_one_model_regress(net, svs_code, net_name=self.net_name) 239 | else: 240 | acc_one_model = self.predict_one_model(net, svs_code, net_name=self.net_name) 241 | acc_wsi.append(acc_one_model) 242 | 243 | 244 | #### 245 | if __name__ == '__main__': 246 | parser = argparse.ArgumentParser() 247 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') 248 | parser.add_argument('--run_info', type=str, default='REGRESS_rank_dorn', 249 | help='CLASS, REGRESS, MULTI + loss, ' 250 | 'loss ex: Class_ce, MULTI_mtmr, REGRESS_rank_ordinal, REGRESS_rank_dorn' 251 | 'REGRESS_FocalOrdinalLoss, REGRESS_soft_ordinal') 252 | parser.add_argument('--net_dir', type=str, 253 | default='/media/trinh/Data0/submit_paper_data/JL_pred/model/JL_model/JL_colon_model/', 254 | help='path to checkpoint model') 255 | parser.add_argument('--in_img_path', type=str, 256 | default='/media/data1/trinh/data/workspace_data/colon_wsi/ColonWSI/', 257 | help='path to wsi image') 258 | parser.add_argument('--in_ano_path', type=str, 259 | default='/media/data1/trinh/data/workspace_data/colon_wsi/Colon_WSI_annotation_npy/', 260 | help='path to wsi npy annotation') 261 | parser.add_argument('--in_patch', type=str, 262 | default='/media/data1/trinh/data/workspace_data/colon_wsi/patches_colon/colon_45WSIs_1144_01_step05_visualize_resize512/', 263 | help='path to patch image') 264 | parser.add_argument('--out_img_path', type=str, 265 | default='/media/data1/trinh/data/workspace_data/colon_wsi/JointLearning_wsi_pred/', 266 | help='path to patch image') 267 | 268 | parser = argparse.ArgumentParser() 269 | args = parser.parse_args() 270 | inferer = Inferer(_args=args) 271 | inferer.run_wsi() 272 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trinhvg/JCO_Learning-pytorch/6a6044801ac0a2f6ea2c479d4417e1128e5a0b72/loss/__init__.py -------------------------------------------------------------------------------- /loss/cancer_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def mae_cancer_v0(input, target): 8 | input_ = input[target != 0] 9 | target_ = target[target != 0] 10 | return F.l1_loss(input_, target_) if len(target_) != 0 else 0 11 | 12 | 13 | # def mse_cancer(input, target): 14 | # input_ = input[target != 0] 15 | # target_ = target[target != 0] 16 | # return F.mse_loss(input_, target_) if len(target_) != 0 else 0 17 | 18 | 19 | def mse_cancer_v0(input, target): 20 | input_ = input[target != 0] 21 | target_ = target[target != 0] 22 | return F.mse_loss(input_, target_) if len(target_) != 0 else 0 23 | 24 | 25 | def ceo_cancer_v0(input, target): 26 | input_ = input[target != 0] 27 | target_ = target[target != 0] 28 | if len(target_) == 0: 29 | return 0 30 | label_ = torch.tensor([1., 2., 3.]).repeat(len(target_), 1).cuda() 31 | logit_proposed_ = input_.repeat(3, 1).permute(1, 0) 32 | logit_proposed_ = torch.abs(logit_proposed_ - label_) 33 | return F.cross_entropy(-logit_proposed_, target_ - 1) 34 | 35 | def mae_cancer(input, target): 36 | mae_loss = F.l1_loss(input, target, reduction='none') 37 | select = torch.randint(0, 2, (target.shape[0],)).float().cuda() * torch.sign(target) 38 | return (mae_loss*select).mean() 39 | 40 | 41 | # def mse_cancer(input, target): 42 | # input_ = input[target != 0] 43 | # target_ = target[target != 0] 44 | # return F.mse_loss(input_, target_) if len(target_) != 0 else 0 45 | 46 | 47 | def mse_cancer(input, target): 48 | mse_loss = F.mse_loss(input, target, reduction='none') 49 | # print(mse_loss.shape) 50 | # print(torch.sign(target).shape) 51 | # print(torch.sign(target)) 52 | select = torch.randint(0, 2, (target.shape[0],)).float().cuda() * torch.sign(target) 53 | return (mse_loss*select).mean() 54 | 55 | 56 | def ceo_cancer(input, target): 57 | label = torch.tensor([0., 1., 2., 3.]).repeat(len(target), 1).cuda() 58 | logit_proposed = input.repeat(4, 1).permute(1, 0) 59 | logit_proposed = torch.abs(logit_proposed - label) 60 | ceo_loss = F.cross_entropy(-logit_proposed, target, reduction='none') 61 | # select = (torch.randint(0, 2, (target.shape[0],)).cuda() * torch.sign(target)).float() 62 | select = torch.sign(target).float() 63 | return (ceo_loss*select).mean() 64 | 65 | # class CeoCancer: 66 | # def __init__(self, ): 67 | # super(CeoCancer, self).__init__() 68 | -------------------------------------------------------------------------------- /loss/ceo_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | class CEOLoss(nn.Module): 8 | """ 9 | Args: 10 | num_classes (int): number of classes. 11 | """ 12 | def __init__(self, num_classes=4): 13 | super(CEOLoss, self).__init__() 14 | self.num_classes = num_classes 15 | self.level = torch.arange(self.num_classes) 16 | 17 | def forward(self, x, y): 18 | """" 19 | Args: 20 | x (tensor): Regression/ordinal output, size (B), type: float 21 | y (tensor): Ground truth, size (B), type: int/long 22 | 23 | Returns: 24 | CEOLoss: Cross-Entropy Ordinal loss 25 | """ 26 | levels = self.level.repeat(len(y), 1).cuda() 27 | logit = x.repeat(self.num_classes, 1).permute(1, 0) 28 | logit = torch.abs(logit - levels) 29 | return F.cross_entropy(-logit, y, reduction='mean') 30 | 31 | 32 | 33 | class FocalLoss(nn.Module): 34 | def __init__(self, alpha=1, gamma=2, logits=False, reduce=True): 35 | super(FocalLoss, self).__init__() 36 | self.alpha = alpha 37 | self.gamma = gamma 38 | self.logits = logits 39 | self.reduce = reduce 40 | 41 | def forward(self, inputs, targets): 42 | if self.logits: 43 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) 44 | else: 45 | BCE_loss = F.cross_entropy(inputs, targets, reduce=None, reduction='none') 46 | pt = torch.exp(-BCE_loss) 47 | F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss 48 | 49 | if self.reduce: 50 | return torch.mean(F_loss) 51 | else: 52 | return F_loss 53 | 54 | 55 | class SoftLabelOrdinalLoss(nn.Module): 56 | def __init__(self, alpha=1.): 57 | super(SoftLabelOrdinalLoss, self).__init__() 58 | self.alpha = alpha 59 | 60 | def forward(self, x, y): 61 | """Validates model name. 62 | 63 | Args: 64 | x (Tensor): [0, 1, 2, 3] 65 | y (Tensor): [0, 1, 2, 3] 66 | 67 | Returns: 68 | loss: scalar 69 | """ 70 | # y /= 3 71 | # y /= 2 72 | x = torch.sigmoid(x) 73 | soft_loss = -(1 - y) * torch.log(1 - x) - self.alpha * y * torch.log(x) 74 | return torch.mean(soft_loss) 75 | 76 | 77 | 78 | def label_to_levels(label, num_classes=4): 79 | levels = [1] * label + [0] * (num_classes - 1 - label) 80 | levels = torch.tensor(levels, dtype=torch.float32) 81 | return levels 82 | 83 | 84 | def labels_to_labels(class_labels, num_classes =4): 85 | """ 86 | class_labels = [2, 1, 3] 87 | """ 88 | levels = [] 89 | for label in class_labels: 90 | levels_from_label = label_to_levels(int(label), num_classes=num_classes) 91 | levels.append(levels_from_label) 92 | return torch.stack(levels).cuda() 93 | 94 | 95 | def cost_fn(logits, label): 96 | num_classes = 3 #Note 97 | imp = torch.ones(num_classes - 1, dtype=torch.float).cuda() 98 | levels = labels_to_labels(label, num_classes) 99 | val = (-torch.sum((F.log_softmax(logits, dim=2)[:, :, 1] * levels 100 | + F.log_softmax(logits, dim=2)[:, :, 0] * (1 - levels)) * imp, dim=1)) 101 | return torch.mean(val) 102 | 103 | 104 | def loss_fn2(logits, label): 105 | num_classes = 3 #Note 106 | imp = torch.ones(num_classes - 1, dtype=torch.float) 107 | levels = labels_to_labels(label) 108 | val = (-torch.sum((F.logsigmoid(logits) * levels 109 | + (F.logsigmoid(logits) - logits) * (1 - levels)) * imp, 110 | dim=1)) 111 | return torch.mean(val) 112 | 113 | 114 | class FocalOrdinalLoss(nn.Module): 115 | def __init__(self, alpha=0.75, pooling=False, num_classes=4): 116 | super(FocalOrdinalLoss, self).__init__() 117 | self.alpha = alpha 118 | self.pooling = pooling 119 | self.num_classes = num_classes 120 | 121 | def forward(self, x, y): 122 | # convert one-hot y to ordinal y 123 | levels = labels_to_labels(y, num_classes=self.num_classes) 124 | q, _ = torch.max(levels*(1-x)**2 + (1-levels)*x**2, dim=1) 125 | if self.pooling: 126 | q = q.unsqueeze(0) 127 | q = q.unsqueeze(0) 128 | q = nn.MaxPool1d(3, 1, padding=1)(q) 129 | x = torch.sigmoid(x) 130 | # compute the loss 131 | f_loss = q*torch.sum(-self.alpha*levels*torch.log(x) - (1-self.alpha)*(1-levels)*torch.log(1-x)) 132 | return torch.mean(f_loss) 133 | 134 | 135 | 136 | 137 | 138 | def count_pred(x): 139 | N = x.shape[0] 140 | x = x.cuda() > 0.5 141 | pred = torch.zeros(N).long().cuda() 142 | pred = pred.view(N, 1) 143 | for i in range(x.shape[1]): 144 | pred_i = x[:, :i+1].prod(1)*x[:, :i+1].sum(1) 145 | pred = torch.cat([pred, pred_i.view(N, 1)], dim =1) 146 | return pred.max(1)[0] 147 | 148 | 149 | # # 150 | # import os 151 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' 152 | # def test(): 153 | # # x = torch.Tensor([[0.7, 0.5, 0.6], [0.5, 0.8, 0.2], [0.8, 0.6, 0.1], [0.1, 0.5, 0.6]]) 154 | # # y = torch.Tensor([1., 2., 3., 0.]) 155 | # # x = x.to("cuda") 156 | # # y = y.to("cuda") 157 | # # FocalOrdinalLoss()(x, y) 158 | # # count_pred(x) 159 | # 160 | # 161 | # x = torch.Tensor([0.7, 2., 0.6, 1.]) 162 | # y = torch.Tensor([1, 2, 3, 0]) 163 | # x = x.to("cuda") 164 | # y = y.to("cuda") 165 | # CEOLoss(4)(x, y) 166 | # count_pred(x) 167 | # 168 | # test() 169 | # 170 | # # 171 | # # 172 | # # 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | -------------------------------------------------------------------------------- /loss/dorn_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | """ 5 | refer to https://github.com/liviniuk/DORN_depth_estimation_Pytorch 6 | """ 7 | 8 | 9 | class OrdinalLoss(nn.Module): 10 | """ 11 | Ordinal loss as defined in the paper "DORN for Monocular Depth Estimation". 12 | refer to https://github.com/liviniuk/DORN_depth_estimation_Pytorch 13 | """ 14 | 15 | def __init__(self): 16 | super(OrdinalLoss, self).__init__() 17 | 18 | def forward(self, pred_softmax, target_labels): 19 | """ 20 | :param pred_softmax: predicted softmax probabilities P 21 | :param target_labels: ground truth ordinal labels 22 | :return: ordinal loss 23 | """ 24 | 25 | n, c = pred_softmax.size() # C - number of discrete sub-intervals (= number of channels) 26 | target_labels = target_labels.int().view(n, 1) 27 | 28 | K = torch.zeros((n, c), dtype=torch.int).cuda() 29 | for i in range(c): 30 | K[:, i] = K[:, i] + i * torch.ones(n, dtype=torch.int).cuda() 31 | 32 | mask = (K <= target_labels).detach() 33 | 34 | loss = pred_softmax[mask].clamp(1e-8, 1e8).log().sum() + (1 - pred_softmax[~mask]).clamp(1e-8, 1e8).log().sum() 35 | loss /= -n 36 | return loss 37 | -------------------------------------------------------------------------------- /loss/mtmr_loss.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | https://github.com/liulihao-cuhk/MTMR-NET 5 | """ 6 | import os 7 | from torch.autograd import Variable 8 | from collections import OrderedDict 9 | import torch.nn as nn 10 | import numpy as np 11 | import torch 12 | import math 13 | 14 | def get_loss_mtmr(output_score_1, cat_subtlety_score, gt_score_1, gt_attribute_score_1): 15 | xcentloss_func_1 = nn.CrossEntropyLoss() 16 | xcentloss_1 = xcentloss_func_1(output_score_1, gt_score_1) 17 | 18 | # ranking loss 19 | ranking_loss_sum = 0 20 | half_size_of_output_score = output_score_1.size()[0] // 2 21 | for i in range(half_size_of_output_score): 22 | tmp_output_1 = output_score_1[i] 23 | tmp_output_2 = output_score_1[i + half_size_of_output_score] 24 | tmp_gt_score_1 = gt_score_1[i] 25 | tmp_gt_score_2 = gt_score_1[i + half_size_of_output_score] 26 | 27 | rankingloss_func = nn.MarginRankingLoss() 28 | 29 | if tmp_gt_score_1.item() != tmp_gt_score_2.item(): 30 | target = torch.ones(1) * -1 31 | ranking_loss_sum += rankingloss_func(tmp_output_1, tmp_output_2, Variable(target.cuda())) 32 | else: 33 | target = torch.ones(1) 34 | ranking_loss_sum += rankingloss_func(tmp_output_1, tmp_output_2, Variable(target.cuda())) 35 | 36 | ranking_loss = ranking_loss_sum / half_size_of_output_score 37 | 38 | # attribute loss 39 | attribute_mseloss_func_1 = nn.MSELoss() 40 | attribute_mseloss_1 = attribute_mseloss_func_1(cat_subtlety_score, gt_attribute_score_1.float()) 41 | 42 | loss = 1 * xcentloss_1 + 5.0e-1 * ranking_loss + 1.0e-3 * attribute_mseloss_1 43 | 44 | return loss 45 | -------------------------------------------------------------------------------- /loss/rank_ordinal_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import pandas as pd 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | """ 9 | https://github.com/Raschka-research-group/coral-cnn/blob/master/coral-implementation-recipe.ipynb 10 | """ 11 | 12 | 13 | def label_to_levels(label, num_classes=4): 14 | levels = [1] * label + [0] * (num_classes - 1 - label) 15 | levels = torch.tensor(levels, dtype=torch.float32) 16 | return levels 17 | 18 | 19 | def labels_to_labels(class_labels, num_classes): 20 | """ 21 | class_labels = [2, 1, 3] 22 | """ 23 | levels = [] 24 | for label in class_labels: 25 | levels_from_label = label_to_levels(int(label), num_classes=num_classes) 26 | levels.append(levels_from_label) 27 | return torch.stack(levels).cuda() 28 | 29 | 30 | def cost_fn(logits, label, num_classes): 31 | imp = torch.ones(num_classes - 1, dtype=torch.float).cuda() 32 | levels = labels_to_labels(label, num_classes) 33 | val = (-torch.sum((F.log_softmax(logits, dim=2)[:, :, 1] * levels 34 | + F.log_softmax(logits, dim=2)[:, :, 0] * (1 - levels)) * imp, dim=1)) 35 | return torch.mean(val) 36 | 37 | 38 | def loss_fn2(logits, label): 39 | num_classes = 4 40 | imp = torch.ones(num_classes - 1, dtype=torch.float) 41 | levels = labels_to_labels(label) 42 | val = (-torch.sum((F.logsigmoid(logits) * levels 43 | + (F.logsigmoid(logits) - logits) * (1 - levels)) * imp, 44 | dim=1)) 45 | return torch.mean(val) 46 | -------------------------------------------------------------------------------- /misc/infer_wsi_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil # High-level file operations 3 | from itertools import chain 4 | from sklearn.metrics import f1_score 5 | import random 6 | import cv2 7 | import numpy as np 8 | import torch.utils.data as data 9 | from torchvision import transforms 10 | 11 | 12 | def color_mask(a, r, g, b): 13 | ch_r = a[..., 0] == r 14 | ch_g = a[..., 1] == g 15 | ch_b = a[..., 2] == b 16 | return ch_r & ch_g & ch_b 17 | 18 | 19 | def normalize(mask, dtype=np.uint8): 20 | return (255 * mask / np.amax(mask)).astype(dtype) 21 | 22 | 23 | def bounding_box(img): 24 | rows = np.any(img, axis=1) 25 | cols = np.any(img, axis=0) 26 | rmin, rmax = np.where(rows)[0][[0, -1]] 27 | cmin, cmax = np.where(cols)[0][[0, -1]] 28 | return rmin, rmax, cmin, cmax 29 | 30 | 31 | def cropping_center(x, crop_shape, batch=False): 32 | orig_shape = x.shape 33 | if not batch: 34 | h0 = int((orig_shape[0] - crop_shape[0]) * 0.5) 35 | w0 = int((orig_shape[1] - crop_shape[1]) * 0.5) 36 | x = x[h0:h0 + crop_shape[0], w0:w0 + crop_shape[1]] 37 | else: 38 | h0 = int((orig_shape[1] - crop_shape[0]) * 0.5) 39 | w0 = int((orig_shape[2] - crop_shape[1]) * 0.5) 40 | x = x[:, h0:h0 + crop_shape[0], w0:w0 + crop_shape[1]] 41 | return x 42 | 43 | 44 | # to make it easier for visualization 45 | def randomize_label(label_map): 46 | label_list = np.unique(label_map) 47 | label_list = label_list[1:] # exclude the background 48 | label_rand = list(label_list) # dup frist cause shuffle is done in place 49 | random.shuffle(label_rand) 50 | new_map = np.zeros(label_map.shape, dtype=label_map.dtype) 51 | 52 | 53 | """Recursive directory creation function. Like mkdir(), 54 | but makes all intermediate-level directories needed to contain the leaf directory. 55 | A leaf is a node on a tree with no child nodes.""" 56 | 57 | 58 | def rm_n_mkdir(dir): 59 | if os.path.isdir(dir): 60 | shutil.rmtree(dir) 61 | os.makedirs(dir) 62 | 63 | 64 | ### 65 | # test 66 | 67 | # import cv2 68 | # import matplotlib.pyplot as plt 69 | # 70 | # img = cv2.imread('/media/vtltrinh/Data1/COLON_MANUAL_PATCHES/v1/1010711/000_3.jpg') 71 | # im = np.array(img) 72 | # im_mask = color_mask(im, 1, 1, 1) 73 | # 74 | # bound = bounding_box(im) 75 | # print(bound) 76 | 77 | def findExtension(directory, extension='.txt'): 78 | files = [] 79 | for file in os.listdir(directory): 80 | if file.endswith(extension): 81 | files += [file] 82 | files.sort() 83 | return files 84 | 85 | 86 | def generate_patch_list_(roi, patch_size, stride): 87 | min_height, min_width, max_height, max_width = roi 88 | min_height, min_width, max_height, max_width = min_height - stride, min_width - stride, max_height + stride, max_width + stride 89 | h_list = np.arange(min_height, max_height - patch_size, stride) 90 | w_list = np.arange(min_width, max_width - patch_size, stride) 91 | out = [[[h_list[h], w_list[w]] for w in range(len(w_list))] for h in range(len(h_list))] 92 | return list(chain(*out)) 93 | 94 | 95 | def generate_patch_list(ano, roi, patch_size, stride): 96 | min_height, min_width, max_height, max_width = roi 97 | min_height, min_width, max_height, max_width = min_height - stride, min_width - stride, max_height + stride, max_width + stride 98 | h_list = np.arange(min_height, max_height - patch_size, stride) 99 | w_list = np.arange(min_width, max_width - patch_size, stride) 100 | out = [[[h_list[h], w_list[w]] for w in range(len(w_list))] for h in range(len(h_list))] 101 | path_list = list(chain(*out)) 102 | # print(len(path_list)) 103 | infer_dataset = DatasetSelectPatch(ano, path_list, patch_size) 104 | path_loader = data.DataLoader(infer_dataset, num_workers=31, batch_size=1144, shuffle=False, drop_last=False) 105 | for keeps, loca in path_loader: 106 | keeps_ = keeps.to('cuda') 107 | keeps_ += 1 108 | for idx in range(len(keeps)): 109 | if keeps[idx] == 1: 110 | a = eval(loca[idx]) 111 | path_list.remove(a) 112 | # print('hi', len(path_list)) 113 | return path_list 114 | 115 | 116 | def read_ano_text(text_path): 117 | list_labels = { 118 | "BG": 0, 119 | "BN": 1, 120 | "WD": 2, 121 | "MD": 3, 122 | "PD": 4, 123 | "Ad": 5, 124 | } 125 | text_file = open(text_path, "r") 126 | lines = text_file.readlines() 127 | lines = [line.replace('\n', '').replace('\t', '') for line in lines] 128 | anos_dict = {} 129 | count_ROIs = np.zeros(shape=5, dtype=int) 130 | for label in list_labels: 131 | anos_dict.__setitem__(label, {}) 132 | 133 | for line in lines[1:-1]: 134 | if line[1:3] in list_labels: 135 | label_id = line[1:3] 136 | coordinates = [] 137 | count_ROIs[list_labels[label_id] - 1] += 1 138 | ROIs_id = count_ROIs[list_labels[label_id] - 1] 139 | else: 140 | if 'X' in line: 141 | dims_val = eval(line.replace("},", "}")) 142 | coordinates.append([int(dims_val[dim]) for dim in dims_val.keys()]) 143 | else: 144 | anos_dict[label_id].__setitem__(ROIs_id, coordinates) 145 | 146 | keys_to_remove = ["BG", "Ad"] 147 | for key in keys_to_remove: 148 | del anos_dict[key] 149 | return anos_dict 150 | 151 | 152 | def find_roi(anos_dict): 153 | min_height = [] 154 | min_width = [] 155 | max_height = [] 156 | max_width = [] 157 | valid_ano = ['BN', 'WD', 'MD', 'PD'] 158 | for label_key in anos_dict.keys(): 159 | if label_key in valid_ano: 160 | for polygon_key in anos_dict[label_key]: 161 | region = anos_dict[label_key][polygon_key] 162 | min_height.append(np.int32([region])[0, :, 1].min()) # np(height, width) while openslide (with,height) 163 | min_width.append(np.int32([region])[0, :, 0].min()) 164 | max_height.append(np.int32([region])[0, :, 1].max()) # np(height, width) while openslide (with,height) 165 | max_width.append(np.int32([region])[0, :, 0].max()) 166 | min_height = min(min_height) 167 | min_width = min(min_width) 168 | max_height = max(max_height) 169 | max_width = max(max_width) 170 | return [min_height, min_width, max_height, max_width] 171 | 172 | 173 | def compute_f1(pred, ano): 174 | pred, ano = pred.flatten(), ano.flatten() 175 | pred = pred[ano != 0] 176 | ano = ano[ano != 0] 177 | f1 = f1_score(ano, pred, average='macro', labels=np.unique(ano)) 178 | return int(f1 * 10000) 179 | 180 | 181 | class DatasetSelectPatch(data.Dataset): 182 | def __init__(self, ano, path_list, patch_size): 183 | self.ano = ano 184 | self.path_list = path_list 185 | self.patch_size = patch_size 186 | 187 | def __getitem__(self, idx): 188 | w = self.path_list[idx][0]//16 189 | h = self.path_list[idx][1]//16 190 | patch_size = self.patch_size//16 191 | input_img = self.ano[w: w + patch_size, h: h + patch_size] 192 | 193 | if input_img.size == 0: 194 | keep = np.array([0]) 195 | elif input_img.mean() > 0: 196 | keep = np.array([1]) 197 | else: 198 | keep = np.array([0]) 199 | return keep, str(self.path_list[idx]) 200 | 201 | def __len__(self): 202 | return len(self.path_list) 203 | -------------------------------------------------------------------------------- /misc/train_ultils_all_iter.py: -------------------------------------------------------------------------------- 1 | import io 2 | import itertools 3 | import json 4 | import os 5 | 6 | import random 7 | import re 8 | import shutil 9 | import textwrap 10 | 11 | import cv2 12 | import matplotlib 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import pandas as pd 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from sklearn.metrics import confusion_matrix 20 | from termcolor import colored 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | import imgaug as ia 26 | 27 | 28 | def check_manual_seed(seed): 29 | """ 30 | If manual seed is not specified, choose a random one and notify it to the user 31 | """ 32 | seed = seed 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | torch.manual_seed(seed) 36 | torch.cuda.manual_seed(seed) 37 | ia.seed(seed) 38 | torch.cuda.manual_seed_all(seed) 39 | torch.backends.cudnn.benchmark = False 40 | torch.backends.cudnn.deterministic = True 41 | 42 | print('Using manual seed: {seed}'.format(seed=seed)) 43 | return 44 | 45 | 46 | def check_log_dir(log_dir): 47 | # check if log dir exist 48 | if os.path.isdir(log_dir): 49 | color_word = colored('WARMING', color='red', attrs=['bold', 'blink']) 50 | print('%s: %s exist!' % (color_word, colored(log_dir, attrs=['underline']))) 51 | while (True): 52 | print('Select Action: d (delete)/ q (quit)', end='') 53 | key = input() 54 | if key == 'd': 55 | shutil.rmtree(log_dir) 56 | break 57 | elif key == 'q': 58 | exit() 59 | else: 60 | color_word = colored('ERR', color='red') 61 | print('---[%s] Unrecognized character!' % color_word) 62 | return 63 | 64 | 65 | def plot_confusion_matrix(conf_mat, label): 66 | """ 67 | Parameters: 68 | title='Confusion matrix' : Title for your matrix 69 | tensor_name = 'MyFigure/image' : Name for the output summay tensor 70 | Returns: 71 | summary: image of plot figure 72 | Other items to note: 73 | - Depending on the number of category and the data , you may have to modify the figzie, font sizes etc. 74 | - Currently, some of the ticks dont line up due to rotations. 75 | """ 76 | 77 | cm = conf_mat 78 | 79 | np.set_printoptions(precision=2) # print numpy array with 2 decimal places 80 | 81 | fig = matplotlib.figure.Figure(figsize=(7, 7), dpi=320, facecolor='w', edgecolor='k') 82 | ax = fig.add_subplot(1, 1, 1) 83 | im = ax.imshow(cm, cmap='Oranges') 84 | 85 | classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in label] 86 | classes = ['\n'.join(textwrap.wrap(l, 40)) for l in classes] 87 | 88 | tick_marks = np.arange(len(classes)) 89 | 90 | ax.set_xlabel('Predicted', fontsize=7) 91 | ax.set_xticks(tick_marks) 92 | c = ax.set_xticklabels(classes, fontsize=4, rotation=-90, ha='center') 93 | ax.xaxis.set_label_position('bottom') 94 | ax.xaxis.tick_bottom() 95 | 96 | ax.set_ylabel('True Label', fontsize=7) 97 | ax.set_yticks(tick_marks) 98 | ax.set_yticklabels(classes, fontsize=4, va='center') 99 | ax.yaxis.set_label_position('left') 100 | ax.yaxis.tick_left() 101 | 102 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 103 | ax.text(j, i, format(cm[i, j], 'd') if cm[i, j] != 0 else '.', 104 | horizontalalignment="center", fontsize=6, 105 | verticalalignment='center', color="black") 106 | fig.set_tight_layout(True) 107 | 108 | fig.canvas.draw() 109 | w, h = fig.canvas.get_width_height() 110 | 111 | # get PNG data from the figure 112 | png_buffer = io.BytesIO() 113 | fig.canvas.print_png(png_buffer) 114 | png_encoded = png_buffer.getvalue() 115 | png_buffer.close() 116 | 117 | return png_encoded 118 | 119 | 120 | #### 121 | def update_log(output, epoch, prefix, color, tfwriter, log_file, logging): 122 | # print values and convert 123 | max_length = len(max(output.keys(), key=len)) 124 | for metric in output: 125 | key = colored(prefix + '-' + metric.ljust(max_length), color) 126 | print('------%s : ' % key, end='') 127 | if metric not in ['conf_mat_c', 'conf_mat_r', 'box_plot_data']: 128 | print('%0.7f' % output[metric]) 129 | elif metric == 'conf_mat_c': 130 | conf_mat_c = output['conf_mat_c'] # use pivot to turn back 131 | conf_mat_c_df = pd.DataFrame(conf_mat_c) 132 | conf_mat_c_df.index.name = 'True' 133 | conf_mat_c_df.columns.name = 'Pred' 134 | output['conf_mat_c'] = conf_mat_c_df 135 | print('\n', conf_mat_c_df) 136 | elif metric == 'conf_mat_r': 137 | conf_mat_r = output['conf_mat_r'] # use pivot to turn back 138 | conf_mat_r_df = pd.DataFrame(conf_mat_r) 139 | conf_mat_r_df.index.name = 'True' 140 | conf_mat_r_df.columns.name = 'Pred' 141 | output['conf_mat_r'] = conf_mat_r_df 142 | print('\n', conf_mat_r_df) 143 | elif metric == 'box_plot_data': 144 | box_plot_data = output['box_plot_data'] # use pivot to turn back 145 | box_plot_data_df = pd.DataFrame(box_plot_data) 146 | box_plot_data_df.columns.name = 'Pred' 147 | output['box_plot_data'] = box_plot_data_df 148 | 149 | if not logging: 150 | return 151 | 152 | # create stat dicts 153 | stat_dict = {} 154 | for metric in output: 155 | if metric not in ['conf_mat_c', 'conf_mat_r', 'box_plot_data']: 156 | metric_value = output[metric] 157 | elif metric == 'conf_mat_c': 158 | conf_mat_df = output['conf_mat_c'] # use pivot to turn back 159 | conf_mat_df = conf_mat_df.unstack().rename('value').reset_index() 160 | conf_mat_df = pd.Series({'conf_mat_c': conf_mat_c}).to_json(orient='records') 161 | metric_value = conf_mat_df 162 | elif metric == 'conf_mat_r': 163 | conf_mat_regres_df = output['conf_mat_r'] # use pivot to turn back 164 | conf_mat_regres_df = conf_mat_regres_df.unstack().rename('value').reset_index() 165 | conf_mat_regres_df = pd.Series({'conf_mat_r': conf_mat_r}).to_json(orient='records') 166 | metric_value = conf_mat_regres_df 167 | elif metric == 'box_plot_data': 168 | box_plot_data_df = pd.Series({'box_plot_data': box_plot_data}).to_json(orient='records') 169 | metric_value = box_plot_data_df 170 | stat_dict['%s-%s' % (prefix, metric)] = metric_value 171 | 172 | # json stat log file, update and overwrite 173 | with open(log_file) as json_file: 174 | json_data = json.load(json_file) 175 | 176 | current_epoch = str(epoch) 177 | if current_epoch in json_data: 178 | old_stat_dict = json_data[current_epoch] 179 | stat_dict.update(old_stat_dict) 180 | current_epoch_dict = {current_epoch: stat_dict} 181 | json_data.update(current_epoch_dict) 182 | 183 | with open(log_file, 'w') as json_file: 184 | json.dump(json_data, json_file) 185 | 186 | # log values to tensorboard 187 | for metric in output: 188 | if metric not in ['conf_mat_c', 'conf_mat_r', 'box_plot_data']: 189 | tfwriter.add_scalar(prefix + '-' + metric, output[metric], current_epoch) 190 | 191 | 192 | #### 193 | def log_train_ema_results(engine, info): 194 | """ 195 | running training measurement 196 | """ 197 | training_ema_output = engine.state.metrics # 198 | training_ema_output['lr'] = float(info['optimizer'].param_groups[0]['lr']) 199 | update_log(training_ema_output, engine.state.iteration, 'train-ema', 'green', 200 | info['tfwriter'], info['json_file'], info['logging']) 201 | 202 | 203 | #### 204 | def process_accumulated_output_multi(output, batch_size, nr_classes): 205 | # 206 | def uneven_seq_to_np(seq): 207 | item_count = batch_size * (len(seq) - 1) + len(seq[-1]) 208 | cat_array = np.zeros((item_count,) + seq[0][0].shape, seq[0].dtype) 209 | # BUG: odd len even 210 | for idx in range(0, len(seq) - 1): 211 | cat_array[idx * batch_size: (idx + 1) * batch_size] = seq[idx] 212 | idx = -1 if len(seq) == 1 else idx # in case len(seq) ==1 then the for loop below will be skipped 213 | cat_array[(idx + 1) * batch_size:] = seq[-1] 214 | return cat_array 215 | 216 | proc_output = dict() 217 | true = uneven_seq_to_np(output['true']) 218 | # threshold then get accuracy 219 | if 'logit_c' in output.keys(): 220 | logit_c = uneven_seq_to_np(output['logit_c']) 221 | pred_c = np.argmax(logit_c, axis=-1) 222 | acc_c = np.mean(pred_c == true) 223 | # confusion matrix 224 | conf_mat_c = confusion_matrix(true, pred_c, labels=np.arange(nr_classes)) 225 | proc_output.update(acc_c=acc_c, conf_mat_c=conf_mat_c, ) 226 | if 'logit_r' in output.keys(): 227 | logit_r = uneven_seq_to_np(output['logit_r']) 228 | label = np.transpose(np.array([[0., 1., 2., 3.]]).repeat(len(true), axis=0), (1, 0)) 229 | pred_r = np.argmin(abs((logit_r - label)), axis=0) 230 | acc_r = np.mean(pred_r == true) 231 | # confusion matrix 232 | conf_mat_r = confusion_matrix(true, pred_r, labels=np.arange(nr_classes)) 233 | proc_output.update(acc_r=acc_r, conf_mat_r=conf_mat_r) 234 | 235 | # proc_output.update(box_plot_data=np.concatenate( 236 | # [true[np.newaxis, :], pred_c[np.newaxis, :], pred_r[np.newaxis, :], logit_r.transpose(1, 0)], 0)) 237 | return proc_output 238 | 239 | 240 | #### 241 | def inference(engine, inferer, prefix, dataloader, info): 242 | """ 243 | inference measurement 244 | """ 245 | inferer.accumulator = {metric: [] for metric in info['metric_names']} 246 | inferer.run(dataloader) 247 | output_stat = process_accumulated_output_multi(inferer.accumulator, 248 | info['infer_batch_size'], info['nr_classes']) 249 | update_log(output_stat, engine.state.iteration, prefix, 'red', 250 | info['tfwriter'], info['json_file'], info['logging']) 251 | return 252 | 253 | 254 | #### 255 | def accumulate_outputs(engine): 256 | batch_output = engine.state.output 257 | for key, item in batch_output.items(): 258 | engine.accumulator[key].extend([item]) 259 | return 260 | -------------------------------------------------------------------------------- /misc/train_ultils_validator.py: -------------------------------------------------------------------------------- 1 | import io 2 | import itertools 3 | import json 4 | import os 5 | 6 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' 7 | import random 8 | import re 9 | import shutil 10 | import textwrap 11 | 12 | import cv2 13 | import matplotlib 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import pandas as pd 17 | import torch 18 | from sklearn.metrics import confusion_matrix 19 | from termcolor import colored 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | import imgaug as ia 25 | from scipy.special import softmax 26 | from sklearn.metrics import classification_report 27 | 28 | def check_manual_seed(seed): 29 | """ 30 | If manual seed is not specified, choose a random one and notify it to the user 31 | """ 32 | seed = seed 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | torch.manual_seed(seed) 36 | torch.cuda.manual_seed(seed) 37 | ia.seed(seed) 38 | torch.cuda.manual_seed_all(seed) 39 | torch.backends.cudnn.benchmark = False 40 | torch.backends.cudnn.deterministic = True 41 | 42 | print('Using manual seed: {seed}'.format(seed=seed)) 43 | return 44 | 45 | 46 | def check_log_dir(log_dir): 47 | # check if log dir exist 48 | if os.path.isdir(log_dir): 49 | color_word = colored('WARMING', color='red', attrs=['bold', 'blink']) 50 | print('%s: %s exist!' % (color_word, colored(log_dir, attrs=['underline']))) 51 | while (True): 52 | print('Select Action: d (delete)/ q (quit)', end='') 53 | key = input() 54 | if key == 'd': 55 | shutil.rmtree(log_dir) 56 | break 57 | elif key == 'q': 58 | exit() 59 | else: 60 | color_word = colored('ERR', color='red') 61 | print('---[%s] Unrecognized character!' % color_word) 62 | return 63 | 64 | 65 | def plot_confusion_matrix(conf_mat, label): 66 | """ 67 | Parameters: 68 | title='Confusion matrix' : Title for your matrix 69 | tensor_name = 'MyFigure/image' : Name for the output summay tensor 70 | Returns: 71 | summary: image of plot figure 72 | Other items to note: 73 | - Depending on the number of category and the data , you may have to modify the figzie, font sizes etc. 74 | - Currently, some of the ticks dont line up due to rotations. 75 | """ 76 | 77 | cm = conf_mat 78 | 79 | np.set_printoptions(precision=2) # print numpy array with 2 decimal places 80 | 81 | fig = matplotlib.figure.Figure(figsize=(7, 7), dpi=320, facecolor='w', edgecolor='k') 82 | ax = fig.add_subplot(1, 1, 1) 83 | im = ax.imshow(cm, cmap='Oranges') 84 | 85 | classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in label] 86 | classes = ['\n'.join(textwrap.wrap(l, 40)) for l in classes] 87 | 88 | tick_marks = np.arange(len(classes)) 89 | 90 | ax.set_xlabel('Predicted', fontsize=7) 91 | ax.set_xticks(tick_marks) 92 | c = ax.set_xticklabels(classes, fontsize=4, rotation=-90, ha='center') 93 | ax.xaxis.set_label_position('bottom') 94 | ax.xaxis.tick_bottom() 95 | 96 | ax.set_ylabel('True Label', fontsize=7) 97 | ax.set_yticks(tick_marks) 98 | ax.set_yticklabels(classes, fontsize=4, va='center') 99 | ax.yaxis.set_label_position('left') 100 | ax.yaxis.tick_left() 101 | 102 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 103 | ax.text(j, i, format(cm[i, j], 'd') if cm[i, j] != 0 else '.', 104 | horizontalalignment="center", fontsize=6, 105 | verticalalignment='center', color="black") 106 | fig.set_tight_layout(True) 107 | 108 | fig.canvas.draw() 109 | w, h = fig.canvas.get_width_height() 110 | 111 | # get PNG data from the figure 112 | png_buffer = io.BytesIO() 113 | fig.canvas.print_png(png_buffer) 114 | png_encoded = png_buffer.getvalue() 115 | png_buffer.close() 116 | 117 | return png_encoded 118 | 119 | 120 | #### 121 | def update_log(output, epoch, net_name, prefix, color, tfwriter, log_file, logging): 122 | # print values and convert 123 | max_length = len(max(output.keys(), key=len)) 124 | for metric in output: 125 | key = colored(prefix + '-' + metric.ljust(max_length), color) 126 | print('------%s : ' % key, end='') 127 | if metric not in ['conf_mat_c', 'conf_mat_r', 'box_plot_data']: 128 | print('%0.7f' % output[metric]) 129 | elif metric == 'conf_mat_c': 130 | conf_mat_c = output['conf_mat_c'] # use pivot to turn back 131 | conf_mat_c_df = pd.DataFrame(conf_mat_c) 132 | conf_mat_c_df.index.name = 'True' 133 | conf_mat_c_df.columns.name = 'Pred' 134 | output['conf_mat_c'] = conf_mat_c_df 135 | print('\n', conf_mat_c_df) 136 | elif metric == 'conf_mat_r': 137 | conf_mat_r = output['conf_mat_r'] # use pivot to turn back 138 | conf_mat_r_df = pd.DataFrame(conf_mat_r) 139 | conf_mat_r_df.index.name = 'True' 140 | conf_mat_r_df.columns.name = 'Pred' 141 | output['conf_mat_r'] = conf_mat_r_df 142 | print('\n', conf_mat_r_df) 143 | elif metric == 'box_plot_data': 144 | box_plot_data = output['box_plot_data'] # use pivot to turn back 145 | box_plot_data_df = pd.DataFrame(box_plot_data) 146 | box_plot_data_df.columns.name = 'Pred' 147 | output['box_plot_data'] = box_plot_data_df 148 | 149 | if not logging: 150 | return 151 | 152 | # create stat dicts 153 | stat_dict = {} 154 | for metric in output: 155 | if metric not in ['conf_mat_c', 'conf_mat_r', 'box_plot_data']: 156 | metric_value = output[metric] 157 | elif metric == 'conf_mat_c': 158 | conf_mat_df = output['conf_mat_c'] # use pivot to turn back 159 | conf_mat_df = conf_mat_df.unstack().rename('value').reset_index() 160 | conf_mat_df = pd.Series({'conf_mat_c': conf_mat_c}).to_json(orient='records') 161 | metric_value = conf_mat_df 162 | elif metric == 'conf_mat_r': 163 | conf_mat_regres_df = output['conf_mat_r'] # use pivot to turn back 164 | conf_mat_regres_df = conf_mat_regres_df.unstack().rename('value').reset_index() 165 | conf_mat_regres_df = pd.Series({'conf_mat_r': conf_mat_r}).to_json(orient='records') 166 | metric_value = conf_mat_regres_df 167 | elif metric == 'box_plot_data': 168 | box_plot_data_df = pd.Series({'box_plot_data': box_plot_data}).to_json(orient='records') 169 | metric_value = box_plot_data_df 170 | stat_dict['%s-%s' % (prefix, metric)] = metric_value 171 | 172 | # json stat log file, update and overwrite 173 | with open(log_file) as json_file: 174 | json_data = json.load(json_file) 175 | 176 | current_epoch = str(epoch) 177 | current_model = str(net_name) 178 | if current_epoch in json_data: 179 | old_stat_dict = json_data[current_model] 180 | stat_dict.update(old_stat_dict) 181 | current_epoch_dict = {current_model: stat_dict} 182 | json_data.update(current_epoch_dict) 183 | 184 | with open(log_file, 'w') as json_file: 185 | json.dump(json_data, json_file) 186 | 187 | # log values to tensorboard 188 | for metric in output: 189 | if metric not in ['conf_mat_c', 'conf_mat_r', 'box_plot_data']: 190 | tfwriter.add_scalar(prefix + '-' + metric, output[metric], current_epoch) 191 | 192 | 193 | #### 194 | def log_train_ema_results(engine, info): 195 | """ 196 | running training measurement 197 | """ 198 | training_ema_output = engine.state.metrics # 199 | training_ema_output['lr'] = float(info['optimizer'].param_groups[0]['lr']) 200 | update_log(training_ema_output, engine.state.epoch, 'train-ema', 'green', 201 | info['tfwriter'], info['json_file'], info['logging']) 202 | 203 | 204 | #### 205 | def process_accumulated_output_multi(output, batch_size, nr_classes): 206 | # 207 | def uneven_seq_to_np(seq): 208 | item_count = batch_size * (len(seq) - 1) + len(seq[-1]) 209 | cat_array = np.zeros((item_count,) + seq[0][0].shape, seq[0].dtype) 210 | # BUG: odd len even 211 | if len(seq) < 2: 212 | return seq[0] 213 | for idx in range(0, len(seq) - 1): 214 | cat_array[idx * batch_size: 215 | (idx + 1) * batch_size] = seq[idx] 216 | cat_array[(idx + 1) * batch_size:] = seq[-1] 217 | return cat_array 218 | 219 | proc_output = dict() 220 | true = uneven_seq_to_np(output['true']) 221 | # threshold then get accuracy 222 | if 'logit_c' in output.keys(): 223 | logit_c = uneven_seq_to_np(output['logit_c']) 224 | pred_c = np.argmax(logit_c, axis=-1) 225 | # pred_c = [covert_dict[pred_c[idx]] for idx in range(len(pred_c))] 226 | acc_c = np.mean(pred_c == true) 227 | print(acc_c) 228 | # confusion matrix 229 | conf_mat_c = confusion_matrix(true, pred_c, labels=np.arange(nr_classes)) 230 | proc_output.update(acc_c=acc_c, conf_mat_c=conf_mat_c,) 231 | if 'logit_r' in output.keys(): 232 | logit_r = uneven_seq_to_np(output['logit_r']) 233 | label = np.transpose(np.array([[0., 1., 2., 3.]]).repeat(len(true), axis=0), (1, 0)) 234 | pred_r = np.argmin(abs((logit_r - label)), axis=0) 235 | # pred_r = [covert_dict[pred_r[idx]] for idx in range(len(pred_r))] 236 | acc_r = np.mean(pred_r == true) 237 | # print(acc_r) 238 | # confusion matrix 239 | conf_mat_r = confusion_matrix(true, pred_r, labels=np.arange(nr_classes)) 240 | proc_output.update(acc_r=acc_r, conf_mat_r=conf_mat_r) 241 | 242 | # proc_output.update(box_plot_data=np.concatenate( 243 | # [true[np.newaxis, :], pred_c[np.newaxis, :], pred_r[np.newaxis, :], logit_r.transpose(1, 0)], 0)) 244 | return proc_output 245 | 246 | def process_accumulated_output_multi_mix(output, batch_size, nr_classes): 247 | # 248 | def uneven_seq_to_np(seq): 249 | item_count = batch_size * (len(seq) - 1) + len(seq[-1]) 250 | cat_array = np.zeros((item_count,) + seq[0][0].shape, seq[0].dtype) 251 | # BUG: odd len even 252 | if len(seq) < 2: 253 | return seq[0] 254 | for idx in range(0, len(seq) - 1): 255 | cat_array[idx * batch_size: 256 | (idx + 1) * batch_size] = seq[idx] 257 | cat_array[(idx + 1) * batch_size:] = seq[-1] 258 | return cat_array 259 | 260 | proc_output = dict() 261 | true = uneven_seq_to_np(output['true']) 262 | # threshold then get accuracy 263 | if 'logit_c' in output.keys(): 264 | logit_c = uneven_seq_to_np(output['logit_c']) 265 | 266 | pred_c = np.argmax(logit_c, axis=-1) 267 | # pred_c = [covert_dict[pred_c[idx]] for idx in range(len(pred_c))] 268 | acc_c = np.mean(pred_c == true) 269 | print('acc_c',acc_c) 270 | # print(classification_report(true, pred_c, labels=[0, 1, 2, 3])) 271 | # confusion matrix 272 | conf_mat_c = confusion_matrix(true, pred_c, labels=np.arange(nr_classes)) 273 | proc_output.update(acc_c=acc_c, conf_mat_c=conf_mat_c,) 274 | if 'logit_r' in output.keys(): 275 | logit_r = uneven_seq_to_np(output['logit_r']) 276 | label = np.transpose(np.array([[0., 1., 2., 3.]]).repeat(len(true), axis=0), (1, 0)) 277 | pred_r = np.argmin(abs((logit_r - label)), axis=0) 278 | # pred_r = [covert_dict[pred_r[idx]] for idx in range(len(pred_r))] 279 | acc_r = np.mean(pred_r == true) 280 | print('acc_r',acc_r) 281 | # print(classification_report(true, pred_r, labels=[0, 1, 2, 3])) 282 | # confusion matrix 283 | conf_mat_r = confusion_matrix(true, pred_r, labels=np.arange(nr_classes)) 284 | proc_output.update(acc_r=acc_r, conf_mat_r=conf_mat_r) 285 | 286 | # if ('logit_r' in output.keys()) and ('logit_c' in output.keys()): 287 | # a = abs((logit_r - label)).transpose(1, 0) 288 | # prob_r = softmax(-a, 1) 289 | # logit_c +=prob_r 290 | # pred_c = np.argmax(logit_c, axis=-1) 291 | # acc_c = np.mean(pred_c == true) 292 | # print('acc_mix',acc_c) 293 | 294 | # proc_output.update(box_plot_data=np.concatenate( 295 | # [true[np.newaxis, :], pred_c[np.newaxis, :], pred_r[np.newaxis, :], logit_r.transpose(1, 0)], 0)) 296 | return proc_output 297 | 298 | def process_accumulated_output_multi_testAUG(output, batch_size, nr_classes): 299 | # 300 | def uneven_seq_to_np(seq): 301 | item_count = batch_size * (len(seq) - 1) + len(seq[-1]) 302 | cat_array = np.zeros((item_count,) + seq[0][0].shape, seq[0].dtype) 303 | # BUG: odd len even 304 | for idx in range(0, len(seq) - 1): 305 | cat_array[idx * batch_size: 306 | (idx + 1) * batch_size] = seq[idx] 307 | cat_array[(idx + 1) * batch_size:] = seq[-1] 308 | return cat_array 309 | 310 | proc_output = dict() 311 | true = uneven_seq_to_np(output['true']) 312 | # threshold then get accuracy 313 | if 'pred_c' in output.keys(): 314 | pred_c = uneven_seq_to_np(output['pred_c']) 315 | acc_c = np.mean(pred_c == true) 316 | # confusion matrix 317 | conf_mat_c = confusion_matrix(true, pred_c, labels=np.arange(nr_classes)) 318 | proc_output.update(acc_c=acc_c, conf_mat_c=conf_mat_c,) 319 | if 'pred_r' in output.keys(): 320 | pred_r = uneven_seq_to_np(output['pred_r']) 321 | acc_r = np.mean(pred_r == true) 322 | # confusion matrix 323 | conf_mat_r = confusion_matrix(true, pred_r, labels=np.arange(nr_classes)) 324 | proc_output.update(acc_r=acc_r, conf_mat_r=conf_mat_r) 325 | return proc_output 326 | 327 | 328 | #### 329 | def inference(engine, inferer, prefix, dataloader, info): 330 | """ 331 | inference measurement 332 | """ 333 | inferer.accumulator = {metric: [] for metric in info['metric_names']} 334 | inferer.run(dataloader) 335 | output_stat = process_accumulated_output_multi(inferer.accumulator, 336 | info['infer_batch_size'], info['nr_classes']) 337 | update_log(output_stat, engine.state.epoch, prefix, 'red', 338 | info['tfwriter'], info['json_file'], info['logging']) 339 | return 340 | 341 | 342 | #### 343 | def accumulate_outputs(engine): 344 | batch_output = engine.state.output 345 | for key, item in batch_output.items(): 346 | engine.accumulator[key].extend([item]) 347 | return 348 | 349 | 350 | def accumulate_predict(pred_patch): 351 | unique, counts = np.unique(pred_patch.cpu(), return_counts=True) 352 | pred_count = dict(zip(unique, counts)) 353 | patch_label = max(pred_count, key=pred_count.get) 354 | return patch_label 355 | -------------------------------------------------------------------------------- /model_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trinhvg/JCO_Learning-pytorch/6a6044801ac0a2f6ea2c479d4417e1128e5a0b72/model_lib/__init__.py -------------------------------------------------------------------------------- /model_lib/efficientnet_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.7.0" 2 | from .model import EfficientNet 3 | from .utils import ( 4 | GlobalParams, 5 | BlockArgs, 6 | BlockDecoder, 7 | efficientnet, 8 | get_model_params, 9 | ) 10 | 11 | -------------------------------------------------------------------------------- /model_lib/efficientnet_pytorch/model.py: -------------------------------------------------------------------------------- 1 | """model.py - Model and module class for EfficientNet. 2 | They are built to mirror those in the official TensorFlow implementation. 3 | """ 4 | 5 | # Author: lukemelas (github username) 6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch 7 | # With adjustments and added comments by workingcoder (github username). 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | import torchsummary as summary 13 | from model_lib.efficientnet_pytorch.utils import ( 14 | round_filters, 15 | round_repeats, 16 | drop_connect, 17 | get_same_padding_conv2d, 18 | get_model_params, 19 | efficientnet_params, 20 | load_pretrained_weights, 21 | Swish, 22 | MemoryEfficientSwish, 23 | calculate_output_image_size 24 | ) 25 | 26 | class MBConvBlock(nn.Module): 27 | """Mobile Inverted Residual Bottleneck Block. 28 | 29 | Args: 30 | block_args (namedtuple): BlockArgs, defined in utils.py. 31 | global_params (namedtuple): GlobalParam, defined in utils.py. 32 | image_size (tuple or list): [image_height, image_width]. 33 | 34 | References: 35 | [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) 36 | [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) 37 | [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) 38 | """ 39 | 40 | def __init__(self, block_args, global_params, image_size=None): 41 | super().__init__() 42 | self._block_args = block_args 43 | self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow 44 | self._bn_eps = global_params.batch_norm_epsilon 45 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 46 | self.id_skip = block_args.id_skip # whether to use skip connection and drop connect 47 | 48 | # Expansion phase (Inverted Bottleneck) 49 | inp = self._block_args.input_filters # number of input channels 50 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 51 | if self._block_args.expand_ratio != 1: 52 | Conv2d = get_same_padding_conv2d(image_size=image_size) 53 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 54 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 55 | # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size 56 | 57 | # Depthwise convolution phase 58 | k = self._block_args.kernel_size 59 | s = self._block_args.stride 60 | Conv2d = get_same_padding_conv2d(image_size=image_size) 61 | self._depthwise_conv = Conv2d( 62 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 63 | kernel_size=k, stride=s, bias=False) 64 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 65 | image_size = calculate_output_image_size(image_size, s) 66 | 67 | # Squeeze and Excitation layer, if desired 68 | if self.has_se: 69 | Conv2d = get_same_padding_conv2d(image_size=(1,1)) 70 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 71 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 72 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 73 | 74 | # Pointwise convolution phase 75 | final_oup = self._block_args.output_filters 76 | Conv2d = get_same_padding_conv2d(image_size=image_size) 77 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 78 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 79 | self._swish = MemoryEfficientSwish() 80 | 81 | def forward(self, inputs, drop_connect_rate=None): 82 | """MBConvBlock's forward function. 83 | 84 | Args: 85 | inputs (tensor): Input tensor. 86 | drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). 87 | 88 | Returns: 89 | Output of this block after processing. 90 | """ 91 | 92 | # Expansion and Depthwise Convolution 93 | x = inputs 94 | if self._block_args.expand_ratio != 1: 95 | x = self._expand_conv(inputs) 96 | x = self._bn0(x) 97 | x = self._swish(x) 98 | 99 | x = self._depthwise_conv(x) 100 | x = self._bn1(x) 101 | x = self._swish(x) 102 | 103 | # Squeeze and Excitation 104 | if self.has_se: 105 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 106 | x_squeezed = self._se_reduce(x_squeezed) 107 | x_squeezed = self._swish(x_squeezed) 108 | x_squeezed = self._se_expand(x_squeezed) 109 | x = torch.sigmoid(x_squeezed) * x 110 | 111 | # Pointwise Convolution 112 | x = self._project_conv(x) 113 | x = self._bn2(x) 114 | 115 | # Skip connection and drop connect 116 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 117 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 118 | # The combination of skip connection and drop connect brings about stochastic depth. 119 | if drop_connect_rate: 120 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 121 | x = x + inputs # skip connection 122 | return x 123 | 124 | def set_swish(self, memory_efficient=True): 125 | """Sets swish function as memory efficient (for training) or standard (for export). 126 | 127 | Args: 128 | memory_efficient (bool): Whether to use memory-efficient version of swish. 129 | """ 130 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 131 | 132 | 133 | class EfficientNet(nn.Module): 134 | """EfficientNet model. 135 | Most easily loaded with the .from_name or .from_pretrained methods. 136 | 137 | Args: 138 | blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. 139 | global_params (namedtuple): A set of GlobalParams shared between blocks. 140 | 141 | References: 142 | [1] https://arxiv.org/abs/1905.11946 (EfficientNet) 143 | 144 | # Example: 145 | # >>> import torch 146 | # >>> from efficientnet.model import EfficientNet 147 | # >>> inputs = torch.rand(1, 3, 224, 224) 148 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0') 149 | # >>> model.eval() 150 | # >>> outputs = model(inputs) 151 | """ 152 | 153 | def __init__(self, task_mode='class', blocks_args=None, global_params=None): 154 | super().__init__() 155 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 156 | assert len(blocks_args) > 0, 'block args must be greater than 0' 157 | self._global_params = global_params 158 | self._blocks_args = blocks_args 159 | self.task_mode = task_mode 160 | 161 | # Batch norm parameters 162 | bn_mom = 1 - self._global_params.batch_norm_momentum 163 | bn_eps = self._global_params.batch_norm_epsilon 164 | 165 | # Get stem static or dynamic convolution depending on image size 166 | image_size = global_params.image_size 167 | Conv2d = get_same_padding_conv2d(image_size=image_size) 168 | 169 | # Stem 170 | in_channels = 3 # rgb 171 | out_channels = round_filters(32, self._global_params) # number of output channels 172 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 173 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 174 | image_size = calculate_output_image_size(image_size, 2) 175 | 176 | # Build blocks 177 | self._blocks = nn.ModuleList([]) 178 | for block_args in self._blocks_args: 179 | 180 | # Update block input and output filters based on depth multiplier. 181 | block_args = block_args._replace( 182 | input_filters=round_filters(block_args.input_filters, self._global_params), 183 | output_filters=round_filters(block_args.output_filters, self._global_params), 184 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 185 | ) 186 | 187 | # The first block needs to take care of stride and filter size increase. 188 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) 189 | image_size = calculate_output_image_size(image_size, block_args.stride) 190 | if block_args.num_repeat > 1: # modify block_args to keep same output size 191 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 192 | for _ in range(block_args.num_repeat - 1): 193 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) 194 | # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 195 | 196 | # Head 197 | in_channels = block_args.output_filters # output of final block 198 | out_channels = round_filters(1280, self._global_params) 199 | Conv2d = get_same_padding_conv2d(image_size=image_size) 200 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 201 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 202 | 203 | # Final linear layer 204 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 205 | # self._dropout = nn.Dropout(self._global_params.dropout_rate) 206 | # self._fc = nn.Linear(out_channels, self._global_params.num_classes) 207 | # building classifier 208 | if self.task_mode in ['class', 'multi']: 209 | self.classifier_ = nn.Sequential( 210 | nn.Dropout(self._global_params.dropout_rate), 211 | nn.Linear(out_channels, self._global_params.num_classes), 212 | ) 213 | if self.task_mode in ['regress', 'multi']: 214 | self.regressioner_ = nn.Sequential( 215 | nn.Dropout(self._global_params.dropout_rate), 216 | nn.Linear(out_channels, 1), 217 | ) 218 | self._swish = MemoryEfficientSwish() 219 | 220 | # weight initialization 221 | for m in self.modules(): 222 | if isinstance(m, nn.Conv2d): 223 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 224 | if m.bias is not None: 225 | nn.init.zeros_(m.bias) 226 | elif isinstance(m, nn.BatchNorm2d): 227 | nn.init.ones_(m.weight) 228 | nn.init.zeros_(m.bias) 229 | elif isinstance(m, nn.Linear): 230 | nn.init.normal_(m.weight, 0, 0.01) 231 | nn.init.zeros_(m.bias) 232 | 233 | def set_swish(self, memory_efficient=True): 234 | """Sets swish function as memory efficient (for training) or standard (for export). 235 | 236 | Args: 237 | memory_efficient (bool): Whether to use memory-efficient version of swish. 238 | 239 | """ 240 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 241 | for block in self._blocks: 242 | block.set_swish(memory_efficient) 243 | 244 | def extract_endpoints(self, inputs): 245 | """Use convolution layer to extract features 246 | from reduction levels i in [1, 2, 3, 4, 5]. 247 | 248 | Args: 249 | inputs (tensor): Input tensor. 250 | 251 | Returns: 252 | Dictionary of last intermediate features 253 | with reduction levels i in [1, 2, 3, 4, 5]. 254 | Example: 255 | # >>> import torch 256 | # >>> from efficientnet.model import EfficientNet 257 | # >>> inputs = torch.rand(1, 3, 224, 224) 258 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0') 259 | # >>> endpoints = model.extract_features(inputs) 260 | # >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) 261 | # >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) 262 | # >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) 263 | # >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) 264 | # >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7]) 265 | """ 266 | endpoints = dict() 267 | 268 | # Stem 269 | x = self._swish(self._bn0(self._conv_stem(inputs))) 270 | prev_x = x 271 | 272 | # Blocks 273 | for idx, block in enumerate(self._blocks): 274 | drop_connect_rate = self._global_params.drop_connect_rate 275 | if drop_connect_rate: 276 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate 277 | x = block(x, drop_connect_rate=drop_connect_rate) 278 | if prev_x.size(2) > x.size(2): 279 | endpoints[f'reduction_{len(endpoints)+1}'] = prev_x 280 | prev_x = x 281 | 282 | # Head 283 | x = self._swish(self._bn1(self._conv_head(x))) 284 | endpoints[f'reduction_{len(endpoints)+1}'] = x 285 | 286 | return endpoints 287 | 288 | def extract_features(self, inputs): 289 | """use convolution layer to extract feature . 290 | 291 | Args: 292 | inputs (tensor): Input tensor. 293 | 294 | Returns: 295 | Output of the final convolution 296 | layer in the efficientnet model. 297 | """ 298 | # Stem 299 | x = self._swish(self._bn0(self._conv_stem(inputs))) 300 | 301 | # Blocks 302 | for idx, block in enumerate(self._blocks): 303 | drop_connect_rate = self._global_params.drop_connect_rate 304 | if drop_connect_rate: 305 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate 306 | x = block(x, drop_connect_rate=drop_connect_rate) 307 | 308 | # Head 309 | x = self._swish(self._bn1(self._conv_head(x))) 310 | 311 | return x 312 | 313 | def forward(self, inputs): 314 | """EfficientNet's forward function. 315 | Calls extract_features to extract features, applies final linear layer, and returns logits. 316 | 317 | Args: 318 | inputs (tensor): Input tensor. 319 | 320 | Returns: 321 | Output of this model after processing. 322 | """ 323 | # Convolution layers 324 | x = self.extract_features(inputs) 325 | 326 | # Pooling and final linear layer 327 | x = self._avg_pooling(x) 328 | x = x.flatten(start_dim=1) 329 | # x = self._dropout(x) 330 | # x = self._fc(x) 331 | # return x 332 | 333 | if self.task_mode == 'class': 334 | c_out = self.classifier_(x) 335 | return c_out 336 | elif self.task_mode == 'regress': 337 | r_out = self.regressioner_(x) 338 | return r_out[:, 0] 339 | elif self.task_mode == 'multi': 340 | c_out = self.classifier_(x) 341 | r_out = self.regressioner_(x) 342 | return c_out, r_out[:, 0] 343 | else: 344 | print(f'Do not support: {self.task_mode}' 345 | f'Only support one of [multi, class, and regress] task_mode') 346 | 347 | 348 | @classmethod 349 | def from_name(cls, task_mode, model_name, in_channels=3, **override_params): 350 | """create an efficientnet model according to name. 351 | 352 | Args: 353 | task_mode (str): class, multi, regress 354 | model_name (str): Name for efficientnet. 355 | in_channels (int): Input data's channel number. 356 | override_params (other key word params): 357 | Params to override model's global_params. 358 | Optional key: 359 | 'width_coefficient', 'depth_coefficient', 360 | 'image_size', 'dropout_rate', 361 | 'num_classes', 'batch_norm_momentum', 362 | 'batch_norm_epsilon', 'drop_connect_rate', 363 | 'depth_divisor', 'min_depth' 364 | 365 | Returns: 366 | An efficientnet model. 367 | """ 368 | cls._check_model_name_is_valid(model_name) 369 | blocks_args, global_params = get_model_params(model_name, override_params) 370 | model = cls(task_mode, blocks_args, global_params) 371 | model._change_in_channels(in_channels) 372 | return model 373 | 374 | @classmethod 375 | def from_pretrained(cls, task_mode, model_name, weights_path=None, advprop=False, 376 | in_channels=3, num_classes=1000, **override_params): 377 | """create an efficientnet model according to name. 378 | 379 | Args: 380 | task_mode (str): class, multi, regress 381 | model_name (str): Name for efficientnet. 382 | weights_path (None or str): 383 | str: path to pretrained weights file on the local disk. 384 | None: use pretrained weights downloaded from the Internet. 385 | advprop (bool): 386 | Whether to load pretrained weights 387 | trained with advprop (valid when weights_path is None). 388 | in_channels (int): Input data's channel number. 389 | num_classes (int): 390 | Number of categories for classification. 391 | It controls the output size for final linear layer. 392 | override_params (other key word params): 393 | Params to override model's global_params. 394 | Optional key: 395 | 'width_coefficient', 'depth_coefficient', 396 | 'image_size', 'dropout_rate', 397 | 'num_classes', 'batch_norm_momentum', 398 | 'batch_norm_epsilon', 'drop_connect_rate', 399 | 'depth_divisor', 'min_depth' 400 | 401 | Returns: 402 | A pretrained efficientnet model. 403 | """ 404 | model = cls.from_name(task_mode, model_name, num_classes=num_classes, **override_params) 405 | load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop) 406 | model._change_in_channels(in_channels) 407 | return model 408 | 409 | @classmethod 410 | def get_image_size(cls, model_name): 411 | """Get the input image size for a given efficientnet model. 412 | 413 | Args: 414 | model_name (str): Name for efficientnet. 415 | 416 | Returns: 417 | Input image size (resolution). 418 | """ 419 | cls._check_model_name_is_valid(model_name) 420 | _, _, res, _ = efficientnet_params(model_name) 421 | return res 422 | 423 | @classmethod 424 | def _check_model_name_is_valid(cls, model_name): 425 | """Validates model name. 426 | 427 | Args: 428 | model_name (str): Name for efficientnet. 429 | 430 | Returns: 431 | bool: Is a valid name or not. 432 | """ 433 | valid_models = ['efficientnet-b'+str(i) for i in range(9)] 434 | 435 | # Support the construction of 'efficientnet-l2' without pretrained weights 436 | valid_models += ['efficientnet-l2'] 437 | 438 | if model_name not in valid_models: 439 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) 440 | 441 | def _change_in_channels(self, in_channels): 442 | """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. 443 | 444 | Args: 445 | in_channels (int): Input data's channel number. 446 | """ 447 | if in_channels != 3: 448 | Conv2d = get_same_padding_conv2d(image_size = self._global_params.image_size) 449 | out_channels = round_filters(32, self._global_params) 450 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 451 | 452 | 453 | def jl_efficientnet(task_mode='class', pretrained=True, num_classes=4, **kwargs): 454 | """ 455 | Joint_learning efficient net 456 | 457 | Args: 458 | task_mode (string): multi, class, regress 459 | pretrained (bool): If True, returns a model pre-trained on ImageNet 460 | num_classes (int): number of class or number of output node 461 | """ 462 | func = EfficientNet.from_pretrained if pretrained else EfficientNet.from_name 463 | model = func(task_mode=task_mode, model_name='efficientnet-b0', num_classes=num_classes) 464 | return model 465 | 466 | 467 | # def _test(): 468 | # net = jl_efficientnet(task_mode='regress', pretrained=True, num_classes=3).cuda() 469 | # y = net(torch.randn(48, 3, 224, 224).cuda()) 470 | # # print(y_class.size(), y_regres.size()) 471 | # # y_class = net(torch.randn(48, 3, 224, 224).cuda()) 472 | # # print(y_class.size()) 473 | # 474 | # model = net.cuda() 475 | # summary(model, (3, 224, 224)) 476 | # _test() 477 | -------------------------------------------------------------------------------- /model_lib/efficientnet_pytorch/model_dorn.py: -------------------------------------------------------------------------------- 1 | """model.py - Model and module class for EfficientNet. 2 | They are built to mirror those in the official TensorFlow implementation. 3 | """ 4 | 5 | # Author: lukemelas (github username) 6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch 7 | # With adjustments and added comments by workingcoder (github username). 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from model_lib.efficientnet_pytorch.utils import ( 13 | round_filters, 14 | round_repeats, 15 | drop_connect, 16 | get_same_padding_conv2d, 17 | get_model_params, 18 | efficientnet_params, 19 | load_pretrained_weights, 20 | Swish, 21 | MemoryEfficientSwish, 22 | calculate_output_image_size 23 | ) 24 | 25 | 26 | class OrdinalRegressionLayer(nn.Module): 27 | def __init__(self): 28 | super(OrdinalRegressionLayer, self).__init__() 29 | 30 | def forward(self, x): 31 | """ 32 | :param x: N x 2K x H x W; N - batch_size, 2K - channels, K - number of discrete sub-intervals 33 | :return: labels - ordinal labels (corresponding to discrete depth values) of size N x 1 x H x W 34 | softmax - predicted softmax probabilities P (as in the paper) of size N x K x H x W 35 | """ 36 | N, K= x.size() 37 | K = K // 2 # number of discrete sub-intervals 38 | 39 | odd = x[:, ::2].clone() 40 | even = x[:, 1::2].clone() 41 | 42 | odd = odd.view(N, 1, K) 43 | even = even.view(N, 1, K) 44 | 45 | paired_channels = torch.cat((odd, even), dim=1) 46 | paired_channels = paired_channels.clamp(min=1e-8, max=1e8) # prevent nans 47 | 48 | softmax = nn.functional.softmax(paired_channels, dim=1) 49 | 50 | softmax = softmax[:, 1, :] 51 | softmax = softmax.view(-1, K) 52 | labels = torch.sum((softmax > 0.5), dim=1).view(-1, 1) - 1 53 | return labels[:, 0], softmax 54 | 55 | 56 | class MBConvBlock(nn.Module): 57 | """Mobile Inverted Residual Bottleneck Block. 58 | 59 | Args: 60 | block_args (namedtuple): BlockArgs, defined in utils.py. 61 | global_params (namedtuple): GlobalParam, defined in utils.py. 62 | image_size (tuple or list): [image_height, image_width]. 63 | 64 | References: 65 | [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) 66 | [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) 67 | [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) 68 | """ 69 | 70 | def __init__(self, block_args, global_params, image_size=None): 71 | super().__init__() 72 | self._block_args = block_args 73 | self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow 74 | self._bn_eps = global_params.batch_norm_epsilon 75 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 76 | self.id_skip = block_args.id_skip # whether to use skip connection and drop connect 77 | 78 | # Expansion phase (Inverted Bottleneck) 79 | inp = self._block_args.input_filters # number of input channels 80 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 81 | if self._block_args.expand_ratio != 1: 82 | Conv2d = get_same_padding_conv2d(image_size=image_size) 83 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 84 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 85 | # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size 86 | 87 | # Depthwise convolution phase 88 | k = self._block_args.kernel_size 89 | s = self._block_args.stride 90 | Conv2d = get_same_padding_conv2d(image_size=image_size) 91 | self._depthwise_conv = Conv2d( 92 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 93 | kernel_size=k, stride=s, bias=False) 94 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 95 | image_size = calculate_output_image_size(image_size, s) 96 | 97 | # Squeeze and Excitation layer, if desired 98 | if self.has_se: 99 | Conv2d = get_same_padding_conv2d(image_size=(1,1)) 100 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 101 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 102 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 103 | 104 | # Pointwise convolution phase 105 | final_oup = self._block_args.output_filters 106 | Conv2d = get_same_padding_conv2d(image_size=image_size) 107 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 108 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 109 | self._swish = MemoryEfficientSwish() 110 | 111 | def forward(self, inputs, drop_connect_rate=None): 112 | """MBConvBlock's forward function. 113 | 114 | Args: 115 | inputs (tensor): Input tensor. 116 | drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). 117 | 118 | Returns: 119 | Output of this block after processing. 120 | """ 121 | 122 | # Expansion and Depthwise Convolution 123 | x = inputs 124 | if self._block_args.expand_ratio != 1: 125 | x = self._expand_conv(inputs) 126 | x = self._bn0(x) 127 | x = self._swish(x) 128 | 129 | x = self._depthwise_conv(x) 130 | x = self._bn1(x) 131 | x = self._swish(x) 132 | 133 | # Squeeze and Excitation 134 | if self.has_se: 135 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 136 | x_squeezed = self._se_reduce(x_squeezed) 137 | x_squeezed = self._swish(x_squeezed) 138 | x_squeezed = self._se_expand(x_squeezed) 139 | x = torch.sigmoid(x_squeezed) * x 140 | 141 | # Pointwise Convolution 142 | x = self._project_conv(x) 143 | x = self._bn2(x) 144 | 145 | # Skip connection and drop connect 146 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 147 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 148 | # The combination of skip connection and drop connect brings about stochastic depth. 149 | if drop_connect_rate: 150 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 151 | x = x + inputs # skip connection 152 | return x 153 | 154 | def set_swish(self, memory_efficient=True): 155 | """Sets swish function as memory efficient (for training) or standard (for export). 156 | 157 | Args: 158 | memory_efficient (bool): Whether to use memory-efficient version of swish. 159 | """ 160 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 161 | 162 | 163 | class EfficientNet(nn.Module): 164 | """EfficientNet model. 165 | Most easily loaded with the .from_name or .from_pretrained methods. 166 | 167 | Args: 168 | blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. 169 | global_params (namedtuple): A set of GlobalParams shared between blocks. 170 | 171 | References: 172 | [1] https://arxiv.org/abs/1905.11946 (EfficientNet) 173 | 174 | # Example: 175 | # >>> import torch 176 | # >>> from efficientnet.model import EfficientNet 177 | # >>> inputs = torch.rand(1, 3, 224, 224) 178 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0') 179 | # >>> model.eval() 180 | # >>> outputs = model(inputs) 181 | """ 182 | 183 | def __init__(self, task_mode='class', blocks_args=None, global_params=None): 184 | super().__init__() 185 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 186 | assert len(blocks_args) > 0, 'block args must be greater than 0' 187 | self._global_params = global_params 188 | self._blocks_args = blocks_args 189 | self.task_mode = task_mode 190 | 191 | # Batch norm parameters 192 | bn_mom = 1 - self._global_params.batch_norm_momentum 193 | bn_eps = self._global_params.batch_norm_epsilon 194 | 195 | # Get stem static or dynamic convolution depending on image size 196 | image_size = global_params.image_size 197 | Conv2d = get_same_padding_conv2d(image_size=image_size) 198 | 199 | # Stem 200 | in_channels = 3 # rgb 201 | out_channels = round_filters(32, self._global_params) # number of output channels 202 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 203 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 204 | image_size = calculate_output_image_size(image_size, 2) 205 | 206 | # Build blocks 207 | self._blocks = nn.ModuleList([]) 208 | for block_args in self._blocks_args: 209 | 210 | # Update block input and output filters based on depth multiplier. 211 | block_args = block_args._replace( 212 | input_filters=round_filters(block_args.input_filters, self._global_params), 213 | output_filters=round_filters(block_args.output_filters, self._global_params), 214 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 215 | ) 216 | 217 | # The first block needs to take care of stride and filter size increase. 218 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) 219 | image_size = calculate_output_image_size(image_size, block_args.stride) 220 | if block_args.num_repeat > 1: # modify block_args to keep same output size 221 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 222 | for _ in range(block_args.num_repeat - 1): 223 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) 224 | # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 225 | 226 | # Head 227 | in_channels = block_args.output_filters # output of final block 228 | out_channels = round_filters(1280, self._global_params) 229 | Conv2d = get_same_padding_conv2d(image_size=image_size) 230 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 231 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 232 | 233 | # Final linear layer 234 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 235 | # self._dropout = nn.Dropout(self._global_params.dropout_rate) 236 | # self._fc = nn.Linear(out_channels, self._global_params.num_classes) 237 | # building classifier 238 | self.classifier_ = nn.Sequential( 239 | nn.Dropout(self._global_params.dropout_rate), 240 | nn.Linear(out_channels, self._global_params.num_classes), 241 | ) 242 | self.ordinal_regression = OrdinalRegressionLayer() 243 | self._swish = MemoryEfficientSwish() 244 | 245 | # weight initialization 246 | for m in self.modules(): 247 | if isinstance(m, nn.Conv2d): 248 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 249 | if m.bias is not None: 250 | nn.init.zeros_(m.bias) 251 | elif isinstance(m, nn.BatchNorm2d): 252 | nn.init.ones_(m.weight) 253 | nn.init.zeros_(m.bias) 254 | elif isinstance(m, nn.Linear): 255 | nn.init.normal_(m.weight, 0, 0.01) 256 | nn.init.zeros_(m.bias) 257 | 258 | def set_swish(self, memory_efficient=True): 259 | """Sets swish function as memory efficient (for training) or standard (for export). 260 | 261 | Args: 262 | memory_efficient (bool): Whether to use memory-efficient version of swish. 263 | 264 | """ 265 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 266 | for block in self._blocks: 267 | block.set_swish(memory_efficient) 268 | 269 | def extract_endpoints(self, inputs): 270 | """Use convolution layer to extract features 271 | from reduction levels i in [1, 2, 3, 4, 5]. 272 | 273 | Args: 274 | inputs (tensor): Input tensor. 275 | 276 | Returns: 277 | Dictionary of last intermediate features 278 | with reduction levels i in [1, 2, 3, 4, 5]. 279 | Example: 280 | # >>> import torch 281 | # >>> from efficientnet.model import EfficientNet 282 | # >>> inputs = torch.rand(1, 3, 224, 224) 283 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0') 284 | # >>> endpoints = model.extract_features(inputs) 285 | # >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) 286 | # >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) 287 | # >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) 288 | # >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) 289 | # >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7]) 290 | """ 291 | endpoints = dict() 292 | 293 | # Stem 294 | x = self._swish(self._bn0(self._conv_stem(inputs))) 295 | prev_x = x 296 | 297 | # Blocks 298 | for idx, block in enumerate(self._blocks): 299 | drop_connect_rate = self._global_params.drop_connect_rate 300 | if drop_connect_rate: 301 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate 302 | x = block(x, drop_connect_rate=drop_connect_rate) 303 | if prev_x.size(2) > x.size(2): 304 | endpoints[f'reduction_{len(endpoints)+1}'] = prev_x 305 | prev_x = x 306 | 307 | # Head 308 | x = self._swish(self._bn1(self._conv_head(x))) 309 | endpoints[f'reduction_{len(endpoints)+1}'] = x 310 | 311 | return endpoints 312 | 313 | def extract_features(self, inputs): 314 | """use convolution layer to extract feature . 315 | 316 | Args: 317 | inputs (tensor): Input tensor. 318 | 319 | Returns: 320 | Output of the final convolution 321 | layer in the efficientnet model. 322 | """ 323 | # Stem 324 | x = self._swish(self._bn0(self._conv_stem(inputs))) 325 | 326 | # Blocks 327 | for idx, block in enumerate(self._blocks): 328 | drop_connect_rate = self._global_params.drop_connect_rate 329 | if drop_connect_rate: 330 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate 331 | x = block(x, drop_connect_rate=drop_connect_rate) 332 | 333 | # Head 334 | x = self._swish(self._bn1(self._conv_head(x))) 335 | 336 | return x 337 | 338 | def forward(self, inputs): 339 | """EfficientNet's forward function. 340 | Calls extract_features to extract features, applies final linear layer, and returns logits. 341 | 342 | Args: 343 | inputs (tensor): Input tensor. 344 | 345 | Returns: 346 | Output of this model after processing. 347 | """ 348 | # Convolution layers 349 | x = self.extract_features(inputs) 350 | 351 | # Pooling and final linear layer 352 | x = self._avg_pooling(x) 353 | x = x.flatten(start_dim=1) 354 | # x = self._dropout(x) 355 | # x = self._fc(x) 356 | # return x 357 | 358 | if self.task_mode == 'class': 359 | c_out = self.classifier_(x) 360 | predicts, softmax = self.ordinal_regression(c_out) 361 | return predicts, softmax 362 | elif self.task_mode == 'regress': 363 | r_out = self.regressioner_(x) 364 | return r_out[:, 0] 365 | elif self.task_mode == 'multi': 366 | c_out = self.classifier_(x) 367 | r_out = self.regressioner_(x) 368 | return c_out, r_out[:, 0] 369 | else: 370 | print(f'Do not support: {self.task_mode}' 371 | f'Only support one of [multi, class, and regress] task_mode') 372 | 373 | 374 | @classmethod 375 | def from_name(cls, task_mode, model_name, in_channels=3, **override_params): 376 | """create an efficientnet model according to name. 377 | 378 | Args: 379 | task_mode (str): class, multi, regress 380 | model_name (str): Name for efficientnet. 381 | in_channels (int): Input data's channel number. 382 | override_params (other key word params): 383 | Params to override model's global_params. 384 | Optional key: 385 | 'width_coefficient', 'depth_coefficient', 386 | 'image_size', 'dropout_rate', 387 | 'num_classes', 'batch_norm_momentum', 388 | 'batch_norm_epsilon', 'drop_connect_rate', 389 | 'depth_divisor', 'min_depth' 390 | 391 | Returns: 392 | An efficientnet model. 393 | """ 394 | cls._check_model_name_is_valid(model_name) 395 | blocks_args, global_params = get_model_params(model_name, override_params) 396 | model = cls(task_mode, blocks_args, global_params) 397 | model._change_in_channels(in_channels) 398 | return model 399 | 400 | @classmethod 401 | def from_pretrained(cls, task_mode, model_name, weights_path=None, advprop=False, 402 | in_channels=3, num_classes=1000, **override_params): 403 | """create an efficientnet model according to name. 404 | 405 | Args: 406 | task_mode (str): class, multi, regress 407 | model_name (str): Name for efficientnet. 408 | weights_path (None or str): 409 | str: path to pretrained weights file on the local disk. 410 | None: use pretrained weights downloaded from the Internet. 411 | advprop (bool): 412 | Whether to load pretrained weights 413 | trained with advprop (valid when weights_path is None). 414 | in_channels (int): Input data's channel number. 415 | num_classes (int): 416 | Number of categories for classification. 417 | It controls the output size for final linear layer. 418 | override_params (other key word params): 419 | Params to override model's global_params. 420 | Optional key: 421 | 'width_coefficient', 'depth_coefficient', 422 | 'image_size', 'dropout_rate', 423 | 'num_classes', 'batch_norm_momentum', 424 | 'batch_norm_epsilon', 'drop_connect_rate', 425 | 'depth_divisor', 'min_depth' 426 | 427 | Returns: 428 | A pretrained efficientnet model. 429 | """ 430 | model = cls.from_name(task_mode, model_name, num_classes=num_classes, **override_params) 431 | load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop) 432 | model._change_in_channels(in_channels) 433 | return model 434 | 435 | @classmethod 436 | def get_image_size(cls, model_name): 437 | """Get the input image size for a given efficientnet model. 438 | 439 | Args: 440 | model_name (str): Name for efficientnet. 441 | 442 | Returns: 443 | Input image size (resolution). 444 | """ 445 | cls._check_model_name_is_valid(model_name) 446 | _, _, res, _ = efficientnet_params(model_name) 447 | return res 448 | 449 | @classmethod 450 | def _check_model_name_is_valid(cls, model_name): 451 | """Validates model name. 452 | 453 | Args: 454 | model_name (str): Name for efficientnet. 455 | 456 | Returns: 457 | bool: Is a valid name or not. 458 | """ 459 | valid_models = ['efficientnet-b'+str(i) for i in range(9)] 460 | 461 | # Support the construction of 'efficientnet-l2' without pretrained weights 462 | valid_models += ['efficientnet-l2'] 463 | 464 | if model_name not in valid_models: 465 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) 466 | 467 | def _change_in_channels(self, in_channels): 468 | """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. 469 | 470 | Args: 471 | in_channels (int): Input data's channel number. 472 | """ 473 | if in_channels != 3: 474 | Conv2d = get_same_padding_conv2d(image_size = self._global_params.image_size) 475 | out_channels = round_filters(32, self._global_params) 476 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 477 | 478 | 479 | def dorn_efficientnet(task_mode='class', pretrained=True, num_classes=4, **kwargs): 480 | """ 481 | Joint_learning efficient net 482 | 483 | Args: 484 | task_mode (string): multi, class, regress 485 | pretrained (bool): If True, returns a model pre-trained on ImageNet 486 | num_classes (int): number of class or number of output node 487 | """ 488 | func = EfficientNet.from_pretrained if pretrained else EfficientNet.from_name 489 | model = func(task_mode=task_mode, model_name='efficientnet-b0', num_classes=(num_classes-1)*2) 490 | return model 491 | 492 | 493 | # def _test(): 494 | # net = dorn_efficientnet(task_mode='class', pretrained=True).cuda() 495 | # y_class, y_regres = net(torch.randn(48, 3, 224, 224).cuda()) 496 | # print(y_class.size(), y_regres.size()) 497 | # # y_class = net(torch.randn(48, 3, 224, 224).cuda()) 498 | # # print(y_class.size()) 499 | # 500 | # # model = net.cuda() 501 | # # summary(model, (3, 224, 224)) 502 | # _test() 503 | -------------------------------------------------------------------------------- /model_lib/efficientnet_pytorch/model_mtmr.py: -------------------------------------------------------------------------------- 1 | """model.py - Model and module class for EfficientNet. 2 | They are built to mirror those in the official TensorFlow implementation. 3 | """ 4 | 5 | # Author: lukemelas (github username) 6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch 7 | # With adjustments and added comments by workingcoder (github username). 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from model_lib.efficientnet_pytorch.utils import ( 13 | round_filters, 14 | round_repeats, 15 | drop_connect, 16 | get_same_padding_conv2d, 17 | get_model_params, 18 | efficientnet_params, 19 | load_pretrained_weights, 20 | Swish, 21 | MemoryEfficientSwish, 22 | calculate_output_image_size 23 | ) 24 | from torch.autograd import Variable 25 | 26 | class MBConvBlock(nn.Module): 27 | """Mobile Inverted Residual Bottleneck Block. 28 | 29 | Args: 30 | block_args (namedtuple): BlockArgs, defined in utils.py. 31 | global_params (namedtuple): GlobalParam, defined in utils.py. 32 | image_size (tuple or list): [image_height, image_width]. 33 | 34 | References: 35 | [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) 36 | [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) 37 | [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) 38 | """ 39 | 40 | def __init__(self, block_args, global_params, image_size=None): 41 | super().__init__() 42 | self._block_args = block_args 43 | self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow 44 | self._bn_eps = global_params.batch_norm_epsilon 45 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 46 | self.id_skip = block_args.id_skip # whether to use skip connection and drop connect 47 | 48 | # Expansion phase (Inverted Bottleneck) 49 | inp = self._block_args.input_filters # number of input channels 50 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 51 | if self._block_args.expand_ratio != 1: 52 | Conv2d = get_same_padding_conv2d(image_size=image_size) 53 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 54 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 55 | # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size 56 | 57 | # Depthwise convolution phase 58 | k = self._block_args.kernel_size 59 | s = self._block_args.stride 60 | Conv2d = get_same_padding_conv2d(image_size=image_size) 61 | self._depthwise_conv = Conv2d( 62 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 63 | kernel_size=k, stride=s, bias=False) 64 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 65 | image_size = calculate_output_image_size(image_size, s) 66 | 67 | # Squeeze and Excitation layer, if desired 68 | if self.has_se: 69 | Conv2d = get_same_padding_conv2d(image_size=(1, 1)) 70 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 71 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 72 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 73 | 74 | # Pointwise convolution phase 75 | final_oup = self._block_args.output_filters 76 | Conv2d = get_same_padding_conv2d(image_size=image_size) 77 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 78 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 79 | self._swish = MemoryEfficientSwish() 80 | 81 | def forward(self, inputs, drop_connect_rate=None): 82 | """MBConvBlock's forward function. 83 | 84 | Args: 85 | inputs (tensor): Input tensor. 86 | drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). 87 | 88 | Returns: 89 | Output of this block after processing. 90 | """ 91 | 92 | # Expansion and Depthwise Convolution 93 | x = inputs 94 | if self._block_args.expand_ratio != 1: 95 | x = self._expand_conv(inputs) 96 | x = self._bn0(x) 97 | x = self._swish(x) 98 | 99 | x = self._depthwise_conv(x) 100 | x = self._bn1(x) 101 | x = self._swish(x) 102 | 103 | # Squeeze and Excitation 104 | if self.has_se: 105 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 106 | x_squeezed = self._se_reduce(x_squeezed) 107 | x_squeezed = self._swish(x_squeezed) 108 | x_squeezed = self._se_expand(x_squeezed) 109 | x = torch.sigmoid(x_squeezed) * x 110 | 111 | # Pointwise Convolution 112 | x = self._project_conv(x) 113 | x = self._bn2(x) 114 | 115 | # Skip connection and drop connect 116 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 117 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 118 | # The combination of skip connection and drop connect brings about stochastic depth. 119 | if drop_connect_rate: 120 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 121 | x = x + inputs # skip connection 122 | return x 123 | 124 | def set_swish(self, memory_efficient=True): 125 | """Sets swish function as memory efficient (for training) or standard (for export). 126 | 127 | Args: 128 | memory_efficient (bool): Whether to use memory-efficient version of swish. 129 | """ 130 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 131 | 132 | 133 | class EfficientNet(nn.Module): 134 | """EfficientNet model. 135 | Most easily loaded with the .from_name or .from_pretrained methods. 136 | 137 | Args: 138 | blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. 139 | global_params (namedtuple): A set of GlobalParams shared between blocks. 140 | 141 | References: 142 | [1] https://arxiv.org/abs/1905.11946 (EfficientNet) 143 | 144 | # Example: 145 | # >>> import torch 146 | # >>> from efficientnet.model import EfficientNet 147 | # >>> inputs = torch.rand(1, 3, 224, 224) 148 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0') 149 | # >>> model.eval() 150 | # >>> outputs = model(inputs) 151 | """ 152 | 153 | def __init__(self, task_mode='class', blocks_args=None, global_params=None): 154 | super().__init__() 155 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 156 | assert len(blocks_args) > 0, 'block args must be greater than 0' 157 | self._global_params = global_params 158 | self._blocks_args = blocks_args 159 | self.task_mode = task_mode 160 | 161 | # Batch norm parameters 162 | bn_mom = 1 - self._global_params.batch_norm_momentum 163 | bn_eps = self._global_params.batch_norm_epsilon 164 | 165 | # Get stem static or dynamic convolution depending on image size 166 | image_size = global_params.image_size 167 | Conv2d = get_same_padding_conv2d(image_size=image_size) 168 | 169 | # Stem 170 | in_channels = 3 # rgb 171 | out_channels = round_filters(32, self._global_params) # number of output channels 172 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 173 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 174 | image_size = calculate_output_image_size(image_size, 2) 175 | 176 | # Build blocks 177 | self._blocks = nn.ModuleList([]) 178 | for block_args in self._blocks_args: 179 | 180 | # Update block input and output filters based on depth multiplier. 181 | block_args = block_args._replace( 182 | input_filters=round_filters(block_args.input_filters, self._global_params), 183 | output_filters=round_filters(block_args.output_filters, self._global_params), 184 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 185 | ) 186 | 187 | # The first block needs to take care of stride and filter size increase. 188 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) 189 | image_size = calculate_output_image_size(image_size, block_args.stride) 190 | if block_args.num_repeat > 1: # modify block_args to keep same output size 191 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 192 | for _ in range(block_args.num_repeat - 1): 193 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) 194 | # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 195 | 196 | # Head 197 | in_channels = block_args.output_filters # output of final block 198 | out_channels = round_filters(1280, self._global_params) 199 | Conv2d = get_same_padding_conv2d(image_size=image_size) 200 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 201 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 202 | 203 | # Final linear layer 204 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 205 | # self._dropout = nn.Dropout(self._global_params.dropout_rate) 206 | # self._fc = nn.Linear(out_channels, self._global_params.num_classes) 207 | # building classifier 208 | if self.task_mode in ['class', 'multi']: 209 | self.classifier_ = nn.Sequential( 210 | nn.Dropout(self._global_params.dropout_rate), 211 | nn.Linear(out_channels, self._global_params.num_classes), 212 | ) 213 | if self.task_mode in ['regress', 'multi']: 214 | self.regressioner_ = nn.Sequential( 215 | nn.Dropout(self._global_params.dropout_rate), 216 | nn.Linear(out_channels, 1), 217 | ) 218 | if self.task_mode in ['multi_mtmr',]: 219 | self.attribute_feature_fc = nn.Linear(out_channels, 256) 220 | self.regression_ = nn.Linear(256, 1) 221 | self.classifier_ = nn.Sequential( 222 | nn.Dropout(self._global_params.dropout_rate), 223 | nn.Linear(out_channels + 256, self._global_params.num_classes), 224 | ) 225 | self._swish = MemoryEfficientSwish() 226 | 227 | # weight initialization 228 | for m in self.modules(): 229 | if isinstance(m, nn.Conv2d): 230 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 231 | if m.bias is not None: 232 | nn.init.zeros_(m.bias) 233 | elif isinstance(m, nn.BatchNorm2d): 234 | nn.init.ones_(m.weight) 235 | nn.init.zeros_(m.bias) 236 | elif isinstance(m, nn.Linear): 237 | nn.init.normal_(m.weight, 0, 0.01) 238 | nn.init.zeros_(m.bias) 239 | 240 | def set_swish(self, memory_efficient=True): 241 | """Sets swish function as memory efficient (for training) or standard (for export). 242 | 243 | Args: 244 | memory_efficient (bool): Whether to use memory-efficient version of swish. 245 | 246 | """ 247 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 248 | for block in self._blocks: 249 | block.set_swish(memory_efficient) 250 | 251 | def extract_endpoints(self, inputs): 252 | """Use convolution layer to extract features 253 | from reduction levels i in [1, 2, 3, 4, 5]. 254 | 255 | Args: 256 | inputs (tensor): Input tensor. 257 | 258 | Returns: 259 | Dictionary of last intermediate features 260 | with reduction levels i in [1, 2, 3, 4, 5]. 261 | Example: 262 | # >>> import torch 263 | # >>> from efficientnet.model import EfficientNet 264 | # >>> inputs = torch.rand(1, 3, 224, 224) 265 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0') 266 | # >>> endpoints = model.extract_features(inputs) 267 | # >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) 268 | # >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) 269 | # >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) 270 | # >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) 271 | # >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7]) 272 | """ 273 | endpoints = dict() 274 | 275 | # Stem 276 | x = self._swish(self._bn0(self._conv_stem(inputs))) 277 | prev_x = x 278 | 279 | # Blocks 280 | for idx, block in enumerate(self._blocks): 281 | drop_connect_rate = self._global_params.drop_connect_rate 282 | if drop_connect_rate: 283 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate 284 | x = block(x, drop_connect_rate=drop_connect_rate) 285 | if prev_x.size(2) > x.size(2): 286 | endpoints[f'reduction_{len(endpoints) + 1}'] = prev_x 287 | prev_x = x 288 | 289 | # Head 290 | x = self._swish(self._bn1(self._conv_head(x))) 291 | endpoints[f'reduction_{len(endpoints) + 1}'] = x 292 | 293 | return endpoints 294 | 295 | def extract_features(self, inputs): 296 | """use convolution layer to extract feature . 297 | 298 | Args: 299 | inputs (tensor): Input tensor. 300 | 301 | Returns: 302 | Output of the final convolution 303 | layer in the efficientnet model. 304 | """ 305 | # Stem 306 | x = self._swish(self._bn0(self._conv_stem(inputs))) 307 | 308 | # Blocks 309 | for idx, block in enumerate(self._blocks): 310 | drop_connect_rate = self._global_params.drop_connect_rate 311 | if drop_connect_rate: 312 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate 313 | x = block(x, drop_connect_rate=drop_connect_rate) 314 | 315 | # Head 316 | x = self._swish(self._bn1(self._conv_head(x))) 317 | 318 | return x 319 | 320 | def forward_once(self, inputs): 321 | """EfficientNet's forward function. 322 | Calls extract_features to extract features, applies final linear layer, and returns logits. 323 | 324 | Args: 325 | inputs (tensor): Input tensor. 326 | 327 | Returns: 328 | Output of this model after processing. 329 | """ 330 | # Convolution layers 331 | x = self.extract_features(inputs) 332 | 333 | # Pooling and final linear layer 334 | x = self._avg_pooling(x) 335 | x = x.flatten(start_dim=1) 336 | # x = self._dropout(x) 337 | # x = self._fc(x) 338 | # return x 339 | 340 | if self.task_mode == 'class': 341 | c_out = self.classifier_(x) 342 | return c_out 343 | elif self.task_mode == 'regress': 344 | r_out = self.regressioner_(x) 345 | return r_out[:, 0] 346 | elif self.task_mode == 'multi_mtmr': 347 | attribute_feature = self.attribute_feature_fc(x) 348 | r_out = self.regression_(attribute_feature) 349 | c_out = torch.cat([attribute_feature, x], dim=1) 350 | c_out = self.classifier_(c_out) 351 | return c_out, r_out[:, 0] 352 | elif self.task_mode == 'multi': 353 | c_out = self.classifier_(x) 354 | r_out = self.regressioner_(x) 355 | return c_out, r_out[:, 0] 356 | else: 357 | print(f'Do not support: {self.task_mode}' 358 | f'Only support one of [multi, class, and regress] task_mode') 359 | 360 | def forward(self, input): 361 | input_1 = input[0:int(input.shape[0]/2), :, :, :] 362 | input_2 = input[int(input.shape[0]/2):input.shape[0], :, :, :] 363 | output_1, attribute_score_1 = self.forward_once(input_1) 364 | output_2, attribute_score_2 = self.forward_once(input_2) 365 | 366 | cat_output = torch.cat([output_1, output_2]) 367 | cat_subtlety_score = torch.cat([attribute_score_1, attribute_score_2]) 368 | 369 | return cat_output, cat_subtlety_score 370 | 371 | @classmethod 372 | def from_name(cls, task_mode, model_name, in_channels=3, **override_params): 373 | """create an efficientnet model according to name. 374 | 375 | Args: 376 | task_mode (str): class, multi, regress 377 | model_name (str): Name for efficientnet. 378 | in_channels (int): Input data's channel number. 379 | override_params (other key word params): 380 | Params to override model's global_params. 381 | Optional key: 382 | 'width_coefficient', 'depth_coefficient', 383 | 'image_size', 'dropout_rate', 384 | 'num_classes', 'batch_norm_momentum', 385 | 'batch_norm_epsilon', 'drop_connect_rate', 386 | 'depth_divisor', 'min_depth' 387 | 388 | Returns: 389 | An efficientnet model. 390 | """ 391 | cls._check_model_name_is_valid(model_name) 392 | blocks_args, global_params = get_model_params(model_name, override_params) 393 | model = cls(task_mode, blocks_args, global_params) 394 | model._change_in_channels(in_channels) 395 | return model 396 | 397 | @classmethod 398 | def from_pretrained(cls, task_mode, model_name, weights_path=None, advprop=False, 399 | in_channels=3, num_classes=1000, **override_params): 400 | """create an efficientnet model according to name. 401 | 402 | Args: 403 | task_mode (str): class, multi, regress 404 | model_name (str): Name for efficientnet. 405 | weights_path (None or str): 406 | str: path to pretrained weights file on the local disk. 407 | None: use pretrained weights downloaded from the Internet. 408 | advprop (bool): 409 | Whether to load pretrained weights 410 | trained with advprop (valid when weights_path is None). 411 | in_channels (int): Input data's channel number. 412 | num_classes (int): 413 | Number of categories for classification. 414 | It controls the output size for final linear layer. 415 | override_params (other key word params): 416 | Params to override model's global_params. 417 | Optional key: 418 | 'width_coefficient', 'depth_coefficient', 419 | 'image_size', 'dropout_rate', 420 | 'num_classes', 'batch_norm_momentum', 421 | 'batch_norm_epsilon', 'drop_connect_rate', 422 | 'depth_divisor', 'min_depth' 423 | 424 | Returns: 425 | A pretrained efficientnet model. 426 | """ 427 | model = cls.from_name(task_mode, model_name, num_classes=num_classes, **override_params) 428 | load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), 429 | advprop=advprop) 430 | model._change_in_channels(in_channels) 431 | return model 432 | 433 | @classmethod 434 | def get_image_size(cls, model_name): 435 | """Get the input image size for a given efficientnet model. 436 | 437 | Args: 438 | model_name (str): Name for efficientnet. 439 | 440 | Returns: 441 | Input image size (resolution). 442 | """ 443 | cls._check_model_name_is_valid(model_name) 444 | _, _, res, _ = efficientnet_params(model_name) 445 | return res 446 | 447 | @classmethod 448 | def _check_model_name_is_valid(cls, model_name): 449 | """Validates model name. 450 | 451 | Args: 452 | model_name (str): Name for efficientnet. 453 | 454 | Returns: 455 | bool: Is a valid name or not. 456 | """ 457 | valid_models = ['efficientnet-b' + str(i) for i in range(9)] 458 | 459 | # Support the construction of 'efficientnet-l2' without pretrained weights 460 | valid_models += ['efficientnet-l2'] 461 | 462 | if model_name not in valid_models: 463 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) 464 | 465 | def _change_in_channels(self, in_channels): 466 | """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. 467 | 468 | Args: 469 | in_channels (int): Input data's channel number. 470 | """ 471 | if in_channels != 3: 472 | Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) 473 | out_channels = round_filters(32, self._global_params) 474 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 475 | 476 | 477 | def jl_efficientnet(task_mode='multi_mtmr', pretrained=True, num_classes=4, **kwargs): 478 | """ 479 | Joint_learning efficient net 480 | 481 | Args: 482 | task_mode (string): multi, class, regress 483 | pretrained (bool): If True, returns a model pre-trained on ImageNet 484 | num_classes (int): number of class or number of output node 485 | """ 486 | func = EfficientNet.from_pretrained if pretrained else EfficientNet.from_name 487 | model = func(task_mode=task_mode, model_name='efficientnet-b0', num_classes=num_classes) 488 | return model 489 | 490 | 491 | 492 | def get_loss_mtmr(output_score_1, cat_subtlety_score, gt_score_1, gt_attribute_score_1): 493 | xcentloss_func_1 = nn.CrossEntropyLoss() 494 | xcentloss_1 = xcentloss_func_1(output_score_1, gt_score_1) 495 | 496 | # ranking loss 497 | ranking_loss_sum = 0 498 | half_size_of_output_score = output_score_1.size()[0] // 2 499 | for i in range(half_size_of_output_score): 500 | tmp_output_1 = output_score_1[i] 501 | tmp_output_2 = output_score_1[i + half_size_of_output_score] 502 | tmp_gt_score_1 = gt_score_1[i] 503 | tmp_gt_score_2 = gt_score_1[i + half_size_of_output_score] 504 | 505 | rankingloss_func = nn.MarginRankingLoss() 506 | 507 | if tmp_gt_score_1.item() != tmp_gt_score_2.item(): 508 | target = torch.ones(1) * -1 509 | ranking_loss_sum += rankingloss_func(tmp_output_1, tmp_output_2, Variable(target.cuda())) 510 | else: 511 | target = torch.ones(1) 512 | ranking_loss_sum += rankingloss_func(tmp_output_1, tmp_output_2, Variable(target.cuda())) 513 | 514 | ranking_loss = ranking_loss_sum / half_size_of_output_score 515 | 516 | # attribute loss 517 | attribute_mseloss_func_1 = nn.MSELoss() 518 | attribute_mseloss_1 = attribute_mseloss_func_1(cat_subtlety_score, gt_attribute_score_1.float()) 519 | 520 | # loss = 0.6 * xcentloss_1 + 0.2 * ranking_loss + 0.2 * attribute_mseloss_1 521 | loss = 1 * xcentloss_1 + 5.0e-1 * ranking_loss + 1.0e-3 * attribute_mseloss_1 522 | 523 | return loss 524 | # def _test(): 525 | # import os 526 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 527 | # net = jl_efficientnet(task_mode='multi_mtmr', pretrained=True, num_classes=3) 528 | # net = torch.nn.DataParallel(net).cuda() 529 | # y_class, y_regres = net(torch.randn(48, 3, 512, 512).cuda()) 530 | # print(y_class.size(), y_regres.size()) 531 | # # y_class = net(torch.randn(48, 3, 224, 224).cuda()) 532 | # # print(y_class.size()) 533 | # 534 | # # model = net.cuda() 535 | # # summary(model, (3, 224, 224)) 536 | # _test() 537 | -------------------------------------------------------------------------------- /model_lib/efficientnet_pytorch/model_rank_ordinal.py: -------------------------------------------------------------------------------- 1 | """model.py - Model and module class for EfficientNet. 2 | They are built to mirror those in the official TensorFlow implementation. 3 | """ 4 | 5 | # Author: lukemelas (github username) 6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch 7 | # With adjustments and added comments by workingcoder (github username). 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from model_lib.efficientnet_pytorch.utils import ( 13 | round_filters, 14 | round_repeats, 15 | drop_connect, 16 | get_same_padding_conv2d, 17 | get_model_params, 18 | efficientnet_params, 19 | load_pretrained_weights, 20 | Swish, 21 | MemoryEfficientSwish, 22 | calculate_output_image_size 23 | ) 24 | from torchsummary import summary 25 | 26 | class OrdinalRegressionLayer(nn.Module): 27 | def __init__(self): 28 | super(OrdinalRegressionLayer, self).__init__() 29 | 30 | def forward(self, x): 31 | """ 32 | :param x: N x 2K x H x W; N - batch_size, 2K - channels, K - number of discrete sub-intervals 33 | :return: labels - ordinal labels (corresponding to discrete depth values) of size N x 1 x H x W 34 | softmax - predicted softmax probabilities P (as in the paper) of size N x K x H x W 35 | """ 36 | N, K= x.size() 37 | K = K // 2 # number of discrete sub-intervals 38 | 39 | odd = x[:, ::2].clone() 40 | even = x[:, 1::2].clone() 41 | 42 | odd = odd.view(N, 1, K) 43 | even = even.view(N, 1, K) 44 | 45 | paired_channels = torch.cat((odd, even), dim=1) 46 | paired_channels = paired_channels.clamp(min=1e-8, max=1e8) # prevent nans 47 | 48 | softmax = F.softmax(paired_channels, dim=1) 49 | 50 | softmax = softmax[:, 1, :] 51 | softmax = softmax.view(-1, K) 52 | labels = torch.sum((softmax > 0.5), dim=1).view(-1, 1) - 1 53 | return labels[:, 0], softmax 54 | 55 | 56 | class MBConvBlock(nn.Module): 57 | """Mobile Inverted Residual Bottleneck Block. 58 | 59 | Args: 60 | block_args (namedtuple): BlockArgs, defined in utils.py. 61 | global_params (namedtuple): GlobalParam, defined in utils.py. 62 | image_size (tuple or list): [image_height, image_width]. 63 | 64 | References: 65 | [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) 66 | [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) 67 | [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) 68 | """ 69 | 70 | def __init__(self, block_args, global_params, image_size=None): 71 | super().__init__() 72 | self._block_args = block_args 73 | self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow 74 | self._bn_eps = global_params.batch_norm_epsilon 75 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 76 | self.id_skip = block_args.id_skip # whether to use skip connection and drop connect 77 | 78 | # Expansion phase (Inverted Bottleneck) 79 | inp = self._block_args.input_filters # number of input channels 80 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 81 | if self._block_args.expand_ratio != 1: 82 | Conv2d = get_same_padding_conv2d(image_size=image_size) 83 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 84 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 85 | # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size 86 | 87 | # Depthwise convolution phase 88 | k = self._block_args.kernel_size 89 | s = self._block_args.stride 90 | Conv2d = get_same_padding_conv2d(image_size=image_size) 91 | self._depthwise_conv = Conv2d( 92 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 93 | kernel_size=k, stride=s, bias=False) 94 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 95 | image_size = calculate_output_image_size(image_size, s) 96 | 97 | # Squeeze and Excitation layer, if desired 98 | if self.has_se: 99 | Conv2d = get_same_padding_conv2d(image_size=(1, 1)) 100 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 101 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 102 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 103 | 104 | # Pointwise convolution phase 105 | final_oup = self._block_args.output_filters 106 | Conv2d = get_same_padding_conv2d(image_size=image_size) 107 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 108 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 109 | self._swish = MemoryEfficientSwish() 110 | 111 | def forward(self, inputs, drop_connect_rate=None): 112 | """MBConvBlock's forward function. 113 | 114 | Args: 115 | inputs (tensor): Input tensor. 116 | drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). 117 | 118 | Returns: 119 | Output of this block after processing. 120 | """ 121 | 122 | # Expansion and Depthwise Convolution 123 | x = inputs 124 | if self._block_args.expand_ratio != 1: 125 | x = self._expand_conv(inputs) 126 | x = self._bn0(x) 127 | x = self._swish(x) 128 | 129 | x = self._depthwise_conv(x) 130 | x = self._bn1(x) 131 | x = self._swish(x) 132 | 133 | # Squeeze and Excitation 134 | if self.has_se: 135 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 136 | x_squeezed = self._se_reduce(x_squeezed) 137 | x_squeezed = self._swish(x_squeezed) 138 | x_squeezed = self._se_expand(x_squeezed) 139 | x = torch.sigmoid(x_squeezed) * x 140 | 141 | # Pointwise Convolution 142 | x = self._project_conv(x) 143 | x = self._bn2(x) 144 | 145 | # Skip connection and drop connect 146 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 147 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 148 | # The combination of skip connection and drop connect brings about stochastic depth. 149 | if drop_connect_rate: 150 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 151 | x = x + inputs # skip connection 152 | return x 153 | 154 | def set_swish(self, memory_efficient=True): 155 | """Sets swish function as memory efficient (for training) or standard (for export). 156 | 157 | Args: 158 | memory_efficient (bool): Whether to use memory-efficient version of swish. 159 | """ 160 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 161 | 162 | 163 | class EfficientNet(nn.Module): 164 | """EfficientNet model. 165 | Most easily loaded with the .from_name or .from_pretrained methods. 166 | 167 | Args: 168 | blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. 169 | global_params (namedtuple): A set of GlobalParams shared between blocks. 170 | 171 | References: 172 | [1] https://arxiv.org/abs/1905.11946 (EfficientNet) 173 | 174 | # Example: 175 | # >>> import torch 176 | # >>> from efficientnet.model import EfficientNet 177 | # >>> inputs = torch.rand(1, 3, 224, 224) 178 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0') 179 | # >>> model.eval() 180 | # >>> outputs = model(inputs) 181 | """ 182 | 183 | def __init__(self, task_mode='class', blocks_args=None, global_params=None): 184 | super().__init__() 185 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 186 | assert len(blocks_args) > 0, 'block args must be greater than 0' 187 | self._global_params = global_params 188 | self._blocks_args = blocks_args 189 | self.task_mode = task_mode 190 | 191 | # Batch norm parameters 192 | bn_mom = 1 - self._global_params.batch_norm_momentum 193 | bn_eps = self._global_params.batch_norm_epsilon 194 | 195 | # Get stem static or dynamic convolution depending on image size 196 | image_size = global_params.image_size 197 | Conv2d = get_same_padding_conv2d(image_size=image_size) 198 | 199 | # Stem 200 | in_channels = 3 # rgb 201 | out_channels = round_filters(32, self._global_params) # number of output channels 202 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 203 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 204 | image_size = calculate_output_image_size(image_size, 2) 205 | 206 | # Build blocks 207 | self._blocks = nn.ModuleList([]) 208 | for block_args in self._blocks_args: 209 | 210 | # Update block input and output filters based on depth multiplier. 211 | block_args = block_args._replace( 212 | input_filters=round_filters(block_args.input_filters, self._global_params), 213 | output_filters=round_filters(block_args.output_filters, self._global_params), 214 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 215 | ) 216 | 217 | # The first block needs to take care of stride and filter size increase. 218 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) 219 | image_size = calculate_output_image_size(image_size, block_args.stride) 220 | if block_args.num_repeat > 1: # modify block_args to keep same output size 221 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 222 | for _ in range(block_args.num_repeat - 1): 223 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) 224 | # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 225 | 226 | # Head 227 | in_channels = block_args.output_filters # output of final block 228 | out_channels = round_filters(1280, self._global_params) 229 | Conv2d = get_same_padding_conv2d(image_size=image_size) 230 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 231 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 232 | 233 | # Final linear layer 234 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 235 | # self._dropout = nn.Dropout(self._global_params.dropout_rate) 236 | # self._fc = nn.Linear(out_channels, self._global_params.num_classes) 237 | # building classifier 238 | if self.task_mode in ['class', 'multi']: 239 | self.classifier_ = nn.Sequential( 240 | nn.Dropout(self._global_params.dropout_rate), 241 | nn.Linear(out_channels, self._global_params.num_classes), 242 | ) 243 | if self.task_mode in ['regress', 'multi']: 244 | self.regressioner_ = nn.Sequential( 245 | nn.Dropout(self._global_params.dropout_rate), 246 | nn.Linear(out_channels, 1), 247 | ) 248 | if self.task_mode in ['regress_rank_ordinal',]: 249 | self.regressioner_ = nn.Sequential( 250 | nn.Dropout(self._global_params.dropout_rate), 251 | nn.Linear(out_channels, (self._global_params.num_classes - 1) * 2), 252 | ) 253 | if self.task_mode in ['regress_rank_dorn', ]: 254 | self.regressioner_ = nn.Sequential( 255 | nn.Dropout(self._global_params.dropout_rate), 256 | nn.Linear(out_channels, (self._global_params.num_classes - 1) * 2), 257 | ) 258 | self.ordinal_regression = OrdinalRegressionLayer() 259 | self._swish = MemoryEfficientSwish() 260 | 261 | # weight initialization 262 | for m in self.modules(): 263 | if isinstance(m, nn.Conv2d): 264 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 265 | if m.bias is not None: 266 | nn.init.zeros_(m.bias) 267 | elif isinstance(m, nn.BatchNorm2d): 268 | nn.init.ones_(m.weight) 269 | nn.init.zeros_(m.bias) 270 | elif isinstance(m, nn.Linear): 271 | nn.init.normal_(m.weight, 0, 0.01) 272 | nn.init.zeros_(m.bias) 273 | 274 | def set_swish(self, memory_efficient=True): 275 | """Sets swish function as memory efficient (for training) or standard (for export). 276 | 277 | Args: 278 | memory_efficient (bool): Whether to use memory-efficient version of swish. 279 | 280 | """ 281 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 282 | for block in self._blocks: 283 | block.set_swish(memory_efficient) 284 | 285 | def extract_endpoints(self, inputs): 286 | """Use convolution layer to extract features 287 | from reduction levels i in [1, 2, 3, 4, 5]. 288 | 289 | Args: 290 | inputs (tensor): Input tensor. 291 | 292 | Returns: 293 | Dictionary of last intermediate features 294 | with reduction levels i in [1, 2, 3, 4, 5]. 295 | Example: 296 | # >>> import torch 297 | # >>> from efficientnet.model import EfficientNet 298 | # >>> inputs = torch.rand(1, 3, 224, 224) 299 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0') 300 | # >>> endpoints = model.extract_features(inputs) 301 | # >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) 302 | # >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) 303 | # >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) 304 | # >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) 305 | # >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7]) 306 | """ 307 | endpoints = dict() 308 | 309 | # Stem 310 | x = self._swish(self._bn0(self._conv_stem(inputs))) 311 | prev_x = x 312 | 313 | # Blocks 314 | for idx, block in enumerate(self._blocks): 315 | drop_connect_rate = self._global_params.drop_connect_rate 316 | if drop_connect_rate: 317 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate 318 | x = block(x, drop_connect_rate=drop_connect_rate) 319 | if prev_x.size(2) > x.size(2): 320 | endpoints[f'reduction_{len(endpoints) + 1}'] = prev_x 321 | prev_x = x 322 | 323 | # Head 324 | x = self._swish(self._bn1(self._conv_head(x))) 325 | endpoints[f'reduction_{len(endpoints) + 1}'] = x 326 | 327 | return endpoints 328 | 329 | def extract_features(self, inputs): 330 | """use convolution layer to extract feature . 331 | 332 | Args: 333 | inputs (tensor): Input tensor. 334 | 335 | Returns: 336 | Output of the final convolution 337 | layer in the efficientnet model. 338 | """ 339 | # Stem 340 | x = self._swish(self._bn0(self._conv_stem(inputs))) 341 | 342 | # Blocks 343 | for idx, block in enumerate(self._blocks): 344 | drop_connect_rate = self._global_params.drop_connect_rate 345 | if drop_connect_rate: 346 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate 347 | x = block(x, drop_connect_rate=drop_connect_rate) 348 | 349 | # Head 350 | x = self._swish(self._bn1(self._conv_head(x))) 351 | 352 | return x 353 | 354 | def forward(self, inputs): 355 | """EfficientNet's forward function. 356 | Calls extract_features to extract features, applies final linear layer, and returns logits. 357 | 358 | Args: 359 | inputs (tensor): Input tensor. 360 | 361 | Returns: 362 | Output of this model after processing. 363 | """ 364 | # Convolution layers 365 | x = self.extract_features(inputs) 366 | 367 | # Pooling and final linear layer 368 | x = self._avg_pooling(x) 369 | x = x.flatten(start_dim=1) 370 | # x = self._dropout(x) 371 | # x = self._fc(x) 372 | # return x 373 | 374 | if self.task_mode == 'class': 375 | c_out = self.classifier_(x) 376 | return c_out 377 | elif self.task_mode == 'regress': 378 | r_out = self.regressioner_(x) 379 | return r_out[:, 0] 380 | elif self.task_mode == 'regress_rank_ordinal': 381 | r_out = self.regressioner_(x) 382 | r_out = r_out.view(-1, (self._global_params.num_classes - 1), 2) 383 | probas = F.softmax(r_out, dim=2)[:, :, 1] 384 | return r_out, probas 385 | elif self.task_mode in ['regress_rank_dorn', ]: 386 | r_out = self.regressioner_(x) 387 | predicts, softmax = self.ordinal_regression(r_out) 388 | return predicts, softmax 389 | elif self.task_mode == 'multi': 390 | c_out = self.classifier_(x) 391 | r_out = self.regressioner_(x) 392 | return c_out, r_out[:, 0] 393 | else: 394 | print(f'Do not support: {self.task_mode}' 395 | f'Only support one of [multi, class, and regress] task_mode') 396 | 397 | @classmethod 398 | def from_name(cls, task_mode, model_name, in_channels=3, **override_params): 399 | """create an efficientnet model according to name. 400 | 401 | Args: 402 | task_mode (str): class, multi, regress 403 | model_name (str): Name for efficientnet. 404 | in_channels (int): Input data's channel number. 405 | override_params (other key word params): 406 | Params to override model's global_params. 407 | Optional key: 408 | 'width_coefficient', 'depth_coefficient', 409 | 'image_size', 'dropout_rate', 410 | 'num_classes', 'batch_norm_momentum', 411 | 'batch_norm_epsilon', 'drop_connect_rate', 412 | 'depth_divisor', 'min_depth' 413 | 414 | Returns: 415 | An efficientnet model. 416 | """ 417 | cls._check_model_name_is_valid(model_name) 418 | blocks_args, global_params = get_model_params(model_name, override_params) 419 | model = cls(task_mode, blocks_args, global_params) 420 | model._change_in_channels(in_channels) 421 | return model 422 | 423 | @classmethod 424 | def from_pretrained(cls, task_mode, model_name, weights_path=None, advprop=False, 425 | in_channels=3, num_classes=1000, **override_params): 426 | """create an efficientnet model according to name. 427 | 428 | Args: 429 | task_mode (str): class, multi, regress 430 | model_name (str): Name for efficientnet. 431 | weights_path (None or str): 432 | str: path to pretrained weights file on the local disk. 433 | None: use pretrained weights downloaded from the Internet. 434 | advprop (bool): 435 | Whether to load pretrained weights 436 | trained with advprop (valid when weights_path is None). 437 | in_channels (int): Input data's channel number. 438 | num_classes (int): 439 | Number of categories for classification. 440 | It controls the output size for final linear layer. 441 | override_params (other key word params): 442 | Params to override model's global_params. 443 | Optional key: 444 | 'width_coefficient', 'depth_coefficient', 445 | 'image_size', 'dropout_rate', 446 | 'num_classes', 'batch_norm_momentum', 447 | 'batch_norm_epsilon', 'drop_connect_rate', 448 | 'depth_divisor', 'min_depth' 449 | 450 | Returns: 451 | A pretrained efficientnet model. 452 | """ 453 | model = cls.from_name(task_mode, model_name, num_classes=num_classes, **override_params) 454 | load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), 455 | advprop=advprop) 456 | model._change_in_channels(in_channels) 457 | return model 458 | 459 | @classmethod 460 | def get_image_size(cls, model_name): 461 | """Get the input image size for a given efficientnet model. 462 | 463 | Args: 464 | model_name (str): Name for efficientnet. 465 | 466 | Returns: 467 | Input image size (resolution). 468 | """ 469 | cls._check_model_name_is_valid(model_name) 470 | _, _, res, _ = efficientnet_params(model_name) 471 | return res 472 | 473 | @classmethod 474 | def _check_model_name_is_valid(cls, model_name): 475 | """Validates model name. 476 | 477 | Args: 478 | model_name (str): Name for efficientnet. 479 | 480 | Returns: 481 | bool: Is a valid name or not. 482 | """ 483 | valid_models = ['efficientnet-b' + str(i) for i in range(9)] 484 | 485 | # Support the construction of 'efficientnet-l2' without pretrained weights 486 | valid_models += ['efficientnet-l2'] 487 | 488 | if model_name not in valid_models: 489 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) 490 | 491 | def _change_in_channels(self, in_channels): 492 | """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. 493 | 494 | Args: 495 | in_channels (int): Input data's channel number. 496 | """ 497 | if in_channels != 3: 498 | Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) 499 | out_channels = round_filters(32, self._global_params) 500 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 501 | 502 | 503 | def jl_efficientnet(task_mode='class', pretrained=True, num_classes=4, **kwargs): 504 | """ 505 | Joint_learning efficient net 506 | 507 | Args: 508 | task_mode (string): multi, class, regress 509 | pretrained (bool): If True, returns a model pre-trained on ImageNet 510 | num_classes (int): number of class or number of output node 511 | """ 512 | func = EfficientNet.from_pretrained if pretrained else EfficientNet.from_name 513 | model = func(task_mode=task_mode, model_name='efficientnet-b0', num_classes=num_classes) 514 | return model 515 | 516 | 517 | # def _test(): 518 | # net = jl_efficientnet(task_mode='REGRESS_rank_ordinal', pretrained=True, num_classes=4).cuda() 519 | # y_class, y_regres = net(torch.randn(48, 3, 224, 224).cuda()) 520 | # print(y_class.size(), y_regres.size()) 521 | # # y_class = net(torch.randn(48, 3, 224, 224).cuda()) 522 | # # print(y_class.size()) 523 | # 524 | # # model = net.cuda() 525 | # # summary(model, (3, 224, 224)) 526 | # _test() 527 | 528 | def label_to_levels(label, num_classes=4): 529 | levels = [1] * label + [0] * (num_classes - 1 - label) 530 | levels = torch.tensor(levels, dtype=torch.float32) 531 | return levels 532 | 533 | 534 | def labels_to_labels(class_labels): 535 | """ 536 | class_labels = [2, 1, 3] 537 | """ 538 | levels = [] 539 | for label in class_labels: 540 | levels_from_label = label_to_levels(int(label), num_classes=4) 541 | levels.append(levels_from_label) 542 | return torch.stack(levels).cuda() 543 | 544 | 545 | def cost_fn(logits, label): 546 | num_classes = 4 547 | imp = torch.ones(num_classes - 1, dtype=torch.float).cuda() 548 | levels = labels_to_labels(label) 549 | val = (-torch.sum((F.log_softmax(logits, dim=2)[:, :, 1] * levels 550 | + F.log_softmax(logits, dim=2)[:, :, 0] * (1 - levels)) * imp, dim=1)) 551 | return torch.mean(val) 552 | 553 | 554 | def loss_fn2(logits, label): 555 | num_classes = 4 556 | imp = torch.ones(num_classes - 1, dtype=torch.float) 557 | levels = labels_to_labels(label) 558 | val = (-torch.sum((F.logsigmoid(logits) * levels 559 | + (F.logsigmoid(logits) - logits) * (1 - levels)) * imp, 560 | dim=1)) 561 | return torch.mean(val) 562 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | torch~=1.8.1+cu111 5 | numpy~=1.20.2 6 | opencv-python~=4.5.1.48 7 | scikit-learn~=0.24.1 8 | torchvision~=0.9.1+cu111 9 | matplotlib~=3.3.4 10 | pandas~=1.2.4 11 | imgaug~=0.4.0 12 | termcolor~=1.1.0 13 | scipy~=1.6.2 14 | torchsummary~=1.5.1 15 | ignite~=0.4.4 16 | tensorboardx~=2.2 -------------------------------------------------------------------------------- /scheduler_lr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trinhvg/JCO_Learning-pytorch/6a6044801ac0a2f6ea2c479d4417e1128e5a0b72/scheduler_lr/__init__.py -------------------------------------------------------------------------------- /scheduler_lr/warmup_cosine_lr.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | import torch 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | class GradualWarmupScheduler(_LRScheduler): 8 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 9 | self.multiplier = multiplier 10 | self.total_epoch = total_epoch 11 | self.after_scheduler = after_scheduler 12 | self.finished = False 13 | super().__init__(optimizer) 14 | 15 | def get_lr(self): 16 | if self.last_epoch > self.total_epoch: 17 | if self.after_scheduler: 18 | if not self.finished: 19 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 20 | self.finished = True 21 | return self.after_scheduler.get_lr() 22 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 23 | 24 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in 25 | self.base_lrs] 26 | 27 | def step(self, epoch=None, metrics=None): 28 | if self.finished and self.after_scheduler: 29 | if epoch is None: 30 | self.after_scheduler.step(None) 31 | else: 32 | self.after_scheduler.step(epoch - self.total_epoch) 33 | else: 34 | return super(GradualWarmupScheduler, self).step(epoch) 35 | 36 | 37 | if __name__ == '__main__': 38 | v = torch.zeros(10) 39 | optim = torch.optim.SGD([v], lr=0.01) 40 | cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 100, eta_min=0, last_epoch=-1) 41 | scheduler = GradualWarmupScheduler(optim, multiplier=8, total_epoch=5, after_scheduler=cosine_scheduler) 42 | a = [] 43 | b = [] 44 | for epoch in range(1, 100): 45 | scheduler.step(epoch) 46 | a.append(epoch) 47 | b.append(optim.param_groups[0]['lr']) 48 | print(epoch, optim.param_groups[0]['lr']) 49 | 50 | plt.plot(a, b) 51 | plt.show() -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trinhvg/JCO_Learning-pytorch/6a6044801ac0a2f6ea2c479d4417e1128e5a0b72/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #for run_info in CLASS_ce MULTI_ce_mse_ceo MULTI_ce_mse MULTI_ce_mae MULTI_ce_mae_ceo REGRESS_mae REGRESS_mse 3 | #for run_info in MULTI_ce_mse MULTI_ce_mae REGRESS_mae REGRESS_mse CLASS_FocalLoss MULTI_mtmr 4 | #for run_info in REGRESS_rank_ordinal REGRESS_FocalOrdinalLoss REGRESS_rank_dorn REGRESS_soft_ordinal REGRESS 5 | 6 | for run_info in MULTI_ce_mse_ceo MULTI_ce_mse MULTI_ce_mae MULTI_ce_mae_ceo 7 | do 8 | python train_test_all_cosin_lr_apply_to_cancer.py --run_info ${run_info} --seed 5 --gpu 0,1 9 | done 10 | 11 | -------------------------------------------------------------------------------- /train_val.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import torch.utils.data as data 6 | 7 | from ignite.contrib.handlers import ProgressBar 8 | from ignite.engine import Engine, Events 9 | from ignite.handlers import ModelCheckpoint, Timer 10 | from ignite.metrics import RunningAverage 11 | from tensorboardX import SummaryWriter 12 | from imgaug import augmenters as iaa 13 | from misc.train_ultils_all_iter import * 14 | import importlib 15 | 16 | 17 | from loss.mtmr_loss import get_loss_mtmr 18 | from loss.rank_ordinal_loss import cost_fn 19 | from loss.dorn_loss import OrdinalLoss 20 | import dataset as dataset 21 | from config import Config 22 | from loss.ceo_loss import CEOLoss, FocalLoss, SoftLabelOrdinalLoss, FocalOrdinalLoss, count_pred 23 | 24 | 25 | #### 26 | 27 | class Trainer(Config): 28 | def __init__(self, _args=None): 29 | super(Trainer, self).__init__(_args=_args) 30 | if _args is not None: 31 | self.__dict__.update(_args.__dict__) 32 | print(self.run_info) 33 | 34 | #### 35 | def view_dataset(self, mode='train'): 36 | train_pairs, valid_pairs = getattr(dataset, ('prepare_%s_data' % self.dataset))() 37 | if mode == 'train': 38 | train_augmentors = self.train_augmentors() 39 | ds = dataset.DatasetSerial(train_pairs, has_aux=False, 40 | shape_augs=iaa.Sequential(train_augmentors[0]), 41 | input_augs=iaa.Sequential(train_augmentors[1])) 42 | else: 43 | infer_augmentors = self.infer_augmentors() # HACK 44 | ds = dataset.DatasetSerial(valid_pairs, has_aux=False, 45 | shape_augs=iaa.Sequential(infer_augmentors)[0]) 46 | dataset.visualize(ds, 4) 47 | return 48 | 49 | #### 50 | def train_step(self, engine, net, batch, iters, scheduler, optimizer, device): 51 | net.train() # train mode 52 | 53 | imgs_cpu, true_cpu = batch 54 | imgs_cpu = imgs_cpu.permute(0, 3, 1, 2) # to NCHW 55 | scheduler.step(engine.state.epoch + engine.state.iteration / iters) # scheduler.step(epoch + i / iters) 56 | # push data to GPUs 57 | imgs = imgs_cpu.to(device).float() 58 | true = true_cpu.to(device).long() # not one-hot 59 | 60 | # ----------------------------------------------------------- 61 | net.zero_grad() # not rnn so not accumulate 62 | out_net = net(imgs) # a list contains all the out put of the network 63 | loss = 0. 64 | 65 | # assign output 66 | if "CLASS" in self.task_type: 67 | logit_class = out_net 68 | if "REGRESS" in self.task_type: 69 | if ("rank_ordinal" in self.loss_type) or ("dorn" in self.loss_type): 70 | logit_regress, probas = out_net[0], out_net[1] 71 | else: 72 | logit_regress = out_net 73 | if "MULTI" in self.task_type: 74 | logit_class, logit_regress = out_net[0], out_net[1] 75 | 76 | # compute loss function 77 | if "ce" in self.loss_type: 78 | prob = F.softmax(logit_class, dim=-1) 79 | loss_entropy = F.cross_entropy(logit_class, true, reduction='mean') 80 | pred = torch.argmax(prob, dim=-1) 81 | loss += loss_entropy 82 | if 'FocalLoss' in self.loss_type: 83 | loss_focal = FocalLoss()(logit_class, true) 84 | prob = F.softmax(logit_class, dim=-1) 85 | pred = torch.argmax(prob, dim=-1) 86 | loss += loss_focal 87 | 88 | if "mse" in self.loss_type: 89 | criterion = torch.nn.MSELoss() 90 | loss_regres = criterion(logit_regress, true.float()) 91 | loss += loss_regres 92 | if "REGRESS" in self.task_type: 93 | label = torch.tensor([0., 1., 2., 3.]).repeat(len(true), 1).permute(1, 0).cuda() 94 | pred = torch.argmin(torch.abs(logit_regress - label), 0) 95 | if "mae" in self.loss_type: 96 | criterion = torch.nn.L1Loss() 97 | loss_regres = criterion(logit_regress, true.float()) 98 | loss += loss_regres 99 | if "REGRESS" in self.task_type: 100 | label = torch.tensor([0., 1., 2., 3.]).repeat(len(true), 1).permute(1, 0).cuda() 101 | pred = torch.argmin(torch.abs(logit_regress - label), 0) 102 | if "soft_label" in self.loss_type: 103 | criterion = SoftLabelOrdinalLoss(alpha=self.alpha) 104 | loss_regres = criterion(logit_regress, true.float()) 105 | loss += loss_regres 106 | if "REGRESS" in self.task_type: 107 | label = torch.tensor([0., 1 / 3, 2 / 3, 1.]).repeat(len(true), 1).permute(1, 0).cuda() 108 | pred = torch.argmin(torch.abs(logit_regress - label), 0) 109 | if "FocalOrdinal" in self.loss_type: 110 | criterion = FocalOrdinalLoss(pooling=True) 111 | loss_regres = criterion(logit_regress, true.float()) 112 | loss += loss_regres 113 | pred = count_pred(logit_regress) 114 | if "ceo" in self.loss_type: 115 | criterion = CEOLoss(num_classes=self.nr_classes) 116 | loss_ordinal = criterion(logit_regress, true) 117 | loss += loss_ordinal 118 | if "mtmr" in self.loss_type: 119 | loss = get_loss_mtmr(logit_class, logit_regress, true, true) 120 | prob = F.softmax(logit_class, dim=-1) 121 | pred = torch.argmax(prob, dim=-1) 122 | if "rank_coral" in self.loss_type: 123 | loss = cost_fn(logit_regress, true) 124 | predict_levels = probas > 0.5 125 | pred = torch.sum(predict_levels, dim=1) 126 | if "rank_dorn" in self.loss_type: 127 | pred, softmax = net(imgs) # forward 128 | loss = OrdinalLoss()(softmax, true) 129 | 130 | acc = torch.mean((pred == true).float()) # batch accuracy 131 | # gradient update 132 | loss.backward() 133 | optimizer.step() 134 | 135 | # ----------------------------------------------------------- 136 | return dict( 137 | loss=loss.item(), 138 | acc=acc.item(), 139 | ) 140 | 141 | #### 142 | def infer_step(self, net, batch, device): 143 | net.eval() # infer mode 144 | 145 | imgs, true = batch 146 | imgs = imgs.permute(0, 3, 1, 2) # to NCHW 147 | 148 | # push data to GPUs and convert to float32 149 | imgs = imgs.to(device).float() 150 | true = true.to(device).long() # not one-hot 151 | 152 | # ----------------------------------------------------------- 153 | with torch.no_grad(): # dont compute gradient 154 | out_net = net(imgs) # a list contains all the out put of the network 155 | if "CLASS" in self.task_type: 156 | logit_class = out_net 157 | prob = nn.functional.softmax(logit_class, dim=-1) 158 | return dict(logit_c=prob.cpu().numpy(), # from now prob of class task is called by logit_c 159 | true=true.cpu().numpy()) 160 | 161 | if "REGRESS" in self.task_type: 162 | if "rank_ordinal" in self.loss_type: 163 | logits, probas = out_net[0], out_net[1] 164 | predict_levels = probas > 0.5 165 | pred = torch.sum(predict_levels, dim=1) 166 | return dict(logit_r=pred.cpu().numpy(), 167 | true=true.cpu().numpy()) 168 | if "rank_dorn" in self.loss_type: 169 | pred, softmax = net(imgs) 170 | return dict(logit_r=pred.cpu().numpy(), 171 | true=true.cpu().numpy()) 172 | if "soft_label" in self.loss_type: 173 | logit_regress = (self.nr_classes - 1) * out_net 174 | return dict(logit_r=logit_regress.cpu().numpy(), 175 | true=true.cpu().numpy()) 176 | if "FocalOrdinal" in self.loss_type: 177 | logit_regress = out_net 178 | pred = count_pred(logit_regress) 179 | return dict(logit_r=pred.cpu().numpy(), 180 | true=true.cpu().numpy()) 181 | else: 182 | logit_regress = out_net 183 | return dict(logit_r=logit_regress.cpu().numpy(), 184 | true=true.cpu().numpy()) 185 | 186 | if "MULTI" in self.task_type: 187 | logit_class, logit_regress = out_net[0], out_net[1] 188 | prob = nn.functional.softmax(logit_class, dim=-1) 189 | return dict(logit_c=prob.cpu().numpy(), 190 | logit_r=logit_regress.cpu().numpy(), 191 | true=true.cpu().numpy()) 192 | 193 | #### 194 | def run_once(self, fold_idx): 195 | 196 | log_dir = self.log_dir 197 | check_manual_seed(self.seed) 198 | train_pairs, valid_pairs, test_pairs = getattr(dataset, ('prepare_%s_data' % self.dataset))(fold_idx) 199 | # --------------------------- Dataloader 200 | 201 | train_augmentors = self.train_augmentors() 202 | train_dataset = dataset.DatasetSerial(train_pairs[:], has_aux=False, 203 | shape_augs=iaa.Sequential(train_augmentors[0]), 204 | input_augs=iaa.Sequential(train_augmentors[1])) 205 | 206 | infer_augmentors = self.infer_augmentors() # HACK at has_aux 207 | infer_dataset = dataset.DatasetSerial(valid_pairs[:], has_aux=False, 208 | shape_augs=iaa.Sequential(infer_augmentors[0])) 209 | test_dataset = dataset.DatasetSerial(test_pairs[:], has_aux=False, 210 | shape_augs=iaa.Sequential(infer_augmentors[0])) 211 | 212 | train_loader = data.DataLoader(train_dataset, 213 | num_workers=self.nr_procs_train, 214 | batch_size=self.train_batch_size, 215 | shuffle=True, drop_last=True) 216 | valid_loader = data.DataLoader(infer_dataset, 217 | num_workers=self.nr_procs_valid, 218 | batch_size=self.infer_batch_size, 219 | shuffle=False, drop_last=False) 220 | test_loader = data.DataLoader(test_dataset, 221 | num_workers=self.nr_procs_valid, 222 | batch_size=self.infer_batch_size, 223 | shuffle=False, drop_last=False) 224 | 225 | # --------------------------- Training Sequence 226 | 227 | if self.logging: 228 | check_log_dir(log_dir) 229 | 230 | device = 'cuda' 231 | 232 | # Define your network here 233 | # # # # # Note: this code for EfficientNet B0 234 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model') # dynamic import 235 | if "FocalOrdinal" in self.loss_type: 236 | net = net_def.jl_efficientnet(task_mode='class', pretrained=True, num_classes=3) 237 | 238 | elif "rank_ordinal" in self.loss_type: 239 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_rank_ordinal') # dynamic import 240 | net = net_def.jl_efficientnet(task_mode='regress_rank_ordinal', pretrained=True) 241 | 242 | elif "mtmr" in self.loss_type: 243 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_mtmr') # dynamic import 244 | net = net_def.jl_efficientnet(task_mode='multi_mtmr', pretrained=True) 245 | 246 | elif "rank_dorn" in self.loss_type: 247 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_rank_ordinal') # dynamic import 248 | net = net_def.jl_efficientnet(task_mode='regress_rank_dorn', pretrained=True) 249 | 250 | else: 251 | net = net_def.jl_efficientnet(task_mode=self.task_type.lower(), pretrained=True) 252 | 253 | 254 | net = torch.nn.DataParallel(net).to(device) 255 | # optimizers 256 | optimizer = optim.Adam(net.parameters(), lr=self.init_lr) 257 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=self.nr_epochs // 3, T_mult=1, 258 | eta_min=self.init_lr * 0.1, last_epoch=-1) 259 | 260 | # 261 | iters = self.nr_epochs * self.epoch_length 262 | trainer = Engine(lambda engine, batch: self.train_step(engine, net, batch, iters, scheduler, optimizer, device)) 263 | valider = Engine(lambda engine, batch: self.infer_step(net, batch, device)) 264 | test = Engine(lambda engine, batch: self.infer_step(net, batch, device)) 265 | 266 | # assign output 267 | if "CLASS" in self.task_type: 268 | infer_output = ['logit_c', 'true'] 269 | if "REGRESS" in self.task_type: 270 | infer_output = ['logit_r', 'true'] 271 | if "MULTI" in self.task_type: 272 | infer_output = ['logit_c', 'logit_r', 'pred_c', 'pred_r', 'true'] 273 | 274 | ## 275 | events = Events.EPOCH_COMPLETED 276 | if self.logging: 277 | @trainer.on(events) 278 | def save_chkpoints(engine): 279 | torch.save(net.state_dict(), self.log_dir + '/_net_' + str(engine.state.iteration) + '.pth') 280 | 281 | timer = Timer(average=True) 282 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 283 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 284 | timer.attach(valider, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 285 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 286 | timer.attach(test, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 287 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 288 | 289 | # attach running average metrics computation 290 | # decay of EMA to 0.95 to match tensorpack default 291 | # TODO: refactor this 292 | RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach(trainer, 'acc') 293 | RunningAverage(alpha=0.95, output_transform=lambda x: x['loss']).attach(trainer, 'loss') 294 | 295 | # attach progress bar 296 | pbar = ProgressBar(persist=True) 297 | pbar.attach(trainer, metric_names=['loss']) 298 | pbar.attach(valider) 299 | pbar.attach(test) 300 | 301 | # writer for tensorboard logging 302 | tfwriter = None # HACK temporary 303 | if self.logging: 304 | tfwriter = SummaryWriter(logdir=log_dir) 305 | json_log_file = log_dir + '/stats.json' 306 | with open(json_log_file, 'w') as json_file: 307 | json.dump({}, json_file) # create empty file 308 | 309 | ### TODO refactor again 310 | log_info_dict = { 311 | 'logging': self.logging, 312 | 'optimizer': optimizer, 313 | 'tfwriter': tfwriter, 314 | 'json_file': json_log_file if self.logging else None, 315 | 'nr_classes': self.nr_classes, 316 | 'metric_names': infer_output, 317 | 'infer_batch_size': self.infer_batch_size # too cumbersome 318 | } 319 | trainer.add_event_handler(Events.EPOCH_COMPLETED, 320 | lambda engine: scheduler.step(engine.state.epoch - 1)) # to change the lr 321 | trainer.add_event_handler(Events.EPOCH_COMPLETED, log_train_ema_results, log_info_dict) 322 | trainer.add_event_handler(Events.EPOCH_COMPLETED, inference, valider, 'valid', valid_loader, log_info_dict) 323 | trainer.add_event_handler(Events.EPOCH_COMPLETED, inference, test, 'test', test_loader, log_info_dict) 324 | valider.add_event_handler(Events.ITERATION_COMPLETED, accumulate_outputs) 325 | test.add_event_handler(Events.ITERATION_COMPLETED, accumulate_outputs) 326 | 327 | # Setup is done. Now let's run the training 328 | # trainer.run(train_loader, self.nr_epochs) 329 | trainer.run(train_loader, self.nr_epochs, self.epoch_length) 330 | return 331 | 332 | #### 333 | def run(self): 334 | if self.cross_valid: 335 | for fold_idx in range(0, trainer.nr_fold): 336 | trainer.run_once(fold_idx) 337 | else: 338 | self.run_once(self.fold_idx) 339 | return 340 | 341 | 342 | #### 343 | if __name__ == '__main__': 344 | parser = argparse.ArgumentParser() 345 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') 346 | parser.add_argument('--view', help='view dataset', action='store_true') 347 | parser.add_argument('--run_info', type=str, default='REGRESS_rank_dorn', 348 | help='CLASS, REGRESS, MULTI + loss, ' 349 | 'loss ex: MULTI_mtmr, REGRESS_rank_ordinal, REGRESS_rank_dorn' 350 | 'REGRESS_FocalOrdinalLoss, REGRESS_soft_ordinal') 351 | parser.add_argument('--dataset', type=str, default='colon_tma', help='colon_tma, prostate_uhu') 352 | parser.add_argument('--seed', type=int, default=5, help='number') 353 | parser.add_argument('--alpha', type=int, default=5, help='number') 354 | args = parser.parse_args() 355 | 356 | trainer = Trainer(_args=args) 357 | if args.view: 358 | trainer.view_dataset() 359 | exit() 360 | if args.gpu: 361 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 362 | trainer.run() 363 | -------------------------------------------------------------------------------- /train_val_ceo_for_cancer_only.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import torch.utils.data as data 6 | 7 | from ignite.contrib.handlers import ProgressBar 8 | from ignite.engine import Engine, Events 9 | from ignite.handlers import ModelCheckpoint, Timer 10 | from ignite.metrics import RunningAverage 11 | from tensorboardX import SummaryWriter 12 | from imgaug import augmenters as iaa 13 | from misc.train_ultils_all_iter import * 14 | from loss.cancer_loss import * 15 | # from misc.train_utils import * 16 | # from misc.focalloss_regression import * 17 | 18 | import importlib 19 | import dataset as dataset 20 | from config import Config 21 | from loss.ceo_loss import CEOLoss, FocalLoss, SoftLabelOrdinalLoss, FocalOrdinalLoss, count_pred 22 | 23 | 24 | #### 25 | 26 | class Trainer(Config): 27 | def __init__(self, _args=None): 28 | super(Trainer, self).__init__(_args=_args) 29 | if _args is not None: 30 | self.__dict__.update(_args.__dict__) 31 | print(self.run_info) 32 | 33 | #### 34 | def view_dataset(self, mode='train'): 35 | train_pairs, valid_pairs = getattr(dataset, ('prepare_%s_data' % self.dataset))() 36 | if mode == 'train': 37 | train_augmentors = self.train_augmentors() 38 | ds = dataset.DatasetSerial(train_pairs, has_aux=False, 39 | shape_augs=iaa.Sequential(train_augmentors[0]), 40 | input_augs=iaa.Sequential(train_augmentors[1])) 41 | else: 42 | infer_augmentors = self.infer_augmentors() # HACK 43 | ds = dataset.DatasetSerial(valid_pairs, has_aux=False, 44 | shape_augs=iaa.Sequential(infer_augmentors)[0]) 45 | dataset.visualize(ds, 4) 46 | return 47 | 48 | #### 49 | def train_step(self, engine, net, batch, iters, scheduler, optimizer, device): 50 | net.train() # train mode 51 | 52 | imgs_cpu, true_cpu = batch 53 | imgs_cpu = imgs_cpu.permute(0, 3, 1, 2) # to NCHW 54 | scheduler.step(engine.state.epoch + engine.state.iteration / iters) # scheduler.step(epoch + i / iters) 55 | # push data to GPUs 56 | imgs = imgs_cpu.to(device).float() 57 | true = true_cpu.to(device).long() # not one-hot 58 | 59 | # ----------------------------------------------------------- 60 | net.zero_grad() # not rnn so not accumulate 61 | out_net = net(imgs) # a list contains all the out put of the network 62 | loss = 0. 63 | 64 | # assign output 65 | if "CLASS" in self.task_type: 66 | logit_class = out_net 67 | if "REGRESS" in self.task_type: 68 | logit_regress = out_net 69 | if "MULTI" in self.task_type: 70 | logit_class, logit_regress = out_net[0], out_net[1] 71 | 72 | # compute loss function 73 | if "ce" in self.loss_type: 74 | prob = F.softmax(logit_class, dim=-1) 75 | loss_entropy = F.cross_entropy(logit_class, true, reduction='mean') 76 | pred = torch.argmax(prob, dim=-1) 77 | loss += loss_entropy 78 | if 'FocalLoss' in self.loss_type: 79 | loss_focal = FocalLoss()(logit_class, true) 80 | prob = F.softmax(logit_class, dim=-1) 81 | pred = torch.argmax(prob, dim=-1) 82 | loss += loss_focal 83 | 84 | if "mse" in self.loss_type: 85 | loss += mse_cancer_v0(logit_regress, true.float()) 86 | # criterion = torch.nn.MSELoss() 87 | # loss_regres = criterion(logit_regress, true.float()) 88 | # loss += loss_regres 89 | if "REGRESS" in self.task_type: 90 | label = torch.tensor(np.arange(self.nr_classes)).float().repeat(len(true), 1).permute(1, 0).cuda() 91 | pred = torch.argmin(torch.abs(logit_regress - label), 0) 92 | if "mae" in self.loss_type: 93 | loss += mae_cancer_v0(logit_regress, true.float()) 94 | # criterion = torch.nn.L1Loss() 95 | # loss_regres = criterion(logit_regress, true.float()) 96 | # loss += loss_regres 97 | if "REGRESS" in self.task_type: 98 | label = torch.tensor(np.arange(self.nr_classes)).float().repeat(len(true), 1).permute(1, 0).cuda() 99 | pred = torch.argmin(torch.abs(logit_regress - label), 0) 100 | if "ceo" in self.loss_type: # ceo when conduct only on cancer sample 101 | loss += ceo_cancer_v0(logit_regress, true) 102 | 103 | acc = torch.mean((pred == true).float()) # batch accuracy 104 | # gradient update 105 | loss.backward() 106 | optimizer.step() 107 | 108 | # ----------------------------------------------------------- 109 | return dict( 110 | loss=loss.item(), 111 | acc=acc.item(), 112 | ) 113 | 114 | #### 115 | def infer_step(self, net, batch, device): 116 | net.eval() # infer mode 117 | 118 | imgs, true = batch 119 | imgs = imgs.permute(0, 3, 1, 2) # to NCHW 120 | 121 | # push data to GPUs and convert to float32 122 | imgs = imgs.to(device).float() 123 | true = true.to(device).long() # not one-hot 124 | 125 | # ----------------------------------------------------------- 126 | with torch.no_grad(): # dont compute gradient 127 | out_net = net(imgs) # a list contains all the out put of the network 128 | if "CLASS" in self.task_type: 129 | logit_class = out_net 130 | prob = nn.functional.softmax(logit_class, dim=-1) 131 | return dict(logit_c=prob.cpu().numpy(), # from now prob of class task is called by logit_c 132 | true=true.cpu().numpy()) 133 | 134 | if "REGRESS" in self.task_type: 135 | if "rank_ordinal" in self.loss_type: 136 | logits, probas = out_net[0], out_net[1] 137 | predict_levels = probas > 0.5 138 | pred = torch.sum(predict_levels, dim=1) 139 | return dict(logit_r=pred.cpu().numpy(), 140 | true=true.cpu().numpy()) 141 | if "rank_dorn" in self.loss_type: 142 | pred, softmax = net(imgs) 143 | return dict(logit_r=pred.cpu().numpy(), 144 | true=true.cpu().numpy()) 145 | if "soft_ordinal" in self.loss_type: 146 | logit_regress = (self.nr_classes - 1) * out_net 147 | return dict(logit_r=logit_regress.cpu().numpy(), 148 | true=true.cpu().numpy()) 149 | if "FocalOrdinalLoss" in self.loss_type: 150 | logit_regress = out_net 151 | pred = count_pred(logit_regress) 152 | return dict(logit_r=pred.cpu().numpy(), 153 | true=true.cpu().numpy()) 154 | else: 155 | logit_regress = out_net 156 | return dict(logit_r=logit_regress.cpu().numpy(), 157 | true=true.cpu().numpy()) 158 | 159 | if "MULTI" in self.task_type: 160 | logit_class, logit_regress = out_net[0], out_net[1] 161 | prob = nn.functional.softmax(logit_class, dim=-1) 162 | return dict(logit_c=prob.cpu().numpy(), 163 | logit_r=logit_regress.cpu().numpy(), 164 | true=true.cpu().numpy()) 165 | 166 | #### 167 | def run_once(self, fold_idx): 168 | 169 | log_dir = self.log_dir 170 | check_manual_seed(self.seed) 171 | train_pairs, valid_pairs, test_pairs = getattr(dataset, ('prepare_%s_data' % self.dataset))(fold_idx) 172 | # --------------------------- Dataloader 173 | 174 | train_augmentors = self.train_augmentors() 175 | train_dataset = dataset.DatasetSerial(train_pairs[:], has_aux=False, 176 | shape_augs=iaa.Sequential(train_augmentors[0]), 177 | input_augs=iaa.Sequential(train_augmentors[1])) 178 | 179 | infer_augmentors = self.infer_augmentors() # HACK at has_aux 180 | infer_dataset = dataset.DatasetSerial(valid_pairs[:], has_aux=False, 181 | shape_augs=iaa.Sequential(infer_augmentors[0])) 182 | test_dataset = dataset.DatasetSerial(test_pairs[:], has_aux=False, 183 | shape_augs=iaa.Sequential(infer_augmentors[0])) 184 | 185 | train_loader = data.DataLoader(train_dataset, 186 | num_workers=self.nr_procs_train, 187 | batch_size=self.train_batch_size, 188 | shuffle=True, drop_last=True) 189 | valid_loader = data.DataLoader(infer_dataset, 190 | num_workers=self.nr_procs_valid, 191 | batch_size=self.infer_batch_size, 192 | shuffle=False, drop_last=False) 193 | test_loader = data.DataLoader(test_dataset, 194 | num_workers=self.nr_procs_valid, 195 | batch_size=self.infer_batch_size, 196 | shuffle=False, drop_last=False) 197 | 198 | # --------------------------- Training Sequence 199 | 200 | if self.logging: 201 | check_log_dir(log_dir) 202 | 203 | device = 'cuda' 204 | 205 | # Define your network here 206 | # # # # # Note: this code for EfficientNet B0 207 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model') # dynamic import 208 | if "FocalOrdinalLoss" in self.loss_type: 209 | net = net_def.jl_efficientnet(task_mode='class', pretrained=True, num_classes=3) 210 | 211 | elif "rank_ordinal" in self.loss_type: 212 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_rank_ordinal') # dynamic import 213 | net = net_def.jl_efficientnet(task_mode='regress_rank_ordinal', pretrained=True) 214 | 215 | elif "mtmr" in self.loss_type: 216 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_mtmr') # dynamic import 217 | net = net_def.jl_efficientnet(task_mode='multi_mtmr', pretrained=True) 218 | 219 | elif "rank_dorn" in self.loss_type: 220 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_rank_ordinal') # dynamic import 221 | net = net_def.jl_efficientnet(task_mode='regress_rank_dorn', pretrained=True) 222 | 223 | else: 224 | net = net_def.jl_efficientnet(task_mode=self.task_type.lower(), pretrained=True) 225 | 226 | 227 | net = torch.nn.DataParallel(net).to(device) 228 | # optimizers 229 | optimizer = optim.Adam(net.parameters(), lr=self.init_lr) 230 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=self.nr_epochs // 3, T_mult=1, 231 | eta_min=self.init_lr * 0.1, last_epoch=-1) 232 | 233 | # 234 | iters = self.nr_epochs * self.epoch_length 235 | trainer = Engine(lambda engine, batch: self.train_step(engine, net, batch, iters, scheduler, optimizer, device)) 236 | valider = Engine(lambda engine, batch: self.infer_step(net, batch, device)) 237 | test = Engine(lambda engine, batch: self.infer_step(net, batch, device)) 238 | 239 | # assign output 240 | if "CLASS" in self.task_type: 241 | infer_output = ['logit_c', 'true'] 242 | if "REGRESS" in self.task_type: 243 | infer_output = ['logit_r', 'true'] 244 | if "MULTI" in self.task_type: 245 | infer_output = ['logit_c', 'logit_r', 'pred_c', 'pred_r', 'true'] 246 | 247 | ## 248 | events = Events.EPOCH_COMPLETED 249 | if self.logging: 250 | @trainer.on(events) 251 | def save_chkpoints(engine): 252 | torch.save(net.state_dict(), self.log_dir + '/_net_' + str(engine.state.iteration) + '.pth') 253 | 254 | timer = Timer(average=True) 255 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 256 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 257 | timer.attach(valider, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 258 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 259 | timer.attach(test, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 260 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 261 | 262 | # attach running average metrics computation 263 | # decay of EMA to 0.95 to match tensorpack default 264 | # TODO: refactor this 265 | RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach(trainer, 'acc') 266 | RunningAverage(alpha=0.95, output_transform=lambda x: x['loss']).attach(trainer, 'loss') 267 | 268 | # attach progress bar 269 | pbar = ProgressBar(persist=True) 270 | pbar.attach(trainer, metric_names=['loss']) 271 | pbar.attach(valider) 272 | pbar.attach(test) 273 | 274 | # writer for tensorboard logging 275 | tfwriter = None # HACK temporary 276 | if self.logging: 277 | tfwriter = SummaryWriter(logdir=log_dir) 278 | json_log_file = log_dir + '/stats.json' 279 | with open(json_log_file, 'w') as json_file: 280 | json.dump({}, json_file) # create empty file 281 | 282 | ### TODO refactor again 283 | log_info_dict = { 284 | 'logging': self.logging, 285 | 'optimizer': optimizer, 286 | 'tfwriter': tfwriter, 287 | 'json_file': json_log_file if self.logging else None, 288 | 'nr_classes': self.nr_classes, 289 | 'metric_names': infer_output, 290 | 'infer_batch_size': self.infer_batch_size # too cumbersome 291 | } 292 | trainer.add_event_handler(Events.EPOCH_COMPLETED, 293 | lambda engine: scheduler.step(engine.state.epoch - 1)) # to change the lr 294 | trainer.add_event_handler(Events.EPOCH_COMPLETED, log_train_ema_results, log_info_dict) 295 | trainer.add_event_handler(Events.EPOCH_COMPLETED, inference, valider, 'valid', valid_loader, log_info_dict) 296 | trainer.add_event_handler(Events.EPOCH_COMPLETED, inference, test, 'test', test_loader, log_info_dict) 297 | valider.add_event_handler(Events.ITERATION_COMPLETED, accumulate_outputs) 298 | test.add_event_handler(Events.ITERATION_COMPLETED, accumulate_outputs) 299 | 300 | # Setup is done. Now let's run the training 301 | # trainer.run(train_loader, self.nr_epochs) 302 | trainer.run(train_loader, self.nr_epochs, self.epoch_length) 303 | return 304 | 305 | #### 306 | def run(self): 307 | if self.cross_valid: 308 | for fold_idx in range(0, trainer.nr_fold): 309 | trainer.run_once(fold_idx) 310 | else: 311 | self.run_once(self.fold_idx) 312 | return 313 | 314 | 315 | #### 316 | if __name__ == '__main__': 317 | parser = argparse.ArgumentParser() 318 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') 319 | parser.add_argument('--view', help='view dataset', action='store_true') 320 | parser.add_argument('--run_info', type=str, default='REGRESS_rank_dorn', 321 | help='CLASS, REGRESS, MULTI + loss, ' 322 | 'loss ex: MULTI_mtmr, REGRESS_rank_ordinal, REGRESS_rank_dorn' 323 | 'REGRESS_FocalOrdinalLoss, REGRESS_soft_ordinal') 324 | parser.add_argument('--dataset', type=str, default='colon_tma', help='colon_set1, prostate_set1') 325 | parser.add_argument('--seed', type=int, default=5, help='number') 326 | parser.add_argument('--alpha', type=int, default=5, help='number') 327 | args = parser.parse_args() 328 | 329 | trainer = Trainer(_args=args) 330 | if args.view: 331 | trainer.view_dataset() 332 | exit() 333 | if args.gpu: 334 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 335 | trainer.run() 336 | --------------------------------------------------------------------------------