├── .github └── workflows │ └── sync.yml ├── .gitignore ├── LICENSE ├── README.md ├── archive ├── __init__.py ├── fuse.py ├── model.py └── params.pth ├── exp ├── __init__.py ├── find_adjust │ ├── __init__.py │ └── find_adjust.py └── test_register │ ├── __init__.py │ ├── test_register.py │ └── weights │ ├── m-register.ckpt │ └── u-register.ckpt ├── lightning ├── __init__.py ├── auto_rf.py └── reco.py ├── modules ├── __init__.py ├── functions │ ├── __init__.py │ ├── integrate.py │ └── transformer.py ├── fuser.py ├── layers │ ├── __init__.py │ ├── conv_group.py │ ├── d_group.py │ ├── u_group.py │ └── u_net.py ├── m_register.py ├── random_adjust.py └── u_register.py ├── scripts ├── pred.py └── train.py └── utils ├── __init__.py ├── choose_images.py └── pretty_vars.py /.github/workflows/sync.yml: -------------------------------------------------------------------------------- 1 | name: Mirror to DUT DIMT 2 | 3 | on : [ push, delete, create ] 4 | 5 | jobs: 6 | git-mirror: 7 | runs-on: ubuntu-latest 8 | steps : 9 | - 10 | name: Configure Private Key 11 | env : 12 | SSH_PRIVATE_KEY: ${{ secrets.PRIVATE_KEY }} 13 | run : | 14 | mkdir -p ~/.ssh 15 | echo "$SSH_PRIVATE_KEY" > ~/.ssh/id_rsa 16 | chmod 600 ~/.ssh/id_rsa 17 | echo "StrictHostKeyChecking no" >> ~/.ssh/config 18 | - 19 | name: Push Mirror 20 | env : 21 | SOURCE_REPO : 'https://github.com/MisakiCoca/ReCoNet.git' 22 | DESTINATION_REPO: 'git@github.com:dlut-dimt/ReCoNet.git' 23 | run : | 24 | git clone --mirror "$SOURCE_REPO" && cd `basename "$SOURCE_REPO"` 25 | git remote set-url --push origin "$DESTINATION_REPO" 26 | git fetch -p origin 27 | git for-each-ref --format 'delete %(refname)' refs/pull | git update-ref --stdin 28 | git push --mirror 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # JetBarins 2 | .idea/* 3 | 4 | # macOS 5 | .DS_*/* 6 | 7 | # data 8 | data/* 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zhanbo Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReCoNet 2 | 3 | ![visitors](https://visitor-badge.glitch.me/badge?page_id=MisakiCoca.ReCoNet) 4 | 5 | Zhanbo Huang, Jinyuan Liu, Xin Fan*, Risheng Liu, Wei Zhong, Zhongxuan Luo. 6 | **"Recurrent Correction Network for Fast and Efficient Multi-modality Image Fusion"**, European Conference on Computer 7 | Vision **(ECCV)**, 2022. 8 | 9 | ## Milestone 10 | 11 | In the near future, we will publish the following materials. 12 | 13 | * v0 [ECCV]: Fuse network (ReCo) with pre-trained parameters for generating results in paper. **Finished** 14 | * v1: A new script & architecture of ReCo+ for fast training & prediction. **Building** 15 | * v1: A highly robust pre-trained parameters for ReCo+ based on realistic scene training. (We are collecting data with 16 | realistic implications.) 17 | 18 | ## Update 19 | 20 | [2022-07-15] Train script for ReCo(v1) is available! 21 | 22 | [2022-07-13] Preview of micro-register is available! 23 | 24 | [2022-07-12] The ReCo(v0) is available! 25 | 26 | ## Requirements 27 | 28 | * Python 3.10 29 | * PyTorch 1.12 30 | * TorchVision 0.13.0 31 | * PyTorch lightning 0.8.5 32 | * Kornia 0.6.5 33 | 34 | ## Extended Experiments 35 | 36 | ### Generate fake visible images 37 | 38 | To generating fake visible images as described in our paper, you can refer to my 39 | another repository [complex-deformation](https://github.com/MisakiCoca/complex-deformation), which is a component of 40 | this work. 41 | 42 | It shows how we can deform the image and generate a restored field that **approximates** the ground truth. 43 | 44 | ### Have a quick preview of our micro-register 45 | 46 | To give a quick preview of our micro-register module, you can try the training & prediction based on 47 | the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset. 48 | 49 | Activate your conda environment and enter folder `exp/test_register`. 50 | 51 | 1. To train the register yourself, you just need to run this code. 52 | 53 | ```shell 54 | export PYTHONPATH="${PYTHONPATH}:$RECO_ROOT" 55 | python test_register.py --backbone $BACKBONE --dst $DST 56 | ``` 57 | 58 | The `$RECO_ROOT` is the root path of ReCo repository, like `~/lab/reco`, the `$BACKBONE` denotes which architecture to 59 | use `m`-`micro` or `u`-`unet`. 60 | 61 | We will do following things automatically: download MNIST dataset, train the register, and save predictions in `$DST`. 62 | 63 | 2. If you just want to test the performance, we offer pre-trained parameters for both `micro` and `unet` based register. 64 | 65 | ```shell 66 | export PYTHONPATH="${PYTHONPATH}:$RECO_ROOT" 67 | python test_register.py --backbone $BACKBONE --dst $DST --only_pred 68 | ``` 69 | 70 | The prediction results will be save in `$DST` and the patches from left to right are `moving`, `fixed` and `moved`, 71 | respectively. 72 | 73 | ## Get start (v0) (**Recommended for Now**) 74 | 75 | 1. To use our pre-trained parameters of ECCV-22 for fusion, you need to prepare your dataset in `$ROOT/data/$NAME`. 76 | 77 | ``` 78 | $DATA (dataset name, like: tno) 79 | ├── ir 80 | ├── vi 81 | ``` 82 | 83 | 2. Enter the archive folder `cd archive`, and activate your conda environment `conda activate $CONDA_ENV`. 84 | 85 | ```shell 86 | export PYTHONPATH="${PYTHONPATH}:$RECO_ROOT" 87 | python fuse.py --ir ../data/$DATA/ir --vi ../data/$DATA/vi --dst $SAVE_TO_WHERE 88 | ``` 89 | 90 | 3. Now, you will find the fusion results in `$SAVE_TO_WHERE`, this operation will create output folder automatically. 91 | 92 | ## Get start (v1) **Preview Version** 93 | 94 | **Only recommended if you are intending in training ReCo+ yourself.** 95 | 96 | **Note that: Due to the instability of the micro-register module in the future, we recommend training only the fusion 97 | part.** 98 | 99 | 1. To use the script to train ReCo+ yourself, you need to prepare your dataset in `$ROOT/data/$NAME`. 100 | 101 | ``` 102 | $DATA (dataset name, like: tno) 103 | ├── ir 104 | ├── vi 105 | ├── iqa (new for v1, optional) 106 | | | ├── ir (information measurement for infrared images) 107 | | | ├── vi (information measurement for visible images) 108 | ├── meta (new for v1) 109 | | | ├── train.txt (which images are used for training) 110 | | | ├── val.txt (which images are used for validation) 111 | | | ├── pred.txt (which images are used for prediction) 112 | ``` 113 | 114 | 2. Activate your conda environment `conda activate $CONDA_ENV`. 115 | 116 | ```shell 117 | # set project path for python 118 | export PYTHONPATH="${PYTHONPATH}:$RECO_ROOT" 119 | # only train fuse part (ReCo) **current recommended** 120 | python train.py --data data/$DATA --ckpt $CHECKPOINT_PATH --lr 1e-3 121 | # train registration and fuse (ReCo+) 122 | python train.py --register m --data data/$DATA --ckpt $CHECKPOINT_PATH --lr 1e-3 --deform $DEFORM_LEVEL 123 | ``` 124 | 125 | The `$DEFORM_LEVEL` should be `easy`, `normal` or `hard`. 126 | 127 | ⚠️ Limitations: As mentioned in the paper, when the difference between mid-wave infrared and visible images in your 128 | dataset is too large, the register may not converge properly. 129 | 130 | 3. To generate the fusion images with pre-trained parameters, just run the following. 131 | 132 | ```shell 133 | # set project path for python 134 | export PYTHONPATH="${PYTHONPATH}:$RECO_ROOT" 135 | # only fuse part (ReCo) **current recommended** 136 | python pred.py --data $data/$DATA --ckpt $CHECKPOINT_PATH --dst $SAVE_TO_WHERE 137 | # registration & fuse (ReCo+) 138 | python pred.py --register m --data $data/$DATA --ckpt $CHECKPOINT_PATH --dst $SAVE_TO_WHERE 139 | ``` 140 | -------------------------------------------------------------------------------- /archive/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlut-dimt/ReCoNet/0870f5f37852ace85850d9fd2d4570eb57682d22/archive/__init__.py -------------------------------------------------------------------------------- /archive/fuse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | import statistics 4 | import time 5 | 6 | import cv2 7 | import kornia 8 | import torch 9 | import torch.backends.cudnn 10 | from tqdm import tqdm 11 | 12 | from archive.model import Fuser 13 | 14 | 15 | class Fuse: 16 | """ 17 | Fuse images with given args. 18 | """ 19 | 20 | def __init__(self, checkpoint: pathlib.Path, loop_num: int = 3, dim: int = 64): 21 | """ 22 | Init model and load pre-trained parameters. 23 | :param checkpoint: pre-trained model checkpoint 24 | :param loop_num: AFuse recurrent loop number, default: 3 25 | :param dim: AFuse feather number, default: 64 26 | """ 27 | 28 | # device [cuda or cpu] 29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | self.device = device 31 | 32 | # load pre-trained network 33 | net = Fuser(loop_num=loop_num, feather_num=dim) 34 | net.load_state_dict(torch.load(str(checkpoint), map_location='cpu')) 35 | net.to(device) 36 | net.eval() 37 | self.net = net 38 | 39 | @torch.no_grad() 40 | def __call__(self, ir_path: pathlib.Path, vi_path: pathlib.Path, dst: pathlib.Path): 41 | """ 42 | Fuse image with infrared folder, visible folder and destination path. 43 | :param ir_path: infrared folder path 44 | :param vi_path: visible folder path 45 | :param dst: fused images destination path 46 | """ 47 | 48 | # src list 49 | ir_list = [x for x in ir_path.glob('*') if x.suffix in ['.bmp', '.jpg', '.png']] 50 | vi_list = [x for x in vi_path.glob('*') if x.suffix in ['.bmp', '.jpg', '.png']] 51 | 52 | # time record 53 | fuse_time = [] 54 | 55 | # fuse images 56 | src = tqdm(zip(ir_list, vi_list)) 57 | for ir_path, vi_path in src: 58 | "fuse one pair with src image path" 59 | 60 | # judge image pair 61 | assert ir_path.name == vi_path.name 62 | src.set_description(f'fuse {ir_path.name}') 63 | 64 | # read image with Tensor 65 | ir = self._imread(ir_path).unsqueeze(0) 66 | vi = self._imread(vi_path).unsqueeze(0) 67 | ir = ir.to(self.device) 68 | vi = vi.to(self.device) 69 | 70 | # network flow 71 | torch.cuda.synchronize() if str(self.device) == 'cuda' else None 72 | start = time.time() 73 | im_f, _, _ = self.net([ir, vi]) 74 | torch.cuda.synchronize() if str(self.device) == 'cuda' else None 75 | end = time.time() 76 | fuse_time.append(end - start) 77 | 78 | # save fusion image 79 | self._imsave(dst / ir_path.name, im_f[-1]) 80 | 81 | # analyze fuse time 82 | std = statistics.stdev(fuse_time[1:]) 83 | avg = statistics.mean(fuse_time[1:]) 84 | print(f'fuse std time: {std:.4f}(s)') 85 | print(f'fuse avg time: {avg:.4f}(s)') 86 | print('fps (equivalence): {:.4f}'.format(1. / avg)) 87 | 88 | @staticmethod 89 | def _imread(path: pathlib.Path, flags=cv2.IMREAD_GRAYSCALE) -> torch.Tensor: 90 | im_cv = cv2.imread(str(path), flags) 91 | im_ts = kornia.utils.image_to_tensor(im_cv / 255.0).type(torch.FloatTensor) 92 | return im_ts 93 | 94 | @staticmethod 95 | def _imsave(path: pathlib.Path, image: torch.Tensor): 96 | im_ts = image.squeeze().cpu() 97 | path.parent.mkdir(parents=True, exist_ok=True) 98 | im_cv = kornia.utils.tensor_to_image(im_ts) * 255. 99 | cv2.imwrite(str(path), im_cv) 100 | 101 | 102 | def hyper_args(): 103 | """ 104 | get hyper parameters from args 105 | """ 106 | 107 | parser = argparse.ArgumentParser(description='ReCo(v0) fuse process') 108 | 109 | # dataset 110 | parser.add_argument('--ir', default='../data/tno/ir', help='infrared image folder') 111 | parser.add_argument('--vi', default='../data/tno/vi', help='visible image folder') 112 | parser.add_argument('--dst', default='../runs/archive', help='fuse image save folder') 113 | # checkpoint 114 | parser.add_argument('--cp', default='params.pth', help='weight checkpoint') 115 | # fuse network 116 | parser.add_argument('--loop', default=3, type=int, help='fuse loop time') 117 | parser.add_argument('--dim', default=64, type=int, help='fuse feather dim') 118 | 119 | args = parser.parse_args() 120 | return args 121 | 122 | 123 | if __name__ == '__main__': 124 | # hyper parameters 125 | args = hyper_args() 126 | 127 | f = Fuse(checkpoint=pathlib.Path(args.cp), loop_num=args.loop, dim=args.dim) 128 | f(pathlib.Path(args.ir), pathlib.Path(args.vi), pathlib.Path(args.dst)) 129 | -------------------------------------------------------------------------------- /archive/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Fuser(nn.Module): 6 | """ 7 | Fuse the two input images. 8 | """ 9 | 10 | def __init__(self, loop_num=3, feather_num=64, fine_tune=False): 11 | super().__init__() 12 | self.loop_num = loop_num 13 | self.fine_tune = fine_tune 14 | 15 | # attention layer 16 | self.att_a_conv = nn.Conv2d(2, 1, 3, padding=1, bias=False) 17 | self.att_b_conv = nn.Conv2d(2, 1, 3, padding=1, bias=False) 18 | 19 | # dilation conv layer 20 | self.dil_conv_1 = nn.Sequential(nn.Conv2d(3, feather_num, 3, 1, 1, 1), nn.BatchNorm2d(feather_num), nn.ReLU()) 21 | self.dil_conv_2 = nn.Sequential(nn.Conv2d(3, feather_num, 3, 1, 2, 2), nn.BatchNorm2d(feather_num), nn.ReLU()) 22 | self.dil_conv_3 = nn.Sequential(nn.Conv2d(3, feather_num, 3, 1, 3, 3), nn.BatchNorm2d(feather_num), nn.ReLU()) 23 | 24 | # fuse conv layer 25 | self.fus_conv = nn.Sequential(nn.Conv2d(3 * feather_num, 1, 3, padding=1), nn.BatchNorm2d(1), nn.Tanh()) 26 | 27 | def forward(self, im_p): 28 | """ 29 | :param im_p: image pair 30 | """ 31 | 32 | # unpack im_p 33 | im_a, im_b = im_p 34 | 35 | # recurrent sub network 36 | # generate f_0 with manual function 37 | im_f = [torch.max(im_a, im_b)] # init im_f_0 38 | att_a = [] 39 | att_b = [] 40 | 41 | # loop in sub network 42 | for e in range(self.loop_num): 43 | im_f_x, att_a_x, att_b_x = self._sub_forward(im_a, im_b, im_f[-1]) 44 | im_f.append(im_f_x) 45 | att_a.append(att_a_x) 46 | att_b.append(att_b_x) 47 | 48 | # return im_f, att list 49 | return im_f, att_a, att_b 50 | 51 | def _sub_forward(self, im_a, im_b, im_f): 52 | # attention 53 | att_a = self._attention(self.att_a_conv, im_a, im_f) 54 | att_b = self._attention(self.att_b_conv, im_b, im_f) 55 | att_a = att_a.detach() if self.fine_tune else att_a 56 | att_b = att_b.detach() if self.fine_tune else att_b 57 | 58 | # focus on attention 59 | im_a_att = im_a * att_a 60 | im_b_att = im_b * att_b 61 | 62 | # image concat 63 | im_cat = torch.cat([im_a_att, im_f, im_b_att], dim=1) 64 | im_cat = im_cat.detach() if self.fine_tune else im_cat 65 | 66 | # dilation 67 | dil_1 = self.dil_conv_1(im_cat) 68 | dil_2 = self.dil_conv_2(im_cat) 69 | dil_3 = self.dil_conv_3(im_cat) 70 | 71 | # feather concat 72 | f_cat = torch.cat([dil_1, dil_2, dil_3], dim=1) 73 | 74 | # fuse 75 | im_f_n = self.fus_conv(f_cat) 76 | 77 | return im_f_n, att_a, att_b 78 | 79 | @staticmethod 80 | def _attention(att_conv, im_x, im_f): 81 | x = torch.cat([im_x, im_f], dim=1) 82 | x_max, _ = torch.max(x, dim=1, keepdim=True) 83 | x_avg = torch.mean(x, dim=1, keepdim=True) 84 | x = torch.cat([x_max, x_avg], dim=1) 85 | x = att_conv(x) 86 | return torch.sigmoid(x) 87 | -------------------------------------------------------------------------------- /archive/params.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlut-dimt/ReCoNet/0870f5f37852ace85850d9fd2d4570eb57682d22/archive/params.pth -------------------------------------------------------------------------------- /exp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlut-dimt/ReCoNet/0870f5f37852ace85850d9fd2d4570eb57682d22/exp/__init__.py -------------------------------------------------------------------------------- /exp/find_adjust/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlut-dimt/ReCoNet/0870f5f37852ace85850d9fd2d4570eb57682d22/exp/find_adjust/__init__.py -------------------------------------------------------------------------------- /exp/find_adjust/find_adjust.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | import torch 5 | from kornia import image_to_tensor 6 | from kornia.filters import canny 7 | from torch import Tensor 8 | from torchvision.transforms import Normalize 9 | from torchvision.utils import save_image 10 | 11 | 12 | def gray_read(img_path: str | Path) -> Tensor: 13 | img_n = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE) 14 | img_t = image_to_tensor(img_n).float() / 255 15 | return img_t 16 | 17 | 18 | def find_adjust(): 19 | # init normalize 20 | norm = Normalize(mean=0.44, std=0.27) 21 | # load source image 22 | ir = gray_read('../../data/tno/ir/A_028.bmp') 23 | vi = gray_read('../../data/tno/vi/A_028.bmp') 24 | # apply canny filter 25 | ir_mag, ir_e = canny(ir.unsqueeze(0)) 26 | vi_mag, vi_e = canny(vi.unsqueeze(0)) 27 | # magnitude max 28 | ir_max = torch.where(ir_mag > 0.1, 1, 0) 29 | vi_max = torch.where(vi_mag > 0.1, 1, 0) 30 | # output 31 | img = torch.hstack([x.squeeze() for x in [ir_mag, ir_e, ir_max, vi_mag, vi_e, vi_max]]) 32 | save_image(img, 'tmp.jpg') 33 | 34 | 35 | if __name__ == '__main__': 36 | find_adjust() 37 | -------------------------------------------------------------------------------- /exp/test_register/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlut-dimt/ReCoNet/0870f5f37852ace85850d9fd2d4570eb57682d22/exp/test_register/__init__.py -------------------------------------------------------------------------------- /exp/test_register/test_register.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | from pathlib import Path 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from pytorch_lightning import Callback 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from torch import Tensor 9 | from torch.nn.functional import mse_loss 10 | from torch.utils.data import DataLoader 11 | from torchvision import transforms 12 | from torchvision.datasets import MNIST 13 | from torchvision.utils import save_image 14 | 15 | from modules.functions.integrate import integrate 16 | from modules.functions.transformer import transformer 17 | from modules.m_register import MRegister 18 | from modules.u_register import URegister 19 | 20 | 21 | class RegisterTest(pl.LightningModule): 22 | def __init__(self, args: Namespace): 23 | super().__init__() 24 | if args.backbone == 'm': 25 | print('Init with m-register') 26 | self.register = MRegister(in_c=2, dim=16, k_sizes=(3, 5, 7), use_bn=False, use_rs=True) 27 | elif args.backbone == 'u': 28 | print('Init with u-register') 29 | self.register = URegister(in_c=2, dim=16) 30 | else: 31 | assert NotImplemented, f'No match backbone: {args.backbone}' 32 | 33 | def training_step(self, batch, batch_idx): 34 | # batch: [b, 1, h, w] -> [b/2, 2, h, w] -> [b/2, (moving, fixed), h, w] 35 | img, _ = batch 36 | moving, fixed = torch.chunk(img, chunks=2, dim=0) 37 | # pred: moving & grid -> moved 38 | moved, locs = self.forward(moving, fixed) 39 | # loss function 40 | img_loss = mse_loss(moved, fixed) 41 | dx = torch.abs(moved[:, :, 1:, :] - moved[:, :, :-1, :]) 42 | dy = torch.abs(moved[:, :, :, 1:] - moved[:, :, :, :-1]) 43 | smo_loss = (torch.mean(dx * dx) + torch.mean(dy * dy)) / 2 44 | rig_loss = img_loss * 0.95 + smo_loss * 0.05 45 | return rig_loss 46 | 47 | def forward(self, moving: Tensor, fixed: Tensor) -> tuple[Tensor, Tensor]: 48 | # moving & fixed: [b, 1, h, w] 49 | # pred: moving & grid -> moved 50 | flow = self.register(torch.cat([moving, fixed], dim=1)) 51 | flow = integrate(n_step=7, flow=flow) 52 | moved, locs = transformer(moving, flow) 53 | # output 54 | return moved, locs 55 | 56 | def predict_step(self, batch, batch_idx, dataloader_idx=0) -> Tensor: 57 | # batch: [2, 1, h, w] -> [(moving, fixed), 1, h, w] 58 | img, _ = batch 59 | moving, fixed = torch.chunk(img, chunks=2, dim=0) 60 | # pred: moving & grid -> moved 61 | moved, locs = self.forward(moving, fixed) 62 | # output 63 | return torch.hstack([x.squeeze() for x in [moving, fixed, moved]]) 64 | 65 | def configure_optimizers(self): 66 | optimizer = torch.optim.AdamW(self.register.parameters(), lr=1e-3) 67 | return optimizer 68 | 69 | @staticmethod 70 | def add_model_specific_args(parent_parser): 71 | parser = parent_parser.add_argument_group('RegisterTest') 72 | parser.add_argument('--backbone', type=str, default='m') 73 | return parent_parser 74 | 75 | 76 | class SaveFigure(Callback): 77 | def __init__(self, dst: str | Path): 78 | super().__init__() 79 | self.dst = Path(dst) 80 | self.dst.mkdir(parents=True, exist_ok=True) 81 | 82 | def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 83 | rst = outputs 84 | save_image(rst, self.dst / f'{str(batch_idx).zfill(3)}.jpg') 85 | 86 | 87 | def main(): 88 | # args parser 89 | parser = ArgumentParser() 90 | 91 | # program level args 92 | parser.add_argument('--dst', type=str, default='tmp') 93 | parser.add_argument('--only_pred', action='store_true', help='use pre-trained parameters') 94 | parser.add_argument('--ckpt', type=str, default='', help='use your pre-trained parameters (in only_pred mode)') 95 | 96 | # model specific args 97 | parser = RegisterTest.add_model_specific_args(parser) 98 | 99 | # parse 100 | args = parser.parse_args() 101 | 102 | # fix seed 103 | pl.seed_everything(443) 104 | 105 | # model 106 | rt = RegisterTest(args) 107 | 108 | # dataset 109 | transform = transforms.Compose([transforms.ToTensor(), transforms.Resize(size=(32, 32))]) 110 | mnist = MNIST('./', train=True, download=True, transform=transform) 111 | 112 | # callbacks 113 | callbacks = [ModelCheckpoint(dirpath='checkpoints', every_n_train_steps=5), SaveFigure(dst=args.dst)] 114 | 115 | # lightning 116 | trainer = pl.Trainer(accelerator='gpu', devices=-1, callbacks=callbacks, max_epochs=10) 117 | 118 | # train 119 | if not args.only_pred: 120 | loader = DataLoader(mnist, batch_size=32, shuffle=True) 121 | trainer.fit(model=rt, train_dataloaders=loader) 122 | 123 | # predict 124 | loader = DataLoader(mnist, batch_size=2, shuffle=True) 125 | if args.only_pred: 126 | ckpt = f'weights/{args.backbone}-register.ckpt' if args.ckpt == '' else args.ckpt 127 | trainer.predict(model=rt, dataloaders=loader, ckpt_path=ckpt) 128 | else: 129 | trainer.predict(model=rt, dataloaders=loader) 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /exp/test_register/weights/m-register.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlut-dimt/ReCoNet/0870f5f37852ace85850d9fd2d4570eb57682d22/exp/test_register/weights/m-register.ckpt -------------------------------------------------------------------------------- /exp/test_register/weights/u-register.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlut-dimt/ReCoNet/0870f5f37852ace85850d9fd2d4570eb57682d22/exp/test_register/weights/u-register.ckpt -------------------------------------------------------------------------------- /lightning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlut-dimt/ReCoNet/0870f5f37852ace85850d9fd2d4570eb57682d22/lightning/__init__.py -------------------------------------------------------------------------------- /lightning/auto_rf.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from pathlib import Path 3 | from typing import Literal, List 4 | 5 | import cv2 6 | import torch 7 | from kornia import image_to_tensor, create_meshgrid 8 | from torch import Tensor, nn 9 | from torch.utils.data import Dataset 10 | from torchvision.transforms import Resize, RandomResizedCrop 11 | 12 | from modules.random_adjust import RandomAdjust 13 | 14 | 15 | class AutoRF(Dataset): 16 | def __init__( 17 | self, 18 | root: str | Path, 19 | mode: str = Literal['train', 'val', 'pred'], 20 | level: str = Literal['none', 'easy', 'normal', 'hard'], 21 | iqa: bool = True, 22 | ): 23 | super().__init__() 24 | root = Path(root) 25 | self.root = root 26 | self.mode = mode 27 | self.iqa = iqa 28 | 29 | # filter 30 | list_f = root / 'meta' / f'{mode}.txt' 31 | assert list_f.exists(), f'find no meta file in path {str(list_f)}' 32 | t = list_f.read_text().splitlines() 33 | 34 | # get sample list 35 | types = ['.jpg', '.bmp', '.png'] 36 | samples = [x.name for x in sorted((root / 'ir').glob('*')) if x.suffix in types] 37 | samples = list(filter(lambda x: Path(x).stem in t, samples)) 38 | self.samples = samples 39 | 40 | # init complex transform 41 | match level: 42 | case 'easy': 43 | easy = {'transforms': 'ep', 'kernel_size': (143, 143), 'sigma': (32, 32), 'distortion_scale': 0.2} 44 | self.adjust = RandomAdjust(easy) 45 | case 'normal': 46 | normal = {'transforms': 'ep', 'kernel_size': (103, 103), 'sigma': (32, 32), 'distortion_scale': 0.3} 47 | self.adjust = RandomAdjust(normal) 48 | case 'hard': 49 | hard = {'transforms': 'ep', 'kernel_size': (63, 63), 'sigma': (32, 32), 'distortion_scale': 0.4} 50 | self.adjust = RandomAdjust(hard) 51 | case _: 52 | self.adjust = None 53 | 54 | # init transform 55 | match mode: 56 | case 'train': 57 | self.transform = RandomResizedCrop(size=(320, 320)) 58 | case 'val': 59 | self.transform = Resize(size=(320, 320)) 60 | case _: 61 | self.transform = nn.Identity() 62 | 63 | def __len__(self) -> int: 64 | return len(self.samples) 65 | 66 | def __getitem__(self, index: int) -> dict: 67 | # find sample with index 68 | name = self.samples[index] 69 | 70 | # load infrared and visible 71 | x = self.gray_read(self.root / 'ir' / name) 72 | y = self.gray_read(self.root / 'vi' / name) 73 | 74 | # load information measurement 75 | if self.iqa: 76 | x_w = self.gray_read(self.root / 'iqa' / 'ir' / name) 77 | y_w = self.gray_read(self.root / 'iqa' / 'vi' / name) 78 | else: 79 | x_w, y_w = torch.ones_like(x), torch.ones_like(x) 80 | 81 | # transform (resize) 82 | t = torch.cat([x, y, x_w, y_w], dim=0) 83 | x, y, x_w, y_w = torch.chunk(self.transform(t), chunks=4, dim=0) 84 | 85 | # adjust (challenge simulations) (optional) 86 | h, w = y.size()[-2:] 87 | grid = create_meshgrid(h, w, device=y.device).to(y.dtype) 88 | if self.adjust is not None: 89 | y_t, params = self.adjust(y.unsqueeze(dim=0)) 90 | flow_gt = reduce(lambda i, j: i + j, [v for _, v in params.items()]) 91 | locs_gt = grid - flow_gt 92 | y_t.squeeze_(dim=0) # [1, 1, h, w] -> [1, h, w] 93 | else: 94 | y_t = y 95 | locs_gt = grid 96 | locs_gt.squeeze_(dim=0) # [1, h, w, 2] -> [h, w, 2] 97 | 98 | # merge data 99 | sample = {'name': name, 'ir': x, 'vi': y, 'ir_w': x_w, 'vi_w': y_w, 'vi_t': y_t, 'locs_gt': locs_gt} 100 | 101 | # return as except 102 | return sample 103 | 104 | @staticmethod 105 | def gray_read(img_path: str | Path) -> Tensor: 106 | img_n = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE) 107 | img_t = image_to_tensor(img_n).float() / 255 108 | return img_t 109 | 110 | @staticmethod 111 | def collate_fn(data: List[dict]) -> dict: 112 | # keys 113 | keys = data[0].keys() 114 | # merge 115 | new_data = {} 116 | for key in keys: 117 | k_data = [d[key] for d in data] 118 | new_data[key] = k_data if isinstance(k_data[0], str) else torch.stack(k_data) 119 | # return as expected 120 | return new_data 121 | -------------------------------------------------------------------------------- /lightning/reco.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from kornia.filters import canny 6 | from kornia.losses import ssim_loss 7 | from torch import Tensor 8 | from torch.nn.functional import mse_loss, l1_loss 9 | from torch.optim.lr_scheduler import ReduceLROnPlateau 10 | 11 | from modules.functions.integrate import integrate 12 | from modules.functions.transformer import transformer 13 | from modules.fuser import Fuser 14 | from modules.m_register import MRegister 15 | from modules.u_register import URegister 16 | 17 | 18 | class ReCo(pl.LightningModule): 19 | def __init__(self, args: Namespace): 20 | super().__init__() 21 | self.save_hyperparameters() 22 | self.dim = args.dim 23 | 24 | # init register 25 | match args.register: 26 | case 'm': 27 | print('Init with m-register') 28 | self.register = MRegister(in_c=4, dim=16, k_sizes=(3, 5, 7), use_bn=False, use_rs=True) 29 | case 'u': 30 | print('Init with u-register') 31 | self.register = URegister(in_c=4, dim=16) 32 | case _: 33 | print('Turn off register') 34 | self.register = None 35 | 36 | # init fuser 37 | self.fuser = Fuser(depth=3, dim=self.dim, use_bn=False) 38 | 39 | # learning rate 40 | self.lr = args.lr 41 | 42 | # weight 43 | self.rf_weight = args.rf_weight 44 | self.r_weight, self.f_weight = args.r_weight, args.f_weight 45 | 46 | def training_step(self, batch, batch_idx): 47 | # infrared & visible: [b, 1, h, w] 48 | x, y = batch['ir'], batch['vi'] 49 | 50 | # register (optional): infrared (fixed) & visible (moving) -> y_m (moved) 51 | _, y_e = canny(y) 52 | if self.register is not None: 53 | y_t = batch['vi_t'] 54 | y_m, locs_pred, y_m_e = self.r_forward(moving=y_t, fixed=x) 55 | else: 56 | y_m, locs_pred, y_m_e = y, 0, y_e 57 | 58 | # fuser: infrared (infrared) & visible (visible) -> f (fusion) 59 | f = self.f_forward(ir=x, vi=y_m) 60 | 61 | # register loss (optional): 62 | if self.register is not None: 63 | # image loss: y_m_e (edges of moved) -> y (edges of visible) 64 | img_loss = mse_loss(y_m_e, y_e) 65 | self.log('reg/img', img_loss) 66 | # locs loss: locs_pred -> locs_gt 67 | locs_gt = batch['locs_gt'] 68 | locs_loss = mse_loss(locs_pred, locs_gt) 69 | self.log('reg/locs', locs_loss) 70 | # smooth loss: y_m_e (edges of moved) smooth 71 | dx = torch.abs(y_m_e[:, :, 1:, :] - y_m_e[:, :, :-1, :]) 72 | dy = torch.abs(y_m_e[:, :, :, 1:] - y_m_e[:, :, :, :-1]) 73 | smo_loss = (torch.mean(dx * dx) + torch.mean(dy * dy)) / 2 74 | self.log('reg/smooth', smo_loss) 75 | reg_loss = img_loss * self.r_weight[0] + locs_loss * self.r_weight[1] + smo_loss * self.r_weight[2] 76 | self.log('train/reg', reg_loss) 77 | else: 78 | reg_loss = 0 79 | 80 | # fuse loss with iqa (if iqa is disabled, x_w = y_w = 1) 81 | x_w, y_w = batch['ir_w'], batch['vi_w'] 82 | x_ssim = ssim_loss(f, x, window_size=11, reduction='none') 83 | y_ssim = ssim_loss(f, y, window_size=11, reduction='none') 84 | s_loss = x_ssim * x_w + y_ssim * y_w 85 | self.log('fus/ssim', s_loss.mean()) 86 | x_l1 = l1_loss(f, x, reduction='none') 87 | y_l1 = l1_loss(f, y, reduction='none') 88 | l_loss = x_l1 * x_w + y_l1 * y_w 89 | self.log('fus/l1', l_loss.mean()) 90 | fus_loss = self.f_weight[0] * s_loss + self.f_weight[1] * l_loss 91 | fus_loss = fus_loss.mean() 92 | self.log('train/fus', fus_loss) 93 | 94 | # final loss 95 | fin_loss = self.rf_weight[0] * reg_loss + self.rf_weight[1] * fus_loss 96 | self.log('train/fin', fin_loss) 97 | 98 | return fin_loss 99 | 100 | def validation_step(self, batch, batch_idx): 101 | # infrared & visible: [b, 1, h, w] 102 | x, y = batch['ir'], batch['vi'] 103 | 104 | # output 105 | o = [x, y] 106 | 107 | # register (optional): infrared (fixed) & visible (moving) -> y_m (moved) 108 | if self.register is not None: 109 | y_t = batch['vi_t'] 110 | y_m, _, _ = self.r_forward(moving=y_t, fixed=x) 111 | o += [y_t, y_m, y_m - y_t] 112 | else: 113 | y_m = y 114 | 115 | # fuser: ir (infrared) & vi (visible) -> f (fusion) 116 | f = self.f_forward(ir=x, vi=y_m) 117 | o += [f] 118 | 119 | # output 120 | o = torch.cat(o, dim=1) 121 | return o 122 | 123 | def predict_step(self, batch, batch_idx, dataloader_idx=0) -> [str, Tensor]: 124 | # infrared & visible (moving): [b, 1, h, w] 125 | x, y_t = batch['ir'], batch['vi'] 126 | 127 | # register (optional): infrared (fixed) & visible (moving) -> y_m (moved) 128 | if self.register is not None: 129 | y_m, _, _ = self.r_forward(moving=y_t, fixed=x) 130 | else: 131 | y_m = y_t 132 | 133 | # fuser: infrared (infrared) & visible (visible) -> f (fusion) 134 | f = self.f_forward(ir=x, vi=y_m) 135 | 136 | # output 137 | return batch['name'], f 138 | 139 | def r_forward(self, moving: Tensor, fixed: Tensor) -> tuple[Tensor, Tensor, Tensor]: 140 | # moving & fixed: [b, 1, h, w] 141 | # pred: moving & grid -> moved 142 | # apply transform 143 | moving_m, moving_e = canny(moving) 144 | fixed_m, fixed_e = canny(fixed) 145 | # predict flow 146 | flow = self.register(torch.cat([moving, fixed, moving_m, fixed_m], dim=1)) 147 | flow = integrate(n_step=7, flow=flow) 148 | moved, locs = transformer(moving, flow) 149 | moved_e, locs = transformer(moving_e, flow) 150 | return moved, locs, moved_e 151 | 152 | def f_forward(self, ir: Tensor, vi: Tensor) -> Tensor: 153 | # ir & vi: [b, 1, h, w] 154 | # pred: ir (infrared) & vi (visible) -> f (fusion) 155 | f = self.fuser(torch.cat([ir, vi], dim=1)) 156 | return f 157 | 158 | def configure_optimizers(self): 159 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 160 | scheduler = ReduceLROnPlateau(optimizer) 161 | return [optimizer], {'scheduler': scheduler, 'monitor': 'train/fin'} 162 | 163 | @staticmethod 164 | def add_model_specific_args(parent_parser): 165 | parser = parent_parser.add_argument_group('ReCo+') 166 | # reco 167 | parser.add_argument('--register', type=str, default='x', help='register (m: micro, u: u-net, x: none)') 168 | parser.add_argument('--dim', type=int, default=32, help='dimension in backbone (default: 16)') 169 | # optimizer 170 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate (default: 1e-3)') 171 | # weights 172 | parser.add_argument('--rf_weight', nargs='+', type=float, help='balance in register & fuse') 173 | parser.add_argument('--r_weight', nargs='+', type=float, help='balance in register: img, locs, smooth') 174 | parser.add_argument('--f_weight', nargs='+', type=float, help='balance in fuse: ssim, l1') 175 | 176 | return parent_parser 177 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlut-dimt/ReCoNet/0870f5f37852ace85850d9fd2d4570eb57682d22/modules/__init__.py -------------------------------------------------------------------------------- /modules/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlut-dimt/ReCoNet/0870f5f37852ace85850d9fd2d4570eb57682d22/modules/functions/__init__.py -------------------------------------------------------------------------------- /modules/functions/integrate.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from modules.functions.transformer import transformer 4 | 5 | 6 | def integrate(n_step: int, flow: Tensor) -> Tensor: 7 | scale = 1.0 / (2 ** n_step) 8 | flow = flow * scale 9 | for _ in range(n_step): 10 | i_flow = flow.permute(0, 3, 1, 2) # [b, 2, h, w] 11 | o_flow = transformer(i_in=i_flow, flow=flow)[0].permute(0, 2, 3, 1) # [b, h, w, 2] 12 | flow = flow + o_flow 13 | return flow 14 | -------------------------------------------------------------------------------- /modules/functions/transformer.py: -------------------------------------------------------------------------------- 1 | from kornia import create_meshgrid 2 | from torch import Tensor 3 | from torch.nn import functional 4 | 5 | 6 | def transformer(i_in: Tensor, flow: Tensor) -> [Tensor, Tensor]: 7 | # create mesh grid: [1, h, w, 2] 8 | h, w = flow.size()[1:3] 9 | grid = create_meshgrid(height=h, width=w, normalized_coordinates=False, device=flow.device).to(flow.dtype) 10 | # new locations: [b, h, w, 2] 11 | locs = grid + flow 12 | # normalize 13 | locs[..., 0] = (locs[..., 0] / (w - 1) - 0.5) * 2 14 | locs[..., 1] = (locs[..., 1] / (h - 1) - 0.5) * 2 15 | # apply transform 16 | i_out = functional.grid_sample(i_in, locs, align_corners=True, mode='bilinear') 17 | # return moved image and flow 18 | return i_out, locs 19 | -------------------------------------------------------------------------------- /modules/fuser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | from modules.layers.d_group import DGroup 5 | 6 | 7 | class Fuser(nn.Module): 8 | def __init__(self, depth: int, dim: int, use_bn: bool): 9 | super().__init__() 10 | self.depth = depth 11 | 12 | # attention layer: [2] -> [1], [2] -> [1] 13 | self.att_a_conv = nn.Conv2d(2, 1, kernel_size=3, padding='same', bias=False) 14 | self.att_b_conv = nn.Conv2d(2, 1, kernel_size=3, padding='same', bias=False) 15 | 16 | # dilation fuse 17 | self.decoder = DGroup(in_c=3, out_c=1, dim=dim, k_size=3, use_bn=use_bn) 18 | 19 | def forward(self, i_in: Tensor, init_f: str = 'max', show_detail: bool = False): 20 | # recurrent subnetwork 21 | # generate f_0 with initial function 22 | i_1, i_2 = torch.chunk(i_in, chunks=2, dim=1) 23 | i_f = [torch.max(i_1, i_2) if init_f == 'max' else (i_1 + i_2) / 2] 24 | att_a, att_b = [], [] 25 | 26 | # loop in subnetwork 27 | for _ in range(self.depth): 28 | i_f_x, att_a_x, att_b_x = self._sub_forward(i_1, i_2, i_f[-1]) 29 | i_f.append(i_f_x), att_a.append(att_a_x), att_b.append(att_b_x) 30 | 31 | # return as expected 32 | return (i_f, att_a, att_b) if show_detail else i_f[-1] 33 | 34 | def _sub_forward(self, i_1: Tensor, i_2: Tensor, i_f: Tensor): 35 | # attention 36 | att_a = self._attention(self.att_a_conv, i_1, i_f) 37 | att_b = self._attention(self.att_b_conv, i_2, i_f) 38 | 39 | # focus on attention 40 | i_1_w = i_1 * att_a 41 | i_2_w = i_2 * att_b 42 | 43 | # dilation fuse 44 | i_in = torch.cat([i_1_w, i_f, i_2_w], dim=1) 45 | i_out = self.decoder(i_in) 46 | 47 | # return fusion result of current recurrence 48 | return i_out, att_a, att_b 49 | 50 | @staticmethod 51 | def _attention(att_conv, i_a, i_b): 52 | i_in = torch.cat([i_a, i_b], dim=1) 53 | i_max, _ = torch.max(i_in, dim=1, keepdim=True) 54 | i_avg = torch.mean(i_in, dim=1, keepdim=True) 55 | i_in = torch.cat([i_max, i_avg], dim=1) 56 | i_out = att_conv(i_in) 57 | return torch.sigmoid(i_out) 58 | -------------------------------------------------------------------------------- /modules/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlut-dimt/ReCoNet/0870f5f37852ace85850d9fd2d4570eb57682d22/modules/layers/__init__.py -------------------------------------------------------------------------------- /modules/layers/conv_group.py: -------------------------------------------------------------------------------- 1 | from torch import nn, Tensor 2 | 3 | 4 | class ConvGroup(nn.Module): 5 | def __init__(self, conv: nn.Conv2d, use_bn: bool): 6 | super().__init__() 7 | 8 | # (Conv2d, BN, GELU) 9 | dim = conv.out_channels 10 | self.group = nn.Sequential( 11 | conv, 12 | nn.BatchNorm2d(dim) if use_bn else nn.Identity(), 13 | nn.GELU(), 14 | ) 15 | 16 | def forward(self, x: Tensor) -> Tensor: 17 | return self.group(x) 18 | -------------------------------------------------------------------------------- /modules/layers/d_group.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tabulate import tabulate 3 | from torch import nn, Tensor 4 | 5 | from modules.layers.conv_group import ConvGroup 6 | 7 | 8 | class DGroup(nn.Module): 9 | """ 10 | [channels: dim, s] -> DGroup -> [channels: 1, s] 11 | """ 12 | 13 | def __init__(self, in_c: int, out_c: int, dim: int, k_size: int, use_bn: bool): 14 | super().__init__() 15 | 16 | # conv_d: [dim] -> [1] 17 | self.conv_d = nn.ModuleList([ 18 | ConvGroup(nn.Conv2d(in_c, dim, kernel_size=k_size, padding='same', dilation=(i + 1)), use_bn=use_bn) 19 | for i in range(3) 20 | ]) 21 | 22 | # conv_s: [3] -> [1] 23 | self.conv_s = nn.Sequential( 24 | nn.Conv2d(3 * dim, out_c, kernel_size=3, padding='same'), 25 | nn.Tanh(), 26 | ) 27 | 28 | def forward(self, x: Tensor) -> Tensor: 29 | f_in = x 30 | # conv_d 31 | f_x = [conv(f_in) for conv in self.conv_d] 32 | # suffix 33 | f_t = torch.cat(f_x, dim=1) 34 | f_out = self.conv_s(f_t) 35 | return f_out 36 | 37 | def __str__(self): 38 | table = [[n, p.mean(), p.grad.mean()] for n, p in self.named_parameters() if p.grad is not None] 39 | return tabulate(table, headers=['layer', 'weights', 'grad'], tablefmt='pretty') 40 | -------------------------------------------------------------------------------- /modules/layers/u_group.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tabulate import tabulate 3 | from torch import nn, Tensor 4 | 5 | from modules.layers.conv_group import ConvGroup 6 | 7 | 8 | class UGroup(nn.Module): 9 | """ 10 | [channels: in_c, s] -> UGroup -> [channels: out_c, s] 11 | """ 12 | 13 | def __init__(self, in_c: int, out_c: int, dim: int, k_size: int, use_bn: bool, use_rs: bool): 14 | super().__init__() 15 | 16 | # conv_1: [in_c, s] -> [d, s] 17 | self.conv_1 = ConvGroup(nn.Conv2d(in_c, dim, kernel_size=k_size, padding='same'), use_bn=use_bn) 18 | 19 | # down sample (pool): [d, s] -> [d, s/2] 20 | self.ds = nn.MaxPool2d(2, stride=2, ceil_mode=True) if use_rs else nn.Identity() 21 | 22 | # conv_2: [d, s/2] -> [d, s/2] 23 | self.conv_2 = ConvGroup(nn.Conv2d(dim, dim, kernel_size=k_size, padding='same'), use_bn=use_bn) 24 | 25 | # dilated conv: [d, s/2] -> [d, s/2] 26 | self.conv_dil = ConvGroup(nn.Conv2d(dim, dim, kernel_size=k_size, padding='same', dilation=2), use_bn=use_bn) 27 | 28 | # conv_3: [2d, s/2] -> [d, s/2] 29 | self.conv_3 = ConvGroup(nn.Conv2d(2 * dim, dim, kernel_size=k_size, padding='same'), use_bn=use_bn) 30 | 31 | # up sample: [d, s/2] -> [d, s] 32 | self.us = nn.Upsample(scale_factor=2, mode='bilinear') if use_rs else nn.Identity() 33 | 34 | # conv_4: [2d, s] -> [out_c, s] 35 | self.conv_4 = ConvGroup(nn.Conv2d(2 * dim, out_c, kernel_size=k_size, padding='same'), use_bn=use_bn) 36 | 37 | def forward(self, x: Tensor) -> Tensor: 38 | f_in = x 39 | # conv_1: [in_c, s] -> [d, s] 40 | f_1 = self.conv_1(f_in) 41 | # conv_2: [d, s/2] -> [d, s/2] 42 | f_t = self.ds(f_1) 43 | f_2 = self.conv_2(f_t) 44 | # conv_dil: [d, s/2] -> [d, s/2] 45 | f_d = self.conv_dil(f_2) 46 | # conv_3: [2d, s/2] -> [d, s/2] 47 | f_t = torch.cat((f_2, f_d), dim=1) 48 | f_3 = self.conv_3(f_t) 49 | # conv_4: [2d, s] -> [out_c, s] 50 | f_t = self.us(f_3) 51 | f_t = torch.cat([f_1, f_t], dim=1) 52 | f_out = self.conv_4(f_t) 53 | return f_out 54 | 55 | def __str__(self): 56 | table = [[n, p.mean(), p.grad.mean()] for n, p in self.named_parameters() if p.grad is not None] 57 | return tabulate(table, headers=['layer', 'weights', 'grad'], tablefmt='pretty') 58 | -------------------------------------------------------------------------------- /modules/layers/u_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | from modules.layers.conv_group import ConvGroup 5 | 6 | 7 | class UNet(nn.Module): 8 | """ 9 | An u-net architecture. 10 | [channels: in_c, s] -> UNet -> [channels: out_c, s] 11 | 12 | default: 13 | encoder: [16, 32, 32, 32] 14 | decoder: [32, 32, 32, 32, 32, 16] 15 | """ 16 | 17 | def __init__(self, in_c: int, out_c: int): 18 | super().__init__() 19 | 20 | # encoder and decoder feature 21 | self.enc_c = [16, 32, 32, 32] 22 | self.dec_c = [32, 32, 32, 32, 32, 16, out_c] 23 | 24 | # upsample 25 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 26 | 27 | # configure encoder (down-sampling path) 28 | self.encoder = nn.ModuleList() 29 | prev_c = in_c 30 | for c in self.enc_c: 31 | # [size: s] -> [size: s/2] 32 | self.encoder.append(ConvGroup(nn.Conv2d(prev_c, c, kernel_size=3, stride=2, padding=1), use_bn=False)) 33 | prev_c = c 34 | 35 | # configure decoder (up-sampling path) 36 | rev_enc_c = list(reversed(self.enc_c)) 37 | self.decoder = nn.ModuleList() 38 | for i, c in enumerate(self.dec_c[:len(self.enc_c)]): 39 | tmp_c = prev_c + rev_enc_c[i] if i > 0 else prev_c 40 | self.decoder.append(ConvGroup(nn.Conv2d(tmp_c, c, kernel_size=3, padding='same'), use_bn=False)) 41 | prev_c = c 42 | 43 | # configure decoder suffix (no up-sampling) 44 | prev_c += in_c 45 | self.suffix = nn.ModuleList() 46 | for c in self.dec_c[len(self.enc_c):]: 47 | self.suffix.append(ConvGroup(nn.Conv2d(prev_c, c, kernel_size=3, padding='same'), use_bn=False)) 48 | prev_c = c 49 | 50 | def forward(self, x: Tensor) -> Tensor: 51 | # encoder 52 | f_in = [x] 53 | for layer in self.encoder: 54 | f_in.append(layer(f_in[-1])) 55 | 56 | # decoder: conv -> upsample -> concat 57 | f_x = f_in.pop() 58 | for layer in self.decoder: 59 | f_x = layer(f_x) 60 | f_x = self.upsample(f_x) 61 | f_x = torch.cat([f_x, f_in.pop()], dim=1) 62 | 63 | # suffix 64 | for layer in self.suffix: 65 | f_x = layer(f_x) 66 | 67 | return f_x 68 | -------------------------------------------------------------------------------- /modules/m_register.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | from modules.layers.conv_group import ConvGroup 5 | from modules.layers.d_group import DGroup 6 | from modules.layers.u_group import UGroup 7 | 8 | 9 | class MRegister(nn.Module): 10 | def __init__(self, in_c: int, dim: int, k_sizes: tuple, use_bn: bool, use_rs: bool): 11 | super().__init__() 12 | 13 | # stem: [2] -> [d] 14 | self.stem = ConvGroup(nn.Conv2d(in_c, dim, kernel_size=3, padding='same'), use_bn=use_bn) 15 | 16 | # group s: [d, 2d, 3d] -> [d] 17 | self.groups = nn.ModuleList( 18 | [ 19 | UGroup(in_c=(i + 1) * dim, out_c=dim, dim=dim, k_size=k_sizes[i], use_bn=use_bn, use_rs=use_rs) 20 | for i in range(3) 21 | ] 22 | ) 23 | 24 | # decoder: [3d] -> [2] 25 | self.decoder = DGroup(in_c=3 * dim, out_c=2, dim=dim, k_size=3, use_bn=use_bn) 26 | self.decoder.conv_s = nn.Conv2d(3 * dim, 2, kernel_size=3, padding='same') 27 | 28 | # init flow layer with small weights and bias 29 | self.decoder.conv_s.apply(self.init_weights) 30 | 31 | def forward(self, i_in: Tensor): 32 | # stem: [2] -> [d] 33 | f_x = self.stem(i_in) 34 | 35 | # group: [d, 2d, 3d] -> [3d] 36 | f_0 = self.groups[0](f_x) 37 | f_1 = self.groups[1](torch.cat([f_0, f_x], dim=1)) 38 | f_2 = self.groups[2](torch.cat([f_0, f_1, f_x], dim=1)) 39 | f_i = torch.cat([f_0, f_1, f_2], dim=1) # [b, 3d, h, w] 40 | 41 | # decoder: [3d] -> [2] 42 | flow = self.decoder(f_i).permute(0, 2, 3, 1) # [b, h, w, 2] 43 | 44 | # return middle vars during forward process 45 | return flow 46 | 47 | @staticmethod 48 | @torch.no_grad() 49 | def init_weights(m): 50 | if type(m) == nn.Conv2d: 51 | nn.init.normal_(m.weight, mean=0, std=1e-5) 52 | nn.init.zeros_(m.bias) 53 | -------------------------------------------------------------------------------- /modules/random_adjust.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia import create_meshgrid 3 | from kornia.augmentation import RandomPerspective, RandomElasticTransform 4 | from kornia.filters import get_gaussian_kernel2d, filter2d 5 | from kornia.geometry import normalize_homography, transform_points, get_perspective_transform 6 | from kornia.utils.helpers import _torch_inverse_cast 7 | from torch import nn, Tensor 8 | 9 | 10 | class RandomAdjust(nn.Module): 11 | def __init__(self, config: dict): 12 | super().__init__() 13 | self.config = config 14 | 15 | # elastic 16 | ks, sigma = config['kernel_size'], config['sigma'] 17 | re = RandomElasticTransform(kernel_size=ks, sigma=sigma, p=1) 18 | self.re = re 19 | 20 | # perspective 21 | ds = config['distortion_scale'] 22 | rp = RandomPerspective(distortion_scale=ds, p=1) 23 | self.rp = rp 24 | 25 | # swap params 26 | self.size = () 27 | self.device, self.dtype = torch.device('cpu'), torch.float 28 | 29 | def forward(self, x: Tensor) -> (Tensor, dict): 30 | # params 31 | self.size = x.size() 32 | self.device, self.dtype = x.device, x.dtype 33 | B, _, H, W = x.size() 34 | 35 | params = {} 36 | 37 | # elastic 38 | if 'e' in self.config['transforms']: 39 | x = self.re(x) 40 | noise = self.re._params['noise'].to(self.device) # [b, h, w, 2] 41 | disp_e = self.get_elastic_disp(noise) # [b, h, w, 2] 42 | # rebase 43 | # disp_e = disp_e.permute(0, 3, 1, 2) # [b, 2, h, w] 44 | # disp_e = self.re.apply_transform(disp_e, self.re._params) 45 | # disp_e = disp_e.permute(0, 2, 3, 1) # [b, h, w, 2] 46 | params |= {'de': disp_e} 47 | 48 | # perspective 49 | if 'p' in self.config['transforms']: 50 | # generate params 51 | self.rp(x) 52 | # fix end_points 53 | corner = self.rp._params['start_points'] 54 | self.rp._params['start_points'] = self.rp._params['end_points'] 55 | self.rp._params['end_points'] = corner 56 | # transform 57 | x = self.rp(x, params=self.rp._params) 58 | # calculate offset disp 59 | f, t = self.rp._params['start_points'].to(x), self.rp._params['end_points'].to(x) 60 | matrix = get_perspective_transform(t, f) # matrix end_points -> start_points 61 | disp_p = self.get_perspective_disp(matrix) # [b, h, w, 2] 62 | params |= {'dp': -disp_p} 63 | 64 | return x, params 65 | 66 | def get_perspective_disp(self, transform: Tensor) -> Tensor: 67 | # params 68 | B, _, H, W = self.size 69 | h_out, w_out = H, W 70 | 71 | # we normalize the 3x3 transformation matrix and convert to 3x4 72 | dst_norm_trans_src_norm = normalize_homography(transform, (H, W), (h_out, w_out)) # Bx3x3 73 | 74 | src_norm_trans_dst_norm = _torch_inverse_cast(dst_norm_trans_src_norm) # Bx3x3 75 | 76 | # this piece of code substitutes F.affine_grid since it does not support 3x3 77 | grid = create_meshgrid(h_out, w_out, normalized_coordinates=True, device=self.device).to(self.dtype) 78 | grid = grid.repeat(B, 1, 1, 1) 79 | disp = transform_points(src_norm_trans_dst_norm[:, None, None], grid) - grid # disp: infrared -> \bar{infrared} 80 | return disp 81 | 82 | def get_elastic_disp(self, noise: Tensor) -> Tensor: 83 | # params 84 | config = self.config 85 | ks, sigma = config['kernel_size'], config['sigma'] 86 | 87 | # Get Gaussian kernel for 'visible' and 'infrared' displacement 88 | kernel_x = get_gaussian_kernel2d(ks, sigma)[None] 89 | kernel_y = get_gaussian_kernel2d(ks, sigma)[None] 90 | 91 | # Convolve over a random displacement matrix and scale them with 'alpha' 92 | disp_x = noise[:, :1] 93 | disp_y = noise[:, 1:] 94 | 95 | disp_x = filter2d(disp_x, kernel=kernel_y, border_type="constant") 96 | disp_y = filter2d(disp_y, kernel=kernel_x, border_type="constant") 97 | 98 | # stack and normalize displacement 99 | disp = torch.cat([disp_x, disp_y], dim=1).permute(0, 2, 3, 1) # disp: infrared -> \bar{infrared} 100 | return disp 101 | -------------------------------------------------------------------------------- /modules/u_register.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | from modules.layers.u_net import UNet 5 | 6 | 7 | class URegister(nn.Module): 8 | def __init__(self, in_c: int, dim: int): 9 | super().__init__() 10 | 11 | # u-net core: [2] -> [16] 12 | self.unet = UNet(in_c=in_c, out_c=dim) 13 | 14 | # flow layer 15 | self.flow = nn.Conv2d(dim, 2, kernel_size=3, padding='same') 16 | 17 | # init flow layer 18 | self.flow.apply(self.init_weights) 19 | 20 | def forward(self, i_in: Tensor): 21 | # unet: [2] -> [d] 22 | f_x = self.unet(i_in) 23 | 24 | # flow: [b, d, h, w] -> [b, h, w, 2] 25 | flow = self.flow(f_x).permute(0, 2, 3, 1) 26 | 27 | # return middle vars during forward process 28 | return flow 29 | 30 | @staticmethod 31 | @torch.no_grad() 32 | def init_weights(m): 33 | if type(m) == nn.Conv2d: 34 | nn.init.normal_(m.weight, mean=0, std=1e-5) 35 | nn.init.zeros_(m.bias) 36 | -------------------------------------------------------------------------------- /scripts/pred.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning import Callback 6 | from torch.utils.data import DataLoader 7 | from torchvision.utils import save_image 8 | 9 | from lightning.auto_rf import AutoRF 10 | from lightning.reco import ReCo 11 | 12 | 13 | class SaveFigure(Callback): 14 | def __init__(self, dst: str | Path): 15 | super().__init__() 16 | self.dst = Path(dst) 17 | self.dst.mkdir(parents=True, exist_ok=True) 18 | 19 | def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 20 | name, f = outputs 21 | save_image(f.squeeze(), self.dst / name[0]) 22 | 23 | 24 | def main(): 25 | # args parser 26 | parser = ArgumentParser() 27 | 28 | # program level args 29 | # lightning 30 | parser.add_argument('--ckpt', type=str, default='../weights/default-f.ckpt', help='checkpoint path') 31 | # auto rf 32 | parser.add_argument('--data', type=str, default='../data/tno', help='input data folder') 33 | parser.add_argument('--deform', type=str, default='none', help='random adjust level') 34 | # reco 35 | parser.add_argument('--dst', type=str, default='runs', help='output save folder') 36 | # cuda 37 | parser.add_argument('--no_cuda', action='store_true', help='disable cuda (for cpu and out of memory)') 38 | 39 | # model specific args 40 | parser = ReCo.add_model_specific_args(parser) 41 | 42 | # parse 43 | args = parser.parse_args() 44 | 45 | # fix seed 46 | pl.seed_everything(443) 47 | 48 | # model 49 | reco = ReCo(args) 50 | 51 | # dataloader 52 | dataset = AutoRF(root=args.data, mode='pred', level=args.deform) 53 | loader = DataLoader(dataset, collate_fn=AutoRF.collate_fn) 54 | 55 | # callbacks 56 | callbacks = [SaveFigure(dst=args.dst)] 57 | 58 | # lightning 59 | accelerator, devices, strategy = ('cpu', None, None) if args.no_cuda else ('gpu', -1, 'ddp') 60 | trainer = pl.Trainer(accelerator=accelerator, devices=devices, callbacks=callbacks, strategy=strategy) 61 | trainer.predict(model=reco, dataloaders=loader, ckpt_path=args.ckpt) 62 | 63 | 64 | if __name__ == '__main__': 65 | main() 66 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import wandb 6 | from pytorch_lightning import Callback 7 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 8 | from pytorch_lightning.loggers import WandbLogger 9 | from torch import Tensor 10 | from torch.utils.data import DataLoader 11 | 12 | from lightning.auto_rf import AutoRF 13 | from lightning.reco import ReCo 14 | from utils.pretty_vars import pretty_vars 15 | 16 | 17 | class LogImageCallback(Callback): 18 | def __init__(self, logger: WandbLogger, show_grad: bool = False): 19 | super().__init__() 20 | self.logger = logger 21 | self.show_grad = show_grad 22 | 23 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 24 | if batch_idx == 0: 25 | assert isinstance(outputs, Tensor) 26 | size = outputs.shape[1] 27 | imgs = [torch.clip(img[-1].squeeze(), min=0, max=1) for img in torch.chunk(outputs, chunks=size, dim=1)] 28 | captions = ['ir', 'vi', 'vi_t', 'vi_m', 'dif', 'f'] if size == 6 else ['ir', 'vi', 'f'] 29 | self.logger.log_image(key='sample', images=imgs, caption=captions) 30 | 31 | def on_before_zero_grad(self, trainer, pl_module, optimizer): 32 | if self.show_grad: 33 | print(pretty_vars(pl_module.register)) 34 | 35 | 36 | def main(): 37 | # args parser 38 | parser = ArgumentParser() 39 | 40 | # program level args 41 | # lightning 42 | parser.add_argument('--ckpt', type=str, default='checkpoints', help='checkpoints save folder') 43 | parser.add_argument('--show_grad', action='store_true', help='show grad before zero_grad') 44 | parser.add_argument('--seed', type=int, default=443, help='seed for random number') 45 | # wandb 46 | parser.add_argument('--key', type=str, help='wandb auth key') 47 | # auto rf 48 | parser.add_argument('--data', type=str, default='../data/tno', help='input data folder') 49 | parser.add_argument('--deform', type=str, default='none', help='random adjust level') 50 | # loader 51 | parser.add_argument('--bs', type=int, default=32, help='batch size') 52 | # cuda 53 | parser.add_argument('--no_cuda', action='store_true', help='disable cuda (for cpu and out of memory)') 54 | 55 | # model specific args 56 | parser = ReCo.add_model_specific_args(parser) 57 | 58 | # parse 59 | args = parser.parse_args() 60 | 61 | # fix seed 62 | pl.seed_everything(args.seed) 63 | 64 | # model 65 | reco = ReCo(args) 66 | 67 | # dataloader 68 | train_dataset = AutoRF(root=args.data, mode='train', level=args.deform) 69 | train_loader = DataLoader( 70 | train_dataset, batch_size=args.bs, 71 | shuffle=True, collate_fn=AutoRF.collate_fn, num_workers=72, 72 | ) 73 | val_dataset = AutoRF(root=args.data, mode='val', level=args.deform) 74 | val_loader = DataLoader( 75 | val_dataset, batch_size=1, 76 | collate_fn=AutoRF.collate_fn, num_workers=72, 77 | ) 78 | 79 | # logger 80 | wandb.login(key=args.key) 81 | logger = WandbLogger(project='reco') 82 | 83 | # callbacks 84 | callbacks = [ 85 | ModelCheckpoint(dirpath=args.ckpt, every_n_train_steps=10), 86 | LogImageCallback(logger=logger, show_grad=args.show_grad), 87 | LearningRateMonitor(logging_interval='step'), 88 | ] 89 | 90 | # lightning 91 | accelerator, devices, strategy = ('cpu', None, None) if args.no_cuda else ('gpu', -1, 'ddp') 92 | trainer = pl.Trainer( 93 | accelerator=accelerator, 94 | devices=devices, 95 | logger=logger, 96 | callbacks=callbacks, 97 | max_epochs=800, 98 | strategy=strategy, 99 | log_every_n_steps=5, 100 | ) 101 | trainer.fit(model=reco, train_dataloaders=train_loader, val_dataloaders=val_loader) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dlut-dimt/ReCoNet/0870f5f37852ace85850d9fd2d4570eb57682d22/utils/__init__.py -------------------------------------------------------------------------------- /utils/choose_images.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from pathlib import Path 3 | from typing import Literal 4 | 5 | import cv2 6 | import numpy 7 | 8 | 9 | def choose_images(root: str | Path, mode: str = Literal['train', 'val', 'pred']): 10 | root = Path(root) 11 | names = [x.name for x in sorted(root.glob('ir/*')) if x.suffix in ['.png', '.jpg', '.bmp']] 12 | save = [] 13 | for name in names: 14 | x = cv2.imread(str(root / 'ir' / name), cv2.IMREAD_GRAYSCALE) 15 | y = cv2.imread(str(root / 'vi' / name), cv2.IMREAD_GRAYSCALE) 16 | t = numpy.hstack([x, y]) 17 | cv2.imshow(name, t) 18 | if cv2.waitKey() == ord('s'): 19 | save.append(name.split('.')[0]) 20 | cv2.destroyWindow(name) 21 | meta = root / 'meta' 22 | meta.mkdir(parents=True, exist_ok=True) 23 | meta_f = meta / f'{mode}.txt' 24 | meta_f.write_text(reduce(lambda i, j: i + j, [t + '\n' for t in save])) 25 | 26 | 27 | if __name__ == '__main__': 28 | choose_images('data/road', mode='val') 29 | -------------------------------------------------------------------------------- /utils/pretty_vars.py: -------------------------------------------------------------------------------- 1 | from tabulate import tabulate 2 | from torch import nn 3 | 4 | 5 | def pretty_vars(module: nn.Module) -> str: 6 | table = [[n, p.mean(), p.grad.mean()] for n, p in module.named_parameters() if p.grad is not None] 7 | return tabulate(table, headers=['layer', 'weights', 'grad'], tablefmt='pretty') 8 | --------------------------------------------------------------------------------