├── results └── Readme.txt ├── LICENSE ├── dataset └── npy_datasets.py ├── loader.py ├── README.md ├── test.py ├── train_continue.py ├── engine.py ├── train.py ├── configs └── config_setting.py ├── utils.py └── models └── HF_UNet.py /results/Readme.txt: -------------------------------------------------------------------------------- 1 | Address of the result file. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 wurenkai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataset/npy_datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | 6 | 7 | class NPY_datasets(Dataset): 8 | def __init__(self, path_Data, config, train=True,test=False): 9 | super(NPY_datasets, self) 10 | if train: 11 | images_list = os.listdir(path_Data+'train/images/') 12 | masks_list = os.listdir(path_Data+'train/masks/') 13 | self.data = [] 14 | for i in range(len(images_list)): 15 | img_path = path_Data+'train/images/' + images_list[i] 16 | mask_path = path_Data+'train/masks/' + masks_list[i] 17 | self.data.append([img_path, mask_path]) 18 | self.transformer = config.train_transformer 19 | elif test: 20 | images_list = os.listdir(path_Data+'test/images/') 21 | masks_list = os.listdir(path_Data+'test/masks/') 22 | self.data = [] 23 | for i in range(len(images_list)): 24 | img_path = path_Data+'test/images/' + images_list[i] 25 | mask_path = path_Data+'test/masks/' + masks_list[i] 26 | self.data.append([img_path, mask_path]) 27 | self.transformer = config.test_transformer 28 | else: 29 | images_list = os.listdir(path_Data+'val/images/') 30 | masks_list = os.listdir(path_Data+'val/masks/') 31 | self.data = [] 32 | for i in range(len(images_list)): 33 | img_path = path_Data+'val/images/' + images_list[i] 34 | mask_path = path_Data+'val/masks/' + masks_list[i] 35 | self.data.append([img_path, mask_path]) 36 | self.transformer = config.val_transformer 37 | 38 | def __getitem__(self, indx): 39 | img_path, msk_path = self.data[indx] 40 | img = np.array(Image.open(img_path).convert('RGB')) 41 | msk = np.expand_dims(np.array(Image.open(msk_path).convert('L')), axis=2) / 255 42 | img, msk = self.transformer((img, msk)) 43 | return img, msk 44 | 45 | def __len__(self): 46 | return len(self.data) 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torch 3 | import numpy as np 4 | import random 5 | import os 6 | from PIL import Image 7 | from einops.layers.torch import Rearrange 8 | from scipy.ndimage.morphology import binary_dilation 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | from scipy import ndimage 12 | from utils import * 13 | 14 | 15 | # ===== normalize over the dataset 16 | def dataset_normalized(imgs): 17 | imgs_normalized = np.empty(imgs.shape) 18 | imgs_std = np.std(imgs) 19 | imgs_mean = np.mean(imgs) 20 | imgs_normalized = (imgs-imgs_mean)/imgs_std 21 | for i in range(imgs.shape[0]): 22 | imgs_normalized[i] = ((imgs_normalized[i] - np.min(imgs_normalized[i])) / (np.max(imgs_normalized[i])-np.min(imgs_normalized[i])))*255 23 | return imgs_normalized 24 | 25 | 26 | ## Temporary 27 | class isic_loader(Dataset): 28 | """ 29 | """ 30 | def __init__(self, path_Data, train = True, Test = False): 31 | super(isic_loader, self) 32 | self.train = train 33 | if train: 34 | self.data = np.load(path_Data+'data_train.npy') 35 | self.mask = np.load(path_Data+'mask_train.npy') 36 | else: 37 | if Test: 38 | self.data = np.load(path_Data+'data_test.npy') 39 | self.mask = np.load(path_Data+'mask_test.npy') 40 | else: 41 | self.data = np.load(path_Data+'data_val.npy') 42 | self.mask = np.load(path_Data+'mask_val.npy') 43 | 44 | self.data = dataset_normalized(self.data) 45 | self.mask = np.expand_dims(self.mask, axis=3) 46 | self.mask = self.mask/255. 47 | 48 | def __getitem__(self, indx): 49 | img = self.data[indx] 50 | seg = self.mask[indx] 51 | if self.train: 52 | if random.random() > 0.5: 53 | img, seg = self.random_rot_flip(img, seg) 54 | if random.random() > 0.5: 55 | img, seg = self.random_rotate(img, seg) 56 | 57 | seg = torch.tensor(seg.copy()) 58 | img = torch.tensor(img.copy()) 59 | img = img.permute( 2, 0, 1) 60 | seg = seg.permute( 2, 0, 1) 61 | 62 | return img, seg 63 | 64 | def random_rot_flip(self,image, label): 65 | k = np.random.randint(0, 4) 66 | image = np.rot90(image, k) 67 | label = np.rot90(label, k) 68 | axis = np.random.randint(0, 2) 69 | image = np.flip(image, axis=axis).copy() 70 | label = np.flip(label, axis=axis).copy() 71 | return image, label 72 | 73 | def random_rotate(self,image, label): 74 | angle = np.random.randint(20, 80) 75 | image = ndimage.rotate(image, angle, order=0, reshape=False) 76 | label = ndimage.rotate(label, angle, order=0, reshape=False) 77 | return image, label 78 | 79 | 80 | 81 | def __len__(self): 82 | return len(self.data) 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
4 | Chenghao Jiang, Renkai Wu, Yinghao Liu, Yue Wang, Qing Chang, Pengchen Liang*, and Yuan Fan* 5 |
6 |7 | 1. Nanjing Medical University, Nanjing, China 8 | 2. Shanghai University, Shanghai, China 9 | 3. The Affiliated Stomatological Hospital of Nanjing Medical University, Nanjing, China 10 | 4. Jiangsu Province Engineering Research Center of Stomatological Translational Medicine, Nanjing, China 11 | 5. University of Shanghai for Science and Technology, Shanghai, China 12 |
13 | 14 | 15 | ## Examples of proposed dataset segmentation tasks 16 | 17 | https://github.com/wurenkai/HF-UNet-and-Autooral-dataset/assets/124028634/de265270-450c-4395-bb97-0cf9d048c601 18 | 19 | 20 | ## Examples of proposed dataset classification tasks 21 | 22 | https://github.com/wurenkai/HF-UNet-and-Autooral-dataset/assets/124028634/94c56a9f-0130-4a96-8cfa-a2e5048e504b 23 | 24 | **0. Main Environments.** 25 | - python 3.8 26 | - pytorch 1.12.0 27 | 28 | **1. The proposed datasets (Autooral dataset).** 29 | (1) The Autooral dataset is available [here](https://drive.google.com/file/d/1n29L25N4H0XFfyWxle95PqS6lyQwNv64/view?usp=sharing). It should be noted: 30 | 1. If you use the dataset, please cite the paper: https://www.nature.com/articles/s41598-024-69125-9 31 | 2. The Autooral dataset may only be used for academic research, not for commercial purposes. 32 | 3. If you can, please give us a like (Starred) for our GitHub project: https://github.com/wurenkai/HF-UNet-and-Autooral-dataset 33 | 34 | (2) After getting the Autooral dataset, execute 'Prepare_Autooral.py' for preprocessing to generate the npy file. We also provide annotations for categorization to provide more richness to the study. 35 | 36 | **2. Train the HF-UNet.** 37 | Modify the dataset address in the config_setting.py file to the address where the npy is stored after preprocessing. Then, perform the following operation: 38 | ``` 39 | python train.py 40 | ``` 41 | - After trianing, you could obtain the outputs in './results/' 42 | 43 | **3. Test the HF-UNet.** 44 | First, in the test.py file, you should change the address of the checkpoint in 'resume_model' and fill in the location of the test data in 'data_path'. 45 | ``` 46 | python test.py 47 | ``` 48 | - After testing, you could obtain the outputs in './results/' 49 | 50 | ## Citation 51 | If you find this repository helpful, please consider citing: 52 | ``` 53 | @article{jiang2024high, 54 | title={A high-order focus interaction model and oral ulcer dataset for oral ulcer segmentation}, 55 | author={Jiang, Chenghao and Wu, Renkai and Liu, Yinghao and Wang, Yue and Chang, Qing and Liang, Pengchen and Fan, Yuan}, 56 | journal={Scientific Reports}, 57 | volume={14}, 58 | number={1}, 59 | pages={20085}, 60 | year={2024}, 61 | publisher={Nature Publishing Group UK London} 62 | } 63 | ``` 64 | 65 | ## References 66 | - [MHorUNet](https://github.com/wurenkai/MHorUNet) 67 | - [HorNet](https://github.com/raoyongming/HorNet) 68 | --- 69 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.cuda.amp import autocast, GradScaler 4 | from torch.utils.data import DataLoader 5 | from loader import * 6 | 7 | from models.HF_UNet import HFUNet 8 | from dataset.npy_datasets import NPY_datasets 9 | from engine import * 10 | import os 11 | import sys 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3" 13 | 14 | from utils import * 15 | from configs.config_setting import setting_config 16 | 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | 21 | 22 | def main(config): 23 | 24 | print('#----------Creating logger----------#') 25 | sys.path.append(config.work_dir + '/') 26 | log_dir = os.path.join(config.work_dir, 'log') 27 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 28 | resume_model = os.path.join(r'results/HF_UNet.pth') 29 | outputs = os.path.join(config.work_dir, 'outputs') 30 | if not os.path.exists(checkpoint_dir): 31 | os.makedirs(checkpoint_dir) 32 | if not os.path.exists(outputs): 33 | os.makedirs(outputs) 34 | 35 | global logger 36 | logger = get_logger('train', log_dir) 37 | 38 | log_config_info(config, logger) 39 | 40 | 41 | 42 | 43 | 44 | print('#----------GPU init----------#') 45 | set_seed(config.seed) 46 | gpu_ids = [0]# [0, 1, 2, 3] 47 | torch.cuda.empty_cache() 48 | 49 | 50 | 51 | 52 | 53 | 54 | print('#----------Preparing dataset----------#') 55 | data_path = './dataset/Autooral_dataset' 56 | test_dataset = isic_loader(path_Data = data_path, train = False, Test = True) 57 | test_loader = DataLoader(test_dataset, 58 | batch_size=1, 59 | shuffle=False, 60 | pin_memory=True, 61 | num_workers=config.num_workers, 62 | drop_last=True) 63 | 64 | 65 | 66 | 67 | 68 | print('#----------Prepareing Models----------#') 69 | model_cfg = config.model_config 70 | model = HFUNet(num_classes=model_cfg['num_classes'], 71 | input_channels=model_cfg['input_channels'], 72 | c_list=model_cfg['c_list'], 73 | split_att=model_cfg['split_att'], 74 | bridge=model_cfg['bridge'], 75 | drop_path_rate=model_cfg['drop_path_rate']) 76 | 77 | model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0]) 78 | 79 | 80 | 81 | 82 | 83 | print('#----------Prepareing loss, opt, sch and amp----------#') 84 | criterion = config.criterion 85 | optimizer = get_optimizer(config, model) 86 | scheduler = get_scheduler(config, optimizer) 87 | scaler = GradScaler() 88 | 89 | 90 | 91 | 92 | 93 | print('#----------Set other params----------#') 94 | min_loss = 999 95 | start_epoch = 1 96 | min_epoch = 1 97 | 98 | 99 | 100 | 101 | 102 | 103 | print('#----------Testing----------#') 104 | best_weight = torch.load(resume_model, map_location=torch.device('cpu')) 105 | model.module.load_state_dict(best_weight['model_state_dict']) 106 | loss = test_one_epoch( 107 | test_loader, 108 | model, 109 | criterion, 110 | logger, 111 | config, 112 | ) 113 | 114 | 115 | 116 | if __name__ == '__main__': 117 | config = setting_config 118 | main(config) -------------------------------------------------------------------------------- /train_continue.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.cuda.amp import autocast, GradScaler 4 | from torch.utils.data import DataLoader 5 | from loader import * 6 | 7 | from models.HF_UNet import HFUNet 8 | from dataset.npy_datasets import NPY_datasets 9 | from engine import * 10 | import os 11 | import sys 12 | 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3" 14 | 15 | from utils import * 16 | from configs.config_setting import setting_config 17 | 18 | import warnings 19 | 20 | warnings.filterwarnings("ignore") 21 | 22 | 23 | def main(config): 24 | print('#----------Creating logger----------#') 25 | sys.path.append(config.work_dir + '/') 26 | log_dir = os.path.join(config.work_dir, 'log') 27 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 28 | resume_model = os.path.join(checkpoint_dir, 'latest.pth') 29 | outputs = os.path.join(config.work_dir, 'outputs') 30 | if not os.path.exists(checkpoint_dir): 31 | os.makedirs(checkpoint_dir) 32 | if not os.path.exists(outputs): 33 | os.makedirs(outputs) 34 | 35 | global logger 36 | logger = get_logger('train', log_dir) 37 | 38 | log_config_info(config, logger) 39 | 40 | print('#----------GPU init----------#') 41 | set_seed(config.seed) 42 | gpu_ids = [0] # [0, 1, 2, 3] 43 | torch.cuda.empty_cache() 44 | 45 | print('#----------Preparing dataset----------#') 46 | train_dataset = isic_loader(path_Data=config.data_path, train=True) 47 | train_loader = DataLoader(train_dataset, 48 | batch_size=config.batch_size, 49 | shuffle=True, 50 | pin_memory=True, 51 | num_workers=config.num_workers) 52 | val_dataset = isic_loader(path_Data=config.data_path, train=False) 53 | val_loader = DataLoader(val_dataset, 54 | batch_size=1, 55 | shuffle=False, 56 | pin_memory=True, 57 | num_workers=config.num_workers, 58 | drop_last=True) 59 | test_dataset = isic_loader(path_Data=config.data_path, train=False, Test=True) 60 | test_loader = DataLoader(test_dataset, 61 | batch_size=1, 62 | shuffle=False, 63 | pin_memory=True, 64 | num_workers=config.num_workers, 65 | drop_last=True) 66 | 67 | print('#----------Prepareing Models----------#') 68 | model_cfg = config.model_config 69 | model = HFUNet(num_classes=model_cfg['num_classes'], 70 | input_channels=model_cfg['input_channels'], 71 | c_list=model_cfg['c_list'], 72 | split_att=model_cfg['split_att'], 73 | bridge=model_cfg['bridge'], 74 | drop_path_rate=model_cfg['drop_path_rate']) 75 | 76 | model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0]) 77 | 78 | print('#----------Prepareing loss, opt, sch and amp----------#') 79 | criterion = config.criterion 80 | optimizer = get_optimizer(config, model) 81 | scheduler = get_scheduler(config, optimizer) 82 | scaler = GradScaler() 83 | 84 | print('#----------Set other params----------#') 85 | min_loss = 999 86 | start_epoch = 1 87 | min_epoch = 1 88 | 89 | print('#----------Resume Model and Other params----------#') 90 | checkpoint = torch.load(r'results/HF_UNet.pth', map_location=torch.device('cpu')) 91 | model.module.load_state_dict(checkpoint['model_state_dict']) 92 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 93 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 94 | #saved_epoch = checkpoint['epoch'] 95 | #start_epoch += saved_epoch 96 | min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss'] 97 | 98 | 99 | print('#----------Training----------#') 100 | for epoch in range(start_epoch, config.epochs + 1): 101 | 102 | torch.cuda.empty_cache() 103 | 104 | train_one_epoch( 105 | train_loader, 106 | model, 107 | criterion, 108 | optimizer, 109 | scheduler, 110 | epoch, 111 | logger, 112 | config, 113 | scaler=scaler 114 | ) 115 | 116 | loss = val_one_epoch( 117 | val_loader, 118 | model, 119 | criterion, 120 | epoch, 121 | logger, 122 | config 123 | ) 124 | 125 | if loss < min_loss: 126 | torch.save(model.module.state_dict(), os.path.join(checkpoint_dir, 'best.pth')) 127 | min_loss = loss 128 | min_epoch = epoch 129 | 130 | torch.save( 131 | { 132 | 'epoch': epoch, 133 | 'min_loss': min_loss, 134 | 'min_epoch': min_epoch, 135 | 'loss': loss, 136 | 'model_state_dict': model.module.state_dict(), 137 | 'optimizer_state_dict': optimizer.state_dict(), 138 | 'scheduler_state_dict': scheduler.state_dict(), 139 | }, os.path.join(checkpoint_dir, 'latest.pth')) 140 | 141 | if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')): 142 | print('#----------Testing_best----------#') 143 | best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu')) 144 | model.module.load_state_dict(best_weight) 145 | loss = test_one_epoch( 146 | test_loader, 147 | model, 148 | criterion, 149 | logger, 150 | config, 151 | ) 152 | os.rename( 153 | os.path.join(checkpoint_dir, 'best.pth'), 154 | os.path.join(checkpoint_dir, f'best-epoch{min_epoch}-loss{min_loss:.4f}.pth') 155 | ) 156 | 157 | if __name__ == '__main__': 158 | config = setting_config 159 | main(config) -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from torch.cuda.amp import autocast as autocast 5 | from sklearn.metrics import confusion_matrix 6 | from utils import save_imgs 7 | 8 | 9 | def train_one_epoch(train_loader, 10 | model, 11 | criterion, 12 | optimizer, 13 | scheduler, 14 | epoch, 15 | logger, 16 | config, 17 | scaler=None): 18 | ''' 19 | train model for one epoch 20 | ''' 21 | # switch to train mode 22 | model.train() 23 | 24 | loss_list = [] 25 | 26 | for iter, data in enumerate(train_loader): 27 | optimizer.zero_grad() 28 | images, targets = data 29 | images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float() 30 | if config.amp: 31 | with autocast(): 32 | out = model(images) 33 | loss = criterion(out, targets) 34 | scaler.scale(loss).backward() 35 | scaler.step(optimizer) 36 | scaler.update() 37 | else: 38 | out = model(images) 39 | loss = criterion(out, targets) 40 | loss.backward() 41 | optimizer.step() 42 | 43 | loss_list.append(loss.item()) 44 | 45 | now_lr = optimizer.state_dict()['param_groups'][0]['lr'] 46 | if iter % config.print_interval == 0: 47 | log_info = f'train: epoch {epoch}, iter:{iter}, loss: {np.mean(loss_list):.4f}, lr: {now_lr}' 48 | print(log_info) 49 | logger.info(log_info) 50 | scheduler.step() 51 | 52 | 53 | def val_one_epoch(test_loader, 54 | model, 55 | criterion, 56 | epoch, 57 | logger, 58 | config): 59 | # switch to evaluate mode 60 | model.eval() 61 | preds = [] 62 | gts = [] 63 | loss_list = [] 64 | with torch.no_grad(): 65 | for data in tqdm(test_loader): 66 | img, msk = data 67 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() 68 | out = model(img) 69 | loss = criterion(out, msk) 70 | loss_list.append(loss.item()) 71 | gts.append(msk.squeeze(1).cpu().detach().numpy()) 72 | if type(out) is tuple: 73 | out = out[0] 74 | out = out.squeeze(1).cpu().detach().numpy() 75 | preds.append(out) 76 | 77 | if epoch % config.val_interval == 0: 78 | preds = np.array(preds).reshape(-1) 79 | gts = np.array(gts).reshape(-1) 80 | 81 | y_pre = np.where(preds>=config.threshold, 1, 0) 82 | y_true = np.where(gts>=0.5, 1, 0) 83 | 84 | confusion = confusion_matrix(y_true, y_pre) 85 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] 86 | 87 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 88 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 89 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 90 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 91 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 92 | 93 | log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}, miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ 94 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' 95 | print(log_info) 96 | logger.info(log_info) 97 | 98 | else: 99 | log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}' 100 | print(log_info) 101 | logger.info(log_info) 102 | 103 | return np.mean(loss_list) 104 | 105 | 106 | def test_one_epoch(test_loader, 107 | model, 108 | criterion, 109 | logger, 110 | config, 111 | test_data_name=None): 112 | # switch to evaluate mode 113 | model.eval() 114 | preds = [] 115 | gts = [] 116 | loss_list = [] 117 | with torch.no_grad(): 118 | for i, data in enumerate(tqdm(test_loader)): 119 | img, msk = data 120 | img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float() 121 | out = model(img) 122 | loss = criterion(out, msk) 123 | loss_list.append(loss.item()) 124 | msk = msk.squeeze(1).cpu().detach().numpy() 125 | gts.append(msk) 126 | if type(out) is tuple: 127 | out = out[0] 128 | out = out.squeeze(1).cpu().detach().numpy() 129 | preds.append(out) 130 | save_imgs(img, msk, out, i, config.work_dir + 'outputs/', config.datasets, config.threshold, test_data_name=test_data_name) 131 | 132 | preds = np.array(preds).reshape(-1) 133 | gts = np.array(gts).reshape(-1) 134 | 135 | y_pre = np.where(preds>=config.threshold, 1, 0) 136 | y_true = np.where(gts>=0.5, 1, 0) 137 | 138 | confusion = confusion_matrix(y_true, y_pre) 139 | TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1] 140 | 141 | accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0 142 | sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0 143 | specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0 144 | f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0 145 | miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0 146 | 147 | if test_data_name is not None: 148 | log_info = f'test_datasets_name: {test_data_name}' 149 | print(log_info) 150 | logger.info(log_info) 151 | log_info = f'test of best model, loss: {np.mean(loss_list):.4f},miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \ 152 | specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}' 153 | print(log_info) 154 | logger.info(log_info) 155 | 156 | return np.mean(loss_list) 157 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.cuda.amp import autocast, GradScaler 4 | from torch.utils.data import DataLoader 5 | from loader import * 6 | 7 | from models.HF_UNet import HFUNet 8 | from dataset.npy_datasets import NPY_datasets 9 | from engine import * 10 | import os 11 | import sys 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3" 13 | 14 | from utils import * 15 | from configs.config_setting import setting_config 16 | 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | 21 | 22 | def main(config): 23 | 24 | print('#----------Creating logger----------#') 25 | sys.path.append(config.work_dir + '/') 26 | log_dir = os.path.join(config.work_dir, 'log') 27 | checkpoint_dir = os.path.join(config.work_dir, 'checkpoints') 28 | resume_model = os.path.join(checkpoint_dir, 'latest.pth') 29 | outputs = os.path.join(config.work_dir, 'outputs') 30 | if not os.path.exists(checkpoint_dir): 31 | os.makedirs(checkpoint_dir) 32 | if not os.path.exists(outputs): 33 | os.makedirs(outputs) 34 | 35 | global logger 36 | logger = get_logger('train', log_dir) 37 | 38 | log_config_info(config, logger) 39 | 40 | 41 | 42 | 43 | 44 | print('#----------GPU init----------#') 45 | set_seed(config.seed) 46 | gpu_ids = [0]# [0, 1, 2, 3] 47 | torch.cuda.empty_cache() 48 | 49 | 50 | 51 | 52 | 53 | print('#----------Preparing dataset----------#') 54 | train_dataset = isic_loader(path_Data = config.data_path, train = True) 55 | train_loader = DataLoader(train_dataset, 56 | batch_size=config.batch_size, 57 | shuffle=True, 58 | pin_memory=True, 59 | num_workers=config.num_workers) 60 | val_dataset = isic_loader(path_Data = config.data_path, train = False) 61 | val_loader = DataLoader(val_dataset, 62 | batch_size=1, 63 | shuffle=False, 64 | pin_memory=True, 65 | num_workers=config.num_workers, 66 | drop_last=True) 67 | test_dataset = isic_loader(path_Data = config.data_path, train = False, Test = True) 68 | test_loader = DataLoader(test_dataset, 69 | batch_size=1, 70 | shuffle=False, 71 | pin_memory=True, 72 | num_workers=config.num_workers, 73 | drop_last=True) 74 | 75 | 76 | 77 | 78 | print('#----------Prepareing Models----------#') 79 | model_cfg = config.model_config 80 | model = HFUNet(num_classes=model_cfg['num_classes'], 81 | input_channels=model_cfg['input_channels'], 82 | c_list=model_cfg['c_list'], 83 | split_att=model_cfg['split_att'], 84 | bridge=model_cfg['bridge'], 85 | drop_path_rate=model_cfg['drop_path_rate']) 86 | 87 | model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0]) 88 | 89 | 90 | 91 | 92 | 93 | print('#----------Prepareing loss, opt, sch and amp----------#') 94 | criterion = config.criterion 95 | optimizer = get_optimizer(config, model) 96 | scheduler = get_scheduler(config, optimizer) 97 | scaler = GradScaler() 98 | 99 | 100 | 101 | 102 | 103 | print('#----------Set other params----------#') 104 | min_loss = 999 105 | start_epoch = 1 106 | min_epoch = 1 107 | 108 | 109 | 110 | 111 | 112 | if os.path.exists(resume_model): 113 | print('#----------Resume Model and Other params----------#') 114 | checkpoint = torch.load(resume_model, map_location=torch.device('cpu')) 115 | model.module.load_state_dict(checkpoint['model_state_dict']) 116 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 117 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 118 | saved_epoch = checkpoint['epoch'] 119 | start_epoch += saved_epoch 120 | min_loss, min_epoch, loss = checkpoint['min_loss'], checkpoint['min_epoch'], checkpoint['loss'] 121 | 122 | log_info = f'resuming model from {resume_model}. resume_epoch: {saved_epoch}, min_loss: {min_loss:.4f}, min_epoch: {min_epoch}, loss: {loss:.4f}' 123 | logger.info(log_info) 124 | 125 | 126 | 127 | 128 | 129 | print('#----------Training----------#') 130 | for epoch in range(start_epoch, config.epochs + 1): 131 | 132 | torch.cuda.empty_cache() 133 | 134 | train_one_epoch( 135 | train_loader, 136 | model, 137 | criterion, 138 | optimizer, 139 | scheduler, 140 | epoch, 141 | logger, 142 | config, 143 | scaler=scaler 144 | ) 145 | 146 | loss = val_one_epoch( 147 | val_loader, 148 | model, 149 | criterion, 150 | epoch, 151 | logger, 152 | config 153 | ) 154 | 155 | 156 | if loss < min_loss: 157 | torch.save(model.module.state_dict(), os.path.join(checkpoint_dir, 'best.pth')) 158 | min_loss = loss 159 | min_epoch = epoch 160 | 161 | torch.save( 162 | { 163 | 'epoch': epoch, 164 | 'min_loss': min_loss, 165 | 'min_epoch': min_epoch, 166 | 'loss': loss, 167 | 'model_state_dict': model.module.state_dict(), 168 | 'optimizer_state_dict': optimizer.state_dict(), 169 | 'scheduler_state_dict': scheduler.state_dict(), 170 | }, os.path.join(checkpoint_dir, 'latest.pth')) 171 | 172 | if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')): 173 | print('#----------Testing----------#') 174 | best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu')) 175 | model.module.load_state_dict(best_weight) 176 | loss = test_one_epoch( 177 | test_loader, 178 | model, 179 | criterion, 180 | logger, 181 | config, 182 | ) 183 | os.rename( 184 | os.path.join(checkpoint_dir, 'best.pth'), 185 | os.path.join(checkpoint_dir, f'best-epoch{min_epoch}-loss{min_loss:.4f}.pth') 186 | ) 187 | 188 | if __name__ == '__main__': 189 | config = setting_config 190 | main(config) -------------------------------------------------------------------------------- /configs/config_setting.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from utils import * 3 | 4 | from datetime import datetime 5 | 6 | class setting_config: 7 | """ 8 | the config of training setting. 9 | """ 10 | network = 'HFUNet' 11 | model_config = { 12 | 'num_classes': 1, 13 | 'input_channels': 3, 14 | 'c_list': [32, 64, 128, 256, 512, 1024], 15 | 'split_att': 'fc', 16 | 'bridge': True, 17 | 'drop_path_rate':0.4 18 | 19 | } 20 | 21 | test_weights = '' 22 | 23 | datasets = 'Autooral' 24 | if datasets == 'Autooral': 25 | data_path = 'dataset/Autooral_dataset/' 26 | else: 27 | raise Exception('datasets in not right!') 28 | 29 | criterion = BceDiceLoss() 30 | 31 | num_classes = 1 32 | input_size_h = 256 33 | input_size_w = 256 34 | input_channels = 3 35 | distributed = False 36 | local_rank = -1 37 | num_workers = 0 38 | seed = 42 39 | world_size = None 40 | rank = None 41 | amp = False 42 | batch_size = 8 43 | epochs = 250 44 | 45 | work_dir = 'results/' + network + '_' + datasets + '_' + datetime.now().strftime('%A_%d_%B_%Y_%Hh_%Mm_%Ss') + '/' 46 | 47 | print_interval = 20 48 | val_interval = 30 49 | save_interval = 100 50 | threshold = 0.5 51 | 52 | train_transformer = transforms.Compose([ 53 | myNormalize(datasets, train=True), 54 | myToTensor(), 55 | myRandomHorizontalFlip(p=0.5), 56 | myRandomVerticalFlip(p=0.5), 57 | myRandomRotation(p=0.5, degree=[0, 360]), 58 | myResize(input_size_h, input_size_w) 59 | ]) 60 | val_transformer = transforms.Compose([ 61 | myNormalize(datasets, train=False), 62 | myToTensor(), 63 | myResize(input_size_h, input_size_w) 64 | ]) 65 | test_transformer = transforms.Compose([ 66 | myToTensor(), 67 | myResize(input_size_h, input_size_w) 68 | ]) 69 | 70 | opt = 'AdamW' 71 | assert opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' 72 | if opt == 'Adadelta': 73 | lr = 0.01 # default: 1.0 – coefficient that scale delta before it is applied to the parameters 74 | rho = 0.9 # default: 0.9 – coefficient used for computing a running average of squared gradients 75 | eps = 1e-6 # default: 1e-6 – term added to the denominator to improve numerical stability 76 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 77 | elif opt == 'Adagrad': 78 | lr = 0.01 # default: 0.01 – learning rate 79 | lr_decay = 0 # default: 0 – learning rate decay 80 | eps = 1e-10 # default: 1e-10 – term added to the denominator to improve numerical stability 81 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 82 | elif opt == 'Adam': 83 | lr = 0.001 # default: 1e-3 – learning rate 84 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 85 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 86 | weight_decay = 0.0001 # default: 0 – weight decay (L2 penalty) 87 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 88 | elif opt == 'AdamW': 89 | lr = 0.001 # default: 1e-3 – learning rate 90 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 91 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 92 | weight_decay = 1e-2 # default: 1e-2 – weight decay coefficient 93 | amsgrad = False # default: False – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond 94 | elif opt == 'Adamax': 95 | lr = 2e-3 # default: 2e-3 – learning rate 96 | betas = (0.9, 0.999) # default: (0.9, 0.999) – coefficients used for computing running averages of gradient and its square 97 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 98 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 99 | elif opt == 'ASGD': 100 | lr = 0.01 # default: 1e-2 – learning rate 101 | lambd = 1e-4 # default: 1e-4 – decay term 102 | alpha = 0.75 # default: 0.75 – power for eta update 103 | t0 = 1e6 # default: 1e6 – point at which to start averaging 104 | weight_decay = 0 # default: 0 – weight decay 105 | elif opt == 'RMSprop': 106 | lr = 1e-2 # default: 1e-2 – learning rate 107 | momentum = 0 # default: 0 – momentum factor 108 | alpha = 0.99 # default: 0.99 – smoothing constant 109 | eps = 1e-8 # default: 1e-8 – term added to the denominator to improve numerical stability 110 | centered = False # default: False – if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance 111 | weight_decay = 0 # default: 0 – weight decay (L2 penalty) 112 | elif opt == 'Rprop': 113 | lr = 1e-2 # default: 1e-2 – learning rate 114 | etas = (0.5, 1.2) # default: (0.5, 1.2) – pair of (etaminus, etaplis), that are multiplicative increase and decrease factors 115 | step_sizes = (1e-6, 50) # default: (1e-6, 50) – a pair of minimal and maximal allowed step sizes 116 | elif opt == 'SGD': 117 | lr = 0.01 # – learning rate 118 | momentum = 0.9 # default: 0 – momentum factor 119 | weight_decay = 0.05 # default: 0 – weight decay (L2 penalty) 120 | dampening = 0 # default: 0 – dampening for momentum 121 | nesterov = False # default: False – enables Nesterov momentum 122 | 123 | sch = 'CosineAnnealingLR' 124 | if sch == 'StepLR': 125 | step_size = epochs // 5 # – Period of learning rate decay. 126 | gamma = 0.5 # – Multiplicative factor of learning rate decay. Default: 0.1 127 | last_epoch = -1 # – The index of last epoch. Default: -1. 128 | elif sch == 'MultiStepLR': 129 | milestones = [60, 120, 150] # – List of epoch indices. Must be increasing. 130 | gamma = 0.1 # – Multiplicative factor of learning rate decay. Default: 0.1. 131 | last_epoch = -1 # – The index of last epoch. Default: -1. 132 | elif sch == 'ExponentialLR': 133 | gamma = 0.99 # – Multiplicative factor of learning rate decay. 134 | last_epoch = -1 # – The index of last epoch. Default: -1. 135 | elif sch == 'CosineAnnealingLR': 136 | T_max = 50 # – Maximum number of iterations. Cosine function period. 137 | eta_min = 0.00001 # – Minimum learning rate. Default: 0. 138 | last_epoch = -1 # – The index of last epoch. Default: -1. 139 | elif sch == 'ReduceLROnPlateau': 140 | mode = 'min' # – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’. 141 | factor = 0.1 # – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1. 142 | patience = 10 # – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience = 2, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn’t improved then. Default: 10. 143 | threshold = 0.0001 # – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. 144 | threshold_mode = 'rel' # – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’. 145 | cooldown = 0 # – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0. 146 | min_lr = 0 # – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0. 147 | eps = 1e-08 # – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8. 148 | elif sch == 'CosineAnnealingWarmRestarts': 149 | T_0 = 50 # – Number of iterations for the first restart. 150 | T_mult = 2 # – A factor increases T_{i} after a restart. Default: 1. 151 | eta_min = 1e-6 # – Minimum learning rate. Default: 0. 152 | last_epoch = -1 # – The index of last epoch. Default: -1. 153 | elif sch == 'WP_MultiStepLR': 154 | warm_up_epochs = 10 155 | gamma = 0.1 156 | milestones = [125, 225] 157 | elif sch == 'WP_CosineLR': 158 | warm_up_epochs = 20 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.backends.cudnn as cudnn 5 | import torchvision.transforms.functional as TF 6 | import numpy as np 7 | import os 8 | import math 9 | import random 10 | import logging 11 | import logging.handlers 12 | from matplotlib import pyplot as plt 13 | 14 | 15 | def set_seed(seed): 16 | # for hash 17 | os.environ['PYTHONHASHSEED'] = str(seed) 18 | # for python and numpy 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | # for cpu gpu 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | # for cudnn 26 | cudnn.benchmark = False 27 | cudnn.deterministic = True 28 | 29 | 30 | def get_logger(name, log_dir): 31 | ''' 32 | Args: 33 | name(str): name of logger 34 | log_dir(str): path of log 35 | ''' 36 | 37 | if not os.path.exists(log_dir): 38 | os.makedirs(log_dir) 39 | 40 | logger = logging.getLogger(name) 41 | logger.setLevel(logging.INFO) 42 | 43 | info_name = os.path.join(log_dir, '{}.info.log'.format(name)) 44 | info_handler = logging.handlers.TimedRotatingFileHandler(info_name, 45 | when='D', 46 | encoding='utf-8') 47 | info_handler.setLevel(logging.INFO) 48 | 49 | formatter = logging.Formatter('%(asctime)s - %(message)s', 50 | datefmt='%Y-%m-%d %H:%M:%S') 51 | 52 | info_handler.setFormatter(formatter) 53 | 54 | logger.addHandler(info_handler) 55 | 56 | return logger 57 | 58 | 59 | def log_config_info(config, logger): 60 | config_dict = config.__dict__ 61 | log_info = f'#----------Config info----------#' 62 | logger.info(log_info) 63 | for k, v in config_dict.items(): 64 | if k[0] == '_': 65 | continue 66 | else: 67 | log_info = f'{k}: {v},' 68 | logger.info(log_info) 69 | 70 | 71 | 72 | def get_optimizer(config, model): 73 | assert config.opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!' 74 | 75 | if config.opt == 'Adadelta': 76 | return torch.optim.Adadelta( 77 | model.parameters(), 78 | lr = config.lr, 79 | rho = config.rho, 80 | eps = config.eps, 81 | weight_decay = config.weight_decay 82 | ) 83 | elif config.opt == 'Adagrad': 84 | return torch.optim.Adagrad( 85 | model.parameters(), 86 | lr = config.lr, 87 | lr_decay = config.lr_decay, 88 | eps = config.eps, 89 | weight_decay = config.weight_decay 90 | ) 91 | elif config.opt == 'Adam': 92 | return torch.optim.Adam( 93 | model.parameters(), 94 | lr = config.lr, 95 | betas = config.betas, 96 | eps = config.eps, 97 | weight_decay = config.weight_decay, 98 | amsgrad = config.amsgrad 99 | ) 100 | elif config.opt == 'AdamW': 101 | return torch.optim.AdamW( 102 | model.parameters(), 103 | lr = config.lr, 104 | betas = config.betas, 105 | eps = config.eps, 106 | weight_decay = config.weight_decay, 107 | amsgrad = config.amsgrad 108 | ) 109 | elif config.opt == 'Adamax': 110 | return torch.optim.Adamax( 111 | model.parameters(), 112 | lr = config.lr, 113 | betas = config.betas, 114 | eps = config.eps, 115 | weight_decay = config.weight_decay 116 | ) 117 | elif config.opt == 'ASGD': 118 | return torch.optim.ASGD( 119 | model.parameters(), 120 | lr = config.lr, 121 | lambd = config.lambd, 122 | alpha = config.alpha, 123 | t0 = config.t0, 124 | weight_decay = config.weight_decay 125 | ) 126 | elif config.opt == 'RMSprop': 127 | return torch.optim.RMSprop( 128 | model.parameters(), 129 | lr = config.lr, 130 | momentum = config.momentum, 131 | alpha = config.alpha, 132 | eps = config.eps, 133 | centered = config.centered, 134 | weight_decay = config.weight_decay 135 | ) 136 | elif config.opt == 'Rprop': 137 | return torch.optim.Rprop( 138 | model.parameters(), 139 | lr = config.lr, 140 | etas = config.etas, 141 | step_sizes = config.step_sizes, 142 | ) 143 | elif config.opt == 'SGD': 144 | return torch.optim.SGD( 145 | model.parameters(), 146 | lr = config.lr, 147 | momentum = config.momentum, 148 | weight_decay = config.weight_decay, 149 | dampening = config.dampening, 150 | nesterov = config.nesterov 151 | ) 152 | else: # default opt is SGD 153 | return torch.optim.SGD( 154 | model.parameters(), 155 | lr = 0.01, 156 | momentum = 0.9, 157 | weight_decay = 0.05, 158 | ) 159 | 160 | 161 | 162 | def get_scheduler(config, optimizer): 163 | assert config.sch in ['StepLR', 'MultiStepLR', 'ExponentialLR', 'CosineAnnealingLR', 'ReduceLROnPlateau', 164 | 'CosineAnnealingWarmRestarts', 'WP_MultiStepLR', 'WP_CosineLR'], 'Unsupported scheduler!' 165 | if config.sch == 'StepLR': 166 | scheduler = torch.optim.lr_scheduler.StepLR( 167 | optimizer, 168 | step_size = config.step_size, 169 | gamma = config.gamma, 170 | last_epoch = config.last_epoch 171 | ) 172 | elif config.sch == 'MultiStepLR': 173 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 174 | optimizer, 175 | milestones = config.milestones, 176 | gamma = config.gamma, 177 | last_epoch = config.last_epoch 178 | ) 179 | elif config.sch == 'ExponentialLR': 180 | scheduler = torch.optim.lr_scheduler.ExponentialLR( 181 | optimizer, 182 | gamma = config.gamma, 183 | last_epoch = config.last_epoch 184 | ) 185 | elif config.sch == 'CosineAnnealingLR': 186 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 187 | optimizer, 188 | T_max = config.T_max, 189 | eta_min = config.eta_min, 190 | last_epoch = config.last_epoch 191 | ) 192 | elif config.sch == 'ReduceLROnPlateau': 193 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 194 | optimizer, 195 | mode = config.mode, 196 | factor = config.factor, 197 | patience = config.patience, 198 | threshold = config.threshold, 199 | threshold_mode = config.threshold_mode, 200 | cooldown = config.cooldown, 201 | min_lr = config.min_lr, 202 | eps = config.eps 203 | ) 204 | elif config.sch == 'CosineAnnealingWarmRestarts': 205 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 206 | optimizer, 207 | T_0 = config.T_0, 208 | T_mult = config.T_mult, 209 | eta_min = config.eta_min, 210 | last_epoch = config.last_epoch 211 | ) 212 | elif config.sch == 'WP_MultiStepLR': 213 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else config.gamma**len( 214 | [m for m in config.milestones if m <= epoch]) 215 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func) 216 | elif config.sch == 'WP_CosineLR': 217 | lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else 0.5 * ( 218 | math.cos((epoch - config.warm_up_epochs) / (config.epochs - config.warm_up_epochs) * math.pi) + 1) 219 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func) 220 | 221 | return scheduler 222 | 223 | 224 | 225 | def save_imgs(img, msk, msk_pred, i, save_path, datasets, threshold=0.5, test_data_name=None): 226 | img = img.squeeze(0).permute(1,2,0).detach().cpu().numpy() 227 | img = img / 255. if img.max() > 1.1 else img 228 | if datasets == 'retinal': 229 | msk = np.squeeze(msk, axis=0) 230 | msk_pred = np.squeeze(msk_pred, axis=0) 231 | else: 232 | msk = np.where(np.squeeze(msk, axis=0) > 0.5, 1, 0) 233 | msk_pred = np.where(np.squeeze(msk_pred, axis=0) > threshold, 1, 0) 234 | 235 | plt.figure(figsize=(7,15)) 236 | 237 | plt.subplot(3,1,1) 238 | plt.imshow(img) 239 | plt.axis('off') 240 | 241 | plt.subplot(3,1,2) 242 | plt.imshow(msk, cmap= 'gray') 243 | plt.axis('off') 244 | 245 | plt.subplot(3,1,3) 246 | plt.imshow(msk_pred, cmap = 'gray') 247 | plt.axis('off') 248 | 249 | if test_data_name is not None: 250 | save_path = save_path + test_data_name + '_' 251 | plt.savefig(save_path + str(i) +'.png') 252 | plt.close() 253 | 254 | 255 | 256 | class BCELoss(nn.Module): 257 | def __init__(self): 258 | super(BCELoss, self).__init__() 259 | self.bceloss = nn.BCELoss() 260 | 261 | def forward(self, pred, target): 262 | size = pred.size(0) 263 | pred_ = pred.view(size, -1) 264 | target_ = target.view(size, -1) 265 | 266 | return self.bceloss(pred_, target_) 267 | 268 | 269 | class DiceLoss(nn.Module): 270 | def __init__(self): 271 | super(DiceLoss, self).__init__() 272 | 273 | def forward(self, pred, target): 274 | smooth = 1 275 | size = pred.size(0) 276 | 277 | pred_ = pred.view(size, -1) 278 | target_ = target.view(size, -1) 279 | intersection = pred_ * target_ 280 | dice_score = (2 * intersection.sum(1) + smooth)/(pred_.sum(1) + target_.sum(1) + smooth) 281 | dice_loss = 1 - dice_score.sum()/size 282 | 283 | return dice_loss 284 | 285 | 286 | class BceDiceLoss(nn.Module): 287 | def __init__(self, wb=1, wd=1): 288 | super(BceDiceLoss, self).__init__() 289 | self.bce = BCELoss() 290 | self.dice = DiceLoss() 291 | self.wb = wb 292 | self.wd = wd 293 | 294 | def forward(self, pred, target): 295 | bceloss = self.bce(pred, target) 296 | diceloss = self.dice(pred, target) 297 | 298 | loss = self.wd * diceloss + self.wb * bceloss 299 | return loss 300 | 301 | 302 | 303 | class myToTensor: 304 | def __init__(self): 305 | pass 306 | def __call__(self, data): 307 | image, mask = data 308 | return torch.tensor(image).permute(2,0,1), torch.tensor(mask).permute(2,0,1) 309 | 310 | 311 | class myResize: 312 | def __init__(self, size_h=256, size_w=256): 313 | self.size_h = size_h 314 | self.size_w = size_w 315 | def __call__(self, data): 316 | image, mask = data 317 | return TF.resize(image, [self.size_h, self.size_w]), TF.resize(mask, [self.size_h, self.size_w]) 318 | 319 | 320 | class myRandomHorizontalFlip: 321 | def __init__(self, p=0.5): 322 | self.p = p 323 | def __call__(self, data): 324 | image, mask = data 325 | if random.random() < self.p: return TF.hflip(image), TF.hflip(mask) 326 | else: return image, mask 327 | 328 | 329 | class myRandomVerticalFlip: 330 | def __init__(self, p=0.5): 331 | self.p = p 332 | def __call__(self, data): 333 | image, mask = data 334 | if random.random() < self.p: return TF.vflip(image), TF.vflip(mask) 335 | else: return image, mask 336 | 337 | 338 | class myRandomRotation: 339 | def __init__(self, p=0.5, degree=[0,360]): 340 | self.angle = random.uniform(degree[0], degree[1]) 341 | self.p = p 342 | def __call__(self, data): 343 | image, mask = data 344 | if random.random() < self.p: return TF.rotate(image,self.angle), TF.rotate(mask,self.angle) 345 | else: return image, mask 346 | 347 | 348 | class myNormalize: 349 | def __init__(self, data_name, train=True,test=False): 350 | if data_name == 'isic17': 351 | if train: 352 | self.mean = 156.9704 353 | self.std = 27.8712 354 | elif test: 355 | self.mean = 156.3398 356 | self.std = 28.2681 357 | else: 358 | self.mean = 159.0763 359 | self.std = 29.6363 360 | elif data_name == 'isic18': 361 | if train: 362 | self.mean = 154.7468 363 | self.std = 28.6875 364 | elif test: 365 | self.mean = 156.6547 366 | self.std = 28.5943 367 | else: 368 | self.mean = 156.0735 369 | self.std = 28.9695 370 | elif data_name == 'PH2': 371 | if train: 372 | self.mean = 154.8111 373 | self.std = 41.3169 374 | elif test: 375 | self.mean = 154.3235 376 | self.std = 40.8471 377 | else: 378 | self.mean = 154.3235 379 | self.std = 40.8471 380 | 381 | def __call__(self, data): 382 | img, msk = data 383 | img_normalized = (img-self.mean)/self.std 384 | img_normalized = ((img_normalized - np.min(img_normalized)) 385 | / (np.max(img_normalized)-np.min(img_normalized))) * 255. 386 | return img_normalized, msk 387 | 388 | 389 | 390 | -------------------------------------------------------------------------------- /models/HF_UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | from timm.models.layers import trunc_normal_, DropPath 6 | import os 7 | import sys 8 | import torch.fft 9 | import math 10 | from torch.autograd import Variable 11 | import numpy as np 12 | 13 | import traceback 14 | 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | 17 | class DepthWiseConv2d(nn.Module): 18 | def __init__(self, dim_in, dim_out, kernel_size=3, padding=1, stride=1, dilation=1): 19 | super().__init__() 20 | 21 | self.conv1 = nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, 22 | stride=stride, dilation=dilation, groups=dim_in) 23 | self.norm_layer = nn.GroupNorm(4, dim_in) 24 | self.conv2 = nn.Conv2d(dim_in, dim_out, kernel_size=1) 25 | 26 | def forward(self, x): 27 | return self.conv2(self.norm_layer(self.conv1(x))) 28 | 29 | 30 | if 'DWCONV_IMPL' in os.environ: 31 | try: 32 | sys.path.append(os.environ['DWCONV_IMPL']) 33 | from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM 34 | 35 | 36 | def get_dwconv(dim, kernel, bias): 37 | return DepthWiseConv2dImplicitGEMM(dim, kernel, bias) 38 | # print('Using Megvii large kernel dw conv impl') 39 | except: 40 | print(traceback.format_exc()) 41 | 42 | 43 | def get_dwconv(dim, kernel, bias): 44 | return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel - 1) // 2, bias=bias, groups=dim) 45 | 46 | # print('[fail to use Megvii Large kernel] Using PyTorch large kernel dw conv impl') 47 | else: 48 | def get_dwconv(dim, kernel, bias): 49 | return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel - 1) // 2, bias=bias, groups=dim) 50 | 51 | # print('Using PyTorch large kernel dw conv impl') 52 | 53 | 54 | class Attention_block(nn.Module): 55 | def __init__(self, F_g, F_l, F_int): 56 | super(Attention_block, self).__init__() 57 | self.W_g = nn.Sequential( 58 | nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), 59 | nn.BatchNorm2d(F_int) 60 | ) 61 | 62 | self.W_x = nn.Sequential( 63 | nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), 64 | nn.BatchNorm2d(F_int) 65 | ) 66 | 67 | self.psi = nn.Sequential( 68 | nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), 69 | nn.BatchNorm2d(1), 70 | nn.Sigmoid() 71 | ) 72 | 73 | self.relu = nn.ReLU(inplace=True) 74 | 75 | def forward(self, g, x): 76 | g1 = self.W_g(g) 77 | x1 = self.W_x(x) 78 | psi = self.relu(g1 + x1) 79 | psi = self.psi(psi) 80 | 81 | return x * psi 82 | 83 | 84 | class HFConv(nn.Module): 85 | def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0): 86 | super().__init__() 87 | self.order = order 88 | self.dims = [dim // 2 ** i for i in range(order)] 89 | self.dims.reverse() 90 | self.proj_in = nn.Conv2d(dim, 2 * dim, 1) 91 | 92 | if gflayer is None: 93 | self.dwconv = get_dwconv(sum(self.dims), 7, True) 94 | else: 95 | self.dwconv = gflayer(sum(self.dims), h=h, w=w) 96 | 97 | self.proj_out = nn.Conv2d(dim, dim, 1) 98 | 99 | self.pws = nn.ModuleList( 100 | [nn.Conv2d(self.dims[i], self.dims[i + 1], 1) for i in range(order - 1)] 101 | ) 102 | if self.order == 3: 103 | self.conv0 = nn.Conv2d(self.dims[0], self.dims[1], 1) 104 | self.conv1 = nn.Conv2d(self.dims[1], self.dims[2], 1) 105 | else: 106 | if self.order == 4: 107 | self.conv0 = nn.Conv2d(self.dims[0], self.dims[1], 1) 108 | self.conv1 = nn.Conv2d(self.dims[1], self.dims[2], 1) 109 | self.conv2 = nn.Conv2d(self.dims[2], self.dims[3], 1) 110 | else: 111 | if self.order == 5: 112 | self.conv0 = nn.Conv2d(self.dims[0], self.dims[1], 1) 113 | self.conv1 = nn.Conv2d(self.dims[1], self.dims[2], 1) 114 | self.conv2 = nn.Conv2d(self.dims[2], self.dims[3], 1) 115 | self.conv3 = nn.Conv2d(self.dims[3], self.dims[4], 1) 116 | else: 117 | if self.order == 2: 118 | self.conv0 = nn.Conv2d(self.dims[0], self.dims[1], 1) 119 | 120 | 121 | 122 | self.focol = nn.ModuleList([ 123 | BasicLayer( 124 | dim=self.dims[i + 1], 125 | depth=2, 126 | mlp_ratio=4., 127 | drop=0., 128 | drop_path=0.4, 129 | norm_layer=nn.LayerNorm, 130 | focal_window=9, 131 | focal_level=2, 132 | use_layerscale=False, 133 | use_checkpoint=False) for i in range(order - 1)] 134 | ) 135 | if self.order == 3: 136 | self.focol0 = BasicLayer(dim=self.dims[0], depth=2, mlp_ratio=4., drop=0., drop_path=0.4, 137 | norm_layer=nn.LayerNorm, focal_window=9, focal_level=2, use_layerscale=False, 138 | use_checkpoint=False) 139 | self.focol1 = BasicLayer(dim=self.dims[1], depth=2, mlp_ratio=4., drop=0., drop_path=0.4, 140 | norm_layer=nn.LayerNorm, focal_window=9, focal_level=2, use_layerscale=False, 141 | use_checkpoint=False) 142 | self.focol2 = BasicLayer( 143 | dim=self.dims[2], 144 | depth=2, 145 | mlp_ratio=4., 146 | drop=0., 147 | drop_path=0.4, 148 | norm_layer=nn.LayerNorm, 149 | focal_window=9, 150 | focal_level=2, 151 | use_layerscale=False, 152 | use_checkpoint=False) 153 | else: 154 | if self.order == 4: 155 | self.focol0 = BasicLayer( 156 | dim=self.dims[0], 157 | depth=2, 158 | mlp_ratio=4., 159 | drop=0., 160 | drop_path=0.4, 161 | norm_layer=nn.LayerNorm, 162 | focal_window=9, 163 | focal_level=2, 164 | use_layerscale=False, 165 | use_checkpoint=False) 166 | 167 | self.focol1 = BasicLayer( 168 | dim=self.dims[1], 169 | depth=2, 170 | mlp_ratio=4., 171 | drop=0., 172 | drop_path=0.4, 173 | norm_layer=nn.LayerNorm, 174 | focal_window=9, 175 | focal_level=2, 176 | use_layerscale=False, 177 | use_checkpoint=False) 178 | 179 | self.focol2 = BasicLayer( 180 | dim=self.dims[2], 181 | depth=2, 182 | mlp_ratio=4., 183 | drop=0., 184 | drop_path=0.4, 185 | norm_layer=nn.LayerNorm, 186 | focal_window=9, 187 | focal_level=2, 188 | use_layerscale=False, 189 | use_checkpoint=False) 190 | 191 | self.focol3 = BasicLayer( 192 | dim=self.dims[3], 193 | depth=2, 194 | mlp_ratio=4., 195 | drop=0., 196 | drop_path=0.4, 197 | norm_layer=nn.LayerNorm, 198 | focal_window=9, 199 | focal_level=2, 200 | use_layerscale=False, 201 | use_checkpoint=False) 202 | else: 203 | if self.order == 5: 204 | self.focol0 = BasicLayer( 205 | dim=self.dims[0], 206 | depth=2, 207 | mlp_ratio=4., 208 | drop=0., 209 | drop_path=0.4, 210 | norm_layer=nn.LayerNorm, 211 | focal_window=9, 212 | focal_level=2, 213 | use_layerscale=False, 214 | use_checkpoint=False) 215 | 216 | self.focol1 = BasicLayer( 217 | dim=self.dims[1], 218 | depth=2, 219 | mlp_ratio=4., 220 | drop=0., 221 | drop_path=0.4, 222 | norm_layer=nn.LayerNorm, 223 | focal_window=9, 224 | focal_level=2, 225 | use_layerscale=False, 226 | use_checkpoint=False) 227 | 228 | self.focol2 = BasicLayer( 229 | dim=self.dims[2], 230 | depth=2, 231 | mlp_ratio=4., 232 | drop=0., 233 | drop_path=0.4, 234 | norm_layer=nn.LayerNorm, 235 | focal_window=9, 236 | focal_level=2, 237 | use_layerscale=False, 238 | use_checkpoint=False) 239 | 240 | self.focol3 = BasicLayer( 241 | dim=self.dims[3], 242 | depth=2, 243 | mlp_ratio=4., 244 | drop=0., 245 | drop_path=0.4, 246 | norm_layer=nn.LayerNorm, 247 | focal_window=9, 248 | focal_level=2, 249 | use_layerscale=False, 250 | use_checkpoint=False) 251 | 252 | self.focol4 = BasicLayer( 253 | dim=self.dims[4], 254 | depth=2, 255 | mlp_ratio=4., 256 | drop=0., 257 | drop_path=0.4, 258 | norm_layer=nn.LayerNorm, 259 | focal_window=9, 260 | focal_level=2, 261 | use_layerscale=False, 262 | use_checkpoint=False) 263 | else: 264 | if self.order == 2: 265 | self.focol0 = BasicLayer( 266 | dim=self.dims[0], 267 | depth=2, 268 | mlp_ratio=4., 269 | drop=0., 270 | drop_path=0.4, 271 | norm_layer=nn.LayerNorm, 272 | focal_window=9, 273 | focal_level=2, 274 | use_layerscale=False, 275 | use_checkpoint=False) 276 | 277 | self.focol1 = BasicLayer( 278 | dim=self.dims[1], 279 | depth=2, 280 | mlp_ratio=4., 281 | drop=0., 282 | drop_path=0.4, 283 | norm_layer=nn.LayerNorm, 284 | focal_window=9, 285 | focal_level=2, 286 | use_layerscale=False, 287 | use_checkpoint=False) 288 | 289 | 290 | self.AG = nn.ModuleList([ 291 | Attention_block(F_g=self.dims[i + 1], F_l=self.dims[i + 1], F_int=self.dims[i + 1]) for i in 292 | range(order - 1)] 293 | ) 294 | 295 | self.AG0 = Attention_block(F_g=self.dims[0], F_l=self.dims[0], F_int=self.dims[0]) 296 | 297 | if self.order == 3: 298 | self.AG1 = Attention_block(F_g=self.dims[1], F_l=self.dims[1], F_int=self.dims[1]) 299 | self.AG2 = Attention_block(F_g=self.dims[2], F_l=self.dims[2], F_int=self.dims[2]) 300 | else: 301 | if self.order == 4: 302 | self.AG1 = Attention_block(F_g=self.dims[1], F_l=self.dims[1], F_int=self.dims[1]) 303 | self.AG2 = Attention_block(F_g=self.dims[2], F_l=self.dims[2], F_int=self.dims[2]) 304 | self.AG3 = Attention_block(F_g=self.dims[3], F_l=self.dims[3], F_int=self.dims[3]) 305 | else: 306 | if self.order == 5: 307 | self.AG1 = Attention_block(F_g=self.dims[1], F_l=self.dims[1], F_int=self.dims[1]) 308 | self.AG2 = Attention_block(F_g=self.dims[2], F_l=self.dims[2], F_int=self.dims[2]) 309 | self.AG3 = Attention_block(F_g=self.dims[3], F_l=self.dims[3], F_int=self.dims[3]) 310 | self.AG4 = Attention_block(F_g=self.dims[4], F_l=self.dims[4], F_int=self.dims[4]) 311 | else: 312 | if self.order == 2: 313 | self.AG1 = Attention_block(F_g=self.dims[1], F_l=self.dims[1], F_int=self.dims[1]) 314 | 315 | 316 | self.scale = s 317 | 318 | print('[HFConv]', order, 'order with dims=', self.dims, 'scale=%.4f' % self.scale) 319 | 320 | def forward(self, x, mask=None, dummy=False): 321 | B, C, H, W = x.shape 322 | 323 | fused_x = self.proj_in(x) 324 | pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1) 325 | 326 | dw_abc = self.dwconv(abc) * self.scale 327 | 328 | dw_list = torch.split(dw_abc, self.dims, dim=1) 329 | x = self.focol0(pwa) 330 | x = self.AG0(g=x, x=dw_list[0]) # pwa * dw_list[0] 331 | x = self.focol1(self.conv0(x)) 332 | x = self.AG1(g=x, x=dw_list[1]) 333 | 334 | if self.order == 3: 335 | x = self.focol2(self.conv1(x)) 336 | x = self.AG2(g=x, x=dw_list[2]) 337 | else: 338 | if self.order == 4: 339 | x = self.focol2(self.conv1(x)) 340 | x = self.AG2(g=x, x=dw_list[2]) 341 | x = self.focol3(self.conv2(x)) 342 | x = self.AG3(g=x, x=dw_list[3]) 343 | else: 344 | if self.order == 5: 345 | x = self.focol2(self.conv1(x)) 346 | x = self.AG2(g=x, x=dw_list[2]) 347 | x = self.focol3(self.conv2(x)) 348 | x = self.AG3(g=x, x=dw_list[3]) 349 | x = self.focol4(self.conv3(x)) 350 | x = self.AG4(g=x, x=dw_list[4]) 351 | #else: 352 | # print('Please select 2, 3, 4 and 5 order') 353 | 354 | x = self.proj_out(x) 355 | 356 | return x 357 | 358 | 359 | class Block(nn.Module): 360 | """ 361 | HFblock 362 | """ 363 | 364 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, HFConv=HFConv): 365 | super().__init__() 366 | 367 | self.norm1 = LayerNorm(dim, eps=1e-6, data_format='channels_first') 368 | self.HFConv = HFConv(dim) # depthwise conv 369 | self.norm2 = LayerNorm(dim, eps=1e-6) 370 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 371 | self.act = nn.GELU() 372 | self.pwconv2 = nn.Linear(4 * dim, dim) 373 | 374 | self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim), 375 | requires_grad=True) if layer_scale_init_value > 0 else None 376 | 377 | self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 378 | requires_grad=True) if layer_scale_init_value > 0 else None 379 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 380 | 381 | def forward(self, x): 382 | B, C, H, W = x.shape 383 | if self.gamma1 is not None: 384 | gamma1 = self.gamma1.view(C, 1, 1) 385 | else: 386 | gamma1 = 1 387 | x = x + self.drop_path(gamma1 * self.HFConv(self.norm1(x))) 388 | 389 | input = x 390 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 391 | x = self.norm2(x) 392 | x = self.pwconv1(x) 393 | x = self.act(x) 394 | x = self.pwconv2(x) 395 | if self.gamma2 is not None: 396 | x = self.gamma2 * x 397 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 398 | 399 | x = input + self.drop_path(x) 400 | return x 401 | 402 | class SingleConv1(nn.Module): 403 | def __init__(self, in_channels, out_channels, ker_size=3, padding=1): 404 | super().__init__() 405 | self.Single_Conv = nn.Sequential( 406 | nn.Conv2d(in_channels, out_channels, kernel_size=ker_size, padding=padding), 407 | nn.BatchNorm2d(out_channels), 408 | nn.ReLU(inplace=True), 409 | ) 410 | def forward(self, x): 411 | return self.Single_Conv(x) 412 | 413 | class Edge_aware_unit(nn.Module): 414 | """ 415 | Lesion localization module(LL-M) 416 | """ 417 | 418 | def __init__(self, in_channels, out_channels): 419 | super(Edge_aware_unit, self).__init__() 420 | self.conv1 = SingleConv1(in_channels, out_channels, ker_size=1, padding=0) 421 | self.conv2 = SingleConv1(out_channels, out_channels // 2, ker_size=3) 422 | self.conv3 = SingleConv1(out_channels//2, out_channels//2, ker_size=3) 423 | self.conv4 = SingleConv1(in_channels, out_channels//2, ker_size=3) 424 | 425 | def detect_edge_1(self, inputs, sobel_kernel): 426 | kernel = np.array(sobel_kernel, dtype='float32') 427 | kernel = kernel.reshape((1, 1, 3, 3)) 428 | weight = Variable(torch.from_numpy(kernel)).to(device) 429 | edge = torch.zeros(inputs.size()[1],inputs.size()[0],inputs.size()[2],inputs.size()[3]).to(device) 430 | for k in range(inputs.size()[1]): 431 | fea_input = inputs[:,k,:,:] 432 | fea_input = fea_input.unsqueeze(1) 433 | edge_c = F.conv2d(fea_input, weight, padding=1) 434 | edge[k] = edge_c.squeeze(1) 435 | edge = edge.permute(1, 0, 2, 3) 436 | return edge 437 | 438 | def detect_edge_2(self, inputs, sobel_kernel): 439 | kernel = np.array(sobel_kernel, dtype='float32') 440 | kernel = kernel.reshape((1, 1, 5, 5)) 441 | weight = Variable(torch.from_numpy(kernel)).to(device) 442 | edge = torch.zeros(inputs.size()[1],inputs.size()[0],inputs.size()[2],inputs.size()[3]).to(device) 443 | for k in range(inputs.size()[1]): 444 | fea_input = inputs[:,k,:,:] 445 | fea_input = fea_input.unsqueeze(1) 446 | edge_c = F.conv2d(fea_input, weight, padding=2) 447 | edge[k] = edge_c.squeeze(1) 448 | edge = edge.permute(1, 0, 2, 3) 449 | return edge 450 | 451 | def sobel_conv2d(self, inputs): 452 | edge_detect1 = self.detect_edge_1(inputs, [[2, 4, 2], [0, 0, 0], [-2, -4, -2]]) 453 | edge_detect2 = self.detect_edge_2(inputs, [[0, 0, 0, 0, 0], [0, -2, -4, -2, 0], [-1, -4, 0, 4, 1], [0, 2, 4, 2, 0], [0, 0, 0, 0, 0]]) 454 | edge_detect3 = self.detect_edge_1(inputs, [[4, 2, 0], [2, 0, -2], [0, -2, -4]]) 455 | edge_detect4 = self.detect_edge_2(inputs, [[0, 0, -1, 0, 0], [0, -2, -4, 2, 0], [0, -4, 0, 4, 0], [0, -2, 4, 2, 0], [0, 0, 1, 0, 0]]) 456 | edge_detect5 = self.detect_edge_1(inputs, [[2, 0, -2], [4, 0, -4], [2, 0, -2]]) 457 | edge_detect6 = self.detect_edge_2(inputs, [[0, 0, 1, 0, 0], [0, -2, 4, 2, 0], [0, -4, 0, 4, 0], [0, -2, -4, 2, 0], [0, 0, -1, 0, 0]]) 458 | edge_detect7 = self.detect_edge_1(inputs, [[0, -2, -4], [2, 0, -2], [4, 2, 0]]) 459 | edge_detect8 = self.detect_edge_2(inputs, [[0, 0, 0, 0, 0], [0, 2, 4, 2, 0], [-1, -4, 0, 4, 1], [0, -2, -4, -2, 0], [0, 0, 0, 0, 0]]) 460 | edge = edge_detect1+edge_detect2+edge_detect3+edge_detect4+edge_detect5+edge_detect6+edge_detect7+edge_detect8 461 | return edge 462 | 463 | def forward(self, input_f): 464 | conv1 = self.conv1(input_f) 465 | conv2 = self.conv2(conv1) 466 | edge_f = self.sobel_conv2d(conv2) 467 | conv3 = self.conv3(edge_f) 468 | input_f = self.conv4(input_f) 469 | conca = torch.cat([input_f, conv3], dim=1) 470 | return conca 471 | 472 | 473 | class Spatial_Att_Bridge(nn.Module): 474 | def __init__(self): 475 | super().__init__() 476 | self.shared_conv2d = nn.Sequential(nn.Conv2d(2, 1, 7, stride=1, padding=9, dilation=3), 477 | nn.Sigmoid()) 478 | 479 | def forward(self, t1, t2, t3, t4, t5): 480 | t_list = [t1, t2, t3, t4, t5] 481 | att_list = [] 482 | for t in t_list: 483 | avg_out = torch.mean(t, dim=1, keepdim=True) 484 | max_out, _ = torch.max(t, dim=1, keepdim=True) 485 | att = torch.cat([avg_out, max_out], dim=1) 486 | att = self.shared_conv2d(att) 487 | att_list.append(att) 488 | return att_list[0], att_list[1], att_list[2], att_list[3], att_list[4] 489 | 490 | 491 | class SC_Att_Bridge(nn.Module): 492 | def __init__(self): 493 | super().__init__() 494 | 495 | self.satt = Spatial_Att_Bridge() 496 | 497 | def forward(self, t1, t2, t3, t4, t5): 498 | r1, r2, r3, r4, r5 = t1, t2, t3, t4, t5 499 | 500 | satt1, satt2, satt3, satt4, satt5 = self.satt(t1, t2, t3, t4, t5) 501 | t1, t2, t3, t4, t5 = satt1 * t1, satt2 * t2, satt3 * t3, satt4 * t4, satt5 * t5 502 | 503 | t1, t2, t3, t4, t5 = t1 + r1, t2 + r2, t3 + r3, t4 + r4, t5 + r5 504 | return t1, t2, t3, t4, t5 505 | 506 | class HFUNet(nn.Module): 507 | 508 | def __init__(self, num_classes=1, input_channels=3, layer_scale_init_value=1e-6, HFConv=HFConv, block=Block, 509 | pretrained=None, 510 | use_checkpoint=False, c_list=[32, 64, 128, 256, 512, 1024], depths=[1, 1, 1], depths_out=[1, 1, 1], 511 | drop_path_rate=0., 512 | split_att='fc', mlp_ratio=4., drop_rate=0., depths_foc=[2, 2, 2], focal_levels=[2, 2, 2, 2], 513 | focal_windows=[9, 9, 9, 9], norm_layer=nn.LayerNorm, use_layerscale=False, bridge=True): 514 | super().__init__() 515 | self.pretrained = pretrained 516 | self.use_checkpoint = use_checkpoint 517 | self.bridge = bridge 518 | 519 | self.encoder1 = nn.Sequential( 520 | nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1), 521 | ) 522 | self.encoder2 = nn.Sequential( 523 | nn.Conv2d(c_list[0], c_list[1], 3, stride=1, padding=1), 524 | ) 525 | 526 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 527 | 528 | if not isinstance(HFConv, list): 529 | HFConv = [partial(HFConv, order=2, s=1 / 3, gflayer=GlobalLocalFilter), 530 | partial(HFConv, order=3, s=1 / 3, gflayer=GlobalLocalFilter), 531 | partial(HFConv, order=4, s=1 / 3, h=24, w=13, gflayer=GlobalLocalFilter), 532 | partial(HFConv, order=5, s=1 / 3, h=12, w=7, gflayer=GlobalLocalFilter)] 533 | else: 534 | HFConv = HFConv 535 | assert len(HFConv) == 3 536 | 537 | if isinstance(HFConv[0], str): 538 | HFConv = [eval(h) for h in HFConv] 539 | 540 | if isinstance(block, str): 541 | block = eval(block) 542 | 543 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_foc))] 544 | 545 | self.encoder3 = nn.Conv2d(c_list[1], c_list[2], 3, stride=1, padding=1) 546 | 547 | self.encoder4 = nn.Sequential( 548 | *[block(dim=c_list[2], drop_path=dp_rates[0 + j], 549 | layer_scale_init_value=layer_scale_init_value, HFConv=HFConv[2]) for j in range(depths[0])], 550 | Edge_aware_unit(c_list[2], c_list[2]), 551 | nn.Conv2d(c_list[2], c_list[3], 3, stride=1, padding=1), 552 | ) 553 | 554 | self.encoder5 = nn.Sequential( 555 | *[block(dim=c_list[3], drop_path=dp_rates[1 + j], 556 | layer_scale_init_value=layer_scale_init_value, HFConv=HFConv[2]) for j in range(depths[1])], 557 | Edge_aware_unit(c_list[3], c_list[3]), 558 | nn.Conv2d(c_list[3], c_list[4], 3, stride=1, padding=1), 559 | ) 560 | 561 | self.encoder6 = nn.Sequential( 562 | *[block(dim=c_list[4], drop_path=dp_rates[2 + j], 563 | layer_scale_init_value=layer_scale_init_value, HFConv=HFConv[2]) for j in range(depths[2])], 564 | Edge_aware_unit(c_list[4], c_list[4]), 565 | nn.Conv2d(c_list[4], c_list[5], 3, stride=1, padding=1), 566 | ) 567 | 568 | # build Bottleneck layers 569 | self.ConvMixer = ConvMixerBlock(dim=c_list[5], depth=7, k=7) 570 | # Skip-connection 571 | self.mdag4 = MDAG(channel=c_list[4]) 572 | self.mdag3 = MDAG(channel=c_list[3]) 573 | self.mdag2 = MDAG(channel=c_list[2]) 574 | self.mdag1 = MDAG(channel=c_list[1]) 575 | self.mdag0 = MDAG(channel=c_list[0]) 576 | 577 | if bridge: 578 | self.scab = SC_Att_Bridge() 579 | print('SC_Att_Bridge was used') 580 | 581 | dp_rates_out = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_out))] 582 | 583 | self.decoder1 = nn.Sequential( 584 | nn.Conv2d(c_list[5], c_list[4], 3, stride=1, padding=1), 585 | Edge_aware_unit(c_list[4], c_list[4]), 586 | *[block(dim=c_list[4], drop_path=dp_rates_out[2 + j], 587 | layer_scale_init_value=layer_scale_init_value, HFConv=HFConv[2]) for j in range(depths_out[2])], 588 | ) 589 | 590 | self.decoder2 = nn.Sequential( 591 | nn.Conv2d(c_list[4], c_list[3], 3, stride=1, padding=1), 592 | Edge_aware_unit(c_list[3], c_list[3]), 593 | *[block(dim=c_list[3], drop_path=dp_rates_out[1 + j], 594 | layer_scale_init_value=layer_scale_init_value, HFConv=HFConv[2]) for j in range(depths_out[1])], 595 | ) 596 | 597 | self.decoder3 = nn.Sequential( 598 | nn.Conv2d(c_list[3], c_list[2], 3, stride=1, padding=1), 599 | Edge_aware_unit(c_list[2], c_list[2]), 600 | *[block(dim=c_list[2], drop_path=dp_rates_out[0 + j], 601 | layer_scale_init_value=layer_scale_init_value, HFConv=HFConv[2]) for j in range(depths_out[0])], 602 | ) 603 | 604 | self.decoder4 = nn.Conv2d(c_list[2], c_list[1], 3, stride=1, padding=1) 605 | 606 | self.decoder5 = nn.Sequential( 607 | nn.Conv2d(c_list[1], c_list[0], 3, stride=1, padding=1), 608 | ) 609 | 610 | self.ebn1 = nn.GroupNorm(4, c_list[0]) 611 | self.ebn2 = nn.GroupNorm(4, c_list[1]) 612 | self.ebn3 = nn.GroupNorm(4, c_list[2]) 613 | self.ebn4 = nn.GroupNorm(4, c_list[3]) 614 | self.ebn5 = nn.GroupNorm(4, c_list[4]) 615 | self.ebn6 = nn.GroupNorm(4, c_list[5]) 616 | self.dbn1 = nn.GroupNorm(4, c_list[4]) 617 | self.dbn2 = nn.GroupNorm(4, c_list[3]) 618 | self.dbn3 = nn.GroupNorm(4, c_list[2]) 619 | self.dbn4 = nn.GroupNorm(4, c_list[1]) 620 | self.dbn5 = nn.GroupNorm(4, c_list[0]) 621 | 622 | self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1) 623 | 624 | self.apply(self._init_weights) 625 | 626 | def _init_weights(self, m): 627 | if isinstance(m, nn.Linear): 628 | trunc_normal_(m.weight, std=.02) 629 | if isinstance(m, nn.Linear) and m.bias is not None: 630 | nn.init.constant_(m.bias, 0) 631 | elif isinstance(m, nn.Conv1d): 632 | n = m.kernel_size[0] * m.out_channels 633 | m.weight.data.normal_(0, math.sqrt(2. / n)) 634 | elif isinstance(m, nn.Conv2d): 635 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 636 | fan_out //= m.groups 637 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 638 | if m.bias is not None: 639 | m.bias.data.zero_() 640 | 641 | def forward(self, x): 642 | 643 | out = F.gelu(F.max_pool2d(self.ebn1(self.encoder1(x)), 2, 2)) 644 | t1 = out # b, c0, H/2, W/2 645 | 646 | out = F.gelu(F.max_pool2d(self.ebn2(self.encoder2(out)), 2, 2)) 647 | t2 = out # b, c1, H/4, W/4 648 | 649 | out = F.gelu(F.max_pool2d(self.ebn3(self.encoder3(out)), 2, 2)) 650 | t3 = out # b, c2, H/8, W/8 651 | 652 | out = F.gelu(F.max_pool2d(self.ebn4(self.encoder4(out)), 2, 2)) 653 | t4 = out # b, c3, H/16, W/16 654 | 655 | out = F.gelu(F.max_pool2d(self.ebn5(self.encoder5(out)), 2, 2)) 656 | t5 = out # b, c4, H/32, W/32 657 | 658 | t5 = self.mdag4(x=t5) 659 | t4 = self.mdag3(x=t4) 660 | t3 = self.mdag2(x=t3) 661 | t2 = self.mdag1(x=t2) 662 | t1 = self.mdag0(x=t1) 663 | 664 | out = F.gelu((self.ebn6(self.encoder6(out)))) 665 | out = self.ConvMixer(out) 666 | 667 | 668 | out5 = F.gelu(self.dbn1(self.decoder1(out))) 669 | out5 = torch.add(out5, t5) 670 | 671 | out4 = F.gelu(F.interpolate(self.dbn2(self.decoder2(out5)), scale_factor=(2, 2), mode='bilinear', 672 | align_corners=True)) 673 | out4 = torch.add(out4, t4) 674 | 675 | out3 = F.gelu(F.interpolate(self.dbn3(self.decoder3(out4)), scale_factor=(2, 2), mode='bilinear', 676 | align_corners=True)) 677 | out3 = torch.add(out3, t3) 678 | 679 | out2 = F.gelu(F.interpolate(self.dbn4(self.decoder4(out3)), scale_factor=(2, 2), mode='bilinear', 680 | align_corners=True)) 681 | out2 = torch.add(out2, t2) 682 | 683 | out1 = F.gelu(F.interpolate(self.dbn5(self.decoder5(out2)), scale_factor=(2, 2), mode='bilinear', 684 | align_corners=True)) 685 | out1 = torch.add(out1, t1) 686 | 687 | out0 = F.interpolate(self.final(out1), scale_factor=(2, 2), mode='bilinear', 688 | align_corners=True) 689 | 690 | return torch.sigmoid(out0) 691 | 692 | class LayerNorm(nn.Module): 693 | 694 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 695 | super().__init__() 696 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 697 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 698 | self.eps = eps 699 | self.data_format = data_format 700 | if self.data_format not in ["channels_last", "channels_first"]: 701 | raise NotImplementedError 702 | self.normalized_shape = (normalized_shape,) 703 | 704 | def forward(self, x): 705 | if self.data_format == "channels_last": 706 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 707 | elif self.data_format == "channels_first": 708 | u = x.mean(1, keepdim=True) 709 | s = (x - u).pow(2).mean(1, keepdim=True) 710 | x = (x - u) / torch.sqrt(s + self.eps) 711 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 712 | return x 713 | 714 | 715 | class GlobalLocalFilter(nn.Module): 716 | def __init__(self, dim, h=14, w=8): 717 | super().__init__() 718 | self.dw = nn.Sequential( 719 | nn.Conv2d(dim, dim // 2, kernel_size=1, padding=0, bias=False, groups=dim // 2), 720 | nn.Conv2d(dim // 2, dim // 2, kernel_size=3, padding=1, bias=False, groups=dim // 2), 721 | ) 722 | self.conv2d = nn.Conv2d(dim, dim // 2, kernel_size=1, padding=0, bias=False, groups=dim // 2) 723 | self.complex_weight = nn.Parameter(torch.randn(dim, h, w, 2, dtype=torch.float32) * 0.02) 724 | trunc_normal_(self.complex_weight, std=.02) 725 | self.pre_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first') 726 | self.post_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first') 727 | 728 | def forward(self, x): 729 | x = self.pre_norm(x) 730 | x1 = self.dw(x) 731 | 732 | x2 = x.to(torch.float32) 733 | B, C, a, b = x2.shape 734 | x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho') 735 | 736 | weight = self.complex_weight 737 | if not weight.shape[1:3] == x2.shape[2:4]: 738 | weight = F.interpolate(weight.permute(3, 0, 1, 2), size=x2.shape[2:4], mode='bilinear', 739 | align_corners=True).permute(1, 2, 3, 0) 740 | 741 | weight = torch.view_as_complex(weight.contiguous()) 742 | 743 | x2 = x2 * weight 744 | x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho') 745 | x2 = self.conv2d(x2) 746 | 747 | x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], dim=2).reshape(B, C, a, b) 748 | x = self.post_norm(x) 749 | return x 750 | 751 | 752 | class Mlp(nn.Module): 753 | """ Multilayer perceptron.""" 754 | 755 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 756 | super().__init__() 757 | out_features = out_features or in_features 758 | hidden_features = hidden_features or in_features 759 | self.fc1 = nn.Linear(in_features, hidden_features) 760 | self.act = act_layer() 761 | self.fc2 = nn.Linear(hidden_features, out_features) 762 | self.drop = nn.Dropout(drop) 763 | 764 | def forward(self, x): 765 | x = self.fc1(x) 766 | x = self.act(x) 767 | x = self.drop(x) 768 | x = self.fc2(x) 769 | x = self.drop(x) 770 | return x 771 | 772 | 773 | class FocalModulation(nn.Module): 774 | """ Focal Modulation 775 | 776 | Args: 777 | dim (int): Number of input channels. 778 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 779 | focal_level (int): Number of focal levels 780 | focal_window (int): Focal window size at focal level 1 781 | focal_factor (int, default=2): Step to increase the focal window 782 | use_postln (bool, default=False): Whether use post-modulation layernorm 783 | """ 784 | 785 | def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False): 786 | 787 | super().__init__() 788 | self.dim = dim 789 | 790 | # specific args for focalv3 791 | self.focal_level = focal_level 792 | self.focal_window = focal_window 793 | self.focal_factor = focal_factor 794 | self.use_postln = use_postln 795 | 796 | self.f = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=True) 797 | self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True) 798 | 799 | self.act = nn.GELU() 800 | self.proj = nn.Linear(dim, dim) 801 | self.proj_drop = nn.Dropout(proj_drop) 802 | self.focal_layers = nn.ModuleList() 803 | 804 | if self.use_postln: 805 | self.ln = nn.LayerNorm(dim) 806 | 807 | for k in range(self.focal_level): 808 | kernel_size = self.focal_factor * k + self.focal_window 809 | self.focal_layers.append( 810 | nn.Sequential( 811 | nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, 812 | padding=kernel_size // 2, bias=False), 813 | nn.GELU(), 814 | ) 815 | ) 816 | 817 | def forward(self, x): 818 | """ Forward function. 819 | 820 | Args: 821 | x: input features with shape of (B, H, W, C) 822 | """ 823 | B, nH, nW, C = x.shape 824 | x = self.f(x) 825 | x = x.permute(0, 3, 1, 2).contiguous() 826 | q, ctx, gates = torch.split(x, (C, C, self.focal_level + 1), 1) 827 | 828 | ctx_all = 0 829 | for l in range(self.focal_level): 830 | ctx = self.focal_layers[l](ctx) 831 | ctx_all = ctx_all + ctx * gates[:, l:l + 1] 832 | ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) 833 | ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:] 834 | 835 | x_out = q * self.h(ctx_all) 836 | x_out = x_out.permute(0, 2, 3, 1).contiguous() 837 | if self.use_postln: 838 | x_out = self.ln(x_out) 839 | x_out = self.proj(x_out) 840 | x_out = self.proj_drop(x_out) 841 | return x_out 842 | 843 | 844 | class FocalModulationBlock(nn.Module): 845 | """ Focal Modulation Block. 846 | 847 | Args: 848 | dim (int): Number of input channels. 849 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 850 | drop (float, optional): Dropout rate. Default: 0.0 851 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 852 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 853 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 854 | focal_level (int): number of focal levels 855 | focal_window (int): focal kernel size at level 1 856 | """ 857 | 858 | def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., 859 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, 860 | focal_level=2, focal_window=9, use_layerscale=False, layerscale_value=1e-4): 861 | super().__init__() 862 | self.dim = dim 863 | self.mlp_ratio = mlp_ratio 864 | self.focal_window = focal_window 865 | self.focal_level = focal_level 866 | self.use_layerscale = use_layerscale 867 | 868 | self.norm1 = norm_layer(dim) 869 | self.modulation = FocalModulation( 870 | dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop 871 | ) 872 | 873 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 874 | self.norm2 = norm_layer(dim) 875 | mlp_hidden_dim = int(dim * mlp_ratio) 876 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 877 | 878 | self.H = None 879 | self.W = None 880 | 881 | self.gamma_1 = 1.0 882 | self.gamma_2 = 1.0 883 | if self.use_layerscale: 884 | self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) 885 | self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) 886 | 887 | def forward(self, x): 888 | """ Forward function. 889 | 890 | Args: 891 | x: Input feature, tensor size (B, H*W, C). 892 | H, W: Spatial resolution of the input feature. 893 | """ 894 | B, C, H, W = x.shape 895 | x = x.view(B, H, W, C) 896 | x = self.norm1(x) 897 | shortcut = x.view(B, H * W, C) 898 | x = x.view(B, H, W, C) 899 | 900 | # FM 901 | x = self.modulation(x).view(B, H * W, C) 902 | 903 | # FFN 904 | x = shortcut + self.drop_path(self.gamma_1 * x) 905 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 906 | x = x.view(B, C, H, W) 907 | 908 | return x 909 | 910 | 911 | class BasicLayer(nn.Module): 912 | """ A basic focal modulation layer for one stage. 913 | 914 | Args: 915 | dim (int): Number of feature channels 916 | depth (int): Depths of this stage. 917 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 918 | drop (float, optional): Dropout rate. Default: 0.0 919 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 920 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 921 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 922 | focal_level (int): Number of focal levels 923 | focal_window (int): Focal window size at focal level 1 924 | use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False 925 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 926 | """ 927 | 928 | def __init__(self, 929 | dim, 930 | depth, 931 | mlp_ratio=4., 932 | drop=0., 933 | drop_path=0., 934 | norm_layer=nn.LayerNorm, 935 | downsample=None, 936 | focal_window=9, 937 | focal_level=2, 938 | use_conv_embed=False, 939 | use_layerscale=False, 940 | use_checkpoint=False 941 | ): 942 | super().__init__() 943 | self.depth = depth 944 | self.use_checkpoint = use_checkpoint 945 | 946 | # build blocks 947 | self.blocks = nn.ModuleList([ 948 | FocalModulationBlock( 949 | dim=dim, 950 | mlp_ratio=mlp_ratio, 951 | drop=drop, 952 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 953 | focal_window=focal_window, 954 | focal_level=focal_level, 955 | use_layerscale=use_layerscale, 956 | norm_layer=norm_layer) 957 | for i in range(depth)]) 958 | 959 | # patch merging layer 960 | if downsample is not None: 961 | self.downsample = downsample( 962 | patch_size=2, 963 | in_chans=dim, embed_dim=2 * dim, 964 | use_conv_embed=use_conv_embed, 965 | norm_layer=norm_layer, 966 | is_stem=False 967 | ) 968 | 969 | else: 970 | self.downsample = None 971 | 972 | def forward(self, x): 973 | """ Forward function. 974 | 975 | Args: 976 | x: Input feature, tensor size (B, H*W, C). 977 | H, W: Spatial resolution of the input feature. 978 | """ 979 | 980 | for blk in self.blocks: 981 | if self.use_checkpoint: 982 | x = checkpoint.checkpoint(blk, x) 983 | else: 984 | x = blk(x) 985 | return x 986 | 987 | 988 | class Residual(nn.Module): 989 | def __init__(self, fn): 990 | super().__init__() 991 | self.fn = fn 992 | 993 | def forward(self, x): 994 | return self.fn(x) + x 995 | 996 | 997 | # bottleneck 998 | class ConvMixerBlock(nn.Module): 999 | def __init__(self, dim=256, depth=7, k=7): 1000 | super(ConvMixerBlock, self).__init__() 1001 | self.block = nn.Sequential( 1002 | *[nn.Sequential( 1003 | Residual(nn.Sequential( 1004 | # deep wise 1005 | nn.Conv2d(dim, dim, kernel_size=(k, k), groups=dim, padding=(k // 2, k // 2)), 1006 | nn.GELU(), 1007 | nn.BatchNorm2d(dim) 1008 | )), 1009 | nn.Conv2d(dim, dim, kernel_size=(1, 1)), 1010 | nn.GELU(), 1011 | nn.BatchNorm2d(dim) 1012 | ) for i in range(depth)] 1013 | ) 1014 | 1015 | def forward(self, x): 1016 | x = self.block(x) 1017 | return x 1018 | 1019 | class MDAG(nn.Module): 1020 | """ 1021 | Multi-dilation attention gate 1022 | """ 1023 | 1024 | def __init__(self, channel, k_size=3, dilated_ratio=[7, 5, 2, 1]): 1025 | super(MDAG, self).__init__() 1026 | self.channel = channel 1027 | self.mda0 = nn.Sequential( 1028 | nn.Conv2d(channel, channel, kernel_size=k_size, stride=1, 1029 | padding=(k_size + (k_size - 1) * (dilated_ratio[0] - 1)) // 2, 1030 | dilation=dilated_ratio[0]), 1031 | nn.BatchNorm2d(self.channel)) 1032 | self.mda1 = nn.Sequential( 1033 | nn.Conv2d(channel, channel, kernel_size=k_size, stride=1, 1034 | padding=(k_size + (k_size - 1) * (dilated_ratio[1] - 1)) // 2, 1035 | dilation=dilated_ratio[1]), 1036 | nn.BatchNorm2d(self.channel)) 1037 | self.mda2 = nn.Sequential( 1038 | nn.Conv2d(channel, channel, kernel_size=k_size, stride=1, 1039 | padding=(k_size + (k_size - 1) * (dilated_ratio[2] - 1)) // 2, 1040 | dilation=dilated_ratio[2]), 1041 | nn.BatchNorm2d(self.channel)) 1042 | self.mda3 = nn.Sequential( 1043 | nn.Conv2d(channel, channel, kernel_size=k_size, stride=1, 1044 | padding=(k_size + (k_size - 1) * (dilated_ratio[3] - 1)) // 2, 1045 | dilation=dilated_ratio[3]), 1046 | nn.BatchNorm2d(self.channel)) 1047 | self.voteConv = nn.Sequential( 1048 | nn.Conv2d(self.channel * 4, self.channel, kernel_size=(1, 1)), 1049 | nn.BatchNorm2d(self.channel), 1050 | nn.Sigmoid() 1051 | ) 1052 | self.relu = nn.ReLU(inplace=True) 1053 | self.AG = Attention_block(F_g=self.channel, F_l=self.channel, F_int=self.channel) 1054 | 1055 | def forward(self, x): 1056 | x1 = self.mda0(x) 1057 | x2 = self.mda1(x) 1058 | x3 = self.mda2(x) 1059 | x4 = self.mda3(x) 1060 | _x = self.relu(torch.cat((x1, x2, x3, x4), dim=1)) 1061 | _x = self.voteConv(_x) 1062 | x = x + x * _x 1063 | return x 1064 | 1065 | --------------------------------------------------------------------------------