├── README.md ├── core ├── __pycache__ │ ├── abc_modules.cpython-38.pyc │ ├── aff_utils.cpython-38.pyc │ ├── datasets.cpython-38.pyc │ ├── deeplab_utils.cpython-38.pyc │ ├── networks.cpython-38.pyc │ ├── puzzle_utils.cpython-38.pyc │ └── seg_hrnet.cpython-38.pyc ├── abc_modules.py ├── aff_utils.py ├── arch_resnest │ ├── __pycache__ │ │ ├── resnest.cpython-38.pyc │ │ ├── resnet.cpython-38.pyc │ │ └── splat.cpython-38.pyc │ ├── resnest(1).py │ ├── resnest.py │ ├── resnet(1).py │ ├── resnet.py │ ├── splat(1).py │ └── splat.py ├── arch_resnet │ ├── __pycache__ │ │ └── resnet.cpython-38.pyc │ ├── resnet(1).py │ └── resnet.py ├── config.yaml ├── datasets.py ├── deeplab_utils.py ├── networks.py ├── puzzle_utils.py ├── seg_hrnet.py └── sync_batchnorm │ ├── __init__(1).py │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── batchnorm.cpython-38.pyc │ ├── comm.cpython-38.pyc │ └── replicate.cpython-38.pyc │ ├── batchnorm(1).py │ ├── batchnorm.py │ ├── comm(1).py │ ├── comm.py │ ├── replicate(1).py │ ├── replicate.py │ ├── unittest(1).py │ └── unittest.py ├── generate_bg_masks.py ├── networks └── ham_net.py ├── tools ├── ai │ ├── __pycache__ │ │ ├── augment_utils.cpython-38.pyc │ │ ├── demo_utils.cpython-38.pyc │ │ ├── evaluate_utils.cpython-38.pyc │ │ ├── log_utils.cpython-38.pyc │ │ ├── optim_utils.cpython-38.pyc │ │ ├── randaugment.cpython-38.pyc │ │ └── torch_utils.cpython-38.pyc │ ├── augment_utils.py │ ├── demo_utils.py │ ├── evaluate_utils.py │ ├── log_utils.py │ ├── optim_utils.py │ ├── randaugment.py │ └── torch_utils.py ├── dataset │ ├── __pycache__ │ │ └── voc_utils.cpython-38.pyc │ └── voc_utils.py └── general │ ├── __pycache__ │ ├── io_utils.cpython-38.pyc │ ├── json_utils.cpython-38.pyc │ ├── time_utils.cpython-38.pyc │ ├── txt_utils.cpython-38.pyc │ └── xml_utils.cpython-38.pyc │ ├── io_utils.py │ ├── json_utils.py │ ├── pickle_utils.py │ ├── time_utils.py │ ├── txt_utils.py │ └── xml_utils.py ├── train_cls.py ├── train_seg.py └── utils ├── CRF.py ├── LoadData.py ├── LoadData_with_bg.py ├── Metrics.py ├── generate_bg_masks.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## HAMIL: High-resolution Activation Maps and Interleaved Learning for Weakly Supervised Segmentation of Histopathological Images 2 | 3 | This repository provides the code for "HAMIL: High-resolution Activation Maps and Interleaved Learning for Weakly Supervised Segmentation of Histopathological Images" 4 | accepted by TMI 2023. 5 | 6 | ## Usage 7 | 1. To obtain the background masks. (If the background is not white regions, skip this step) 8 | 9 | ```train_set_root```: training set root, ```gamma_path```: path to save gamma transform for training set, ```gamma_crf_path```: path to save extracted backgrounds for training set, 10 | 11 | ``` 12 | generate_bg_masks.py --train_set_root train_set_root --gamma_path gamma_path --gamma_crf_path gamma_crf_path 13 | ``` 14 | 15 | 2. Train classification network 16 | ``` 17 | train_cls.py --dataset_root dataset_root --gpu 0 18 | ``` 19 | 20 | 3. Train segmentation network 21 | ``` 22 | train_seg.py --dataset_root dataset_root --gpu 0 23 | ``` 24 | 25 | ## Citation 26 | ``` 27 | @article{zhong2023hamil, 28 | title={HAMIL: High-resolution Activation Maps and Interleaved Learning for Weakly Supervised Segmentation of Histopathological Images}, 29 | author={Zhong, Lanfeng and Wang, Guotai and Liao, Xin and Zhang, Shaoting}, 30 | journal={IEEE Transactions on Medical Imaging}, 31 | year={2023}, 32 | publisher={IEEE} 33 | } 34 | ``` 35 | 36 | ## Acknowledgement 37 | The code of DeepLabv3+ is borrowed from [PuzzleCAM](https://github.com/shjo-april/PuzzleCAM) 38 | -------------------------------------------------------------------------------- /core/__pycache__/abc_modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/__pycache__/abc_modules.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/aff_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/__pycache__/aff_utils.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/deeplab_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/__pycache__/deeplab_utils.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/__pycache__/networks.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/puzzle_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/__pycache__/puzzle_utils.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/seg_hrnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/__pycache__/seg_hrnet.cpython-38.pyc -------------------------------------------------------------------------------- /core/abc_modules.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from abc import ABC 8 | 9 | class ABC_Model(ABC): 10 | def global_average_pooling_2d(self, x, keepdims=False): 11 | x = torch.mean(x.view(x.size(0), x.size(1), -1), -1) 12 | if keepdims: 13 | x = x.view(x.size(0), x.size(1), 1, 1) 14 | return x 15 | 16 | def initialize(self, modules): 17 | for m in modules: 18 | if isinstance(m, nn.Conv2d): 19 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 20 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 21 | torch.nn.init.kaiming_normal_(m.weight) 22 | 23 | elif isinstance(m, nn.BatchNorm2d): 24 | m.weight.data.fill_(1) 25 | m.bias.data.zero_() 26 | 27 | def get_parameter_groups(self, print_fn=print): 28 | groups = ([], [], [], []) 29 | 30 | for name, value in self.named_parameters(): 31 | # pretrained weights 32 | if 'model' in name: 33 | if 'weight' in name: 34 | # print_fn(f'pretrained weights : {name}') 35 | groups[0].append(value) 36 | else: 37 | # print_fn(f'pretrained bias : {name}') 38 | groups[1].append(value) 39 | 40 | # scracthed weights 41 | else: 42 | if 'weight' in name: 43 | if print_fn is not None: 44 | print_fn(f'scratched weights : {name}') 45 | groups[2].append(value) 46 | else: 47 | if print_fn is not None: 48 | print_fn(f'scratched bias : {name}') 49 | groups[3].append(value) 50 | return groups 51 | -------------------------------------------------------------------------------- /core/aff_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | class PathIndex: 6 | def __init__(self, radius, default_size): 7 | self.radius = radius 8 | self.radius_floor = int(np.ceil(radius) - 1) 9 | 10 | self.search_paths, self.search_dst = self.get_search_paths_dst(self.radius) 11 | self.path_indices, self.src_indices, self.dst_indices = self.get_path_indices(default_size) 12 | 13 | def get_search_paths_dst(self, max_radius=5): 14 | coord_indices_by_length = [[] for _ in range(max_radius * 4)] 15 | 16 | search_dirs = [] 17 | for x in range(1, max_radius): 18 | search_dirs.append((0, x)) 19 | 20 | for y in range(1, max_radius): 21 | for x in range(-max_radius + 1, max_radius): 22 | if x * x + y * y < max_radius ** 2: 23 | search_dirs.append((y, x)) 24 | 25 | for dir in search_dirs: 26 | length_sq = dir[0] ** 2 + dir[1] ** 2 27 | path_coords = [] 28 | 29 | min_y, max_y = sorted((0, dir[0])) 30 | min_x, max_x = sorted((0, dir[1])) 31 | 32 | for y in range(min_y, max_y + 1): 33 | for x in range(min_x, max_x + 1): 34 | 35 | dist_sq = (dir[0] * x - dir[1] * y) ** 2 / length_sq 36 | 37 | if dist_sq < 1: 38 | path_coords.append([y, x]) 39 | 40 | path_coords.sort(key=lambda x: -abs(x[0]) - abs(x[1])) 41 | path_length = len(path_coords) 42 | 43 | coord_indices_by_length[path_length].append(path_coords) 44 | 45 | path_list_by_length = [np.asarray(v) for v in coord_indices_by_length if v] 46 | path_destinations = np.concatenate([p[:, 0] for p in path_list_by_length], axis=0) 47 | 48 | return path_list_by_length, path_destinations 49 | 50 | def get_path_indices(self, size): 51 | full_indices = np.reshape(np.arange(0, size[0] * size[1], dtype=np.int64), (size[0], size[1])) 52 | 53 | cropped_height = size[0] - self.radius_floor 54 | cropped_width = size[1] - 2 * self.radius_floor 55 | 56 | path_indices = [] 57 | for paths in self.search_paths: 58 | 59 | path_indices_list = [] 60 | for p in paths: 61 | coord_indices_list = [] 62 | 63 | for dy, dx in p: 64 | coord_indices = full_indices[dy:dy + cropped_height, 65 | self.radius_floor + dx:self.radius_floor + dx + cropped_width] 66 | coord_indices = np.reshape(coord_indices, [-1]) 67 | 68 | coord_indices_list.append(coord_indices) 69 | 70 | path_indices_list.append(coord_indices_list) 71 | 72 | path_indices.append(np.array(path_indices_list)) 73 | 74 | src_indices = np.reshape(full_indices[:cropped_height, self.radius_floor:self.radius_floor + cropped_width], -1) 75 | dst_indices = np.concatenate([p[:,0] for p in path_indices], axis=0) 76 | 77 | return path_indices, src_indices, dst_indices 78 | 79 | 80 | def edge_to_affinity(edge, paths_indices): 81 | aff_list = [] 82 | edge = edge.view(edge.size(0), -1) 83 | 84 | for i in range(len(paths_indices)): 85 | if isinstance(paths_indices[i], np.ndarray): 86 | paths_indices[i] = torch.from_numpy(paths_indices[i]) 87 | paths_indices[i] = paths_indices[i].cuda(non_blocking=True) 88 | 89 | for ind in paths_indices: 90 | ind_flat = ind.view(-1) 91 | dist = torch.index_select(edge, dim=-1, index=ind_flat) 92 | dist = dist.view(dist.size(0), ind.size(0), ind.size(1), ind.size(2)) 93 | aff = torch.squeeze(1 - F.max_pool2d(dist, (dist.size(2), 1)), dim=2) 94 | aff_list.append(aff) 95 | aff_cat = torch.cat(aff_list, dim=1) 96 | 97 | return aff_cat 98 | 99 | 100 | def affinity_sparse2dense(affinity_sparse, ind_from, ind_to, n_vertices): 101 | ind_from = torch.from_numpy(ind_from) 102 | ind_to = torch.from_numpy(ind_to) 103 | 104 | affinity_sparse = affinity_sparse.view(-1).cpu() 105 | ind_from = ind_from.repeat(ind_to.size(0)).view(-1) 106 | ind_to = ind_to.view(-1) 107 | 108 | indices = torch.stack([ind_from, ind_to]) 109 | indices_tp = torch.stack([ind_to, ind_from]) 110 | 111 | indices_id = torch.stack([torch.arange(0, n_vertices).long(), torch.arange(0, n_vertices).long()]) 112 | 113 | affinity_dense = torch.sparse.FloatTensor(torch.cat([indices, indices_id, indices_tp], dim=1), 114 | torch.cat([affinity_sparse, torch.ones([n_vertices]), affinity_sparse])).to_dense().cuda() 115 | 116 | return affinity_dense 117 | 118 | 119 | def to_transition_matrix(affinity_dense, beta, times): 120 | scaled_affinity = torch.pow(affinity_dense, beta) 121 | 122 | trans_mat = scaled_affinity / torch.sum(scaled_affinity, dim=0, keepdim=True) 123 | for _ in range(times): 124 | trans_mat = torch.matmul(trans_mat, trans_mat) 125 | 126 | return trans_mat 127 | 128 | def propagate_to_edge(x, edge, radius=5, beta=10, exp_times=8): 129 | height, width = x.shape[-2:] 130 | 131 | hor_padded = width+radius*2 132 | ver_padded = height+radius 133 | 134 | path_index = PathIndex(radius=radius, default_size=(ver_padded, hor_padded)) 135 | 136 | edge_padded = F.pad(edge, (radius, radius, 0, radius), mode='constant', value=1.0) 137 | sparse_aff = edge_to_affinity(torch.unsqueeze(edge_padded, 0), 138 | path_index.path_indices) 139 | 140 | dense_aff = affinity_sparse2dense(sparse_aff, path_index.src_indices, 141 | path_index.dst_indices, ver_padded * hor_padded) 142 | dense_aff = dense_aff.view(ver_padded, hor_padded, ver_padded, hor_padded) 143 | dense_aff = dense_aff[:-radius, radius:-radius, :-radius, radius:-radius] 144 | dense_aff = dense_aff.reshape(height * width, height * width) 145 | 146 | trans_mat = to_transition_matrix(dense_aff, beta=beta, times=exp_times) 147 | 148 | x = x.view(-1, height, width) * (1 - edge) 149 | 150 | rw = torch.matmul(x.view(-1, height * width), trans_mat) 151 | rw = rw.view(rw.size(0), 1, height, width) 152 | 153 | return rw 154 | 155 | class GetAffinityLabelFromIndices(): 156 | def __init__(self, indices_from, indices_to): 157 | self.indices_from = indices_from 158 | self.indices_to = indices_to 159 | 160 | def __call__(self, segm_map): 161 | segm_map_flat = np.reshape(segm_map, -1) 162 | 163 | segm_label_from = np.expand_dims(segm_map_flat[self.indices_from], axis=0) 164 | segm_label_to = segm_map_flat[self.indices_to] 165 | 166 | valid_label = np.logical_and(np.less(segm_label_from, 21), np.less(segm_label_to, 21)) 167 | 168 | equal_label = np.equal(segm_label_from, segm_label_to) 169 | 170 | pos_affinity_label = np.logical_and(equal_label, valid_label) 171 | 172 | bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(segm_label_from, 0)).astype(np.float32) 173 | fg_pos_affinity_label = np.logical_and(pos_affinity_label, np.greater(segm_label_from, 0)).astype(np.float32) 174 | 175 | neg_affinity_label = np.logical_and(np.logical_not(equal_label), valid_label).astype(np.float32) 176 | 177 | return torch.from_numpy(bg_pos_affinity_label), torch.from_numpy(fg_pos_affinity_label), torch.from_numpy(neg_affinity_label) 178 | 179 | -------------------------------------------------------------------------------- /core/arch_resnest/__pycache__/resnest.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/arch_resnest/__pycache__/resnest.cpython-38.pyc -------------------------------------------------------------------------------- /core/arch_resnest/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/arch_resnest/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /core/arch_resnest/__pycache__/splat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/arch_resnest/__pycache__/splat.cpython-38.pyc -------------------------------------------------------------------------------- /core/arch_resnest/resnest(1).py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## Email: zhanghang0704@gmail.com 4 | ## Copyright (c) 2020 5 | ## 6 | ## LICENSE file in the root directory of this source tree 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | """ResNeSt models""" 9 | 10 | import torch 11 | from .resnet import ResNet, Bottleneck 12 | 13 | __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269'] 14 | 15 | _url_format = 'https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/{}-{}.pth' 16 | 17 | _model_sha256 = {name: checksum for checksum, name in [ 18 | ('528c19ca', 'resnest50'), 19 | ('22405ba7', 'resnest101'), 20 | ('75117900', 'resnest200'), 21 | ('0cc87c48', 'resnest269'), 22 | ]} 23 | 24 | def short_hash(name): 25 | if name not in _model_sha256: 26 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 27 | return _model_sha256[name][:8] 28 | 29 | resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for 30 | name in _model_sha256.keys() 31 | } 32 | 33 | def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): 34 | model = ResNet(Bottleneck, [3, 4, 6, 3], 35 | radix=2, groups=1, bottleneck_width=64, 36 | deep_stem=True, stem_width=32, avg_down=True, 37 | avd=True, avd_first=False, **kwargs) 38 | if pretrained: 39 | model.load_state_dict(torch.hub.load_state_dict_from_url( 40 | resnest_model_urls['resnest50'], progress=True, check_hash=True)) 41 | return model 42 | 43 | def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): 44 | model = ResNet(Bottleneck, [3, 4, 23, 3], 45 | radix=2, groups=1, bottleneck_width=64, 46 | deep_stem=True, stem_width=64, avg_down=True, 47 | avd=True, avd_first=False, **kwargs) 48 | if pretrained: 49 | model.load_state_dict(torch.hub.load_state_dict_from_url( 50 | resnest_model_urls['resnest101'], progress=True, check_hash=True)) 51 | return model 52 | 53 | def resnest200(pretrained=False, root='~/.encoding/models', **kwargs): 54 | model = ResNet(Bottleneck, [3, 24, 36, 3], 55 | radix=2, groups=1, bottleneck_width=64, 56 | deep_stem=True, stem_width=64, avg_down=True, 57 | avd=True, avd_first=False, **kwargs) 58 | if pretrained: 59 | model.load_state_dict(torch.hub.load_state_dict_from_url( 60 | resnest_model_urls['resnest200'], progress=True, check_hash=True)) 61 | return model 62 | 63 | def resnest269(pretrained=False, root='~/.encoding/models', **kwargs): 64 | model = ResNet(Bottleneck, [3, 30, 48, 8], 65 | radix=2, groups=1, bottleneck_width=64, 66 | deep_stem=True, stem_width=64, avg_down=True, 67 | avd=True, avd_first=False, **kwargs) 68 | if pretrained: 69 | model.load_state_dict(torch.hub.load_state_dict_from_url( 70 | resnest_model_urls['resnest269'], progress=True, check_hash=True)) 71 | return model 72 | -------------------------------------------------------------------------------- /core/arch_resnest/resnest.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## Email: zhanghang0704@gmail.com 4 | ## Copyright (c) 2020 5 | ## 6 | ## LICENSE file in the root directory of this source tree 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | """ResNeSt models""" 9 | 10 | import torch 11 | from .resnet import ResNet, Bottleneck 12 | 13 | __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269'] 14 | 15 | _url_format = 'https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/{}-{}.pth' 16 | 17 | _model_sha256 = {name: checksum for checksum, name in [ 18 | ('528c19ca', 'resnest50'), 19 | ('22405ba7', 'resnest101'), 20 | ('75117900', 'resnest200'), 21 | ('0cc87c48', 'resnest269'), 22 | ]} 23 | 24 | def short_hash(name): 25 | if name not in _model_sha256: 26 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 27 | return _model_sha256[name][:8] 28 | 29 | resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for 30 | name in _model_sha256.keys() 31 | } 32 | 33 | def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): 34 | model = ResNet(Bottleneck, [3, 4, 6, 3], 35 | radix=2, groups=1, bottleneck_width=64, 36 | deep_stem=True, stem_width=32, avg_down=True, 37 | avd=True, avd_first=False, **kwargs) 38 | if pretrained: 39 | model.load_state_dict(torch.hub.load_state_dict_from_url( 40 | resnest_model_urls['resnest50'], progress=True, check_hash=True)) 41 | return model 42 | 43 | def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): 44 | model = ResNet(Bottleneck, [3, 4, 23, 3], 45 | radix=2, groups=1, bottleneck_width=64, 46 | deep_stem=True, stem_width=64, avg_down=True, 47 | avd=True, avd_first=False, **kwargs) 48 | if pretrained: 49 | model.load_state_dict(torch.hub.load_state_dict_from_url( 50 | resnest_model_urls['resnest101'], progress=True, check_hash=True)) 51 | return model 52 | 53 | def resnest200(pretrained=False, root='~/.encoding/models', **kwargs): 54 | model = ResNet(Bottleneck, [3, 24, 36, 3], 55 | radix=2, groups=1, bottleneck_width=64, 56 | deep_stem=True, stem_width=64, avg_down=True, 57 | avd=True, avd_first=False, **kwargs) 58 | if pretrained: 59 | model.load_state_dict(torch.hub.load_state_dict_from_url( 60 | resnest_model_urls['resnest200'], progress=True, check_hash=True)) 61 | return model 62 | 63 | def resnest269(pretrained=False, root='~/.encoding/models', **kwargs): 64 | model = ResNet(Bottleneck, [3, 30, 48, 8], 65 | radix=2, groups=1, bottleneck_width=64, 66 | deep_stem=True, stem_width=64, avg_down=True, 67 | avd=True, avd_first=False, **kwargs) 68 | if pretrained: 69 | model.load_state_dict(torch.hub.load_state_dict_from_url( 70 | resnest_model_urls['resnest269'], progress=True, check_hash=True)) 71 | return model 72 | -------------------------------------------------------------------------------- /core/arch_resnest/resnet.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## Email: zhanghang0704@gmail.com 4 | ## Copyright (c) 2020 5 | ## 6 | ## LICENSE file in the root directory of this source tree 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | """ResNet variants""" 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | 13 | from .splat import SplAtConv2d 14 | 15 | __all__ = ['ResNet', 'Bottleneck'] 16 | 17 | class DropBlock2D(object): 18 | def __init__(self, *args, **kwargs): 19 | raise NotImplementedError 20 | 21 | class GlobalAvgPool2d(nn.Module): 22 | def __init__(self): 23 | """Global average pooling over the input's spatial dimensions""" 24 | super(GlobalAvgPool2d, self).__init__() 25 | 26 | def forward(self, inputs): 27 | return nn.functional.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1) 28 | 29 | class Bottleneck(nn.Module): 30 | """ResNet Bottleneck 31 | """ 32 | # pylint: disable=unused-argument 33 | expansion = 4 34 | def __init__(self, inplanes, planes, stride=1, downsample=None, 35 | radix=1, cardinality=1, bottleneck_width=64, 36 | avd=False, avd_first=False, dilation=1, is_first=False, 37 | rectified_conv=False, rectify_avg=False, 38 | norm_layer=None, dropblock_prob=0.0, last_gamma=False): 39 | super(Bottleneck, self).__init__() 40 | group_width = int(planes * (bottleneck_width / 64.)) * cardinality 41 | self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) 42 | self.bn1 = norm_layer(group_width) 43 | self.dropblock_prob = dropblock_prob 44 | self.radix = radix 45 | self.avd = avd and (stride > 1 or is_first) 46 | self.avd_first = avd_first 47 | 48 | if self.avd: 49 | self.avd_layer = nn.AvgPool2d(3, stride, padding=1) 50 | stride = 1 51 | 52 | if dropblock_prob > 0.0: 53 | self.dropblock1 = DropBlock2D(dropblock_prob, 3) 54 | if radix == 1: 55 | self.dropblock2 = DropBlock2D(dropblock_prob, 3) 56 | self.dropblock3 = DropBlock2D(dropblock_prob, 3) 57 | 58 | if radix >= 1: 59 | self.conv2 = SplAtConv2d( 60 | group_width, group_width, kernel_size=3, 61 | stride=stride, padding=dilation, 62 | dilation=dilation, groups=cardinality, bias=False, 63 | radix=radix, rectify=rectified_conv, 64 | rectify_avg=rectify_avg, 65 | norm_layer=norm_layer, 66 | dropblock_prob=dropblock_prob) 67 | elif rectified_conv: 68 | from rfconv import RFConv2d 69 | self.conv2 = RFConv2d( 70 | group_width, group_width, kernel_size=3, stride=stride, 71 | padding=dilation, dilation=dilation, 72 | groups=cardinality, bias=False, 73 | average_mode=rectify_avg) 74 | self.bn2 = norm_layer(group_width) 75 | else: 76 | self.conv2 = nn.Conv2d( 77 | group_width, group_width, kernel_size=3, stride=stride, 78 | padding=dilation, dilation=dilation, 79 | groups=cardinality, bias=False) 80 | self.bn2 = norm_layer(group_width) 81 | 82 | self.conv3 = nn.Conv2d( 83 | group_width, planes * 4, kernel_size=1, bias=False) 84 | self.bn3 = norm_layer(planes*4) 85 | 86 | if last_gamma: 87 | from torch.nn.init import zeros_ 88 | zeros_(self.bn3.weight) 89 | self.relu = nn.ReLU(inplace=True) 90 | self.downsample = downsample 91 | self.dilation = dilation 92 | self.stride = stride 93 | 94 | def forward(self, x): 95 | residual = x 96 | 97 | out = self.conv1(x) 98 | out = self.bn1(out) 99 | if self.dropblock_prob > 0.0: 100 | out = self.dropblock1(out) 101 | out = self.relu(out) 102 | 103 | if self.avd and self.avd_first: 104 | out = self.avd_layer(out) 105 | 106 | out = self.conv2(out) 107 | if self.radix == 0: 108 | out = self.bn2(out) 109 | if self.dropblock_prob > 0.0: 110 | out = self.dropblock2(out) 111 | out = self.relu(out) 112 | 113 | if self.avd and not self.avd_first: 114 | out = self.avd_layer(out) 115 | 116 | out = self.conv3(out) 117 | out = self.bn3(out) 118 | if self.dropblock_prob > 0.0: 119 | out = self.dropblock3(out) 120 | 121 | if self.downsample is not None: 122 | residual = self.downsample(x) 123 | 124 | out += residual 125 | out = self.relu(out) 126 | 127 | return out 128 | 129 | class ResNet(nn.Module): 130 | """ResNet Variants 131 | 132 | Parameters 133 | ---------- 134 | block : Block 135 | Class for the residual block. Options are BasicBlockV1, BottleneckV1. 136 | layers : list of int 137 | Numbers of layers in each block 138 | classes : int, default 1000 139 | Number of classification classes. 140 | dilated : bool, default False 141 | Applying dilation strategy to pretrained ResNet yielding a stride-8 model, 142 | typically used in Semantic Segmentation. 143 | norm_layer : object 144 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 145 | for Synchronized Cross-GPU BachNormalization). 146 | 147 | Reference: 148 | 149 | - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 150 | 151 | - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." 152 | """ 153 | # pylint: disable=unused-variable 154 | def __init__(self, block, layers, radix=1, groups=1, bottleneck_width=64, 155 | num_classes=1000, dilated=False, dilation=1, 156 | deep_stem=False, stem_width=64, avg_down=False, 157 | rectified_conv=False, rectify_avg=False, 158 | avd=False, avd_first=False, 159 | final_drop=0.0, dropblock_prob=0, 160 | last_gamma=False, norm_layer=nn.BatchNorm2d): 161 | self.cardinality = groups 162 | self.bottleneck_width = bottleneck_width 163 | # ResNet-D params 164 | self.inplanes = stem_width*2 if deep_stem else 64 165 | self.avg_down = avg_down 166 | self.last_gamma = last_gamma 167 | # ResNeSt params 168 | self.radix = radix 169 | self.avd = avd 170 | self.avd_first = avd_first 171 | 172 | super(ResNet, self).__init__() 173 | self.rectified_conv = rectified_conv 174 | self.rectify_avg = rectify_avg 175 | if rectified_conv: 176 | from rfconv import RFConv2d 177 | conv_layer = RFConv2d 178 | else: 179 | conv_layer = nn.Conv2d 180 | conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {} 181 | if deep_stem: 182 | self.conv1 = nn.Sequential( 183 | conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs), 184 | norm_layer(stem_width), 185 | nn.ReLU(inplace=True), 186 | conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), 187 | norm_layer(stem_width), 188 | nn.ReLU(inplace=True), 189 | conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs), 190 | ) 191 | else: 192 | self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3, 193 | bias=False, **conv_kwargs) 194 | self.bn1 = norm_layer(self.inplanes) 195 | self.relu = nn.ReLU(inplace=True) 196 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 197 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False) 198 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 199 | if dilated or dilation == 4: 200 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 201 | dilation=2, norm_layer=norm_layer, 202 | dropblock_prob=dropblock_prob) 203 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 204 | dilation=4, norm_layer=norm_layer, 205 | dropblock_prob=dropblock_prob) 206 | elif dilation==2: 207 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 208 | dilation=1, norm_layer=norm_layer, 209 | dropblock_prob=dropblock_prob) 210 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 211 | dilation=2, norm_layer=norm_layer, 212 | dropblock_prob=dropblock_prob) 213 | else: 214 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 215 | norm_layer=norm_layer, 216 | dropblock_prob=dropblock_prob) 217 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 218 | norm_layer=norm_layer, 219 | dropblock_prob=dropblock_prob) 220 | 221 | self.avgpool = GlobalAvgPool2d() 222 | self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None 223 | self.fc = nn.Linear(512 * block.expansion, num_classes) 224 | 225 | for m in self.modules(): 226 | if isinstance(m, nn.Conv2d): 227 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 228 | m.weight.data.normal_(0, math.sqrt(2. / n)) 229 | elif isinstance(m, norm_layer): 230 | m.weight.data.fill_(1) 231 | m.bias.data.zero_() 232 | 233 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, 234 | dropblock_prob=0.0, is_first=True): 235 | downsample = None 236 | if stride != 1 or self.inplanes != planes * block.expansion: 237 | down_layers = [] 238 | if self.avg_down: 239 | if dilation == 1: 240 | down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride, 241 | ceil_mode=True, count_include_pad=False)) 242 | else: 243 | down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1, 244 | ceil_mode=True, count_include_pad=False)) 245 | down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, 246 | kernel_size=1, stride=1, bias=False)) 247 | else: 248 | down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, 249 | kernel_size=1, stride=stride, bias=False)) 250 | down_layers.append(norm_layer(planes * block.expansion)) 251 | downsample = nn.Sequential(*down_layers) 252 | 253 | layers = [] 254 | if dilation == 1 or dilation == 2: 255 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 256 | radix=self.radix, cardinality=self.cardinality, 257 | bottleneck_width=self.bottleneck_width, 258 | avd=self.avd, avd_first=self.avd_first, 259 | dilation=1, is_first=is_first, rectified_conv=self.rectified_conv, 260 | rectify_avg=self.rectify_avg, 261 | norm_layer=norm_layer, dropblock_prob=dropblock_prob, 262 | last_gamma=self.last_gamma)) 263 | elif dilation == 4: 264 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 265 | radix=self.radix, cardinality=self.cardinality, 266 | bottleneck_width=self.bottleneck_width, 267 | avd=self.avd, avd_first=self.avd_first, 268 | dilation=2, is_first=is_first, rectified_conv=self.rectified_conv, 269 | rectify_avg=self.rectify_avg, 270 | norm_layer=norm_layer, dropblock_prob=dropblock_prob, 271 | last_gamma=self.last_gamma)) 272 | else: 273 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 274 | 275 | self.inplanes = planes * block.expansion 276 | for i in range(1, blocks): 277 | layers.append(block(self.inplanes, planes, 278 | radix=self.radix, cardinality=self.cardinality, 279 | bottleneck_width=self.bottleneck_width, 280 | avd=self.avd, avd_first=self.avd_first, 281 | dilation=dilation, rectified_conv=self.rectified_conv, 282 | rectify_avg=self.rectify_avg, 283 | norm_layer=norm_layer, dropblock_prob=dropblock_prob, 284 | last_gamma=self.last_gamma)) 285 | 286 | return nn.Sequential(*layers) 287 | 288 | def forward(self, x): 289 | x = self.conv1(x) 290 | x = self.bn1(x) 291 | x = self.relu(x) 292 | x = self.maxpool(x) 293 | 294 | x = self.layer1(x) 295 | x = self.layer2(x) 296 | x = self.layer3(x) 297 | x = self.layer4(x) 298 | 299 | # print(x.size()) 300 | 301 | x = self.avgpool(x) 302 | #x = x.view(x.size(0), -1) 303 | x = torch.flatten(x, 1) 304 | if self.drop: 305 | x = self.drop(x) 306 | x = self.fc(x) 307 | 308 | return x 309 | -------------------------------------------------------------------------------- /core/arch_resnest/splat(1).py: -------------------------------------------------------------------------------- 1 | """Split-Attention""" 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU 7 | from torch.nn.modules.utils import _pair 8 | 9 | __all__ = ['SplAtConv2d'] 10 | 11 | class SplAtConv2d(Module): 12 | """Split-Attention Conv2d 13 | """ 14 | def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), 15 | dilation=(1, 1), groups=1, bias=True, 16 | radix=2, reduction_factor=4, 17 | rectify=False, rectify_avg=False, norm_layer=None, 18 | dropblock_prob=0.0, **kwargs): 19 | super(SplAtConv2d, self).__init__() 20 | padding = _pair(padding) 21 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 22 | self.rectify_avg = rectify_avg 23 | inter_channels = max(in_channels*radix//reduction_factor, 32) 24 | self.radix = radix 25 | self.cardinality = groups 26 | self.channels = channels 27 | self.dropblock_prob = dropblock_prob 28 | if self.rectify: 29 | from rfconv import RFConv2d 30 | self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, 31 | groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs) 32 | else: 33 | self.conv = Conv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, 34 | groups=groups*radix, bias=bias, **kwargs) 35 | self.use_bn = norm_layer is not None 36 | if self.use_bn: 37 | self.bn0 = norm_layer(channels*radix) 38 | self.relu = ReLU(inplace=True) 39 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 40 | if self.use_bn: 41 | self.bn1 = norm_layer(inter_channels) 42 | self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality) 43 | if dropblock_prob > 0.0: 44 | self.dropblock = DropBlock2D(dropblock_prob, 3) 45 | self.rsoftmax = rSoftMax(radix, groups) 46 | 47 | def forward(self, x): 48 | x = self.conv(x) 49 | if self.use_bn: 50 | x = self.bn0(x) 51 | if self.dropblock_prob > 0.0: 52 | x = self.dropblock(x) 53 | x = self.relu(x) 54 | 55 | batch, rchannel = x.shape[:2] 56 | if self.radix > 1: 57 | if torch.__version__ < '1.5': 58 | splited = torch.split(x, int(rchannel//self.radix), dim=1) 59 | else: 60 | splited = torch.split(x, rchannel//self.radix, dim=1) 61 | gap = sum(splited) 62 | else: 63 | gap = x 64 | gap = F.adaptive_avg_pool2d(gap, 1) 65 | gap = self.fc1(gap) 66 | 67 | if self.use_bn: 68 | gap = self.bn1(gap) 69 | gap = self.relu(gap) 70 | 71 | atten = self.fc2(gap) 72 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 73 | 74 | if self.radix > 1: 75 | if torch.__version__ < '1.5': 76 | attens = torch.split(atten, int(rchannel//self.radix), dim=1) 77 | else: 78 | attens = torch.split(atten, rchannel//self.radix, dim=1) 79 | out = sum([att*split for (att, split) in zip(attens, splited)]) 80 | else: 81 | out = atten * x 82 | return out.contiguous() 83 | 84 | class rSoftMax(nn.Module): 85 | def __init__(self, radix, cardinality): 86 | super().__init__() 87 | self.radix = radix 88 | self.cardinality = cardinality 89 | 90 | def forward(self, x): 91 | batch = x.size(0) 92 | if self.radix > 1: 93 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 94 | x = F.softmax(x, dim=1) 95 | x = x.reshape(batch, -1) 96 | else: 97 | x = torch.sigmoid(x) 98 | return x 99 | 100 | -------------------------------------------------------------------------------- /core/arch_resnest/splat.py: -------------------------------------------------------------------------------- 1 | """Split-Attention""" 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU 7 | from torch.nn.modules.utils import _pair 8 | 9 | __all__ = ['SplAtConv2d'] 10 | 11 | class SplAtConv2d(Module): 12 | """Split-Attention Conv2d 13 | """ 14 | def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), 15 | dilation=(1, 1), groups=1, bias=True, 16 | radix=2, reduction_factor=4, 17 | rectify=False, rectify_avg=False, norm_layer=None, 18 | dropblock_prob=0.0, **kwargs): 19 | super(SplAtConv2d, self).__init__() 20 | padding = _pair(padding) 21 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 22 | self.rectify_avg = rectify_avg 23 | inter_channels = max(in_channels*radix//reduction_factor, 32) 24 | self.radix = radix 25 | self.cardinality = groups 26 | self.channels = channels 27 | self.dropblock_prob = dropblock_prob 28 | if self.rectify: 29 | from rfconv import RFConv2d 30 | self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, 31 | groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs) 32 | else: 33 | self.conv = Conv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, 34 | groups=groups*radix, bias=bias, **kwargs) 35 | self.use_bn = norm_layer is not None 36 | if self.use_bn: 37 | self.bn0 = norm_layer(channels*radix) 38 | self.relu = ReLU(inplace=True) 39 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 40 | if self.use_bn: 41 | self.bn1 = norm_layer(inter_channels) 42 | self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality) 43 | if dropblock_prob > 0.0: 44 | self.dropblock = DropBlock2D(dropblock_prob, 3) 45 | self.rsoftmax = rSoftMax(radix, groups) 46 | 47 | def forward(self, x): 48 | x = self.conv(x) 49 | if self.use_bn: 50 | x = self.bn0(x) 51 | if self.dropblock_prob > 0.0: 52 | x = self.dropblock(x) 53 | x = self.relu(x) 54 | 55 | batch, rchannel = x.shape[:2] 56 | if self.radix > 1: 57 | if torch.__version__ < '1.5': 58 | splited = torch.split(x, int(rchannel//self.radix), dim=1) 59 | else: 60 | splited = torch.split(x, rchannel//self.radix, dim=1) 61 | gap = sum(splited) 62 | else: 63 | gap = x 64 | gap = F.adaptive_avg_pool2d(gap, 1) 65 | gap = self.fc1(gap) 66 | 67 | if self.use_bn: 68 | gap = self.bn1(gap) 69 | gap = self.relu(gap) 70 | 71 | atten = self.fc2(gap) 72 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 73 | 74 | if self.radix > 1: 75 | if torch.__version__ < '1.5': 76 | attens = torch.split(atten, int(rchannel//self.radix), dim=1) 77 | else: 78 | attens = torch.split(atten, rchannel//self.radix, dim=1) 79 | out = sum([att*split for (att, split) in zip(attens, splited)]) 80 | else: 81 | out = atten * x 82 | return out.contiguous() 83 | 84 | class rSoftMax(nn.Module): 85 | def __init__(self, radix, cardinality): 86 | super().__init__() 87 | self.radix = radix 88 | self.cardinality = cardinality 89 | 90 | def forward(self, x): 91 | batch = x.size(0) 92 | if self.radix > 1: 93 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 94 | x = F.softmax(x, dim=1) 95 | x = x.reshape(batch, -1) 96 | else: 97 | x = torch.sigmoid(x) 98 | return x 99 | 100 | -------------------------------------------------------------------------------- /core/arch_resnet/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/arch_resnet/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /core/arch_resnet/resnet(1).py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | urls_dic = { 6 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 7 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 8 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 9 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 10 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 11 | } 12 | 13 | layers_dic = { 14 | 'resnet18' : [2, 2, 2, 2], 15 | 'resnet34' : [3, 4, 6, 3], 16 | 'resnet50' : [3, 4, 6, 3], 17 | 'resnet101' : [3, 4, 23, 3], 18 | 'resnet152' : [3, 8, 36, 3] 19 | } 20 | 21 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 22 | """3x3 convolution with padding""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=dilation, groups=groups, bias=False, dilation=dilation) 25 | 26 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 27 | """1x1 convolution""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 29 | 30 | class BasicBlock(nn.Module): 31 | expansion: int = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, batch_norm_fn=nn.BatchNorm2d): 34 | super(BasicBlock, self).__init__() 35 | 36 | self.conv1 = conv3x3(inplanes, planes, stride) 37 | self.bn1 = batch_norm_fn(planes) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.conv2 = conv3x3(planes, planes) 40 | self.bn2 = batch_norm_fn(planes) 41 | self.downsample = downsample 42 | self.stride = stride 43 | 44 | def forward(self, x): 45 | identity = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | 54 | if self.downsample is not None: 55 | identity = self.downsample(x) 56 | 57 | out += identity 58 | out = self.relu(out) 59 | 60 | return out 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, batch_norm_fn=nn.BatchNorm2d): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = batch_norm_fn(planes) 69 | 70 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 71 | padding=dilation, bias=False, dilation=dilation) 72 | self.bn2 = batch_norm_fn(planes) 73 | 74 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 75 | self.bn3 = batch_norm_fn(planes * 4) 76 | 77 | self.relu = nn.ReLU(inplace=True) 78 | self.downsample = downsample 79 | self.stride = stride 80 | self.dilation = dilation 81 | 82 | def forward(self, x): 83 | residual = x 84 | 85 | out = self.conv1(x) 86 | out = self.bn1(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv3(out) 94 | out = self.bn3(out) 95 | 96 | if self.downsample is not None: 97 | residual = self.downsample(x) 98 | 99 | out += residual 100 | out = self.relu(out) 101 | 102 | return out 103 | 104 | class ResNet(nn.Module): 105 | 106 | def __init__(self, block, layers, strides=(2, 2, 2, 2), dilations=(1, 1, 1, 1), batch_norm_fn=nn.BatchNorm2d): 107 | self.batch_norm_fn = batch_norm_fn 108 | 109 | self.inplanes = 64 110 | super(ResNet, self).__init__() 111 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=strides[0], padding=3, 112 | bias=False) 113 | self.bn1 = self.batch_norm_fn(64) 114 | self.relu = nn.ReLU(inplace=True) 115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 116 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1, dilation=dilations[0]) 117 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1]) 118 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2]) 119 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3]) 120 | self.inplanes = 1024 121 | 122 | #self.avgpool = nn.AvgPool2d(7, stride=1) 123 | #self.fc = nn.Linear(512 * block.expansion, 1000) 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = nn.Sequential( 129 | nn.Conv2d(self.inplanes, planes * block.expansion, 130 | kernel_size=1, stride=stride, bias=False), 131 | self.batch_norm_fn(planes * block.expansion), 132 | ) 133 | 134 | layers = [block(self.inplanes, planes, stride, downsample, dilation=1, batch_norm_fn=self.batch_norm_fn)] 135 | self.inplanes = planes * block.expansion 136 | for i in range(1, blocks): 137 | layers.append(block(self.inplanes, planes, dilation=dilation, batch_norm_fn=self.batch_norm_fn)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x): 142 | x = self.conv1(x) 143 | x = self.bn1(x) 144 | x = self.relu(x) 145 | x = self.maxpool(x) 146 | 147 | x = self.layer1(x) 148 | x = self.layer2(x) 149 | x = self.layer3(x) 150 | x = self.layer4(x) 151 | 152 | x = self.avgpool(x) 153 | x = x.view(x.size(0), -1) 154 | x = self.fc(x) 155 | 156 | return x 157 | 158 | -------------------------------------------------------------------------------- /core/arch_resnet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | urls_dic = { 6 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 7 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 8 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 9 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 10 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 11 | } 12 | 13 | layers_dic = { 14 | 'resnet18' : [2, 2, 2, 2], 15 | 'resnet34' : [3, 4, 6, 3], 16 | 'resnet50' : [3, 4, 6, 3], 17 | 'resnet101' : [3, 4, 23, 3], 18 | 'resnet152' : [3, 8, 36, 3] 19 | } 20 | 21 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 22 | """3x3 convolution with padding""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=dilation, groups=groups, bias=False, dilation=dilation) 25 | 26 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 27 | """1x1 convolution""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 29 | 30 | class BasicBlock(nn.Module): 31 | expansion: int = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, batch_norm_fn=nn.BatchNorm2d): 34 | super(BasicBlock, self).__init__() 35 | 36 | self.conv1 = conv3x3(inplanes, planes, stride) 37 | self.bn1 = batch_norm_fn(planes) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.conv2 = conv3x3(planes, planes) 40 | self.bn2 = batch_norm_fn(planes) 41 | self.downsample = downsample 42 | self.stride = stride 43 | 44 | def forward(self, x): 45 | identity = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | 54 | if self.downsample is not None: 55 | identity = self.downsample(x) 56 | 57 | out += identity 58 | out = self.relu(out) 59 | 60 | return out 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, batch_norm_fn=nn.BatchNorm2d): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = batch_norm_fn(planes) 69 | 70 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 71 | padding=dilation, bias=False, dilation=dilation) 72 | self.bn2 = batch_norm_fn(planes) 73 | 74 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 75 | self.bn3 = batch_norm_fn(planes * 4) 76 | 77 | self.relu = nn.ReLU(inplace=True) 78 | self.downsample = downsample 79 | self.stride = stride 80 | self.dilation = dilation 81 | 82 | def forward(self, x): 83 | residual = x 84 | 85 | out = self.conv1(x) 86 | out = self.bn1(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv3(out) 94 | out = self.bn3(out) 95 | 96 | if self.downsample is not None: 97 | residual = self.downsample(x) 98 | 99 | out += residual 100 | out = self.relu(out) 101 | 102 | return out 103 | 104 | class ResNet(nn.Module): 105 | 106 | def __init__(self, block, layers, strides=(2, 2, 2, 2), dilations=(1, 1, 1, 1), batch_norm_fn=nn.BatchNorm2d): 107 | self.batch_norm_fn = batch_norm_fn 108 | 109 | self.inplanes = 64 110 | super(ResNet, self).__init__() 111 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=strides[0], padding=3, 112 | bias=False) 113 | self.bn1 = self.batch_norm_fn(64) 114 | self.relu = nn.ReLU(inplace=True) 115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 116 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1, dilation=dilations[0]) 117 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1]) 118 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2]) 119 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3]) 120 | self.inplanes = 1024 121 | 122 | #self.avgpool = nn.AvgPool2d(7, stride=1) 123 | #self.fc = nn.Linear(512 * block.expansion, 1000) 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = nn.Sequential( 129 | nn.Conv2d(self.inplanes, planes * block.expansion, 130 | kernel_size=1, stride=stride, bias=False), 131 | self.batch_norm_fn(planes * block.expansion), 132 | ) 133 | 134 | layers = [block(self.inplanes, planes, stride, downsample, dilation=1, batch_norm_fn=self.batch_norm_fn)] 135 | self.inplanes = planes * block.expansion 136 | for i in range(1, blocks): 137 | layers.append(block(self.inplanes, planes, dilation=dilation, batch_norm_fn=self.batch_norm_fn)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x): 142 | x = self.conv1(x) 143 | x = self.bn1(x) 144 | x = self.relu(x) 145 | x = self.maxpool(x) 146 | 147 | x = self.layer1(x) 148 | x = self.layer2(x) 149 | x = self.layer3(x) 150 | x = self.layer4(x) 151 | 152 | x = self.avgpool(x) 153 | x = x.view(x.size(0), -1) 154 | x = self.fc(x) 155 | 156 | return x 157 | 158 | -------------------------------------------------------------------------------- /core/config.yaml: -------------------------------------------------------------------------------- 1 | CUDNN: 2 | BENCHMARK: true 3 | DETERMINISTIC: false 4 | ENABLED: true 5 | GPUS: (0,1,2,3) 6 | OUTPUT_DIR: 'output' 7 | LOG_DIR: 'log' 8 | WORKERS: 4 9 | PRINT_FREQ: 100 10 | 11 | DATASET: 12 | DATASET: pascal_ctx 13 | ROOT: 'data/' 14 | TEST_SET: 'val' 15 | TRAIN_SET: 'train' 16 | NUM_CLASSES: 59 17 | MODEL: 18 | NAME: seg_hrnet 19 | PRETRAINED: 'pretrained_models/hrnetv2_w48_imagenet_pretrained.pth' 20 | EXTRA: 21 | FINAL_CONV_KERNEL: 1 22 | STAGE1: 23 | NUM_MODULES: 1 24 | NUM_RANCHES: 1 25 | BLOCK: BOTTLENECK 26 | NUM_BLOCKS: 27 | - 4 28 | NUM_CHANNELS: 29 | - 64 30 | FUSE_METHOD: SUM 31 | STAGE2: 32 | NUM_MODULES: 1 33 | NUM_BRANCHES: 2 34 | BLOCK: BASIC 35 | NUM_BLOCKS: 36 | - 4 37 | - 4 38 | NUM_CHANNELS: 39 | - 48 40 | - 96 41 | FUSE_METHOD: SUM 42 | STAGE3: 43 | NUM_MODULES: 4 44 | NUM_BRANCHES: 3 45 | BLOCK: BASIC 46 | NUM_BLOCKS: 47 | - 4 48 | - 4 49 | - 4 50 | NUM_CHANNELS: 51 | - 48 52 | - 96 53 | - 192 54 | FUSE_METHOD: SUM 55 | STAGE4: 56 | NUM_MODULES: 3 57 | NUM_BRANCHES: 4 58 | BLOCK: BASIC 59 | NUM_BLOCKS: 60 | - 4 61 | - 4 62 | - 4 63 | - 4 64 | NUM_CHANNELS: 65 | - 48 66 | - 96 67 | - 192 68 | - 384 69 | FUSE_METHOD: SUM 70 | LOSS: 71 | USE_OHEM: false 72 | OHEMTHRES: 0.9 73 | OHEMKEEP: 131072 74 | TRAIN: 75 | IMAGE_SIZE: 76 | - 480 77 | - 480 78 | BASE_SIZE: 520 79 | BATCH_SIZE_PER_GPU: 4 80 | SHUFFLE: true 81 | BEGIN_EPOCH: 0 82 | END_EPOCH: 200 83 | RESUME: true 84 | OPTIMIZER: sgd 85 | LR: 0.004 86 | WD: 0.0001 87 | MOMENTUM: 0.9 88 | NESTEROV: false 89 | FLIP: true 90 | MULTI_SCALE: true 91 | DOWNSAMPLERATE: 1 92 | IGNORE_LABEL: -1 93 | SCALE_FACTOR: 16 94 | TEST: 95 | IMAGE_SIZE: 96 | - 480 97 | - 480 98 | BASE_SIZE: 520 99 | BATCH_SIZE_PER_GPU: 16 100 | FLIP_TEST: false 101 | MULTI_SCALE: false 102 | -------------------------------------------------------------------------------- /core/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from re import L 3 | import cv2 4 | import glob 5 | import torch 6 | 7 | import math 8 | import imageio 9 | import numpy as np 10 | import re 11 | 12 | from PIL import Image 13 | 14 | from core.aff_utils import * 15 | 16 | from tools.ai.augment_utils import * 17 | from tools.ai.torch_utils import one_hot_embedding 18 | 19 | from tools.general.xml_utils import read_xml 20 | from tools.general.json_utils import read_json 21 | from tools.dataset.voc_utils import get_color_map_dic 22 | 23 | class Iterator: 24 | def __init__(self, loader): 25 | self.loader = loader 26 | self.init() 27 | 28 | def init(self): 29 | self.iterator = iter(self.loader) 30 | 31 | def get(self): 32 | try: 33 | data = next(self.iterator) 34 | except StopIteration: 35 | self.init() 36 | data = next(self.iterator) 37 | 38 | return data 39 | 40 | class VOC_Dataset(torch.utils.data.Dataset): 41 | def __init__(self, root_dir, domain, with_id=False, with_tags=False, with_mask=False): 42 | self.root_dir = root_dir 43 | 44 | self.image_dir = self.root_dir + '1.training/' 45 | self.xml_dir = self.root_dir + 'Annotations/' 46 | self.mask_dir = self.root_dir + 'SegmentationClass/' 47 | 48 | self.image_id_list = [image_id.strip() for image_id in open('./data/%s.txt'%domain).readlines()] 49 | 50 | self.with_id = with_id 51 | self.with_tags = with_tags 52 | self.with_mask = with_mask 53 | 54 | def __len__(self): 55 | return len(self.image_id_list) 56 | 57 | def get_image(self, image_id): 58 | image = Image.open(self.image_dir + image_id + '.png').convert('RGB') 59 | return image 60 | 61 | def get_mask(self, image_id): 62 | mask_path = self.mask_dir + image_id + '.png' 63 | if os.path.isfile(mask_path): 64 | mask = Image.open(mask_path).convert('RGB') 65 | else: 66 | mask = None 67 | return mask 68 | 69 | def get_tags(self, image_id): 70 | _, tags = read_xml(self.xml_dir + image_id + '.xml') 71 | return tags 72 | 73 | def __getitem__(self, index): 74 | image_id = self.image_id_list[index] 75 | 76 | data_list = [self.get_image(image_id)] 77 | 78 | if self.with_id: 79 | data_list.append(image_id) 80 | 81 | if self.with_tags: 82 | data_list.append(self.get_tags(image_id)) 83 | 84 | if self.with_mask: 85 | data_list.append(self.get_mask(image_id)) 86 | 87 | return data_list 88 | 89 | class VOC_Dataset_For_Classification(VOC_Dataset): 90 | def __init__(self, root_dir, domain, transform=None): 91 | super().__init__(root_dir, domain, with_tags=True) 92 | self.transform = transform 93 | 94 | data = read_json('./data/VOC_2012.json') 95 | 96 | self.class_dic = data['class_dic'] 97 | self.classes = data['classes'] 98 | 99 | def __getitem__(self, index): 100 | image, tags = super().__getitem__(index) 101 | 102 | if self.transform is not None: 103 | image = self.transform(image) 104 | 105 | label = one_hot_embedding([self.class_dic[tag] for tag in tags], self.classes) 106 | return image, label 107 | 108 | class VOC_Dataset_For_Segmentation(VOC_Dataset): 109 | def __init__(self, root_dir, domain, transform=None): 110 | super().__init__(root_dir, domain, with_mask=True) 111 | self.transform = transform 112 | self.image_dir = self.root_dir + '2.validation/img_patch_256/' 113 | self.mask_dir = self.root_dir + '2.validation/mask_patch_256/' 114 | self.colors = np.array([[255, 255, 255], [0, 64, 128], [64, 128, 0], [243, 152, 0]], dtype=np.int32) 115 | 116 | def __getitem__(self, index): 117 | image, mask = super().__getitem__(index) 118 | mask = np.array(mask).astype(np.int32) 119 | mask = self.image2label(mask) 120 | if self.transform is not None: 121 | input_dic = {'image':image, 'mask':mask} 122 | output_dic = self.transform(input_dic) 123 | 124 | image = output_dic['image'] 125 | mask = output_dic['mask'] 126 | 127 | return image, mask 128 | 129 | def image2label(self, im): 130 | color2int = np.zeros(256 ** 3) # 131 | for idx, color in enumerate(self.colors): 132 | color2int[(color[0] * 256 + color[1]) * 256 + color[2]] = idx # 133 | data = np.array(im, dtype=np.int32) 134 | idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2] 135 | return np.array(color2int[idx], dtype=np.int32) # 136 | 137 | class VOC_Dataset_For_Evaluation(VOC_Dataset): 138 | def __init__(self, root_dir, domain, transform=None): 139 | super().__init__(root_dir, domain, with_id=True, with_mask=True) 140 | self.transform = transform 141 | 142 | cmap_dic, _, class_names = get_color_map_dic() 143 | self.colors = np.asarray([cmap_dic[class_name] for class_name in class_names]) 144 | 145 | def __getitem__(self, index): 146 | image, image_id, mask = super().__getitem__(index) 147 | 148 | if self.transform is not None: 149 | input_dic = {'image':image, 'mask':mask} 150 | output_dic = self.transform(input_dic) 151 | 152 | image = output_dic['image'] 153 | mask = output_dic['mask'] 154 | 155 | return image, image_id, mask 156 | 157 | class VOC_Dataset_For_WSSS(VOC_Dataset): 158 | def __init__(self, root_dir, domain, pred_dir, transform=None): 159 | super().__init__(root_dir, domain, with_id=True) 160 | self.pred_dir = pred_dir 161 | self.transform = transform 162 | 163 | self.colors = np.array([[255, 255, 255], [0, 64, 128], [64, 128, 0], [243, 152, 0]], dtype=np.int32) 164 | 165 | def __getitem__(self, index): 166 | image, image_id = super().__getitem__(index) 167 | mask = Image.open(self.pred_dir + image_id + '.png') 168 | 169 | if self.transform is not None: 170 | input_dic = {'image':image, 'mask':mask} 171 | output_dic = self.transform(input_dic) 172 | 173 | image = output_dic['image'] 174 | mask = output_dic['mask'] 175 | 176 | return image, mask 177 | 178 | class VOC_Dataset_For_Testing_CAM(VOC_Dataset): 179 | def __init__(self, root_dir, domain, transform=None): 180 | super().__init__(root_dir, domain, with_tags=True, with_mask=True) 181 | self.transform = transform 182 | 183 | cmap_dic, _, class_names = get_color_map_dic() 184 | self.colors = np.asarray([cmap_dic[class_name] for class_name in class_names]) 185 | 186 | data = read_json('./data/VOC_2012.json') 187 | 188 | self.class_dic = data['class_dic'] 189 | self.classes = data['classes'] 190 | 191 | def __getitem__(self, index): 192 | image, tags, mask = super().__getitem__(index) 193 | 194 | if self.transform is not None: 195 | input_dic = {'image':image, 'mask':mask} 196 | output_dic = self.transform(input_dic) 197 | 198 | image = output_dic['image'] 199 | mask = output_dic['mask'] 200 | 201 | label = one_hot_embedding([self.class_dic[tag] for tag in tags], self.classes) 202 | return image, label, mask 203 | 204 | class VOC_Dataset_For_Making_CAM(VOC_Dataset): 205 | def __init__(self, root_dir, domain): 206 | super().__init__(root_dir, domain, with_id=True, with_tags=False, with_mask=False) 207 | 208 | def __getitem__(self, index): 209 | image, image_id = super().__getitem__(index) 210 | label = self.get_label(image_id) 211 | 212 | return image, image_id, label 213 | 214 | def get_label(self, img_name): 215 | res = re.findall(r"\[(.*?)\]", img_name) 216 | label = torch.tensor(list(eval(res[0]))) 217 | return label 218 | 219 | class VOC_Dataset_For_Affinity(VOC_Dataset): 220 | def __init__(self, root_dir, domain, path_index, label_dir, transform=None): 221 | super().__init__(root_dir, domain, with_id=True) 222 | 223 | self.transform = transform 224 | 225 | self.label_dir = label_dir 226 | self.path_index = path_index 227 | 228 | self.extract_aff_lab_func = GetAffinityLabelFromIndices(self.path_index.src_indices, self.path_index.dst_indices) 229 | 230 | def __getitem__(self, idx): 231 | image, image_id = super().__getitem__(idx) 232 | 233 | label = imageio.imread(self.label_dir + image_id + '.png.png') 234 | label = Image.fromarray(label) 235 | 236 | output_dic = self.transform({'image':image, 'mask':label}) 237 | image, label = output_dic['image'], output_dic['mask'] 238 | 239 | return image, self.extract_aff_lab_func(label) 240 | 241 | -------------------------------------------------------------------------------- /core/deeplab_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 * Ltd. All rights reserved. 2 | # author : Sanghyeon Jo 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class ASPPModule(nn.Module): 9 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, norm_fn=None): 10 | super().__init__() 11 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = norm_fn(planes) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | self.initialize([self.atrous_conv, self.bn]) 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | return self.relu(x) 21 | 22 | def initialize(self, modules): 23 | for m in modules: 24 | if isinstance(m, nn.Conv2d): 25 | torch.nn.init.kaiming_normal_(m.weight) 26 | elif isinstance(m, nn.BatchNorm2d): 27 | m.weight.data.fill_(1) 28 | m.bias.data.zero_() 29 | 30 | class ASPP(nn.Module): 31 | def __init__(self, output_stride, norm_fn): 32 | super().__init__() 33 | 34 | inplanes = 2048 35 | 36 | if output_stride == 16: 37 | dilations = [1, 6, 12, 18] 38 | elif output_stride == 8: 39 | dilations = [1, 12, 24, 36] 40 | 41 | self.aspp1 = ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], norm_fn=norm_fn) 42 | self.aspp2 = ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], norm_fn=norm_fn) 43 | self.aspp3 = ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], norm_fn=norm_fn) 44 | self.aspp4 = ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], norm_fn=norm_fn) 45 | 46 | self.global_avg_pool = nn.Sequential( 47 | nn.AdaptiveAvgPool2d((1, 1)), 48 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 49 | norm_fn(256), 50 | nn.ReLU(inplace=True), 51 | ) 52 | 53 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 54 | self.bn1 = norm_fn(256) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.dropout = nn.Dropout(0.5) 57 | 58 | self.initialize([self.conv1, self.bn1] + list(self.global_avg_pool.modules())) 59 | 60 | def forward(self, x): 61 | x1 = self.aspp1(x) 62 | x2 = self.aspp2(x) 63 | x3 = self.aspp3(x) 64 | x4 = self.aspp4(x) 65 | 66 | x5 = self.global_avg_pool(x) 67 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 68 | 69 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 70 | 71 | x = self.conv1(x) 72 | x = self.bn1(x) 73 | x = self.relu(x) 74 | x = self.dropout(x) 75 | 76 | return x 77 | 78 | def initialize(self, modules): 79 | for m in modules: 80 | if isinstance(m, nn.Conv2d): 81 | torch.nn.init.kaiming_normal_(m.weight) 82 | elif isinstance(m, nn.BatchNorm2d): 83 | m.weight.data.fill_(1) 84 | m.bias.data.zero_() 85 | 86 | class Decoder(nn.Module): 87 | def __init__(self, num_classes, low_level_inplanes, norm_fn): 88 | super().__init__() 89 | 90 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 91 | self.bn1 = norm_fn(48) 92 | self.relu = nn.ReLU(inplace=True) 93 | 94 | self.classifier = nn.Sequential( 95 | nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 96 | norm_fn(256), 97 | nn.ReLU(inplace=True), 98 | nn.Dropout(0.5), 99 | 100 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 101 | norm_fn(256), 102 | nn.ReLU(inplace=True), 103 | nn.Dropout(0.1), 104 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1) 105 | ) 106 | 107 | self.initialize([self.conv1, self.bn1] + list(self.classifier.modules())) 108 | 109 | def forward(self, x, x_low_level): 110 | x_low_level = self.conv1(x_low_level) 111 | x_low_level = self.bn1(x_low_level) 112 | x_low_level = self.relu(x_low_level) 113 | 114 | x = F.interpolate(x, size=x_low_level.size()[2:], mode='bilinear', align_corners=True) 115 | x = torch.cat((x, x_low_level), dim=1) 116 | x = self.classifier(x) 117 | 118 | return x 119 | 120 | def initialize(self, modules): 121 | for m in modules: 122 | if isinstance(m, nn.Conv2d): 123 | torch.nn.init.kaiming_normal_(m.weight) 124 | elif isinstance(m, nn.BatchNorm2d): 125 | m.weight.data.fill_(1) 126 | m.bias.data.zero_() -------------------------------------------------------------------------------- /core/puzzle_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | def tile_features(features, num_pieces): 7 | _, _, h, w = features.size() 8 | 9 | num_pieces_per_line = int(math.sqrt(num_pieces)) 10 | 11 | h_per_patch = h // num_pieces_per_line 12 | w_per_patch = w // num_pieces_per_line 13 | 14 | """ 15 | +-----+-----+ 16 | | 1 | 2 | 17 | +-----+-----+ 18 | | 3 | 4 | 19 | +-----+-----+ 20 | 21 | +-----+-----+-----+-----+ 22 | | 1 | 2 | 3 | 4 | 23 | +-----+-----+-----+-----+ 24 | """ 25 | patches = [] 26 | for splitted_features in torch.split(features, h_per_patch, dim=2): 27 | for patch in torch.split(splitted_features, w_per_patch, dim=3): 28 | patches.append(patch) 29 | 30 | return torch.cat(patches, dim=0) 31 | 32 | def merge_features(features, num_pieces, batch_size): 33 | """ 34 | +-----+-----+-----+-----+ 35 | | 1 | 2 | 3 | 4 | 36 | +-----+-----+-----+-----+ 37 | 38 | +-----+-----+ 39 | | 1 | 2 | 40 | +-----+-----+ 41 | | 3 | 4 | 42 | +-----+-----+ 43 | """ 44 | features_list = list(torch.split(features, batch_size)) 45 | num_pieces_per_line = int(math.sqrt(num_pieces)) 46 | 47 | index = 0 48 | ext_h_list = [] 49 | 50 | for _ in range(num_pieces_per_line): 51 | 52 | ext_w_list = [] 53 | for _ in range(num_pieces_per_line): 54 | ext_w_list.append(features_list[index]) 55 | index += 1 56 | 57 | ext_h_list.append(torch.cat(ext_w_list, dim=3)) 58 | 59 | features = torch.cat(ext_h_list, dim=2) 60 | return features 61 | 62 | def puzzle_module(x, func_list, num_pieces): 63 | tiled_x = tile_features(x, num_pieces) 64 | 65 | for func in func_list: 66 | tiled_x = func(tiled_x) 67 | 68 | merged_x = merge_features(tiled_x, num_pieces, x.size()[0]) 69 | return merged_x 70 | -------------------------------------------------------------------------------- /core/sync_batchnorm/__init__(1).py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /core/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /core/sync_batchnorm/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/sync_batchnorm/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /core/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc -------------------------------------------------------------------------------- /core/sync_batchnorm/__pycache__/comm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/sync_batchnorm/__pycache__/comm.cpython-38.pyc -------------------------------------------------------------------------------- /core/sync_batchnorm/__pycache__/replicate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/core/sync_batchnorm/__pycache__/replicate.cpython-38.pyc -------------------------------------------------------------------------------- /core/sync_batchnorm/batchnorm(1).py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /core/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /core/sync_batchnorm/comm(1).py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /core/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /core/sync_batchnorm/replicate(1).py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /core/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /core/sync_batchnorm/unittest(1).py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /core/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /generate_bg_masks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | import scipy 6 | from utils.CRF import MyCRF 7 | import re 8 | import cv2 9 | from skimage import morphology 10 | 11 | 12 | def get_arguments(): 13 | parser = argparse.ArgumentParser(description="HAMIL pytorch implementation") 14 | parser.add_argument("--train_set_root", type=str, default="", help="training images root") 15 | parser.add_argument("--gamma_path", type=str, default="", help="gamma transform root") 16 | parser.add_argument("--gamma_crf_path", type=str, default="", help="background root") 17 | return parser.parse_args() 18 | 19 | def gamma_transform(img_path): 20 | image = Image.open(img_path).convert("L") 21 | image = np.array(image, dtype=np.float32) 22 | image /= 255 23 | gamma = 2.4 24 | out = np.power(image, gamma) 25 | out *= 255 26 | out = out.astype(np.uint8) 27 | 28 | return out 29 | 30 | def seg_to_color(seg): 31 | H, W = seg.shape[0], seg.shape[1] 32 | # white, green, blue, yellow 33 | classes = ["background", "Tumor", "Stroma", "Normal"] 34 | color_map = [[255, 255, 255], [0, 64, 128], [64, 128, 0], [243, 152, 0]] 35 | img = np.zeros((H, W, 3)) 36 | for i in range(H): 37 | for j in range(W): 38 | img[i, j, :] = color_map[seg[i, j]] 39 | return img 40 | 41 | def open_with_crf(img_path, open_img_path): 42 | img = Image.open(img_path).convert("RGB") 43 | img_array = np.array(img, dtype=np.float32) 44 | open_img = Image.open(open_img_path) 45 | open_img = np.array(open_img, dtype=np.float32) 46 | H, W = img_array.shape[0], img_array.shape[1] 47 | mycrf = MyCRF() 48 | p = 1.0 49 | background = open_img.copy() 50 | background /= np.max(background) 51 | foreground = (1 - background) * p 52 | probability_map = np.concatenate( 53 | (foreground.reshape((1, H, W)), background.reshape((1, H, W))), axis=0 54 | ) 55 | out = mycrf.inference(img_array, probability_map) 56 | out = out.argmax(0) 57 | return (out * 255).astype(np.uint8) 58 | 59 | def get_label(img_name): 60 | res = re.findall(r"\[(.*?)\]", img_name) 61 | label = np.array(list(eval(res[0])), dtype=np.uint8) 62 | return label 63 | 64 | 65 | if __name__ == "__main__": 66 | args = get_arguments() 67 | dataset_root = args.train_set_root 68 | gamma_dir = args.gamma_path 69 | gamma_crf_dir = args.gamma_crf_path 70 | img_paths = os.listdir(dataset_root) 71 | # clean 72 | img_paths = filter(img_paths) 73 | print(f"all {len(img_paths)} images") 74 | 75 | for img_name in img_paths: 76 | """save gamma crf background""" 77 | img_path = dataset_root + img_name 78 | img_gamma = gamma_transform(img_path) 79 | img_gamma = Image.fromarray(img_gamma) 80 | gamma_path = gamma_dir + img_name 81 | img_gamma.save(gamma_path) 82 | open_crf = open_with_crf(img_path, gamma_path) 83 | img_open_crf = Image.fromarray(open_crf) 84 | img_open_crf.save(gamma_crf_dir + img_name) 85 | out = Image.open(img_path) 86 | out = np.array(out).astype(np.uint8) 87 | if len(np.unique(out)) == 1: 88 | os.remove(img_path) 89 | out_remove = np.array(out, dtype=bool) 90 | morphology.remove_small_holes(out_remove, 32, 1, True) 91 | out_remove = Image.fromarray(out_remove) 92 | out_remove.save("gamma_crf_train/" + img_name) 93 | 94 | print("done!") 95 | 96 | -------------------------------------------------------------------------------- /networks/ham_net.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.getcwd()) 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.model_zoo as model_zoo 7 | import torch.nn.functional as F 8 | from torchvision.models import vgg16 9 | 10 | 11 | model_urls = { 12 | "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", 13 | "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", 14 | } 15 | 16 | 17 | class ham_net_base(nn.Module): 18 | def __init__(self, num_classes=3): 19 | super(ham_net_base, self).__init__() 20 | self.extra_convs = nn.Sequential( 21 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 22 | nn.ReLU(), 23 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 24 | nn.ReLU(), 25 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 26 | nn.ReLU(), 27 | nn.Conv2d(512, num_classes, 1), 28 | ) 29 | self.classifier_b5 = nn.Conv2d(512, num_classes, kernel_size=1) 30 | self.classifier_b4 = nn.Conv2d(512, num_classes, kernel_size=1) 31 | 32 | self._initialize_weights() 33 | 34 | model_vgg_baseline = vgg16(pretrained=True) 35 | self.features = model_vgg_baseline.features 36 | self.num_classes = num_classes 37 | 38 | def forward(self, x, CAM=False, size=None): 39 | x = self.features[:24](x) 40 | x_b4 = self.classifier_b4(x) 41 | 42 | x = self.features[24:](x) 43 | x_b5 = self.classifier_b5(x) 44 | 45 | x_b6 = self.extra_convs(x) 46 | 47 | logit_b6 = F.avg_pool2d(x_b6, kernel_size=(x_b6.size(2), x_b6.size(3)), padding=0) 48 | logit_b6 = logit_b6.view(-1, self.num_classes) 49 | 50 | logit_b5 = F.avg_pool2d(x_b5, kernel_size=(x_b5.size(2), x_b5.size(3)), padding=0) 51 | logit_b5 = logit_b5.view(-1, self.num_classes) 52 | 53 | logit_b4 = F.avg_pool2d(x_b4, kernel_size=(x_b4.size(2), x_b4.size(3)), padding=0) 54 | logit_b4 = logit_b4.view(-1, self.num_classes) 55 | 56 | if not CAM: 57 | return logit_b6, logit_b5, logit_b4 58 | # return logit_b5, logit_b4, logit_b3 59 | else: 60 | if size == None: 61 | return logit_b6, logit_b5,logit_b4,x_b6,x_b5,x_b4 62 | else: 63 | x_b6 = self.cam_normalize(x_b6, size) 64 | x_b5 = self.cam_normalize(x_b5, size) 65 | x_b4 = self.cam_normalize(x_b4, size) 66 | return logit_b6, logit_b5, logit_b4, x_b6, x_b5, x_b4 67 | 68 | def cam_normalize(self, cam, size): 69 | cam = F.relu(cam) 70 | cam = F.interpolate(cam, size=size, mode="bilinear", align_corners=True) 71 | cam = cam / (F.adaptive_max_pool2d(cam, 1) + 1e-5) 72 | return cam 73 | 74 | def _initialize_weights(self): 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | m.weight.data.normal_(0, 0.01) 78 | if m.bias is not None: 79 | m.bias.data.zero_() 80 | elif isinstance(m, nn.BatchNorm2d): 81 | m.weight.data.fill_(1) 82 | m.bias.data.zero_() 83 | elif isinstance(m, nn.Linear): 84 | m.weight.data.normal_(0, 0.01) 85 | m.bias.data.zero_() 86 | 87 | def get_parameter_groups(self): 88 | groups = ([], [], [], []) 89 | 90 | for name, value in self.named_parameters(): 91 | 92 | if "extra" in name: 93 | if "weight" in name: 94 | groups[2].append(value) 95 | else: 96 | groups[3].append(value) 97 | else: 98 | if "weight" in name: 99 | groups[0].append(value) 100 | else: 101 | groups[1].append(value) 102 | return groups 103 | 104 | 105 | class ham_net(ham_net_base): 106 | def __init__(self, num_classes=3): 107 | super(ham_net, self).__init__(num_classes) 108 | del self.features[30] 109 | del self.features[23] -------------------------------------------------------------------------------- /tools/ai/__pycache__/augment_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/tools/ai/__pycache__/augment_utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/ai/__pycache__/demo_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/tools/ai/__pycache__/demo_utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/ai/__pycache__/evaluate_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/tools/ai/__pycache__/evaluate_utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/ai/__pycache__/log_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/tools/ai/__pycache__/log_utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/ai/__pycache__/optim_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/tools/ai/__pycache__/optim_utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/ai/__pycache__/randaugment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/tools/ai/__pycache__/randaugment.cpython-38.pyc -------------------------------------------------------------------------------- /tools/ai/__pycache__/torch_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/tools/ai/__pycache__/torch_utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/ai/augment_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import numpy as np 4 | 5 | from PIL import Image 6 | 7 | def convert_OpenCV_to_PIL(image): 8 | return Image.fromarray(image[..., ::-1]) 9 | 10 | def convert_PIL_to_OpenCV(image): 11 | return np.asarray(image)[..., ::-1] 12 | 13 | class RandomResize: 14 | def __init__(self, min_image_size, max_image_size): 15 | self.min_image_size = min_image_size 16 | self.max_image_size = max_image_size 17 | 18 | self.modes = [Image.BICUBIC, Image.NEAREST] 19 | 20 | def __call__(self, image, mode=Image.BICUBIC): 21 | rand_image_size = random.randint(self.min_image_size, self.max_image_size) 22 | 23 | w, h = image.size 24 | if w < h: 25 | scale = rand_image_size / h 26 | else: 27 | scale = rand_image_size / w 28 | 29 | size = (int(round(w*scale)), int(round(h*scale))) 30 | if size[0] == w and size[1] == h: 31 | return image 32 | 33 | return image.resize(size, mode) 34 | 35 | class RandomResize_For_Segmentation: 36 | def __init__(self, min_image_size, max_image_size): 37 | self.min_image_size = min_image_size 38 | self.max_image_size = max_image_size 39 | 40 | self.modes = [Image.BICUBIC, Image.NEAREST] 41 | 42 | def __call__(self, data): 43 | image, mask = data['image'], data['mask'] 44 | 45 | rand_image_size = random.randint(self.min_image_size, self.max_image_size) 46 | 47 | w, h = image.size 48 | if w < h: 49 | scale = rand_image_size / h 50 | else: 51 | scale = rand_image_size / w 52 | 53 | size = (int(round(w*scale)), int(round(h*scale))) 54 | if size[0] == w and size[1] == h: 55 | pass 56 | else: 57 | data['image'] = image.resize(size, Image.BICUBIC) 58 | data['mask'] = mask.resize(size, Image.NEAREST) 59 | 60 | return data 61 | 62 | class RandomHorizontalFlip: 63 | def __init__(self): 64 | pass 65 | 66 | def __call__(self, image): 67 | if bool(random.getrandbits(1)): 68 | return image.transpose(Image.FLIP_LEFT_RIGHT) 69 | return image 70 | 71 | class RandomHorizontalFlip_For_Segmentation: 72 | def __init__(self): 73 | pass 74 | 75 | def __call__(self, data): 76 | image, mask = data['image'], data['mask'] 77 | 78 | if bool(random.getrandbits(1)): 79 | data['image'] = image.transpose(Image.FLIP_LEFT_RIGHT) 80 | data['mask'] = mask.transpose(Image.FLIP_LEFT_RIGHT) 81 | 82 | return data 83 | 84 | class Normalize: 85 | def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 86 | self.mean = mean 87 | self.std = std 88 | 89 | def __call__(self, image): 90 | image = np.asarray(image) 91 | norm_image = np.empty_like(image, np.float32) 92 | 93 | norm_image[..., 0] = (image[..., 0] / 255. - self.mean[0]) / self.std[0] 94 | norm_image[..., 1] = (image[..., 1] / 255. - self.mean[1]) / self.std[1] 95 | norm_image[..., 2] = (image[..., 2] / 255. - self.mean[2]) / self.std[2] 96 | 97 | return norm_image 98 | 99 | class Normalize_For_Segmentation: 100 | def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 101 | self.mean = mean 102 | self.std = std 103 | 104 | def __call__(self, data): 105 | image, mask = data['image'], data['mask'] 106 | 107 | image = np.asarray(image, dtype=np.float32) 108 | mask = np.asarray(mask, dtype=np.int64) 109 | 110 | norm_image = np.empty_like(image, np.float32) 111 | 112 | norm_image[..., 0] = (image[..., 0] / 255. - self.mean[0]) / self.std[0] 113 | norm_image[..., 1] = (image[..., 1] / 255. - self.mean[1]) / self.std[1] 114 | norm_image[..., 2] = (image[..., 2] / 255. - self.mean[2]) / self.std[2] 115 | 116 | data['image'] = norm_image 117 | data['mask'] = mask 118 | 119 | return data 120 | 121 | class Top_Left_Crop: 122 | def __init__(self, crop_size, channels=3): 123 | self.bg_value = 0 124 | self.crop_size = crop_size 125 | self.crop_shape = (self.crop_size, self.crop_size, channels) 126 | 127 | def __call__(self, image): 128 | h, w, c = image.shape 129 | 130 | ch = min(self.crop_size, h) 131 | cw = min(self.crop_size, w) 132 | 133 | cropped_image = np.ones(self.crop_shape, image.dtype) * self.bg_value 134 | cropped_image[:ch, :cw] = image[:ch, :cw] 135 | 136 | return cropped_image 137 | 138 | class Top_Left_Crop_For_Segmentation: 139 | def __init__(self, crop_size, channels=3): 140 | self.bg_value = 0 141 | self.crop_size = crop_size 142 | self.crop_shape = (self.crop_size, self.crop_size, channels) 143 | self.crop_shape_for_mask = (self.crop_size, self.crop_size) 144 | 145 | def __call__(self, data): 146 | image, mask = data['image'], data['mask'] 147 | 148 | h, w, c = image.shape 149 | 150 | ch = min(self.crop_size, h) 151 | cw = min(self.crop_size, w) 152 | 153 | cropped_image = np.ones(self.crop_shape, image.dtype) * self.bg_value 154 | cropped_image[:ch, :cw] = image[:ch, :cw] 155 | 156 | cropped_mask = np.ones(self.crop_shape_for_mask, mask.dtype) * 255 157 | cropped_mask[:ch, :cw] = mask[:ch, :cw] 158 | 159 | data['image'] = cropped_image 160 | data['mask'] = cropped_mask 161 | 162 | return data 163 | 164 | class RandomCrop: 165 | def __init__(self, crop_size, channels=3, with_bbox=False): 166 | self.bg_value = 0 167 | self.with_bbox = with_bbox 168 | self.crop_size = crop_size 169 | self.crop_shape = (self.crop_size, self.crop_size, channels) 170 | 171 | def get_random_crop_box(self, image): 172 | h, w, c = image.shape 173 | 174 | ch = min(self.crop_size, h) 175 | cw = min(self.crop_size, w) 176 | 177 | w_space = w - self.crop_size 178 | h_space = h - self.crop_size 179 | 180 | if w_space > 0: 181 | cont_left = 0 182 | img_left = random.randrange(w_space + 1) 183 | else: 184 | cont_left = random.randrange(-w_space + 1) 185 | img_left = 0 186 | 187 | if h_space > 0: 188 | cont_top = 0 189 | img_top = random.randrange(h_space + 1) 190 | else: 191 | cont_top = random.randrange(-h_space + 1) 192 | img_top = 0 193 | 194 | dst_bbox = { 195 | 'xmin' : cont_left, 'ymin' : cont_top, 196 | 'xmax' : cont_left+cw, 'ymax' : cont_top+ch 197 | } 198 | src_bbox = { 199 | 'xmin' : img_left, 'ymin' : img_top, 200 | 'xmax' : img_left+cw, 'ymax' : img_top+ch 201 | } 202 | 203 | return dst_bbox, src_bbox 204 | 205 | def __call__(self, image, bbox_dic=None): 206 | if bbox_dic is None: 207 | dst_bbox, src_bbox = self.get_random_crop_box(image) 208 | else: 209 | dst_bbox, src_bbox = bbox_dic['dst_bbox'], bbox_dic['src_bbox'] 210 | 211 | cropped_image = np.ones(self.crop_shape, image.dtype) * self.bg_value 212 | cropped_image[dst_bbox['ymin']:dst_bbox['ymax'], dst_bbox['xmin']:dst_bbox['xmax']] = \ 213 | image[src_bbox['ymin']:src_bbox['ymax'], src_bbox['xmin']:src_bbox['xmax']] 214 | 215 | if self.with_bbox: 216 | return cropped_image, {'dst_bbox':dst_bbox, 'src_bbox':src_bbox} 217 | else: 218 | return cropped_image 219 | 220 | class RandomCrop_For_Segmentation(RandomCrop): 221 | def __init__(self, crop_size): 222 | super().__init__(crop_size) 223 | 224 | self.crop_shape_for_mask = (self.crop_size, self.crop_size) 225 | 226 | def __call__(self, data): 227 | image, mask = data['image'], data['mask'] 228 | 229 | dst_bbox, src_bbox = self.get_random_crop_box(image) 230 | 231 | cropped_image = np.ones(self.crop_shape, image.dtype) * self.bg_value 232 | cropped_image[dst_bbox['ymin']:dst_bbox['ymax'], dst_bbox['xmin']:dst_bbox['xmax']] = \ 233 | image[src_bbox['ymin']:src_bbox['ymax'], src_bbox['xmin']:src_bbox['xmax']] 234 | 235 | cropped_mask = np.ones(self.crop_shape_for_mask, mask.dtype) * 255 236 | cropped_mask[dst_bbox['ymin']:dst_bbox['ymax'], dst_bbox['xmin']:dst_bbox['xmax']] = \ 237 | mask[src_bbox['ymin']:src_bbox['ymax'], src_bbox['xmin']:src_bbox['xmax']] 238 | 239 | data['image'] = cropped_image 240 | data['mask'] = cropped_mask 241 | 242 | return data 243 | 244 | class Transpose: 245 | def __init__(self): 246 | pass 247 | 248 | def __call__(self, image): 249 | return image.transpose((2, 0, 1)) 250 | 251 | class Transpose_For_Segmentation: 252 | def __init__(self): 253 | pass 254 | 255 | def __call__(self, data): 256 | # h, w, c -> c, h, w 257 | data['image'] = data['image'].transpose((2, 0, 1)) 258 | return data 259 | 260 | class Resize_For_Mask: 261 | def __init__(self, size): 262 | self.size = (size, size) 263 | 264 | def __call__(self, data): 265 | mask = Image.fromarray(data['mask'].astype(np.uint8)) 266 | mask = mask.resize(self.size, Image.NEAREST) 267 | data['mask'] = np.asarray(mask, dtype=np.uint64) 268 | return data 269 | -------------------------------------------------------------------------------- /tools/ai/demo_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import numpy as np 4 | 5 | from PIL import Image 6 | 7 | def get_strided_size(orig_size, stride): 8 | return ((orig_size[0]-1)//stride+1, (orig_size[1]-1)//stride+1) 9 | 10 | def get_strided_up_size(orig_size, stride): 11 | strided_size = get_strided_size(orig_size, stride) 12 | return strided_size[0]*stride, strided_size[1]*stride 13 | 14 | def imshow(image, delay=0, mode='RGB', title='show'): 15 | if mode == 'RGB': 16 | demo_image = image[..., ::-1] 17 | else: 18 | demo_image = image 19 | 20 | cv2.imshow(title, demo_image) 21 | if delay >= 0: 22 | cv2.waitKey(delay) 23 | 24 | def transpose(image): 25 | return image.transpose((1, 2, 0)) 26 | 27 | def denormalize(image, mean=None, std=None, dtype=np.uint8, tp=True): 28 | if tp: 29 | image = transpose(image) 30 | 31 | if mean is not None: 32 | image = (image * std) + mean 33 | 34 | if dtype == np.uint8: 35 | image *= 255. 36 | return image.astype(np.uint8) 37 | else: 38 | return image 39 | 40 | def colormap(cam, shape=None, mode=cv2.COLORMAP_JET): 41 | if shape is not None: 42 | h, w, c = shape 43 | cam = cv2.resize(cam, (w, h)) 44 | cam = cv2.applyColorMap(cam, mode) 45 | return cam 46 | 47 | def decode_from_colormap(data, colors): 48 | ignore = (data == 255).astype(np.int32) 49 | 50 | mask = 1 - ignore 51 | data *= mask 52 | 53 | h, w = data.shape 54 | image = colors[data.reshape((h * w))].reshape((h, w, 3)) 55 | 56 | ignore = np.concatenate([ignore[..., np.newaxis], ignore[..., np.newaxis], ignore[..., np.newaxis]], axis=-1) 57 | image[ignore.astype(np.bool)] = 255 58 | return image 59 | 60 | def normalize(cam, epsilon=1e-5): 61 | cam = np.maximum(cam, 0) 62 | max_value = np.max(cam, axis=(0, 1), keepdims=True) 63 | return np.maximum(cam - epsilon, 0) / (max_value + epsilon) 64 | 65 | def crf_inference(img, probs, t=10, scale_factor=1, labels=21): 66 | import pydensecrf.densecrf as dcrf 67 | from pydensecrf.utils import unary_from_softmax 68 | 69 | h, w = img.shape[:2] 70 | n_labels = labels 71 | 72 | d = dcrf.DenseCRF2D(w, h, n_labels) 73 | 74 | unary = unary_from_softmax(probs) 75 | unary = np.ascontiguousarray(unary) 76 | 77 | d.setUnaryEnergy(unary) 78 | d.addPairwiseGaussian(sxy=3/scale_factor, compat=3) 79 | d.addPairwiseBilateral(sxy=80/scale_factor, srgb=13, rgbim=np.copy(img), compat=10) 80 | Q = d.inference(t) 81 | 82 | return np.array(Q).reshape((n_labels, h, w)) 83 | 84 | def crf_with_alpha(ori_image, cams, alpha): 85 | # h, w, c -> c, h, w 86 | # cams = cams.transpose((2, 0, 1)) 87 | 88 | bg_score = np.power(1 - np.max(cams, axis=0, keepdims=True), alpha) 89 | bgcam_score = np.concatenate((bg_score, cams), axis=0) 90 | 91 | cams_with_crf = crf_inference(ori_image, bgcam_score, labels=bgcam_score.shape[0]) 92 | # return cams_with_crf.transpose((1, 2, 0)) 93 | return cams_with_crf 94 | 95 | def crf_inference_label(img, labels, t=10, n_labels=4, gt_prob=0.7): 96 | import pydensecrf.densecrf as dcrf 97 | from pydensecrf.utils import unary_from_labels 98 | 99 | h, w = img.shape[:2] 100 | 101 | d = dcrf.DenseCRF2D(w, h, n_labels) 102 | 103 | unary = unary_from_labels(labels, n_labels, gt_prob=gt_prob, zero_unsure=False) 104 | 105 | d.setUnaryEnergy(unary) 106 | d.addPairwiseGaussian(sxy=3, compat=3) 107 | d.addPairwiseBilateral(sxy=50, srgb=5, rgbim=np.ascontiguousarray(np.copy(img)), compat=10) 108 | 109 | q = d.inference(t) 110 | 111 | return np.argmax(np.array(q).reshape((n_labels, h, w)), axis=0) -------------------------------------------------------------------------------- /tools/ai/evaluate_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tools.general.json_utils import read_json 4 | 5 | def calculate_for_tags(pred_tags, gt_tags): 6 | """This function calculates precision, recall, and f1-score using tags. 7 | 8 | Args: 9 | pred_tags: 10 | The type of variable is list. 11 | The type of each element is string. 12 | 13 | gt_tags: 14 | The type of variable is list. 15 | the type of each element is string. 16 | 17 | Returns: 18 | precision: 19 | pass 20 | 21 | recall: 22 | pass 23 | 24 | f1-score: 25 | pass 26 | """ 27 | if len(pred_tags) == 0 and len(gt_tags) == 0: 28 | return 100, 100, 100 29 | elif len(pred_tags) == 0 or len(gt_tags) == 0: 30 | return 0, 0, 0 31 | 32 | pred_tags = np.asarray(pred_tags) 33 | gt_tags = np.asarray(gt_tags) 34 | 35 | precision = pred_tags[:, np.newaxis] == gt_tags[np.newaxis, :] 36 | recall = gt_tags[:, np.newaxis] == pred_tags[np.newaxis, :] 37 | 38 | precision = np.sum(precision) / len(precision) * 100 39 | recall = np.sum(recall) / len(recall) * 100 40 | 41 | if precision == 0 and recall == 0: 42 | f1_score = 0 43 | else: 44 | f1_score = 2 * ((precision * recall) / (precision + recall)) 45 | 46 | return precision, recall, f1_score 47 | 48 | def calculate_mIoU(pred_mask, gt_mask): 49 | """This function is to calculate precision, recall, and f1-score using tags. 50 | 51 | Args: 52 | pred_mask: 53 | The type of variable is numpy array. 54 | 55 | gt_mask: 56 | The type of variable is numpy array. 57 | 58 | Returns: 59 | miou: 60 | miou is meanIU. 61 | """ 62 | inter = np.logical_and(pred_mask, gt_mask) 63 | union = np.logical_or(pred_mask, gt_mask) 64 | 65 | epsilon = 1e-5 66 | miou = (np.sum(inter) + epsilon) / (np.sum(union) + epsilon) 67 | return miou * 100 68 | 69 | class Calculator_For_mIoU: 70 | def __init__(self): 71 | self.class_names = ['background'] + ['Tumor', 'Stroma', 'Normal'] 72 | self.classes = len(self.class_names) 73 | 74 | self.clear() 75 | 76 | def get_data(self, pred_mask, gt_mask): 77 | obj_mask = gt_mask<255 78 | correct_mask = (pred_mask==gt_mask) * obj_mask 79 | 80 | P_list, T_list, TP_list = [], [], [] 81 | for i in range(self.classes): 82 | P_list.append(np.sum((pred_mask==i)*obj_mask)) 83 | T_list.append(np.sum((gt_mask==i)*obj_mask)) 84 | TP_list.append(np.sum((gt_mask==i)*correct_mask)) 85 | 86 | return (P_list, T_list, TP_list) 87 | 88 | def add_using_data(self, data): 89 | P_list, T_list, TP_list = data 90 | for i in range(self.classes): 91 | self.P[i] += P_list[i] 92 | self.T[i] += T_list[i] 93 | self.TP[i] += TP_list[i] 94 | 95 | def add(self, pred_mask, gt_mask): 96 | obj_mask = gt_mask<255 97 | correct_mask = (pred_mask==gt_mask) * obj_mask 98 | 99 | for i in range(self.classes): 100 | self.P[i] += np.sum((pred_mask==i)*obj_mask) 101 | self.T[i] += np.sum((gt_mask==i)*obj_mask) 102 | self.TP[i] += np.sum((gt_mask==i)*correct_mask) 103 | 104 | def get(self, detail=False, clear=True): 105 | IoU_dic = {} 106 | IoU_list = [] 107 | 108 | FP_list = [] # over activation 109 | FN_list = [] # under activation 110 | 111 | for i in range(self.classes): 112 | IoU = self.TP[i]/(self.T[i]+self.P[i]-self.TP[i]+1e-10) * 100 113 | FP = (self.P[i]-self.TP[i])/(self.T[i] + self.P[i] - self.TP[i] + 1e-10) 114 | FN = (self.T[i]-self.TP[i])/(self.T[i] + self.P[i] - self.TP[i] + 1e-10) 115 | 116 | IoU_dic[self.class_names[i]] = IoU 117 | 118 | IoU_list.append(IoU) 119 | FP_list.append(FP) 120 | FN_list.append(FN) 121 | 122 | mIoU = np.mean(np.asarray(IoU_list)) 123 | mIoU_foreground = np.mean(np.asarray(IoU_list)[1:]) 124 | 125 | FP = np.mean(np.asarray(FP_list)) 126 | FN = np.mean(np.asarray(FN_list)) 127 | 128 | if clear: 129 | self.clear() 130 | 131 | if detail: 132 | return mIoU, mIoU_foreground, IoU_dic, FP, FN 133 | else: 134 | return mIoU, mIoU_foreground 135 | 136 | def clear(self): 137 | self.TP = [] 138 | self.P = [] 139 | self.T = [] 140 | 141 | for _ in range(self.classes): 142 | self.TP.append(0) 143 | self.P.append(0) 144 | self.T.append(0) -------------------------------------------------------------------------------- /tools/ai/log_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2020 * Ltd. All rights reserved. 2 | # author : Sanghyeon Jo 3 | 4 | import numpy as np 5 | from tools.general.txt_utils import add_txt 6 | 7 | def log_print(message, path): 8 | """This function shows message and saves message. 9 | 10 | Args: 11 | pred_tags: 12 | The type of variable is list. 13 | The type of each element is string. 14 | 15 | gt_tags: 16 | The type of variable is list. 17 | the type of each element is string. 18 | """ 19 | print(message) 20 | add_txt(path, message) 21 | 22 | class Logger: 23 | def __init__(self): 24 | pass 25 | 26 | class Average_Meter: 27 | def __init__(self, keys): 28 | self.keys = keys 29 | self.clear() 30 | 31 | def add(self, dic): 32 | for key, value in dic.items(): 33 | self.data_dic[key].append(value) 34 | 35 | def get(self, keys=None, clear=False): 36 | if keys is None: 37 | keys = self.keys 38 | 39 | dataset = [float(np.mean(self.data_dic[key])) for key in keys] 40 | if clear: 41 | self.clear() 42 | 43 | if len(dataset) == 1: 44 | dataset = dataset[0] 45 | 46 | return dataset 47 | 48 | def clear(self): 49 | self.data_dic = {key : [] for key in self.keys} 50 | 51 | -------------------------------------------------------------------------------- /tools/ai/optim_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .torch_utils import * 3 | 4 | class PolyOptimizer(torch.optim.SGD): 5 | def __init__(self, params, lr, weight_decay, max_step, momentum=0.9, nesterov=False): 6 | super().__init__(params, lr, weight_decay, nesterov=nesterov) 7 | 8 | self.global_step = 0 9 | self.max_step = max_step 10 | self.momentum = momentum 11 | 12 | self.__initial_lr = [group['lr'] for group in self.param_groups] 13 | 14 | def step(self, closure=None): 15 | if self.global_step < self.max_step: 16 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum 17 | 18 | for i in range(len(self.param_groups)): 19 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult 20 | 21 | super().step(closure) 22 | 23 | self.global_step += 1 24 | -------------------------------------------------------------------------------- /tools/ai/randaugment.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from 2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 5 | import logging 6 | import random 7 | 8 | import numpy as np 9 | import PIL 10 | import PIL.ImageOps 11 | import PIL.ImageEnhance 12 | import PIL.ImageDraw 13 | from PIL import Image 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | PARAMETER_MAX = 10 18 | 19 | 20 | def AutoContrast(img, **kwarg): 21 | return PIL.ImageOps.autocontrast(img) 22 | 23 | 24 | def Brightness(img, v, max_v, bias=0): 25 | v = _float_parameter(v, max_v) + bias 26 | return PIL.ImageEnhance.Brightness(img).enhance(v) 27 | 28 | 29 | def Color(img, v, max_v, bias=0): 30 | v = _float_parameter(v, max_v) + bias 31 | return PIL.ImageEnhance.Color(img).enhance(v) 32 | 33 | 34 | def Contrast(img, v, max_v, bias=0): 35 | v = _float_parameter(v, max_v) + bias 36 | return PIL.ImageEnhance.Contrast(img).enhance(v) 37 | 38 | 39 | def Cutout(img, v, max_v, bias=0): 40 | if v == 0: 41 | return img 42 | v = _float_parameter(v, max_v) + bias 43 | v = int(v * min(img.size)) 44 | return CutoutAbs(img, v) 45 | 46 | 47 | def CutoutAbs(img, v, **kwarg): 48 | w, h = img.size 49 | x0 = np.random.uniform(0, w) 50 | y0 = np.random.uniform(0, h) 51 | x0 = int(max(0, x0 - v / 2.)) 52 | y0 = int(max(0, y0 - v / 2.)) 53 | x1 = int(min(w, x0 + v)) 54 | y1 = int(min(h, y0 + v)) 55 | xy = (x0, y0, x1, y1) 56 | 57 | # gray 58 | # color = (127, 127, 127) 59 | 60 | # black 61 | color = (0, 0, 0) 62 | 63 | img = img.copy() 64 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 65 | return img 66 | 67 | 68 | def Equalize(img, **kwarg): 69 | return PIL.ImageOps.equalize(img) 70 | 71 | 72 | def Identity(img, **kwarg): 73 | return img 74 | 75 | 76 | def Invert(img, **kwarg): 77 | return PIL.ImageOps.invert(img) 78 | 79 | 80 | def Posterize(img, v, max_v, bias=0): 81 | v = _int_parameter(v, max_v) + bias 82 | return PIL.ImageOps.posterize(img, v) 83 | 84 | 85 | def Rotate(img, v, max_v, bias=0): 86 | v = _int_parameter(v, max_v) + bias 87 | if random.random() < 0.5: 88 | v = -v 89 | return img.rotate(v) 90 | 91 | 92 | def Sharpness(img, v, max_v, bias=0): 93 | v = _float_parameter(v, max_v) + bias 94 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 95 | 96 | 97 | def ShearX(img, v, max_v, bias=0): 98 | v = _float_parameter(v, max_v) + bias 99 | if random.random() < 0.5: 100 | v = -v 101 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 102 | 103 | 104 | def ShearY(img, v, max_v, bias=0): 105 | v = _float_parameter(v, max_v) + bias 106 | if random.random() < 0.5: 107 | v = -v 108 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 109 | 110 | 111 | def Solarize(img, v, max_v, bias=0): 112 | v = _int_parameter(v, max_v) + bias 113 | return PIL.ImageOps.solarize(img, 256 - v) 114 | 115 | 116 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 117 | v = _int_parameter(v, max_v) + bias 118 | if random.random() < 0.5: 119 | v = -v 120 | img_np = np.array(img).astype(np.int) 121 | img_np = img_np + v 122 | img_np = np.clip(img_np, 0, 255) 123 | img_np = img_np.astype(np.uint8) 124 | img = Image.fromarray(img_np) 125 | return PIL.ImageOps.solarize(img, threshold) 126 | 127 | 128 | def TranslateX(img, v, max_v, bias=0): 129 | v = _float_parameter(v, max_v) + bias 130 | if random.random() < 0.5: 131 | v = -v 132 | v = int(v * img.size[0]) 133 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 134 | 135 | 136 | def TranslateY(img, v, max_v, bias=0): 137 | v = _float_parameter(v, max_v) + bias 138 | if random.random() < 0.5: 139 | v = -v 140 | v = int(v * img.size[1]) 141 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 142 | 143 | 144 | def _float_parameter(v, max_v): 145 | return float(v) * max_v / PARAMETER_MAX 146 | 147 | 148 | def _int_parameter(v, max_v): 149 | return int(v * max_v / PARAMETER_MAX) 150 | 151 | 152 | def fixmatch_augment_pool(): 153 | # FixMatch paper 154 | augs = [(AutoContrast, None, None), 155 | (Brightness, 0.9, 0.05), 156 | (Color, 0.9, 0.05), 157 | (Contrast, 0.9, 0.05), 158 | (Equalize, None, None), 159 | (Identity, None, None), 160 | (Posterize, 4, 4), 161 | (Rotate, 30, 0), 162 | (Sharpness, 0.9, 0.05), 163 | (ShearX, 0.3, 0), 164 | (ShearY, 0.3, 0), 165 | (Solarize, 256, 0), 166 | (TranslateX, 0.3, 0), 167 | (TranslateY, 0.3, 0)] 168 | return augs 169 | 170 | 171 | def my_augment_pool(): 172 | # Test 173 | augs = [(AutoContrast, None, None), 174 | (Brightness, 1.8, 0.1), 175 | (Color, 1.8, 0.1), 176 | (Contrast, 1.8, 0.1), 177 | (Cutout, 0.2, 0), 178 | (Equalize, None, None), 179 | (Invert, None, None), 180 | (Posterize, 4, 4), 181 | (Rotate, 30, 0), 182 | (Sharpness, 1.8, 0.1), 183 | (ShearX, 0.3, 0), 184 | (ShearY, 0.3, 0), 185 | (Solarize, 256, 0), 186 | (SolarizeAdd, 110, 0), 187 | (TranslateX, 0.45, 0), 188 | (TranslateY, 0.45, 0)] 189 | return augs 190 | 191 | 192 | class RandAugmentPC(object): 193 | def __init__(self, n, m): 194 | assert n >= 1 195 | assert 1 <= m <= 10 196 | self.n = n 197 | self.m = m 198 | self.augment_pool = my_augment_pool() 199 | 200 | def __call__(self, img): 201 | ops = random.choices(self.augment_pool, k=self.n) 202 | for op, max_v, bias in ops: 203 | prob = np.random.uniform(0.2, 0.8) 204 | if random.random() + prob >= 1: 205 | img = op(img, v=self.m, max_v=max_v, bias=bias) 206 | img = CutoutAbs(img, int(32*0.5)) 207 | return img 208 | 209 | 210 | class RandAugmentMC(object): 211 | def __init__(self, n, m): 212 | assert n >= 1 213 | assert 1 <= m <= 10 214 | self.n = n 215 | self.m = m 216 | self.augment_pool = fixmatch_augment_pool() 217 | 218 | def __call__(self, img): 219 | ops = random.choices(self.augment_pool, k=self.n) 220 | for op, max_v, bias in ops: 221 | v = np.random.randint(1, self.m) 222 | if random.random() < 0.5: 223 | img = op(img, v=v, max_v=max_v, bias=bias) 224 | img = CutoutAbs(img, int(32*0.5)) 225 | return img 226 | -------------------------------------------------------------------------------- /tools/ai/torch_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import torch 4 | import random 5 | import numpy as np 6 | 7 | import torch.nn.functional as F 8 | 9 | from torch.optim.lr_scheduler import LambdaLR 10 | 11 | def set_seed(seed): 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | 15 | torch.manual_seed(seed) 16 | if torch.cuda.is_available(): 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | def rotation(x, k): 20 | return torch.rot90(x, k, (1, 2)) 21 | 22 | def interleave(x, size): 23 | s = list(x.shape) 24 | return x.reshape([-1, size] + s[1:]).transpose(0, 1).reshape([-1] + s[1:]) 25 | 26 | def de_interleave(x, size): 27 | s = list(x.shape) 28 | return x.reshape([size, -1] + s[1:]).transpose(0, 1).reshape([-1] + s[1:]) 29 | 30 | def resize_for_tensors(tensors, size, mode='bilinear', align_corners=False): 31 | return F.interpolate(tensors, size, mode=mode, align_corners=align_corners) 32 | 33 | def L1_Loss(A_tensors, B_tensors): 34 | return torch.abs(A_tensors - B_tensors) 35 | 36 | def L2_Loss(A_tensors, B_tensors): 37 | return torch.pow(A_tensors - B_tensors, 2) 38 | 39 | # ratio = 0.2, top=20% 40 | def Online_Hard_Example_Mining(values, ratio=0.2): 41 | b, c, h, w = values.size() 42 | return torch.topk(values.reshape(b, -1), k=int(c * h * w * ratio), dim=-1)[0] 43 | 44 | def shannon_entropy_loss(logits, activation=torch.sigmoid, epsilon=1e-5): 45 | v = activation(logits) 46 | return -torch.sum(v * torch.log(v+epsilon), dim=1).mean() 47 | 48 | def make_cam(x, epsilon=1e-5): 49 | # relu(x) = max(x, 0) 50 | x = F.relu(x) 51 | 52 | b, c, h, w = x.size() 53 | 54 | flat_x = x.view(b, c, (h * w)) 55 | max_value = flat_x.max(axis=-1)[0].view((b, c, 1, 1)) 56 | 57 | return F.relu(x - epsilon) / (max_value + epsilon) 58 | 59 | def one_hot_embedding(label, classes): 60 | """Embedding labels to one-hot form. 61 | 62 | Args: 63 | labels: (int) class labels. 64 | num_classes: (int) number of classes. 65 | 66 | Returns: 67 | (tensor) encoded labels, sized [N, #classes]. 68 | """ 69 | 70 | vector = np.zeros((classes), dtype = np.float32) 71 | if len(label) > 0: 72 | vector[label] = 1. 73 | return vector 74 | 75 | def calculate_parameters(model): 76 | return sum(param.numel() for param in model.parameters())/1000000.0 77 | 78 | def get_learning_rate_from_optimizer(optimizer): 79 | return optimizer.param_groups[0]['lr'] 80 | 81 | def get_numpy_from_tensor(tensor): 82 | return tensor.cpu().detach().numpy() 83 | 84 | def load_model(model, model_path, parallel=False): 85 | if parallel: 86 | model.module.load_state_dict(torch.load(model_path)) 87 | else: 88 | model.load_state_dict(torch.load(model_path)) 89 | 90 | def save_model(model, model_path, parallel=False): 91 | if parallel: 92 | torch.save(model.module.state_dict(), model_path) 93 | else: 94 | torch.save(model.state_dict(), model_path) 95 | 96 | def transfer_model(pretrained_model, model): 97 | pretrained_dict = pretrained_model.state_dict() 98 | model_dict = model.state_dict() 99 | 100 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict} 101 | 102 | model_dict.update(pretrained_dict) 103 | model.load_state_dict(model_dict) 104 | 105 | def get_learning_rate(optimizer): 106 | lr=[] 107 | for param_group in optimizer.param_groups: 108 | lr +=[ param_group['lr'] ] 109 | return lr 110 | 111 | def get_cosine_schedule_with_warmup(optimizer, 112 | warmup_iteration, 113 | max_iteration, 114 | cycles=7./16. 115 | ): 116 | def _lr_lambda(current_iteration): 117 | if current_iteration < warmup_iteration: 118 | return float(current_iteration) / float(max(1, warmup_iteration)) 119 | 120 | no_progress = float(current_iteration - warmup_iteration) / float(max(1, max_iteration - warmup_iteration)) 121 | return max(0., math.cos(math.pi * cycles * no_progress)) 122 | 123 | return LambdaLR(optimizer, _lr_lambda, -1) -------------------------------------------------------------------------------- /tools/dataset/__pycache__/voc_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/tools/dataset/__pycache__/voc_utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/dataset/voc_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def color_map(N = 256): 4 | def bitget(byteval, idx): 5 | return ((byteval & (1 << idx)) != 0) 6 | 7 | cmap = np.zeros((N, 3), dtype = np.uint8) 8 | for i in range(N): 9 | r = g = b = 0 10 | c = i 11 | for j in range(8): 12 | r = r | (bitget(c, 0) << 7-j) 13 | g = g | (bitget(c, 1) << 7-j) 14 | b = b | (bitget(c, 2) << 7-j) 15 | c = c >> 3 16 | 17 | cmap[i] = np.array([b, g, r]) 18 | 19 | return cmap 20 | 21 | def get_color_map_dic(): 22 | labels = ['background', 23 | 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 24 | 'bus', 'car', 'cat', 'chair', 'cow', 25 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 26 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor', 'void'] 27 | 28 | # n_classes = 21 29 | n_classes = len(labels) 30 | 31 | h = 20 32 | w = 500 33 | 34 | color_index_list = [index for index in range(n_classes)] 35 | 36 | cmap = color_map() 37 | cmap_dic = {label : cmap[color_index] for label, color_index in zip(labels, range(n_classes))} 38 | cmap_image = np.empty((h * len(labels), w, 3), dtype = np.uint8) 39 | 40 | for color_index in color_index_list: 41 | cmap_image[color_index * h : (color_index + 1) * h, :] = cmap[color_index] 42 | 43 | return cmap_dic, cmap_image, labels 44 | -------------------------------------------------------------------------------- /tools/general/__pycache__/io_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/tools/general/__pycache__/io_utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/general/__pycache__/json_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/tools/general/__pycache__/json_utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/general/__pycache__/time_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/tools/general/__pycache__/time_utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/general/__pycache__/txt_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/tools/general/__pycache__/txt_utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/general/__pycache__/xml_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/HAMIL/1b87c5085a17a904051b2cbd7b4e293cc36f1b62/tools/general/__pycache__/xml_utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/general/io_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2020 * Ltd. All rights reserved. 2 | # author : Sanghyeon Jo 3 | 4 | import os 5 | import random 6 | import argparse 7 | 8 | import numpy as np 9 | 10 | def create_directory(path): 11 | if not os.path.isdir(path): 12 | os.makedirs(path) 13 | return path 14 | 15 | def str2bool(v): 16 | if isinstance(v, bool): 17 | return v 18 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 19 | return True 20 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 21 | return False 22 | else: 23 | raise argparse.ArgumentTypeError('Boolean value expected.') 24 | 25 | -------------------------------------------------------------------------------- /tools/general/json_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2020 * Ltd. All rights reserved. 2 | # author : Sanghyeon Jo 3 | 4 | import json 5 | 6 | def read_json(filepath): 7 | with open(filepath, 'r') as f: 8 | data = json.load(f) 9 | return data 10 | 11 | def write_json(filepath, data): 12 | with open(filepath, 'w') as f: 13 | json.dump(data, f, indent = '\t') 14 | 15 | -------------------------------------------------------------------------------- /tools/general/pickle_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2020 * Ltd. All rights reserved. 2 | # author : Sanghyeon Jo 3 | 4 | import pickle 5 | 6 | def dump_pickle(path, data): 7 | pickle.dump(data, open(path, 'wb')) 8 | 9 | def load_pickle(path): 10 | return pickle.load(open(path, 'rb')) 11 | 12 | -------------------------------------------------------------------------------- /tools/general/time_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2020 * Ltd. All rights reserved. 2 | # author : Sanghyeon Jo 3 | 4 | import time 5 | 6 | def get_today(): 7 | now = time.localtime() 8 | s = "%04d-%02d-%02d-%02dh%02dm%02ds" % (now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, now.tm_sec) 9 | return s 10 | 11 | class Timer: 12 | def __init__(self): 13 | self.start_time = 0.0 14 | self.end_time = 0.0 15 | 16 | self.tik() 17 | 18 | def tik(self): 19 | self.start_time = time.time() 20 | 21 | def tok(self, ms = False, clear=False): 22 | self.end_time = time.time() 23 | 24 | if ms: 25 | duration = int((self.end_time - self.start_time) * 1000) 26 | else: 27 | duration = int(self.end_time - self.start_time) 28 | 29 | if clear: 30 | self.tik() 31 | 32 | return duration -------------------------------------------------------------------------------- /tools/general/txt_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2020 * Ltd. All rights reserved. 2 | # author : Sanghyeon Jo 3 | 4 | def read_txt(path): 5 | with open(path, 'r') as f: 6 | return [line.strip() for line in f.readlines()] 7 | 8 | def write_txt(path, data_list): 9 | with open(path, 'w') as f: 10 | for data in data_list: 11 | f.write(data + '\n') 12 | 13 | def add_txt(path, string): 14 | with open(path, 'a+') as f: 15 | f.write(string + '\n') -------------------------------------------------------------------------------- /tools/general/xml_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2020 * Ltd. All rights reserved. 2 | # author : Sanghyeon Jo 3 | 4 | import xml.etree.ElementTree as ET 5 | 6 | def read_xml(xml_path): 7 | tree = ET.parse(xml_path) 8 | root = tree.getroot() 9 | 10 | size = root.find('size') 11 | image_width = int(size.find('width').text) 12 | image_height = int(size.find('height').text) 13 | 14 | bboxes = [] 15 | classes = [] 16 | 17 | for obj in root.findall('object'): 18 | label = obj.find('name').text 19 | bbox = obj.find('bndbox') 20 | 21 | bbox_xmin = max(min(int(bbox.find('xmin').text.split('.')[0]), image_width - 1), 0) 22 | bbox_ymin = max(min(int(bbox.find('ymin').text.split('.')[0]), image_height - 1), 0) 23 | bbox_xmax = max(min(int(bbox.find('xmax').text.split('.')[0]), image_width - 1), 0) 24 | bbox_ymax = max(min(int(bbox.find('ymax').text.split('.')[0]), image_height - 1), 0) 25 | 26 | if (bbox_xmax - bbox_xmin) == 0 or (bbox_ymax - bbox_ymin) == 0: 27 | continue 28 | 29 | bboxes.append([bbox_xmin, bbox_ymin, bbox_xmax, bbox_ymax]) 30 | classes.append(label) 31 | 32 | return bboxes, classes 33 | -------------------------------------------------------------------------------- /train_cls.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from utils.LoadData import Wsss_test_dataset 3 | from utils.LoadData import Wsss_dataset 4 | import torchvision.transforms.functional as transF 5 | from torchvision import transforms 6 | from PIL import Image 7 | from utils.Metrics import DiceMetric 8 | from networks.ham_net import ham_net 9 | import torch.nn.functional as F 10 | from utils.utils import bg2mask, monte_augmentation 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import argparse 14 | import torch 15 | import numpy as np 16 | import sys 17 | import os 18 | import time 19 | import logging 20 | sys.path.append(os.getcwd()) 21 | 22 | 23 | def seg_to_color(seg): 24 | H, W = seg.shape[1], seg.shape[2] 25 | # whit, green, blue, yellow 26 | classes = ["background", "Tumor", "Stroma", "Normal"] 27 | color_map = [[255, 255, 255], [0, 64, 128], [64, 128, 0], [243, 152, 0]] 28 | img = np.zeros((H, W, 3)) 29 | for i in range(H): 30 | for j in range(W): 31 | img[i, j, :] = color_map[seg[0, i, j]] 32 | return img 33 | 34 | def get_arguments(): 35 | parser = argparse.ArgumentParser(description="HAMIL pytorch implementation") 36 | parser.add_argument("--dataset_root", type=str, 37 | default="", help="training images") 38 | parser.add_argument("--batch_size", type=int, 39 | default=32, help="Train batch size") 40 | parser.add_argument("--num_classes", type=int, 41 | default=3, help="Train class num") 42 | parser.add_argument("--lr", type=float, default=1e-3) 43 | parser.add_argument("--weight_decay", type=float, default=5e-4) 44 | parser.add_argument("--gpu", nargs="+", type=int) 45 | parser.add_argument("--train_epochs", default=100, type=int) 46 | parser.add_argument("--save_folder", default="checkpoints") 47 | parser.add_argument("--checkpoint", type=str, default="") 48 | parser.add_argument("--input_size", type=int, default=256) 49 | parser.add_argument("--crop_size", type=int, default=224) 50 | 51 | return parser.parse_args() 52 | 53 | def get_model(args, pre_trained=False): 54 | model = ham_net() 55 | if pre_trained: 56 | ckpt = torch.load(args.checkpoint, map_location="cpu") 57 | model.load_state_dict(ckpt["model"], strict=True) 58 | 59 | model = torch.nn.DataParallel(model, device_ids=args.gpu) 60 | param_groups = model.module.get_parameter_groups() 61 | optimizer = optim.SGD( 62 | [ 63 | {"params": param_groups[0], "lr": args.lr}, 64 | {"params": param_groups[1], "lr": 2 * args.lr}, 65 | {"params": param_groups[2], "lr": 10 * args.lr}, 66 | {"params": param_groups[3], "lr": 20 * args.lr}, 67 | ], 68 | momentum=0.9, 69 | weight_decay=args.weight_decay, 70 | nesterov=True, 71 | ) 72 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[40, 80]) 73 | return model, optimizer, scheduler 74 | 75 | 76 | def train(model, optimizer, train_dataloader): 77 | model.train() 78 | loss_epoch = 0 79 | for img, label, _ in train_dataloader: 80 | 81 | img, label = img.cuda(), label.cuda() 82 | logit = model(img) 83 | 84 | # loss 85 | loss = F.multilabel_soft_margin_loss(logit, label) 86 | loss_epoch += loss 87 | 88 | optimizer.zero_grad() 89 | loss.backward() 90 | optimizer.step() 91 | print(f"train loss:{loss_epoch.item()/len(train_dataloader)}") 92 | 93 | def train_deep(model, optimizer, train_dataloader): 94 | model.train() 95 | loss_epoch = 0 96 | cls_acc = 0 97 | count = 0 98 | 99 | for img, label, _ in train_dataloader: 100 | 101 | img, label = img.cuda(), label.cuda() 102 | logit_b6, logit_b5, logit_b4 = model(img) 103 | 104 | # compute classification loss 105 | loss1 = F.multilabel_soft_margin_loss(logit_b6, label) 106 | loss2 = F.multilabel_soft_margin_loss(logit_b5, label) 107 | loss3 = F.multilabel_soft_margin_loss(logit_b4, label) 108 | 109 | loss = (loss1+loss2+loss3)/3 110 | 111 | loss_epoch += loss 112 | # compute cls 113 | label_cpu = label.cpu().detach().numpy() 114 | logit_cpu = logit_b6.cpu().detach().numpy() 115 | logit_cpu = logit_cpu > 0 116 | correct_num = np.sum(label_cpu == logit_cpu, axis=0) 117 | cls_acc += correct_num 118 | count += label_cpu.shape[0] 119 | 120 | optimizer.zero_grad() 121 | loss.backward() 122 | optimizer.step() 123 | print(f"train loss:{loss_epoch.item()/len(train_dataloader)}") 124 | return sum(cls_acc) / count / 3 125 | 126 | def compute_dice(model, valid_dataloader, verbose=False, save=False): 127 | model.eval() 128 | # Dice_metric 129 | Dice_Metric = DiceMetric(4) 130 | # cls acc 131 | cls_acc = 0 132 | count = 0 133 | 134 | # my background 135 | my_background_root = "/mnt/data1/dataset/WSSS4LUAD/2.validation/my_bg_mask_patch_256/" 136 | """for every image, compute the dice""" 137 | with torch.no_grad(): 138 | for img, label, bg_mask, gt, raw_img, img_name in valid_dataloader: 139 | # H, W 140 | H, W = gt.shape[1], gt.shape[2] 141 | 142 | img, label = img.cuda(), label.cuda() 143 | # logit, cam = model(img, True, (H, W)) 144 | # multi-scale 145 | img_mean, img_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 146 | scale = 256 147 | cam_a = torch.zeros((1, 3, H, W)).cuda() 148 | logit = torch.zeros((1, 3)).cuda() 149 | img_trans = transforms.Compose([transforms.Resize( 150 | (scale, scale)), transforms.ToTensor(), transforms.Normalize(img_mean, img_std)]) 151 | img_path = valid_dataset.img_path + "/" + img_name[0] 152 | _img = Image.open(img_path).convert("RGB") 153 | _img = img_trans(_img) 154 | _img = torch.unsqueeze(_img, 0) 155 | _img = _img.cuda() 156 | 157 | logit1, _, _, _, _, _ = model(_img, True, (H, W)) 158 | logit = logit1 159 | 160 | cam_a = monte_augmentation(20, model, img_path, H, W) 161 | cam = cam_a.clone() 162 | 163 | # compute cls 164 | label_cpu = label.cpu().detach().numpy() 165 | logit_cpu = logit.cpu().detach().numpy() 166 | logit_cpu = logit_cpu > 0 167 | correct_num = np.sum(label_cpu == logit_cpu, axis=0) 168 | cls_acc += correct_num 169 | count += label_cpu.shape[0] 170 | 171 | cam = cam.detach() * label[:, :, None, None] 172 | my_bg_mask = Image.open(my_background_root + img_name[0]) 173 | my_bg_mask = np.array(my_bg_mask, np.uint8) 174 | # compute dice 175 | Dice_Metric.add_batch(cam, gt, my_bg_mask, label_cpu) 176 | print(f"cls_acc:{sum(cls_acc)/count/3}", cls_acc, count, cls_acc / count) 177 | return Dice_Metric.compute_dice(verbose=verbose, save=save), sum(cls_acc)/count/3 178 | 179 | 180 | def save_pic(model, dataloader): 181 | model.eval() 182 | # my background 183 | my_background_root = "/mnt/data1/dataset/WSSS4LUAD/1.training/gamma_crf_train/" 184 | with torch.no_grad(): 185 | for img, label, img_name in dataloader: 186 | # H, W 187 | H, W = img.shape[2], img.shape[3] 188 | img, label = img.cuda(), label.cuda() 189 | 190 | img_path = test_dataset.img_path + "/" + img_name[0] 191 | # logit, cam = model(img, True, (H, W)) 192 | img_mean, img_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 193 | s = [256] 194 | cam_a = torch.zeros((1, 3, H, W)).cuda() 195 | logit = torch.zeros((1, 3)).cuda() 196 | for scale in s: 197 | img_trans = transforms.Compose([transforms.Resize( 198 | (scale, scale)), transforms.ToTensor(), transforms.Normalize(img_mean, img_std)]) 199 | 200 | _img = Image.open(img_path).convert("RGB") 201 | _img = img_trans(_img) 202 | _img = torch.unsqueeze(_img, 0) 203 | _img = _img.cuda() 204 | 205 | # deep3 206 | logit, _, _, cam_b6, cam_b5, cam_b4 = model( 207 | _img, True, (H, W)) 208 | cam_a += (cam_b4+cam_b5+cam_b6)/3 209 | 210 | cam_a = monte_augmentation(20, model, img_path, H, W) 211 | # cam_a = cam_a/len(s) 212 | logit = logit/len(s) 213 | 214 | cam = cam_a.clone() 215 | 216 | # cam = cam.detach() * label[:, :, None, None] 217 | cam = cam.detach() * logit[:, :, None, None] + 1e-7 218 | my_bg_mask = Image.open(my_background_root + img_name[0]) 219 | my_bg_mask = np.array(my_bg_mask, np.uint8) 220 | 221 | # save pic 222 | cam_with_bg = np.concatenate( 223 | (np.expand_dims(my_bg_mask, 0), cam[0].cpu().numpy()), axis=0) 224 | segmentation = cam_with_bg.argmax(0) 225 | segmentation = np.reshape(segmentation, (1, H, W)) 226 | 227 | color_img = seg_to_color(segmentation) 228 | color_img = color_img.astype(np.uint8) 229 | color_img = Image.fromarray(color_img) 230 | color_img.save(f"pseudo_masks/stage1/" + img_name[0]) 231 | 232 | 233 | if __name__ == "__main__": 234 | args = get_arguments() 235 | logging.basicConfig(level=logging.INFO, filename=f'log/train_cls_deep.txt') 236 | 237 | time_start = time.time() 238 | # training and validation dataset 239 | train_dataset = Wsss_dataset(args, train=True) 240 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=6) 241 | 242 | valid_dataset = Wsss_test_dataset(args, test=False) 243 | valid_dataloader = DataLoader(valid_dataset, batch_size=1) 244 | 245 | test_dataset = Wsss_test_dataset(args, test=True) 246 | test_dataloader = DataLoader(test_dataset, batch_size=1) 247 | 248 | # network and optimizer 249 | model, optimizer, scheduler = get_model(args, pre_trained=False) 250 | 251 | # # gpu setting 252 | torch.cuda.set_device(args.gpu[0]) 253 | model.cuda() 254 | 255 | if not os.path.exists(args.save_folder): 256 | os.makedirs(args.save_folder) 257 | 258 | cls_acc_max = 0 259 | 260 | for i in range(args.train_epochs): 261 | print(f"\nepoch:{i+1} \n-------------------") 262 | 263 | t0 = time.time() 264 | train_cls_acc = train_deep(model, optimizer, train_dataloader) 265 | t1 = time.time() 266 | valid_dice, valid_cls_acc = compute_dice(model, valid_dataloader, verbose=True) 267 | t2 = time.time() 268 | print("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) 269 | logging.info('train cls_acc {0:.4f}, valid cls_acc {1:.4f}'.format(train_cls_acc, valid_cls_acc)) 270 | 271 | scheduler.step() 272 | if valid_dice > cls_acc_max: 273 | cls_acc_max = valid_dice 274 | "save" 275 | ckpt = model.module.state_dict() 276 | print("current best model") 277 | torch.save(ckpt, args.checkpoint) 278 | 279 | model_test = ham_net() 280 | model_test.cuda() 281 | ckpt = torch.load(args.checkpoint, map_location="cpu") 282 | model_test.load_state_dict(ckpt['model'], strict=True) 283 | compute_dice(model_test, test_dataloader, verbose=True, save='log/ham_net.csv') 284 | save_pic(model, train_dataloader) 285 | time_end = time.time() 286 | print(f'done, time:{(time_end-time_start)/60}') 287 | -------------------------------------------------------------------------------- /train_seg.py: -------------------------------------------------------------------------------- 1 | from utils.Metrics import DiceMetric 2 | from core.networks import DeepLabv3_Plus 3 | from torch.utils.data import DataLoader 4 | from utils.LoadData_with_bg import Wsss_dataset, Wsss_test_dataset 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import argparse 9 | import torch 10 | from PIL import Image 11 | import numpy as np 12 | import sys 13 | import os 14 | import time 15 | import logging 16 | from utils.utils import monte_augmentation 17 | sys.path.append(os.getcwd()) 18 | 19 | 20 | class KDLoss(nn.Module): 21 | """ 22 | Distilling the Knowledge in a Neural Network 23 | https://arxiv.org/pdf/1503.02531.pdf 24 | """ 25 | 26 | def __init__(self, T): 27 | super(KDLoss, self).__init__() 28 | self.T = T 29 | 30 | def forward(self, out_s, out_t): 31 | loss = ( 32 | F.kl_div(F.log_softmax(out_s / self.T, dim=1), 33 | F.softmax(out_t / self.T, dim=1), reduction="batchmean") 34 | * self.T 35 | * self.T 36 | ) 37 | return loss 38 | 39 | def get_arguments(): 40 | parser = argparse.ArgumentParser(description="HAMIL pytorch implementation") 41 | parser.add_argument("--dataset_root", type=str, 42 | default="", help="training images") 43 | parser.add_argument("--batch_size", type=int, 44 | default=32, help="Train batch size") 45 | parser.add_argument("--lr", type=float, default=2e-3) 46 | parser.add_argument("--wd", type=float, default=5e-4) 47 | 48 | parser.add_argument("--gpu", nargs="+", type=int) 49 | parser.add_argument("--max_epoch", default=100, type=int) 50 | parser.add_argument("--save_folder", default="checkpoints2/train_seg") 51 | parser.add_argument("--checkpoint1", type=str, default="") 52 | parser.add_argument("--checkpoint2", type=str, default="") 53 | parser.add_argument("--checkpoint3", type=str, default="") 54 | 55 | parser.add_argument("--image_resize", default=256, type=int) 56 | parser.add_argument("--image_crop", default=224, type=int) 57 | parser.add_argument("--test_image_size", default=224, type=int) 58 | parser.add_argument("--alpha", type=float, 59 | default=0.2, help="Weight factor") 60 | parser.add_argument("--T", type=float, default=30, 61 | help="Temperature for KD") 62 | return parser.parse_args() 63 | 64 | def get_model(args): 65 | model = DeepLabv3_Plus("resnet50", use_group_norm=True) 66 | model = torch.nn.DataParallel(model, device_ids=args.gpu) 67 | param_groups = model.module.get_parameter_groups(None) 68 | optimizer = optim.SGD( 69 | [ 70 | {"params": param_groups[0], "lr": args.lr}, 71 | {"params": param_groups[1], "lr": 2 * args.lr}, 72 | {"params": param_groups[2], "lr": 10 * args.lr}, 73 | {"params": param_groups[3], "lr": 20 * args.lr}, 74 | ], 75 | momentum=0.9, 76 | weight_decay=args.wd, 77 | nesterov=True, 78 | ) 79 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[40, 80]) 80 | return model, optimizer, scheduler 81 | 82 | def selection(outputs_main, outputs_aux1, outputs_aux2, mask): 83 | n = outputs_main.shape[0] 84 | loss_main = F.cross_entropy( 85 | outputs_main, mask, reduction='none').view(n, -1) 86 | hard_aux1 = torch.argmax(outputs_aux1, dim=1).view(n, -1) 87 | hard_aux2 = torch.argmax(outputs_aux2, dim=1).view(n, -1) 88 | loss_select = 0 89 | for i in range(n): 90 | aux1_sample = hard_aux1[i] 91 | aux2_sample = hard_aux2[i] 92 | loss_sample = loss_main[i] 93 | agree_aux = (aux1_sample == aux2_sample) 94 | disagree_aux = (aux1_sample != aux2_sample) 95 | loss_select += 2*torch.sum(loss_sample[agree_aux]) + \ 96 | 0.5*torch.sum(loss_sample[disagree_aux]) 97 | return loss_select / (n*loss_main.shape[1]) 98 | 99 | 100 | def weight_loss(loss): 101 | n = loss.shape[0] 102 | loss = loss.view(n, -1) 103 | loss_weight = F.softmax(loss.clone().detach(), dim=1) / torch.mean( 104 | F.softmax(loss.clone().detach(), dim=1), dim=1, keepdim=True 105 | ) 106 | loss = torch.sum(loss * loss_weight) / (n * loss.shape[1]) 107 | return loss 108 | 109 | def joint_optimization(outputs_main, outputs_aux1, outputs_aux2, mask, kd_weight, kd_T): 110 | kd_loss = KDLoss(T=kd_T) 111 | avg_aux = (outputs_aux1 + outputs_aux2) / 2 112 | 113 | L_kd = kd_loss(outputs_main.permute(0, 2, 3, 1).reshape(-1, 4), 114 | avg_aux.permute(0, 2, 3, 1).reshape(-1, 4)) 115 | L_ce = selection(outputs_main, outputs_aux1, outputs_aux2, mask) 116 | L = L_ce + kd_weight * L_kd 117 | return L 118 | 119 | def train_tri(model1, model2, model3, optimizer1, optimizer2, optimizer3, train_dataloader, args, epoch): 120 | model1.train() 121 | model2.train() 122 | model3.train() 123 | train_loss_epoch = 0 124 | 125 | for img, mask, soft_mask, img_names in train_dataloader: 126 | n, c, h, w = img.size() 127 | img, mask = img.cuda(), mask.type(torch.LongTensor).cuda() 128 | soft_mask = soft_mask.to(torch.double).cuda() 129 | 130 | outputs1 = model1(img) 131 | outputs2 = model2(img) 132 | outputs3 = model3(img) 133 | 134 | loss1 = joint_optimization(outputs1, outputs2.detach( 135 | ), outputs3.detach(), mask, args.alpha, args.T) 136 | loss2 = joint_optimization(outputs2, outputs3.detach( 137 | ), outputs1.detach(), mask, args.alpha, args.T) 138 | loss3 = joint_optimization(outputs3, outputs1.detach( 139 | ), outputs2.detach(), mask, args.alpha, args.T) 140 | 141 | optimizer1.zero_grad() 142 | loss1.backward() 143 | optimizer1.step() 144 | optimizer2.zero_grad() 145 | loss2.backward() 146 | optimizer2.step() 147 | optimizer3.zero_grad() 148 | loss3.backward() 149 | optimizer3.step() 150 | 151 | train_loss_epoch += (loss1 + loss2 + loss3) / 3 152 | print(f"tri-kd, train_loss:{train_loss_epoch.item()/len(train_dataloader)}") 153 | 154 | def validate(model1, model2, model3, valid_dataloader, verbose=False, Monte=False, save=False): 155 | Dice_Metric = DiceMetric(4) 156 | model1.eval() 157 | model2.eval() 158 | model3.eval() 159 | my_background_root = "/mnt/data1/dataset/WSSS4LUAD/2.validation/my_bg_mask_patch_256/" 160 | with torch.no_grad(): 161 | for img1, mask, img_label, img_names in valid_dataloader: 162 | img1 = img1.cuda() 163 | ori_size = mask.shape[1:] 164 | H, W = ori_size[0], ori_size[1] 165 | 166 | # bg 167 | my_bg_mask = Image.open(my_background_root + img_names[0]) 168 | my_bg_mask = np.array(my_bg_mask, np.uint8) 169 | 170 | img_path = valid_dataset.img_path + "/" + img_names[0] 171 | 172 | # monte carlo test 173 | if not Monte: 174 | pred1 = (model1(img1) + model2(img1) + model3(img1)) / 3 175 | else: 176 | pred1 = ( 177 | monte_augmentation(20, model1, img_path, 178 | H, W, args.test_image_size) 179 | + monte_augmentation(20, model2, img_path, 180 | H, W, args.test_image_size) 181 | + monte_augmentation(20, model3, img_path, 182 | H, W, args.test_image_size) 183 | ) / 3 184 | pred1 = F.interpolate(pred1, size=ori_size, 185 | mode="bilinear", align_corners=True) 186 | pred = pred1 187 | 188 | Dice_Metric.add_batch(pred, mask, my_bg_mask) 189 | if not verbose: 190 | return Dice_Metric.compute_dice(False, save) 191 | else: 192 | return Dice_Metric.compute_dice(True, save) 193 | 194 | 195 | if __name__ == "__main__": 196 | start_time = time.time() 197 | args = get_arguments() 198 | logging.basicConfig(level=logging.INFO, filename=f'log/train_tri_kd.txt') 199 | # training and validation dataset 200 | train_dataset = Wsss_dataset(args) 201 | train_dataloader = DataLoader( 202 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=6) 203 | 204 | train_dataloader2 = DataLoader( 205 | train_dataset, batch_size=1, shuffle=True, num_workers=6) 206 | 207 | valid_dataset = Wsss_test_dataset(args, valid=True) 208 | valid_dataloader = DataLoader(valid_dataset, batch_size=1) 209 | 210 | test_dataset = Wsss_test_dataset(args, valid=False) 211 | test_dataloader = DataLoader(test_dataset, batch_size=1) 212 | 213 | # network and optimizer 214 | model1, optimizer1, scheduler1 = get_model(args) 215 | model2, optimizer2, scheduler2 = get_model(args) 216 | model3, optimizer3, scheduler3 = get_model(args) 217 | # use gpu 218 | torch.cuda.set_device(args.gpu[0]) 219 | model1.cuda() 220 | model2.cuda() 221 | model3.cuda() 222 | 223 | if not os.path.exists(args.save_folder): 224 | os.makedirs(args.save_folder) 225 | 226 | dice_max = 0 227 | for i in range(args.max_epoch): 228 | print(f"\nepoch:{i+1} \n-------------------") 229 | t0 = time.time() 230 | train_tri(model1, model2, model3, optimizer1, optimizer2, 231 | optimizer3, train_dataloader, args, i) 232 | t1 = time.time() 233 | valid_dice = validate(model1, model2, model3, valid_dataloader, verbose=True) 234 | t2 = time.time() 235 | print("training/validation time: {0:.2f}s/{1:.2f}s".format(t1 - t0, t2 - t1)) 236 | logging.info('valid dice {1:.4f}'.format(valid_dice)) 237 | 238 | scheduler1.step() 239 | scheduler2.step() 240 | scheduler3.step() 241 | if valid_dice > dice_max: 242 | dice_max = valid_dice 243 | "save" 244 | cpkt1 = model1.module.state_dict() 245 | cpkt2 = model2.module.state_dict() 246 | cpkt3 = model3.module.state_dict() 247 | print("current best model dice=", dice_max) 248 | torch.save(cpkt1, os.path.join( 249 | args.save_folder, "tri-kd-pre1.pth")) 250 | torch.save(cpkt2, os.path.join( 251 | args.save_folder, "tri-kd-pre2.pth")) 252 | torch.save(cpkt3, os.path.join( 253 | args.save_folder, "tri-kd-pre3.pth")) 254 | # test 255 | print("test") 256 | model1 = DeepLabv3_Plus("resnet50", use_group_norm=True, mode="fix") 257 | model1.cuda() 258 | model1.eval() 259 | ckpt = torch.load(args.checkpoint1, map_location="cpu") 260 | model1.load_state_dict(ckpt, strict=True) 261 | 262 | model2 = DeepLabv3_Plus("resnet50", use_group_norm=True, mode="fix") 263 | model2.cuda() 264 | model2.eval() 265 | ckpt = torch.load(args.checkpoint2, map_location="cpu") 266 | model2.load_state_dict(ckpt, strict=True) 267 | 268 | model3 = DeepLabv3_Plus("resnet50", use_group_norm=True, mode="fix") 269 | model3.cuda() 270 | model3.eval() 271 | ckpt = torch.load(args.checkpoint3, map_location="cpu") 272 | model3.load_state_dict(ckpt, strict=True) 273 | 274 | validate(model1, model2, model3, test_dataloader, verbose=True) 275 | end_time = time.time() 276 | print("running time / mins:", (end_time - start_time) / 60) 277 | -------------------------------------------------------------------------------- /utils/CRF.py: -------------------------------------------------------------------------------- 1 | import pydensecrf.densecrf as dcrf 2 | import numpy as np 3 | 4 | 5 | class MyCRF: 6 | def __init__( 7 | self, 8 | pos_xy_std=1, 9 | pos_w=15, 10 | bi_xy_std=16, 11 | bi_rgb_std=4, 12 | bi_w=10, 13 | maxiter=10, 14 | scale_factor=1.0, 15 | ): 16 | self.pos_xy_std = pos_xy_std 17 | self.pos_w = pos_w 18 | self.bi_xy_std = bi_xy_std 19 | self.bi_rgb_std = bi_rgb_std 20 | self.bi_w = bi_w 21 | self.maxiter = maxiter 22 | self.scale_factor = scale_factor 23 | 24 | def inference(self, im, unary): 25 | H, W = im.shape[:2] 26 | C = unary.shape[0] 27 | d = dcrf.DenseCRF2D(W, H, C) 28 | d.setUnaryEnergy(-unary.reshape(C, -1)) 29 | d.addPairwiseGaussian( 30 | sxy=( 31 | self.pos_xy_std / self.scale_factor, 32 | self.pos_xy_std / self.scale_factor, 33 | ), 34 | compat=self.pos_w, 35 | kernel=dcrf.DIAG_KERNEL, 36 | normalization=dcrf.NORMALIZE_SYMMETRIC, 37 | ) 38 | 39 | d.addPairwiseBilateral( 40 | sxy=( 41 | self.bi_xy_std / self.scale_factor, 42 | self.bi_xy_std / self.scale_factor, 43 | ), 44 | srgb=(self.bi_rgb_std, self.bi_rgb_std, self.bi_rgb_std), 45 | rgbim=im.astype(np.uint8), 46 | compat=self.bi_w, 47 | kernel=dcrf.DIAG_KERNEL, 48 | normalization=dcrf.NORMALIZE_SYMMETRIC, 49 | ) 50 | prediction = np.array(d.inference(self.maxiter), dtype=np.float32).reshape( 51 | (C, H, W) 52 | ) 53 | 54 | return prediction 55 | -------------------------------------------------------------------------------- /utils/LoadData.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | sys.path.append(os.getcwd()) 5 | import numpy as np 6 | import torch 7 | from torchvision import transforms 8 | from torch.utils.data import DataLoader, Dataset 9 | from PIL import Image 10 | import re 11 | from torchvision.transforms import functional as F 12 | import random 13 | 14 | class Normalization(object): 15 | def __call__(self, image): 16 | image_cpu = image.cpu().detach().numpy() 17 | mean_image = [0, 0, 0] 18 | std_image = [0, 0, 0] 19 | r, g, b = ( 20 | image_cpu[0, :, :].copy(), 21 | image_cpu[1, :, :].copy(), 22 | image_cpu[2, :, :].copy(), 23 | ) 24 | mean_image[0], mean_image[1], mean_image[2] = ( 25 | np.mean(r, axis=(0, 1)), 26 | np.mean(g, axis=(0, 1)), 27 | np.mean(b, axis=(0, 1)), 28 | ) 29 | std_image[0], std_image[1], std_image[2] = ( 30 | np.std(r, axis=(0, 1)), 31 | np.std(g, axis=(0, 1)), 32 | np.std(b, axis=(0, 1)), 33 | ) 34 | image = F.normalize(image, mean=mean_image, std=std_image) 35 | return image 36 | 37 | 38 | class Wsss_dataset(Dataset): 39 | def __init__(self, args, train=True): 40 | super().__init__() 41 | self.dataset_root = args.dataset_root + "/" + "1.training/" 42 | self.my_bg_path = "/mnt/data1/dataset/WSSS4LUAD/1.training/gamma_crf_train/" 43 | self.input_size = args.input_size 44 | self.crop_size = args.crop_size 45 | img_paths = os.listdir(self.dataset_root) 46 | self.img_paths = self._filter(img_paths) 47 | self.mean, self.std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 48 | # shuffle 49 | np.random.seed(0) 50 | np.random.shuffle(self.img_paths) 51 | 52 | self.img_paths = self.get_new_set(self.img_paths) 53 | # self.img_paths = self.get_single_cls(self.img_paths) 54 | print(f'train images:{len(self.img_paths)}') 55 | if train: 56 | self.img_transform = transforms.Compose( 57 | [ 58 | transforms.Resize((self.input_size, self.input_size)), 59 | # transforms.Resize((random.randint(args.min_image_size, args.max_image_size),)), 60 | transforms.RandomHorizontalFlip(), 61 | transforms.RandomVerticalFlip(), 62 | transforms.RandomRotation((0, 180)), 63 | transforms.ToTensor(), 64 | transforms.Normalize(self.mean, self.std), 65 | # [0.6270, 0.5013, 0.7519], [0.1627, 0.1682, 0.0977] another dataset. 66 | # [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] ImageNet 67 | # [0.737, 0.505, 0.678], [0.176, 0.210, 0.146] WSSS train dataset 68 | ] 69 | ) 70 | else: 71 | self.img_paths = self.img_paths[:] 72 | print(f'generate pseudo masks:{len(self.img_paths)}') 73 | self.img_transform = transforms.Compose( 74 | [ 75 | transforms.Resize((self.input_size, self.input_size)), 76 | transforms.ToTensor(), 77 | transforms.Normalize(self.mean, self.std), 78 | ] 79 | ) 80 | 81 | def _filter(self, img_paths): 82 | img_list_new = [] 83 | for img_name in img_paths: 84 | if img_name[-4:] == ".png": 85 | img_list_new.append(img_name) 86 | else: 87 | pass 88 | return img_list_new 89 | 90 | def __getitem__(self, index): 91 | img_name = self.img_paths[index] 92 | img_path = self.dataset_root + "/" + img_name 93 | label = self.get_label(img_name) 94 | img = Image.open(img_path).convert("RGB") 95 | 96 | raw_img = np.array(img, dtype=np.uint8) 97 | img = self.img_transform(img) 98 | 99 | return img, label, img_name 100 | 101 | def get_label(self, img_name): 102 | res = re.findall(r"\[(.*?)\]", img_name) 103 | label = torch.tensor(list(eval(res[0]))) 104 | return label 105 | 106 | def get_single_cls(self, img_names): 107 | new_names = [] 108 | for img_name in img_names: 109 | label = self.get_label(img_name) 110 | if sum(label) == 1: 111 | new_names.append(img_name) 112 | return new_names 113 | 114 | def get_new_set(self, img_names): 115 | new_names = [] 116 | for img_name in img_names: 117 | label = self.get_label(img_name) 118 | if sum(label) != 1: 119 | new_names.append(img_name) 120 | single_names = self.get_single_cls(img_names) 121 | portion = 0.25 122 | single_names = single_names[:int(portion*len(single_names))] 123 | new_names = new_names + single_names 124 | return new_names 125 | 126 | def __len__(self): 127 | return len(self.img_paths) 128 | 129 | 130 | class Wsss_test_dataset(Dataset): 131 | def __init__(self, args, test=False): 132 | super().__init__() 133 | self.dataset_root = args.dataset_root + "/" + "2.validation" 134 | self.input_size = args.input_size 135 | self.crop_size = args.crop_size 136 | self.mean, self.std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 137 | # "_patch_224_320" 138 | self.mask_path = self.dataset_root + "/" + "mask_patch_256" 139 | self.img_path = self.dataset_root + "/" + "img_patch_256" 140 | self.my_bg_path = self.dataset_root + "/" + "my_bg_mask_patch_256" 141 | self.img_names = os.listdir(self.img_path) 142 | # self.img_names = self.img_names[:31] 143 | np.random.seed(42) 144 | np.random.shuffle(self.img_names) 145 | "colormap" 146 | self.classes = ["background", "Tumor", "Stroma", "Normal"] 147 | self.color_map = [[255, 255, 255], [0, 64, 128], [64, 128, 0], [243, 152, 0]] 148 | if test==False: 149 | self.img_names = self.img_names[:300] 150 | # self.img_names = self.filter_single_class_image() 151 | print(f'valid images:{len(self.img_names)}') 152 | else: 153 | self.img_names = self.img_names[300:] 154 | print(f'test images:{len(self.img_names)}') 155 | "transform." 156 | self.img_trans = transforms.Compose( 157 | [ 158 | transforms.Resize((self.crop_size, self.crop_size)), 159 | transforms.ToTensor(), 160 | transforms.Normalize(self.mean, self.std), 161 | ] 162 | ) 163 | 164 | def __getitem__(self, index): 165 | 166 | img_name = self.img_names[index] 167 | 168 | img_path = self.img_path + "/" + img_name 169 | mask_path = self.mask_path + "/" + img_name 170 | my_bg_path = self.my_bg_path + "/" + img_name 171 | 172 | img = Image.open(img_path).convert("RGB") 173 | raw_img = np.array(img, dtype=np.uint8) 174 | mask = Image.open(mask_path).convert("RGB") 175 | 176 | my_bg_mask = Image.open(my_bg_path) 177 | my_bg_mask = np.array(my_bg_mask,dtype=np.uint8) 178 | 179 | # convert the background to 0 180 | # img = np.array(img,dtype=np.uint8) 181 | # img[my_bg_mask==255]=0 182 | # img = Image.fromarray(img) 183 | 184 | mask = np.array(mask).astype(np.uint8) 185 | mask = self.image2label(mask) 186 | label = self.get_label_from_img(mask) 187 | img = self.img_trans(img) 188 | # my_bg_map, 189 | return img, label, my_bg_mask, mask, raw_img, img_name 190 | 191 | def image2label(self, im): 192 | color2int = np.zeros(256 ** 3) 193 | for idx, color in enumerate(self.color_map): 194 | color2int[(color[0] * 256 + color[1]) * 256 + color[2]] = idx 195 | data = np.array(im, dtype=np.int32) 196 | idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2] 197 | return np.array(color2int[idx], dtype=np.int32) 198 | 199 | def get_label_from_img(self, img_label): 200 | temp = np.unique(img_label) 201 | cls_label = np.zeros( 202 | 3, 203 | ) 204 | if 1 in temp: 205 | cls_label[0] = 1 206 | if 2 in temp: 207 | cls_label[1] = 1 208 | if 3 in temp: 209 | cls_label[2] = 1 210 | return torch.tensor(cls_label, dtype=torch.int32) 211 | 212 | def filter_single_class_image(self): 213 | single_class_images = [] 214 | for img_name in self.img_names: 215 | mask_path = self.mask_path + "/" + img_name 216 | mask = Image.open(mask_path).convert("RGB") 217 | mask = np.array(mask).astype(np.uint8) 218 | mask = self.image2label(mask) 219 | label = self.get_label_from_img(mask) 220 | if sum(label) != 1: 221 | if label[2] == 1: 222 | single_class_images.append(img_name) 223 | return single_class_images 224 | 225 | def filter_multi_class_image(self): 226 | multi_class_images = [] 227 | for img_name in self.img_names: 228 | mask_path = self.mask_path + "/" + img_name 229 | mask = Image.open(mask_path).convert("RGB") 230 | mask = np.array(mask).astype(np.uint8) 231 | mask = self.image2label(mask) 232 | label = self.get_label_from_img(mask) 233 | if sum(label) != 1: 234 | multi_class_images.append(img_name) 235 | return multi_class_images 236 | 237 | def filter(self, img_paths): 238 | img_list_new = [] 239 | for img_name in img_paths: 240 | if "_" in img_name: 241 | img_list_new.append(img_name) 242 | else: 243 | pass 244 | return img_list_new 245 | 246 | def __len__(self): 247 | return len(self.img_names) 248 | 249 | 250 | if __name__ == "__main__": 251 | 252 | parser = argparse.ArgumentParser(description="Wsss pytorch implementation") 253 | parser.add_argument("--dataset_root", type=str, default="/mnt/data1/dataset/WSSS4LUAD/", help="training images") 254 | parser.add_argument("--batch_size", type=int, default=64, help="Train batch size") 255 | parser.add_argument("--num_classes", type=int, default=3, help="Train class num") 256 | parser.add_argument("--delta", type=float, default=0, help="set 0 for the...") 257 | # 0.01 for deep3 -->81.6; 0.005 for deep2; 0.005 for deep4; 0.001 for baseline 258 | parser.add_argument("--lr", type=float, default=0.01) 259 | parser.add_argument("--weight_decay", type=float, default=0.0005) 260 | parser.add_argument("--gpu", nargs="+", type=int) 261 | parser.add_argument("--train_epochs", default=30, type=int) 262 | parser.add_argument("--save_folder", default="checkpoints") 263 | parser.add_argument("--checkpoint", type=str, default="") 264 | 265 | # (280, 256), (256, 224), (324, 256), 266 | parser.add_argument("--input_size", type=int, default=256) 267 | parser.add_argument("--crop_size", type=int, default=224) 268 | parser.add_argument("--test_size", type=int, default=384) 269 | 270 | parser.add_argument('--image_size', default=256, type=int) 271 | parser.add_argument('--min_image_size', default=256, type=int) 272 | parser.add_argument('--max_image_size', default=512, type=int) 273 | 274 | args = parser.parse_args() 275 | 276 | train_set = Wsss_dataset(args) 277 | for item in train_set: 278 | break -------------------------------------------------------------------------------- /utils/Metrics.py: -------------------------------------------------------------------------------- 1 | from re import A 2 | import numpy as np 3 | from sklearn.metrics import confusion_matrix 4 | import torch 5 | 6 | def get_soft_label(input_tensor, num_class, data_type = 'float'): 7 | """ 8 | convert a label tensor to one-hot label 9 | input_tensor: tensor with shae [B, 1, D, H, W] or [B, 1, H, W] 10 | output_tensor: shape [B, num_class, D, H, W] or [B, num_class, H, W] 11 | """ 12 | 13 | shape = input_tensor.shape 14 | if len(shape) == 5: 15 | output_tensor = torch.nn.functional.one_hot(input_tensor[:, 0], num_classes = num_class).permute(0, 4, 1, 2, 3) 16 | elif len(shape) == 4: 17 | output_tensor = torch.nn.functional.one_hot(input_tensor[:, 0], num_classes = num_class).permute(0, 3, 1, 2) 18 | else: 19 | raise ValueError("dimention of data can only be 4 or 5: {0:}".format(len(shape))) 20 | 21 | if(data_type == 'float'): 22 | output_tensor = output_tensor.float() 23 | elif(data_type == 'double'): 24 | output_tensor = output_tensor.double() 25 | else: 26 | raise ValueError("data type can only be float and double: {0:}".format(data_type)) 27 | 28 | return output_tensor 29 | 30 | 31 | def reshape_prediction_and_ground_truth(predict, soft_y): 32 | """ 33 | reshape input variables of shape [B, C, D, H, W] to [voxel_n, C] 34 | """ 35 | tensor_dim = len(predict.size()) 36 | num_class = list(predict.size())[1] 37 | if(tensor_dim == 5): 38 | soft_y = soft_y.permute(0, 2, 3, 4, 1) 39 | predict = predict.permute(0, 2, 3, 4, 1) 40 | elif(tensor_dim == 4): 41 | soft_y = soft_y.permute(0, 2, 3, 1) 42 | predict = predict.permute(0, 2, 3, 1) 43 | else: 44 | raise ValueError("{0:}D tensor not supported".format(tensor_dim)) 45 | 46 | predict = torch.reshape(predict, (-1, num_class)) 47 | soft_y = torch.reshape(soft_y, (-1, num_class)) 48 | 49 | return predict, soft_y 50 | 51 | 52 | def get_classwise_dice(predict, soft_y, pix_w = None): 53 | """ 54 | get dice scores for each class in predict (after softmax) and soft_y 55 | """ 56 | 57 | if(pix_w is None): 58 | y_vol = torch.sum(soft_y, dim = 0) 59 | p_vol = torch.sum(predict, dim = 0) 60 | intersect = torch.sum(soft_y * predict, dim = 0) 61 | else: 62 | y_vol = torch.sum(soft_y * pix_w, dim = 0) 63 | p_vol = torch.sum(predict * pix_w, dim = 0) 64 | intersect = torch.sum(soft_y * predict * pix_w, dim = 0) 65 | dice_score = (2.0 * intersect + 1e-5)/ (y_vol + p_vol + 1e-5) 66 | return dice_score 67 | 68 | 69 | class Cls_Accuracy(): 70 | def __init__(self): 71 | self.total = 0 72 | self.correct = 0 73 | 74 | 75 | def update(self, logit, label): 76 | 77 | logit = logit.sigmoid_() 78 | logit = (logit >= 0.5) 79 | all_correct = torch.all(logit == label.byte(), dim=1).float().sum().item() 80 | 81 | self.total += logit.size(0) 82 | self.correct += all_correct 83 | 84 | def compute_avg_acc(self): 85 | return self.correct / self.total 86 | 87 | 88 | 89 | class RunningConfusionMatrix(): 90 | """Running Confusion Matrix class that enables computation of confusion matrix 91 | on the go and has methods to compute such accuracy metrics as Mean Intersection over 92 | Union MIOU. 93 | 94 | Attributes 95 | ---------- 96 | labels : list[int] 97 | List that contains int values that represent classes. 98 | overall_confusion_matrix : sklean.confusion_matrix object 99 | Container of the sum of all confusion matrices. Used to compute MIOU at the end. 100 | ignore_label : int 101 | A label representing parts that should be ignored during 102 | computation of metrics 103 | 104 | """ 105 | 106 | def __init__(self, labels, ignore_label=255): 107 | 108 | self.labels = labels 109 | self.ignore_label = ignore_label 110 | self.overall_confusion_matrix = None 111 | 112 | def update_matrix(self, ground_truth, prediction): 113 | """Updates overall confusion matrix statistics. 114 | If you are working with 2D data, just .flatten() it before running this 115 | function. 116 | Parameters 117 | ---------- 118 | groundtruth : array, shape = [n_samples] 119 | An array with groundtruth values 120 | prediction : array, shape = [n_samples] 121 | An array with predictions 122 | """ 123 | 124 | # Mask-out value is ignored by default in the sklearn 125 | # read sources to see how that was handled 126 | # But sometimes all the elements in the groundtruth can 127 | # be equal to ignore value which will cause the crush 128 | # of scikit_learn.confusion_matrix(), this is why we check it here 129 | if (ground_truth == self.ignore_label).all(): 130 | 131 | return 132 | 133 | current_confusion_matrix = confusion_matrix(y_true=ground_truth, 134 | y_pred=prediction, 135 | labels=self.labels) 136 | 137 | if self.overall_confusion_matrix is not None: 138 | 139 | self.overall_confusion_matrix += current_confusion_matrix 140 | else: 141 | 142 | self.overall_confusion_matrix = current_confusion_matrix 143 | 144 | def compute_current_mean_intersection_over_union(self): 145 | 146 | intersection = np.diag(self.overall_confusion_matrix) 147 | ground_truth_set = self.overall_confusion_matrix.sum(axis=1) 148 | predicted_set = self.overall_confusion_matrix.sum(axis=0) 149 | union = ground_truth_set + predicted_set - intersection 150 | 151 | #intersection_over_union = intersection / (union.astype(np.float32) + 1e-4) 152 | intersection_over_union = intersection / union.astype(np.float32) 153 | 154 | mean_intersection_over_union = np.mean(intersection_over_union) 155 | 156 | return mean_intersection_over_union 157 | 158 | 159 | class IOUMetric: 160 | """ 161 | Class to calculate mean-iou using fast_hist method 162 | """ 163 | 164 | def __init__(self, num_classes): 165 | self.num_classes = num_classes 166 | self.hist = np.zeros((num_classes, num_classes)) 167 | 168 | def _fast_hist(self, label_pred, label_true): 169 | mask = (label_true >= 0) & (label_true < self.num_classes) 170 | 171 | hist = np.bincount( 172 | self.num_classes*label_true[mask] + label_pred[mask], 173 | minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes) 174 | 175 | return hist 176 | 177 | def add_batch(self, predictions, gts): 178 | for lp, lt in zip(predictions, gts): 179 | self.hist += self._fast_hist(lp.flatten(), lt.flatten()) 180 | 181 | def evaluate(self): 182 | acc = np.diag(self.hist).sum() / self.hist.sum() 183 | acc_cls = np.diag(self.hist) / self.hist.sum(axis=1) 184 | acc_cls = np.nanmean(acc_cls) 185 | iu = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist)) 186 | mean_iu = np.nanmean(iu) 187 | freq = self.hist.sum(axis=1) / self.hist.sum() 188 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 189 | cls_iu = dict(zip(range(self.num_classes), iu)) 190 | 191 | return { 192 | "Pixel_Accuracy": acc, 193 | "Mean_Accuracy": acc_cls, 194 | "Frequency_Weighted_IoU": fwavacc, 195 | "Mean_IoU": mean_iu, 196 | "Class_IoU": cls_iu, 197 | } 198 | 199 | 200 | class DiceMetric: 201 | def __init__(self, num_classes): 202 | self.num_classes = num_classes 203 | self.train_dice_list = [] 204 | 205 | def add_batch(self, pred, gt, my_bg_mask, label_cpu): 206 | # compute dice 207 | if sum(label_cpu[0]) == 1: 208 | label_cls = list(label_cpu[0]).index(1) + 1 209 | my_bg_mask[my_bg_mask==0]=label_cls 210 | my_bg_mask[my_bg_mask==255]=0 211 | outputs_argmax = np.expand_dims(np.expand_dims(my_bg_mask,0),0) 212 | outputs_argmax = torch.tensor(outputs_argmax, dtype=torch.int64) 213 | 214 | else: 215 | pred_seg = torch.argmax(pred, dim=1) 216 | pred_seg = pred_seg[0] + 1 217 | pred_seg = pred_seg.cpu().numpy() 218 | pred_seg[my_bg_mask==255] = 0 219 | outputs_argmax = np.expand_dims((np.expand_dims(pred_seg,0)),0) 220 | outputs_argmax = torch.tensor(outputs_argmax).long() 221 | 222 | soft_out = get_soft_label(outputs_argmax, 4) 223 | labels_prob = torch.unsqueeze(gt, 1).long() 224 | labels_prob = get_soft_label(labels_prob, 4) 225 | soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) 226 | dice_list = get_classwise_dice(soft_out, labels_prob) 227 | self.train_dice_list.append(dice_list.cpu().numpy()) 228 | 229 | def compute_dice(self, verbose=False, save=False): 230 | train_dice_list = np.asarray(self.train_dice_list)*100 231 | train_dice_list = train_dice_list[:,1:] 232 | if save: 233 | np.savetxt(save, train_dice_list, delimiter=",") 234 | # print(train_dice_list) 235 | train_cls_dice = train_dice_list.mean(axis = 0) 236 | train_avg_dice = train_dice_list.mean(axis = 1) 237 | train_std_dice = train_avg_dice.std() 238 | train_scalers = {'avg_dice':train_avg_dice.mean(), 'class_dice': train_cls_dice,'std_dice':train_std_dice} 239 | 240 | if verbose: 241 | print("%.2f"%train_cls_dice[0],"%.2f"%train_cls_dice[1],"%.2f"%train_cls_dice[2],"%.2f"%train_cls_dice.mean()) 242 | print("%.2f"%train_dice_list[:,0].std(),"%.2f"%train_dice_list[:,1].std(),"%.2f"%train_dice_list[:,2].std(),"%.2f"%train_dice_list.std(0).mean()) 243 | else: 244 | print("%.2f"%train_scalers['avg_dice']) 245 | return train_scalers['avg_dice'] 246 | 247 | def compute_dice_exist(self, verbose=False): 248 | train_dice_list = np.asarray(self.train_dice_list)*100 249 | train_dice_list = train_dice_list[:,1:] 250 | dice_cls = np.zeros((train_dice_list.shape[1],)) 251 | std_cls = np.zeros((train_dice_list.shape[1],)) 252 | for i in range(train_dice_list.shape[1]): 253 | i_cls_dice_list = train_dice_list[:,i].copy() 254 | i_cls_dice_list = i_cls_dice_list[i_cls_dice_list!=100] 255 | dice_cls[i] = i_cls_dice_list.mean() 256 | std_cls[i] = i_cls_dice_list.std() 257 | if verbose: 258 | print("%.2f"%dice_cls[0],"%.2f"%dice_cls[1],"%.2f"%dice_cls[2],"%.2f"%dice_cls.mean()) 259 | print("%.2f"%std_cls[0],"%.2f"%std_cls[1],"%.2f"%std_cls[2],"%.2f"%std_cls.mean()) 260 | else: 261 | print("%.2f"%dice_cls.mean()) 262 | return dice_cls.mean() -------------------------------------------------------------------------------- /utils/generate_bg_masks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import scipy 5 | from utils.CRF import MyCRF 6 | import re 7 | import cv2 8 | from skimage import morphology 9 | 10 | 11 | def gamma_transform(img_path): 12 | image = Image.open(img_path).convert("L") 13 | image = np.array(image, dtype=np.float32) 14 | image /= 255 15 | gamma = 2.4 16 | out = np.power(image, gamma) 17 | out *= 255 18 | out = out.astype(np.uint8) 19 | 20 | return out 21 | 22 | def seg_to_color(seg): 23 | H, W = seg.shape[0], seg.shape[1] 24 | # white, green, blue, yellow 25 | classes = ["background", "Tumor", "Stroma", "Normal"] 26 | color_map = [[255, 255, 255], [0, 64, 128], [64, 128, 0], [243, 152, 0]] 27 | img = np.zeros((H, W, 3)) 28 | for i in range(H): 29 | for j in range(W): 30 | img[i, j, :] = color_map[seg[i, j]] 31 | return img 32 | 33 | def open_with_crf(img_path, open_img_path): 34 | img = Image.open(img_path).convert("RGB") 35 | img_array = np.array(img, dtype=np.float32) 36 | open_img = Image.open(open_img_path) 37 | open_img = np.array(open_img, dtype=np.float32) 38 | H, W = img_array.shape[0], img_array.shape[1] 39 | mycrf = MyCRF() 40 | p = 1.0 41 | background = open_img.copy() 42 | background /= np.max(background) 43 | foreground = (1 - background) * p 44 | probability_map = np.concatenate( 45 | (foreground.reshape((1, H, W)), background.reshape((1, H, W))), axis=0 46 | ) 47 | out = mycrf.inference(img_array, probability_map) 48 | out = out.argmax(0) 49 | return (out * 255).astype(np.uint8) 50 | 51 | def get_label(img_name): 52 | res = re.findall(r"\[(.*?)\]", img_name) 53 | label = np.array(list(eval(res[0])), dtype=np.uint8) 54 | return label 55 | 56 | 57 | if __name__ == "__main__": 58 | dataset_root = "/mnt/data1/dataset/WSSS4LUAD/1.training/" 59 | gamma_dir = "gamma_transform_valid_patch_v2/" 60 | gamma_crf_dir = "gamma_crf_valid_patch_v2/" 61 | img_paths = os.listdir(dataset_root) 62 | # clean 63 | img_paths = filter(img_paths) 64 | print(f"all {len(img_paths)} images") 65 | 66 | i = 0 67 | count = [0, 0, 0] 68 | 69 | for img_name in img_paths: 70 | """save gamma crf background""" 71 | img_path = dataset_root + img_name 72 | img_gamma = gamma_transform(img_path) 73 | img_gamma = Image.fromarray(img_gamma) 74 | gamma_path = gamma_dir + img_name 75 | img_gamma.save(gamma_path) 76 | open_crf = open_with_crf(img_path, gamma_path) 77 | img_open_crf = Image.fromarray(open_crf) 78 | img_open_crf.save(gamma_crf_dir + img_name) 79 | out = Image.open(img_path) 80 | out = np.array(out).astype(np.uint8) 81 | if len(np.unique(out)) == 1: 82 | os.remove(img_path) 83 | out_remove = np.array(out, dtype=bool) 84 | morphology.remove_small_holes(out_remove, 32, 1, True) 85 | out_remove = Image.fromarray(out_remove) 86 | out_remove.save("gamma_crf_train/" + img_name) 87 | 88 | print("done!") 89 | 90 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.getcwd()) 4 | from PIL import Image 5 | import torch 6 | import numpy as np 7 | import torch.nn.functional as F 8 | import torch 9 | from torchvision import transforms 10 | import torchvision.transforms.functional as transF 11 | 12 | def bg2mask(bg_path, cls_index, H, W): 13 | """ 14 | input: single image background, [H,W], [0,255] 15 | output: mask, [H,W] 16 | """ 17 | # 0 for foreground, 255 for background 18 | bg = np.array(Image.open(bg_path),dtype=np.uint8) 19 | bg[bg==0] = cls_index 20 | bg[bg==255] = 0 21 | bg = torch.tensor(bg,dtype=torch.uint8).unsqueeze(0).unsqueeze(0) 22 | bg = F.interpolate(bg,(H,W),mode='nearest') 23 | return bg.squeeze(0).squeeze(0) 24 | 25 | 26 | def monte_augmentation(n, model, img_path, H, W): 27 | """ 28 | n: the number 29 | """ 30 | cam = torch.zeros((1, 3, H, W)).cuda() 31 | img_mean, img_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 32 | for _ in range(n): 33 | scale_factor = 1 + np.random.uniform(low=-0.25,high=0.25) 34 | H_scale, W_scale = int(scale_factor*256), int(scale_factor*256) 35 | p_vflip = np.random.rand() 36 | p_hflip = np.random.rand() 37 | vflip_flag = p_vflip>0.5 38 | hflip_flag = p_hflip>0.5 39 | rotation_degree = int(np.random.choice([0, 90, 180 ,270])) 40 | 41 | img_trans = transforms.Compose([transforms.Resize((H_scale, W_scale)), transforms.ToTensor(), transforms.Normalize(img_mean, img_std)]) 42 | _img = Image.open(img_path).convert("RGB") 43 | _img = img_trans(_img) 44 | _img = torch.unsqueeze(_img, 0) 45 | _img = _img.cuda() 46 | 47 | if vflip_flag: 48 | _img = transF.vflip(_img) 49 | 50 | if hflip_flag: 51 | _img = transF.hflip(_img) 52 | 53 | if rotation_degree: 54 | # print(rotation_degree) 55 | _img = transF.rotate(_img,rotation_degree) 56 | 57 | _,_,_,cam_b6,cam_b5,cam_b4 = model(_img,True,(H,W)) 58 | 59 | if rotation_degree: 60 | cur_degree = 360-rotation_degree 61 | cam_b6,cam_b5,cam_b4 = transF.rotate(cam_b6,cur_degree),transF.rotate(cam_b5,cur_degree),transF.rotate(cam_b4,cur_degree) 62 | 63 | if hflip_flag: 64 | cam_b6,cam_b5,cam_b4 = transF.hflip(cam_b6),transF.hflip(cam_b5),transF.hflip(cam_b4) 65 | 66 | if vflip_flag: 67 | cam_b6,cam_b5,cam_b4 = transF.vflip(cam_b6),transF.vflip(cam_b5),transF.vflip(cam_b4) 68 | 69 | cam += (cam_b6+cam_b5+cam_b4)/3 70 | 71 | return cam/n 72 | 73 | if __name__ == "__main__": 74 | monte_augmentation(20, True, True) --------------------------------------------------------------------------------