├── model ├── iterativeRefinementModels │ ├── __init__.py │ ├── RITM_modules │ │ ├── __init__.py │ │ ├── RITM_ocr.py │ │ └── RITM_hrnet.py │ └── RITM_SE_HRNet32.py └── __init__.py ├── train.sh ├── test.sh ├── dataset ├── transforms.py ├── __init__.py └── dataset.py ├── config └── spineweb_ours.yaml ├── morph_pairs └── dataset16 │ └── dataset16.json ├── misc ├── optimizer.py ├── loss.py ├── metric.py ├── heatmap_maker.py └── train.py ├── main.py ├── data_preprocessing.py ├── README.md └── util.py /model/iterativeRefinementModels/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/iterativeRefinementModels/RITM_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | seed="42" 4 | gpu='1' 5 | 6 | config='spineweb_ours' 7 | default_command="--seed ${seed} --config ${config}" 8 | custom_command="" 9 | CUDA_VISIBLE_DEVICES="${gpu}" python -u main.py ${default_command} ${custom_command} --save_test_prediction 10 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | seed="42" 4 | gpu='1' 5 | 6 | config='_' 7 | default_command="--seed ${seed} --config ${config}" 8 | custom_command="" 9 | CUDA_VISIBLE_DEVICES="${gpu}" python -u main.py ${default_command} ${custom_command} --save_test_prediction --only_test_version "ExpNum[00001]_Dataset[dataset16]_Model[RITM_SE_HRNet32]_config[spineweb_ours]_seed[42]" 10 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def get_model(save_manager): 4 | name = save_manager.config.Model.NAME 5 | 6 | if name == 'RITM_SE_HRNet32': 7 | from model.iterativeRefinementModels.RITM_SE_HRNet32 import RITM as Model 8 | else: 9 | save_manager.write_log('ERROR: NOT SPECIFIED MODEL NAME: {}'.format(name)) 10 | raise NotImplemented 11 | 12 | model = Model(save_manager.config) 13 | return model -------------------------------------------------------------------------------- /dataset/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import albumentations 3 | 4 | def default_aug(img_size): 5 | return albumentations.Compose([ 6 | albumentations.augmentations.geometric.rotate.SafeRotate((-15, 15), p=0.5, border_mode=cv2.BORDER_CONSTANT), 7 | albumentations.HorizontalFlip(p=0.5), 8 | albumentations.augmentations.geometric.resize.RandomScale((0.9, 1.2), p=0.5), 9 | albumentations.RandomBrightnessContrast(), 10 | albumentations.augmentations.geometric.resize.Resize(img_size[0], img_size[1], p=1) 11 | ], keypoint_params=albumentations.KeypointParams(format='xy', remove_invisible=False)) 12 | 13 | def fake(**kwargs): 14 | return {**kwargs} 15 | -------------------------------------------------------------------------------- /config/spineweb_ours.yaml: -------------------------------------------------------------------------------- 1 | PATH: 2 | ROOT_PATH: './' 3 | DATA: 4 | IMAGE: './data/dataset16_512/' 5 | TABLE: './data/dataset16_512/' 6 | 7 | Dataset: 8 | NAME: 'dataset16' 9 | image_size: [512, 256] 10 | num_keypoint: 68 11 | heatmap_std: 7.5 12 | aug: 13 | type: 'aug_default' 14 | p: [0.2, 0.2, 0.2, 0.2] 15 | subpixel_decoding_patch_size: 15 16 | subpixel_decoding: True 17 | 18 | Model: 19 | NAME: 'RITM_SE_HRNet32' 20 | SE_maxpool: True 21 | 22 | Optimizer: 23 | optimizer: 'Adam' 24 | lr: 0.001 25 | scheduler: '' 26 | 27 | Train: 28 | patience: 50 29 | batch_size: 4 30 | epoch: 5000 31 | metric: ["MAE", "RMSE", "MRE"] 32 | decision_metric: 'hargmax_mm_MRE' 33 | SR_standard: '' 34 | 35 | 36 | Hint: 37 | max_hint: 13 38 | num_dist: datset16 39 | 40 | 41 | MISC: 42 | TB: True 43 | gpu: '0' 44 | num_workers: 0 45 | 46 | Morph: 47 | use: True 48 | pairs: 'dataset16' 49 | angle_lambda: 0.01 50 | distance_lambda: 0.01 51 | distance_l1: True 52 | cosineSimilarityLoss: True 53 | threePointAngle: True 54 | -------------------------------------------------------------------------------- /morph_pairs/dataset16/dataset16.json: -------------------------------------------------------------------------------- 1 | [[[2, 4], [6, 8], [3, 5], [18, 20], [10, 12], [7, 9], [14, 16], [11, 13], [22, 24], [23, 25], [19, 21], [26, 28], [15, 17], [30, 32], [4, 6], [8, 12], [8, 10], [31, 33], [0, 2], [27, 29], [34, 36], [6, 10], [38, 40], [4, 8], [1, 3], [12, 16], [35, 37], [12, 14], [46, 48], [29, 33], [29, 31], [2, 6], [33, 37], [39, 41], [42, 44], [17, 21], [27, 31], [35, 39], [23, 27], [5, 7], [10, 14], [9, 13], [16, 20], [25, 27], [26, 30], [1, 5], [31, 35], [43, 45], [21, 25], [19, 23], [6, 12], [16, 18], [5, 9], [33, 35], [14, 18], [28, 30], [25, 29], [0, 4], [28, 32], [55, 57], [59, 61], [9, 11], [21, 23], [24, 26], [13, 17], [36, 40], [18, 22], [47, 49], [7, 11], [3, 7]], [[2, 63, 4], [2, 62, 4], [2, 61, 4], [2, 59, 4], [2, 60, 4], [2, 58, 4], [3, 63, 5], [3, 62, 5], [2, 57, 4], [2, 56, 4], [2, 55, 4], [3, 61, 5], [3, 60, 5], [2, 54, 4], [6, 63, 8], [3, 59, 5], [3, 58, 5], [2, 53, 4], [6, 62, 8], [2, 52, 4], [2, 51, 4], [6, 61, 8], [7, 62, 9], [3, 56, 5], [7, 63, 9], [3, 57, 5], [2, 50, 4], [6, 59, 8], [6, 60, 8], [3, 54, 5], [3, 55, 5], [7, 60, 9], [6, 58, 8], [2, 49, 4], [7, 61, 9], [2, 48, 4], [11, 62, 13], [11, 63, 13], [6, 57, 8], [0, 63, 2], [2, 47, 4], [7, 58, 9], [2, 46, 4], [3, 52, 5], [7, 59, 9], [10, 63, 12], [0, 62, 2], [3, 53, 5], [6, 56, 8], [6, 55, 8], [3, 50, 5], [0, 61, 2], [10, 62, 12], [11, 61, 13], [3, 51, 5], [7, 65, 9], [11, 60, 13], [6, 54, 8], [0, 60, 2], [7, 56, 9], [10, 61, 12], [2, 44, 4], [0, 59, 2], [2, 45, 4], [7, 57, 9], [0, 58, 2], [11, 59, 13], [11, 58, 13], [6, 53, 8], [3, 48, 5]]] -------------------------------------------------------------------------------- /misc/optimizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import optim 3 | 4 | def get_optimizer(config, model): 5 | optimizer = base_optimizer(config, model) 6 | return optimizer 7 | 8 | class base_optimizer(object): 9 | def __init__(self, config, model): 10 | super(base_optimizer, self).__init__() 11 | 12 | # optimizer 13 | if config.optimizer == 'Adam': 14 | self.optimizer = optim.Adam(model.parameters(), lr=config.lr) 15 | elif config.optimizer == 'Adadelta': 16 | self.optimizer = optim.Adadelta(model.parameters(), lr=config.lr) 17 | else: 18 | raise 19 | 20 | # scheduler 21 | if config.scheduler == 'ReduceLROnPlateau': 22 | from torch.optim.lr_scheduler import ReduceLROnPlateau 23 | self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', verbose=True) 24 | elif config.scheduler == 'StepLR': 25 | from torch.optim.lr_scheduler import StepLR 26 | self.scheduler = StepLR(self.optimizer, 100, gamma=0.1, last_epoch=-1) 27 | else: 28 | self.scheduler = None 29 | 30 | def update_model(self, loss): 31 | self.optimizer.zero_grad() 32 | 33 | if np.isnan(loss.item()): 34 | print('\n\n\nERROR::: THE LOSS IS NAN\n\n\n') 35 | raise() 36 | else: 37 | loss.backward() 38 | self.optimizer.step() 39 | return None 40 | 41 | def scheduler_step(self, metric): 42 | if self.scheduler is not None: 43 | self.scheduler.step(metric) 44 | return None -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import torch 5 | from dataset.dataset import Dataset, collate_fn 6 | 7 | def dataloader(config, split:str, data:list): 8 | # split : 'train' / 'val' / 'test' 9 | dataset = Dataset(config=config, split=split, data=data) 10 | 11 | # loader 12 | if split == 'train': 13 | shuffle = True 14 | drop_last = True 15 | else: 16 | shuffle = False 17 | drop_last = False 18 | 19 | def _init_fn(worker_id): 20 | np.random.seed(config.seed + worker_id) 21 | 22 | data_loader = torch.utils.data.DataLoader(dataset, shuffle=shuffle, worker_init_fn=_init_fn, 23 | batch_size=config.Train.batch_size, num_workers=config.MISC.num_workers, collate_fn=collate_fn, drop_last=drop_last) 24 | return data_loader 25 | 26 | def get_split_data(config, split=None): 27 | if split == 'train': 28 | with open(os.path.join(config.PATH.DATA.TABLE, 'train.json'), 'r') as f: 29 | train_data = json.load(f) 30 | return train_data 31 | elif split =='val': 32 | with open(os.path.join(config.PATH.DATA.TABLE, 'val.json'), 'r') as f: 33 | val_data = json.load(f) 34 | return val_data 35 | 36 | elif split =='test': 37 | with open(os.path.join(config.PATH.DATA.TABLE, 'test.json'), 'r') as f: 38 | test_data = json.load(f) 39 | return test_data 40 | 41 | def get_dataloader(config, split): 42 | data = get_split_data(config, split) 43 | loader = dataloader(config, split, data) 44 | return loader 45 | -------------------------------------------------------------------------------- /misc/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from munch import Munch 5 | 6 | class LossManager(): 7 | def __init__(self, config, heatmap_maker): 8 | self.config = config 9 | self.heatmap_maker = heatmap_maker 10 | self.mse_criterion = nn.MSELoss() 11 | self.mae_criterion = nn.L1Loss() 12 | self.bce_loss = nn.BCELoss() 13 | if self.config.Morph.use : 14 | if self.config.Morph.cosineSimilarityLoss: 15 | self.angle_criterion = nn.CosineEmbeddingLoss() 16 | else: 17 | self.angle_criterion = self.mse_criterion 18 | 19 | def __call__(self, pred_heatmap, label): 20 | loss, pred_heatmap = self.get_heatmap_loss(pred_heatmap=pred_heatmap, label_heatmap=label.heatmap) 21 | pred_coord = self.heatmap_maker.get_heatmap2sargmax_coord(pred_heatmap=pred_heatmap) 22 | if self.config.Morph.use: 23 | morph_loss = self.get_morph_loss(pred_coord=pred_coord, label_coord=label.coord, morph_loss_mask=label.morph_loss_mask) 24 | loss = loss + morph_loss 25 | 26 | if self.config.Morph.coord_use: 27 | coord_loss = nn.L1Loss()(pred_coord, label.coord) 28 | loss = loss + 0.01 * coord_loss 29 | 30 | if torch.isnan(loss).item(): 31 | print("========== ERROR ::: Loss is nan ===========") 32 | raise 33 | 34 | 35 | out = Munch.fromDict({'pred':{'sargmax_coord':pred_coord, 'heatmap':pred_heatmap}, 'loss':loss, }) 36 | return out 37 | 38 | def get_heatmap_loss(self, pred_heatmap, label_heatmap, mse_flag=False): 39 | if mse_flag: 40 | heatmap_loss = self.mse_criterion(pred_heatmap, label_heatmap) 41 | else: # BCE loss 42 | pred_heatmap = pred_heatmap.sigmoid() 43 | heatmap_loss = self.bce_loss(pred_heatmap, label_heatmap) 44 | return heatmap_loss, pred_heatmap 45 | 46 | def get_morph_loss(self, pred_coord, label_coord, morph_loss_mask): 47 | pred_dist, pred_angle = self.heatmap_maker.get_morph(pred_coord) 48 | with torch.no_grad(): 49 | label_dist, label_angle = self.heatmap_maker.get_morph(label_coord) 50 | 51 | if morph_loss_mask.sum()>0: 52 | pred_dist, pred_angle = pred_dist[morph_loss_mask], pred_angle[morph_loss_mask] 53 | label_dist, label_angle = label_dist[morph_loss_mask], label_angle[morph_loss_mask] 54 | 55 | if self.config.Morph.distance_l1: 56 | loss_dist = self.mae_criterion(pred_dist, label_dist) 57 | else: 58 | loss_dist = self.mse_criterion(pred_dist, label_dist) 59 | 60 | if self.config.Morph.cosineSimilarityLoss: 61 | N = pred_angle.shape[0] * pred_angle.shape[1] 62 | label_similairty = torch.ones(N, dtype=torch.long, device=pred_angle.device) 63 | loss_angle = self.angle_criterion(pred_angle.reshape(N, 2), label_angle.reshape(N, 2), label_similairty) 64 | else: 65 | pred_angle_normalized = torch.nn.functional.normalize(pred_angle, dim=-1) # (batch, 13, 2) 66 | label_angle_normalized = torch.nn.functional.normalize(label_angle, dim=-1) 67 | loss_angle = self.angle_criterion(pred_angle_normalized, label_angle_normalized) 68 | 69 | return loss_dist * self.config.Morph.distance_lambda + loss_angle * self.config.Morph.angle_lambda 70 | return 0 71 | 72 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from pytz import timezone 2 | import time 3 | import argparse 4 | import datetime 5 | 6 | import numpy as np 7 | import random 8 | import torch 9 | import torch.nn as nn 10 | 11 | from util import SaveManager, TensorBoardManager 12 | from model import get_model 13 | from dataset import get_dataloader 14 | from misc.metric import MetricManager 15 | from misc.optimizer import get_optimizer 16 | from misc.train import Trainer 17 | 18 | torch.set_num_threads(4) 19 | 20 | def set_seed(config): 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | torch.manual_seed(config.seed) 24 | torch.cuda.manual_seed(config.seed) 25 | torch.cuda.manual_seed_all(config.seed) 26 | np.random.seed(config.seed) 27 | random.seed(config.seed) 28 | return 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser(description='TMI experiments') 32 | parser.add_argument('--config', type=str, help='config name, required') 33 | parser.add_argument('--seed', type=int, default=42, help='random seed') 34 | parser.add_argument('--only_test_version', type=str, default=None, help='If activated, there is no training. The number is the experiment number. => load & test model') 35 | parser.add_argument('--save_test_prediction', action='store_true', default=False, help='If activated, save test predictions at save path') 36 | arg = parser.parse_args() 37 | set_seed(arg) 38 | return arg 39 | 40 | 41 | 42 | def main(save_manager): 43 | if save_manager.config.MISC.TB and save_manager.config.only_test_version is None: 44 | writer = TensorBoardManager(save_manager) 45 | else: 46 | writer = None 47 | 48 | # model initialization 49 | device_ids = list(range(len(save_manager.config.MISC.gpu.split(',')))) 50 | model = nn.DataParallel(get_model(save_manager), device_ids=device_ids) 51 | model.to(save_manager.config.MISC.device) 52 | 53 | # calculate the number of model parameters 54 | n_params = 0 55 | for k, v in model.named_parameters(): 56 | n_params += v.reshape(-1).shape[0] 57 | save_manager.write_log('Number of model parameters : {}'.format(n_params), 0) 58 | 59 | # optimizer initialization 60 | optimizer = get_optimizer(save_manager.config.Optimizer, model) 61 | 62 | metric_manager = MetricManager(save_manager) 63 | trainer = Trainer(model, metric_manager) 64 | 65 | if not save_manager.config.only_test_version: 66 | # dataloader 67 | train_loader = get_dataloader(save_manager.config, 'train') 68 | val_loader = get_dataloader(save_manager.config, 'val') 69 | 70 | # training 71 | save_manager.write_log('Start Training...'.format(n_params), 4) 72 | trainer.train(save_manager=save_manager, 73 | train_loader=train_loader, 74 | val_loader=val_loader, 75 | optimizer=optimizer, 76 | writer=writer) 77 | # deallocate data loaders from the memory 78 | del train_loader 79 | del val_loader 80 | 81 | trainer.best_param, trainer.best_epoch, trainer.best_metric = save_manager.load_model() 82 | 83 | save_manager.write_log('Start Test Evaluation...'.format(n_params), 4) 84 | test_loader = get_dataloader(save_manager.config, 'test') 85 | trainer.test(save_manager=save_manager, test_loader=test_loader, writer=writer) 86 | del test_loader 87 | 88 | 89 | if __name__ == '__main__': 90 | start_time = time.time() 91 | 92 | arg = parse_args() 93 | save_manager = SaveManager(arg) 94 | save_manager.write_log('Process Start ::: {}'.format(datetime.datetime.now(), n_mark=16)) 95 | 96 | main(save_manager) 97 | 98 | end_time = time.time() 99 | save_manager.write_log('Process End ::: {} {:.2f} hours'.format(datetime.datetime.now(), (end_time - start_time) / 3600), n_mark=16) 100 | save_manager.write_log('Version ::: {}'.format(save_manager.config.version), n_mark=16) 101 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | from munch import Munch 4 | import numpy as np 5 | 6 | import torch 7 | 8 | import dataset.transforms as transforms 9 | # from detectron2.data import transforms as T 10 | 11 | class Dataset(torch.utils.data.Dataset): 12 | def __init__(self, config, split:str, data:list): 13 | 14 | config.DICT_KEY = Munch.fromDict({}) 15 | config.DICT_KEY.IMAGE = 'image' 16 | config.DICT_KEY.BBOX = 'bbox_{}'.format(config.Dataset.image_size[0]) 17 | config.DICT_KEY.POINTS = 'points_{}'.format(config.Dataset.image_size[0]) 18 | config.DICT_KEY.RAW_SIZE = 'raw_size_row_col' 19 | config.DICT_KEY.PSPACE = 'pixelSpacing' 20 | 21 | 22 | 23 | 24 | # init 25 | self.config = config 26 | self.split = split 27 | self.data = data 28 | 29 | if self.split == 'train': 30 | self.transformer = transforms.default_aug(self.config.Dataset.image_size) 31 | 32 | else: 33 | self.transformer = transforms.fake 34 | 35 | self.loadimage = self.load_npy 36 | 37 | for item in self.data: 38 | item[self.config.DICT_KEY.IMAGE] = item[self.config.DICT_KEY.IMAGE].replace('.png', '.npy') 39 | 40 | def __len__(self): 41 | return len(self.data) 42 | 43 | def __getitem__(self, index): 44 | # indexing 45 | item = self.data[index] 46 | 47 | # image load 48 | img_path = os.path.join(self.config.PATH.DATA.IMAGE, item[self.config.DICT_KEY.IMAGE]) 49 | img, row, column = self.loadimage(img_path, item[self.config.DICT_KEY.RAW_SIZE]) 50 | 51 | # pixel spacing 52 | pspace_list = item[self.config.DICT_KEY.PSPACE] # row, column 53 | raw_size_and_pspace = torch.tensor([row, column] + pspace_list) 54 | 55 | # points load (13,2) (column, row)==(xy) 56 | coords = copy.deepcopy(item[self.config.DICT_KEY.POINTS]) 57 | coords.append([1.0,1.0]) 58 | 59 | 60 | transformed = self.transformer(image=img, keypoints=coords) 61 | img, coords = transformed["image"], transformed["keypoints"] 62 | additional = torch.tensor([]) 63 | 64 | coords = np.array(coords) 65 | 66 | # np array to tensor (800, 640)=(row, col) 67 | img = torch.tensor(img, dtype=torch.float) 68 | img = img.permute(2, 0, 1) 69 | img /= 255.0 # 0~255 to 0~1 70 | img = img * 2 - 1 # 0~1 to -1~1 71 | 72 | 73 | coords = torch.tensor(copy.deepcopy(coords[:, ::-1]), dtype=torch.float) 74 | morph_loss_mask = (coords[-1] == torch.tensor([1.0, 1.0], dtype=torch.float)).all() 75 | coords = coords[:-1] 76 | 77 | # hint 78 | if self.split == 'train': 79 | # random hint 80 | num_hint = np.random.choice(range(self.config.Dataset.num_keypoint ), size=None, p=self.config.Hint.num_dist) 81 | hint_indices = np.random.choice(range(self.config.Dataset.num_keypoint ), size=num_hint, replace=False) #[1,2,3] 82 | else: 83 | hint_indices = None 84 | 85 | return img_path, img, raw_size_and_pspace, hint_indices, coords, additional, index, morph_loss_mask 86 | 87 | def load_npy(self, img_path, size=None): 88 | img = np.load(img_path) 89 | if size is not None: 90 | row, column = size 91 | else: 92 | row, column = img.shape[:2] 93 | return img, row, column 94 | 95 | def collate_fn(batch): 96 | batch = list(zip(*batch)) 97 | batch_dict = { 98 | 'input_image_path':batch[0], # list 99 | 'input_image':torch.stack(batch[1]), 100 | 'label':{'morph_offset':torch.stack(batch[2]), 101 | 'coord': torch.stack(batch[4]), 102 | 'morph_loss_mask':torch.stack(batch[7]) 103 | }, 104 | 'pspace':torch.stack(batch[2]), 105 | 'hint':{'index': list(batch[3])}, 106 | 'additional': torch.stack(batch[5]), 107 | 'index':list(batch[6]) 108 | } 109 | return Munch.fromDict(batch_dict) -------------------------------------------------------------------------------- /data_preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from PIL import Image 5 | import scipy.io 6 | from tqdm.auto import tqdm 7 | import copy 8 | 9 | from albumentations import ( Resize, Compose, KeypointParams) 10 | 11 | source_path = './data/dataset16/boostnet_labeldata/' 12 | target_path = './data/' 13 | 14 | base_image_path= os.path.join(source_path,'data') 15 | base_label_path= os.path.join(source_path,'labels') 16 | 17 | 18 | 19 | 20 | train_image_paths = [] 21 | test_image_paths = [] 22 | 23 | for name in os.listdir(os.path.join(base_image_path,'training')): 24 | p = os.path.join(base_image_path, 'training/', name) 25 | train_image_paths.append(p) 26 | 27 | for name in os.listdir(os.path.join(base_image_path , 'test')): 28 | p = os.path.join(base_image_path, 'test/', name) 29 | test_image_paths.append(p) 30 | 31 | 32 | train_label_paths = [(i+'.mat').replace('boostnet_labeldata/data/','boostnet_labeldata/labels/') for i in train_image_paths] 33 | test_label_paths = [(i+'.mat').replace('boostnet_labeldata/data/','boostnet_labeldata/labels/') for i in test_image_paths] 34 | 35 | 36 | train_image_paths.sort() 37 | test_image_paths.sort() 38 | train_label_paths.sort() 39 | test_label_paths.sort() 40 | 41 | 42 | # select random validation dataset (val size = 128) 43 | val_idx = sorted(np.random.choice(range(len(train_image_paths)), size=128, replace=False, p=None)) 44 | train_idx = [i for i in range(len(train_image_paths)) if i not in val_idx] 45 | 46 | val_image_paths = np.array(train_image_paths)[val_idx].tolist() 47 | train_image_paths = np.array(train_image_paths)[train_idx].tolist() 48 | 49 | val_label_paths = np.array(train_label_paths)[val_idx].tolist() 50 | train_label_paths = np.array(train_label_paths)[train_idx].tolist() 51 | 52 | 53 | # make json items 54 | def make_data(image_paths, label_paths): 55 | data = [] 56 | 57 | 58 | 59 | for idx in range(len(image_paths)): 60 | 61 | item = {'image':None, 'label':None, 'raw_size_row_col':None, 'pixelSpacing':[1,1]} 62 | 63 | 64 | # indexing 65 | image_path = image_paths[idx] 66 | label_path = label_paths[idx] 67 | 68 | # loda data 69 | img = np.repeat(np.array(Image.open(os.path.join(image_path)))[:,:,None], 3, axis=-1) # (row, col) -> (row,col,3) 70 | label = scipy.io.loadmat(os.path.join(label_path))['p2'] # x,y 71 | 72 | # make items 73 | item['image'] = image_path.replace(source_path, '') # remove base path 74 | item['label'] = label.tolist() 75 | item['raw_size_row_col'] = (img.shape[0], img.shape[1]) 76 | 77 | data.append(item) 78 | return data 79 | 80 | train_data = make_data(train_image_paths, train_label_paths) 81 | val_data = make_data(val_image_paths, val_label_paths) 82 | test_data = make_data(test_image_paths, test_label_paths) 83 | 84 | 85 | def inference_aug(img_size): 86 | return Compose([ 87 | Resize(img_size[0], img_size[1]), 88 | ], keypoint_params=KeypointParams(format='xy')) 89 | 90 | 91 | # check keypoints are inside of the corresponding image 92 | print('train') 93 | remove_list = [] 94 | for t, item in (enumerate(train_data)): 95 | row, col = item['raw_size_row_col'] 96 | coord = np.array(item['label']) 97 | if coord[:, 1].max() > row: 98 | print('error:', t, '- exceed max y') 99 | remove_list.append(t) 100 | elif coord[:, 0].max() > col: 101 | print('error:', t, '- excced max x') 102 | remove_list.append(t) 103 | for t in remove_list: 104 | del train_data[t] 105 | 106 | print('val') 107 | remove_list=[] 108 | for t, item in (enumerate(val_data)): 109 | row, col = item['raw_size_row_col'] 110 | coord = np.array(item['label']) 111 | if coord[:, 1].max() > row: 112 | print('error:', t, '- exceed max y') 113 | remove_list.append(t) 114 | if coord[:, 0].max() > col: 115 | print('error:', t, '- excced max x') 116 | remove_list.append(t) 117 | for t in remove_list: 118 | del val_data[t] 119 | 120 | 121 | map_dic = {} 122 | for sizes in [(512, 256)]: 123 | aug = inference_aug((sizes[0], sizes[1])) 124 | print(sizes, 'starts') 125 | size = sizes[0] 126 | 127 | if not os.path.exists('{}/dataset16_{}'.format(target_path,size)): 128 | os.makedirs('{}/dataset16_{}'.format(target_path,size)) 129 | 130 | for temp in ['train', 'val', 'test']: 131 | print(temp, 'starts') 132 | 133 | if temp == 'train': 134 | table = copy.deepcopy(train_data) 135 | elif temp == 'val': 136 | table = copy.deepcopy(val_data) 137 | else: 138 | table = copy.deepcopy(test_data) 139 | 140 | for t, item in tqdm(enumerate(table)): 141 | # image load 142 | img_path = os.path.join( 143 | source_path, 144 | item['image']) 145 | img = np.repeat(np.array(Image.open(img_path))[:, :, None], 3, axis=-1) # (row, col) -> (row,col,3) 146 | 147 | points = item['label'] 148 | 149 | if img_path not in map_dic: 150 | map_dic[img_path] = {'points': points} 151 | 152 | transformed = aug(image=img, keypoints=points) 153 | img, points = transformed["image"], transformed["keypoints"] 154 | 155 | img_save_path = os.path.join(target_path,'dataset16_{}'.format(size), 156 | item['image'].replace('.jpg', '.npy')) 157 | if not os.path.exists( os.path.dirname(img_save_path)): 158 | os.makedirs(os.path.dirname(img_save_path)) 159 | np.save(img_save_path, img) 160 | item['points_{}'.format(size)] = points 161 | item['image'] = item['image'].replace('.jpg', '.npy') 162 | 163 | map_dic[img_path].update({'points_{}'.format(size): points}) 164 | 165 | with open(os.path.join(target_path,'dataset16_{}'.format(size),'{}.json'.format(temp)), 'w') as f: 166 | json.dump(table, f) 167 | -------------------------------------------------------------------------------- /misc/metric.py: -------------------------------------------------------------------------------- 1 | from munch import Munch 2 | from misc.heatmap_maker import HeatmapMaker, heatmap2hargmax_coord 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | class MetricManager(): 8 | def __init__(self, save_manager): 9 | self.config = save_manager.config 10 | 11 | if save_manager.config.Train.decision_metric.split('_')[-1] in ['MAE', 'MSE', 'MRE']: 12 | self.minimize_metric = True 13 | elif save_manager.config.Train.decision_metric.split('_')[-1] in ['AUC', 'ACC']: 14 | self.minimize_metric = False 15 | else: 16 | save_manager.write_log('ERROR::: Specify the decision metric in metric.py') 17 | raise() 18 | self.heatmap_maker = HeatmapMaker(save_manager.config) 19 | self.init_running_metric() 20 | self.device = self.config.MISC.device 21 | 22 | 23 | def init_running_metric(self): 24 | self.running_metric=Munch.fromDict({}) 25 | for name in ['hargmax_pixel','hargmax_mm','sargmax_pixel','sargmax_mm']: 26 | for metric in self.config.Train.metric: 27 | metric_name = '{}_{}'.format(name,metric) 28 | self.running_metric[metric_name] = [] 29 | 30 | if metric == 'MRE': 31 | for standard in self.config.Train.SR_standard: 32 | self.running_metric['{}_SR[{}]'.format(name, standard)] = [] 33 | 34 | def average_running_metric(self): 35 | 36 | self.metric = Munch.fromDict({}) 37 | for metric in self.running_metric: 38 | try: 39 | self.metric[metric] = np.mean(self.running_metric[metric]) 40 | except: 41 | self.metric[metric] = np.mean([m.mean() for m in self.running_metric[metric]]) 42 | self.init_running_metric() 43 | return self.metric 44 | 45 | def init_best_metric(self): 46 | if self.minimize_metric: 47 | best_metric = {self.config.Train.decision_metric: 1e4} 48 | else: 49 | best_metric = {self.config.Train.decision_metric: -1e4} 50 | return best_metric 51 | 52 | def is_new_best(self, old, new): 53 | if old[self.config.Train.decision_metric] > new[self.config.Train.decision_metric]: 54 | if self.minimize_metric: 55 | return True 56 | else: 57 | return False 58 | else: 59 | if self.minimize_metric: 60 | return False 61 | else: 62 | return True 63 | 64 | def measure_metric(self, pred, label, pspace, metric_flag, average_flag): 65 | with torch.no_grad(): 66 | # to cuda 67 | pspace = pspace.detach().to(self.device) 68 | label.coord = label.coord.detach().to(self.device) 69 | label.heatmap = label.heatmap.detach().to(self.device) 70 | pred.sargmax_coord = pred.sargmax_coord.detach().to(self.device) 71 | pred.heatmap = pred.heatmap.detach().to(self.device) 72 | 73 | if self.config.Model.facto_heatmap: 74 | pred.hargmax_coord = pred.hard_coord.float().to(self.device) 75 | else: 76 | # pred coord (hard-argmax) 77 | pred.hargmax_coord = heatmap2hargmax_coord(pred.heatmap) 78 | pred.hargmax_coord_mm = self.pixel2mm(self.config, pred.hargmax_coord, pspace) 79 | 80 | # pred coord (soft-argmax, model output) 81 | pred.sargmax_coord_mm = self.pixel2mm(self.config, pred.sargmax_coord, pspace) 82 | 83 | label.coord_mm = self.pixel2mm(self.config, label.coord, pspace) 84 | 85 | zip = [(pred.hargmax_coord, label.coord, 'hargmax_pixel'), 86 | (pred.hargmax_coord_mm, label.coord_mm, 'hargmax_mm'), 87 | (pred.sargmax_coord, label.coord, 'sargmax_pixel'), 88 | (pred.sargmax_coord_mm, label.coord_mm, 'sargmax_mm')] 89 | 90 | if metric_flag: 91 | metric_list = self.config.Train.metric 92 | else: 93 | metric_list = [self.config.Train.decision_metric.split('_')[-1]] 94 | 95 | for item in zip: 96 | pred_coord, label_coord, name = item 97 | 98 | if 'MAE' in metric_list: 99 | if average_flag: 100 | self.running_metric['{}_MAE'.format(name)].append(nn.L1Loss()(pred_coord, label_coord).item()) 101 | else: 102 | self.running_metric['{}_MAE'.format(name)].append( 103 | (nn.L1Loss(reduction='none')(pred_coord, label_coord)).cpu()) 104 | 105 | if 'RMSE' in metric_list: 106 | self.running_metric['{}_RMSE'.format(name)].append(torch.sqrt(nn.MSELoss()(pred_coord, label_coord)).item()) 107 | 108 | if 'MRE' in metric_list: 109 | # (batch, 13, 2) 110 | y_diff_sq = (pred_coord[:, :, 0] - label_coord[:, :, 0]) ** 2 111 | x_diff_sq = (pred_coord[:, :, 1] - label_coord[:, :, 1]) ** 2 112 | sqrt_x2y2 = torch.sqrt(y_diff_sq + x_diff_sq) # (batch,13) 113 | if average_flag: 114 | mre = torch.mean(sqrt_x2y2).item() 115 | else: 116 | mre = sqrt_x2y2.cpu() 117 | self.running_metric['{}_MRE'.format(name)].append(mre) 118 | 119 | for standard in self.config.Train.SR_standard: 120 | self.running_metric['{}_SR[{}]'.format(name, standard)].append((sqrt_x2y2 < standard).float().mean().item()) 121 | 122 | 123 | def pixel2mm(self, config, points, pspace): 124 | mm_points = torch.zeros_like(points) 125 | pspace = pspace.to(points.device) 126 | mm_points[:, :, 1] = points[:, :, 1] / config.Dataset.image_size[1] * pspace[:, 1:2] * pspace[:, 3:4] 127 | mm_points[:, :, 0] = points[:, :, 0] / config.Dataset.image_size[0] * pspace[:, 0:1] * pspace[:, 2:3] 128 | return mm_points 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /model/iterativeRefinementModels/RITM_modules/RITM_ocr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch._utils 9 | import torch.nn.functional as F 10 | 11 | 12 | class SpatialGather_Module(nn.Module): 13 | """ 14 | Aggregate the context features according to the initial 15 | predicted probability distribution. 16 | Employ the soft-weighted method to aggregate the context. 17 | """ 18 | 19 | def __init__(self, cls_num=0, scale=1): 20 | super(SpatialGather_Module, self).__init__() 21 | self.cls_num = cls_num 22 | self.scale = scale 23 | 24 | def forward(self, feats, probs): 25 | # feats: batch, c, H, W 26 | # probs: batch x c x num_keypoint x 1, 27 | batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) 28 | probs = probs.view(batch_size, c, -1) # (b, num_keypoint, -1) 29 | feats = feats.view(batch_size, feats.size(1), -1) 30 | feats = feats.permute(0, 2, 1) # batch x hw x c 31 | probs = F.softmax(self.scale * probs, dim=2) # batch x num_keypoint x hw 32 | ocr_context = torch.matmul(probs, feats) \ 33 | .permute(0, 2, 1).unsqueeze(3) # batch x c x num_keypoint x 1 34 | return ocr_context 35 | 36 | 37 | class SpatialOCR_Module(nn.Module): 38 | """ 39 | Implementation of the OCR module: 40 | We aggregate the global object representation to update the representation for each pixel. 41 | """ 42 | 43 | def __init__(self, 44 | in_channels, 45 | key_channels, 46 | out_channels, 47 | scale=1, 48 | dropout=0.1, 49 | norm_layer=nn.BatchNorm2d, 50 | align_corners=True): 51 | super(SpatialOCR_Module, self).__init__() 52 | self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale, 53 | norm_layer, align_corners) 54 | _in_channels = 2 * in_channels 55 | 56 | self.conv_bn_dropout = nn.Sequential( 57 | nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False), 58 | nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)), 59 | nn.Dropout2d(dropout) 60 | ) 61 | 62 | def forward(self, feats, proxy_feats): 63 | # proxy_feats=context : batch x c x num_keypoint x 1 64 | # feats: (b, feature_dim, H/4,W/4) 65 | context = self.object_context_block(feats, proxy_feats) 66 | 67 | output = self.conv_bn_dropout(torch.cat([context, feats], 1)) 68 | 69 | return output 70 | 71 | 72 | class ObjectAttentionBlock2D(nn.Module): 73 | ''' 74 | The basic implementation for object context block 75 | Input: 76 | N X C X H X W 77 | Parameters: 78 | in_channels : the dimension of the input feature map 79 | key_channels : the dimension after the key/query transform 80 | scale : choose the scale to downsample the input feature maps (save memory cost) 81 | bn_type : specify the bn type 82 | Return: 83 | N X C X H X W 84 | ''' 85 | 86 | def __init__(self, 87 | in_channels, 88 | key_channels, 89 | scale=1, 90 | norm_layer=nn.BatchNorm2d, 91 | align_corners=True): 92 | super(ObjectAttentionBlock2D, self).__init__() 93 | self.scale = scale 94 | self.in_channels = in_channels 95 | self.key_channels = key_channels 96 | self.align_corners = align_corners 97 | 98 | self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) 99 | self.f_pixel = nn.Sequential( 100 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 101 | kernel_size=1, stride=1, padding=0, bias=False), 102 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), 103 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, 104 | kernel_size=1, stride=1, padding=0, bias=False), 105 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) 106 | ) 107 | self.f_object = nn.Sequential( 108 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 109 | kernel_size=1, stride=1, padding=0, bias=False), 110 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), 111 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, 112 | kernel_size=1, stride=1, padding=0, bias=False), 113 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) 114 | ) 115 | self.f_down = nn.Sequential( 116 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 117 | kernel_size=1, stride=1, padding=0, bias=False), 118 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) 119 | ) 120 | self.f_up = nn.Sequential( 121 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, 122 | kernel_size=1, stride=1, padding=0, bias=False), 123 | nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True)) 124 | ) 125 | 126 | def forward(self, x, proxy): 127 | # proxy: batch, c, num_keypoint, 1 128 | batch_size, h, w = x.size(0), x.size(2), x.size(3) 129 | if self.scale > 1: 130 | x = self.pool(x) 131 | 132 | query = self.f_pixel(x).view(batch_size, self.key_channels, -1) # (b,c,hw) 133 | query = query.permute(0, 2, 1) # (b, hw, c) 134 | 135 | 136 | key = self.f_object(proxy).view(batch_size, self.key_channels, -1) # (b, c, num_key) 137 | value = self.f_down(proxy).view(batch_size, self.key_channels, -1) # (b, c, num_key) 138 | value = value.permute(0, 2, 1) # (b, num_key, c) 139 | 140 | sim_map = torch.matmul(query, key) # b, hw, num_key 141 | sim_map = (self.key_channels ** -.5) * sim_map #scaling 142 | sim_map = F.softmax(sim_map, dim=-1) # num_key softamx 143 | 144 | context = torch.matmul(sim_map, value) # b, hw, c) 145 | context = context.permute(0, 2, 1).contiguous() # b,c,hw 146 | context = context.view(batch_size, self.key_channels, *x.size()[2:]) # b, c, h, w 147 | context = self.f_up(context) # b, c, h, w 148 | if self.scale > 1: 149 | context = F.interpolate(input=context, size=(h, w), 150 | mode='bilinear', align_corners=self.align_corners) 151 | 152 | return context -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Morphology-Aware Interactive Keypoint Estimation (MICCAI 2022) - Official PyTorch Implementation 2 | 3 | [__[Paper]__](https://arxiv.org/abs/2209.07163)   4 |   5 | [__[Project page]__](https://seharanul17.github.io/interactive_keypoint_estimation/) 6 |     7 | [__[Video]__](https://youtu.be/Z5gtLviQ_TU) 8 | 9 | ## Introduction 10 | This is the official Pytorch implementation of [Morphology-Aware Interactive Keypoint Estimation](). 11 | 12 | Diagnosis based on medical images, such as X-ray images, often involves manual annotation of anatomical keypoints. However, this process involves significant human efforts and can thus be a bottleneck in the diagnostic process. To fully automate this procedure, deep-learning-based methods have been widely proposed and have achieved high performance in detecting keypoints in medical images. However, these methods still have clinical limitations: accuracy cannot be guaranteed for all cases, and it is necessary for doctors to double-check all predictions of models. In response, we propose a novel deep neural network that, given an X-ray image, automatically detects and refines the anatomical keypoints through a user-interactive system in which doctors can fix mispredicted keypoints with fewer clicks than needed during manual revision. Using our own collected data and the publicly available AASCE dataset, we demonstrate the effectiveness of the proposed method in reducing the annotation costs via extensive quantitative and qualitative results. 13 | 14 | ## Environment 15 | The code was developed using python 3.8 on Ubuntu 18.04. 16 | 17 | The experiments were performed on a single GeForce RTX 3090 in the training and evaluation phases. 18 | 19 | ## Quick start 20 | 21 | ### Prerequisites 22 | Install following dependencies: 23 | - Python 3.8 24 | - torch == 1.8.0 25 | - albumentations == 1.1.0 26 | - munch 27 | - tensorboard 28 | - pytz 29 | - tqdm 30 | 31 | 32 | ### Preparing code and model files 33 | 34 | 35 | #### Directory layout 36 |  . 37 | ├── code 38 | │     ├── data 39 | │     ├── pretrained_models 40 | │     ├── data_preprocessing.py 41 | │     ├── train.sh 42 | │     ├── test.sh 43 | │     └── ... 44 | ├── save 45 | └── ... 46 | 47 | 1. Clone this repository in the ``code`` folder: 48 | ``` 49 | git clone https://github.com/seharanul17/interactive_keypoint_estimation code 50 | ``` 51 | 52 | 2. Create ``code/pretrained_models`` and ``save`` folders. 53 | ``` 54 | mkdir code/pretrained_models 55 | mkdir save 56 | ``` 57 | 58 | 3. To train our model using the pretrained HRNet backbone model, download the model file from the [HRNet Github repository](https://github.com/HRNet/HRNet-Image-Classification). 59 | Place the downloaded file in the ``pretrained_models`` folder. Related code line can be found [here](https://github.com/seharanul17/interactive_keypoint_estimation/blob/7f50ec271b9ae9613c839533d3958110405d04f5/model/iterativeRefinementModels/RITM_SE_HRNet32.py#L29). 60 | 61 | 62 | 4. To test our pre-trained model, download our model file and config file from [here](https://www.dropbox.com/sh/m53iqw9loddqhfq/AAD0KuCCxpXsBE435Hw3KJU8a?dl=0). 63 | Place the downloaded folder contatining the files into the ``save`` folder. 64 | Related code line can be found [here](https://github.com/seharanul17/interactive_keypoint_estimation/blob/7f50ec271b9ae9613c839533d3958110405d04f5/util.py#L77). 65 | 66 | 67 | 68 | ### Preparing data 69 | 70 | We provide the code to conduct experiments on a public dataset, the AASCE challenge dataset. 71 | 72 | 1. Create the ``data`` folder inside the ``code`` folder. 73 | ``` 74 | cd code 75 | mkdir data 76 | ``` 77 | 78 | 2. Download the data and place it inside the ``data`` folder. 79 | - The AASCE challenge dataset can be obtained from [SpineWeb](http://spineweb.digitalimaginggroup.ca/index.php?n=main.datasets#Dataset_16.3A_609_spinal_anterior-posterior_x-ray_images). 80 | - The AASCE challenge dataset corresponds to `Dataset 16: 609 spinal anterior-posterior x-ray images' on the webpage. 81 | 82 | 3. Preprocess the downloaded data. 83 | - Related code line is [here](https://github.com/seharanul17/interactive_keypoint_estimation/blob/b85c22e26dd315289219cbe1baecdc815ba1d097/data_preprocessing.py#L11). 84 | - Run the following command: 85 | ``` 86 | python data_preprocessing.py 87 | ``` 88 | 89 | 90 | ### Usage 91 | - To run the training code, run the following command: 92 | ``` 93 | bash train.sh 94 | ``` 95 | - To test the pre-trained model: 96 | 1. Locate the pre-trained model in the ``../save/`` folder. 97 | 2. Run the test code: 98 | ``` 99 | bash test.sh 100 | ``` 101 | - To test your own model: 102 | 1. Change the value of the argument ``--only_test_version {your_model_name}`` in the ``test.sh`` file. 103 | 2. Run the test code: 104 | ``` 105 | bash test.sh 106 | ``` 107 | 108 | When the evaluation ends, the mean radial error (MRE) of model prediction and manual revision will be reported. 109 | The ``sargmax_mm_MRE`` corresponds to the MRE reported in Fig. 4. 110 | 111 | 112 | ## Results 113 | The following table compares the refinement performance of our proposed interactive model and manual revision. 114 | Both models revise the same initial prediction results of our model. The number of user modifications is prolonged from zero (initial prediction) to five. 115 | The model performance is measured using mean radial errors on the AASCE dataset. 116 | For more information, please see Fig. 4 in our main manuscript. 117 | 118 | - "Ours (model revision)" indicates automatically revised results by the proposed interactive keypoint estimation approach. 119 | - "Ours (manual revision)" indicates fully-manually revised results by a user without the assistance of an interactive model. 120 | 121 | | Method | No. of user modification | | | | | | 122 | |:----------------: |:-----------------------: |:-------: |:-------: |:-------: |:-------: |:-------: | 123 | | | 0 (initial prediction) | 1 | 2 | 3 | 4 | 5 | 124 | | Ours (model revision) | 58.58 | 35.39 | 29.35 | 24.02 | 21.06 | 17.67 | 125 | | Ours (manual revision) | 58.58 | 55.85 | 53.33 | 50.90 | 48.55 | 47.03 | 126 | 127 | 128 | ## Citation 129 | 130 | If you find this work or code is helpful in your research, please cite: 131 | ``` 132 | @inproceedings{kim2022morphology, 133 | title={Morphology-Aware Interactive Keypoint Estimation}, 134 | author={Kim, Jinhee and 135 | Kim, Taesung and 136 | Kim, Taewoo and 137 | Choo, Jaegul and 138 | Kim, Dong-Wook and 139 | Ahn, Byungduk and 140 | Song, In-Seok and 141 | Kim, Yoon-Ji}, 142 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 143 | pages={675--685}, 144 | year={2022}, 145 | organization={Springer} 146 | } 147 | ``` 148 | 149 | -------------------------------------------------------------------------------- /model/iterativeRefinementModels/RITM_SE_HRNet32.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import numpy as np 6 | 7 | from misc.heatmap_maker import HeatmapMaker, heatmap2hargmax_coord 8 | from misc.loss import LossManager 9 | 10 | from model.iterativeRefinementModels.RITM_modules.RITM_hrnet import HighResolutionNet 11 | 12 | class RITM(nn.Module): 13 | def __init__(self, config): 14 | super(RITM, self).__init__() 15 | # default 16 | self.config = config 17 | self.device = config.MISC.device 18 | 19 | self.heatmap_maker = HeatmapMaker(config) 20 | self.LossManager = LossManager(config, self.heatmap_maker) 21 | 22 | 23 | # RITM hyper-param 24 | self.max_iter = 3 25 | 26 | # backbone 27 | self.ritm_hrnet = HighResolutionNet(width=32, ocr_width=128, small=False, num_classes=config.Dataset.num_keypoint, 28 | norm_layer=nn.BatchNorm2d, addHintEncSENet=True, SE_maxpool=config.Model.SE_maxpool, SE_softmax=config.Model.SE_softmax) 29 | self.ritm_hrnet.load_pretrained_weights('./pretrained_models/hrnetv2_w32_imagenet_pretrained.pth') 30 | 31 | 32 | if self.config.Model.no_iterative_training: 33 | hintmap_input_channels = config.Dataset.num_keypoint 34 | else: 35 | hintmap_input_channels = config.Dataset.num_keypoint*2 36 | 37 | self.hintmap_encoder = nn.Sequential( 38 | nn.Conv2d(in_channels=hintmap_input_channels, out_channels=16, kernel_size=1), 39 | nn.LeakyReLU(negative_slope=0.2), 40 | nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1), 41 | ScaleLayer(init_value=0.05, lr_mult=1) 42 | ) 43 | 44 | def forward_model(self, input_image, input_hint_heatmap, prev_pred_heatmap): 45 | if prev_pred_heatmap is None: 46 | hint_features = input_hint_heatmap 47 | else: 48 | hint_features = torch.cat((prev_pred_heatmap, input_hint_heatmap), dim=1) 49 | 50 | encoded_input_hint_heatmap = self.hintmap_encoder(hint_features) 51 | pred_logit, aux_pred_logit = self.ritm_hrnet(input_image, encoded_input_hint_heatmap, input_hint_heatmap=input_hint_heatmap) 52 | 53 | pred_logit = F.interpolate(pred_logit, size=self.config.Dataset.image_size, mode='bilinear', align_corners=True) 54 | aux_pred_logit = F.interpolate(aux_pred_logit, size=self.config.Dataset.image_size, mode='bilinear', align_corners=True) 55 | 56 | return pred_logit, aux_pred_logit 57 | 58 | def forward(self, batch): 59 | #make label, hint heatmap 60 | with torch.no_grad(): 61 | batch.label.coord = batch.label.coord.to(self.device) 62 | batch.label.heatmap = self.heatmap_maker.coord2heatmap(batch.label.coord) 63 | batch.hint.heatmap = torch.zeros_like(batch.label.heatmap) 64 | for i in range(batch.label.heatmap.shape[0]): 65 | if batch.hint.index[i] is not None: 66 | batch.hint.heatmap[i, batch.hint.index[i]] = batch.label.heatmap[i, batch.hint.index[i]] 67 | 68 | input_image = batch.input_image.to(self.device) 69 | input_hint_heatmap = batch.hint.heatmap.to(self.device) 70 | 71 | if batch.is_training: 72 | # 1. random number of iteration without update 73 | if self.config.Model.no_iterative_training: 74 | pred_heatmap = None 75 | else: 76 | with torch.no_grad(): 77 | self.eval() 78 | num_iters = np.random.randint(0, self.max_iter) 79 | pred_heatmap = torch.zeros_like(input_hint_heatmap) 80 | for click_indx in range(num_iters): 81 | # prediction 82 | pred_logit, aux_pred_logit = self.forward_model(input_image, input_hint_heatmap, pred_heatmap) 83 | pred_heatmap = pred_logit.sigmoid() 84 | 85 | # hint update (training 때는 hint를 줘가면서 update하는 과정을 거침) 86 | batch = self.get_next_points(batch, pred_heatmap) 87 | for i in range(batch.label.heatmap.shape[0]): 88 | if batch.hint.index[i] is not None: 89 | batch.hint.heatmap[i, batch.hint.index[i]] = batch.label.heatmap[i, batch.hint.index[i]] 90 | 91 | self.train() 92 | # 2. forward for model update 93 | pred_logit, aux_pred_logit = self.forward_model(input_image, input_hint_heatmap, pred_heatmap) 94 | out = self.LossManager(pred_heatmap=pred_logit, label=batch.label) 95 | out.loss += self.LossManager.get_heatmap_loss(pred_heatmap=aux_pred_logit, label_heatmap=batch.label.heatmap)[0] 96 | else: # test 97 | # forward 98 | if self.config.Model.no_iterative_training: 99 | pred_heatmap = None 100 | else: 101 | if batch.get('prev_heatmap', None) is None: 102 | pred_heatmap = torch.zeros_like(input_hint_heatmap) 103 | else: 104 | pred_heatmap = batch.prev_heatmap.to(input_hint_heatmap.device) 105 | pred_logit, aux_pred_logit = self.forward_model(input_image, input_hint_heatmap, pred_heatmap) 106 | out = self.LossManager(pred_heatmap=pred_logit, label=batch.label) 107 | return out, batch 108 | 109 | def get_next_points(self, batch, pred_heatmap): 110 | worst_index = self.find_worst_pred_index(batch, pred_heatmap) 111 | for i, idx in enumerate(batch.hint.index): 112 | if idx is None: 113 | batch.hint.index[i] = worst_index[i] # (batch, 1) 114 | else: 115 | if not torch.is_tensor(idx): 116 | batch.hint.index[i] = torch.tensor(batch.hint.index[i], dtype=torch.long, device=worst_index.device) 117 | batch.hint.index[i] = torch.cat((batch.hint.index[i], worst_index[i])) # ... (batch, max hint) 118 | return batch 119 | 120 | def find_worst_pred_index(self, batch, pred_heatmap): 121 | # ==== calculate pixel MRE ==== 122 | with torch.no_grad(): 123 | hargmax_coord_pred = heatmap2hargmax_coord(pred_heatmap) 124 | batch_metric_value = torch.sqrt(torch.sum((hargmax_coord_pred-batch.label.coord)**2,dim=-1)) #MRE (batch, 13) 125 | for j, idx in enumerate(batch.hint.index): 126 | if idx is not None: 127 | batch_metric_value[j, idx] = torch.full_like(batch_metric_value[j,idx], -1000) 128 | worst_index = batch_metric_value.argmax(-1, keepdim=True) 129 | return worst_index 130 | 131 | 132 | class ScaleLayer(nn.Module): 133 | def __init__(self, init_value=1.0, lr_mult=1): 134 | super().__init__() 135 | self.lr_mult = lr_mult 136 | self.scale = nn.Parameter( 137 | torch.full((1,), init_value / lr_mult, dtype=torch.float32) 138 | ) 139 | 140 | def forward(self, x): 141 | scale = torch.abs(self.scale * self.lr_mult) 142 | return x * scale 143 | -------------------------------------------------------------------------------- /misc/heatmap_maker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class HeatmapMaker(): 4 | def __init__(self, config): 5 | self.config = config 6 | self.image_size = config.Dataset.image_size 7 | self.heatmap_std = config.Dataset.heatmap_std 8 | self.morph_pairs = config.Morph.pairs 9 | # self.morph_unit_vector_size = config.Morph.unit_vector_size 10 | 11 | def make_gaussian_heatmap(self, mean, size, std): 12 | # coord : (13,2) 13 | mean = mean.unsqueeze(1).unsqueeze(1) 14 | var = std ** 2 # 64, 1 15 | grid = torch.stack(torch.meshgrid([torch.arange(size[0]), torch.arange(size[1])]), dim=-1).unsqueeze(0) 16 | grid = grid.to(mean.device) 17 | x_minus_mean = grid - mean # 13, 1024, 1024, 2 18 | 19 | # (x-u)^2: (13, 512, 512, 2) inverse_cov: (1, 1, 1, 1) > (13, 512, 512) 20 | gaus = (-0.5 * (x_minus_mean.pow(2) / var)).sum(-1).exp() 21 | if self.config.Dataset.get('heatmap_encoding_maxone', False): 22 | gaus /= gaus.max(1, keepdim=True)[0].max(2, keepdim=True)[0] #(13, 512, 512) 23 | # (13, 512, 512) 24 | return gaus 25 | 26 | 27 | 28 | # subpixel paper encoding 29 | def sample_gaussian_heatmap(self, mean): 30 | integer = torch.floor(mean) 31 | rest = mean - integer 32 | grid_gaussian = self.make_gaussian_heatmap(rest, (2, 2), 0.5) 33 | sampled_offset = torch.multinomial(grid_gaussian.reshape(mean.shape[0], 4), num_samples=1) 34 | row = torch.floor(torch.div(sampled_offset, 2)) #rounding_mode='floor') # (13,1) 35 | col = torch.remainder(sampled_offset, 2) # (13,1) 36 | 37 | new_sampled_offset = torch.cat((row, col), dim=1) 38 | return new_sampled_offset + integer 39 | 40 | def coord2heatmap(self, coord): 41 | # coord : (batch, 13, 2), torch tensor, gpu 42 | with torch.no_grad(): 43 | if self.config.Dataset.get('subpixel_heatmap_encoding', False): 44 | coord = torch.stack([self.sample_gaussian_heatmap(coord_item) for coord_item in coord]) 45 | heatmap = torch.stack([ 46 | self.make_gaussian_heatmap(coord_item, size=self.image_size, std=self.heatmap_std) for coord_item in coord]) 47 | return heatmap 48 | 49 | def get_heatmap2sargmax_coord(self, pred_heatmap): 50 | if self.config.Dataset.subpixel_decoding: 51 | pred_coord = self.heatmap2subpixel_argmax_coord(pred_heatmap) 52 | else: 53 | pred_coord = self.heatmap2sargmax_coord(pred_heatmap) 54 | return pred_coord 55 | 56 | def heatmap2sargmax_coord(self, heatmap): 57 | # heatmap: batch, 13, 1024, 1024 (batch, c, row, column) 58 | 59 | if self.config.Dataset.label_smoothing: 60 | heatmap = torch.clamp((heatmap - 0.1) / 0.8, 0, 1) 61 | 62 | pred_col = torch.sum(heatmap, (-2)) # bach, c, column 63 | pred_row = torch.sum(heatmap, (-1)) # batch, c, row 64 | 65 | # 1, 1, 1024 66 | mesh_c = torch.arange(pred_col.shape[-1]).unsqueeze(0).unsqueeze(0).to(heatmap.device) 67 | mesh_r = torch.arange(pred_row.shape[-1]).unsqueeze(0).unsqueeze(0).to(heatmap.device) 68 | 69 | # batch, 13 70 | coord_c = torch.sum(pred_col * mesh_c, (-1)) / torch.sum(pred_col, (-1)) 71 | coord_r = torch.sum(pred_row * mesh_r, (-1)) / torch.sum(pred_row, (-1)) 72 | 73 | # batch, 13, 2 (row, column) 74 | coord = torch.stack((coord_r, coord_c), -1) 75 | 76 | return coord 77 | 78 | def heatmap2subpixel_argmax_coord(self, heatmap): # row: y, col: x tensor->(row, col)=(y,x) 79 | hard_coord = heatmap2hargmax_coord(heatmap).to(heatmap.device) #(batch, 13, 2) 80 | patch_size = self.config.Dataset.subpixel_decoding_patch_size # scalar 81 | reshaped_heatmap = heatmap.reshape(heatmap.shape[0] * heatmap.shape[1], 1, heatmap.shape[2], heatmap.shape[3]) 82 | reshaped_hard_coord = hard_coord.reshape(hard_coord.shape[0] * hard_coord.shape[1], 1, 83 | hard_coord.shape[2]) # (bk, 1, 2) 84 | 85 | patch_index = torch.arange(patch_size, device=heatmap.device) - patch_size // 2 # (5) 86 | patch_index_y = patch_index[:, None].expand(patch_size, patch_size)[None, :] # (1, 5, 5) 87 | patch_index_x = patch_index[None, :].expand(patch_size, patch_size)[None, :] # (1, 5, 5) 88 | patch_index_y = patch_index_y + reshaped_hard_coord[:, :, 0:1] # (bk, 5, 5) 89 | patch_index_x = patch_index_x + reshaped_hard_coord[:, :, 1:2] 90 | 91 | # pad heatmap 92 | padded_reshaped_heatmap = torch.nn.functional.pad(reshaped_heatmap, 93 | (patch_size, patch_size, patch_size, patch_size), 94 | mode='constant', value=0.0) # pad (left, right, top, bottom) 95 | pad_patch_index_y = patch_size + patch_index_y 96 | pad_patch_index_x = patch_size + patch_index_x 97 | 98 | patch = padded_reshaped_heatmap[:, :, pad_patch_index_y.long(), pad_patch_index_x.long()] 99 | 100 | patch = patch.diagonal(dim1=0, dim2=2).permute(dims=[3, 0, 1, 2]) 101 | patch = patch.reshape(heatmap.shape[0], heatmap.shape[1], patch_size, patch_size) #b, 13, 5, 5 102 | patch = (patch*10).reshape(patch.shape[0], patch.shape[1], -1).softmax(-1).reshape(patch.shape) 103 | 104 | soft_patch_offset = self.heatmap2sargmax_coord(patch) #batch, 13, 2 105 | final_coord = hard_coord + soft_patch_offset - patch_size//2 106 | return final_coord 107 | 108 | def get_morph(self, points): 109 | # normalize 110 | points = self.points_normalize(points, dim0=self.image_size[0], dim1=self.image_size[1]) 111 | 112 | dist_pairs = torch.tensor(self.morph_pairs[0]) 113 | vec_pairs = torch.tensor(self.morph_pairs[1]) 114 | 115 | # dist 116 | # batch, 13, 2 117 | from_vec = points[:, dist_pairs[:, 0]] 118 | to_vec = points[:, dist_pairs[:, 1]] 119 | diffs = to_vec - from_vec 120 | pred_dist = torch.norm(diffs, dim = -1).unsqueeze(-1) # (batch, 16,1) 121 | 122 | # unit_vec_pairs 123 | if self.config.Morph.threePointAngle: 124 | # batch, 13, 2 125 | x = points[:, vec_pairs[:, 0]] 126 | y = points[:, vec_pairs[:, 1]] 127 | z = points[:, vec_pairs[:, 2]] 128 | pred_angle = self.get_angle(x,y,z) 129 | else: 130 | from_vec = points[:, vec_pairs[:, 0]] 131 | to_vec = points[:, vec_pairs[:, 1]] 132 | diffs = to_vec - from_vec 133 | pred_angle = diffs 134 | 135 | return pred_dist, pred_angle 136 | 137 | def get_angle(self, x, y, z): 138 | # y가 꼭짓점임. z-y-x 139 | N = x.shape[0] * x.shape[1] 140 | delta_vector_1 = torch.reshape(x - y, (N, 2)) # batch*16, 2 141 | delta_vector_2 = torch.reshape(z - y, (N, 2)) 142 | 143 | # (y,x) = (row, col) 144 | angle_1 = torch.atan2(delta_vector_1[:, 0], delta_vector_1[:, 1] + 1e-8) 145 | angle_2 = torch.atan2(delta_vector_2[:, 0], delta_vector_2[:, 1] + 1e-8) 146 | 147 | delta_angle = angle_2 - angle_1 148 | 149 | angle_vector = torch.stack((torch.cos(delta_angle), torch.sin(delta_angle)), -1) 150 | 151 | angle_vector = torch.reshape(angle_vector, (x.shape[0], x.shape[1], 2)) 152 | return angle_vector 153 | 154 | def points_normalize(self, points, dim0, dim1): 155 | new_coord = torch.zeros_like(points) 156 | new_coord[..., 0] = points[..., 0] / dim0 157 | new_coord[..., 1] = points[..., 1] / dim1 158 | 159 | return new_coord 160 | 161 | def heatmap2hargmax_coord(heatmap): 162 | b, c, row, column = heatmap.shape 163 | heatmap = heatmap.reshape(b, c, -1) 164 | max_indices = heatmap.argmax(-1) 165 | keypoint = torch.zeros(b, c, 2, device=heatmap.device) 166 | # keypoint[:, :, 0] = torch.floor(torch.div(max_indices, column)) # old environment 167 | keypoint[:, :, 0] = torch.div(max_indices, column, rounding_mode='floor') 168 | keypoint[:, :, 1] = max_indices % column 169 | return keypoint 170 | 171 | -------------------------------------------------------------------------------- /misc/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from munch import Munch 3 | import torch 4 | import copy 5 | from misc.metric import MetricManager 6 | from tqdm.auto import tqdm 7 | 8 | class Trainer(object): 9 | def __init__(self, model, metric_manager): 10 | self.model = model 11 | 12 | self.metric_manager = metric_manager 13 | self.best_epoch = None 14 | self.best_param = None 15 | self.best_metric = self.metric_manager.init_best_metric() 16 | 17 | self.patience = 0 18 | 19 | def train(self, save_manager, train_loader, val_loader, optimizer, writer=None): 20 | for epoch in range(1, save_manager.config.Train.epoch+1): 21 | start_time = time.time() 22 | 23 | #train 24 | self.model.train() 25 | train_loss = 0 26 | for i, batch in enumerate(train_loader): 27 | batch.is_training = True 28 | out, batch = self.forward_batch(batch) 29 | optimizer.update_model(out.loss) 30 | train_loss += out.loss.item() 31 | 32 | if writer is not None: 33 | writer.write_loss(out.loss.item(), 'train') 34 | if i % 50 == 0: 35 | save_manager.write_log('iter [{}/{}] loss [{:.6f}]'.format(i, len(train_loader), out.loss.item())) 36 | 37 | train_loss = train_loss / len(train_loader) 38 | train_metric = self.metric_manager.average_running_metric() 39 | 40 | if writer is not None: 41 | writer.write_metric(train_metric, 'train') 42 | 43 | #val 44 | self.model.eval() 45 | with torch.no_grad(): 46 | val_loss = 0 47 | for i, batch in enumerate(val_loader): 48 | batch.is_training = True 49 | out, batch = self.forward_batch(batch) 50 | val_loss += out.loss.item() 51 | 52 | if writer is not None: 53 | writer.write_loss(out.loss.item(), 'val') 54 | if epoch % 20 == 0 and i == 0: 55 | writer.plot_image_heatmap(image=batch.input_image.cpu(), pred_heatmap=out.pred.heatmap.cpu(), label_heatmap=batch.label.heatmap.cpu(), epoch=epoch) 56 | val_metric = self.metric_manager.average_running_metric() 57 | val_loss = val_loss / len(val_loader) 58 | 59 | optimizer.scheduler_step(val_metric[save_manager.config.Train.decision_metric]) 60 | 61 | 62 | if writer is not None: 63 | writer.write_metric(val_metric, 'val') 64 | if epoch % 5 == 0: 65 | writer.plot_model_param_histogram(self.model, epoch) 66 | 67 | save_manager.write_log('Epoch [{}/{}] train loss [{:.6f}] train {} [{:.6f}] val loss [{:.6f}] val {} [{:.6f}] Epoch time [{:.1f}min]'.format(epoch, 68 | save_manager.config.Train.epoch, 69 | train_loss, 70 | save_manager.config.Train.decision_metric, 71 | train_metric[save_manager.config.Train.decision_metric], 72 | val_loss, 73 | save_manager.config.Train.decision_metric, 74 | val_metric[save_manager.config.Train.decision_metric], 75 | (time.time()-start_time)/60)) 76 | print('version: {}'.format(save_manager.config.version)) 77 | if self.metric_manager.is_new_best(old=self.best_metric, new=val_metric): 78 | self.patience = 0 79 | self.best_epoch = epoch 80 | self.best_metric = val_metric 81 | save_manager.save_model(self.best_epoch, self.model.state_dict(), self.best_metric) 82 | save_manager.write_log('Model saved after Epoch {}'.format(epoch), 4) 83 | else : 84 | self.patience += 1 85 | if self.patience > save_manager.config.Train.patience: 86 | save_manager.write_log('Training Early Stopped after Epoch {}'.format(epoch), 16) 87 | break 88 | 89 | def test(self, save_manager, test_loader, writer=None): 90 | # To reduce GPU usage, load params on cpu and upload model on gpu(params on gpu also need additional GPU memory) 91 | self.model.cpu() 92 | self.model.load_state_dict(self.best_param) 93 | self.model.to(save_manager.config.MISC.device) 94 | 95 | 96 | save_manager.write_log('Best model at epoch {} loaded'.format(self.best_epoch)) 97 | self.model.eval() 98 | with torch.no_grad(): 99 | post_metric_managers = [MetricManager(save_manager) for _ in range(save_manager.config.Dataset.num_keypoint+1)] 100 | manual_metric_managers = [MetricManager(save_manager) for _ in range(save_manager.config.Dataset.num_keypoint+1)] 101 | 102 | for i, batch in enumerate(tqdm(test_loader)): 103 | batch = Munch.fromDict(batch) 104 | batch.is_training = False 105 | max_hint = 10+1 106 | for n_hint in range(max_hint): 107 | # 0~max hint 108 | 109 | ## model forward 110 | out, batch, post_processing_pred = self.forward_batch(batch, metric_flag=True, average_flag=False, metric_manager=post_metric_managers[n_hint], return_post_processing_pred=True) 111 | if save_manager.config.Model.use_prev_heatmap: 112 | batch.prev_heatmap = out.pred.heatmap.detach() 113 | 114 | ## manual forward 115 | if n_hint == 0: 116 | manual_pred = copy.deepcopy(out.pred) 117 | manual_pred.sargmax_coord = manual_pred.sargmax_coord 118 | manual_pred.heatmap = manual_pred.heatmap 119 | manual_hint_index = [] 120 | else: 121 | # manual prediction update 122 | for k in range(manual_hint_index.shape[0]): 123 | manual_pred.sargmax_coord[k, manual_hint_index[k]] = batch.label.coord[ 124 | k, manual_hint_index[k]].to(manual_pred.sargmax_coord.device) 125 | manual_pred.heatmap[k, manual_hint_index[k]] = batch.label.heatmap[ 126 | k, manual_hint_index[k]].to(manual_pred.heatmap.device) 127 | manual_metric_managers[n_hint].measure_metric(manual_pred, batch.label, batch.pspace, metric_flag=True, average_flag=False) 128 | 129 | # ============================= model hint ================================= 130 | worst_index = self.find_worst_pred_index(batch.hint.index, post_metric_managers, 131 | save_manager, n_hint) 132 | # hint index update 133 | if n_hint == 0: 134 | batch.hint.index = worst_index # (batch, 1) 135 | else: 136 | batch.hint.index = torch.cat((batch.hint.index, worst_index.to(batch.hint.index.device)), dim=1) # ... (batch, max hint) 137 | 138 | # 139 | if save_manager.config.Model.use_prev_heatmap_only_for_hint_index: 140 | new_prev_heatmap = torch.zeros_like(out.pred.heatmap) 141 | for j in range(len(batch.hint.index)): 142 | new_prev_heatmap[j, batch.hint.index[j]] = out.pred.heatmap[j, batch.hint.index[j]] 143 | batch.prev_heatmap = new_prev_heatmap 144 | # ================================= manual hint ========================= 145 | worst_index = self.find_worst_pred_index(manual_hint_index, manual_metric_managers, 146 | save_manager, n_hint) 147 | # hint index update 148 | if n_hint == 0: 149 | manual_hint_index = worst_index # (batch, 1) 150 | else: 151 | manual_hint_index = torch.cat((manual_hint_index, worst_index.to(manual_hint_index.device)), dim=1) # ... (batch, max hint) 152 | 153 | # save result 154 | if save_manager.config.save_test_prediction: 155 | save_manager.add_test_prediction_for_save(batch, post_processing_pred, manual_pred, n_hint, post_metric_managers[n_hint], manual_metric_managers[n_hint]) 156 | 157 | post_metrics = [metric_manager.average_running_metric() for metric_manager in post_metric_managers] 158 | manual_metrics = [metric_manager.average_running_metric() for metric_manager in manual_metric_managers] 159 | 160 | # save metrics 161 | for t in range(min(max_hint, len(post_metrics))): 162 | save_manager.write_log('(model ) Hint {} ::: {}'.format(t, post_metrics[t])) 163 | save_manager.write_log('(manual) Hint {} ::: {}'.format(t, manual_metrics[t])) 164 | save_manager.save_metric(post_metrics, manual_metrics) 165 | 166 | if save_manager.config.save_test_prediction: 167 | save_manager.save_test_prediction() 168 | 169 | 170 | def forward_batch(self, batch, metric_flag=False, average_flag=True, metric_manager=None, return_post_processing_pred=False): 171 | out, batch = self.model(batch) 172 | with torch.no_grad(): 173 | if metric_manager is None: 174 | self.metric_manager.measure_metric(out.pred, batch.label, batch.pspace, metric_flag, average_flag) 175 | else: 176 | #post processing 177 | post_processing_pred = copy.deepcopy(out.pred) 178 | post_processing_pred.sargmax_coord = post_processing_pred.sargmax_coord.detach() 179 | post_processing_pred.heatmap = post_processing_pred.heatmap.detach() 180 | for i in range(len(batch.hint.index)): #for 문이 batch에 대해서 도는중 i번째 item 181 | if batch.hint.index[i] is not None: 182 | post_processing_pred.sargmax_coord[i, batch.hint.index[i]] = batch.label.coord[i, batch.hint.index[i]].detach().to(post_processing_pred.sargmax_coord.device) 183 | post_processing_pred.heatmap[i, batch.hint.index[i]] = batch.label.heatmap[i, batch.hint.index[i]].detach().to(post_processing_pred.heatmap.device) 184 | 185 | metric_manager.measure_metric(post_processing_pred, copy.deepcopy(batch.label), batch.pspace, metric_flag=True, average_flag=False) 186 | if return_post_processing_pred: 187 | return out, batch, post_processing_pred 188 | else: 189 | return out, batch 190 | 191 | def find_worst_pred_index(self, previous_hint_index, metric_managers, save_manager, n_hint): 192 | batch_metric_value = metric_managers[n_hint].running_metric[save_manager.config.Train.decision_metric][-1] 193 | if len(batch_metric_value.shape) == 3: 194 | batch_metric_value = batch_metric_value.mean(-1) # (batch, 13) 195 | 196 | tmp_metric = batch_metric_value.clone().detach() 197 | if metric_managers[n_hint].minimize_metric: 198 | for j, idx in enumerate(previous_hint_index): 199 | if idx is not None: 200 | tmp_metric[j, idx] = -1000 201 | worst_index = tmp_metric.argmax(-1, keepdim=True) 202 | else: 203 | for j, idx in enumerate(previous_hint_index): 204 | if idx is not None: 205 | tmp_metric[j, idx] = 1000 206 | worst_index = tmp_metric.argmin(-1, keepdim=True) 207 | return worst_index 208 | 209 | 210 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import json 3 | import pickle 4 | from munch import Munch 5 | import os 6 | import torch 7 | from misc.metric import heatmap2hargmax_coord 8 | from torch.utils.tensorboard import SummaryWriter 9 | import torchvision 10 | 11 | class SaveManager(object): 12 | def __init__(self, arg): 13 | print('\n\n\n\n') 14 | 15 | self.read_config(arg) 16 | self.unspecified_configs_to_default() 17 | self.save_config() 18 | self.config.MISC.device = torch.device('cuda') 19 | self.config_name2value() 20 | 21 | def config_name2value(self): 22 | with open('./morph_pairs/{}/{}.json'.format(self.config.Dataset.NAME, self.config.Morph.pairs), 'r') as f: 23 | self.config.Morph.pairs = json.load(f) 24 | 25 | if self.config.Hint.num_dist == 'datset16': 26 | self.config.Hint.num_dist = [1 / 8, 1 / 2, 1 / 4, 1 / 16, 1 / 32, 1 / 64, 1 / 128, 1 / 256, 1 / 512, 1 / 1024, 1 / 2048, 1 / 4096, 1 / 4096]+[0 for _ in range(68-13)] 27 | 28 | def read_config(self, arg): 29 | # load config 30 | if arg.only_test_version: 31 | config_path = '../save/{}/config.yaml'.format(arg.only_test_version) 32 | else: 33 | config_path = './config/{}.yaml'.format(arg.config) 34 | with open(config_path) as f: 35 | config = yaml.safe_load(f) 36 | 37 | self.config = Munch.fromDict(config) 38 | 39 | if not arg.only_test_version: 40 | os.makedirs('../save/', exist_ok=True) 41 | exp_num = self.get_new_exp_num() 42 | self.config.version = "ExpNum[{}]_Dataset[{}]_Model[{}]_config[{}]_seed[{}]".format(exp_num, 43 | self.config.Dataset.NAME, 44 | self.config.Model.NAME, 45 | arg.config, arg.seed) 46 | os.makedirs('../save/{}'.format(self.config.version)) 47 | else: 48 | self.config.version = arg.only_test_version 49 | 50 | # set path and write version 51 | self.set_config_path() 52 | self.write_log('Version : {}'.format(self.config.version), n_mark=16) 53 | 54 | # update arg to config 55 | items = arg.__dict__.items() if 'namespace' in str(arg.__class__).lower() else arg.items() 56 | for key, value in items: 57 | self.config[key] = value 58 | 59 | 60 | if arg.save_test_prediction: 61 | self.predictions_for_save = Munch({}) 62 | 63 | self.config.Model.use_prev_heatmap_only_for_hint_index = True 64 | self.write_log('Use previous heatmap only for hint index, else zero', n_mark=8) 65 | 66 | def write_log(self, text, n_mark=0, save_flag=True): 67 | log = '{} {} {}'.format('='*n_mark, text, '='*n_mark) 68 | print(log) 69 | if save_flag: 70 | with open(self.config.PATH.LOG_PATH, 'a+') as f: 71 | f.write('{}\n'.format(log)) 72 | return 73 | 74 | def set_config_path(self): 75 | self.config.PATH.LOG_PATH = '../save/{}/log.txt'.format(self.config.version) 76 | self.config.PATH.CONFIG_PATH = '../save/{}/config.yaml'.format(self.config.version) 77 | self.config.PATH.MODEL_PATH = '../save/{}/model.pth'.format(self.config.version) 78 | self.config.PATH.post_RESULT_PATH = '../save/{}/post_result.pickle'.format(self.config.version) 79 | self.config.PATH.manual_RESUAL_PATH = '../save/{}/manual_result.pickle'.format(self.config.version) 80 | self.config.PATH.PREDICTION_RESULT_PATH = '../save/{}/predictions.pickle'.format(self.config.version) 81 | 82 | def save_config(self): 83 | with open(self.config.PATH.CONFIG_PATH, 'w') as f: 84 | yaml.dump(self.config.toDict(), f) 85 | return 86 | 87 | def load_model(self): 88 | best_save = torch.load(self.config.PATH.MODEL_PATH, map_location=torch.device('cpu')) 89 | best_param = best_save['model'] 90 | best_epoch = best_save['epoch'] 91 | return best_param, best_epoch, None 92 | 93 | def save_model(self, epoch, param, metric): 94 | save_dict = { 95 | 'model': param, 96 | 'epoch':epoch, 97 | 'metric':metric 98 | } 99 | torch.save(save_dict, self.config.PATH.MODEL_PATH) 100 | return 101 | 102 | def save_metric(self, metric, manual_metric=None): 103 | with open(self.config.PATH.post_RESULT_PATH, 'wb') as f: 104 | pickle.dump(metric, f) 105 | if manual_metric is not None: 106 | with open(self.config.PATH.manual_RESUAL_PATH, 'wb') as f: 107 | pickle.dump(manual_metric, f) 108 | 109 | def save_test_prediction(self): 110 | with open(self.config.PATH.PREDICTION_RESULT_PATH, 'wb') as f: 111 | pickle.dump(self.predictions_for_save, f) 112 | 113 | def add_test_prediction_for_save(self, batch, post_processing_pred, manual_pred, n_hint, post_metric_manager, manual_metric_manager): 114 | manual_pred.hargmax_coord = heatmap2hargmax_coord(manual_pred.heatmap).detach() 115 | post_sargmax_mm_MRE = post_metric_manager.running_metric['sargmax_mm_MRE'][-1].detach().cpu().numpy() 116 | post_hargmax_mm_MRE = post_metric_manager.running_metric['hargmax_mm_MRE'][-1].detach().cpu().numpy() 117 | manual_sargmax_mm_MRE = manual_metric_manager.running_metric['sargmax_mm_MRE'][-1].detach().cpu().numpy() 118 | manual_hargmax_mm_MRE = manual_metric_manager.running_metric['hargmax_mm_MRE'][-1].detach().cpu().numpy() 119 | for b in range(len(batch.label.coord)): 120 | name = 'batch{}_hint{}'.format(batch.index[b], n_hint) 121 | self.predictions_for_save[name] = Munch({}) 122 | self.predictions_for_save[name].post = Munch({}) 123 | self.predictions_for_save[name].post.sargmax_coord = post_processing_pred.sargmax_coord[b].detach().cpu().numpy() 124 | self.predictions_for_save[name].post.hargmax_coord = post_processing_pred.hargmax_coord[b].detach().cpu().numpy() 125 | # heatmap은 저장하면 너무 느리고 용량도 너무 큼. 차라리 model forward 한번 더 하는게 낫다. 126 | self.predictions_for_save[name].manual = Munch({}) 127 | self.predictions_for_save[name].manual.sargmax_coord = manual_pred.sargmax_coord[b].detach().cpu().numpy() 128 | self.predictions_for_save[name].manual.hargmax_coord = manual_pred.hargmax_coord[b].detach().cpu().numpy() 129 | 130 | self.predictions_for_save[name].hint = Munch({}) 131 | if torch.is_tensor(batch.hint.index[b]): 132 | self.predictions_for_save[name].hint.index = batch.hint.index[b].detach().cpu().numpy() 133 | else: 134 | self.predictions_for_save[name].hint.index = batch.hint.index[b] 135 | 136 | self.predictions_for_save[name].metric = Munch({}) 137 | self.predictions_for_save[name].metric.post = Munch({}) 138 | self.predictions_for_save[name].metric.post.sargmax_mm_MRE = post_sargmax_mm_MRE[b] 139 | self.predictions_for_save[name].metric.post.hargmax_mm_MRE = post_hargmax_mm_MRE[b] 140 | self.predictions_for_save[name].metric.manual = Munch({}) 141 | self.predictions_for_save[name].metric.manual.sargmax_mm_MRE = manual_sargmax_mm_MRE[b] 142 | self.predictions_for_save[name].metric.manual.hargmax_mm_MRE = manual_hargmax_mm_MRE[b] 143 | 144 | 145 | 146 | def get_new_exp_num(self): 147 | save_path = '../save/' 148 | save_folder_names = os.listdir(save_path) 149 | max_exp_num = 0 150 | for folder_name in save_folder_names: 151 | exp_num = folder_name.split('_')[0][:-1].split('[')[-1] 152 | max_exp_num = max(int(exp_num), max_exp_num) 153 | new_exp_num = '{:05d}'.format(max_exp_num+1) 154 | 155 | return new_exp_num 156 | 157 | def unspecified_configs_to_default(self): 158 | if self.config.Dataset.get('subpixel_decoding',None) is None: 159 | self.config.Dataset.subpixel_decoding = False 160 | if self.config.Dataset.get('subpixel_decoding_patch_size',None) is None: 161 | self.config.Dataset.subpixel_decoding_patch_size = 5 162 | if self.config.Dataset.get('heatmap_encoding_maxone',None) is None: 163 | self.config.Dataset.heatmap_encoding_maxone = False 164 | if self.config.Dataset.get('subpixel_heatmap_encoding', None) is None: 165 | self.config.Dataset.subpixel_heatmap_encoding = False 166 | if self.config.Model.get('subpixel_decoding_coord_loss', None) is None: 167 | self.config.Model.subpixel_decoding_coord_loss = False 168 | if self.config.Model.get('facto_heatmap', None) is None: 169 | self.config.Model.facto_heatmap = False 170 | if self.config.Model.get('HintEncoder', None) is None: 171 | self.config.Model.HintEncoder = Munch({}) 172 | if self.config.Model.HintEncoder.get('dilation', None) is None: 173 | self.config.Model.HintEncoder.dilation = 5 174 | if self.config.Model.get('Decoder', None) is None: 175 | self.config.Model.Decoder = Munch({}) 176 | if self.config.Model.Decoder.get('dilation', None) is None: 177 | self.config.Model.Decoder.dilation = 5 178 | if self.config.Dataset.get('label_smoothing', None) is None: 179 | self.config.Dataset.label_smoothing = False 180 | if self.config.Model.get('MSELoss', None) is None: 181 | self.config.Model.MSELoss = 0.0 182 | if self.config.Dataset.get('heatmap_max_norm',None) is None: 183 | self.config.Dataset.heatmap_max_norm = False 184 | if self.config.Model.get('input_padding') is None: 185 | self.config.Model.input_padding = None 186 | if self.config.Morph.get('cosineSimilarityLoss') is None: 187 | self.config.Morph.cosineSimilarityLoss = False 188 | if self.config.Morph.get('threePointAngle') is None: 189 | self.config.Morph.threePointAngle = False 190 | if self.config.MISC.get('free_memory',None) is None: 191 | self.config.MISC.free_memory = False 192 | if self.config.Model.get('bbox_predictor', None) is None: 193 | self.config.Model.bbox_predictor = False 194 | if self.config.Morph.get('distance_l1', None) is None: 195 | self.config.Morph.distance_l1 = False 196 | if self.config.Model.get('SE_maxpool',None) is None: 197 | self.config.Model.SE_maxpool = False 198 | if self.config.Model.get('SE_softmax', None) is None: 199 | self.config.Model.SE_softmax = False 200 | if self.config.Model.get('use_prev_heatmap', None) is None: 201 | self.config.Model.use_prev_heatmap = False 202 | if self.config.Model.get('no_iterative_training', None) is None: 203 | self.config.Model.no_iterative_training = False # RITM에만 적용 204 | if self.config.Morph.get('coord_use', None) is None: 205 | self.config.Morph.coord_use = False 206 | 207 | class TensorBoardManager(): 208 | def __init__(self, save_manager): 209 | tensorboard_path = '../tensorboard/{}/{}/'.format(save_manager.config.Dataset.NAME, save_manager.config.version) 210 | os.makedirs(tensorboard_path, exist_ok=True) 211 | self.writer = SummaryWriter(tensorboard_path) 212 | 213 | self.n_iter = {'train':0, 'val':0, 'test':0} 214 | self.n_epoch = {'train':0, 'val':0, 'test':0} 215 | 216 | def plot_image_heatmap(self, image, pred_heatmap, label_heatmap, epoch): 217 | # image: (batch, 3, H, W) -1 <= x <= 1 218 | # heatmap: (batch, num_keypoint, Height, Width) 0 <= x <= 1 219 | image_unNorm = image * 0.5 + 0.5 220 | 221 | pred_label_heatmap = torch.cat((pred_heatmap[:,:,None,:,:], label_heatmap[:,:,None,:,:], torch.zeros_like(pred_heatmap)[:,:,None,:,:]), dim=2) 222 | 223 | grids = [torchvision.utils.make_grid( 224 | torch.cat([image_unNorm[i].unsqueeze(0), pred_label_heatmap[i]], dim=0) 225 | ) 226 | for i in range(pred_label_heatmap.shape[0])] # (1,3,H,W), (num_keypoint,3,H,W) 227 | 228 | for i, grid in enumerate(grids): 229 | text = 'Epoch [{}] - {} ::: Pred(red) Label(green)'.format(epoch, i+1) 230 | self.writer.add_image(text, grid) 231 | 232 | def plot_outlier_image_heatmap(self, image, pred_heatmap, label_heatmap, epoch): 233 | # image: (batch, 3, H, W) -1 <= x <= 1 234 | # heatmap: (batch, num_keypoint, Height, Width) 0 <= x <= 1 235 | image_unNorm = image * 0.5 + 0.5 236 | 237 | pred_label_heatmap = torch.cat((pred_heatmap[:,:,None,:,:], label_heatmap[:,:,None,:,:], torch.zeros_like(pred_heatmap)[:,:,None,:,:]), dim=2) 238 | 239 | grids = [torchvision.utils.make_grid( 240 | torch.cat([image_unNorm[i].unsqueeze(0), pred_label_heatmap[i]], dim=0) 241 | ) 242 | for i in range(pred_label_heatmap.shape[0])] # (1,3,H,W), (num_keypoint,3,H,W) 243 | 244 | for i, grid in enumerate(grids): 245 | text = 'iter [{}] - outlier - {}'.format(self.n_iter['train'], i+1) 246 | self.writer.add_image(text, grid) 247 | 248 | def plot_model_param_histogram(self, model, epoch): 249 | for k, v in model.named_parameters(): 250 | self.writer.add_histogram(k, v.data.cpu().reshape(-1), epoch) 251 | return 252 | 253 | def write_loss(self, loss, split): 254 | # loss: scalar 255 | # split: 'train', 'val', 'test' 256 | self.n_iter[split] += 1 257 | self.writer.add_scalar('Loss/{}'.format(split), loss, self.n_iter[split]) 258 | 259 | def write_metric(self, metric, split): 260 | self.n_epoch[split] += 1 261 | for key in metric: 262 | self.writer.add_scalar('{}/{}'.format(key, split), metric[key], self.n_epoch[split]) 263 | 264 | -------------------------------------------------------------------------------- /model/iterativeRefinementModels/RITM_modules/RITM_hrnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch._utils 6 | import torch.nn.functional as F 7 | from model.iterativeRefinementModels.RITM_modules.RITM_ocr import SpatialOCR_Module, SpatialGather_Module 8 | 9 | 10 | relu_inplace = True 11 | 12 | 13 | class HighResolutionModule(nn.Module): 14 | def __init__(self, num_branches, blocks, num_blocks, num_inchannels, 15 | num_channels, fuse_method,multi_scale_output=True, 16 | norm_layer=nn.BatchNorm2d, align_corners=True): 17 | super(HighResolutionModule, self).__init__() 18 | self._check_branches(num_branches, num_blocks, num_inchannels, num_channels) 19 | 20 | self.num_inchannels = num_inchannels 21 | self.fuse_method = fuse_method 22 | self.num_branches = num_branches 23 | self.norm_layer = norm_layer 24 | self.align_corners = align_corners 25 | 26 | self.multi_scale_output = multi_scale_output 27 | 28 | self.branches = self._make_branches( 29 | num_branches, blocks, num_blocks, num_channels) 30 | self.fuse_layers = self._make_fuse_layers() 31 | self.relu = nn.ReLU(inplace=relu_inplace) 32 | 33 | def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels): 34 | if num_branches != len(num_blocks): 35 | error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( 36 | num_branches, len(num_blocks)) 37 | raise ValueError(error_msg) 38 | 39 | if num_branches != len(num_channels): 40 | error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( 41 | num_branches, len(num_channels)) 42 | raise ValueError(error_msg) 43 | 44 | if num_branches != len(num_inchannels): 45 | error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( 46 | num_branches, len(num_inchannels)) 47 | raise ValueError(error_msg) 48 | 49 | def _make_one_branch(self, branch_index, block, num_blocks, num_channels, 50 | stride=1): 51 | downsample = None 52 | if stride != 1 or \ 53 | self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: 54 | downsample = nn.Sequential( 55 | nn.Conv2d(self.num_inchannels[branch_index], 56 | num_channels[branch_index] * block.expansion, 57 | kernel_size=1, stride=stride, bias=False), 58 | self.norm_layer(num_channels[branch_index] * block.expansion), 59 | ) 60 | 61 | layers = [] 62 | layers.append(block(self.num_inchannels[branch_index], 63 | num_channels[branch_index], stride, 64 | downsample=downsample, norm_layer=self.norm_layer)) 65 | self.num_inchannels[branch_index] = \ 66 | num_channels[branch_index] * block.expansion 67 | for i in range(1, num_blocks[branch_index]): 68 | layers.append(block(self.num_inchannels[branch_index], 69 | num_channels[branch_index], 70 | norm_layer=self.norm_layer)) 71 | 72 | return nn.Sequential(*layers) 73 | 74 | def _make_branches(self, num_branches, block, num_blocks, num_channels): 75 | branches = [] 76 | 77 | for i in range(num_branches): 78 | branches.append( 79 | self._make_one_branch(i, block, num_blocks, num_channels)) 80 | 81 | return nn.ModuleList(branches) 82 | 83 | def _make_fuse_layers(self): 84 | if self.num_branches == 1: 85 | return None 86 | 87 | num_branches = self.num_branches 88 | num_inchannels = self.num_inchannels 89 | fuse_layers = [] 90 | for i in range(num_branches if self.multi_scale_output else 1): 91 | fuse_layer = [] 92 | for j in range(num_branches): 93 | if j > i: 94 | fuse_layer.append(nn.Sequential( 95 | nn.Conv2d(in_channels=num_inchannels[j], 96 | out_channels=num_inchannels[i], 97 | kernel_size=1, 98 | bias=False), 99 | self.norm_layer(num_inchannels[i]))) 100 | elif j == i: 101 | fuse_layer.append(None) 102 | else: 103 | conv3x3s = [] 104 | for k in range(i - j): 105 | if k == i - j - 1: 106 | num_outchannels_conv3x3 = num_inchannels[i] 107 | conv3x3s.append(nn.Sequential( 108 | nn.Conv2d(num_inchannels[j], 109 | num_outchannels_conv3x3, 110 | kernel_size=3, stride=2, padding=1, bias=False), 111 | self.norm_layer(num_outchannels_conv3x3))) 112 | else: 113 | num_outchannels_conv3x3 = num_inchannels[j] 114 | conv3x3s.append(nn.Sequential( 115 | nn.Conv2d(num_inchannels[j], 116 | num_outchannels_conv3x3, 117 | kernel_size=3, stride=2, padding=1, bias=False), 118 | self.norm_layer(num_outchannels_conv3x3), 119 | nn.ReLU(inplace=relu_inplace))) 120 | fuse_layer.append(nn.Sequential(*conv3x3s)) 121 | fuse_layers.append(nn.ModuleList(fuse_layer)) 122 | 123 | return nn.ModuleList(fuse_layers) 124 | 125 | def get_num_inchannels(self): 126 | return self.num_inchannels 127 | 128 | def forward(self, x): 129 | if self.num_branches == 1: 130 | return [self.branches[0](x[0])] 131 | 132 | for i in range(self.num_branches): 133 | x[i] = self.branches[i](x[i]) 134 | 135 | x_fuse = [] 136 | for i in range(len(self.fuse_layers)): 137 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 138 | for j in range(1, self.num_branches): 139 | if i == j: 140 | y = y + x[j] 141 | elif j > i: 142 | width_output = x[i].shape[-1] 143 | height_output = x[i].shape[-2] 144 | y = y + F.interpolate( 145 | self.fuse_layers[i][j](x[j]), 146 | size=[height_output, width_output], 147 | mode='bilinear', align_corners=self.align_corners) 148 | else: 149 | y = y + self.fuse_layers[i][j](x[j]) 150 | x_fuse.append(self.relu(y)) 151 | 152 | return x_fuse 153 | 154 | 155 | class SqueezeExitationBlock(nn.Module): 156 | def __init__(self, in_ch, mid_ch1, out_ch, SE_maxpool=False, SE_softmax=False): 157 | super(SqueezeExitationBlock, self).__init__() 158 | self.conv1 = nn.Conv2d(in_ch, mid_ch1, kernel_size=(1, 1)) 159 | self.conv2 = nn.Conv2d(mid_ch1, out_ch, kernel_size=(1, 1)) 160 | self.SE_maxpool = SE_maxpool 161 | self.SE_softmax = SE_softmax 162 | 163 | def forward(self, x): 164 | # === squeeze === 165 | # pool 166 | if self.SE_maxpool: 167 | x = x.max(-1)[0].max(-1)[0] 168 | x = x[:,:,None,None] 169 | else: 170 | x = x.mean(-1, keepdim=True).mean(-2, keepdim=True) # b, in_ch, 1, 1 171 | # conv1, relu 172 | x = self.conv1(x).relu() # b, mid_ch, 1, 1 173 | # conv2, relu 174 | x = self.conv2(x) 175 | 176 | if self.SE_softmax: 177 | x = x.softmax(1) 178 | else: 179 | x = x.sigmoid() # b, out_ch, 1, 1 180 | return x 181 | class ConvBlocks(nn.Module): 182 | def __init__(self, in_ch, channels, kernel_sizes=None, strides=None, dilations=None, paddings=None, 183 | BatchNorm=nn.BatchNorm2d): 184 | super(ConvBlocks, self).__init__() 185 | self.num = len(channels) 186 | if kernel_sizes is None: kernel_sizes = [3 for c in channels] 187 | if strides is None: strides = [1 for c in channels] 188 | if dilations is None: dilations = [1 for c in channels] 189 | if paddings is None: paddings = [ 190 | ((kernel_sizes[i] // 2) if dilations[i] == 1 else (kernel_sizes[i] // 2 * dilations[i])) for i in 191 | range(self.num)] 192 | convs_tmp = [] 193 | for i in range(self.num): 194 | if channels[i] == 1: 195 | convs_tmp.append( 196 | nn.Conv2d(in_ch if i == 0 else channels[i - 1], channels[i], kernel_size=kernel_sizes[i], 197 | stride=strides[i], padding=paddings[i], dilation=dilations[i])) 198 | else: 199 | convs_tmp.append(nn.Sequential( 200 | nn.Conv2d(in_ch if i == 0 else channels[i - 1], channels[i], kernel_size=kernel_sizes[i], 201 | stride=strides[i], padding=paddings[i], dilation=dilations[i], bias=False), 202 | BatchNorm(channels[i]), nn.ReLU())) 203 | self.convs = nn.Sequential(*convs_tmp) 204 | 205 | # weight initialization 206 | for m in self.convs.modules(): 207 | if isinstance(m, nn.Conv2d): 208 | torch.nn.init.kaiming_normal_(m.weight) 209 | elif isinstance(m, nn.BatchNorm2d): 210 | m.weight.data.fill_(1) 211 | m.bias.data.zero_() 212 | 213 | def forward(self, x): 214 | return self.convs(x) 215 | class HintEncSENet(nn.Module): 216 | def __init__(self, se_output_channel, num_classes, SE_maxpool=False, SE_softmax=False, input_channel=256): 217 | super(HintEncSENet, self).__init__() 218 | self.SENet = SqueezeExitationBlock(256, 256//16, se_output_channel, SE_maxpool=SE_maxpool, SE_softmax=SE_softmax) 219 | self.hintEncoder = ConvBlocks(input_channel+num_classes, [256, 256, 256], [3, 3, 3], [2, 1, 1]) 220 | 221 | 222 | def forward(self,x,hint): 223 | hint = F.interpolate(hint, size=x.size()[2:], mode='bilinear', align_corners=True) 224 | se = self.hintEncoder(torch.cat((x, hint),dim=1)) 225 | se = self.SENet(se) 226 | return se 227 | 228 | 229 | class HighResolutionNet(nn.Module): 230 | def __init__(self, width, num_classes, ocr_width=256, small=False, 231 | norm_layer=nn.BatchNorm2d, align_corners=True, addHintEncSENet=False, SE_maxpool=False, SE_softmax=False): 232 | super(HighResolutionNet, self).__init__() 233 | 234 | 235 | self.norm_layer = norm_layer 236 | self.width = width 237 | self.ocr_width = ocr_width 238 | self.align_corners = align_corners 239 | 240 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) 241 | self.bn1 = norm_layer(64) 242 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) 243 | self.bn2 = norm_layer(64) 244 | self.relu = nn.ReLU(inplace=relu_inplace) 245 | 246 | num_blocks = 2 if small else 4 247 | 248 | stage1_num_channels = 64 249 | self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks) 250 | stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels 251 | 252 | self.stage2_num_branches = 2 253 | num_channels = [width, 2 * width] 254 | num_inchannels = [ 255 | num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] 256 | self.transition1 = self._make_transition_layer( 257 | [stage1_out_channel], num_inchannels) 258 | self.stage2, pre_stage_channels = self._make_stage( 259 | BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches, 260 | num_blocks=2 * [num_blocks], num_channels=num_channels) 261 | 262 | self.stage3_num_branches = 3 263 | num_channels = [width, 2 * width, 4 * width] 264 | num_inchannels = [ 265 | num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] 266 | self.transition2 = self._make_transition_layer( 267 | pre_stage_channels, num_inchannels) 268 | self.stage3, pre_stage_channels = self._make_stage( 269 | BasicBlockV1b, num_inchannels=num_inchannels, 270 | num_modules=3 if small else 4, num_branches=self.stage3_num_branches, 271 | num_blocks=3 * [num_blocks], num_channels=num_channels) 272 | 273 | self.stage4_num_branches = 4 274 | num_channels = [width, 2 * width, 4 * width, 8 * width] 275 | num_inchannels = [ 276 | num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] 277 | self.transition3 = self._make_transition_layer( 278 | pre_stage_channels, num_inchannels) 279 | self.stage4, pre_stage_channels = self._make_stage( 280 | BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3, 281 | num_branches=self.stage4_num_branches, 282 | num_blocks=4 * [num_blocks], num_channels=num_channels) 283 | 284 | last_inp_channels = np.int(np.sum(pre_stage_channels)) 285 | if self.ocr_width > 0: 286 | ocr_mid_channels = 2 * self.ocr_width 287 | ocr_key_channels = self.ocr_width 288 | 289 | self.conv3x3_ocr = nn.Sequential( 290 | nn.Conv2d(last_inp_channels, ocr_mid_channels, 291 | kernel_size=3, stride=1, padding=1), 292 | norm_layer(ocr_mid_channels), 293 | nn.ReLU(inplace=relu_inplace), 294 | ) 295 | self.ocr_gather_head = SpatialGather_Module(num_classes) 296 | 297 | self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, 298 | key_channels=ocr_key_channels, 299 | out_channels=ocr_mid_channels, 300 | scale=1, 301 | dropout=0.05, 302 | norm_layer=norm_layer, 303 | align_corners=align_corners) 304 | self.cls_head = nn.Conv2d( 305 | ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True) 306 | 307 | self.aux_head = nn.Sequential( 308 | nn.Conv2d(last_inp_channels, last_inp_channels, 309 | kernel_size=1, stride=1, padding=0), 310 | norm_layer(last_inp_channels), 311 | nn.ReLU(inplace=relu_inplace), 312 | nn.Conv2d(last_inp_channels, num_classes, 313 | kernel_size=1, stride=1, padding=0, bias=True) 314 | ) 315 | else: 316 | self.cls_head = nn.Sequential( 317 | nn.Conv2d(last_inp_channels, last_inp_channels, 318 | kernel_size=3, stride=1, padding=1), 319 | norm_layer(last_inp_channels), 320 | nn.ReLU(inplace=relu_inplace), 321 | nn.Conv2d(last_inp_channels, num_classes, 322 | kernel_size=1, stride=1, padding=0, bias=True) 323 | ) 324 | 325 | 326 | self.addHintEncSENet = addHintEncSENet 327 | if self.addHintEncSENet: 328 | self.HintEncSENet = HintEncSENet( se_output_channel=last_inp_channels, num_classes=num_classes, SE_maxpool=SE_maxpool, SE_softmax=SE_softmax) 329 | 330 | 331 | def _make_transition_layer( 332 | self, num_channels_pre_layer, num_channels_cur_layer): 333 | num_branches_cur = len(num_channels_cur_layer) 334 | num_branches_pre = len(num_channels_pre_layer) 335 | 336 | transition_layers = [] 337 | for i in range(num_branches_cur): 338 | if i < num_branches_pre: 339 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]: 340 | transition_layers.append(nn.Sequential( 341 | nn.Conv2d(num_channels_pre_layer[i], 342 | num_channels_cur_layer[i], 343 | kernel_size=3, 344 | stride=1, 345 | padding=1, 346 | bias=False), 347 | self.norm_layer(num_channels_cur_layer[i]), 348 | nn.ReLU(inplace=relu_inplace))) 349 | else: 350 | transition_layers.append(None) 351 | else: 352 | conv3x3s = [] 353 | for j in range(i + 1 - num_branches_pre): 354 | inchannels = num_channels_pre_layer[-1] 355 | outchannels = num_channels_cur_layer[i] \ 356 | if j == i - num_branches_pre else inchannels 357 | conv3x3s.append(nn.Sequential( 358 | nn.Conv2d(inchannels, outchannels, 359 | kernel_size=3, stride=2, padding=1, bias=False), 360 | self.norm_layer(outchannels), 361 | nn.ReLU(inplace=relu_inplace))) 362 | transition_layers.append(nn.Sequential(*conv3x3s)) 363 | 364 | return nn.ModuleList(transition_layers) 365 | 366 | def _make_layer(self, block, inplanes, planes, blocks, stride=1): 367 | downsample = None 368 | if stride != 1 or inplanes != planes * block.expansion: 369 | downsample = nn.Sequential( 370 | nn.Conv2d(inplanes, planes * block.expansion, 371 | kernel_size=1, stride=stride, bias=False), 372 | self.norm_layer(planes * block.expansion), 373 | ) 374 | 375 | layers = [] 376 | layers.append(block(inplanes, planes, stride, 377 | downsample=downsample, norm_layer=self.norm_layer)) 378 | inplanes = planes * block.expansion 379 | for i in range(1, blocks): 380 | layers.append(block(inplanes, planes, norm_layer=self.norm_layer)) 381 | 382 | return nn.Sequential(*layers) 383 | 384 | def _make_stage(self, block, num_inchannels, 385 | num_modules, num_branches, num_blocks, num_channels, 386 | fuse_method='SUM', 387 | multi_scale_output=True): 388 | modules = [] 389 | for i in range(num_modules): 390 | # multi_scale_output is only used last module 391 | if not multi_scale_output and i == num_modules - 1: 392 | reset_multi_scale_output = False 393 | else: 394 | reset_multi_scale_output = True 395 | modules.append( 396 | HighResolutionModule(num_branches, 397 | block, 398 | num_blocks, 399 | num_inchannels, 400 | num_channels, 401 | fuse_method, 402 | reset_multi_scale_output, 403 | norm_layer=self.norm_layer, 404 | align_corners=self.align_corners) 405 | ) 406 | num_inchannels = modules[-1].get_num_inchannels() 407 | 408 | return nn.Sequential(*modules), num_inchannels 409 | 410 | def forward(self, x, additional_features=None, input_hint_heatmap=None): 411 | feats = self.compute_hrnet_feats(x, additional_features, input_hint_heatmap) 412 | if self.ocr_width > 0: 413 | out_aux = self.aux_head(feats) # aux_head : conv norm relu conv (soft object regions), output channel: num_classes 414 | feats = self.conv3x3_ocr(feats) # conv3x3_ocr : conv norm relu (pixel representation 415 | 416 | context = self.ocr_gather_head(feats, out_aux) # context : batch x c x num_keypoint x 1, feats: batch, c, H, W 417 | feats = self.ocr_distri_head(feats, context) 418 | out = self.cls_head(feats) 419 | return [out, out_aux] 420 | else: 421 | return [self.cls_head(feats), None] 422 | 423 | def compute_hrnet_feats(self, x, additional_features, input_hint_heatmap): 424 | x = self.compute_pre_stage_features(x, additional_features) 425 | x = self.layer1(x) 426 | 427 | if input_hint_heatmap is not None: 428 | hint_encoder_output = self.HintEncSENet(x, input_hint_heatmap) 429 | 430 | x_list = [] 431 | for i in range(self.stage2_num_branches): 432 | if self.transition1[i] is not None: 433 | x_list.append(self.transition1[i](x)) 434 | else: 435 | x_list.append(x) 436 | y_list = self.stage2(x_list) 437 | 438 | x_list = [] 439 | for i in range(self.stage3_num_branches): 440 | if self.transition2[i] is not None: 441 | if i < self.stage2_num_branches: 442 | x_list.append(self.transition2[i](y_list[i])) 443 | else: 444 | x_list.append(self.transition2[i](y_list[-1])) 445 | else: 446 | x_list.append(y_list[i]) 447 | y_list = self.stage3(x_list) 448 | 449 | x_list = [] 450 | for i in range(self.stage4_num_branches): 451 | if self.transition3[i] is not None: 452 | if i < self.stage3_num_branches: 453 | x_list.append(self.transition3[i](y_list[i])) 454 | else: 455 | x_list.append(self.transition3[i](y_list[-1])) 456 | else: 457 | x_list.append(y_list[i]) 458 | x = self.stage4(x_list) 459 | 460 | out = self.aggregate_hrnet_features(x) 461 | if input_hint_heatmap is not None: 462 | return hint_encoder_output * out 463 | else: 464 | return out 465 | 466 | 467 | def compute_pre_stage_features(self, x, additional_features): 468 | x = self.conv1(x) 469 | x = self.bn1(x) 470 | x = self.relu(x) 471 | if additional_features is not None: 472 | x = x + additional_features 473 | x = self.conv2(x) 474 | x = self.bn2(x) 475 | return self.relu(x) 476 | 477 | def aggregate_hrnet_features(self, x): 478 | # Upsampling 479 | x0_h, x0_w = x[0].size(2), x[0].size(3) 480 | x1 = F.interpolate(x[1], size=(x0_h, x0_w), 481 | mode='bilinear', align_corners=self.align_corners) 482 | x2 = F.interpolate(x[2], size=(x0_h, x0_w), 483 | mode='bilinear', align_corners=self.align_corners) 484 | x3 = F.interpolate(x[3], size=(x0_h, x0_w), 485 | mode='bilinear', align_corners=self.align_corners) 486 | 487 | return torch.cat([x[0], x1, x2, x3], 1) 488 | 489 | def load_pretrained_weights(self, pretrained_path=''): 490 | model_dict = self.state_dict() 491 | 492 | if not os.path.exists(pretrained_path): 493 | print(f'\nFile "{pretrained_path}" does not exist.') 494 | print('You need to specify the correct path to the pre-trained weights.\n' 495 | 'You can download the weights for HRNet from the repository:\n' 496 | 'https://github.com/HRNet/HRNet-Image-Classification') 497 | exit(1) 498 | pretrained_dict = torch.load(pretrained_path, map_location='cpu') 499 | pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in 500 | pretrained_dict.items()} 501 | 502 | pretrained_dict = {k: v for k, v in pretrained_dict.items() 503 | if k in model_dict.keys()} 504 | 505 | model_dict.update(pretrained_dict) 506 | self.load_state_dict(model_dict, strict=False) 507 | 508 | 509 | 510 | 511 | import torch 512 | import torch.nn as nn 513 | GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet' 514 | 515 | 516 | class BasicBlockV1b(nn.Module): 517 | expansion = 1 518 | 519 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, 520 | previous_dilation=1, norm_layer=nn.BatchNorm2d): 521 | super(BasicBlockV1b, self).__init__() 522 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 523 | padding=dilation, dilation=dilation, bias=False) 524 | self.bn1 = norm_layer(planes) 525 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, 526 | padding=previous_dilation, dilation=previous_dilation, bias=False) 527 | self.bn2 = norm_layer(planes) 528 | 529 | self.relu = nn.ReLU(inplace=True) 530 | self.downsample = downsample 531 | self.stride = stride 532 | 533 | def forward(self, x): 534 | residual = x 535 | 536 | out = self.conv1(x) 537 | out = self.bn1(out) 538 | out = self.relu(out) 539 | 540 | out = self.conv2(out) 541 | out = self.bn2(out) 542 | 543 | if self.downsample is not None: 544 | residual = self.downsample(x) 545 | 546 | out = out + residual 547 | out = self.relu(out) 548 | 549 | return out 550 | 551 | 552 | class BottleneckV1b(nn.Module): 553 | expansion = 4 554 | 555 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, 556 | previous_dilation=1, norm_layer=nn.BatchNorm2d): 557 | super(BottleneckV1b, self).__init__() 558 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 559 | self.bn1 = norm_layer(planes) 560 | 561 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 562 | padding=dilation, dilation=dilation, bias=False) 563 | self.bn2 = norm_layer(planes) 564 | 565 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 566 | self.bn3 = norm_layer(planes * self.expansion) 567 | 568 | self.relu = nn.ReLU(inplace=True) 569 | self.downsample = downsample 570 | self.stride = stride 571 | 572 | def forward(self, x): 573 | residual = x 574 | 575 | out = self.conv1(x) 576 | out = self.bn1(out) 577 | out = self.relu(out) 578 | 579 | out = self.conv2(out) 580 | out = self.bn2(out) 581 | out = self.relu(out) 582 | 583 | out = self.conv3(out) 584 | out = self.bn3(out) 585 | 586 | if self.downsample is not None: 587 | residual = self.downsample(x) 588 | 589 | out = out + residual 590 | out = self.relu(out) 591 | 592 | return out --------------------------------------------------------------------------------