├── README.md ├── T-SC-STICL ├── README.md ├── finetune.py ├── moco │ ├── __init__.py │ ├── builder.py │ └── loader.py └── pretrain.py ├── T-SS-GLCNet ├── config │ ├── __init__.py │ ├── get_opt.py │ └── opt_road.py ├── data_example │ └── Potsdam │ │ ├── class_dict.txt │ │ ├── trainR1_RGBIR.txt │ │ ├── trainR1_lbl.txt │ │ ├── train_RGBIR.txt │ │ ├── train_RGBIR │ │ └── top_potsdam_2_10_RGBIR__36.tif │ │ ├── train_lbl.txt │ │ ├── train_lbl │ │ └── top_potsdam_2_10_RGBIR__36.tif │ │ ├── valR1500_RGBIR.txt │ │ ├── valR1500_lbl.txt │ │ ├── val_RGBIR.txt │ │ ├── val_RGBIR │ │ └── top_potsdam_2_13_RGBIR__26.tif │ │ ├── val_lbl.txt │ │ └── val_lbl │ │ └── top_potsdam_2_13_RGBIR__26.tif ├── dl_tools │ └── basictools │ │ ├── dlbasic.py │ │ ├── dldata.py │ │ ├── dlimage.py │ │ ├── dltrain.py │ │ └── fileop.py ├── evalute.py ├── main_ss.py ├── models │ ├── GLCNet.py │ ├── GLCNet_nodecoder.py │ ├── __init__.py │ ├── deeplab_utils │ │ ├── ResNet.py │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ └── sync_batchnorm │ │ │ ├── __init__.py │ │ │ ├── batchnorm.py │ │ │ ├── batchnorm_reimpl.py │ │ │ ├── comm.py │ │ │ ├── replicate.py │ │ │ └── unittest.py │ ├── v3p.py │ ├── v3p_SimCLR_encoder.py │ ├── v3p_decoder12_ft.py │ ├── v3p_encoder_ft.py │ ├── v3p_inpaiting_encoder.py │ ├── v3p_jigsaw_encoder.py │ ├── v3p_mocov2_encoder_true.py │ └── v3p_resnet_ft.py ├── permutations_1000.npy ├── readme ├── readme.md ├── utils │ ├── __init__.py │ ├── contrast_loss.py │ ├── data_for_seg.py │ ├── data_for_self_GLCNet.py │ ├── data_for_self_Jigsaw.py │ ├── data_for_self_contrast.py │ ├── data_for_self_inpaitting.py │ ├── loss.py │ ├── nt_xent.py │ ├── tools.py │ └── util.py └── utils_SS_pretexts.py ├── TOV_v1 ├── .gitignore ├── README.md ├── Readme ├── TOV_models │ └── Place the TOV pre-training model in this floder ├── classification │ ├── config │ │ ├── __init__.py │ │ ├── category.py │ │ └── default │ │ │ ├── __init__.py │ │ │ ├── opt_AID.py │ │ │ ├── opt_EuroSAT_RGB.py │ │ │ ├── opt_NR.py │ │ │ ├── opt_PatternNet.py │ │ │ ├── opt_RSD46.py │ │ │ ├── opt_TianGong2_RGB.py │ │ │ └── opt_UCMerced.py │ ├── dataset │ │ ├── __init__.py │ │ ├── classfy_data.py │ │ └── data_interface.py │ ├── main_cls.py │ ├── models │ │ ├── __init__.py │ │ ├── basic_model.py │ │ └── cls_net.py │ ├── tov_finetune_cls.bash │ └── utils │ │ ├── __init__.py │ │ ├── mycallbacks.py │ │ └── tools.py └── segmentation │ ├── config │ ├── __init__.py │ ├── category.py │ └── default │ │ ├── __init__.py │ │ ├── opt_CVPR_LandCover.py │ │ ├── opt_DLRSD.py │ │ └── opt_ISPRS_Postdam.py │ ├── dataset │ ├── __init__.py │ ├── data_interface.py │ ├── seg_data.py │ └── transforms.py │ ├── main_seg.py │ └── models │ ├── __init__.py │ ├── basic_model.py │ └── fcn.py ├── big_picture.png └── big_picture.svg /T-SC-STICL/README.md: -------------------------------------------------------------------------------- 1 | ## STICL: Spatial-temporal Invariant Contrastive Learning for Remote Sensing Scene Classification 2 | 3 | This is a PyTorch implementation of the [STICL](https://ieeexplore.ieee.org/document/9770815): 4 | ``` 5 | @ARTICLE{huang2022sticl, 6 | author={Huang, Haozhe and Mou, Zhongfeng and Li, Yunying and Li, Qiujun and Chen, Jie and Li, Haifeng}, 7 | journal={IEEE Geoscience and Remote Sensing Letters}, 8 | title={Spatial-Temporal Invariant Contrastive Learning for Remote Sensing Scene Classification}, 9 | year={2022}, 10 | volume={19}, 11 | number={}, 12 | pages={1-5}, 13 | doi={10.1109/LGRS.2022.3173419}} 14 | ``` 15 | 16 | ### Details 17 | This version of sticl is implemented based on [Moco v2](https://github.com/facebookresearch/moco), and spatial-temporal Invariant Contrastive Learning can also be used in other self-supervised learning methods. The framework of STICL consists of two main parts, the pre-training phase and the fine-tuning phase. 18 | 19 | #### **Pretraining** 20 | 21 | For example, to do pre-training of a ResNet-50 model on unlabeled RSIs, run: 22 | ``` 23 | python pretrain.py \ 24 | -a resnet50 \ 25 | --lr 0.03 \ 26 | --batch-size 256 \ 27 | --epochs 300 \ 28 | --mlp \ 29 | --moco-t 0.2 \ 30 | --save_model 'mocov2_mix_bs256_300e_sti_10p.pth.tar' \ 31 | --cos \ 32 | --dist-url 'tcp://localhost:10000' --multiprocessing-distributed --world-size 1 --rank 0 \ 33 | --st_prob 0.1 \ 34 | --data [your unlabeled image folder with train and val folders] 35 | 36 | ``` 37 | --st_prob (between 0 and 1) for adjusting the strength of Spatial-temporal feature transfer. 38 | 39 | #### **Finetuning** 40 | 41 | With a pre-trained model, to do finetuning on labeled RSIs, run: 42 | ``` 43 | python finetune.py \ 44 | -a resnet50 \ 45 | --lr 0.01 \ 46 | --batch-size 256 \ 47 | --method 'moco' \ 48 | --pretrained [your pre-trained model]\ 49 | --num_samples 20 \ 50 | --nclass 30 \ 51 | --data [your labeled image folder with train and val folders] 52 | ``` 53 | --num_samples is used to set how many samples per class are used for finetuning. 54 | -------------------------------------------------------------------------------- /T-SC-STICL/moco/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /T-SC-STICL/moco/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class MoCo(nn.Module): 7 | """ 8 | Build a MoCo model with: a query encoder, a key encoder, and a queue 9 | https://arxiv.org/abs/1911.05722 10 | """ 11 | def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False): 12 | """ 13 | dim: feature dimension (default: 128) 14 | K: queue size; number of negative keys (default: 65536) 15 | m: moco momentum of updating key encoder (default: 0.999) 16 | T: softmax temperature (default: 0.07) 17 | """ 18 | super(MoCo, self).__init__() 19 | 20 | self.K = K 21 | self.m = m 22 | self.T = T 23 | 24 | # create the encoders 25 | # num_classes is the output fc dimension 26 | self.encoder_q = base_encoder(num_classes=dim) 27 | self.encoder_k = base_encoder(num_classes=dim) 28 | 29 | if mlp: # hack: brute-force replacement 30 | dim_mlp = self.encoder_q.fc.weight.shape[1] 31 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 32 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 33 | 34 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 35 | param_k.data.copy_(param_q.data) # initialize 36 | param_k.requires_grad = False # not update by gradient 37 | 38 | # create the queue 39 | self.register_buffer("queue", torch.randn(dim, K)) 40 | self.queue = nn.functional.normalize(self.queue, dim=0) 41 | 42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 43 | 44 | @torch.no_grad() 45 | def _momentum_update_key_encoder(self): 46 | """ 47 | Momentum update of the key encoder 48 | """ 49 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 50 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 51 | 52 | @torch.no_grad() 53 | def _dequeue_and_enqueue(self, keys): 54 | # gather keys before updating queue 55 | keys = concat_all_gather(keys) 56 | 57 | batch_size = keys.shape[0] 58 | 59 | ptr = int(self.queue_ptr) 60 | assert self.K % batch_size == 0 # for simplicity 61 | 62 | # replace the keys at ptr (dequeue and enqueue) 63 | self.queue[:, ptr:ptr + batch_size] = keys.T 64 | ptr = (ptr + batch_size) % self.K # move pointer 65 | 66 | self.queue_ptr[0] = ptr 67 | 68 | @torch.no_grad() 69 | def _batch_shuffle_ddp(self, x): 70 | """ 71 | Batch shuffle, for making use of BatchNorm. 72 | *** Only support DistributedDataParallel (DDP) model. *** 73 | """ 74 | # gather from all gpus 75 | batch_size_this = x.shape[0] 76 | x_gather = concat_all_gather(x) 77 | batch_size_all = x_gather.shape[0] 78 | 79 | num_gpus = batch_size_all // batch_size_this 80 | 81 | # random shuffle index 82 | idx_shuffle = torch.randperm(batch_size_all).cuda() 83 | 84 | # broadcast to all gpus 85 | torch.distributed.broadcast(idx_shuffle, src=0) 86 | 87 | # index for restoring 88 | idx_unshuffle = torch.argsort(idx_shuffle) 89 | 90 | # shuffled index for this gpu 91 | gpu_idx = torch.distributed.get_rank() 92 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 93 | 94 | return x_gather[idx_this], idx_unshuffle 95 | 96 | @torch.no_grad() 97 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 98 | """ 99 | Undo batch shuffle. 100 | *** Only support DistributedDataParallel (DDP) model. *** 101 | """ 102 | # gather from all gpus 103 | batch_size_this = x.shape[0] 104 | x_gather = concat_all_gather(x) 105 | batch_size_all = x_gather.shape[0] 106 | 107 | num_gpus = batch_size_all // batch_size_this 108 | 109 | # restored index for this gpu 110 | gpu_idx = torch.distributed.get_rank() 111 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 112 | 113 | return x_gather[idx_this] 114 | 115 | def forward(self, im_q, im_k): 116 | """ 117 | Input: 118 | im_q: a batch of query images 119 | im_k: a batch of key images 120 | Output: 121 | logits, targets 122 | """ 123 | 124 | # compute query features 125 | q = self.encoder_q(im_q) # queries: NxC 126 | q = nn.functional.normalize(q, dim=1) 127 | 128 | # compute key features 129 | with torch.no_grad(): # no gradient to keys 130 | self._momentum_update_key_encoder() # update the key encoder 131 | 132 | # shuffle for making use of BN 133 | im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) 134 | 135 | k = self.encoder_k(im_k) # keys: NxC 136 | k = nn.functional.normalize(k, dim=1) 137 | 138 | # undo shuffle 139 | k = self._batch_unshuffle_ddp(k, idx_unshuffle) 140 | 141 | # compute logits 142 | # Einstein sum is more intuitive 143 | # positive logits: Nx1 144 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 145 | # negative logits: NxK 146 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 147 | 148 | # logits: Nx(1+K) 149 | logits = torch.cat([l_pos, l_neg], dim=1) 150 | 151 | # apply temperature 152 | logits /= self.T 153 | 154 | # labels: positive key indicators 155 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 156 | 157 | # dequeue and enqueue 158 | self._dequeue_and_enqueue(k) 159 | 160 | return logits, labels 161 | 162 | 163 | # utils 164 | @torch.no_grad() 165 | def concat_all_gather(tensor): 166 | """ 167 | Performs all_gather operation on the provided tensors. 168 | *** Warning ***: torch.distributed.all_gather has no gradient. 169 | """ 170 | tensors_gather = [torch.ones_like(tensor) 171 | for _ in range(torch.distributed.get_world_size())] 172 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 173 | 174 | output = torch.cat(tensors_gather, dim=0) 175 | return output 176 | -------------------------------------------------------------------------------- /T-SC-STICL/moco/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from PIL import ImageFilter 3 | import random 4 | import torch.utils.data as data 5 | 6 | from PIL import Image 7 | import os 8 | import os.path 9 | import numpy as np 10 | import ot 11 | from torchvision import transforms 12 | from torchvision.transforms.transforms import Resize 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 19 | std=[0.229, 0.224, 0.225]) 20 | 21 | class TwoCropsTransform_sti: 22 | """Take two random crops of one image as the query and key.""" 23 | 24 | def __init__(self, base_transform, prob): 25 | self.base_transform = base_transform 26 | self.prob = prob 27 | 28 | def __call__(self, x1, x2): 29 | 30 | q = self.base_transform(x1) 31 | 32 | is_sttrans = np.random.rand() 33 | if is_sttrans < self.prob: 34 | xt = ST_transf(x1, x2) 35 | k = self.base_transform(xt) 36 | else: 37 | k = self.base_transform(x1) 38 | 39 | return [q, k] 40 | 41 | 42 | class GaussianBlur(object): 43 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 44 | 45 | def __init__(self, sigma=[.1, 2.]): 46 | self.sigma = sigma 47 | 48 | def __call__(self, x): 49 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 50 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 51 | return x 52 | 53 | def ST_transf(content_img, style_img): 54 | 55 | adjust_image = transforms.Compose([ 56 | transforms.Resize([256, 256]), 57 | ]) 58 | rand_seed = np.random.RandomState(42) 59 | 60 | content_img = adjust_image(content_img) 61 | style_img = adjust_image(style_img) 62 | 63 | I1 = np.array(content_img).astype(np.float64) / 256 64 | I2 = np.array(style_img).astype(np.float64) / 256 65 | X1 = I1.reshape((I1.shape[0] * I1.shape[1], I1.shape[2])) 66 | X2 = I2.reshape((I2.shape[0] * I2.shape[1], I2.shape[2])) 67 | idx1 = rand_seed.randint(X1.shape[0], size=(1000,)) 68 | idx2 = rand_seed.randint(X2.shape[0], size=(1000,)) 69 | Xs = X1[idx1, :] 70 | Xt = X2[idx2, :] 71 | 72 | ot_emd = ot.da.EMDTransport() 73 | ot_emd.fit(Xs=Xs, Xt=Xt) 74 | trans_Xs_emd = ot_emd.transform(Xs=X1, batch_size=1024) 75 | Image_emd = np.clip(trans_Xs_emd.reshape(I1.shape), 0, 1) 76 | Image_emd = Image_emd * 255 77 | Image_aug = Image.fromarray(Image_emd.astype('uint8')) 78 | 79 | return Image_aug 80 | 81 | def is_image_file(filename): 82 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 83 | 84 | def find_classes(dir): 85 | classes = os.listdir(dir) 86 | classes.sort() 87 | class_to_idx = {classes[i]: i for i in range(len(classes))} 88 | return classes, class_to_idx 89 | 90 | 91 | def make_dataset(dir, class_to_idx, num_sample = 0): 92 | images = [] 93 | if num_sample > 0: 94 | for target in os.listdir(dir): 95 | d = os.path.join(dir, target) 96 | if not os.path.isdir(d): 97 | continue 98 | count = 0 99 | for filename in os.listdir(d): 100 | if is_image_file(filename) and count < num_sample: 101 | path = '{0}/{1}'.format(target, filename) 102 | item = (path, class_to_idx[target]) 103 | images.append(item) 104 | count = count + 1 105 | else: 106 | for target in os.listdir(dir): 107 | d = os.path.join(dir, target) 108 | if not os.path.isdir(d): 109 | continue 110 | 111 | for filename in os.listdir(d): 112 | if is_image_file(filename): 113 | path = '{0}/{1}'.format(target, filename) 114 | item = (path, class_to_idx[target]) 115 | images.append(item) 116 | 117 | return images 118 | 119 | def default_loader(path): 120 | return Image.open(path).convert('RGB') 121 | 122 | class ImageFolderLoader(data.Dataset): 123 | def __init__(self, root, transform=None, 124 | target_transform=None, 125 | loader=default_loader, num_sample = 0, is_pretrain = False): 126 | 127 | if num_sample < 0: 128 | raise Exception('Error: Number of sample should not less than 0') 129 | 130 | if num_sample > 0: 131 | if is_pretrain: 132 | raise Exception('Error: Pretrain mode can not choose number of sample') 133 | else: 134 | classes, class_to_idx = find_classes(root) 135 | imgs = make_dataset(root, class_to_idx, num_sample) 136 | else: #num_sample is 0 137 | classes, class_to_idx = find_classes(root) 138 | imgs = make_dataset(root, class_to_idx) 139 | 140 | 141 | self.root = root 142 | self.is_pretrain = is_pretrain 143 | self.imgs = imgs 144 | self.classes = classes 145 | self.class_to_idx = class_to_idx 146 | self.transform = transform 147 | self.target_transform = target_transform 148 | self.loader = loader 149 | 150 | def __len__(self): 151 | return len(self.imgs) 152 | 153 | def __getitem__(self, index): 154 | 155 | if self.is_pretrain: 156 | 157 | path, target = self.imgs[index] 158 | max_index = len(self.imgs) 159 | 160 | #Random pick an image as scene sample 161 | index_aug = np.random.randint(0, max_index) 162 | path_aug, target_aug = self.imgs[index_aug] 163 | 164 | img = self.loader(os.path.join(self.root, path)) 165 | img_aug = self.loader(os.path.join(self.root, path_aug)) 166 | 167 | if self.transform is not None: 168 | img = self.transform(img, img_aug) 169 | if self.target_transform is not None: 170 | target = self.target_transform(target) 171 | else: 172 | 173 | path, target = self.imgs[index] 174 | img = self.loader(os.path.join(self.root, path)) 175 | 176 | if self.transform is not None: 177 | img = self.transform(img) 178 | if self.target_transform is not None: 179 | target = self.target_transform(target) 180 | 181 | return img, target 182 | 183 | class TwoCropsTransform: 184 | """Take two random crops of one image as the query and key.""" 185 | 186 | def __init__(self, base_transform): 187 | self.base_transform = base_transform 188 | 189 | def __call__(self, x1, x2): 190 | 191 | q = self.base_transform(x1) 192 | k = self.base_transform(x1) 193 | return [q, k] -------------------------------------------------------------------------------- /T-SS-GLCNet/config/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /T-SS-GLCNet/config/get_opt.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import os 3 | import time 4 | # import config as config_lib 5 | 6 | TYPE2BAND = {'RGB': 3, 'RGBIR': 4, 'SAR': 1, 'TEN': 10, 'ALL': 12, 'MS': 13} # 'ALL' for bigearth data; 'MS' for EuroSAT 7 | 8 | from pathlib import Path 9 | 10 | def unify_type(param, ptype=list, repeat=1): 11 | ''' Unify the type of param. 12 | 13 | Args: 14 | ptype: support list or tuple 15 | repeat: The times of repeating param in a list or tuple type. 16 | ''' 17 | if repeat == 1: 18 | if type(param) is not ptype: 19 | if ptype == list: 20 | param = [param] 21 | elif ptype == tuple: 22 | param = (param) 23 | elif repeat > 1: 24 | if type(param) is ptype and len(param) == repeat: 25 | return param 26 | elif type(param) is list: 27 | param = param * repeat 28 | else: 29 | param = [param] * repeat 30 | param = ptype(param) 31 | 32 | return param 33 | 34 | def get_config(args=None): 35 | from .opt_road import DefaultConfig 36 | 37 | if args is not None: 38 | mOptions = DefaultConfig(args) 39 | else: 40 | mOptions = DefaultConfig() 41 | 42 | return mOptions 43 | def get_opt(args=None): 44 | '''Get options by name, and may use args to update them.''' 45 | 46 | opts = get_config() 47 | if args is not None: 48 | # Extra 49 | 50 | # Use ArgumentParser object to update the default configs 51 | for k, v in args.__dict__.items(): 52 | if v is not None or not hasattr(opts, k): 53 | setattr(opts, k, v) 54 | 55 | if opts.n_channels==None: 56 | opts.n_channels=TYPE2BAND[opts.dtype] 57 | if opts.dataset_dir==None: 58 | opts.dataset_dir=opts.root 59 | if opts.class_dict_road==None: 60 | opts.class_dict_road=opts.dataset_dir+'/class_dict.txt' 61 | if opts.ckpt == None: 62 | opts.ckpt = opts.root + '/Model' 63 | Path(opts.ckpt).mkdir(parents=True, exist_ok=True) 64 | if opts.outputimage == None: 65 | opts.outputimage = opts.root + '/OutImage' 66 | Path(opts.outputimage).mkdir(parents=True, exist_ok=True) 67 | 68 | return opts 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /T-SS-GLCNet/config/opt_road.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | ''' 3 | Default Config about cvpr_road数据集. 4 | ''' 5 | 6 | 7 | class DefaultConfig(object): 8 | env = 'result' # visdom 环境 9 | root = None 10 | dataset_dir = None 11 | ckpt = None 12 | outputimage = None 13 | benchmark = True 14 | deterministic = False 15 | enabled = True 16 | arch='resnet50' 17 | non_blocking=True 18 | pin_memory=True 19 | quick_train=True 20 | n_channels=None 21 | 22 | nodes=1 23 | ngpus_per_node=1 24 | world_size=1 25 | gpu=None 26 | #local_rank=-1 27 | 28 | pross_num = 16#pross_num_for_GLCNet 29 | 30 | 31 | continue_training = False 32 | 33 | # Optimiztion related arguments 34 | use_gpu = True # if use GPU 35 | ckpt_freq = 40 36 | save_log=True 37 | 38 | lr_decay = 0.98 # pre epoch 39 | weight_decay = 1e-5 # L2 loss 40 | loss = ['MSELoss', 'BCEWithLogitsLoss', 'CrossEntropyLoss', 'FocalLoss', 'SegmentationMultiLosses'] 41 | 42 | dtype = 'RGB' # default 43 | bl_dtype = [''] 44 | # Data related arguments 45 | num_workers = 8 # number of data loading workers# 46 | scnn_size = [32, 32] # 56-448 47 | input_size = (256, 256) # final input size of network(random-crop use this) 48 | crop_params = [256, 256, 128] # [H, W, stride] only used by slide-crop 49 | crop_mode = 'slide' # crop way of Val data, one of [random, slide] 50 | ont_hot = False # Is the output of data_loader one_hot type 51 | class_num=None 52 | print_freq = 20 # print info every N batch 53 | 54 | 55 | small_sample = False # 56 | class_dict_road = None 57 | environ = '0' # select gpu 58 | device_id = [0] # select gpu 59 | batch_size = 16 # batch size# 60 | start_epoch = 0 # use to continue_training 61 | cur_epoch = 0 62 | 63 | 64 | opt = DefaultConfig() 65 | -------------------------------------------------------------------------------- /T-SS-GLCNet/data_example/Potsdam/class_dict.txt: -------------------------------------------------------------------------------- 1 | Class Name,B,G,R 2 | 0,0,0,0 3 | 低矮植被,0,255,255 4 | 其它,255,255,255 5 | 树木,0,255,0 6 | 建筑物,255,0,0 7 | 不透水面,0,0,255 8 | 汽车,255,255,0 -------------------------------------------------------------------------------- /T-SS-GLCNet/data_example/Potsdam/trainR1_RGBIR.txt: -------------------------------------------------------------------------------- 1 | top_potsdam_7_8_RGBIR__519.tif 2 | top_potsdam_4_11_RGBIR__51.tif 3 | top_potsdam_6_7_RGBIR__285.tif 4 | top_potsdam_5_12_RGBIR__182.tif 5 | top_potsdam_2_11_RGBIR__546.tif 6 | top_potsdam_4_11_RGBIR__472.tif 7 | top_potsdam_6_10_RGBIR__188.tif 8 | top_potsdam_3_10_RGBIR__255.tif 9 | top_potsdam_7_11_RGBIR__335.tif 10 | top_potsdam_2_12_RGBIR__543.tif 11 | top_potsdam_2_10_RGBIR__465.tif 12 | top_potsdam_3_10_RGBIR__553.tif 13 | top_potsdam_3_10_RGBIR__239.tif 14 | top_potsdam_6_12_RGBIR__38.tif 15 | top_potsdam_6_12_RGBIR__333.tif 16 | top_potsdam_2_12_RGBIR__281.tif 17 | top_potsdam_6_12_RGBIR__454.tif 18 | top_potsdam_7_11_RGBIR__514.tif 19 | top_potsdam_7_7_RGBIR__159.tif 20 | top_potsdam_7_12_RGBIR__366.tif 21 | top_potsdam_3_12_RGBIR__127.tif 22 | top_potsdam_6_10_RGBIR__207.tif 23 | top_potsdam_2_12_RGBIR__475.tif 24 | top_potsdam_6_7_RGBIR__218.tif 25 | top_potsdam_6_9_RGBIR__555.tif 26 | top_potsdam_6_9_RGBIR__347.tif 27 | top_potsdam_4_11_RGBIR__136.tif 28 | top_potsdam_6_11_RGBIR__107.tif 29 | top_potsdam_2_11_RGBIR__430.tif 30 | top_potsdam_6_12_RGBIR__537.tif 31 | top_potsdam_2_12_RGBIR__491.tif 32 | top_potsdam_6_12_RGBIR__243.tif 33 | top_potsdam_5_12_RGBIR__112.tif 34 | top_potsdam_2_11_RGBIR__345.tif 35 | top_potsdam_3_11_RGBIR__362.tif 36 | top_potsdam_7_9_RGBIR__164.tif 37 | top_potsdam_5_11_RGBIR__514.tif 38 | top_potsdam_4_12_RGBIR__561.tif 39 | top_potsdam_2_11_RGBIR__225.tif 40 | top_potsdam_3_10_RGBIR__5.tif 41 | top_potsdam_7_8_RGBIR__549.tif 42 | top_potsdam_3_10_RGBIR__386.tif 43 | top_potsdam_5_12_RGBIR__178.tif 44 | top_potsdam_2_12_RGBIR__548.tif 45 | top_potsdam_3_10_RGBIR__541.tif 46 | top_potsdam_5_11_RGBIR__118.tif 47 | top_potsdam_6_11_RGBIR__22.tif 48 | top_potsdam_6_8_RGBIR__135.tif 49 | top_potsdam_5_11_RGBIR__304.tif 50 | top_potsdam_6_11_RGBIR__222.tif 51 | top_potsdam_3_11_RGBIR__58.tif 52 | top_potsdam_6_11_RGBIR__289.tif 53 | top_potsdam_6_10_RGBIR__478.tif 54 | top_potsdam_3_12_RGBIR__308.tif 55 | top_potsdam_3_12_RGBIR__438.tif 56 | top_potsdam_6_10_RGBIR__20.tif 57 | top_potsdam_3_11_RGBIR__137.tif 58 | top_potsdam_7_11_RGBIR__228.tif 59 | top_potsdam_4_11_RGBIR__153.tif 60 | top_potsdam_7_7_RGBIR__567.tif 61 | top_potsdam_7_11_RGBIR__272.tif 62 | top_potsdam_7_10_RGBIR__486.tif 63 | top_potsdam_3_10_RGBIR__168.tif 64 | top_potsdam_2_11_RGBIR__132.tif 65 | top_potsdam_5_12_RGBIR__429.tif 66 | top_potsdam_4_10_RGBIR__481.tif 67 | top_potsdam_4_10_RGBIR__376.tif 68 | top_potsdam_4_10_RGBIR__342.tif 69 | top_potsdam_2_10_RGBIR__19.tif 70 | top_potsdam_7_12_RGBIR__150.tif 71 | top_potsdam_7_10_RGBIR__248.tif 72 | top_potsdam_6_7_RGBIR__489.tif 73 | top_potsdam_7_9_RGBIR__403.tif 74 | top_potsdam_2_11_RGBIR__346.tif 75 | top_potsdam_4_11_RGBIR__102.tif 76 | top_potsdam_5_12_RGBIR__338.tif 77 | top_potsdam_4_12_RGBIR__133.tif 78 | top_potsdam_3_12_RGBIR__375.tif 79 | top_potsdam_2_12_RGBIR__186.tif 80 | top_potsdam_7_7_RGBIR__452.tif 81 | top_potsdam_7_11_RGBIR__284.tif 82 | top_potsdam_7_7_RGBIR__42.tif 83 | top_potsdam_2_12_RGBIR__406.tif 84 | top_potsdam_7_9_RGBIR__573.tif 85 | top_potsdam_7_7_RGBIR__545.tif 86 | top_potsdam_6_8_RGBIR__524.tif 87 | top_potsdam_6_8_RGBIR__1.tif 88 | top_potsdam_5_11_RGBIR__397.tif 89 | top_potsdam_7_12_RGBIR__356.tif 90 | top_potsdam_2_11_RGBIR__149.tif 91 | top_potsdam_2_11_RGBIR__127.tif 92 | top_potsdam_5_12_RGBIR__52.tif 93 | top_potsdam_6_7_RGBIR__233.tif 94 | top_potsdam_4_12_RGBIR__215.tif 95 | top_potsdam_2_12_RGBIR__264.tif 96 | top_potsdam_7_8_RGBIR__74.tif 97 | top_potsdam_7_10_RGBIR__302.tif 98 | top_potsdam_2_11_RGBIR__259.tif 99 | top_potsdam_2_10_RGBIR__351.tif 100 | top_potsdam_3_10_RGBIR__286.tif 101 | top_potsdam_2_11_RGBIR__242.tif 102 | top_potsdam_6_8_RGBIR__12.tif 103 | top_potsdam_6_9_RGBIR__160.tif 104 | top_potsdam_5_12_RGBIR__253.tif 105 | top_potsdam_6_10_RGBIR__13.tif 106 | top_potsdam_3_10_RGBIR__415.tif 107 | top_potsdam_6_8_RGBIR__553.tif 108 | top_potsdam_3_12_RGBIR__523.tif 109 | top_potsdam_6_7_RGBIR__111.tif 110 | top_potsdam_6_10_RGBIR__114.tif 111 | top_potsdam_7_9_RGBIR__464.tif 112 | top_potsdam_6_7_RGBIR__253.tif 113 | top_potsdam_5_11_RGBIR__507.tif 114 | top_potsdam_6_10_RGBIR__526.tif 115 | top_potsdam_6_10_RGBIR__350.tif 116 | top_potsdam_6_8_RGBIR__60.tif 117 | top_potsdam_6_7_RGBIR__491.tif 118 | top_potsdam_7_7_RGBIR__329.tif 119 | top_potsdam_6_10_RGBIR__37.tif 120 | top_potsdam_3_11_RGBIR__527.tif 121 | top_potsdam_4_12_RGBIR__381.tif 122 | top_potsdam_3_11_RGBIR__281.tif 123 | top_potsdam_7_11_RGBIR__209.tif 124 | top_potsdam_6_8_RGBIR__431.tif 125 | top_potsdam_3_11_RGBIR__25.tif 126 | top_potsdam_7_12_RGBIR__116.tif 127 | top_potsdam_2_11_RGBIR__532.tif 128 | top_potsdam_5_10_RGBIR__229.tif 129 | top_potsdam_2_11_RGBIR__276.tif 130 | top_potsdam_3_10_RGBIR__498.tif 131 | top_potsdam_4_12_RGBIR__218.tif 132 | top_potsdam_2_11_RGBIR__461.tif 133 | top_potsdam_7_10_RGBIR__117.tif 134 | top_potsdam_4_12_RGBIR__423.tif 135 | top_potsdam_3_11_RGBIR__172.tif 136 | top_potsdam_3_10_RGBIR__57.tif 137 | top_potsdam_3_11_RGBIR__190.tif 138 | top_potsdam_6_11_RGBIR__142.tif 139 | -------------------------------------------------------------------------------- /T-SS-GLCNet/data_example/Potsdam/trainR1_lbl.txt: -------------------------------------------------------------------------------- 1 | top_potsdam_7_8_RGBIR__519.tif 2 | top_potsdam_4_11_RGBIR__51.tif 3 | top_potsdam_6_7_RGBIR__285.tif 4 | top_potsdam_5_12_RGBIR__182.tif 5 | top_potsdam_2_11_RGBIR__546.tif 6 | top_potsdam_4_11_RGBIR__472.tif 7 | top_potsdam_6_10_RGBIR__188.tif 8 | top_potsdam_3_10_RGBIR__255.tif 9 | top_potsdam_7_11_RGBIR__335.tif 10 | top_potsdam_2_12_RGBIR__543.tif 11 | top_potsdam_2_10_RGBIR__465.tif 12 | top_potsdam_3_10_RGBIR__553.tif 13 | top_potsdam_3_10_RGBIR__239.tif 14 | top_potsdam_6_12_RGBIR__38.tif 15 | top_potsdam_6_12_RGBIR__333.tif 16 | top_potsdam_2_12_RGBIR__281.tif 17 | top_potsdam_6_12_RGBIR__454.tif 18 | top_potsdam_7_11_RGBIR__514.tif 19 | top_potsdam_7_7_RGBIR__159.tif 20 | top_potsdam_7_12_RGBIR__366.tif 21 | top_potsdam_3_12_RGBIR__127.tif 22 | top_potsdam_6_10_RGBIR__207.tif 23 | top_potsdam_2_12_RGBIR__475.tif 24 | top_potsdam_6_7_RGBIR__218.tif 25 | top_potsdam_6_9_RGBIR__555.tif 26 | top_potsdam_6_9_RGBIR__347.tif 27 | top_potsdam_4_11_RGBIR__136.tif 28 | top_potsdam_6_11_RGBIR__107.tif 29 | top_potsdam_2_11_RGBIR__430.tif 30 | top_potsdam_6_12_RGBIR__537.tif 31 | top_potsdam_2_12_RGBIR__491.tif 32 | top_potsdam_6_12_RGBIR__243.tif 33 | top_potsdam_5_12_RGBIR__112.tif 34 | top_potsdam_2_11_RGBIR__345.tif 35 | top_potsdam_3_11_RGBIR__362.tif 36 | top_potsdam_7_9_RGBIR__164.tif 37 | top_potsdam_5_11_RGBIR__514.tif 38 | top_potsdam_4_12_RGBIR__561.tif 39 | top_potsdam_2_11_RGBIR__225.tif 40 | top_potsdam_3_10_RGBIR__5.tif 41 | top_potsdam_7_8_RGBIR__549.tif 42 | top_potsdam_3_10_RGBIR__386.tif 43 | top_potsdam_5_12_RGBIR__178.tif 44 | top_potsdam_2_12_RGBIR__548.tif 45 | top_potsdam_3_10_RGBIR__541.tif 46 | top_potsdam_5_11_RGBIR__118.tif 47 | top_potsdam_6_11_RGBIR__22.tif 48 | top_potsdam_6_8_RGBIR__135.tif 49 | top_potsdam_5_11_RGBIR__304.tif 50 | top_potsdam_6_11_RGBIR__222.tif 51 | top_potsdam_3_11_RGBIR__58.tif 52 | top_potsdam_6_11_RGBIR__289.tif 53 | top_potsdam_6_10_RGBIR__478.tif 54 | top_potsdam_3_12_RGBIR__308.tif 55 | top_potsdam_3_12_RGBIR__438.tif 56 | top_potsdam_6_10_RGBIR__20.tif 57 | top_potsdam_3_11_RGBIR__137.tif 58 | top_potsdam_7_11_RGBIR__228.tif 59 | top_potsdam_4_11_RGBIR__153.tif 60 | top_potsdam_7_7_RGBIR__567.tif 61 | top_potsdam_7_11_RGBIR__272.tif 62 | top_potsdam_7_10_RGBIR__486.tif 63 | top_potsdam_3_10_RGBIR__168.tif 64 | top_potsdam_2_11_RGBIR__132.tif 65 | top_potsdam_5_12_RGBIR__429.tif 66 | top_potsdam_4_10_RGBIR__481.tif 67 | top_potsdam_4_10_RGBIR__376.tif 68 | top_potsdam_4_10_RGBIR__342.tif 69 | top_potsdam_2_10_RGBIR__19.tif 70 | top_potsdam_7_12_RGBIR__150.tif 71 | top_potsdam_7_10_RGBIR__248.tif 72 | top_potsdam_6_7_RGBIR__489.tif 73 | top_potsdam_7_9_RGBIR__403.tif 74 | top_potsdam_2_11_RGBIR__346.tif 75 | top_potsdam_4_11_RGBIR__102.tif 76 | top_potsdam_5_12_RGBIR__338.tif 77 | top_potsdam_4_12_RGBIR__133.tif 78 | top_potsdam_3_12_RGBIR__375.tif 79 | top_potsdam_2_12_RGBIR__186.tif 80 | top_potsdam_7_7_RGBIR__452.tif 81 | top_potsdam_7_11_RGBIR__284.tif 82 | top_potsdam_7_7_RGBIR__42.tif 83 | top_potsdam_2_12_RGBIR__406.tif 84 | top_potsdam_7_9_RGBIR__573.tif 85 | top_potsdam_7_7_RGBIR__545.tif 86 | top_potsdam_6_8_RGBIR__524.tif 87 | top_potsdam_6_8_RGBIR__1.tif 88 | top_potsdam_5_11_RGBIR__397.tif 89 | top_potsdam_7_12_RGBIR__356.tif 90 | top_potsdam_2_11_RGBIR__149.tif 91 | top_potsdam_2_11_RGBIR__127.tif 92 | top_potsdam_5_12_RGBIR__52.tif 93 | top_potsdam_6_7_RGBIR__233.tif 94 | top_potsdam_4_12_RGBIR__215.tif 95 | top_potsdam_2_12_RGBIR__264.tif 96 | top_potsdam_7_8_RGBIR__74.tif 97 | top_potsdam_7_10_RGBIR__302.tif 98 | top_potsdam_2_11_RGBIR__259.tif 99 | top_potsdam_2_10_RGBIR__351.tif 100 | top_potsdam_3_10_RGBIR__286.tif 101 | top_potsdam_2_11_RGBIR__242.tif 102 | top_potsdam_6_8_RGBIR__12.tif 103 | top_potsdam_6_9_RGBIR__160.tif 104 | top_potsdam_5_12_RGBIR__253.tif 105 | top_potsdam_6_10_RGBIR__13.tif 106 | top_potsdam_3_10_RGBIR__415.tif 107 | top_potsdam_6_8_RGBIR__553.tif 108 | top_potsdam_3_12_RGBIR__523.tif 109 | top_potsdam_6_7_RGBIR__111.tif 110 | top_potsdam_6_10_RGBIR__114.tif 111 | top_potsdam_7_9_RGBIR__464.tif 112 | top_potsdam_6_7_RGBIR__253.tif 113 | top_potsdam_5_11_RGBIR__507.tif 114 | top_potsdam_6_10_RGBIR__526.tif 115 | top_potsdam_6_10_RGBIR__350.tif 116 | top_potsdam_6_8_RGBIR__60.tif 117 | top_potsdam_6_7_RGBIR__491.tif 118 | top_potsdam_7_7_RGBIR__329.tif 119 | top_potsdam_6_10_RGBIR__37.tif 120 | top_potsdam_3_11_RGBIR__527.tif 121 | top_potsdam_4_12_RGBIR__381.tif 122 | top_potsdam_3_11_RGBIR__281.tif 123 | top_potsdam_7_11_RGBIR__209.tif 124 | top_potsdam_6_8_RGBIR__431.tif 125 | top_potsdam_3_11_RGBIR__25.tif 126 | top_potsdam_7_12_RGBIR__116.tif 127 | top_potsdam_2_11_RGBIR__532.tif 128 | top_potsdam_5_10_RGBIR__229.tif 129 | top_potsdam_2_11_RGBIR__276.tif 130 | top_potsdam_3_10_RGBIR__498.tif 131 | top_potsdam_4_12_RGBIR__218.tif 132 | top_potsdam_2_11_RGBIR__461.tif 133 | top_potsdam_7_10_RGBIR__117.tif 134 | top_potsdam_4_12_RGBIR__423.tif 135 | top_potsdam_3_11_RGBIR__172.tif 136 | top_potsdam_3_10_RGBIR__57.tif 137 | top_potsdam_3_11_RGBIR__190.tif 138 | top_potsdam_6_11_RGBIR__142.tif 139 | -------------------------------------------------------------------------------- /T-SS-GLCNet/data_example/Potsdam/train_RGBIR/top_potsdam_2_10_RGBIR__36.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/G-RSIM/518efd3776d8e4937b0a4c34137a5bbda4b4f8a1/T-SS-GLCNet/data_example/Potsdam/train_RGBIR/top_potsdam_2_10_RGBIR__36.tif -------------------------------------------------------------------------------- /T-SS-GLCNet/data_example/Potsdam/train_lbl/top_potsdam_2_10_RGBIR__36.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/G-RSIM/518efd3776d8e4937b0a4c34137a5bbda4b4f8a1/T-SS-GLCNet/data_example/Potsdam/train_lbl/top_potsdam_2_10_RGBIR__36.tif -------------------------------------------------------------------------------- /T-SS-GLCNet/data_example/Potsdam/val_RGBIR/top_potsdam_2_13_RGBIR__26.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/G-RSIM/518efd3776d8e4937b0a4c34137a5bbda4b4f8a1/T-SS-GLCNet/data_example/Potsdam/val_RGBIR/top_potsdam_2_13_RGBIR__26.tif -------------------------------------------------------------------------------- /T-SS-GLCNet/data_example/Potsdam/val_lbl/top_potsdam_2_13_RGBIR__26.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/G-RSIM/518efd3776d8e4937b0a4c34137a5bbda4b4f8a1/T-SS-GLCNet/data_example/Potsdam/val_lbl/top_potsdam_2_13_RGBIR__26.tif -------------------------------------------------------------------------------- /T-SS-GLCNet/dl_tools/basictools/dlbasic.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 深度学习基础工具 3 | ''' 4 | # import sys 5 | # import os 6 | # import cv2 7 | import numpy as np 8 | # from scipy import misc 9 | # import matplotlib.pyplot as plt 10 | # import re 11 | 12 | 13 | # ********************************************** 14 | # ********** Result visualization ************ 15 | # ********************************************** 16 | # def visualize_feature_map(FM, outdir, num=1, name=None): 17 | # ''' 18 | # Parse the input FeatureMap and do sample visualization 19 | # Args: 20 | # input: Tensor of feature map. 21 | # outdir: Dir of ouput image saving. 22 | 23 | # ''' 24 | # FM = np.array(FM) 25 | # if len(FM.shape) == 5: # [11HWC] -> [1HWC] 26 | # FM = FM[0] 27 | # if len(FM.shape) == 4: # [1HWC] -> [HWC] 28 | # FM = FM[0] 29 | 30 | # h, w, c = FM.shape 31 | # # 归一化 0-1 32 | # tp_min, tp_max = np.min(FM), np.max(FM) 33 | # if tp_min < 0: 34 | # FM += abs(tp_min) 35 | # FM /= (tp_max - tp_min) 36 | 37 | # # Iterate filters 38 | # for i in range(c): 39 | # fig = plt.figure(figsize=(12, 12)) 40 | # axes = fig.add_subplot(111) 41 | # img = FM[:, :, i] 42 | # # Toning color 43 | # axes.imshow(img, vmin=0, vmax=0.9, interpolation='bicubic', cmap='coolwarm') 44 | # # Remove any labels from the axes 45 | # # axes.set_xticks([]) 46 | # # axes.set_yticks([]) 47 | 48 | # # Save figure 49 | # if name is None: 50 | # # misc.imsave(outdir + '/%d_%.3d.png' % (num, i), FM[:, :, i]) 51 | # plt.savefig(outdir + '/%d_%.3d.png' % (num, i), dpi=60, bbox_inches='tight') 52 | # else: 53 | # # misc.imsave(outdir + '/%s_%.3d.png' % (name, i), FM[:, :, i]) 54 | # plt.savefig(outdir + '/%s_%.3d.png' % (name, i), dpi=60, bbox_inches='tight') 55 | # plt.close(fig) 56 | 57 | 58 | if __name__ == '__main__': 59 | # in_dir = '/home/tao/Data/VOCdevkit2007/VOC2007/small/train' 60 | # rename_files(in_dir) 61 | # mkdir_of_dataset('/home/tao/Data/XM') 62 | pass 63 | -------------------------------------------------------------------------------- /T-SS-GLCNet/dl_tools/basictools/dltrain.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Tools collection of operations during training. 3 | The icluded functions are: 4 | def train_log() 5 | def draw_loss_curve() 6 | def memory_watcher() 7 | def Accuracy Evaluation[serval functions] 8 | 9 | Version 1.0 2018-04-02 22:44:32 10 | by QiJi Refence: https://github.com/GeorgeSeif/Semantic-Segmentation-Suite 11 | Version 2.0 2018-10-29 09:10:41 12 | by QiJi 13 | TODO: 14 | 1. 精度评定相关的函数还没Debug 15 | 16 | ''' 17 | import datetime 18 | import os 19 | import sys 20 | 21 | import numpy as np 22 | # from matplotlib import pyplot as plt 23 | 24 | 25 | # def train_log(X, f=None): 26 | # ''' Print with time. To console or a file(f) ''' 27 | # time_stamp = datetime.datetime.now().strftime("[%Y-%m-%d %H:%M:%S]") 28 | # if not f: 29 | # sys.stdout.write(time_stamp + " " + X) 30 | # sys.stdout.flush() 31 | # else: 32 | # f.write(time_stamp + " " + X) 33 | 34 | 35 | # def draw_loss_curve(epochs, loss, path): 36 | # ''' Paint loss curve ''' 37 | # fig = plt.figure(figsize=(12, 9)) 38 | # ax1 = fig.add_subplot(111) 39 | # ax1.plot(range(epochs), loss) 40 | # ax1.set_title("Average loss vs epochs") 41 | # ax1.set_xlabel("Epoch") 42 | # ax1.set_ylabel("Current loss") 43 | # plt.savefig(path + '/loss_vs_epochs.png') 44 | # plt.close(fig) 45 | 46 | 47 | def memory_watcher(): 48 | ''' Compute the memory usage, for debugging ''' 49 | import psutil 50 | pid = os.getpid() 51 | py = psutil.Process(pid) 52 | memoryUse = py.memory_info()[0] / 2.**30 # Memory use in GB 53 | print('Memory usage in GBs:', memoryUse) 54 | 55 | 56 | # ********************************************** 57 | # *********** Accuracy Evaluation ************** 58 | # ********************************************** 59 | def compute_global_accuracy(pred, label): 60 | ''' 61 | Compute the average segmentation accuracy across all classes, 62 | Input [HW] or [HWC] label 63 | ''' 64 | count_mat = pred == label 65 | return np.sum(count_mat) / np.prod(count_mat.shape) 66 | 67 | 68 | # def compute_class_accuracies(y_pred, y_true, num_classes): 69 | # ''' Compute the class-specific segmentation accuracy ''' 70 | # # 只能用于计算单张图精度,多张图需要连接处理(计算total时) 71 | # w = y_true.shape[0] 72 | # h = y_true.shape[1] 73 | # flat_image = np.reshape(y_true, w * h) 74 | # total = [] 75 | # for val in range(num_classes): 76 | # total.append((flat_image == val).sum()) 77 | 78 | # count = [0.0] * num_classes 79 | # for i in range(w): 80 | # for j in range(h): 81 | # if y_pred[i, j] == y_true[i, j]: 82 | # count[int(y_pred[i, j])] = count[int(y_pred[i, j])] + 1.0 83 | # # If there are no pixels from a certain class in the GT, it returns NAN 84 | # # because of divide by zero, Replace the nans with a 1.0. 85 | # accuracies = [] 86 | # for i in range(len(total)): 87 | # if total[i] == 0: 88 | # accuracies.append(1.0) 89 | # else: 90 | # accuracies.append(count[i] / total[i]) 91 | 92 | # return accuracies 93 | 94 | def compute_class_accuracies(y_pred, y_true, num_classes, total): 95 | ''' Compute the class-specific segmentation accuracy ''' 96 | # 只能用于计算单张图精度,多张图需要连接处理(计算total时) 97 | w = y_true.shape[0] 98 | h = y_true.shape[1] 99 | # flat_image = np.reshape(y_true, w * h) 100 | # total = [] 101 | # for val in range(num_classes): 102 | # total.append((flat_image == val).sum()) 103 | 104 | count = [0.0] * num_classes 105 | for i in range(w): 106 | for j in range(h): 107 | if y_pred[i, j] == y_true[i, j]: 108 | count[int(y_pred[i, j])] = count[int(y_pred[i, j])] + 1.0 109 | # If there are no pixels from a certain class in the GT, it returns NAN 110 | # because of divide by zero, Replace the nans with a 1.0. 111 | accuracies = [] 112 | for i in range(len(total)): 113 | if total[i] == 0: 114 | accuracies.append(0) # c 115 | else: 116 | accuracies.append(count[i] / total[i]) 117 | 118 | return accuracies 119 | 120 | 121 | def precision(pred, label): 122 | ''' 123 | Compute precision 124 | TODO: Only for 2 class now. 125 | ''' 126 | TP = np.float(np.count_nonzero(pred * label)) 127 | FP = np.float(np.count_nonzero(pred * (label - 1))) 128 | prec = TP / (TP + FP) 129 | return prec 130 | 131 | 132 | def recall(pred, label): 133 | ''' 134 | Compute recall. 135 | TODO: Only for 2 class now. 136 | ''' 137 | TP = np.float(np.count_nonzero(pred * label)) 138 | FN = np.float(np.count_nonzero((pred - 1) * label)) 139 | rec = TP / (TP + FN) 140 | return rec 141 | 142 | 143 | def f1score(pred, label): 144 | ''' Compute f1 score ''' 145 | prec = precision(pred, label) 146 | rec = recall(pred, label) 147 | f1 = np.divide(2 * prec * rec, (prec + rec)) 148 | return f1 149 | 150 | 151 | def compute_class_iou(pred, gt, num_classes): 152 | ''' 153 | Args: 154 | pred: Predict label [HW]. 155 | gt: Ground truth label [HW]. 156 | Return: 157 | (每一类的)intersection and union list. 158 | ''' 159 | # If there are no pixels from a certain class in the GT, it returns NAN 160 | # because of divide by zero, the class_iou will be 0 ,we 161 | # pred = pred[:gt.shape[0], :gt.shape[1]] 162 | intersection = np.zeros(num_classes) 163 | union = np.zeros(num_classes) 164 | for i in range(num_classes): 165 | pred_i = pred == i 166 | label_i = gt == i 167 | # label_i.resize(pred.shape[0], pred.shape[1]) # 168 | pred_i = pred_i[:label_i.shape[0], :label_i.shape[1]] 169 | intersection[i] = float(np.sum(np.logical_and(label_i, pred_i))) 170 | union[i] = float(np.sum(np.logical_or(label_i, pred_i)) + 1e-8) 171 | class_iou = intersection / union 172 | return class_iou 173 | 174 | 175 | def evaluate_segmentation(pred, gt, num_classes): 176 | ''' 177 | Evaluate Segmentation result 178 | 179 | Args: 180 | pred: Predict label [HW]. 181 | gt: Ground truth label [HW]. 182 | num_classes: Num of classes. 183 | Returns: 184 | accuracy: 185 | class_accuracies: 186 | prec: 187 | rec: 188 | f1: 189 | iou: 190 | ''' 191 | accuracy = compute_global_accuracy(pred, gt) 192 | class_accuracies = compute_class_accuracies(pred, gt, num_classes) 193 | prec = precision(pred, gt) 194 | rec = recall(pred, gt) 195 | f1 = f1score(pred, gt) 196 | i, u = compute_class_iou(pred, gt, num_classes) 197 | iou = np.mean(i / u) 198 | return accuracy, class_accuracies, prec, rec, f1, iou 199 | 200 | 201 | # ********************************************** 202 | # *********** ************** 203 | # ********************************************** 204 | -------------------------------------------------------------------------------- /T-SS-GLCNet/dl_tools/basictools/fileop.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | ''' 3 | Files tools. 4 | 5 | Version 1.0 2018-10-25 16:36:55 6 | by QiJi Refence: 7 | TODO: 8 | 1. xxx 9 | 10 | ''' 11 | # import sys 12 | import os 13 | import re 14 | 15 | 16 | # ********************************************** 17 | # *************** **************** 18 | # ********************************************** 19 | def rename_file(file_name, ifID=0, addstr=None, extension=None): 20 | '''Rename a file. 21 | Args: 22 | file_name: The name/path of file. 23 | ifID: 1 - only keep the number(ID) in old_name 24 | Carefully! if only keep ID, file_name can't be path. 25 | addstr: The addition str add between name and extension 26 | extension: Set the new extension(kind of image, such as: 'png'). 27 | ''' 28 | savename, extn = os.path.splitext(file_name) # extn content '.' 29 | if ifID: 30 | # file_path = os.path.dirname(full_name) 31 | ID_nums = re.findall(r"\d+", savename) 32 | ID_str = str(ID_nums[0]) 33 | for i in range(len(ID_nums)-1): 34 | ID_str += ('_'+(ID_nums[i+1])) 35 | savename = ID_str 36 | 37 | if addstr is not None: 38 | savename += '_' + addstr 39 | 40 | if extension is not None: 41 | extn = '.' + extension 42 | 43 | return savename + extn 44 | 45 | 46 | def mkdir_of_dataset(data_dir): 47 | ''' 48 | Create folders of DL datasets according to the standard structure: 49 | ├── "dataset_name"(data_dir) 50 | | ├── Log: log the train details 51 | | ├── Model 52 | | ├── checkpoints 53 | | ├── SAVE: save some checkpoints and results 54 | | ├── BackUp: backup some data or code 55 | | ├── train 56 | | ├── train_labels 57 | | ├── val 58 | | ├── val_labels 59 | | ├── test 60 | | ├── test_labels 61 | ''' 62 | dir_list = ['/Log', '/Model', '/Model/checkpoints', 63 | '/SAVE', '/BackUp', '/train', '/train_labels', 64 | '/val', '/val_labels', '/test', '/test_labels'] 65 | for a_dir in dir_list: 66 | if not os.path.exists(data_dir+a_dir): 67 | os.mkdir(data_dir+a_dir) 68 | print('make %s' % (a_dir)) 69 | 70 | 71 | def mkdir_of_classifyresult(out_dir, class_list): 72 | ''' 73 | Make dir of (classification)reslut 74 | Args: 75 | out_dir: Dir of output 76 | class_list: list of class name 77 | ''' 78 | for a_class in class_list: 79 | a_dir = out_dir + "/" + a_class 80 | if os.path.exists(a_dir): 81 | os.mkdir(a_dir) 82 | 83 | 84 | def filelist(floder_dir, ifPath=False, extension=None): 85 | ''' 86 | Get names(or whole path) of all files(with specify extension) 87 | in the floder_dir and return as a list. 88 | 89 | Args: 90 | floder_dir: The dir of the floder_dir. 91 | ifPath: 92 | True - Return whole path of files. 93 | False - Only return name of files.(Defualt) 94 | extension: Specify extension to only get that kind of file names. 95 | 96 | Returns: 97 | namelist: Name(or path) list of all files(with specify extension) 98 | ''' 99 | namelist = sorted(os.listdir(floder_dir)) 100 | 101 | if ifPath: 102 | for i in range(len(namelist)): 103 | namelist[i] = os.path.join(floder_dir, namelist[i]) 104 | 105 | if extension is not None: 106 | n = len(namelist)-1 # orignal len of namelist 107 | for i in range(len(namelist)): 108 | if not namelist[n-i].endswith(extension): 109 | namelist.remove(namelist[n-i]) # discard the files with other extension 110 | 111 | return namelist 112 | 113 | 114 | def filepath_to_name(full_name, extension=False): 115 | ''' 116 | Takes an absolute file path and returns the name of the file with(out) the extension. 117 | ''' 118 | file_name = os.path.basename(full_name) 119 | if not extension: # if False then discard extension 120 | file_name = os.path.splitext(file_name)[0] 121 | return file_name 122 | 123 | 124 | # ********************************************** 125 | # ************ Main functions ****************** 126 | # ********************************************** 127 | def rename_files(input_dir, out_dir=None, num=1): 128 | ''' Rename all the files in a floder by number. 129 | Args: 130 | input_dir: Original folder directory 131 | out_dir: Renamed files output directory(optional) 132 | num: The starting number of renamed files(default=1) 133 | ''' 134 | # Log the rename record 135 | folder_dir = os.path.dirname(input_dir) 136 | folder_name = os.path.basename(input_dir) 137 | target = open(folder_dir+"/ReName Log_%s.txt" % folder_name, 'w') 138 | 139 | file_names = sorted(os.listdir(input_dir)) 140 | extension = os.path.splitext(file_names[0])[1] # Plan A 141 | if out_dir is None: 142 | out_dir = input_dir 143 | for name in file_names: 144 | newname = ("%.5d" % (num))+extension # Plan A 145 | # newname = rename_file(name, ifID=1) # Plan B 146 | # TODO: 文件名中文乱码,下同 147 | os.rename(input_dir+'/'+name, input_dir+'/'+newname) 148 | target.write(name + '\tTo\t' + newname + '\n') 149 | num += 1 150 | 151 | target.close() 152 | print("Finish rename.") 153 | 154 | 155 | def main(): 156 | pass 157 | 158 | 159 | if __name__ == '__main__': 160 | # main() 161 | # mkdir_of_dataset('/home/tao/Data/RBDD') 162 | pass 163 | -------------------------------------------------------------------------------- /T-SS-GLCNet/evalute.py: -------------------------------------------------------------------------------- 1 | # from dltrain import fast_hist 2 | import cv2 3 | # from dldata import get_label_info 4 | import numpy as np 5 | import os 6 | import scipy 7 | # import sklearn 8 | # from sklearn.metrics import confusion_matrix 9 | from sklearn.metrics import confusion_matrix as c_matrix 10 | # # path = 'D:/Data_Lib/Seg/MeiTB/1007and2008label' 11 | # label_truth = cv2.imread(path + '/' + '2008test_5label_color.tif') 12 | # label_pred = cv2.imread(path + '/' + '2008_superclass55.tif') 13 | # hist = fast_hist(label_truth, label_pred, 5) 14 | 15 | # class_names, label_values = get_label_info(root+'/class_dict.txt') 16 | # class_num = len(class_names)-1 17 | 18 | 19 | # def class_label(label, label_values): 20 | # ''' 21 | # Convert RGB label to 2D [HW] array, each pixel value is the classified class key. 22 | # ''' 23 | # semantic_map = np.zeros(label.shape[:2], label.dtype) 24 | # for i in range(len(label_values)): 25 | # equality = np.equal(label, label_values[i]) 26 | # class_map = np.all(equality, axis=-1) 27 | # semantic_map[class_map] = i 28 | # return semantic_map 29 | 30 | 31 | def hist(pred, truth, class_num): 32 | hist_m = np.zeros([class_num, class_num], dtype=np.int64) 33 | # for i in range(truth.shape[0]): 34 | # for j in range(truth.shape[1]): 35 | # col = truth[i][j] 36 | # row = pred[i][j] 37 | # hist_m[row][col] += 1 38 | flat_pred = pred.flatten() 39 | flat_true = truth.flatten() 40 | label_class = [x for x in range(class_num)] 41 | 42 | hist_m = c_matrix(flat_true, flat_pred, label_class) 43 | return hist_m 44 | 45 | 46 | def matrix(preddataset_path, truthdataset_path, class_Num): 47 | pred_names = sorted(os.listdir(preddataset_path)) 48 | truth_names = sorted(os.listdir(truthdataset_path)) 49 | matrix1 = np.zeros([class_Num, class_Num], dtype=np.int64) 50 | for i in range(len(pred_names)): 51 | pred = cv2.imread(preddataset_path + '/' + pred_names[i]) 52 | truth = cv2.imread(truthdataset_path + '/' + truth_names[i]) 53 | truth[truth == 255] = 0 54 | # truth[truth == 6] = 1 55 | # truth[truth == 7] = 2 56 | # truth[truth == 8] = 3 57 | # truth[truth == 9] = 4 58 | matrix1 += hist(pred[:, :, 0], truth[:, :, 0], class_Num) 59 | return matrix1 60 | 61 | 62 | def get_scores(hist=None): 63 | """Returns accuracy score evaluation result. 64 | - Overall Acc 65 | - Class Acc 66 | - Mean Acc 67 | """ 68 | # hist = self.confusion_matrix if hist is None else hist 69 | # Overall accuracy 70 | # hist1 = hist[:, 1:] 71 | # hist = hist1[1:, :] 72 | #acc_overall = np.diag(hist)[1:].sum() / (hist.sum()-np.diag(hist)[0]+1e-8) 73 | acc_overall = np.diag(hist).sum() / (hist.sum()+1e-8) 74 | # Class accuracy 75 | acc_cls = np.diag(hist) / (hist.sum(axis=0)+1e-8) # acc per class 76 | recall = np.diag(hist) / (hist.sum(axis=1)+1e-8) 77 | F1_score = (2*acc_cls*recall)/(recall+acc_cls) 78 | # Class average accuracy 79 | acc_cls_avg = np.nanmean(acc_cls) 80 | # Kappa 81 | n = hist.sum() 82 | p0 = hist.diagonal().sum() 83 | p1 = hist.sum(0) 84 | p2 = hist.sum(1) 85 | kappa = float(n*p0-np.inner(p1, p2)) / float(n*n - np.inner(p1, p2) + 1e-8) 86 | 87 | # print('\n------Class Acc\n') 88 | # print(acc_cls) 89 | # print('\n------recall\n') 90 | # print(recall) 91 | # print('\n------F1_score\n') 92 | # print(F1_score) 93 | # print('\n------Hist\n') 94 | # print(hist) 95 | # print('\n------kappa') 96 | # print(kappa) 97 | # print('-----Overall Acc') 98 | # print(acc_overall) 99 | # print('-----Mean Acc\n') 100 | # print(acc_cls_avg) 101 | 102 | return ( 103 | { 104 | "Hist": hist, # 混淆矩阵 105 | "Kappa": kappa, 106 | "Overall Acc": acc_overall, 107 | 108 | "Class Acc": acc_cls, # 类别精度 109 | "recall": recall, 110 | "F1_score": F1_score, 111 | "Mean Acc": acc_cls_avg, 112 | } # Return as a dictionary 113 | ) 114 | 115 | 116 | # truth = class_label(label_truth, label_values) 117 | # pred = class_label(label_pred, label_values) 118 | 119 | 120 | # confusion_matrix 121 | # root = '/project/ytwang/yzw/DeeplabAttASPP/Data' 122 | # # datasetname_list = ['Deeplabv3+_45_pre1','Deeplabv3+_47_pre1','Deeplabv3+_48_pre1_1','Deeplabv3+_48_pre1'] 123 | # datasetname_list = ['Deeplabv3+_48_pre1'] 124 | # truthset_path = root + 'val_labels' 125 | # class_num = 10 126 | # for i in range(1): 127 | # # predset_path = root + '/' + datasetname_list[i] 128 | # confusion_matrix = matrix(predset_path, truthset_path, class_num) 129 | # print(confusion_matrix) 130 | # user = np.zeros([1, class_num], dtype=np.int64) 131 | # prod = np.zeros([1, class_num], dtype=np.int64) 132 | # for i in range(class_num): 133 | # for j in range(class_num): 134 | # a = confusion_matrix[i][j] 135 | # user[0][i] += a 136 | 137 | # # user = np.zeros([1, 5]) 138 | # for i in range(class_num): 139 | # for j in range(class_num): 140 | # prod[0][i] += confusion_matrix[j][i] 141 | # # p:precision/prod_acc n:recall/user_acc 142 | # user_acc_recall = np.diag(confusion_matrix)/user 143 | # prod_acc_presion = np.diag(confusion_matrix)/prod 144 | # F1_score = (2*prod_acc_presion*user_acc_recall)/(user_acc_recall+prod_acc_presion) 145 | # print('|||||||||||||||------------------------------------------------------') 146 | # print('recall/r') 147 | # print(user_acc_recall) 148 | # print('class_precision/r') 149 | # print(prod_acc_presion) 150 | # print('F1_score/r') 151 | # print(F1_score) 152 | # # kappa = get_scores(confusion_matrix) 153 | # print(kappa) 154 | # print('------------------------------------------------------||||||||||||||||') -------------------------------------------------------------------------------- /T-SS-GLCNet/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | ''' 3 | Abstract. 4 | 5 | Version 1.0 2019-12-06 22:17:13 by QiJi 6 | TODO: 7 | 1. 目前只改了DinkNet50的 first_conv,其余的都还得修改 8 | 9 | ''' 10 | from .import v3p,v3p_SimCLR_encoder,GLCNet,v3p_inpaiting_encoder,v3p_jigsaw_encoder,v3p_mocov2_encoder_true 11 | 12 | 13 | #from torchvision import models 14 | 15 | 16 | def build_model(args): 17 | ''' 18 | Args: 19 | net_num: 6位数, 20 | 第1位代表main model的arch: 21 | 0 - Scene_Base 22 | 第2位代表Backbone 的arch: 23 | 1 - resnet34 24 | 2 - resnet50 25 | 3 - resnet101 26 | 4 - vgg16 27 | 5 - googlenet 28 | ''' 29 | if args.self_mode==1: 30 | model=GLCNet.build_model(num_classes=64, in_channels=args.n_channels, pretrained=False, arch=args.arch, patch_num=args.patch_num, 31 | patch_size=args.patch_size, patch_out_channel=False, pross_num=args.pross_num) # pretrained=False, 32 | elif args.self_mode == 11: 33 | model = GLCNet.build_model(num_classes=64, in_channels=args.n_channels, pretrained=False, arch=args.arch, 34 | patch_num=args.patch_num, 35 | patch_size=args.patch_size, patch_out_channel=False,noStyle=True, 36 | pross_num=args.pross_num) # pretrained=False, 37 | elif args.self_mode == 12: 38 | model = GLCNet.build_model(num_classes=64, in_channels=args.n_channels, pretrained=False, arch=args.arch, 39 | patch_num=args.patch_num, 40 | patch_size=args.patch_size, patch_out_channel=False,noGlobal=True, 41 | pross_num=args.pross_num) # pretrained=False, 42 | 43 | elif args.self_mode == 13: 44 | model = GLCNet.build_model(num_classes=64, in_channels=args.n_channels, pretrained=False, arch=args.arch, 45 | patch_num=args.patch_num, 46 | patch_size=args.patch_size, patch_out_channel=False,noLocal=True, 47 | pross_num=args.pross_num) # pretrained=False, 48 | elif args.self_mode==2: 49 | model = v3p_SimCLR_encoder.build_model(num_classes=64, in_channels=args.n_channels, pretrained=False, 50 | arch=args.arch) # pretrained=False, 51 | elif args.self_mode==3: 52 | m = 0.999 53 | K = 2048 54 | model = v3p_mocov2_encoder_true.build_model(num_classes=64, in_channels=args.n_channels, pretrained=False, 55 | arch=args.arch, m=m, 56 | K=K) # pretrained=False, 57 | elif args.self_mode==4: 58 | model = v3p_inpaiting_encoder.build_model(in_channels=args.n_channels, pretrained=False, 59 | arch=args.arch) # pretrained=False, 60 | elif args.self_mode==5: 61 | model=v3p_jigsaw_encoder.build_model(num_classes=1000, in_channels=args.n_channels, pretrained=False, 62 | arch=args.arch) 63 | elif args.self_mode==0: 64 | model = v3p.build_model(num_classes=args.class_num, in_channels=args.n_channels, pretrained=False, 65 | arch=args.arch) # num_classes1=50, 66 | 67 | if hasattr(model, 'model_name'): 68 | print(model.model_name) 69 | return model 70 | -------------------------------------------------------------------------------- /T-SS-GLCNet/models/deeplab_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # import sys 3 | 4 | # path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | # for py in [f[:-3] for f in os.listdir(path) if f.endswith('.py') and f != '__init__.py']: 7 | # mod = __import__('.'.join([__name__, py]), fromlist=[py]) 8 | # classes = [getattr(mod, x) for x in dir(mod) if isinstance(getattr(mod, x), type)] 9 | # for cls in classes: 10 | # setattr(sys.modules[__name__], cls.__name__, cls) 11 | -------------------------------------------------------------------------------- /T-SS-GLCNet/models/deeplab_utils/decoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/9/19 17:30 3 | # @Author : HLin 4 | # @Email : linhua2017@ia.ac.cn 5 | # @File : decoder.py 6 | # @Software: PyCharm 7 | 8 | import os 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | # from torchsummary import summary 13 | from .ResNet101 import resnet101 14 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 15 | 16 | import sys 17 | sys.path.append(os.path.abspath('..')) 18 | 19 | from .encoder import Encoder 20 | 21 | 22 | class Decoder(nn.Module): 23 | def __init__(self, class_num, bn_momentum=0.1): 24 | super(Decoder, self).__init__() 25 | self.conv1 = nn.Conv2d(256, 48, kernel_size=1, bias=False) 26 | self.bn1 = SynchronizedBatchNorm2d(48, momentum=bn_momentum) 27 | self.relu = nn.ReLU() 28 | # self.conv2 = SeparableConv2d(304, 256, kernel_size=3) 29 | # self.conv3 = SeparableConv2d(256, 256, kernel_size=3) 30 | self.conv2 = nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False) 31 | self.bn2 = SynchronizedBatchNorm2d(256, momentum=bn_momentum) 32 | self.dropout2 = nn.Dropout(0.5) 33 | self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False) 34 | self.bn3 = SynchronizedBatchNorm2d(256, momentum=bn_momentum) 35 | self.dropout3 = nn.Dropout(0.1) 36 | self.conv4 = nn.Conv2d(256, class_num, kernel_size=1) 37 | 38 | self._init_weight() 39 | 40 | def forward(self, x, low_level_feature): 41 | low_level_feature = self.conv1(low_level_feature) 42 | low_level_feature = self.bn1(low_level_feature) 43 | low_level_feature = self.relu(low_level_feature) 44 | x_4 = F.interpolate( 45 | x, 46 | size=low_level_feature.size()[2:4], 47 | mode='bilinear', 48 | align_corners=True) 49 | x_4_cat = torch.cat((x_4, low_level_feature), dim=1) 50 | x_4_cat = self.conv2(x_4_cat) 51 | x_4_cat = self.bn2(x_4_cat) 52 | x_4_cat = self.relu(x_4_cat) 53 | x_4_cat = self.dropout2(x_4_cat) 54 | x_4_cat = self.conv3(x_4_cat) 55 | x_4_cat = self.bn3(x_4_cat) 56 | x_4_cat = self.relu(x_4_cat) 57 | x_4_cat = self.dropout3(x_4_cat) 58 | x_4_cat = self.conv4(x_4_cat) 59 | 60 | return x_4_cat 61 | 62 | def _init_weight(self): 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv2d): 65 | torch.nn.init.kaiming_normal_(m.weight) 66 | elif isinstance(m, SynchronizedBatchNorm2d): 67 | m.weight.data.fill_(1) 68 | m.bias.data.zero_() 69 | 70 | 71 | class DeepLab(nn.Module): 72 | def __init__(self, 73 | output_stride, 74 | class_num, 75 | pretrained, 76 | bn_momentum=0.1, 77 | freeze_bn=False): 78 | super(DeepLab, self).__init__() 79 | self.Resnet101 = resnet101(bn_momentum, pretrained) 80 | self.encoder = Encoder(bn_momentum, output_stride) 81 | self.decoder = Decoder(class_num, bn_momentum) 82 | if freeze_bn: 83 | self.freeze_bn() 84 | print("freeze bacth normalization successfully!") 85 | 86 | def forward(self, input): 87 | x, low_level_features = self.Resnet101(input) 88 | 89 | x = self.encoder(x) 90 | predict = self.decoder(x, low_level_features) 91 | output = F.interpolate( 92 | predict, 93 | size=input.size()[2:4], 94 | mode='bilinear', 95 | align_corners=True) 96 | return output 97 | 98 | def freeze_bn(self): 99 | for m in self.modules(): 100 | if isinstance(m, SynchronizedBatchNorm2d): 101 | m.eval() 102 | 103 | 104 | if __name__ == "__main__": 105 | model = DeepLab( 106 | output_stride=16, class_num=21, pretrained=False, freeze_bn=False) 107 | model.eval() 108 | # print(model) 109 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 110 | # model = model.to(device) 111 | # summary(model, (3, 513, 513)) 112 | # for m in model.named_modules(): 113 | for m in model.modules(): 114 | if isinstance(m, SynchronizedBatchNorm2d): 115 | print(m) 116 | -------------------------------------------------------------------------------- /T-SS-GLCNet/models/deeplab_utils/encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/9/19 16:56 3 | # @Author : HLin 4 | # @Email : linhua2017@ia.ac.cn 5 | # @File : encoder.py 6 | # @Software: PyCharm 7 | 8 | import os 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | # from torchsummary import summary 13 | # import torchvision 14 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 15 | 16 | import sys 17 | sys.path.append(os.path.abspath('..')) 18 | 19 | 20 | def _AsppConv(in_channels, 21 | out_channels, 22 | kernel_size, 23 | stride=1, 24 | padding=0, 25 | dilation=1, 26 | bn_momentum=0.1): 27 | asppconv = nn.Sequential( 28 | nn.Conv2d( 29 | in_channels, 30 | out_channels, 31 | kernel_size, 32 | stride, 33 | padding, 34 | dilation, 35 | bias=False), 36 | SynchronizedBatchNorm2d(out_channels, momentum=bn_momentum), nn.ReLU()) 37 | return asppconv 38 | 39 | 40 | class AsppModule(nn.Module): 41 | def __init__(self, bn_momentum=0.1, output_stride=16): 42 | super(AsppModule, self).__init__() 43 | 44 | # output_stride choice 45 | if output_stride == 16: 46 | atrous_rates = [0, 6, 12, 18] 47 | elif output_stride == 8: 48 | atrous_rates = 2 * [0, 3, 5, 7] 49 | else: 50 | raise Warning("output_stride must be 8 or 16!") 51 | # atrous_spatial_pyramid_pooling part 52 | self._atrous_convolution1 = _AsppConv( 53 | 2048, 256, 1, 1, bn_momentum=bn_momentum) 54 | self._atrous_convolution2 = _AsppConv( 55 | 2048, 256, 3, 1, 56 | padding=atrous_rates[1], 57 | dilation=atrous_rates[1], 58 | bn_momentum=bn_momentum) 59 | self._atrous_convolution3 = _AsppConv( 60 | 2048, 256, 3, 1, 61 | padding=atrous_rates[2], 62 | dilation=atrous_rates[2], 63 | bn_momentum=bn_momentum) 64 | self._atrous_convolution4 = _AsppConv( 65 | 2048, 256, 3, 1, 66 | padding=atrous_rates[3], 67 | dilation=atrous_rates[3], 68 | bn_momentum=bn_momentum) 69 | 70 | # image_pooling part 71 | self._image_pool = nn.Sequential( 72 | nn.AdaptiveAvgPool2d((1, 1)), 73 | nn.Conv2d(2048, 256, kernel_size=1, bias=False), 74 | SynchronizedBatchNorm2d(256, momentum=bn_momentum), nn.ReLU()) 75 | 76 | self.__init_weight() 77 | 78 | def forward(self, input): 79 | input1 = self._atrous_convolution1(input) 80 | input2 = self._atrous_convolution2(input) 81 | input3 = self._atrous_convolution3(input) 82 | input4 = self._atrous_convolution4(input) 83 | input5 = self._image_pool(input) 84 | input5 = F.interpolate( 85 | input=input5, 86 | size=input4.size()[2:4], 87 | mode='bilinear', 88 | align_corners=True) 89 | 90 | return torch.cat((input1, input2, input3, input4, input5), dim=1) 91 | 92 | def __init_weight(self): 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d): 95 | torch.nn.init.kaiming_normal_(m.weight) 96 | elif isinstance(m, SynchronizedBatchNorm2d): 97 | m.weight.data.fill_(1) 98 | m.bias.data.zero_() 99 | 100 | 101 | class Encoder(nn.Module): 102 | def __init__(self, bn_momentum=0.1, output_stride=16): 103 | super(Encoder, self).__init__() 104 | self.ASPP = AsppModule( 105 | bn_momentum=bn_momentum, output_stride=output_stride) 106 | self.relu = nn.ReLU() 107 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 108 | self.bn1 = SynchronizedBatchNorm2d(256, momentum=bn_momentum) 109 | self.dropout = nn.Dropout(0.5) 110 | 111 | self.__init_weight() 112 | 113 | def forward(self, input): 114 | input = self.ASPP(input) 115 | input = self.conv1(input) 116 | input = self.bn1(input) 117 | input = self.relu(input) 118 | input = self.dropout(input) 119 | return input 120 | 121 | def __init_weight(self): 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | torch.nn.init.kaiming_normal_(m.weight) 125 | elif isinstance(m, SynchronizedBatchNorm2d): 126 | m.weight.data.fill_(1) 127 | m.bias.data.zero_() 128 | 129 | 130 | if __name__ == "__main__": 131 | model = Encoder() 132 | model.eval() 133 | print(model) 134 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 135 | model = model.to(device) 136 | # summary(model, (3, 512, 512)) 137 | -------------------------------------------------------------------------------- /T-SS-GLCNet/models/deeplab_utils/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import patch_sync_batchnorm, convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | 15 | -------------------------------------------------------------------------------- /T-SS-GLCNet/models/deeplab_utils/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNorm2dReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /T-SS-GLCNet/models/deeplab_utils/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /T-SS-GLCNet/models/deeplab_utils/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /T-SS-GLCNet/models/deeplab_utils/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /T-SS-GLCNet/models/v3p.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/9/19 17:30 3 | # @Author : HLin 4 | # @Email : linhua2017@ia.ac.cn 5 | # @File : decoder.py 6 | # @Software: PyCharm 7 | 8 | # import os 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from models.deeplab_utils import ResNet 13 | from models.deeplab_utils.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 14 | from models.deeplab_utils.encoder import Encoder 15 | 16 | 17 | 18 | class Decoder(nn.Module): 19 | def __init__(self, class_num, bn_momentum=0.1): 20 | super(Decoder, self).__init__() 21 | self.conv1 = nn.Conv2d(256, 48, kernel_size=1, bias=False) 22 | self.bn1 = SynchronizedBatchNorm2d(48, momentum=bn_momentum) 23 | self.relu = nn.ReLU() 24 | # self.conv2 = SeparableConv2d(304, 256, kernel_size=3) 25 | # self.conv3 = SeparableConv2d(256, 256, kernel_size=3) 26 | self.conv2 = nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False) 27 | self.bn2 = SynchronizedBatchNorm2d(256, momentum=bn_momentum) 28 | self.dropout2 = nn.Dropout(0.5) 29 | self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False) 30 | self.bn3 = SynchronizedBatchNorm2d(256, momentum=bn_momentum) 31 | self.dropout3 = nn.Dropout(0.1) 32 | self.conv4 = nn.Conv2d(256, class_num, kernel_size=1) 33 | 34 | self._init_weight() 35 | 36 | def forward(self, x, low_level_feature): 37 | low_level_feature = self.conv1(low_level_feature) 38 | low_level_feature = self.bn1(low_level_feature) 39 | low_level_feature = self.relu(low_level_feature) 40 | x_4 = F.interpolate( 41 | x, 42 | size=low_level_feature.size()[2:4], 43 | mode='bilinear', 44 | align_corners=True) 45 | x_4_cat = torch.cat((x_4, low_level_feature), dim=1) 46 | x_4_cat = self.conv2(x_4_cat) 47 | x_4_cat = self.bn2(x_4_cat) 48 | x_4_cat = self.relu(x_4_cat) 49 | x_4_cat = self.dropout2(x_4_cat) 50 | x_4_cat = self.conv3(x_4_cat) 51 | x_4_cat = self.bn3(x_4_cat) 52 | x_4_cat = self.relu(x_4_cat) 53 | x_4_cat = self.dropout3(x_4_cat) 54 | x_4_cat = self.conv4(x_4_cat) 55 | 56 | return x_4_cat 57 | 58 | def _init_weight(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d): 61 | torch.nn.init.kaiming_normal_(m.weight) 62 | elif isinstance(m, SynchronizedBatchNorm2d): 63 | m.weight.data.fill_(1) 64 | m.bias.data.zero_() 65 | 66 | 67 | class DeepLab(nn.Module): 68 | def __init__(self, 69 | num_classes=2, 70 | in_channels=3, 71 | arch='resnet101', 72 | output_stride=16, 73 | bn_momentum=0.9, 74 | freeze_bn=False, 75 | pretrained=False, 76 | **kwargs): 77 | super(DeepLab, self).__init__(**kwargs) 78 | self.model_name = 'deeplabv3plus_' + arch 79 | 80 | # Setup arch 81 | if arch == 'resnet18': 82 | NotImplementedError('resnet18 backbone is not implemented yet.') 83 | elif arch == 'resnet34': 84 | NotImplementedError('resnet34 backbone is not implemented yet.') 85 | elif arch == 'resnet50': 86 | self.backbone = ResNet.resnet50(bn_momentum, pretrained) 87 | if in_channels != 3: 88 | self.backbone.conv1 = nn.Conv2d( 89 | in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 90 | elif arch == 'resnet101': 91 | self.backbone = ResNet.resnet101(bn_momentum, pretrained) 92 | if in_channels != 3: 93 | self.backbone.conv1 = nn.Conv2d( 94 | in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 95 | 96 | self.encoder = Encoder(bn_momentum, output_stride) 97 | self.decoder = Decoder(num_classes, bn_momentum) 98 | 99 | def forward(self, input): 100 | x, low_level_features = self.backbone(input) 101 | 102 | x = self.encoder(x) 103 | predict = self.decoder(x, low_level_features) 104 | output = F.interpolate( 105 | predict, 106 | size=input.size()[2:4], 107 | mode='bilinear', 108 | align_corners=True) 109 | return output 110 | 111 | def freeze_bn(self): 112 | for m in self.modules(): 113 | if isinstance(m, SynchronizedBatchNorm2d): 114 | m.eval() 115 | 116 | 117 | class Decoder_without_last_conv(nn.Module): 118 | def __init__(self, class_num, bn_momentum=0.1): 119 | super(Decoder_without_last_conv, self).__init__() 120 | self.conv1 = nn.Conv2d(256, 48, kernel_size=1, bias=False) 121 | self.bn1 = SynchronizedBatchNorm2d(48, momentum=bn_momentum) 122 | self.relu = nn.ReLU() 123 | # self.conv2 = SeparableConv2d(304, 256, kernel_size=3) 124 | # self.conv3 = SeparableConv2d(256, 256, kernel_size=3) 125 | self.conv2 = nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False) 126 | self.bn2 = SynchronizedBatchNorm2d(256, momentum=bn_momentum) 127 | # self.dropout2 = nn.Dropout(0.5) 128 | self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False) 129 | self.bn3 = SynchronizedBatchNorm2d(256, momentum=bn_momentum) 130 | # self.dropout3 = nn.Dropout(0.1) 131 | # self.conv4 = nn.Conv2d(256, class_num, kernel_size=1) 132 | 133 | self._init_weight() 134 | 135 | def forward(self, x, low_level_feature): 136 | low_level_feature = self.conv1(low_level_feature) 137 | low_level_feature = self.bn1(low_level_feature) 138 | low_level_feature = self.relu(low_level_feature) 139 | x_4 = F.interpolate( 140 | x, 141 | size=low_level_feature.size()[2:4], 142 | mode='bilinear', 143 | align_corners=True) 144 | x_4_cat = torch.cat((x_4, low_level_feature), dim=1) 145 | x_4_cat = self.conv2(x_4_cat) 146 | x_4_cat = self.bn2(x_4_cat) 147 | x_4_cat = self.relu(x_4_cat) 148 | # x_4_cat = self.dropout2(x_4_cat) 149 | x_4_cat = self.conv3(x_4_cat) 150 | x_4_cat = self.bn3(x_4_cat) 151 | x_4_cat = self.relu(x_4_cat) 152 | # x_4_cat = self.dropout3(x_4_cat) 153 | # x_4_cat = self.conv4(x_4_cat) 154 | 155 | return x_4_cat 156 | 157 | def _init_weight(self): 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv2d): 160 | torch.nn.init.kaiming_normal_(m.weight) 161 | elif isinstance(m, SynchronizedBatchNorm2d): 162 | m.weight.data.fill_(1) 163 | m.bias.data.zero_() 164 | 165 | 166 | 167 | def build_model(num_classes=5, in_channels=3,pretrained=False,arch='resnet101'): 168 | model = DeepLab(num_classes=num_classes, in_channels=in_channels,pretrained=pretrained,arch=arch) 169 | return model 170 | 171 | if __name__ == "__main__": 172 | model = DeepLab( 173 | output_stride=16, class_num=21, pretrained=False, freeze_bn=False) 174 | model.eval() 175 | for m in model.modules(): 176 | if isinstance(m, SynchronizedBatchNorm2d): 177 | print(m) 178 | -------------------------------------------------------------------------------- /T-SS-GLCNet/models/v3p_SimCLR_encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/9/19 17:30 3 | # @Author : HLin 4 | # @Email : linhua2017@ia.ac.cn 5 | # @File : decoder.py 6 | # @Software: PyCharm 7 | 8 | # import os 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from models.deeplab_utils import ResNet as ResNet 13 | from models.deeplab_utils.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 14 | from models.deeplab_utils.encoder import Encoder 15 | 16 | 17 | 18 | 19 | class DeepLab(nn.Module): 20 | def __init__(self, 21 | num_classes=2, 22 | in_channels=3, 23 | arch='resnet101', 24 | output_stride=16, 25 | bn_momentum=0.9, 26 | freeze_bn=False, 27 | pretrained=False, 28 | **kwargs): 29 | super(DeepLab, self).__init__(**kwargs) 30 | self.model_name = 'deeplabv3plus_' + arch 31 | 32 | # Setup arch 33 | if arch == 'resnet18': 34 | NotImplementedError('resnet18 backbone is not implemented yet.') 35 | elif arch == 'resnet34': 36 | NotImplementedError('resnet34 backbone is not implemented yet.') 37 | elif arch == 'resnet50': 38 | self.backbone = ResNet.resnet50(bn_momentum, pretrained) 39 | if in_channels != 3: 40 | self.backbone.conv1 = nn.Conv2d( 41 | in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 42 | elif arch == 'resnet101': 43 | self.backbone = ResNet.resnet101(bn_momentum, pretrained) 44 | if in_channels != 3: 45 | self.backbone.conv1 = nn.Conv2d( 46 | in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 47 | 48 | self.encoder = Encoder(bn_momentum, output_stride) 49 | #self.decoder = Decoder(20,100, bn_momentum) 50 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 51 | # projection head 52 | ''' 53 | self.proj = nn.Sequential( 54 | nn.Conv2d(256, 256, 1, bias=False), 55 | nn.BatchNorm2d(256), 56 | nn.ReLU(inplace=True), 57 | nn.Conv2d(256, 10, 1, bias=True) 58 | ) 59 | ''' 60 | self.proj =nn.Sequential(nn.Linear(256, 256), nn.ReLU(), nn.Linear(256,num_classes)) 61 | def forward(self, input,input1): 62 | x ,_= self.backbone(input) 63 | #print(low_level_features.size()),56 64 | x = self.encoder(x) 65 | #print(x.size()),14 66 | x=self.avgpool(x) 67 | #print(x.size()) 68 | x = torch.flatten(x, 1) 69 | #print(x.size()) 70 | q=self.proj(x) 71 | x ,_= self.backbone(input1) 72 | 73 | x = self.encoder(x) 74 | x=self.avgpool(x) 75 | x = torch.flatten(x, 1) 76 | k=self.proj(x) 77 | #predict,predict1 = self.decoder(x1, low_level_features) 78 | 79 | return q,k 80 | 81 | def freeze_bn(self): 82 | for m in self.modules(): 83 | if isinstance(m, SynchronizedBatchNorm2d): 84 | m.eval() 85 | 86 | 87 | 88 | 89 | 90 | def build_model(num_classes=5, in_channels=3,pretrained=False,arch='resnet101'): 91 | model = DeepLab(num_classes=num_classes, in_channels=in_channels,pretrained=pretrained,arch=arch) 92 | return model 93 | 94 | if __name__ == "__main__": 95 | model = DeepLab( 96 | output_stride=16, class_num=21, pretrained=False, freeze_bn=False) 97 | model.eval() 98 | for m in model.modules(): 99 | if isinstance(m, SynchronizedBatchNorm2d): 100 | print(m) 101 | -------------------------------------------------------------------------------- /T-SS-GLCNet/models/v3p_decoder12_ft.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/9/19 17:30 3 | # @Author : HLin 4 | # @Email : linhua2017@ia.ac.cn 5 | # @File : decoder.py 6 | # @Software: PyCharm 7 | 8 | # import os 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from models.deeplab_utils import ResNet 13 | from models.deeplab_utils.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 14 | from models.deeplab_utils.encoder import Encoder 15 | 16 | 17 | 18 | class Decoder(nn.Module): 19 | def __init__(self, class_num, bn_momentum=0.1): 20 | super(Decoder, self).__init__() 21 | self.conv1 = nn.Conv2d(256, 48, kernel_size=1, bias=False) 22 | self.bn1 = SynchronizedBatchNorm2d(48, momentum=bn_momentum) 23 | self.relu = nn.ReLU() 24 | # self.conv2 = SeparableConv2d(304, 256, kernel_size=3) 25 | # self.conv3 = SeparableConv2d(256, 256, kernel_size=3) 26 | self.conv2 = nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False) 27 | self.bn2 = SynchronizedBatchNorm2d(256, momentum=bn_momentum) 28 | self.dropout2 = nn.Dropout(0.5) 29 | self.conv3t = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False) 30 | self.bn3t = SynchronizedBatchNorm2d(256, momentum=bn_momentum) 31 | self.dropout3 = nn.Dropout(0.1) 32 | self.conv4 = nn.Conv2d(256, class_num, kernel_size=1) 33 | 34 | self._init_weight() 35 | 36 | def forward(self, x, low_level_feature): 37 | low_level_feature = self.conv1(low_level_feature) 38 | low_level_feature = self.bn1(low_level_feature) 39 | low_level_feature = self.relu(low_level_feature) 40 | x_4 = F.interpolate( 41 | x, 42 | size=low_level_feature.size()[2:4], 43 | mode='bilinear', 44 | align_corners=True) 45 | x_4_cat = torch.cat((x_4, low_level_feature), dim=1) 46 | x_4_cat = self.conv2(x_4_cat) 47 | x_4_cat = self.bn2(x_4_cat) 48 | x_4_cat = self.relu(x_4_cat) 49 | x_4_cat = self.dropout2(x_4_cat) 50 | x_4_cat = self.conv3t(x_4_cat) 51 | x_4_cat = self.bn3t(x_4_cat) 52 | x_4_cat = self.relu(x_4_cat) 53 | x_4_cat = self.dropout3(x_4_cat) 54 | x_4_cat = self.conv4(x_4_cat) 55 | 56 | return x_4_cat 57 | 58 | def _init_weight(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d): 61 | torch.nn.init.kaiming_normal_(m.weight) 62 | elif isinstance(m, SynchronizedBatchNorm2d): 63 | m.weight.data.fill_(1) 64 | m.bias.data.zero_() 65 | 66 | 67 | class DeepLab(nn.Module): 68 | def __init__(self, 69 | num_classes=2, 70 | in_channels=3, 71 | arch='resnet101', 72 | output_stride=16, 73 | bn_momentum=0.9, 74 | freeze_bn=False, 75 | pretrained=False, 76 | **kwargs): 77 | super(DeepLab, self).__init__(**kwargs) 78 | self.model_name = 'deeplabv3plus_' + arch 79 | 80 | # Setup arch 81 | if arch == 'resnet18': 82 | NotImplementedError('resnet18 backbone is not implemented yet.') 83 | elif arch == 'resnet34': 84 | NotImplementedError('resnet34 backbone is not implemented yet.') 85 | elif arch == 'resnet50': 86 | self.backbone = ResNet.resnet50(bn_momentum, pretrained) 87 | if in_channels != 3: 88 | self.backbone.conv1 = nn.Conv2d( 89 | in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 90 | elif arch == 'resnet101': 91 | self.backbone = ResNet.resnet101(bn_momentum, pretrained) 92 | if in_channels != 3: 93 | self.backbone.conv1 = nn.Conv2d( 94 | in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 95 | 96 | self.encoder = Encoder(bn_momentum, output_stride) 97 | self.decoder = Decoder(num_classes, bn_momentum) 98 | 99 | def forward(self, input): 100 | x, low_level_features = self.backbone(input) 101 | 102 | x = self.encoder(x) 103 | predict = self.decoder(x, low_level_features) 104 | output = F.interpolate( 105 | predict, 106 | size=input.size()[2:4], 107 | mode='bilinear', 108 | align_corners=True) 109 | return output 110 | 111 | def freeze_bn(self): 112 | for m in self.modules(): 113 | if isinstance(m, SynchronizedBatchNorm2d): 114 | m.eval() 115 | 116 | 117 | class Decoder_without_last_conv(nn.Module): 118 | def __init__(self, class_num, bn_momentum=0.1): 119 | super(Decoder_without_last_conv, self).__init__() 120 | self.conv1 = nn.Conv2d(256, 48, kernel_size=1, bias=False) 121 | self.bn1 = SynchronizedBatchNorm2d(48, momentum=bn_momentum) 122 | self.relu = nn.ReLU() 123 | # self.conv2 = SeparableConv2d(304, 256, kernel_size=3) 124 | # self.conv3 = SeparableConv2d(256, 256, kernel_size=3) 125 | self.conv2 = nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False) 126 | self.bn2 = SynchronizedBatchNorm2d(256, momentum=bn_momentum) 127 | # self.dropout2 = nn.Dropout(0.5) 128 | self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False) 129 | self.bn3 = SynchronizedBatchNorm2d(256, momentum=bn_momentum) 130 | # self.dropout3 = nn.Dropout(0.1) 131 | # self.conv4 = nn.Conv2d(256, class_num, kernel_size=1) 132 | 133 | self._init_weight() 134 | 135 | def forward(self, x, low_level_feature): 136 | low_level_feature = self.conv1(low_level_feature) 137 | low_level_feature = self.bn1(low_level_feature) 138 | low_level_feature = self.relu(low_level_feature) 139 | x_4 = F.interpolate( 140 | x, 141 | size=low_level_feature.size()[2:4], 142 | mode='bilinear', 143 | align_corners=True) 144 | x_4_cat = torch.cat((x_4, low_level_feature), dim=1) 145 | x_4_cat = self.conv2(x_4_cat) 146 | x_4_cat = self.bn2(x_4_cat) 147 | x_4_cat = self.relu(x_4_cat) 148 | # x_4_cat = self.dropout2(x_4_cat) 149 | x_4_cat = self.conv3(x_4_cat) 150 | x_4_cat = self.bn3(x_4_cat) 151 | x_4_cat = self.relu(x_4_cat) 152 | # x_4_cat = self.dropout3(x_4_cat) 153 | # x_4_cat = self.conv4(x_4_cat) 154 | 155 | return x_4_cat 156 | 157 | def _init_weight(self): 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv2d): 160 | torch.nn.init.kaiming_normal_(m.weight) 161 | elif isinstance(m, SynchronizedBatchNorm2d): 162 | m.weight.data.fill_(1) 163 | m.bias.data.zero_() 164 | 165 | 166 | 167 | 168 | def build_model(num_classes=5, in_channels=3,pretrained=False,arch='resnet101'): 169 | model = DeepLab(num_classes=num_classes, in_channels=in_channels,pretrained=pretrained,arch=arch) 170 | return model 171 | 172 | if __name__ == "__main__": 173 | model = DeepLab( 174 | output_stride=16, class_num=21, pretrained=False, freeze_bn=False) 175 | model.eval() 176 | for m in model.modules(): 177 | if isinstance(m, SynchronizedBatchNorm2d): 178 | print(m) 179 | -------------------------------------------------------------------------------- /T-SS-GLCNet/models/v3p_encoder_ft.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/9/19 17:30 3 | # @Author : HLin 4 | # @Email : linhua2017@ia.ac.cn 5 | # @File : decoder.py 6 | # @Software: PyCharm 7 | 8 | # import os 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from models.deeplab_utils import ResNet 13 | from models.deeplab_utils.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 14 | from models.deeplab_utils.encoder import Encoder 15 | 16 | 17 | 18 | class Decoder(nn.Module): 19 | def __init__(self, class_num, bn_momentum=0.1): 20 | super(Decoder, self).__init__() 21 | self.conv1 = nn.Conv2d(256, 48, kernel_size=1, bias=False) 22 | self.bn1 = SynchronizedBatchNorm2d(48, momentum=bn_momentum) 23 | self.relu = nn.ReLU() 24 | # self.conv2 = SeparableConv2d(304, 256, kernel_size=3) 25 | # self.conv3 = SeparableConv2d(256, 256, kernel_size=3) 26 | self.conv2 = nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False) 27 | self.bn2 = SynchronizedBatchNorm2d(256, momentum=bn_momentum) 28 | self.dropout2 = nn.Dropout(0.5) 29 | self.conv3 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False) 30 | self.bn3 = SynchronizedBatchNorm2d(256, momentum=bn_momentum) 31 | self.dropout3 = nn.Dropout(0.1) 32 | self.conv4 = nn.Conv2d(256, class_num, kernel_size=1) 33 | 34 | self._init_weight() 35 | 36 | def forward(self, x, low_level_feature): 37 | low_level_feature = self.conv1(low_level_feature) 38 | low_level_feature = self.bn1(low_level_feature) 39 | low_level_feature = self.relu(low_level_feature) 40 | x_4 = F.interpolate( 41 | x, 42 | size=low_level_feature.size()[2:4], 43 | mode='bilinear', 44 | align_corners=True) 45 | x_4_cat = torch.cat((x_4, low_level_feature), dim=1) 46 | x_4_cat = self.conv2(x_4_cat) 47 | x_4_cat = self.bn2(x_4_cat) 48 | x_4_cat = self.relu(x_4_cat) 49 | x_4_cat = self.dropout2(x_4_cat) 50 | x_4_cat = self.conv3(x_4_cat) 51 | x_4_cat = self.bn3(x_4_cat) 52 | x_4_cat = self.relu(x_4_cat) 53 | x_4_cat = self.dropout3(x_4_cat) 54 | x_4_cat = self.conv4(x_4_cat) 55 | 56 | return x_4_cat 57 | 58 | def _init_weight(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d): 61 | torch.nn.init.kaiming_normal_(m.weight) 62 | elif isinstance(m, SynchronizedBatchNorm2d): 63 | m.weight.data.fill_(1) 64 | m.bias.data.zero_() 65 | 66 | 67 | class DeepLab(nn.Module): 68 | def __init__(self, 69 | num_classes=2, 70 | in_channels=3, 71 | arch='resnet101', 72 | output_stride=16, 73 | bn_momentum=0.9, 74 | freeze_bn=False, 75 | pretrained=False, 76 | **kwargs): 77 | super(DeepLab, self).__init__(**kwargs) 78 | self.model_name = 'deeplabv3plus_' + arch 79 | 80 | # Setup arch 81 | if arch == 'resnet18': 82 | NotImplementedError('resnet18 backbone is not implemented yet.') 83 | elif arch == 'resnet34': 84 | NotImplementedError('resnet34 backbone is not implemented yet.') 85 | elif arch == 'resnet50': 86 | self.backbone = ResNet.resnet50(bn_momentum, pretrained) 87 | if in_channels != 3: 88 | self.backbone.conv1 = nn.Conv2d( 89 | in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 90 | elif arch == 'resnet101': 91 | self.backbone = ResNet.resnet101(bn_momentum, pretrained) 92 | if in_channels != 3: 93 | self.backbone.conv1 = nn.Conv2d( 94 | in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 95 | 96 | self.encoder = Encoder(bn_momentum, output_stride) 97 | self.decoder_ft = Decoder(num_classes, bn_momentum) 98 | 99 | def forward(self, input): 100 | x, low_level_features = self.backbone(input) 101 | 102 | x = self.encoder(x) 103 | predict = self.decoder_ft(x, low_level_features) 104 | output = F.interpolate( 105 | predict, 106 | size=input.size()[2:4], 107 | mode='bilinear', 108 | align_corners=True) 109 | return output 110 | 111 | def freeze_bn(self): 112 | for m in self.modules(): 113 | if isinstance(m, SynchronizedBatchNorm2d): 114 | m.eval() 115 | 116 | 117 | 118 | 119 | 120 | def build_model(num_classes=5, in_channels=3,pretrained=False,arch='resnet101'): 121 | model = DeepLab(num_classes=num_classes, in_channels=in_channels,pretrained=pretrained,arch=arch) 122 | return model 123 | 124 | if __name__ == "__main__": 125 | model = DeepLab( 126 | output_stride=16, class_num=21, pretrained=False, freeze_bn=False) 127 | model.eval() 128 | for m in model.modules(): 129 | if isinstance(m, SynchronizedBatchNorm2d): 130 | print(m) 131 | -------------------------------------------------------------------------------- /T-SS-GLCNet/models/v3p_jigsaw_encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/9/19 17:30 3 | # @Author : HLin 4 | # @Email : linhua2017@ia.ac.cn 5 | # @File : decoder.py 6 | # @Software: PyCharm 7 | 8 | # import os 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from models.deeplab_utils import ResNet 13 | from models.deeplab_utils.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 14 | from models.deeplab_utils.encoder import Encoder 15 | 16 | 17 | 18 | 19 | class DeepLab(nn.Module): 20 | def __init__(self, 21 | num_classes=2, 22 | in_channels=3, 23 | arch='resnet101', 24 | output_stride=16, 25 | bn_momentum=0.9, 26 | freeze_bn=False, 27 | pretrained=False,puzzle=3, 28 | **kwargs): 29 | super(DeepLab, self).__init__(**kwargs) 30 | self.model_name = 'deeplabv3plus_' + arch 31 | 32 | #num_classes=puzzle**2 33 | 34 | # Setup arch 35 | if arch == 'resnet18': 36 | NotImplementedError('resnet18 backbone is not implemented yet.') 37 | elif arch == 'resnet34': 38 | NotImplementedError('resnet34 backbone is not implemented yet.') 39 | elif arch == 'resnet50': 40 | self.backbone = ResNet.resnet50(bn_momentum, pretrained) 41 | if in_channels != 3: 42 | self.backbone.conv1 = nn.Conv2d( 43 | in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 44 | elif arch == 'resnet101': 45 | self.backbone = ResNet.resnet101(bn_momentum, pretrained) 46 | if in_channels != 3: 47 | self.backbone.conv1 = nn.Conv2d( 48 | in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 49 | 50 | self.encoder = Encoder(bn_momentum, output_stride) 51 | #self.decoder = Decoder(num_classes, bn_momentum) 52 | self.avgpool = nn.AdaptiveAvgPool2d((3, 3)) 53 | self.fc5 = nn.Sequential( 54 | nn.Linear(256 * puzzle * puzzle, 256), 55 | nn.ReLU(), 56 | nn.Dropout(p=0.1) 57 | ) 58 | self.fc6 = nn.Sequential( 59 | nn.Linear(256 * (puzzle**2), 1024), 60 | nn.ReLU(), 61 | nn.Dropout(p=0.1) 62 | ) 63 | 64 | self.classifier = nn.Sequential( 65 | nn.Linear(1024, num_classes) 66 | ) 67 | # projection head 68 | ''' 69 | self.proj = nn.Sequential( 70 | nn.Conv2d(256, 256, 1, bias=False), 71 | nn.BatchNorm2d(256), 72 | nn.ReLU(inplace=True), 73 | nn.Conv2d(256, 10, 1, bias=True) 74 | ) 75 | ''' 76 | self.proj =nn.Sequential(nn.Linear(256, 256), nn.ReLU(), nn.Linear(256,num_classes)) 77 | def forward(self, x): 78 | N, T, C, H, W = x.size() 79 | x = x.transpose(0, 1) 80 | 81 | x_list = [] 82 | for i in range(T): 83 | z,_ = self.backbone(x[i]) # 2x2 84 | z = self.encoder(z) 85 | z=self.avgpool(z) 86 | z = self.fc5(z.view(N, -1)) 87 | z = z.view([N, 1, -1]) 88 | x_list.append(z) 89 | 90 | x = torch.cat(x_list, 1) 91 | x = self.fc6(x.view(N, -1)) 92 | logist = self.classifier(x) 93 | return logist 94 | 95 | 96 | def freeze_bn(self): 97 | for m in self.modules(): 98 | if isinstance(m, SynchronizedBatchNorm2d): 99 | m.eval() 100 | 101 | 102 | 103 | 104 | def build_model(num_classes=5, in_channels=3,pretrained=False,arch='resnet101'): 105 | model = DeepLab(num_classes=num_classes, in_channels=in_channels,pretrained=pretrained,arch=arch) 106 | return model 107 | 108 | if __name__ == "__main__": 109 | model = DeepLab( 110 | output_stride=16, class_num=21, pretrained=False, freeze_bn=False) 111 | model.eval() 112 | for m in model.modules(): 113 | if isinstance(m, SynchronizedBatchNorm2d): 114 | print(m) 115 | -------------------------------------------------------------------------------- /T-SS-GLCNet/permutations_1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/G-RSIM/518efd3776d8e4937b0a4c34137a5bbda4b4f8a1/T-SS-GLCNet/permutations_1000.npy -------------------------------------------------------------------------------- /T-SS-GLCNet/readme: -------------------------------------------------------------------------------- 1 | training 2 | python main_ss.py 3 | --ex_mode: 0-just do self-supervised learning(ssl) 4 | 1-ssl and fine-tune 5 | 2-fine-tune 6 | 3-just fine-tune val 7 | --self_mode: 1-'GLCNet', 2-'SimCLR', 3-'mocov2', 4-'inpaiting', 8 | 5-'jigsaw',11-'GLCNet_noStyle',12-'GLCNet_noGlobal',13-'GLCNet_noLocal' 9 | -------------------------------------------------------------------------------- /T-SS-GLCNet/readme.md: -------------------------------------------------------------------------------- 1 | # Global and Local Contrastive Self-Supervised Learning for Semantic Segmentation of HR Remote Sensing Images. 2 | ## Abstract 3 | 4 | 5 | Recently, supervised deep learning has achieved great success in remote sensing image (RSI) semantic segmentation. However, supervised learning for semantic segmentation requires a large number of labeled samples, which is difficult to obtain in the field of remote sensing. A new learning paradigm, self-supervised learning (SSL), can be used to solve such problems by pre-training a general model with a large number of unlabeled images and then fine-tuning it on a downstream task with very few labeled samples. Contrastive learning is a typical method of SSL that can learn general invariant features. However, most existing contrastive learning methods are designed for classification tasks to obtain an image-level representation, which may be suboptimal for semantic segmentation tasks requiring pixel-level discrimination. Therefore, we propose a global style and local matching contrastive learning network (GLCNet) for remote sensing image semantic segmentation. Specifically, 1) the global style contrastive learning module is used to better learn an image-level representation, as we consider that style features can better represent the overall image features. 2) The local features matching contrastive learning module is designed to learn representations of local regions, which is beneficial for semantic segmentation. We evaluate four RSI semantic segmentation datasets, and the experimental results show that our method mostly outperforms state-of-the-art self-supervised methods and the ImageNet pre-training method. Specifically, with 1\% annotation from the original dataset, our approach improves Kappa by 6\% on the ISPRS Potsdam dataset relative to the existing baseline. Moreover, our method outperforms supervised learning methods when there are some differences between the datasets of upstream tasks and downstream tasks. Our study promotes the development of self-supervised learning in the field of RSI semantic segmentation. Since SSL could directly learn the essential characteristics of data from unlabeled data, which is easy to obtain in the remote sensing field, this may be of great significance for tasks such as global mapping. 6 | 7 | You can visit the paper via https://arxiv.org/abs/2106.10605 or 8 | 9 | ## Dataset Directory Structure 10 | ------- 11 | File Structure is as follows: 12 | 13 | $train_RGBIR/*.tif 14 | $train_lbl/*.tif 15 | $val_RGBIR/*.tif 16 | $val_lbl/*.tif 17 | train_RGBIR.txt 18 | trainR1_RGBIR.txt 19 | trainR1_lbl.txt 20 | val_RGBIR.txt 21 | val_lbl.txt 22 | 23 | ## Training 24 | ------- 25 | To pretrain the model with our GLCNet and finetune , try the following command: 26 | ``` 27 | python main_ss.py root=./data_example/Potsdam 28 | --ex_mode=1 --self_mode=1 \ 29 | --self_max_epoch=400 --ft_max_epoch=150 \ 30 | --self_data_name=train --ft_train_name=trainR1 31 | ``` 32 | 33 | ## Citation 34 | If our repo is useful to you, please cite our published paper as follow: 35 | 36 | ``` 37 | Bibtex 38 | @article{Li2021GLCNet, 39 | title={Global and Local Contrastive Self-Supervised Learning for Semantic Segmentation of HR Remote Sensing Images}, 40 | author={Li, Haifeng and Yi, Li and Zhang, Guo and Liu, Ruoyun and Huang, Haozhe and Zhu, Qing and Tao, Chao}, 41 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 42 | DOI = {10.1109/TGRS.2022.3147513}, 43 | year={2022}, 44 | type = {Journal Article} 45 | } 46 | 47 | Endnote 48 | %0 Journal Article 49 | %A Li, Haifeng 50 | %A Yi, Li 51 | %A Zhang, Guo 52 | %A Liu, Ruoyun 53 | %A Huang, Haozhe 54 | %A Zhu, Qing 55 | %A Tao, Chao 56 | %D 2022 57 | %T Global and Local Contrastive Self-Supervised Learning for Semantic Segmentation of HR Remote Sensing Images 58 | %B IEEE Transactions on Intelligent Transportation Systems 59 | %R DOI:10.1109/TITS.2019.2935152 60 | %! Global and Local Contrastive Self-Supervised Learning for Semantic Segmentation of HR Remote Sensing Images 61 | ``` 62 | -------------------------------------------------------------------------------- /T-SS-GLCNet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | ''' 3 | Abstract. 4 | 5 | Version 1.0 2019-12-06 22:17:13 by QiJi 6 | TODO: 7 | 1. 目前只改了DinkNet50的 first_conv,其余的都还得修改 8 | 9 | ''' 10 | from .import data_for_self_inpaitting,data_for_self_contrast,data_for_self_Jigsaw,data_for_self_GLCNet 11 | 12 | #from torchvision import models 13 | 14 | 15 | def get_self_dataset(args): 16 | ''' 17 | Args: 18 | net_num: 6位数, 19 | 第1位代表main model的arch: 20 | 0 - Scene_Base 21 | 第2位代表Backbone 的arch: 22 | 1 - resnet34 23 | 2 - resnet50 24 | 3 - resnet101 25 | 4 - vgg16 26 | 5 - googlenet 27 | ''' 28 | if args.self_mode==1 or args.self_mode==11 or args.self_mode==12 or args.self_mode==13: 29 | 30 | train_dataset = data_for_self_GLCNet.Train_Dataset(args.dataset_dir, args.self_data_name, args) 31 | elif args.self_mode==2 or args.self_mode==3: 32 | train_dataset = data_for_self_contrast.Train_Dataset(args.dataset_dir, args.self_data_name, args) 33 | 34 | elif args.self_mode==4: 35 | train_dataset = data_for_self_inpaitting.Train_Dataset(args.dataset_dir, args.self_data_name, args) 36 | 37 | elif args.self_mode==5: 38 | train_dataset = data_for_self_Jigsaw.Train_Dataset(args.dataset_dir, args.self_data_name, args) 39 | 40 | 41 | 42 | return train_dataset 43 | -------------------------------------------------------------------------------- /T-SS-GLCNet/utils/contrast_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | # import torch.autograd as autograd 8 | 9 | 10 | def mask_type_transfer(mask): 11 | mask = mask.type(torch.bool) 12 | #mask = mask.type(torch.uint8) 13 | return mask 14 | 15 | 16 | def get_pos_and_neg_mask(bs): 17 | ''' Org_NTXentLoss_mask ''' 18 | zeros = torch.zeros((bs, bs), dtype=torch.uint8) 19 | eye = torch.eye(bs, dtype=torch.uint8) 20 | pos_mask = torch.cat([ 21 | torch.cat([zeros, eye], dim=0), torch.cat([eye, zeros], dim=0), 22 | ], dim=1) 23 | neg_mask = _get_correlated_mask(bs) 24 | #(torch.ones(2*bs, 2*bs, dtype=torch.uint8) - torch.eye(2*bs, dtype=torch.uint8)) 25 | pos_mask = mask_type_transfer(pos_mask) 26 | neg_mask = mask_type_transfer(neg_mask) 27 | return pos_mask, neg_mask 28 | 29 | 30 | class NTXentLoss(nn.Module): 31 | """ NTXentLoss 32 | 33 | Args: 34 | tau: The temperature parameter. 35 | """ 36 | 37 | def __init__(self, 38 | bs,gpu, 39 | tau=1, 40 | cos_sim=False, 41 | use_gpu=True, 42 | eps=1e-8): 43 | super(NTXentLoss, self).__init__() 44 | self.name = 'NTXentLoss_Org' 45 | self.tau = tau 46 | self.use_cos_sim = cos_sim 47 | self.gpu = gpu 48 | self.eps = eps 49 | self.bs=bs 50 | 51 | if cos_sim: 52 | self.cosine_similarity = nn.CosineSimilarity(dim=-1) 53 | self.name += '_CosSim' 54 | 55 | # Get pos and neg mask 56 | self.pos_mask, self.neg_mask = get_pos_and_neg_mask(bs) 57 | 58 | if use_gpu: 59 | self.pos_mask = self.pos_mask.cuda(gpu) 60 | self.neg_mask = self.neg_mask.cuda(gpu) 61 | print(self.name) 62 | 63 | def forward(self, zi, zj): 64 | ''' 65 | input: {'zi': out_feature_1, 'zj': out_feature_2} 66 | target: one_hot lbl_prob_mat 67 | ''' 68 | zi, zj = F.normalize(zi, dim=1), F.normalize(zj, dim=1) 69 | bs = zi.shape[0] 70 | 71 | z_all = torch.cat([zi, zj], dim=0) # input1,input2: z_i,z_j 72 | # [2*bs, 2*bs] - pairwise similarity 73 | if self.use_cos_sim: 74 | sim_mat = torch.exp(self.cosine_similarity( 75 | z_all.unsqueeze(1), z_all.unsqueeze(0)) / self.tau) # s_(i,j) 76 | else: 77 | sim_mat = torch.exp(torch.mm(z_all, z_all.t().contiguous()) / self.tau) # s_(i,j) 78 | # if bs!=self.bs: 79 | # pos_mask, neg_mask = get_pos_and_neg_mask(bs) 80 | # pos_mask, neg_mask=pos_mask.cuda(self.gpu), neg_mask(self.gpu) 81 | # sim_pos = sim_mat.masked_select(pos_mask).view(2 * bs).clone() 82 | # # [2*bs, 2*bs-1] 83 | # sim_neg = sim_mat.masked_select(neg_mask).view(2 * bs, -1) 84 | # else: 85 | 86 | #pos = torch.sum(sim_mat * self.pos_mask, 1) 87 | #neg = torch.sum(sim_mat * self.neg_mask, 1) 88 | #loss = -(torch.mean(torch.log(pos / (pos + neg)))) 89 | sim_pos = sim_mat.masked_select(self.pos_mask).view(2 * bs).clone() 90 | # [2*bs, 2*bs-1] 91 | sim_neg = sim_mat.masked_select(self.neg_mask).view(2 * bs, -1) 92 | # Compute loss 93 | loss = (- torch.log(sim_pos / (sim_neg.sum(dim=-1) + self.eps))).mean() 94 | 95 | 96 | return loss 97 | 98 | def _get_correlated_mask(batch_size): 99 | diag = np.eye(2 * batch_size) 100 | l1 = np.eye((2 * batch_size), 2 * batch_size, k=-batch_size) 101 | l2 = np.eye((2 * batch_size), 2 * batch_size, k=batch_size) 102 | mask = torch.from_numpy((diag + l1 + l2)) 103 | mask = (1 - mask)#.byte()#.type(torch) 104 | return mask#.to(self.device) 105 | 106 | def get_contrast_loss(name, **kwargs): 107 | if name == 'NTXentLoss': 108 | criterion = NTXentLoss 109 | 110 | return criterion(**kwargs) 111 | 112 | 113 | def main(): 114 | 115 | pass 116 | 117 | 118 | if __name__ == '__main__': 119 | main() 120 | -------------------------------------------------------------------------------- /T-SS-GLCNet/utils/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | ''' 3 | Some custom loss functions for PyTorch. 4 | 5 | Version 1.0 2018-11-02 15:15:44 6 | 7 | cross_entropy2d() and multi_scale_cross_entropy2d() are not written by me. 8 | ''' 9 | from torch import nn 10 | import os 11 | import cv2 12 | import torch 13 | import numpy as np 14 | # import torch.nn as nn 15 | import torch.nn.functional as F 16 | # import torch.autograd as autograd 17 | 18 | 19 | class FocalLoss(nn.Module): 20 | def __init__(self, alpha=0.25, gamma=2, weight=None, ignore_index=None): 21 | super(FocalLoss, self).__init__() 22 | self.alpha = alpha 23 | self.gamma = gamma 24 | self.weight = weight 25 | self.ignore_index = ignore_index 26 | self.bce_fn = nn.CrossEntropyLoss(weight=self.weight) 27 | #self.bce_fn = nn.BCEWithLogitsLoss(weight=self.weight) 28 | 29 | def forward(self, preds, labels): 30 | if self.ignore_index is not None: 31 | mask = labels != self.ignore_index 32 | 33 | labels = labels[mask] 34 | preds = preds[mask] 35 | 36 | logpt = -self.bce_fn(preds, labels) 37 | pt = torch.exp(logpt) 38 | loss = -((1 - pt) ** self.gamma) * self.alpha * logpt 39 | return loss 40 | 41 | 42 | def cross_entropy2d(input, target, weight=None, size_average=True): 43 | n, c, h, w = input.size() 44 | nt, ht, wt = target.size() 45 | 46 | # Handle inconsistent size between input and target 47 | if h > ht and w > wt: # upsample labels 48 | target = target.unsequeeze(1) 49 | target = F.upsample(target, size=(h, w), mode='nearest') 50 | target = target.sequeeze(1) 51 | elif h < ht and w < wt: # upsample images 52 | input = F.upsample(input, size=(ht, wt), mode='bilinear') 53 | elif h != ht and w != wt: 54 | raise Exception("Only support upsampling") 55 | 56 | log_p = F.log_softmax(input, dim=1) 57 | log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 58 | log_p = log_p[target.view(-1, 1).repeat(1, c) >= 0] 59 | log_p = log_p.view(-1, c) 60 | 61 | mask = target >= 0 62 | target = target[mask] 63 | loss = F.nll_loss( 64 | log_p, target, ignore_index=250, weight=weight, size_average=False) 65 | if size_average: 66 | loss /= mask.data.sum() 67 | return loss 68 | 69 | 70 | def multi_scale_cross_entropy2d(input, 71 | target, 72 | weight=None, 73 | size_average=True, 74 | scale_weight=None): 75 | # Auxiliary training for PSPNet [1.0, 0.4] and ICNet [1.0, 0.4, 0.16] 76 | if scale_weight is None: # scale_weight: torch tensor type 77 | n_inp = len(input) 78 | scale = 0.4 79 | scale_weight = torch.pow(scale * torch.ones(n_inp), 80 | torch.arange(n_inp)) 81 | 82 | loss = 0.0 83 | for i, inp in enumerate(input): 84 | loss = loss + scale_weight[i] * cross_entropy2d( 85 | input=inp, target=target, weight=weight, size_average=size_average) 86 | 87 | return loss 88 | 89 | 90 | def jaccard_loss(input, target): 91 | ''' Soft IoU loss. Note: 未测试多类(2类以上,不包括2类)情况下是否正确 92 | Args: 93 | input - net output tensor, one_hot [NCHW] 94 | target - gt label tensor, one_hot [NCHW] 95 | ''' 96 | n, c, h, w = input.size() 97 | # nt, ht, wt = target.size() 98 | if input.size(0) != target.size(0): 99 | raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).' 100 | .format(input.size(0), target.size(0))) 101 | input = torch.sigmoid(input) 102 | # Expand target tensor dim 103 | target = torch.zeros(n, 2, h, w).scatter_(dim=1, index=target, value=1) 104 | intersection = input * target # #[NCHW] # #相同为input,不同为0 105 | # input1 = input.cpu().detach().numpy() 106 | # target1 = target.cpu().detach().numpy() 107 | union = input + target - intersection # #相同为1,不同为input 108 | iou = intersection / union # #相同为input/1,不同为0 109 | # iou1 = iou.cpu().detach().numpy() 110 | return (intersection / union).sum() / (n*h*w) 111 | 112 | 113 | class DiceLoss(nn.Module): 114 | def __init__(self): 115 | super(DiceLoss, self).__init__() 116 | 117 | def forward(self, input, target): 118 | N = target.size(0) 119 | smooth = 1 120 | 121 | input_flat = input.view(N, -1) 122 | target_flat = target.view(N, -1) 123 | 124 | intersection = input_flat * target_flat 125 | 126 | loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth) 127 | loss = 1 - loss.sum() / N 128 | 129 | def main(): 130 | 131 | # 预测值f(x) 构造样本,神经网络输出层 132 | # input_tensor = torch.ones([3, 2, 5, 5], dtype=torch.float64) 133 | # tmp_mat = torch.ones([5, 5], dtype=torch.float64) 134 | # input_tensor[0, 0, :, :] = tmp_mat * 0.5 135 | # input_tensor[1, 1, :, :] = tmp_mat * 0.5 136 | # input_tensor[2, 1, :, :] = tmp_mat * 0.5 137 | # label = torch.argmax(input_tensor, 3) 138 | # print(label[0]) 139 | # print(label[1]) 140 | # print(label.size()) 141 | # [0.8, 0.2] * [1, 0]: 0.8 / (0.8+0.2 + 1 - 0.8) = 0.8 / 1.2 = 2/3 142 | # [0.4, 0.6] * [1, 0]: 0.4 / (2 - 0.4) = 0.4 / 1.6 = 1/4 143 | # [0.0, 1.0] * [0, 1]: 0 144 | 145 | # 真值y 146 | # labels = torch.LongTensor([0, 1, 4, 7, 3, 2]).unsqueeze(1) 147 | # print(labels.size()) 148 | # one_hot = torch.zeros(6, 8).scatter_(dim=1, index=labels, value=1) 149 | # print(one_hot) 150 | 151 | # target_tensor = torch.ones([3, 5, 5], dtype=torch.int64).unsqueeze(1) 152 | # target_tensor = torch.zeros(3, 2, 5, 5).scatter_(1, target_tensor, 1) 153 | # print(target_tensor.size()) 154 | # J = input_tensor * target_tensor 155 | 156 | p = np.array([0.8, 0.2]) 157 | t = np.array([1, 0]) 158 | print() 159 | pass 160 | 161 | 162 | if __name__ == '__main__': 163 | main() 164 | -------------------------------------------------------------------------------- /T-SS-GLCNet/utils/nt_xent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | class NTXentLoss(torch.nn.Module): 6 | 7 | def __init__(self, device, batch_size, temperature, use_cosine_similarity): 8 | super(NTXentLoss, self).__init__() 9 | self.batch_size = batch_size 10 | self.temperature = temperature 11 | self.device = device 12 | self.softmax = torch.nn.Softmax(dim=-1) 13 | self.mask = self._get_correlated_mask().type(torch.bool)#bool,.byte() 14 | self.similarity_function = self._get_similarity_function(use_cosine_similarity) 15 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") 16 | 17 | def _get_similarity_function(self, use_cosine_similarity): 18 | if use_cosine_similarity: 19 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)#计算余弦相似度 20 | return self._cosine_simililarity 21 | else: 22 | return self._dot_simililarity 23 | 24 | def _get_correlated_mask(self): 25 | diag = np.eye(2 * self.batch_size) 26 | l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size) 27 | l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size) 28 | mask = torch.from_numpy((diag + l1 + l2)) 29 | mask = (1 - mask).byte()#.type(torch) 30 | return mask.to(self.device) 31 | 32 | @staticmethod 33 | def _dot_simililarity(x, y): 34 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) 35 | # x shape: (N, 1, C) 36 | # y shape: (1, C, 2N) 37 | # v shape: (N, 2N) 38 | return v 39 | 40 | def _cosine_simililarity(self, x, y): 41 | # x shape: (2N, 1, C) 42 | # y shape: (1, 2N, C) 43 | # v shape: (N, 2N) 44 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 45 | return v 46 | 47 | def forward(self, zis, zjs): 48 | 49 | 50 | batch_size = zis.shape[0] 51 | representations = torch.cat([zis, zjs], dim=0) 52 | 53 | similarity_matrix = torch.exp(torch.mm(representations,representations.t().contiguous()) / self.temperature) 54 | 55 | #mask = (torch.ones_like(similarity_matrix) -torch.eye(2 * batch_size, device=similarity_matrix.device)).type(torch.bool)#.byte()#type(torch.bool)#uint8 56 | 57 | similarity_matrix = similarity_matrix.masked_select(self.mask).view(2 * batch_size, -1) 58 | 59 | pos_sim = torch.exp(torch.sum(zis *zjs, dim=-1) /self.temperature) 60 | 61 | pos_sim = torch.cat([pos_sim, pos_sim], dim=0) 62 | loss = (-torch.log(pos_sim / (similarity_matrix.sum(dim=-1) + 1e-8))).mean() 63 | 64 | return loss 65 | ''' 66 | representations = torch.cat([zjs, zis], dim=0) 67 | self.batch_size = representations.shape[0] // 2 68 | 69 | similarity_matrix = self.similarity_function(representations, representations) 70 | 71 | # filter out the scores from the positive samples 72 | l_pos = torch.diag(similarity_matrix, self.batch_size) 73 | r_pos = torch.diag(similarity_matrix, -self.batch_size) 74 | positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) 75 | 76 | negatives = similarity_matrix[self._get_correlated_mask().type(torch.bool)].view(2 * self.batch_size, -1) 77 | 78 | logits = torch.cat((positives, negatives), dim=1) 79 | logits /= self.temperature 80 | 81 | labels = torch.zeros(2 * self.batch_size).to(self.device).long() 82 | loss = self.criterion(logits, labels) 83 | 84 | return loss / (2 * self.batch_size) 85 | ''' 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /T-SS-GLCNet/utils/tools.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Road extration expriment tools. 3 | 4 | by QiJi 5 | TODO: 6 | 1. xxx 7 | 8 | ''' 9 | # import os 10 | import cv2 11 | import torch 12 | import numpy as np 13 | # from tqdm import tqdm 14 | from dl_tools.basictools.dldata import vote_combine 15 | # import re 16 | 17 | 18 | def net_predict1(net, image, opt, crop_info=None): 19 | ''' Do predict use Net(only one image at a time). ''' 20 | if crop_info is None: 21 | # predict a complete image 22 | if opt.use_gpu: 23 | image = image.cuda() 24 | output = net(image)[0] # [NCHW] -> [CHW] 25 | predict = np.argmax(output.cpu().detach().numpy(), 0) # [CHW] -> [HW] 26 | else: 27 | # predict the list of croped images 28 | predict = [] 29 | image.transpose_(0, 1) # [NLCHW] -> [LNCHW] 30 | for input in image: 31 | if opt.use_gpu: 32 | input = input.cuda() # [NCHW](N=1) 33 | output = net(input)[0] # [NCHW] -> [CHW] 34 | # output = output[0] 35 | # tmp = np.mean(input, axis=1, dtype=input.type) 36 | # indx = np.argwhere(tmp == 0) 37 | # output[indx] = 0 38 | # output = output[0] 39 | output = np.transpose(output.cpu().detach().numpy(), (1, 2, 0)) # [CHW]->[HWC] 40 | if opt.input_size != opt.crop_params[:2]: 41 | output = cv2.resize(output, tuple(opt.crop_params[:2][::-1]), 1) 42 | predict.append(output) 43 | predict = vote_combine(predict, opt.crop_params, crop_info, 2) 44 | predict = np.argmax(predict, -1) # [HWC] -> [HW] 45 | 46 | return predict # [HW] array 47 | 48 | 49 | def net_predict2(net, image, opt, crop_info=None): 50 | ''' Do predict use Net(only one image at a time). ''' 51 | if crop_info is None: 52 | # predict a complete image 53 | if opt.use_gpu: 54 | image = image.cuda() 55 | output = net(image)[0] # [NCHW] -> [CHW] 56 | predict = np.argmax(output.cpu().detach().numpy(), 0) # [CHW] -> [HW] 57 | else: 58 | # predict the list of croped images 59 | predict = [] 60 | image.transpose_(0, 1) # [NLCHW] -> [LNCHW] 61 | for input in image: 62 | if opt.use_gpu: 63 | input = input.cuda() # [NCHW](N=1) 64 | output = net(input)[0] # [NCHW] -> [CHW] 65 | output = output[0] 66 | # tmp = np.mean(input, axis=1, dtype=input.type) 67 | # indx = np.argwhere(tmp == 0) 68 | # output[indx] = 0 69 | # output = output[0] 70 | output = np.transpose(output.cpu().detach().numpy(), (1, 2, 0)) # [CHW]->[HWC] 71 | if opt.input_size != opt.crop_params[:2]: 72 | output = cv2.resize(output, tuple(opt.crop_params[:2][::-1]), 1) 73 | predict.append(output) 74 | predict = vote_combine(predict, opt.crop_params, crop_info, 2) 75 | predict = np.argmax(predict, -1) # [HWC] -> [HW] 76 | 77 | return predict # [HW] array 78 | 79 | 80 | def net_predict_enhance(net, image, opt, crop_info=None): 81 | ''' Do predict use Net with some trick(only one image at a time). ''' 82 | predict_list = [] 83 | if crop_info is None: 84 | # predict a complete image 85 | for i in range(4): 86 | input = torch.from_numpy(np.rot90(image, i, axes=(3, 2)).copy()) 87 | if opt.use_gpu: 88 | input = input.cuda() 89 | output = net(input)[0] # [NCHW] -> [CHW] 90 | output = output.cpu().detach().numpy() # Tensor -> array 91 | output = np.transpose(output, (1, 2, 0)) # [CHW]->[HWC] 92 | predict_list.append(np.rot90(output, i, axes=(0, 1))) # counter-clockwise rotation 93 | 94 | else: 95 | # predict the list of croped images 96 | image.permute(1,0,2,3,4)#image.transpose_(0, 1) # [NLCHW] -> [LNCHW] 97 | for i in range(4): 98 | predict = [] 99 | for img in image: 100 | input = torch.from_numpy(np.rot90(img, i, axes=(3, 2)).copy()) 101 | if opt.use_gpu: 102 | input = input.cuda() # [NCHW](N=1) 103 | output = net(input)[0] # [NCHW] -> [CHW] 104 | output = output.cpu().detach().numpy() 105 | output = np.transpose(output, (1, 2, 0)) # [CHW]->[HWC] 106 | if opt.input_size != opt.crop_params[:2]: 107 | output = cv2.resize(output, tuple(opt.crop_params[:2][::-1]), 1) 108 | predict.append(np.rot90(output, i, axes=(0, 1))) 109 | predict_list.append(vote_combine(predict, opt.crop_params, crop_info, 2)) 110 | 111 | predict = predict_list[0] 112 | for i in range(1, 4): 113 | predict += predict_list[i] 114 | return np.argmax(predict, -1) # [HWC] -> [HW] array 115 | 116 | 117 | def net_predict_enhance2(net, image, opt, crop_info=None): 118 | ''' Do predict use Net with some trick(only one image at a time). ''' 119 | predict_list = [] 120 | if crop_info is None: 121 | # predict a complete image 122 | for i in range(4): 123 | input = torch.from_numpy(np.rot90(image, i, axes=(3, 2)).copy()) 124 | if opt.use_gpu: 125 | input = input.cuda() 126 | output = net(input)[0] # [NCHW] -> [CHW] 127 | output = output.cpu().detach().numpy() # Tensor -> array 128 | output = np.transpose(output, (1, 2, 0)) # [CHW]->[HWC] 129 | predict_list.append(np.rot90(output, i, axes=(0, 1))) # counter-clockwise rotation 130 | 131 | else: 132 | # predict the list of croped images 133 | image.transpose_(0, 1) # [NLCHW] -> [LNCHW] 134 | for i in range(4): 135 | predict = [] 136 | for img in image: 137 | input = torch.from_numpy(np.rot90(img, i, axes=(3, 2)).copy()) 138 | if opt.use_gpu: 139 | input = input.cuda() # [NCHW](N=1) 140 | output = net(input)[0] # [NCHW] -> [CHW] 141 | output = output.cpu().detach().numpy() 142 | output = np.transpose(output, (1, 2, 0)) # [CHW]->[HWC] 143 | if opt.input_size != opt.crop_params[:2]: 144 | output = cv2.resize(output, tuple(opt.crop_params[:2][::-1]), 1) 145 | predict.append(np.rot90(output, i, axes=(0, 1))) 146 | predict_list.append(vote_combine(predict, opt.crop_params, crop_info, 2)) 147 | 148 | predict = predict_list[0] 149 | for i in range(1, 4): 150 | predict += predict_list[i] 151 | return np.argmax(predict, -1) # [HWC] -> [HW] array 152 | -------------------------------------------------------------------------------- /T-SS-GLCNet/utils/util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.distributed as dist 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | def __init__(self): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | 26 | def dist_collect(x): 27 | """ collect all tensor from all GPUs 28 | args: 29 | x: shape (mini_batch, ...) 30 | returns: 31 | shape (mini_batch * num_gpu, ...) 32 | """ 33 | x = x.contiguous() 34 | out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype) 35 | for _ in range(dist.get_world_size())] 36 | dist.all_gather(out_list, x) 37 | return torch.cat(out_list, dim=0) 38 | 39 | def reduce_tensor(tensor): 40 | rt = tensor.clone() 41 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 42 | rt /= dist.get_world_size() 43 | return rt 44 | 45 | 46 | -------------------------------------------------------------------------------- /TOV_v1/.gitignore: -------------------------------------------------------------------------------- 1 | # Some folder 2 | debug/ 3 | .vscode/ 4 | lightning_logs/ 5 | Log/ 6 | Cls_data/ 7 | Seg_data/ 8 | Task_models/ 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | # *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | -------------------------------------------------------------------------------- /TOV_v1/README.md: -------------------------------------------------------------------------------- 1 | # Project of TOV 1.0 2 | This is project of the TOV paper: 3 | ``` 4 | [1] C. Tao, J. Qi, G. Zhang, Q. Zhu, W. Lu and H. Li, "TOV: The Original Vision Model for Optical Remote Sensing Image Understanding via Self-Supervised Learning," in IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing, vol. 16, pp. 4916-4930, 2023, doi: 10.1109/JSTARS.2023.3271312. 5 | 6 | @ARTICLE{10110958, 7 | author={Tao, Chao and Qi, Ji and Zhang, Guo and Zhu, Qing and Lu, Weipeng and Li, Haifeng}, 8 | journal={IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing}, 9 | title={TOV: The Original Vision Model for Optical Remote Sensing Image Understanding via Self-Supervised Learning}, 10 | year={2023}, 11 | volume={16}, 12 | number={}, 13 | pages={4916-4930}, 14 | doi={10.1109/JSTARS.2023.3271312} 15 | } 16 | ``` 17 | ## Abstract 18 | Do we on the right way for remote sensing image understanding (RSIU) by training models via supervised data-dependent and task-dependent way, instead of human vision in a label-free and task-independent way? We argue that a more desirable RSIU model should be trained with intrinsic structure from data rather that extrinsic human labels to realize generalizability across a wide range of RSIU tasks. According to this hypothesis, we proposed \textbf{T}he \textbf{O}riginal \textbf{V}ision model (TOV) in remote sensing filed. Trained by massive unlabeled optical data along a human-like self-supervised learning (SSL) path that is from general knowledge to specialized knowledge, TOV model can be easily adapted to various RSIU tasks, including scene classification, object detection, and semantic segmentation, and outperforms dominant ImageNet supervised pretrained method as well as two recently proposed SSL pretrained methods on majority of 12 publicly available benchmarks. Moreover, we analyze the influences of two key factors on the performance of building TOV model for RSIU, including the influence of using different data sampling methods and the selection of learning paths during self-supervised optimization. We believe that a general model which is trained by a label-free and task-independent way may be the next paradigm for RSIU and hope the insights distilled from this study can help to foster the development of an original vision model for RSIU. 19 | 20 | 21 | ## Introduction 22 | **TOV**: **T**he **O**rginal **V**ision for Optical Remote Sensing Image Understanding 23 | 24 | > We argue that a more desirable remote sensing image understanding (RSIU) model should be trained with intrinsic structure from data rather than extrinsic human labels to realize generalizability across a wide range of RSIU tasks. According to this hypothesis, we define the original vision model, which serves as a general purpose of visual perception for a wide range of RSIU tasks, and proposed a framework to build the first original vision model (TOV 1.0) in remote sensing filed. 25 | 26 | To foster the development of an original vision model for RSIU, in this project, we will realse our pre-trained TOV model and related materials: 27 | - [x] Pretrained TOV model ([GoogleDrive](https://drive.google.com/drive/folders/14c0TnHFi1N_DC_egcoNWHCKX9C2pmmUR?usp=sharing) | [BaiduDrive](https://pan.baidu.com/s/1NHnuTbj7fVvCuUJXU9N5vQ?pwd=TOV1)) 28 | - [x] The benchmark datasets and codes for evalutation. 29 | - [ ] TOV-RS: the large scale remote sensing image dataset constructed by the proposed method. ([BaiduDrive](https://pan.baidu.com/s/1VGvoi8UlgbBrFkWmWsORvQ?pwd=xy29)) 30 | - [ ] ... 31 | 32 | ## Using TOV model for various downstream RSIU tasks 33 | TOV pre-trained models are expected to be placed in the `TOV_models` folder, e.g., `TOV_models/0102300000_22014162253_pretrain/TOV_v1_model_pretrained_on_TOV-RS-balanced_ep800.pth.tar` 34 | 35 | ### Classification 36 | #### Example: fine-tune pre-trained TOV model on AID 37 | ```bash 38 | python classification/main_cls.py \ 39 | --dataset aid \ 40 | --train_scale 5 \ # use 5 samples per category for finetune 41 | 42 | # Other parameters (e.g., `learning_rate` can directly use the default values provided in the `classification/main_cls.py` 43 | ``` 44 | The scene classification datasets are expected to be placed in the `classification/Cls_data` directory, e.g., `classification/Cls_data/AID` 45 | 46 | ### Object detection 47 | We recommend using the powerful object detection framework, [MMDetection](https://github.com/open-mmlab/mmdetection). 48 | 49 | ### Semantic segmentation 50 | #### Example: fine-tune pre-trained TOV model on DLRSD 51 | ```bash 52 | pythonsegmentation/main_seg.py \ 53 | --dataset dlrsd \ 54 | --train_scale 0.01 \ # use 1% training samples for finetune 55 | --batch_size 16 --gpus 2 56 | ``` 57 | The semantic segmentation datasets are expected to be placed in the 58 | `segmentation/Seg_data` directory, e.g., `segmentation/Seg_data/ISPRS_Postdam` 59 | -------------------------------------------------------------------------------- /TOV_v1/Readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /TOV_v1/TOV_models/Place the TOV pre-training model in this floder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/G-RSIM/518efd3776d8e4937b0a4c34137a5bbda4b4f8a1/TOV_v1/TOV_models/Place the TOV pre-training model in this floder -------------------------------------------------------------------------------- /TOV_v1/classification/config/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import importlib 4 | import sys 5 | import warnings 6 | from pytorch_lightning.utilities import rank_zero_only 7 | 8 | warnings.filterwarnings("ignore") 9 | 10 | TYPE2BAND = {'RGB': 3, 'NIR': 1, 'SAR': 1, 'TEN': 10, 'ALL': 12, 'MS': 13} # 'ALL' for sentienl data; 'MS' for EuroSAT 11 | MODE_NAME = {1: 'train', 2: 'val', 3: 'test', 4: 'finetune', # 5: 'exp', 12 | 5: '', 6: '', 7: '', 13 | 8: '', 9: 'pretrain', 0: 'debug'} 14 | 15 | 16 | class __redirection__: 17 | def __init__(self, mode='console', file_path=None): 18 | assert mode in ['console', 'file', 'both'] 19 | 20 | self.mode = mode 21 | self.buff = '' 22 | self.__console__ = sys.stdout 23 | 24 | self.file = None 25 | if file_path is not None and mode != 'console': 26 | try: 27 | self.file = open(file_path, "w", buffering=1) 28 | except OSError: 29 | print('Fail to open log_file: {}'.format( 30 | file_path)) 31 | 32 | @rank_zero_only 33 | def write(self, output_stream): 34 | self.buff += output_stream 35 | if self.mode == 'console': 36 | self.to_console(output_stream) 37 | elif self.mode == 'file': 38 | self.to_file(output_stream) 39 | elif self.mode == 'both': 40 | self.to_console(output_stream) 41 | self.to_file(output_stream) 42 | 43 | @rank_zero_only 44 | def to_console(self, content): 45 | sys.stdout = self.__console__ 46 | print(content, end='') 47 | sys.stdout = self 48 | 49 | @rank_zero_only 50 | def to_file(self, content): 51 | if self.file is not None: 52 | sys.stdout = self.file 53 | print(content, end='') 54 | sys.stdout = self 55 | 56 | @rank_zero_only 57 | def all_to_console(self, flush=False): 58 | sys.stdout = self.__console__ 59 | print(self.buff, end='') 60 | sys.stdout = self 61 | 62 | @rank_zero_only 63 | def all_to_file(self, file_path=None, flush=True): 64 | if file_path is not None: 65 | self.open(file_path) 66 | if self.file is not None: 67 | sys.stdout = self.file 68 | print(self.buff, end='') 69 | sys.stdout = self 70 | # self.file.close() 71 | 72 | @rank_zero_only 73 | def open(self, file_path): 74 | try: 75 | self.file = open(file_path, "w", buffering=1) 76 | except OSError: 77 | print('Fail to open log_file: {}'.format( 78 | file_path)) 79 | 80 | @rank_zero_only 81 | def close(self): 82 | if self.file is not None: 83 | self.file.close() 84 | self.file = None 85 | 86 | @rank_zero_only 87 | def flush(self): 88 | self.buff = '' 89 | 90 | @rank_zero_only 91 | def reset(self): 92 | sys.stdout = self.__console__ 93 | 94 | 95 | def get_opt(name, args=None, redirection=False): 96 | '''Get options by name and current platform, and may use args to update them.''' 97 | 98 | get_config = importlib.import_module('config.default').get_config 99 | opts = get_config(name) 100 | if args is None: 101 | return opts # simple mode 102 | 103 | opts = preprocess_settings(opts, args) 104 | 105 | # Normalize the form of some parameters 106 | opts.dtype = unify_type(opts.dtype, list) 107 | opts.input_size = unify_type(opts.input_size, tuple, 2) 108 | for dt in opts.dtype: 109 | opts.in_channel += TYPE2BAND[dt] 110 | opts.mean = opts.mean if len(opts.mean) == opts.in_channel else opts.mean * opts.in_channel 111 | opts.std = opts.std if len(opts.std) == opts.in_channel else opts.std * opts.in_channel 112 | 113 | # Generate logging flags / info. 114 | opts.timestamp = time.strftime('%y%j%H%M%S', time.localtime(time.time())) 115 | mode_digits = len(str(opts.mode)) 116 | opts.mode_name = MODE_NAME[opts.mode // (10**(mode_digits-1))] 117 | if not hasattr(opts, 'net_suffix'): 118 | opts.net_suffix = opts.timestamp + '_' + opts.mode_name 119 | if opts.mode_name == 'finetune': 120 | assert opts.load_pretrained, '`load_pretrained` must be provided for finetune' 121 | expnum = 'temp' 122 | for s in opts.load_pretrained.split('_'): 123 | if len(s) == 11 and s.isdigit(): 124 | expnum = s 125 | break 126 | elif s in ['other']: 127 | expnum = s 128 | break 129 | opts.net_suffix = expnum + '_' + opts.net_suffix 130 | 131 | opts.exp_name = '{}_{}_{}'.format(opts.env, opts.model_name, opts.net_suffix) 132 | 133 | # Re-direction the print 134 | if redirection: 135 | print_mode = 'both' 136 | stream = __redirection__(print_mode) 137 | sys.stdout = stream 138 | 139 | print('Current Main Mode: {} - {}\n'.format(opts.mode, opts.exp_name)) 140 | 141 | # Basic prepare 142 | if opts.ckpt: 143 | if not os.path.exists(opts.ckpt): 144 | os.mkdir(opts.ckpt) 145 | 146 | if redirection: 147 | return opts, stream 148 | return opts 149 | 150 | 151 | def preprocess_settings(opts, args=None): 152 | 153 | # Combine user args with predefined opts 154 | if args is not None: 155 | for attr in opts.__dir__(): 156 | if attr.startswith('__') or attr.endswith('__'): 157 | continue 158 | value = getattr(opts, attr) 159 | if attr not in args or args.__dict__[attr] is None: 160 | setattr(args, attr, value) 161 | opts = args 162 | 163 | # Add some default settings 164 | if opts is not None: 165 | if not hasattr(opts, 'pj') or not opts.pj: 166 | opts.pj = opts.env 167 | 168 | if opts.max_epochs is None: 169 | opts.max_epochs = 3 170 | 171 | # Other 172 | if not hasattr(opts, 'full_train') or opts.full_train is None: 173 | opts.full_train = False 174 | 175 | return opts 176 | 177 | 178 | def unify_type(param, ptype=list, repeat=1): 179 | ''' Unify the type of param. 180 | 181 | Args: 182 | ptype: support list or tuple 183 | repeat: The times of repeating param in a list or tuple type. 184 | ''' 185 | if repeat == 1: 186 | if type(param) is not ptype: 187 | if ptype == list: 188 | param = [param] 189 | elif ptype == tuple: 190 | param = (param) 191 | elif repeat > 1: 192 | if type(param) is ptype and len(param) == repeat: 193 | return param 194 | elif type(param) is list: 195 | param = param * repeat 196 | else: 197 | param = [param] * repeat 198 | param = ptype(param) 199 | 200 | return param 201 | 202 | 203 | def args_list2dict(params): 204 | if type(params) is dict: 205 | return params 206 | elif type(params) is list: 207 | assert len(params) % 2 == 0, 'Must be paired args' 208 | options = {} 209 | for i in range(0, len(params), 2): 210 | options[params[i]] = params[i+1] 211 | 212 | return options 213 | -------------------------------------------------------------------------------- /TOV_v1/classification/config/category.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | 4 | class NWPU_RESISC45(object): 5 | ''' Category details of NWPU_RESISC45 dataset. ''' 6 | plan_name = 'NWPU_RESISC45' 7 | table = { 8 | "airplane": 0, 9 | "airport": 1, 10 | "baseball_diamond": 2, 11 | "basketball_court": 3, 12 | "beach": 4, 13 | "bridge": 5, 14 | "chaparral": 6, 15 | "church": 7, 16 | "circular_farmland": 8, 17 | "cloud": 9, 18 | "commercial_area": 10, 19 | "dense_residential": 11, 20 | "desert": 12, 21 | "forest": 13, 22 | "freeway": 14, 23 | "golf_course": 15, 24 | "ground_track_field": 16, 25 | "harbor": 17, 26 | "industrial_area": 18, 27 | "intersection": 19, 28 | "island": 20, 29 | "lake": 21, 30 | "meadow": 22, 31 | "medium_residential": 23, 32 | "mobile_home_park": 24, 33 | "mountain": 25, 34 | "overpass": 26, 35 | "palace": 27, 36 | "parking_lot": 28, 37 | "railway": 29, 38 | "railway_station": 30, 39 | "rectangular_farmland": 31, 40 | "river": 32, 41 | "roundabout": 33, 42 | "runway": 34, 43 | "sea_ice": 35, 44 | "ship": 36, 45 | "snowberg": 37, 46 | "sparse_residential": 38, 47 | "stadium": 39, 48 | "storage_tank": 40, 49 | "tennis_court": 41, 50 | "terrace": 42, 51 | "thermal_power_station": 43, 52 | "wetland": 44, 53 | } 54 | 55 | names = [name for (name, _) in table.items()] 56 | inds = table.values() 57 | 58 | num = len(table) 59 | 60 | color_table = None 61 | bgr_table = None # BGR 62 | mapping = None 63 | 64 | 65 | class AID(object): 66 | ''' Category details of AID dataset. ''' 67 | plan_name = 'AID' 68 | table = { 69 | "Airport": 0, 70 | "BareLand": 1, 71 | "BaseballField": 2, 72 | "Beach": 3, 73 | "Bridge": 4, 74 | "Center": 5, 75 | "Church": 6, 76 | "Commercial": 7, 77 | "DenseResidential": 8, 78 | "Desert": 9, 79 | "Farmland": 10, 80 | "Forest": 11, 81 | "Industrial": 12, 82 | "Meadow": 13, 83 | "MediumResidential": 14, 84 | "Mountain": 15, 85 | "Park": 16, 86 | "Parking": 17, 87 | "Playground": 18, 88 | "Pond": 19, 89 | "Port": 20, 90 | "RailwayStation": 21, 91 | "Resort": 22, 92 | "River": 23, 93 | "School": 24, 94 | "SparseResidential": 25, 95 | "Square": 26, 96 | "Stadium": 27, 97 | "StorageTanks": 28, 98 | "Viaduct": 29, 99 | } 100 | 101 | names = [name for (name, _) in table.items()] 102 | inds = table.values() 103 | 104 | num = len(table) 105 | 106 | color_table = None 107 | bgr_table = None # BGR 108 | mapping = None 109 | 110 | 111 | class Unlabeled(object): 112 | ''' Unlabeled self-supervised dataset. ''' 113 | plan_name = 'SSD' 114 | table = { 115 | "Unlabeled": 0, 116 | } 117 | 118 | names = [name for (name, _) in table.items()] 119 | inds = table.values() 120 | 121 | num = len(table) 122 | 123 | color_table = None 124 | bgr_table = None # BGR 125 | mapping = None 126 | 127 | 128 | class TianGong2(object): 129 | plan_name = 'TianGong2' 130 | 131 | table = { 132 | "beach": 0, 133 | "circularfarmland": 1, 134 | "cloud": 2, 135 | "desert": 3, 136 | "forest": 4, 137 | "mountain": 5, 138 | "rectangularfarmland": 6, 139 | "residential": 7, 140 | "river": 8, 141 | "snowberg": 9, 142 | } 143 | 144 | names = [name for (name, _) in table.items()] 145 | inds = table.values() 146 | num = len(table) 147 | color_table = None 148 | bgr_table = None # BGR 149 | mapping = None 150 | 151 | 152 | class EuroSAT(object): 153 | plan_name = 'EuroSAT' 154 | 155 | table = { 156 | "Highway": 0, 157 | "Industrial": 1, 158 | "Pasture": 2, 159 | "PermanentCrop": 3, 160 | "Residential": 4, 161 | "River": 5, 162 | "SeaLake": 6, 163 | "AnnualCrop": 7, 164 | "Forest": 8, 165 | "HerbaceousVegetation": 9, 166 | } 167 | 168 | names = [name for (name, _) in table.items()] 169 | inds = table.values() 170 | num = len(table) 171 | color_table = None 172 | bgr_table = None # BGR 173 | mapping = None 174 | 175 | 176 | class WHU_RSD46(object): 177 | ''' Category details of WHU_RSD46 dataset. ''' 178 | plan_name = 'WHU_RSD46' 179 | table = { 180 | 'Airplane': 0, 181 | 'Airport': 1, 182 | 'Artificial dense forest land': 2, 183 | 'Artificial sparse forest land': 3, 184 | 'Bare land': 4, 185 | 'Basketball court': 5, 186 | 'Blue structured factory building': 6, 187 | 'Building': 7, 188 | 'Construction site': 8, 189 | 'Cross river bridge': 9, 190 | 'Crossroads': 10, 191 | 'Dense tall building': 11, 192 | 'Dock': 12, 193 | 'Fish pond': 13, 194 | 'Footbridge': 14, 195 | 'Graff': 15, # 壕, 沟, 河渠 196 | 'Grassland': 16, 197 | 'Low scattered building': 17, 198 | 'Lrregular farmland': 18, 199 | 'Medium density scattered building': 19, 200 | 'Medium density structured building': 20, 201 | 'Natural dense forest land': 21, 202 | 'Natural sparse forest land': 22, 203 | 'Oil tank': 23, 204 | 'Overpass': 24, 205 | 'Parking lot': 25, 206 | 'Plastic greenhouse': 26, 207 | 'Playground': 27, 208 | 'Railway': 28, 209 | 'Red structured factory building': 29, 210 | 'Refinery': 30, 211 | 'Regular farmland': 31, 212 | 'Scattered blue roof factory building': 32, 213 | 'Scattered red roof factory building': 33, 214 | 'Sewage plant-type-one': 34, 215 | 'Sewage plant-type-two': 35, 216 | 'Ship': 36, 217 | 'Solar power station': 37, 218 | 'Sparse residential area': 38, 219 | 'Square': 39, 220 | 'Steal smelter': 40, 221 | 'Storage land': 41, 222 | 'Tennis court': 42, 223 | 'Thermal power plant': 43, 224 | 'Vegetable plot': 44, 225 | 'Water': 45 226 | } 227 | 228 | names = [name for (name, _) in table.items()] 229 | inds = table.values() 230 | 231 | num = len(table) 232 | 233 | color_table = None 234 | bgr_table = None # BGR 235 | mapping = None 236 | 237 | -------------------------------------------------------------------------------- /TOV_v1/classification/config/default/__init__.py: -------------------------------------------------------------------------------- 1 | def get_config(name, args=None): 2 | if name == 'nr': 3 | from .opt_NR import Config 4 | elif name == 'aid': 5 | from .opt_AID import Config 6 | elif name == 'rsd46': 7 | from .opt_RSD46 import Config 8 | elif name == 'ucm': 9 | from .opt_UCMerced import Config 10 | elif name == 'pnt': 11 | from .opt_PatternNet import Config 12 | elif name == 'tg2rgb': 13 | from .opt_TianGong2_RGB import Config 14 | elif name == 'eurorgb': 15 | from .opt_EuroSAT_RGB import Config 16 | 17 | if args is not None: 18 | mOptions = Config(args) 19 | else: 20 | mOptions = Config() 21 | 22 | return mOptions 23 | 24 | # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 25 | # sys.path.append(BASE_DIR+'/Tools/dltoos') 26 | -------------------------------------------------------------------------------- /TOV_v1/classification/config/default/opt_AID.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Default Config about AID dataset, classification mode 5 | ''' 6 | from config.category import AID 7 | 8 | 9 | class Config(object): 10 | ''' Fixed initialization Settings ''' 11 | # Path and file 12 | root = '' 13 | data_dir = root + 'classification/Cls_data/AID' 14 | ckpt = root + 'Task_models/AID' 15 | env = 'AID' # visdom 环境 16 | 17 | # train_val_ratio = [0.09, 0.91] 18 | train_val_ratio = [0.8, 0.2] 19 | train_scale = 1 # Scale of training set reduction 20 | val_scale = 0.5 21 | # Model related arguments 22 | sepoch = 1 # use to continue from a checkpoint 23 | pretrained = False # use pre-train checkpoint 24 | 25 | # Optimiztion related arguments 26 | batch_size = 16 # batch size 27 | max_epochs = None # 16-[256,256] dataset only need 8~9 epoch 28 | ckpt_freq = 0 29 | learning_rate = 4e-4 # initial learning rate 30 | lr_scheduler = 'step' # pre epoch 31 | lr_decay_rate = 0.5 32 | lr_decay_steps = 20 33 | warmup = 0 # if warmup > 0, use warmup strategy and end at warmup 34 | weight_decay = 1e-5 # L2 loss 35 | optimizer = ['adam', 'sgd', 'lars'][1] 36 | loss = ['CrossEntropyLoss', 'NTXentloss'][0] 37 | loss_weight = None 38 | 39 | # Data related arguments 40 | num_workers = 4 # number of data loading workers 41 | dtype = ['RGB'][0] 42 | bl_dtype = [''][0] 43 | in_channel = 0 44 | 45 | input_size = (224, 224) # final input size of network(random-crop use this) 46 | # crop_params = [256, 256, 256] # crop_params for val and pre 47 | 48 | # feature_dim = {1: 128, 2: 256, 3: 384, 4: 512} 49 | mean = [0.5, 0.5, 0.5] # BGR, 此处的均值应该是0-1 50 | std = [0.5, 0.5, 0.5] # [0.12283102, 0.1269429, 0.15580289] 51 | 52 | category = AID() 53 | classes = category.names 54 | class_dict = category.table 55 | num_classes = len(classes) 56 | -------------------------------------------------------------------------------- /TOV_v1/classification/config/default/opt_EuroSAT_RGB.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Default Config about EuroSAT dataset, classification mode 5 | URL: https://github.com/phelber/eurosat 6 | ''' 7 | from config.category import EuroSAT 8 | 9 | 10 | class Config(object): 11 | ''' Fixed initialization Settings ''' 12 | # Path and file 13 | root = '' 14 | data_dir = root + 'classification/Cls_data/EuroSATRGB' 15 | ckpt = root + 'Task_models/EuroSAT' 16 | env = 'Euro' # visdom 环境 17 | 18 | train_val_ratio = [0.8, 0.2] # train & val set is fixed 19 | train_scale = 1 # Scale of training set reduction 20 | val_scale = 0.2 21 | # Model related arguments 22 | sepoch = 1 # use to continue from a checkpoint 23 | pretrained = False # use pre-train checkpoint 24 | 25 | # Optimiztion related arguments 26 | batch_size = 64 # batch size 27 | max_epochs = None # 16-[256,256] dataset only need 8~9 epoch 28 | ckpt_freq = 0 29 | learning_rate = 1e-3 # initial learning rate 30 | lr_decay_rate = 0.5 31 | lr_decay_steps = 20 32 | lr_scheduler = 'step' # pre epoch 33 | warmup = 0 # if warmup > 0, use warmup strategy and end at warmup 34 | weight_decay = 1e-5 # L2 loss 35 | optimizer = ['adam', 'sgd', 'lars'][0] 36 | loss = ['CrossEntropyLoss', ][0] 37 | loss_weight = None 38 | 39 | # Data related arguments 40 | num_workers = 4 # number of data loading workers 41 | dtype = ['RGB'] 42 | bl_dtype = [''][0] 43 | in_channel = 0 44 | 45 | input_size = (64, 64) # final input size of network(random-crop use this) 46 | # crop_params = [256, 256, 256] # crop_params for val and pre 47 | 48 | # feature_dim = {1: 128, 2: 256, 3: 384, 4: 512} 49 | mean = [0.5] 50 | std = [0.5] 51 | 52 | category = EuroSAT() 53 | classes = category.names 54 | class_dict = category.table 55 | num_classes = len(classes) 56 | -------------------------------------------------------------------------------- /TOV_v1/classification/config/default/opt_NR.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Default Config about NWPU-RESISC45 dataset, classification mode 5 | ''' 6 | from config.category import NWPU_RESISC45 7 | 8 | 9 | class Config(object): 10 | ''' Fixed initialization Settings ''' 11 | # Path and file 12 | root = '' 13 | data_dir = root + 'classification/Cls_data/NWPU_RESISC45' 14 | ckpt = root + 'Task_models/NWPU_RESISC45' 15 | env = 'NR' # visdom 环境 16 | 17 | train_val_ratio = [0.8, 0.2] 18 | train_scale = 1 # Scale of training set reduction 19 | val_scale = 0.2 20 | 21 | # Model related arguments 22 | sepoch = 1 # use to continue from a checkpoint 23 | pretrained = False # use pre-train checkpoint 24 | 25 | # Optimiztion related arguments 26 | batch_size = 16 # batch size 27 | max_epochs = None # 16-[256,256] dataset only need 8~9 epoch 28 | ckpt_freq = 0 29 | learning_rate = 4e-4 # initial learning rate 30 | lr_decay_rate = 0.5 31 | lr_decay_steps = 20 32 | lr_scheduler = 'step' # pre epoch 33 | warmup = 0 # if warmup > 0, use warmup strategy and end at warmup 34 | weight_decay = 1e-5 # L2 loss 35 | optimizer = ['adam', 'sgd', 'lars'][0] 36 | loss = ['CrossEntropyLoss', 'NTXentloss'][0] 37 | loss_weight = None 38 | 39 | # Data related arguments 40 | num_workers = 4 # number of data loading workers 41 | dtype = ['RGB'][0] 42 | bl_dtype = [''][0] 43 | in_channel = 0 44 | 45 | input_size = (224, 224) # final input size of network(random-crop use this) 46 | # crop_params = [256, 256, 256] # crop_params for val and pre 47 | 48 | # feature_dim = {1: 128, 2: 256, 3: 384, 4: 512} 49 | mean = [0.5, 0.5, 0.5] # BGR, 此处的均值应该是0-1 50 | std = [0.5, 0.5, 0.5] 51 | 52 | category = NWPU_RESISC45() 53 | classes = category.names 54 | class_dict = category.table 55 | num_classes = len(classes) 56 | -------------------------------------------------------------------------------- /TOV_v1/classification/config/default/opt_PatternNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Default Config about PatternNet dataset, classification mode 5 | ''' 6 | 7 | 8 | class Config(object): 9 | ''' Fixed initialization Settings ''' 10 | # Path and file 11 | root = '' 12 | data_dir = root + 'classification/Cls_data/PatternNet' 13 | ckpt = root + 'Task_models/PatternNet' 14 | env = 'PatternNet' # visdom 环境 15 | 16 | train_val_ratio = [0.8, 0.2] 17 | train_scale = 1 # Scale of training set reduction 18 | val_scale = 0.2 19 | 20 | # Model related arguments 21 | sepoch = 1 # use to continue from a checkpoint 22 | pretrained = False # use pre-train checkpoint 23 | 24 | # Optimiztion related arguments 25 | batch_size = 16 # batch size 26 | max_epochs = None # 16-[256,256] dataset only need 8~9 epoch 27 | ckpt_freq = 0 28 | learning_rate = 4e-4 # initial learning rate 29 | lr_decay_rate = 0.5 30 | lr_decay_steps = 20 31 | lr_scheduler = 'step' # pre epoch 32 | warmup = 0 # if warmup > 0, use warmup strategy and end at warmup 33 | weight_decay = 1e-5 # L2 loss 34 | optimizer = ['adam', 'sgd', 'lars'][0] 35 | loss = ['CrossEntropyLoss', 'NTXentloss'][0] 36 | loss_weight = None 37 | 38 | # Data related arguments 39 | num_workers = 4 # number of data loading workers 40 | dtype = ['RGB'][0] 41 | bl_dtype = [''][0] 42 | in_channel = 0 43 | 44 | input_size = (224, 224) # final input size of network(random-crop use this) 45 | # crop_params = [256, 256, 256] # crop_params for val and pre 46 | 47 | # feature_dim = {1: 128, 2: 256, 3: 384, 4: 512} 48 | mean = [0.5, 0.5, 0.5] # BGR, 此处的均值应该是0-1 49 | std = [0.5, 0.5, 0.5] 50 | 51 | category = None 52 | classes = None 53 | class_dict = None 54 | num_classes = 38 55 | -------------------------------------------------------------------------------- /TOV_v1/classification/config/default/opt_RSD46.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Default Config about WHU_RSD46 dataset, classification mode 5 | ''' 6 | from config.category import WHU_RSD46 7 | 8 | 9 | class Config(object): 10 | ''' Fixed initialization Settings ''' 11 | # Path and file 12 | root = '' 13 | data_dir = root + 'classification/Cls_data/WHU_RSD46' 14 | # data_dir = '/dev/shm/WHU_RSD46' 15 | ckpt = root + 'Task_models/WHU_RSD46' 16 | env = 'RSD46' # visdom 环境 17 | 18 | train_val_ratio = [0, 0] # train & val set is fixed 19 | train_scale = 1 # Scale of training set reduction 20 | val_scale = 0.1 21 | # Model related arguments 22 | sepoch = 1 # use to continue from a checkpoint 23 | pretrained = False # use pre-train checkpoint 24 | 25 | # Optimiztion related arguments 26 | batch_size = 16 # batch size 27 | max_epochs = None # 16-[256,256] dataset only need 8~9 epoch 28 | ckpt_freq = 10 29 | learning_rate = 4e-4 # initial learning rate 30 | lr_scheduler = 'step' # pre epoch 31 | lr_decay_rate = 0.5 32 | lr_decay_steps = 20 33 | warmup = 0 # if warmup > 0, use warmup strategy and end at warmup 34 | weight_decay = 1e-5 # L2 loss 35 | optimizer = ['adam', 'sgd', 'lars'][0] 36 | loss = ['CrossEntropyLoss', 'NTXentloss'][0] 37 | loss_weight = None 38 | 39 | # Data related arguments 40 | num_workers = 4 # number of data loading workers 41 | dtype = ['RGB'][0] 42 | bl_dtype = [''][0] 43 | in_channel = 0 44 | 45 | input_size = (224, 224) # final input size of network(random-crop use this) 46 | # crop_params = [256, 256, 256] # crop_params for val and pre 47 | 48 | # feature_dim = {1: 128, 2: 256, 3: 384, 4: 512} 49 | mean = [0.5, 0.5, 0.5] # BGR, 此处的均值应该是0-1 50 | std = [0.5, 0.5, 0.5] # [0.12283102, 0.1269429, 0.15580289] 51 | 52 | category = WHU_RSD46() 53 | classes = category.names 54 | class_dict = category.table 55 | num_classes = len(classes) 56 | -------------------------------------------------------------------------------- /TOV_v1/classification/config/default/opt_TianGong2_RGB.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Default Config about TianGong2 dataset, classification mode 5 | URL: 6 | ''' 7 | from config.category import TianGong2 8 | 9 | 10 | class Config(object): 11 | ''' Fixed initialization Settings ''' 12 | # Path and file 13 | root = '' 14 | data_dir = root + 'classification/Cls_data/TianGong2/RGB' 15 | ckpt = root + 'Task_models/TianGong2' 16 | env = 'TianGong2' # visdom 环境 17 | 18 | train_val_ratio = [0.8, 0.2] # train & val set is fixed 19 | train_scale = 1 # Scale of training set reduction 20 | val_scale = 0.5 21 | # Model related arguments 22 | sepoch = 1 # use to continue from a checkpoint 23 | pretrained = False # use pre-train checkpoint 24 | 25 | # Optimiztion related arguments 26 | batch_size = 64 # batch size 27 | max_epochs = None # 16-[256,256] dataset only need 8~9 epoch 28 | ckpt_freq = 0 29 | learning_rate = 1e-3 # initial learning rate 30 | lr_decay_rate = 0.5 31 | lr_decay_steps = 20 32 | lr_scheduler = 'step' # pre epoch 33 | warmup = 0 # if warmup > 0, use warmup strategy and end at warmup 34 | weight_decay = 1e-5 # L2 loss 35 | optimizer = ['adam', 'sgd', 'lars'][0] 36 | loss = ['CrossEntropyLoss', ][0] 37 | loss_weight = None 38 | 39 | # Data related arguments 40 | num_workers = 4 # number of data loading workers 41 | dtype = ['RGB'] 42 | bl_dtype = [''][0] 43 | in_channel = 0 44 | 45 | input_size = (128, 128) # final input size of network(random-crop use this) 46 | # crop_params = [256, 256, 256] # crop_params for val and pre 47 | 48 | # feature_dim = {1: 128, 2: 256, 3: 384, 4: 512} 49 | mean = [0.5] 50 | std = [0.5] 51 | 52 | category = TianGong2() 53 | classes = category.names 54 | class_dict = category.table 55 | num_classes = len(classes) 56 | -------------------------------------------------------------------------------- /TOV_v1/classification/config/default/opt_UCMerced.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Default Config about UCMerced dataset, classification mode 5 | ''' 6 | 7 | 8 | class Config(object): 9 | ''' Fixed initialization Settings ''' 10 | # Path and file 11 | root = '' 12 | data_dir = root + 'classification/Cls_data/UCMerced_LandUse' 13 | ckpt = root + 'Task_models/UCMerced' 14 | env = 'UCMerced' # visdom 环境 15 | 16 | train_val_ratio = [0.5, 0.5] 17 | train_scale = 1 # Scale of training set reduction 18 | val_scale = 0.5 19 | 20 | # Model related arguments 21 | sepoch = 1 # use to continue from a checkpoint 22 | pretrained = False # use pre-train checkpoint 23 | 24 | # Optimiztion related arguments 25 | batch_size = 16 # batch size 26 | max_epochs = None # 16-[256,256] dataset only need 8~9 epoch 27 | ckpt_freq = 0 28 | learning_rate = 4e-4 # initial learning rate 29 | lr_decay_rate = 0.5 30 | lr_decay_steps = 20 31 | lr_scheduler = 'step' # pre epoch 32 | warmup = 0 # if warmup > 0, use warmup strategy and end at warmup 33 | weight_decay = 1e-5 # L2 loss 34 | optimizer = ['adam', 'sgd', 'lars'][0] 35 | loss = ['CrossEntropyLoss', 'NTXentloss'][0] 36 | loss_weight = None 37 | 38 | # Data related arguments 39 | num_workers = 4 # number of data loading workers 40 | dtype = ['RGB'][0] 41 | bl_dtype = [''][0] 42 | in_channel = 0 43 | 44 | input_size = (224, 224) # final input size of network(random-crop use this) 45 | # crop_params = [256, 256, 256] # crop_params for val and pre 46 | 47 | # feature_dim = {1: 128, 2: 256, 3: 384, 4: 512} 48 | mean = [0.5, 0.5, 0.5] # BGR, 此处的均值应该是0-1 49 | std = [0.5, 0.5, 0.5] 50 | 51 | category = None 52 | classes = None 53 | class_dict = None 54 | num_classes = 21 55 | -------------------------------------------------------------------------------- /TOV_v1/classification/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_interface import DInterface, datasets_init 2 | from .classfy_data import ClassfyData 3 | 4 | __all__ = [ 5 | 'DInterface', 'datasets_init', 'ClassfyData', 6 | ] 7 | -------------------------------------------------------------------------------- /TOV_v1/classification/dataset/classfy_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Dataset for Scene classification. 4 | ''' 5 | import time 6 | import numpy as np 7 | import cv2 8 | from PIL import Image 9 | from torch.utils import data 10 | # from .transform import get_cls_transform as get_transform 11 | 12 | 13 | def imread_with_cv(pth): 14 | ''' Only load RGB images data. ''' 15 | img = cv2.imread(pth, 1) 16 | img = img[:, :, ::-1] # BGR→RGB 17 | return img.copy() 18 | 19 | 20 | def imread_with_pil(self, pth): 21 | ''' Only load RGB images data. ''' 22 | with open(pth, 'rb') as f: 23 | img = Image.open(f) 24 | return img.convert('RGB') 25 | 26 | 27 | def shuffle_samples(samples, seed=None): 28 | if seed is None: 29 | seed = np.random.randint(100000) 30 | np.random.seed(seed) 31 | np.random.shuffle(samples) 32 | return samples 33 | 34 | 35 | class ClassfyData(data.Dataset): 36 | ''' 37 | Dataset loader for scene classification. 38 | Args: 39 | init_data - Dict of init data = {'train': [(img_path, lbl), ..], 'val': [...]} 40 | split - One of ['train', 'val', 'test'] 41 | ratio - (float) extra parameter to conctrl the num of dataset 42 | if ratio < 0.01, then stand for nshort mode: 43 | ratio = 0.00n, where n is the num of sample per category 44 | transform - 可以传入自定义的transform对象 45 | ''' 46 | def __init__(self, 47 | init_data, 48 | split='train', 49 | dtype=['RGB'], 50 | ratio=1, 51 | transforms=None, 52 | class_wise_sampling=True, 53 | data_seed=0, 54 | debug=False): 55 | # Set all input args as attributes 56 | self.__dict__.update(locals()) 57 | 58 | if len(dtype) == 1 and dtype[0] == 'RGB': 59 | self.dtype = 'RGB' 60 | self.imread_func = imread_with_cv 61 | else: 62 | self.dtype = 'other' 63 | # self.imread_func = imread_with_gdal 64 | 65 | # Collect all dataset files, divide into dict. 66 | tic = time.time() 67 | self.imgs, self.lbls = [], [] 68 | total_num = 0 69 | self.sample_statis = '' 70 | 71 | self.num_classes = len(init_data[split].keys()) 72 | if type(init_data) is dict: 73 | if class_wise_sampling: 74 | for (cls, samples) in init_data[split].items(): 75 | total_num += len(samples) 76 | if ratio <= 1: 77 | use_num = max(round(len(samples)*ratio), 1) 78 | else: 79 | use_num = min(int(ratio), len(samples)) 80 | self.sample_statis += '| %s | %d |_' % (cls, use_num) 81 | if use_num == 0: 82 | continue 83 | if data_seed: 84 | samples = shuffle_samples(samples, data_seed) 85 | for (pth, lbl) in samples[:use_num]: 86 | self.imgs.append(pth) 87 | self.lbls.append(lbl) 88 | else: 89 | categories = {} 90 | for (cls, samples) in init_data[split].items(): 91 | for (pth, lbl) in samples: 92 | self.imgs.append(pth) 93 | self.lbls.append(lbl) 94 | total_num += len(samples) 95 | categories[cls] = lbl 96 | if ratio <= 1: 97 | use_num = max(round(total_num*ratio), 1) 98 | else: 99 | use_num = min(int(ratio), total_num) 100 | self.shuffle_data(data_seed) # shuffle data with seed 101 | self.imgs = self.imgs[:use_num] 102 | # print(self.imgs) 103 | self.lbls = self.lbls[:use_num] 104 | for cls, lbl in categories.items(): 105 | cls_num = (np.array(self.lbls) == lbl).sum() 106 | self.sample_statis += '| %s | %d |_' % (cls, cls_num) 107 | self.use_num = len(self.imgs) 108 | 109 | print('{} set contains {} images, in {} categories.'.format( 110 | split, total_num, self.num_classes)) 111 | if split == 'train': 112 | print('*'*6, f'data seed = {data_seed}', '*'*6) 113 | print(self.sample_statis) 114 | print('Actual number of samples used = {}'.format(self.use_num)) 115 | 116 | def load_data_from_disk(self, index): 117 | img = self.imread_func(self.imgs[index]) 118 | lbl = np.array(self.lbls[index], dtype=np.int64) 119 | return img, lbl 120 | 121 | def shuffle_data(self, seed=None): 122 | if seed is None: 123 | seed = np.random.randint(100000) 124 | 125 | np.random.seed(seed) 126 | np.random.shuffle(self.imgs) 127 | np.random.seed(seed) 128 | np.random.shuffle(self.lbls) 129 | 130 | def __len__(self): 131 | return self.use_num 132 | 133 | def __getitem__(self, index): 134 | ''' Return one image(and label) per time. ''' 135 | img, lbl = self.load_data_from_disk(index) 136 | 137 | img = Image.fromarray(img) 138 | # sample = {'image': img, 'label': lbl, 'name': self.imgs[index]} 139 | if self.transforms is not None: 140 | img = self.transforms(img) 141 | 142 | return img, lbl # , self.imgs[index] 143 | -------------------------------------------------------------------------------- /TOV_v1/classification/models/cls_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | ''' 3 | Package the torchvision model. 4 | 5 | Version 1.0 2020 by QiJi 6 | ''' 7 | 8 | import torch 9 | import collections 10 | from torch import nn 11 | 12 | 13 | class ClsNet(nn.Module): 14 | """ Image-level classification network contains `features` and `classifier`. 15 | """ 16 | def __init__(self, backbone, num_classes=2, out_dim=None, 17 | enhance_features=[], attention_module=None, 18 | **kwargs): 19 | super().__init__() 20 | if hasattr(backbone, 'model_name'): 21 | bb_name = backbone.model_name 22 | else: 23 | bb_name = backbone._get_name().lower() 24 | self.model_name = 'ClsNet_' + bb_name 25 | 26 | if 'alexnet' in bb_name: 27 | self.features = backbone.features 28 | # self.out_dim = 256 29 | elif 'resnet' in bb_name or 'resnext' in bb_name: 30 | bb = nn.Sequential(collections.OrderedDict(list(backbone.named_children()))) 31 | self.features = bb[:8] 32 | # self.out_dim = 2048 33 | # if ('resnet18' in bb_name) or ('resnet34' in bb_name): 34 | # self.out_dim = 512 35 | elif 'vgg' in bb_name: 36 | self.features = backbone.features 37 | # self.out_dim = 512 38 | elif 'googlenet' in bb_name: 39 | backbone = nn.Sequential(collections.OrderedDict(dict(backbone.named_children()))) 40 | self.features = backbone[:16] 41 | # self.out_dim = 1024 42 | elif 'inception' in bb_name: 43 | pass 44 | 45 | if out_dim is not None: 46 | self.features = nn.Sequential( 47 | self.features, 48 | nn.Conv2d(self.out_dim, out_dim, 1), 49 | nn.BatchNorm2d(out_dim), nn.ReLU(True)) 50 | self.out_dim = out_dim 51 | else: 52 | # loop layers and get last conv channels 53 | for name, m in self.features.named_modules(): 54 | if isinstance(m, torch.nn.Conv2d): 55 | self.out_dim = m.out_channels 56 | 57 | self.am = None 58 | if attention_module is not None: 59 | raise NotImplementedError('attention is not support this version') 60 | 61 | self.gap = nn.AdaptiveAvgPool2d((1, 1)) 62 | self.classifier = nn.Linear(self.out_dim, num_classes) 63 | 64 | def forward(self, x): 65 | # 1. Get feature - [N,C,?,?] 66 | x = self.features(x) 67 | if self.am is not None: 68 | x = self.am(x) 69 | 70 | # 2. Globel avg pooling to get feature of per image 71 | feature = self.gap(x) # [N,C,1,1] 72 | 73 | # 3. Get logits 74 | x = torch.flatten(feature, 1) # [N,C,1,1] -> [N,C] 75 | logits = self.classifier(x) 76 | 77 | return logits 78 | 79 | -------------------------------------------------------------------------------- /TOV_v1/classification/tov_finetune_cls.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | 4 | 5 | NetNum=0102300000 6 | CP_dir='Task_models' 7 | pCP_dir='TOV_models' 8 | PJ="SSD" 9 | 10 | # * Finetune * 11 | f_epochs=200 12 | 13 | # *New* 14 | pDataSet='TOV-RS-blanced' 15 | pretrain_config=$pCP_dir'/'$NetNum'_'22061085439_pretrain # 16 | 17 | for fDataSet in "eurorgb" 18 | do 19 | note=$pDataSet"_fintune_on_"$fDataSet 20 | for train_scales in 5 20 50 100 21 | do 22 | python main_cls.py --mode 410 --pj $PJ --exp_note $note \ 23 | --dataset $fDataSet --input_size 64 --val_scale 1 --train_scale $train_scales \ 24 | --learning_rate 4e-3 --optimizer adam --lr_scheduler cosine --freeze_layers 'features' \ 25 | --model_name $NetNum --max_epochs $f_epochs --weight_decay 0 \ 26 | --ckpt $CP_dir --num_workers 4 --gpu 1 --batch_size 64 \ 27 | --super_prefetch --data_workers 4 \ 28 | --map_keys '' 'features.' \ 29 | --load_pretrained $pretrain_config 30 | done 31 | done 32 | 33 | note=$pDataSet"_fintune_on_tg2rgb" 34 | for train_scales in 5 20 50 100 35 | do 36 | python main_cls.py --mode 410 --pj $PJ --exp_note $note \ 37 | --dataset "tg2rgb" --input_size 128 --val_scale 1 --train_scale $train_scales \ 38 | --learning_rate 4e-3 --optimizer adam --lr_scheduler cosine --freeze_layers 'features' \ 39 | --model_name $NetNum --max_epochs $f_epochs --weight_decay 0 \ 40 | --ckpt $CP_dir --num_workers 4 --gpu 1 --batch_size 64 \ 41 | --super_prefetch --data_workers 4 \ 42 | --map_keys '' 'features.' \ 43 | --load_pretrained $pretrain_config 44 | done 45 | 46 | 47 | for fDataSet in "aid" "nr" "rsd46" "pnt" "ucm" 48 | do 49 | note=$pDataSet"_fintune_on_"$fDataSet 50 | for train_scales in 5 20 50 100 51 | do 52 | python main_cls.py --mode 410 --pj $PJ --exp_note $note \ 53 | --dataset $fDataSet --input_size 224 --val_scale 1 --train_scale $train_scales \ 54 | --learning_rate 4e-3 --optimizer adam --lr_scheduler cosine --freeze_layers 'features' \ 55 | --model_name $NetNum --max_epochs $f_epochs --weight_decay 0 \ 56 | --ckpt $CP_dir --num_workers 4 --gpu 1 --batch_size 64 \ 57 | --super_prefetch --data_workers 4 \ 58 | --map_keys '' 'features.' \ 59 | --load_pretrained $pretrain_config 60 | done 61 | done 62 | 63 | # note="Imagenet_pretrain" 64 | # for fDataSet in "eurorgb" 65 | # do 66 | # note="Imagenet_pretrain_fintune_on_"$fDataSet 67 | # for train_scales in 5 20 50 100 68 | # do 69 | # python main_cls.py --mode 1 --pj $PJ --exp_note $note \ 70 | # --dataset $fDataSet --input_size 64 --val_scale 1 --train_scale $train_scales \ 71 | # --learning_rate 4e-3 --optimizer adam --lr_scheduler cosine --freeze_layers 'features' \ 72 | # --model_name $NetNum --max_epochs $f_epochs --weight_decay 0 \ 73 | # --ckpt $CP_dir --num_workers 4 --gpu 1 --batch_size 64 \ 74 | # --super_prefetch --data_workers 4 \ 75 | # --pretrained --load_pretrained '' # ImageNet(SL) 76 | # done 77 | # done 78 | # for fDataSet in "nr" 79 | # do 80 | # note="Imagenet_pretrain_fintune_on_"$fDataSet 81 | # for train_scales in 100 82 | # do 83 | # python main_cls.py --mode 1 --pj $PJ --exp_note $note \ 84 | # --dataset $fDataSet --input_size 224 --val_scale 1 --train_scale $train_scales \ 85 | # --learning_rate 4e-3 --optimizer adam --lr_scheduler cosine --freeze_layers 'features' \ 86 | # --model_name $NetNum --max_epochs $f_epochs --weight_decay 0 \ 87 | # --ckpt $CP_dir --num_workers 4 --gpu 1 --batch_size 64 \ 88 | # --super_prefetch --data_workers 4 \ 89 | # --pretrained --load_pretrained '' # ImageNet(SL) 90 | # done 91 | # done 92 | 93 | # for fDataSet in "rsd46" "pnt" "ucm" "eurorgb" 94 | # do 95 | # note="Imagenet_pretrain_fintune_on_"$fDataSet 96 | # for train_scales in 5 20 50 100 97 | # do 98 | # python main_cls.py --mode 1 --pj $PJ --exp_note $note \ 99 | # --dataset $fDataSet --input_size 224 --val_scale 1 --train_scale $train_scales \ 100 | # --learning_rate 4e-3 --optimizer adam --lr_scheduler cosine --freeze_layers 'features' \ 101 | # --model_name $NetNum --max_epochs $f_epochs --weight_decay 0 \ 102 | # --ckpt $CP_dir --num_workers 4 --gpu 1 --batch_size 64 \ 103 | # --super_prefetch --data_workers 4 \ 104 | # --pretrained --load_pretrained '' # ImageNet(SL) 105 | # done 106 | # done 107 | -------------------------------------------------------------------------------- /TOV_v1/classification/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .tools import ( 3 | load_pretrain_path_by_args, load_model_path_by_args, 4 | load_ckpt, 5 | checkpoint_dict_mapping, checkpoint_standardize, 6 | ) 7 | from .mycallbacks import ( 8 | MyLogger, MyDebugCallback, MySpeedUpCallback, 9 | get_checkpoint_callback 10 | ) 11 | 12 | __all__ = [ 13 | 'MyLogger', 'MyDebugCallback', 'MySpeedUpCallback', 14 | 'get_checkpoint_callback', 15 | 'load_pretrain_path_by_args', 'load_model_path_by_args', 16 | 'load_ckpt', 17 | 'checkpoint_dict_mapping', 'checkpoint_standardize', 18 | ] 19 | 20 | 21 | ''' 22 | def load_model_path(root=None, version=None, v_num=None, best=False): 23 | """ When best = True, return the best model's path in a directory 24 | by selecting the best model with largest epoch. If not, return 25 | the last model saved. You must provide at least one of the 26 | first three args. 27 | Args: 28 | root: The root directory of checkpoints. It can also be a 29 | model ckpt file. Then the function will return it. 30 | version: The name of the version you are going to load. 31 | v_num: The version's number that you are going to load. 32 | best: Whether return the best model. 33 | """ 34 | def sort_by_epoch(path): 35 | name = path.stem 36 | epoch = int(name.split('-')[1].split('=')[1]) 37 | return epoch 38 | 39 | def generate_root(): 40 | if root is not None: 41 | return root 42 | elif version is not None: 43 | return str(Path('lightning_logs', version, 'checkpoints')) 44 | else: 45 | return str( 46 | Path('lightning_logs', f'version_{v_num}', 'checkpoints')) 47 | 48 | if root == version == v_num is None: 49 | return None 50 | 51 | root = generate_root() 52 | if Path(root).is_file(): 53 | return root 54 | if best: 55 | files = [ 56 | i for i in list(Path(root).iterdir()) if i.stem.startswith('best') 57 | ] 58 | files.sort(key=sort_by_epoch, reverse=True) 59 | res = str(files[0]) 60 | else: 61 | res = str(Path(root) / 'last.ckpt') 62 | return res 63 | 64 | 65 | def load_model_path_by_args(args): 66 | return load_model_path(root=args.load_dir, 67 | version=args.load_ver, 68 | v_num=args.load_v_num) 69 | 70 | class MyTimeLogCallback(pl.Callback): 71 | 72 | def __init__(self, log_file=None): 73 | super().__init__() 74 | self.tics = {'total_epoch': 0} 75 | self.epoch_count = 0 76 | self.log_file = log_file 77 | 78 | def on_train_start(self, trainer, pl_module): 79 | if self.log_file is not None: 80 | try: 81 | self.log_file = open(self.log_file, 'w') 82 | except OSError: 83 | print('Fail to open log_file: {}'.format( 84 | self.log_file)) 85 | self.log_file = None 86 | 87 | def on_train_epoch_start(self, trainer, pl_module): 88 | self.tics['epoch'] = time.time() 89 | 90 | def on_train_epoch_end(self, trainer, pl_module, unused=None): 91 | self.epoch_count += 1 92 | epoch_time = time.time()-self.tics['epoch'] 93 | m, s = divmod(epoch_time, 60) 94 | h, m = divmod(m, 60) 95 | print('epoch time = {:0>2.0f}:{:0>2.0f}:{:0>2.0f}'.format(h, m, s), 96 | file=self.log_file) 97 | self.tics['total_epoch'] += epoch_time 98 | 99 | def on_train_end(self, trainer, pl_module): 100 | avg_epoch_time = self.tics['total_epoch'] / self.epoch_count 101 | m, s = divmod(avg_epoch_time, 60) 102 | h, m = divmod(m, 60) 103 | print('avg epoch time = {:0>2.0f}:{:0>2.0f}:{:0>2.0f}'.format(h, m, s), 104 | file=self.log_file) 105 | if self.log_file: 106 | self.log_file.close() 107 | ''' 108 | -------------------------------------------------------------------------------- /TOV_v1/segmentation/config/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import importlib 4 | import sys 5 | import warnings 6 | from pytorch_lightning.utilities import rank_zero_only 7 | 8 | warnings.filterwarnings("ignore") 9 | 10 | TYPE2BAND = {'RGB': 3, 'NIR': 1, 'SAR': 1, 'TEN': 10, 'ALL': 12, 'MS': 13} # 'ALL' for sentienl data; 'MS' for EuroSAT 11 | MODE_NAME = {1: 'train', 2: 'val', 3: 'test', 4: 'finetune', # 5: 'exp', 12 | 5: '', 6: '', 7: '', 13 | 8: '', 9: 'pretrain', 0: 'debug'} 14 | 15 | 16 | class __redirection__: 17 | def __init__(self, mode='console', file_path=None): 18 | assert mode in ['console', 'file', 'both'] 19 | 20 | self.mode = mode 21 | self.buff = '' 22 | self.__console__ = sys.stdout 23 | 24 | self.file = None 25 | if file_path is not None and mode != 'console': 26 | try: 27 | self.file = open(file_path, "w", buffering=1) 28 | except OSError: 29 | print('Fail to open log_file: {}'.format( 30 | file_path)) 31 | 32 | @rank_zero_only 33 | def write(self, output_stream): 34 | self.buff += output_stream 35 | if self.mode == 'console': 36 | self.to_console(output_stream) 37 | elif self.mode == 'file': 38 | self.to_file(output_stream) 39 | elif self.mode == 'both': 40 | self.to_console(output_stream) 41 | self.to_file(output_stream) 42 | 43 | @rank_zero_only 44 | def to_console(self, content): 45 | sys.stdout = self.__console__ 46 | print(content, end='') 47 | sys.stdout = self 48 | 49 | @rank_zero_only 50 | def to_file(self, content): 51 | if self.file is not None: 52 | sys.stdout = self.file 53 | print(content, end='') 54 | sys.stdout = self 55 | 56 | @rank_zero_only 57 | def all_to_console(self, flush=False): 58 | sys.stdout = self.__console__ 59 | print(self.buff, end='') 60 | sys.stdout = self 61 | 62 | @rank_zero_only 63 | def all_to_file(self, file_path=None, flush=True): 64 | if file_path is not None: 65 | self.open(file_path) 66 | if self.file is not None: 67 | sys.stdout = self.file 68 | print(self.buff, end='') 69 | sys.stdout = self 70 | # self.file.close() 71 | 72 | @rank_zero_only 73 | def open(self, file_path): 74 | try: 75 | self.file = open(file_path, "w", buffering=1) 76 | except OSError: 77 | print('Fail to open log_file: {}'.format( 78 | file_path)) 79 | 80 | @rank_zero_only 81 | def close(self): 82 | if self.file is not None: 83 | self.file.close() 84 | self.file = None 85 | 86 | @rank_zero_only 87 | def flush(self): 88 | self.buff = '' 89 | 90 | @rank_zero_only 91 | def reset(self): 92 | sys.stdout = self.__console__ 93 | 94 | 95 | def get_opt(name, args=None, redirection=False): 96 | '''Get options by name and current platform, and may use args to update them.''' 97 | 98 | get_config = importlib.import_module('config.default').get_config 99 | opts = get_config(name) 100 | if args is None: 101 | return opts # simple mode 102 | 103 | opts = preprocess_settings(opts, args) 104 | 105 | # Normalize the form of some parameters 106 | opts.dtype = unify_type(opts.dtype, list) 107 | opts.input_size = unify_type(opts.input_size, tuple, 2) 108 | for dt in opts.dtype: 109 | opts.in_channel += TYPE2BAND[dt] 110 | opts.mean = opts.mean if len(opts.mean) == opts.in_channel else opts.mean * opts.in_channel 111 | opts.std = opts.std if len(opts.std) == opts.in_channel else opts.std * opts.in_channel 112 | 113 | # Generate logging flags / info. 114 | opts.timestamp = time.strftime('%y%j%H%M%S', time.localtime(time.time())) 115 | mode_digits = len(str(opts.mode)) 116 | opts.mode_name = MODE_NAME[opts.mode // (10**(mode_digits-1))] 117 | if not hasattr(opts, 'net_suffix'): 118 | opts.net_suffix = opts.timestamp + '_' + opts.mode_name 119 | if opts.mode_name == 'finetune': 120 | assert opts.load_pretrained, '`load_pretrained` must be provided for finetune' 121 | expnum = 'temp' 122 | for s in opts.load_pretrained.split('_'): 123 | if len(s) == 11 and s.isdigit(): 124 | expnum = s 125 | break 126 | elif s in ['other']: 127 | expnum = s 128 | break 129 | opts.net_suffix = expnum + '_' + opts.net_suffix 130 | opts.exp_name = '{}_{}_{}'.format(opts.env, opts.model_name, opts.net_suffix) 131 | 132 | # Re-direction the print 133 | if redirection: 134 | print_mode = 'both' 135 | stream = __redirection__(print_mode) 136 | sys.stdout = stream 137 | 138 | print('Current Main Mode: {} - {}\n'.format(opts.mode, opts.exp_name)) 139 | 140 | # Basic prepare 141 | if opts.ckpt: 142 | if not os.path.exists(opts.ckpt): 143 | os.mkdir(opts.ckpt) 144 | 145 | if redirection: 146 | return opts, stream 147 | return opts 148 | 149 | 150 | def preprocess_settings(opts, args=None): 151 | 152 | # Combine user args with predefined opts 153 | if args is not None: 154 | for attr in opts.__dir__(): 155 | if attr.startswith('__') or attr.endswith('__'): 156 | continue 157 | value = getattr(opts, attr) 158 | if attr not in args or args.__dict__[attr] is None: 159 | setattr(args, attr, value) 160 | opts = args 161 | 162 | # Add some default settings 163 | if opts is not None: 164 | if not hasattr(opts, 'pj') or not opts.pj: 165 | opts.pj = opts.env 166 | 167 | if opts.max_epochs is None: 168 | opts.max_epochs = 3 169 | 170 | # Other 171 | if not hasattr(opts, 'full_train') or opts.full_train is None: 172 | opts.full_train = False 173 | 174 | return opts 175 | 176 | 177 | def unify_type(param, ptype=list, repeat=1): 178 | ''' Unify the type of param. 179 | 180 | Args: 181 | ptype: support list or tuple 182 | repeat: The times of repeating param in a list or tuple type. 183 | ''' 184 | if repeat == 1: 185 | if type(param) is not ptype: 186 | if ptype == list: 187 | param = [param] 188 | elif ptype == tuple: 189 | param = (param) 190 | elif repeat > 1: 191 | if type(param) is ptype and len(param) == repeat: 192 | return param 193 | elif type(param) is list: 194 | param = param * repeat 195 | else: 196 | param = [param] * repeat 197 | param = ptype(param) 198 | 199 | return param 200 | 201 | 202 | def args_list2dict(params): 203 | if type(params) is dict: 204 | return params 205 | elif type(params) is list: 206 | assert len(params) % 2 == 0, 'Must be paired args' 207 | options = {} 208 | for i in range(0, len(params), 2): 209 | options[params[i]] = params[i+1] 210 | 211 | return options 212 | -------------------------------------------------------------------------------- /TOV_v1/segmentation/config/category.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | 4 | class CVPR_LandCover(object): 5 | ''' Category details of CVPR_LandCover dataset. 6 | 7 | | 场景/地物名称 | 标定标签 | 8 | | ----- | ----- | 9 | | Ignore | 0| 10 | | Urban land | 1| 11 | | Agriculture land | 2| 12 | | Rangeland | 3| 13 | | Forest land | 4| 14 | | Water | 5| 15 | | Barren land | 6| 16 | | Unknown | 7| 17 | 18 | ''' 19 | plan_name = 'CVPR_LandCover' 20 | table = { 21 | 'Urban land': 0, 22 | 'Agriculture land': 1, 23 | 'Rangeland': 2, 24 | 'Forest land': 3, 25 | 'Water': 4, 26 | 'Barren land': 5, 27 | 'Unknown': 6, 28 | } 29 | 30 | names = [name for (name, _) in table.items()] 31 | inds = table.values() 32 | 33 | num = len(table) 34 | 35 | color_table = [ 36 | # [5, 5, 5], # 0 37 | [0, 255, 255], # 0 38 | [255, 255, 0], # 1 39 | [255, 0, 255], # 2 40 | [0, 255, 0], # 3 41 | [0, 0, 255], # 4 42 | [255, 255, 255], # 5 43 | [0, 0, 0], # 6 44 | ] # RGB 45 | 46 | bgr_table = [tb[::-1] for tb in color_table] # BGR 47 | mapping = None 48 | 49 | 50 | class ISPRS(object): 51 | ''' Category details of ISPRS dataset. 52 | 53 | |场景/地物名称|标定标签| 54 | |-----|-----| 55 | |BG (ignore) | 0| 56 | |Impervious surfaces | 0| 57 | |Building | 1| 58 | |Low vegetation | 2| 59 | |Tree | 3| 60 | |Car | 4| 61 | |Clutter/background | 5| 62 | ''' 63 | plan_name = 'ISPRS' 64 | names = [ 65 | 'Ignore', 66 | 'Impervious surfaces', 67 | 'Building', 68 | 'Low vegetation', 69 | 'Tree', 70 | 'Car', 71 | 'Clutter/background', 72 | ] 73 | table = {cls: i for (i, cls) in enumerate(names)} 74 | inds = list(range(len(names))) 75 | 76 | num = len(table) 77 | 78 | color_table = [ 79 | [0, 0, 0], # 0 80 | [255, 255, 255], # 1 81 | [0, 0, 255], # 2 82 | [0, 255, 255], # 3 83 | [0, 255, 0], # 4 84 | [255, 255, 0], # 5 85 | [255, 0, 0], # 6 86 | ] # RGB 87 | 88 | bgr_table = [tb[::-1] for tb in color_table] # BGR 89 | mapping = None 90 | 91 | 92 | class DLRSD(object): 93 | ''' Category details of DLRSD dataset. ''' 94 | table = { 95 | "none": 0, 96 | "airplane": 1, 97 | "bare soil": 2, 98 | "buildings": 3, 99 | "cars": 4, 100 | "chaparral": 5, 101 | "court": 6, 102 | "dock": 7, 103 | "field": 8, 104 | "grass": 9, 105 | "mobile home": 10, 106 | "pavement": 11, 107 | "sand": 12, 108 | "sea": 13, 109 | "ship": 14, 110 | "tanks": 15, 111 | "trees": 16, 112 | "water": 17, 113 | } 114 | names = [name for (name, _) in table.items()] 115 | inds = table.values() 116 | 117 | num = len(table) 118 | 119 | color_table = [ 120 | [0, 0, 0], 121 | [166, 202, 240], 122 | [128, 128, 0], 123 | [0, 0, 128], 124 | [255, 0, 0], 125 | [0, 128, 0], 126 | [128, 0, 0], 127 | [255, 233, 233], 128 | [160, 160, 164], 129 | [0, 128, 128], 130 | [90, 87, 255], 131 | [255, 255, 0], 132 | [255, 192, 0], 133 | [0, 0, 255], 134 | [255, 0, 192], 135 | [128, 0, 128], 136 | [0, 255, 0], 137 | [0, 255, 255], 138 | ] # RGB 139 | 140 | bgr_table = [tb[::-1] for tb in color_table] # BGR 141 | mapping = None 142 | -------------------------------------------------------------------------------- /TOV_v1/segmentation/config/default/__init__.py: -------------------------------------------------------------------------------- 1 | def get_config(name, args=None): 2 | if '+' in name: 3 | pass 4 | elif 'cvprlc' in name: 5 | from .opt_CVPR_LandCover import Config 6 | elif 'dlrsd' in name: 7 | from .opt_DLRSD import Config 8 | elif 'isprspd' in name: 9 | from .opt_ISPRS_Postdam import Config 10 | 11 | if args is not None: 12 | mOptions = Config(args) 13 | else: 14 | mOptions = Config() 15 | 16 | return mOptions 17 | 18 | # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | # sys.path.append(BASE_DIR+'/Tools/dltoos') 20 | -------------------------------------------------------------------------------- /TOV_v1/segmentation/config/default/opt_CVPR_LandCover.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Default Config about CVPR_LandCover dataset, segmentation mode 5 | ''' 6 | from config.category import CVPR_LandCover 7 | 8 | 9 | class Config(object): 10 | ''' Fixed initialization Settings ''' 11 | # Path and file 12 | root = '' 13 | data_dir = root + 'segmentation/Seg_data/CVPR_LandCover' 14 | ckpt = root + 'Task_models/CVPR_LandCover' 15 | env = 'CVPRLC' # visdom 环境 16 | 17 | train_val_ratio = None 18 | train_scale = 1 # Scale of training set reduction 19 | val_scale = 0.1 20 | # Model related arguments 21 | sepoch = 1 # use to continue from a checkpoint 22 | pretrained = False # use pre-train checkpoint 23 | 24 | # Optimiztion related arguments 25 | batch_size = 16 # batch size 26 | max_epochs = None # 16-[256,256] dataset only need 8~9 epoch 27 | ckpt_freq = 0 28 | learning_rate = 4e-4 # initial learning rate 29 | lr_decay_rate = 0.5 30 | lr_decay_steps = 20 31 | lr_scheduler = 'cosine' # pre epoch 32 | warmup = 0 # if warmup > 0, use warmup strategy and end at warmup 33 | weight_decay = 1e-5 # L2 loss 34 | optimizer = ['adam', 'sgd', 'lars'][1] 35 | loss = ['CrossEntropyLoss'][0] 36 | loss_weight = None 37 | 38 | # Data related arguments 39 | num_workers = 4 # number of data loading workers 40 | dtype = ['RGB'][0] 41 | bl_dtype = [''][0] 42 | in_channel = 0 43 | 44 | input_size = (256, 256) # final input size of network(random-crop use this) 45 | # crop_params = [256, 256, 256] # crop_params for val and pre 46 | 47 | # feature_dim = {1: 128, 2: 256, 3: 384, 4: 512} 48 | mean = [0.5, 0.5, 0.5] # BGR, 此处的均值应该是0-1 49 | std = [0.5, 0.5, 0.5] # [0.12283102, 0.1269429, 0.15580289] 50 | 51 | category = CVPR_LandCover() 52 | classes = category.names 53 | class_dict = category.table 54 | num_classes = len(classes) 55 | 56 | reduce_zero_label = False 57 | ignore_index = None 58 | -------------------------------------------------------------------------------- /TOV_v1/segmentation/config/default/opt_DLRSD.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Default Config about DLRSD dataset, segmentation mode 5 | ''' 6 | from config.category import DLRSD 7 | 8 | 9 | class Config(object): 10 | ''' Fixed initialization Settings ''' 11 | # Path and file 12 | root = '' 13 | data_dir = root + 'segmentation/Seg_data/DLRSD' 14 | ckpt = root + 'Task_models/DLRSD' 15 | env = 'DLRSD' # visdom 环境 16 | 17 | train_val_ratio = None 18 | train_scale = 1 # Scale of training set reduction 19 | val_scale = 0.5 20 | # Model related arguments 21 | sepoch = 1 # use to continue from a checkpoint 22 | pretrained = False # use pre-train checkpoint 23 | 24 | # Optimiztion related arguments 25 | batch_size = 16 # batch size 26 | max_epochs = None # 16-[256,256] dataset only need 8~9 epoch 27 | ckpt_freq = 0 28 | learning_rate = 4e-4 # initial learning rate 29 | lr_decay_rate = 0.5 30 | lr_decay_steps = 20 31 | lr_scheduler = 'cosine' # pre epoch 32 | warmup = 0 # if warmup > 0, use warmup strategy and end at warmup 33 | weight_decay = 1e-5 # L2 loss 34 | optimizer = ['adam', 'sgd', 'lars'][1] 35 | loss = ['CrossEntropyLoss'][0] 36 | loss_weight = None 37 | 38 | # Data related arguments 39 | num_workers = 4 # number of data loading workers 40 | dtype = ['RGB'][0] 41 | bl_dtype = [''][0] 42 | in_channel = 0 43 | 44 | input_size = (256, 256) # final input size of network(random-crop use this) 45 | # crop_params = [256, 256, 256] # crop_params for val and pre 46 | 47 | # feature_dim = {1: 128, 2: 256, 3: 384, 4: 512} 48 | mean = [0.5, 0.5, 0.5] # BGR, 此处的均值应该是0-1 49 | std = [0.5, 0.5, 0.5] # [0.12283102, 0.1269429, 0.15580289] 50 | 51 | category = DLRSD() 52 | classes = category.names 53 | class_dict = category.table 54 | num_classes = len(classes) 55 | 56 | reduce_zero_label = False 57 | ignore_index = 0 58 | -------------------------------------------------------------------------------- /TOV_v1/segmentation/config/default/opt_ISPRS_Postdam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Default Config about ISPRS_Postdam dataset, segmentation mode 5 | ''' 6 | from config.category import ISPRS 7 | 8 | 9 | class Config(object): 10 | ''' Fixed initialization Settings ''' 11 | # Path and file 12 | root = '' 13 | data_dir = root + 'segmentation/Seg_data/ISPRS_Postdam' 14 | ckpt = root + 'Task_models/ISPRS_Postdam' 15 | env = 'ISPRS_Postdam' # visdom 环境 16 | 17 | train_val_ratio = None 18 | train_scale = 1 # Scale of training set reduction 19 | val_scale = 0.1 20 | # Model related arguments 21 | sepoch = 1 # use to continue from a checkpoint 22 | pretrained = False # use pre-train checkpoint 23 | 24 | # Optimiztion related arguments 25 | batch_size = 16 # batch size 26 | max_epochs = None # 16-[256,256] dataset only need 8~9 epoch 27 | ckpt_freq = 0 28 | learning_rate = 4e-4 # initial learning rate 29 | lr_decay_rate = 0.5 30 | lr_decay_steps = 20 31 | lr_scheduler = 'cosine' # pre epoch 32 | warmup = 0 # if warmup > 0, use warmup strategy and end at warmup 33 | weight_decay = 1e-5 # L2 loss 34 | optimizer = ['adam', 'sgd', 'lars'][1] 35 | loss = ['CrossEntropyLoss'][0] 36 | loss_weight = None 37 | 38 | # Data related arguments 39 | num_workers = 4 # number of data loading workers 40 | dtype = ['RGB'][0] 41 | bl_dtype = [''][0] 42 | in_channel = 0 43 | 44 | input_size = (256, 256) # final input size of network(random-crop use this) 45 | 46 | mean = [0.5, 0.5, 0.5] 47 | std = [0.5, 0.5, 0.5] 48 | 49 | category = ISPRS() 50 | classes = category.names 51 | class_dict = category.table 52 | num_classes = len(classes) 53 | 54 | reduce_zero_label = False 55 | ignore_index = 0 56 | -------------------------------------------------------------------------------- /TOV_v1/segmentation/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_interface import DInterface, datasets_init 2 | from .seg_data import SegData 3 | 4 | __all__ = [ 5 | 'DInterface', 'datasets_init' 6 | 'SegData' 7 | ] 8 | -------------------------------------------------------------------------------- /TOV_v1/segmentation/dataset/seg_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Dataset for segmentation. 4 | ''' 5 | import time 6 | import numpy as np 7 | import cv2 8 | from torch.utils import data 9 | 10 | TYPE2BAND = {'RGB': 3, 'NIR': 1, 'SAR': 1} 11 | 12 | 13 | def imread_with_cv(pth): 14 | ''' Only load RGB images data. ''' 15 | img = cv2.imread(pth, 1) 16 | img = img[:, :, ::-1] # BGR→RGB 17 | return img.copy() 18 | 19 | 20 | class SegData(data.Dataset): 21 | ''' 22 | Dataset loader for Segmentation. 23 | Args: 24 | init_data - Dict of init data = {'train': {'label': [], 'RGB': []}, 'val': [...]} 25 | split - One of ['train', 'val', 'test'] 26 | ratio - (float) extra parameter to conctrl the num of dataset 27 | if ratio < 0.01, then stand for nshort mode: 28 | ratio = 0.00n, where n is the num of sample per category 29 | transform - 可以传入自定义的transform对象 30 | ''' 31 | 32 | def __init__(self, 33 | init_data, 34 | split='train', 35 | dtype=['RGB'], 36 | ratio=1, 37 | transforms=None, 38 | data_seed=0, 39 | debug=False, 40 | **kwargs): 41 | # Set all input args as attributes 42 | self.__dict__.update(locals()) 43 | 44 | # Collect all dataset files, divide into dict. 45 | tic = time.time() 46 | self.imgs, self.lbls = {}, [] 47 | 48 | total_num = len(init_data[split]['image'][dtype[0]]) 49 | if ratio <= 1: 50 | use_num = max(round(total_num*ratio), 1) 51 | else: 52 | use_num = min(int(ratio), total_num) 53 | 54 | if type(init_data) is dict: 55 | for dt in dtype: 56 | self.imgs[dt] = init_data[split]['image'][dt] 57 | self.lbls = init_data[split]['label'] 58 | 59 | self.shuffle_data(data_seed) # shuffle data with seed 60 | 61 | for dt in self.imgs.keys(): 62 | self.imgs[dt] = self.imgs[dt][:use_num] 63 | self.lbls = self.lbls[:use_num] 64 | # print(self.imgs) 65 | elif type(init_data) is str: 66 | raise NotImplementedError() 67 | self.use_num = len(self.imgs[dtype[0]]) 68 | 69 | print('{} set contains {} images.'.format(split, total_num)) 70 | if split == 'train': 71 | print('*'*6, f'data seed = {data_seed}', '*'*6) 72 | print('Actual number of samples used = {}'.format(self.use_num)) 73 | print(self.imgs[dtype[0]][:5]) 74 | 75 | print('Time to collect data = {:.2f}s.'.format(time.time()-tic)) 76 | 77 | def imread(self, index): 78 | images = [] 79 | for dt in self.dtype: 80 | img = cv2.imread(self.imgs[dt][index], cv2.IMREAD_LOAD_GDAL) 81 | if dt == 'RGB': 82 | img = img[:, :, ::-1] 83 | elif TYPE2BAND[dt] == 1: 84 | img = np.expand_dims(img, axis=2) 85 | images.append(img) 86 | image = np.concatenate(images, axis=2) 87 | return image 88 | 89 | def load_data_from_disk(self, index): 90 | img = self.imread(index) 91 | lbl = None 92 | if self.lbls is not None: 93 | lbl = cv2.imread(self.lbls[index], cv2.IMREAD_LOAD_GDAL) 94 | return img, lbl 95 | 96 | def shuffle_data(self, seed=None): 97 | if seed is None: 98 | seed = np.random.randint(100000) 99 | 100 | for dt in self.imgs.keys(): 101 | np.random.seed(seed) 102 | np.random.shuffle(self.imgs[dt]) 103 | np.random.seed(seed) 104 | np.random.shuffle(self.lbls) 105 | 106 | def __len__(self): 107 | return self.use_num 108 | 109 | def __getitem__(self, index): 110 | ''' Return one image(and label) per time. ''' 111 | img, lbl = self.load_data_from_disk(index) 112 | 113 | if self.transforms is not None: 114 | transformed = self.transforms(image=img, mask=lbl) 115 | img = transformed['image'] 116 | lbl = transformed['mask'] 117 | 118 | sample = {'image': img, 'label': lbl, 'name': self.imgs[self.dtype[0]]} 119 | return sample 120 | -------------------------------------------------------------------------------- /TOV_v1/segmentation/dataset/transforms.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | from albumentations.pytorch.transforms import ToTensorV2 3 | 4 | 5 | def make_transform(args): 6 | base_transform = [ 7 | A.Resize(args.img_size, args.img_size), 8 | A.Normalize( 9 | mean=[0.485, 0.456, 0.406], 10 | std=[0.229, 0.224, 0.225], 11 | ), 12 | ToTensorV2() 13 | ] 14 | 15 | train_transform = [] 16 | if args.Blur: 17 | train_transform.append( 18 | A.Blur(p=args.Blur)) 19 | if args.Blur: 20 | train_transform.append( 21 | A.ElasticTransform(p=args.Blur)) 22 | 23 | if args.CLAHE: 24 | train_transform.append(A.CLAHE(clip_limit=(1, 4), 25 | tile_grid_size=(8, 8), 26 | p=args.CLAHE 27 | )) 28 | if args.RandomBrightnessContrast: 29 | train_transform.append( 30 | A.RandomBrightnessContrast(brightness_limit=0.2, 31 | contrast_limit=0.2, 32 | brightness_by_max=True, 33 | p=args.RandomBrightnessContrast 34 | )) 35 | if args.HueSaturationValue: 36 | train_transform.append(A.HueSaturationValue(hue_shift_limit=20, 37 | sat_shift_limit=30, 38 | val_shift_limit=20, 39 | p=args.HueSaturationValue 40 | )) 41 | if args.RGBShift: 42 | train_transform.append(A.RGBShift(r_shift_limit=20, 43 | g_shift_limit=20, 44 | b_shift_limit=20, 45 | p=args.RGBShift 46 | )) 47 | if args.RandomGamma: 48 | train_transform.append(A.RandomGamma(gamma_limit=(80, 120), 49 | p=args.RandomGamma 50 | )) 51 | if args.HorizontalFlip: 52 | train_transform.append(A.HorizontalFlip(p=args.HorizontalFlip)) 53 | 54 | if args.VerticalFlip: 55 | train_transform.append(A.VerticalFlip(p=args.VerticalFlip)) 56 | 57 | if args.ShiftScaleRotate: 58 | train_transform.append(A.ShiftScaleRotate(shift_limit=0.2, 59 | scale_limit=0.2, 60 | rotate_limit=10, 61 | border_mode=args.ShiftScaleRotateMode, 62 | p=args.ShiftScaleRotate 63 | )) 64 | if args.GridDistortion: 65 | train_transform.append(A.GridDistortion(num_steps=5, 66 | distort_limit=(-0.3, 0.3), 67 | p=args.GridDistortion 68 | )) 69 | if args.MotionBlur: 70 | train_transform.append(A.MotionBlur(blur_limit=(3, 7), 71 | p=args.MotionBlur 72 | )) 73 | if args.RandomResizedCrop: 74 | train_transform.append(A.RandomResizedCrop(height=args.img_size, 75 | width=args.img_size, 76 | scale=(-0.4, 1.0), 77 | ratio=(0.75, 1.3333333333333333), 78 | p=args.RandomResizedCrop 79 | )) 80 | if args.ImageCompression: 81 | train_transform.append(A.ImageCompression(quality_lower=99, 82 | quality_upper=100, 83 | p=args.ImageCompression 84 | )) 85 | train_transform.extend(base_transform) 86 | 87 | train_transform = A.Compose(train_transform) 88 | test_transform = A.Compose(base_transform) 89 | 90 | return train_transform, test_transform 91 | -------------------------------------------------------------------------------- /TOV_v1/segmentation/models/fcn.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | ''' 3 | Package the torchvision model into FCN. 4 | ''' 5 | 6 | import torch 7 | import collections 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from torchvision.models._utils import IntermediateLayerGetter 12 | 13 | 14 | class FCNHead(nn.Sequential): 15 | def __init__(self, in_channels, channels): 16 | inter_channels = in_channels // 4 17 | layers = [ 18 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 19 | nn.BatchNorm2d(inter_channels), 20 | nn.ReLU(), 21 | nn.Dropout(0.1), 22 | nn.Conv2d(inter_channels, channels, 1) 23 | ] 24 | 25 | super(FCNHead, self).__init__(*layers) 26 | 27 | 28 | class FCNHead_simple(nn.Sequential): 29 | def __init__(self, in_channels, channels): 30 | layers = [ 31 | nn.ReLU(), 32 | nn.Conv2d(in_channels, channels, 3, padding=1, bias=False), 33 | nn.BatchNorm2d(channels), 34 | ] 35 | 36 | super(FCNHead_simple, self).__init__(*layers) 37 | 38 | 39 | class UpBlock(nn.Sequential): 40 | def __init__(self, in_channels, channels): 41 | layers = [ 42 | nn.Conv2d(in_channels, channels, 3, padding=1, bias=False), 43 | nn.BatchNorm2d(channels), 44 | # nn.ReLU(), 45 | ] 46 | super(UpBlock, self).__init__(*layers) 47 | 48 | def forward(self, x): 49 | for mod in self: 50 | x = mod(x) 51 | return F.interpolate( 52 | x, scale_factor=2.0, mode='bilinear', align_corners=False) 53 | 54 | 55 | class FCN(nn.Module): 56 | """ Fully conv network contains `features` and `classifier`. 57 | """ 58 | def __init__(self, backbone, num_classes=2, aux=None, 59 | **kwargs): 60 | super().__init__() 61 | if hasattr(backbone, 'model_name'): 62 | bb_name = backbone.model_name 63 | else: 64 | bb_name = backbone._get_name().lower() 65 | self.model_name = 'FCN_' + bb_name 66 | 67 | if 'alexnet' in bb_name: 68 | self.features = backbone.features 69 | # self.out_dim = 256 70 | elif 'resnet' in bb_name or 'resnext' in bb_name: 71 | return_layers = {'layer4': 'out'} 72 | if aux: 73 | self.aux_dim = 1024 74 | if ('resnet18' in bb_name) or ('resnet34' in bb_name): 75 | self.out_dim = 256 76 | return_layers['layer3'] = 'aux' 77 | self.features = IntermediateLayerGetter( 78 | backbone, return_layers=return_layers) 79 | elif 'vgg' in bb_name: 80 | self.features = backbone.features 81 | # self.out_dim = 512 82 | elif 'googlenet' in bb_name: 83 | backbone = nn.Sequential(collections.OrderedDict(dict(backbone.named_children()))) 84 | self.features = backbone[:16] 85 | # self.out_dim = 1024 86 | elif 'inception' in bb_name: 87 | pass 88 | 89 | self.out_dim = None 90 | if hasattr(backbone, 'out_dim'): 91 | self.out_dim = backbone.out_dim 92 | else: 93 | # loop layers and get last conv channels 94 | for name, m in self.features.named_modules(): 95 | if isinstance(m, torch.nn.Conv2d): 96 | self.out_dim = m.out_channels 97 | 98 | self.aux_classifier = None 99 | if aux: 100 | self.aux_classifier = FCNHead(self.aux_dim, num_classes) 101 | 102 | self.classifier = FCNHead(self.out_dim, num_classes) 103 | 104 | def forward(self, x): 105 | input_shape = x.shape[-2:] 106 | # contract: features is a dict of tensors 107 | features = self.features(x) 108 | 109 | result = collections.OrderedDict() 110 | x = features["out"] 111 | x = self.classifier(x) 112 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 113 | result["out"] = x 114 | 115 | if self.aux_classifier is not None: 116 | x = features["aux"] 117 | x = self.aux_classifier(x) 118 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 119 | result["aux"] = x 120 | 121 | return result 122 | 123 | 124 | class FCN_d8(FCN): 125 | """ Fully conv network contains `features` and `classifier`. 126 | """ 127 | def __init__(self, backbone, num_classes=2, aux=None, 128 | **kwargs): 129 | super(FCN_d8, self).__init__(backbone, num_classes, aux, **kwargs) 130 | 131 | self.model_name = self.model_name.replace('FCN', 'FCN_d8') 132 | 133 | if aux: 134 | self.aux_classifier = FCNHead(self.aux_dim, num_classes) 135 | 136 | self.uplayer1 = UpBlock(self.out_dim, self.out_dim//2) # d4 137 | self.uplayer2 = UpBlock(self.out_dim//2, self.out_dim//4) # d2 138 | 139 | self.classifier = FCNHead_simple(self.out_dim//4, num_classes) # d1 140 | 141 | def forward(self, x): 142 | input_shape = x.shape[-2:] 143 | # contract: features is a dict of tensors 144 | features = self.features(x) 145 | 146 | result = collections.OrderedDict() 147 | x = features["out"] 148 | x = self.uplayer1(x) # d4 149 | x = self.uplayer2(x) # d2 150 | x = self.classifier(x) # d1 151 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 152 | result["out"] = x 153 | 154 | if self.aux_classifier is not None: 155 | x = features["aux"] 156 | x = self.aux_classifier(x) 157 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 158 | result["aux"] = x 159 | 160 | return result 161 | -------------------------------------------------------------------------------- /big_picture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/G-RSIM/518efd3776d8e4937b0a4c34137a5bbda4b4f8a1/big_picture.png --------------------------------------------------------------------------------