├── model
├── iterativeRefinementModels
│ ├── __init__.py
│ ├── RITM_modules
│ │ ├── __init__.py
│ │ ├── RITM_ocr.py
│ │ └── RITM_hrnet.py
│ └── RITM_SE_HRNet32.py
└── __init__.py
├── train.sh
├── test.sh
├── dataset
├── transforms.py
├── __init__.py
└── dataset.py
├── config
└── spineweb_ours.yaml
├── morph_pairs
└── dataset16
│ └── dataset16.json
├── misc
├── optimizer.py
├── loss.py
├── metric.py
├── heatmap_maker.py
└── train.py
├── main.py
├── data_preprocessing.py
├── README.md
└── util.py
/model/iterativeRefinementModels/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model/iterativeRefinementModels/RITM_modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | seed="42"
4 | gpu='1'
5 |
6 | config='spineweb_ours'
7 | default_command="--seed ${seed} --config ${config}"
8 | custom_command=""
9 | CUDA_VISIBLE_DEVICES="${gpu}" python -u main.py ${default_command} ${custom_command} --save_test_prediction
10 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | seed="42"
4 | gpu='1'
5 |
6 | config='_'
7 | default_command="--seed ${seed} --config ${config}"
8 | custom_command=""
9 | CUDA_VISIBLE_DEVICES="${gpu}" python -u main.py ${default_command} ${custom_command} --save_test_prediction --only_test_version "ExpNum[00001]_Dataset[dataset16]_Model[RITM_SE_HRNet32]_config[spineweb_ours]_seed[42]"
10 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def get_model(save_manager):
4 | name = save_manager.config.Model.NAME
5 |
6 | if name == 'RITM_SE_HRNet32':
7 | from model.iterativeRefinementModels.RITM_SE_HRNet32 import RITM as Model
8 | else:
9 | save_manager.write_log('ERROR: NOT SPECIFIED MODEL NAME: {}'.format(name))
10 | raise NotImplemented
11 |
12 | model = Model(save_manager.config)
13 | return model
--------------------------------------------------------------------------------
/dataset/transforms.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import albumentations
3 |
4 | def default_aug(img_size):
5 | return albumentations.Compose([
6 | albumentations.augmentations.geometric.rotate.SafeRotate((-15, 15), p=0.5, border_mode=cv2.BORDER_CONSTANT),
7 | albumentations.HorizontalFlip(p=0.5),
8 | albumentations.augmentations.geometric.resize.RandomScale((0.9, 1.2), p=0.5),
9 | albumentations.RandomBrightnessContrast(),
10 | albumentations.augmentations.geometric.resize.Resize(img_size[0], img_size[1], p=1)
11 | ], keypoint_params=albumentations.KeypointParams(format='xy', remove_invisible=False))
12 |
13 | def fake(**kwargs):
14 | return {**kwargs}
15 |
--------------------------------------------------------------------------------
/config/spineweb_ours.yaml:
--------------------------------------------------------------------------------
1 | PATH:
2 | ROOT_PATH: './'
3 | DATA:
4 | IMAGE: './data/dataset16_512/'
5 | TABLE: './data/dataset16_512/'
6 |
7 | Dataset:
8 | NAME: 'dataset16'
9 | image_size: [512, 256]
10 | num_keypoint: 68
11 | heatmap_std: 7.5
12 | aug:
13 | type: 'aug_default'
14 | p: [0.2, 0.2, 0.2, 0.2]
15 | subpixel_decoding_patch_size: 15
16 | subpixel_decoding: True
17 |
18 | Model:
19 | NAME: 'RITM_SE_HRNet32'
20 | SE_maxpool: True
21 |
22 | Optimizer:
23 | optimizer: 'Adam'
24 | lr: 0.001
25 | scheduler: ''
26 |
27 | Train:
28 | patience: 50
29 | batch_size: 4
30 | epoch: 5000
31 | metric: ["MAE", "RMSE", "MRE"]
32 | decision_metric: 'hargmax_mm_MRE'
33 | SR_standard: ''
34 |
35 |
36 | Hint:
37 | max_hint: 13
38 | num_dist: datset16
39 |
40 |
41 | MISC:
42 | TB: True
43 | gpu: '0'
44 | num_workers: 0
45 |
46 | Morph:
47 | use: True
48 | pairs: 'dataset16'
49 | angle_lambda: 0.01
50 | distance_lambda: 0.01
51 | distance_l1: True
52 | cosineSimilarityLoss: True
53 | threePointAngle: True
54 |
--------------------------------------------------------------------------------
/morph_pairs/dataset16/dataset16.json:
--------------------------------------------------------------------------------
1 | [[[2, 4], [6, 8], [3, 5], [18, 20], [10, 12], [7, 9], [14, 16], [11, 13], [22, 24], [23, 25], [19, 21], [26, 28], [15, 17], [30, 32], [4, 6], [8, 12], [8, 10], [31, 33], [0, 2], [27, 29], [34, 36], [6, 10], [38, 40], [4, 8], [1, 3], [12, 16], [35, 37], [12, 14], [46, 48], [29, 33], [29, 31], [2, 6], [33, 37], [39, 41], [42, 44], [17, 21], [27, 31], [35, 39], [23, 27], [5, 7], [10, 14], [9, 13], [16, 20], [25, 27], [26, 30], [1, 5], [31, 35], [43, 45], [21, 25], [19, 23], [6, 12], [16, 18], [5, 9], [33, 35], [14, 18], [28, 30], [25, 29], [0, 4], [28, 32], [55, 57], [59, 61], [9, 11], [21, 23], [24, 26], [13, 17], [36, 40], [18, 22], [47, 49], [7, 11], [3, 7]], [[2, 63, 4], [2, 62, 4], [2, 61, 4], [2, 59, 4], [2, 60, 4], [2, 58, 4], [3, 63, 5], [3, 62, 5], [2, 57, 4], [2, 56, 4], [2, 55, 4], [3, 61, 5], [3, 60, 5], [2, 54, 4], [6, 63, 8], [3, 59, 5], [3, 58, 5], [2, 53, 4], [6, 62, 8], [2, 52, 4], [2, 51, 4], [6, 61, 8], [7, 62, 9], [3, 56, 5], [7, 63, 9], [3, 57, 5], [2, 50, 4], [6, 59, 8], [6, 60, 8], [3, 54, 5], [3, 55, 5], [7, 60, 9], [6, 58, 8], [2, 49, 4], [7, 61, 9], [2, 48, 4], [11, 62, 13], [11, 63, 13], [6, 57, 8], [0, 63, 2], [2, 47, 4], [7, 58, 9], [2, 46, 4], [3, 52, 5], [7, 59, 9], [10, 63, 12], [0, 62, 2], [3, 53, 5], [6, 56, 8], [6, 55, 8], [3, 50, 5], [0, 61, 2], [10, 62, 12], [11, 61, 13], [3, 51, 5], [7, 65, 9], [11, 60, 13], [6, 54, 8], [0, 60, 2], [7, 56, 9], [10, 61, 12], [2, 44, 4], [0, 59, 2], [2, 45, 4], [7, 57, 9], [0, 58, 2], [11, 59, 13], [11, 58, 13], [6, 53, 8], [3, 48, 5]]]
--------------------------------------------------------------------------------
/misc/optimizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch import optim
3 |
4 | def get_optimizer(config, model):
5 | optimizer = base_optimizer(config, model)
6 | return optimizer
7 |
8 | class base_optimizer(object):
9 | def __init__(self, config, model):
10 | super(base_optimizer, self).__init__()
11 |
12 | # optimizer
13 | if config.optimizer == 'Adam':
14 | self.optimizer = optim.Adam(model.parameters(), lr=config.lr)
15 | elif config.optimizer == 'Adadelta':
16 | self.optimizer = optim.Adadelta(model.parameters(), lr=config.lr)
17 | else:
18 | raise
19 |
20 | # scheduler
21 | if config.scheduler == 'ReduceLROnPlateau':
22 | from torch.optim.lr_scheduler import ReduceLROnPlateau
23 | self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', verbose=True)
24 | elif config.scheduler == 'StepLR':
25 | from torch.optim.lr_scheduler import StepLR
26 | self.scheduler = StepLR(self.optimizer, 100, gamma=0.1, last_epoch=-1)
27 | else:
28 | self.scheduler = None
29 |
30 | def update_model(self, loss):
31 | self.optimizer.zero_grad()
32 |
33 | if np.isnan(loss.item()):
34 | print('\n\n\nERROR::: THE LOSS IS NAN\n\n\n')
35 | raise()
36 | else:
37 | loss.backward()
38 | self.optimizer.step()
39 | return None
40 |
41 | def scheduler_step(self, metric):
42 | if self.scheduler is not None:
43 | self.scheduler.step(metric)
44 | return None
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import json
4 | import torch
5 | from dataset.dataset import Dataset, collate_fn
6 |
7 | def dataloader(config, split:str, data:list):
8 | # split : 'train' / 'val' / 'test'
9 | dataset = Dataset(config=config, split=split, data=data)
10 |
11 | # loader
12 | if split == 'train':
13 | shuffle = True
14 | drop_last = True
15 | else:
16 | shuffle = False
17 | drop_last = False
18 |
19 | def _init_fn(worker_id):
20 | np.random.seed(config.seed + worker_id)
21 |
22 | data_loader = torch.utils.data.DataLoader(dataset, shuffle=shuffle, worker_init_fn=_init_fn,
23 | batch_size=config.Train.batch_size, num_workers=config.MISC.num_workers, collate_fn=collate_fn, drop_last=drop_last)
24 | return data_loader
25 |
26 | def get_split_data(config, split=None):
27 | if split == 'train':
28 | with open(os.path.join(config.PATH.DATA.TABLE, 'train.json'), 'r') as f:
29 | train_data = json.load(f)
30 | return train_data
31 | elif split =='val':
32 | with open(os.path.join(config.PATH.DATA.TABLE, 'val.json'), 'r') as f:
33 | val_data = json.load(f)
34 | return val_data
35 |
36 | elif split =='test':
37 | with open(os.path.join(config.PATH.DATA.TABLE, 'test.json'), 'r') as f:
38 | test_data = json.load(f)
39 | return test_data
40 |
41 | def get_dataloader(config, split):
42 | data = get_split_data(config, split)
43 | loader = dataloader(config, split, data)
44 | return loader
45 |
--------------------------------------------------------------------------------
/misc/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from munch import Munch
5 |
6 | class LossManager():
7 | def __init__(self, config, heatmap_maker):
8 | self.config = config
9 | self.heatmap_maker = heatmap_maker
10 | self.mse_criterion = nn.MSELoss()
11 | self.mae_criterion = nn.L1Loss()
12 | self.bce_loss = nn.BCELoss()
13 | if self.config.Morph.use :
14 | if self.config.Morph.cosineSimilarityLoss:
15 | self.angle_criterion = nn.CosineEmbeddingLoss()
16 | else:
17 | self.angle_criterion = self.mse_criterion
18 |
19 | def __call__(self, pred_heatmap, label):
20 | loss, pred_heatmap = self.get_heatmap_loss(pred_heatmap=pred_heatmap, label_heatmap=label.heatmap)
21 | pred_coord = self.heatmap_maker.get_heatmap2sargmax_coord(pred_heatmap=pred_heatmap)
22 | if self.config.Morph.use:
23 | morph_loss = self.get_morph_loss(pred_coord=pred_coord, label_coord=label.coord, morph_loss_mask=label.morph_loss_mask)
24 | loss = loss + morph_loss
25 |
26 | if self.config.Morph.coord_use:
27 | coord_loss = nn.L1Loss()(pred_coord, label.coord)
28 | loss = loss + 0.01 * coord_loss
29 |
30 | if torch.isnan(loss).item():
31 | print("========== ERROR ::: Loss is nan ===========")
32 | raise
33 |
34 |
35 | out = Munch.fromDict({'pred':{'sargmax_coord':pred_coord, 'heatmap':pred_heatmap}, 'loss':loss, })
36 | return out
37 |
38 | def get_heatmap_loss(self, pred_heatmap, label_heatmap, mse_flag=False):
39 | if mse_flag:
40 | heatmap_loss = self.mse_criterion(pred_heatmap, label_heatmap)
41 | else: # BCE loss
42 | pred_heatmap = pred_heatmap.sigmoid()
43 | heatmap_loss = self.bce_loss(pred_heatmap, label_heatmap)
44 | return heatmap_loss, pred_heatmap
45 |
46 | def get_morph_loss(self, pred_coord, label_coord, morph_loss_mask):
47 | pred_dist, pred_angle = self.heatmap_maker.get_morph(pred_coord)
48 | with torch.no_grad():
49 | label_dist, label_angle = self.heatmap_maker.get_morph(label_coord)
50 |
51 | if morph_loss_mask.sum()>0:
52 | pred_dist, pred_angle = pred_dist[morph_loss_mask], pred_angle[morph_loss_mask]
53 | label_dist, label_angle = label_dist[morph_loss_mask], label_angle[morph_loss_mask]
54 |
55 | if self.config.Morph.distance_l1:
56 | loss_dist = self.mae_criterion(pred_dist, label_dist)
57 | else:
58 | loss_dist = self.mse_criterion(pred_dist, label_dist)
59 |
60 | if self.config.Morph.cosineSimilarityLoss:
61 | N = pred_angle.shape[0] * pred_angle.shape[1]
62 | label_similairty = torch.ones(N, dtype=torch.long, device=pred_angle.device)
63 | loss_angle = self.angle_criterion(pred_angle.reshape(N, 2), label_angle.reshape(N, 2), label_similairty)
64 | else:
65 | pred_angle_normalized = torch.nn.functional.normalize(pred_angle, dim=-1) # (batch, 13, 2)
66 | label_angle_normalized = torch.nn.functional.normalize(label_angle, dim=-1)
67 | loss_angle = self.angle_criterion(pred_angle_normalized, label_angle_normalized)
68 |
69 | return loss_dist * self.config.Morph.distance_lambda + loss_angle * self.config.Morph.angle_lambda
70 | return 0
71 |
72 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from pytz import timezone
2 | import time
3 | import argparse
4 | import datetime
5 |
6 | import numpy as np
7 | import random
8 | import torch
9 | import torch.nn as nn
10 |
11 | from util import SaveManager, TensorBoardManager
12 | from model import get_model
13 | from dataset import get_dataloader
14 | from misc.metric import MetricManager
15 | from misc.optimizer import get_optimizer
16 | from misc.train import Trainer
17 |
18 | torch.set_num_threads(4)
19 |
20 | def set_seed(config):
21 | torch.backends.cudnn.deterministic = True
22 | torch.backends.cudnn.benchmark = False
23 | torch.manual_seed(config.seed)
24 | torch.cuda.manual_seed(config.seed)
25 | torch.cuda.manual_seed_all(config.seed)
26 | np.random.seed(config.seed)
27 | random.seed(config.seed)
28 | return
29 |
30 | def parse_args():
31 | parser = argparse.ArgumentParser(description='TMI experiments')
32 | parser.add_argument('--config', type=str, help='config name, required')
33 | parser.add_argument('--seed', type=int, default=42, help='random seed')
34 | parser.add_argument('--only_test_version', type=str, default=None, help='If activated, there is no training. The number is the experiment number. => load & test model')
35 | parser.add_argument('--save_test_prediction', action='store_true', default=False, help='If activated, save test predictions at save path')
36 | arg = parser.parse_args()
37 | set_seed(arg)
38 | return arg
39 |
40 |
41 |
42 | def main(save_manager):
43 | if save_manager.config.MISC.TB and save_manager.config.only_test_version is None:
44 | writer = TensorBoardManager(save_manager)
45 | else:
46 | writer = None
47 |
48 | # model initialization
49 | device_ids = list(range(len(save_manager.config.MISC.gpu.split(','))))
50 | model = nn.DataParallel(get_model(save_manager), device_ids=device_ids)
51 | model.to(save_manager.config.MISC.device)
52 |
53 | # calculate the number of model parameters
54 | n_params = 0
55 | for k, v in model.named_parameters():
56 | n_params += v.reshape(-1).shape[0]
57 | save_manager.write_log('Number of model parameters : {}'.format(n_params), 0)
58 |
59 | # optimizer initialization
60 | optimizer = get_optimizer(save_manager.config.Optimizer, model)
61 |
62 | metric_manager = MetricManager(save_manager)
63 | trainer = Trainer(model, metric_manager)
64 |
65 | if not save_manager.config.only_test_version:
66 | # dataloader
67 | train_loader = get_dataloader(save_manager.config, 'train')
68 | val_loader = get_dataloader(save_manager.config, 'val')
69 |
70 | # training
71 | save_manager.write_log('Start Training...'.format(n_params), 4)
72 | trainer.train(save_manager=save_manager,
73 | train_loader=train_loader,
74 | val_loader=val_loader,
75 | optimizer=optimizer,
76 | writer=writer)
77 | # deallocate data loaders from the memory
78 | del train_loader
79 | del val_loader
80 |
81 | trainer.best_param, trainer.best_epoch, trainer.best_metric = save_manager.load_model()
82 |
83 | save_manager.write_log('Start Test Evaluation...'.format(n_params), 4)
84 | test_loader = get_dataloader(save_manager.config, 'test')
85 | trainer.test(save_manager=save_manager, test_loader=test_loader, writer=writer)
86 | del test_loader
87 |
88 |
89 | if __name__ == '__main__':
90 | start_time = time.time()
91 |
92 | arg = parse_args()
93 | save_manager = SaveManager(arg)
94 | save_manager.write_log('Process Start ::: {}'.format(datetime.datetime.now(), n_mark=16))
95 |
96 | main(save_manager)
97 |
98 | end_time = time.time()
99 | save_manager.write_log('Process End ::: {} {:.2f} hours'.format(datetime.datetime.now(), (end_time - start_time) / 3600), n_mark=16)
100 | save_manager.write_log('Version ::: {}'.format(save_manager.config.version), n_mark=16)
101 |
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | from munch import Munch
4 | import numpy as np
5 |
6 | import torch
7 |
8 | import dataset.transforms as transforms
9 | # from detectron2.data import transforms as T
10 |
11 | class Dataset(torch.utils.data.Dataset):
12 | def __init__(self, config, split:str, data:list):
13 |
14 | config.DICT_KEY = Munch.fromDict({})
15 | config.DICT_KEY.IMAGE = 'image'
16 | config.DICT_KEY.BBOX = 'bbox_{}'.format(config.Dataset.image_size[0])
17 | config.DICT_KEY.POINTS = 'points_{}'.format(config.Dataset.image_size[0])
18 | config.DICT_KEY.RAW_SIZE = 'raw_size_row_col'
19 | config.DICT_KEY.PSPACE = 'pixelSpacing'
20 |
21 |
22 |
23 |
24 | # init
25 | self.config = config
26 | self.split = split
27 | self.data = data
28 |
29 | if self.split == 'train':
30 | self.transformer = transforms.default_aug(self.config.Dataset.image_size)
31 |
32 | else:
33 | self.transformer = transforms.fake
34 |
35 | self.loadimage = self.load_npy
36 |
37 | for item in self.data:
38 | item[self.config.DICT_KEY.IMAGE] = item[self.config.DICT_KEY.IMAGE].replace('.png', '.npy')
39 |
40 | def __len__(self):
41 | return len(self.data)
42 |
43 | def __getitem__(self, index):
44 | # indexing
45 | item = self.data[index]
46 |
47 | # image load
48 | img_path = os.path.join(self.config.PATH.DATA.IMAGE, item[self.config.DICT_KEY.IMAGE])
49 | img, row, column = self.loadimage(img_path, item[self.config.DICT_KEY.RAW_SIZE])
50 |
51 | # pixel spacing
52 | pspace_list = item[self.config.DICT_KEY.PSPACE] # row, column
53 | raw_size_and_pspace = torch.tensor([row, column] + pspace_list)
54 |
55 | # points load (13,2) (column, row)==(xy)
56 | coords = copy.deepcopy(item[self.config.DICT_KEY.POINTS])
57 | coords.append([1.0,1.0])
58 |
59 |
60 | transformed = self.transformer(image=img, keypoints=coords)
61 | img, coords = transformed["image"], transformed["keypoints"]
62 | additional = torch.tensor([])
63 |
64 | coords = np.array(coords)
65 |
66 | # np array to tensor (800, 640)=(row, col)
67 | img = torch.tensor(img, dtype=torch.float)
68 | img = img.permute(2, 0, 1)
69 | img /= 255.0 # 0~255 to 0~1
70 | img = img * 2 - 1 # 0~1 to -1~1
71 |
72 |
73 | coords = torch.tensor(copy.deepcopy(coords[:, ::-1]), dtype=torch.float)
74 | morph_loss_mask = (coords[-1] == torch.tensor([1.0, 1.0], dtype=torch.float)).all()
75 | coords = coords[:-1]
76 |
77 | # hint
78 | if self.split == 'train':
79 | # random hint
80 | num_hint = np.random.choice(range(self.config.Dataset.num_keypoint ), size=None, p=self.config.Hint.num_dist)
81 | hint_indices = np.random.choice(range(self.config.Dataset.num_keypoint ), size=num_hint, replace=False) #[1,2,3]
82 | else:
83 | hint_indices = None
84 |
85 | return img_path, img, raw_size_and_pspace, hint_indices, coords, additional, index, morph_loss_mask
86 |
87 | def load_npy(self, img_path, size=None):
88 | img = np.load(img_path)
89 | if size is not None:
90 | row, column = size
91 | else:
92 | row, column = img.shape[:2]
93 | return img, row, column
94 |
95 | def collate_fn(batch):
96 | batch = list(zip(*batch))
97 | batch_dict = {
98 | 'input_image_path':batch[0], # list
99 | 'input_image':torch.stack(batch[1]),
100 | 'label':{'morph_offset':torch.stack(batch[2]),
101 | 'coord': torch.stack(batch[4]),
102 | 'morph_loss_mask':torch.stack(batch[7])
103 | },
104 | 'pspace':torch.stack(batch[2]),
105 | 'hint':{'index': list(batch[3])},
106 | 'additional': torch.stack(batch[5]),
107 | 'index':list(batch[6])
108 | }
109 | return Munch.fromDict(batch_dict)
--------------------------------------------------------------------------------
/data_preprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | from PIL import Image
5 | import scipy.io
6 | from tqdm.auto import tqdm
7 | import copy
8 |
9 | from albumentations import ( Resize, Compose, KeypointParams)
10 |
11 | source_path = './data/dataset16/boostnet_labeldata/'
12 | target_path = './data/'
13 |
14 | base_image_path= os.path.join(source_path,'data')
15 | base_label_path= os.path.join(source_path,'labels')
16 |
17 |
18 |
19 |
20 | train_image_paths = []
21 | test_image_paths = []
22 |
23 | for name in os.listdir(os.path.join(base_image_path,'training')):
24 | p = os.path.join(base_image_path, 'training/', name)
25 | train_image_paths.append(p)
26 |
27 | for name in os.listdir(os.path.join(base_image_path , 'test')):
28 | p = os.path.join(base_image_path, 'test/', name)
29 | test_image_paths.append(p)
30 |
31 |
32 | train_label_paths = [(i+'.mat').replace('boostnet_labeldata/data/','boostnet_labeldata/labels/') for i in train_image_paths]
33 | test_label_paths = [(i+'.mat').replace('boostnet_labeldata/data/','boostnet_labeldata/labels/') for i in test_image_paths]
34 |
35 |
36 | train_image_paths.sort()
37 | test_image_paths.sort()
38 | train_label_paths.sort()
39 | test_label_paths.sort()
40 |
41 |
42 | # select random validation dataset (val size = 128)
43 | val_idx = sorted(np.random.choice(range(len(train_image_paths)), size=128, replace=False, p=None))
44 | train_idx = [i for i in range(len(train_image_paths)) if i not in val_idx]
45 |
46 | val_image_paths = np.array(train_image_paths)[val_idx].tolist()
47 | train_image_paths = np.array(train_image_paths)[train_idx].tolist()
48 |
49 | val_label_paths = np.array(train_label_paths)[val_idx].tolist()
50 | train_label_paths = np.array(train_label_paths)[train_idx].tolist()
51 |
52 |
53 | # make json items
54 | def make_data(image_paths, label_paths):
55 | data = []
56 |
57 |
58 |
59 | for idx in range(len(image_paths)):
60 |
61 | item = {'image':None, 'label':None, 'raw_size_row_col':None, 'pixelSpacing':[1,1]}
62 |
63 |
64 | # indexing
65 | image_path = image_paths[idx]
66 | label_path = label_paths[idx]
67 |
68 | # loda data
69 | img = np.repeat(np.array(Image.open(os.path.join(image_path)))[:,:,None], 3, axis=-1) # (row, col) -> (row,col,3)
70 | label = scipy.io.loadmat(os.path.join(label_path))['p2'] # x,y
71 |
72 | # make items
73 | item['image'] = image_path.replace(source_path, '') # remove base path
74 | item['label'] = label.tolist()
75 | item['raw_size_row_col'] = (img.shape[0], img.shape[1])
76 |
77 | data.append(item)
78 | return data
79 |
80 | train_data = make_data(train_image_paths, train_label_paths)
81 | val_data = make_data(val_image_paths, val_label_paths)
82 | test_data = make_data(test_image_paths, test_label_paths)
83 |
84 |
85 | def inference_aug(img_size):
86 | return Compose([
87 | Resize(img_size[0], img_size[1]),
88 | ], keypoint_params=KeypointParams(format='xy'))
89 |
90 |
91 | # check keypoints are inside of the corresponding image
92 | print('train')
93 | remove_list = []
94 | for t, item in (enumerate(train_data)):
95 | row, col = item['raw_size_row_col']
96 | coord = np.array(item['label'])
97 | if coord[:, 1].max() > row:
98 | print('error:', t, '- exceed max y')
99 | remove_list.append(t)
100 | elif coord[:, 0].max() > col:
101 | print('error:', t, '- excced max x')
102 | remove_list.append(t)
103 | for t in remove_list:
104 | del train_data[t]
105 |
106 | print('val')
107 | remove_list=[]
108 | for t, item in (enumerate(val_data)):
109 | row, col = item['raw_size_row_col']
110 | coord = np.array(item['label'])
111 | if coord[:, 1].max() > row:
112 | print('error:', t, '- exceed max y')
113 | remove_list.append(t)
114 | if coord[:, 0].max() > col:
115 | print('error:', t, '- excced max x')
116 | remove_list.append(t)
117 | for t in remove_list:
118 | del val_data[t]
119 |
120 |
121 | map_dic = {}
122 | for sizes in [(512, 256)]:
123 | aug = inference_aug((sizes[0], sizes[1]))
124 | print(sizes, 'starts')
125 | size = sizes[0]
126 |
127 | if not os.path.exists('{}/dataset16_{}'.format(target_path,size)):
128 | os.makedirs('{}/dataset16_{}'.format(target_path,size))
129 |
130 | for temp in ['train', 'val', 'test']:
131 | print(temp, 'starts')
132 |
133 | if temp == 'train':
134 | table = copy.deepcopy(train_data)
135 | elif temp == 'val':
136 | table = copy.deepcopy(val_data)
137 | else:
138 | table = copy.deepcopy(test_data)
139 |
140 | for t, item in tqdm(enumerate(table)):
141 | # image load
142 | img_path = os.path.join(
143 | source_path,
144 | item['image'])
145 | img = np.repeat(np.array(Image.open(img_path))[:, :, None], 3, axis=-1) # (row, col) -> (row,col,3)
146 |
147 | points = item['label']
148 |
149 | if img_path not in map_dic:
150 | map_dic[img_path] = {'points': points}
151 |
152 | transformed = aug(image=img, keypoints=points)
153 | img, points = transformed["image"], transformed["keypoints"]
154 |
155 | img_save_path = os.path.join(target_path,'dataset16_{}'.format(size),
156 | item['image'].replace('.jpg', '.npy'))
157 | if not os.path.exists( os.path.dirname(img_save_path)):
158 | os.makedirs(os.path.dirname(img_save_path))
159 | np.save(img_save_path, img)
160 | item['points_{}'.format(size)] = points
161 | item['image'] = item['image'].replace('.jpg', '.npy')
162 |
163 | map_dic[img_path].update({'points_{}'.format(size): points})
164 |
165 | with open(os.path.join(target_path,'dataset16_{}'.format(size),'{}.json'.format(temp)), 'w') as f:
166 | json.dump(table, f)
167 |
--------------------------------------------------------------------------------
/misc/metric.py:
--------------------------------------------------------------------------------
1 | from munch import Munch
2 | from misc.heatmap_maker import HeatmapMaker, heatmap2hargmax_coord
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 |
7 | class MetricManager():
8 | def __init__(self, save_manager):
9 | self.config = save_manager.config
10 |
11 | if save_manager.config.Train.decision_metric.split('_')[-1] in ['MAE', 'MSE', 'MRE']:
12 | self.minimize_metric = True
13 | elif save_manager.config.Train.decision_metric.split('_')[-1] in ['AUC', 'ACC']:
14 | self.minimize_metric = False
15 | else:
16 | save_manager.write_log('ERROR::: Specify the decision metric in metric.py')
17 | raise()
18 | self.heatmap_maker = HeatmapMaker(save_manager.config)
19 | self.init_running_metric()
20 | self.device = self.config.MISC.device
21 |
22 |
23 | def init_running_metric(self):
24 | self.running_metric=Munch.fromDict({})
25 | for name in ['hargmax_pixel','hargmax_mm','sargmax_pixel','sargmax_mm']:
26 | for metric in self.config.Train.metric:
27 | metric_name = '{}_{}'.format(name,metric)
28 | self.running_metric[metric_name] = []
29 |
30 | if metric == 'MRE':
31 | for standard in self.config.Train.SR_standard:
32 | self.running_metric['{}_SR[{}]'.format(name, standard)] = []
33 |
34 | def average_running_metric(self):
35 |
36 | self.metric = Munch.fromDict({})
37 | for metric in self.running_metric:
38 | try:
39 | self.metric[metric] = np.mean(self.running_metric[metric])
40 | except:
41 | self.metric[metric] = np.mean([m.mean() for m in self.running_metric[metric]])
42 | self.init_running_metric()
43 | return self.metric
44 |
45 | def init_best_metric(self):
46 | if self.minimize_metric:
47 | best_metric = {self.config.Train.decision_metric: 1e4}
48 | else:
49 | best_metric = {self.config.Train.decision_metric: -1e4}
50 | return best_metric
51 |
52 | def is_new_best(self, old, new):
53 | if old[self.config.Train.decision_metric] > new[self.config.Train.decision_metric]:
54 | if self.minimize_metric:
55 | return True
56 | else:
57 | return False
58 | else:
59 | if self.minimize_metric:
60 | return False
61 | else:
62 | return True
63 |
64 | def measure_metric(self, pred, label, pspace, metric_flag, average_flag):
65 | with torch.no_grad():
66 | # to cuda
67 | pspace = pspace.detach().to(self.device)
68 | label.coord = label.coord.detach().to(self.device)
69 | label.heatmap = label.heatmap.detach().to(self.device)
70 | pred.sargmax_coord = pred.sargmax_coord.detach().to(self.device)
71 | pred.heatmap = pred.heatmap.detach().to(self.device)
72 |
73 | if self.config.Model.facto_heatmap:
74 | pred.hargmax_coord = pred.hard_coord.float().to(self.device)
75 | else:
76 | # pred coord (hard-argmax)
77 | pred.hargmax_coord = heatmap2hargmax_coord(pred.heatmap)
78 | pred.hargmax_coord_mm = self.pixel2mm(self.config, pred.hargmax_coord, pspace)
79 |
80 | # pred coord (soft-argmax, model output)
81 | pred.sargmax_coord_mm = self.pixel2mm(self.config, pred.sargmax_coord, pspace)
82 |
83 | label.coord_mm = self.pixel2mm(self.config, label.coord, pspace)
84 |
85 | zip = [(pred.hargmax_coord, label.coord, 'hargmax_pixel'),
86 | (pred.hargmax_coord_mm, label.coord_mm, 'hargmax_mm'),
87 | (pred.sargmax_coord, label.coord, 'sargmax_pixel'),
88 | (pred.sargmax_coord_mm, label.coord_mm, 'sargmax_mm')]
89 |
90 | if metric_flag:
91 | metric_list = self.config.Train.metric
92 | else:
93 | metric_list = [self.config.Train.decision_metric.split('_')[-1]]
94 |
95 | for item in zip:
96 | pred_coord, label_coord, name = item
97 |
98 | if 'MAE' in metric_list:
99 | if average_flag:
100 | self.running_metric['{}_MAE'.format(name)].append(nn.L1Loss()(pred_coord, label_coord).item())
101 | else:
102 | self.running_metric['{}_MAE'.format(name)].append(
103 | (nn.L1Loss(reduction='none')(pred_coord, label_coord)).cpu())
104 |
105 | if 'RMSE' in metric_list:
106 | self.running_metric['{}_RMSE'.format(name)].append(torch.sqrt(nn.MSELoss()(pred_coord, label_coord)).item())
107 |
108 | if 'MRE' in metric_list:
109 | # (batch, 13, 2)
110 | y_diff_sq = (pred_coord[:, :, 0] - label_coord[:, :, 0]) ** 2
111 | x_diff_sq = (pred_coord[:, :, 1] - label_coord[:, :, 1]) ** 2
112 | sqrt_x2y2 = torch.sqrt(y_diff_sq + x_diff_sq) # (batch,13)
113 | if average_flag:
114 | mre = torch.mean(sqrt_x2y2).item()
115 | else:
116 | mre = sqrt_x2y2.cpu()
117 | self.running_metric['{}_MRE'.format(name)].append(mre)
118 |
119 | for standard in self.config.Train.SR_standard:
120 | self.running_metric['{}_SR[{}]'.format(name, standard)].append((sqrt_x2y2 < standard).float().mean().item())
121 |
122 |
123 | def pixel2mm(self, config, points, pspace):
124 | mm_points = torch.zeros_like(points)
125 | pspace = pspace.to(points.device)
126 | mm_points[:, :, 1] = points[:, :, 1] / config.Dataset.image_size[1] * pspace[:, 1:2] * pspace[:, 3:4]
127 | mm_points[:, :, 0] = points[:, :, 0] / config.Dataset.image_size[0] * pspace[:, 0:1] * pspace[:, 2:3]
128 | return mm_points
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
--------------------------------------------------------------------------------
/model/iterativeRefinementModels/RITM_modules/RITM_ocr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch._utils
9 | import torch.nn.functional as F
10 |
11 |
12 | class SpatialGather_Module(nn.Module):
13 | """
14 | Aggregate the context features according to the initial
15 | predicted probability distribution.
16 | Employ the soft-weighted method to aggregate the context.
17 | """
18 |
19 | def __init__(self, cls_num=0, scale=1):
20 | super(SpatialGather_Module, self).__init__()
21 | self.cls_num = cls_num
22 | self.scale = scale
23 |
24 | def forward(self, feats, probs):
25 | # feats: batch, c, H, W
26 | # probs: batch x c x num_keypoint x 1,
27 | batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
28 | probs = probs.view(batch_size, c, -1) # (b, num_keypoint, -1)
29 | feats = feats.view(batch_size, feats.size(1), -1)
30 | feats = feats.permute(0, 2, 1) # batch x hw x c
31 | probs = F.softmax(self.scale * probs, dim=2) # batch x num_keypoint x hw
32 | ocr_context = torch.matmul(probs, feats) \
33 | .permute(0, 2, 1).unsqueeze(3) # batch x c x num_keypoint x 1
34 | return ocr_context
35 |
36 |
37 | class SpatialOCR_Module(nn.Module):
38 | """
39 | Implementation of the OCR module:
40 | We aggregate the global object representation to update the representation for each pixel.
41 | """
42 |
43 | def __init__(self,
44 | in_channels,
45 | key_channels,
46 | out_channels,
47 | scale=1,
48 | dropout=0.1,
49 | norm_layer=nn.BatchNorm2d,
50 | align_corners=True):
51 | super(SpatialOCR_Module, self).__init__()
52 | self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
53 | norm_layer, align_corners)
54 | _in_channels = 2 * in_channels
55 |
56 | self.conv_bn_dropout = nn.Sequential(
57 | nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
58 | nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
59 | nn.Dropout2d(dropout)
60 | )
61 |
62 | def forward(self, feats, proxy_feats):
63 | # proxy_feats=context : batch x c x num_keypoint x 1
64 | # feats: (b, feature_dim, H/4,W/4)
65 | context = self.object_context_block(feats, proxy_feats)
66 |
67 | output = self.conv_bn_dropout(torch.cat([context, feats], 1))
68 |
69 | return output
70 |
71 |
72 | class ObjectAttentionBlock2D(nn.Module):
73 | '''
74 | The basic implementation for object context block
75 | Input:
76 | N X C X H X W
77 | Parameters:
78 | in_channels : the dimension of the input feature map
79 | key_channels : the dimension after the key/query transform
80 | scale : choose the scale to downsample the input feature maps (save memory cost)
81 | bn_type : specify the bn type
82 | Return:
83 | N X C X H X W
84 | '''
85 |
86 | def __init__(self,
87 | in_channels,
88 | key_channels,
89 | scale=1,
90 | norm_layer=nn.BatchNorm2d,
91 | align_corners=True):
92 | super(ObjectAttentionBlock2D, self).__init__()
93 | self.scale = scale
94 | self.in_channels = in_channels
95 | self.key_channels = key_channels
96 | self.align_corners = align_corners
97 |
98 | self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
99 | self.f_pixel = nn.Sequential(
100 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
101 | kernel_size=1, stride=1, padding=0, bias=False),
102 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
103 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
104 | kernel_size=1, stride=1, padding=0, bias=False),
105 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
106 | )
107 | self.f_object = nn.Sequential(
108 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
109 | kernel_size=1, stride=1, padding=0, bias=False),
110 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
111 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
112 | kernel_size=1, stride=1, padding=0, bias=False),
113 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
114 | )
115 | self.f_down = nn.Sequential(
116 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
117 | kernel_size=1, stride=1, padding=0, bias=False),
118 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
119 | )
120 | self.f_up = nn.Sequential(
121 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
122 | kernel_size=1, stride=1, padding=0, bias=False),
123 | nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
124 | )
125 |
126 | def forward(self, x, proxy):
127 | # proxy: batch, c, num_keypoint, 1
128 | batch_size, h, w = x.size(0), x.size(2), x.size(3)
129 | if self.scale > 1:
130 | x = self.pool(x)
131 |
132 | query = self.f_pixel(x).view(batch_size, self.key_channels, -1) # (b,c,hw)
133 | query = query.permute(0, 2, 1) # (b, hw, c)
134 |
135 |
136 | key = self.f_object(proxy).view(batch_size, self.key_channels, -1) # (b, c, num_key)
137 | value = self.f_down(proxy).view(batch_size, self.key_channels, -1) # (b, c, num_key)
138 | value = value.permute(0, 2, 1) # (b, num_key, c)
139 |
140 | sim_map = torch.matmul(query, key) # b, hw, num_key
141 | sim_map = (self.key_channels ** -.5) * sim_map #scaling
142 | sim_map = F.softmax(sim_map, dim=-1) # num_key softamx
143 |
144 | context = torch.matmul(sim_map, value) # b, hw, c)
145 | context = context.permute(0, 2, 1).contiguous() # b,c,hw
146 | context = context.view(batch_size, self.key_channels, *x.size()[2:]) # b, c, h, w
147 | context = self.f_up(context) # b, c, h, w
148 | if self.scale > 1:
149 | context = F.interpolate(input=context, size=(h, w),
150 | mode='bilinear', align_corners=self.align_corners)
151 |
152 | return context
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Morphology-Aware Interactive Keypoint Estimation (MICCAI 2022) - Official PyTorch Implementation
2 |
3 | [__[Paper]__](https://arxiv.org/abs/2209.07163)
4 |
5 | [__[Project page]__](https://seharanul17.github.io/interactive_keypoint_estimation/)
6 |
7 | [__[Video]__](https://youtu.be/Z5gtLviQ_TU)
8 |
9 | ## Introduction
10 | This is the official Pytorch implementation of [Morphology-Aware Interactive Keypoint Estimation]().
11 |
12 | Diagnosis based on medical images, such as X-ray images, often involves manual annotation of anatomical keypoints. However, this process involves significant human efforts and can thus be a bottleneck in the diagnostic process. To fully automate this procedure, deep-learning-based methods have been widely proposed and have achieved high performance in detecting keypoints in medical images. However, these methods still have clinical limitations: accuracy cannot be guaranteed for all cases, and it is necessary for doctors to double-check all predictions of models. In response, we propose a novel deep neural network that, given an X-ray image, automatically detects and refines the anatomical keypoints through a user-interactive system in which doctors can fix mispredicted keypoints with fewer clicks than needed during manual revision. Using our own collected data and the publicly available AASCE dataset, we demonstrate the effectiveness of the proposed method in reducing the annotation costs via extensive quantitative and qualitative results.
13 |
14 | ## Environment
15 | The code was developed using python 3.8 on Ubuntu 18.04.
16 |
17 | The experiments were performed on a single GeForce RTX 3090 in the training and evaluation phases.
18 |
19 | ## Quick start
20 |
21 | ### Prerequisites
22 | Install following dependencies:
23 | - Python 3.8
24 | - torch == 1.8.0
25 | - albumentations == 1.1.0
26 | - munch
27 | - tensorboard
28 | - pytz
29 | - tqdm
30 |
31 |
32 | ### Preparing code and model files
33 |
34 |
35 | #### Directory layout
36 | .
37 | ├── code
38 | │ ├── data
39 | │ ├── pretrained_models
40 | │ ├── data_preprocessing.py
41 | │ ├── train.sh
42 | │ ├── test.sh
43 | │ └── ...
44 | ├── save
45 | └── ...
46 |
47 | 1. Clone this repository in the ``code`` folder:
48 | ```
49 | git clone https://github.com/seharanul17/interactive_keypoint_estimation code
50 | ```
51 |
52 | 2. Create ``code/pretrained_models`` and ``save`` folders.
53 | ```
54 | mkdir code/pretrained_models
55 | mkdir save
56 | ```
57 |
58 | 3. To train our model using the pretrained HRNet backbone model, download the model file from the [HRNet Github repository](https://github.com/HRNet/HRNet-Image-Classification).
59 | Place the downloaded file in the ``pretrained_models`` folder. Related code line can be found [here](https://github.com/seharanul17/interactive_keypoint_estimation/blob/7f50ec271b9ae9613c839533d3958110405d04f5/model/iterativeRefinementModels/RITM_SE_HRNet32.py#L29).
60 |
61 |
62 | 4. To test our pre-trained model, download our model file and config file from [here](https://www.dropbox.com/sh/m53iqw9loddqhfq/AAD0KuCCxpXsBE435Hw3KJU8a?dl=0).
63 | Place the downloaded folder contatining the files into the ``save`` folder.
64 | Related code line can be found [here](https://github.com/seharanul17/interactive_keypoint_estimation/blob/7f50ec271b9ae9613c839533d3958110405d04f5/util.py#L77).
65 |
66 |
67 |
68 | ### Preparing data
69 |
70 | We provide the code to conduct experiments on a public dataset, the AASCE challenge dataset.
71 |
72 | 1. Create the ``data`` folder inside the ``code`` folder.
73 | ```
74 | cd code
75 | mkdir data
76 | ```
77 |
78 | 2. Download the data and place it inside the ``data`` folder.
79 | - The AASCE challenge dataset can be obtained from [SpineWeb](http://spineweb.digitalimaginggroup.ca/index.php?n=main.datasets#Dataset_16.3A_609_spinal_anterior-posterior_x-ray_images).
80 | - The AASCE challenge dataset corresponds to `Dataset 16: 609 spinal anterior-posterior x-ray images' on the webpage.
81 |
82 | 3. Preprocess the downloaded data.
83 | - Related code line is [here](https://github.com/seharanul17/interactive_keypoint_estimation/blob/b85c22e26dd315289219cbe1baecdc815ba1d097/data_preprocessing.py#L11).
84 | - Run the following command:
85 | ```
86 | python data_preprocessing.py
87 | ```
88 |
89 |
90 | ### Usage
91 | - To run the training code, run the following command:
92 | ```
93 | bash train.sh
94 | ```
95 | - To test the pre-trained model:
96 | 1. Locate the pre-trained model in the ``../save/`` folder.
97 | 2. Run the test code:
98 | ```
99 | bash test.sh
100 | ```
101 | - To test your own model:
102 | 1. Change the value of the argument ``--only_test_version {your_model_name}`` in the ``test.sh`` file.
103 | 2. Run the test code:
104 | ```
105 | bash test.sh
106 | ```
107 |
108 | When the evaluation ends, the mean radial error (MRE) of model prediction and manual revision will be reported.
109 | The ``sargmax_mm_MRE`` corresponds to the MRE reported in Fig. 4.
110 |
111 |
112 | ## Results
113 | The following table compares the refinement performance of our proposed interactive model and manual revision.
114 | Both models revise the same initial prediction results of our model. The number of user modifications is prolonged from zero (initial prediction) to five.
115 | The model performance is measured using mean radial errors on the AASCE dataset.
116 | For more information, please see Fig. 4 in our main manuscript.
117 |
118 | - "Ours (model revision)" indicates automatically revised results by the proposed interactive keypoint estimation approach.
119 | - "Ours (manual revision)" indicates fully-manually revised results by a user without the assistance of an interactive model.
120 |
121 | | Method | No. of user modification | | | | | |
122 | |:----------------: |:-----------------------: |:-------: |:-------: |:-------: |:-------: |:-------: |
123 | | | 0 (initial prediction) | 1 | 2 | 3 | 4 | 5 |
124 | | Ours (model revision) | 58.58 | 35.39 | 29.35 | 24.02 | 21.06 | 17.67 |
125 | | Ours (manual revision) | 58.58 | 55.85 | 53.33 | 50.90 | 48.55 | 47.03 |
126 |
127 |
128 | ## Citation
129 |
130 | If you find this work or code is helpful in your research, please cite:
131 | ```
132 | @inproceedings{kim2022morphology,
133 | title={Morphology-Aware Interactive Keypoint Estimation},
134 | author={Kim, Jinhee and
135 | Kim, Taesung and
136 | Kim, Taewoo and
137 | Choo, Jaegul and
138 | Kim, Dong-Wook and
139 | Ahn, Byungduk and
140 | Song, In-Seok and
141 | Kim, Yoon-Ji},
142 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
143 | pages={675--685},
144 | year={2022},
145 | organization={Springer}
146 | }
147 | ```
148 |
149 |
--------------------------------------------------------------------------------
/model/iterativeRefinementModels/RITM_SE_HRNet32.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import torchvision
5 | import numpy as np
6 |
7 | from misc.heatmap_maker import HeatmapMaker, heatmap2hargmax_coord
8 | from misc.loss import LossManager
9 |
10 | from model.iterativeRefinementModels.RITM_modules.RITM_hrnet import HighResolutionNet
11 |
12 | class RITM(nn.Module):
13 | def __init__(self, config):
14 | super(RITM, self).__init__()
15 | # default
16 | self.config = config
17 | self.device = config.MISC.device
18 |
19 | self.heatmap_maker = HeatmapMaker(config)
20 | self.LossManager = LossManager(config, self.heatmap_maker)
21 |
22 |
23 | # RITM hyper-param
24 | self.max_iter = 3
25 |
26 | # backbone
27 | self.ritm_hrnet = HighResolutionNet(width=32, ocr_width=128, small=False, num_classes=config.Dataset.num_keypoint,
28 | norm_layer=nn.BatchNorm2d, addHintEncSENet=True, SE_maxpool=config.Model.SE_maxpool, SE_softmax=config.Model.SE_softmax)
29 | self.ritm_hrnet.load_pretrained_weights('./pretrained_models/hrnetv2_w32_imagenet_pretrained.pth')
30 |
31 |
32 | if self.config.Model.no_iterative_training:
33 | hintmap_input_channels = config.Dataset.num_keypoint
34 | else:
35 | hintmap_input_channels = config.Dataset.num_keypoint*2
36 |
37 | self.hintmap_encoder = nn.Sequential(
38 | nn.Conv2d(in_channels=hintmap_input_channels, out_channels=16, kernel_size=1),
39 | nn.LeakyReLU(negative_slope=0.2),
40 | nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1),
41 | ScaleLayer(init_value=0.05, lr_mult=1)
42 | )
43 |
44 | def forward_model(self, input_image, input_hint_heatmap, prev_pred_heatmap):
45 | if prev_pred_heatmap is None:
46 | hint_features = input_hint_heatmap
47 | else:
48 | hint_features = torch.cat((prev_pred_heatmap, input_hint_heatmap), dim=1)
49 |
50 | encoded_input_hint_heatmap = self.hintmap_encoder(hint_features)
51 | pred_logit, aux_pred_logit = self.ritm_hrnet(input_image, encoded_input_hint_heatmap, input_hint_heatmap=input_hint_heatmap)
52 |
53 | pred_logit = F.interpolate(pred_logit, size=self.config.Dataset.image_size, mode='bilinear', align_corners=True)
54 | aux_pred_logit = F.interpolate(aux_pred_logit, size=self.config.Dataset.image_size, mode='bilinear', align_corners=True)
55 |
56 | return pred_logit, aux_pred_logit
57 |
58 | def forward(self, batch):
59 | #make label, hint heatmap
60 | with torch.no_grad():
61 | batch.label.coord = batch.label.coord.to(self.device)
62 | batch.label.heatmap = self.heatmap_maker.coord2heatmap(batch.label.coord)
63 | batch.hint.heatmap = torch.zeros_like(batch.label.heatmap)
64 | for i in range(batch.label.heatmap.shape[0]):
65 | if batch.hint.index[i] is not None:
66 | batch.hint.heatmap[i, batch.hint.index[i]] = batch.label.heatmap[i, batch.hint.index[i]]
67 |
68 | input_image = batch.input_image.to(self.device)
69 | input_hint_heatmap = batch.hint.heatmap.to(self.device)
70 |
71 | if batch.is_training:
72 | # 1. random number of iteration without update
73 | if self.config.Model.no_iterative_training:
74 | pred_heatmap = None
75 | else:
76 | with torch.no_grad():
77 | self.eval()
78 | num_iters = np.random.randint(0, self.max_iter)
79 | pred_heatmap = torch.zeros_like(input_hint_heatmap)
80 | for click_indx in range(num_iters):
81 | # prediction
82 | pred_logit, aux_pred_logit = self.forward_model(input_image, input_hint_heatmap, pred_heatmap)
83 | pred_heatmap = pred_logit.sigmoid()
84 |
85 | # hint update (training 때는 hint를 줘가면서 update하는 과정을 거침)
86 | batch = self.get_next_points(batch, pred_heatmap)
87 | for i in range(batch.label.heatmap.shape[0]):
88 | if batch.hint.index[i] is not None:
89 | batch.hint.heatmap[i, batch.hint.index[i]] = batch.label.heatmap[i, batch.hint.index[i]]
90 |
91 | self.train()
92 | # 2. forward for model update
93 | pred_logit, aux_pred_logit = self.forward_model(input_image, input_hint_heatmap, pred_heatmap)
94 | out = self.LossManager(pred_heatmap=pred_logit, label=batch.label)
95 | out.loss += self.LossManager.get_heatmap_loss(pred_heatmap=aux_pred_logit, label_heatmap=batch.label.heatmap)[0]
96 | else: # test
97 | # forward
98 | if self.config.Model.no_iterative_training:
99 | pred_heatmap = None
100 | else:
101 | if batch.get('prev_heatmap', None) is None:
102 | pred_heatmap = torch.zeros_like(input_hint_heatmap)
103 | else:
104 | pred_heatmap = batch.prev_heatmap.to(input_hint_heatmap.device)
105 | pred_logit, aux_pred_logit = self.forward_model(input_image, input_hint_heatmap, pred_heatmap)
106 | out = self.LossManager(pred_heatmap=pred_logit, label=batch.label)
107 | return out, batch
108 |
109 | def get_next_points(self, batch, pred_heatmap):
110 | worst_index = self.find_worst_pred_index(batch, pred_heatmap)
111 | for i, idx in enumerate(batch.hint.index):
112 | if idx is None:
113 | batch.hint.index[i] = worst_index[i] # (batch, 1)
114 | else:
115 | if not torch.is_tensor(idx):
116 | batch.hint.index[i] = torch.tensor(batch.hint.index[i], dtype=torch.long, device=worst_index.device)
117 | batch.hint.index[i] = torch.cat((batch.hint.index[i], worst_index[i])) # ... (batch, max hint)
118 | return batch
119 |
120 | def find_worst_pred_index(self, batch, pred_heatmap):
121 | # ==== calculate pixel MRE ====
122 | with torch.no_grad():
123 | hargmax_coord_pred = heatmap2hargmax_coord(pred_heatmap)
124 | batch_metric_value = torch.sqrt(torch.sum((hargmax_coord_pred-batch.label.coord)**2,dim=-1)) #MRE (batch, 13)
125 | for j, idx in enumerate(batch.hint.index):
126 | if idx is not None:
127 | batch_metric_value[j, idx] = torch.full_like(batch_metric_value[j,idx], -1000)
128 | worst_index = batch_metric_value.argmax(-1, keepdim=True)
129 | return worst_index
130 |
131 |
132 | class ScaleLayer(nn.Module):
133 | def __init__(self, init_value=1.0, lr_mult=1):
134 | super().__init__()
135 | self.lr_mult = lr_mult
136 | self.scale = nn.Parameter(
137 | torch.full((1,), init_value / lr_mult, dtype=torch.float32)
138 | )
139 |
140 | def forward(self, x):
141 | scale = torch.abs(self.scale * self.lr_mult)
142 | return x * scale
143 |
--------------------------------------------------------------------------------
/misc/heatmap_maker.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class HeatmapMaker():
4 | def __init__(self, config):
5 | self.config = config
6 | self.image_size = config.Dataset.image_size
7 | self.heatmap_std = config.Dataset.heatmap_std
8 | self.morph_pairs = config.Morph.pairs
9 | # self.morph_unit_vector_size = config.Morph.unit_vector_size
10 |
11 | def make_gaussian_heatmap(self, mean, size, std):
12 | # coord : (13,2)
13 | mean = mean.unsqueeze(1).unsqueeze(1)
14 | var = std ** 2 # 64, 1
15 | grid = torch.stack(torch.meshgrid([torch.arange(size[0]), torch.arange(size[1])]), dim=-1).unsqueeze(0)
16 | grid = grid.to(mean.device)
17 | x_minus_mean = grid - mean # 13, 1024, 1024, 2
18 |
19 | # (x-u)^2: (13, 512, 512, 2) inverse_cov: (1, 1, 1, 1) > (13, 512, 512)
20 | gaus = (-0.5 * (x_minus_mean.pow(2) / var)).sum(-1).exp()
21 | if self.config.Dataset.get('heatmap_encoding_maxone', False):
22 | gaus /= gaus.max(1, keepdim=True)[0].max(2, keepdim=True)[0] #(13, 512, 512)
23 | # (13, 512, 512)
24 | return gaus
25 |
26 |
27 |
28 | # subpixel paper encoding
29 | def sample_gaussian_heatmap(self, mean):
30 | integer = torch.floor(mean)
31 | rest = mean - integer
32 | grid_gaussian = self.make_gaussian_heatmap(rest, (2, 2), 0.5)
33 | sampled_offset = torch.multinomial(grid_gaussian.reshape(mean.shape[0], 4), num_samples=1)
34 | row = torch.floor(torch.div(sampled_offset, 2)) #rounding_mode='floor') # (13,1)
35 | col = torch.remainder(sampled_offset, 2) # (13,1)
36 |
37 | new_sampled_offset = torch.cat((row, col), dim=1)
38 | return new_sampled_offset + integer
39 |
40 | def coord2heatmap(self, coord):
41 | # coord : (batch, 13, 2), torch tensor, gpu
42 | with torch.no_grad():
43 | if self.config.Dataset.get('subpixel_heatmap_encoding', False):
44 | coord = torch.stack([self.sample_gaussian_heatmap(coord_item) for coord_item in coord])
45 | heatmap = torch.stack([
46 | self.make_gaussian_heatmap(coord_item, size=self.image_size, std=self.heatmap_std) for coord_item in coord])
47 | return heatmap
48 |
49 | def get_heatmap2sargmax_coord(self, pred_heatmap):
50 | if self.config.Dataset.subpixel_decoding:
51 | pred_coord = self.heatmap2subpixel_argmax_coord(pred_heatmap)
52 | else:
53 | pred_coord = self.heatmap2sargmax_coord(pred_heatmap)
54 | return pred_coord
55 |
56 | def heatmap2sargmax_coord(self, heatmap):
57 | # heatmap: batch, 13, 1024, 1024 (batch, c, row, column)
58 |
59 | if self.config.Dataset.label_smoothing:
60 | heatmap = torch.clamp((heatmap - 0.1) / 0.8, 0, 1)
61 |
62 | pred_col = torch.sum(heatmap, (-2)) # bach, c, column
63 | pred_row = torch.sum(heatmap, (-1)) # batch, c, row
64 |
65 | # 1, 1, 1024
66 | mesh_c = torch.arange(pred_col.shape[-1]).unsqueeze(0).unsqueeze(0).to(heatmap.device)
67 | mesh_r = torch.arange(pred_row.shape[-1]).unsqueeze(0).unsqueeze(0).to(heatmap.device)
68 |
69 | # batch, 13
70 | coord_c = torch.sum(pred_col * mesh_c, (-1)) / torch.sum(pred_col, (-1))
71 | coord_r = torch.sum(pred_row * mesh_r, (-1)) / torch.sum(pred_row, (-1))
72 |
73 | # batch, 13, 2 (row, column)
74 | coord = torch.stack((coord_r, coord_c), -1)
75 |
76 | return coord
77 |
78 | def heatmap2subpixel_argmax_coord(self, heatmap): # row: y, col: x tensor->(row, col)=(y,x)
79 | hard_coord = heatmap2hargmax_coord(heatmap).to(heatmap.device) #(batch, 13, 2)
80 | patch_size = self.config.Dataset.subpixel_decoding_patch_size # scalar
81 | reshaped_heatmap = heatmap.reshape(heatmap.shape[0] * heatmap.shape[1], 1, heatmap.shape[2], heatmap.shape[3])
82 | reshaped_hard_coord = hard_coord.reshape(hard_coord.shape[0] * hard_coord.shape[1], 1,
83 | hard_coord.shape[2]) # (bk, 1, 2)
84 |
85 | patch_index = torch.arange(patch_size, device=heatmap.device) - patch_size // 2 # (5)
86 | patch_index_y = patch_index[:, None].expand(patch_size, patch_size)[None, :] # (1, 5, 5)
87 | patch_index_x = patch_index[None, :].expand(patch_size, patch_size)[None, :] # (1, 5, 5)
88 | patch_index_y = patch_index_y + reshaped_hard_coord[:, :, 0:1] # (bk, 5, 5)
89 | patch_index_x = patch_index_x + reshaped_hard_coord[:, :, 1:2]
90 |
91 | # pad heatmap
92 | padded_reshaped_heatmap = torch.nn.functional.pad(reshaped_heatmap,
93 | (patch_size, patch_size, patch_size, patch_size),
94 | mode='constant', value=0.0) # pad (left, right, top, bottom)
95 | pad_patch_index_y = patch_size + patch_index_y
96 | pad_patch_index_x = patch_size + patch_index_x
97 |
98 | patch = padded_reshaped_heatmap[:, :, pad_patch_index_y.long(), pad_patch_index_x.long()]
99 |
100 | patch = patch.diagonal(dim1=0, dim2=2).permute(dims=[3, 0, 1, 2])
101 | patch = patch.reshape(heatmap.shape[0], heatmap.shape[1], patch_size, patch_size) #b, 13, 5, 5
102 | patch = (patch*10).reshape(patch.shape[0], patch.shape[1], -1).softmax(-1).reshape(patch.shape)
103 |
104 | soft_patch_offset = self.heatmap2sargmax_coord(patch) #batch, 13, 2
105 | final_coord = hard_coord + soft_patch_offset - patch_size//2
106 | return final_coord
107 |
108 | def get_morph(self, points):
109 | # normalize
110 | points = self.points_normalize(points, dim0=self.image_size[0], dim1=self.image_size[1])
111 |
112 | dist_pairs = torch.tensor(self.morph_pairs[0])
113 | vec_pairs = torch.tensor(self.morph_pairs[1])
114 |
115 | # dist
116 | # batch, 13, 2
117 | from_vec = points[:, dist_pairs[:, 0]]
118 | to_vec = points[:, dist_pairs[:, 1]]
119 | diffs = to_vec - from_vec
120 | pred_dist = torch.norm(diffs, dim = -1).unsqueeze(-1) # (batch, 16,1)
121 |
122 | # unit_vec_pairs
123 | if self.config.Morph.threePointAngle:
124 | # batch, 13, 2
125 | x = points[:, vec_pairs[:, 0]]
126 | y = points[:, vec_pairs[:, 1]]
127 | z = points[:, vec_pairs[:, 2]]
128 | pred_angle = self.get_angle(x,y,z)
129 | else:
130 | from_vec = points[:, vec_pairs[:, 0]]
131 | to_vec = points[:, vec_pairs[:, 1]]
132 | diffs = to_vec - from_vec
133 | pred_angle = diffs
134 |
135 | return pred_dist, pred_angle
136 |
137 | def get_angle(self, x, y, z):
138 | # y가 꼭짓점임. z-y-x
139 | N = x.shape[0] * x.shape[1]
140 | delta_vector_1 = torch.reshape(x - y, (N, 2)) # batch*16, 2
141 | delta_vector_2 = torch.reshape(z - y, (N, 2))
142 |
143 | # (y,x) = (row, col)
144 | angle_1 = torch.atan2(delta_vector_1[:, 0], delta_vector_1[:, 1] + 1e-8)
145 | angle_2 = torch.atan2(delta_vector_2[:, 0], delta_vector_2[:, 1] + 1e-8)
146 |
147 | delta_angle = angle_2 - angle_1
148 |
149 | angle_vector = torch.stack((torch.cos(delta_angle), torch.sin(delta_angle)), -1)
150 |
151 | angle_vector = torch.reshape(angle_vector, (x.shape[0], x.shape[1], 2))
152 | return angle_vector
153 |
154 | def points_normalize(self, points, dim0, dim1):
155 | new_coord = torch.zeros_like(points)
156 | new_coord[..., 0] = points[..., 0] / dim0
157 | new_coord[..., 1] = points[..., 1] / dim1
158 |
159 | return new_coord
160 |
161 | def heatmap2hargmax_coord(heatmap):
162 | b, c, row, column = heatmap.shape
163 | heatmap = heatmap.reshape(b, c, -1)
164 | max_indices = heatmap.argmax(-1)
165 | keypoint = torch.zeros(b, c, 2, device=heatmap.device)
166 | # keypoint[:, :, 0] = torch.floor(torch.div(max_indices, column)) # old environment
167 | keypoint[:, :, 0] = torch.div(max_indices, column, rounding_mode='floor')
168 | keypoint[:, :, 1] = max_indices % column
169 | return keypoint
170 |
171 |
--------------------------------------------------------------------------------
/misc/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | from munch import Munch
3 | import torch
4 | import copy
5 | from misc.metric import MetricManager
6 | from tqdm.auto import tqdm
7 |
8 | class Trainer(object):
9 | def __init__(self, model, metric_manager):
10 | self.model = model
11 |
12 | self.metric_manager = metric_manager
13 | self.best_epoch = None
14 | self.best_param = None
15 | self.best_metric = self.metric_manager.init_best_metric()
16 |
17 | self.patience = 0
18 |
19 | def train(self, save_manager, train_loader, val_loader, optimizer, writer=None):
20 | for epoch in range(1, save_manager.config.Train.epoch+1):
21 | start_time = time.time()
22 |
23 | #train
24 | self.model.train()
25 | train_loss = 0
26 | for i, batch in enumerate(train_loader):
27 | batch.is_training = True
28 | out, batch = self.forward_batch(batch)
29 | optimizer.update_model(out.loss)
30 | train_loss += out.loss.item()
31 |
32 | if writer is not None:
33 | writer.write_loss(out.loss.item(), 'train')
34 | if i % 50 == 0:
35 | save_manager.write_log('iter [{}/{}] loss [{:.6f}]'.format(i, len(train_loader), out.loss.item()))
36 |
37 | train_loss = train_loss / len(train_loader)
38 | train_metric = self.metric_manager.average_running_metric()
39 |
40 | if writer is not None:
41 | writer.write_metric(train_metric, 'train')
42 |
43 | #val
44 | self.model.eval()
45 | with torch.no_grad():
46 | val_loss = 0
47 | for i, batch in enumerate(val_loader):
48 | batch.is_training = True
49 | out, batch = self.forward_batch(batch)
50 | val_loss += out.loss.item()
51 |
52 | if writer is not None:
53 | writer.write_loss(out.loss.item(), 'val')
54 | if epoch % 20 == 0 and i == 0:
55 | writer.plot_image_heatmap(image=batch.input_image.cpu(), pred_heatmap=out.pred.heatmap.cpu(), label_heatmap=batch.label.heatmap.cpu(), epoch=epoch)
56 | val_metric = self.metric_manager.average_running_metric()
57 | val_loss = val_loss / len(val_loader)
58 |
59 | optimizer.scheduler_step(val_metric[save_manager.config.Train.decision_metric])
60 |
61 |
62 | if writer is not None:
63 | writer.write_metric(val_metric, 'val')
64 | if epoch % 5 == 0:
65 | writer.plot_model_param_histogram(self.model, epoch)
66 |
67 | save_manager.write_log('Epoch [{}/{}] train loss [{:.6f}] train {} [{:.6f}] val loss [{:.6f}] val {} [{:.6f}] Epoch time [{:.1f}min]'.format(epoch,
68 | save_manager.config.Train.epoch,
69 | train_loss,
70 | save_manager.config.Train.decision_metric,
71 | train_metric[save_manager.config.Train.decision_metric],
72 | val_loss,
73 | save_manager.config.Train.decision_metric,
74 | val_metric[save_manager.config.Train.decision_metric],
75 | (time.time()-start_time)/60))
76 | print('version: {}'.format(save_manager.config.version))
77 | if self.metric_manager.is_new_best(old=self.best_metric, new=val_metric):
78 | self.patience = 0
79 | self.best_epoch = epoch
80 | self.best_metric = val_metric
81 | save_manager.save_model(self.best_epoch, self.model.state_dict(), self.best_metric)
82 | save_manager.write_log('Model saved after Epoch {}'.format(epoch), 4)
83 | else :
84 | self.patience += 1
85 | if self.patience > save_manager.config.Train.patience:
86 | save_manager.write_log('Training Early Stopped after Epoch {}'.format(epoch), 16)
87 | break
88 |
89 | def test(self, save_manager, test_loader, writer=None):
90 | # To reduce GPU usage, load params on cpu and upload model on gpu(params on gpu also need additional GPU memory)
91 | self.model.cpu()
92 | self.model.load_state_dict(self.best_param)
93 | self.model.to(save_manager.config.MISC.device)
94 |
95 |
96 | save_manager.write_log('Best model at epoch {} loaded'.format(self.best_epoch))
97 | self.model.eval()
98 | with torch.no_grad():
99 | post_metric_managers = [MetricManager(save_manager) for _ in range(save_manager.config.Dataset.num_keypoint+1)]
100 | manual_metric_managers = [MetricManager(save_manager) for _ in range(save_manager.config.Dataset.num_keypoint+1)]
101 |
102 | for i, batch in enumerate(tqdm(test_loader)):
103 | batch = Munch.fromDict(batch)
104 | batch.is_training = False
105 | max_hint = 10+1
106 | for n_hint in range(max_hint):
107 | # 0~max hint
108 |
109 | ## model forward
110 | out, batch, post_processing_pred = self.forward_batch(batch, metric_flag=True, average_flag=False, metric_manager=post_metric_managers[n_hint], return_post_processing_pred=True)
111 | if save_manager.config.Model.use_prev_heatmap:
112 | batch.prev_heatmap = out.pred.heatmap.detach()
113 |
114 | ## manual forward
115 | if n_hint == 0:
116 | manual_pred = copy.deepcopy(out.pred)
117 | manual_pred.sargmax_coord = manual_pred.sargmax_coord
118 | manual_pred.heatmap = manual_pred.heatmap
119 | manual_hint_index = []
120 | else:
121 | # manual prediction update
122 | for k in range(manual_hint_index.shape[0]):
123 | manual_pred.sargmax_coord[k, manual_hint_index[k]] = batch.label.coord[
124 | k, manual_hint_index[k]].to(manual_pred.sargmax_coord.device)
125 | manual_pred.heatmap[k, manual_hint_index[k]] = batch.label.heatmap[
126 | k, manual_hint_index[k]].to(manual_pred.heatmap.device)
127 | manual_metric_managers[n_hint].measure_metric(manual_pred, batch.label, batch.pspace, metric_flag=True, average_flag=False)
128 |
129 | # ============================= model hint =================================
130 | worst_index = self.find_worst_pred_index(batch.hint.index, post_metric_managers,
131 | save_manager, n_hint)
132 | # hint index update
133 | if n_hint == 0:
134 | batch.hint.index = worst_index # (batch, 1)
135 | else:
136 | batch.hint.index = torch.cat((batch.hint.index, worst_index.to(batch.hint.index.device)), dim=1) # ... (batch, max hint)
137 |
138 | #
139 | if save_manager.config.Model.use_prev_heatmap_only_for_hint_index:
140 | new_prev_heatmap = torch.zeros_like(out.pred.heatmap)
141 | for j in range(len(batch.hint.index)):
142 | new_prev_heatmap[j, batch.hint.index[j]] = out.pred.heatmap[j, batch.hint.index[j]]
143 | batch.prev_heatmap = new_prev_heatmap
144 | # ================================= manual hint =========================
145 | worst_index = self.find_worst_pred_index(manual_hint_index, manual_metric_managers,
146 | save_manager, n_hint)
147 | # hint index update
148 | if n_hint == 0:
149 | manual_hint_index = worst_index # (batch, 1)
150 | else:
151 | manual_hint_index = torch.cat((manual_hint_index, worst_index.to(manual_hint_index.device)), dim=1) # ... (batch, max hint)
152 |
153 | # save result
154 | if save_manager.config.save_test_prediction:
155 | save_manager.add_test_prediction_for_save(batch, post_processing_pred, manual_pred, n_hint, post_metric_managers[n_hint], manual_metric_managers[n_hint])
156 |
157 | post_metrics = [metric_manager.average_running_metric() for metric_manager in post_metric_managers]
158 | manual_metrics = [metric_manager.average_running_metric() for metric_manager in manual_metric_managers]
159 |
160 | # save metrics
161 | for t in range(min(max_hint, len(post_metrics))):
162 | save_manager.write_log('(model ) Hint {} ::: {}'.format(t, post_metrics[t]))
163 | save_manager.write_log('(manual) Hint {} ::: {}'.format(t, manual_metrics[t]))
164 | save_manager.save_metric(post_metrics, manual_metrics)
165 |
166 | if save_manager.config.save_test_prediction:
167 | save_manager.save_test_prediction()
168 |
169 |
170 | def forward_batch(self, batch, metric_flag=False, average_flag=True, metric_manager=None, return_post_processing_pred=False):
171 | out, batch = self.model(batch)
172 | with torch.no_grad():
173 | if metric_manager is None:
174 | self.metric_manager.measure_metric(out.pred, batch.label, batch.pspace, metric_flag, average_flag)
175 | else:
176 | #post processing
177 | post_processing_pred = copy.deepcopy(out.pred)
178 | post_processing_pred.sargmax_coord = post_processing_pred.sargmax_coord.detach()
179 | post_processing_pred.heatmap = post_processing_pred.heatmap.detach()
180 | for i in range(len(batch.hint.index)): #for 문이 batch에 대해서 도는중 i번째 item
181 | if batch.hint.index[i] is not None:
182 | post_processing_pred.sargmax_coord[i, batch.hint.index[i]] = batch.label.coord[i, batch.hint.index[i]].detach().to(post_processing_pred.sargmax_coord.device)
183 | post_processing_pred.heatmap[i, batch.hint.index[i]] = batch.label.heatmap[i, batch.hint.index[i]].detach().to(post_processing_pred.heatmap.device)
184 |
185 | metric_manager.measure_metric(post_processing_pred, copy.deepcopy(batch.label), batch.pspace, metric_flag=True, average_flag=False)
186 | if return_post_processing_pred:
187 | return out, batch, post_processing_pred
188 | else:
189 | return out, batch
190 |
191 | def find_worst_pred_index(self, previous_hint_index, metric_managers, save_manager, n_hint):
192 | batch_metric_value = metric_managers[n_hint].running_metric[save_manager.config.Train.decision_metric][-1]
193 | if len(batch_metric_value.shape) == 3:
194 | batch_metric_value = batch_metric_value.mean(-1) # (batch, 13)
195 |
196 | tmp_metric = batch_metric_value.clone().detach()
197 | if metric_managers[n_hint].minimize_metric:
198 | for j, idx in enumerate(previous_hint_index):
199 | if idx is not None:
200 | tmp_metric[j, idx] = -1000
201 | worst_index = tmp_metric.argmax(-1, keepdim=True)
202 | else:
203 | for j, idx in enumerate(previous_hint_index):
204 | if idx is not None:
205 | tmp_metric[j, idx] = 1000
206 | worst_index = tmp_metric.argmin(-1, keepdim=True)
207 | return worst_index
208 |
209 |
210 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import json
3 | import pickle
4 | from munch import Munch
5 | import os
6 | import torch
7 | from misc.metric import heatmap2hargmax_coord
8 | from torch.utils.tensorboard import SummaryWriter
9 | import torchvision
10 |
11 | class SaveManager(object):
12 | def __init__(self, arg):
13 | print('\n\n\n\n')
14 |
15 | self.read_config(arg)
16 | self.unspecified_configs_to_default()
17 | self.save_config()
18 | self.config.MISC.device = torch.device('cuda')
19 | self.config_name2value()
20 |
21 | def config_name2value(self):
22 | with open('./morph_pairs/{}/{}.json'.format(self.config.Dataset.NAME, self.config.Morph.pairs), 'r') as f:
23 | self.config.Morph.pairs = json.load(f)
24 |
25 | if self.config.Hint.num_dist == 'datset16':
26 | self.config.Hint.num_dist = [1 / 8, 1 / 2, 1 / 4, 1 / 16, 1 / 32, 1 / 64, 1 / 128, 1 / 256, 1 / 512, 1 / 1024, 1 / 2048, 1 / 4096, 1 / 4096]+[0 for _ in range(68-13)]
27 |
28 | def read_config(self, arg):
29 | # load config
30 | if arg.only_test_version:
31 | config_path = '../save/{}/config.yaml'.format(arg.only_test_version)
32 | else:
33 | config_path = './config/{}.yaml'.format(arg.config)
34 | with open(config_path) as f:
35 | config = yaml.safe_load(f)
36 |
37 | self.config = Munch.fromDict(config)
38 |
39 | if not arg.only_test_version:
40 | os.makedirs('../save/', exist_ok=True)
41 | exp_num = self.get_new_exp_num()
42 | self.config.version = "ExpNum[{}]_Dataset[{}]_Model[{}]_config[{}]_seed[{}]".format(exp_num,
43 | self.config.Dataset.NAME,
44 | self.config.Model.NAME,
45 | arg.config, arg.seed)
46 | os.makedirs('../save/{}'.format(self.config.version))
47 | else:
48 | self.config.version = arg.only_test_version
49 |
50 | # set path and write version
51 | self.set_config_path()
52 | self.write_log('Version : {}'.format(self.config.version), n_mark=16)
53 |
54 | # update arg to config
55 | items = arg.__dict__.items() if 'namespace' in str(arg.__class__).lower() else arg.items()
56 | for key, value in items:
57 | self.config[key] = value
58 |
59 |
60 | if arg.save_test_prediction:
61 | self.predictions_for_save = Munch({})
62 |
63 | self.config.Model.use_prev_heatmap_only_for_hint_index = True
64 | self.write_log('Use previous heatmap only for hint index, else zero', n_mark=8)
65 |
66 | def write_log(self, text, n_mark=0, save_flag=True):
67 | log = '{} {} {}'.format('='*n_mark, text, '='*n_mark)
68 | print(log)
69 | if save_flag:
70 | with open(self.config.PATH.LOG_PATH, 'a+') as f:
71 | f.write('{}\n'.format(log))
72 | return
73 |
74 | def set_config_path(self):
75 | self.config.PATH.LOG_PATH = '../save/{}/log.txt'.format(self.config.version)
76 | self.config.PATH.CONFIG_PATH = '../save/{}/config.yaml'.format(self.config.version)
77 | self.config.PATH.MODEL_PATH = '../save/{}/model.pth'.format(self.config.version)
78 | self.config.PATH.post_RESULT_PATH = '../save/{}/post_result.pickle'.format(self.config.version)
79 | self.config.PATH.manual_RESUAL_PATH = '../save/{}/manual_result.pickle'.format(self.config.version)
80 | self.config.PATH.PREDICTION_RESULT_PATH = '../save/{}/predictions.pickle'.format(self.config.version)
81 |
82 | def save_config(self):
83 | with open(self.config.PATH.CONFIG_PATH, 'w') as f:
84 | yaml.dump(self.config.toDict(), f)
85 | return
86 |
87 | def load_model(self):
88 | best_save = torch.load(self.config.PATH.MODEL_PATH, map_location=torch.device('cpu'))
89 | best_param = best_save['model']
90 | best_epoch = best_save['epoch']
91 | return best_param, best_epoch, None
92 |
93 | def save_model(self, epoch, param, metric):
94 | save_dict = {
95 | 'model': param,
96 | 'epoch':epoch,
97 | 'metric':metric
98 | }
99 | torch.save(save_dict, self.config.PATH.MODEL_PATH)
100 | return
101 |
102 | def save_metric(self, metric, manual_metric=None):
103 | with open(self.config.PATH.post_RESULT_PATH, 'wb') as f:
104 | pickle.dump(metric, f)
105 | if manual_metric is not None:
106 | with open(self.config.PATH.manual_RESUAL_PATH, 'wb') as f:
107 | pickle.dump(manual_metric, f)
108 |
109 | def save_test_prediction(self):
110 | with open(self.config.PATH.PREDICTION_RESULT_PATH, 'wb') as f:
111 | pickle.dump(self.predictions_for_save, f)
112 |
113 | def add_test_prediction_for_save(self, batch, post_processing_pred, manual_pred, n_hint, post_metric_manager, manual_metric_manager):
114 | manual_pred.hargmax_coord = heatmap2hargmax_coord(manual_pred.heatmap).detach()
115 | post_sargmax_mm_MRE = post_metric_manager.running_metric['sargmax_mm_MRE'][-1].detach().cpu().numpy()
116 | post_hargmax_mm_MRE = post_metric_manager.running_metric['hargmax_mm_MRE'][-1].detach().cpu().numpy()
117 | manual_sargmax_mm_MRE = manual_metric_manager.running_metric['sargmax_mm_MRE'][-1].detach().cpu().numpy()
118 | manual_hargmax_mm_MRE = manual_metric_manager.running_metric['hargmax_mm_MRE'][-1].detach().cpu().numpy()
119 | for b in range(len(batch.label.coord)):
120 | name = 'batch{}_hint{}'.format(batch.index[b], n_hint)
121 | self.predictions_for_save[name] = Munch({})
122 | self.predictions_for_save[name].post = Munch({})
123 | self.predictions_for_save[name].post.sargmax_coord = post_processing_pred.sargmax_coord[b].detach().cpu().numpy()
124 | self.predictions_for_save[name].post.hargmax_coord = post_processing_pred.hargmax_coord[b].detach().cpu().numpy()
125 | # heatmap은 저장하면 너무 느리고 용량도 너무 큼. 차라리 model forward 한번 더 하는게 낫다.
126 | self.predictions_for_save[name].manual = Munch({})
127 | self.predictions_for_save[name].manual.sargmax_coord = manual_pred.sargmax_coord[b].detach().cpu().numpy()
128 | self.predictions_for_save[name].manual.hargmax_coord = manual_pred.hargmax_coord[b].detach().cpu().numpy()
129 |
130 | self.predictions_for_save[name].hint = Munch({})
131 | if torch.is_tensor(batch.hint.index[b]):
132 | self.predictions_for_save[name].hint.index = batch.hint.index[b].detach().cpu().numpy()
133 | else:
134 | self.predictions_for_save[name].hint.index = batch.hint.index[b]
135 |
136 | self.predictions_for_save[name].metric = Munch({})
137 | self.predictions_for_save[name].metric.post = Munch({})
138 | self.predictions_for_save[name].metric.post.sargmax_mm_MRE = post_sargmax_mm_MRE[b]
139 | self.predictions_for_save[name].metric.post.hargmax_mm_MRE = post_hargmax_mm_MRE[b]
140 | self.predictions_for_save[name].metric.manual = Munch({})
141 | self.predictions_for_save[name].metric.manual.sargmax_mm_MRE = manual_sargmax_mm_MRE[b]
142 | self.predictions_for_save[name].metric.manual.hargmax_mm_MRE = manual_hargmax_mm_MRE[b]
143 |
144 |
145 |
146 | def get_new_exp_num(self):
147 | save_path = '../save/'
148 | save_folder_names = os.listdir(save_path)
149 | max_exp_num = 0
150 | for folder_name in save_folder_names:
151 | exp_num = folder_name.split('_')[0][:-1].split('[')[-1]
152 | max_exp_num = max(int(exp_num), max_exp_num)
153 | new_exp_num = '{:05d}'.format(max_exp_num+1)
154 |
155 | return new_exp_num
156 |
157 | def unspecified_configs_to_default(self):
158 | if self.config.Dataset.get('subpixel_decoding',None) is None:
159 | self.config.Dataset.subpixel_decoding = False
160 | if self.config.Dataset.get('subpixel_decoding_patch_size',None) is None:
161 | self.config.Dataset.subpixel_decoding_patch_size = 5
162 | if self.config.Dataset.get('heatmap_encoding_maxone',None) is None:
163 | self.config.Dataset.heatmap_encoding_maxone = False
164 | if self.config.Dataset.get('subpixel_heatmap_encoding', None) is None:
165 | self.config.Dataset.subpixel_heatmap_encoding = False
166 | if self.config.Model.get('subpixel_decoding_coord_loss', None) is None:
167 | self.config.Model.subpixel_decoding_coord_loss = False
168 | if self.config.Model.get('facto_heatmap', None) is None:
169 | self.config.Model.facto_heatmap = False
170 | if self.config.Model.get('HintEncoder', None) is None:
171 | self.config.Model.HintEncoder = Munch({})
172 | if self.config.Model.HintEncoder.get('dilation', None) is None:
173 | self.config.Model.HintEncoder.dilation = 5
174 | if self.config.Model.get('Decoder', None) is None:
175 | self.config.Model.Decoder = Munch({})
176 | if self.config.Model.Decoder.get('dilation', None) is None:
177 | self.config.Model.Decoder.dilation = 5
178 | if self.config.Dataset.get('label_smoothing', None) is None:
179 | self.config.Dataset.label_smoothing = False
180 | if self.config.Model.get('MSELoss', None) is None:
181 | self.config.Model.MSELoss = 0.0
182 | if self.config.Dataset.get('heatmap_max_norm',None) is None:
183 | self.config.Dataset.heatmap_max_norm = False
184 | if self.config.Model.get('input_padding') is None:
185 | self.config.Model.input_padding = None
186 | if self.config.Morph.get('cosineSimilarityLoss') is None:
187 | self.config.Morph.cosineSimilarityLoss = False
188 | if self.config.Morph.get('threePointAngle') is None:
189 | self.config.Morph.threePointAngle = False
190 | if self.config.MISC.get('free_memory',None) is None:
191 | self.config.MISC.free_memory = False
192 | if self.config.Model.get('bbox_predictor', None) is None:
193 | self.config.Model.bbox_predictor = False
194 | if self.config.Morph.get('distance_l1', None) is None:
195 | self.config.Morph.distance_l1 = False
196 | if self.config.Model.get('SE_maxpool',None) is None:
197 | self.config.Model.SE_maxpool = False
198 | if self.config.Model.get('SE_softmax', None) is None:
199 | self.config.Model.SE_softmax = False
200 | if self.config.Model.get('use_prev_heatmap', None) is None:
201 | self.config.Model.use_prev_heatmap = False
202 | if self.config.Model.get('no_iterative_training', None) is None:
203 | self.config.Model.no_iterative_training = False # RITM에만 적용
204 | if self.config.Morph.get('coord_use', None) is None:
205 | self.config.Morph.coord_use = False
206 |
207 | class TensorBoardManager():
208 | def __init__(self, save_manager):
209 | tensorboard_path = '../tensorboard/{}/{}/'.format(save_manager.config.Dataset.NAME, save_manager.config.version)
210 | os.makedirs(tensorboard_path, exist_ok=True)
211 | self.writer = SummaryWriter(tensorboard_path)
212 |
213 | self.n_iter = {'train':0, 'val':0, 'test':0}
214 | self.n_epoch = {'train':0, 'val':0, 'test':0}
215 |
216 | def plot_image_heatmap(self, image, pred_heatmap, label_heatmap, epoch):
217 | # image: (batch, 3, H, W) -1 <= x <= 1
218 | # heatmap: (batch, num_keypoint, Height, Width) 0 <= x <= 1
219 | image_unNorm = image * 0.5 + 0.5
220 |
221 | pred_label_heatmap = torch.cat((pred_heatmap[:,:,None,:,:], label_heatmap[:,:,None,:,:], torch.zeros_like(pred_heatmap)[:,:,None,:,:]), dim=2)
222 |
223 | grids = [torchvision.utils.make_grid(
224 | torch.cat([image_unNorm[i].unsqueeze(0), pred_label_heatmap[i]], dim=0)
225 | )
226 | for i in range(pred_label_heatmap.shape[0])] # (1,3,H,W), (num_keypoint,3,H,W)
227 |
228 | for i, grid in enumerate(grids):
229 | text = 'Epoch [{}] - {} ::: Pred(red) Label(green)'.format(epoch, i+1)
230 | self.writer.add_image(text, grid)
231 |
232 | def plot_outlier_image_heatmap(self, image, pred_heatmap, label_heatmap, epoch):
233 | # image: (batch, 3, H, W) -1 <= x <= 1
234 | # heatmap: (batch, num_keypoint, Height, Width) 0 <= x <= 1
235 | image_unNorm = image * 0.5 + 0.5
236 |
237 | pred_label_heatmap = torch.cat((pred_heatmap[:,:,None,:,:], label_heatmap[:,:,None,:,:], torch.zeros_like(pred_heatmap)[:,:,None,:,:]), dim=2)
238 |
239 | grids = [torchvision.utils.make_grid(
240 | torch.cat([image_unNorm[i].unsqueeze(0), pred_label_heatmap[i]], dim=0)
241 | )
242 | for i in range(pred_label_heatmap.shape[0])] # (1,3,H,W), (num_keypoint,3,H,W)
243 |
244 | for i, grid in enumerate(grids):
245 | text = 'iter [{}] - outlier - {}'.format(self.n_iter['train'], i+1)
246 | self.writer.add_image(text, grid)
247 |
248 | def plot_model_param_histogram(self, model, epoch):
249 | for k, v in model.named_parameters():
250 | self.writer.add_histogram(k, v.data.cpu().reshape(-1), epoch)
251 | return
252 |
253 | def write_loss(self, loss, split):
254 | # loss: scalar
255 | # split: 'train', 'val', 'test'
256 | self.n_iter[split] += 1
257 | self.writer.add_scalar('Loss/{}'.format(split), loss, self.n_iter[split])
258 |
259 | def write_metric(self, metric, split):
260 | self.n_epoch[split] += 1
261 | for key in metric:
262 | self.writer.add_scalar('{}/{}'.format(key, split), metric[key], self.n_epoch[split])
263 |
264 |
--------------------------------------------------------------------------------
/model/iterativeRefinementModels/RITM_modules/RITM_hrnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch._utils
6 | import torch.nn.functional as F
7 | from model.iterativeRefinementModels.RITM_modules.RITM_ocr import SpatialOCR_Module, SpatialGather_Module
8 |
9 |
10 | relu_inplace = True
11 |
12 |
13 | class HighResolutionModule(nn.Module):
14 | def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
15 | num_channels, fuse_method,multi_scale_output=True,
16 | norm_layer=nn.BatchNorm2d, align_corners=True):
17 | super(HighResolutionModule, self).__init__()
18 | self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
19 |
20 | self.num_inchannels = num_inchannels
21 | self.fuse_method = fuse_method
22 | self.num_branches = num_branches
23 | self.norm_layer = norm_layer
24 | self.align_corners = align_corners
25 |
26 | self.multi_scale_output = multi_scale_output
27 |
28 | self.branches = self._make_branches(
29 | num_branches, blocks, num_blocks, num_channels)
30 | self.fuse_layers = self._make_fuse_layers()
31 | self.relu = nn.ReLU(inplace=relu_inplace)
32 |
33 | def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
34 | if num_branches != len(num_blocks):
35 | error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
36 | num_branches, len(num_blocks))
37 | raise ValueError(error_msg)
38 |
39 | if num_branches != len(num_channels):
40 | error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
41 | num_branches, len(num_channels))
42 | raise ValueError(error_msg)
43 |
44 | if num_branches != len(num_inchannels):
45 | error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
46 | num_branches, len(num_inchannels))
47 | raise ValueError(error_msg)
48 |
49 | def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
50 | stride=1):
51 | downsample = None
52 | if stride != 1 or \
53 | self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
54 | downsample = nn.Sequential(
55 | nn.Conv2d(self.num_inchannels[branch_index],
56 | num_channels[branch_index] * block.expansion,
57 | kernel_size=1, stride=stride, bias=False),
58 | self.norm_layer(num_channels[branch_index] * block.expansion),
59 | )
60 |
61 | layers = []
62 | layers.append(block(self.num_inchannels[branch_index],
63 | num_channels[branch_index], stride,
64 | downsample=downsample, norm_layer=self.norm_layer))
65 | self.num_inchannels[branch_index] = \
66 | num_channels[branch_index] * block.expansion
67 | for i in range(1, num_blocks[branch_index]):
68 | layers.append(block(self.num_inchannels[branch_index],
69 | num_channels[branch_index],
70 | norm_layer=self.norm_layer))
71 |
72 | return nn.Sequential(*layers)
73 |
74 | def _make_branches(self, num_branches, block, num_blocks, num_channels):
75 | branches = []
76 |
77 | for i in range(num_branches):
78 | branches.append(
79 | self._make_one_branch(i, block, num_blocks, num_channels))
80 |
81 | return nn.ModuleList(branches)
82 |
83 | def _make_fuse_layers(self):
84 | if self.num_branches == 1:
85 | return None
86 |
87 | num_branches = self.num_branches
88 | num_inchannels = self.num_inchannels
89 | fuse_layers = []
90 | for i in range(num_branches if self.multi_scale_output else 1):
91 | fuse_layer = []
92 | for j in range(num_branches):
93 | if j > i:
94 | fuse_layer.append(nn.Sequential(
95 | nn.Conv2d(in_channels=num_inchannels[j],
96 | out_channels=num_inchannels[i],
97 | kernel_size=1,
98 | bias=False),
99 | self.norm_layer(num_inchannels[i])))
100 | elif j == i:
101 | fuse_layer.append(None)
102 | else:
103 | conv3x3s = []
104 | for k in range(i - j):
105 | if k == i - j - 1:
106 | num_outchannels_conv3x3 = num_inchannels[i]
107 | conv3x3s.append(nn.Sequential(
108 | nn.Conv2d(num_inchannels[j],
109 | num_outchannels_conv3x3,
110 | kernel_size=3, stride=2, padding=1, bias=False),
111 | self.norm_layer(num_outchannels_conv3x3)))
112 | else:
113 | num_outchannels_conv3x3 = num_inchannels[j]
114 | conv3x3s.append(nn.Sequential(
115 | nn.Conv2d(num_inchannels[j],
116 | num_outchannels_conv3x3,
117 | kernel_size=3, stride=2, padding=1, bias=False),
118 | self.norm_layer(num_outchannels_conv3x3),
119 | nn.ReLU(inplace=relu_inplace)))
120 | fuse_layer.append(nn.Sequential(*conv3x3s))
121 | fuse_layers.append(nn.ModuleList(fuse_layer))
122 |
123 | return nn.ModuleList(fuse_layers)
124 |
125 | def get_num_inchannels(self):
126 | return self.num_inchannels
127 |
128 | def forward(self, x):
129 | if self.num_branches == 1:
130 | return [self.branches[0](x[0])]
131 |
132 | for i in range(self.num_branches):
133 | x[i] = self.branches[i](x[i])
134 |
135 | x_fuse = []
136 | for i in range(len(self.fuse_layers)):
137 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
138 | for j in range(1, self.num_branches):
139 | if i == j:
140 | y = y + x[j]
141 | elif j > i:
142 | width_output = x[i].shape[-1]
143 | height_output = x[i].shape[-2]
144 | y = y + F.interpolate(
145 | self.fuse_layers[i][j](x[j]),
146 | size=[height_output, width_output],
147 | mode='bilinear', align_corners=self.align_corners)
148 | else:
149 | y = y + self.fuse_layers[i][j](x[j])
150 | x_fuse.append(self.relu(y))
151 |
152 | return x_fuse
153 |
154 |
155 | class SqueezeExitationBlock(nn.Module):
156 | def __init__(self, in_ch, mid_ch1, out_ch, SE_maxpool=False, SE_softmax=False):
157 | super(SqueezeExitationBlock, self).__init__()
158 | self.conv1 = nn.Conv2d(in_ch, mid_ch1, kernel_size=(1, 1))
159 | self.conv2 = nn.Conv2d(mid_ch1, out_ch, kernel_size=(1, 1))
160 | self.SE_maxpool = SE_maxpool
161 | self.SE_softmax = SE_softmax
162 |
163 | def forward(self, x):
164 | # === squeeze ===
165 | # pool
166 | if self.SE_maxpool:
167 | x = x.max(-1)[0].max(-1)[0]
168 | x = x[:,:,None,None]
169 | else:
170 | x = x.mean(-1, keepdim=True).mean(-2, keepdim=True) # b, in_ch, 1, 1
171 | # conv1, relu
172 | x = self.conv1(x).relu() # b, mid_ch, 1, 1
173 | # conv2, relu
174 | x = self.conv2(x)
175 |
176 | if self.SE_softmax:
177 | x = x.softmax(1)
178 | else:
179 | x = x.sigmoid() # b, out_ch, 1, 1
180 | return x
181 | class ConvBlocks(nn.Module):
182 | def __init__(self, in_ch, channels, kernel_sizes=None, strides=None, dilations=None, paddings=None,
183 | BatchNorm=nn.BatchNorm2d):
184 | super(ConvBlocks, self).__init__()
185 | self.num = len(channels)
186 | if kernel_sizes is None: kernel_sizes = [3 for c in channels]
187 | if strides is None: strides = [1 for c in channels]
188 | if dilations is None: dilations = [1 for c in channels]
189 | if paddings is None: paddings = [
190 | ((kernel_sizes[i] // 2) if dilations[i] == 1 else (kernel_sizes[i] // 2 * dilations[i])) for i in
191 | range(self.num)]
192 | convs_tmp = []
193 | for i in range(self.num):
194 | if channels[i] == 1:
195 | convs_tmp.append(
196 | nn.Conv2d(in_ch if i == 0 else channels[i - 1], channels[i], kernel_size=kernel_sizes[i],
197 | stride=strides[i], padding=paddings[i], dilation=dilations[i]))
198 | else:
199 | convs_tmp.append(nn.Sequential(
200 | nn.Conv2d(in_ch if i == 0 else channels[i - 1], channels[i], kernel_size=kernel_sizes[i],
201 | stride=strides[i], padding=paddings[i], dilation=dilations[i], bias=False),
202 | BatchNorm(channels[i]), nn.ReLU()))
203 | self.convs = nn.Sequential(*convs_tmp)
204 |
205 | # weight initialization
206 | for m in self.convs.modules():
207 | if isinstance(m, nn.Conv2d):
208 | torch.nn.init.kaiming_normal_(m.weight)
209 | elif isinstance(m, nn.BatchNorm2d):
210 | m.weight.data.fill_(1)
211 | m.bias.data.zero_()
212 |
213 | def forward(self, x):
214 | return self.convs(x)
215 | class HintEncSENet(nn.Module):
216 | def __init__(self, se_output_channel, num_classes, SE_maxpool=False, SE_softmax=False, input_channel=256):
217 | super(HintEncSENet, self).__init__()
218 | self.SENet = SqueezeExitationBlock(256, 256//16, se_output_channel, SE_maxpool=SE_maxpool, SE_softmax=SE_softmax)
219 | self.hintEncoder = ConvBlocks(input_channel+num_classes, [256, 256, 256], [3, 3, 3], [2, 1, 1])
220 |
221 |
222 | def forward(self,x,hint):
223 | hint = F.interpolate(hint, size=x.size()[2:], mode='bilinear', align_corners=True)
224 | se = self.hintEncoder(torch.cat((x, hint),dim=1))
225 | se = self.SENet(se)
226 | return se
227 |
228 |
229 | class HighResolutionNet(nn.Module):
230 | def __init__(self, width, num_classes, ocr_width=256, small=False,
231 | norm_layer=nn.BatchNorm2d, align_corners=True, addHintEncSENet=False, SE_maxpool=False, SE_softmax=False):
232 | super(HighResolutionNet, self).__init__()
233 |
234 |
235 | self.norm_layer = norm_layer
236 | self.width = width
237 | self.ocr_width = ocr_width
238 | self.align_corners = align_corners
239 |
240 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
241 | self.bn1 = norm_layer(64)
242 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
243 | self.bn2 = norm_layer(64)
244 | self.relu = nn.ReLU(inplace=relu_inplace)
245 |
246 | num_blocks = 2 if small else 4
247 |
248 | stage1_num_channels = 64
249 | self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
250 | stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
251 |
252 | self.stage2_num_branches = 2
253 | num_channels = [width, 2 * width]
254 | num_inchannels = [
255 | num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
256 | self.transition1 = self._make_transition_layer(
257 | [stage1_out_channel], num_inchannels)
258 | self.stage2, pre_stage_channels = self._make_stage(
259 | BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
260 | num_blocks=2 * [num_blocks], num_channels=num_channels)
261 |
262 | self.stage3_num_branches = 3
263 | num_channels = [width, 2 * width, 4 * width]
264 | num_inchannels = [
265 | num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
266 | self.transition2 = self._make_transition_layer(
267 | pre_stage_channels, num_inchannels)
268 | self.stage3, pre_stage_channels = self._make_stage(
269 | BasicBlockV1b, num_inchannels=num_inchannels,
270 | num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
271 | num_blocks=3 * [num_blocks], num_channels=num_channels)
272 |
273 | self.stage4_num_branches = 4
274 | num_channels = [width, 2 * width, 4 * width, 8 * width]
275 | num_inchannels = [
276 | num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
277 | self.transition3 = self._make_transition_layer(
278 | pre_stage_channels, num_inchannels)
279 | self.stage4, pre_stage_channels = self._make_stage(
280 | BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
281 | num_branches=self.stage4_num_branches,
282 | num_blocks=4 * [num_blocks], num_channels=num_channels)
283 |
284 | last_inp_channels = np.int(np.sum(pre_stage_channels))
285 | if self.ocr_width > 0:
286 | ocr_mid_channels = 2 * self.ocr_width
287 | ocr_key_channels = self.ocr_width
288 |
289 | self.conv3x3_ocr = nn.Sequential(
290 | nn.Conv2d(last_inp_channels, ocr_mid_channels,
291 | kernel_size=3, stride=1, padding=1),
292 | norm_layer(ocr_mid_channels),
293 | nn.ReLU(inplace=relu_inplace),
294 | )
295 | self.ocr_gather_head = SpatialGather_Module(num_classes)
296 |
297 | self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
298 | key_channels=ocr_key_channels,
299 | out_channels=ocr_mid_channels,
300 | scale=1,
301 | dropout=0.05,
302 | norm_layer=norm_layer,
303 | align_corners=align_corners)
304 | self.cls_head = nn.Conv2d(
305 | ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
306 |
307 | self.aux_head = nn.Sequential(
308 | nn.Conv2d(last_inp_channels, last_inp_channels,
309 | kernel_size=1, stride=1, padding=0),
310 | norm_layer(last_inp_channels),
311 | nn.ReLU(inplace=relu_inplace),
312 | nn.Conv2d(last_inp_channels, num_classes,
313 | kernel_size=1, stride=1, padding=0, bias=True)
314 | )
315 | else:
316 | self.cls_head = nn.Sequential(
317 | nn.Conv2d(last_inp_channels, last_inp_channels,
318 | kernel_size=3, stride=1, padding=1),
319 | norm_layer(last_inp_channels),
320 | nn.ReLU(inplace=relu_inplace),
321 | nn.Conv2d(last_inp_channels, num_classes,
322 | kernel_size=1, stride=1, padding=0, bias=True)
323 | )
324 |
325 |
326 | self.addHintEncSENet = addHintEncSENet
327 | if self.addHintEncSENet:
328 | self.HintEncSENet = HintEncSENet( se_output_channel=last_inp_channels, num_classes=num_classes, SE_maxpool=SE_maxpool, SE_softmax=SE_softmax)
329 |
330 |
331 | def _make_transition_layer(
332 | self, num_channels_pre_layer, num_channels_cur_layer):
333 | num_branches_cur = len(num_channels_cur_layer)
334 | num_branches_pre = len(num_channels_pre_layer)
335 |
336 | transition_layers = []
337 | for i in range(num_branches_cur):
338 | if i < num_branches_pre:
339 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
340 | transition_layers.append(nn.Sequential(
341 | nn.Conv2d(num_channels_pre_layer[i],
342 | num_channels_cur_layer[i],
343 | kernel_size=3,
344 | stride=1,
345 | padding=1,
346 | bias=False),
347 | self.norm_layer(num_channels_cur_layer[i]),
348 | nn.ReLU(inplace=relu_inplace)))
349 | else:
350 | transition_layers.append(None)
351 | else:
352 | conv3x3s = []
353 | for j in range(i + 1 - num_branches_pre):
354 | inchannels = num_channels_pre_layer[-1]
355 | outchannels = num_channels_cur_layer[i] \
356 | if j == i - num_branches_pre else inchannels
357 | conv3x3s.append(nn.Sequential(
358 | nn.Conv2d(inchannels, outchannels,
359 | kernel_size=3, stride=2, padding=1, bias=False),
360 | self.norm_layer(outchannels),
361 | nn.ReLU(inplace=relu_inplace)))
362 | transition_layers.append(nn.Sequential(*conv3x3s))
363 |
364 | return nn.ModuleList(transition_layers)
365 |
366 | def _make_layer(self, block, inplanes, planes, blocks, stride=1):
367 | downsample = None
368 | if stride != 1 or inplanes != planes * block.expansion:
369 | downsample = nn.Sequential(
370 | nn.Conv2d(inplanes, planes * block.expansion,
371 | kernel_size=1, stride=stride, bias=False),
372 | self.norm_layer(planes * block.expansion),
373 | )
374 |
375 | layers = []
376 | layers.append(block(inplanes, planes, stride,
377 | downsample=downsample, norm_layer=self.norm_layer))
378 | inplanes = planes * block.expansion
379 | for i in range(1, blocks):
380 | layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
381 |
382 | return nn.Sequential(*layers)
383 |
384 | def _make_stage(self, block, num_inchannels,
385 | num_modules, num_branches, num_blocks, num_channels,
386 | fuse_method='SUM',
387 | multi_scale_output=True):
388 | modules = []
389 | for i in range(num_modules):
390 | # multi_scale_output is only used last module
391 | if not multi_scale_output and i == num_modules - 1:
392 | reset_multi_scale_output = False
393 | else:
394 | reset_multi_scale_output = True
395 | modules.append(
396 | HighResolutionModule(num_branches,
397 | block,
398 | num_blocks,
399 | num_inchannels,
400 | num_channels,
401 | fuse_method,
402 | reset_multi_scale_output,
403 | norm_layer=self.norm_layer,
404 | align_corners=self.align_corners)
405 | )
406 | num_inchannels = modules[-1].get_num_inchannels()
407 |
408 | return nn.Sequential(*modules), num_inchannels
409 |
410 | def forward(self, x, additional_features=None, input_hint_heatmap=None):
411 | feats = self.compute_hrnet_feats(x, additional_features, input_hint_heatmap)
412 | if self.ocr_width > 0:
413 | out_aux = self.aux_head(feats) # aux_head : conv norm relu conv (soft object regions), output channel: num_classes
414 | feats = self.conv3x3_ocr(feats) # conv3x3_ocr : conv norm relu (pixel representation
415 |
416 | context = self.ocr_gather_head(feats, out_aux) # context : batch x c x num_keypoint x 1, feats: batch, c, H, W
417 | feats = self.ocr_distri_head(feats, context)
418 | out = self.cls_head(feats)
419 | return [out, out_aux]
420 | else:
421 | return [self.cls_head(feats), None]
422 |
423 | def compute_hrnet_feats(self, x, additional_features, input_hint_heatmap):
424 | x = self.compute_pre_stage_features(x, additional_features)
425 | x = self.layer1(x)
426 |
427 | if input_hint_heatmap is not None:
428 | hint_encoder_output = self.HintEncSENet(x, input_hint_heatmap)
429 |
430 | x_list = []
431 | for i in range(self.stage2_num_branches):
432 | if self.transition1[i] is not None:
433 | x_list.append(self.transition1[i](x))
434 | else:
435 | x_list.append(x)
436 | y_list = self.stage2(x_list)
437 |
438 | x_list = []
439 | for i in range(self.stage3_num_branches):
440 | if self.transition2[i] is not None:
441 | if i < self.stage2_num_branches:
442 | x_list.append(self.transition2[i](y_list[i]))
443 | else:
444 | x_list.append(self.transition2[i](y_list[-1]))
445 | else:
446 | x_list.append(y_list[i])
447 | y_list = self.stage3(x_list)
448 |
449 | x_list = []
450 | for i in range(self.stage4_num_branches):
451 | if self.transition3[i] is not None:
452 | if i < self.stage3_num_branches:
453 | x_list.append(self.transition3[i](y_list[i]))
454 | else:
455 | x_list.append(self.transition3[i](y_list[-1]))
456 | else:
457 | x_list.append(y_list[i])
458 | x = self.stage4(x_list)
459 |
460 | out = self.aggregate_hrnet_features(x)
461 | if input_hint_heatmap is not None:
462 | return hint_encoder_output * out
463 | else:
464 | return out
465 |
466 |
467 | def compute_pre_stage_features(self, x, additional_features):
468 | x = self.conv1(x)
469 | x = self.bn1(x)
470 | x = self.relu(x)
471 | if additional_features is not None:
472 | x = x + additional_features
473 | x = self.conv2(x)
474 | x = self.bn2(x)
475 | return self.relu(x)
476 |
477 | def aggregate_hrnet_features(self, x):
478 | # Upsampling
479 | x0_h, x0_w = x[0].size(2), x[0].size(3)
480 | x1 = F.interpolate(x[1], size=(x0_h, x0_w),
481 | mode='bilinear', align_corners=self.align_corners)
482 | x2 = F.interpolate(x[2], size=(x0_h, x0_w),
483 | mode='bilinear', align_corners=self.align_corners)
484 | x3 = F.interpolate(x[3], size=(x0_h, x0_w),
485 | mode='bilinear', align_corners=self.align_corners)
486 |
487 | return torch.cat([x[0], x1, x2, x3], 1)
488 |
489 | def load_pretrained_weights(self, pretrained_path=''):
490 | model_dict = self.state_dict()
491 |
492 | if not os.path.exists(pretrained_path):
493 | print(f'\nFile "{pretrained_path}" does not exist.')
494 | print('You need to specify the correct path to the pre-trained weights.\n'
495 | 'You can download the weights for HRNet from the repository:\n'
496 | 'https://github.com/HRNet/HRNet-Image-Classification')
497 | exit(1)
498 | pretrained_dict = torch.load(pretrained_path, map_location='cpu')
499 | pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
500 | pretrained_dict.items()}
501 |
502 | pretrained_dict = {k: v for k, v in pretrained_dict.items()
503 | if k in model_dict.keys()}
504 |
505 | model_dict.update(pretrained_dict)
506 | self.load_state_dict(model_dict, strict=False)
507 |
508 |
509 |
510 |
511 | import torch
512 | import torch.nn as nn
513 | GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'
514 |
515 |
516 | class BasicBlockV1b(nn.Module):
517 | expansion = 1
518 |
519 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
520 | previous_dilation=1, norm_layer=nn.BatchNorm2d):
521 | super(BasicBlockV1b, self).__init__()
522 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
523 | padding=dilation, dilation=dilation, bias=False)
524 | self.bn1 = norm_layer(planes)
525 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
526 | padding=previous_dilation, dilation=previous_dilation, bias=False)
527 | self.bn2 = norm_layer(planes)
528 |
529 | self.relu = nn.ReLU(inplace=True)
530 | self.downsample = downsample
531 | self.stride = stride
532 |
533 | def forward(self, x):
534 | residual = x
535 |
536 | out = self.conv1(x)
537 | out = self.bn1(out)
538 | out = self.relu(out)
539 |
540 | out = self.conv2(out)
541 | out = self.bn2(out)
542 |
543 | if self.downsample is not None:
544 | residual = self.downsample(x)
545 |
546 | out = out + residual
547 | out = self.relu(out)
548 |
549 | return out
550 |
551 |
552 | class BottleneckV1b(nn.Module):
553 | expansion = 4
554 |
555 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
556 | previous_dilation=1, norm_layer=nn.BatchNorm2d):
557 | super(BottleneckV1b, self).__init__()
558 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
559 | self.bn1 = norm_layer(planes)
560 |
561 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
562 | padding=dilation, dilation=dilation, bias=False)
563 | self.bn2 = norm_layer(planes)
564 |
565 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
566 | self.bn3 = norm_layer(planes * self.expansion)
567 |
568 | self.relu = nn.ReLU(inplace=True)
569 | self.downsample = downsample
570 | self.stride = stride
571 |
572 | def forward(self, x):
573 | residual = x
574 |
575 | out = self.conv1(x)
576 | out = self.bn1(out)
577 | out = self.relu(out)
578 |
579 | out = self.conv2(out)
580 | out = self.bn2(out)
581 | out = self.relu(out)
582 |
583 | out = self.conv3(out)
584 | out = self.bn3(out)
585 |
586 | if self.downsample is not None:
587 | residual = self.downsample(x)
588 |
589 | out = out + residual
590 | out = self.relu(out)
591 |
592 | return out
--------------------------------------------------------------------------------