├── .idea ├── .gitignore ├── .name ├── AMTNet.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml └── modules.xml ├── README.md ├── __init__.py ├── __pycache__ ├── data_utils.cpython-37.pyc ├── data_utils.cpython-39.pyc ├── train_options.cpython-39.pyc ├── train_options_CLCD.cpython-39.pyc ├── train_options_HRSCD.cpython-39.pyc ├── train_options_LEVIR.cpython-39.pyc └── train_options_WHU.cpython-39.pyc ├── data_utils.py ├── loss ├── __pycache__ │ ├── losses.cpython-37.pyc │ └── losses.cpython-39.pyc └── losses.py ├── model ├── MSCANet.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-39.pyc │ ├── backbone.cpython-39.pyc │ ├── modules.cpython-39.pyc │ └── network.cpython-39.pyc ├── backbone.py ├── dtcdscn.py ├── modules.py ├── network.py ├── siamunet_conc.py ├── siamunet_diff.py └── unet.py ├── test.py ├── train.py ├── train_options.py ├── utils.py └── visualize_results.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/.name: -------------------------------------------------------------------------------- 1 | losses.py -------------------------------------------------------------------------------- /.idea/AMTNet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AMTNet 2 | The pytorch implementation for AMTNet in paper ["An attention-based multiscale transformer network for remote sensing image 3 | change detection"](https://www.sciencedirect.com/science/article/abs/pii/S092427162300182X?CMX_ID=&SIS_ID=&dgcid=STMJ_AUTH_SERV_PUBLISHED&utm_acid=276849605&utm_campaign=STMJ_AUTH_SERV_PUBLISHED&utm_in=DM391842&utm_medium=email&utm_source=AC_)on ["ISPRS Journal of Photogrammetry and Remote Sensing"](https://www.sciencedirect.com/journal/isprs-journal-of-photogrammetry-and-remote-sensing). 4 | # Requirements 5 | * Python 3.9 6 | * Pytorch 1.12 7 | # DataSet 8 | * Download the [CLCD Dataset](https://pan.baidu.com/share/init?surl=Un-bVxUm1N9IHiDOXLLHlg&pwd=miu2) 9 | * Download the [HRSCD Dataset](https://ieee-dataport.org/open-access/hrscd-high-resolution-semantic-change-detection-dataset) 10 | * Download the [WHU-CD Dataset](http://gpcv.whu.edu.cn/data/building_dataset.html) 11 | * Download the [LEVIR-CD Dataset](http://chenhao.in/LEVIR/) 12 | ``` 13 | Prepare datasets into following structure and set their path in train_options.py 14 | ├─Train 15 | │ ├─time1 16 | │ │─time2 17 | │ │─label 18 | ├─Test 19 | │ ├─time1 20 | │ │─time2 21 | │ │─label 22 | ``` 23 | # Train 24 | ``` 25 | python train.py 26 | ``` 27 | All the hyperparameters can be adjusted in train_options.py 28 | # model zool 29 | The models with the scores can be downloaded from[Baidu Cloud](). 30 | # Acknowledgments 31 | This code is heavily borrowed from [MSCANet](https://github.com/liumency/CropLand-CD) and [changer](https://github.com/likyoo/open-cd/tree/main). 32 | # Citation 33 | If you find this repo useful for your research, please consider citing the paper as follows: 34 | ``` 35 | @article{liu2023attention, 36 | title={An attention-based multiscale transformer network for remote sensing image change detection}, 37 | author={Liu, Wei and Lin, Yiyuan and Liu, Weijia and Yu, Yongtao and Li, Jonathan}, 38 | journal={ISPRS Journal of Photogrammetry and Remote Sensing}, 39 | volume={202}, 40 | pages={599--609}, 41 | year={2023}, 42 | publisher={Elsevier} 43 | } 44 | ``` 45 | 46 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__init__.py -------------------------------------------------------------------------------- /__pycache__/data_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/data_utils.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/data_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/data_utils.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/train_options.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/train_options.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/train_options_CLCD.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/train_options_CLCD.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/train_options_HRSCD.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/train_options_HRSCD.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/train_options_LEVIR.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/train_options_LEVIR.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/train_options_WHU.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/__pycache__/train_options_WHU.cpython-39.pyc -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | from os.path import join 3 | import torch 4 | from PIL import Image, ImageEnhance 5 | from torch.utils.data.dataset import Dataset 6 | from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize 7 | import numpy as np 8 | import torchvision.transforms as transforms 9 | import os 10 | import imageio 11 | 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in ['.tif','.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']) 15 | 16 | def calMetric_iou(predict, label): 17 | tp = np.sum(np.logical_and(predict == 1, label == 1)) 18 | fp = np.sum(predict==1) 19 | fn = np.sum(label == 1) 20 | return tp,fp+fn-tp 21 | 22 | 23 | def getDataList(img_path): 24 | dataline = open(img_path, 'r').readlines() 25 | datalist =[] 26 | for line in dataline: 27 | temp = line.strip('\n') 28 | datalist.append(temp) 29 | return datalist 30 | 31 | 32 | def make_one_hot(input, num_classes): 33 | """Convert class index tensor to one hot encoding tensor. 34 | 35 | Args: 36 | input: A tensor of shape [N, 1, *] 37 | num_classes: An int of number of class 38 | Returns: 39 | A tensor of shape [N, num_classes, *] 40 | """ 41 | shape = np.array(input.shape) 42 | shape[1] = num_classes 43 | shape = tuple(shape) 44 | result = torch.zeros(shape) 45 | result = result.scatter_(1, input.cpu(), 1) 46 | return result 47 | 48 | 49 | def get_transform(convert=True, normalize=False): 50 | transform_list = [] 51 | if convert: 52 | transform_list += [transforms.ToTensor()] 53 | if normalize: 54 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 55 | (0.5, 0.5, 0.5))] 56 | return transforms.Compose(transform_list) 57 | 58 | 59 | 60 | class LoadDatasetFromFolder(Dataset): 61 | def __init__(self, args, hr1_path, hr2_path, lab_path): 62 | super(LoadDatasetFromFolder, self).__init__() 63 | # 获取图片列表 64 | datalist = [name for name in os.listdir(hr1_path) for item in args.suffix if 65 | os.path.splitext(name)[1] == item] 66 | 67 | self.hr1_filenames = [join(hr1_path, x) for x in datalist if is_image_file(x)] 68 | self.hr2_filenames = [join(hr2_path, x) for x in datalist if is_image_file(x)] 69 | self.lab_filenames = [join(lab_path, x) for x in datalist if is_image_file(x)] 70 | 71 | self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1] 72 | self.label_transform = get_transform() # only convert to tensor 73 | 74 | def __getitem__(self, index): 75 | hr1_img = self.transform(Image.open(self.hr1_filenames[index]).convert('RGB')) 76 | # lr2_img = self.transform(Image.open(self.lr2_filenames[index]).convert('RGB')) 77 | hr2_img = self.transform(Image.open(self.hr2_filenames[index]).convert('RGB')) 78 | 79 | label = self.label_transform(Image.open(self.lab_filenames[index])) 80 | label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0) 81 | 82 | return hr1_img, hr2_img, label 83 | 84 | def __len__(self): 85 | return len(self.hr1_filenames) 86 | 87 | 88 | class TestDatasetFromFolder(Dataset): 89 | def __init__(self, args, Time1_dir, Time2_dir, Label_dir): 90 | super(TestDatasetFromFolder, self).__init__() 91 | 92 | datalist = [name for name in os.listdir(Time1_dir) for item in args.suffix if 93 | os.path.splitext(name)[1] == item] 94 | 95 | self.image1_filenames = [join(Time1_dir, x) for x in datalist if is_image_file(x)] 96 | self.image2_filenames = [join(Time2_dir, x) for x in datalist if is_image_file(x)] 97 | self.image3_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)] 98 | 99 | self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1] 100 | self.label_transform = get_transform() 101 | 102 | def __getitem__(self, index): 103 | image1 = self.transform(Image.open(self.image1_filenames[index]).convert('RGB')) 104 | image2 = self.transform(Image.open(self.image2_filenames[index]).convert('RGB')) 105 | 106 | label = self.label_transform(Image.open(self.image3_filenames[index])) 107 | label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0) 108 | 109 | image_name = self.image1_filenames[index].split('/', -1) 110 | image_name = image_name[len(image_name)-1] 111 | 112 | return image1, image2, label, image_name 113 | 114 | def __len__(self): 115 | return len(self.image1_filenames) 116 | 117 | 118 | class trainImageAug(object): 119 | def __init__(self, crop = True, augment = True, angle = 30): 120 | self.crop =crop 121 | self.augment = augment 122 | self.angle = angle 123 | 124 | def __call__(self, image1, image2, mask): 125 | if self.crop: 126 | w = np.random.randint(0,256) 127 | h = np.random.randint(0,256) 128 | box = (w, h, w+256, h+256) 129 | image1 = image1.crop(box) 130 | image2 = image2.crop(box) 131 | mask = mask.crop(box) 132 | if self.augment: 133 | prop = np.random.uniform(0, 1) 134 | if prop < 0.15: 135 | image1 = image1.transpose(Image.FLIP_LEFT_RIGHT) 136 | image2 = image2.transpose(Image.FLIP_LEFT_RIGHT) 137 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 138 | elif prop < 0.3: 139 | image1 = image1.transpose(Image.FLIP_TOP_BOTTOM) 140 | image2 = image2.transpose(Image.FLIP_TOP_BOTTOM) 141 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM) 142 | elif prop < 0.5: 143 | image1 = image1.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle])) 144 | image2 = image2.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle])) 145 | mask = mask.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle])) 146 | 147 | return image1, image2, mask 148 | 149 | def get_transform(convert=True, normalize=False): 150 | transform_list = [] 151 | if convert: 152 | transform_list += [ 153 | transforms.ToTensor(), 154 | ] 155 | if normalize: 156 | transform_list += [ 157 | # transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), 158 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))] 159 | return transforms.Compose(transform_list) 160 | 161 | 162 | class DA_DatasetFromFolder(Dataset): 163 | def __init__(self, Image_dir1, Image_dir2, Label_dir, crop=True, augment = True, angle = 30): 164 | super(DA_DatasetFromFolder, self).__init__() 165 | # 获取图片列表 166 | datalist = os.listdir(Image_dir1) 167 | self.image_filenames1 = [join(Image_dir1, x) for x in datalist if is_image_file(x)] 168 | self.image_filenames2 = [join(Image_dir2, x) for x in datalist if is_image_file(x)] 169 | self.label_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)] 170 | self.data_augment = trainImageAug(crop=crop, augment = augment, angle=angle) 171 | self.img_transform = get_transform(convert=True, normalize=True) 172 | self.lab_transform = get_transform() 173 | 174 | def __getitem__(self, index): 175 | image1 = Image.open(self.image_filenames1[index]).convert('RGB') 176 | image2 = Image.open(self.image_filenames2[index]).convert('RGB') 177 | label = Image.open(self.label_filenames[index]) 178 | image1, image2, label = self.data_augment(image1, image2, label) 179 | image1, image2 = self.img_transform(image1), self.img_transform(image2) 180 | label = self.lab_transform(label) 181 | label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0) 182 | return image1, image2, label 183 | 184 | def __len__(self): 185 | return len(self.image_filenames1) -------------------------------------------------------------------------------- /loss/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/loss/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/losses.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/loss/__pycache__/losses.cpython-39.pyc -------------------------------------------------------------------------------- /loss/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | class cross_entropy(nn.Module): 6 | def __init__(self, weight=None, reduction='mean',ignore_index=256): 7 | super(cross_entropy, self).__init__() 8 | self.weight = weight 9 | self.ignore_index =ignore_index 10 | self.reduction = reduction 11 | 12 | 13 | def forward(self,input, target): 14 | target = target.long() 15 | if target.dim() == 4: 16 | target = torch.squeeze(target, dim=1) 17 | if input.shape[-1] != target.shape[-1]: 18 | input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True) 19 | 20 | return F.cross_entropy(input=input, target=target, weight=self.weight, 21 | ignore_index=self.ignore_index, reduction=self.reduction) 22 | -------------------------------------------------------------------------------- /model/MSCANet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | 7 | class Residual(nn.Module): 8 | def __init__(self, fn): 9 | super().__init__() 10 | self.fn = fn 11 | def forward(self, x, **kwargs): 12 | return self.fn(x, **kwargs) + x 13 | 14 | 15 | class Residual2(nn.Module): 16 | def __init__(self, fn): 17 | super().__init__() 18 | self.fn = fn 19 | def forward(self, x, x2, **kwargs): 20 | return self.fn(x, x2, **kwargs) + x 21 | 22 | 23 | class PreNorm(nn.Module): 24 | def __init__(self, dim, fn): 25 | super().__init__() 26 | self.norm = nn.LayerNorm(dim) 27 | self.fn = fn 28 | def forward(self, x, **kwargs): 29 | return self.fn(self.norm(x), **kwargs) 30 | 31 | class FeedForward(nn.Module): 32 | def __init__(self, dim, hidden_dim, dropout = 0.): 33 | super().__init__() 34 | self.net = nn.Sequential( 35 | nn.Linear(dim, hidden_dim), 36 | nn.GELU(), 37 | nn.Dropout(dropout), 38 | nn.Linear(hidden_dim, dim), 39 | nn.Dropout(dropout) 40 | ) 41 | def forward(self, x): 42 | return self.net(x) 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 46 | super().__init__() 47 | inner_dim = dim_head * heads 48 | project_out = not (heads == 1 and dim_head == dim) 49 | 50 | self.heads = heads 51 | self.scale = dim_head ** -0.5 52 | 53 | self.attend = nn.Softmax(dim = -1) 54 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 55 | 56 | self.to_out = nn.Sequential( 57 | nn.Linear(inner_dim, dim), 58 | nn.Dropout(dropout) 59 | ) if project_out else nn.Identity() 60 | 61 | def forward(self, x): 62 | qkv = self.to_qkv(x).chunk(3, dim = -1) 63 | 64 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 65 | 66 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 67 | 68 | attn = self.attend(dots) 69 | 70 | out = torch.matmul(attn, v) 71 | 72 | out = rearrange(out, 'b h n d -> b n (h d)') 73 | out = self.to_out(out) 74 | 75 | return out 76 | 77 | class Transformer(nn.Module): 78 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 79 | super().__init__() 80 | self.layers = nn.ModuleList([]) 81 | for _ in range(depth): 82 | self.layers.append(nn.ModuleList([ 83 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 84 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 85 | ])) 86 | def forward(self, x): 87 | for attn, ff in self.layers: 88 | x = attn(x) + x 89 | x = ff(x) + x 90 | return x 91 | 92 | 93 | 94 | class Cross_Attention(nn.Module): 95 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., softmax=True): 96 | super().__init__() 97 | inner_dim = dim_head * heads 98 | self.heads = heads 99 | self.scale = dim ** -0.5 100 | 101 | self.softmax = softmax 102 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 103 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 104 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 105 | 106 | self.to_out = nn.Sequential( 107 | nn.Linear(inner_dim, dim), 108 | nn.Dropout(dropout) 109 | ) 110 | 111 | def forward(self, x, m, mask = None): 112 | 113 | b, n, _, h = *x.shape, self.heads 114 | q = self.to_q(x) 115 | k = self.to_k(m) 116 | v = self.to_v(m) 117 | 118 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q,k,v]) 119 | 120 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 121 | mask_value = -torch.finfo(dots.dtype).max 122 | 123 | if mask is not None: 124 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 125 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 126 | mask = mask[:, None, :] * mask[:, :, None] 127 | dots.masked_fill_(~mask, mask_value) 128 | del mask 129 | 130 | if self.softmax: 131 | attn = dots.softmax(dim=-1) 132 | else: 133 | attn = dots 134 | 135 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 136 | out = rearrange(out, 'b h n d -> b n (h d)') 137 | out = self.to_out(out) 138 | 139 | return out 140 | 141 | class PreNorm2(nn.Module): 142 | def __init__(self, dim, fn): 143 | super().__init__() 144 | self.norm = nn.LayerNorm(dim) 145 | self.fn = fn 146 | def forward(self, x, x2, **kwargs): 147 | return self.fn(self.norm(x), self.norm(x2), **kwargs) 148 | 149 | 150 | class TransformerDecoder(nn.Module): 151 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True): 152 | super().__init__() 153 | self.layers = nn.ModuleList([]) 154 | for _ in range(depth): 155 | self.layers.append(nn.ModuleList([ 156 | Residual2(PreNorm2(dim, Cross_Attention(dim, heads = heads, 157 | dim_head = dim_head, dropout = dropout, 158 | softmax=softmax))), 159 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) 160 | ])) 161 | def forward(self, x, m, mask = None): 162 | """target(query), memory""" 163 | for attn, ff in self.layers: 164 | x = attn(x, m, mask = mask) 165 | x = ff(x) 166 | return x -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | from torch.autograd import Variable 8 | 9 | class Model(nn.Module): 10 | def __init__(self, args, ckp): 11 | super(Model, self).__init__() 12 | print('Making model...') 13 | 14 | self.scale = args.scale 15 | self.idx_scale = 0 16 | self.self_ensemble = args.self_ensemble 17 | self.chop = args.chop 18 | self.precision = args.precision 19 | self.cpu = args.cpu 20 | self.device = torch.device('cpu' if args.cpu else 'cuda') 21 | self.n_GPUs = args.n_GPUs 22 | self.save_models = args.save_models 23 | 24 | module = import_module('model.' + args.model.lower()) 25 | self.model = module.make_model(args).to(self.device) 26 | if args.precision == 'half': self.model.half() 27 | 28 | if not args.cpu and args.n_GPUs > 1: 29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs)) 30 | 31 | self.load( 32 | ckp.dir, 33 | pre_train=args.pre_train, 34 | resume=args.resume, 35 | cpu=args.cpu 36 | ) 37 | print(self.model, file=ckp.log_file) 38 | 39 | def forward(self, x, scale, pos_mat): 40 | self.scale = scale 41 | target = self.get_model() 42 | if hasattr(target, 'set_scale'): 43 | target.set_scale(scale) 44 | 45 | if self.self_ensemble and not self.training: 46 | if self.chop: 47 | forward_function = self.forward_chop 48 | else: 49 | forward_function = self.model.forward 50 | 51 | return self.forward_x8(x, forward_function) 52 | elif self.chop and not self.training: 53 | return self.forward_chop(x,pos_mat) 54 | else: 55 | return self.model(x,pos_mat) 56 | 57 | def get_model(self): 58 | if self.n_GPUs <= 1 or self.cpu: 59 | return self.model 60 | else: 61 | return self.model.module 62 | 63 | def state_dict(self, **kwargs): 64 | target = self.get_model() 65 | return target.state_dict(**kwargs) 66 | 67 | def save(self, apath, epoch, is_best=False): 68 | target = self.get_model() 69 | torch.save( 70 | target.state_dict(), 71 | os.path.join(apath, 'model', 'model_latest.pt') 72 | ) 73 | if is_best: 74 | torch.save( 75 | target.state_dict(), 76 | os.path.join(apath, 'model', 'model_best.pt') 77 | ) 78 | 79 | if self.save_models: 80 | torch.save( 81 | target.state_dict(), 82 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) 83 | ) 84 | 85 | def load(self, apath, pre_train='.', resume=-1, cpu=False): 86 | if cpu: 87 | kwargs = {'map_location': lambda storage, loc: storage} 88 | else: 89 | kwargs = {} 90 | 91 | if resume == -1: 92 | self.get_model().load_state_dict( 93 | torch.load( 94 | os.path.join(apath, 'model', 'model_latest.pt'), 95 | **kwargs 96 | ), 97 | strict=False 98 | ) 99 | elif resume == 0: 100 | if pre_train != '.': 101 | print('Loading model from {}'.format(pre_train)) 102 | self.get_model().load_state_dict( 103 | torch.load(pre_train, **kwargs), 104 | strict=False 105 | ) 106 | print('load_model_mode=1') 107 | else: 108 | self.get_model().load_state_dict( 109 | torch.load( 110 | os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), 111 | **kwargs 112 | ), 113 | strict=False 114 | ) 115 | print('load_model_mode=2') 116 | 117 | def forward_chop(self, x, pos_mat, shave=10, min_size=160000): 118 | scale = self.scale[self.idx_scale] 119 | n_GPUs = min(self.n_GPUs, 4) 120 | b, c, h, w = x.size() 121 | h_half, w_half = h // 2, w // 2 122 | h_size, w_size = h_half + shave, w_half + shave 123 | lr_list = [ 124 | x[:, :, 0:h_size, 0:w_size], 125 | x[:, :, 0:h_size, (w - w_size):w], 126 | x[:, :, (h - h_size):h, 0:w_size], 127 | x[:, :, (h - h_size):h, (w - w_size):w]] 128 | 129 | if w_size * h_size < min_size: 130 | sr_list = [] 131 | for i in range(0, 4, n_GPUs): 132 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 133 | sr_batch = self.model(lr_batch, pos_mat) 134 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 135 | else: 136 | sr_list = [ 137 | self.forward_chop(patch, pos_mat, shave=shave, min_size=min_size) \ 138 | for patch in lr_list 139 | ] 140 | scale = math.ceil(scale) 141 | h, w = scale * h, scale * w 142 | h_half, w_half = scale * h_half, scale * w_half 143 | h_size, w_size = scale * h_size, scale * w_size 144 | shave *= scale 145 | 146 | output = x.new(b, c, h, w) 147 | output[:, :, 0:h_half, 0:w_half] \ 148 | = sr_list[0][:, :, 0:h_half, 0:w_half] 149 | output[:, :, 0:h_half, w_half:w] \ 150 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 151 | output[:, :, h_half:h, 0:w_half] \ 152 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 153 | output[:, :, h_half:h, w_half:w] \ 154 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 155 | 156 | return output 157 | 158 | def forward_x8(self, x, forward_function): 159 | def _transform(v, op): 160 | if self.precision != 'single': v = v.float() 161 | 162 | v2np = v.data.cpu().numpy() 163 | if op == 'v': 164 | tfnp = v2np[:, :, :, ::-1].copy() 165 | elif op == 'h': 166 | tfnp = v2np[:, :, ::-1, :].copy() 167 | elif op == 't': 168 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 169 | 170 | ret = torch.Tensor(tfnp).to(self.device) 171 | if self.precision == 'half': ret = ret.half() 172 | 173 | return ret 174 | 175 | lr_list = [x] 176 | for tf in 'v', 'h', 't': 177 | lr_list.extend([_transform(t, tf) for t in lr_list]) 178 | 179 | sr_list = [forward_function(aug) for aug in lr_list] 180 | for i in range(len(sr_list)): 181 | if i > 3: 182 | sr_list[i] = _transform(sr_list[i], 't') 183 | if i % 4 > 1: 184 | sr_list[i] = _transform(sr_list[i], 'h') 185 | if (i % 4) % 2 == 1: 186 | sr_list[i] = _transform(sr_list[i], 'v') 187 | 188 | output_cat = torch.cat(sr_list, dim=0) 189 | output = output_cat.mean(dim=0, keepdim=True) 190 | 191 | return output 192 | 193 | -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/backbone.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/model/__pycache__/backbone.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/modules.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/model/__pycache__/modules.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/network.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linyiyuan11/AMT_Net/a6dd0678fe1ae5a88fd3200bac9fa1bdbb7e8bc6/model/__pycache__/network.cpython-39.pyc -------------------------------------------------------------------------------- /model/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | import math 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=dilation, groups=groups, bias=False, dilation=dilation) 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 19 | } 20 | 21 | def ResNet34(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3): 22 | """ 23 | output, low_level_feat: 24 | 512, 64 25 | """ 26 | print(in_c) 27 | model = ResNet(BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c) 28 | if in_c != 3: 29 | pretrained = False 30 | if pretrained: 31 | model._load_pretrained_model(model_urls['resnet34']) 32 | return model 33 | 34 | 35 | def ResNet18(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3): 36 | """ 37 | output, low_level_feat: 38 | 512, 256, 128, 64, 64 39 | """ 40 | model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, in_c=in_c) 41 | if in_c !=3: 42 | pretrained=False 43 | if pretrained: 44 | print(122222222) 45 | model._load_pretrained_model(model_urls['resnet18']) 46 | return model 47 | 48 | 49 | def ResNet50(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3): 50 | """ 51 | output, low_level_feat: 52 | 2048, 256 53 | """ 54 | model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c) 55 | if in_c !=3: 56 | pretrained=False 57 | if pretrained: 58 | model._load_pretrained_model(model_urls['resnet50']) 59 | return model 60 | 61 | 62 | class BasicBlock(nn.Module): 63 | expansion = 1 64 | 65 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 66 | super(BasicBlock, self).__init__() 67 | 68 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 69 | dilation=dilation, padding=dilation, bias=False) 70 | self.bn1 = BatchNorm(planes) 71 | self.relu = nn.ReLU(inplace=True) 72 | # self.do1 = nn.Dropout2d(p=0.2) 73 | 74 | self.conv2 = conv3x3(planes, planes) 75 | self.bn2 = BatchNorm(planes) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | identity = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | 89 | if self.downsample is not None: 90 | identity = self.downsample(x) 91 | 92 | out += identity 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | class SELayer(nn.Module): 98 | def __init__(self, channel, reduction=16): 99 | super(SELayer, self).__init__() 100 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 101 | self.fc = nn.Sequential( 102 | nn.Linear(channel, channel // reduction, bias=False), 103 | nn.ReLU(inplace=True), 104 | nn.Linear(channel // reduction, channel, bias=False), 105 | nn.Sigmoid() 106 | ) 107 | 108 | def forward(self, x): 109 | b, c, _, _ = x.size() 110 | y = self.avg_pool(x).view(b, c) 111 | y = self.fc(y).view(b, c, 1, 1) 112 | return x * y.expand_as(x) 113 | 114 | class Bottleneck(nn.Module): 115 | expansion = 4 116 | 117 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 118 | super(Bottleneck, self).__init__() 119 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 120 | self.bn1 = BatchNorm(planes) 121 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 122 | dilation=dilation, padding=dilation, bias=False) 123 | self.bn2 = BatchNorm(planes) 124 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 125 | self.bn3 = BatchNorm(planes * 4) 126 | self.relu = nn.ReLU() 127 | self.downsample = downsample 128 | self.stride = stride 129 | self.dilation = dilation 130 | 131 | 132 | 133 | def forward(self, x): 134 | residual = x 135 | 136 | out = self.conv1(x) 137 | out = self.bn1(out) 138 | out = self.relu(out) 139 | 140 | out = self.conv2(out) 141 | out = self.bn2(out) 142 | out = self.relu(out) 143 | 144 | out = self.conv3(out) 145 | out = self.bn3(out) 146 | 147 | if self.downsample is not None: 148 | residual = self.downsample(x) 149 | 150 | out += residual 151 | out = self.relu(out) 152 | 153 | return out 154 | 155 | class PA(nn.Module): 156 | def __init__(self, inchan = 512, out_chan = 32): 157 | super().__init__() 158 | self.conv = nn.Conv2d(inchan, out_chan, kernel_size=3, padding=1, bias=False) 159 | self.bn = nn.BatchNorm2d(out_chan) 160 | self.re = nn.ReLU() 161 | self.do = nn.Dropout2d(0.2) 162 | 163 | self.pa_conv = nn.Conv2d(out_chan, out_chan, kernel_size=1, padding=0, groups=out_chan) 164 | self.sigmoid = nn.Sigmoid() 165 | 166 | def forward(self, x): 167 | x0 = self.conv(x) 168 | x = self.do(self.re(self.bn(x0))) 169 | return x0 *self.sigmoid(self.pa_conv(x)) 170 | 171 | 172 | class ResNet(nn.Module): 173 | 174 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True, in_c=3): 175 | 176 | self.inplanes = 64 177 | self.in_c = in_c 178 | print('in_c: ',self.in_c) 179 | super(ResNet, self).__init__() 180 | blocks = [1, 2, 4] 181 | if output_stride == 32: 182 | strides = [1, 2, 2, 2] 183 | dilations = [1, 1, 1, 1] 184 | elif output_stride == 16: 185 | strides = [1, 2, 2, 1] 186 | dilations = [1, 1, 1, 2] 187 | elif output_stride == 8: 188 | strides = [1, 2, 1, 1] 189 | dilations = [1, 1, 2, 4] 190 | elif output_stride == 4: 191 | strides = [1, 1, 1, 1] 192 | dilations = [1, 2, 4, 8] 193 | else: 194 | raise NotImplementedError 195 | 196 | # Modules 197 | self.conv1 = nn.Conv2d(self.in_c, 64, kernel_size=7, stride=2, padding=3, 198 | bias=False) 199 | self.bn1 = BatchNorm(64) 200 | self.relu = nn.ReLU(inplace=True) 201 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 202 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 203 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 204 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 205 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 206 | self._init_weight() 207 | 208 | self.pos_s16 = PA(2048, 32) 209 | self.pos_s8 = PA(512, 32) 210 | self.pos_s4 = PA(256, 32) 211 | 212 | # self.pos_s16 = PA(512, 32) 213 | # self.pos_s8 = PA(128, 32) 214 | # self.pos_s4 = PA(64, 32) 215 | 216 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 217 | downsample = None 218 | if stride != 1 or self.inplanes != planes * block.expansion: 219 | downsample = nn.Sequential( 220 | nn.Conv2d(self.inplanes, planes * block.expansion, 221 | kernel_size=1, stride=stride, bias=False), 222 | BatchNorm(planes * block.expansion), 223 | ) 224 | 225 | layers = [] 226 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 227 | self.inplanes = planes * block.expansion 228 | for i in range(1, blocks): 229 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 230 | 231 | return nn.Sequential(*layers) 232 | 233 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 234 | downsample = None 235 | if stride != 1 or self.inplanes != planes * block.expansion: 236 | downsample = nn.Sequential( 237 | nn.Conv2d(self.inplanes, planes * block.expansion, 238 | kernel_size=1, stride=stride, bias=False), 239 | BatchNorm(planes * block.expansion), 240 | ) 241 | 242 | layers = [] 243 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 244 | downsample=downsample, BatchNorm=BatchNorm)) 245 | self.inplanes = planes * block.expansion 246 | for i in range(1, len(blocks)): 247 | layers.append(block(self.inplanes, planes, stride=1, 248 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 249 | 250 | return nn.Sequential(*layers) 251 | 252 | def forward(self, input): 253 | x = self.conv1(input) 254 | x = self.bn1(x) 255 | x = self.relu(x) 256 | x = self.maxpool(x) # | 4 257 | 258 | 259 | x = self.layer1(x) # | 4 260 | low_level_feat2 = x 261 | 262 | x = self.layer2(x) # | 8 263 | low_level_feat3 = x 264 | 265 | x = self.layer3(x) # | 16 266 | x = self.layer4(x) # | 16 267 | 268 | out_s16, out_s8, out_s4 = self.pos_s16(x), self.pos_s8(low_level_feat3), self.pos_s4(low_level_feat2) 269 | return out_s16, out_s8, out_s4 270 | 271 | 272 | def _init_weight(self): 273 | for m in self.modules(): 274 | if isinstance(m, nn.Conv2d): 275 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 276 | m.weight.data.normal_(0, math.sqrt(2. / n)) 277 | elif isinstance(m, nn.BatchNorm2d): 278 | m.weight.data.fill_(1) 279 | m.bias.data.zero_() 280 | 281 | def _load_pretrained_model(self, model_path): 282 | pretrain_dict = model_zoo.load_url(model_path) 283 | model_dict = {} 284 | state_dict = self.state_dict() 285 | for k, v in pretrain_dict.items(): 286 | if k in state_dict: 287 | model_dict[k] = v 288 | state_dict.update(model_dict) 289 | self.load_state_dict(state_dict) 290 | 291 | def build_backbone(backbone, output_stride, BatchNorm, in_c=3): 292 | if backbone == 'resnet50': 293 | return ResNet50(output_stride, BatchNorm, in_c=in_c) 294 | elif backbone == 'resnet34': 295 | return ResNet34(output_stride, BatchNorm, in_c=in_c) 296 | elif backbone == 'resnet18': 297 | return ResNet18(output_stride, BatchNorm, in_c=in_c) 298 | else: 299 | raise NotImplementedError 300 | 301 | -------------------------------------------------------------------------------- /model/dtcdscn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torchvision.models import ResNet 5 | import torch.nn.functional as F 6 | from functools import partial 7 | 8 | 9 | nonlinearity = partial(F.relu,inplace=True) 10 | 11 | class SELayer(nn.Module): 12 | def __init__(self, channel, reduction=16): 13 | super(SELayer, self).__init__() 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | self.fc = nn.Sequential( 16 | nn.Linear(channel, channel // reduction, bias=False), 17 | nn.ReLU(inplace=True), 18 | nn.Linear(channel // reduction, channel, bias=False), 19 | nn.Sigmoid() 20 | ) 21 | 22 | def forward(self, x): 23 | b, c, _, _ = x.size() 24 | y = self.avg_pool(x).view(b, c) 25 | y = self.fc(y).view(b, c, 1, 1) 26 | return x * y.expand_as(x) 27 | 28 | class Dblock_more_dilate(nn.Module): 29 | def __init__(self, channel): 30 | super(Dblock_more_dilate, self).__init__() 31 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 32 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2) 33 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4) 34 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8) 35 | self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16) 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | 41 | def forward(self, x): 42 | dilate1_out = nonlinearity(self.dilate1(x)) 43 | dilate2_out = nonlinearity(self.dilate2(dilate1_out)) 44 | dilate3_out = nonlinearity(self.dilate3(dilate2_out)) 45 | dilate4_out = nonlinearity(self.dilate4(dilate3_out)) 46 | dilate5_out = nonlinearity(self.dilate5(dilate4_out)) 47 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out + dilate5_out 48 | return out 49 | class Dblock(nn.Module): 50 | def __init__(self, channel): 51 | super(Dblock, self).__init__() 52 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 53 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2) 54 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4) 55 | self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8) 56 | # self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16) 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 59 | if m.bias is not None: 60 | m.bias.data.zero_() 61 | 62 | def forward(self, x): 63 | dilate1_out = nonlinearity(self.dilate1(x)) 64 | dilate2_out = nonlinearity(self.dilate2(dilate1_out)) 65 | dilate3_out = nonlinearity(self.dilate3(dilate2_out)) 66 | dilate4_out = nonlinearity(self.dilate4(dilate3_out)) 67 | # dilate5_out = nonlinearity(self.dilate5(dilate4_out)) 68 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out # + dilate5_out 69 | return out 70 | 71 | def conv3x3(in_planes, out_planes, stride=1): 72 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 73 | 74 | class SEBasicBlock(nn.Module): 75 | expansion = 1 76 | 77 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16): 78 | super(SEBasicBlock, self).__init__() 79 | self.conv1 = conv3x3(inplanes, planes, stride) 80 | self.bn1 = nn.BatchNorm2d(planes) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.conv2 = conv3x3(planes, planes, 1) 83 | self.bn2 = nn.BatchNorm2d(planes) 84 | self.se = SELayer(planes, reduction) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | residual = x 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.se(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | class DecoderBlock(nn.Module): 107 | def __init__(self, in_channels, n_filters): 108 | super(DecoderBlock,self).__init__() 109 | 110 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1) 111 | self.norm1 = nn.BatchNorm2d(in_channels // 4) 112 | self.relu1 = nonlinearity 113 | self.scse = SCSEBlock(in_channels // 4) 114 | 115 | self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1) 116 | self.norm2 = nn.BatchNorm2d(in_channels // 4) 117 | self.relu2 = nonlinearity 118 | 119 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1) 120 | self.norm3 = nn.BatchNorm2d(n_filters) 121 | self.relu3 = nonlinearity 122 | 123 | def forward(self, x): 124 | x = self.conv1(x) 125 | x = self.norm1(x) 126 | x = self.relu1(x) 127 | y = self.scse(x) 128 | x = x + y 129 | x = self.deconv2(x) 130 | x = self.norm2(x) 131 | x = self.relu2(x) 132 | x = self.conv3(x) 133 | x = self.norm3(x) 134 | x = self.relu3(x) 135 | return x 136 | 137 | class SCSEBlock(nn.Module): 138 | def __init__(self, channel, reduction=16): 139 | super(SCSEBlock, self).__init__() 140 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 141 | 142 | '''self.channel_excitation = nn.Sequential(nn.(channel, int(channel//reduction)), 143 | nn.ReLU(inplace=True), 144 | nn.Linear(int(channel//reduction), channel), 145 | nn.Sigmoid())''' 146 | self.channel_excitation = nn.Sequential(nn.Conv2d(channel, int(channel//reduction), kernel_size=1, 147 | stride=1, padding=0, bias=False), 148 | nn.ReLU(inplace=True), 149 | nn.Conv2d(int(channel // reduction), channel,kernel_size=1, 150 | stride=1, padding=0, bias=False), 151 | nn.Sigmoid()) 152 | 153 | self.spatial_se = nn.Sequential(nn.Conv2d(channel, 1, kernel_size=1, 154 | stride=1, padding=0, bias=False), 155 | nn.Sigmoid()) 156 | 157 | def forward(self, x): 158 | bahs, chs, _, _ = x.size() 159 | 160 | # Returns a new tensor with the same data as the self tensor but of a different size. 161 | chn_se = self.avg_pool(x) 162 | chn_se = self.channel_excitation(chn_se) 163 | chn_se = torch.mul(x, chn_se) 164 | spa_se = self.spatial_se(x) 165 | spa_se = torch.mul(x, spa_se) 166 | return torch.add(chn_se, 1, spa_se) 167 | 168 | class CDNet_model(nn.Module): 169 | def __init__(self, in_channels=3, block=SEBasicBlock, layers=[3, 4, 6, 3], num_classes=7): 170 | super(CDNet_model, self).__init__() 171 | 172 | filters = [64, 128, 256, 512] 173 | self.inplanes = 64 174 | self.firstconv = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, 175 | bias=False) 176 | self.firstbn = nn.BatchNorm2d(64) 177 | self.firstrelu = nn.ReLU(inplace=True) 178 | self.firstmaxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 179 | self.encoder1 = self._make_layer(block, 64, layers[0]) 180 | self.encoder2 = self._make_layer(block, 128, layers[1], stride=2) 181 | self.encoder3 = self._make_layer(block, 256, layers[2], stride=2) 182 | self.encoder4 = self._make_layer(block, 512, layers[3], stride=2) 183 | 184 | self.decoder4 = DecoderBlock(filters[3], filters[2]) 185 | self.decoder3 = DecoderBlock(filters[2], filters[1]) 186 | self.decoder2 = DecoderBlock(filters[1], filters[0]) 187 | self.decoder1 = DecoderBlock(filters[0], filters[0]) 188 | 189 | self.dblock_master = Dblock(512) 190 | self.dblock = Dblock(512) 191 | 192 | self.decoder4_master = DecoderBlock(filters[3], filters[2]) 193 | self.decoder3_master = DecoderBlock(filters[2], filters[1]) 194 | self.decoder2_master = DecoderBlock(filters[1], filters[0]) 195 | self.decoder1_master = DecoderBlock(filters[0], filters[0]) 196 | 197 | self.finaldeconv1_master = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 198 | self.finalrelu1_master = nonlinearity 199 | self.finalconv2_master = nn.Conv2d(32, 32, 3, padding=1) 200 | self.finalrelu2_master = nonlinearity 201 | self.finalconv3_master = nn.Conv2d(32, 2, 3, padding=1) 202 | 203 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1) 204 | self.finalrelu1 = nonlinearity 205 | self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1) 206 | self.finalrelu2 = nonlinearity 207 | self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1) 208 | 209 | for m in self.modules(): 210 | if isinstance(m, nn.Conv2d): 211 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 212 | m.weight.data.normal_(0, math.sqrt(2. / n)) 213 | elif isinstance(m, nn.BatchNorm2d): 214 | m.weight.data.fill_(1) 215 | m.bias.data.zero_() 216 | 217 | def _make_layer(self, block, planes, blocks, stride=1): 218 | downsample = None 219 | if stride != 1 or self.inplanes != planes * block.expansion: 220 | downsample = nn.Sequential( 221 | nn.Conv2d(self.inplanes, planes * block.expansion, 222 | kernel_size=1, stride=stride, bias=False), 223 | nn.BatchNorm2d(planes * block.expansion), 224 | ) 225 | 226 | layers = [] 227 | layers.append(block(self.inplanes, planes, stride, downsample)) 228 | self.inplanes = planes * block.expansion 229 | for i in range(1, blocks): 230 | layers.append(block(self.inplanes, planes)) 231 | 232 | return nn.Sequential(*layers) 233 | 234 | def forward(self, x, y): 235 | # Encoder_1 236 | x = self.firstconv(x) 237 | x = self.firstbn(x) 238 | x = self.firstrelu(x) 239 | x = self.firstmaxpool(x) 240 | 241 | e1_x = self.encoder1(x) 242 | e2_x = self.encoder2(e1_x) 243 | e3_x = self.encoder3(e2_x) 244 | e4_x = self.encoder4(e3_x) 245 | 246 | # # Center_1 247 | e4_x_center = self.dblock(e4_x) 248 | 249 | # Decoder_1 250 | d4_x = self.decoder4(e4_x_center) + e3_x 251 | d3_x = self.decoder3(d4_x) + e2_x 252 | d2_x = self.decoder2(d3_x) + e1_x 253 | d1_x = self.decoder1(d2_x) 254 | 255 | out1 = self.finaldeconv1(d1_x) 256 | out1 = self.finalrelu1(out1) 257 | out1 = self.finalconv2(out1) 258 | out1 = self.finalrelu2(out1) 259 | out1 = self.finalconv3(out1) 260 | 261 | # Encoder_2 262 | y = self.firstconv(y) 263 | y = self.firstbn(y) 264 | y = self.firstrelu(y) 265 | y = self.firstmaxpool(y) 266 | 267 | e1_y = self.encoder1(y) 268 | e2_y = self.encoder2(e1_y) 269 | e3_y = self.encoder3(e2_y) 270 | e4_y = self.encoder4(e3_y) 271 | 272 | # # Center_2 273 | e4_y_center = self.dblock(e4_y) 274 | 275 | # Decoder_2 276 | d4_y = self.decoder4(e4_y_center) + e3_y 277 | d3_y = self.decoder3(d4_y) + e2_y 278 | d2_y = self.decoder2(d3_y) + e1_y 279 | d1_y = self.decoder1(d2_y) 280 | 281 | out2 = self.finaldeconv1(d1_y) 282 | out2 = self.finalrelu1(out2) 283 | out2 = self.finalconv2(out2) 284 | out2 = self.finalrelu2(out2) 285 | out2 = self.finalconv3(out2) 286 | 287 | # center_master 288 | e4 = self.dblock_master(e4_x - e4_y) 289 | # decoder_master 290 | d4 = self.decoder4_master(e4) + e3_x - e3_y 291 | d3 = self.decoder3_master(d4) + e2_x - e2_y 292 | d2 = self.decoder2_master(d3) + e1_x - e1_y 293 | d1 = self.decoder1_master(d2) 294 | 295 | out = self.finaldeconv1_master(d1) 296 | out = self.finalrelu1_master(out) 297 | out = self.finalconv2_master(out) 298 | out = self.finalrelu2_master(out) 299 | out = self.finalconv3_master(out) 300 | 301 | return out,out1,out2 302 | 303 | # return F.sigmoid(out),F.sigmoid(out1),F.sigmoid(out2) 304 | 305 | 306 | 307 | def CDNet34(in_channels, **kwargs): 308 | 309 | model = CDNet_model(in_channels, SEBasicBlock, [3, 4, 6, 3], **kwargs) 310 | 311 | return model -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from einops import rearrange 3 | from torch import nn 4 | import torch 5 | 6 | class Residual(nn.Module): 7 | def __init__(self, fn): 8 | super().__init__() 9 | self.fn = fn 10 | def forward(self, x, **kwargs): 11 | return self.fn(x, **kwargs) + x 12 | 13 | 14 | class Residual2(nn.Module): 15 | def __init__(self, fn): 16 | super().__init__() 17 | self.fn = fn 18 | def forward(self, x, x2, **kwargs): 19 | return self.fn(x, x2, **kwargs) + x 20 | 21 | 22 | class PreNorm(nn.Module): 23 | def __init__(self, dim, fn): 24 | super().__init__() 25 | self.norm = nn.LayerNorm(dim) 26 | self.fn = fn 27 | def forward(self, x, **kwargs): 28 | return self.fn(self.norm(x), **kwargs) 29 | 30 | class FeedForward(nn.Module): 31 | def __init__(self, dim, hidden_dim, dropout = 0.): 32 | super().__init__() 33 | self.net = nn.Sequential( 34 | nn.Linear(dim, hidden_dim), 35 | nn.GELU(), 36 | nn.Dropout(dropout), 37 | nn.Linear(hidden_dim, dim), 38 | nn.Dropout(dropout) 39 | ) 40 | def forward(self, x): 41 | return self.net(x) 42 | 43 | class Attention(nn.Module): 44 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 45 | super().__init__() 46 | inner_dim = dim_head * heads 47 | project_out = not (heads == 1 and dim_head == dim) 48 | 49 | self.heads = heads 50 | self.scale = dim_head ** -0.5 51 | 52 | self.attend = nn.Softmax(dim = -1) 53 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 54 | 55 | self.to_out = nn.Sequential( 56 | nn.Linear(inner_dim, dim), 57 | nn.Dropout(dropout) 58 | ) if project_out else nn.Identity() 59 | 60 | def forward(self, x): 61 | qkv = self.to_qkv(x).chunk(3, dim = -1) 62 | 63 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 64 | 65 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 66 | 67 | attn = self.attend(dots) 68 | 69 | out = torch.matmul(attn, v) 70 | 71 | out = rearrange(out, 'b h n d -> b n (h d)') 72 | out = self.to_out(out) 73 | 74 | return out 75 | 76 | class Transformer(nn.Module): 77 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 78 | super().__init__() 79 | self.layers = nn.ModuleList([]) 80 | for _ in range(depth): 81 | self.layers.append(nn.ModuleList([ 82 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 83 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 84 | ])) 85 | def forward(self, x): 86 | for attn, ff in self.layers: 87 | x = attn(x) + x 88 | x = ff(x) + x 89 | return x 90 | 91 | 92 | 93 | class Cross_Attention(nn.Module): 94 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., softmax=True): 95 | super().__init__() 96 | inner_dim = dim_head * heads 97 | self.heads = heads 98 | self.scale = dim ** -0.5 99 | 100 | self.softmax = softmax 101 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 102 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 103 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 104 | 105 | self.to_out = nn.Sequential( 106 | nn.Linear(inner_dim, dim), 107 | nn.Dropout(dropout) 108 | ) 109 | 110 | def forward(self, x, m, mask = None): 111 | 112 | b, n, _, h = *x.shape, self.heads 113 | q = self.to_q(x) 114 | k = self.to_k(m) 115 | v = self.to_v(m) 116 | 117 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q,k,v]) 118 | 119 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 120 | mask_value = -torch.finfo(dots.dtype).max 121 | 122 | if mask is not None: 123 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 124 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 125 | mask = mask[:, None, :] * mask[:, :, None] 126 | dots.masked_fill_(~mask, mask_value) 127 | del mask 128 | 129 | if self.softmax: 130 | attn = dots.softmax(dim=-1) 131 | else: 132 | attn = dots 133 | 134 | out = torch.einsum('bhij,bhjd->bhid', attn, v) 135 | out = rearrange(out, 'b h n d -> b n (h d)') 136 | out = self.to_out(out) 137 | 138 | return out 139 | 140 | class PreNorm2(nn.Module): 141 | def __init__(self, dim, fn): 142 | super().__init__() 143 | self.norm = nn.LayerNorm(dim) 144 | self.fn = fn 145 | def forward(self, x, x2, **kwargs): 146 | return self.fn(self.norm(x), self.norm(x2), **kwargs) 147 | 148 | 149 | class TransformerDecoder(nn.Module): 150 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True): 151 | super().__init__() 152 | self.layers = nn.ModuleList([]) 153 | for _ in range(depth): 154 | self.layers.append(nn.ModuleList([ 155 | Residual2(PreNorm2(dim, Cross_Attention(dim, heads = heads, 156 | dim_head = dim_head, dropout = dropout, 157 | softmax=softmax))), 158 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) 159 | ])) 160 | def forward(self, x, m, mask = None): 161 | """target(query), memory""" 162 | for attn, ff in self.layers: 163 | x = attn(x, m, mask = mask) 164 | x = ff(x) 165 | return x -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from mmcv.runner import BaseModule 5 | from torch import Tensor 6 | from torch import Tensor, reshape, stack 7 | from .backbone import build_backbone 8 | from .modules import TransformerDecoder, Transformer 9 | from einops import rearrange 10 | from torch.nn import Upsample 11 | 12 | 13 | class SpatialAttention(nn.Module): 14 | def __init__(self, kernel_size=7): 15 | super(SpatialAttention, self).__init__() 16 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 17 | padding = 3 if kernel_size == 7 else 1 18 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) # 7,3 3,1 19 | self.sigmoid = nn.Sigmoid() 20 | def forward(self, x): 21 | avg_out = torch.mean(x, dim=1, keepdim=True) 22 | max_out, _ = torch.max(x, dim=1, keepdim=True) 23 | x = torch.cat([avg_out, max_out], dim=1) 24 | x = self.conv1(x) 25 | return self.sigmoid(x) 26 | 27 | class ChannelExchange(BaseModule): 28 | 29 | 30 | def __init__(self, p=2): 31 | super(ChannelExchange, self).__init__() 32 | self.p = p 33 | self.sam = SpatialAttention() 34 | def forward(self, x1, x2): 35 | N, c, h, w = x1.shape 36 | 37 | exchange_map = torch.arange(c) % self.p == 0 38 | exchange_mask = exchange_map.unsqueeze(0).expand((N, -1)) 39 | 40 | out_x1, out_x2 = torch.zeros_like(x1), torch.zeros_like(x2) 41 | out_x1[~exchange_mask, ...] = x1[~exchange_mask, ...] 42 | out_x2[~exchange_mask, ...] = x2[~exchange_mask, ...] 43 | out_x1[exchange_mask, ...] = x2[exchange_mask, ...] 44 | out_x2[exchange_mask, ...] = x1[exchange_mask, ...] 45 | 46 | return out_x1, out_x2 47 | 48 | 49 | class SpatialExchange(BaseModule): 50 | 51 | 52 | def __init__(self, p=2): 53 | super(SpatialExchange, self).__init__() 54 | self.p = p 55 | self.sam = SpatialAttention() 56 | def forward(self, x1, x2): 57 | N, c, h, w = x1.shape 58 | exchange_mask = torch.arange(w) % self.p == 0 59 | 60 | out_x1, out_x2 = torch.zeros_like(x1), torch.zeros_like(x2) 61 | out_x1[..., ~exchange_mask] = x1[..., ~exchange_mask] 62 | out_x2[..., ~exchange_mask] = x2[..., ~exchange_mask] 63 | out_x1[..., exchange_mask] = x2[..., exchange_mask] 64 | out_x2[..., exchange_mask] = x1[..., exchange_mask] 65 | 66 | return out_x1, out_x2 67 | 68 | class token_encoder(nn.Module): 69 | def __init__(self, in_chan = 32, token_len = 4, heads = 8): 70 | super(token_encoder, self).__init__() 71 | self.token_len = token_len 72 | self.conv_a = nn.Conv2d(in_chan, token_len, kernel_size=1, padding=0) 73 | self.pos_embedding = nn.Parameter(torch.randn(1, token_len, in_chan)) 74 | self.transformer = Transformer(dim=in_chan, depth=1, heads=heads, dim_head=64, mlp_dim=64, dropout=0) 75 | def forward(self, x): 76 | b, c, h, w = x.shape 77 | 78 | spatial_attention = self.conv_a(x) 79 | spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous() 80 | spatial_attention = torch.softmax(spatial_attention, dim=-1) 81 | x = x.view([b, c, -1]).contiguous() 82 | 83 | tokens = torch.einsum('bln, bcn->blc', spatial_attention, x) 84 | 85 | tokens += self.pos_embedding 86 | x = self.transformer(tokens) 87 | return x 88 | 89 | class token_decoder(nn.Module): 90 | def __init__(self, in_chan = 32, size = 32, heads = 8): 91 | super(token_decoder, self).__init__() 92 | self.pos_embedding_decoder = nn.Parameter(torch.randn(1, in_chan, size, size)) 93 | self.transformer_decoder = TransformerDecoder(dim=in_chan, depth=1, heads=heads, dim_head=True, mlp_dim=in_chan*2, dropout=0,softmax=in_chan) 94 | 95 | def forward(self, x, m): 96 | b, c, h, w = x.shape 97 | x = x + self.pos_embedding_decoder 98 | x = rearrange(x, 'b c h w -> b (h w) c') 99 | x = self.transformer_decoder(x, m) 100 | x = rearrange(x, 'b (h w) c -> b c h w', h=h) 101 | return x 102 | 103 | 104 | class context_aggregator(nn.Module): 105 | def __init__(self, in_chan=32, size=32): 106 | super(context_aggregator, self).__init__() 107 | self.token_encoder = token_encoder(in_chan=in_chan, token_len=4) 108 | self.token_decoder = token_decoder(in_chan = 32, size = size, heads = 8) 109 | def forward(self, feature): 110 | token = self.token_encoder(feature) 111 | out = self.token_decoder(feature, token) 112 | return out 113 | 114 | class Classifier(nn.Module): 115 | def __init__(self, in_chan=32, n_class=2): 116 | super(Classifier, self).__init__() 117 | self.head = nn.Sequential( 118 | nn.Conv2d(in_chan * 2, in_chan, kernel_size=3, padding=1, stride=1, bias=False), 119 | nn.BatchNorm2d(in_chan), 120 | nn.ReLU(), 121 | nn.Conv2d(in_chan, n_class, kernel_size=3, padding=1, stride=1)) 122 | def forward(self, x): 123 | x = self.head(x) 124 | return x 125 | 126 | class Conv_stride(nn.Module): 127 | def __init__(self,in_channels=32): 128 | super(Conv_stride, self).__init__() 129 | self.conv = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, bias=False) 130 | self.batchNorm = nn.BatchNorm2d(32, momentum=0.1) 131 | self.relu = nn.ReLU(inplace=True) 132 | def forward(self, x): 133 | x = self.conv(x) 134 | x = self.batchNorm(x) 135 | x = self.relu(x) 136 | return x 137 | 138 | class Upsample(nn.Module): 139 | def __init__(self,scale_factor=2): 140 | super(Upsample, self).__init__() 141 | self.conv = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1, padding=0, bias=False) 142 | self.batchNorm = nn.BatchNorm2d(32, momentum=0.1) 143 | self.upsample = nn.Upsample(scale_factor=scale_factor, mode="bicubic", align_corners=True) 144 | def forward(self, x): 145 | x = self.conv(x) 146 | x = self.batchNorm(x) 147 | x = self.upsample(x) 148 | return x 149 | 150 | class CDNet(nn.Module): 151 | def __init__(self, backbone='resnet50', output_stride=16, img_size = 512, img_chan=3, chan_num = 32, n_class =2): 152 | super(CDNet, self).__init__() 153 | BatchNorm = nn.BatchNorm2d 154 | 155 | self.backbone = build_backbone(backbone, output_stride, BatchNorm, img_chan) 156 | 157 | self.CA_s16 = context_aggregator(in_chan=chan_num, size=img_size//16) 158 | self.CA_s8 = context_aggregator(in_chan=chan_num, size=img_size//8) 159 | self.CA_s4 = context_aggregator(in_chan=chan_num, size=img_size//4) 160 | 161 | self.conv_s8 = nn.Conv2d(chan_num*2, chan_num, kernel_size=3, padding=1) 162 | self.conv_s4 = nn.Conv2d(chan_num*3, chan_num, kernel_size=3, padding=1) 163 | 164 | self.upsamplex2 = nn.Upsample(scale_factor=2, mode="bicubic", align_corners=True) 165 | self.upsamplex4 = nn.Upsample(scale_factor=4, mode="bicubic", align_corners=True) 166 | 167 | self.EX = ChannelExchange() 168 | self.SE = SpatialExchange() 169 | 170 | self.liner1 = nn.Linear(1, 1) 171 | self.liner2 = nn.Linear(1, 1) 172 | self.relu = nn.ReLU(inplace=True) 173 | self.sigmoid = nn.Sigmoid() 174 | 175 | self.classifier1 = Classifier(n_class = n_class) 176 | self.classifier2 = Classifier(n_class = n_class) 177 | self.classifier3 = Classifier(n_class = n_class) 178 | 179 | self.sam = SpatialAttention() 180 | 181 | self.conv_stride_1_1 = Conv_stride() 182 | self.conv_stride_1_2 = Conv_stride() 183 | self.conv_stride_1_3 = Conv_stride() 184 | self.conv_stride_2_1 = Conv_stride() 185 | self.conv_stride_2_2 = Conv_stride() 186 | self.conv_stride_2_3 = Conv_stride() 187 | 188 | self.upsample_1_1 = Upsample() 189 | self.upsample_1_2 = Upsample() 190 | self.upsample_1_3 = Upsample(scale_factor=4) 191 | self.upsample_2_1 = Upsample() 192 | self.upsample_2_2 = Upsample() 193 | self.upsample_2_3 = Upsample(scale_factor=4) 194 | 195 | def forward(self, img1, img2): 196 | # CNN backbone, feature extractor 197 | out1_s16, out1_s8, out1_s4 = self.backbone(img1) 198 | out2_s16, out2_s8, out2_s4 = self.backbone(img2) 199 | 200 | out1_s16, out2_s16 = self.EX(out1_s16, out2_s16) 201 | out1_s8, out2_s8 = self.EX(out1_s8, out2_s8) 202 | out1_s4, out2_s4 = self.SE(out1_s4, out2_s4) 203 | 204 | out1_s16 = out1_s16*self.sam(out1_s16) 205 | out2_s16 = out2_s16 * self.sam(out2_s16) 206 | 207 | out1_s4 = out1_s4*self.sam(out1_s4) 208 | out2_s4 = out2_s4 * self.sam(out2_s4) 209 | 210 | out2_s8 = out2_s8*self.sam(out2_s8) 211 | out1_s8 = out1_s8 * self.sam(out1_s8) 212 | 213 | out1_s4_down_1 = self.conv_stride_1_1(out1_s4) 214 | out1_s4_down_2 = self.conv_stride_1_2(out1_s4_down_1) 215 | 216 | out1_s8_up = self.upsample_1_1(out1_s8) 217 | out1_s8_down = self.conv_stride_1_3(out1_s8) 218 | 219 | out1_s16_up_1 = self.upsample_1_2(out1_s16) 220 | out1_s16_up_2 = self.upsample_1_3(out1_s16) 221 | 222 | out1_s16 = out1_s16 + out1_s4_down_2 + out1_s8_down 223 | out1_s8 = out1_s8 + out1_s4_down_1 + out1_s16_up_1 224 | out1_s4 = out1_s4 + out1_s16_up_2 + out1_s8_up 225 | 226 | out2_s4_down_1 = self.conv_stride_2_1(out2_s4) 227 | out2_s4_down_2 = self.conv_stride_2_2(out2_s4_down_1) 228 | 229 | out2_s8_up = self.upsample_2_1(out2_s8) 230 | out2_s8_down = self.conv_stride_2_3(out2_s8) 231 | 232 | out2_s16_up_1 = self.upsample_2_2(out2_s16) 233 | out2_s16_up_2 = self.upsample_2_3(out2_s16) 234 | 235 | out2_s16 = out2_s16 + out2_s4_down_2 + out2_s8_down 236 | out2_s8 = out2_s8 + out2_s4_down_1 + out2_s16_up_1 237 | out2_s4 = out2_s4 + out2_s16_up_2 + out2_s8_up 238 | 239 | x1_s16 = self.CA_s16(out1_s16) 240 | x2_s16 = self.CA_s16(out2_s16) # [8,32,32,32] 241 | 242 | x1_s8 = self.CA_s8(out1_s8) 243 | x2_s8 = self.CA_s8(out2_s8) # [8,32,64,64] 244 | 245 | x1 = self.CA_s4(out1_s4) # [8,32,128,128] 246 | x2 = self.CA_s4(out2_s4) 247 | 248 | x_s16 = x1_s16 + x2_s16 249 | weight_s16 = F.adaptive_avg_pool2d(x_s16, (1, 1)) 250 | 251 | x_s8 = x1_s8 + x2_s8 252 | weight_s8 = F.adaptive_avg_pool2d(x_s8, (1, 1)) 253 | 254 | x_s = x1 + x2 255 | weight_sx = F.adaptive_avg_pool2d(x_s, (1, 1)) 256 | 257 | weight = weight_sx + weight_s8 + weight_s16 258 | 259 | weight = self.liner1(weight) 260 | weight = self.relu(weight) 261 | weight = self.liner2(weight) 262 | weight = self.sigmoid(weight) 263 | 264 | x1_s16 = weight * x1_s16 265 | x2_s16 = weight * x2_s16 266 | 267 | x1_s8 = weight * x1_s8 268 | x2_s8 = weight * x2_s8 269 | 270 | x1 = weight * x1 271 | x2 = weight * x2 272 | 273 | x16 = torch.cat([x1_s16, x2_s16], dim=1) 274 | x8 = torch.cat([x1_s8, x2_s8], dim=1) 275 | x = torch.cat([x1, x2], dim=1) 276 | 277 | x16 = F.interpolate(x16, size=img1.shape[2:], mode='bicubic', align_corners=True) 278 | x8 = F.interpolate(x8, size=img1.shape[2:], mode='bicubic', align_corners=True) 279 | x = F.interpolate(x, size=img1.shape[2:], mode='bicubic', align_corners=True) 280 | 281 | x = self.classifier3(x) 282 | x8 = self.classifier2(x8) 283 | x16 = self.classifier1(x16) 284 | 285 | return x, x8, x16 286 | 287 | def freeze_bn(self): 288 | for m in self.modules(): 289 | if isinstance(m, nn.BatchNorm2d): 290 | m.eval() 291 | -------------------------------------------------------------------------------- /model/siamunet_conc.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | class SiamUnet_conc(nn.Module): 11 | """SiamUnet_conc segmentation network.""" 12 | 13 | def __init__(self, input_nbr, label_nbr): 14 | super(SiamUnet_conc, self).__init__() 15 | 16 | self.input_nbr = input_nbr 17 | 18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 19 | self.bn11 = nn.BatchNorm2d(16) 20 | self.do11 = nn.Dropout2d(p=0.2) 21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 22 | self.bn12 = nn.BatchNorm2d(16) 23 | self.do12 = nn.Dropout2d(p=0.2) 24 | 25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 26 | self.bn21 = nn.BatchNorm2d(32) 27 | self.do21 = nn.Dropout2d(p=0.2) 28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 29 | self.bn22 = nn.BatchNorm2d(32) 30 | self.do22 = nn.Dropout2d(p=0.2) 31 | 32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.bn31 = nn.BatchNorm2d(64) 34 | self.do31 = nn.Dropout2d(p=0.2) 35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn32 = nn.BatchNorm2d(64) 37 | self.do32 = nn.Dropout2d(p=0.2) 38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 39 | self.bn33 = nn.BatchNorm2d(64) 40 | self.do33 = nn.Dropout2d(p=0.2) 41 | 42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 43 | self.bn41 = nn.BatchNorm2d(128) 44 | self.do41 = nn.Dropout2d(p=0.2) 45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn42 = nn.BatchNorm2d(128) 47 | self.do42 = nn.Dropout2d(p=0.2) 48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 49 | self.bn43 = nn.BatchNorm2d(128) 50 | self.do43 = nn.Dropout2d(p=0.2) 51 | 52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 53 | 54 | self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1) 55 | self.bn43d = nn.BatchNorm2d(128) 56 | self.do43d = nn.Dropout2d(p=0.2) 57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 58 | self.bn42d = nn.BatchNorm2d(128) 59 | self.do42d = nn.Dropout2d(p=0.2) 60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 61 | self.bn41d = nn.BatchNorm2d(64) 62 | self.do41d = nn.Dropout2d(p=0.2) 63 | 64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 65 | 66 | self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1) 67 | self.bn33d = nn.BatchNorm2d(64) 68 | self.do33d = nn.Dropout2d(p=0.2) 69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 70 | self.bn32d = nn.BatchNorm2d(64) 71 | self.do32d = nn.Dropout2d(p=0.2) 72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 73 | self.bn31d = nn.BatchNorm2d(32) 74 | self.do31d = nn.Dropout2d(p=0.2) 75 | 76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 77 | 78 | self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1) 79 | self.bn22d = nn.BatchNorm2d(32) 80 | self.do22d = nn.Dropout2d(p=0.2) 81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 82 | self.bn21d = nn.BatchNorm2d(16) 83 | self.do21d = nn.Dropout2d(p=0.2) 84 | 85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 86 | 87 | self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1) 88 | self.bn12d = nn.BatchNorm2d(16) 89 | self.do12d = nn.Dropout2d(p=0.2) 90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 91 | 92 | self.sm = nn.LogSoftmax(dim=1) 93 | 94 | def forward(self, x1, x2): 95 | 96 | """Forward method.""" 97 | # Stage 1 98 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) 99 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 100 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 101 | 102 | 103 | # Stage 2 104 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 105 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 106 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 107 | 108 | # Stage 3 109 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 110 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 111 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 112 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 113 | 114 | # Stage 4 115 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 116 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 117 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 118 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 119 | 120 | 121 | #################################################### 122 | # Stage 1 123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 126 | 127 | # Stage 2 128 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 129 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 130 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 131 | 132 | # Stage 3 133 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 134 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 135 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 136 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 137 | 138 | # Stage 4 139 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 140 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 141 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 142 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 143 | 144 | 145 | #################################################### 146 | # Stage 4d 147 | x4d = self.upconv4(x4p) 148 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 149 | x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1) 150 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 151 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 152 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 153 | 154 | # Stage 3d 155 | x3d = self.upconv3(x41d) 156 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 157 | x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1) 158 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 159 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 160 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 161 | 162 | # Stage 2d 163 | x2d = self.upconv2(x31d) 164 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 165 | x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1) 166 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 167 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 168 | 169 | # Stage 1d 170 | x1d = self.upconv1(x21d) 171 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 172 | x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1) 173 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 174 | x11d = self.conv11d(x12d) 175 | 176 | return self.sm(x11d) 177 | 178 | 179 | -------------------------------------------------------------------------------- /model/siamunet_diff.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | class SiamUnet_diff(nn.Module): 11 | """SiamUnet_diff segmentation network.""" 12 | 13 | def __init__(self, input_nbr, label_nbr): 14 | super(SiamUnet_diff, self).__init__() 15 | 16 | self.input_nbr = input_nbr 17 | 18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 19 | self.bn11 = nn.BatchNorm2d(16) 20 | self.do11 = nn.Dropout2d(p=0.2) 21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 22 | self.bn12 = nn.BatchNorm2d(16) 23 | self.do12 = nn.Dropout2d(p=0.2) 24 | 25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 26 | self.bn21 = nn.BatchNorm2d(32) 27 | self.do21 = nn.Dropout2d(p=0.2) 28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 29 | self.bn22 = nn.BatchNorm2d(32) 30 | self.do22 = nn.Dropout2d(p=0.2) 31 | 32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.bn31 = nn.BatchNorm2d(64) 34 | self.do31 = nn.Dropout2d(p=0.2) 35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn32 = nn.BatchNorm2d(64) 37 | self.do32 = nn.Dropout2d(p=0.2) 38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 39 | self.bn33 = nn.BatchNorm2d(64) 40 | self.do33 = nn.Dropout2d(p=0.2) 41 | 42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 43 | self.bn41 = nn.BatchNorm2d(128) 44 | self.do41 = nn.Dropout2d(p=0.2) 45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn42 = nn.BatchNorm2d(128) 47 | self.do42 = nn.Dropout2d(p=0.2) 48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 49 | self.bn43 = nn.BatchNorm2d(128) 50 | self.do43 = nn.Dropout2d(p=0.2) 51 | 52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 53 | 54 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) 55 | self.bn43d = nn.BatchNorm2d(128) 56 | self.do43d = nn.Dropout2d(p=0.2) 57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 58 | self.bn42d = nn.BatchNorm2d(128) 59 | self.do42d = nn.Dropout2d(p=0.2) 60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 61 | self.bn41d = nn.BatchNorm2d(64) 62 | self.do41d = nn.Dropout2d(p=0.2) 63 | 64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 65 | 66 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 67 | self.bn33d = nn.BatchNorm2d(64) 68 | self.do33d = nn.Dropout2d(p=0.2) 69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 70 | self.bn32d = nn.BatchNorm2d(64) 71 | self.do32d = nn.Dropout2d(p=0.2) 72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 73 | self.bn31d = nn.BatchNorm2d(32) 74 | self.do31d = nn.Dropout2d(p=0.2) 75 | 76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 77 | 78 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 79 | self.bn22d = nn.BatchNorm2d(32) 80 | self.do22d = nn.Dropout2d(p=0.2) 81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 82 | self.bn21d = nn.BatchNorm2d(16) 83 | self.do21d = nn.Dropout2d(p=0.2) 84 | 85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 86 | 87 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 88 | self.bn12d = nn.BatchNorm2d(16) 89 | self.do12d = nn.Dropout2d(p=0.2) 90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 91 | 92 | self.sm = nn.LogSoftmax(dim=1) 93 | 94 | def forward(self, x1, x2): 95 | 96 | 97 | """Forward method.""" 98 | # Stage 1 99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) 100 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 101 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 102 | 103 | 104 | # Stage 2 105 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 106 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 107 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 108 | 109 | # Stage 3 110 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 111 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 112 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 113 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 114 | 115 | # Stage 4 116 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 117 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 118 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 119 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 120 | 121 | #################################################### 122 | # Stage 1 123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 126 | 127 | 128 | # Stage 2 129 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 130 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 131 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 132 | 133 | # Stage 3 134 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 135 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 136 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 137 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 138 | 139 | # Stage 4 140 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 141 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 142 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 143 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 144 | 145 | 146 | 147 | # Stage 4d 148 | x4d = self.upconv4(x4p) 149 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 150 | x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1) 151 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 152 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 153 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 154 | 155 | # Stage 3d 156 | x3d = self.upconv3(x41d) 157 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 158 | x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1) 159 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 160 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 161 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 162 | 163 | # Stage 2d 164 | x2d = self.upconv2(x31d) 165 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 166 | x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1) 167 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 168 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 169 | 170 | # Stage 1d 171 | x1d = self.upconv1(x21d) 172 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 173 | x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1) 174 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 175 | x11d = self.conv11d(x12d) 176 | 177 | return self.sm(x11d) 178 | 179 | 180 | -------------------------------------------------------------------------------- /model/unet.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | class Unet(nn.Module): 11 | """EF segmentation network.""" 12 | 13 | def __init__(self, input_nbr, label_nbr): 14 | super(Unet, self).__init__() 15 | 16 | self.input_nbr = input_nbr 17 | 18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 19 | self.bn11 = nn.BatchNorm2d(16) 20 | self.do11 = nn.Dropout2d(p=0.2) 21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 22 | self.bn12 = nn.BatchNorm2d(16) 23 | self.do12 = nn.Dropout2d(p=0.2) 24 | 25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 26 | self.bn21 = nn.BatchNorm2d(32) 27 | self.do21 = nn.Dropout2d(p=0.2) 28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 29 | self.bn22 = nn.BatchNorm2d(32) 30 | self.do22 = nn.Dropout2d(p=0.2) 31 | 32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.bn31 = nn.BatchNorm2d(64) 34 | self.do31 = nn.Dropout2d(p=0.2) 35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn32 = nn.BatchNorm2d(64) 37 | self.do32 = nn.Dropout2d(p=0.2) 38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 39 | self.bn33 = nn.BatchNorm2d(64) 40 | self.do33 = nn.Dropout2d(p=0.2) 41 | 42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 43 | self.bn41 = nn.BatchNorm2d(128) 44 | self.do41 = nn.Dropout2d(p=0.2) 45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn42 = nn.BatchNorm2d(128) 47 | self.do42 = nn.Dropout2d(p=0.2) 48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 49 | self.bn43 = nn.BatchNorm2d(128) 50 | self.do43 = nn.Dropout2d(p=0.2) 51 | 52 | 53 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 54 | 55 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) 56 | self.bn43d = nn.BatchNorm2d(128) 57 | self.do43d = nn.Dropout2d(p=0.2) 58 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 59 | self.bn42d = nn.BatchNorm2d(128) 60 | self.do42d = nn.Dropout2d(p=0.2) 61 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 62 | self.bn41d = nn.BatchNorm2d(64) 63 | self.do41d = nn.Dropout2d(p=0.2) 64 | 65 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 66 | 67 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 68 | self.bn33d = nn.BatchNorm2d(64) 69 | self.do33d = nn.Dropout2d(p=0.2) 70 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 71 | self.bn32d = nn.BatchNorm2d(64) 72 | self.do32d = nn.Dropout2d(p=0.2) 73 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 74 | self.bn31d = nn.BatchNorm2d(32) 75 | self.do31d = nn.Dropout2d(p=0.2) 76 | 77 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 78 | 79 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 80 | self.bn22d = nn.BatchNorm2d(32) 81 | self.do22d = nn.Dropout2d(p=0.2) 82 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 83 | self.bn21d = nn.BatchNorm2d(16) 84 | self.do21d = nn.Dropout2d(p=0.2) 85 | 86 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 87 | 88 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 89 | self.bn12d = nn.BatchNorm2d(16) 90 | self.do12d = nn.Dropout2d(p=0.2) 91 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 92 | 93 | self.sm = nn.LogSoftmax(dim=1) 94 | 95 | def forward(self, x1, x2): 96 | 97 | x = torch.cat((x1, x2), 1) 98 | 99 | """Forward method.""" 100 | # Stage 1 101 | x11 = self.do11(F.relu(self.bn11(self.conv11(x)))) 102 | x12 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 103 | x1p = F.max_pool2d(x12, kernel_size=2, stride=2) 104 | 105 | # Stage 2 106 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 107 | x22 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 108 | x2p = F.max_pool2d(x22, kernel_size=2, stride=2) 109 | 110 | # Stage 3 111 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 112 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 113 | x33 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 114 | x3p = F.max_pool2d(x33, kernel_size=2, stride=2) 115 | 116 | # Stage 4 117 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 118 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 119 | x43 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 120 | x4p = F.max_pool2d(x43, kernel_size=2, stride=2) 121 | 122 | 123 | # Stage 4d 124 | x4d = self.upconv4(x4p) 125 | pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2))) 126 | x4d = torch.cat((pad4(x4d), x43), 1) 127 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 128 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 129 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 130 | 131 | # Stage 3d 132 | x3d = self.upconv3(x41d) 133 | pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2))) 134 | x3d = torch.cat((pad3(x3d), x33), 1) 135 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 136 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 137 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 138 | 139 | # Stage 2d 140 | x2d = self.upconv2(x31d) 141 | pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2))) 142 | x2d = torch.cat((pad2(x2d), x22), 1) 143 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 144 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 145 | 146 | # Stage 1d 147 | x1d = self.upconv1(x21d) 148 | pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2))) 149 | x1d = torch.cat((pad1(x1d), x12), 1) 150 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 151 | x11d = self.conv11d(x12d) 152 | 153 | return self.sm(x11d) 154 | 155 | 156 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import torch.optim as optim 4 | import torch.utils.data 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | from data_utils import LoadDatasetFromFolder, DA_DatasetFromFolder, calMetric_iou,TestDatasetFromFolder 8 | import numpy as np 9 | import random 10 | from model.network import CDNet 11 | from train_options_HRSCD import parser 12 | import itertools 13 | from loss.losses import cross_entropy 14 | from ever import opt 15 | import time 16 | from collections import OrderedDict 17 | import ever as er 18 | import cv2 as cv 19 | import numpy as np 20 | import logging 21 | from PIL import Image 22 | import matplotlib.pyplot as plt 23 | 24 | args = parser.parse_args() 25 | os.environ["CUDA_VISIBLE_DEVICES"] = str(0) 26 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 27 | 28 | def seed_torch(seed=2022): 29 | random.seed(seed) 30 | os.environ['PYTHONHASHSEED'] = str(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | seed_torch(2022) 35 | 36 | COLOR_MAP = OrderedDict( 37 | Background=(255, 255, 255), 38 | Building=(255, 0, 0), 39 | ) 40 | 41 | def tes(mloss): 42 | CDNet.eval() 43 | with torch.no_grad(): 44 | test_bar = tqdm(test_loader) 45 | testing_results = {'batch_sizes': 0, 'IoU': 0} 46 | 47 | metric_op = er.metric.PixelMetric(2, logdir=None, logger=None) 48 | for hr_img1, hr_img2, label, name in test_bar: 49 | testing_results['batch_sizes'] += args.val_batchsize 50 | 51 | hr_img1 = hr_img1.to(device, dtype=torch.float) 52 | hr_img2 = hr_img2.to(device, dtype=torch.float) 53 | label = label.to(device, dtype=torch.float) 54 | 55 | label = torch.argmax(label, 1).unsqueeze(1).float() 56 | 57 | cd_map, _, _ = CDNet(hr_img1, hr_img2) 58 | cd_map = torch.argmax(cd_map, 1).unsqueeze(1).float() 59 | 60 | 61 | gt_value = (label > 0).float() 62 | prob = (cd_map > 0).float() 63 | prob = prob.cpu().detach().numpy() 64 | 65 | gt_value = gt_value.cpu().detach().numpy() 66 | gt_value = np.squeeze(gt_value) 67 | result = np.squeeze(prob) 68 | metric_op.forward(gt_value, result) 69 | re = metric_op.summary_all() 70 | 71 | CDNet.train() 72 | return mloss 73 | 74 | import cv2 as cv 75 | if __name__ == '__main__': 76 | mloss = 0 77 | 78 | test_set = TestDatasetFromFolder(args, args.hr1_test, args.hr2_test, args.lab_test) 79 | 80 | test_loader = DataLoader(dataset=test_set, num_workers=args.num_workers, batch_size=args.val_batchsize, shuffle=True) 81 | 82 | # define model 83 | CDNet = CDNet(img_size = args.img_size).to(device, dtype=torch.float) 84 | 85 | CDNet.load_state_dict(torch.load('/./best.pth'),strict=False) 86 | mloss = tes(mloss) 87 | 88 | 89 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import torch.optim as optim 4 | import torch.utils.data 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | from data_utils import DA_DatasetFromFolder, TestDatasetFromFolder, LoadDatasetFromFolder 8 | import numpy as np 9 | import random 10 | from model.network import CDNet 11 | from train_options import parser 12 | import itertools 13 | from loss.losses import cross_entropy 14 | import time 15 | from collections import OrderedDict 16 | import ever as er 17 | 18 | args = parser.parse_args() 19 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | def seed_torch(seed=2022): 23 | random.seed(seed) 24 | os.environ['PYTHONHASHSEED'] = str(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | seed_torch(2022) 29 | 30 | COLOR_MAP = OrderedDict( 31 | Background=(255, 255, 255), 32 | Building=(255, 0, 0), 33 | ) 34 | 35 | def val(mloss): 36 | CDNet.eval() 37 | with torch.no_grad(): 38 | val_bar = tqdm(test_loader) 39 | valing_results = {'batch_sizes': 0, 'IoU': 0} 40 | 41 | metric_op = er.metric.PixelMetric(2, logdir=None, logger=None) 42 | for hr_img1, hr_img2, label, name in val_bar: 43 | 44 | valing_results['batch_sizes'] += args.val_batchsize 45 | hr_img1 = hr_img1.to(device, dtype=torch.float) 46 | hr_img2 = hr_img2.to(device, dtype=torch.float) 47 | label = label.to(device, dtype=torch.float) 48 | 49 | label = torch.argmax(label, 1).unsqueeze(1).float() 50 | 51 | cd_map, _, _ = CDNet(hr_img1, hr_img2) 52 | 53 | cd_map = torch.argmax(cd_map, 1).unsqueeze(1).float() 54 | 55 | gt_value = (label > 0).float() 56 | prob = (cd_map > 0).float() 57 | prob = prob.cpu().detach().numpy() 58 | 59 | gt_value = gt_value.cpu().detach().numpy() 60 | gt_value = np.squeeze(gt_value) 61 | result = np.squeeze(prob) 62 | metric_op.forward(gt_value, result) 63 | re = metric_op.summary_all() 64 | 65 | test_loss = re.rows[1][1] 66 | if test_loss > mloss or epoch == 1: 67 | torch.save(CDNet.state_dict(), args.model_dir +'_'+ str(test_loss) +'_best.pth') 68 | CDNet.train() 69 | return test_loss 70 | 71 | def train_epoch(): 72 | CDNet.train() 73 | for hr_img1, hr_img2, label in train_bar: 74 | running_results['batch_sizes'] += args.batchsize 75 | 76 | hr_img1 = hr_img1.to(device, dtype=torch.float) 77 | hr_img2 = hr_img2.to(device, dtype=torch.float) 78 | label = label.to(device, dtype=torch.float) 79 | label = torch.argmax(label, 1).unsqueeze(1).float() 80 | 81 | result1, result2, result3 = CDNet(hr_img1, hr_img2) 82 | 83 | CD_loss = CDcriterionCD(result1, label) + CDcriterionCD(result2, label) + CDcriterionCD(result3, label) 84 | 85 | CDNet.zero_grad() 86 | CD_loss.backward() 87 | optimizer.step() 88 | 89 | running_results['CD_loss'] += CD_loss.item() * args.batchsize 90 | 91 | train_bar.set_description( 92 | desc='[%d/%d] loss: %.4f' % ( 93 | epoch, args.num_epochs, 94 | running_results['CD_loss'] / running_results['batch_sizes'],)) 95 | 96 | 97 | if __name__ == '__main__': 98 | mloss = 0 99 | 100 | # load data 101 | train_set = DA_DatasetFromFolder(args.hr1_train, args.hr2_train, args.lab_train, crop=False) 102 | test_set = TestDatasetFromFolder(args, args.hr1_test, args.hr2_test, args.lab_test) 103 | 104 | train_loader = DataLoader(dataset=train_set, num_workers=args.num_workers, batch_size=args.batchsize, shuffle=True) 105 | test_loader = DataLoader(dataset=test_set, num_workers=args.num_workers, batch_size=args.val_batchsize, shuffle=True) 106 | 107 | # define model 108 | CDNet = CDNet(img_size = args.img_size).to(device, dtype=torch.float) 109 | 110 | # set optimization 111 | optimizer = optim.AdamW(itertools.chain(CDNet.parameters()), lr= args.lr, betas=(0.9, 0.999),weight_decay=0.01) 112 | CDcriterionCD = cross_entropy().to(device, dtype=torch.float) 113 | 114 | # training 115 | for epoch in range(1, args.num_epochs + 1): 116 | train_bar = tqdm(train_loader) 117 | running_results = {'batch_sizes': 0, 'SR_loss':0, 'CD_loss':0, 'loss': 0 } 118 | train_epoch() 119 | mloss = val(mloss) 120 | 121 | 122 | -------------------------------------------------------------------------------- /train_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | file_train_path_time1 = '/./train/time1' 3 | file_train_path_time2 = '/./train/time2' 4 | file_train_path_label = '/./train/label' 5 | 6 | file_test_path_time1 = '/./test/time1' 7 | file_test_path_time2 = '/./test/time2' 8 | file_test_path_label = '/./test/label' 9 | 10 | #training options 11 | parser = argparse.ArgumentParser(description='Training Change Detection Network') 12 | 13 | # training parameters 14 | parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number') 15 | parser.add_argument('--batchsize', default=8, type=int, help='batchsize') 16 | parser.add_argument('--val_batchsize', default=8, type=int, help='batchsize for validation') 17 | parser.add_argument('--num_workers', default=24, type=int, help='num of workers') 18 | parser.add_argument('--n_class', default=2, type=int, help='number of class') 19 | parser.add_argument('--gpu_id', default="0", type=str, help='which gpu to run.') 20 | parser.add_argument('--suffix', default=['.png','.jpg','.tif'], type=list, help='the suffix of the image files.') 21 | parser.add_argument('--img_size', default=256, type=int, help='imagesize') 22 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') 23 | 24 | 25 | # path for loading data from folder 26 | parser.add_argument('--hr1_train', default= file_train_path_time1, type=str, help='image at t1 in train set') 27 | parser.add_argument('--hr2_train', default= file_train_path_time2, type=str, help='image at t2 in train set') 28 | parser.add_argument('--lab_train', default= file_train_path_label, type=str, help='label image in train set') 29 | # # """ 30 | 31 | parser.add_argument('--hr1_test', default= file_test_path_time1, type=str, help='image at t1 in test set') 32 | parser.add_argument('--hr2_test', default= file_test_path_time2, type=str, help='image at t2 in test set') 33 | parser.add_argument('--lab_test', default= file_test_path_label, type=str, help='label image in test set') 34 | 35 | # network saving 36 | parser.add_argument('--model_dir', default='/./', type=str, help='model save path') 37 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if isinstance(v, torch.Tensor): 88 | v = v.item() 89 | assert isinstance(v, (float, int)) 90 | self.meters[k].update(v) 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | i = 0 117 | if not header: 118 | header = '' 119 | start_time = time.time() 120 | end = time.time() 121 | iter_time = SmoothedValue(fmt='{avg:.4f}') 122 | data_time = SmoothedValue(fmt='{avg:.4f}') 123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 124 | log_msg = [ 125 | header, 126 | '[{0' + space_fmt + '}/{1}]', 127 | 'eta: {eta}', 128 | '{meters}', 129 | 'time: {time}', 130 | 'data: {data}' 131 | ] 132 | if torch.cuda.is_available(): 133 | log_msg.append('max mem: {memory:.0f}') 134 | log_msg = self.delimiter.join(log_msg) 135 | MB = 1024.0 * 1024.0 136 | for obj in iterable: 137 | data_time.update(time.time() - end) 138 | yield obj 139 | iter_time.update(time.time() - end) 140 | if i % print_freq == 0 or i == len(iterable) - 1: 141 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 143 | if torch.cuda.is_available(): 144 | print(log_msg.format( 145 | i, len(iterable), eta=eta_string, 146 | meters=str(self), 147 | time=str(iter_time), data=str(data_time), 148 | memory=torch.cuda.max_memory_allocated() / MB)) 149 | else: 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time))) 154 | i += 1 155 | end = time.time() 156 | total_time = time.time() - start_time 157 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 158 | print('{} Total time: {} ({:.4f} s / it)'.format( 159 | header, total_time_str, total_time / len(iterable))) 160 | 161 | 162 | def _load_checkpoint_for_ema(model_ema, checkpoint): 163 | """ 164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 165 | """ 166 | mem_file = io.BytesIO() 167 | torch.save(checkpoint, mem_file) 168 | mem_file.seek(0) 169 | model_ema._load_checkpoint(mem_file) 170 | 171 | 172 | def setup_for_distributed(is_master): 173 | """ 174 | This function disables printing when not in master process 175 | """ 176 | import builtins as __builtin__ 177 | builtin_print = __builtin__.print 178 | 179 | def print(*args, **kwargs): 180 | force = kwargs.pop('force', False) 181 | if is_master or force: 182 | builtin_print(*args, **kwargs) 183 | 184 | __builtin__.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 218 | args.rank = int(os.environ["RANK"]) 219 | args.world_size = int(os.environ['WORLD_SIZE']) 220 | args.gpu = int(os.environ['LOCAL_RANK']) 221 | elif 'SLURM_PROCID' in os.environ: 222 | args.rank = int(os.environ['SLURM_PROCID']) 223 | args.gpu = args.rank % torch.cuda.device_count() 224 | else: 225 | print('Not using distributed mode') 226 | args.distributed = False 227 | return 228 | 229 | args.distributed = True 230 | 231 | torch.cuda.set_device(args.gpu) 232 | args.dist_backend = 'nccl' 233 | print('| distributed init (rank {}): {}'.format( 234 | args.rank, args.dist_url), flush=True) 235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 236 | world_size=args.world_size, rank=args.rank) 237 | torch.distributed.barrier() 238 | setup_for_distributed(args.rank == 0) 239 | -------------------------------------------------------------------------------- /visualize_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | import sys 5 | import warnings 6 | from pathlib import Path 7 | 8 | import cv2 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | 13 | class CDVisualization(object): 14 | def __init__(self, policy=['compare_pixel', 'pixel']): 15 | """Change Detection Visualization 16 | 17 | Args: 18 | policy (list, optional): _description_. Defaults to ['compare_pixel', 'pixel']. 19 | """ 20 | super().__init__() 21 | assert isinstance(policy, list) 22 | self.policy = policy 23 | self.num_classes = 2 24 | self.COLOR_MAP = {'0': (0, 0, 0), # black is TN 25 | '1': (0, 255, 0), # green is FP 误检 26 | '2': (255, 0, 0), # red is FN 漏检 27 | '3': (255, 255, 255)} # white is TP 28 | 29 | def read_and_check_label(self, file_name): 30 | img = cv2.imread(file_name, cv2.IMREAD_GRAYSCALE) 31 | if np.max(img) >= self.num_classes: 32 | warnings.warn('Please make sure the range of the pixel value' \ 33 | 'in your pred or gt.') 34 | img[img < 128] = 0 35 | img[img >= 128] = 1 36 | return img 37 | 38 | def read_img(self, file_name): 39 | img = cv2.imread(file_name, cv2.IMREAD_UNCHANGED) 40 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 41 | return img 42 | 43 | def save_img(self, file_name, vis_res, imgs=None): 44 | dir_name = os.path.dirname(file_name) 45 | base_name, suffix = os.path.basename(file_name).split('.') 46 | os.makedirs(dir_name, exist_ok=True) 47 | 48 | # consistent with the original image 49 | vis_res = cv2.cvtColor(vis_res, cv2.COLOR_RGB2BGR) 50 | if imgs is not None: 51 | assert isinstance(imgs, list), '`imgs` must be a list.' 52 | if vis_res.shape != imgs[0].shape: 53 | vis_res = cv2.cvtColor(vis_res, cv2.COLOR_GRAY2RGB) 54 | for idx, img in enumerate(imgs): 55 | assert img.shape == vis_res.shape, '`img` and `vis_res` must be ' \ 56 | 'of the same shape.' 57 | cv2.imwrite(osp.join(dir_name, base_name + '_' + str(idx) + '.' + suffix), 58 | cv2.addWeighted(img, 1, vis_res, 0.25, 0.0)) 59 | else: 60 | cv2.imwrite(file_name, vis_res) 61 | 62 | def trainIdToColor(self, trainId): 63 | """convert label id to color 64 | 65 | Args: 66 | trainId (int): _description_ 67 | 68 | Returns: 69 | color (tuple) 70 | """ 71 | color = self.COLOR_MAP[str(trainId)] 72 | return color 73 | 74 | def gray2color(self, grayImage: np.ndarray, num_class: list): 75 | """convert label to color image 76 | 77 | Args: 78 | grayImage (np.ndarray): _description_ 79 | num_class (list): _description_ 80 | 81 | Returns: 82 | _type_: _description_ 83 | """ 84 | rgbImage = np.zeros((grayImage.shape[0], grayImage.shape[1], 3), dtype='uint8') 85 | for cls in num_class: 86 | row, col = np.where(grayImage == cls) 87 | if (len(row) == 0): 88 | continue 89 | color = self.trainIdToColor(cls) 90 | rgbImage[row, col] = color 91 | return rgbImage 92 | 93 | def res_pixel_visual(self, label): 94 | assert np.max(label) < self.num_classes, 'There exists the value of ' \ 95 | '`label` that is greater than `num_classes`' 96 | label_rgb = self.gray2color(label, num_class=list(range(self.num_classes))) 97 | return label_rgb 98 | 99 | def res_compare_pixel_visual(self, pred, gt): 100 | """visualize according to confusion matrix. 101 | 102 | Args: 103 | pred (_type_): _description_ 104 | gt (_type_): _description_ 105 | """ 106 | assert np.max(pred) < self.num_classes, 'There exists the value of ' \ 107 | '`pred` that is greater than `num_classes`' 108 | assert np.max(gt) < self.num_classes, 'There exists the value of ' \ 109 | '`gt` that is greater than `num_classes`' 110 | 111 | visual_ = self.num_classes * gt.astype(int) + pred.astype(int) 112 | visual_rgb = self.gray2color(visual_, num_class=list(range(self.num_classes ** 2))) 113 | return visual_rgb 114 | 115 | def res_compare_boundary_visual(self): 116 | pass 117 | 118 | def __call__(self, pred_path, gt_path, dst_path, imgs=None): 119 | # dst_prefix, dst_suffix = osp.abspath(dst_path).split('.') 120 | # dst_path = dst_prefix + '_{}.' + dst_suffix 121 | file_name = osp.basename(dst_path) 122 | dst_path = osp.dirname(dst_path) 123 | 124 | pred = self.read_and_check_label(pred_path) 125 | gt = self.read_and_check_label(gt_path) 126 | if imgs is not None: 127 | assert isinstance(imgs, list), '`imgs` must be a list.' 128 | imgs = [self.read_img(p) for p in imgs] 129 | 130 | for pol in self.policy: 131 | dst_path_pol = osp.join(dst_path, pol) 132 | os.makedirs(dst_path_pol, exist_ok=True) 133 | dst_file = osp.join(dst_path_pol, file_name) 134 | if pol == 'compare_pixel': 135 | visual_map = self.res_compare_pixel_visual(pred, gt) 136 | self.save_img(dst_file, visual_map, None) 137 | elif pol == 'pixel': 138 | visual_map = self.res_pixel_visual(pred) 139 | self.save_img(dst_file, visual_map, imgs) 140 | else: 141 | raise ValueError(f'Invalid policy {pol}') 142 | 143 | 144 | if __name__ == '__main__': 145 | gt_dir = '/./label' 146 | pred_dir = '/./pre' 147 | dst_dir = '/./save' 148 | img_dir = ['/./test/time1', 149 | '/./time2'] 150 | 151 | 152 | file_name_list = os.listdir(gt_dir) 153 | 154 | for file_name in file_name_list: 155 | 156 | CDVisual = CDVisualization(policy=['compare_pixel']) 157 | CDVisual(osp.join(pred_dir, file_name), 158 | osp.join(gt_dir, file_name), 159 | osp.join(dst_dir, file_name), 160 | [osp.join(p, file_name) for p in img_dir]) 161 | 162 | --------------------------------------------------------------------------------