├── LICENSE
├── README.md
├── config.py
├── config_validator.py
├── dataset.py
├── docs
└── JCOL.jpg
├── environment.yml
├── infer_produce_predict_map_wsi.py
├── loss
├── __init__.py
├── cancer_loss.py
├── ceo_loss.py
├── dorn_loss.py
├── mtmr_loss.py
└── rank_ordinal_loss.py
├── misc
├── infer_wsi_utils.py
├── train_ultils_all_iter.py
└── train_ultils_validator.py
├── model_lib
├── __init__.py
└── efficientnet_pytorch
│ ├── __init__.py
│ ├── model.py
│ ├── model_dorn.py
│ ├── model_mtmr.py
│ ├── model_rank_ordinal.py
│ └── utils.py
├── requirements.txt
├── scheduler_lr
├── __init__.py
└── warmup_cosine_lr.py
├── scripts
├── __init__.py
└── run_train.sh
├── train_val.py
└── train_val_ceo_for_cancer_only.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 TrinhVg
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # JCO_Learning: Joint Categorical and Ordinal Learning for Cancer Grading in Pathology Images
2 | ## About
3 | A multi-task deep learning model for pathology image grading conducts categorical classification,
4 | and auxiliary ordinal classification for Cancer Grading in Pathology Images uses a L_CEO loss for the auxiliary ordinal task.
5 | [Link](https://www.sciencedirect.com/science/article/pii/S1361841521002516) to Medical Image Analysis paper.
6 |
7 | 
8 | ## Datasets
9 | All the models in this project were evaluated on the following datasets:
10 |
11 | - [Colon_KBSMC](https://github.com/QuIIL/KBSMC_colon_cancer_grading_dataset) (Colon TMA from Kangbuk Samsung Hospital)
12 | - [Colon_KBSMC](https://github.com/QuIIL/KBSMC_colon_cancer_grading_dataset) (Colon WSI from Kangbuk Samsung Hospital)
13 | - [Prostate_UHU](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/OCYCMP) (Prostate TMA from University Hospital Zurich - Harvard dataverse)
14 | - [Prostate_UBC](https://gleason2019.grand-challenge.org/) (Prostate TMA from UBC - MICCAI 2019)
15 |
16 | ## Set Up Environment
17 |
18 | ```
19 | conda env create -f environment.yml
20 | conda activate jco_learning
21 | pip install torch~=1.8.1+cu111
22 | ```
23 |
24 | Above, we install PyTorch version 1.8.1 with CUDA 11.1.
25 | The code still work older Pytorch version (PyTorch >=1.1).
26 | ## Repository Structure
27 |
28 | Below are the main directories in the repository:
29 |
30 | - `dataloader/`: the data loader and augmentation pipeline
31 | - `docs/`: figures/GIFs used in the repo
32 | - `misc/`: utils that are
33 | - `model_lib/`: model definition, along with the main run step and hyperparameter settings
34 | - `script/`: defines the training loop
35 |
36 | Below are the main executable scripts in the repository:
37 |
38 | - `config.py`: configuration file
39 | - `config_validator.py`: still configuration file but for validation/test phrase or generate the predicted maps
40 | - `dataset.py`: defines the dataset classes
41 | - `train_val.py`: main training script
42 | - `train_val_ceo_for_cancer_only.py`: still training script but ordinal loss only applied to cancer classes (benign class is excluded)
43 | - `infer_produce_predict_map_wsi.py`: following sliding window fashion to generate a predicted map or probability map for WSI/core image
44 |
45 | # Running the Code
46 |
47 | ## Training and Options
48 |
49 | ```
50 | python train_val.py [--gpu=] [--run_info=] [--dataset=]
51 | ```
52 |
53 | Options:
54 | ** Our proposed and 9 common/state-of-the-art categorical and ordinal classification methods, including:**
55 |
56 | | METHOD | run_info | Description |
57 | | -------------|----------------------| ----------------------|
58 | | C_CE | CLASS_ce | Classification: Cross-Entropy loss
59 | | C_FOCAL | CLASS_FocalLoss | Classification: Focal loss, Focal loss for dense object detection [[paper]](https://arxiv.org/abs/1708.02002)
60 | | R_MAE | REGRESS_mae | Regression: MAE loss
61 | | R_MSE | REGRESS_mse | Regression: MSE loss
62 | | R_SL | REGRESS_soft_label | Regression: Soft-Label loss, Deep learning regression for prostate cancer detection and grading in Bi-parametric MRI [[paper]](https://ieeexplore.ieee.org/document/9090311)
63 | | O_DORN | REGRESS_rank_dorn | Ordinal regression: Deep ordinal regression network for monocular depth estimation [[paper]](https://arxiv.org/abs/1806.02446) [[code]](https://github.com/hufu6371/DORN?utm_source=catalyzex.com)
64 | | O_CORAL | REGRESS_rank_coral | Ordinal regression: Rank consistent ordinal regression for neural networks with application to age estimation [[paper]](https://arxiv.org/abs/1901.07884) [[code]](https://github.com/Raschka-research-group/coral-cnn?utm_source=catalyzex.com)
65 | | O_FOCAL | REGRESS_FocalOrdinal | Ordinal regression: Joint prostate cancer detection and Gleason score prediction in mp-MRI via FocalNet [[paper]](https://ieeexplore.ieee.org/document/8653866)
66 | | M_MTMR | MULTI_mtmr | Multitask: Multi-task deep model with margin ranking loss for lung nodule analysis [[paper]](https://ieeexplore.ieee.org/document/8794587) [[code]](https://github.com/lihaoliu-cambridge/mtmr-net)
67 | | M_MAE | MULTI_ce_mae | Multitask: Class_CE + Regression_MAE
68 | | M_MSE | MULTI_ce_mse | Multitask: Class_CE + Regression_MSE
69 | | M_MAE_CEO | MULTI_ce_mae_ceo | Multitask: Class_CE + Regression_MAE_CEO (Ours)
70 | | M_MSE_CEO | MULTI_ce_mae_ceo | Multitask: Class_CE + Regression_MSE_CEO (Ours)
71 |
72 |
73 |
74 |
75 | ## Inference
76 |
77 | ```
78 | python infer_produce_predict_map_wsi.py [--gpu=] [--run_info=]
79 | ```
80 |
81 | ### Model Weights
82 |
83 | Model weights obtained from training MULTI_ce_mse_ceo here:
84 | - [Colon checkpoint](https://drive.google.com/drive/folders/1Gf2HjjcjJw4h1VvFUbnF2xvr9SJ6_r48?usp=sharing)
85 | - [Prostate checkpoint](https://drive.google.com/drive/folders/1Gf2HjjcjJw4h1VvFUbnF2xvr9SJ6_r48?usp=sharing)
86 |
87 | Access the entire checkpoints [here](https://drive.google.com/drive/folders/1KQMD0iRibfAP9AxBE4TuU1NtPGvw-h5R?usp=sharing).
88 |
89 | If any of the above checkpoints are used, please ensure to cite the corresponding paper.
90 |
91 | ## Authors
92 |
93 | * [Trinh, TL Vuong](https://github.com/trinhvg), Kim, Kyungeun and Song, Boram [Jin Tae Kwak](https://github.com/JinTaeKwak)
94 |
95 |
96 | ## Citation
97 |
98 | If any part of this code is used, please give appropriate citation to our paper.
99 |
100 | BibTex entry:
101 | ```
102 | @article{le2021joint,
103 | title={Joint categorical and ordinal learning for cancer grading in pathology images},
104 | author={Le Vuong, Trinh Thi and Kim, Kyungeun and Song, Boram and Kwak, Jin Tae},
105 | journal={Medical image analysis},
106 | pages={102206},
107 | year={2021},
108 | publisher={Elsevier}
109 | }
110 | ```
111 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import imgaug # https://github.com/aleju/imgaug
2 | from imgaug import augmenters as iaa
3 | import imgaug as ia
4 | import os
5 |
6 | ####
7 | class Config(object):
8 | def __init__(self, _args=None):
9 | if _args is not None:
10 | self.__dict__.update(_args.__dict__)
11 | self.seed = self.seed
12 | self.init_lr = 1.0e-3
13 | self.lr_steps = 20 # decrease at every n-th epoch
14 | self.gamma = 0.2
15 | self.train_batch_size = 64
16 | self.infer_batch_size = 256
17 | self.nr_classes = 4
18 | self.nr_epochs = 60
19 | self.epoch_length = 50
20 |
21 | # nr of processes for parallel processing input
22 | self.nr_procs_train = 8
23 | self.nr_procs_valid = 8
24 |
25 | self.nr_fold = 5
26 | self.fold_idx = 0
27 | self.cross_valid = False
28 |
29 | self.load_network = False
30 | self.save_net_path = ""
31 |
32 | #
33 | self.dataset = 'colon_manual'
34 | self.logging = True # True for debug run only
35 |
36 | self.log_path = '/media/data1/trinh_2021/data/workspace_data/join_learning_2021/colon/ordinalforcancer_v0/'
37 |
38 | self.chkpts_prefix = 'model'
39 | if _args is not None:
40 | self.__dict__.update(_args.__dict__)
41 | self.task_type = self.run_info.split('_')[0]
42 | self.loss_type = self.run_info.replace(self.task_type + "_", "")
43 | self.model_name = f'/{self.task_type}_{self.loss_type}_cancer_Effi_seed{self.seed}_BS64'
44 | self.log_dir = self.log_path + self.model_name
45 | print(self.model_name)
46 |
47 | def train_augmentors(self):
48 | if self.dataset == "prostate_hv":
49 | shape_augs = [
50 | iaa.Resize(0.5, interpolation='nearest'),
51 | iaa.CropToFixedSize(width=350, height=350),
52 | ]
53 | else:
54 | shape_augs = []
55 | #
56 | sometimes = lambda aug: iaa.Sometimes(0.2, aug)
57 | input_augs = iaa.Sequential(
58 | [
59 | # apply the following augmenters to most images
60 | iaa.Fliplr(0.5), # horizontally flip 50% of all images
61 | iaa.Flipud(0.5), # vertically flip 50% of all images
62 | sometimes(iaa.Affine(
63 | rotate=(-45, 45), # rotate by -45 to +45 degrees
64 | shear=(-16, 16), # shear by -16 to +16 degrees
65 | order=[0, 1], # use nearest neighbour or bilinear interpolation (fast)
66 | cval=(0, 255), # if mode is constant, use a cval between 0 and 255
67 | mode='symmetric'
68 | # use any of scikit-image's warping modes (see 2nd image from the top for examples)
69 | )),
70 | # execute 0 to 5 of the following (less important) augmenters per image
71 | # don't execute all of them, as that would often be way too strong
72 | iaa.SomeOf((0, 5),
73 | [
74 | iaa.OneOf([
75 | iaa.GaussianBlur((0, 3.0)), # blur images with a sigma between 0 and 3.0
76 | iaa.AverageBlur(k=(2, 7)),
77 | # blur image using local means with kernel sizes between 2 and 7
78 | iaa.MedianBlur(k=(3, 11)),
79 | # blur image using local medians with kernel sizes between 2 and 7
80 | ]),
81 | iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5),
82 | # add gaussian noise to images
83 | iaa.Dropout((0.01, 0.1), per_channel=0.5), # randomly remove up to 10% of the pixels
84 | # change brightness of images (by -10 to 10 of original value)
85 | iaa.AddToHueAndSaturation((-20, 20)), # change hue and saturation
86 | iaa.LinearContrast((0.5, 2.0), per_channel=0.5), # improve or worsen the contrast
87 | ],
88 | random_order=True
89 | )
90 | ],
91 | random_order=True
92 | )
93 | return shape_augs, input_augs
94 |
95 | ####
96 | def infer_augmentors(self):
97 | if self.dataset == "prostate_hv":
98 | shape_augs = [
99 | iaa.Resize(0.5, interpolation='nearest'),
100 | iaa.CropToFixedSize(width=350, height=350, position="center"),
101 | ]
102 | else:
103 | shape_augs = []
104 | return shape_augs, None
105 |
106 | ###########################################################################
--------------------------------------------------------------------------------
/config_validator.py:
--------------------------------------------------------------------------------
1 | import imgaug # https://github.com/aleju/imgaug
2 | from imgaug import augmenters as iaa
3 | import imgaug as ia
4 |
5 |
6 | ####
7 | class Config(object):
8 | def __init__(self, _args=None):
9 | if _args is not None:
10 | self.__dict__.update(_args.__dict__)
11 | self.seed = 5 #self.seed
12 | self.infer_batch_size = 128
13 | self.nr_classes = 4
14 |
15 | # nr of processes for parallel processing input
16 | self.nr_procs_valid = 8
17 |
18 | self.load_network = False
19 | self.save_net_path = ""
20 |
21 | self.dataset = 'colon_manual'
22 | self.logging = False # True for debug run only
23 | self.log_path = ""
24 | self.chkpts_prefix = 'model'
25 | self.model_name = 'validator'
26 | self.log_dir = self.log_path + self.model_name
27 | print(self.model_name)
28 |
29 | ####
30 | def infer_augmentors(self):
31 | if self.dataset == "prostate_hv":
32 | shape_augs = [
33 | iaa.Resize(0.5, interpolation='nearest'),
34 | iaa.CropToFixedSize(width=350, height=350, position="center"),
35 | ]
36 | else:
37 | shape_augs = []
38 | return shape_augs, None
39 |
40 | ############################################################################
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import glob
4 | import random
5 | from collections import Counter
6 | import cv2
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import torch.utils.data as data
10 | from torchvision import transforms
11 | from imgaug import augmenters as iaa
12 |
13 | ####
14 |
15 |
16 | class DatasetSerial(data.Dataset):
17 |
18 | def __init__(self, pair_list, shape_augs=None, input_augs=None, has_aux=False, test_aux=False):
19 | self.test_aux = test_aux
20 | self.pair_list = pair_list
21 | self.shape_augs = shape_augs
22 | self.input_augs = input_augs
23 |
24 | def __getitem__(self, idx):
25 | pair = self.pair_list[idx]
26 | # print(pair)
27 | input_img = cv2.imread(pair[0])
28 | input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
29 | img_label = pair[1]
30 | # print(input_img.shape)
31 | transform = transforms.Compose([
32 | transforms.ToTensor(),
33 | transforms.Normalize(mean=[0., 0., 0.],
34 | std=[1., 1., 1.])
35 | ])
36 |
37 | if not self.test_aux:
38 |
39 | # shape must be deterministic so it can be reused
40 | if self.shape_augs is not None:
41 | shape_augs = self.shape_augs.to_deterministic()
42 | input_img = shape_augs.augment_image(input_img)
43 |
44 | # additional augmenattion just for the input
45 | if self.input_augs is not None:
46 | input_img = self.input_augs.augment_image(input_img)
47 |
48 | input_img = np.array(input_img).copy()
49 | transform = transforms.Compose([
50 | transforms.ToTensor(),
51 | transforms.Normalize(mean=[0., 0., 0.],
52 | std=[1., 1., 1.])
53 | ])
54 |
55 | out_img = np.array(transform(input_img)).transpose(1, 2, 0)
56 | else:
57 | out_img = []
58 | for idx in range(5):
59 | input_img_ = input_img.copy()
60 | if self.shape_augs is not None:
61 | shape_augs = self.shape_augs.to_deterministic()
62 | input_img_ = shape_augs.augment_image(input_img_)
63 | input_img_ = iaa.Sequential(self.input_augs[idx]).augment_image(input_img_)
64 | input_img_ = np.array(input_img_).copy()
65 | input_img_ = np.array(transform(input_img_)).transpose(1, 2, 0)
66 | out_img.append(input_img_)
67 | return np.array(out_img), img_label
68 |
69 | def __len__(self):
70 | return len(self.pair_list)
71 |
72 |
73 | class DatasetSerialWSI(data.Dataset):
74 | def __init__(self, path_list):
75 | self.path_list = path_list
76 |
77 | def __getitem__(self, idx):
78 | input_img = cv2.imread(self.path_list[idx])
79 | input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
80 | input_img = np.array(input_img).copy()
81 | transform = transforms.Compose([
82 | transforms.ToTensor(),
83 | transforms.Normalize(mean=[0., 0., 0.],
84 | std=[1., 1., 1.])
85 | ])
86 | input_img = np.array(transform(input_img)).transpose(1, 2, 0)
87 | location = self.path_list[idx].split('/')[-1].split('.')[0].split('_')
88 | return input_img, location
89 |
90 | def __len__(self):
91 | return len(self.path_list)
92 |
93 | def prepare_colon_tma_data():
94 | def load_data_info(pathname, parse_label=True, label_value=0):
95 | file_list = glob.glob(pathname)
96 | cancer_test = False
97 | if cancer_test:
98 | file_list_bn = glob.glob(pathname.replace('*.jpg', '*0.jpg'))
99 | file_list = [elem for elem in file_list if elem not in file_list_bn]
100 | label_list = [int(file_path.split('_')[-1].split('.')[0])-1 for file_path in file_list]
101 | else:
102 | if parse_label:
103 | label_list = [int(file_path.split('_')[-1].split('.')[0]) for file_path in file_list]
104 | else:
105 | label_list = [label_value for file_path in file_list]
106 | print(Counter(label_list))
107 | return list(zip(file_list, label_list))
108 |
109 | data_root_dir = '/media/data1/member1/projects/workspace_data/COLON_MANUAL_512/COLON_MANUAL_512'
110 |
111 | set_1010711 = load_data_info('%s/1010711/*.jpg' % data_root_dir)
112 | set_1010712 = load_data_info('%s/1010712/*.jpg' % data_root_dir)
113 | set_1010713 = load_data_info('%s/1010713/*.jpg' % data_root_dir)
114 | set_1010714 = load_data_info('%s/1010714/*.jpg' % data_root_dir)
115 | set_1010715 = load_data_info('%s/1010715/*.jpg' % data_root_dir)
116 | set_1010716 = load_data_info('%s/1010716/*.jpg' % data_root_dir)
117 | wsi_00016 = load_data_info('%s/wsi_00016/*.jpg' % data_root_dir, parse_label=True,
118 | label_value=0) # benign exclusively
119 | wsi_00017 = load_data_info('%s/wsi_00017/*.jpg' % data_root_dir, parse_label=True,
120 | label_value=0) # benign exclusively
121 | wsi_00018 = load_data_info('%s/wsi_00018/*.jpg' % data_root_dir, parse_label=True,
122 | label_value=0) # benign exclusively
123 |
124 | train_set = set_1010711 + set_1010712 + set_1010713 + set_1010715 + wsi_00016
125 | valid_set = set_1010716 + wsi_00018
126 | test_set = set_1010714 + wsi_00017
127 | return train_set, valid_set, test_set
128 |
129 |
130 | def prepare_colon_wsi_patch(data_visual=False):
131 | def load_data_info_from_list(data_dir, path_list):
132 | file_list = []
133 | for WSI_name in path_list:
134 | pathname = glob.glob(f'{data_dir}/{WSI_name}/*/*.png')
135 | file_list.extend(pathname)
136 | label_list = [int(file_path.split('_')[-1].split('.')[0]) - 1 for file_path in file_list]
137 | print(Counter(label_list))
138 | list_out = list(zip(file_list, label_list))
139 | return list_out
140 |
141 | data_root_dir = '/media/data1/trinh/data/workspace_data/colon_wsi/patches_colon_edit_MD/colon_45WSIs_1144_08_step05_05'
142 | data_visual = '/media/data1/trinh/data/workspace_data/colon_wsi/patches_colon_edit_MD/colon_45WSIs_1144_01_step05_visualize/patch_512/'
143 |
144 | df_test = [] #Note: Will be update later
145 |
146 | if data_visual:
147 | test_set = load_data_info_from_list(data_visual, df_test)
148 | else:
149 | test_set = load_data_info_from_list(data_root_dir, df_test)
150 | return test_set
151 |
152 |
153 | def prepare_prostate_uhu_data():
154 | def load_data_info(pathname, parse_label=True, label_value=0, cancer_test=False):
155 | file_list = glob.glob(pathname)
156 |
157 | if cancer_test:
158 | file_list_bn = glob.glob(pathname.replace('*.jpg', '*0.jpg'))
159 | file_list = [elem for elem in file_list if elem not in file_list_bn]
160 | label_list = [int(file_path.split('_')[-1].split('.')[0])-1 for file_path in file_list]
161 | else:
162 | if parse_label:
163 | label_list = [int(file_path.split('_')[-1].split('.')[0]) for file_path in file_list]
164 | else:
165 | label_list = [label_value for file_path in file_list]
166 | print(Counter(label_list))
167 | return list(zip(file_list, label_list))
168 |
169 | data_root_dir = '/data1/trinh/data/patches_data/prostate_harvard/'
170 | data_root_dir_train = f'{data_root_dir}/train_validation_patches_750/'
171 | data_root_dir_test = f'{data_root_dir}/test_patches_750/'
172 |
173 | train_set_111 = load_data_info('%s/ZT111*/*.jpg' % data_root_dir_train)
174 | train_set_199 = load_data_info('%s/ZT199*/*.jpg' % data_root_dir_train)
175 | train_set_204 = load_data_info('%s/ZT204*/*.jpg' % data_root_dir_train)
176 | valid_set = load_data_info('%s/ZT76*/*.jpg' % data_root_dir_train)
177 | test_set = load_data_info('%s/patho_1/*/*.jpg' % data_root_dir_test)
178 |
179 | train_set = train_set_111 + train_set_199 + train_set_204
180 | return train_set, valid_set, test_set
181 |
182 |
183 | def prepare_prostate_ubc_data(fold_idx=0):
184 | def load_data_info(pathname, parse_label=True, label_value=0):
185 | file_list = glob.glob(pathname)
186 | cancer_test = False
187 | if cancer_test:
188 | file_list_bn = glob.glob(pathname.replace('*.jpg', '*0.jpg'))
189 | file_list = [elem for elem in file_list if elem not in file_list_bn]
190 | label_list = [int(file_path.split('_')[-1].split('.')[0]) for file_path in file_list]
191 | label_dict = {2: 0, 3: 1, 4: 2}
192 | label_list = [label_dict[k] for k in label_list]
193 | else:
194 | if parse_label:
195 | label_list = [int(file_path.split('_')[-1].split('.')[0]) for file_path in file_list]
196 | else:
197 | label_list = [label_value for file_path in file_list]
198 | label_dict = {0: 0, 2: 1, 3: 2, 4: 3}
199 | label_list = [label_dict[k] for k in label_list]
200 | print(Counter(label_list))
201 | return list(zip(file_list, label_list))
202 |
203 | assert fold_idx < 3, "Currently only support 5 fold, each fold is 1 TMA"
204 |
205 | data_root_dir = '/data1/trinh/data/patches_data/'
206 | data_root_dir_train_ubc = f'{data_root_dir}/prostate_miccai_2019_patches_690_80_step05_test/'
207 | test_set_ubc = load_data_info('%s/*/*.jpg' % data_root_dir_train_ubc)
208 | return test_set_ubc
209 |
210 |
211 | def visualize(ds, batch_size, nr_steps=100):
212 | data_idx = 0
213 | cmap = plt.get_cmap('jet')
214 | for i in range(0, nr_steps):
215 | if data_idx >= len(ds):
216 | data_idx = 0
217 | for j in range(1, batch_size + 1):
218 | sample = ds[data_idx + j]
219 | if len(sample) == 2:
220 | img = sample[0]
221 | else:
222 | img = sample[0]
223 | # TODO: case with multiple channels
224 | aux = np.squeeze(sample[-1])
225 | aux = cmap(aux)[..., :3] # gray to RGB heatmap
226 | aux = (aux * 255).astype('unint8')
227 | img = np.concatenate([img, aux], axis=0)
228 | img = cv2.resize(img, (40, 80), interpolation=cv2.INTER_CUBIC)
229 | plt.subplot(1, batch_size, j)
230 | plt.title(str(sample[1]))
231 | plt.imshow(img)
232 | plt.show()
233 | data_idx += batch_size
234 |
235 |
236 |
237 |
238 |
--------------------------------------------------------------------------------
/docs/JCOL.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/trinhvg/JCO_Learning-pytorch/6a6044801ac0a2f6ea2c479d4417e1128e5a0b72/docs/JCOL.jpg
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: jco_learning
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - python=3.7
7 | - pip=20.3.1
8 | - openslide
9 | - pip:
10 | - -r file:requirements.txt
11 | - openslide-python==1.1.2
12 |
--------------------------------------------------------------------------------
/infer_produce_predict_map_wsi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import cv2
4 | import matplotlib.pyplot as plt
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import importlib
9 | import glob
10 |
11 | import dataset
12 | from config_validator import Config
13 | from misc.infer_wsi_utils import *
14 | from loss.ceo_loss import count_pred
15 |
16 |
17 | def compute_acc(pred_, ano_):
18 | pred, ano = pred_.copy(), ano_.copy()
19 | pred = pred[ano > 0]
20 | ano = ano[ano > 0]
21 | acc = np.mean(pred == ano)
22 | return np.round(acc, 4)
23 |
24 |
25 | class Inferer(Config):
26 | def __init__(self, _args=None):
27 | super(Inferer, self).__init__(_args=_args)
28 | if _args is not None:
29 | self.__dict__.update(_args.__dict__)
30 | self.run_info = self.run_info
31 | self.net_name = self.run_info
32 | self.net_dir = self.net_dir
33 | self.in_img_path = self.in_img_path
34 | self.in_ano_path = self.in_ano_path
35 | self.in_patch = self.in_patch
36 | self.out_img_path = self.out_img_path
37 | self.net_name = self.net_name
38 | self.infer_batch_size = 256
39 | self.nr_procs_valid = 31
40 | self.patch_size = 1144
41 | self.patch_stride = 1144 // 2
42 | self.nr_classes = 4
43 |
44 | def resize_save(self, svs_code, save_name, img, scale=1.0):
45 | ano = img.copy()
46 | cmap = plt.get_cmap('jet')
47 | path = f'{self.out_img_path}/{svs_code}/'
48 | img = (cmap(img / scale)[..., :3] * 255).astype('uint8')
49 | img[ano == 0] = [10, 10, 10]
50 | img = cv2.resize(img, (0, 0), fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC)
51 | cv2.imwrite(f'{path}/{save_name}.png', cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
52 | return 0
53 |
54 | def infer_step_m(self, net, batch, net_name):
55 | net.eval() # infer mode
56 |
57 | imgs = batch # batch is NHWC
58 | imgs = imgs.permute(0, 3, 1, 2) # to NCHW
59 |
60 | # push data to GPUs and convert to float32
61 | imgs = imgs.to('cuda').float()
62 |
63 | with torch.no_grad(): # dont compute gradient
64 | logit_class, _ = net(imgs) # forward
65 | prob = nn.functional.softmax(logit_class, dim=1)
66 | # prob = prob.permute(0, 2, 3, 1) # to NHWC
67 | return prob.cpu().numpy()
68 |
69 | def infer_step_c(self, net, batch, net_name):
70 | net.eval() # infer mode
71 |
72 | imgs = batch # batch is NHWC
73 | imgs = imgs.permute(0, 3, 1, 2) # to NCHW
74 |
75 | # push data to GPUs and convert to float32
76 | imgs = imgs.to('cuda').float()
77 |
78 | with torch.no_grad(): # dont compute gradient
79 | logit_class = net(imgs) # forward
80 | prob = nn.functional.softmax(logit_class, dim=1)
81 | # prob = prob.permute(0, 2, 3, 1) # to NHWC
82 | return prob.cpu().numpy()
83 |
84 | def infer_step_r(self, net, batch, net_name):
85 | net.eval() # infer mode
86 |
87 | imgs = batch # batch is NHWC
88 | imgs = imgs.permute(0, 3, 1, 2) # to NCHW
89 |
90 | # push data to GPUs and convert to float32
91 | imgs = imgs.to('cuda').float()
92 |
93 | with torch.no_grad(): # dont compute gradient
94 | if "rank_ordinal" in net_name:
95 | logits, probas = net(imgs)
96 | predict_levels = probas > 0.5
97 | pred = torch.sum(predict_levels, dim=1)
98 | return pred.cpu().numpy()
99 | elif "rank_dorn" in net_name:
100 | pred, softmax = net(imgs)
101 | return pred.cpu().numpy()
102 | elif "soft_label" in net_name:
103 | logit_regres = net(imgs) # forward
104 | label = torch.tensor([0., 1. / 3., 2. / 3., 1.]).repeat(len(logit_regres), 1).permute(1, 0).cuda()
105 | idx = torch.argmin(torch.abs(logit_regres - label), 0)
106 | return idx.cpu().numpy()
107 | elif "FocalOrdinal" in net_name:
108 | logit_regress = net(imgs)
109 | pred = count_pred(logit_regress)
110 | return pred.cpu().numpy()
111 | else:
112 | logit_regres = net(imgs) # forward
113 | label = torch.tensor([0., 1., 2., 3.]).repeat(len(logit_regres), 1).permute(1, 0).cuda()
114 | idx = torch.argmin(torch.abs(logit_regres - label), 0)
115 | return idx.cpu().numpy()
116 |
117 | def predict_one_model(self, net, svs_code, net_name="Multi_512_mse"):
118 | infer_step = Inferer.__getattribute__(self, f'infer_step_{net_name[0].lower()}')
119 | ano = np.float32(np.load(f'{self.in_ano_path}/{svs_code}.npy')) # [h, w]
120 | inf_output_dir = f'{self.out_img_path}/{svs_code}/'
121 | if not os.path.isdir(inf_output_dir):
122 | os.makedirs(inf_output_dir)
123 |
124 | path_pairs = glob.glob(f'{self.in_patch}/{svs_code}/*/*.png')
125 | infer_dataset = dataset.DatasetSerialWSI(path_pairs)
126 | dataloader = data.DataLoader(infer_dataset,
127 | num_workers=self.nr_procs_valid,
128 | batch_size=256,
129 | shuffle=False,
130 | drop_last=False)
131 |
132 | out_prob = np.zeros([self.nr_classes, ano.shape[0], ano.shape[1]], dtype=np.float32) # [h, w]
133 | out_prob_count = np.zeros([ano.shape[0], ano.shape[1]], dtype=np.float32) # [h, w]
134 |
135 | for batch_data in dataloader:
136 | imgs_input, imgs_path = batch_data
137 | imgs_path = np.array(imgs_path).transpose(1, 0)
138 | output_prob = infer_step(net, imgs_input, net_name)
139 | for idx, patch_loc in enumerate(imgs_path):
140 | patch_loc = patch_loc.astype(int) // 16
141 | patch_loc = [patch_loc[0], patch_loc[1]]
142 | out_prob_count[patch_loc[0]:patch_loc[0] + self.patch_size // 16,
143 | patch_loc[1]:patch_loc[1] + self.patch_size // 16] += 1
144 | for grade in range(self.nr_classes):
145 | out_prob[grade][patch_loc[0]:patch_loc[0] + self.patch_size // 16,
146 | patch_loc[1]:patch_loc[1] + self.patch_size // 16] += output_prob[idx][grade]
147 |
148 | out_prob_count[out_prob_count == 0.] = 1.
149 | out_prob /= out_prob_count
150 | predict = np.argmax(out_prob, axis=0) + 1
151 |
152 | for c in range(self.nr_classes):
153 | out_prob[c][ano == 0] = 0
154 | predict[ano == 0] = 0
155 |
156 | acc = compute_acc(predict, ano)
157 | print(acc)
158 |
159 | self.resize_save(svs_code, f'predict_{net_name}_{acc}', predict, scale=4.0)
160 | self.resize_save(svs_code, 'ano', ano, scale=4.0)
161 | np.save(f'{self.out_img_path}/{svs_code}/predict_{net_name}', predict)
162 | np.save(f'{self.out_img_path}/{svs_code}/ano', ano)
163 | print('done')
164 | return 0
165 |
166 | def predict_one_model_regress(self, net, svs_code, net_name="Multi_512_mse"):
167 | infer_step = Inferer.__getattribute__(self, f'infer_step_{net_name[0].lower()}')
168 | ano = np.float32(np.load(f'{self.in_ano_path}/{svs_code}.npy')) # [h, w]
169 | inf_output_dir = f'{self.out_img_path}/{svs_code}/'
170 | if not os.path.isdir(inf_output_dir):
171 | os.makedirs(inf_output_dir)
172 |
173 | path_pairs = glob.glob(f'{self.in_patch}/{svs_code}/*/*.png')
174 | infer_dataset = dataset.DatasetSerialWSI(path_pairs)
175 | dataloader = data.DataLoader(infer_dataset,
176 | num_workers=self.nr_procs_valid,
177 | batch_size=128,
178 | shuffle=False,
179 | drop_last=False)
180 | out_prob = np.zeros([self.nr_classes, ano.shape[0], ano.shape[1]], dtype=np.float32) # [h, w]
181 |
182 | for batch_data in dataloader:
183 | imgs_input, imgs_path = batch_data
184 | imgs_path = np.array(imgs_path).transpose(1, 0)
185 | output_prob = infer_step(net, imgs_input, net_name)
186 | for idx, patch_loc in enumerate(imgs_path):
187 | patch_loc = patch_loc.astype(int) // 16
188 | patch_loc = [patch_loc[0], patch_loc[1]]
189 | for grade in range(self.nr_classes):
190 | if grade == output_prob[idx]:
191 | out_prob[grade][patch_loc[0]:patch_loc[0] + self.patch_size // 16,
192 | patch_loc[1]:patch_loc[1] + self.patch_size // 16] += 1
193 | predict = np.argmax(out_prob, axis=0) + 1
194 |
195 | for c in range(self.nr_classes):
196 | out_prob[c][ano == 0] = 0
197 | predict[ano == 0] = 0
198 |
199 | acc = compute_acc(predict, ano)
200 | plt.imshow(predict)
201 | plt.show()
202 | print(acc)
203 | self.resize_save(svs_code, f'predict_{net_name}_{acc}', predict, scale=4.0)
204 | self.resize_save(svs_code, 'ano', ano, scale=4.0)
205 | np.save(f'{self.out_img_path}/{svs_code}/predict_{net_name}', predict)
206 | np.save(f'{self.out_img_path}/{svs_code}/ano', ano)
207 | print('done')
208 | return 0
209 |
210 | def run_wsi(self):
211 | device = 'cuda'
212 |
213 | self.task_type = self.net_name.split('_')[0]
214 |
215 | if "rank_dorn" in self.net_name:
216 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_rank_ordinal') # dynamic import
217 | net = net_def.jl_efficientnet(task_mode='regress_rank_dorn', pretrained=True)
218 |
219 | elif "FocalOrdinalLoss" in self.net_name:
220 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model') # dynamic import
221 | net = net_def.jl_efficientnet(task_mode='class', pretrained=True, num_classes=3)
222 | else:
223 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model') # dynamic import
224 | net = net_def.jl_efficientnet(task_mode=self.task_type.lower(), pretrained=True)
225 |
226 | net = torch.nn.DataParallel(net).to(device)
227 | inf_model_path = os.path.join(self.net_dir, self.net_name, f'trained_net.pth')
228 | saved_state = torch.load(inf_model_path)
229 | net.load_state_dict(saved_state)
230 |
231 | name_wsi_list = findExtension(self.in_ano_path, '.npy')
232 |
233 | for name in name_wsi_list:
234 | svs_code = name[:-4]
235 | print(svs_code)
236 | acc_wsi = []
237 | if 'REGRESS' in self.net_name:
238 | acc_one_model = self.predict_one_model_regress(net, svs_code, net_name=self.net_name)
239 | else:
240 | acc_one_model = self.predict_one_model(net, svs_code, net_name=self.net_name)
241 | acc_wsi.append(acc_one_model)
242 |
243 |
244 | ####
245 | if __name__ == '__main__':
246 | parser = argparse.ArgumentParser()
247 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
248 | parser.add_argument('--run_info', type=str, default='REGRESS_rank_dorn',
249 | help='CLASS, REGRESS, MULTI + loss, '
250 | 'loss ex: Class_ce, MULTI_mtmr, REGRESS_rank_ordinal, REGRESS_rank_dorn'
251 | 'REGRESS_FocalOrdinalLoss, REGRESS_soft_ordinal')
252 | parser.add_argument('--net_dir', type=str,
253 | default='/media/trinh/Data0/submit_paper_data/JL_pred/model/JL_model/JL_colon_model/',
254 | help='path to checkpoint model')
255 | parser.add_argument('--in_img_path', type=str,
256 | default='/media/data1/trinh/data/workspace_data/colon_wsi/ColonWSI/',
257 | help='path to wsi image')
258 | parser.add_argument('--in_ano_path', type=str,
259 | default='/media/data1/trinh/data/workspace_data/colon_wsi/Colon_WSI_annotation_npy/',
260 | help='path to wsi npy annotation')
261 | parser.add_argument('--in_patch', type=str,
262 | default='/media/data1/trinh/data/workspace_data/colon_wsi/patches_colon/colon_45WSIs_1144_01_step05_visualize_resize512/',
263 | help='path to patch image')
264 | parser.add_argument('--out_img_path', type=str,
265 | default='/media/data1/trinh/data/workspace_data/colon_wsi/JointLearning_wsi_pred/',
266 | help='path to patch image')
267 |
268 | parser = argparse.ArgumentParser()
269 | args = parser.parse_args()
270 | inferer = Inferer(_args=args)
271 | inferer.run_wsi()
272 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/trinhvg/JCO_Learning-pytorch/6a6044801ac0a2f6ea2c479d4417e1128e5a0b72/loss/__init__.py
--------------------------------------------------------------------------------
/loss/cancer_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def mae_cancer_v0(input, target):
8 | input_ = input[target != 0]
9 | target_ = target[target != 0]
10 | return F.l1_loss(input_, target_) if len(target_) != 0 else 0
11 |
12 |
13 | # def mse_cancer(input, target):
14 | # input_ = input[target != 0]
15 | # target_ = target[target != 0]
16 | # return F.mse_loss(input_, target_) if len(target_) != 0 else 0
17 |
18 |
19 | def mse_cancer_v0(input, target):
20 | input_ = input[target != 0]
21 | target_ = target[target != 0]
22 | return F.mse_loss(input_, target_) if len(target_) != 0 else 0
23 |
24 |
25 | def ceo_cancer_v0(input, target):
26 | input_ = input[target != 0]
27 | target_ = target[target != 0]
28 | if len(target_) == 0:
29 | return 0
30 | label_ = torch.tensor([1., 2., 3.]).repeat(len(target_), 1).cuda()
31 | logit_proposed_ = input_.repeat(3, 1).permute(1, 0)
32 | logit_proposed_ = torch.abs(logit_proposed_ - label_)
33 | return F.cross_entropy(-logit_proposed_, target_ - 1)
34 |
35 | def mae_cancer(input, target):
36 | mae_loss = F.l1_loss(input, target, reduction='none')
37 | select = torch.randint(0, 2, (target.shape[0],)).float().cuda() * torch.sign(target)
38 | return (mae_loss*select).mean()
39 |
40 |
41 | # def mse_cancer(input, target):
42 | # input_ = input[target != 0]
43 | # target_ = target[target != 0]
44 | # return F.mse_loss(input_, target_) if len(target_) != 0 else 0
45 |
46 |
47 | def mse_cancer(input, target):
48 | mse_loss = F.mse_loss(input, target, reduction='none')
49 | # print(mse_loss.shape)
50 | # print(torch.sign(target).shape)
51 | # print(torch.sign(target))
52 | select = torch.randint(0, 2, (target.shape[0],)).float().cuda() * torch.sign(target)
53 | return (mse_loss*select).mean()
54 |
55 |
56 | def ceo_cancer(input, target):
57 | label = torch.tensor([0., 1., 2., 3.]).repeat(len(target), 1).cuda()
58 | logit_proposed = input.repeat(4, 1).permute(1, 0)
59 | logit_proposed = torch.abs(logit_proposed - label)
60 | ceo_loss = F.cross_entropy(-logit_proposed, target, reduction='none')
61 | # select = (torch.randint(0, 2, (target.shape[0],)).cuda() * torch.sign(target)).float()
62 | select = torch.sign(target).float()
63 | return (ceo_loss*select).mean()
64 |
65 | # class CeoCancer:
66 | # def __init__(self, ):
67 | # super(CeoCancer, self).__init__()
68 |
--------------------------------------------------------------------------------
/loss/ceo_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 |
7 | class CEOLoss(nn.Module):
8 | """
9 | Args:
10 | num_classes (int): number of classes.
11 | """
12 | def __init__(self, num_classes=4):
13 | super(CEOLoss, self).__init__()
14 | self.num_classes = num_classes
15 | self.level = torch.arange(self.num_classes)
16 |
17 | def forward(self, x, y):
18 | """"
19 | Args:
20 | x (tensor): Regression/ordinal output, size (B), type: float
21 | y (tensor): Ground truth, size (B), type: int/long
22 |
23 | Returns:
24 | CEOLoss: Cross-Entropy Ordinal loss
25 | """
26 | levels = self.level.repeat(len(y), 1).cuda()
27 | logit = x.repeat(self.num_classes, 1).permute(1, 0)
28 | logit = torch.abs(logit - levels)
29 | return F.cross_entropy(-logit, y, reduction='mean')
30 |
31 |
32 |
33 | class FocalLoss(nn.Module):
34 | def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
35 | super(FocalLoss, self).__init__()
36 | self.alpha = alpha
37 | self.gamma = gamma
38 | self.logits = logits
39 | self.reduce = reduce
40 |
41 | def forward(self, inputs, targets):
42 | if self.logits:
43 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
44 | else:
45 | BCE_loss = F.cross_entropy(inputs, targets, reduce=None, reduction='none')
46 | pt = torch.exp(-BCE_loss)
47 | F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
48 |
49 | if self.reduce:
50 | return torch.mean(F_loss)
51 | else:
52 | return F_loss
53 |
54 |
55 | class SoftLabelOrdinalLoss(nn.Module):
56 | def __init__(self, alpha=1.):
57 | super(SoftLabelOrdinalLoss, self).__init__()
58 | self.alpha = alpha
59 |
60 | def forward(self, x, y):
61 | """Validates model name.
62 |
63 | Args:
64 | x (Tensor): [0, 1, 2, 3]
65 | y (Tensor): [0, 1, 2, 3]
66 |
67 | Returns:
68 | loss: scalar
69 | """
70 | # y /= 3
71 | # y /= 2
72 | x = torch.sigmoid(x)
73 | soft_loss = -(1 - y) * torch.log(1 - x) - self.alpha * y * torch.log(x)
74 | return torch.mean(soft_loss)
75 |
76 |
77 |
78 | def label_to_levels(label, num_classes=4):
79 | levels = [1] * label + [0] * (num_classes - 1 - label)
80 | levels = torch.tensor(levels, dtype=torch.float32)
81 | return levels
82 |
83 |
84 | def labels_to_labels(class_labels, num_classes =4):
85 | """
86 | class_labels = [2, 1, 3]
87 | """
88 | levels = []
89 | for label in class_labels:
90 | levels_from_label = label_to_levels(int(label), num_classes=num_classes)
91 | levels.append(levels_from_label)
92 | return torch.stack(levels).cuda()
93 |
94 |
95 | def cost_fn(logits, label):
96 | num_classes = 3 #Note
97 | imp = torch.ones(num_classes - 1, dtype=torch.float).cuda()
98 | levels = labels_to_labels(label, num_classes)
99 | val = (-torch.sum((F.log_softmax(logits, dim=2)[:, :, 1] * levels
100 | + F.log_softmax(logits, dim=2)[:, :, 0] * (1 - levels)) * imp, dim=1))
101 | return torch.mean(val)
102 |
103 |
104 | def loss_fn2(logits, label):
105 | num_classes = 3 #Note
106 | imp = torch.ones(num_classes - 1, dtype=torch.float)
107 | levels = labels_to_labels(label)
108 | val = (-torch.sum((F.logsigmoid(logits) * levels
109 | + (F.logsigmoid(logits) - logits) * (1 - levels)) * imp,
110 | dim=1))
111 | return torch.mean(val)
112 |
113 |
114 | class FocalOrdinalLoss(nn.Module):
115 | def __init__(self, alpha=0.75, pooling=False, num_classes=4):
116 | super(FocalOrdinalLoss, self).__init__()
117 | self.alpha = alpha
118 | self.pooling = pooling
119 | self.num_classes = num_classes
120 |
121 | def forward(self, x, y):
122 | # convert one-hot y to ordinal y
123 | levels = labels_to_labels(y, num_classes=self.num_classes)
124 | q, _ = torch.max(levels*(1-x)**2 + (1-levels)*x**2, dim=1)
125 | if self.pooling:
126 | q = q.unsqueeze(0)
127 | q = q.unsqueeze(0)
128 | q = nn.MaxPool1d(3, 1, padding=1)(q)
129 | x = torch.sigmoid(x)
130 | # compute the loss
131 | f_loss = q*torch.sum(-self.alpha*levels*torch.log(x) - (1-self.alpha)*(1-levels)*torch.log(1-x))
132 | return torch.mean(f_loss)
133 |
134 |
135 |
136 |
137 |
138 | def count_pred(x):
139 | N = x.shape[0]
140 | x = x.cuda() > 0.5
141 | pred = torch.zeros(N).long().cuda()
142 | pred = pred.view(N, 1)
143 | for i in range(x.shape[1]):
144 | pred_i = x[:, :i+1].prod(1)*x[:, :i+1].sum(1)
145 | pred = torch.cat([pred, pred_i.view(N, 1)], dim =1)
146 | return pred.max(1)[0]
147 |
148 |
149 | # #
150 | # import os
151 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1'
152 | # def test():
153 | # # x = torch.Tensor([[0.7, 0.5, 0.6], [0.5, 0.8, 0.2], [0.8, 0.6, 0.1], [0.1, 0.5, 0.6]])
154 | # # y = torch.Tensor([1., 2., 3., 0.])
155 | # # x = x.to("cuda")
156 | # # y = y.to("cuda")
157 | # # FocalOrdinalLoss()(x, y)
158 | # # count_pred(x)
159 | #
160 | #
161 | # x = torch.Tensor([0.7, 2., 0.6, 1.])
162 | # y = torch.Tensor([1, 2, 3, 0])
163 | # x = x.to("cuda")
164 | # y = y.to("cuda")
165 | # CEOLoss(4)(x, y)
166 | # count_pred(x)
167 | #
168 | # test()
169 | #
170 | # #
171 | # #
172 | # #
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
--------------------------------------------------------------------------------
/loss/dorn_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | """
5 | refer to https://github.com/liviniuk/DORN_depth_estimation_Pytorch
6 | """
7 |
8 |
9 | class OrdinalLoss(nn.Module):
10 | """
11 | Ordinal loss as defined in the paper "DORN for Monocular Depth Estimation".
12 | refer to https://github.com/liviniuk/DORN_depth_estimation_Pytorch
13 | """
14 |
15 | def __init__(self):
16 | super(OrdinalLoss, self).__init__()
17 |
18 | def forward(self, pred_softmax, target_labels):
19 | """
20 | :param pred_softmax: predicted softmax probabilities P
21 | :param target_labels: ground truth ordinal labels
22 | :return: ordinal loss
23 | """
24 |
25 | n, c = pred_softmax.size() # C - number of discrete sub-intervals (= number of channels)
26 | target_labels = target_labels.int().view(n, 1)
27 |
28 | K = torch.zeros((n, c), dtype=torch.int).cuda()
29 | for i in range(c):
30 | K[:, i] = K[:, i] + i * torch.ones(n, dtype=torch.int).cuda()
31 |
32 | mask = (K <= target_labels).detach()
33 |
34 | loss = pred_softmax[mask].clamp(1e-8, 1e8).log().sum() + (1 - pred_softmax[~mask]).clamp(1e-8, 1e8).log().sum()
35 | loss /= -n
36 | return loss
37 |
--------------------------------------------------------------------------------
/loss/mtmr_loss.py:
--------------------------------------------------------------------------------
1 | # !/usr/bin/env python
2 | # coding=utf-8
3 | """
4 | https://github.com/liulihao-cuhk/MTMR-NET
5 | """
6 | import os
7 | from torch.autograd import Variable
8 | from collections import OrderedDict
9 | import torch.nn as nn
10 | import numpy as np
11 | import torch
12 | import math
13 |
14 | def get_loss_mtmr(output_score_1, cat_subtlety_score, gt_score_1, gt_attribute_score_1):
15 | xcentloss_func_1 = nn.CrossEntropyLoss()
16 | xcentloss_1 = xcentloss_func_1(output_score_1, gt_score_1)
17 |
18 | # ranking loss
19 | ranking_loss_sum = 0
20 | half_size_of_output_score = output_score_1.size()[0] // 2
21 | for i in range(half_size_of_output_score):
22 | tmp_output_1 = output_score_1[i]
23 | tmp_output_2 = output_score_1[i + half_size_of_output_score]
24 | tmp_gt_score_1 = gt_score_1[i]
25 | tmp_gt_score_2 = gt_score_1[i + half_size_of_output_score]
26 |
27 | rankingloss_func = nn.MarginRankingLoss()
28 |
29 | if tmp_gt_score_1.item() != tmp_gt_score_2.item():
30 | target = torch.ones(1) * -1
31 | ranking_loss_sum += rankingloss_func(tmp_output_1, tmp_output_2, Variable(target.cuda()))
32 | else:
33 | target = torch.ones(1)
34 | ranking_loss_sum += rankingloss_func(tmp_output_1, tmp_output_2, Variable(target.cuda()))
35 |
36 | ranking_loss = ranking_loss_sum / half_size_of_output_score
37 |
38 | # attribute loss
39 | attribute_mseloss_func_1 = nn.MSELoss()
40 | attribute_mseloss_1 = attribute_mseloss_func_1(cat_subtlety_score, gt_attribute_score_1.float())
41 |
42 | loss = 1 * xcentloss_1 + 5.0e-1 * ranking_loss + 1.0e-3 * attribute_mseloss_1
43 |
44 | return loss
45 |
--------------------------------------------------------------------------------
/loss/rank_ordinal_loss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import pandas as pd
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | """
9 | https://github.com/Raschka-research-group/coral-cnn/blob/master/coral-implementation-recipe.ipynb
10 | """
11 |
12 |
13 | def label_to_levels(label, num_classes=4):
14 | levels = [1] * label + [0] * (num_classes - 1 - label)
15 | levels = torch.tensor(levels, dtype=torch.float32)
16 | return levels
17 |
18 |
19 | def labels_to_labels(class_labels, num_classes):
20 | """
21 | class_labels = [2, 1, 3]
22 | """
23 | levels = []
24 | for label in class_labels:
25 | levels_from_label = label_to_levels(int(label), num_classes=num_classes)
26 | levels.append(levels_from_label)
27 | return torch.stack(levels).cuda()
28 |
29 |
30 | def cost_fn(logits, label, num_classes):
31 | imp = torch.ones(num_classes - 1, dtype=torch.float).cuda()
32 | levels = labels_to_labels(label, num_classes)
33 | val = (-torch.sum((F.log_softmax(logits, dim=2)[:, :, 1] * levels
34 | + F.log_softmax(logits, dim=2)[:, :, 0] * (1 - levels)) * imp, dim=1))
35 | return torch.mean(val)
36 |
37 |
38 | def loss_fn2(logits, label):
39 | num_classes = 4
40 | imp = torch.ones(num_classes - 1, dtype=torch.float)
41 | levels = labels_to_labels(label)
42 | val = (-torch.sum((F.logsigmoid(logits) * levels
43 | + (F.logsigmoid(logits) - logits) * (1 - levels)) * imp,
44 | dim=1))
45 | return torch.mean(val)
46 |
--------------------------------------------------------------------------------
/misc/infer_wsi_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil # High-level file operations
3 | from itertools import chain
4 | from sklearn.metrics import f1_score
5 | import random
6 | import cv2
7 | import numpy as np
8 | import torch.utils.data as data
9 | from torchvision import transforms
10 |
11 |
12 | def color_mask(a, r, g, b):
13 | ch_r = a[..., 0] == r
14 | ch_g = a[..., 1] == g
15 | ch_b = a[..., 2] == b
16 | return ch_r & ch_g & ch_b
17 |
18 |
19 | def normalize(mask, dtype=np.uint8):
20 | return (255 * mask / np.amax(mask)).astype(dtype)
21 |
22 |
23 | def bounding_box(img):
24 | rows = np.any(img, axis=1)
25 | cols = np.any(img, axis=0)
26 | rmin, rmax = np.where(rows)[0][[0, -1]]
27 | cmin, cmax = np.where(cols)[0][[0, -1]]
28 | return rmin, rmax, cmin, cmax
29 |
30 |
31 | def cropping_center(x, crop_shape, batch=False):
32 | orig_shape = x.shape
33 | if not batch:
34 | h0 = int((orig_shape[0] - crop_shape[0]) * 0.5)
35 | w0 = int((orig_shape[1] - crop_shape[1]) * 0.5)
36 | x = x[h0:h0 + crop_shape[0], w0:w0 + crop_shape[1]]
37 | else:
38 | h0 = int((orig_shape[1] - crop_shape[0]) * 0.5)
39 | w0 = int((orig_shape[2] - crop_shape[1]) * 0.5)
40 | x = x[:, h0:h0 + crop_shape[0], w0:w0 + crop_shape[1]]
41 | return x
42 |
43 |
44 | # to make it easier for visualization
45 | def randomize_label(label_map):
46 | label_list = np.unique(label_map)
47 | label_list = label_list[1:] # exclude the background
48 | label_rand = list(label_list) # dup frist cause shuffle is done in place
49 | random.shuffle(label_rand)
50 | new_map = np.zeros(label_map.shape, dtype=label_map.dtype)
51 |
52 |
53 | """Recursive directory creation function. Like mkdir(),
54 | but makes all intermediate-level directories needed to contain the leaf directory.
55 | A leaf is a node on a tree with no child nodes."""
56 |
57 |
58 | def rm_n_mkdir(dir):
59 | if os.path.isdir(dir):
60 | shutil.rmtree(dir)
61 | os.makedirs(dir)
62 |
63 |
64 | ###
65 | # test
66 |
67 | # import cv2
68 | # import matplotlib.pyplot as plt
69 | #
70 | # img = cv2.imread('/media/vtltrinh/Data1/COLON_MANUAL_PATCHES/v1/1010711/000_3.jpg')
71 | # im = np.array(img)
72 | # im_mask = color_mask(im, 1, 1, 1)
73 | #
74 | # bound = bounding_box(im)
75 | # print(bound)
76 |
77 | def findExtension(directory, extension='.txt'):
78 | files = []
79 | for file in os.listdir(directory):
80 | if file.endswith(extension):
81 | files += [file]
82 | files.sort()
83 | return files
84 |
85 |
86 | def generate_patch_list_(roi, patch_size, stride):
87 | min_height, min_width, max_height, max_width = roi
88 | min_height, min_width, max_height, max_width = min_height - stride, min_width - stride, max_height + stride, max_width + stride
89 | h_list = np.arange(min_height, max_height - patch_size, stride)
90 | w_list = np.arange(min_width, max_width - patch_size, stride)
91 | out = [[[h_list[h], w_list[w]] for w in range(len(w_list))] for h in range(len(h_list))]
92 | return list(chain(*out))
93 |
94 |
95 | def generate_patch_list(ano, roi, patch_size, stride):
96 | min_height, min_width, max_height, max_width = roi
97 | min_height, min_width, max_height, max_width = min_height - stride, min_width - stride, max_height + stride, max_width + stride
98 | h_list = np.arange(min_height, max_height - patch_size, stride)
99 | w_list = np.arange(min_width, max_width - patch_size, stride)
100 | out = [[[h_list[h], w_list[w]] for w in range(len(w_list))] for h in range(len(h_list))]
101 | path_list = list(chain(*out))
102 | # print(len(path_list))
103 | infer_dataset = DatasetSelectPatch(ano, path_list, patch_size)
104 | path_loader = data.DataLoader(infer_dataset, num_workers=31, batch_size=1144, shuffle=False, drop_last=False)
105 | for keeps, loca in path_loader:
106 | keeps_ = keeps.to('cuda')
107 | keeps_ += 1
108 | for idx in range(len(keeps)):
109 | if keeps[idx] == 1:
110 | a = eval(loca[idx])
111 | path_list.remove(a)
112 | # print('hi', len(path_list))
113 | return path_list
114 |
115 |
116 | def read_ano_text(text_path):
117 | list_labels = {
118 | "BG": 0,
119 | "BN": 1,
120 | "WD": 2,
121 | "MD": 3,
122 | "PD": 4,
123 | "Ad": 5,
124 | }
125 | text_file = open(text_path, "r")
126 | lines = text_file.readlines()
127 | lines = [line.replace('\n', '').replace('\t', '') for line in lines]
128 | anos_dict = {}
129 | count_ROIs = np.zeros(shape=5, dtype=int)
130 | for label in list_labels:
131 | anos_dict.__setitem__(label, {})
132 |
133 | for line in lines[1:-1]:
134 | if line[1:3] in list_labels:
135 | label_id = line[1:3]
136 | coordinates = []
137 | count_ROIs[list_labels[label_id] - 1] += 1
138 | ROIs_id = count_ROIs[list_labels[label_id] - 1]
139 | else:
140 | if 'X' in line:
141 | dims_val = eval(line.replace("},", "}"))
142 | coordinates.append([int(dims_val[dim]) for dim in dims_val.keys()])
143 | else:
144 | anos_dict[label_id].__setitem__(ROIs_id, coordinates)
145 |
146 | keys_to_remove = ["BG", "Ad"]
147 | for key in keys_to_remove:
148 | del anos_dict[key]
149 | return anos_dict
150 |
151 |
152 | def find_roi(anos_dict):
153 | min_height = []
154 | min_width = []
155 | max_height = []
156 | max_width = []
157 | valid_ano = ['BN', 'WD', 'MD', 'PD']
158 | for label_key in anos_dict.keys():
159 | if label_key in valid_ano:
160 | for polygon_key in anos_dict[label_key]:
161 | region = anos_dict[label_key][polygon_key]
162 | min_height.append(np.int32([region])[0, :, 1].min()) # np(height, width) while openslide (with,height)
163 | min_width.append(np.int32([region])[0, :, 0].min())
164 | max_height.append(np.int32([region])[0, :, 1].max()) # np(height, width) while openslide (with,height)
165 | max_width.append(np.int32([region])[0, :, 0].max())
166 | min_height = min(min_height)
167 | min_width = min(min_width)
168 | max_height = max(max_height)
169 | max_width = max(max_width)
170 | return [min_height, min_width, max_height, max_width]
171 |
172 |
173 | def compute_f1(pred, ano):
174 | pred, ano = pred.flatten(), ano.flatten()
175 | pred = pred[ano != 0]
176 | ano = ano[ano != 0]
177 | f1 = f1_score(ano, pred, average='macro', labels=np.unique(ano))
178 | return int(f1 * 10000)
179 |
180 |
181 | class DatasetSelectPatch(data.Dataset):
182 | def __init__(self, ano, path_list, patch_size):
183 | self.ano = ano
184 | self.path_list = path_list
185 | self.patch_size = patch_size
186 |
187 | def __getitem__(self, idx):
188 | w = self.path_list[idx][0]//16
189 | h = self.path_list[idx][1]//16
190 | patch_size = self.patch_size//16
191 | input_img = self.ano[w: w + patch_size, h: h + patch_size]
192 |
193 | if input_img.size == 0:
194 | keep = np.array([0])
195 | elif input_img.mean() > 0:
196 | keep = np.array([1])
197 | else:
198 | keep = np.array([0])
199 | return keep, str(self.path_list[idx])
200 |
201 | def __len__(self):
202 | return len(self.path_list)
203 |
--------------------------------------------------------------------------------
/misc/train_ultils_all_iter.py:
--------------------------------------------------------------------------------
1 | import io
2 | import itertools
3 | import json
4 | import os
5 |
6 | import random
7 | import re
8 | import shutil
9 | import textwrap
10 |
11 | import cv2
12 | import matplotlib
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | import pandas as pd
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | from sklearn.metrics import confusion_matrix
20 | from termcolor import colored
21 |
22 | import torch
23 | import torch.nn as nn
24 | import torch.nn.functional as F
25 | import imgaug as ia
26 |
27 |
28 | def check_manual_seed(seed):
29 | """
30 | If manual seed is not specified, choose a random one and notify it to the user
31 | """
32 | seed = seed
33 | random.seed(seed)
34 | np.random.seed(seed)
35 | torch.manual_seed(seed)
36 | torch.cuda.manual_seed(seed)
37 | ia.seed(seed)
38 | torch.cuda.manual_seed_all(seed)
39 | torch.backends.cudnn.benchmark = False
40 | torch.backends.cudnn.deterministic = True
41 |
42 | print('Using manual seed: {seed}'.format(seed=seed))
43 | return
44 |
45 |
46 | def check_log_dir(log_dir):
47 | # check if log dir exist
48 | if os.path.isdir(log_dir):
49 | color_word = colored('WARMING', color='red', attrs=['bold', 'blink'])
50 | print('%s: %s exist!' % (color_word, colored(log_dir, attrs=['underline'])))
51 | while (True):
52 | print('Select Action: d (delete)/ q (quit)', end='')
53 | key = input()
54 | if key == 'd':
55 | shutil.rmtree(log_dir)
56 | break
57 | elif key == 'q':
58 | exit()
59 | else:
60 | color_word = colored('ERR', color='red')
61 | print('---[%s] Unrecognized character!' % color_word)
62 | return
63 |
64 |
65 | def plot_confusion_matrix(conf_mat, label):
66 | """
67 | Parameters:
68 | title='Confusion matrix' : Title for your matrix
69 | tensor_name = 'MyFigure/image' : Name for the output summay tensor
70 | Returns:
71 | summary: image of plot figure
72 | Other items to note:
73 | - Depending on the number of category and the data , you may have to modify the figzie, font sizes etc.
74 | - Currently, some of the ticks dont line up due to rotations.
75 | """
76 |
77 | cm = conf_mat
78 |
79 | np.set_printoptions(precision=2) # print numpy array with 2 decimal places
80 |
81 | fig = matplotlib.figure.Figure(figsize=(7, 7), dpi=320, facecolor='w', edgecolor='k')
82 | ax = fig.add_subplot(1, 1, 1)
83 | im = ax.imshow(cm, cmap='Oranges')
84 |
85 | classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in label]
86 | classes = ['\n'.join(textwrap.wrap(l, 40)) for l in classes]
87 |
88 | tick_marks = np.arange(len(classes))
89 |
90 | ax.set_xlabel('Predicted', fontsize=7)
91 | ax.set_xticks(tick_marks)
92 | c = ax.set_xticklabels(classes, fontsize=4, rotation=-90, ha='center')
93 | ax.xaxis.set_label_position('bottom')
94 | ax.xaxis.tick_bottom()
95 |
96 | ax.set_ylabel('True Label', fontsize=7)
97 | ax.set_yticks(tick_marks)
98 | ax.set_yticklabels(classes, fontsize=4, va='center')
99 | ax.yaxis.set_label_position('left')
100 | ax.yaxis.tick_left()
101 |
102 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
103 | ax.text(j, i, format(cm[i, j], 'd') if cm[i, j] != 0 else '.',
104 | horizontalalignment="center", fontsize=6,
105 | verticalalignment='center', color="black")
106 | fig.set_tight_layout(True)
107 |
108 | fig.canvas.draw()
109 | w, h = fig.canvas.get_width_height()
110 |
111 | # get PNG data from the figure
112 | png_buffer = io.BytesIO()
113 | fig.canvas.print_png(png_buffer)
114 | png_encoded = png_buffer.getvalue()
115 | png_buffer.close()
116 |
117 | return png_encoded
118 |
119 |
120 | ####
121 | def update_log(output, epoch, prefix, color, tfwriter, log_file, logging):
122 | # print values and convert
123 | max_length = len(max(output.keys(), key=len))
124 | for metric in output:
125 | key = colored(prefix + '-' + metric.ljust(max_length), color)
126 | print('------%s : ' % key, end='')
127 | if metric not in ['conf_mat_c', 'conf_mat_r', 'box_plot_data']:
128 | print('%0.7f' % output[metric])
129 | elif metric == 'conf_mat_c':
130 | conf_mat_c = output['conf_mat_c'] # use pivot to turn back
131 | conf_mat_c_df = pd.DataFrame(conf_mat_c)
132 | conf_mat_c_df.index.name = 'True'
133 | conf_mat_c_df.columns.name = 'Pred'
134 | output['conf_mat_c'] = conf_mat_c_df
135 | print('\n', conf_mat_c_df)
136 | elif metric == 'conf_mat_r':
137 | conf_mat_r = output['conf_mat_r'] # use pivot to turn back
138 | conf_mat_r_df = pd.DataFrame(conf_mat_r)
139 | conf_mat_r_df.index.name = 'True'
140 | conf_mat_r_df.columns.name = 'Pred'
141 | output['conf_mat_r'] = conf_mat_r_df
142 | print('\n', conf_mat_r_df)
143 | elif metric == 'box_plot_data':
144 | box_plot_data = output['box_plot_data'] # use pivot to turn back
145 | box_plot_data_df = pd.DataFrame(box_plot_data)
146 | box_plot_data_df.columns.name = 'Pred'
147 | output['box_plot_data'] = box_plot_data_df
148 |
149 | if not logging:
150 | return
151 |
152 | # create stat dicts
153 | stat_dict = {}
154 | for metric in output:
155 | if metric not in ['conf_mat_c', 'conf_mat_r', 'box_plot_data']:
156 | metric_value = output[metric]
157 | elif metric == 'conf_mat_c':
158 | conf_mat_df = output['conf_mat_c'] # use pivot to turn back
159 | conf_mat_df = conf_mat_df.unstack().rename('value').reset_index()
160 | conf_mat_df = pd.Series({'conf_mat_c': conf_mat_c}).to_json(orient='records')
161 | metric_value = conf_mat_df
162 | elif metric == 'conf_mat_r':
163 | conf_mat_regres_df = output['conf_mat_r'] # use pivot to turn back
164 | conf_mat_regres_df = conf_mat_regres_df.unstack().rename('value').reset_index()
165 | conf_mat_regres_df = pd.Series({'conf_mat_r': conf_mat_r}).to_json(orient='records')
166 | metric_value = conf_mat_regres_df
167 | elif metric == 'box_plot_data':
168 | box_plot_data_df = pd.Series({'box_plot_data': box_plot_data}).to_json(orient='records')
169 | metric_value = box_plot_data_df
170 | stat_dict['%s-%s' % (prefix, metric)] = metric_value
171 |
172 | # json stat log file, update and overwrite
173 | with open(log_file) as json_file:
174 | json_data = json.load(json_file)
175 |
176 | current_epoch = str(epoch)
177 | if current_epoch in json_data:
178 | old_stat_dict = json_data[current_epoch]
179 | stat_dict.update(old_stat_dict)
180 | current_epoch_dict = {current_epoch: stat_dict}
181 | json_data.update(current_epoch_dict)
182 |
183 | with open(log_file, 'w') as json_file:
184 | json.dump(json_data, json_file)
185 |
186 | # log values to tensorboard
187 | for metric in output:
188 | if metric not in ['conf_mat_c', 'conf_mat_r', 'box_plot_data']:
189 | tfwriter.add_scalar(prefix + '-' + metric, output[metric], current_epoch)
190 |
191 |
192 | ####
193 | def log_train_ema_results(engine, info):
194 | """
195 | running training measurement
196 | """
197 | training_ema_output = engine.state.metrics #
198 | training_ema_output['lr'] = float(info['optimizer'].param_groups[0]['lr'])
199 | update_log(training_ema_output, engine.state.iteration, 'train-ema', 'green',
200 | info['tfwriter'], info['json_file'], info['logging'])
201 |
202 |
203 | ####
204 | def process_accumulated_output_multi(output, batch_size, nr_classes):
205 | #
206 | def uneven_seq_to_np(seq):
207 | item_count = batch_size * (len(seq) - 1) + len(seq[-1])
208 | cat_array = np.zeros((item_count,) + seq[0][0].shape, seq[0].dtype)
209 | # BUG: odd len even
210 | for idx in range(0, len(seq) - 1):
211 | cat_array[idx * batch_size: (idx + 1) * batch_size] = seq[idx]
212 | idx = -1 if len(seq) == 1 else idx # in case len(seq) ==1 then the for loop below will be skipped
213 | cat_array[(idx + 1) * batch_size:] = seq[-1]
214 | return cat_array
215 |
216 | proc_output = dict()
217 | true = uneven_seq_to_np(output['true'])
218 | # threshold then get accuracy
219 | if 'logit_c' in output.keys():
220 | logit_c = uneven_seq_to_np(output['logit_c'])
221 | pred_c = np.argmax(logit_c, axis=-1)
222 | acc_c = np.mean(pred_c == true)
223 | # confusion matrix
224 | conf_mat_c = confusion_matrix(true, pred_c, labels=np.arange(nr_classes))
225 | proc_output.update(acc_c=acc_c, conf_mat_c=conf_mat_c, )
226 | if 'logit_r' in output.keys():
227 | logit_r = uneven_seq_to_np(output['logit_r'])
228 | label = np.transpose(np.array([[0., 1., 2., 3.]]).repeat(len(true), axis=0), (1, 0))
229 | pred_r = np.argmin(abs((logit_r - label)), axis=0)
230 | acc_r = np.mean(pred_r == true)
231 | # confusion matrix
232 | conf_mat_r = confusion_matrix(true, pred_r, labels=np.arange(nr_classes))
233 | proc_output.update(acc_r=acc_r, conf_mat_r=conf_mat_r)
234 |
235 | # proc_output.update(box_plot_data=np.concatenate(
236 | # [true[np.newaxis, :], pred_c[np.newaxis, :], pred_r[np.newaxis, :], logit_r.transpose(1, 0)], 0))
237 | return proc_output
238 |
239 |
240 | ####
241 | def inference(engine, inferer, prefix, dataloader, info):
242 | """
243 | inference measurement
244 | """
245 | inferer.accumulator = {metric: [] for metric in info['metric_names']}
246 | inferer.run(dataloader)
247 | output_stat = process_accumulated_output_multi(inferer.accumulator,
248 | info['infer_batch_size'], info['nr_classes'])
249 | update_log(output_stat, engine.state.iteration, prefix, 'red',
250 | info['tfwriter'], info['json_file'], info['logging'])
251 | return
252 |
253 |
254 | ####
255 | def accumulate_outputs(engine):
256 | batch_output = engine.state.output
257 | for key, item in batch_output.items():
258 | engine.accumulator[key].extend([item])
259 | return
260 |
--------------------------------------------------------------------------------
/misc/train_ultils_validator.py:
--------------------------------------------------------------------------------
1 | import io
2 | import itertools
3 | import json
4 | import os
5 |
6 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
7 | import random
8 | import re
9 | import shutil
10 | import textwrap
11 |
12 | import cv2
13 | import matplotlib
14 | import matplotlib.pyplot as plt
15 | import numpy as np
16 | import pandas as pd
17 | import torch
18 | from sklearn.metrics import confusion_matrix
19 | from termcolor import colored
20 |
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.functional as F
24 | import imgaug as ia
25 | from scipy.special import softmax
26 | from sklearn.metrics import classification_report
27 |
28 | def check_manual_seed(seed):
29 | """
30 | If manual seed is not specified, choose a random one and notify it to the user
31 | """
32 | seed = seed
33 | random.seed(seed)
34 | np.random.seed(seed)
35 | torch.manual_seed(seed)
36 | torch.cuda.manual_seed(seed)
37 | ia.seed(seed)
38 | torch.cuda.manual_seed_all(seed)
39 | torch.backends.cudnn.benchmark = False
40 | torch.backends.cudnn.deterministic = True
41 |
42 | print('Using manual seed: {seed}'.format(seed=seed))
43 | return
44 |
45 |
46 | def check_log_dir(log_dir):
47 | # check if log dir exist
48 | if os.path.isdir(log_dir):
49 | color_word = colored('WARMING', color='red', attrs=['bold', 'blink'])
50 | print('%s: %s exist!' % (color_word, colored(log_dir, attrs=['underline'])))
51 | while (True):
52 | print('Select Action: d (delete)/ q (quit)', end='')
53 | key = input()
54 | if key == 'd':
55 | shutil.rmtree(log_dir)
56 | break
57 | elif key == 'q':
58 | exit()
59 | else:
60 | color_word = colored('ERR', color='red')
61 | print('---[%s] Unrecognized character!' % color_word)
62 | return
63 |
64 |
65 | def plot_confusion_matrix(conf_mat, label):
66 | """
67 | Parameters:
68 | title='Confusion matrix' : Title for your matrix
69 | tensor_name = 'MyFigure/image' : Name for the output summay tensor
70 | Returns:
71 | summary: image of plot figure
72 | Other items to note:
73 | - Depending on the number of category and the data , you may have to modify the figzie, font sizes etc.
74 | - Currently, some of the ticks dont line up due to rotations.
75 | """
76 |
77 | cm = conf_mat
78 |
79 | np.set_printoptions(precision=2) # print numpy array with 2 decimal places
80 |
81 | fig = matplotlib.figure.Figure(figsize=(7, 7), dpi=320, facecolor='w', edgecolor='k')
82 | ax = fig.add_subplot(1, 1, 1)
83 | im = ax.imshow(cm, cmap='Oranges')
84 |
85 | classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in label]
86 | classes = ['\n'.join(textwrap.wrap(l, 40)) for l in classes]
87 |
88 | tick_marks = np.arange(len(classes))
89 |
90 | ax.set_xlabel('Predicted', fontsize=7)
91 | ax.set_xticks(tick_marks)
92 | c = ax.set_xticklabels(classes, fontsize=4, rotation=-90, ha='center')
93 | ax.xaxis.set_label_position('bottom')
94 | ax.xaxis.tick_bottom()
95 |
96 | ax.set_ylabel('True Label', fontsize=7)
97 | ax.set_yticks(tick_marks)
98 | ax.set_yticklabels(classes, fontsize=4, va='center')
99 | ax.yaxis.set_label_position('left')
100 | ax.yaxis.tick_left()
101 |
102 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
103 | ax.text(j, i, format(cm[i, j], 'd') if cm[i, j] != 0 else '.',
104 | horizontalalignment="center", fontsize=6,
105 | verticalalignment='center', color="black")
106 | fig.set_tight_layout(True)
107 |
108 | fig.canvas.draw()
109 | w, h = fig.canvas.get_width_height()
110 |
111 | # get PNG data from the figure
112 | png_buffer = io.BytesIO()
113 | fig.canvas.print_png(png_buffer)
114 | png_encoded = png_buffer.getvalue()
115 | png_buffer.close()
116 |
117 | return png_encoded
118 |
119 |
120 | ####
121 | def update_log(output, epoch, net_name, prefix, color, tfwriter, log_file, logging):
122 | # print values and convert
123 | max_length = len(max(output.keys(), key=len))
124 | for metric in output:
125 | key = colored(prefix + '-' + metric.ljust(max_length), color)
126 | print('------%s : ' % key, end='')
127 | if metric not in ['conf_mat_c', 'conf_mat_r', 'box_plot_data']:
128 | print('%0.7f' % output[metric])
129 | elif metric == 'conf_mat_c':
130 | conf_mat_c = output['conf_mat_c'] # use pivot to turn back
131 | conf_mat_c_df = pd.DataFrame(conf_mat_c)
132 | conf_mat_c_df.index.name = 'True'
133 | conf_mat_c_df.columns.name = 'Pred'
134 | output['conf_mat_c'] = conf_mat_c_df
135 | print('\n', conf_mat_c_df)
136 | elif metric == 'conf_mat_r':
137 | conf_mat_r = output['conf_mat_r'] # use pivot to turn back
138 | conf_mat_r_df = pd.DataFrame(conf_mat_r)
139 | conf_mat_r_df.index.name = 'True'
140 | conf_mat_r_df.columns.name = 'Pred'
141 | output['conf_mat_r'] = conf_mat_r_df
142 | print('\n', conf_mat_r_df)
143 | elif metric == 'box_plot_data':
144 | box_plot_data = output['box_plot_data'] # use pivot to turn back
145 | box_plot_data_df = pd.DataFrame(box_plot_data)
146 | box_plot_data_df.columns.name = 'Pred'
147 | output['box_plot_data'] = box_plot_data_df
148 |
149 | if not logging:
150 | return
151 |
152 | # create stat dicts
153 | stat_dict = {}
154 | for metric in output:
155 | if metric not in ['conf_mat_c', 'conf_mat_r', 'box_plot_data']:
156 | metric_value = output[metric]
157 | elif metric == 'conf_mat_c':
158 | conf_mat_df = output['conf_mat_c'] # use pivot to turn back
159 | conf_mat_df = conf_mat_df.unstack().rename('value').reset_index()
160 | conf_mat_df = pd.Series({'conf_mat_c': conf_mat_c}).to_json(orient='records')
161 | metric_value = conf_mat_df
162 | elif metric == 'conf_mat_r':
163 | conf_mat_regres_df = output['conf_mat_r'] # use pivot to turn back
164 | conf_mat_regres_df = conf_mat_regres_df.unstack().rename('value').reset_index()
165 | conf_mat_regres_df = pd.Series({'conf_mat_r': conf_mat_r}).to_json(orient='records')
166 | metric_value = conf_mat_regres_df
167 | elif metric == 'box_plot_data':
168 | box_plot_data_df = pd.Series({'box_plot_data': box_plot_data}).to_json(orient='records')
169 | metric_value = box_plot_data_df
170 | stat_dict['%s-%s' % (prefix, metric)] = metric_value
171 |
172 | # json stat log file, update and overwrite
173 | with open(log_file) as json_file:
174 | json_data = json.load(json_file)
175 |
176 | current_epoch = str(epoch)
177 | current_model = str(net_name)
178 | if current_epoch in json_data:
179 | old_stat_dict = json_data[current_model]
180 | stat_dict.update(old_stat_dict)
181 | current_epoch_dict = {current_model: stat_dict}
182 | json_data.update(current_epoch_dict)
183 |
184 | with open(log_file, 'w') as json_file:
185 | json.dump(json_data, json_file)
186 |
187 | # log values to tensorboard
188 | for metric in output:
189 | if metric not in ['conf_mat_c', 'conf_mat_r', 'box_plot_data']:
190 | tfwriter.add_scalar(prefix + '-' + metric, output[metric], current_epoch)
191 |
192 |
193 | ####
194 | def log_train_ema_results(engine, info):
195 | """
196 | running training measurement
197 | """
198 | training_ema_output = engine.state.metrics #
199 | training_ema_output['lr'] = float(info['optimizer'].param_groups[0]['lr'])
200 | update_log(training_ema_output, engine.state.epoch, 'train-ema', 'green',
201 | info['tfwriter'], info['json_file'], info['logging'])
202 |
203 |
204 | ####
205 | def process_accumulated_output_multi(output, batch_size, nr_classes):
206 | #
207 | def uneven_seq_to_np(seq):
208 | item_count = batch_size * (len(seq) - 1) + len(seq[-1])
209 | cat_array = np.zeros((item_count,) + seq[0][0].shape, seq[0].dtype)
210 | # BUG: odd len even
211 | if len(seq) < 2:
212 | return seq[0]
213 | for idx in range(0, len(seq) - 1):
214 | cat_array[idx * batch_size:
215 | (idx + 1) * batch_size] = seq[idx]
216 | cat_array[(idx + 1) * batch_size:] = seq[-1]
217 | return cat_array
218 |
219 | proc_output = dict()
220 | true = uneven_seq_to_np(output['true'])
221 | # threshold then get accuracy
222 | if 'logit_c' in output.keys():
223 | logit_c = uneven_seq_to_np(output['logit_c'])
224 | pred_c = np.argmax(logit_c, axis=-1)
225 | # pred_c = [covert_dict[pred_c[idx]] for idx in range(len(pred_c))]
226 | acc_c = np.mean(pred_c == true)
227 | print(acc_c)
228 | # confusion matrix
229 | conf_mat_c = confusion_matrix(true, pred_c, labels=np.arange(nr_classes))
230 | proc_output.update(acc_c=acc_c, conf_mat_c=conf_mat_c,)
231 | if 'logit_r' in output.keys():
232 | logit_r = uneven_seq_to_np(output['logit_r'])
233 | label = np.transpose(np.array([[0., 1., 2., 3.]]).repeat(len(true), axis=0), (1, 0))
234 | pred_r = np.argmin(abs((logit_r - label)), axis=0)
235 | # pred_r = [covert_dict[pred_r[idx]] for idx in range(len(pred_r))]
236 | acc_r = np.mean(pred_r == true)
237 | # print(acc_r)
238 | # confusion matrix
239 | conf_mat_r = confusion_matrix(true, pred_r, labels=np.arange(nr_classes))
240 | proc_output.update(acc_r=acc_r, conf_mat_r=conf_mat_r)
241 |
242 | # proc_output.update(box_plot_data=np.concatenate(
243 | # [true[np.newaxis, :], pred_c[np.newaxis, :], pred_r[np.newaxis, :], logit_r.transpose(1, 0)], 0))
244 | return proc_output
245 |
246 | def process_accumulated_output_multi_mix(output, batch_size, nr_classes):
247 | #
248 | def uneven_seq_to_np(seq):
249 | item_count = batch_size * (len(seq) - 1) + len(seq[-1])
250 | cat_array = np.zeros((item_count,) + seq[0][0].shape, seq[0].dtype)
251 | # BUG: odd len even
252 | if len(seq) < 2:
253 | return seq[0]
254 | for idx in range(0, len(seq) - 1):
255 | cat_array[idx * batch_size:
256 | (idx + 1) * batch_size] = seq[idx]
257 | cat_array[(idx + 1) * batch_size:] = seq[-1]
258 | return cat_array
259 |
260 | proc_output = dict()
261 | true = uneven_seq_to_np(output['true'])
262 | # threshold then get accuracy
263 | if 'logit_c' in output.keys():
264 | logit_c = uneven_seq_to_np(output['logit_c'])
265 |
266 | pred_c = np.argmax(logit_c, axis=-1)
267 | # pred_c = [covert_dict[pred_c[idx]] for idx in range(len(pred_c))]
268 | acc_c = np.mean(pred_c == true)
269 | print('acc_c',acc_c)
270 | # print(classification_report(true, pred_c, labels=[0, 1, 2, 3]))
271 | # confusion matrix
272 | conf_mat_c = confusion_matrix(true, pred_c, labels=np.arange(nr_classes))
273 | proc_output.update(acc_c=acc_c, conf_mat_c=conf_mat_c,)
274 | if 'logit_r' in output.keys():
275 | logit_r = uneven_seq_to_np(output['logit_r'])
276 | label = np.transpose(np.array([[0., 1., 2., 3.]]).repeat(len(true), axis=0), (1, 0))
277 | pred_r = np.argmin(abs((logit_r - label)), axis=0)
278 | # pred_r = [covert_dict[pred_r[idx]] for idx in range(len(pred_r))]
279 | acc_r = np.mean(pred_r == true)
280 | print('acc_r',acc_r)
281 | # print(classification_report(true, pred_r, labels=[0, 1, 2, 3]))
282 | # confusion matrix
283 | conf_mat_r = confusion_matrix(true, pred_r, labels=np.arange(nr_classes))
284 | proc_output.update(acc_r=acc_r, conf_mat_r=conf_mat_r)
285 |
286 | # if ('logit_r' in output.keys()) and ('logit_c' in output.keys()):
287 | # a = abs((logit_r - label)).transpose(1, 0)
288 | # prob_r = softmax(-a, 1)
289 | # logit_c +=prob_r
290 | # pred_c = np.argmax(logit_c, axis=-1)
291 | # acc_c = np.mean(pred_c == true)
292 | # print('acc_mix',acc_c)
293 |
294 | # proc_output.update(box_plot_data=np.concatenate(
295 | # [true[np.newaxis, :], pred_c[np.newaxis, :], pred_r[np.newaxis, :], logit_r.transpose(1, 0)], 0))
296 | return proc_output
297 |
298 | def process_accumulated_output_multi_testAUG(output, batch_size, nr_classes):
299 | #
300 | def uneven_seq_to_np(seq):
301 | item_count = batch_size * (len(seq) - 1) + len(seq[-1])
302 | cat_array = np.zeros((item_count,) + seq[0][0].shape, seq[0].dtype)
303 | # BUG: odd len even
304 | for idx in range(0, len(seq) - 1):
305 | cat_array[idx * batch_size:
306 | (idx + 1) * batch_size] = seq[idx]
307 | cat_array[(idx + 1) * batch_size:] = seq[-1]
308 | return cat_array
309 |
310 | proc_output = dict()
311 | true = uneven_seq_to_np(output['true'])
312 | # threshold then get accuracy
313 | if 'pred_c' in output.keys():
314 | pred_c = uneven_seq_to_np(output['pred_c'])
315 | acc_c = np.mean(pred_c == true)
316 | # confusion matrix
317 | conf_mat_c = confusion_matrix(true, pred_c, labels=np.arange(nr_classes))
318 | proc_output.update(acc_c=acc_c, conf_mat_c=conf_mat_c,)
319 | if 'pred_r' in output.keys():
320 | pred_r = uneven_seq_to_np(output['pred_r'])
321 | acc_r = np.mean(pred_r == true)
322 | # confusion matrix
323 | conf_mat_r = confusion_matrix(true, pred_r, labels=np.arange(nr_classes))
324 | proc_output.update(acc_r=acc_r, conf_mat_r=conf_mat_r)
325 | return proc_output
326 |
327 |
328 | ####
329 | def inference(engine, inferer, prefix, dataloader, info):
330 | """
331 | inference measurement
332 | """
333 | inferer.accumulator = {metric: [] for metric in info['metric_names']}
334 | inferer.run(dataloader)
335 | output_stat = process_accumulated_output_multi(inferer.accumulator,
336 | info['infer_batch_size'], info['nr_classes'])
337 | update_log(output_stat, engine.state.epoch, prefix, 'red',
338 | info['tfwriter'], info['json_file'], info['logging'])
339 | return
340 |
341 |
342 | ####
343 | def accumulate_outputs(engine):
344 | batch_output = engine.state.output
345 | for key, item in batch_output.items():
346 | engine.accumulator[key].extend([item])
347 | return
348 |
349 |
350 | def accumulate_predict(pred_patch):
351 | unique, counts = np.unique(pred_patch.cpu(), return_counts=True)
352 | pred_count = dict(zip(unique, counts))
353 | patch_label = max(pred_count, key=pred_count.get)
354 | return patch_label
355 |
--------------------------------------------------------------------------------
/model_lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/trinhvg/JCO_Learning-pytorch/6a6044801ac0a2f6ea2c479d4417e1128e5a0b72/model_lib/__init__.py
--------------------------------------------------------------------------------
/model_lib/efficientnet_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.7.0"
2 | from .model import EfficientNet
3 | from .utils import (
4 | GlobalParams,
5 | BlockArgs,
6 | BlockDecoder,
7 | efficientnet,
8 | get_model_params,
9 | )
10 |
11 |
--------------------------------------------------------------------------------
/model_lib/efficientnet_pytorch/model.py:
--------------------------------------------------------------------------------
1 | """model.py - Model and module class for EfficientNet.
2 | They are built to mirror those in the official TensorFlow implementation.
3 | """
4 |
5 | # Author: lukemelas (github username)
6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7 | # With adjustments and added comments by workingcoder (github username).
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 | import torchsummary as summary
13 | from model_lib.efficientnet_pytorch.utils import (
14 | round_filters,
15 | round_repeats,
16 | drop_connect,
17 | get_same_padding_conv2d,
18 | get_model_params,
19 | efficientnet_params,
20 | load_pretrained_weights,
21 | Swish,
22 | MemoryEfficientSwish,
23 | calculate_output_image_size
24 | )
25 |
26 | class MBConvBlock(nn.Module):
27 | """Mobile Inverted Residual Bottleneck Block.
28 |
29 | Args:
30 | block_args (namedtuple): BlockArgs, defined in utils.py.
31 | global_params (namedtuple): GlobalParam, defined in utils.py.
32 | image_size (tuple or list): [image_height, image_width].
33 |
34 | References:
35 | [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
36 | [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
37 | [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
38 | """
39 |
40 | def __init__(self, block_args, global_params, image_size=None):
41 | super().__init__()
42 | self._block_args = block_args
43 | self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
44 | self._bn_eps = global_params.batch_norm_epsilon
45 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
46 | self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
47 |
48 | # Expansion phase (Inverted Bottleneck)
49 | inp = self._block_args.input_filters # number of input channels
50 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
51 | if self._block_args.expand_ratio != 1:
52 | Conv2d = get_same_padding_conv2d(image_size=image_size)
53 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
54 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
55 | # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
56 |
57 | # Depthwise convolution phase
58 | k = self._block_args.kernel_size
59 | s = self._block_args.stride
60 | Conv2d = get_same_padding_conv2d(image_size=image_size)
61 | self._depthwise_conv = Conv2d(
62 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
63 | kernel_size=k, stride=s, bias=False)
64 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
65 | image_size = calculate_output_image_size(image_size, s)
66 |
67 | # Squeeze and Excitation layer, if desired
68 | if self.has_se:
69 | Conv2d = get_same_padding_conv2d(image_size=(1,1))
70 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
71 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
72 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
73 |
74 | # Pointwise convolution phase
75 | final_oup = self._block_args.output_filters
76 | Conv2d = get_same_padding_conv2d(image_size=image_size)
77 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
78 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
79 | self._swish = MemoryEfficientSwish()
80 |
81 | def forward(self, inputs, drop_connect_rate=None):
82 | """MBConvBlock's forward function.
83 |
84 | Args:
85 | inputs (tensor): Input tensor.
86 | drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
87 |
88 | Returns:
89 | Output of this block after processing.
90 | """
91 |
92 | # Expansion and Depthwise Convolution
93 | x = inputs
94 | if self._block_args.expand_ratio != 1:
95 | x = self._expand_conv(inputs)
96 | x = self._bn0(x)
97 | x = self._swish(x)
98 |
99 | x = self._depthwise_conv(x)
100 | x = self._bn1(x)
101 | x = self._swish(x)
102 |
103 | # Squeeze and Excitation
104 | if self.has_se:
105 | x_squeezed = F.adaptive_avg_pool2d(x, 1)
106 | x_squeezed = self._se_reduce(x_squeezed)
107 | x_squeezed = self._swish(x_squeezed)
108 | x_squeezed = self._se_expand(x_squeezed)
109 | x = torch.sigmoid(x_squeezed) * x
110 |
111 | # Pointwise Convolution
112 | x = self._project_conv(x)
113 | x = self._bn2(x)
114 |
115 | # Skip connection and drop connect
116 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
117 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
118 | # The combination of skip connection and drop connect brings about stochastic depth.
119 | if drop_connect_rate:
120 | x = drop_connect(x, p=drop_connect_rate, training=self.training)
121 | x = x + inputs # skip connection
122 | return x
123 |
124 | def set_swish(self, memory_efficient=True):
125 | """Sets swish function as memory efficient (for training) or standard (for export).
126 |
127 | Args:
128 | memory_efficient (bool): Whether to use memory-efficient version of swish.
129 | """
130 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
131 |
132 |
133 | class EfficientNet(nn.Module):
134 | """EfficientNet model.
135 | Most easily loaded with the .from_name or .from_pretrained methods.
136 |
137 | Args:
138 | blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
139 | global_params (namedtuple): A set of GlobalParams shared between blocks.
140 |
141 | References:
142 | [1] https://arxiv.org/abs/1905.11946 (EfficientNet)
143 |
144 | # Example:
145 | # >>> import torch
146 | # >>> from efficientnet.model import EfficientNet
147 | # >>> inputs = torch.rand(1, 3, 224, 224)
148 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0')
149 | # >>> model.eval()
150 | # >>> outputs = model(inputs)
151 | """
152 |
153 | def __init__(self, task_mode='class', blocks_args=None, global_params=None):
154 | super().__init__()
155 | assert isinstance(blocks_args, list), 'blocks_args should be a list'
156 | assert len(blocks_args) > 0, 'block args must be greater than 0'
157 | self._global_params = global_params
158 | self._blocks_args = blocks_args
159 | self.task_mode = task_mode
160 |
161 | # Batch norm parameters
162 | bn_mom = 1 - self._global_params.batch_norm_momentum
163 | bn_eps = self._global_params.batch_norm_epsilon
164 |
165 | # Get stem static or dynamic convolution depending on image size
166 | image_size = global_params.image_size
167 | Conv2d = get_same_padding_conv2d(image_size=image_size)
168 |
169 | # Stem
170 | in_channels = 3 # rgb
171 | out_channels = round_filters(32, self._global_params) # number of output channels
172 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
173 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
174 | image_size = calculate_output_image_size(image_size, 2)
175 |
176 | # Build blocks
177 | self._blocks = nn.ModuleList([])
178 | for block_args in self._blocks_args:
179 |
180 | # Update block input and output filters based on depth multiplier.
181 | block_args = block_args._replace(
182 | input_filters=round_filters(block_args.input_filters, self._global_params),
183 | output_filters=round_filters(block_args.output_filters, self._global_params),
184 | num_repeat=round_repeats(block_args.num_repeat, self._global_params)
185 | )
186 |
187 | # The first block needs to take care of stride and filter size increase.
188 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
189 | image_size = calculate_output_image_size(image_size, block_args.stride)
190 | if block_args.num_repeat > 1: # modify block_args to keep same output size
191 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
192 | for _ in range(block_args.num_repeat - 1):
193 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
194 | # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
195 |
196 | # Head
197 | in_channels = block_args.output_filters # output of final block
198 | out_channels = round_filters(1280, self._global_params)
199 | Conv2d = get_same_padding_conv2d(image_size=image_size)
200 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
201 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
202 |
203 | # Final linear layer
204 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
205 | # self._dropout = nn.Dropout(self._global_params.dropout_rate)
206 | # self._fc = nn.Linear(out_channels, self._global_params.num_classes)
207 | # building classifier
208 | if self.task_mode in ['class', 'multi']:
209 | self.classifier_ = nn.Sequential(
210 | nn.Dropout(self._global_params.dropout_rate),
211 | nn.Linear(out_channels, self._global_params.num_classes),
212 | )
213 | if self.task_mode in ['regress', 'multi']:
214 | self.regressioner_ = nn.Sequential(
215 | nn.Dropout(self._global_params.dropout_rate),
216 | nn.Linear(out_channels, 1),
217 | )
218 | self._swish = MemoryEfficientSwish()
219 |
220 | # weight initialization
221 | for m in self.modules():
222 | if isinstance(m, nn.Conv2d):
223 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
224 | if m.bias is not None:
225 | nn.init.zeros_(m.bias)
226 | elif isinstance(m, nn.BatchNorm2d):
227 | nn.init.ones_(m.weight)
228 | nn.init.zeros_(m.bias)
229 | elif isinstance(m, nn.Linear):
230 | nn.init.normal_(m.weight, 0, 0.01)
231 | nn.init.zeros_(m.bias)
232 |
233 | def set_swish(self, memory_efficient=True):
234 | """Sets swish function as memory efficient (for training) or standard (for export).
235 |
236 | Args:
237 | memory_efficient (bool): Whether to use memory-efficient version of swish.
238 |
239 | """
240 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
241 | for block in self._blocks:
242 | block.set_swish(memory_efficient)
243 |
244 | def extract_endpoints(self, inputs):
245 | """Use convolution layer to extract features
246 | from reduction levels i in [1, 2, 3, 4, 5].
247 |
248 | Args:
249 | inputs (tensor): Input tensor.
250 |
251 | Returns:
252 | Dictionary of last intermediate features
253 | with reduction levels i in [1, 2, 3, 4, 5].
254 | Example:
255 | # >>> import torch
256 | # >>> from efficientnet.model import EfficientNet
257 | # >>> inputs = torch.rand(1, 3, 224, 224)
258 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0')
259 | # >>> endpoints = model.extract_features(inputs)
260 | # >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
261 | # >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
262 | # >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
263 | # >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
264 | # >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
265 | """
266 | endpoints = dict()
267 |
268 | # Stem
269 | x = self._swish(self._bn0(self._conv_stem(inputs)))
270 | prev_x = x
271 |
272 | # Blocks
273 | for idx, block in enumerate(self._blocks):
274 | drop_connect_rate = self._global_params.drop_connect_rate
275 | if drop_connect_rate:
276 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
277 | x = block(x, drop_connect_rate=drop_connect_rate)
278 | if prev_x.size(2) > x.size(2):
279 | endpoints[f'reduction_{len(endpoints)+1}'] = prev_x
280 | prev_x = x
281 |
282 | # Head
283 | x = self._swish(self._bn1(self._conv_head(x)))
284 | endpoints[f'reduction_{len(endpoints)+1}'] = x
285 |
286 | return endpoints
287 |
288 | def extract_features(self, inputs):
289 | """use convolution layer to extract feature .
290 |
291 | Args:
292 | inputs (tensor): Input tensor.
293 |
294 | Returns:
295 | Output of the final convolution
296 | layer in the efficientnet model.
297 | """
298 | # Stem
299 | x = self._swish(self._bn0(self._conv_stem(inputs)))
300 |
301 | # Blocks
302 | for idx, block in enumerate(self._blocks):
303 | drop_connect_rate = self._global_params.drop_connect_rate
304 | if drop_connect_rate:
305 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
306 | x = block(x, drop_connect_rate=drop_connect_rate)
307 |
308 | # Head
309 | x = self._swish(self._bn1(self._conv_head(x)))
310 |
311 | return x
312 |
313 | def forward(self, inputs):
314 | """EfficientNet's forward function.
315 | Calls extract_features to extract features, applies final linear layer, and returns logits.
316 |
317 | Args:
318 | inputs (tensor): Input tensor.
319 |
320 | Returns:
321 | Output of this model after processing.
322 | """
323 | # Convolution layers
324 | x = self.extract_features(inputs)
325 |
326 | # Pooling and final linear layer
327 | x = self._avg_pooling(x)
328 | x = x.flatten(start_dim=1)
329 | # x = self._dropout(x)
330 | # x = self._fc(x)
331 | # return x
332 |
333 | if self.task_mode == 'class':
334 | c_out = self.classifier_(x)
335 | return c_out
336 | elif self.task_mode == 'regress':
337 | r_out = self.regressioner_(x)
338 | return r_out[:, 0]
339 | elif self.task_mode == 'multi':
340 | c_out = self.classifier_(x)
341 | r_out = self.regressioner_(x)
342 | return c_out, r_out[:, 0]
343 | else:
344 | print(f'Do not support: {self.task_mode}'
345 | f'Only support one of [multi, class, and regress] task_mode')
346 |
347 |
348 | @classmethod
349 | def from_name(cls, task_mode, model_name, in_channels=3, **override_params):
350 | """create an efficientnet model according to name.
351 |
352 | Args:
353 | task_mode (str): class, multi, regress
354 | model_name (str): Name for efficientnet.
355 | in_channels (int): Input data's channel number.
356 | override_params (other key word params):
357 | Params to override model's global_params.
358 | Optional key:
359 | 'width_coefficient', 'depth_coefficient',
360 | 'image_size', 'dropout_rate',
361 | 'num_classes', 'batch_norm_momentum',
362 | 'batch_norm_epsilon', 'drop_connect_rate',
363 | 'depth_divisor', 'min_depth'
364 |
365 | Returns:
366 | An efficientnet model.
367 | """
368 | cls._check_model_name_is_valid(model_name)
369 | blocks_args, global_params = get_model_params(model_name, override_params)
370 | model = cls(task_mode, blocks_args, global_params)
371 | model._change_in_channels(in_channels)
372 | return model
373 |
374 | @classmethod
375 | def from_pretrained(cls, task_mode, model_name, weights_path=None, advprop=False,
376 | in_channels=3, num_classes=1000, **override_params):
377 | """create an efficientnet model according to name.
378 |
379 | Args:
380 | task_mode (str): class, multi, regress
381 | model_name (str): Name for efficientnet.
382 | weights_path (None or str):
383 | str: path to pretrained weights file on the local disk.
384 | None: use pretrained weights downloaded from the Internet.
385 | advprop (bool):
386 | Whether to load pretrained weights
387 | trained with advprop (valid when weights_path is None).
388 | in_channels (int): Input data's channel number.
389 | num_classes (int):
390 | Number of categories for classification.
391 | It controls the output size for final linear layer.
392 | override_params (other key word params):
393 | Params to override model's global_params.
394 | Optional key:
395 | 'width_coefficient', 'depth_coefficient',
396 | 'image_size', 'dropout_rate',
397 | 'num_classes', 'batch_norm_momentum',
398 | 'batch_norm_epsilon', 'drop_connect_rate',
399 | 'depth_divisor', 'min_depth'
400 |
401 | Returns:
402 | A pretrained efficientnet model.
403 | """
404 | model = cls.from_name(task_mode, model_name, num_classes=num_classes, **override_params)
405 | load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop)
406 | model._change_in_channels(in_channels)
407 | return model
408 |
409 | @classmethod
410 | def get_image_size(cls, model_name):
411 | """Get the input image size for a given efficientnet model.
412 |
413 | Args:
414 | model_name (str): Name for efficientnet.
415 |
416 | Returns:
417 | Input image size (resolution).
418 | """
419 | cls._check_model_name_is_valid(model_name)
420 | _, _, res, _ = efficientnet_params(model_name)
421 | return res
422 |
423 | @classmethod
424 | def _check_model_name_is_valid(cls, model_name):
425 | """Validates model name.
426 |
427 | Args:
428 | model_name (str): Name for efficientnet.
429 |
430 | Returns:
431 | bool: Is a valid name or not.
432 | """
433 | valid_models = ['efficientnet-b'+str(i) for i in range(9)]
434 |
435 | # Support the construction of 'efficientnet-l2' without pretrained weights
436 | valid_models += ['efficientnet-l2']
437 |
438 | if model_name not in valid_models:
439 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
440 |
441 | def _change_in_channels(self, in_channels):
442 | """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
443 |
444 | Args:
445 | in_channels (int): Input data's channel number.
446 | """
447 | if in_channels != 3:
448 | Conv2d = get_same_padding_conv2d(image_size = self._global_params.image_size)
449 | out_channels = round_filters(32, self._global_params)
450 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
451 |
452 |
453 | def jl_efficientnet(task_mode='class', pretrained=True, num_classes=4, **kwargs):
454 | """
455 | Joint_learning efficient net
456 |
457 | Args:
458 | task_mode (string): multi, class, regress
459 | pretrained (bool): If True, returns a model pre-trained on ImageNet
460 | num_classes (int): number of class or number of output node
461 | """
462 | func = EfficientNet.from_pretrained if pretrained else EfficientNet.from_name
463 | model = func(task_mode=task_mode, model_name='efficientnet-b0', num_classes=num_classes)
464 | return model
465 |
466 |
467 | # def _test():
468 | # net = jl_efficientnet(task_mode='regress', pretrained=True, num_classes=3).cuda()
469 | # y = net(torch.randn(48, 3, 224, 224).cuda())
470 | # # print(y_class.size(), y_regres.size())
471 | # # y_class = net(torch.randn(48, 3, 224, 224).cuda())
472 | # # print(y_class.size())
473 | #
474 | # model = net.cuda()
475 | # summary(model, (3, 224, 224))
476 | # _test()
477 |
--------------------------------------------------------------------------------
/model_lib/efficientnet_pytorch/model_dorn.py:
--------------------------------------------------------------------------------
1 | """model.py - Model and module class for EfficientNet.
2 | They are built to mirror those in the official TensorFlow implementation.
3 | """
4 |
5 | # Author: lukemelas (github username)
6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7 | # With adjustments and added comments by workingcoder (github username).
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 | from model_lib.efficientnet_pytorch.utils import (
13 | round_filters,
14 | round_repeats,
15 | drop_connect,
16 | get_same_padding_conv2d,
17 | get_model_params,
18 | efficientnet_params,
19 | load_pretrained_weights,
20 | Swish,
21 | MemoryEfficientSwish,
22 | calculate_output_image_size
23 | )
24 |
25 |
26 | class OrdinalRegressionLayer(nn.Module):
27 | def __init__(self):
28 | super(OrdinalRegressionLayer, self).__init__()
29 |
30 | def forward(self, x):
31 | """
32 | :param x: N x 2K x H x W; N - batch_size, 2K - channels, K - number of discrete sub-intervals
33 | :return: labels - ordinal labels (corresponding to discrete depth values) of size N x 1 x H x W
34 | softmax - predicted softmax probabilities P (as in the paper) of size N x K x H x W
35 | """
36 | N, K= x.size()
37 | K = K // 2 # number of discrete sub-intervals
38 |
39 | odd = x[:, ::2].clone()
40 | even = x[:, 1::2].clone()
41 |
42 | odd = odd.view(N, 1, K)
43 | even = even.view(N, 1, K)
44 |
45 | paired_channels = torch.cat((odd, even), dim=1)
46 | paired_channels = paired_channels.clamp(min=1e-8, max=1e8) # prevent nans
47 |
48 | softmax = nn.functional.softmax(paired_channels, dim=1)
49 |
50 | softmax = softmax[:, 1, :]
51 | softmax = softmax.view(-1, K)
52 | labels = torch.sum((softmax > 0.5), dim=1).view(-1, 1) - 1
53 | return labels[:, 0], softmax
54 |
55 |
56 | class MBConvBlock(nn.Module):
57 | """Mobile Inverted Residual Bottleneck Block.
58 |
59 | Args:
60 | block_args (namedtuple): BlockArgs, defined in utils.py.
61 | global_params (namedtuple): GlobalParam, defined in utils.py.
62 | image_size (tuple or list): [image_height, image_width].
63 |
64 | References:
65 | [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
66 | [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
67 | [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
68 | """
69 |
70 | def __init__(self, block_args, global_params, image_size=None):
71 | super().__init__()
72 | self._block_args = block_args
73 | self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
74 | self._bn_eps = global_params.batch_norm_epsilon
75 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
76 | self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
77 |
78 | # Expansion phase (Inverted Bottleneck)
79 | inp = self._block_args.input_filters # number of input channels
80 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
81 | if self._block_args.expand_ratio != 1:
82 | Conv2d = get_same_padding_conv2d(image_size=image_size)
83 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
84 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
85 | # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
86 |
87 | # Depthwise convolution phase
88 | k = self._block_args.kernel_size
89 | s = self._block_args.stride
90 | Conv2d = get_same_padding_conv2d(image_size=image_size)
91 | self._depthwise_conv = Conv2d(
92 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
93 | kernel_size=k, stride=s, bias=False)
94 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
95 | image_size = calculate_output_image_size(image_size, s)
96 |
97 | # Squeeze and Excitation layer, if desired
98 | if self.has_se:
99 | Conv2d = get_same_padding_conv2d(image_size=(1,1))
100 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
101 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
102 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
103 |
104 | # Pointwise convolution phase
105 | final_oup = self._block_args.output_filters
106 | Conv2d = get_same_padding_conv2d(image_size=image_size)
107 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
108 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
109 | self._swish = MemoryEfficientSwish()
110 |
111 | def forward(self, inputs, drop_connect_rate=None):
112 | """MBConvBlock's forward function.
113 |
114 | Args:
115 | inputs (tensor): Input tensor.
116 | drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
117 |
118 | Returns:
119 | Output of this block after processing.
120 | """
121 |
122 | # Expansion and Depthwise Convolution
123 | x = inputs
124 | if self._block_args.expand_ratio != 1:
125 | x = self._expand_conv(inputs)
126 | x = self._bn0(x)
127 | x = self._swish(x)
128 |
129 | x = self._depthwise_conv(x)
130 | x = self._bn1(x)
131 | x = self._swish(x)
132 |
133 | # Squeeze and Excitation
134 | if self.has_se:
135 | x_squeezed = F.adaptive_avg_pool2d(x, 1)
136 | x_squeezed = self._se_reduce(x_squeezed)
137 | x_squeezed = self._swish(x_squeezed)
138 | x_squeezed = self._se_expand(x_squeezed)
139 | x = torch.sigmoid(x_squeezed) * x
140 |
141 | # Pointwise Convolution
142 | x = self._project_conv(x)
143 | x = self._bn2(x)
144 |
145 | # Skip connection and drop connect
146 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
147 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
148 | # The combination of skip connection and drop connect brings about stochastic depth.
149 | if drop_connect_rate:
150 | x = drop_connect(x, p=drop_connect_rate, training=self.training)
151 | x = x + inputs # skip connection
152 | return x
153 |
154 | def set_swish(self, memory_efficient=True):
155 | """Sets swish function as memory efficient (for training) or standard (for export).
156 |
157 | Args:
158 | memory_efficient (bool): Whether to use memory-efficient version of swish.
159 | """
160 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
161 |
162 |
163 | class EfficientNet(nn.Module):
164 | """EfficientNet model.
165 | Most easily loaded with the .from_name or .from_pretrained methods.
166 |
167 | Args:
168 | blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
169 | global_params (namedtuple): A set of GlobalParams shared between blocks.
170 |
171 | References:
172 | [1] https://arxiv.org/abs/1905.11946 (EfficientNet)
173 |
174 | # Example:
175 | # >>> import torch
176 | # >>> from efficientnet.model import EfficientNet
177 | # >>> inputs = torch.rand(1, 3, 224, 224)
178 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0')
179 | # >>> model.eval()
180 | # >>> outputs = model(inputs)
181 | """
182 |
183 | def __init__(self, task_mode='class', blocks_args=None, global_params=None):
184 | super().__init__()
185 | assert isinstance(blocks_args, list), 'blocks_args should be a list'
186 | assert len(blocks_args) > 0, 'block args must be greater than 0'
187 | self._global_params = global_params
188 | self._blocks_args = blocks_args
189 | self.task_mode = task_mode
190 |
191 | # Batch norm parameters
192 | bn_mom = 1 - self._global_params.batch_norm_momentum
193 | bn_eps = self._global_params.batch_norm_epsilon
194 |
195 | # Get stem static or dynamic convolution depending on image size
196 | image_size = global_params.image_size
197 | Conv2d = get_same_padding_conv2d(image_size=image_size)
198 |
199 | # Stem
200 | in_channels = 3 # rgb
201 | out_channels = round_filters(32, self._global_params) # number of output channels
202 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
203 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
204 | image_size = calculate_output_image_size(image_size, 2)
205 |
206 | # Build blocks
207 | self._blocks = nn.ModuleList([])
208 | for block_args in self._blocks_args:
209 |
210 | # Update block input and output filters based on depth multiplier.
211 | block_args = block_args._replace(
212 | input_filters=round_filters(block_args.input_filters, self._global_params),
213 | output_filters=round_filters(block_args.output_filters, self._global_params),
214 | num_repeat=round_repeats(block_args.num_repeat, self._global_params)
215 | )
216 |
217 | # The first block needs to take care of stride and filter size increase.
218 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
219 | image_size = calculate_output_image_size(image_size, block_args.stride)
220 | if block_args.num_repeat > 1: # modify block_args to keep same output size
221 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
222 | for _ in range(block_args.num_repeat - 1):
223 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
224 | # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
225 |
226 | # Head
227 | in_channels = block_args.output_filters # output of final block
228 | out_channels = round_filters(1280, self._global_params)
229 | Conv2d = get_same_padding_conv2d(image_size=image_size)
230 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
231 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
232 |
233 | # Final linear layer
234 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
235 | # self._dropout = nn.Dropout(self._global_params.dropout_rate)
236 | # self._fc = nn.Linear(out_channels, self._global_params.num_classes)
237 | # building classifier
238 | self.classifier_ = nn.Sequential(
239 | nn.Dropout(self._global_params.dropout_rate),
240 | nn.Linear(out_channels, self._global_params.num_classes),
241 | )
242 | self.ordinal_regression = OrdinalRegressionLayer()
243 | self._swish = MemoryEfficientSwish()
244 |
245 | # weight initialization
246 | for m in self.modules():
247 | if isinstance(m, nn.Conv2d):
248 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
249 | if m.bias is not None:
250 | nn.init.zeros_(m.bias)
251 | elif isinstance(m, nn.BatchNorm2d):
252 | nn.init.ones_(m.weight)
253 | nn.init.zeros_(m.bias)
254 | elif isinstance(m, nn.Linear):
255 | nn.init.normal_(m.weight, 0, 0.01)
256 | nn.init.zeros_(m.bias)
257 |
258 | def set_swish(self, memory_efficient=True):
259 | """Sets swish function as memory efficient (for training) or standard (for export).
260 |
261 | Args:
262 | memory_efficient (bool): Whether to use memory-efficient version of swish.
263 |
264 | """
265 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
266 | for block in self._blocks:
267 | block.set_swish(memory_efficient)
268 |
269 | def extract_endpoints(self, inputs):
270 | """Use convolution layer to extract features
271 | from reduction levels i in [1, 2, 3, 4, 5].
272 |
273 | Args:
274 | inputs (tensor): Input tensor.
275 |
276 | Returns:
277 | Dictionary of last intermediate features
278 | with reduction levels i in [1, 2, 3, 4, 5].
279 | Example:
280 | # >>> import torch
281 | # >>> from efficientnet.model import EfficientNet
282 | # >>> inputs = torch.rand(1, 3, 224, 224)
283 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0')
284 | # >>> endpoints = model.extract_features(inputs)
285 | # >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
286 | # >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
287 | # >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
288 | # >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
289 | # >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
290 | """
291 | endpoints = dict()
292 |
293 | # Stem
294 | x = self._swish(self._bn0(self._conv_stem(inputs)))
295 | prev_x = x
296 |
297 | # Blocks
298 | for idx, block in enumerate(self._blocks):
299 | drop_connect_rate = self._global_params.drop_connect_rate
300 | if drop_connect_rate:
301 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
302 | x = block(x, drop_connect_rate=drop_connect_rate)
303 | if prev_x.size(2) > x.size(2):
304 | endpoints[f'reduction_{len(endpoints)+1}'] = prev_x
305 | prev_x = x
306 |
307 | # Head
308 | x = self._swish(self._bn1(self._conv_head(x)))
309 | endpoints[f'reduction_{len(endpoints)+1}'] = x
310 |
311 | return endpoints
312 |
313 | def extract_features(self, inputs):
314 | """use convolution layer to extract feature .
315 |
316 | Args:
317 | inputs (tensor): Input tensor.
318 |
319 | Returns:
320 | Output of the final convolution
321 | layer in the efficientnet model.
322 | """
323 | # Stem
324 | x = self._swish(self._bn0(self._conv_stem(inputs)))
325 |
326 | # Blocks
327 | for idx, block in enumerate(self._blocks):
328 | drop_connect_rate = self._global_params.drop_connect_rate
329 | if drop_connect_rate:
330 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
331 | x = block(x, drop_connect_rate=drop_connect_rate)
332 |
333 | # Head
334 | x = self._swish(self._bn1(self._conv_head(x)))
335 |
336 | return x
337 |
338 | def forward(self, inputs):
339 | """EfficientNet's forward function.
340 | Calls extract_features to extract features, applies final linear layer, and returns logits.
341 |
342 | Args:
343 | inputs (tensor): Input tensor.
344 |
345 | Returns:
346 | Output of this model after processing.
347 | """
348 | # Convolution layers
349 | x = self.extract_features(inputs)
350 |
351 | # Pooling and final linear layer
352 | x = self._avg_pooling(x)
353 | x = x.flatten(start_dim=1)
354 | # x = self._dropout(x)
355 | # x = self._fc(x)
356 | # return x
357 |
358 | if self.task_mode == 'class':
359 | c_out = self.classifier_(x)
360 | predicts, softmax = self.ordinal_regression(c_out)
361 | return predicts, softmax
362 | elif self.task_mode == 'regress':
363 | r_out = self.regressioner_(x)
364 | return r_out[:, 0]
365 | elif self.task_mode == 'multi':
366 | c_out = self.classifier_(x)
367 | r_out = self.regressioner_(x)
368 | return c_out, r_out[:, 0]
369 | else:
370 | print(f'Do not support: {self.task_mode}'
371 | f'Only support one of [multi, class, and regress] task_mode')
372 |
373 |
374 | @classmethod
375 | def from_name(cls, task_mode, model_name, in_channels=3, **override_params):
376 | """create an efficientnet model according to name.
377 |
378 | Args:
379 | task_mode (str): class, multi, regress
380 | model_name (str): Name for efficientnet.
381 | in_channels (int): Input data's channel number.
382 | override_params (other key word params):
383 | Params to override model's global_params.
384 | Optional key:
385 | 'width_coefficient', 'depth_coefficient',
386 | 'image_size', 'dropout_rate',
387 | 'num_classes', 'batch_norm_momentum',
388 | 'batch_norm_epsilon', 'drop_connect_rate',
389 | 'depth_divisor', 'min_depth'
390 |
391 | Returns:
392 | An efficientnet model.
393 | """
394 | cls._check_model_name_is_valid(model_name)
395 | blocks_args, global_params = get_model_params(model_name, override_params)
396 | model = cls(task_mode, blocks_args, global_params)
397 | model._change_in_channels(in_channels)
398 | return model
399 |
400 | @classmethod
401 | def from_pretrained(cls, task_mode, model_name, weights_path=None, advprop=False,
402 | in_channels=3, num_classes=1000, **override_params):
403 | """create an efficientnet model according to name.
404 |
405 | Args:
406 | task_mode (str): class, multi, regress
407 | model_name (str): Name for efficientnet.
408 | weights_path (None or str):
409 | str: path to pretrained weights file on the local disk.
410 | None: use pretrained weights downloaded from the Internet.
411 | advprop (bool):
412 | Whether to load pretrained weights
413 | trained with advprop (valid when weights_path is None).
414 | in_channels (int): Input data's channel number.
415 | num_classes (int):
416 | Number of categories for classification.
417 | It controls the output size for final linear layer.
418 | override_params (other key word params):
419 | Params to override model's global_params.
420 | Optional key:
421 | 'width_coefficient', 'depth_coefficient',
422 | 'image_size', 'dropout_rate',
423 | 'num_classes', 'batch_norm_momentum',
424 | 'batch_norm_epsilon', 'drop_connect_rate',
425 | 'depth_divisor', 'min_depth'
426 |
427 | Returns:
428 | A pretrained efficientnet model.
429 | """
430 | model = cls.from_name(task_mode, model_name, num_classes=num_classes, **override_params)
431 | load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000), advprop=advprop)
432 | model._change_in_channels(in_channels)
433 | return model
434 |
435 | @classmethod
436 | def get_image_size(cls, model_name):
437 | """Get the input image size for a given efficientnet model.
438 |
439 | Args:
440 | model_name (str): Name for efficientnet.
441 |
442 | Returns:
443 | Input image size (resolution).
444 | """
445 | cls._check_model_name_is_valid(model_name)
446 | _, _, res, _ = efficientnet_params(model_name)
447 | return res
448 |
449 | @classmethod
450 | def _check_model_name_is_valid(cls, model_name):
451 | """Validates model name.
452 |
453 | Args:
454 | model_name (str): Name for efficientnet.
455 |
456 | Returns:
457 | bool: Is a valid name or not.
458 | """
459 | valid_models = ['efficientnet-b'+str(i) for i in range(9)]
460 |
461 | # Support the construction of 'efficientnet-l2' without pretrained weights
462 | valid_models += ['efficientnet-l2']
463 |
464 | if model_name not in valid_models:
465 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
466 |
467 | def _change_in_channels(self, in_channels):
468 | """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
469 |
470 | Args:
471 | in_channels (int): Input data's channel number.
472 | """
473 | if in_channels != 3:
474 | Conv2d = get_same_padding_conv2d(image_size = self._global_params.image_size)
475 | out_channels = round_filters(32, self._global_params)
476 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
477 |
478 |
479 | def dorn_efficientnet(task_mode='class', pretrained=True, num_classes=4, **kwargs):
480 | """
481 | Joint_learning efficient net
482 |
483 | Args:
484 | task_mode (string): multi, class, regress
485 | pretrained (bool): If True, returns a model pre-trained on ImageNet
486 | num_classes (int): number of class or number of output node
487 | """
488 | func = EfficientNet.from_pretrained if pretrained else EfficientNet.from_name
489 | model = func(task_mode=task_mode, model_name='efficientnet-b0', num_classes=(num_classes-1)*2)
490 | return model
491 |
492 |
493 | # def _test():
494 | # net = dorn_efficientnet(task_mode='class', pretrained=True).cuda()
495 | # y_class, y_regres = net(torch.randn(48, 3, 224, 224).cuda())
496 | # print(y_class.size(), y_regres.size())
497 | # # y_class = net(torch.randn(48, 3, 224, 224).cuda())
498 | # # print(y_class.size())
499 | #
500 | # # model = net.cuda()
501 | # # summary(model, (3, 224, 224))
502 | # _test()
503 |
--------------------------------------------------------------------------------
/model_lib/efficientnet_pytorch/model_mtmr.py:
--------------------------------------------------------------------------------
1 | """model.py - Model and module class for EfficientNet.
2 | They are built to mirror those in the official TensorFlow implementation.
3 | """
4 |
5 | # Author: lukemelas (github username)
6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7 | # With adjustments and added comments by workingcoder (github username).
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 | from model_lib.efficientnet_pytorch.utils import (
13 | round_filters,
14 | round_repeats,
15 | drop_connect,
16 | get_same_padding_conv2d,
17 | get_model_params,
18 | efficientnet_params,
19 | load_pretrained_weights,
20 | Swish,
21 | MemoryEfficientSwish,
22 | calculate_output_image_size
23 | )
24 | from torch.autograd import Variable
25 |
26 | class MBConvBlock(nn.Module):
27 | """Mobile Inverted Residual Bottleneck Block.
28 |
29 | Args:
30 | block_args (namedtuple): BlockArgs, defined in utils.py.
31 | global_params (namedtuple): GlobalParam, defined in utils.py.
32 | image_size (tuple or list): [image_height, image_width].
33 |
34 | References:
35 | [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
36 | [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
37 | [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
38 | """
39 |
40 | def __init__(self, block_args, global_params, image_size=None):
41 | super().__init__()
42 | self._block_args = block_args
43 | self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
44 | self._bn_eps = global_params.batch_norm_epsilon
45 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
46 | self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
47 |
48 | # Expansion phase (Inverted Bottleneck)
49 | inp = self._block_args.input_filters # number of input channels
50 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
51 | if self._block_args.expand_ratio != 1:
52 | Conv2d = get_same_padding_conv2d(image_size=image_size)
53 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
54 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
55 | # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
56 |
57 | # Depthwise convolution phase
58 | k = self._block_args.kernel_size
59 | s = self._block_args.stride
60 | Conv2d = get_same_padding_conv2d(image_size=image_size)
61 | self._depthwise_conv = Conv2d(
62 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
63 | kernel_size=k, stride=s, bias=False)
64 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
65 | image_size = calculate_output_image_size(image_size, s)
66 |
67 | # Squeeze and Excitation layer, if desired
68 | if self.has_se:
69 | Conv2d = get_same_padding_conv2d(image_size=(1, 1))
70 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
71 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
72 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
73 |
74 | # Pointwise convolution phase
75 | final_oup = self._block_args.output_filters
76 | Conv2d = get_same_padding_conv2d(image_size=image_size)
77 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
78 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
79 | self._swish = MemoryEfficientSwish()
80 |
81 | def forward(self, inputs, drop_connect_rate=None):
82 | """MBConvBlock's forward function.
83 |
84 | Args:
85 | inputs (tensor): Input tensor.
86 | drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
87 |
88 | Returns:
89 | Output of this block after processing.
90 | """
91 |
92 | # Expansion and Depthwise Convolution
93 | x = inputs
94 | if self._block_args.expand_ratio != 1:
95 | x = self._expand_conv(inputs)
96 | x = self._bn0(x)
97 | x = self._swish(x)
98 |
99 | x = self._depthwise_conv(x)
100 | x = self._bn1(x)
101 | x = self._swish(x)
102 |
103 | # Squeeze and Excitation
104 | if self.has_se:
105 | x_squeezed = F.adaptive_avg_pool2d(x, 1)
106 | x_squeezed = self._se_reduce(x_squeezed)
107 | x_squeezed = self._swish(x_squeezed)
108 | x_squeezed = self._se_expand(x_squeezed)
109 | x = torch.sigmoid(x_squeezed) * x
110 |
111 | # Pointwise Convolution
112 | x = self._project_conv(x)
113 | x = self._bn2(x)
114 |
115 | # Skip connection and drop connect
116 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
117 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
118 | # The combination of skip connection and drop connect brings about stochastic depth.
119 | if drop_connect_rate:
120 | x = drop_connect(x, p=drop_connect_rate, training=self.training)
121 | x = x + inputs # skip connection
122 | return x
123 |
124 | def set_swish(self, memory_efficient=True):
125 | """Sets swish function as memory efficient (for training) or standard (for export).
126 |
127 | Args:
128 | memory_efficient (bool): Whether to use memory-efficient version of swish.
129 | """
130 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
131 |
132 |
133 | class EfficientNet(nn.Module):
134 | """EfficientNet model.
135 | Most easily loaded with the .from_name or .from_pretrained methods.
136 |
137 | Args:
138 | blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
139 | global_params (namedtuple): A set of GlobalParams shared between blocks.
140 |
141 | References:
142 | [1] https://arxiv.org/abs/1905.11946 (EfficientNet)
143 |
144 | # Example:
145 | # >>> import torch
146 | # >>> from efficientnet.model import EfficientNet
147 | # >>> inputs = torch.rand(1, 3, 224, 224)
148 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0')
149 | # >>> model.eval()
150 | # >>> outputs = model(inputs)
151 | """
152 |
153 | def __init__(self, task_mode='class', blocks_args=None, global_params=None):
154 | super().__init__()
155 | assert isinstance(blocks_args, list), 'blocks_args should be a list'
156 | assert len(blocks_args) > 0, 'block args must be greater than 0'
157 | self._global_params = global_params
158 | self._blocks_args = blocks_args
159 | self.task_mode = task_mode
160 |
161 | # Batch norm parameters
162 | bn_mom = 1 - self._global_params.batch_norm_momentum
163 | bn_eps = self._global_params.batch_norm_epsilon
164 |
165 | # Get stem static or dynamic convolution depending on image size
166 | image_size = global_params.image_size
167 | Conv2d = get_same_padding_conv2d(image_size=image_size)
168 |
169 | # Stem
170 | in_channels = 3 # rgb
171 | out_channels = round_filters(32, self._global_params) # number of output channels
172 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
173 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
174 | image_size = calculate_output_image_size(image_size, 2)
175 |
176 | # Build blocks
177 | self._blocks = nn.ModuleList([])
178 | for block_args in self._blocks_args:
179 |
180 | # Update block input and output filters based on depth multiplier.
181 | block_args = block_args._replace(
182 | input_filters=round_filters(block_args.input_filters, self._global_params),
183 | output_filters=round_filters(block_args.output_filters, self._global_params),
184 | num_repeat=round_repeats(block_args.num_repeat, self._global_params)
185 | )
186 |
187 | # The first block needs to take care of stride and filter size increase.
188 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
189 | image_size = calculate_output_image_size(image_size, block_args.stride)
190 | if block_args.num_repeat > 1: # modify block_args to keep same output size
191 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
192 | for _ in range(block_args.num_repeat - 1):
193 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
194 | # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
195 |
196 | # Head
197 | in_channels = block_args.output_filters # output of final block
198 | out_channels = round_filters(1280, self._global_params)
199 | Conv2d = get_same_padding_conv2d(image_size=image_size)
200 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
201 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
202 |
203 | # Final linear layer
204 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
205 | # self._dropout = nn.Dropout(self._global_params.dropout_rate)
206 | # self._fc = nn.Linear(out_channels, self._global_params.num_classes)
207 | # building classifier
208 | if self.task_mode in ['class', 'multi']:
209 | self.classifier_ = nn.Sequential(
210 | nn.Dropout(self._global_params.dropout_rate),
211 | nn.Linear(out_channels, self._global_params.num_classes),
212 | )
213 | if self.task_mode in ['regress', 'multi']:
214 | self.regressioner_ = nn.Sequential(
215 | nn.Dropout(self._global_params.dropout_rate),
216 | nn.Linear(out_channels, 1),
217 | )
218 | if self.task_mode in ['multi_mtmr',]:
219 | self.attribute_feature_fc = nn.Linear(out_channels, 256)
220 | self.regression_ = nn.Linear(256, 1)
221 | self.classifier_ = nn.Sequential(
222 | nn.Dropout(self._global_params.dropout_rate),
223 | nn.Linear(out_channels + 256, self._global_params.num_classes),
224 | )
225 | self._swish = MemoryEfficientSwish()
226 |
227 | # weight initialization
228 | for m in self.modules():
229 | if isinstance(m, nn.Conv2d):
230 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
231 | if m.bias is not None:
232 | nn.init.zeros_(m.bias)
233 | elif isinstance(m, nn.BatchNorm2d):
234 | nn.init.ones_(m.weight)
235 | nn.init.zeros_(m.bias)
236 | elif isinstance(m, nn.Linear):
237 | nn.init.normal_(m.weight, 0, 0.01)
238 | nn.init.zeros_(m.bias)
239 |
240 | def set_swish(self, memory_efficient=True):
241 | """Sets swish function as memory efficient (for training) or standard (for export).
242 |
243 | Args:
244 | memory_efficient (bool): Whether to use memory-efficient version of swish.
245 |
246 | """
247 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
248 | for block in self._blocks:
249 | block.set_swish(memory_efficient)
250 |
251 | def extract_endpoints(self, inputs):
252 | """Use convolution layer to extract features
253 | from reduction levels i in [1, 2, 3, 4, 5].
254 |
255 | Args:
256 | inputs (tensor): Input tensor.
257 |
258 | Returns:
259 | Dictionary of last intermediate features
260 | with reduction levels i in [1, 2, 3, 4, 5].
261 | Example:
262 | # >>> import torch
263 | # >>> from efficientnet.model import EfficientNet
264 | # >>> inputs = torch.rand(1, 3, 224, 224)
265 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0')
266 | # >>> endpoints = model.extract_features(inputs)
267 | # >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
268 | # >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
269 | # >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
270 | # >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
271 | # >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
272 | """
273 | endpoints = dict()
274 |
275 | # Stem
276 | x = self._swish(self._bn0(self._conv_stem(inputs)))
277 | prev_x = x
278 |
279 | # Blocks
280 | for idx, block in enumerate(self._blocks):
281 | drop_connect_rate = self._global_params.drop_connect_rate
282 | if drop_connect_rate:
283 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
284 | x = block(x, drop_connect_rate=drop_connect_rate)
285 | if prev_x.size(2) > x.size(2):
286 | endpoints[f'reduction_{len(endpoints) + 1}'] = prev_x
287 | prev_x = x
288 |
289 | # Head
290 | x = self._swish(self._bn1(self._conv_head(x)))
291 | endpoints[f'reduction_{len(endpoints) + 1}'] = x
292 |
293 | return endpoints
294 |
295 | def extract_features(self, inputs):
296 | """use convolution layer to extract feature .
297 |
298 | Args:
299 | inputs (tensor): Input tensor.
300 |
301 | Returns:
302 | Output of the final convolution
303 | layer in the efficientnet model.
304 | """
305 | # Stem
306 | x = self._swish(self._bn0(self._conv_stem(inputs)))
307 |
308 | # Blocks
309 | for idx, block in enumerate(self._blocks):
310 | drop_connect_rate = self._global_params.drop_connect_rate
311 | if drop_connect_rate:
312 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
313 | x = block(x, drop_connect_rate=drop_connect_rate)
314 |
315 | # Head
316 | x = self._swish(self._bn1(self._conv_head(x)))
317 |
318 | return x
319 |
320 | def forward_once(self, inputs):
321 | """EfficientNet's forward function.
322 | Calls extract_features to extract features, applies final linear layer, and returns logits.
323 |
324 | Args:
325 | inputs (tensor): Input tensor.
326 |
327 | Returns:
328 | Output of this model after processing.
329 | """
330 | # Convolution layers
331 | x = self.extract_features(inputs)
332 |
333 | # Pooling and final linear layer
334 | x = self._avg_pooling(x)
335 | x = x.flatten(start_dim=1)
336 | # x = self._dropout(x)
337 | # x = self._fc(x)
338 | # return x
339 |
340 | if self.task_mode == 'class':
341 | c_out = self.classifier_(x)
342 | return c_out
343 | elif self.task_mode == 'regress':
344 | r_out = self.regressioner_(x)
345 | return r_out[:, 0]
346 | elif self.task_mode == 'multi_mtmr':
347 | attribute_feature = self.attribute_feature_fc(x)
348 | r_out = self.regression_(attribute_feature)
349 | c_out = torch.cat([attribute_feature, x], dim=1)
350 | c_out = self.classifier_(c_out)
351 | return c_out, r_out[:, 0]
352 | elif self.task_mode == 'multi':
353 | c_out = self.classifier_(x)
354 | r_out = self.regressioner_(x)
355 | return c_out, r_out[:, 0]
356 | else:
357 | print(f'Do not support: {self.task_mode}'
358 | f'Only support one of [multi, class, and regress] task_mode')
359 |
360 | def forward(self, input):
361 | input_1 = input[0:int(input.shape[0]/2), :, :, :]
362 | input_2 = input[int(input.shape[0]/2):input.shape[0], :, :, :]
363 | output_1, attribute_score_1 = self.forward_once(input_1)
364 | output_2, attribute_score_2 = self.forward_once(input_2)
365 |
366 | cat_output = torch.cat([output_1, output_2])
367 | cat_subtlety_score = torch.cat([attribute_score_1, attribute_score_2])
368 |
369 | return cat_output, cat_subtlety_score
370 |
371 | @classmethod
372 | def from_name(cls, task_mode, model_name, in_channels=3, **override_params):
373 | """create an efficientnet model according to name.
374 |
375 | Args:
376 | task_mode (str): class, multi, regress
377 | model_name (str): Name for efficientnet.
378 | in_channels (int): Input data's channel number.
379 | override_params (other key word params):
380 | Params to override model's global_params.
381 | Optional key:
382 | 'width_coefficient', 'depth_coefficient',
383 | 'image_size', 'dropout_rate',
384 | 'num_classes', 'batch_norm_momentum',
385 | 'batch_norm_epsilon', 'drop_connect_rate',
386 | 'depth_divisor', 'min_depth'
387 |
388 | Returns:
389 | An efficientnet model.
390 | """
391 | cls._check_model_name_is_valid(model_name)
392 | blocks_args, global_params = get_model_params(model_name, override_params)
393 | model = cls(task_mode, blocks_args, global_params)
394 | model._change_in_channels(in_channels)
395 | return model
396 |
397 | @classmethod
398 | def from_pretrained(cls, task_mode, model_name, weights_path=None, advprop=False,
399 | in_channels=3, num_classes=1000, **override_params):
400 | """create an efficientnet model according to name.
401 |
402 | Args:
403 | task_mode (str): class, multi, regress
404 | model_name (str): Name for efficientnet.
405 | weights_path (None or str):
406 | str: path to pretrained weights file on the local disk.
407 | None: use pretrained weights downloaded from the Internet.
408 | advprop (bool):
409 | Whether to load pretrained weights
410 | trained with advprop (valid when weights_path is None).
411 | in_channels (int): Input data's channel number.
412 | num_classes (int):
413 | Number of categories for classification.
414 | It controls the output size for final linear layer.
415 | override_params (other key word params):
416 | Params to override model's global_params.
417 | Optional key:
418 | 'width_coefficient', 'depth_coefficient',
419 | 'image_size', 'dropout_rate',
420 | 'num_classes', 'batch_norm_momentum',
421 | 'batch_norm_epsilon', 'drop_connect_rate',
422 | 'depth_divisor', 'min_depth'
423 |
424 | Returns:
425 | A pretrained efficientnet model.
426 | """
427 | model = cls.from_name(task_mode, model_name, num_classes=num_classes, **override_params)
428 | load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000),
429 | advprop=advprop)
430 | model._change_in_channels(in_channels)
431 | return model
432 |
433 | @classmethod
434 | def get_image_size(cls, model_name):
435 | """Get the input image size for a given efficientnet model.
436 |
437 | Args:
438 | model_name (str): Name for efficientnet.
439 |
440 | Returns:
441 | Input image size (resolution).
442 | """
443 | cls._check_model_name_is_valid(model_name)
444 | _, _, res, _ = efficientnet_params(model_name)
445 | return res
446 |
447 | @classmethod
448 | def _check_model_name_is_valid(cls, model_name):
449 | """Validates model name.
450 |
451 | Args:
452 | model_name (str): Name for efficientnet.
453 |
454 | Returns:
455 | bool: Is a valid name or not.
456 | """
457 | valid_models = ['efficientnet-b' + str(i) for i in range(9)]
458 |
459 | # Support the construction of 'efficientnet-l2' without pretrained weights
460 | valid_models += ['efficientnet-l2']
461 |
462 | if model_name not in valid_models:
463 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
464 |
465 | def _change_in_channels(self, in_channels):
466 | """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
467 |
468 | Args:
469 | in_channels (int): Input data's channel number.
470 | """
471 | if in_channels != 3:
472 | Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
473 | out_channels = round_filters(32, self._global_params)
474 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
475 |
476 |
477 | def jl_efficientnet(task_mode='multi_mtmr', pretrained=True, num_classes=4, **kwargs):
478 | """
479 | Joint_learning efficient net
480 |
481 | Args:
482 | task_mode (string): multi, class, regress
483 | pretrained (bool): If True, returns a model pre-trained on ImageNet
484 | num_classes (int): number of class or number of output node
485 | """
486 | func = EfficientNet.from_pretrained if pretrained else EfficientNet.from_name
487 | model = func(task_mode=task_mode, model_name='efficientnet-b0', num_classes=num_classes)
488 | return model
489 |
490 |
491 |
492 | def get_loss_mtmr(output_score_1, cat_subtlety_score, gt_score_1, gt_attribute_score_1):
493 | xcentloss_func_1 = nn.CrossEntropyLoss()
494 | xcentloss_1 = xcentloss_func_1(output_score_1, gt_score_1)
495 |
496 | # ranking loss
497 | ranking_loss_sum = 0
498 | half_size_of_output_score = output_score_1.size()[0] // 2
499 | for i in range(half_size_of_output_score):
500 | tmp_output_1 = output_score_1[i]
501 | tmp_output_2 = output_score_1[i + half_size_of_output_score]
502 | tmp_gt_score_1 = gt_score_1[i]
503 | tmp_gt_score_2 = gt_score_1[i + half_size_of_output_score]
504 |
505 | rankingloss_func = nn.MarginRankingLoss()
506 |
507 | if tmp_gt_score_1.item() != tmp_gt_score_2.item():
508 | target = torch.ones(1) * -1
509 | ranking_loss_sum += rankingloss_func(tmp_output_1, tmp_output_2, Variable(target.cuda()))
510 | else:
511 | target = torch.ones(1)
512 | ranking_loss_sum += rankingloss_func(tmp_output_1, tmp_output_2, Variable(target.cuda()))
513 |
514 | ranking_loss = ranking_loss_sum / half_size_of_output_score
515 |
516 | # attribute loss
517 | attribute_mseloss_func_1 = nn.MSELoss()
518 | attribute_mseloss_1 = attribute_mseloss_func_1(cat_subtlety_score, gt_attribute_score_1.float())
519 |
520 | # loss = 0.6 * xcentloss_1 + 0.2 * ranking_loss + 0.2 * attribute_mseloss_1
521 | loss = 1 * xcentloss_1 + 5.0e-1 * ranking_loss + 1.0e-3 * attribute_mseloss_1
522 |
523 | return loss
524 | # def _test():
525 | # import os
526 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
527 | # net = jl_efficientnet(task_mode='multi_mtmr', pretrained=True, num_classes=3)
528 | # net = torch.nn.DataParallel(net).cuda()
529 | # y_class, y_regres = net(torch.randn(48, 3, 512, 512).cuda())
530 | # print(y_class.size(), y_regres.size())
531 | # # y_class = net(torch.randn(48, 3, 224, 224).cuda())
532 | # # print(y_class.size())
533 | #
534 | # # model = net.cuda()
535 | # # summary(model, (3, 224, 224))
536 | # _test()
537 |
--------------------------------------------------------------------------------
/model_lib/efficientnet_pytorch/model_rank_ordinal.py:
--------------------------------------------------------------------------------
1 | """model.py - Model and module class for EfficientNet.
2 | They are built to mirror those in the official TensorFlow implementation.
3 | """
4 |
5 | # Author: lukemelas (github username)
6 | # Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
7 | # With adjustments and added comments by workingcoder (github username).
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 | from model_lib.efficientnet_pytorch.utils import (
13 | round_filters,
14 | round_repeats,
15 | drop_connect,
16 | get_same_padding_conv2d,
17 | get_model_params,
18 | efficientnet_params,
19 | load_pretrained_weights,
20 | Swish,
21 | MemoryEfficientSwish,
22 | calculate_output_image_size
23 | )
24 | from torchsummary import summary
25 |
26 | class OrdinalRegressionLayer(nn.Module):
27 | def __init__(self):
28 | super(OrdinalRegressionLayer, self).__init__()
29 |
30 | def forward(self, x):
31 | """
32 | :param x: N x 2K x H x W; N - batch_size, 2K - channels, K - number of discrete sub-intervals
33 | :return: labels - ordinal labels (corresponding to discrete depth values) of size N x 1 x H x W
34 | softmax - predicted softmax probabilities P (as in the paper) of size N x K x H x W
35 | """
36 | N, K= x.size()
37 | K = K // 2 # number of discrete sub-intervals
38 |
39 | odd = x[:, ::2].clone()
40 | even = x[:, 1::2].clone()
41 |
42 | odd = odd.view(N, 1, K)
43 | even = even.view(N, 1, K)
44 |
45 | paired_channels = torch.cat((odd, even), dim=1)
46 | paired_channels = paired_channels.clamp(min=1e-8, max=1e8) # prevent nans
47 |
48 | softmax = F.softmax(paired_channels, dim=1)
49 |
50 | softmax = softmax[:, 1, :]
51 | softmax = softmax.view(-1, K)
52 | labels = torch.sum((softmax > 0.5), dim=1).view(-1, 1) - 1
53 | return labels[:, 0], softmax
54 |
55 |
56 | class MBConvBlock(nn.Module):
57 | """Mobile Inverted Residual Bottleneck Block.
58 |
59 | Args:
60 | block_args (namedtuple): BlockArgs, defined in utils.py.
61 | global_params (namedtuple): GlobalParam, defined in utils.py.
62 | image_size (tuple or list): [image_height, image_width].
63 |
64 | References:
65 | [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
66 | [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
67 | [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
68 | """
69 |
70 | def __init__(self, block_args, global_params, image_size=None):
71 | super().__init__()
72 | self._block_args = block_args
73 | self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
74 | self._bn_eps = global_params.batch_norm_epsilon
75 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
76 | self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
77 |
78 | # Expansion phase (Inverted Bottleneck)
79 | inp = self._block_args.input_filters # number of input channels
80 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
81 | if self._block_args.expand_ratio != 1:
82 | Conv2d = get_same_padding_conv2d(image_size=image_size)
83 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
84 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
85 | # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
86 |
87 | # Depthwise convolution phase
88 | k = self._block_args.kernel_size
89 | s = self._block_args.stride
90 | Conv2d = get_same_padding_conv2d(image_size=image_size)
91 | self._depthwise_conv = Conv2d(
92 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
93 | kernel_size=k, stride=s, bias=False)
94 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
95 | image_size = calculate_output_image_size(image_size, s)
96 |
97 | # Squeeze and Excitation layer, if desired
98 | if self.has_se:
99 | Conv2d = get_same_padding_conv2d(image_size=(1, 1))
100 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
101 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
102 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
103 |
104 | # Pointwise convolution phase
105 | final_oup = self._block_args.output_filters
106 | Conv2d = get_same_padding_conv2d(image_size=image_size)
107 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
108 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
109 | self._swish = MemoryEfficientSwish()
110 |
111 | def forward(self, inputs, drop_connect_rate=None):
112 | """MBConvBlock's forward function.
113 |
114 | Args:
115 | inputs (tensor): Input tensor.
116 | drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
117 |
118 | Returns:
119 | Output of this block after processing.
120 | """
121 |
122 | # Expansion and Depthwise Convolution
123 | x = inputs
124 | if self._block_args.expand_ratio != 1:
125 | x = self._expand_conv(inputs)
126 | x = self._bn0(x)
127 | x = self._swish(x)
128 |
129 | x = self._depthwise_conv(x)
130 | x = self._bn1(x)
131 | x = self._swish(x)
132 |
133 | # Squeeze and Excitation
134 | if self.has_se:
135 | x_squeezed = F.adaptive_avg_pool2d(x, 1)
136 | x_squeezed = self._se_reduce(x_squeezed)
137 | x_squeezed = self._swish(x_squeezed)
138 | x_squeezed = self._se_expand(x_squeezed)
139 | x = torch.sigmoid(x_squeezed) * x
140 |
141 | # Pointwise Convolution
142 | x = self._project_conv(x)
143 | x = self._bn2(x)
144 |
145 | # Skip connection and drop connect
146 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
147 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
148 | # The combination of skip connection and drop connect brings about stochastic depth.
149 | if drop_connect_rate:
150 | x = drop_connect(x, p=drop_connect_rate, training=self.training)
151 | x = x + inputs # skip connection
152 | return x
153 |
154 | def set_swish(self, memory_efficient=True):
155 | """Sets swish function as memory efficient (for training) or standard (for export).
156 |
157 | Args:
158 | memory_efficient (bool): Whether to use memory-efficient version of swish.
159 | """
160 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
161 |
162 |
163 | class EfficientNet(nn.Module):
164 | """EfficientNet model.
165 | Most easily loaded with the .from_name or .from_pretrained methods.
166 |
167 | Args:
168 | blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
169 | global_params (namedtuple): A set of GlobalParams shared between blocks.
170 |
171 | References:
172 | [1] https://arxiv.org/abs/1905.11946 (EfficientNet)
173 |
174 | # Example:
175 | # >>> import torch
176 | # >>> from efficientnet.model import EfficientNet
177 | # >>> inputs = torch.rand(1, 3, 224, 224)
178 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0')
179 | # >>> model.eval()
180 | # >>> outputs = model(inputs)
181 | """
182 |
183 | def __init__(self, task_mode='class', blocks_args=None, global_params=None):
184 | super().__init__()
185 | assert isinstance(blocks_args, list), 'blocks_args should be a list'
186 | assert len(blocks_args) > 0, 'block args must be greater than 0'
187 | self._global_params = global_params
188 | self._blocks_args = blocks_args
189 | self.task_mode = task_mode
190 |
191 | # Batch norm parameters
192 | bn_mom = 1 - self._global_params.batch_norm_momentum
193 | bn_eps = self._global_params.batch_norm_epsilon
194 |
195 | # Get stem static or dynamic convolution depending on image size
196 | image_size = global_params.image_size
197 | Conv2d = get_same_padding_conv2d(image_size=image_size)
198 |
199 | # Stem
200 | in_channels = 3 # rgb
201 | out_channels = round_filters(32, self._global_params) # number of output channels
202 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
203 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
204 | image_size = calculate_output_image_size(image_size, 2)
205 |
206 | # Build blocks
207 | self._blocks = nn.ModuleList([])
208 | for block_args in self._blocks_args:
209 |
210 | # Update block input and output filters based on depth multiplier.
211 | block_args = block_args._replace(
212 | input_filters=round_filters(block_args.input_filters, self._global_params),
213 | output_filters=round_filters(block_args.output_filters, self._global_params),
214 | num_repeat=round_repeats(block_args.num_repeat, self._global_params)
215 | )
216 |
217 | # The first block needs to take care of stride and filter size increase.
218 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
219 | image_size = calculate_output_image_size(image_size, block_args.stride)
220 | if block_args.num_repeat > 1: # modify block_args to keep same output size
221 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
222 | for _ in range(block_args.num_repeat - 1):
223 | self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
224 | # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
225 |
226 | # Head
227 | in_channels = block_args.output_filters # output of final block
228 | out_channels = round_filters(1280, self._global_params)
229 | Conv2d = get_same_padding_conv2d(image_size=image_size)
230 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
231 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
232 |
233 | # Final linear layer
234 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
235 | # self._dropout = nn.Dropout(self._global_params.dropout_rate)
236 | # self._fc = nn.Linear(out_channels, self._global_params.num_classes)
237 | # building classifier
238 | if self.task_mode in ['class', 'multi']:
239 | self.classifier_ = nn.Sequential(
240 | nn.Dropout(self._global_params.dropout_rate),
241 | nn.Linear(out_channels, self._global_params.num_classes),
242 | )
243 | if self.task_mode in ['regress', 'multi']:
244 | self.regressioner_ = nn.Sequential(
245 | nn.Dropout(self._global_params.dropout_rate),
246 | nn.Linear(out_channels, 1),
247 | )
248 | if self.task_mode in ['regress_rank_ordinal',]:
249 | self.regressioner_ = nn.Sequential(
250 | nn.Dropout(self._global_params.dropout_rate),
251 | nn.Linear(out_channels, (self._global_params.num_classes - 1) * 2),
252 | )
253 | if self.task_mode in ['regress_rank_dorn', ]:
254 | self.regressioner_ = nn.Sequential(
255 | nn.Dropout(self._global_params.dropout_rate),
256 | nn.Linear(out_channels, (self._global_params.num_classes - 1) * 2),
257 | )
258 | self.ordinal_regression = OrdinalRegressionLayer()
259 | self._swish = MemoryEfficientSwish()
260 |
261 | # weight initialization
262 | for m in self.modules():
263 | if isinstance(m, nn.Conv2d):
264 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
265 | if m.bias is not None:
266 | nn.init.zeros_(m.bias)
267 | elif isinstance(m, nn.BatchNorm2d):
268 | nn.init.ones_(m.weight)
269 | nn.init.zeros_(m.bias)
270 | elif isinstance(m, nn.Linear):
271 | nn.init.normal_(m.weight, 0, 0.01)
272 | nn.init.zeros_(m.bias)
273 |
274 | def set_swish(self, memory_efficient=True):
275 | """Sets swish function as memory efficient (for training) or standard (for export).
276 |
277 | Args:
278 | memory_efficient (bool): Whether to use memory-efficient version of swish.
279 |
280 | """
281 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
282 | for block in self._blocks:
283 | block.set_swish(memory_efficient)
284 |
285 | def extract_endpoints(self, inputs):
286 | """Use convolution layer to extract features
287 | from reduction levels i in [1, 2, 3, 4, 5].
288 |
289 | Args:
290 | inputs (tensor): Input tensor.
291 |
292 | Returns:
293 | Dictionary of last intermediate features
294 | with reduction levels i in [1, 2, 3, 4, 5].
295 | Example:
296 | # >>> import torch
297 | # >>> from efficientnet.model import EfficientNet
298 | # >>> inputs = torch.rand(1, 3, 224, 224)
299 | # >>> model = EfficientNet.from_pretrained('efficientnet-b0')
300 | # >>> endpoints = model.extract_features(inputs)
301 | # >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112])
302 | # >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56])
303 | # >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28])
304 | # >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14])
305 | # >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 1280, 7, 7])
306 | """
307 | endpoints = dict()
308 |
309 | # Stem
310 | x = self._swish(self._bn0(self._conv_stem(inputs)))
311 | prev_x = x
312 |
313 | # Blocks
314 | for idx, block in enumerate(self._blocks):
315 | drop_connect_rate = self._global_params.drop_connect_rate
316 | if drop_connect_rate:
317 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
318 | x = block(x, drop_connect_rate=drop_connect_rate)
319 | if prev_x.size(2) > x.size(2):
320 | endpoints[f'reduction_{len(endpoints) + 1}'] = prev_x
321 | prev_x = x
322 |
323 | # Head
324 | x = self._swish(self._bn1(self._conv_head(x)))
325 | endpoints[f'reduction_{len(endpoints) + 1}'] = x
326 |
327 | return endpoints
328 |
329 | def extract_features(self, inputs):
330 | """use convolution layer to extract feature .
331 |
332 | Args:
333 | inputs (tensor): Input tensor.
334 |
335 | Returns:
336 | Output of the final convolution
337 | layer in the efficientnet model.
338 | """
339 | # Stem
340 | x = self._swish(self._bn0(self._conv_stem(inputs)))
341 |
342 | # Blocks
343 | for idx, block in enumerate(self._blocks):
344 | drop_connect_rate = self._global_params.drop_connect_rate
345 | if drop_connect_rate:
346 | drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
347 | x = block(x, drop_connect_rate=drop_connect_rate)
348 |
349 | # Head
350 | x = self._swish(self._bn1(self._conv_head(x)))
351 |
352 | return x
353 |
354 | def forward(self, inputs):
355 | """EfficientNet's forward function.
356 | Calls extract_features to extract features, applies final linear layer, and returns logits.
357 |
358 | Args:
359 | inputs (tensor): Input tensor.
360 |
361 | Returns:
362 | Output of this model after processing.
363 | """
364 | # Convolution layers
365 | x = self.extract_features(inputs)
366 |
367 | # Pooling and final linear layer
368 | x = self._avg_pooling(x)
369 | x = x.flatten(start_dim=1)
370 | # x = self._dropout(x)
371 | # x = self._fc(x)
372 | # return x
373 |
374 | if self.task_mode == 'class':
375 | c_out = self.classifier_(x)
376 | return c_out
377 | elif self.task_mode == 'regress':
378 | r_out = self.regressioner_(x)
379 | return r_out[:, 0]
380 | elif self.task_mode == 'regress_rank_ordinal':
381 | r_out = self.regressioner_(x)
382 | r_out = r_out.view(-1, (self._global_params.num_classes - 1), 2)
383 | probas = F.softmax(r_out, dim=2)[:, :, 1]
384 | return r_out, probas
385 | elif self.task_mode in ['regress_rank_dorn', ]:
386 | r_out = self.regressioner_(x)
387 | predicts, softmax = self.ordinal_regression(r_out)
388 | return predicts, softmax
389 | elif self.task_mode == 'multi':
390 | c_out = self.classifier_(x)
391 | r_out = self.regressioner_(x)
392 | return c_out, r_out[:, 0]
393 | else:
394 | print(f'Do not support: {self.task_mode}'
395 | f'Only support one of [multi, class, and regress] task_mode')
396 |
397 | @classmethod
398 | def from_name(cls, task_mode, model_name, in_channels=3, **override_params):
399 | """create an efficientnet model according to name.
400 |
401 | Args:
402 | task_mode (str): class, multi, regress
403 | model_name (str): Name for efficientnet.
404 | in_channels (int): Input data's channel number.
405 | override_params (other key word params):
406 | Params to override model's global_params.
407 | Optional key:
408 | 'width_coefficient', 'depth_coefficient',
409 | 'image_size', 'dropout_rate',
410 | 'num_classes', 'batch_norm_momentum',
411 | 'batch_norm_epsilon', 'drop_connect_rate',
412 | 'depth_divisor', 'min_depth'
413 |
414 | Returns:
415 | An efficientnet model.
416 | """
417 | cls._check_model_name_is_valid(model_name)
418 | blocks_args, global_params = get_model_params(model_name, override_params)
419 | model = cls(task_mode, blocks_args, global_params)
420 | model._change_in_channels(in_channels)
421 | return model
422 |
423 | @classmethod
424 | def from_pretrained(cls, task_mode, model_name, weights_path=None, advprop=False,
425 | in_channels=3, num_classes=1000, **override_params):
426 | """create an efficientnet model according to name.
427 |
428 | Args:
429 | task_mode (str): class, multi, regress
430 | model_name (str): Name for efficientnet.
431 | weights_path (None or str):
432 | str: path to pretrained weights file on the local disk.
433 | None: use pretrained weights downloaded from the Internet.
434 | advprop (bool):
435 | Whether to load pretrained weights
436 | trained with advprop (valid when weights_path is None).
437 | in_channels (int): Input data's channel number.
438 | num_classes (int):
439 | Number of categories for classification.
440 | It controls the output size for final linear layer.
441 | override_params (other key word params):
442 | Params to override model's global_params.
443 | Optional key:
444 | 'width_coefficient', 'depth_coefficient',
445 | 'image_size', 'dropout_rate',
446 | 'num_classes', 'batch_norm_momentum',
447 | 'batch_norm_epsilon', 'drop_connect_rate',
448 | 'depth_divisor', 'min_depth'
449 |
450 | Returns:
451 | A pretrained efficientnet model.
452 | """
453 | model = cls.from_name(task_mode, model_name, num_classes=num_classes, **override_params)
454 | load_pretrained_weights(model, model_name, weights_path=weights_path, load_fc=(num_classes == 1000),
455 | advprop=advprop)
456 | model._change_in_channels(in_channels)
457 | return model
458 |
459 | @classmethod
460 | def get_image_size(cls, model_name):
461 | """Get the input image size for a given efficientnet model.
462 |
463 | Args:
464 | model_name (str): Name for efficientnet.
465 |
466 | Returns:
467 | Input image size (resolution).
468 | """
469 | cls._check_model_name_is_valid(model_name)
470 | _, _, res, _ = efficientnet_params(model_name)
471 | return res
472 |
473 | @classmethod
474 | def _check_model_name_is_valid(cls, model_name):
475 | """Validates model name.
476 |
477 | Args:
478 | model_name (str): Name for efficientnet.
479 |
480 | Returns:
481 | bool: Is a valid name or not.
482 | """
483 | valid_models = ['efficientnet-b' + str(i) for i in range(9)]
484 |
485 | # Support the construction of 'efficientnet-l2' without pretrained weights
486 | valid_models += ['efficientnet-l2']
487 |
488 | if model_name not in valid_models:
489 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
490 |
491 | def _change_in_channels(self, in_channels):
492 | """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
493 |
494 | Args:
495 | in_channels (int): Input data's channel number.
496 | """
497 | if in_channels != 3:
498 | Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
499 | out_channels = round_filters(32, self._global_params)
500 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
501 |
502 |
503 | def jl_efficientnet(task_mode='class', pretrained=True, num_classes=4, **kwargs):
504 | """
505 | Joint_learning efficient net
506 |
507 | Args:
508 | task_mode (string): multi, class, regress
509 | pretrained (bool): If True, returns a model pre-trained on ImageNet
510 | num_classes (int): number of class or number of output node
511 | """
512 | func = EfficientNet.from_pretrained if pretrained else EfficientNet.from_name
513 | model = func(task_mode=task_mode, model_name='efficientnet-b0', num_classes=num_classes)
514 | return model
515 |
516 |
517 | # def _test():
518 | # net = jl_efficientnet(task_mode='REGRESS_rank_ordinal', pretrained=True, num_classes=4).cuda()
519 | # y_class, y_regres = net(torch.randn(48, 3, 224, 224).cuda())
520 | # print(y_class.size(), y_regres.size())
521 | # # y_class = net(torch.randn(48, 3, 224, 224).cuda())
522 | # # print(y_class.size())
523 | #
524 | # # model = net.cuda()
525 | # # summary(model, (3, 224, 224))
526 | # _test()
527 |
528 | def label_to_levels(label, num_classes=4):
529 | levels = [1] * label + [0] * (num_classes - 1 - label)
530 | levels = torch.tensor(levels, dtype=torch.float32)
531 | return levels
532 |
533 |
534 | def labels_to_labels(class_labels):
535 | """
536 | class_labels = [2, 1, 3]
537 | """
538 | levels = []
539 | for label in class_labels:
540 | levels_from_label = label_to_levels(int(label), num_classes=4)
541 | levels.append(levels_from_label)
542 | return torch.stack(levels).cuda()
543 |
544 |
545 | def cost_fn(logits, label):
546 | num_classes = 4
547 | imp = torch.ones(num_classes - 1, dtype=torch.float).cuda()
548 | levels = labels_to_labels(label)
549 | val = (-torch.sum((F.log_softmax(logits, dim=2)[:, :, 1] * levels
550 | + F.log_softmax(logits, dim=2)[:, :, 0] * (1 - levels)) * imp, dim=1))
551 | return torch.mean(val)
552 |
553 |
554 | def loss_fn2(logits, label):
555 | num_classes = 4
556 | imp = torch.ones(num_classes - 1, dtype=torch.float)
557 | levels = labels_to_labels(label)
558 | val = (-torch.sum((F.logsigmoid(logits) * levels
559 | + (F.logsigmoid(logits) - logits) * (1 - levels)) * imp,
560 | dim=1))
561 | return torch.mean(val)
562 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | torch~=1.8.1+cu111
5 | numpy~=1.20.2
6 | opencv-python~=4.5.1.48
7 | scikit-learn~=0.24.1
8 | torchvision~=0.9.1+cu111
9 | matplotlib~=3.3.4
10 | pandas~=1.2.4
11 | imgaug~=0.4.0
12 | termcolor~=1.1.0
13 | scipy~=1.6.2
14 | torchsummary~=1.5.1
15 | ignite~=0.4.4
16 | tensorboardx~=2.2
--------------------------------------------------------------------------------
/scheduler_lr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/trinhvg/JCO_Learning-pytorch/6a6044801ac0a2f6ea2c479d4417e1128e5a0b72/scheduler_lr/__init__.py
--------------------------------------------------------------------------------
/scheduler_lr/warmup_cosine_lr.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 | import torch
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | class GradualWarmupScheduler(_LRScheduler):
8 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
9 | self.multiplier = multiplier
10 | self.total_epoch = total_epoch
11 | self.after_scheduler = after_scheduler
12 | self.finished = False
13 | super().__init__(optimizer)
14 |
15 | def get_lr(self):
16 | if self.last_epoch > self.total_epoch:
17 | if self.after_scheduler:
18 | if not self.finished:
19 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
20 | self.finished = True
21 | return self.after_scheduler.get_lr()
22 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
23 |
24 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in
25 | self.base_lrs]
26 |
27 | def step(self, epoch=None, metrics=None):
28 | if self.finished and self.after_scheduler:
29 | if epoch is None:
30 | self.after_scheduler.step(None)
31 | else:
32 | self.after_scheduler.step(epoch - self.total_epoch)
33 | else:
34 | return super(GradualWarmupScheduler, self).step(epoch)
35 |
36 |
37 | if __name__ == '__main__':
38 | v = torch.zeros(10)
39 | optim = torch.optim.SGD([v], lr=0.01)
40 | cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 100, eta_min=0, last_epoch=-1)
41 | scheduler = GradualWarmupScheduler(optim, multiplier=8, total_epoch=5, after_scheduler=cosine_scheduler)
42 | a = []
43 | b = []
44 | for epoch in range(1, 100):
45 | scheduler.step(epoch)
46 | a.append(epoch)
47 | b.append(optim.param_groups[0]['lr'])
48 | print(epoch, optim.param_groups[0]['lr'])
49 |
50 | plt.plot(a, b)
51 | plt.show()
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/trinhvg/JCO_Learning-pytorch/6a6044801ac0a2f6ea2c479d4417e1128e5a0b72/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/run_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #for run_info in CLASS_ce MULTI_ce_mse_ceo MULTI_ce_mse MULTI_ce_mae MULTI_ce_mae_ceo REGRESS_mae REGRESS_mse
3 | #for run_info in MULTI_ce_mse MULTI_ce_mae REGRESS_mae REGRESS_mse CLASS_FocalLoss MULTI_mtmr
4 | #for run_info in REGRESS_rank_ordinal REGRESS_FocalOrdinalLoss REGRESS_rank_dorn REGRESS_soft_ordinal REGRESS
5 |
6 | for run_info in MULTI_ce_mse_ceo MULTI_ce_mse MULTI_ce_mae MULTI_ce_mae_ceo
7 | do
8 | python train_test_all_cosin_lr_apply_to_cancer.py --run_info ${run_info} --seed 5 --gpu 0,1
9 | done
10 |
11 |
--------------------------------------------------------------------------------
/train_val.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 | import torch.utils.data as data
6 |
7 | from ignite.contrib.handlers import ProgressBar
8 | from ignite.engine import Engine, Events
9 | from ignite.handlers import ModelCheckpoint, Timer
10 | from ignite.metrics import RunningAverage
11 | from tensorboardX import SummaryWriter
12 | from imgaug import augmenters as iaa
13 | from misc.train_ultils_all_iter import *
14 | import importlib
15 |
16 |
17 | from loss.mtmr_loss import get_loss_mtmr
18 | from loss.rank_ordinal_loss import cost_fn
19 | from loss.dorn_loss import OrdinalLoss
20 | import dataset as dataset
21 | from config import Config
22 | from loss.ceo_loss import CEOLoss, FocalLoss, SoftLabelOrdinalLoss, FocalOrdinalLoss, count_pred
23 |
24 |
25 | ####
26 |
27 | class Trainer(Config):
28 | def __init__(self, _args=None):
29 | super(Trainer, self).__init__(_args=_args)
30 | if _args is not None:
31 | self.__dict__.update(_args.__dict__)
32 | print(self.run_info)
33 |
34 | ####
35 | def view_dataset(self, mode='train'):
36 | train_pairs, valid_pairs = getattr(dataset, ('prepare_%s_data' % self.dataset))()
37 | if mode == 'train':
38 | train_augmentors = self.train_augmentors()
39 | ds = dataset.DatasetSerial(train_pairs, has_aux=False,
40 | shape_augs=iaa.Sequential(train_augmentors[0]),
41 | input_augs=iaa.Sequential(train_augmentors[1]))
42 | else:
43 | infer_augmentors = self.infer_augmentors() # HACK
44 | ds = dataset.DatasetSerial(valid_pairs, has_aux=False,
45 | shape_augs=iaa.Sequential(infer_augmentors)[0])
46 | dataset.visualize(ds, 4)
47 | return
48 |
49 | ####
50 | def train_step(self, engine, net, batch, iters, scheduler, optimizer, device):
51 | net.train() # train mode
52 |
53 | imgs_cpu, true_cpu = batch
54 | imgs_cpu = imgs_cpu.permute(0, 3, 1, 2) # to NCHW
55 | scheduler.step(engine.state.epoch + engine.state.iteration / iters) # scheduler.step(epoch + i / iters)
56 | # push data to GPUs
57 | imgs = imgs_cpu.to(device).float()
58 | true = true_cpu.to(device).long() # not one-hot
59 |
60 | # -----------------------------------------------------------
61 | net.zero_grad() # not rnn so not accumulate
62 | out_net = net(imgs) # a list contains all the out put of the network
63 | loss = 0.
64 |
65 | # assign output
66 | if "CLASS" in self.task_type:
67 | logit_class = out_net
68 | if "REGRESS" in self.task_type:
69 | if ("rank_ordinal" in self.loss_type) or ("dorn" in self.loss_type):
70 | logit_regress, probas = out_net[0], out_net[1]
71 | else:
72 | logit_regress = out_net
73 | if "MULTI" in self.task_type:
74 | logit_class, logit_regress = out_net[0], out_net[1]
75 |
76 | # compute loss function
77 | if "ce" in self.loss_type:
78 | prob = F.softmax(logit_class, dim=-1)
79 | loss_entropy = F.cross_entropy(logit_class, true, reduction='mean')
80 | pred = torch.argmax(prob, dim=-1)
81 | loss += loss_entropy
82 | if 'FocalLoss' in self.loss_type:
83 | loss_focal = FocalLoss()(logit_class, true)
84 | prob = F.softmax(logit_class, dim=-1)
85 | pred = torch.argmax(prob, dim=-1)
86 | loss += loss_focal
87 |
88 | if "mse" in self.loss_type:
89 | criterion = torch.nn.MSELoss()
90 | loss_regres = criterion(logit_regress, true.float())
91 | loss += loss_regres
92 | if "REGRESS" in self.task_type:
93 | label = torch.tensor([0., 1., 2., 3.]).repeat(len(true), 1).permute(1, 0).cuda()
94 | pred = torch.argmin(torch.abs(logit_regress - label), 0)
95 | if "mae" in self.loss_type:
96 | criterion = torch.nn.L1Loss()
97 | loss_regres = criterion(logit_regress, true.float())
98 | loss += loss_regres
99 | if "REGRESS" in self.task_type:
100 | label = torch.tensor([0., 1., 2., 3.]).repeat(len(true), 1).permute(1, 0).cuda()
101 | pred = torch.argmin(torch.abs(logit_regress - label), 0)
102 | if "soft_label" in self.loss_type:
103 | criterion = SoftLabelOrdinalLoss(alpha=self.alpha)
104 | loss_regres = criterion(logit_regress, true.float())
105 | loss += loss_regres
106 | if "REGRESS" in self.task_type:
107 | label = torch.tensor([0., 1 / 3, 2 / 3, 1.]).repeat(len(true), 1).permute(1, 0).cuda()
108 | pred = torch.argmin(torch.abs(logit_regress - label), 0)
109 | if "FocalOrdinal" in self.loss_type:
110 | criterion = FocalOrdinalLoss(pooling=True)
111 | loss_regres = criterion(logit_regress, true.float())
112 | loss += loss_regres
113 | pred = count_pred(logit_regress)
114 | if "ceo" in self.loss_type:
115 | criterion = CEOLoss(num_classes=self.nr_classes)
116 | loss_ordinal = criterion(logit_regress, true)
117 | loss += loss_ordinal
118 | if "mtmr" in self.loss_type:
119 | loss = get_loss_mtmr(logit_class, logit_regress, true, true)
120 | prob = F.softmax(logit_class, dim=-1)
121 | pred = torch.argmax(prob, dim=-1)
122 | if "rank_coral" in self.loss_type:
123 | loss = cost_fn(logit_regress, true)
124 | predict_levels = probas > 0.5
125 | pred = torch.sum(predict_levels, dim=1)
126 | if "rank_dorn" in self.loss_type:
127 | pred, softmax = net(imgs) # forward
128 | loss = OrdinalLoss()(softmax, true)
129 |
130 | acc = torch.mean((pred == true).float()) # batch accuracy
131 | # gradient update
132 | loss.backward()
133 | optimizer.step()
134 |
135 | # -----------------------------------------------------------
136 | return dict(
137 | loss=loss.item(),
138 | acc=acc.item(),
139 | )
140 |
141 | ####
142 | def infer_step(self, net, batch, device):
143 | net.eval() # infer mode
144 |
145 | imgs, true = batch
146 | imgs = imgs.permute(0, 3, 1, 2) # to NCHW
147 |
148 | # push data to GPUs and convert to float32
149 | imgs = imgs.to(device).float()
150 | true = true.to(device).long() # not one-hot
151 |
152 | # -----------------------------------------------------------
153 | with torch.no_grad(): # dont compute gradient
154 | out_net = net(imgs) # a list contains all the out put of the network
155 | if "CLASS" in self.task_type:
156 | logit_class = out_net
157 | prob = nn.functional.softmax(logit_class, dim=-1)
158 | return dict(logit_c=prob.cpu().numpy(), # from now prob of class task is called by logit_c
159 | true=true.cpu().numpy())
160 |
161 | if "REGRESS" in self.task_type:
162 | if "rank_ordinal" in self.loss_type:
163 | logits, probas = out_net[0], out_net[1]
164 | predict_levels = probas > 0.5
165 | pred = torch.sum(predict_levels, dim=1)
166 | return dict(logit_r=pred.cpu().numpy(),
167 | true=true.cpu().numpy())
168 | if "rank_dorn" in self.loss_type:
169 | pred, softmax = net(imgs)
170 | return dict(logit_r=pred.cpu().numpy(),
171 | true=true.cpu().numpy())
172 | if "soft_label" in self.loss_type:
173 | logit_regress = (self.nr_classes - 1) * out_net
174 | return dict(logit_r=logit_regress.cpu().numpy(),
175 | true=true.cpu().numpy())
176 | if "FocalOrdinal" in self.loss_type:
177 | logit_regress = out_net
178 | pred = count_pred(logit_regress)
179 | return dict(logit_r=pred.cpu().numpy(),
180 | true=true.cpu().numpy())
181 | else:
182 | logit_regress = out_net
183 | return dict(logit_r=logit_regress.cpu().numpy(),
184 | true=true.cpu().numpy())
185 |
186 | if "MULTI" in self.task_type:
187 | logit_class, logit_regress = out_net[0], out_net[1]
188 | prob = nn.functional.softmax(logit_class, dim=-1)
189 | return dict(logit_c=prob.cpu().numpy(),
190 | logit_r=logit_regress.cpu().numpy(),
191 | true=true.cpu().numpy())
192 |
193 | ####
194 | def run_once(self, fold_idx):
195 |
196 | log_dir = self.log_dir
197 | check_manual_seed(self.seed)
198 | train_pairs, valid_pairs, test_pairs = getattr(dataset, ('prepare_%s_data' % self.dataset))(fold_idx)
199 | # --------------------------- Dataloader
200 |
201 | train_augmentors = self.train_augmentors()
202 | train_dataset = dataset.DatasetSerial(train_pairs[:], has_aux=False,
203 | shape_augs=iaa.Sequential(train_augmentors[0]),
204 | input_augs=iaa.Sequential(train_augmentors[1]))
205 |
206 | infer_augmentors = self.infer_augmentors() # HACK at has_aux
207 | infer_dataset = dataset.DatasetSerial(valid_pairs[:], has_aux=False,
208 | shape_augs=iaa.Sequential(infer_augmentors[0]))
209 | test_dataset = dataset.DatasetSerial(test_pairs[:], has_aux=False,
210 | shape_augs=iaa.Sequential(infer_augmentors[0]))
211 |
212 | train_loader = data.DataLoader(train_dataset,
213 | num_workers=self.nr_procs_train,
214 | batch_size=self.train_batch_size,
215 | shuffle=True, drop_last=True)
216 | valid_loader = data.DataLoader(infer_dataset,
217 | num_workers=self.nr_procs_valid,
218 | batch_size=self.infer_batch_size,
219 | shuffle=False, drop_last=False)
220 | test_loader = data.DataLoader(test_dataset,
221 | num_workers=self.nr_procs_valid,
222 | batch_size=self.infer_batch_size,
223 | shuffle=False, drop_last=False)
224 |
225 | # --------------------------- Training Sequence
226 |
227 | if self.logging:
228 | check_log_dir(log_dir)
229 |
230 | device = 'cuda'
231 |
232 | # Define your network here
233 | # # # # # Note: this code for EfficientNet B0
234 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model') # dynamic import
235 | if "FocalOrdinal" in self.loss_type:
236 | net = net_def.jl_efficientnet(task_mode='class', pretrained=True, num_classes=3)
237 |
238 | elif "rank_ordinal" in self.loss_type:
239 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_rank_ordinal') # dynamic import
240 | net = net_def.jl_efficientnet(task_mode='regress_rank_ordinal', pretrained=True)
241 |
242 | elif "mtmr" in self.loss_type:
243 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_mtmr') # dynamic import
244 | net = net_def.jl_efficientnet(task_mode='multi_mtmr', pretrained=True)
245 |
246 | elif "rank_dorn" in self.loss_type:
247 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_rank_ordinal') # dynamic import
248 | net = net_def.jl_efficientnet(task_mode='regress_rank_dorn', pretrained=True)
249 |
250 | else:
251 | net = net_def.jl_efficientnet(task_mode=self.task_type.lower(), pretrained=True)
252 |
253 |
254 | net = torch.nn.DataParallel(net).to(device)
255 | # optimizers
256 | optimizer = optim.Adam(net.parameters(), lr=self.init_lr)
257 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=self.nr_epochs // 3, T_mult=1,
258 | eta_min=self.init_lr * 0.1, last_epoch=-1)
259 |
260 | #
261 | iters = self.nr_epochs * self.epoch_length
262 | trainer = Engine(lambda engine, batch: self.train_step(engine, net, batch, iters, scheduler, optimizer, device))
263 | valider = Engine(lambda engine, batch: self.infer_step(net, batch, device))
264 | test = Engine(lambda engine, batch: self.infer_step(net, batch, device))
265 |
266 | # assign output
267 | if "CLASS" in self.task_type:
268 | infer_output = ['logit_c', 'true']
269 | if "REGRESS" in self.task_type:
270 | infer_output = ['logit_r', 'true']
271 | if "MULTI" in self.task_type:
272 | infer_output = ['logit_c', 'logit_r', 'pred_c', 'pred_r', 'true']
273 |
274 | ##
275 | events = Events.EPOCH_COMPLETED
276 | if self.logging:
277 | @trainer.on(events)
278 | def save_chkpoints(engine):
279 | torch.save(net.state_dict(), self.log_dir + '/_net_' + str(engine.state.iteration) + '.pth')
280 |
281 | timer = Timer(average=True)
282 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
283 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
284 | timer.attach(valider, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
285 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
286 | timer.attach(test, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
287 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
288 |
289 | # attach running average metrics computation
290 | # decay of EMA to 0.95 to match tensorpack default
291 | # TODO: refactor this
292 | RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach(trainer, 'acc')
293 | RunningAverage(alpha=0.95, output_transform=lambda x: x['loss']).attach(trainer, 'loss')
294 |
295 | # attach progress bar
296 | pbar = ProgressBar(persist=True)
297 | pbar.attach(trainer, metric_names=['loss'])
298 | pbar.attach(valider)
299 | pbar.attach(test)
300 |
301 | # writer for tensorboard logging
302 | tfwriter = None # HACK temporary
303 | if self.logging:
304 | tfwriter = SummaryWriter(logdir=log_dir)
305 | json_log_file = log_dir + '/stats.json'
306 | with open(json_log_file, 'w') as json_file:
307 | json.dump({}, json_file) # create empty file
308 |
309 | ### TODO refactor again
310 | log_info_dict = {
311 | 'logging': self.logging,
312 | 'optimizer': optimizer,
313 | 'tfwriter': tfwriter,
314 | 'json_file': json_log_file if self.logging else None,
315 | 'nr_classes': self.nr_classes,
316 | 'metric_names': infer_output,
317 | 'infer_batch_size': self.infer_batch_size # too cumbersome
318 | }
319 | trainer.add_event_handler(Events.EPOCH_COMPLETED,
320 | lambda engine: scheduler.step(engine.state.epoch - 1)) # to change the lr
321 | trainer.add_event_handler(Events.EPOCH_COMPLETED, log_train_ema_results, log_info_dict)
322 | trainer.add_event_handler(Events.EPOCH_COMPLETED, inference, valider, 'valid', valid_loader, log_info_dict)
323 | trainer.add_event_handler(Events.EPOCH_COMPLETED, inference, test, 'test', test_loader, log_info_dict)
324 | valider.add_event_handler(Events.ITERATION_COMPLETED, accumulate_outputs)
325 | test.add_event_handler(Events.ITERATION_COMPLETED, accumulate_outputs)
326 |
327 | # Setup is done. Now let's run the training
328 | # trainer.run(train_loader, self.nr_epochs)
329 | trainer.run(train_loader, self.nr_epochs, self.epoch_length)
330 | return
331 |
332 | ####
333 | def run(self):
334 | if self.cross_valid:
335 | for fold_idx in range(0, trainer.nr_fold):
336 | trainer.run_once(fold_idx)
337 | else:
338 | self.run_once(self.fold_idx)
339 | return
340 |
341 |
342 | ####
343 | if __name__ == '__main__':
344 | parser = argparse.ArgumentParser()
345 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
346 | parser.add_argument('--view', help='view dataset', action='store_true')
347 | parser.add_argument('--run_info', type=str, default='REGRESS_rank_dorn',
348 | help='CLASS, REGRESS, MULTI + loss, '
349 | 'loss ex: MULTI_mtmr, REGRESS_rank_ordinal, REGRESS_rank_dorn'
350 | 'REGRESS_FocalOrdinalLoss, REGRESS_soft_ordinal')
351 | parser.add_argument('--dataset', type=str, default='colon_tma', help='colon_tma, prostate_uhu')
352 | parser.add_argument('--seed', type=int, default=5, help='number')
353 | parser.add_argument('--alpha', type=int, default=5, help='number')
354 | args = parser.parse_args()
355 |
356 | trainer = Trainer(_args=args)
357 | if args.view:
358 | trainer.view_dataset()
359 | exit()
360 | if args.gpu:
361 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
362 | trainer.run()
363 |
--------------------------------------------------------------------------------
/train_val_ceo_for_cancer_only.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 | import torch.utils.data as data
6 |
7 | from ignite.contrib.handlers import ProgressBar
8 | from ignite.engine import Engine, Events
9 | from ignite.handlers import ModelCheckpoint, Timer
10 | from ignite.metrics import RunningAverage
11 | from tensorboardX import SummaryWriter
12 | from imgaug import augmenters as iaa
13 | from misc.train_ultils_all_iter import *
14 | from loss.cancer_loss import *
15 | # from misc.train_utils import *
16 | # from misc.focalloss_regression import *
17 |
18 | import importlib
19 | import dataset as dataset
20 | from config import Config
21 | from loss.ceo_loss import CEOLoss, FocalLoss, SoftLabelOrdinalLoss, FocalOrdinalLoss, count_pred
22 |
23 |
24 | ####
25 |
26 | class Trainer(Config):
27 | def __init__(self, _args=None):
28 | super(Trainer, self).__init__(_args=_args)
29 | if _args is not None:
30 | self.__dict__.update(_args.__dict__)
31 | print(self.run_info)
32 |
33 | ####
34 | def view_dataset(self, mode='train'):
35 | train_pairs, valid_pairs = getattr(dataset, ('prepare_%s_data' % self.dataset))()
36 | if mode == 'train':
37 | train_augmentors = self.train_augmentors()
38 | ds = dataset.DatasetSerial(train_pairs, has_aux=False,
39 | shape_augs=iaa.Sequential(train_augmentors[0]),
40 | input_augs=iaa.Sequential(train_augmentors[1]))
41 | else:
42 | infer_augmentors = self.infer_augmentors() # HACK
43 | ds = dataset.DatasetSerial(valid_pairs, has_aux=False,
44 | shape_augs=iaa.Sequential(infer_augmentors)[0])
45 | dataset.visualize(ds, 4)
46 | return
47 |
48 | ####
49 | def train_step(self, engine, net, batch, iters, scheduler, optimizer, device):
50 | net.train() # train mode
51 |
52 | imgs_cpu, true_cpu = batch
53 | imgs_cpu = imgs_cpu.permute(0, 3, 1, 2) # to NCHW
54 | scheduler.step(engine.state.epoch + engine.state.iteration / iters) # scheduler.step(epoch + i / iters)
55 | # push data to GPUs
56 | imgs = imgs_cpu.to(device).float()
57 | true = true_cpu.to(device).long() # not one-hot
58 |
59 | # -----------------------------------------------------------
60 | net.zero_grad() # not rnn so not accumulate
61 | out_net = net(imgs) # a list contains all the out put of the network
62 | loss = 0.
63 |
64 | # assign output
65 | if "CLASS" in self.task_type:
66 | logit_class = out_net
67 | if "REGRESS" in self.task_type:
68 | logit_regress = out_net
69 | if "MULTI" in self.task_type:
70 | logit_class, logit_regress = out_net[0], out_net[1]
71 |
72 | # compute loss function
73 | if "ce" in self.loss_type:
74 | prob = F.softmax(logit_class, dim=-1)
75 | loss_entropy = F.cross_entropy(logit_class, true, reduction='mean')
76 | pred = torch.argmax(prob, dim=-1)
77 | loss += loss_entropy
78 | if 'FocalLoss' in self.loss_type:
79 | loss_focal = FocalLoss()(logit_class, true)
80 | prob = F.softmax(logit_class, dim=-1)
81 | pred = torch.argmax(prob, dim=-1)
82 | loss += loss_focal
83 |
84 | if "mse" in self.loss_type:
85 | loss += mse_cancer_v0(logit_regress, true.float())
86 | # criterion = torch.nn.MSELoss()
87 | # loss_regres = criterion(logit_regress, true.float())
88 | # loss += loss_regres
89 | if "REGRESS" in self.task_type:
90 | label = torch.tensor(np.arange(self.nr_classes)).float().repeat(len(true), 1).permute(1, 0).cuda()
91 | pred = torch.argmin(torch.abs(logit_regress - label), 0)
92 | if "mae" in self.loss_type:
93 | loss += mae_cancer_v0(logit_regress, true.float())
94 | # criterion = torch.nn.L1Loss()
95 | # loss_regres = criterion(logit_regress, true.float())
96 | # loss += loss_regres
97 | if "REGRESS" in self.task_type:
98 | label = torch.tensor(np.arange(self.nr_classes)).float().repeat(len(true), 1).permute(1, 0).cuda()
99 | pred = torch.argmin(torch.abs(logit_regress - label), 0)
100 | if "ceo" in self.loss_type: # ceo when conduct only on cancer sample
101 | loss += ceo_cancer_v0(logit_regress, true)
102 |
103 | acc = torch.mean((pred == true).float()) # batch accuracy
104 | # gradient update
105 | loss.backward()
106 | optimizer.step()
107 |
108 | # -----------------------------------------------------------
109 | return dict(
110 | loss=loss.item(),
111 | acc=acc.item(),
112 | )
113 |
114 | ####
115 | def infer_step(self, net, batch, device):
116 | net.eval() # infer mode
117 |
118 | imgs, true = batch
119 | imgs = imgs.permute(0, 3, 1, 2) # to NCHW
120 |
121 | # push data to GPUs and convert to float32
122 | imgs = imgs.to(device).float()
123 | true = true.to(device).long() # not one-hot
124 |
125 | # -----------------------------------------------------------
126 | with torch.no_grad(): # dont compute gradient
127 | out_net = net(imgs) # a list contains all the out put of the network
128 | if "CLASS" in self.task_type:
129 | logit_class = out_net
130 | prob = nn.functional.softmax(logit_class, dim=-1)
131 | return dict(logit_c=prob.cpu().numpy(), # from now prob of class task is called by logit_c
132 | true=true.cpu().numpy())
133 |
134 | if "REGRESS" in self.task_type:
135 | if "rank_ordinal" in self.loss_type:
136 | logits, probas = out_net[0], out_net[1]
137 | predict_levels = probas > 0.5
138 | pred = torch.sum(predict_levels, dim=1)
139 | return dict(logit_r=pred.cpu().numpy(),
140 | true=true.cpu().numpy())
141 | if "rank_dorn" in self.loss_type:
142 | pred, softmax = net(imgs)
143 | return dict(logit_r=pred.cpu().numpy(),
144 | true=true.cpu().numpy())
145 | if "soft_ordinal" in self.loss_type:
146 | logit_regress = (self.nr_classes - 1) * out_net
147 | return dict(logit_r=logit_regress.cpu().numpy(),
148 | true=true.cpu().numpy())
149 | if "FocalOrdinalLoss" in self.loss_type:
150 | logit_regress = out_net
151 | pred = count_pred(logit_regress)
152 | return dict(logit_r=pred.cpu().numpy(),
153 | true=true.cpu().numpy())
154 | else:
155 | logit_regress = out_net
156 | return dict(logit_r=logit_regress.cpu().numpy(),
157 | true=true.cpu().numpy())
158 |
159 | if "MULTI" in self.task_type:
160 | logit_class, logit_regress = out_net[0], out_net[1]
161 | prob = nn.functional.softmax(logit_class, dim=-1)
162 | return dict(logit_c=prob.cpu().numpy(),
163 | logit_r=logit_regress.cpu().numpy(),
164 | true=true.cpu().numpy())
165 |
166 | ####
167 | def run_once(self, fold_idx):
168 |
169 | log_dir = self.log_dir
170 | check_manual_seed(self.seed)
171 | train_pairs, valid_pairs, test_pairs = getattr(dataset, ('prepare_%s_data' % self.dataset))(fold_idx)
172 | # --------------------------- Dataloader
173 |
174 | train_augmentors = self.train_augmentors()
175 | train_dataset = dataset.DatasetSerial(train_pairs[:], has_aux=False,
176 | shape_augs=iaa.Sequential(train_augmentors[0]),
177 | input_augs=iaa.Sequential(train_augmentors[1]))
178 |
179 | infer_augmentors = self.infer_augmentors() # HACK at has_aux
180 | infer_dataset = dataset.DatasetSerial(valid_pairs[:], has_aux=False,
181 | shape_augs=iaa.Sequential(infer_augmentors[0]))
182 | test_dataset = dataset.DatasetSerial(test_pairs[:], has_aux=False,
183 | shape_augs=iaa.Sequential(infer_augmentors[0]))
184 |
185 | train_loader = data.DataLoader(train_dataset,
186 | num_workers=self.nr_procs_train,
187 | batch_size=self.train_batch_size,
188 | shuffle=True, drop_last=True)
189 | valid_loader = data.DataLoader(infer_dataset,
190 | num_workers=self.nr_procs_valid,
191 | batch_size=self.infer_batch_size,
192 | shuffle=False, drop_last=False)
193 | test_loader = data.DataLoader(test_dataset,
194 | num_workers=self.nr_procs_valid,
195 | batch_size=self.infer_batch_size,
196 | shuffle=False, drop_last=False)
197 |
198 | # --------------------------- Training Sequence
199 |
200 | if self.logging:
201 | check_log_dir(log_dir)
202 |
203 | device = 'cuda'
204 |
205 | # Define your network here
206 | # # # # # Note: this code for EfficientNet B0
207 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model') # dynamic import
208 | if "FocalOrdinalLoss" in self.loss_type:
209 | net = net_def.jl_efficientnet(task_mode='class', pretrained=True, num_classes=3)
210 |
211 | elif "rank_ordinal" in self.loss_type:
212 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_rank_ordinal') # dynamic import
213 | net = net_def.jl_efficientnet(task_mode='regress_rank_ordinal', pretrained=True)
214 |
215 | elif "mtmr" in self.loss_type:
216 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_mtmr') # dynamic import
217 | net = net_def.jl_efficientnet(task_mode='multi_mtmr', pretrained=True)
218 |
219 | elif "rank_dorn" in self.loss_type:
220 | net_def = importlib.import_module('model_lib.efficientnet_pytorch.model_rank_ordinal') # dynamic import
221 | net = net_def.jl_efficientnet(task_mode='regress_rank_dorn', pretrained=True)
222 |
223 | else:
224 | net = net_def.jl_efficientnet(task_mode=self.task_type.lower(), pretrained=True)
225 |
226 |
227 | net = torch.nn.DataParallel(net).to(device)
228 | # optimizers
229 | optimizer = optim.Adam(net.parameters(), lr=self.init_lr)
230 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=self.nr_epochs // 3, T_mult=1,
231 | eta_min=self.init_lr * 0.1, last_epoch=-1)
232 |
233 | #
234 | iters = self.nr_epochs * self.epoch_length
235 | trainer = Engine(lambda engine, batch: self.train_step(engine, net, batch, iters, scheduler, optimizer, device))
236 | valider = Engine(lambda engine, batch: self.infer_step(net, batch, device))
237 | test = Engine(lambda engine, batch: self.infer_step(net, batch, device))
238 |
239 | # assign output
240 | if "CLASS" in self.task_type:
241 | infer_output = ['logit_c', 'true']
242 | if "REGRESS" in self.task_type:
243 | infer_output = ['logit_r', 'true']
244 | if "MULTI" in self.task_type:
245 | infer_output = ['logit_c', 'logit_r', 'pred_c', 'pred_r', 'true']
246 |
247 | ##
248 | events = Events.EPOCH_COMPLETED
249 | if self.logging:
250 | @trainer.on(events)
251 | def save_chkpoints(engine):
252 | torch.save(net.state_dict(), self.log_dir + '/_net_' + str(engine.state.iteration) + '.pth')
253 |
254 | timer = Timer(average=True)
255 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
256 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
257 | timer.attach(valider, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
258 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
259 | timer.attach(test, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
260 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
261 |
262 | # attach running average metrics computation
263 | # decay of EMA to 0.95 to match tensorpack default
264 | # TODO: refactor this
265 | RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach(trainer, 'acc')
266 | RunningAverage(alpha=0.95, output_transform=lambda x: x['loss']).attach(trainer, 'loss')
267 |
268 | # attach progress bar
269 | pbar = ProgressBar(persist=True)
270 | pbar.attach(trainer, metric_names=['loss'])
271 | pbar.attach(valider)
272 | pbar.attach(test)
273 |
274 | # writer for tensorboard logging
275 | tfwriter = None # HACK temporary
276 | if self.logging:
277 | tfwriter = SummaryWriter(logdir=log_dir)
278 | json_log_file = log_dir + '/stats.json'
279 | with open(json_log_file, 'w') as json_file:
280 | json.dump({}, json_file) # create empty file
281 |
282 | ### TODO refactor again
283 | log_info_dict = {
284 | 'logging': self.logging,
285 | 'optimizer': optimizer,
286 | 'tfwriter': tfwriter,
287 | 'json_file': json_log_file if self.logging else None,
288 | 'nr_classes': self.nr_classes,
289 | 'metric_names': infer_output,
290 | 'infer_batch_size': self.infer_batch_size # too cumbersome
291 | }
292 | trainer.add_event_handler(Events.EPOCH_COMPLETED,
293 | lambda engine: scheduler.step(engine.state.epoch - 1)) # to change the lr
294 | trainer.add_event_handler(Events.EPOCH_COMPLETED, log_train_ema_results, log_info_dict)
295 | trainer.add_event_handler(Events.EPOCH_COMPLETED, inference, valider, 'valid', valid_loader, log_info_dict)
296 | trainer.add_event_handler(Events.EPOCH_COMPLETED, inference, test, 'test', test_loader, log_info_dict)
297 | valider.add_event_handler(Events.ITERATION_COMPLETED, accumulate_outputs)
298 | test.add_event_handler(Events.ITERATION_COMPLETED, accumulate_outputs)
299 |
300 | # Setup is done. Now let's run the training
301 | # trainer.run(train_loader, self.nr_epochs)
302 | trainer.run(train_loader, self.nr_epochs, self.epoch_length)
303 | return
304 |
305 | ####
306 | def run(self):
307 | if self.cross_valid:
308 | for fold_idx in range(0, trainer.nr_fold):
309 | trainer.run_once(fold_idx)
310 | else:
311 | self.run_once(self.fold_idx)
312 | return
313 |
314 |
315 | ####
316 | if __name__ == '__main__':
317 | parser = argparse.ArgumentParser()
318 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
319 | parser.add_argument('--view', help='view dataset', action='store_true')
320 | parser.add_argument('--run_info', type=str, default='REGRESS_rank_dorn',
321 | help='CLASS, REGRESS, MULTI + loss, '
322 | 'loss ex: MULTI_mtmr, REGRESS_rank_ordinal, REGRESS_rank_dorn'
323 | 'REGRESS_FocalOrdinalLoss, REGRESS_soft_ordinal')
324 | parser.add_argument('--dataset', type=str, default='colon_tma', help='colon_set1, prostate_set1')
325 | parser.add_argument('--seed', type=int, default=5, help='number')
326 | parser.add_argument('--alpha', type=int, default=5, help='number')
327 | args = parser.parse_args()
328 |
329 | trainer = Trainer(_args=args)
330 | if args.view:
331 | trainer.view_dataset()
332 | exit()
333 | if args.gpu:
334 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
335 | trainer.run()
336 |
--------------------------------------------------------------------------------