├── fig └── overall_networks.png ├── SSLRemoteSensing ├── models │ ├── builder.py │ ├── backbone │ │ ├── builder.py │ │ ├── vgg.py │ │ └── resnet.py │ └── representation │ │ └── vr_nets_inpainting_agr_examplar.py ├── losses │ ├── vr_losses.py │ ├── builder.py │ └── examplar_loss.py ├── utils │ ├── optims │ │ └── builder.py │ ├── path_utils.py │ └── utils.py ├── metric │ └── time_metric.py └── datasets │ ├── transforms │ └── representation │ │ ├── builder.py │ │ ├── agr_transforms.py │ │ ├── inpainting_transforms.py │ │ └── transforms.py │ └── datasets │ └── representation │ └── vr_dataset_inpainting_agr.py ├── train.py ├── requirements.txt ├── readme.md └── configs ├── vr_vgg16_inapinting_agr_examplar_cfg.py └── vr_resnet50_inapinting_agr_examplar_cfg.py /fig/overall_networks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flyakon/SSLRemoteSensing/HEAD/fig/overall_networks.png -------------------------------------------------------------------------------- /SSLRemoteSensing/models/builder.py: -------------------------------------------------------------------------------- 1 | from .representation.vr_nets_inpainting_agr_examplar import VRNetsWithInpaintingAGRExamplar 2 | 3 | 4 | models_dict={ 5 | 'VRNetsWithInpaintingAGRExamplar':VRNetsWithInpaintingAGRExamplar 6 | } 7 | 8 | def builder_models(name='VRNetsWithInpainting',**kwargs): 9 | if name in models_dict.keys(): 10 | return models_dict[name](**kwargs) 11 | else: 12 | raise NotImplementedError('name not in availables values.'.format(name)) -------------------------------------------------------------------------------- /SSLRemoteSensing/losses/vr_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class InpaintingLoss(object): 5 | 6 | def __init__(self,**kwargs): 7 | super(InpaintingLoss,self).__init__() 8 | self.criterion=nn.L1Loss(reduction='none') 9 | 10 | def __call__(self, logits,labels,attention_map): 11 | loss = self.criterion(logits, labels) 12 | loss = loss * attention_map 13 | count=torch.sum(attention_map)+1e-6 14 | loss = torch.sum(loss) / count 15 | return loss 16 | -------------------------------------------------------------------------------- /SSLRemoteSensing/losses/builder.py: -------------------------------------------------------------------------------- 1 | from .vr_losses import InpaintingLoss 2 | from .examplar_loss import ExamplarLoss 3 | 4 | 5 | 6 | import torch 7 | import torchvision 8 | import torch.nn as nn 9 | 10 | losses_dict={'CrossEntropyLoss':nn.CrossEntropyLoss,'InpaintingLoss':InpaintingLoss, 11 | 'ExamplarLoss':ExamplarLoss,'L1Loss':nn.L1Loss} 12 | 13 | 14 | def builder_loss(name='CrossEntropyLoss',**kwargs): 15 | 16 | if name in losses_dict.keys(): 17 | return losses_dict[name](**kwargs) 18 | else: 19 | raise NotImplementedError('name not in availables values.'.format(name)) -------------------------------------------------------------------------------- /SSLRemoteSensing/models/backbone/builder.py: -------------------------------------------------------------------------------- 1 | from .resnet import get_resnet 2 | from .vgg import get_vgg 3 | import torch 4 | 5 | def build_backbone(name='resnet50',**kwargs): 6 | if name.startswith('resnet'): 7 | model=get_resnet(name,**kwargs) 8 | elif name.startswith('vgg'): 9 | model=get_vgg(name,**kwargs) 10 | else: 11 | raise NotImplementedError(r'''{0} is not an available values. \ 12 | Please choose one of the available values in 13 | [resnet18, reset50, resnet101, resnet152,vgg11,vgg16]'''.format(name)) 14 | return model -------------------------------------------------------------------------------- /SSLRemoteSensing/utils/optims/builder.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | optim_dict={'Adam':optim.Adam, 4 | 'SGD':optim.SGD} 5 | lr_schedule_dict={'stepLR':optim.lr_scheduler.StepLR} 6 | def build_optim(name='Adam',**kwargs): 7 | if name in optim_dict.keys(): 8 | return optim_dict[name](**kwargs) 9 | else: 10 | raise NotImplementedError('name not in availables values.'.format(name)) 11 | 12 | 13 | def build_lr_schedule(name='stepLR',**kwargs): 14 | if name in lr_schedule_dict.keys(): 15 | return lr_schedule_dict[name](**kwargs) 16 | else: 17 | raise NotImplementedError('name not in availables values.'.format(name)) -------------------------------------------------------------------------------- /SSLRemoteSensing/utils/path_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def get_filename(file_path,is_suffix=True): 4 | file_name=file_path.replace('\\','/') 5 | file_name=file_name.split('/')[-1] 6 | if is_suffix: 7 | return file_name 8 | else: 9 | index=file_name.rfind('.') 10 | if index>0: 11 | return file_name[0:index] 12 | else: 13 | return file_name 14 | 15 | def get_parent_folder(file_path,with_root=False): 16 | 17 | file_path=file_path.replace('\\','/') 18 | 19 | if os.path.isdir(file_path): 20 | parent_folder=file_path 21 | else: 22 | index = file_path.rfind('/') 23 | parent_folder=file_path[0:index] 24 | if not with_root: 25 | return get_filename(parent_folder) 26 | return parent_folder 27 | 28 | if __name__=='__main__': 29 | path=r'G:\deep_learning\dataSet\UCMerced_LandUse\Images' 30 | print(get_parent_folder(path)) 31 | print(os.listdir(path)) -------------------------------------------------------------------------------- /SSLRemoteSensing/metric/time_metric.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @anthor: Wenyuan Li 3 | @desc: Time Metric 4 | @date: 2020/5/17 5 | ''' 6 | import datetime 7 | 8 | class TimeMetric(object): 9 | 10 | def __init__(self): 11 | self.start_time=datetime.datetime.now() 12 | 13 | def start(self): 14 | self.start_time=datetime.datetime.now() 15 | def reset(self): 16 | self.start_time=datetime.datetime.now() 17 | 18 | 19 | def get_time_ms(self): 20 | self.end_time=datetime.datetime.now() 21 | ms_time=(self.end_time-self.start_time).seconds*1000+\ 22 | (self.end_time-self.start_time).microseconds/1000. 23 | return ms_time 24 | def get_time(self): 25 | self.end_time = datetime.datetime.now() 26 | return (self.end_time-self.start_time).seconds 27 | 28 | def get_fps(self,toal_frames): 29 | total_time=self.get_time()+1e-6 30 | fps=toal_frames/total_time 31 | return fps -------------------------------------------------------------------------------- /SSLRemoteSensing/datasets/transforms/representation/builder.py: -------------------------------------------------------------------------------- 1 | from .transforms import Rotate,HorizontalFlip,VerticalFlip,ColorJitter,\ 2 | RandomCrop,Resize,Normal 3 | import torchvision 4 | from torchvision.transforms.transforms import RandomGrayscale 5 | transforms_dict={ 6 | 'RandomHorizontalFlip':HorizontalFlip, 7 | 'RandomVerticalFlip':VerticalFlip, 8 | 'Rotate':Rotate, 9 | 'ColorJitter':ColorJitter, 10 | 'RandomCrop':RandomCrop, 11 | 'ToTensor':torchvision.transforms.ToTensor, 12 | 'Resize':Resize, 13 | 'RandomGrayscale':torchvision.transforms.RandomGrayscale, 14 | 'Normal':Normal} 15 | 16 | 17 | def build_transforms(name='NoiseTransforms',**kwargs): 18 | if name in transforms_dict.keys(): 19 | return transforms_dict[name](**kwargs) 20 | else: 21 | raise NotImplementedError('name not in available values.'.format(name)) -------------------------------------------------------------------------------- /SSLRemoteSensing/datasets/transforms/representation/agr_transforms.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @anthor: Wenyuan Li 3 | @desc: Transforms for self-supervised 4 | @date: 2020/5/22 5 | ''' 6 | from .builder import build_transforms 7 | import torchvision 8 | import numpy as np 9 | 10 | class AGRTransforms(object): 11 | 12 | def __init__(self,transforms_cfg:dict,shortcut_cfg,**kwargs): 13 | 14 | self.transforms=[] 15 | for param in transforms_cfg.values(): 16 | self.transforms.append(build_transforms(**param)) 17 | 18 | shortcut_transforms=[] 19 | for param in shortcut_cfg.values(): 20 | shortcut_transforms.append(build_transforms(**param)) 21 | shortcut_transforms=torchvision.transforms.RandomOrder(shortcut_transforms) 22 | self.shortcut_transforms=torchvision.transforms.Compose([shortcut_transforms]) 23 | 24 | def forward(self,img): 25 | agr_label = np.random.randint(0, len(self.transforms), 1)[0] 26 | post_img=self.transforms[agr_label](img) 27 | post_img=self.shortcut_transforms(post_img) 28 | return post_img,agr_label -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import numpy as np 4 | import mmcv 5 | from SSLRemoteSensing.models.builder import builder_models 6 | from SSLRemoteSensing.utils import utils 7 | 8 | parse=argparse.ArgumentParser() 9 | # parse.add_argument('--config_file', 10 | # default=r'configs/vr_resnet50_inapinting_agr_cfg.py',type=str) 11 | parse.add_argument('--config_file',default=r'configs/vr_vgg16_inapinting_agr_examplar_cfg.py',type=str) 12 | # 13 | parse.add_argument('--checkpoints_path',default=None,type=str) 14 | parse.add_argument('--with_imagenet',default=None,type=utils.str2bool) 15 | parse.add_argument('--log_path',default=None,type=str) 16 | 17 | if __name__=='__main__': 18 | args = parse.parse_args() 19 | print(args) 20 | cfg = mmcv.Config.fromfile(args.config_file) 21 | if args.with_imagenet is not None: 22 | cfg['config']['backbone_cfg']['pretrained'] = args.with_imagenet 23 | models=builder_models(**cfg['config']) 24 | 25 | run_args={} 26 | 27 | models.run_train_interface(checkpoint_path=args.checkpoints_path, 28 | 29 | log_path=args.log_path) 30 | 31 | -------------------------------------------------------------------------------- /SSLRemoteSensing/losses/examplar_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.functional as F 6 | 7 | 8 | class ExamplarLoss(nn.Module): 9 | 10 | def __init__(self,batch_size,device): 11 | super(ExamplarLoss,self).__init__() 12 | self.batch_size=batch_size 13 | self.loss_mask=torch.zeros([self.batch_size*2,self.batch_size*2],device=device) 14 | for i in range(self.batch_size): 15 | self.loss_mask[2*i,2*i+1]=1 16 | self.loss_mask[2 * i+1, 2 * i] = 1 17 | N = 2 * self.batch_size 18 | self.mask=1-torch.eye(N,N,device=device) 19 | 20 | 21 | 22 | def forward(self,latent1,latent2,eplision=1): 23 | ''' 24 | :param latent1: 25 | :param latent2: 26 | :return: 27 | ''' 28 | latent=torch.stack([latent1,latent2],dim=1) 29 | 30 | N = 2*self.batch_size 31 | latent = latent.reshape((N,-1)) 32 | latent_i=torch.unsqueeze(latent,dim=-1).contiguous() #2N*L*1 33 | 34 | 35 | latent_i=latent_i.repeat((1,1,N)).contiguous() #2N*L*2N 36 | latent_j=torch.transpose(latent_i,0,2).contiguous()#2N*L*2N 37 | 38 | S=torch.sum(latent_i*latent_j,dim=1) #2N*2N 39 | norm_i=torch.norm(latent_i,p=2,dim=1) 40 | norm_j=torch.norm(latent_j,p=2,dim=1) 41 | S=S/(eplision*norm_i*norm_j) #2N*2N 42 | 43 | 44 | S=S*self.mask 45 | 46 | m=nn.LogSoftmax(dim=-1) 47 | loss=m(S) 48 | loss=torch.sum(loss*self.loss_mask)/N 49 | return -loss -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.1 2 | addict==2.2.1 3 | apex==0.1 4 | astunparse==1.6.3 5 | av==6.2.0 6 | beautifulsoup4==4.9.1 7 | cachetools==4.1.1 8 | certifi==2019.9.11 9 | cffi==1.12.3 10 | chainer==6.4.0 11 | chainercv==0.13.1 12 | chardet==3.0.4 13 | cloudpickle==1.2.2 14 | cycler==0.10.0 15 | Cython==0.29.14 16 | cytoolz==0.10.0 17 | dask==2.6.0 18 | decorator==4.4.0 19 | filelock==3.0.12 20 | future==0.18.2 21 | gast==0.3.3 22 | GDAL==3.1.2 23 | google-auth==1.21.1 24 | google-auth-oauthlib==0.4.1 25 | google-pasta==0.2.0 26 | grpcio==1.32.0 27 | h5py==2.10.0 28 | idna==2.8 29 | imagecodecs==2019.5.22 30 | imagecorruptions==1.1.0 31 | imageio==2.6.1 32 | joblib==0.14.1 33 | Keras-Preprocessing==1.1.2 34 | kiwisolver==1.1.0 35 | lxml==4.5.1 36 | Markdown==3.1.1 37 | matplotlib==3.1.1 38 | mkl-service==2.3.0 39 | mmcv==1.1.2 40 | networkx==2.4 41 | numpy==1.17.2 42 | oauthlib==3.1.0 43 | olefile==0.46 44 | opencv-python==4.1.1.26 45 | opt-einsum==3.3.0 46 | pandas==1.0.3 47 | Pillow==6.1.0 48 | protobuf==3.13.0 49 | pyasn1==0.4.8 50 | pyasn1-modules==0.2.8 51 | pyclipper==1.2.0 52 | pycocotools==2.0 53 | pycparser==2.19 54 | pydensecrf==1.0rc3 55 | pyparsing==2.4.2 56 | pyproj==2.6.1.post1 57 | python-dateutil==2.8.0 58 | pytz==2019.3 59 | PyWavelets==1.0.3 60 | pyyaml==5.1.2 61 | regex==2020.7.14 62 | requests==2.22.0 63 | requests-oauthlib==1.3.0 64 | rsa==4.6 65 | scikit-image==0.16.1 66 | scikit-learn==0.22.1 67 | scipy==1.2.0 68 | seaborn==0.10.0 69 | selenium==3.141.0 70 | six==1.12.0 71 | sklearn==0.0 72 | soupsieve==2.0.1 73 | tensorboard==2.3.0 74 | tensorboard-plugin-wit==1.7.0 75 | tensorflow==2.3.0 76 | tensorflow-estimator==2.3.0 77 | termcolor==1.1.0 78 | terminaltables==3.1.0 79 | tifffile==2019.7.26 80 | toolz==0.10.0 81 | torch==1.5.0+cu101 82 | torchvision==0.6.0+cu101 83 | tornado==6.0.3 84 | tqdm==4.42.1 85 | typing==3.6.6 86 | typing-extensions==3.6.6 87 | urllib3==1.25.6 88 | Werkzeug==0.16.0 89 | wincertstore==0.2 90 | wrapt==1.12.1 91 | xlrd==1.2.0 92 | xlwt==1.3.0 93 | yapf==0.30.0 94 | -------------------------------------------------------------------------------- /SSLRemoteSensing/datasets/datasets/representation/vr_dataset_inpainting_agr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @anthor: Wenyuan Li 3 | @desc: Datasets for self-supervised 4 | @date: 2020/5/15 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | import glob 10 | import os 11 | import numpy as np 12 | from skimage import io 13 | from skimage import util as sk_utils 14 | from PIL import Image 15 | from SSLRemoteSensing.datasets.transforms.representation import inpainting_transforms,builder 16 | import torch.utils.data as data_utils 17 | from SSLRemoteSensing.datasets.transforms.representation.agr_transforms import AGRTransforms 18 | 19 | class InpaintingAGRDataset(data_utils.Dataset): 20 | 21 | def __init__(self,data_path,data_format,inpainting_transforms_cfg,agr_transforms_cfg, 22 | pre_transforms_cfg,post_transforms_cfg, img_size=256): 23 | super(InpaintingAGRDataset, self).__init__() 24 | 25 | self.data_files=glob.glob(os.path.join(data_path,data_format)) 26 | 27 | self.img_size=img_size 28 | self.inpainting_transforms=inpainting_transforms.InpaintingTransforms(**inpainting_transforms_cfg) 29 | pre_transforms=[] 30 | for param in pre_transforms_cfg.values(): 31 | pre_transforms.append(builder.build_transforms(**param)) 32 | self.pre_transforms=torchvision.transforms.Compose(pre_transforms) 33 | 34 | post_transforms=[] 35 | for param in post_transforms_cfg.values(): 36 | post_transforms.append(builder.build_transforms(**param)) 37 | self.post_transforms=torchvision.transforms.Compose(post_transforms) 38 | 39 | self.agr_transforms =AGRTransforms(**agr_transforms_cfg) 40 | 41 | 42 | def __getitem__(self, item): 43 | 44 | img=Image.open(self.data_files[item]) 45 | img=self.pre_transforms(img) 46 | 47 | inpainting_label=img 48 | 49 | pre_img=img 50 | post_img,agr_label=self.agr_transforms.forward(img) 51 | data=img 52 | data=self.inpainting_transforms(data) 53 | data=self.post_transforms(data) 54 | inpainting_label=self.post_transforms(inpainting_label) 55 | inpainting_mask=torch.abs(inpainting_label-data) 56 | pre_img=self.post_transforms(pre_img) 57 | post_img=self.post_transforms(post_img) 58 | agr_label=torch.tensor(agr_label,dtype=torch.int64) 59 | return data,pre_img,post_img, inpainting_label,inpainting_mask,agr_label 60 | 61 | 62 | def __len__(self): 63 | return len(self.data_files) 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /SSLRemoteSensing/datasets/transforms/representation/inpainting_transforms.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @anthor: Wenyuan Li 3 | @desc: Transforms for self-supervised 4 | @date: 2020/5/15 5 | ''' 6 | import torch 7 | import cv2 8 | import torchvision 9 | import torchvision.transforms.functional as F 10 | import numpy as np 11 | from skimage import util as sk_utils 12 | import random 13 | from PIL import Image 14 | from SSLRemoteSensing.datasets.transforms.representation.transforms import ColorJitter,HorizontalFlip,VerticalFlip,Rotate 15 | import cv2 16 | class InpaintingTransforms(object): 17 | 18 | def __init__(self,min_cover_ratio=0.2,max_cover_ratio=1./3, 19 | brightness=0.3,contrast=(0.5,1.5),saturation=(0.5,1.5),hue=(-0.3,0.3)): 20 | self.min_conver_ratio=min_cover_ratio 21 | self.max_cover_ratio=max_cover_ratio 22 | self.colorJitter=ColorJitter(brightness=brightness, 23 | contrast=contrast,saturation=saturation,hue=hue) 24 | 25 | def __call__(self,img): 26 | return self.forward(img) 27 | 28 | def random_block(self,block): 29 | ''' 30 | 对block进行随机的操作 31 | :param block: [3,height,width] 32 | :return: 33 | ''' 34 | 35 | idx = np.random.randint(0, 4, (1,), dtype=np.int64)[0] 36 | if idx==0: 37 | 38 | block=np.flip(block,0) 39 | elif idx==1: 40 | block = np.flip(block, 1) 41 | else: 42 | k=np.random.randint(1, 4, (1,), dtype=int)[0] 43 | k=int(k) 44 | block=np.rot90(block,k) 45 | return block 46 | 47 | def forward(self,img): 48 | dtype=np.ndarray 49 | if isinstance(img,torch.Tensor): 50 | img=img.numpy() 51 | dtype=torch.Tensor 52 | elif isinstance(img,Image.Image): 53 | img=np.array(img) 54 | dtype=Image.Image 55 | ratio=random.random() 56 | ratio=self.min_conver_ratio+ratio*(self.max_cover_ratio-self.min_conver_ratio) 57 | img_height,img_width=img.shape[0:2] 58 | crop_height=int(img_height*ratio) 59 | crop_width=int(img_width*ratio) 60 | crop_size=min(crop_height,crop_width) 61 | x=np.random.randint(0,img_width-crop_size,1)[0] 62 | y=np.random.randint(0,img_height-crop_size,1)[0] 63 | 64 | img[y:y+crop_size,x:x+crop_size]=\ 65 | self.random_block(self.colorJitter(np.copy(img[y:y+crop_size,x:x+crop_size]))) 66 | if dtype==torch.Tensor: 67 | img=torch.from_numpy(img) 68 | elif dtype==Image.Image: 69 | img=Image.fromarray(img) 70 | return img 71 | 72 | 73 | if __name__=='__main__': 74 | import cv2 75 | 76 | img_files=r'F:\uvp_data\total\dior_data\00256_288_0.jpg' 77 | result_file=r'G:\other\inpainting.png' 78 | img=Image.open(img_files) 79 | img=F.resize(img,(256,256)) 80 | trans=InpaintingTransforms(min_cover_ratio=0.6,max_cover_ratio=0.7) 81 | result=trans.forward(img) 82 | result=np.array(result) 83 | cv2.imshow('result', result) 84 | cv2.waitKey() 85 | cv2.imwrite(result_file,result) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Overview. 2 | __Semantic Segmentation of Remote Sensing Images With Self-Supervised Multitask Representation Learning 3 | [Paper](https://ieeexplore.ieee.org/abstract/document/9460820)__ 4 | 5 | Existing deep learning-based remote sensing images semantic segmentation methods require large-scale labeled datasets. However, the annotation of segmentation datasets is often too time-consuming and expensive. To ease the burden of data annotation, self-supervised representation learning methods have emerged recently. However, the semantic segmentation methods need to learn both high-level and low-level features, but most of the existing self-supervised representation learning methods usually focus on one level, which affects the performance of semantic segmentation for remote sensing images. In order to solve this problem, we propose a self-supervised multitask representation learning method to capture effective visual representations of remote sensing images. We design three different pretext tasks and a triplet Siamese network to learn the high-level and low-level image features at the same time. The network can be trained without any labeled data, and the trained model can be fine-tuned with the annotated segmentation dataset. We conduct experiments on Potsdam, Vaihingen dataset, and cloud/snow detection dataset Levir_CS to verify the effectiveness of our methods. Experimental results show that our proposed method can effectively reduce the demand of labeled datasets and improve the performance of remote sensing semantic segmentation. Compared with the recent state-of-the-art self-supervised representation learning methods and the mostly used initialization methods (such as random initialization and ImageNet pretraining), our proposed method has achieved the best results in most experiments, especially in the case of few training data. With only 10% to 50% labeled data, our method can achieve the comparable performance compared with random initialization. 6 | 7 | ![Overview](fig/overall_networks.png) 8 | 9 | In this repository, we implement the training of self-supervised multi-task representation learning for remote sensing images with pytorch and generate pretrained models. With the code, you can also try on your own dataset by following the instructions below. 10 | 11 | 12 | 13 | # Requriements 14 | 15 | - python 3.6.7 16 | 17 | - pytorch 1.7.0 18 | 19 | - torchvision 0.6.0 20 | 21 | - cuda 10.1 22 | 23 | See also in [Requirements.txt](requirements.txt). 24 | 25 | # Setup 26 | 27 | 1. Clone this repo. 28 | 29 | `git clone https://github.com/flyakon/SSLRemoteSensing.git` 30 | 31 | `cd SSLRemoteSensing` 32 | 33 | 2. Prepare the training data and put it into the specified folder, such as ".. / dataset / train_ data". 34 | 35 | 3. Modify the configs file [vr_vgg16_inapinting_agr_examplar_cfg.py](configs/vr_vgg16_inapinting_agr_examplar_cfg.py) to configure the training parameters. 36 | 37 | Some important training parameters: 38 | 39 | ```_ 40 | backbone_cfg: which network ("vgg16_bn" or "resnet50 ") to choose as the backbone. 41 | inpainting_head_cfg, agr_head_cfg and examplar_head_cfg: network parameters corresponding to different pretext tasks. 42 | train_cfg: parameters corresponding to self-supervised representation learning. 43 | ``` 44 | 45 | 4. Change the "--config_file" option to the location of [vr_vgg16_inapinting_agr_examplar_cfg.py](configs/vr_vgg16_inapinting_agr_examplar_cfg.py) in [train.py](train.py) and run this file. 46 | 47 | # Citation 48 | 49 | If you find the code useful, please cite: 50 | 51 | `````` 52 | @ARTICLE{9460820, 53 | author={Li, Wenyuan and Chen, Hao and Shi, Zhenwei}, 54 | journal={IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing}, 55 | title={Semantic Segmentation of Remote Sensing Images With Self-Supervised Multitask Representation Learning}, 56 | year={2021}, 57 | volume={14}, 58 | number={}, 59 | pages={6438-6450}, 60 | doi={10.1109/JSTARS.2021.3090418}} 61 | `````` 62 | -------------------------------------------------------------------------------- /configs/vr_vgg16_inapinting_agr_examplar_cfg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Inpainting and Argument Recongnition configure file 3 | ''' 4 | img_size=256 5 | mode='train' 6 | config=dict( 7 | name='VRNetsWithInpaintingAGRExamplar', 8 | backbone_cfg=dict( 9 | name='vgg16_bn', 10 | num_classes=None, 11 | in_channels=3, 12 | pretrained=False, 13 | out_keys=('block2','block3','block4','block5','block6') 14 | ), 15 | inpainting_head_cfg=dict( 16 | name='InpaintingHead', 17 | in_channels=512, 18 | out_channels=3, 19 | img_size=img_size, 20 | feat_channels=(512,256,128,64,32) 21 | ), 22 | agr_head_cfg=dict( 23 | name='AGPHead', 24 | in_channels=512, 25 | num_classes=6, 26 | ), 27 | examplar_head_cfg=dict( 28 | name='ExamplarHead', 29 | in_channels=512, 30 | out_channels=512 31 | ), 32 | train_cfg=dict( 33 | batch_size=4, 34 | device='cuda:0', 35 | num_epoch=13, 36 | num_workers=4, 37 | train_data=dict( 38 | 39 | data_path=r'../dataset/train_data', 40 | data_format='*.jpg', 41 | img_size=img_size, 42 | inpainting_transforms_cfg=dict( 43 | min_cover_ratio=0.2,max_cover_ratio=1./3, 44 | brightness=0.3,contrast=(0.5,1.5),saturation=(0.5,1.5),hue=(-0.3,0.3) 45 | ), 46 | pre_transforms_cfg=dict( 47 | RandomHorizontalFlip=dict(name='RandomHorizontalFlip'), 48 | RandomVerticalFlip=dict(name='RandomVerticalFlip'), 49 | Rotate=dict(name='Rotate'), 50 | ), 51 | 52 | agr_transforms_cfg=dict( 53 | transforms_cfg=dict( 54 | RandomHorizontalFlip=dict(name='RandomHorizontalFlip',p=1.), 55 | RandomVerticalFlip=dict(name='RandomVerticalFlip',p=1.), 56 | Rotate_0=dict(name='Rotate',angle=0,p=1), 57 | Rotate_90=dict(name='Rotate',angle=90,p=1), 58 | Rotate_180=dict(name='Rotate',angle=180,p=1), 59 | Rotate_270=dict(name='Rotate',angle=270,p=1), 60 | # RandomCrop=dict(name='RandomCrop', crop_ratio_min=0.7, crop_ratio_max=0.95), 61 | # ColorJitter=dict(name='ColorJitter', brightness=0.3, contrast=(0.5, 1.5), saturation=(0.5, 1.5), 62 | # hue=(-0.3, 0.3)), 63 | ), 64 | shortcut_cfg=dict( 65 | RandomCrop=dict(name='RandomCrop', crop_ratio_min=0.7, crop_ratio_max=0.95), 66 | ColorJitter=dict(name='ColorJitter', brightness=0.3, contrast=(0.5, 1.5), saturation=(0.5, 1.5), 67 | hue=(-0.3, 0.3)), 68 | RandomGrayscale=dict(name='RandomGrayscale',p=0.5), 69 | # Normal=dict(name='Normal') 70 | ), 71 | ), 72 | post_transforms_cfg=dict( 73 | Resize=dict(name='Resize',size=(img_size,img_size)), 74 | ToTensor=dict(name='ToTensor') 75 | ), 76 | ), 77 | 78 | losses=dict( 79 | InpaintingLoss=dict(name='InpaintingLoss'), 80 | AGRLoss=dict(name='CrossEntropyLoss'), 81 | ExamplarLoss=dict(name='ExamplarLoss'), 82 | factors=[20,1,1] 83 | ), 84 | 85 | optimizer=dict( 86 | name='Adam', 87 | lr=0.0005 88 | ), 89 | 90 | checkpoints=dict( 91 | checkpoints_path=r'checkpoints/checkpoints_vgg16_inpainting_agr_examplar_total', 92 | save_step=1, 93 | ), 94 | lr_schedule=dict( 95 | name='stepLR', 96 | step_size=2, 97 | gamma=0.95 98 | ), 99 | log=dict( 100 | log_path=r'log/log_vgg16_inpainting_agr_examplar_total', 101 | log_step=50, 102 | with_vis=False, 103 | vis_path=r'' 104 | ), 105 | ), 106 | ) 107 | 108 | -------------------------------------------------------------------------------- /configs/vr_resnet50_inapinting_agr_examplar_cfg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Inpainting and Argument Recongnition configure file 3 | ''' 4 | img_size=256 5 | mode='train' 6 | config=dict( 7 | name='VRNetsWithInpaintingAGRExamplar', 8 | backbone_cfg=dict( 9 | name='resnet50', 10 | num_classes=None, 11 | in_channels=3, 12 | pretrained=False, 13 | out_keys=('block2','block3','block4','block5') 14 | ), 15 | inpainting_head_cfg=dict( 16 | name='InpaintingHead', 17 | in_channels=2048, 18 | out_channels=3, 19 | img_size=img_size, 20 | feat_channels=(1024,512,256,64,32) 21 | ), 22 | agr_head_cfg=dict( 23 | name='AGPHead', 24 | in_channels=2048, 25 | num_classes=6, 26 | ), 27 | examplar_head_cfg=dict( 28 | name='ExamplarHead', 29 | in_channels=2048, 30 | out_channels=512 31 | ), 32 | train_cfg=dict( 33 | batch_size=4, 34 | device='cuda:0', 35 | num_epoch=13, 36 | num_workers=4, 37 | train_data=dict( 38 | 39 | data_path=r'../dataset/train_data', 40 | data_format='*.jpg', 41 | img_size=img_size, 42 | inpainting_transforms_cfg=dict( 43 | min_cover_ratio=0.2,max_cover_ratio=1./3, 44 | brightness=0.3,contrast=(0.5,1.5),saturation=(0.5,1.5),hue=(-0.3,0.3) 45 | ), 46 | pre_transforms_cfg=dict( 47 | RandomHorizontalFlip=dict(name='RandomHorizontalFlip'), 48 | RandomVerticalFlip=dict(name='RandomVerticalFlip'), 49 | Rotate=dict(name='Rotate'), 50 | ), 51 | 52 | agr_transforms_cfg=dict( 53 | transforms_cfg=dict( 54 | RandomHorizontalFlip=dict(name='RandomHorizontalFlip',p=1.), 55 | RandomVerticalFlip=dict(name='RandomVerticalFlip',p=1.), 56 | Rotate_0=dict(name='Rotate',angle=0,p=1), 57 | Rotate_90=dict(name='Rotate',angle=90,p=1), 58 | Rotate_180=dict(name='Rotate',angle=180,p=1), 59 | Rotate_270=dict(name='Rotate',angle=270,p=1), 60 | # RandomCrop=dict(name='RandomCrop', crop_ratio_min=0.7, crop_ratio_max=0.95), 61 | # ColorJitter=dict(name='ColorJitter', brightness=0.3, contrast=(0.5, 1.5), saturation=(0.5, 1.5), 62 | # hue=(-0.3, 0.3)), 63 | ), 64 | shortcut_cfg=dict( 65 | RandomCrop=dict(name='RandomCrop', crop_ratio_min=0.7, crop_ratio_max=0.95), 66 | ColorJitter=dict(name='ColorJitter', brightness=0.3, contrast=(0.5, 1.5), saturation=(0.5, 1.5), 67 | hue=(-0.3, 0.3)), 68 | RandomGrayscale=dict(name='RandomGrayscale',p=0.5), 69 | # Normal=dict(name='Normal') 70 | ), 71 | ), 72 | post_transforms_cfg=dict( 73 | Resize=dict(name='Resize',size=(img_size,img_size)), 74 | ToTensor=dict(name='ToTensor') 75 | ), 76 | ), 77 | 78 | losses=dict( 79 | InpaintingLoss=dict(name='InpaintingLoss'), 80 | AGRLoss=dict(name='CrossEntropyLoss'), 81 | ExamplarLoss=dict(name='ExamplarLoss'), 82 | factors=[20,1,1] 83 | ), 84 | 85 | optimizer=dict( 86 | name='Adam', 87 | lr=0.0005 88 | ), 89 | 90 | checkpoints=dict( 91 | checkpoints_path=r'checkpoints/checkpoints_resnet50_imagenet_inpainting_agr_examplar_total', 92 | save_step=1, 93 | ), 94 | lr_schedule=dict( 95 | name='stepLR', 96 | step_size=2, 97 | gamma=0.95 98 | ), 99 | log=dict( 100 | log_path=r'log/log_resnet50_inpainting_agr_examplar_total', 101 | log_step=50, 102 | with_vis=False, 103 | vis_path=r'' 104 | ), 105 | ), 106 | test_cfg=dict( 107 | batch_size=1, 108 | test_data=dict( 109 | data_path=r'', 110 | label_path=r'', 111 | data_format='*.jpg', 112 | label_format='*.png' 113 | ), 114 | 115 | checkpoints=dict( 116 | checkpoints_path=r'', 117 | ), 118 | log=dict( 119 | with_vis=True, 120 | vis_path=r'' 121 | ), 122 | ) 123 | ) 124 | 125 | -------------------------------------------------------------------------------- /SSLRemoteSensing/models/backbone/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | from torchvision.models.utils import load_state_dict_from_url 5 | from collections import OrderedDict 6 | 7 | class NamedSequential(nn.Sequential): 8 | 9 | def forward(self, input: torch.Tensor): 10 | result_dict=OrderedDict() 11 | i=2 12 | for name,module in self._modules.items(): 13 | # print(name,module) 14 | input = module(input) 15 | if type(module).__name__ =='MaxPool2d': 16 | key='block%d'%i 17 | result_dict[key]=input 18 | i+=1 19 | 20 | return input,result_dict 21 | 22 | 23 | 24 | 25 | model_urls = { 26 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 27 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 28 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 29 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 30 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 31 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 32 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 33 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 34 | } 35 | 36 | 37 | class VGG(nn.Module): 38 | 39 | def __init__(self, features, num_classes=1000,out_keys=None, init_weights=True): 40 | super(VGG, self).__init__() 41 | self.out_keys=out_keys 42 | self.features = features 43 | self.num_classes=num_classes 44 | if self.num_classes is not None: 45 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 46 | self.classifier = nn.Sequential( 47 | nn.Linear(512 * 7 * 7, 4096), 48 | nn.ReLU(True), 49 | nn.Dropout(), 50 | nn.Linear(4096, 4096), 51 | nn.ReLU(True), 52 | nn.Dropout(), 53 | nn.Linear(4096, num_classes), 54 | ) 55 | if init_weights: 56 | print('initialize_weights') 57 | self._initialize_weights() 58 | 59 | def forward(self, x): 60 | x,endpoints = self.features(x) 61 | if self.num_classes is not None: 62 | x = self.avgpool(x) 63 | x = torch.flatten(x, 1) 64 | x = self.classifier(x) 65 | if self.out_keys is None: 66 | endpoints = {} 67 | else: 68 | endpoints = {key: endpoints[key] for key in self.out_keys} 69 | return x,endpoints 70 | 71 | def _initialize_weights(self): 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 75 | if m.bias is not None: 76 | nn.init.constant_(m.bias, 0) 77 | elif isinstance(m, nn.BatchNorm2d): 78 | nn.init.constant_(m.weight, 1) 79 | nn.init.constant_(m.bias, 0) 80 | elif isinstance(m, nn.Linear): 81 | nn.init.normal_(m.weight, 0, 0.01) 82 | nn.init.constant_(m.bias, 0) 83 | 84 | 85 | def make_layers(cfg, batch_norm=False,in_channels = 3): 86 | layers = [] 87 | for v in cfg: 88 | if v == 'M': 89 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 90 | else: 91 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 92 | if batch_norm: 93 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 94 | else: 95 | layers += [conv2d, nn.ReLU(inplace=True)] 96 | in_channels = v 97 | return NamedSequential(*layers) 98 | 99 | 100 | cfgs = { 101 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 102 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 103 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 104 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 105 | } 106 | 107 | arch_cfg_dict={'vgg11':'A','vgg13':'B','vgg16':'D','vgg19':'E'} 108 | 109 | def _vgg(arch, cfg, batch_norm, pretrained, progress,in_channels,num_classes, **kwargs): 110 | if pretrained: 111 | kwargs['init_weights'] = False 112 | 113 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm,in_channels=in_channels),num_classes=num_classes, **kwargs) 114 | 115 | if pretrained: 116 | if batch_norm: 117 | name='%s_bn'%arch 118 | else: 119 | name=arch 120 | state_dict = load_state_dict_from_url(model_urls[name], 121 | progress=progress) 122 | 123 | if in_channels != 3: 124 | keys = state_dict.keys() 125 | keys = [x for x in keys if 'features.0' in x] 126 | for key in keys: 127 | del state_dict[key] 128 | if num_classes is None: 129 | keys = state_dict.keys() 130 | keys = [x for x in keys if 'classifier' in x] 131 | for key in keys: 132 | del state_dict[key] 133 | model.load_state_dict(state_dict) 134 | elif num_classes != 1000: 135 | keys = state_dict.keys() 136 | keys = [x for x in keys if 'classifier.6' in x] 137 | for key in keys: 138 | del state_dict[key] 139 | model.load_state_dict(state_dict, strict=False) 140 | else: 141 | model.load_state_dict(state_dict) 142 | return model 143 | 144 | 145 | def get_vgg(name='vgg16',pretrained=True,progress=True, 146 | num_classes=1000,out_keys=None,in_channels=3,**kwargs): 147 | ''' 148 | Get resnet model with name. 149 | :param name: vgg model name 150 | :param pretrained: If True, returns a model pre-trained on ImageNet 151 | ''' 152 | 153 | if pretrained and num_classes !=1000: 154 | print('warning: num_class is not equal to 1000, which will cause some parameters to fail to load!') 155 | if pretrained and in_channels !=3: 156 | print('warning: in_channels is not equal to 3, which will cause some parameters to fail to load!') 157 | batch_norm=True if 'bn' in name else False 158 | if batch_norm: 159 | name=name.replace('_bn','') 160 | print('batchnorm:{0}'.format(batch_norm)) 161 | return _vgg(arch=name,cfg=arch_cfg_dict[name],batch_norm=batch_norm,pretrained=pretrained,progress=progress, 162 | num_classes=num_classes,out_keys=out_keys,in_channels=in_channels,**kwargs) 163 | 164 | if __name__=='__main__': 165 | model=get_vgg('vgg16_bn',pretrained=True,num_classes=None,in_channels=3, 166 | out_keys=('block2','block3','block4','block5','block6')) 167 | x=torch.rand([2,3,256,256]) 168 | result,endponits=model.forward(x) 169 | print(result.shape) 170 | print(endponits['block6'].shape) 171 | print(endponits['block5'].shape) 172 | print(endponits['block4'].shape) 173 | print(endponits['block3'].shape) 174 | print(endponits['block2'].shape) 175 | 176 | print(endponits.keys()) 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /SSLRemoteSensing/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import shutil 4 | import cv2 5 | from skimage import measure 6 | import numpy as np 7 | from skimage import io 8 | import argparse 9 | 10 | 11 | isprs_map={0:(255,255,255), 12 | 1:(0,0,255), 13 | 2:(0,255,255), 14 | 3:(0,255,0), 15 | 4:(255,255,0), 16 | 5:(255,0,0)} 17 | 18 | 19 | def voc_colormap(N=21): 20 | def bitget(val, idx): 21 | return ((val & (1 << idx)) != 0) 22 | 23 | cmap = np.zeros((N, 3), dtype=np.uint8) 24 | for i in range(N): 25 | r = g = b = 0 26 | c = i 27 | for j in range(8): 28 | r |= (bitget(c, 0) << 7 - j) 29 | g |= (bitget(c, 1) << 7 - j) 30 | b |= (bitget(c, 2) << 7 - j) 31 | c >>= 3 32 | cmap[i, :] = [r, g, b] 33 | return cmap 34 | 35 | VOC_COLOR_MAP=voc_colormap(21) 36 | 37 | def load_model(model_path,current_epoch=None,prefix='cub_model'): 38 | 39 | ''' 40 | 载入模型,默认model文件夹中有一个latest.pth文件 41 | :param state_dict: 42 | :param model_path: 43 | :return: 44 | ''' 45 | if os.path.isfile(model_path): 46 | model_file=model_path 47 | else: 48 | if current_epoch is None: 49 | model_file=os.path.join(model_path,'latest.pth') 50 | else: 51 | model_file = os.path.join(model_path, '%s_%d.pth'%(prefix,current_epoch)) 52 | if not os.path.exists(model_file): 53 | print('warning:%s does not exist!'%model_file) 54 | return None,0,0 55 | print('start to resume from %s' % model_file) 56 | 57 | state_dict=torch.load(model_file) 58 | 59 | try: 60 | glob_step=state_dict.pop('gobal_step') 61 | except KeyError: 62 | print('warning:glob_step not in state_dict.') 63 | glob_step=0 64 | try: 65 | epoch=state_dict.pop('epoch') 66 | except KeyError: 67 | print('glob_step not in state_dict.') 68 | epoch=0 69 | 70 | return state_dict,epoch+1,glob_step 71 | 72 | def save_model(model,model_path,epoch,global_step,prefix='cub_model',max_keep=10): 73 | 74 | if isinstance(model,torch.nn.Module): 75 | state_dict=model.state_dict() 76 | else: 77 | state_dict=model 78 | state_dict['epoch']=epoch 79 | state_dict['gobal_step']=global_step 80 | 81 | model_file=os.path.join(model_path,'%s_%d.pth'%(prefix,epoch)) 82 | torch.save(state_dict,model_file) 83 | shutil.copy(model_file,os.path.join(model_path,'latest.pth')) 84 | 85 | # if epoch>max_keep: 86 | # for i in range(0,epoch-max_keep): 87 | # model_file=os.path.join(model_path,'%s_%d.pth'%(prefix,epoch)) 88 | # if os.path.exists(model_file): 89 | # os.remove(model_file) 90 | 91 | def localization(img,thres=120): 92 | loc_list=[] 93 | condinate_set=set() 94 | label_mask=img>thres 95 | label_img = np.array(label_mask, np.uint8) 96 | kernel = np.ones((3, 3), np.uint8) 97 | label_img = cv2.morphologyEx(label_img, cv2.MORPH_CLOSE, kernel) 98 | label_mask=label_img>0 99 | label = measure.label(label_mask) 100 | props = measure.regionprops(label) 101 | # img=cv2.cvtColor(img,cv2.COLOR_GRAY2RGB) 102 | 103 | for prop in props: 104 | # if prop.area<=200: 105 | # continue 106 | bbox=prop.bbox 107 | if bbox in condinate_set: 108 | continue 109 | condinate_set.add(bbox) 110 | loc_list.append([bbox[1],bbox[0],bbox[3],bbox[2]]) 111 | loc_list=np.array(loc_list) 112 | return loc_list 113 | 114 | def calc_iou(prediction,gt): 115 | inter_xmin=np.maximum(prediction[:,0],gt[:,0]) 116 | inter_ymin=np.maximum(prediction[:,1],gt[:,1]) 117 | inter_xmax = np.minimum(prediction[:, 2], gt[:, 2]) 118 | inter_ymax = np.minimum(prediction[:, 3], gt[:, 3]) 119 | inter_height=np.maximum(inter_ymax-inter_ymin,0) 120 | inter_width=np.maximum(inter_xmax-inter_xmin,0) 121 | inter_area=inter_height*inter_width 122 | 123 | total_area=(prediction[:,3]-prediction[:,1])*(prediction[:,2]-prediction[:,0])+\ 124 | (gt[:,3]-gt[:,1])*(gt[:,2]-gt[:,0])-inter_area 125 | iou=inter_area/total_area 126 | return iou 127 | 128 | def rectangle(img,bndboxes,color): 129 | img_height, img_width, _ = img.shape 130 | img = cv2.rectangle(img, (int(bndboxes[0] * img_width), int(bndboxes[1] * img_height)), 131 | (int(bndboxes[2] * img_width), int(bndboxes[3] * img_height)), color, 2) 132 | return img 133 | 134 | def vis_info(img,gt_box,pred_box,result_path,file_name): 135 | 136 | img=img*255 137 | if isinstance(img,torch.Tensor): 138 | img=img.cpu().numpy() 139 | img=img.astype(np.uint8) 140 | img=np.transpose(img,[1,2,0]) 141 | img=cv2.cvtColor(img,cv2.COLOR_RGB2BGR) 142 | 143 | if isinstance(gt_box, torch.Tensor): 144 | gt_box=gt_box.cpu().numpy() 145 | img=rectangle(img,gt_box,[0,255,0]) 146 | if isinstance(pred_box,torch.Tensor): 147 | pred_box = pred_box.cpu().detach().numpy() 148 | img = rectangle(img, pred_box, [0, 0, 255]) 149 | 150 | cv2.imwrite(os.path.join(result_path,'{0}.jpg'.format(file_name)),img) 151 | 152 | def vis_fcn_result(img:torch.Tensor,label:torch.Tensor,result:torch.Tensor, 153 | result_path,file_name,as_binary=False): 154 | img=img.cpu().numpy() 155 | result=result.cpu().detach().numpy() 156 | img=img*255 157 | img=img.astype(np.uint8) 158 | img=np.transpose(img,(1,2,0)) 159 | 160 | # result=result[1]>0.5 161 | if as_binary: 162 | result=np.argmax(result,axis=0) 163 | result = result * 127 164 | # result=np.where(result[1]>0.5,255,0) 165 | else: 166 | result=result[1]*255 167 | result=result.astype(np.uint8) 168 | 169 | 170 | label=label.cpu().numpy()*127 171 | label=label.astype(np.uint8) 172 | 173 | io.imsave(os.path.join(result_path, '{0}_label.jpg'.format(file_name)), label) 174 | io.imsave(os.path.join(result_path,'{0}_img.jpg'.format(file_name)),img) 175 | io.imsave(os.path.join(result_path, '{0}_result.png'.format(file_name)), result) 176 | 177 | def vis_nap(img,block,idx_logits,result_path,global_step): 178 | img = img.cpu().detach().numpy() 179 | img=img[0] 180 | img = (img+1)/2 * 255 181 | img = img.astype(np.uint8) 182 | img = np.transpose(img, (1, 2, 0)) 183 | 184 | block = block.cpu().detach().numpy() 185 | block=block[0] 186 | block = (block+1)/2* 255 187 | block = block.astype(np.uint8) 188 | block = np.transpose(block, (1, 2, 0)) 189 | block_size,_,_=block.shape 190 | 191 | idx=idx_logits.cpu().detach().numpy() 192 | idx=np.argmax(idx,axis=-1)[0] 193 | 194 | row = np.mod(idx, 3) * block_size 195 | col = (idx // 3) * block_size 196 | recovery_img=np.copy(img) 197 | recovery_img[col:col+block_size,row:row+block_size,:]=block 198 | io.imsave(os.path.join(result_path, '{0}_img.jpg'.format(global_step)), img) 199 | io.imsave(os.path.join(result_path, '{0}_recovery.jpg'.format(global_step)), recovery_img) 200 | 201 | def vis_nap_argu(img,block,result_path,global_step): 202 | img = img.cpu().detach().numpy() 203 | img=img[0] 204 | img = (img+1)/2 * 255 205 | img = img.astype(np.uint8) 206 | img = np.transpose(img, (1, 2, 0)) 207 | 208 | block = block.cpu().detach().numpy() 209 | block=block[0] 210 | block = (block+1)/2* 255 211 | block = block.astype(np.uint8) 212 | block = np.transpose(block, (1, 2, 0)) 213 | 214 | recovery_img=block 215 | io.imsave(os.path.join(result_path, '{0}_img.jpg'.format(global_step)), img) 216 | io.imsave(os.path.join(result_path, '{0}_recovery.jpg'.format(global_step)), recovery_img) 217 | 218 | def vis_isprs_result(img,label,result,result_path,file_name): 219 | img = img.cpu().numpy() 220 | result = result.cpu().detach().numpy() 221 | img = img * 255 222 | img = img.astype(np.uint8) 223 | img = np.transpose(img, (1, 2, 0)) 224 | 225 | img_height,img_width,_=img.shape 226 | result = np.argmax(result, axis=0) 227 | result_map=-np.ones([img_height,img_width,3]) 228 | for i in range(6): 229 | result_map=np.where(result[:,:,np.newaxis]==i,isprs_map[i],result_map) 230 | assert (result_map==-1).any() == False 231 | result_map = result_map.astype(np.uint8) 232 | 233 | label = label.cpu().numpy() 234 | label_map=-np.ones([img_height,img_width,3]) 235 | for i in range(6): 236 | label_map = np.where(label[:, :, np.newaxis] == i, isprs_map[i], label_map) 237 | label_map = label_map.astype(np.uint8) 238 | assert (label_map == -1).any() == False 239 | io.imsave(os.path.join(result_path, '{0}_label.jpg'.format(file_name)), label_map) 240 | io.imsave(os.path.join(result_path, '{0}_img.jpg'.format(file_name)), img) 241 | io.imsave(os.path.join(result_path, '{0}_result.jpg'.format(file_name)), result_map) 242 | 243 | def vis_voc_result(img,label,result,result_path,file_name): 244 | img = img.cpu().numpy() 245 | result = result.cpu().detach().numpy() 246 | img = img * 255 247 | img = img.astype(np.uint8) 248 | img = np.transpose(img, (1, 2, 0)) 249 | 250 | img_height,img_width,_=img.shape 251 | result = np.argmax(result, axis=0) 252 | result_map=-np.ones([img_height,img_width,3]) 253 | for i in range(21): 254 | result_map=np.where(result[:,:,np.newaxis]==i,VOC_COLOR_MAP[i],result_map) 255 | assert (result_map==-1).any() == False 256 | result_map = result_map.astype(np.uint8) 257 | 258 | label = label.cpu().numpy() 259 | label_map=-np.ones([img_height,img_width,3]) 260 | for i in range(21): 261 | label_map = np.where(label[:, :, np.newaxis] == i, VOC_COLOR_MAP[i], label_map) 262 | label_map = label_map.astype(np.uint8) 263 | assert (label_map == -1).any() == False 264 | io.imsave(os.path.join(result_path, '{0}_label.jpg'.format(file_name)), label_map) 265 | io.imsave(os.path.join(result_path, '{0}_img.jpg'.format(file_name)), img) 266 | io.imsave(os.path.join(result_path, '{0}_result.jpg'.format(file_name)), result_map) 267 | 268 | 269 | def vis_agp_img(img,agu_img,result_path,global_step): 270 | img = img.cpu().detach().numpy() 271 | img=img[0] 272 | img = (img+1)/2 * 255 273 | img = img.astype(np.uint8) 274 | img = np.transpose(img, (1, 2, 0)) 275 | 276 | agu_img = agu_img.cpu().detach().numpy() 277 | agu_img=agu_img[0] 278 | agu_img = (agu_img+1)/2* 255 279 | agu_img = agu_img.astype(np.uint8) 280 | agu_img = np.transpose(agu_img, (1, 2, 0)) 281 | 282 | io.imsave(os.path.join(result_path, '{0}_img.jpg'.format(global_step)), img) 283 | io.imsave(os.path.join(result_path, '{0}_agu.jpg'.format(global_step)), agu_img) 284 | 285 | def str2bool(v): 286 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 287 | return True 288 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 289 | return False 290 | elif v.lower in ('none','null','-1'): 291 | return None 292 | else: 293 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 294 | 295 | if __name__=='__main__': 296 | map=voc_colormap(21) 297 | print(map) -------------------------------------------------------------------------------- /SSLRemoteSensing/models/representation/vr_nets_inpainting_agr_examplar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @anthor: Wenyuan Li 3 | @desc: Networks for self-supervised 4 | @date: 2020/5/20 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | from ..backbone.builder import build_backbone 10 | from SSLRemoteSensing.datasets.datasets.representation.vr_dataset_inpainting_agr import InpaintingAGRDataset 11 | import datetime 12 | from torch.utils.tensorboard import SummaryWriter 13 | import os 14 | import torch.utils.data as data_utils 15 | import SSLRemoteSensing.utils.utils as utils 16 | from SSLRemoteSensing.utils.optims.builder import build_optim,build_lr_schedule 17 | from SSLRemoteSensing.losses.builder import builder_loss 18 | 19 | class VRNetsWithInpaintingAGRExamplar(nn.Module): 20 | 21 | def __init__(self,backbone_cfg:dict, 22 | inpainting_head_cfg:dict, 23 | agr_head_cfg:dict, 24 | examplar_head_cfg:dict, 25 | train_cfg:dict, 26 | **kwargs): 27 | super(VRNetsWithInpaintingAGRExamplar,self).__init__() 28 | self.backbone=build_backbone(**backbone_cfg) 29 | self.build_arch(inpainting_head_cfg) 30 | self.build_agr_arch(agr_head_cfg) 31 | self.build_examplar_arch(examplar_head_cfg) 32 | self.train_cfg=train_cfg 33 | 34 | 35 | def build_examplar_arch(self,examplar_cfg): 36 | in_channels = examplar_cfg['in_channels'] 37 | out_channels = examplar_cfg['out_channels'] 38 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 39 | self.softmax = nn.Softmax(dim=1) 40 | self.examplar_fc = nn.Sequential(*[ 41 | nn.Linear(in_channels, out_channels*2), 42 | nn.Sigmoid(), 43 | nn.Linear(2*out_channels, out_channels) 44 | ]) 45 | 46 | def build_agr_arch(self,agr_cfg): 47 | in_channels = agr_cfg['in_channels'] 48 | num_classes=agr_cfg['num_classes'] 49 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 50 | self.softmax = nn.Softmax(dim=1) 51 | self.agr_conv=nn.Sequential(*[ 52 | nn.Conv2d(in_channels*2,in_channels,kernel_size=3,padding=1,stride=1), 53 | nn.BatchNorm2d(in_channels), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=1), 56 | nn.BatchNorm2d(in_channels), 57 | nn.ReLU(inplace=True) 58 | ]) 59 | self.agr_fc=nn.Sequential(*[ 60 | nn.Linear(in_channels,2*in_channels), 61 | nn.ReLU(inplace=True), 62 | nn.Dropout(), 63 | nn.Linear(2*in_channels,num_classes) 64 | ]) 65 | 66 | def build_arch(self,head_cfg): 67 | 68 | in_channels=head_cfg['in_channels'] 69 | out_channels=head_cfg['out_channels'] 70 | feat_channels=head_cfg['feat_channels'] 71 | 72 | self.trans_conv1=nn.Sequential(*[ 73 | nn.Conv2d(in_channels, feat_channels[0], kernel_size=3, stride=1, padding=1), 74 | nn.BatchNorm2d(feat_channels[0]), 75 | nn.ReLU(inplace=True)]) 76 | self.trans_conv2 = nn.Sequential(*[ 77 | nn.Conv2d(feat_channels[0], feat_channels[1], kernel_size=3, stride=1, padding=1), 78 | nn.BatchNorm2d(feat_channels[1]), 79 | nn.ReLU(inplace=True)]) 80 | self.trans_conv3 = nn.Sequential(*[ 81 | nn.Conv2d(feat_channels[1], feat_channels[2], kernel_size=3, stride=1, padding=1), 82 | nn.BatchNorm2d(feat_channels[2]), 83 | nn.ReLU(inplace=True)]) 84 | self.trans_conv4 = nn.Sequential(*[ 85 | nn.Conv2d(feat_channels[2], feat_channels[3], kernel_size=3, stride=1, padding=1), 86 | nn.BatchNorm2d(feat_channels[3]), 87 | nn.ReLU(inplace=True)]) 88 | self.trans_conv5 = nn.Sequential(*[ 89 | nn.Conv2d(feat_channels[3], feat_channels[4], kernel_size=3, stride=1, padding=1), 90 | nn.BatchNorm2d(feat_channels[4]), 91 | nn.ReLU(inplace=True)]) 92 | 93 | self.pred_conv=nn.Conv2d(feat_channels[4], out_channels, kernel_size=3, stride=1, padding=1) 94 | 95 | def forward_agr(self,pre_img,post_img): 96 | pre_logits,_=self.backbone(pre_img) 97 | post_logits,_ = self.backbone(post_img) 98 | x=torch.cat((pre_logits,post_logits),dim=1) 99 | 100 | x=self.agr_conv(x) 101 | x=self.avg_pool(x) 102 | x=torch.flatten(x,1) 103 | logits=self.agr_fc(x) 104 | return logits,pre_logits,post_logits 105 | 106 | 107 | def forward_inpainting(self, x, **kwargs): 108 | x,endpoints=self.backbone(x) 109 | 110 | x = torch.nn.functional.interpolate(x, align_corners=True, scale_factor=2, mode='bilinear') 111 | x = self.trans_conv1(x) 112 | if 'block5' in endpoints.keys(): 113 | x=x+endpoints['block5'] 114 | 115 | x = torch.nn.functional.interpolate(x, align_corners=True, scale_factor=2, mode='bilinear') 116 | x = self.trans_conv2(x) 117 | if 'block4' in endpoints.keys(): 118 | x = x + endpoints['block4'] 119 | 120 | x = torch.nn.functional.interpolate(x, align_corners=True, scale_factor=2, mode='bilinear') 121 | x = self.trans_conv3(x) 122 | if 'block3' in endpoints.keys(): 123 | x = x + endpoints['block3'] 124 | 125 | x = torch.nn.functional.interpolate(x, align_corners=True, scale_factor=2, mode='bilinear') 126 | x = self.trans_conv4(x) 127 | if 'block2' in endpoints.keys(): 128 | x = x + endpoints['block2'] 129 | 130 | x = torch.nn.functional.interpolate(x, align_corners=True, scale_factor=2, mode='bilinear') 131 | x = self.trans_conv5(x) 132 | if 'block1' in endpoints.keys(): 133 | x = x + endpoints['block1'] 134 | 135 | logits=self.pred_conv(x) 136 | 137 | return logits 138 | 139 | def forward_examplar(self,pre,post): 140 | pre=self.avg_pool(pre) 141 | pre=torch.flatten(pre,1) 142 | pre=self.examplar_fc(pre) 143 | 144 | post = self.avg_pool(post) 145 | post = torch.flatten(post, 1) 146 | post = self.examplar_fc(post) 147 | return pre,post 148 | 149 | def forward(self,inpainting_img,pre_img,post_img): 150 | inpainting_logits=self.forward_inpainting(inpainting_img) 151 | agr_logits,pre_logits,post_logits=self.forward_agr(pre_img,post_img) 152 | pre_logits,post_logits=self.forward_examplar(pre_logits,post_logits) 153 | return inpainting_logits,agr_logits,pre_logits,post_logits 154 | 155 | 156 | def run_train_interface(self,**kwargs): 157 | batch_size=self.train_cfg['batch_size'] 158 | device=self.train_cfg['device'] 159 | num_epoch=self.train_cfg['num_epoch'] 160 | num_workers=self.train_cfg['num_workers'] 161 | checkpoint_path=self.train_cfg['checkpoints']['checkpoints_path'] 162 | save_step=self.train_cfg['checkpoints']['save_step'] 163 | log_path=self.train_cfg['log']['log_path'] 164 | log_step=self.train_cfg['log']['log_step'] 165 | with_vis=self.train_cfg['log']['with_vis'] 166 | vis_path=self.train_cfg['log']['vis_path'] 167 | self.to(device) 168 | if not os.path.exists(checkpoint_path): 169 | os.makedirs(checkpoint_path) 170 | train_dataset=InpaintingAGRDataset(**self.train_cfg['train_data']) 171 | train_dataloader=data_utils.DataLoader(train_dataset,batch_size,shuffle=True,num_workers=num_workers, 172 | drop_last=True) 173 | inpainting_criterion=builder_loss(**self.train_cfg['losses']['InpaintingLoss']) 174 | agr_criterion=builder_loss(**self.train_cfg['losses']['AGRLoss']) 175 | examplar_criterion=builder_loss(**self.train_cfg['losses']['ExamplarLoss'], 176 | device=device,batch_size=batch_size) 177 | loss_factors=self.train_cfg['losses']['factors'] 178 | 179 | optimizer = build_optim(params=self.parameters(), **self.train_cfg['optimizer']) 180 | 181 | if 'lr_schedule' in self.train_cfg.keys(): 182 | lr_schedule=build_lr_schedule(optimizer=optimizer,**self.train_cfg['lr_schedule']) 183 | 184 | state_dict, current_epoch, global_step = utils.load_model(checkpoint_path) 185 | if state_dict is not None: 186 | print('resume from epoch %d global_step %d' % (current_epoch, global_step)) 187 | self.load_state_dict(state_dict, strict=True) 188 | 189 | summary = SummaryWriter(log_path) 190 | start_time = datetime.datetime.now() 191 | 192 | for epoch in range(current_epoch,num_epoch): 193 | for i, data in enumerate(train_dataloader): 194 | # data.to(device) 195 | global_step += 1 196 | input,pre_img,post_img,inpainting_label,attention_mask,agr_label=data 197 | input=input.to(device) 198 | pre_img=pre_img.to(device) 199 | post_img=post_img.to(device) 200 | inpainting_label=inpainting_label.to(device) 201 | attention_mask=attention_mask.to(device) 202 | agr_label=agr_label.to(device) 203 | 204 | self.train() 205 | outputs,agr_logits,pre_logits,post_logits = self.forward(input,pre_img,post_img) 206 | inpainting_loss = inpainting_criterion(outputs,inpainting_label,attention_mask) 207 | inpainting_loss=inpainting_loss*loss_factors[0] 208 | agr_loss=agr_criterion(agr_logits,agr_label)*loss_factors[1] 209 | examplar_loss=examplar_criterion(pre_logits,post_logits)*loss_factors[2] 210 | loss=inpainting_loss+agr_loss+examplar_loss 211 | optimizer.zero_grad() 212 | loss.backward() 213 | optimizer.step() 214 | 215 | if global_step % log_step == 0: 216 | end_time = datetime.datetime.now() 217 | total_time = ((end_time - start_time).seconds * 1000 + (end_time - start_time).microseconds / 1000) 218 | total_time = total_time / log_step / batch_size 219 | fps = 1 / total_time * 1000 220 | start_time = datetime.datetime.now() 221 | # if args.is_vis: 222 | # utils.vis_nap(noise_img, block_recovery, block_logits, 'train_result', global_step) 223 | print("[Epoch %d/%d] [Batch %d/%d] [inpainting loss:%f,agr loss:%f,examplar loss:%f,loss: %f] [fps:%f]" % ( 224 | epoch, num_epoch, i, len(train_dataloader), 225 | inpainting_loss.item(),agr_loss.item(),examplar_loss.item(),loss.item(), fps)) 226 | summary.add_scalar('inpainting loss', inpainting_loss, global_step) 227 | summary.add_scalar('agr loss', agr_loss, global_step) 228 | summary.add_scalar('examplar loss', examplar_loss, global_step) 229 | summary.add_scalar('total loss', loss, global_step) 230 | 231 | if 'lr_schedule' in self.train_cfg.keys(): 232 | lr_schedule.step(epoch=epoch) 233 | summary.add_scalar('learning_rate',optimizer.state_dict()['param_groups'][0]['lr'],global_step) 234 | if epoch % save_step == 0: 235 | print('save model') 236 | utils.save_model(self, checkpoint_path, 237 | epoch, global_step, max_keep=200) 238 | 239 | 240 | 241 | 242 | 243 | 244 | -------------------------------------------------------------------------------- /SSLRemoteSensing/datasets/transforms/representation/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data_utils 3 | import os 4 | import numpy as np 5 | from PIL import Image, ImageFilter 6 | import torchvision.transforms 7 | import random 8 | import torchvision.transforms.functional as F 9 | import cv2 10 | import sys 11 | import collections 12 | import numbers 13 | import datetime 14 | 15 | if sys.version_info < (3, 3): 16 | Sequence = collections.Sequence 17 | Iterable = collections.Iterable 18 | else: 19 | Sequence = collections.abc.Sequence 20 | Iterable = collections.abc.Iterable 21 | 22 | 23 | class VerticalFlip(): 24 | """Vertically flip the given PIL Image and bounding boxes randomly with a given probability. 25 | 26 | Args: 27 | p (float): probability of the image being flipped. Default value is 0.5 28 | """ 29 | 30 | def __init__(self, p=0.5,with_idx=False): 31 | self.p = p 32 | self.with_idx=with_idx 33 | 34 | def __call__(self, img): 35 | """ 36 | Args: 37 | img ( Image): Image to be flipped. 38 | Returns: 39 | PIL Image: Randomly flipped image. 40 | """ 41 | label=0 42 | if random.random() width, then image will be rescaled to 342 | (size * height / width, size) 343 | interpolation (int, optional): Desired interpolation. Default is 344 | ``PIL.Image.BILINEAR`` 345 | """ 346 | 347 | def __init__(self, size, interpolation=Image.BILINEAR): 348 | assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) 349 | self.size = size 350 | self.interpolation = interpolation 351 | 352 | def __call__(self, img): 353 | """ 354 | Args: 355 | img (PIL Image): Image to be scaled. 356 | 357 | Returns: 358 | PIL Image: Rescaled image. 359 | """ 360 | if isinstance(img,np.ndarray): 361 | img=Image.fromarray(img) 362 | elif isinstance(img,torch.Tensor): 363 | img=Image.fromarray(img.numpy()) 364 | 365 | img=F.resize(img, self.size, self.interpolation) 366 | return img 367 | 368 | 369 | def __repr__(self): 370 | 371 | return self.__class__.__name__ + '(size={0})'.format(self.size) 372 | 373 | class Normal(object): 374 | 375 | def __call__(self,img): 376 | return img -------------------------------------------------------------------------------- /SSLRemoteSensing/models/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | from torchvision.models.utils import load_state_dict_from_url 5 | 6 | model_urls = { 7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 8 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 9 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 10 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 11 | } 12 | 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=dilation, groups=groups, bias=False, dilation=dilation) 19 | 20 | def conv1x1(in_planes, out_planes, stride=1): 21 | """1x1 convolution""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 28 | base_width=64, dilation=1, norm_layer=None): 29 | super(BasicBlock, self).__init__() 30 | if norm_layer is None: 31 | norm_layer = nn.BatchNorm2d 32 | if groups != 1 or base_width != 64: 33 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 34 | if dilation > 1: 35 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 36 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | self.bn1 = norm_layer(planes) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = conv3x3(planes, planes) 41 | self.bn2 = norm_layer(planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | identity = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | identity = self.downsample(x) 57 | 58 | out += identity 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 4 66 | 67 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 68 | base_width=64, dilation=1, norm_layer=None): 69 | super(Bottleneck, self).__init__() 70 | if norm_layer is None: 71 | norm_layer = nn.BatchNorm2d 72 | width = int(planes * (base_width / 64.)) * groups 73 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 74 | self.conv1 = conv1x1(inplanes, width) 75 | self.bn1 = norm_layer(width) 76 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 77 | self.bn2 = norm_layer(width) 78 | self.conv3 = conv1x1(width, planes * self.expansion) 79 | self.bn3 = norm_layer(planes * self.expansion) 80 | self.relu = nn.ReLU(inplace=True) 81 | self.downsample = downsample 82 | self.stride = stride 83 | 84 | def forward(self, x): 85 | identity = x 86 | 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv3(out) 96 | out = self.bn3(out) 97 | 98 | if self.downsample is not None: 99 | identity = self.downsample(x) 100 | 101 | out += identity 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | class ResNet(nn.Module): 107 | 108 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 109 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 110 | norm_layer=None,out_keys=None,in_channels=3): 111 | super(ResNet, self).__init__() 112 | if norm_layer is None: 113 | norm_layer = nn.BatchNorm2d 114 | self._norm_layer = norm_layer 115 | self.out_keys=out_keys 116 | self.num_classes=num_classes 117 | self.inplanes = 64 118 | self.dilation = 1 119 | if replace_stride_with_dilation is None: 120 | # each element in the tuple indicates if we should replace 121 | # the 2x2 stride with a dilated convolution instead 122 | replace_stride_with_dilation = [False, False, False] 123 | if len(replace_stride_with_dilation) != 3: 124 | raise ValueError("replace_stride_with_dilation should be None " 125 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 126 | self.groups = groups 127 | self.base_width = width_per_group 128 | self.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3, 129 | bias=False) 130 | self.bn1 = norm_layer(self.inplanes) 131 | self.relu = nn.ReLU(inplace=True) 132 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 133 | self.layer1 = self._make_layer(block, 64, layers[0]) 134 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 135 | dilate=replace_stride_with_dilation[0]) 136 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 137 | dilate=replace_stride_with_dilation[1]) 138 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 139 | dilate=replace_stride_with_dilation[2]) 140 | if self.num_classes is not None: 141 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 142 | self.fc = nn.Linear(512 * block.expansion, self.num_classes) 143 | 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 147 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 148 | nn.init.constant_(m.weight, 1) 149 | nn.init.constant_(m.bias, 0) 150 | 151 | # Zero-initialize the last BN in each residual branch, 152 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 153 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 154 | if zero_init_residual: 155 | for m in self.modules(): 156 | if isinstance(m, Bottleneck): 157 | nn.init.constant_(m.bn3.weight, 0) 158 | elif isinstance(m, BasicBlock): 159 | nn.init.constant_(m.bn2.weight, 0) 160 | 161 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 162 | norm_layer = self._norm_layer 163 | downsample = None 164 | previous_dilation = self.dilation 165 | if dilate: 166 | self.dilation *= stride 167 | stride = 1 168 | if stride != 1 or self.inplanes != planes * block.expansion: 169 | downsample = nn.Sequential( 170 | conv1x1(self.inplanes, planes * block.expansion, stride), 171 | norm_layer(planes * block.expansion), 172 | ) 173 | 174 | layers = [] 175 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 176 | self.base_width, previous_dilation, norm_layer)) 177 | self.inplanes = planes * block.expansion 178 | for _ in range(1, blocks): 179 | layers.append(block(self.inplanes, planes, groups=self.groups, 180 | base_width=self.base_width, dilation=self.dilation, 181 | norm_layer=norm_layer)) 182 | 183 | return nn.Sequential(*layers) 184 | 185 | def forward(self, x): 186 | endpoints=dict() 187 | x = self.conv1(x) 188 | x = self.bn1(x) 189 | x = self.relu(x) 190 | endpoints['block2'] = x 191 | x = self.maxpool(x) 192 | x = self.layer1(x) 193 | endpoints['block3'] = x 194 | x = self.layer2(x) 195 | endpoints['block4'] = x 196 | x = self.layer3(x) 197 | endpoints['block5'] = x 198 | x = self.layer4(x) 199 | endpoints['block6'] = x 200 | 201 | if self.num_classes is not None: 202 | x = self.avgpool(x) 203 | x = torch.flatten(x, 1) 204 | x = self.fc(x) 205 | if self.out_keys is None: 206 | endpoints={} 207 | else: 208 | endpoints={key:endpoints[key] for key in self.out_keys} 209 | return x,endpoints 210 | 211 | def _resnet(arch, block, layers, pretrained, progress,num_classes=1000,in_channels=3, **kwargs): 212 | model = ResNet(block, layers,num_classes,in_channels=in_channels, **kwargs) 213 | if pretrained: 214 | 215 | state_dict = load_state_dict_from_url(model_urls[arch], 216 | progress=progress) 217 | 218 | if in_channels !=3: 219 | keys = state_dict.keys() 220 | keys = [x for x in keys if 'conv1' in x] 221 | for key in keys: 222 | del state_dict[key] 223 | if num_classes is None: 224 | keys = state_dict.keys() 225 | keys = [x for x in keys if 'fc' in x] 226 | for key in keys: 227 | del state_dict[key] 228 | model.load_state_dict(state_dict) 229 | elif num_classes !=1000: 230 | keys = state_dict.keys() 231 | keys = [x for x in keys if 'fc' in x] 232 | for key in keys: 233 | del state_dict[key] 234 | model.load_state_dict(state_dict,strict=False) 235 | else: 236 | model.load_state_dict(state_dict) 237 | 238 | return model 239 | 240 | def _resnet18(name='resnet18',pretrained=True,progress=True, num_classes=1000,out_keys=None,**kwargs): 241 | r"""ResNet-18 model from 242 | `"Deep Residual Learning for Image Recognition" `_ 243 | 244 | Args: 245 | pretrained (bool): If True, returns a model pre-trained on ImageNet 246 | progress (bool): If True, displays a progress bar of the download to stderr 247 | """ 248 | return _resnet(name, BasicBlock, [2, 2, 2, 2], pretrained, progress, 249 | num_classes=num_classes,out_keys=out_keys, **kwargs) 250 | 251 | def _resnet50(name='resnet50',pretrained=False, progress=True,num_classes=1000,out_keys=None, **kwargs): 252 | r"""ResNet-50 model from 253 | `"Deep Residual Learning for Image Recognition" `_ 254 | 255 | Args: 256 | pretrained (bool): If True, returns a model pre-trained on ImageNet 257 | progress (bool): If True, displays a progress bar of the download to stderr 258 | """ 259 | return _resnet(name, Bottleneck, [3, 4, 6, 3], pretrained, progress, 260 | num_classes=num_classes,out_keys=out_keys, 261 | **kwargs) 262 | 263 | 264 | def _resnet101(name='resnet101',pretrained=False, progress=True, num_classes=1000,out_keys=None,**kwargs): 265 | r"""ResNet-101 model from 266 | `"Deep Residual Learning for Image Recognition" `_ 267 | 268 | Args: 269 | pretrained (bool): If True, returns a model pre-trained on ImageNet 270 | progress (bool): If True, displays a progress bar of the download to stderr 271 | """ 272 | return _resnet(name, Bottleneck, [3, 4, 23, 3], pretrained, progress, 273 | num_classes=num_classes, out_keys=out_keys, 274 | **kwargs) 275 | 276 | 277 | def _resnet152(name='resnet152',pretrained=False, progress=True,num_classes=1000,out_keys=None,**kwargs): 278 | r"""ResNet-152 model from 279 | `"Deep Residual Learning for Image Recognition" `_ 280 | 281 | Args: 282 | pretrained (bool): If True, returns a model pre-trained on ImageNet 283 | progress (bool): If True, displays a progress bar of the download to stderr 284 | """ 285 | return _resnet(name, Bottleneck, [3, 8, 36, 3], pretrained, progress, 286 | num_classes=num_classes, out_keys=out_keys, 287 | **kwargs) 288 | 289 | def get_resnet(name='resnet50',pretrained=True,progress=True, 290 | num_classes=1000,out_keys=None,in_channels=3,**kwargs): 291 | ''' 292 | Get resnet model with name. 293 | :param name: resnet model name, optional values:[resnet18, reset50, resnet101, resnet152] 294 | :param pretrained: If True, returns a model pre-trained on ImageNet 295 | ''' 296 | 297 | if pretrained and num_classes !=1000: 298 | print('warning: num_class is not equal to 1000, which will cause some parameters to fail to load!') 299 | if pretrained and in_channels !=3: 300 | print('warning: in_channels is not equal to 3, which will cause some parameters to fail to load!') 301 | 302 | if name=='resnet18': 303 | return _resnet18(name=name,pretrained=pretrained,progress=progress, 304 | num_classes=num_classes,out_keys=out_keys,in_channels=in_channels,**kwargs) 305 | elif name=='resnet50': 306 | return _resnet50(name=name,pretrained=pretrained,progress=progress, 307 | num_classes=num_classes,out_keys=out_keys,in_channels=in_channels,**kwargs) 308 | elif name=='resnet101': 309 | return _resnet101(name=name,pretrained=pretrained,progress=progress, 310 | num_classes=num_classes,out_keys=out_keys,in_channels=in_channels,**kwargs) 311 | elif name=='resnet152': 312 | return _resnet152(name=name,pretrained=pretrained,progress=progress, 313 | num_classes=num_classes,out_keys=out_keys,in_channels=in_channels,**kwargs) 314 | else: 315 | raise NotImplementedError(r'''{0} is not an available values. \ 316 | Please choose one of the available values in 317 | [resnet18, reset50, resnet101, resnet152]'''.format(name)) 318 | 319 | if __name__=='__main__': 320 | model=get_resnet('resnet50',pretrained=True,num_classes=None,in_channels=3, 321 | out_keys=('block2','block3','block4','block5','block6')) 322 | x=torch.rand([2,3,256,256]) 323 | result,endponits=model.forward(x) 324 | print(result.shape) 325 | print(endponits['block6'].shape) 326 | print(endponits['block5'].shape) 327 | print(endponits['block4'].shape) 328 | print(endponits['block3'].shape) 329 | print(endponits['block2'].shape) 330 | 331 | print(endponits.keys()) --------------------------------------------------------------------------------