├── README.md ├── __pycache__ ├── args_fusion.cpython-38.pyc ├── function.cpython-38.pyc ├── loss.cpython-38.pyc ├── net.cpython-38.pyc ├── t2t_vit.cpython-38.pyc ├── utils.cpython-38.pyc └── vit.cpython-38.pyc ├── args_fusion.py ├── data_loader.py ├── function.py ├── images ├── Test_ir │ ├── 1.bmp │ ├── 10.bmp │ ├── 2.bmp │ ├── 3.bmp │ ├── 4.bmp │ ├── 5.bmp │ ├── 6.bmp │ ├── 7.bmp │ ├── 8.bmp │ └── 9.bmp └── Test_vi │ ├── 1.bmp │ ├── 10.bmp │ ├── 2.bmp │ ├── 3.bmp │ ├── 4.bmp │ ├── 5.bmp │ ├── 6.bmp │ ├── 7.bmp │ ├── 8.bmp │ └── 9.bmp ├── ipt.py ├── loss.py ├── models └── Baiduyun ├── net.py ├── outputs ├── fusion_1.png ├── fusion_10.png ├── fusion_2.png ├── fusion_3.png ├── fusion_4.png ├── fusion_5.png ├── fusion_6.png ├── fusion_7.png ├── fusion_8.png └── fusion_9.png ├── pytorch_msssim ├── __init__.py └── __pycache__ │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-38.pyc ├── t2t_vit.py ├── test.py ├── train.py ├── utils.py ├── utils ├── data_vis.py └── dataset.py └── vit.py /README.md: -------------------------------------------------------------------------------- 1 | # TGFuse 2 | The code of TGFuse 3 | -------------------------------------------------------------------------------- /__pycache__/args_fusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/__pycache__/args_fusion.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/function.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/__pycache__/function.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/__pycache__/net.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/t2t_vit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/__pycache__/t2t_vit.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/vit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/__pycache__/vit.cpython-38.pyc -------------------------------------------------------------------------------- /args_fusion.py: -------------------------------------------------------------------------------- 1 | 2 | class args(): 3 | 4 | # training args 5 | g1 = 0.5 6 | epochs = 50#"number of training epochs, default is 2" 7 | batch_size = 16 #"batch size for training, default is 4"E:\database\KAIST-database 8 | # dataset = "/data/Disk_B/MSCOCO2014/train2014" 9 | dataset2 = "D:\\file\paper\dataset\dataset\\train\ir" 10 | train_num = 40000 11 | 12 | HEIGHT = 256 13 | WIDTH = 256 14 | save_model_dir = "D:\\file\paper\\new1\\fusion\models" #"path to folder where trained model will be saved." 15 | save_loss_dir = "D:\\file\paper\\new1\\fusion\models/loss" # "path to folder where trained model will be saved." 16 | 17 | image_size = 256 #"size of training images, default is 256 X 256" 18 | cuda = 1 #"set it to 1 for running on GPU, 0 for CPU" 19 | seed = 42 #"random seed for training" 20 | # ssim_weight = [1,10,100,1000,10000] 21 | ssim_path = ['1e0', '1e1', '1e2', '1e3', '1e4'] 22 | alpha = 0.5 23 | beta = 0.5 24 | gama = 1 25 | yita = 1 26 | deta = 1 27 | 28 | lr = 1e-4 #"learning rate, default is 0.001" 29 | lr_d = 1e-4 30 | # lr_light = 1e-4 # "learning rate, default is 0.001" 31 | log_interval = 10 #"number of images after which the training loss is logged, default is 500" 32 | # resume = "./models/pre/Epoch_8_iters_2500.model" 33 | resume = None 34 | # trans_model_path = "trans_model/VITB.pth" 35 | trans_model_path = None 36 | is_para = False 37 | 38 | # model_path_gray = None 39 | # model_path_edge = "./models/pre/network-bsds500.pytorch" 40 | # model_path_depth = "./models/pre/Best_model_period1.t7" 41 | 42 | 43 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # data loader 2 | from __future__ import print_function, division 3 | import glob 4 | import torch 5 | from skimage import io, transform, color 6 | import numpy as np 7 | import random 8 | from scipy.misc import imread, imsave, imresize 9 | from torchvision import transforms, utils 10 | from PIL import Image 11 | from os import listdir, mkdir, sep 12 | from os.path import join, exists, splitext 13 | import matplotlib as mpl 14 | 15 | #==========================dataset load========================== 16 | def list_images(directory): 17 | images = [] 18 | names = [] 19 | dir = listdir(directory) 20 | dir.sort() 21 | for file in dir: 22 | name = file.lower() 23 | if name.endswith('.png'): 24 | images.append(join(directory, file)) 25 | elif name.endswith('.jpg'): 26 | images.append(join(directory, file)) 27 | elif name.endswith('.jpeg'): 28 | images.append(join(directory, file)) 29 | name1 = name.split('.') 30 | names.append(name1[0]) 31 | return images 32 | 33 | def tensor_load_rgbimage(filename, size=None, scale=None, keep_asp=False): 34 | img = Image.open(filename).convert('RGB') 35 | if size is not None: 36 | if keep_asp: 37 | size2 = int(size * 1.0 / img.size[0] * img.size[1]) 38 | img = img.resize((size, size2), Image.ANTIALIAS) 39 | else: 40 | img = img.resize((size, size), Image.ANTIALIAS) 41 | 42 | elif scale is not None: 43 | img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS) 44 | img = np.array(img).transpose(2, 0, 1) 45 | img = torch.from_numpy(img).float() 46 | return img 47 | 48 | 49 | def tensor_save_rgbimage(tensor, filename, cuda=True): 50 | if cuda: 51 | img = tensor.cpu().clamp(0, 255).data[0].numpy() 52 | else: 53 | img = tensor.clamp(0, 255).numpy() 54 | img = img.transpose(1, 2, 0).astype('uint8') 55 | img = Image.fromarray(img) 56 | img.save(filename) 57 | 58 | def tensor_save_bgrimage(tensor, filename, cuda=False): 59 | (b, g, r) = torch.chunk(tensor, 3) 60 | tensor = torch.cat((r, g, b)) 61 | tensor_save_rgbimage(tensor, filename, cuda) 62 | 63 | # load training images 64 | def load_dataset(image_path, BATCH_SIZE, num_imgs=None): 65 | if num_imgs is None: 66 | num_imgs = len(image_path) 67 | original_imgs_path = image_path[:num_imgs] 68 | # random 69 | random.shuffle(original_imgs_path) 70 | mod = num_imgs % BATCH_SIZE 71 | print('BATCH SIZE %d.' % BATCH_SIZE) 72 | print('Train images number %d.' % num_imgs) 73 | print('Train images samples %s.' % str(num_imgs / BATCH_SIZE)) 74 | 75 | if mod > 0: 76 | print('Train set has been trimmed %d samples...\n' % mod) 77 | original_imgs_path = original_imgs_path[:-mod] 78 | batches = int(len(original_imgs_path) // BATCH_SIZE) 79 | return original_imgs_path, batches 80 | 81 | 82 | def get_image(path, height=256, width=256): 83 | image = Image.open(path).convert('RGB') 84 | 85 | if height is not None and width is not None: 86 | image = imresize(image, [height, width], interp='nearest') 87 | return image 88 | 89 | 90 | def get_train_images_auto(paths, height=256, width=256): 91 | if isinstance(paths, str): 92 | paths = [paths] 93 | images = [] 94 | for path in paths: 95 | image = get_image(path, height, width,) 96 | image = np.reshape(image, [image.shape[2], image.shape[0], image.shape[1]]) 97 | images.append(image) 98 | 99 | images = np.stack(images, axis=0) 100 | images = torch.from_numpy(images).float() 101 | return images 102 | 103 | # def get_test_images(path_con,path_sty, height=None, width=None): 104 | # ImageToTensor = transforms.Compose([transforms.ToTensor()]) 105 | # 106 | # cons = [] 107 | # stys = [] 108 | # 109 | # con = get_image(path_con, height, width) 110 | # sty = get_image(path_sty, height, width) 111 | # w_con, h_con = con.size 112 | # w_sty, h_sty = sty.size 113 | # w = w_con if w_con 0. else nn.Identity() 72 | self.norm2 = norm_layer(dim) 73 | ffn_hidden_dim = int(dim * ffn_ratio) 74 | self.ffn = Ffn(in_features=dim, hidden_features=ffn_hidden_dim, act_layer=act_layer, drop=drop) 75 | 76 | def forward(self, x): 77 | x = self.norm1(x) 78 | q, k, v = x , x , x 79 | x = x + self.attn(q, k, v) 80 | x = x + self.ffn(self.norm2(x)) 81 | return x 82 | 83 | 84 | class DecoderLayer(nn.Module): 85 | 86 | def __init__(self, dim, num_heads, ffn_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 87 | act_layer=nn.ReLU, norm_layer=nn.LayerNorm): 88 | super().__init__() 89 | self.norm1 = norm_layer(dim) 90 | self.attn1 = Attention( 91 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 92 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 93 | # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 94 | self.norm2 = norm_layer(dim) 95 | self.attn2 = Attention( 96 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 97 | self.norm3 = norm_layer(dim) 98 | ffn_hidden_dim = int(dim * ffn_ratio) 99 | self.ffn = Ffn(in_features=dim, hidden_features=ffn_hidden_dim, act_layer=act_layer, drop=drop) 100 | 101 | def forward(self, x): 102 | memory = x 103 | x = self.norm1(x) 104 | q, k, v = x, x, x 105 | x = x + self.attn1(q, k, v) 106 | x = self.norm2(x) 107 | q, k, v = x, memory, memory 108 | x = x + self.attn2(q, k, v) 109 | x = x + self.ffn(self.norm3(x)) 110 | return x 111 | 112 | 113 | class ResBlock(nn.Module): 114 | 115 | def __init__(self, channels): 116 | super(ResBlock, self).__init__() 117 | self.conv1 = nn.Conv2d(channels, channels, kernel_size=5, stride=1, 118 | padding=2, bias=False) 119 | # self.bn1 = nn.BatchNorm2d(channels) 120 | self.relu = nn.ReLU(inplace=True) 121 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=5, stride=1, 122 | padding=2, bias=False) 123 | # self.bn2 = nn.BatchNorm2d(channels) 124 | 125 | def forward(self, x): 126 | residual = x 127 | 128 | out = self.conv1(x) 129 | # out = self.bn1(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv2(out) 133 | # out = self.bn2(out) 134 | 135 | out += residual 136 | # out = self.relu(out) 137 | 138 | return out 139 | 140 | 141 | class Head(nn.Module): 142 | """ Head consisting of convolution layers 143 | Extract features from corrupted images, mapping N3HW images into NCHW feature map. 144 | """ 145 | 146 | def __init__(self, in_channels, out_channels): 147 | super(Head, self).__init__() 148 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, 149 | padding=1, bias=False) 150 | # self.bn1 = nn.BatchNorm2d(out_channels) if task_id in [0, 1, 5] else nn.Identity() 151 | # self.relu = nn.ReLU(inplace=True) 152 | self.resblock1 = ResBlock(out_channels) 153 | self.resblock2 = ResBlock(out_channels) 154 | 155 | def forward(self, x): 156 | out = self.conv1(x) 157 | # out = self.bn1(out) 158 | # out = self.relu(out) 159 | 160 | out = self.resblock1(out) 161 | out = self.resblock2(out) 162 | 163 | return out 164 | 165 | 166 | class PatchEmbed(nn.Module): 167 | """ Feature to Patch Embedding 168 | input : N C H W 169 | output: N num_patch P^2*C 170 | """ 171 | 172 | def __init__(self, patch_size=1, in_channels=64): 173 | super().__init__() 174 | self.patch_size = patch_size 175 | self.dim = self.patch_size ** 2 * in_channels 176 | 177 | def forward(self, x): 178 | N, C, H, W = ori_shape = x.shape 179 | 180 | p = self.patch_size 181 | num_patches = (H // p) * (W // p) 182 | out = torch.zeros((N, num_patches, self.dim)).to(x.device) 183 | # print(f"feature map size: {ori_shape}, embedding size: {out.shape}") 184 | i, j = 0, 0 185 | for k in range(num_patches): 186 | if i + p > W: 187 | i = 0 188 | j += p 189 | out[:, k, :] = x[:, :, i:i + p, j:j + p].flatten(1) 190 | i += p 191 | return out, ori_shape 192 | 193 | 194 | class DePatchEmbed(nn.Module): 195 | """ Patch Embedding to Feature 196 | input : N num_patch P^2*C 197 | output: N C H W 198 | """ 199 | 200 | def __init__(self, patch_size=1, in_channels=64): 201 | super().__init__() 202 | self.patch_size = patch_size 203 | self.num_patches = None 204 | self.dim = self.patch_size ** 2 * in_channels 205 | 206 | def forward(self, x, ori_shape): 207 | N, num_patches, dim = x.shape 208 | _, C, H, W = ori_shape 209 | p = self.patch_size 210 | out = torch.zeros(ori_shape).to(x.device) 211 | i, j = 0, 0 212 | for k in range(num_patches): 213 | if i + p > W: 214 | i = 0 215 | j += p 216 | out[:, :, i:i + p, j:j + p] = x[:, k, :].reshape(N, C, p, p) 217 | # out[:, k, :] = x[:, :, i:i+p, j:j+p].flatten(1) 218 | i += p 219 | return out 220 | 221 | 222 | class Tail(nn.Module): 223 | """ Tail consisting of convolution layers and pixel shuffle layers 224 | NCHW -> N3HW. 225 | """ 226 | 227 | def __init__(self, in_channels, out_channels): 228 | super(Tail, self).__init__() 229 | # assert 0 <= task_id <= 5 230 | # 0, 1 for noise 30, 50; 2, 3, 4 for sr x2, x3, x4, 5 for defog 231 | # upscale_map = [1, 1, 2, 3, 4, 1] 232 | # scale = upscale_map[task_id] 233 | m = [] 234 | # for SR task 235 | # if scale > 1: 236 | # m.append(nn.Conv2d(in_channels, in_channels * scale * scale, kernel_size=3, stride=1, 237 | # padding=1, bias=False)) 238 | # if (scale & (scale - 1)) == 0: 239 | # for _ in range(int(math.log(scale, 2))): 240 | # m.append(nn.PixelShuffle(2)) 241 | # elif scale == 3: 242 | # m.append(nn.PixelShuffle(3)) 243 | # else: 244 | # raise NameError("Only support x3 and x2^n SR") 245 | 246 | m.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, 247 | padding=1, bias=False)) 248 | self.m = nn.Sequential(*m) 249 | 250 | def forward(self, x): 251 | out = self.m(x) 252 | # print("task_id:", self.task_id) 253 | # print("shape of tail's output:", x.shape) 254 | # out = self.bn1(out) 255 | return out 256 | 257 | 258 | class ImageProcessingTransformer(nn.Module): 259 | """ Vision Transformer with support for patch or hybrid CNN input stage 260 | """ 261 | 262 | def __init__(self, patch_size=1, in_channels=1, mid_channels=3, num_classes=1000, depth=4, 263 | num_heads=4, ffn_ratio=2., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 264 | norm_layer=nn.LayerNorm): 265 | super(ImageProcessingTransformer, self).__init__() 266 | 267 | self.num_classes = num_classes 268 | self.embed_dim = patch_size * patch_size * mid_channels 269 | self.headsets = Head(in_channels, mid_channels) 270 | self.patch_embedding = PatchEmbed(patch_size=patch_size, in_channels=mid_channels) 271 | self.embed_dim = self.patch_embedding.dim 272 | if self.embed_dim % num_heads != 0: 273 | raise RuntimeError("Embedding dim must be devided by numbers of heads") 274 | 275 | # self.pos_embed = nn.Parameter(torch.zeros(1, (128 // patch_size) ** 2, self.embed_dim)) 276 | # self.task_embed = nn.Parameter(torch.zeros(6, 1, (48 // patch_size) ** 2, self.embed_dim)) 277 | self.encoder = nn.ModuleList([ 278 | EncoderLayer( 279 | dim=self.embed_dim, num_heads=num_heads, ffn_ratio=ffn_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 280 | drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer) 281 | for _ in range(depth)]) 282 | self.decoder = nn.ModuleList([ 283 | DecoderLayer( 284 | dim=self.embed_dim, num_heads=num_heads, ffn_ratio=ffn_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 285 | drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer) 286 | for _ in range(depth)]) 287 | # self.norm = norm_layer(self.embed_dim) 288 | 289 | self.de_patch_embedding = DePatchEmbed(patch_size=patch_size, in_channels=mid_channels) 290 | # tail 291 | self.tailsets = Tail(mid_channels, in_channels) 292 | 293 | #trunc_normal_(self.pos_embed, std=.02) 294 | self.apply(self._init_weights) 295 | 296 | # def set_task(self, task_id): 297 | # self.task_id = task_id 298 | 299 | def _init_weights(self, m): 300 | if isinstance(m, nn.Linear): 301 | trunc_normal_(m.weight, std=.02) 302 | if isinstance(m, nn.Linear) and m.bias is not None: 303 | nn.init.constant_(m.bias, 0) 304 | elif isinstance(m, nn.LayerNorm): 305 | nn.init.constant_(m.bias, 0) 306 | nn.init.constant_(m.weight, 1.0) 307 | 308 | def en(self, x): 309 | x = self.headsets(x) 310 | x, ori_shape = self.patch_embedding(x) 311 | # print("embedding shape:", x.shape) 312 | # print(x.device, self.pos_embed.device) 313 | for blk in self.encoder: 314 | x = blk(x) 315 | return x, ori_shape 316 | 317 | def de(self, x, ori_shape): 318 | for blk in self.decoder: 319 | x = blk(x) 320 | x = self.de_patch_embedding(x, ori_shape) 321 | x = self.tailsets(x) 322 | return x 323 | 324 | def forward(self, x): 325 | x, ori = self.en(x) 326 | out = self.de(x, ori) 327 | return out 328 | 329 | 330 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 331 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 332 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 333 | def norm_cdf(x): 334 | # Computes standard normal cumulative distribution function 335 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 336 | 337 | if (mean < a - 2 * std) or (mean > b + 2 * std): 338 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 339 | "The distribution of values may be incorrect.", 340 | stacklevel=2) 341 | 342 | with torch.no_grad(): 343 | # Values are generated by using a truncated uniform distribution and 344 | # then using the inverse CDF for the normal distribution. 345 | # Get upper and lower cdf values 346 | l = norm_cdf((a - mean) / std) 347 | u = norm_cdf((b - mean) / std) 348 | 349 | # Uniformly fill tensor with values from [l, u], then translate to 350 | # [2l-1, 2u-1]. 351 | tensor.uniform_(2 * l - 1, 2 * u - 1) 352 | 353 | # Use inverse cdf transform for normal distribution to get truncated 354 | # standard normal 355 | tensor.erfinv_() 356 | 357 | # Transform to proper mean, std 358 | tensor.mul_(std * math.sqrt(2.)) 359 | tensor.add_(mean) 360 | 361 | # Clamp to ensure it's in the proper range 362 | tensor.clamp_(min=a, max=b) 363 | return tensor 364 | 365 | 366 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 367 | # type: (Tensor, float, float, float, float) -> Tensor 368 | r"""Fills the input Tensor with values drawn from a truncated 369 | normal distribution. The values are effectively drawn from the 370 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 371 | with values outside :math:`[a, b]` redrawn until they are within 372 | the bounds. The method used for generating the random values works 373 | best when :math:`a \leq \text{mean} \leq b`. 374 | Args: 375 | tensor: an n-dimensional `torch.Tensor` 376 | mean: the mean of the normal distribution 377 | std: the standard deviation of the normal distribution 378 | a: the minimum cutoff value 379 | b: the maximum cutoff value 380 | Examples: 381 | >>> w = torch.empty(3, 5) 382 | >>> nn.init.trunc_normal_(w) 383 | """ 384 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 385 | 386 | 387 | def ipt_base(): 388 | model = ImageProcessingTransformer( 389 | patch_size=4, depth=4, num_heads=4, ffn_ratio=2, qkv_bias=True, 390 | norm_layer=partial(nn.LayerNorm, eps=1e-6)) 391 | return model 392 | 393 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from math import exp 5 | import numpy as np 6 | import cv2 7 | import math 8 | 9 | 10 | def dis_loss_func(vis_output, fusion_output): 11 | # a = torch.mean(torch.square(vis_output - torch.Tensor(vis_output.shape).uniform_(0.7, 1.2))) 12 | return torch.mean(torch.square(vis_output - torch.Tensor(vis_output.shape).uniform_(0.7, 1.2).cuda())) + \ 13 | torch.mean(torch.square(fusion_output.cuda() - torch.Tensor(fusion_output.shape).uniform_(0, 0.3).cuda())) 14 | 15 | def loss_I(real_pair, fake_pair): 16 | batch_size = real_pair.size()[0] 17 | real_pair = 1 - real_pair 18 | real_pair = real_pair ** 2 19 | fake_pair = fake_pair ** 2 20 | real_pair = torch.sum(real_pair) 21 | fake_pair = torch.sum(fake_pair) 22 | return (real_pair + fake_pair) / batch_size 23 | 24 | def w_loss(img_ir): 25 | w1d = F.max_pool2d(img_ir, 2, 2) 26 | w2d = F.max_pool2d(w1d, 2, 2) 27 | w2u = F.upsample_bilinear(w2d, scale_factor=2) 28 | w_ir = F.upsample_bilinear(w2u, scale_factor=2) 29 | w_ir = F.softmax(w_ir, 0) 30 | w_vi = 1-w_ir 31 | return w_ir, w_vi 32 | 33 | 34 | def gaussian(window_size, sigma): 35 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 36 | return gauss/gauss.sum() 37 | 38 | 39 | def create_window(window_size, channel=1): 40 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) # sigma = 1.5 shape: [11, 1] 41 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) # unsqueeze()函数,增加维度 .t() 进行了转置 shape: [1, 1, 11, 11] 42 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() # window shape: [1,1, 11, 11] 43 | return window 44 | 45 | 46 | # 计算 ssim 损失函数 47 | def mssim(img1, img2, window_size=11): 48 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 49 | 50 | max_val = 255 51 | min_val = 0 52 | L = max_val - min_val 53 | padd = window_size // 2 54 | 55 | 56 | (_, channel, height, width) = img1.size() 57 | 58 | # 滤波器窗口 59 | window = create_window(window_size, channel=channel).to(img1.device) 60 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 61 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 62 | 63 | mu1_sq = mu1.pow(2) 64 | mu2_sq = mu2.pow(2) 65 | mu1_mu2 = mu1 * mu2 66 | 67 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 68 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 69 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 70 | 71 | C1 = (0.01 * L) ** 2 72 | C2 = (0.03 * L) ** 2 73 | 74 | v1 = 2.0 * sigma12 + C2 75 | v2 = sigma1_sq + sigma2_sq + C2 76 | cs = torch.mean(v1 / v2) # contrast sensitivity 77 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 78 | ret = ssim_map 79 | return ret 80 | 81 | def mse(img1, img2, window_size=9): 82 | max_val = 255 83 | min_val = 0 84 | L = max_val - min_val 85 | padd = window_size // 2 86 | 87 | (_, channel, height, width) = img1.size() 88 | 89 | img1_f = F.unfold(img1, (window_size, window_size), padding=padd) 90 | img2_f = F.unfold(img2, (window_size, window_size), padding=padd) 91 | 92 | res = (img1_f - img2_f) ** 2 93 | 94 | res = torch.sum(res, dim=1, keepdim=True) / (window_size ** 2) 95 | 96 | res = F.fold(res, output_size=(256, 256), kernel_size=(1, 1)) 97 | return res 98 | 99 | 100 | # 方差计算 101 | def std(img, window_size=9): 102 | 103 | padd = window_size // 2 104 | (_, channel, height, width) = img.size() 105 | window = create_window(window_size, channel=channel).to(img.device) 106 | mu = F.conv2d(img, window, padding=padd, groups=channel) 107 | mu_sq = mu.pow(2) 108 | sigma1 = F.conv2d(img * img, window, padding=padd, groups=channel) - mu_sq 109 | 110 | return sigma1 111 | 112 | def sum(img, window_size=9): 113 | 114 | padd = window_size // 2 115 | (_, channel, height, width) = img.size() 116 | window = create_window(window_size, channel=channel).to(img.device) 117 | win1 = torch.ones_like(window) 118 | res = F.conv2d(img, win1, padding=padd, groups=channel) 119 | return res 120 | 121 | 122 | 123 | def final_ssim(img_ir, img_vis, img_fuse): 124 | 125 | ssim_ir = mssim(img_ir, img_fuse) 126 | ssim_vi = mssim(img_vis, img_fuse) 127 | 128 | # std_ir = std(img_ir) 129 | # std_vi = std(img_vis) 130 | std_ir = std(img_ir) 131 | std_vi = std(img_vis) 132 | 133 | zero = torch.zeros_like(std_ir) 134 | one = torch.ones_like(std_vi) 135 | 136 | # m = torch.mean(img_ir) 137 | # w_ir = torch.where(img_ir > m, one, zero) 138 | 139 | map1 = torch.where((std_ir - std_vi) > 0, one, zero) 140 | map2 = torch.where((std_ir - std_vi) >= 0, zero, one) 141 | 142 | ssim = map1 * ssim_ir + map2 * ssim_vi 143 | # ssim = ssim * w_ir 144 | return ssim.mean() 145 | 146 | def final_mse(img_ir, img_vis, img_fuse): 147 | mse_ir = mse(img_ir, img_fuse) 148 | mse_vi = mse(img_vis, img_fuse) 149 | 150 | std_ir = std(img_ir) 151 | std_vi = std(img_vis) 152 | # std_ir = sum(img_ir) 153 | # std_vi = sum(img_vis) 154 | 155 | zero = torch.zeros_like(std_ir) 156 | one = torch.ones_like(std_vi) 157 | 158 | m = torch.mean(img_ir) 159 | w_vi = torch.where(img_ir <= m, one, zero) 160 | 161 | map1 = torch.where((std_ir - std_vi) > 0, one, zero) 162 | map2 = torch.where((std_ir - std_vi) >= 0, zero, one) 163 | 164 | res = map1 * mse_ir + map2 * mse_vi 165 | res = res * w_vi 166 | return res.mean() 167 | 168 | 169 | 170 | if __name__ == '__main__': 171 | criterion = mssim 172 | input = torch.rand([1, 1, 64, 64]) 173 | output = torch.rand([1, 1, 64, 64]) 174 | img_fuse = torch.rand([1, 1, 64, 64]) 175 | uw = torch.Tensor(np.ones((11, 11), dtype=float)) / 11 176 | uw = uw.float().unsqueeze(0).unsqueeze(0) 177 | print(uw) 178 | input = input.cuda() 179 | output = output.cuda() 180 | img_fuse = img_fuse.cuda() 181 | ssim = final_ssim(input, output, img_fuse) 182 | print(ssim) 183 | -------------------------------------------------------------------------------- /models/Baiduyun: -------------------------------------------------------------------------------- 1 | 链接:https://pan.baidu.com/s/130DzITsfzKYppD1SYK03bQ 2 | 提取码:vwov 3 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from t2t_vit import Channel, Spatial 6 | from function import adaptive_instance_normalization 7 | 8 | 9 | # Convolution operation 10 | class ConvLayer(torch.nn.Module): 11 | def __init__(self, in_channels, out_channels, kernel_size, stride, is_last=False): 12 | super(ConvLayer, self).__init__() 13 | reflection_padding = int(np.floor(kernel_size / 2)) 14 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 15 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) 16 | # self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 17 | # self.bn = nn.BatchNorm2d(out_channels) 18 | self.dropout = nn.Dropout2d(p=0.5) 19 | self.is_last = is_last 20 | 21 | def forward(self, x): 22 | x = x.cuda() 23 | 24 | # out = self.conv2d(x) 25 | # out = self.bn(out) 26 | out = self.reflection_pad(x) 27 | 28 | out = self.conv2d(out) 29 | 30 | if self.is_last is False: 31 | out = F.leaky_relu(out, inplace=True) 32 | return out 33 | 34 | class ResidualBlock(nn.Module): 35 | def __init__(self, channels): 36 | super(ResidualBlock, self).__init__() 37 | self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1) 38 | self.bn1 = nn.BatchNorm2d(channels, affine=True) 39 | self.relu = nn.ReLU() 40 | self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1) 41 | self.bn2 = nn.BatchNorm2d(channels, affine=True) 42 | 43 | def forward(self, x): 44 | residual = x 45 | out = self.relu(self.bn1(self.conv1(x))) 46 | out = self.bn2(self.conv2(out)) 47 | out = out + residual 48 | out = self.relu(out) 49 | return out 50 | 51 | 52 | class SELayer(nn.Module): 53 | def __init__(self, channel, reduction=16): 54 | super(SELayer, self).__init__() 55 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 56 | self.fc = nn.Sequential( 57 | nn.Linear(channel, channel // reduction, bias=False), 58 | nn.ReLU(inplace=True), 59 | nn.Linear(channel // reduction, channel, bias=False), 60 | nn.Sigmoid() 61 | ) 62 | 63 | def forward(self, x): 64 | b, c, _, _ = x.size() 65 | y = self.avg_pool(x).view(b, c) 66 | y = self.fc(y).view(b, c, 1, 1) 67 | return x * y.expand_as(x) 68 | 69 | 70 | class self_a(nn.Module): 71 | """ Self attention Layer""" 72 | 73 | def __init__(self, in_dim): 74 | super(self_a, self).__init__() 75 | self.chanel_in = in_dim 76 | # self.activation = activation 77 | 78 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 79 | self.value_conv_x = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 80 | self.value_conv_y = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 81 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 82 | 83 | self.s_conv = nn.Conv2d(in_channels=2*in_dim, out_channels=in_dim, kernel_size=1) 84 | self.softmax = nn.Softmax(dim=-1) 85 | self.softmax2 = nn.Softmax(dim=1) 86 | self.gamma = nn.Parameter(torch.zeros(1)) 87 | self.beta = nn.Parameter(torch.zeros(1)) 88 | 89 | def forward(self, x, y): 90 | """ 91 | inputs : 92 | x : input feature maps( B * C * W * H) 93 | returns : 94 | out : self attention value + input feature 95 | attention: B * N * N (N is Width*Height) 96 | """ 97 | m_batchsize, C, width, height = x.size() 98 | 99 | proj_query = self.query_conv(y).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B*N*C 100 | 101 | proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B*C*N 102 | 103 | energy = torch.bmm(proj_query, proj_key) # batch的matmul B*N*N 104 | attention = self.softmax(energy) # B * (N) * (N) 105 | 106 | proj_value_x = self.value_conv_x(x).view(m_batchsize, -1, width * height) # B * C * N 107 | proj_value_y = self.value_conv_y(y).view(m_batchsize, -1, width * height) 108 | 109 | out_x = torch.bmm(proj_value_x, attention.permute(0, 2, 1)) # B*C*N 110 | out_y = torch.bmm(proj_value_y, attention.permute(0, 2, 1)) 111 | 112 | out_x = out_x.view(m_batchsize, C, width, height) 113 | out_y = out_y.view(m_batchsize, C, width, height) 114 | 115 | x_att = self.gamma * out_x 116 | y_att = self.beta * out_y 117 | 118 | return x_att, y_att 119 | 120 | class Encoder(nn.Module): 121 | def __init__(self, in_channels, out_channels, kernel_size, stride): 122 | super().__init__() 123 | self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride) 124 | self.res = ResidualBlock(out_channels) 125 | self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, stride) 126 | 127 | def forward(self, x): 128 | x = self.conv1(x) 129 | x = self.res(x) 130 | x = self.conv2(x) 131 | return x 132 | 133 | class Decoder(nn.Module): 134 | def __init__(self, in_channels, out_channels, kernel_size, stride): 135 | super().__init__() 136 | self.conv1 = ConvLayer(in_channels, in_channels//2, kernel_size, stride) 137 | self.conv2 = ConvLayer(in_channels//2, out_channels, kernel_size, stride) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = self.conv2(x) 142 | return x 143 | 144 | 145 | #1 146 | class net(nn.Module): 147 | def __init__(self, input_nc=2, output_nc=1): 148 | super(net, self).__init__() 149 | kernel_size = 1 150 | stride = 1 151 | 152 | self.down1 = nn.AvgPool2d(2) 153 | self.down2 = nn.AvgPool2d(4) 154 | self.down3 = nn.AvgPool2d(8) 155 | 156 | self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 157 | self.up2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 158 | self.up3 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 159 | 160 | 161 | 162 | self.conv_in1 = ConvLayer(input_nc, input_nc, kernel_size, stride) 163 | self.conv_out = ConvLayer(64, 1, kernel_size, stride, is_last=True) 164 | # self.conv_t3 = ConvLayer(128, 64, kernel_size=1, stride=1) 165 | # self.conv_t2 = ConvLayer(64, 32, kernel_size=1, stride=1) 166 | # self.conv_t0 = ConvLayer(3, 3, kernel_size, stride) 167 | 168 | self.en0 = Encoder(2, 64, kernel_size, stride) 169 | self.en1 = Encoder(64, 64, kernel_size, stride) 170 | self.en2 = Encoder(64, 64, kernel_size, stride) 171 | self.en3 = Encoder(64, 64, kernel_size, stride) 172 | 173 | # self.de3 = Decoder(96, 32, kernel_size, stride) 174 | # self.de2 = Decoder(48, 16, kernel_size, stride) 175 | # self.de1 = Decoder(19, 3, kernel_size, stride) 176 | # self.de0 = Decoder(3, 3, kernel_size, stride) 177 | 178 | # self.f1 = ConvLayer(6, 3, kernel_size, stride) 179 | # self.f2 = ConvLayer(32, 16, kernel_size, stride) 180 | # self.f3 = ConvLayer(64, 32, kernel_size, stride) 181 | 182 | # self.ctrans0 = Channel(size=256, embed_dim=128, patch_size=16, channel=3) 183 | # self.ctrans1 = Channel(size=128, embed_dim=128, patch_size=16, channel=16) 184 | # self.ctrans2 = Channel(size=64, embed_dim=128, patch_size=16, channel=32) 185 | self.ctrans3 = Channel(size=32, embed_dim=128, patch_size=16, channel=64) 186 | 187 | #self.strans0 = Spatial(size=256, embed_dim=128*2, patch_size=8, channel=3) 188 | #self.strans1 = Spatial(size=128, embed_dim=256*2, patch_size=8, channel=16) 189 | # self.strans2 = Spatial(size=256, embed_dim=512*2, patch_size=8, channel=32) 190 | self.strans3 = Spatial(size=256, embed_dim=1024*2, patch_size=4, channel=64) 191 | 192 | 193 | def en(self, vi, ir): 194 | f = torch.cat([vi, ir], dim=1) 195 | x = self.conv_in1(f) 196 | x0 = self.en0(x) 197 | x1 = self.en1(self.down1(x0)) 198 | x2 = self.en2(self.down1(x1)) 199 | x3 = self.en3(self.down1(x2)) 200 | 201 | return [x0, x1, x2, x3] 202 | 203 | # def de(self, f): 204 | # x0, x1, x2, x3 = f 205 | # o3 = self.de3(torch.cat([self.up1(x3), x2], dim=1)) 206 | # o2 = self.de2(torch.cat([self.up1(o3), x1], dim=1)) 207 | # o1 = self.de1(torch.cat([self.up1(o2), x0], dim=1)) 208 | # o0 = self.de0(o1) 209 | # out = self.conv_out1(o0) 210 | # return out 211 | 212 | def forward(self, vi, ir): 213 | # w = ir / (torch.max(ir) - torch.min(ir)) 214 | # f_pre = w * ir + (1-w) * vi 215 | f0 = torch.cat([vi, ir], dim=1) 216 | x = self.conv_in1(f0) 217 | x0 = self.en0(x) 218 | x1 = self.en1(self.down1(x0)) 219 | x2 = self.en2(self.down1(x1)) 220 | x3 = self.en3(self.down1(x2)) 221 | 222 | x3t = self.strans3(self.ctrans3(x3)) 223 | # x2r = self.ctrans2(x2) 224 | # x1r = self.ctrans1(x1) 225 | # x0r = self.ctrans0(x0) 226 | # x3m = torch.clamp(x3r, 0, 1) 227 | x3m = x3t 228 | x3r = x3 * x3m 229 | x2m = self.up1(x3m) 230 | x2r = x2 * x2m 231 | x1m = self.up1(x2m) + self.up2(x3m) 232 | x1r = x1 * x1m 233 | x0m = self.up1(x1m) + self.up2(x2m) + self.up3(x3m) 234 | x0r = x0 * x0m 235 | 236 | other =self.up3(x3r) + self.up2(x2r) + self.up1(x1r) + x0r 237 | f1 = self.conv_out(other) 238 | # out = self.conv_out(f1) 239 | 240 | return f1 241 | 242 | 243 | 244 | -------------------------------------------------------------------------------- /outputs/fusion_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/outputs/fusion_1.png -------------------------------------------------------------------------------- /outputs/fusion_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/outputs/fusion_10.png -------------------------------------------------------------------------------- /outputs/fusion_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/outputs/fusion_2.png -------------------------------------------------------------------------------- /outputs/fusion_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/outputs/fusion_3.png -------------------------------------------------------------------------------- /outputs/fusion_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/outputs/fusion_4.png -------------------------------------------------------------------------------- /outputs/fusion_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/outputs/fusion_5.png -------------------------------------------------------------------------------- /outputs/fusion_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/outputs/fusion_6.png -------------------------------------------------------------------------------- /outputs/fusion_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/outputs/fusion_7.png -------------------------------------------------------------------------------- /outputs/fusion_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/outputs/fusion_8.png -------------------------------------------------------------------------------- /outputs/fusion_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/outputs/fusion_9.png -------------------------------------------------------------------------------- /pytorch_msssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from math import exp 4 | import numpy as np 5 | 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | 12 | def create_window(window_size, channel=1): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 16 | return window 17 | 18 | 19 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 20 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 21 | if val_range is None: 22 | if torch.max(img1) > 128: 23 | max_val = 255 24 | else: 25 | max_val = 1 26 | 27 | if torch.min(img1) < -0.5: 28 | min_val = -1 29 | else: 30 | min_val = 0 31 | L = max_val - min_val 32 | else: 33 | L = val_range 34 | 35 | padd = 0 36 | (_, channel, height, width) = img1.size() 37 | if window is None: 38 | real_size = min(window_size, height, width) 39 | window = create_window(real_size, channel=channel).to(img1.device) 40 | 41 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 42 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 43 | 44 | mu1_sq = mu1.pow(2) 45 | mu2_sq = mu2.pow(2) 46 | mu1_mu2 = mu1 * mu2 47 | 48 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 49 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 50 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 51 | 52 | C1 = (0.01 * L) ** 2 53 | C2 = (0.03 * L) ** 2 54 | 55 | v1 = 2.0 * sigma12 + C2 56 | v2 = sigma1_sq + sigma2_sq + C2 57 | cs = torch.mean(v1 / v2) # contrast sensitivity 58 | 59 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 60 | 61 | if size_average: 62 | ret = ssim_map.mean() 63 | else: 64 | ret = ssim_map.mean(1).mean(1).mean(1) 65 | 66 | if full: 67 | return ret, cs 68 | return ret 69 | 70 | 71 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): 72 | device = img1.device 73 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 74 | levels = weights.size()[0] 75 | mssim = [] 76 | mcs = [] 77 | for _ in range(levels): 78 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 79 | mssim.append(sim) 80 | mcs.append(cs) 81 | 82 | img1 = F.avg_pool2d(img1, (2, 2)) 83 | img2 = F.avg_pool2d(img2, (2, 2)) 84 | 85 | mssim = torch.stack(mssim) 86 | mcs = torch.stack(mcs) 87 | 88 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) 89 | if normalize: 90 | mssim = (mssim + 1) / 2 91 | mcs = (mcs + 1) / 2 92 | 93 | pow1 = mcs ** weights 94 | pow2 = mssim ** weights 95 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 96 | output = torch.prod(pow1[:-1] * pow2[-1]) 97 | return output 98 | 99 | 100 | # Classes to re-use window 101 | class SSIM(torch.nn.Module): 102 | def __init__(self, window_size=11, size_average=True, val_range=None): 103 | super(SSIM, self).__init__() 104 | self.window_size = window_size 105 | self.size_average = size_average 106 | self.val_range = val_range 107 | 108 | # Assume 1 channel for SSIM 109 | self.channel = 1 110 | self.window = create_window(window_size) 111 | 112 | def forward(self, img1, img2): 113 | (_, channel, _, _) = img1.size() 114 | 115 | if channel == self.channel and self.window.dtype == img1.dtype: 116 | window = self.window 117 | else: 118 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 119 | self.window = window 120 | self.channel = channel 121 | 122 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 123 | 124 | class MSSSIM(torch.nn.Module): 125 | def __init__(self, window_size=11, size_average=True, channel=3): 126 | super(MSSSIM, self).__init__() 127 | self.window_size = window_size 128 | self.size_average = size_average 129 | self.channel = channel 130 | 131 | def forward(self, img1, img2): 132 | # TODO: store window between calls if possible 133 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) 134 | -------------------------------------------------------------------------------- /pytorch_msssim/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/pytorch_msssim/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_msssim/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongyuya/TGFuse/4a6edef1b1ce3ffff36f3d1b20ad3a5dac5a4a37/pytorch_msssim/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /t2t_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | # from token_transformer import Token_transformer 6 | # from transformer_block import Block, get_sinusoid_encoding, PositionEmbs, DePatchEmbed 7 | from einops import rearrange 8 | from einops.layers.torch import Rearrange 9 | import torch.nn.functional as F 10 | from timm.models.layers import DropPath 11 | 12 | 13 | class ConvLayer(torch.nn.Module): 14 | def __init__(self, in_channels, out_channels, kernel_size, stride, is_last=False): 15 | super(ConvLayer, self).__init__() 16 | reflection_padding = int(np.floor(kernel_size / 2)) 17 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 18 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) 19 | # self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 20 | # self.bn = nn.BatchNorm2d(out_channels) 21 | self.dropout = nn.Dropout2d(p=0.5) 22 | self.is_last = is_last 23 | 24 | def forward(self, x): 25 | # x = x.cuda() 26 | 27 | # out = self.conv2d(x) 28 | # out = self.bn(out) 29 | out = self.reflection_pad(x) 30 | 31 | out = self.conv2d(out) 32 | 33 | if self.is_last is False: 34 | out = F.relu(out, inplace=True) 35 | return out 36 | 37 | class ResBlock(nn.Module): 38 | def __init__(self, channels): 39 | super(ResBlock, self).__init__() 40 | self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1) 41 | self.bn1 = nn.BatchNorm2d(channels, affine=True) 42 | self.relu = nn.ReLU() 43 | self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1) 44 | self.bn2 = nn.BatchNorm2d(channels, affine=True) 45 | 46 | def forward(self, x): 47 | residual = x 48 | out = self.relu(self.bn1(self.conv1(x))) 49 | out = self.bn2(self.conv2(out)) 50 | out = out + residual 51 | out = self.relu(out) 52 | return out 53 | 54 | class Mlp(nn.Module): 55 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 56 | super().__init__() 57 | out_features = out_features or in_features 58 | hidden_features = hidden_features or in_features 59 | self.fc1 = nn.Linear(in_features, hidden_features) 60 | self.act = act_layer() 61 | self.fc2 = nn.Linear(hidden_features, out_features) 62 | self.drop = nn.Dropout(drop) 63 | 64 | def forward(self, x): 65 | x = self.fc1(x) 66 | x = self.act(x) 67 | x = self.drop(x) 68 | x = self.fc2(x) 69 | x = self.drop(x) 70 | return x 71 | 72 | class Attention(nn.Module): 73 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 74 | super().__init__() 75 | self.num_heads = num_heads 76 | head_dim = dim // num_heads 77 | 78 | self.scale = qk_scale or head_dim ** -0.5 79 | 80 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 81 | self.attn_drop = nn.Dropout(attn_drop) 82 | self.proj = nn.Linear(dim, dim) 83 | self.proj_drop = nn.Dropout(proj_drop) 84 | 85 | def forward(self, x): 86 | B, N, C = x.shape 87 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 88 | q, k, v = qkv[0], qkv[1], qkv[2] 89 | 90 | attn = (q @ k.transpose(-2, -1)) * self.scale 91 | attn = attn.softmax(dim=-1) 92 | attn = self.attn_drop(attn) 93 | 94 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 95 | x = self.proj(x) 96 | x = self.proj_drop(x) 97 | return x 98 | 99 | class Block(nn.Module): 100 | 101 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 102 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 103 | super().__init__() 104 | self.norm1 = norm_layer(dim) 105 | self.attn = Attention( 106 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 107 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 108 | self.norm2 = norm_layer(dim) 109 | mlp_hidden_dim = int(dim * mlp_ratio) 110 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 111 | 112 | def forward(self, x): 113 | x = x + self.drop_path(self.attn(self.norm1(x))) 114 | x = x + self.drop_path(self.mlp(self.norm2(x))) 115 | return x 116 | 117 | 118 | class C_DePatch(nn.Module): 119 | def __init__(self, channel=3, embed_dim=128, patch_size=16): 120 | self.patch_size = patch_size 121 | super().__init__() 122 | self.projection = nn.Sequential( 123 | nn.Linear(embed_dim, patch_size**2), 124 | ) 125 | 126 | def forward(self, x, ori): 127 | b, c, h, w = ori 128 | h_ = h // self.patch_size 129 | w_ = w // self.patch_size 130 | x = self.projection(x) 131 | x = rearrange(x, '(b h w) c (p1 p2) -> b c (h p1) (w p2)', h=h_, w=w_, p1=self.patch_size, p2=self.patch_size) 132 | return x 133 | 134 | #class C_DePatch(nn.Module): 135 | # def __init__(self, channel=3, embed_dim=128, patch_size=16): 136 | # self.patch_size = patch_size 137 | # super().__init__() 138 | # self.projection = nn.Sequential( 139 | # nn.Linear(embed_dim, patch_size**2), 140 | # ) 141 | # self.f = nn.Linear(channel, 1) 142 | # 143 | # def forward(self, x, ori): 144 | # b, c, h, w = ori 145 | # h_ = h // self.patch_size 146 | # w_ = w // self.patch_size 147 | # x = self.projection(x) 148 | # x = rearrange(x, '(b h w) c (p1 p2) -> (b h w) (p1 p2) c', h=h_, w=w_, p1=self.patch_size, p2=self.patch_size) 149 | # x = self.f(x) 150 | # x = rearrange(x, '(b h w) (p1 p2) c -> b c (h p1) (w p2)', h=h_, w=w_, p1=self.patch_size, p2=self.patch_size) 151 | # return x 152 | 153 | class S_DePatch(nn.Module): 154 | def __init__(self, channel=16, embed_dim=128, patch_size=16): 155 | self.patch_size = patch_size 156 | super().__init__() 157 | self.projection = nn.Sequential( 158 | nn.Linear(embed_dim, patch_size**2), 159 | ) 160 | 161 | def forward(self, x, ori): 162 | b, c, h, w = ori 163 | h_ = h // self.patch_size 164 | w_ = w // self.patch_size 165 | x = self.projection(x) 166 | x = rearrange(x, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=h_, w=w_, p1=self.patch_size, p2=self.patch_size) 167 | return x 168 | 169 | 170 | 171 | class encoder(nn.Module): 172 | def __init__(self, embed_dim=256, depth=4, 173 | num_heads=4, mlp_ratio=2., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 174 | drop_path_rate=0., norm_layer=nn.LayerNorm): 175 | super().__init__() 176 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 177 | 178 | self.pos_drop = nn.Dropout(p=drop_rate) 179 | 180 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 181 | self.blocks = nn.ModuleList([ 182 | Block( 183 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 184 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 185 | for i in range(depth)]) 186 | self.norm = norm_layer(embed_dim) 187 | 188 | # Classifier head 189 | # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 190 | 191 | # trunc_normal_(self.cls_token, std=.02) 192 | 193 | def forward(self, x): 194 | 195 | # x = self.tokens_to_token(x) 196 | 197 | # x = self.pos_embedding(x, ori) 198 | 199 | for blk in self.blocks: 200 | x = blk(x) 201 | 202 | x = self.norm(x) 203 | return x 204 | 205 | 206 | 207 | class Channel(nn.Module): 208 | def __init__(self, size=224,embed_dim=128, depth=4, channel=16, 209 | num_heads=4, mlp_ratio=2., patch_size=16,qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 210 | drop_path_rate=0., norm_layer=nn.LayerNorm): 211 | super().__init__() 212 | 213 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 214 | self.embedding = nn.Sequential( 215 | Rearrange('b c (h p1) (w p2) -> (b h w) c (p1 p2)', p1=patch_size, p2=patch_size), 216 | nn.Linear(patch_size**2, embed_dim), 217 | ) 218 | # self.embedding = nn.Sequential( 219 | # nn.Conv2d(channel, channel, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size)), 220 | # Rearrange('b c h w -> b c (h w)'), 221 | # ) 222 | # self.linear = nn.Linear(embed_dim, embed_dim) 223 | 224 | self.pos_drop = nn.Dropout(p=drop_rate) 225 | 226 | self.en = encoder(embed_dim, depth, 227 | num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, 228 | drop_path_rate, norm_layer) 229 | 230 | self.depatch = C_DePatch(channel=channel, embed_dim=embed_dim, patch_size=patch_size) 231 | 232 | def forward(self, x): 233 | ori = x.shape 234 | x2_t = self.embedding(x) 235 | x2_t = self.pos_drop(x2_t) 236 | x2_t = self.en(x2_t) 237 | out = self.depatch(x2_t, ori) 238 | return out 239 | 240 | 241 | class Spatial(nn.Module): 242 | def __init__(self, size=256, embed_dim=128, depth=4, channel=16, 243 | num_heads=4, mlp_ratio=2., patch_size=16, qkv_bias=False, qk_scale=None, drop_rate=0., 244 | attn_drop_rate=0., 245 | drop_path_rate=0., norm_layer=nn.LayerNorm): 246 | super().__init__() 247 | 248 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 249 | self.embedding = nn.Sequential( 250 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size), 251 | nn.Linear(patch_size ** 2 * channel, embed_dim), 252 | ) 253 | # self.embedding = nn.Conv2d(in_chans, embed_dim*in_chans, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size)) 254 | # self.linear = nn.Linear(embed_dim, embed_dim) 255 | 256 | self.pos_drop = nn.Dropout(p=drop_rate) 257 | 258 | self.en = encoder(embed_dim, depth, 259 | num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, 260 | drop_path_rate, norm_layer) 261 | 262 | self.depatch = S_DePatch(channel=channel, embed_dim=embed_dim, patch_size=patch_size) 263 | 264 | def forward(self, x): 265 | ori = x.shape 266 | x2_t = self.embedding(x) 267 | x2_t = self.pos_drop(x2_t) 268 | x2_t = self.en(x2_t) 269 | out = self.depatch(x2_t, ori) 270 | return out 271 | 272 | 273 | 274 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # test phase 2 | 3 | import os 4 | 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 6 | # from skimage import io, transform 7 | import torch 8 | import torchvision 9 | from torch.autograd import Variable 10 | from function import adaptive_instance_normalization 11 | 12 | # from u2net_train import args 13 | import utils 14 | 15 | import numpy as np 16 | import time 17 | 18 | # from trans_net import ipt_base 19 | from net import net 20 | 21 | # normalize the predicted SOD probability map 22 | def load_model(path): 23 | # fuse_net = net() 24 | # pre_dict = torch.load(path) 25 | # new_pre = {} 26 | # for k, v in pre_dict.items(): 27 | # name = k[7:] 28 | # new_pre[name] = v 29 | # 30 | # fuse_net.load_state_dict(new_pre) 31 | 32 | fuse_net = net() 33 | fuse_net.load_state_dict(torch.load(path)) 34 | 35 | para = sum([np.prod(list(p.size())) for p in fuse_net.parameters()]) 36 | type_size = 4 37 | print('Model {} : params: {:4f}M'.format(fuse_net._get_name(), para * type_size / 1000 / 1000)) 38 | 39 | fuse_net.eval() 40 | fuse_net.cuda() 41 | 42 | return fuse_net 43 | 44 | 45 | def _generate_fusion_image(model, vi, ir): 46 | 47 | # vi_f, w_vi, vi_w = model.en_vi(vi) 48 | # ir_f, w_ir, ir_w = model.en_ir(ir) 49 | # fuse = model.fusion(vi_f, ir_f) 50 | # out = model.de(fuse) 51 | # vi_f = model.en_vi(vi) 52 | # ir_f = model.en_ir(ir) 53 | # 54 | # fusion = model.fusion(vi_f, ir_f) 55 | 56 | # en_img, ori_shape = model.en(vi) 57 | 58 | out = model(vi, ir) 59 | return out 60 | 61 | 62 | def run_demo(model, vi_path, ir_path, output_path_root, index): 63 | vi_img = utils.get_test_images(vi_path, height=None, width=None) 64 | ir_img = utils.get_test_images(ir_path, height=None, width=None) 65 | 66 | out = utils.get_image(vi_path, height=None, width=None) 67 | # dim = img_ir.shape 68 | 69 | vi_img = vi_img.cuda() 70 | ir_img = ir_img.cuda() 71 | vi_img = Variable(vi_img, requires_grad=False) 72 | ir_img = Variable(ir_img, requires_grad=False) 73 | # dimension = con_img.size() 74 | 75 | img_fusion = _generate_fusion_image(model, vi_img, ir_img) 76 | ############################ multi outputs ############################################## 77 | file_name = 'fusion_' + str(index) + '.png' 78 | output_path = output_path_root + file_name 79 | if torch.cuda.is_available(): 80 | img = img_fusion.cpu().clamp(0, 255).numpy() 81 | else: 82 | img = img_fusion.clamp(0, 255).numpy() 83 | img = img.astype('uint8') 84 | utils.save_images(output_path, img, out) 85 | # utils.save_images(output_path, img, out) 86 | print(output_path) 87 | 88 | 89 | def main(): 90 | vi_path = "images/Test_vi/" 91 | ir_path = "images/Test_ir/" 92 | # network_type = 'densefuse' 93 | 94 | output_path = './outputs/' 95 | # strategy_type = strategy_type_list[0] 96 | 97 | if os.path.exists(output_path) is False: 98 | os.mkdir(output_path) 99 | 100 | in_c = 1 101 | out_c = in_c 102 | model_path = "./models/Epoch_19_iters_2500.model" 103 | #model_path = "./models/Final_epoch_50.model" 104 | 105 | with torch.no_grad(): 106 | 107 | model = load_model(model_path) 108 | for i in range(10): 109 | index = i + 1 110 | visible_path = vi_path + str(index) + '.bmp' 111 | infrared_path = ir_path + str(index) + '.bmp' 112 | start = time.time() 113 | run_demo(model, visible_path, infrared_path, output_path, index) 114 | end = time.time() 115 | print('time:', end - start, 'S') 116 | print('Done......') 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Training DenseFuse network 2 | # auto-encoder 3 | 4 | import os 5 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 6 | from os.path import join 7 | # import sys 8 | import time 9 | import numpy as np 10 | # from tqdm import tqdm_notebook as tqdm 11 | from tqdm import tqdm, trange 12 | from time import sleep 13 | import scipy.io as scio 14 | import random 15 | import torch 16 | import torch.nn as nn 17 | from torch.optim import Adam 18 | from torch.autograd import Variable 19 | from torch.utils.tensorboard import SummaryWriter 20 | import utils 21 | from net import net 22 | from vit import VisionTransformer 23 | 24 | from args_fusion import args 25 | import pytorch_msssim 26 | from torchvision import transforms 27 | from loss import final_ssim, final_mse, dis_loss_func 28 | from function import Vgg16 29 | import torch.nn.functional as F 30 | # from aloss import a_ssim 31 | # from hloss import h_ssim 32 | # device = torch.device("cuda:0") 33 | 34 | 35 | def main(): 36 | # original_imgs_path = utils.list_images(args.dataset) 37 | original_imgs_path2 = utils.list_images(args.dataset2) 38 | train_num = args.train_num 39 | # original_imgs_path = original_imgs_path[:train_num] 40 | original_imgs_path2 = original_imgs_path2[:train_num] 41 | random.shuffle(original_imgs_path2) 42 | # for i in range(5): 43 | i = 2 44 | train(i, original_imgs_path2) 45 | 46 | 47 | def train(i, original_imgs_path): 48 | batch_size = args.batch_size 49 | 50 | in_c = 1 # 1 - gray; 3 - RGB 51 | if in_c == 1: 52 | img_model = 'L' 53 | else: 54 | img_model = 'RGB' 55 | # model = Generator() 56 | gen = net() 57 | dis1 = Vgg16() 58 | dis2 = Vgg16() 59 | # vgg = Vgg16() 60 | # pre_model = Pre() 61 | 62 | if args.trans_model_path is not None: 63 | pre_dict = torch.load(args.trans_model_path)['state_dict'] 64 | 65 | if args.resume is not None: 66 | print('Resuming, initializing using weight from {}.'.format(args.resume)) 67 | gen.load_state_dict(torch.load(args.resume)) 68 | print(gen) 69 | 70 | #optimizer = Adam(model.parameters(), args.lr) 71 | mse_loss = torch.nn.MSELoss() 72 | L1_loss = nn.L1Loss() 73 | # ssim_loss = final_ssim 74 | ssim_loss = pytorch_msssim.ssim 75 | bce_loss = nn.BCEWithLogitsLoss() 76 | writer = SummaryWriter('./log') 77 | 78 | 79 | if args.cuda: 80 | gen.cuda() 81 | dis1.cuda() 82 | dis2.cuda() 83 | # vgg.cuda() 84 | 85 | # vgg.eval() 86 | # dis1.eval() 87 | 88 | tbar = trange(args.epochs, ncols=150) 89 | print('Start training.....') 90 | 91 | # creating save path 92 | temp_path_model = os.path.join(args.save_model_dir, args.ssim_path[i]) 93 | if os.path.exists(temp_path_model) is False: 94 | os.mkdir(temp_path_model) 95 | 96 | temp_path_loss = os.path.join(args.save_loss_dir, args.ssim_path[i]) 97 | if os.path.exists(temp_path_loss) is False: 98 | os.mkdir(temp_path_loss) 99 | 100 | 101 | # Loss_con = [] 102 | Loss_gen = [] 103 | Loss_all = [] 104 | Loss_dis1 = [] 105 | Loss_dis2 = [] 106 | 107 | all_ssim_loss = 0 108 | all_gen_loss = 0. 109 | all_dis_loss1 = 0. 110 | all_dis_loss2 = 0. 111 | w_num = 0 112 | for e in tbar: 113 | print('Epoch %d.....' % e) 114 | # load training database 115 | image_set, batches = utils.load_dataset(original_imgs_path, batch_size) 116 | gen.train() 117 | count = 0 118 | 119 | # if e != 0: 120 | # args.lr = args.lr * 0.5 121 | # if args.lr < 2e-6: 122 | # args.lr = 2e-6 123 | 124 | for batch in range(batches): 125 | 126 | image_paths = image_set[batch * batch_size:(batch * batch_size + batch_size)] 127 | # directory1 = "/data/Disk_B/KAIST-RGBIR/visible" 128 | # directory2 = "/data/Disk_B/KAIST-RGBIR/lwir" 129 | directory1 = "D:\\file\paper\dataset\dataset\\train\\vi" 130 | directory2 = "D:\\file\paper\dataset\dataset\\train\ir" 131 | paths1 = [] 132 | paths2 = [] 133 | for path in image_paths: 134 | paths1.append(join(directory1, path)) 135 | paths2.append(join(directory2, path)) 136 | # paths = [] 137 | # for path in image_paths: 138 | # paths.append(join(args.dataset, path)) 139 | 140 | # img = utils.get_train_images_auto(paths, height=args.HEIGHT, width=args.WIDTH, mode=img_model) 141 | img_vi = utils.get_train_images_auto(paths1, height=args.HEIGHT, width=args.WIDTH, mode=img_model) 142 | img_ir = utils.get_train_images_auto(paths2, height=args.HEIGHT, width=args.WIDTH, mode=img_model) 143 | 144 | 145 | count += 1 146 | 147 | optimizer_G = Adam(gen.parameters(), args.lr) 148 | optimizer_G.zero_grad() 149 | 150 | optimizer_D1 = Adam(dis1.parameters(), args.lr_d) 151 | optimizer_D1.zero_grad() 152 | 153 | optimizer_D2 = Adam(dis2.parameters(), args.lr_d) 154 | optimizer_D2.zero_grad() 155 | 156 | 157 | if args.cuda: 158 | # img = img.cuda() 159 | img_vi = img_vi.cuda() 160 | img_ir = img_ir.cuda() 161 | 162 | outputs = gen(img_vi, img_ir) 163 | # resolution loss 164 | # img = Variable(img.data.clone(), requires_grad=False) 165 | 166 | con_loss_value = 0 167 | ssim_loss_value = 0 168 | 169 | ssim_loss_temp = 1 - final_ssim(img_ir, img_vi, outputs) 170 | # con_loss_temp = final_mse(img_ir, img_vi, outputs) 171 | con_loss_temp = 0 172 | 173 | 174 | con_loss_value += con_loss_temp 175 | ssim_loss_value += ssim_loss_temp 176 | 177 | _, c, h, w = outputs.size() 178 | con_loss_value /= len(outputs) 179 | ssim_loss_value /= len(outputs) 180 | 181 | # total loss 182 | gen_loss = ssim_loss_value + con_loss_value 183 | gen_loss.backward() 184 | optimizer_G.step() 185 | # scheduler.step() 186 | 187 | #------------------------------------------------------------------------------------------------------------------- 188 | vgg_out = dis1(gen(img_vi, img_ir))[0] 189 | vgg_vi = dis1(img_vi)[0] 190 | 191 | 192 | dis_loss1 = L1_loss(vgg_out, vgg_vi) 193 | 194 | dis_loss_value1 = 0 195 | dis_loss_temp1 = dis_loss1 196 | dis_loss_value1 += dis_loss_temp1 197 | 198 | dis_loss_value1 /= len(outputs) 199 | 200 | dis_loss_value1.backward() 201 | optimizer_D1.step() 202 | # ---------------------------------------------------------------------------------------------------------------- 203 | vgg_out = dis2(gen(img_vi, img_ir))[2] 204 | vgg_ir = dis2(img_ir)[2] 205 | dis_loss2 = L1_loss(vgg_out, vgg_ir) 206 | 207 | dis_loss_value2 = 0 208 | dis_loss_temp2 = dis_loss2 209 | dis_loss_value2 += dis_loss_temp2 210 | 211 | dis_loss_value2 /= len(outputs) 212 | 213 | dis_loss_value2.backward() 214 | optimizer_D2.step() 215 | 216 | # all_con_loss += con_loss_value.item() 217 | all_ssim_loss += ssim_loss_value.item() 218 | all_dis_loss1 += dis_loss_value1.item() 219 | all_dis_loss2 += dis_loss_value2.item() 220 | all_gen_loss = all_ssim_loss 221 | if (batch + 1) % args.log_interval == 0: 222 | mesg = "{}\tEpoch {}:[{}/{}] gen loss: {:.5f} dis_ir loss: {:.5f} dis_vi loss: {:.5f}".format( 223 | time.ctime(), e + 1, count, batches, 224 | all_gen_loss / args.log_interval, 225 | all_dis_loss1 / args.log_interval, 226 | all_dis_loss2 / args.log_interval 227 | #(all_con_loss + all_ssim_loss) / args.log_interval 228 | ) 229 | tbar.set_description(mesg) 230 | # tbar.close() 231 | 232 | # tqdm.write(mesg) 233 | 234 | # all_l = (all_con_loss + all_ssim_loss) / args.log_interval 235 | # Loss_con.append(all_con_loss / args.log_interval) 236 | # Loss_ssim.append(all_ssim_loss / args.log_interval) 237 | Loss_gen.append(all_ssim_loss / args.log_interval) 238 | Loss_dis1.append(all_dis_loss1 / args.log_interval) 239 | Loss_dis2.append(all_dis_loss2 / args.log_interval) 240 | # Loss_all.append((all_con_loss + all_ssim_loss) / args.log_interval) 241 | writer.add_scalar('gen', all_gen_loss / args.log_interval, w_num) 242 | writer.add_scalar('dis_ir', all_dis_loss1 / args.log_interval, w_num) 243 | writer.add_scalar('dis_vi', all_dis_loss2 / args.log_interval, w_num) 244 | # writer.add_scalar('loss_ssim', all_ssim_loss / args.log_interval, w_num) 245 | w_num += 1 246 | 247 | all_con_loss = 0. 248 | all_ssim_loss = 0. 249 | 250 | if (batch + 1) % (args.train_num//args.batch_size) == 0: 251 | # save model 252 | gen.eval() 253 | gen.cpu() 254 | save_model_filename = "Epoch_" + str(e) + "_iters_" + str(count) + ".model" 255 | save_model_path = os.path.join(args.save_model_dir, save_model_filename) 256 | torch.save(gen.state_dict(), save_model_path) 257 | gen.train() 258 | gen.cuda() 259 | tbar.set_description("\nCheckpoint, trained model saved at", save_model_path) 260 | 261 | gen.eval() 262 | gen.cpu() 263 | save_model_filename = "Final_epoch_" + str(args.epochs) + ".model" 264 | save_model_path = os.path.join(args.save_model_dir, save_model_filename) 265 | torch.save(gen.state_dict(), save_model_path) 266 | 267 | print("\nDone, trained model saved at", save_model_path) 268 | 269 | 270 | if __name__ == "__main__": 271 | main() 272 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import listdir, mkdir, sep 3 | import random 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.autograd import Variable 8 | from args_fusion import args 9 | # from scipy.misc import imread, imsave, imresize 10 | import cv2 11 | 12 | 13 | # import matplotlib as mpl 14 | from torchvision import transforms 15 | 16 | 17 | def list_images(directory): 18 | images = [] 19 | names = [] 20 | dir = listdir(directory) 21 | dir.sort() 22 | for file in dir: 23 | name = file.lower() 24 | if name.endswith('.png'): 25 | images.append( file) 26 | elif name.endswith('.jpg'): 27 | images.append(file) 28 | elif name.endswith('.jpeg'): 29 | images.append(file) 30 | name1 = name.split('.') 31 | names.append(name1[0]) 32 | return images 33 | 34 | 35 | def tensor_load_rgbimage(filename, size=None, scale=None, keep_asp=False): 36 | img = Image.open(filename).convert('RGB') 37 | if size is not None: 38 | if keep_asp: 39 | size2 = int(size * 1.0 / img.size[0] * img.size[1]) 40 | img = img.resize((size, size2), Image.ANTIALIAS) 41 | else: 42 | img = img.resize((size, size), Image.ANTIALIAS) 43 | 44 | elif scale is not None: 45 | img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS) 46 | img = np.array(img).transpose(2, 0, 1) 47 | img = torch.from_numpy(img).float() 48 | return img 49 | 50 | 51 | def tensor_save_rgbimage(tensor, filename, cuda=True): 52 | if cuda: 53 | # img = tensor.clone().cpu().clamp(0, 255).numpy() 54 | img = tensor.cpu().clamp(0, 255).data[0].numpy() 55 | else: 56 | # img = tensor.clone().clamp(0, 255).numpy() 57 | img = tensor.clamp(0, 255).numpy() 58 | img = img.transpose(1, 2, 0).astype('uint8') 59 | img = Image.fromarray(img) 60 | img.save(filename) 61 | 62 | 63 | def tensor_save_bgrimage(tensor, filename, cuda=False): 64 | (b, g, r) = torch.chunk(tensor, 3) 65 | tensor = torch.cat((r, g, b)) 66 | tensor_save_rgbimage(tensor, filename, cuda) 67 | 68 | 69 | def gram_matrix(y): 70 | (b, ch, h, w) = y.size() 71 | features = y.view(b, ch, w * h) 72 | features_t = features.transpose(1, 2) 73 | gram = features.bmm(features_t) / (ch * h * w) 74 | return gram 75 | 76 | 77 | def matSqrt(x): 78 | U,D,V = torch.svd(x) 79 | return U * (D.pow(0.5).diag()) * V.t() 80 | 81 | 82 | # load training images 83 | def load_dataset(image_path, BATCH_SIZE, num_imgs=None): 84 | if num_imgs is None: 85 | num_imgs = len(image_path) 86 | original_imgs_path = image_path[:num_imgs] 87 | # random 88 | random.shuffle(original_imgs_path) 89 | mod = num_imgs % BATCH_SIZE 90 | print('BATCH SIZE %d.' % BATCH_SIZE) 91 | print('Train images number %d.' % num_imgs) 92 | print('Train images samples %s.' % str(num_imgs / BATCH_SIZE)) 93 | 94 | if mod > 0: 95 | print('Train set has been trimmed %d samples...\n' % mod) 96 | original_imgs_path = original_imgs_path[:-mod] 97 | batches = int(len(original_imgs_path) // BATCH_SIZE) 98 | return original_imgs_path, batches 99 | 100 | 101 | def get_image(path, height=256, width=256, mode='L'): 102 | if mode == 'L': 103 | image = cv2.imread(path, 0) 104 | elif mode == 'RGB': 105 | image = cv2.cvtColor(path, cv2.COLOR_BGR2RGB) 106 | 107 | if height is not None and width is not None: 108 | image = cv2.resize(image, (width, height), interpolation=cv2.INTER_LINEAR) 109 | return image 110 | 111 | 112 | def get_train_images_auto(paths, height=256, width=256, mode='RGB'): 113 | if isinstance(paths, str): 114 | paths = [paths] 115 | images = [] 116 | for path in paths: 117 | image = get_image(path, height, width, mode=mode) 118 | if mode == 'L': 119 | image = np.reshape(image, [1, image.shape[0], image.shape[1]]) 120 | else: 121 | image = np.reshape(image, [image.shape[2], image.shape[0], image.shape[1]]) 122 | images.append(image) 123 | 124 | images = np.stack(images, axis=0) 125 | images = torch.from_numpy(images).float() 126 | return images 127 | 128 | # def get_test_images(paths, height=None, width=None, mode='L'): 129 | # ImageToTensor = transforms.Compose([transforms.ToTensor()]) 130 | # if isinstance(paths, str): 131 | # paths = [paths] 132 | # images = [] 133 | # for path in paths: 134 | # image = get_image(path, height, width, mode=mode) 135 | # w, d = image.shape[0], image.shape[1] 136 | # w = int(w / 32) * 32 137 | # d = int(d / 32) * 32 138 | # image = cv2.resize(image, [d, w]) 139 | # if mode == 'L': 140 | # image = np.reshape(image, [1, image.shape[0], image.shape[1]]) 141 | # else: 142 | # # test = ImageToTensor(image).numpy() 143 | # # shape = ImageToTensor(image).size() 144 | # image = ImageToTensor(image).float().numpy()*255 145 | # images.append(image) 146 | # images = np.stack(images, axis=0) 147 | # images = torch.from_numpy(images).float() 148 | # return images 149 | def get_test_images(paths, height=None, width=None, mode='L'): 150 | ImageToTensor = transforms.Compose([transforms.ToTensor()]) 151 | if isinstance(paths, str): 152 | paths = [paths] 153 | images = [] 154 | for path in paths: 155 | image = get_image(path, height, width, mode=mode) 156 | w, h = image.shape[0], image.shape[1] 157 | w_s = 256 - w % 256 158 | h_s = 256 - h % 256 159 | image = cv2.copyMakeBorder(image, 0, w_s, 0, h_s, cv2.BORDER_CONSTANT, 160 | value=128) 161 | if mode == 'L': 162 | image = np.reshape(image, [1, image.shape[0], image.shape[1]]) 163 | else: 164 | image = ImageToTensor(image).float().numpy()*255 165 | images.append(image) 166 | images = np.stack(images, axis=0) 167 | images = torch.from_numpy(images).float() 168 | return images 169 | 170 | def patch_test(paths, height=None, width=None, mode='L'): 171 | ImageToTensor = transforms.Compose([transforms.ToTensor()]) 172 | if isinstance(paths, str): 173 | paths = [paths] 174 | images = [] 175 | for path in paths: 176 | image = get_image(path, height, width, mode=mode) 177 | w, h = image.shape[0], image.shape[1] 178 | w_s = 256 - w % 256 179 | h_s = 256 - h % 256 180 | image = cv2.copyMakeBorder(image, 0, w_s, 0, h_s, cv2.BORDER_CONSTANT, value=128) 181 | nw = (w // 256 + 1) 182 | nh = (h // 256 + 1) 183 | crop = [] 184 | if mode == 'L': 185 | for j in range(nh): 186 | for i in range(nw): 187 | crop.append(image[i*256:(i+1)*256, j*256:(j+1)*256]) 188 | crop = np.stack(crop, axis=0) 189 | # image = np.reshape(image, [1, image.shape[0], image.shape[1]]) 190 | else: 191 | image = ImageToTensor(image).float().numpy() * 255 192 | images.append(crop) 193 | images = np.stack(images, axis=1) 194 | images = torch.from_numpy(images).float() 195 | return images 196 | 197 | 198 | # colormap 199 | # def colormap(): 200 | # return mpl.colors.LinearSegmentedColormap.from_list('cmap', ['#FFFFFF', '#98F5FF', '#00FF00', '#FFFF00','#FF0000', '#8B0000'], 256) 201 | 202 | 203 | def save_patch_images(path, data, out): 204 | 205 | if data.shape[1] == 1: 206 | data = data.reshape([data.shape[0], data.shape[2], data.shape[3]]) 207 | w, h = out.shape[0], out.shape[1] 208 | nw = (w // 256 + 1) 209 | nh = (h // 256 + 1) 210 | result = np.zeros((nw*256, nh*256)) 211 | num = 0 212 | for j in range(nh): 213 | for i in range(nw): 214 | result[i * 256:(i + 1) * 256, j * 256:(j + 1) * 256] = data[num] 215 | num += 1 216 | 217 | ori = result[0:w, 0:h] 218 | cv2.imwrite(path, ori) 219 | 220 | def save_images(path, data, out): 221 | w, h = out.shape[0], out.shape[1] 222 | if data.shape[1] == 1: 223 | data = data.reshape([data.shape[2], data.shape[3]]) 224 | ori = data[0:w, 0:h] 225 | cv2.imwrite(path, ori) 226 | -------------------------------------------------------------------------------- /utils/data_vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def plot_img_and_mask(img, mask): 5 | classes = mask.shape[2] if len(mask.shape) > 2 else 1 6 | fig, ax = plt.subplots(1, classes + 1) 7 | ax[0].set_title('Input image') 8 | ax[0].imshow(img) 9 | if classes > 1: 10 | for i in range(classes): 11 | ax[i+1].set_title(f'Output mask (class {i+1})') 12 | ax[i+1].imshow(mask[:, :, i]) 13 | else: 14 | ax[1].set_title(f'Output mask') 15 | ax[1].imshow(mask) 16 | plt.xticks([]), plt.yticks([]) 17 | plt.show() 18 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | from os.path import splitext 2 | from os import listdir 3 | import numpy as np 4 | from glob import glob 5 | import torch 6 | from torch.utils.data import Dataset 7 | import logging 8 | from PIL import Image 9 | 10 | 11 | class BasicDataset(Dataset): 12 | def __init__(self, imgs_dir, masks_dir, scale=1): 13 | self.imgs_dir = imgs_dir 14 | self.masks_dir = masks_dir 15 | self.scale = scale 16 | assert 0 < scale <= 1, 'Scale must be between 0 and 1' 17 | 18 | self.ids = [splitext(file)[0] for file in listdir(imgs_dir) 19 | if not file.startswith('.')] 20 | logging.info(f'Creating dataset with {len(self.ids)} examples') 21 | 22 | def __len__(self): 23 | return len(self.ids) 24 | 25 | @classmethod 26 | def preprocess(cls, pil_img, scale): 27 | w, h = pil_img.size 28 | newW, newH = int(scale * w), int(scale * h) 29 | assert newW > 0 and newH > 0, 'Scale is too small' 30 | pil_img = pil_img.resize((newW, newH)) 31 | 32 | img_nd = np.array(pil_img) 33 | 34 | if len(img_nd.shape) == 2: 35 | img_nd = np.expand_dims(img_nd, axis=2) 36 | 37 | # HWC to CHW 38 | img_trans = img_nd.transpose((2, 0, 1)) 39 | if img_trans.max() > 1: 40 | img_trans = img_trans / 255 41 | 42 | return img_trans 43 | 44 | def __getitem__(self, i): 45 | idx = self.ids[i] 46 | mask_file = glob(self.masks_dir + idx + '*') 47 | img_file = glob(self.imgs_dir + idx + '*') 48 | 49 | assert len(mask_file) == 1, \ 50 | f'Either no mask or multiple masks found for the ID {idx}: {mask_file}' 51 | assert len(img_file) == 1, \ 52 | f'Either no image or multiple images found for the ID {idx}: {img_file}' 53 | mask = Image.open(mask_file[0]) 54 | img = Image.open(img_file[0]) 55 | 56 | assert img.size == mask.size, \ 57 | f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}' 58 | 59 | img = self.preprocess(img, self.scale) 60 | mask = self.preprocess(mask, self.scale) 61 | 62 | return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)} 63 | -------------------------------------------------------------------------------- /vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class PositionEmbs(nn.Module): 8 | def __init__(self, num_patches, emb_dim, dropout_rate=0.1): 9 | super(PositionEmbs, self).__init__() 10 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim)) 11 | if dropout_rate > 0: 12 | self.dropout = nn.Dropout(dropout_rate) 13 | else: 14 | self.dropout = None 15 | 16 | def forward(self, x): 17 | out = x + self.pos_embedding 18 | 19 | if self.dropout: 20 | out = self.dropout(out) 21 | 22 | return out 23 | 24 | 25 | class MlpBlock(nn.Module): 26 | """ Transformer Feed-Forward Block """ 27 | def __init__(self, in_dim, mlp_dim, out_dim, dropout_rate=0.1): 28 | super(MlpBlock, self).__init__() 29 | 30 | # init layers 31 | self.fc1 = nn.Linear(in_dim, mlp_dim) 32 | self.fc2 = nn.Linear(mlp_dim, out_dim) 33 | self.act = nn.GELU() 34 | if dropout_rate > 0.0: 35 | self.dropout1 = nn.Dropout(dropout_rate) 36 | self.dropout2 = nn.Dropout(dropout_rate) 37 | else: 38 | self.dropout1 = None 39 | self.dropout2 = None 40 | 41 | def forward(self, x): 42 | 43 | out = self.fc1(x) 44 | out = self.act(out) 45 | if self.dropout1: 46 | out = self.dropout1(out) 47 | 48 | out = self.fc2(out) 49 | out = self.dropout2(out) 50 | return out 51 | 52 | 53 | class LinearGeneral(nn.Module): 54 | def __init__(self, in_dim=(768,), feat_dim=(12, 64)): 55 | super(LinearGeneral, self).__init__() 56 | 57 | self.weight = nn.Parameter(torch.randn(*in_dim, *feat_dim)) 58 | self.bias = nn.Parameter(torch.zeros(*feat_dim)) 59 | 60 | def forward(self, x, dims): 61 | a = torch.tensordot(x, self.weight, dims=dims) + self.bias 62 | return a 63 | 64 | 65 | class SelfAttention(nn.Module): 66 | def __init__(self, in_dim, heads=8, dropout_rate=0.1): 67 | super(SelfAttention, self).__init__() 68 | self.heads = heads 69 | self.head_dim = in_dim // heads 70 | self.scale = self.head_dim ** 0.5 71 | 72 | self.query = LinearGeneral((in_dim,), (self.heads, self.head_dim)) 73 | self.key = LinearGeneral((in_dim,), (self.heads, self.head_dim)) 74 | self.value = LinearGeneral((in_dim,), (self.heads, self.head_dim)) 75 | self.out = LinearGeneral((self.heads, self.head_dim), (in_dim,)) 76 | 77 | if dropout_rate > 0: 78 | self.dropout = nn.Dropout(dropout_rate) 79 | else: 80 | self.dropout = None 81 | 82 | def forward(self, x): 83 | b, n, _ = x.shape 84 | 85 | q = self.query(x, dims=([2], [0])) 86 | k = self.key(x, dims=([2], [0])) 87 | v = self.value(x, dims=([2], [0])) 88 | 89 | q = q.permute(0, 2, 1, 3) 90 | k = k.permute(0, 2, 1, 3) 91 | v = v.permute(0, 2, 1, 3) 92 | 93 | attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale 94 | attn_weights = F.softmax(attn_weights, dim=-1) 95 | out = torch.matmul(attn_weights, v) 96 | out = out.permute(0, 2, 1, 3) 97 | 98 | out = self.out(out, dims=([2, 3], [0, 1])) 99 | 100 | return out 101 | 102 | 103 | class EncoderBlock(nn.Module): 104 | def __init__(self, in_dim, mlp_dim, num_heads, dropout_rate=0.1, attn_dropout_rate=0.1): 105 | super(EncoderBlock, self).__init__() 106 | 107 | self.norm1 = nn.LayerNorm(in_dim) 108 | self.attn = SelfAttention(in_dim, heads=num_heads, dropout_rate=attn_dropout_rate) 109 | if dropout_rate > 0: 110 | self.dropout = nn.Dropout(dropout_rate) 111 | else: 112 | self.dropout = None 113 | self.norm2 = nn.LayerNorm(in_dim) 114 | self.mlp = MlpBlock(in_dim, mlp_dim, in_dim, dropout_rate) 115 | 116 | def forward(self, x): 117 | residual = x 118 | out = self.norm1(x) 119 | out = self.attn(out) 120 | if self.dropout: 121 | out = self.dropout(out) 122 | out += residual 123 | residual = out 124 | 125 | out = self.norm2(out) 126 | out = self.mlp(out) 127 | out += residual 128 | return out 129 | 130 | 131 | class Encoder(nn.Module): 132 | def __init__(self, num_patches, emb_dim, mlp_dim, num_layers=12, num_heads=12, dropout_rate=0.1, attn_dropout_rate=0.0): 133 | super(Encoder, self).__init__() 134 | 135 | # positional embedding 136 | self.pos_embedding = PositionEmbs(num_patches, emb_dim, dropout_rate) 137 | 138 | # encoder blocks 139 | in_dim = emb_dim 140 | self.encoder_layers = nn.ModuleList() 141 | for i in range(num_layers): 142 | layer = EncoderBlock(in_dim, mlp_dim, num_heads, dropout_rate, attn_dropout_rate) 143 | self.encoder_layers.append(layer) 144 | self.norm = nn.LayerNorm(in_dim) 145 | 146 | def forward(self, x): 147 | 148 | out = self.pos_embedding(x) 149 | 150 | for layer in self.encoder_layers: 151 | out = layer(out) 152 | 153 | out = self.norm(out) 154 | return out 155 | 156 | 157 | class VisionTransformer(nn.Module): 158 | """ Vision Transformer """ 159 | def __init__(self, 160 | image_size=(224, 224), 161 | patch_size=(16, 16), 162 | emb_dim=768, 163 | mlp_dim=3072, 164 | num_heads=12, 165 | num_layers=4, 166 | num_classes=1, 167 | attn_dropout_rate=0.0, 168 | dropout_rate=0.1, 169 | feat_dim=None): 170 | super(VisionTransformer, self).__init__() 171 | h, w = image_size 172 | 173 | # embedding layer 174 | fh, fw = patch_size 175 | gh, gw = h // fh, w // fw 176 | num_patches = gh * gw 177 | self.embedding = nn.Conv2d(3, emb_dim, kernel_size=(fh, fw), stride=(fh, fw)) 178 | # class token 179 | self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim)) 180 | 181 | # transformer 182 | self.transformer = Encoder( 183 | num_patches=num_patches, 184 | emb_dim=emb_dim, 185 | mlp_dim=mlp_dim, 186 | num_layers=num_layers, 187 | num_heads=num_heads, 188 | dropout_rate=dropout_rate, 189 | attn_dropout_rate=attn_dropout_rate) 190 | 191 | # classfier 192 | # self.classifier = nn.Linear(emb_dim, num_classes) 193 | self.conv = nn.Conv2d(1, 3, 1, 1) 194 | self.f = nn.Linear(emb_dim, num_classes) 195 | 196 | def forward(self, x): 197 | x = self.conv(x) 198 | emb = self.embedding(x) # (n, c, gh, gw) 199 | emb = emb.permute(0, 2, 3, 1) # (n, gh, hw, c) 200 | b, h, w, c = emb.shape 201 | emb = emb.reshape(b, h * w, c) 202 | 203 | # prepend class token 204 | cls_token = self.cls_token.repeat(b, 1, 1) 205 | emb = torch.cat([cls_token, emb], dim=1) 206 | 207 | # transformer 208 | feat = self.transformer(emb) 209 | 210 | # classifier 211 | logits = self.f(feat[:, 0]) 212 | return logits 213 | 214 | 215 | # if __name__ == '__main__': 216 | # model = VisionTransformer(num_layers=2) 217 | # x = torch.randn((2, 3, 256, 256)) 218 | # out = model(x) 219 | # 220 | # state_dict = model.state_dict() 221 | # 222 | # for key, value in state_dict.items(): 223 | # print("{}: {}".format(key, value.shape)) 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | --------------------------------------------------------------------------------