├── model ├── __init__.py ├── loss.py ├── builder.py └── network.py ├── utils ├── __init__.py ├── torch_tps_upsample.py ├── utils_op.py ├── utils_transform.py ├── flow_viz.py ├── utils_module.py └── torch_tps_transform.py ├── assets └── teaser.jpg ├── scripts ├── train.sh ├── test.sh └── test_portrait.sh ├── warmup_scheduler ├── __init__.py ├── run.py └── scheduler.py ├── requirements.txt ├── LICENSE ├── Datasets └── readme.md ├── README.md ├── dataset_loaders.py ├── test.py ├── test_portrait.py └── train.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KangLiao929/MOWA/HEAD/assets/teaser.jpg -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | python train.py -gpu '0, 1, 2, 3, 4, 5, 6, 7' -b 8 -m 'mowa' [--train_path=...] -------------------------------------------------------------------------------- /warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from warmup_scheduler.scheduler import GradualWarmupScheduler 3 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | python test.py --gpu 0 --batch_size 1 --model_path '/checkpoint/' --method 'mowa' [--test_path=...] -------------------------------------------------------------------------------- /scripts/test_portrait.sh: -------------------------------------------------------------------------------- 1 | python test_portrait.py --gpu 0 --batch_size 1 --model_path '/checkpoint/' --method 'mowa' [--test_path=...] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==3.11.0 2 | imageio==2.33.1 3 | imgaug==0.4.0 4 | lpips==0.1.4 5 | Markdown==3.5.1 6 | matplotlib==3.7.4 7 | networkx==3.1 8 | numpy==1.24.4 9 | opencv-python==4.9.0.80 10 | opencv-python-headless==4.9.0.80 11 | packaging==23.2 12 | pandas==2.0.3 13 | pillow==10.2.0 14 | pytorch-fid==0.3.0 15 | pytorch-lightning==2.2.1 16 | PyYAML==6.0.1 17 | scikit-image==0.21.0 18 | scikit-learn==1.3.2 19 | scipy==1.10.1 20 | seaborn==0.13.1 21 | six==1.16.0 22 | tensorboard==2.14.0 23 | tensorboardX==2.6.2.2 24 | timm==0.9.12 25 | torch==2.1.2 26 | torchmetrics==1.3.2 27 | torchvision==0.16.2 28 | tqdm==4.66.1 29 | -------------------------------------------------------------------------------- /warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 3 | from torch.optim.sgd import SGD 4 | 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 10 | optim = SGD(model, 0.1) 11 | 12 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) 13 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) 14 | 15 | optim.zero_grad() 16 | optim.step() 17 | 18 | for epoch in range(1, 20): 19 | scheduler_warmup.step(epoch) 20 | print(epoch, optim.param_groups[0]['lr']) 21 | 22 | optim.step() 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2025 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and 6 | binary forms, with or without modification, are permitted provided 7 | that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in 14 | the documentation and/or other materials provided with the 15 | distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | In the event that redistribution and/or use for commercial purpose in 34 | source or binary forms, with or without modification is required, 35 | please contact the contributor(s) of the work. 36 | -------------------------------------------------------------------------------- /warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def get_vgg19_FeatureMap(vgg_model, input_tensor, layer_index): 5 | mean = torch.tensor([0.485, 0.456, 0.406]).reshape((1,3,1,1)) 6 | std = torch.tensor([0.229, 0.224, 0.225]).reshape((1,3,1,1)) 7 | if torch.cuda.is_available(): 8 | mean = mean.cuda() 9 | std = std.cuda() 10 | 11 | vgg_input = (input_tensor - mean) / std 12 | for i in range(0, layer_index+1): 13 | if i == 0: 14 | x = vgg_model.module.features[0](vgg_input) 15 | else: 16 | x = vgg_model.module.features[i](x) 17 | return x 18 | 19 | def l_num_loss(img1, img2, l_num=1): 20 | return torch.mean(torch.abs((img1 - img2)**l_num)) 21 | 22 | def mask_flow_loss(flow, gt, task_ids, target_id=5): 23 | batch_size = flow.size(0) 24 | target_ids = torch.full((batch_size, 1), target_id, dtype=task_ids.dtype, device=task_ids.device) 25 | mask = task_ids == target_ids.squeeze(-1) 26 | flow_mask = flow * mask.view(batch_size, 1, 1, 1) 27 | gt_mask = gt * mask.view(batch_size, 1, 1, 1) 28 | return l_num_loss(flow_mask, gt_mask, 1) 29 | 30 | def cal_appearance_loss_sum(warp_list, gt, weights, eps=1e-7): 31 | num = len(warp_list) 32 | loss = [] 33 | for i in range(num): 34 | warp = warp_list[i] 35 | loss.append(l_num_loss(warp, gt, 1) + eps) 36 | 37 | return torch.sum(torch.stack([w * v for w, v in zip(weights, loss)])) 38 | 39 | def cal_perception_loss_sum(vgg_model, warp_list, gt, weights): 40 | num = len(warp_list) 41 | loss = [] 42 | for i in range(num): 43 | warp = warp_list[i] 44 | warp_feature = get_vgg19_FeatureMap(vgg_model, warp, 24) 45 | gt_feature = get_vgg19_FeatureMap(vgg_model, gt, 24) 46 | 47 | loss.append(l_num_loss(warp_feature, gt_feature, 2)) 48 | 49 | return torch.sum(torch.stack([w * v for w, v in zip(weights, loss)])) 50 | 51 | def cal_point_loss(pre, gt): 52 | criterion = nn.CrossEntropyLoss() 53 | return criterion(pre, gt.long().cuda()) 54 | 55 | def cal_inter_grid_loss_sum(mesh_list, tps_points, weights, eps=1e-8): 56 | num = len(mesh_list) 57 | loss = [] 58 | for i in range(num): 59 | mesh = mesh_list[i] 60 | grid_w = tps_points[i]-1 61 | grid_h = tps_points[i]-1 62 | w_edges = mesh[:,:,0:grid_w,:] - mesh[:,:,1:grid_w+1,:] 63 | 64 | w_norm1 = torch.sqrt(torch.sum(w_edges[:,:,0:grid_w-1,:] * w_edges[:,:,0:grid_w-1,:], 3) + eps) 65 | w_norm2 = torch.sqrt(torch.sum(w_edges[:,:,1:grid_w,:] * w_edges[:,:,1:grid_w,:], 3) + eps) 66 | cos_w = torch.sum(w_edges[:,:,0:grid_w-1,:] * w_edges[:,:,1:grid_w,:], 3) / (w_norm1 * w_norm2) 67 | delta_w_angle = 1 - cos_w 68 | 69 | h_edges = mesh[:,0:grid_h,:,:] - mesh[:,1:grid_h+1,:,:] 70 | 71 | h_norm1 = torch.sqrt(torch.sum(h_edges[:,0:grid_h-1,:,:] * h_edges[:,0:grid_h-1,:,:], 3) + eps) 72 | h_norm2 = torch.sqrt(torch.sum(h_edges[:,1:grid_h,:,:] * h_edges[:,1:grid_h,:,:], 3) + eps) 73 | cos_h = torch.sum(h_edges[:,0:grid_h-1,:,:] * h_edges[:,1:grid_h,:,:], 3) / (h_norm1 * h_norm2) 74 | delta_h_angle = 1 - cos_h 75 | 76 | loss.append(torch.mean(delta_w_angle) + torch.mean(delta_h_angle)) 77 | 78 | return torch.sum(torch.stack([w * v for w, v in zip(weights, loss)])) 79 | -------------------------------------------------------------------------------- /utils/torch_tps_upsample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def transformer(source, target, out_size): 5 | """ 6 | Thin Plate Spline Spatial Transformer Layer 7 | convert the TPS deformation into optical flows 8 | TPS control points are arranged in arbitrary positions given by `source`. 9 | source : float Tensor [num_batch, num_point, 2] 10 | The source position of the control points. 11 | target : float Tensor [num_batch, num_point, 2] 12 | The target position of the control points. 13 | out_size: tuple of two integers [height, width] 14 | The size of the output of the network (height, width) 15 | """ 16 | 17 | def _meshgrid(height, width, source): 18 | 19 | x_t = torch.matmul(torch.ones([height, 1]), torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 0)) 20 | y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1), torch.ones([1, width])) 21 | if torch.cuda.is_available(): 22 | x_t = x_t.cuda() 23 | y_t = y_t.cuda() 24 | 25 | x_t_flat = x_t.reshape([1, 1, -1]) 26 | y_t_flat = y_t.reshape([1, 1, -1]) 27 | 28 | num_batch = source.size()[0] 29 | px = torch.unsqueeze(source[:,:,0], 2) 30 | py = torch.unsqueeze(source[:,:,1], 2) 31 | if torch.cuda.is_available(): 32 | px = px.cuda() 33 | py = py.cuda() 34 | d2 = torch.square(x_t_flat - px) + torch.square(y_t_flat - py) 35 | r = d2 * torch.log(d2 + 1e-9) 36 | x_t_flat_g = x_t_flat.expand(num_batch, -1, -1) 37 | y_t_flat_g = y_t_flat.expand(num_batch, -1, -1) 38 | ones = torch.ones_like(x_t_flat_g) 39 | if torch.cuda.is_available(): 40 | ones = ones.cuda() 41 | 42 | grid = torch.cat((ones, x_t_flat_g, y_t_flat_g, r), 1) 43 | 44 | return grid 45 | 46 | def _transform(T, source, out_size): 47 | num_batch, *_ = T.size() 48 | 49 | out_height, out_width = out_size[0], out_size[1] 50 | grid = _meshgrid(out_height, out_width, source) 51 | 52 | T_g = torch.matmul(T, grid) 53 | x_s = T_g[:,0,:] 54 | y_s = T_g[:,1,:] 55 | 56 | flow_x = (x_s - grid[:,1,:])*(out_width/2) 57 | flow_y = (y_s - grid[:,2,:])*(out_height/2) 58 | 59 | flow = torch.stack([flow_x, flow_y], 1) 60 | flow = flow.reshape([num_batch, 2, out_height, out_width]) 61 | 62 | return flow 63 | 64 | 65 | def _solve_system(source, target): 66 | num_batch = source.size()[0] 67 | num_point = source.size()[1] 68 | 69 | np.set_printoptions(precision=8) 70 | 71 | ones = torch.ones(num_batch, num_point, 1).float() 72 | if torch.cuda.is_available(): 73 | ones = ones.cuda() 74 | p = torch.cat([ones, source], 2) 75 | 76 | p_1 = p.reshape([num_batch, -1, 1, 3]) 77 | p_2 = p.reshape([num_batch, 1, -1, 3]) 78 | d2 = torch.sum(torch.square(p_1-p_2), 3) 79 | r = d2 * torch.log(d2 + 1e-9) 80 | 81 | zeros = torch.zeros(num_batch, 3, 3).float() 82 | if torch.cuda.is_available(): 83 | zeros = zeros.cuda() 84 | W_0 = torch.cat((p, r), 2) 85 | W_1 = torch.cat((zeros, p.permute(0,2,1)), 2) 86 | W = torch.cat((W_0, W_1), 1) 87 | W_inv = torch.inverse(W.type(torch.float64)) 88 | 89 | zeros2 = torch.zeros(num_batch, 3, 2) 90 | if torch.cuda.is_available(): 91 | zeros2 = zeros2.cuda() 92 | tp = torch.cat((target, zeros2), 1) 93 | T = torch.matmul(W_inv, tp.type(torch.float64)) 94 | T = T.permute(0, 2, 1) 95 | 96 | return T.type(torch.float32) 97 | 98 | T = _solve_system(source, target) 99 | output = _transform(T, source, out_size) 100 | 101 | return output -------------------------------------------------------------------------------- /Datasets/readme.md: -------------------------------------------------------------------------------- 1 | # Guidance for the Practical Image Warping Datasets 2 | For the convenience of training/testing various image warping tasks in one project, we cleaned and arranged the mainstream warping datasets with unified structures and more visual assistance. Note that the ground truth of warping flow is only available in training dataset of the portrait correction task. For other tasks, we provide the pseudo flows (obtained by [RAFT](https://github.com/princeton-vl/RAFT)) for visualization. The download link and source of each task are shown as follows. 3 | 4 | 5 | | # | Type | Training Set | Testing Set | Source | 6 | |:--:|:---------------------:|:------------------------------------------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------:| 7 | | 1 | Stitched Image | [Google Drive](https://drive.google.com/file/d/1vWaD3bbUd6TZ_wQlehBxVoYlwBxZdAwu/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1eiMwbJMCERbgZUuf6hQOhQ?pwd=fpa8) | [Google Drive](https://drive.google.com/file/d/1Ldqn4Q-mrrRibGNdO9t8cUVpH0YZWK7i/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1MOCOUUJGfqBss76nw-dDiw?pwd=d4du) | [Website](https://github.com/nie-lang/DeepRectangling) | 8 | | 2 | Rectified Wide-Angle Image | [Google Drive](https://drive.google.com/file/d/1Cxv97NybP5t8aBekm7XNvBP4i2CAhxUz/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1J7LeYenzbG4fKxMHePGNpg?pwd=cmc1) | [Google Drive](https://drive.google.com/file/d/1WzXSFQoLuqeAlAXR4gO5eBZhUJgz-aN_/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1dFlPT-iB-F-382ZSQx5z-Q?pwd=wdu8) | [Website](https://github.com/KangLiao929/RecRecNet) | 9 | | 3 | Unrolling Shutter Image | [Google Drive](https://drive.google.com/file/d/1r3B3BAmZjmy5PSzQZllIxE8mM6bSPFNj/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1dHxI8ILMg84dUE6icP3pdg?pwd=cqu3) | [Google Drive](https://drive.google.com/file/d/18Qtj_sp2cDWM3OZv0UYjf-5qpOBzv4QN/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1KnNs-S0BOpwvMer3epnndg?pwd=yvxw) | Self-constructed | 10 | | 4 | Rotated Image | [Google Drive](https://drive.google.com/file/d/1WJAvf5sG3JyLeVOi9ZrSBOcOqdnoURGG/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1ambt8n0Bc9tkMD6jdbagSg?pwd=2nkw) | [Google Drive](https://drive.google.com/file/d/1HdQdK6enbOebASL4Wie3kzkiQ9SUNmBH/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1Tr3SvOnKXZ4JfE9KYVmNfw?pwd=3ea8) | [Website](https://github.com/nie-lang/RotationCorrection) | 11 | | 5 | Fisheye Image | [Google Drive](https://drive.google.com/file/d/1yn_hlVyFRIt3yTBPsDNKKfx9U9j7vY01/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1tVbeSywPyceyeCIdhbGSAQ?pwd=1v6a) | [Google Drive](https://drive.google.com/file/d/1X1851eEB0gvzEJrOOw7uQdKDJ2MQZAaG/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1TwURM0jjLcnjm5v4zdtjCg?pwd=7w56) | [Website](https://github.com/uof1745-cmd/PCN) | 12 | | 6 | Portrait Photo | [Google Drive](https://drive.google.com/file/d/1Ng_dx2Y4v8Qjv4xVtWA6f2il3SklBIu3/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1AJi1X3wkgFiVrO7GhvSrNw?pwd=yidy) | [Google Drive](https://drive.google.com/file/d/18yY7O7Yi6TygnRG73eCTxkVtgMk5qY9k/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1FVCK57a7y8PoHfMDXxVmmA?pwd=ej81) | [Website](https://github.com/megvii-research/Portraits_Correction) | 13 | -------------------------------------------------------------------------------- /utils/utils_op.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | import cv2 6 | from imgaug import augmenters as iaa 7 | import os 8 | 9 | def set_gpu(args, distributed=False, rank=0): 10 | """ set parameter to gpu or ddp """ 11 | if args is None: 12 | return None 13 | if distributed and isinstance(args, torch.nn.Module): 14 | return DDP(args.cuda(), device_ids=[rank], output_device=rank, broadcast_buffers=True, find_unused_parameters=True) 15 | else: 16 | return args.cuda() 17 | 18 | def set_device(args, distributed=False, rank=0): 19 | """ set parameter to gpu or cpu """ 20 | if torch.cuda.is_available(): 21 | if isinstance(args, list): 22 | return (set_gpu(item, distributed, rank) for item in args) 23 | elif isinstance(args, dict): 24 | return {key:set_gpu(args[key], distributed, rank) for key in args} 25 | else: 26 | args = set_gpu(args, distributed, rank) 27 | return args 28 | 29 | def count_files(directory_path): 30 | for subdir, _, _ in os.walk(directory_path): 31 | if subdir != directory_path: 32 | return len([file for _, _, files in os.walk(subdir) for file in files]) 33 | return 0 34 | 35 | def draw_mesh_on_warp(warp, f_local, grid_h, grid_w): 36 | height = warp.shape[0] 37 | width = warp.shape[1] 38 | 39 | min_w = np.minimum(np.min(f_local[:,:,0]), 0).astype(np.int32) 40 | max_w = np.maximum(np.max(f_local[:,:,0]), width).astype(np.int32) 41 | min_h = np.minimum(np.min(f_local[:,:,1]), 0).astype(np.int32) 42 | max_h = np.maximum(np.max(f_local[:,:,1]), height).astype(np.int32) 43 | cw = max_w - min_w 44 | ch = max_h - min_h 45 | 46 | pic = np.ones([ch+10, cw+10, 3], np.int32)*255 47 | pic[0-min_h+5:0-min_h+height+5, 0-min_w+5:0-min_w+width+5, :] = warp 48 | 49 | warp = pic 50 | f_local[:,:,0] = f_local[:,:,0] - min_w+5 51 | f_local[:,:,1] = f_local[:,:,1] - min_h+5 52 | 53 | point_color = (0, 255, 0) 54 | thickness = 2 55 | lineType = 8 56 | num = 1 57 | for i in range(grid_h+1): 58 | for j in range(grid_w+1): 59 | num = num + 1 60 | if j == grid_w and i == grid_h: 61 | continue 62 | elif j == grid_w: 63 | cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType) 64 | elif i == grid_h: 65 | cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType) 66 | else : 67 | cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType) 68 | cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType) 69 | 70 | return warp 71 | 72 | def data_aug(img, gt): 73 | oplist = [] 74 | if random.random() > 0.5: 75 | oplist.append(iaa.GaussianBlur(sigma=(0.0, 1.0))) 76 | elif random.random() > 0.5: 77 | oplist.append(iaa.WithChannels(0, iaa.Add((1, 15)))) 78 | elif random.random() > 0.5: 79 | oplist.append(iaa.WithChannels(1, iaa.Add((1, 15)))) 80 | elif random.random() > 0.5: 81 | oplist.append(iaa.WithChannels(2, iaa.Add((1, 15)))) 82 | elif random.random() > 0.5: 83 | oplist.append(iaa.AdditiveGaussianNoise(scale=(0, 10))) 84 | elif random.random() > 0.5: 85 | oplist.append(iaa.Sharpen(alpha=0.15)) 86 | 87 | seq = iaa.Sequential(oplist) 88 | images_aug = seq.augment_images([img]) 89 | gt_aug = seq.augment_images([gt]) 90 | return images_aug[0], gt_aug[0] 91 | 92 | def get_weight_mask(mask, gt, pred, weight=10): 93 | mask = (mask * (weight - 1)) + 1 94 | gt = gt.mul(mask) 95 | pred = pred.mul(mask) 96 | return gt, pred 97 | 98 | def adjust_weight(epoch, total_epoch, weight): 99 | return (1 - 0.9 * (epoch / total_epoch)) * weight 100 | 101 | def flow2list(flow): 102 | h, w, c = flow.shape 103 | dirs_grid = [] 104 | for i in range(h): 105 | dirs_row = [] 106 | for j in range(w): 107 | dx, dy = flow[i, j, :] 108 | dx = np.round(dx) 109 | dy = np.round(dy) 110 | dirs_row.append([-int(dx), -int(dy)]) 111 | dirs_grid.append(dirs_row) 112 | return dirs_grid -------------------------------------------------------------------------------- /model/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils.torch_tps_upsample as torch_tps_upsample 3 | from utils.utils_transform import * 4 | import torch.nn.functional as F 5 | 6 | def build_model(net, input_tensor1, input_tensor2, mask_tensor, tps_points): 7 | """ 8 | input_tensor1: source image with original resolution 9 | input_tensor2: resized image with fixed resolution (256x256) 10 | .. input_tensor1 == input_tensor2 in training 11 | """ 12 | batch_size, _, img_h, img_w = input_tensor1.size() 13 | batch_size, _, input_size, input_size = input_tensor2.size() 14 | 15 | offset, flow, point_cls = net(input_tensor2, mask_tensor) 16 | head_num = len(offset) 17 | norm_rigid_mesh_list = [] 18 | norm_ori_mesh_list = [] 19 | output_tps_list = [] 20 | ori_mesh_list = [] 21 | tps2flow_list = [] 22 | 23 | for i in range(head_num): 24 | mesh_motion = offset[i].reshape(-1, tps_points[i], tps_points[i], 2) 25 | rigid_mesh = get_rigid_mesh(batch_size, input_size, input_size, tps_points[i]-1, tps_points[i]-1) 26 | ori_mesh = rigid_mesh + mesh_motion 27 | clamped_x = torch.clamp(ori_mesh[..., 0], min=0, max=input_size - 1) 28 | clamped_y = torch.clamp(ori_mesh[..., 1], min=0, max=input_size - 1) 29 | ori_mesh = torch.stack((clamped_x, clamped_y), dim=-1) 30 | 31 | norm_rigid_mesh = get_norm_mesh(rigid_mesh, input_size, input_size) 32 | norm_ori_mesh = get_norm_mesh(ori_mesh, input_size, input_size) 33 | tps2flow = torch_tps_upsample.transformer(norm_rigid_mesh, norm_ori_mesh, (img_h, img_w)) 34 | output_tps = resample_image(input_tensor1, tps2flow) 35 | 36 | norm_rigid_mesh_list.append(norm_rigid_mesh) 37 | norm_ori_mesh_list.append(norm_ori_mesh) 38 | output_tps_list.append(output_tps) 39 | tps2flow_list.append(tps2flow) 40 | ori_mesh_list.append(ori_mesh) 41 | 42 | tps_flow = tps2flow_list[-1] 43 | final_flow = flow + tps_flow 44 | output_flow = resample_image(output_tps_list[-1], flow) 45 | 46 | out_dict = {} 47 | out_dict.update(warp_tps=output_tps_list, warp_flow=output_flow, mesh=ori_mesh_list, 48 | flow1=flow, flow2=tps_flow, flow3=final_flow, point_cls=point_cls) 49 | return out_dict 50 | 51 | 52 | def build_model_test(net, input_tensor1, input_tensor2, mask_tensor, tps_points, resize_flow=False): 53 | """ 54 | input_tensor1: source image with original resolution 55 | input_tensor2: resized image with fixed resolution (256x256) 56 | .. input_tensor1 = input_tensor2 in training 57 | """ 58 | batch_size, _, img_h, img_w = input_tensor1.size() 59 | batch_size, _, input_size, input_size = input_tensor2.size() 60 | 61 | offset, flow, point_cls = net(input_tensor2, mask_tensor) 62 | head_num = len(offset) 63 | norm_rigid_mesh_list = [] 64 | norm_ori_mesh_list = [] 65 | output_tps_list = [] 66 | ori_mesh_list = [] 67 | tps2flow_list = [] 68 | for i in range(head_num): 69 | mesh_motion = offset[i].reshape(-1, tps_points[i], tps_points[i], 2) 70 | rigid_mesh = get_rigid_mesh(batch_size, input_size, input_size, tps_points[i]-1, tps_points[i]-1) 71 | ori_mesh = rigid_mesh + mesh_motion 72 | clamped_x = torch.clamp(ori_mesh[..., 0], min=0, max=input_size - 1) 73 | clamped_y = torch.clamp(ori_mesh[..., 1], min=0, max=input_size - 1) 74 | ori_mesh = torch.stack((clamped_x, clamped_y), dim=-1) 75 | 76 | norm_rigid_mesh = get_norm_mesh(rigid_mesh, input_size, input_size) 77 | norm_ori_mesh = get_norm_mesh(ori_mesh, input_size, input_size) 78 | tps2flow = torch_tps_upsample.transformer(norm_rigid_mesh, norm_ori_mesh, (img_h, img_w)) 79 | output_tps = resample_image_xy(input_tensor1, tps2flow) 80 | norm_rigid_mesh_list.append(norm_rigid_mesh) 81 | norm_ori_mesh_list.append(norm_ori_mesh) 82 | output_tps_list.append(output_tps) 83 | ori_mesh_list.append(ori_mesh) 84 | tps2flow_list.append(tps2flow) 85 | 86 | tps_flow = tps2flow_list[-1] 87 | if(resize_flow): 88 | flow = F.interpolate(flow, size=(img_h, img_w), mode='bilinear', align_corners=True) 89 | scale_H, scale_W = img_h / input_size, img_w / input_size 90 | flow[:, 0, :, :] *= scale_W 91 | flow[:, 1, :, :] *= scale_H 92 | 93 | final_flow = flow + tps_flow 94 | output_flow = resample_image_xy(output_tps_list[-1], flow) 95 | out_dict = {} 96 | out_dict.update(warp_tps=output_tps_list, warp_flow=output_flow, mesh=ori_mesh_list, 97 | flow1=flow, flow2=tps_flow, flow3=final_flow, point_cls=point_cls) 98 | return out_dict 99 | -------------------------------------------------------------------------------- /utils/utils_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import utils.torch_tps_transform as torch_tps_transform 5 | import utils.torch_tps_upsample as torch_tps_upsample 6 | 7 | def get_rigid_mesh(batch_size, height, width, grid_w, grid_h): 8 | ww = torch.matmul(torch.ones([grid_h+1, 1]), torch.unsqueeze(torch.linspace(0., float(width), grid_w+1), 0)) 9 | hh = torch.matmul(torch.unsqueeze(torch.linspace(0.0, float(height), grid_h+1), 1), torch.ones([1, grid_w+1])) 10 | if torch.cuda.is_available(): 11 | ww = ww.cuda() 12 | hh = hh.cuda() 13 | 14 | ori_pt = torch.cat((ww.unsqueeze(2), hh.unsqueeze(2)),2) 15 | ori_pt = ori_pt.unsqueeze(0).expand(batch_size, -1, -1, -1) 16 | 17 | return ori_pt 18 | 19 | def get_norm_mesh(mesh, height, width): 20 | batch_size = mesh.size()[0] 21 | mesh_w = mesh[...,0]*2./float(width) - 1. 22 | mesh_h = mesh[...,1]*2./float(height) - 1. 23 | norm_mesh = torch.stack([mesh_w, mesh_h], 3) 24 | 25 | return norm_mesh.reshape([batch_size, -1, 2]) 26 | 27 | def transform_tps_fea(offset, input_tensor, grid_w, grid_h, dim, h, w): 28 | input_tensor = input_tensor.permute(0,2,1).view(-1, dim, h, w) 29 | batch_size, _, img_h, img_w = input_tensor.size() 30 | 31 | mesh_motion = offset.reshape(-1, grid_h+1, grid_w+1, 2) 32 | 33 | rigid_mesh = get_rigid_mesh(batch_size, img_h, img_w, grid_w, grid_h) 34 | ori_mesh = rigid_mesh + mesh_motion 35 | 36 | clamped_x = torch.clamp(ori_mesh[..., 0], min=0, max=img_h - 1) 37 | clamped_y = torch.clamp(ori_mesh[..., 1], min=0, max=img_w - 1) 38 | ori_mesh = torch.stack((clamped_x, clamped_y), dim=-1) 39 | 40 | norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w) 41 | norm_ori_mesh = get_norm_mesh(ori_mesh, img_h, img_w) 42 | 43 | output_tps = torch_tps_transform.transformer(input_tensor, norm_rigid_mesh, norm_ori_mesh, (img_h, img_w)) 44 | output_tps = output_tps.view(-1, dim, h*w).permute(0,2,1) 45 | 46 | return output_tps 47 | 48 | def upsample_tps(offset, grid_w, grid_h, out_h, out_w): 49 | if(grid_w+1 == out_w): 50 | return offset 51 | 52 | else: 53 | batch_size, *_ = offset.size() 54 | mesh_motion = offset.reshape(-1, grid_h+1, grid_w+1, 2) 55 | 56 | rigid_mesh = get_rigid_mesh(batch_size, out_h, out_w, grid_w, grid_h) 57 | ori_mesh = rigid_mesh + mesh_motion 58 | 59 | norm_rigid_mesh = get_norm_mesh(rigid_mesh, out_h, out_w) 60 | norm_ori_mesh = get_norm_mesh(ori_mesh, out_h, out_w) 61 | 62 | up_points = torch_tps_upsample.transformer(norm_rigid_mesh, norm_ori_mesh, (out_h, out_w)) 63 | out = up_points.permute(0, 2, 3, 1).view(-1, out_h*out_w, 2) 64 | 65 | return out 66 | 67 | def get_coordinate(shape, det_uv): 68 | b, _, w, h = shape 69 | uv_d = np.zeros([w, h, 2], np.float32) 70 | 71 | for i in range(0, w): 72 | for j in range(0, h): 73 | uv_d[i, j, 0] = j 74 | uv_d[i, j, 1] = i 75 | 76 | uv_d = np.expand_dims(uv_d.swapaxes(2, 1).swapaxes(1, 0), 0) 77 | uv_d = torch.from_numpy(uv_d).cuda() 78 | uv_d = uv_d.repeat(b, 1, 1, 1) 79 | 80 | det_uv = uv_d + det_uv 81 | return det_uv 82 | 83 | def uniform(shape, img_uv): 84 | b, _, w, h = shape 85 | x0 = (w - 1) / 2. 86 | 87 | img_nor = (img_uv - x0)/x0 88 | img_nor = img_nor.permute(0, 2, 3, 1) 89 | return img_nor 90 | 91 | def resample_image(feature, flow): 92 | img_uv = get_coordinate(feature.shape, flow) 93 | grid = uniform(feature.shape, img_uv) 94 | target_image = F.grid_sample(feature, grid) 95 | return target_image 96 | 97 | def get_coordinate_xy(shape, det_uv): 98 | b, _, h, w = shape 99 | uv_d = np.zeros([h, w, 2], np.float32) 100 | 101 | for j in range(0, h): 102 | for i in range(0, w): 103 | uv_d[j, i, 0] = i 104 | uv_d[j, i, 1] = j 105 | 106 | uv_d = np.expand_dims(uv_d.swapaxes(2, 1).swapaxes(1, 0), 0) 107 | uv_d = torch.from_numpy(uv_d).cuda() 108 | uv_d = uv_d.repeat(b, 1, 1, 1) 109 | det_uv = uv_d + det_uv 110 | return det_uv 111 | 112 | def uniform_xy(shape, uv): 113 | b, _, h, w = shape 114 | y0 = (h - 1) / 2. 115 | x0 = (w - 1) / 2. 116 | 117 | nor = uv.clone() 118 | nor[:, 0, :, :] = (uv[:, 0, :, :] - x0) / x0 119 | nor[:, 1, :, :] = (uv[:, 1, :, :] - y0) / y0 120 | nor = nor.permute(0, 2, 3, 1) # b w h 2 121 | 122 | return nor 123 | 124 | def resample_image_xy(feature, flow): 125 | uv = get_coordinate_xy(feature.shape, flow) 126 | grid = uniform_xy(feature.shape, uv) 127 | target_image = F.grid_sample(feature, grid) 128 | return target_image -------------------------------------------------------------------------------- /utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /utils/utils_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class AddCoords(nn.Module): 7 | def __init__(self): 8 | super(AddCoords, self).__init__() 9 | 10 | def forward(self, input_tensor): 11 | batch_size, _, height, width = input_tensor.size() 12 | 13 | xx_channel = torch.arange(width).repeat(1, height, 1) 14 | yy_channel = torch.arange(height).repeat(1, width, 1).transpose(1, 2) 15 | 16 | xx_channel = xx_channel.float() / (width - 1) 17 | yy_channel = yy_channel.float() / (height - 1) 18 | 19 | xx_channel = xx_channel * 2 - 1 20 | yy_channel = yy_channel * 2 - 1 21 | 22 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1) 23 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1) 24 | 25 | input_tensor = torch.cat([ 26 | input_tensor, 27 | xx_channel.type_as(input_tensor), 28 | yy_channel.type_as(input_tensor)], dim=1) 29 | 30 | return input_tensor 31 | 32 | class CoordConv(nn.Module): 33 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 34 | super(CoordConv, self).__init__() 35 | self.addcoords = AddCoords() 36 | self.conv = nn.Conv2d(in_channels+2, out_channels, kernel_size, stride, padding) 37 | 38 | def forward(self, x): 39 | x = self.addcoords(x) 40 | x = self.conv(x) 41 | return x 42 | 43 | class MotionNet_Coord(nn.Module): 44 | def __init__(self, in_channel, out_channel, num): 45 | super(MotionNet_Coord, self).__init__() 46 | if(num==8): 47 | self.conv = nn.Sequential( 48 | CoordConv(in_channel, out_channel, kernel_size=3, stride=2, padding=1), 49 | ) 50 | if(num==10): 51 | self.conv = nn.Sequential( 52 | CoordConv(in_channel, 64, kernel_size=5, stride=1, padding=0), 53 | CoordConv(64, out_channel, kernel_size=3, stride=1, padding=0), 54 | ) 55 | if(num==12): 56 | self.conv = nn.Sequential( 57 | CoordConv(in_channel, out_channel, kernel_size=5, stride=1, padding=0), 58 | ) 59 | if(num==14): 60 | self.conv = nn.Sequential( 61 | CoordConv(in_channel, out_channel, kernel_size=3, stride=1, padding=0), 62 | ) 63 | if(num==16): 64 | self.conv = nn.Sequential( 65 | CoordConv(in_channel, out_channel, kernel_size=3, stride=1, padding=1), 66 | ) 67 | 68 | def forward(self, x): 69 | B, L, C = x.shape 70 | H = int(math.sqrt(L)) 71 | W = int(math.sqrt(L)) 72 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 73 | out = self.conv(x).flatten(2).transpose(1, 2).contiguous() 74 | return out 75 | 76 | class PointNet(nn.Module): 77 | def __init__(self, num_classes=4, grid_h=12, grid_w=12): 78 | super(PointNet, self).__init__() 79 | self.conv1 = nn.Conv1d(8, 256, 1) 80 | self.conv2 = nn.Conv1d(256, 256, 1) 81 | self.conv3 = nn.Conv1d(256, 512, 1) 82 | self.fc1 = nn.Linear(512, 512) 83 | self.fc2 = nn.Linear(512, 256) 84 | self.fc3 = nn.Linear(256, num_classes) 85 | self.dropout1 = nn.Dropout(p=0.3) 86 | self.dropout2 = nn.Dropout(p=0.3) 87 | self.h = grid_h 88 | self.w = grid_w 89 | self.fc_fea = nn.Linear(16*16, 6) 90 | 91 | def forward(self, pre, fea): 92 | '''pre-processing the predicted points''' 93 | x = pre.reshape(-1, 2, (self.h*self.w)) 94 | 95 | '''pre-processing the prompt features and form the superpoint''' 96 | fea = nn.MaxPool1d(fea.size(-1))(fea).squeeze(-1) 97 | fea = F.relu(self.fc_fea(fea)) 98 | fea = fea.unsqueeze(-1).repeat(1, 1, self.h*self.w) 99 | superpoint = torch.cat((x, fea), dim=1) 100 | 101 | '''learn the superpoints' features''' 102 | x = F.relu(self.conv1(superpoint)) 103 | x = F.relu(self.conv2(x)) 104 | x = F.relu(self.conv3(x)) 105 | x = nn.MaxPool1d(x.size(-1))(x) 106 | x = x.view(-1, 512) 107 | 108 | '''classification''' 109 | x = F.relu(self.fc1(x)) 110 | x = self.dropout1(x) 111 | x = F.relu(self.fc2(x)) 112 | x = self.dropout2(x) 113 | x = self.fc3(x) 114 | return x 115 | 116 | class PromptGenBlock(nn.Module): 117 | def __init__(self,prompt_dim=128,prompt_len=5,prompt_size = 96,lin_dim = 192): 118 | super(PromptGenBlock,self).__init__() 119 | self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size)) 120 | self.linear_layer = nn.Linear(lin_dim,prompt_len) 121 | self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False) 122 | 123 | def forward(self,x): 124 | B, C, H, W = x.shape 125 | emb = x.mean(dim=(-2,-1)) 126 | prompt_weights = F.softmax(self.linear_layer(emb),dim=1) 127 | prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1) 128 | prompt = torch.sum(prompt,dim=1) 129 | prompt = F.interpolate(prompt,(H,W),mode="bilinear") 130 | prompt = self.conv3x3(prompt) 131 | return prompt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MOWA: Multiple-in-One Image Warping Model 2 | 3 | ## Introduction 4 | This is the official implementation for [MOWA](https://arxiv.org/abs/2404.10716) (TPAMI 2025). 5 | 6 | [Kang Liao](https://kangliao929.github.io/), [Zongsheng Yue](https://zsyoaoa.github.io/), [Zhonghua Wu](https://wu-zhonghua.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) 7 | 8 | S-Lab, Nanyang Technological University 9 | 10 | 11 |
12 | 13 |
14 | 15 | > ### Why MOWA? 16 | > MOWA is a practical multiple-in-one image warping framework, particularly in computational photography, where six distinct tasks are considered. Compared to previous works tailored to specific tasks, our method can solve various warping tasks from different camera models or manipulation spaces in a single framework. It also demonstrates an ability to generalize to novel scenarios, as evidenced in both cross-domain and zero-shot evaluations. 17 | > ### Features 18 | > * The first practical multiple-in-one image warping framework especially in the field of computational photography. 19 | > * We propose to mitigate the difficulty of multi-task learning by decoupling the motion estimation in both the region level and pixel level. 20 | > * A prompt learning module, guided by a lightweight point-based classifier, is designed to facilitate task-aware image warpings. 21 | > * We show that through multi-task learning, our framework develops a robust generalized warping strategy that gains improved performance across various tasks and even generalizes to unseen tasks. 22 | 23 | Check out more visual results and interactions [here](https://kangliao929.github.io/projects/mowa/). 24 | 25 | ## 📢 News 26 | - Our recent work **Puffin** can unify the camera-centric understanding (camera calibration, pose estimation) and generation (camera-controllable T2I and I2I generation) within a cohesive multimodal framework. It enables more precise understanding and generation performance by our proposed *thinking with camera*, and provides insights on the meaningful mutual effect among multimodal tasks. If you are interested in the camera-related 3D vision, photography, embodied AI, and spatial intelligence, please check out more details [here](https://kangliao929.github.io/projects/puffin/). 27 | - MOWA has been included in [AI Art Weekly #80](https://aiartweekly.com/issues/80). 28 | 29 | ## 📝 Changelog 30 | 31 | - [x] 2024.04.16: The paper of the arXiv version is online. 32 | - [x] 2024.06.30: Release the code and pre-trained model. 33 | - [x] 2025.05.01: MOWA has been accepted to IEEE TPAMI. 34 | - [ ] Release a demo for users to try MOWA online. 35 | - [ ] Release an interactive interface to drag the control points and perform customized warpings. 36 | 37 | ## Installation 38 | Using the virtual environment (conda) to run the code is recommended. 39 | ``` 40 | conda create -n mowa python=3.8.13 41 | conda activate mowa 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | ## Dataset 46 | We mainly explored six representative image warping tasks in this work. The datasets are derived/constructed from previous works. For the convenience of training and testing in one project, we cleaned and arranged these six types of datasets with unified structures and more visual assistance. Please refer to the category and download links in [Datasets](https://github.com/KangLiao929/MOWA/tree/main/Datasets). 47 | 48 | ## Pretrained Model 49 | Download the pretrained model from [Google Drive](https://drive.google.com/file/d/1fxQbD1TLoRnW8lG2a8KMinmD6Jlol8EX/view?usp=drive_link) or [Baidu Netdisk](https://pan.baidu.com/s/1swMZTkTm1iSYDGVsdepdBA?pwd=hcvy), and put it into the ```.\checkpoint``` folder. 50 | 51 | ## Testing 52 | ### Unified Warping and Evaluation on Public Benchmark 53 | Customize the paths of the checkpoint and test set, and run: 54 | ``` 55 | sh scripts/test.sh 56 | ``` 57 | The warped images and the intermediate results such as the control points and warping flow can be found in the ```.\results``` folder. The evaluated metrics such as PSNR and SSIM are also shown with the task ID. 58 | 59 | ### Specific Evaluation on Portrait Correction 60 | In the portrait correction task, the ground truth of warped image and flow is unavailable and thus the image quality metrics cannot be evaluated. Instead, the specific metric (ShapeAcc) regarding this task's purpose, i.e., correcting the face distortion, was presented. To reproduce the warping performance on portrait photos, customize the paths of checkpoint and test set, and run: 61 | ``` 62 | sh scripts/test_portrait.sh 63 | ``` 64 | The warped images can also be found in the test path. 65 | 66 | ## Training 67 | Customize the paths of all warping training datasets in a list, and run: 68 | ``` 69 | sh scripts/train.sh 70 | ``` 71 | ## Projects that use MOWA 72 | * MOWA-onnxrun (ONNX Runtime Implementation): [https://github.com/hpc203/MOWA-onnxrun](https://github.com/hpc203/MOWA-onnxrun) 73 | 74 | ## Demo 75 | TBD 76 | 77 | ## Acknowledgment 78 | The current version of **MOWA** is inspired by previous specific image warping works such as [RectanglingPano](https://people.csail.mit.edu/kaiming/publications/sig13pano.pdf), [DeepRectangling](https://github.com/nie-lang/DeepRectangling), [RecRecNet](https://github.com/KangLiao929/RecRecNet), [PCN](https://github.com/uof1745-cmd/PCN), [Deep_RS-HM](https://github.com/DavidYan2001/Deep_RS-HM), [SSPC](https://github.com/megvii-research/Portraits_Correction). 79 | 80 | ## Citation 81 | 82 | ```bibtex 83 | @article{liao2024mowa, 84 | title={MOWA: Multiple-in-One Image Warping Model}, 85 | author={Liao, Kang and Yue, Zongsheng and Wu, Zhonghua and Loy, Chen Change}, 86 | journal={arXiv preprint arXiv:2404.10716}, 87 | year={2024} 88 | } 89 | ``` 90 | 91 | ## Contact 92 | For any questions, feel free to email `kang.liao@ntu.edu.sg`. 93 | 94 | ## License 95 | This project is licensed under [NTU S-Lab License 1.0](LICENSE). 96 | -------------------------------------------------------------------------------- /dataset_loaders.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import cv2, torch 4 | import os 5 | import glob 6 | from collections import OrderedDict 7 | from utils.utils_op import data_aug 8 | 9 | class TrainDataset(Dataset): 10 | def __init__(self, paths): 11 | 12 | self.width = 256 13 | self.height = 256 14 | self.prob = 0.5 15 | self.input_images = [] 16 | self.gt_images = [] 17 | self.masks = [] 18 | self.flows = [] 19 | self.task_id = [] 20 | for index, path in enumerate(paths): 21 | inputs = glob.glob(os.path.join(path, 'input/', '*.*')) 22 | gts = glob.glob(os.path.join(path, 'gt/', '*.*')) 23 | masks = glob.glob(os.path.join(path, 'mask/', '*.*')) 24 | flows = glob.glob(os.path.join(path, 'flow_npy/', '*.*')) 25 | inputs.sort() 26 | gts.sort() 27 | masks.sort() 28 | flows.sort() 29 | 30 | lens = len(inputs) 31 | index_array = [index] * lens 32 | self.task_id.extend(index_array) 33 | self.input_images.extend(inputs) 34 | self.gt_images.extend(gts) 35 | self.masks.extend(masks) 36 | self.flows.extend(flows) 37 | 38 | print("total dataset num: ", len(self.input_images)) 39 | 40 | def __getitem__(self, index): 41 | 42 | '''load images''' 43 | task_id = self.task_id[index] 44 | input_src = cv2.imread(self.input_images[index]) 45 | input_resized = cv2.resize(input_src, (self.width, self.height)) 46 | gt = cv2.imread(self.gt_images[index]) 47 | gt = cv2.resize(gt, (self.width, self.height)) 48 | input_resized, gt = data_aug(input_resized, gt) 49 | 50 | input_resized = input_resized.astype(dtype=np.float32) 51 | input_resized = input_resized / 255.0 52 | input_resized = np.transpose(input_resized, [2, 0, 1]) 53 | gt = gt.astype(dtype=np.float32) 54 | gt = gt / 255.0 55 | gt = np.transpose(gt, [2, 0, 1]) 56 | 57 | '''load mask''' 58 | mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) 59 | mask = cv2.resize(mask, (self.width, self.height)) 60 | mask = np.expand_dims(mask, axis=-1) 61 | mask = mask.astype(dtype=np.float32) 62 | mask = mask / 255.0 63 | mask = np.transpose(mask, [2, 0, 1]) 64 | mask_tensor = torch.tensor(mask) 65 | 66 | input_tensor = torch.tensor(input_resized) 67 | gt_tensor = torch.tensor(gt) 68 | task_id_tensor = torch.tensor(task_id, dtype=torch.int64) 69 | 70 | '''load flow and face mask for the portrait task''' 71 | if(task_id == 5): 72 | flow = np.load(self.flows[index]) 73 | flow = flow.astype(dtype=np.float32) 74 | flow = np.transpose(flow, [2, 0, 1]) 75 | flow_tensor = torch.tensor(flow) 76 | 77 | face_mask_path = self.input_images[index].replace('/input/', '/mask_face/') 78 | facemask = cv2.imread(face_mask_path, 0) 79 | facemask = facemask.astype(dtype=np.float32) 80 | facemask = (facemask / 255.0) 81 | facemask = np.expand_dims(facemask, axis=-1) 82 | facemask = np.transpose(facemask, [2, 0, 1]) 83 | face_mask = torch.tensor(facemask) 84 | mask_sum = torch.sum(face_mask) 85 | weight = self.width * self.height / mask_sum - 1 86 | weight = torch.max(weight / 3, torch.ones(1)) 87 | face_weight = weight.unsqueeze(-1).unsqueeze(-1) 88 | else: 89 | flow_tensor = torch.zeros_like(input_tensor[0:2, :, :]) 90 | face_mask = torch.zeros_like(mask_tensor) 91 | face_weight = torch.mean(face_mask).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 92 | 93 | return (input_tensor, gt_tensor, mask_tensor, task_id_tensor, flow_tensor, face_mask, face_weight) 94 | 95 | def __len__(self): 96 | return len(self.input_images) 97 | 98 | class TestDatasetMask(Dataset): 99 | def __init__(self, data_path, task_id): 100 | self.width = 256 101 | self.height = 256 102 | self.test_path = data_path 103 | self.datas = OrderedDict() 104 | self.task_id = task_id 105 | 106 | datas = glob.glob(os.path.join(self.test_path, '*')) 107 | for data in sorted(datas): 108 | data_name = data.split('/')[-1] 109 | if data_name == 'input' or data_name == 'gt' or data_name == 'mask': 110 | self.datas[data_name] = {} 111 | self.datas[data_name]['path'] = data 112 | self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.*')) 113 | self.datas[data_name]['image'].sort() 114 | 115 | def __getitem__(self, index): 116 | 117 | input = cv2.imread(self.datas['input']['image'][index]) 118 | input1 = input.astype(dtype=np.float32) 119 | input1 = input1 / 255.0 120 | input1 = np.transpose(input1, [2, 0, 1]) 121 | 122 | input2 = cv2.resize(input, (self.width, self.height)) 123 | input2 = input2.astype(dtype=np.float32) 124 | input2 = input2 / 255.0 125 | input2 = np.transpose(input2, [2, 0, 1]) 126 | 127 | gt = cv2.imread(self.datas['gt']['image'][index]) 128 | gt1 = gt.astype(dtype=np.float32) 129 | gt1 = gt1 / 255.0 130 | gt1 = np.transpose(gt1, [2, 0, 1]) 131 | 132 | gt2 = cv2.resize(gt, (self.width, self.height)) 133 | gt2 = gt2.astype(dtype=np.float32) 134 | gt2 = gt2 / 255.0 135 | gt2 = np.transpose(gt2, [2, 0, 1]) 136 | 137 | mask = cv2.imread(self.datas['mask']['image'][index], cv2.IMREAD_GRAYSCALE) 138 | mask = cv2.resize(mask, (self.width, self.height)) 139 | mask = np.expand_dims(mask, axis=-1) 140 | mask = mask.astype(dtype=np.float32) 141 | mask = mask / 255.0 142 | mask = np.transpose(mask, [2, 0, 1]) 143 | 144 | input1_tensor = torch.tensor(input1) 145 | input2_tensor = torch.tensor(input2) 146 | gt1_tensor = torch.tensor(gt1) 147 | gt2_tensor = torch.tensor(gt2) 148 | mask_tensor = torch.tensor(mask) 149 | task_id_tensor = torch.tensor(self.task_id, dtype=torch.int64) 150 | file_name = os.path.basename(self.datas['input']['image'][index]) 151 | file_name, _ = os.path.splitext(file_name) 152 | 153 | out_dict = {} 154 | out_dict.update(input1_tensor=input1_tensor, input2_tensor=input2_tensor, gt1_tensor=gt1_tensor, 155 | gt2_tensor=gt2_tensor, mask_tensor=mask_tensor, task_id_tensor=task_id_tensor, file_name=file_name) 156 | 157 | return out_dict 158 | 159 | def __len__(self): 160 | return len(self.datas['input']['image']) -------------------------------------------------------------------------------- /utils/torch_tps_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | import torch.nn.functional as F 5 | 6 | def transformer(U, source, target, out_size): 7 | """ 8 | Thin Plate Spline Spatial Transformer Layer 9 | TPS control points are arranged in arbitrary positions given by `source`. 10 | U : float Tensor [num_batch, height, width, num_channels]. 11 | Input Tensor. 12 | source : float Tensor [num_batch, num_point, 2] 13 | The source position of the control points. 14 | target : float Tensor [num_batch, num_point, 2] 15 | The target position of the control points. 16 | out_size: tuple of two integers [height, width] 17 | The size of the output of the network (height, width) 18 | """ 19 | 20 | def _repeat(x, n_repeats): 21 | rep = torch.ones([n_repeats, ]).unsqueeze(0) 22 | rep = rep.int() 23 | x = x.int() 24 | 25 | x = torch.matmul(x.reshape([-1,1]), rep) 26 | return x.reshape([-1]) 27 | 28 | def _interpolate(im, x, y, out_size): 29 | num_batch, num_channels , height, width = im.size() 30 | 31 | height_f = height 32 | width_f = width 33 | out_height, out_width = out_size[0], out_size[1] 34 | 35 | zero = 0 36 | max_y = height - 1 37 | max_x = width - 1 38 | 39 | x = (x + 1.0)*(width_f) / 2.0 40 | y = (y + 1.0) * (height_f) / 2.0 41 | 42 | # sampling 43 | x0 = torch.floor(x).int() 44 | x1 = x0 + 1 45 | y0 = torch.floor(y).int() 46 | y1 = y0 + 1 47 | 48 | x0 = torch.clamp(x0, zero, max_x) 49 | x1 = torch.clamp(x1, zero, max_x) 50 | y0 = torch.clamp(y0, zero, max_y) 51 | y1 = torch.clamp(y1, zero, max_y) 52 | dim2 = torch.from_numpy( np.array(width) ) 53 | dim1 = torch.from_numpy( np.array(width * height) ) 54 | 55 | base = _repeat(torch.arange(0,num_batch) * dim1, out_height * out_width) 56 | if torch.cuda.is_available(): 57 | dim2 = dim2.cuda() 58 | dim1 = dim1.cuda() 59 | y0 = y0.cuda() 60 | y1 = y1.cuda() 61 | x0 = x0.cuda() 62 | x1 = x1.cuda() 63 | base = base.cuda() 64 | base_y0 = base + y0 * dim2 65 | base_y1 = base + y1 * dim2 66 | idx_a = base_y0 + x0 67 | idx_b = base_y1 + x0 68 | idx_c = base_y0 + x1 69 | idx_d = base_y1 + x1 70 | 71 | # channels dim 72 | im = im.permute(0,2,3,1) 73 | im_flat = im.reshape([-1, num_channels]).float() 74 | 75 | 76 | idx_a = idx_a.unsqueeze(-1).long() 77 | idx_a = idx_a.expand(out_height * out_width * num_batch,num_channels) 78 | Ia = torch.gather(im_flat, 0, idx_a) 79 | 80 | idx_b = idx_b.unsqueeze(-1).long() 81 | idx_b = idx_b.expand(out_height * out_width * num_batch, num_channels) 82 | Ib = torch.gather(im_flat, 0, idx_b) 83 | 84 | idx_c = idx_c.unsqueeze(-1).long() 85 | idx_c = idx_c.expand(out_height * out_width * num_batch, num_channels) 86 | Ic = torch.gather(im_flat, 0, idx_c) 87 | 88 | idx_d = idx_d.unsqueeze(-1).long() 89 | idx_d = idx_d.expand(out_height * out_width * num_batch, num_channels) 90 | Id = torch.gather(im_flat, 0, idx_d) 91 | 92 | x0_f = x0.float() 93 | x1_f = x1.float() 94 | y0_f = y0.float() 95 | y1_f = y1.float() 96 | 97 | wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1) 98 | wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1) 99 | wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1) 100 | wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1) 101 | output = wa*Ia+wb*Ib+wc*Ic+wd*Id 102 | 103 | return output 104 | 105 | def _meshgrid(height, width, source): 106 | 107 | x_t = torch.matmul(torch.ones([height, 1]), torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 0)) 108 | y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1), torch.ones([1, width])) 109 | if torch.cuda.is_available(): 110 | x_t = x_t.cuda() 111 | y_t = y_t.cuda() 112 | 113 | x_t_flat = x_t.reshape([1, 1, -1]) 114 | y_t_flat = y_t.reshape([1, 1, -1]) 115 | 116 | num_batch = source.size()[0] 117 | px = torch.unsqueeze(source[:,:,0], 2) 118 | py = torch.unsqueeze(source[:,:,1], 2) 119 | if torch.cuda.is_available(): 120 | px = px.cuda() 121 | py = py.cuda() 122 | d2 = torch.square(x_t_flat - px) + torch.square(y_t_flat - py) 123 | r = d2 * torch.log(d2 + 1e-6) 124 | x_t_flat_g = x_t_flat.expand(num_batch, -1, -1) 125 | y_t_flat_g = y_t_flat.expand(num_batch, -1, -1) 126 | ones = torch.ones_like(x_t_flat_g) 127 | if torch.cuda.is_available(): 128 | ones = ones.cuda() 129 | 130 | grid = torch.cat((ones, x_t_flat_g, y_t_flat_g, r), 1) 131 | 132 | return grid 133 | 134 | def _transform(T, source, input_dim, out_size): 135 | num_batch, num_channels, height, width = input_dim.size() 136 | 137 | out_height, out_width = out_size[0], out_size[1] 138 | grid = _meshgrid(out_height, out_width, source) 139 | 140 | # transform A x (1, x_t, y_t, r1, r2, ..., rn) -> (x_s, y_s) 141 | # [bn, 2, pn+3] x [bn, pn+3, h*w] -> [bn, 2, h*w] 142 | T_g = torch.matmul(T, grid) 143 | x_s = T_g[:,0,:] 144 | y_s = T_g[:,1,:] 145 | x_s_flat = x_s.reshape([-1]) 146 | y_s_flat = y_s.reshape([-1]) 147 | 148 | input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat,out_size) 149 | 150 | output = input_transformed.reshape([num_batch, out_height, out_width, num_channels]) 151 | 152 | output = output.permute(0,3,1,2) 153 | return output 154 | 155 | def _solve_system(source, target): 156 | num_batch = source.size()[0] 157 | num_point = source.size()[1] 158 | 159 | ones = torch.ones(num_batch, num_point, 1).float() 160 | if torch.cuda.is_available(): 161 | ones = ones.cuda() 162 | p = torch.cat([ones, source], 2) 163 | 164 | p_1 = p.reshape([num_batch, -1, 1, 3]) 165 | p_2 = p.reshape([num_batch, 1, -1, 3]) 166 | d2 = torch.sum(torch.square(p_1-p_2), 3) 167 | 168 | r = d2 * torch.log(d2 + 1e-6) 169 | 170 | zeros = torch.zeros(num_batch, 3, 3).float() 171 | if torch.cuda.is_available(): 172 | zeros = zeros.cuda() 173 | W_0 = torch.cat((p, r), 2) 174 | W_1 = torch.cat((zeros, p.permute(0,2,1)), 2) 175 | W = torch.cat((W_0, W_1), 1) 176 | 177 | W_inv = torch.inverse(W.type(torch.float64)) 178 | 179 | zeros2 = torch.zeros(num_batch, 3, 2) 180 | if torch.cuda.is_available(): 181 | zeros2 = zeros2.cuda() 182 | tp = torch.cat((target, zeros2), 1) 183 | T = torch.matmul(W_inv, tp.type(torch.float64)) 184 | T = T.permute(0, 2, 1) 185 | 186 | return T.type(torch.float32) 187 | 188 | T = _solve_system(source, target) 189 | output = _transform(T, source, U, out_size) 190 | 191 | return output 192 | 193 | def interpolate_grid(points, scr_w, scr_h, dest_w, dest_h): 194 | sparse_control_points = points.reshape(-1, 2, scr_w, scr_h) 195 | dense_control_points = F.interpolate(sparse_control_points, size=(dest_h, dest_w), mode='bilinear', align_corners=True) 196 | interpolated_points = dense_control_points.reshape(-1, 2*dest_w*dest_h) 197 | return interpolated_points -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from model.builder import * 5 | from dataset_loaders import TestDatasetMask 6 | import os 7 | import glob 8 | import numpy as np 9 | import cv2 10 | from skimage.metrics import peak_signal_noise_ratio as psnr 11 | from skimage.metrics import structural_similarity as ssim 12 | from model.network import MOWA 13 | from collections import OrderedDict 14 | from utils import flow_viz 15 | from utils.utils_op import * 16 | 17 | def test(args): 18 | 19 | path = os.path.dirname(os.path.abspath(__file__)) 20 | IMG_DIR = os.path.join(path, 'results/', args.method, 'img/') 21 | MESH_DIR = os.path.join(path, 'results/', args.method, 'mesh/') 22 | RES_DIR = os.path.join(path, 'results/', args.method, 'res/') 23 | IMG_DIR_LIST = [os.path.join(IMG_DIR, str(i)) for i in range(len(args.test_path))] 24 | MESH_DIR_LIST = [os.path.join(MESH_DIR, str(i)) for i in range(len(args.test_path))] 25 | RES_DIR_LIST = [os.path.join(RES_DIR, str(i)) for i in range(len(args.test_path))] 26 | 27 | for i in range(len(IMG_DIR_LIST)): 28 | path = IMG_DIR_LIST[i] 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | for i in range(len(MESH_DIR_LIST)): 32 | path = MESH_DIR_LIST[i] 33 | if not os.path.exists(path): 34 | os.makedirs(path) 35 | for i in range(len(RES_DIR_LIST)): 36 | path = RES_DIR_LIST[i] 37 | if not os.path.exists(path): 38 | os.makedirs(path) 39 | 40 | test_loader_list = [DataLoader(dataset=TestDatasetMask(test_path, i), batch_size=1, num_workers=4, shuffle=False, drop_last=False) \ 41 | for i, test_path in enumerate(args.test_path)] 42 | 43 | '''define the network''' 44 | net = MOWA(img_size=args.input_size, tps_points=args.tps_points, embed_dim=args.embed_dim, win_size=args.win_size, 45 | token_projection=args.token_projection, token_mlp=args.token_mlp, depths=args.depths, 46 | prompt=args.prompt, task_classes=args.task_classes, head_num=args.head_num, shared_head=args.shared_head) 47 | 48 | if torch.cuda.is_available(): 49 | torch.cuda.set_device(args.gpu) 50 | device = torch.device('cuda:{}'.format(args.gpu)) 51 | net = net.to(device) 52 | 53 | '''load the existing models''' 54 | MODEL_DIR = args.model_path 55 | ckpt_list = glob.glob(MODEL_DIR + "/*.pth") 56 | ckpt_list.sort() 57 | if len(ckpt_list) != 0: 58 | model_path = ckpt_list[-1] 59 | checkpoint = torch.load(model_path) 60 | state_dict = checkpoint["model"] 61 | new_state_dict = OrderedDict() 62 | for k, v in state_dict.items(): 63 | name = k[7:] if 'module.' in k else k 64 | new_state_dict[name] = v 65 | 66 | net.load_state_dict(new_state_dict) 67 | print('load model from {}!'.format(model_path)) 68 | else: 69 | raise FileNotFoundError(f'No checkpoint found in directory {MODEL_DIR}!') 70 | 71 | print("##################start testing#######################") 72 | net.eval() 73 | test_num = len(test_loader_list) 74 | 75 | for index in range(test_num): 76 | print("Task ID:", index) 77 | NUM = count_files(args.test_path[index]) 78 | acc_temp = 0 79 | psnr_img = 0 80 | ssim_img = 0 81 | path_img = str(IMG_DIR_LIST[index]) + '/' 82 | path_grid = str(MESH_DIR_LIST[index]) + '/' 83 | path_res = str(RES_DIR_LIST[index]) + '/' 84 | 85 | with torch.no_grad(): 86 | net.eval() 87 | test_loader = test_loader_list[index] 88 | for i, outputs in enumerate(test_loader): 89 | 90 | input1_tensor = outputs['input1_tensor'].float() 91 | input2_tensor = outputs['input2_tensor'].float() 92 | gt1_tensor = outputs['gt1_tensor'].float() 93 | gt2_tensor = outputs['gt2_tensor'].float() 94 | mask_tensor = outputs['mask_tensor'].float() 95 | task_id_tensor = outputs['task_id_tensor'].float() 96 | file_name = outputs['file_name'][0] 97 | 98 | if torch.cuda.is_available(): 99 | input1_tensor = input1_tensor.cuda() 100 | input2_tensor = input2_tensor.cuda() 101 | gt1_tensor = gt1_tensor.cuda() 102 | gt2_tensor = gt2_tensor.cuda() 103 | mask_tensor = mask_tensor.cuda() 104 | task_id_tensor = task_id_tensor.cuda() 105 | 106 | ''' parsing the output ''' 107 | batch_out = build_model_test(net, input1_tensor, input2_tensor, mask_tensor, args.tps_points, resize_flow=True) 108 | warp_tps, warp_flow, mesh, flow1, flow2, flow3, point_cls = \ 109 | [batch_out[key] for key in ['warp_tps', 'warp_flow', 'mesh', 'flow1', 'flow2', 'flow3', 'point_cls']] 110 | 111 | ''' tensor to numpy and post-processing ''' 112 | _, c, ori_h, ori_w = input1_tensor.shape 113 | input_np2 = ((input2_tensor[0])*255.0).cpu().detach().numpy().transpose(1,2,0).astype(np.uint8) 114 | gt_np1 = ((gt1_tensor[0])*255.0).cpu().detach().numpy().transpose(1,2,0).astype(np.uint8) 115 | gt_np2 = ((gt2_tensor[0])*255.0).cpu().detach().numpy().transpose(1,2,0).astype(np.uint8) 116 | gt_np2 = cv2.resize(gt_np2, (ori_w, ori_h)) 117 | 118 | warp_flow_np = ((warp_flow[0])*255.0).cpu().detach().numpy().transpose(1,2,0).astype(np.uint8) 119 | flow1 = (flow1[0]).cpu().detach().numpy().transpose(1,2,0) 120 | flow2 = (flow2[0]).cpu().detach().numpy().transpose(1,2,0) 121 | flow3 = (flow3[0]).cpu().detach().numpy().transpose(1,2,0) 122 | 123 | flow1 = flow_viz.flow_to_image(flow1) 124 | flow2 = flow_viz.flow_to_image(flow2) 125 | flow3 = flow_viz.flow_to_image(flow3) 126 | 127 | _, point_cls = torch.max(point_cls[0], 0) 128 | acc_temp += (point_cls == task_id_tensor[0]).float().mean().item() 129 | warp_tps_np = ((warp_tps[-1][0])*255.0).cpu().detach().numpy().transpose(1,2,0).astype(np.uint8) 130 | mesh_np = mesh[-1][0].cpu().detach().numpy() 131 | cv2.imwrite(path_img + file_name + "_mesh" + ".jpg", warp_tps_np) 132 | input_with_mesh = draw_mesh_on_warp(input_np2, mesh_np, args.tps_points[-1]-1, args.tps_points[-1]-1) 133 | cv2.imwrite(path_grid + file_name + "_mesh" + ".jpg", input_with_mesh) 134 | 135 | ''' calculate metrics ''' 136 | psnr_img += psnr(warp_flow_np, gt_np1, data_range=255) 137 | ssim_img += ssim(warp_flow_np, gt_np1, data_range=255, channel_axis=2) 138 | cv2.imwrite(path_img + file_name + "_flow1.jpg", flow1) 139 | cv2.imwrite(path_img + file_name + "_flow2.jpg", flow2) 140 | cv2.imwrite(path_img + file_name + "_flow3.jpg", flow3) 141 | cv2.imwrite(path_res + file_name + ".jpg", warp_flow_np) 142 | cv2.imwrite(path_img + file_name + "_flow.jpg", warp_flow_np) 143 | 144 | print(f"Validation PSNR: {round(psnr_img / NUM, 4)}, Validation SSIM: {round(ssim_img / NUM, 4)}, Validation Acc: {round(acc_temp / NUM, 4)}") 145 | 146 | 147 | 148 | if __name__=="__main__": 149 | 150 | parser = argparse.ArgumentParser() 151 | 152 | '''Implementation details''' 153 | parser.add_argument('--gpu', type=int, default=0) 154 | parser.add_argument('--batch_size', type=int, default=1) 155 | parser.add_argument('--model_path', type=str, default='model/') 156 | parser.add_argument('--method', type=str, default='method') 157 | 158 | '''Network details''' 159 | parser.add_argument('--input_size', type=int, default=256) 160 | parser.add_argument('--depths', nargs='+', type=int, default=[2, 2, 2, 2, 2, 2, 2, 2, 2], help='depths for transformer layers') 161 | parser.add_argument('--embed_dim', type=int, default=32) 162 | parser.add_argument('--win_size', type=int, default=8) 163 | parser.add_argument('--token_projection', type=str, default='linear') 164 | parser.add_argument('--token_mlp', type=str, default='leff') 165 | parser.add_argument('--prompt', type=bool, default=True) 166 | parser.add_argument('--task_classes', type=int, default=6) 167 | parser.add_argument('--tps_points', nargs='+', type=int, default=[10, 12, 14, 16], help='tps points for regression heads') 168 | parser.add_argument('--head_num', type=int, default=4) 169 | parser.add_argument('--shared_head', type=bool, default=False) 170 | 171 | '''Dataset settings''' 172 | parser.add_argument('--test_path', type=str, default=['/stitch/test/', '/wide-angle/test/', '/RS_Rec/test/', '/Rotation/test/', '/fisheye/test/', '/portrait/test/']) 173 | 174 | print('<==================== Testing ===================>\n') 175 | 176 | args = parser.parse_args() 177 | print(args) 178 | test(args) -------------------------------------------------------------------------------- /test_portrait.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import torch 4 | from model.builder import * 5 | import os 6 | import glob 7 | import numpy as np 8 | import cv2 9 | from model.network import MOWA 10 | from collections import OrderedDict 11 | from utils.utils_op import * 12 | from tqdm import tqdm 13 | 14 | eps = 1e-6 15 | 16 | def estimation_flowmap(model, img, device, args): 17 | model.eval() 18 | img = cv2.resize(img, (256, 256)) 19 | img = img.astype(dtype=np.float32) 20 | img = img / 255.0 21 | img = np.transpose(img, [2, 0, 1]) 22 | img = torch.tensor(img) 23 | img = img.unsqueeze(0) 24 | mask = np.ones((256, 256), dtype=np.uint8) * 255 25 | mask = np.expand_dims(mask, axis=-1) 26 | mask = mask.astype(dtype=np.float32) 27 | mask = mask / 255.0 28 | mask = np.transpose(mask, [2, 0, 1]) 29 | mask = torch.tensor(mask) 30 | mask = mask.unsqueeze(0) 31 | with torch.no_grad(): 32 | img = img.to(device) 33 | mask = mask.to(device) 34 | batch_out = build_model_test(model, img, img, mask, args.tps_points) 35 | output = batch_out['flow3'] 36 | output = output.detach().cpu().squeeze(0).numpy() 37 | return output 38 | 39 | 40 | # ----------------------The computation process of face metric --------------------------------------- 41 | def compute_cosin_similarity(preds, gts): 42 | people_num = gts.shape[0] 43 | points_num = gts.shape[1] 44 | similarity_list = [] 45 | preds = preds.astype(np.float32) 46 | gts = gts.astype(np.float32) 47 | for people_index in range(people_num): 48 | # the index 63 of lmk is the center point of the face, that is, the tip of the nose 49 | pred_center = preds[people_index, 63, :] 50 | pred = preds[people_index, :, :] 51 | pred = pred - pred_center[None, :] 52 | gt_center = gts[people_index, 63, :] 53 | gt = gts[people_index, :, :] 54 | gt = gt - gt_center[None, :] 55 | 56 | dot = np.sum((pred * gt), axis=1) 57 | pred = np.sqrt(np.sum(pred * pred, axis=1)) 58 | gt = np.sqrt(np.sum(gt * gt, axis=1)) 59 | 60 | similarity_list_tmp = [] 61 | for i in range(points_num): 62 | if i != 63: 63 | similarity = (dot[i] / (pred[i] * gt[i] + eps)) 64 | similarity_list_tmp.append(similarity) 65 | 66 | similarity_list.append(np.mean(similarity_list_tmp)) 67 | 68 | return np.mean(similarity_list) 69 | 70 | 71 | # --------------------The normalization function ----------------------------------------------------- 72 | def normalization(x): 73 | return [(float(i) - min(x)) / float(max(x) - min(x) + eps) for i in x] 74 | 75 | 76 | # -------------------The computation process of line metric------------------------------------------- 77 | def compute_line_slope_difference(pred_line, gt_k): 78 | scores = [] 79 | for i in range(pred_line.shape[0] - 1): 80 | pk = (pred_line[i + 1, 1] - pred_line[i, 1]) / (pred_line[i + 1, 0] - pred_line[i, 0] + eps) 81 | score = np.abs(pk - gt_k) 82 | scores.append(score) 83 | scores_norm = normalization(scores) 84 | score = np.mean(scores_norm) 85 | score = 1 - score 86 | return score 87 | 88 | 89 | # -------------------------------Compute the out put flow map ------------------------------------------------- 90 | def compute_ori2shape_face_line_metric(model, oriimg_paths, device, args): 91 | line_all_sum_pred = [] 92 | face_all_sum_pred = [] 93 | 94 | for oriimg_path in tqdm(oriimg_paths): 95 | # Get the [Source image] 96 | ori_img = cv2.imread(oriimg_path) # Read the oriinal image 97 | ori_height, ori_width, _ = ori_img.shape # get the size of the oriinal image 98 | input = ori_img.copy() # get the image as the input of our model 99 | 100 | # Get the [flow map]""" 101 | pred = estimation_flowmap(model, input, device, args) 102 | pflow = pred.transpose(1, 2, 0) 103 | predflow_x, predflow_y = pflow[:, :, 0], pflow[:, :, 1] 104 | 105 | scale_x = ori_width / predflow_x.shape[1] 106 | scale_y = ori_height / predflow_x.shape[0] 107 | predflow_x = cv2.resize(predflow_x, (ori_width, ori_height)) * scale_x 108 | predflow_y = cv2.resize(predflow_y, (ori_width, ori_height)) * scale_y 109 | 110 | # Get the [predicted image]""" 111 | ys, xs = np.mgrid[:ori_height, :ori_width] 112 | mesh_x = predflow_x.astype("float32") + xs.astype("float32") 113 | mesh_y = predflow_y.astype("float32") + ys.astype("float32") 114 | pred_out = cv2.remap(input, mesh_x, mesh_y, cv2.INTER_LINEAR) 115 | cv2.imwrite(oriimg_path.replace(".jpg", "_pred.jpg"), pred_out) 116 | 117 | # Get the landmarks from the [gt image] 118 | stereo_lmk_file = open(oriimg_path.replace(".jpg", "_stereo_landmark.json")) 119 | stereo_lmk = np.array(json.load(stereo_lmk_file), dtype="float32") 120 | 121 | # Get the landmarks from the [source image] 122 | ori_lmk_file = open(oriimg_path.replace(".jpg", "_landmark.json")) 123 | ori_lmk = np.array(json.load(ori_lmk_file), dtype="float32") 124 | 125 | # Get the landmarks from the the pred out 126 | out_lmk = np.zeros_like(ori_lmk) 127 | for i in range(ori_lmk.shape[0]): 128 | for j in range(ori_lmk.shape[1]): 129 | x = ori_lmk[i, j, 0] 130 | y = ori_lmk[i, j, 1] 131 | if y < predflow_y.shape[0] and x < predflow_y.shape[1]: 132 | out_lmk[i, j, 0] = x - predflow_x[int(y), int(x)] 133 | out_lmk[i, j, 1] = y - predflow_y[int(y), int(x)] 134 | else: 135 | out_lmk[i, j, 0] = x 136 | out_lmk[i, j, 1] = y 137 | 138 | # Compute the face metric 139 | face_pred_sim = compute_cosin_similarity(out_lmk, stereo_lmk) 140 | face_all_sum_pred.append(face_pred_sim) 141 | stereo_lmk_file.close() 142 | ori_lmk_file.close() 143 | 144 | # Get the line from the [gt image] 145 | gt_line_file = oriimg_path.replace(".jpg", "_line_lines.json") 146 | lines = json.load(open(gt_line_file)) 147 | 148 | # Get the line from the [source image] 149 | ori_line_file = oriimg_path.replace(".jpg", "_lines.json") 150 | ori_lines = json.load(open(ori_line_file)) 151 | 152 | # Get the line from the pred out 153 | pred_ori2shape_lines = [] 154 | for index, ori_line in enumerate(ori_lines): 155 | ori_line = np.array(ori_line, dtype="float32") 156 | pred_ori2shape = np.zeros_like(ori_line) 157 | for i in range(ori_line.shape[0]): 158 | x = ori_line[i, 0] 159 | y = ori_line[i, 1] 160 | pred_ori2shape[i, 0] = x - predflow_x[int(y), int(x)] 161 | pred_ori2shape[i, 1] = y - predflow_y[int(y), int(x)] 162 | pred_ori2shape = pred_ori2shape.tolist() 163 | pred_ori2shape_lines.append(pred_ori2shape) 164 | 165 | # Compute the lines score 166 | line_pred_ori2shape_sum = [] 167 | for index, line in enumerate(lines): 168 | gt_line = np.array(line, dtype="float32") 169 | pred_ori2shape = np.array(pred_ori2shape_lines[index], dtype="float32") 170 | gt_k = (gt_line[1, 1] - gt_line[0, 1]) / (gt_line[1, 0] - gt_line[0, 0] + eps) 171 | pred_ori2shape_score = compute_line_slope_difference(pred_ori2shape, gt_k) 172 | line_pred_ori2shape_sum.append(pred_ori2shape_score) 173 | line_all_sum_pred.append(np.mean(line_pred_ori2shape_sum)) 174 | 175 | return np.mean(line_all_sum_pred) * 100, np.mean(face_all_sum_pred) * 100 176 | 177 | def test(args): 178 | 179 | net = MOWA(img_size=args.input_size, tps_points=args.tps_points, embed_dim=args.embed_dim, win_size=args.win_size, 180 | token_projection=args.token_projection, token_mlp=args.token_mlp, depths=args.depths, 181 | prompt=args.prompt, task_classes=args.task_classes, head_num=args.head_num, shared_head=args.shared_head) 182 | 183 | if torch.cuda.is_available(): 184 | torch.cuda.set_device(args.gpu) 185 | device = torch.device('cuda:{}'.format(args.gpu)) 186 | net = net.to(device) 187 | 188 | MODEL_DIR = args.model_path 189 | ckpt_list = glob.glob(MODEL_DIR + "/*.pth") 190 | ckpt_list.sort() 191 | if len(ckpt_list) != 0: 192 | model_path = ckpt_list[-1] 193 | print(model_path) 194 | checkpoint = torch.load(model_path) 195 | state_dict = checkpoint["model"] 196 | new_state_dict = OrderedDict() 197 | for k, v in state_dict.items(): 198 | name = k[7:] if 'module.' in k else k 199 | new_state_dict[name] = v 200 | 201 | net.load_state_dict(new_state_dict) 202 | print('load model from {}!'.format(model_path)) 203 | else: 204 | raise FileNotFoundError(f'No checkpoint found in directory {MODEL_DIR}!') 205 | 206 | print("##################start testing#######################") 207 | net.eval() 208 | 209 | oriimg_paths = [] 210 | for root, _, files in os.walk(args.test_path): 211 | for file_name in files: 212 | if file_name.endswith(".jpg"): 213 | if "line" not in file_name and "stereo" not in file_name and "pred" not in file_name: 214 | oriimg_paths.append(os.path.join(root, file_name)) 215 | 216 | print("The number of images: :", len(oriimg_paths)) 217 | 218 | line_score, face_score = compute_ori2shape_face_line_metric(net, oriimg_paths, device, args) 219 | print("Line_score = {:.3f}, Face_score = {:.3f} ".format(line_score, face_score)) 220 | 221 | 222 | if __name__=="__main__": 223 | 224 | parser = argparse.ArgumentParser() 225 | 226 | '''Implementation details''' 227 | parser.add_argument('--gpu', type=int, default=0) 228 | parser.add_argument('--batch_size', type=int, default=1) 229 | parser.add_argument('--model_path', type=str, default='model/') 230 | parser.add_argument('--method', type=str, default='method') 231 | 232 | '''Network details''' 233 | parser.add_argument('--input_size', type=int, default=256) 234 | parser.add_argument('--depths', nargs='+', type=int, default=[2, 2, 2, 2, 2, 2, 2, 2, 2], help='depths for transformer layers') 235 | parser.add_argument('--embed_dim', type=int, default=32) 236 | parser.add_argument('--win_size', type=int, default=8) 237 | parser.add_argument('--token_projection', type=str, default='linear') 238 | parser.add_argument('--token_mlp', type=str, default='leff') 239 | parser.add_argument('--prompt', type=bool, default=True) 240 | parser.add_argument('--task_classes', type=int, default=6) 241 | parser.add_argument('--tps_points', nargs='+', type=int, default=[10, 12, 14, 16], help='tps points for regression heads') 242 | parser.add_argument('--head_num', type=int, default=4) 243 | parser.add_argument('--shared_head', type=bool, default=False) 244 | 245 | '''Dataset settings''' 246 | parser.add_argument('--test_path', type=str, default="/Dataset/FaceRec/test_4_3_all/") 247 | 248 | print('<==================== Testing ===================>\n') 249 | 250 | args = parser.parse_args() 251 | print(args) 252 | test(args) 253 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch.utils.data import DataLoader 4 | import os 5 | import torch.optim as optim 6 | from torch.optim.lr_scheduler import StepLR 7 | from torch.utils.tensorboard import SummaryWriter 8 | from model.builder import * 9 | from model.network import MOWA 10 | from dataset_loaders import TrainDataset 11 | import glob 12 | from model.loss import * 13 | import torchvision.models as models 14 | import torch.multiprocessing as mp 15 | from torch.utils.data.distributed import DistributedSampler 16 | from torch.cuda.amp import autocast 17 | import torch.distributed as dist 18 | from warmup_scheduler import GradualWarmupScheduler 19 | from utils.utils_op import * 20 | torch.backends.cudnn.enabled = True 21 | 22 | def train(gpu, ngpus_per_node, args): 23 | 24 | """ threads running on each GPU """ 25 | if args.distributed: 26 | torch.cuda.set_device(int(gpu)) 27 | print('using GPU {} for training'.format(int(gpu))) 28 | torch.distributed.init_process_group(backend = 'nccl', 29 | init_method = 'tcp://127.0.0.1:' + args.port, 30 | world_size = ngpus_per_node, 31 | rank = gpu, 32 | group_name='mtorch' 33 | ) 34 | 35 | ''' folder settings''' 36 | path = os.path.dirname(os.path.abspath(__file__)) 37 | MODEL_DIR = os.path.join(path, 'checkpoint/', args.method) 38 | SUMMARY_DIR = os.path.join(path, 'summary/', args.method) 39 | if dist.get_rank() == 0: 40 | writer = SummaryWriter(log_dir=SUMMARY_DIR) 41 | if not os.path.exists(MODEL_DIR): 42 | os.makedirs(MODEL_DIR) 43 | if not os.path.exists(SUMMARY_DIR): 44 | os.makedirs(SUMMARY_DIR) 45 | 46 | ''' dataloader settings ''' 47 | train_dataset = TrainDataset(args.train_path) 48 | train_sampler = DistributedSampler(train_dataset, num_replicas=ngpus_per_node, rank=gpu) 49 | train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, num_workers=4, sampler=train_sampler, drop_last=True) 50 | 51 | ''' define the network and training scheduler''' 52 | net = MOWA(img_size=args.input_size, tps_points=args.tps_points, embed_dim=args.embed_dim, win_size=args.win_size, 53 | token_projection=args.token_projection, token_mlp=args.token_mlp, depths=args.depths, 54 | prompt=args.prompt, task_classes=args.task_classes, head_num=args.head_num, shared_head=args.shared_head) 55 | 56 | vgg_model = models.vgg19(pretrained=True) 57 | net = set_device(net, distributed=args.distributed) 58 | vgg_model = set_device(vgg_model, distributed=args.distributed) 59 | vgg_model.eval() 60 | 61 | optimizer = optim.AdamW(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay) 62 | scaler = torch.cuda.amp.GradScaler() 63 | if args.warmup: 64 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_epoch-args.warmup_epochs, eta_min=args.eta_min) 65 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=args.warmup_epochs, after_scheduler=scheduler_cosine) 66 | scheduler.step() 67 | else: 68 | step = 50 69 | print("Using StepLR,step={}!".format(step)) 70 | scheduler = StepLR(optimizer, step_size=step, gamma=0.5) 71 | scheduler.step() 72 | 73 | ''' resume training or not''' 74 | ckpt_list = glob.glob(MODEL_DIR + "/*.pth") 75 | ckpt_list.sort() 76 | if len(ckpt_list) != 0: 77 | model_path = ckpt_list[-1] 78 | checkpoint = torch.load(model_path) 79 | net.load_state_dict(checkpoint['model']) 80 | optimizer.load_state_dict(checkpoint['optimizer']) 81 | start_epoch = checkpoint['epoch'] 82 | glob_iter = checkpoint['glob_iter'] 83 | scheduler.last_epoch = start_epoch 84 | print('load model from {}!'.format(model_path)) 85 | else: 86 | start_epoch = 0 87 | glob_iter = 0 88 | print('training from stratch!') 89 | 90 | ''' network training''' 91 | for epoch in range(start_epoch, args.max_epoch): 92 | train_sampler.set_epoch(epoch) 93 | net.train() 94 | total_loss_sigma = 0. 95 | appearance_loss_sigma = 0. 96 | perception_loss_sigma = 0. 97 | inter_grid_loss_sigma = 0. 98 | point_loss_sigma = 0. 99 | flow_loss_sigma = 0. 100 | 101 | for i, batch_value in enumerate(train_loader): 102 | input_tesnor = batch_value[0].float() 103 | gt_tesnor = batch_value[1].float() 104 | mask_tensor = batch_value[2].float() 105 | task_id_tensor = batch_value[3].float() 106 | flow_tensor = batch_value[4].float() 107 | face_mask = batch_value[5].float() 108 | face_weight = batch_value[6].float() 109 | 110 | if torch.cuda.is_available(): 111 | input_tesnor = set_device(input_tesnor, distributed=args.distributed) 112 | gt_tesnor = set_device(gt_tesnor, distributed=args.distributed) 113 | mask_tensor = set_device(mask_tensor, distributed=args.distributed) 114 | task_id_tensor = set_device(task_id_tensor, distributed=args.distributed) 115 | flow_tensor = set_device(flow_tensor, distributed=args.distributed) 116 | face_mask = set_device(face_mask, distributed=args.distributed) 117 | face_weight = set_device(face_weight, distributed=args.distributed) 118 | 119 | optimizer.zero_grad() 120 | with autocast(): 121 | batch_out = build_model(net.module, input_tesnor, input_tesnor, mask_tensor, args.tps_points) 122 | warp_tps, mesh, warp_flow, flow, point_cls = \ 123 | [batch_out[key] for key in ['warp_tps', 'mesh', 'warp_flow', 'flow3', 'point_cls']] 124 | 125 | ''' calculate losses''' 126 | inter_grid_loss = cal_inter_grid_loss_sum(mesh, args.tps_points, [1.0/args.head_num for _ in range(args.head_num)]) 127 | point_loss = cal_point_loss(point_cls, task_id_tensor) * 0.1 128 | face_weight_ad = adjust_weight(epoch, args.max_epoch, face_weight) 129 | flow_tensor, flow = get_weight_mask(face_mask, flow_tensor, flow, weight=face_weight_ad) 130 | flow_loss = mask_flow_loss(flow, flow_tensor, task_id_tensor) * 0.1 131 | 132 | if(epoch <= 10): 133 | appearance_loss = cal_appearance_loss_sum(warp_tps, gt_tesnor, args.img_weight1) 134 | perception_loss = cal_perception_loss_sum(vgg_model, warp_tps, gt_tesnor, args.img_weight1) 135 | total_loss = appearance_loss + perception_loss + inter_grid_loss + point_loss 136 | else: 137 | appearance_loss = cal_appearance_loss_sum(warp_tps+[warp_flow], gt_tesnor, args.img_weight1+args.img_weight2) 138 | perception_loss = cal_perception_loss_sum(vgg_model, warp_tps+[warp_flow], gt_tesnor, args.img_weight1+args.img_weight2) 139 | total_loss = appearance_loss + perception_loss + inter_grid_loss + flow_loss + point_loss 140 | 141 | scaler.scale(total_loss).backward() 142 | scaler.step(optimizer) 143 | scaler.update() 144 | 145 | total_loss_sigma += total_loss.item() 146 | appearance_loss_sigma += appearance_loss.item() 147 | perception_loss_sigma += perception_loss.item() 148 | inter_grid_loss_sigma += inter_grid_loss.item() 149 | point_loss_sigma += point_loss.item() 150 | flow_loss_sigma += flow_loss.item() 151 | 152 | ''' writting training logs ''' 153 | if i % args.print_interval == 0 and i != 0: 154 | if dist.get_rank() == 0: 155 | total_loss_average = total_loss_sigma / args.print_interval 156 | appearance_loss_average = appearance_loss_sigma/ args.print_interval 157 | perception_loss_average = perception_loss_sigma/ args.print_interval 158 | inter_grid_loss_average = inter_grid_loss_sigma/ args.print_interval 159 | point_loss_average = point_loss_sigma/ args.print_interval 160 | flow_loss_average = flow_loss_sigma/ args.print_interval 161 | 162 | total_loss_sigma = 0. 163 | appearance_loss_sigma = 0. 164 | perception_loss_sigma = 0. 165 | inter_grid_loss_sigma = 0. 166 | point_loss_sigma = 0. 167 | flow_loss_sigma = 0. 168 | 169 | print(f"Training: Epoch[{epoch + 1:0>3}/{args.max_epoch:0>3}] " 170 | f"Iteration[{i + 1:0>3}/{len(train_loader):0>3}] " 171 | f"Total Loss: {total_loss_average:.4f} " 172 | f"Appearance Loss: {appearance_loss_average:.4f} " 173 | f"Perception Loss: {perception_loss_average:.4f} " 174 | f"Point Loss: {point_loss_average:.4f} " 175 | f"Flow Loss: {flow_loss_average:.4f} " 176 | f"Inter-Grid Loss: {inter_grid_loss_average:.4f} " 177 | f"lr={optimizer.state_dict()['param_groups'][0]['lr']:.8f}") 178 | 179 | writer.add_image("input", (input_tesnor[0]), glob_iter) 180 | writer.add_image("rectangling", (warp_flow[0]), glob_iter) 181 | writer.add_image("gt", (gt_tesnor[0]), glob_iter) 182 | writer.add_scalar('lr', optimizer.state_dict()['param_groups'][0]['lr'], glob_iter) 183 | writer.add_scalar('total loss', total_loss_average, glob_iter) 184 | writer.add_scalar('appearance loss', appearance_loss_average, glob_iter) 185 | writer.add_scalar('perception loss', perception_loss_average, glob_iter) 186 | writer.add_scalar('inter-grid loss', inter_grid_loss_average, glob_iter) 187 | writer.add_scalar('point loss', point_loss_average, glob_iter) 188 | writer.add_scalar('flow loss', flow_loss_average, glob_iter) 189 | 190 | glob_iter += 1 191 | 192 | ''' save models ''' 193 | if ((epoch+1) % args.eva_interval == 0 or (epoch+1)==args.max_epoch): 194 | if dist.get_rank() == 0: 195 | filename ='epoch' + str(epoch+1).zfill(3) + '_model.pth' 196 | model_save_path = os.path.join(MODEL_DIR, filename) 197 | state = {'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch+1, "glob_iter": glob_iter} 198 | torch.save(state, model_save_path) 199 | 200 | scheduler.step() 201 | 202 | 203 | if __name__=="__main__": 204 | 205 | print('<==================== setting arguments ===================>\n') 206 | 207 | parser = argparse.ArgumentParser() 208 | '''Implementation details''' 209 | parser.add_argument('-gpu', '--gpu_ids', type=str, default='0') 210 | parser.add_argument('-b', '--batch_size', type=int, default=8) 211 | parser.add_argument('--max_epoch', type=int, default=300) 212 | parser.add_argument('-m', '--method', type=str, default='mowa') 213 | parser.add_argument('-P', '--port', default='21016', type=str) 214 | parser.add_argument('-d', '--distributed', type=bool, default=True) 215 | parser.add_argument('-w', '--warmup', type=bool, default=True) 216 | parser.add_argument('--warmup_epochs', type=int, default=3, help='epochs for warmup') 217 | parser.add_argument('--print_interval', type=int, default=160) 218 | parser.add_argument('--eva_interval', type=int, default=10) 219 | parser.add_argument("--lr", type=float, default=2e-4, help="start learning rate") 220 | parser.add_argument("--eta_min", type=float, default=1e-7, help="final learning rate") 221 | parser.add_argument("--weight_decay", type=float, default=0, help="weight decay of the optimizer") 222 | parser.add_argument('--img_weight1', nargs='+', type=float, default=[0.1, 0.1, 0.2, 0.5], help='weights for img loss (stage1)') 223 | parser.add_argument('--img_weight2', nargs='+', type=float, default=[0.5], help='weights for img loss (stage2)') 224 | 225 | '''Network details''' 226 | parser.add_argument('--input_size', type=int, default=256) 227 | parser.add_argument('--depths', nargs='+', type=int, default=[2, 2, 2, 2, 2, 2, 2, 2, 2], help='depths for transformer layers') 228 | parser.add_argument('--tps_points', nargs='+', type=int, default=[10, 12, 14, 16], help='tps points for regression heads') 229 | parser.add_argument('--embed_dim', type=int, default=32) 230 | parser.add_argument('--win_size', type=int, default=8) 231 | parser.add_argument('--token_projection', type=str, default='linear') 232 | parser.add_argument('--token_mlp', type=str, default='leff') 233 | parser.add_argument('--prompt', type=bool, default=True) 234 | parser.add_argument('--task_classes', type=int, default=6) 235 | parser.add_argument('--head_num', type=int, default=4) 236 | parser.add_argument('--shared_head', type=bool, default=False) 237 | 238 | '''Dataset settings''' 239 | parser.add_argument('--train_path', type=str, default=['/Dataset/pano-rectangling/train/', '/Dataset/wide-angle_rectangling/train/', 240 | '/Dataset/RS_Rec/RS_Rec/train/', '/Dataset/Rotation/train/', 241 | '/Dataset/fisheye/train/', '/Dataset/FaceRec/train/']) 242 | 243 | args = parser.parse_args() 244 | print(args) 245 | 246 | gpu_str = args.gpu_ids 247 | gpu_ids = [int(id) for id in args.gpu_ids.split(',')] 248 | num_gpus = len(gpu_ids) 249 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_str 250 | print('export CUDA_VISIBLE_DEVICES={}'.format(gpu_str)) 251 | opt=0 252 | print('<==================== start training ===================>\n') 253 | mp.spawn(train, nprocs=num_gpus, args=(num_gpus, args)) 254 | 255 | print("################## end training #######################") -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | import math 8 | from utils.utils_module import * 9 | from utils.utils_transform import * 10 | 11 | class FastLeFF(nn.Module): 12 | 13 | def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop = 0.): 14 | super().__init__() 15 | 16 | from torch_dwconv import DepthwiseConv2d 17 | 18 | self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim), 19 | act_layer()) 20 | self.dwconv = nn.Sequential(DepthwiseConv2d(hidden_dim, hidden_dim, kernel_size=3,stride=1,padding=1), 21 | act_layer()) 22 | self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim)) 23 | self.dim = dim 24 | self.hidden_dim = hidden_dim 25 | 26 | def forward(self, x): 27 | # bs x hw x c 28 | bs, hw, c = x.size() 29 | hh = int(math.sqrt(hw)) 30 | 31 | x = self.linear1(x) 32 | 33 | # spatial restore 34 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h = hh, w = hh) 35 | # bs,hidden_dim,32x32 36 | 37 | x = self.dwconv(x) 38 | 39 | # flaten 40 | x = rearrange(x, ' b c h w -> b (h w) c', h = hh, w = hh) 41 | 42 | x = self.linear2(x) 43 | 44 | return x 45 | 46 | class SELayer(nn.Module): 47 | def __init__(self, channel, reduction=16): 48 | super(SELayer, self).__init__() 49 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 50 | self.fc = nn.Sequential( 51 | nn.Linear(channel, channel // reduction, bias=False), 52 | nn.ReLU(inplace=True), 53 | nn.Linear(channel // reduction, channel, bias=False), 54 | nn.Sigmoid() 55 | ) 56 | 57 | def forward(self, x): # x: [B, N, C] 58 | x = torch.transpose(x, 1, 2) # [B, C, N] 59 | b, c, _ = x.size() 60 | y = self.avg_pool(x).view(b, c) 61 | y = self.fc(y).view(b, c, 1) 62 | x = x * y.expand_as(x) 63 | x = torch.transpose(x, 1, 2) # [B, N, C] 64 | return x 65 | 66 | class SepConv2d(torch.nn.Module): 67 | def __init__(self, 68 | in_channels, 69 | out_channels, 70 | kernel_size, 71 | stride=1, 72 | padding=0, 73 | dilation=1, act_layer=nn.ReLU): 74 | super(SepConv2d, self).__init__() 75 | self.depthwise = torch.nn.Conv2d(in_channels, 76 | in_channels, 77 | kernel_size=kernel_size, 78 | stride=stride, 79 | padding=padding, 80 | dilation=dilation, 81 | groups=in_channels) 82 | self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1) 83 | self.act_layer = act_layer() if act_layer is not None else nn.Identity() 84 | 85 | def forward(self, x): 86 | x = self.depthwise(x) 87 | x = self.act_layer(x) 88 | x = self.pointwise(x) 89 | return x 90 | 91 | class ConvProjection(nn.Module): 92 | def __init__(self, dim, heads=8, dim_head=64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, dropout=0., 93 | last_stage=False, bias=True): 94 | super().__init__() 95 | 96 | inner_dim = dim_head * heads 97 | self.heads = heads 98 | pad = (kernel_size - q_stride) // 2 99 | self.to_q = SepConv2d(dim, inner_dim, kernel_size, q_stride, pad, bias) 100 | self.to_k = SepConv2d(dim, inner_dim, kernel_size, k_stride, pad, bias) 101 | self.to_v = SepConv2d(dim, inner_dim, kernel_size, v_stride, pad, bias) 102 | 103 | def forward(self, x, attn_kv=None): 104 | b, n, c, h = *x.shape, self.heads 105 | l = int(math.sqrt(n)) 106 | w = int(math.sqrt(n)) 107 | 108 | attn_kv = x if attn_kv is None else attn_kv 109 | x = rearrange(x, 'b (l w) c -> b c l w', l=l, w=w) 110 | attn_kv = rearrange(attn_kv, 'b (l w) c -> b c l w', l=l, w=w) 111 | # print(attn_kv) 112 | q = self.to_q(x) 113 | q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h) 114 | 115 | k = self.to_k(attn_kv) 116 | v = self.to_v(attn_kv) 117 | k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h) 118 | v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h) 119 | return q, k, v 120 | 121 | class LinearProjection(nn.Module): 122 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., bias=True): 123 | super().__init__() 124 | inner_dim = dim_head * heads 125 | self.heads = heads 126 | self.to_q = nn.Linear(dim, inner_dim, bias=bias) 127 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias) 128 | 129 | def forward(self, x, attn_kv=None): 130 | B_, N, C = x.shape 131 | attn_kv = x if attn_kv is None else attn_kv 132 | q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 133 | kv = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 134 | q = q[0] 135 | k, v = kv[0], kv[1] 136 | return q, k, v 137 | 138 | class LinearProjection_Concat_kv(nn.Module): 139 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., bias=True): 140 | super().__init__() 141 | inner_dim = dim_head * heads 142 | self.heads = heads 143 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=bias) 144 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias) 145 | 146 | def forward(self, x, attn_kv=None): 147 | B_, N, C = x.shape 148 | attn_kv = x if attn_kv is None else attn_kv 149 | qkv_dec = self.to_qkv(x).reshape(B_, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 150 | kv_enc = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 151 | q, k_d, v_d = qkv_dec[0], qkv_dec[1], qkv_dec[2] # make torchscript happy (cannot use tensor as tuple) 152 | k_e, v_e = kv_enc[0], kv_enc[1] 153 | k = torch.cat((k_d, k_e), dim=2) 154 | v = torch.cat((v_d, v_e), dim=2) 155 | return q, k, v 156 | 157 | class WindowAttention(nn.Module): 158 | def __init__(self, dim, win_size, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., 159 | proj_drop=0., se_layer=False): 160 | 161 | super().__init__() 162 | self.dim = dim 163 | self.win_size = win_size # Wh, Ww 164 | self.num_heads = num_heads 165 | head_dim = dim // num_heads 166 | self.scale = qk_scale or head_dim ** -0.5 167 | 168 | # define a parameter table of relative position bias 169 | self.relative_position_bias_table = nn.Parameter( 170 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 171 | 172 | # get pair-wise relative position index for each token inside the window 173 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1] 174 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1] 175 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 176 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 177 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 178 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 179 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0 180 | relative_coords[:, :, 1] += self.win_size[1] - 1 181 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 182 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 183 | self.register_buffer("relative_position_index", relative_position_index) 184 | 185 | if token_projection == 'conv': 186 | self.qkv = ConvProjection(dim, num_heads, dim // num_heads, bias=qkv_bias) 187 | elif token_projection == 'linear_concat': 188 | self.qkv = LinearProjection_Concat_kv(dim, num_heads, dim // num_heads, bias=qkv_bias) 189 | else: 190 | self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias) 191 | 192 | self.attn_drop = nn.Dropout(attn_drop) 193 | self.proj = nn.Linear(dim, dim) 194 | self.se_layer = SELayer(dim) if se_layer else nn.Identity() 195 | self.proj_drop = nn.Dropout(proj_drop) 196 | 197 | trunc_normal_(self.relative_position_bias_table, std=.02) 198 | self.softmax = nn.Softmax(dim=-1) 199 | 200 | def forward(self, x, attn_kv=None, mask=None): 201 | B_, N, C = x.shape 202 | q, k, v = self.qkv(x, attn_kv) 203 | q = q * self.scale 204 | attn = (q @ k.transpose(-2, -1)) 205 | 206 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 207 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH 208 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 209 | ratio = attn.size(-1) // relative_position_bias.size(-1) 210 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio) 211 | 212 | attn = attn + relative_position_bias.unsqueeze(0) 213 | 214 | if mask is not None: 215 | nW = mask.shape[0] 216 | mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio) 217 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0) 218 | attn = attn.view(-1, self.num_heads, N, N * ratio) 219 | attn = self.softmax(attn) 220 | else: 221 | attn = self.softmax(attn) 222 | 223 | attn = self.attn_drop(attn) 224 | 225 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 226 | x = self.proj(x) 227 | x = self.se_layer(x) 228 | x = self.proj_drop(x) 229 | return x 230 | 231 | def extra_repr(self) -> str: 232 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}' 233 | 234 | class Attention(nn.Module): 235 | def __init__(self, dim,num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 236 | 237 | super().__init__() 238 | self.dim = dim 239 | self.num_heads = num_heads 240 | head_dim = dim // num_heads 241 | self.scale = qk_scale or head_dim ** -0.5 242 | 243 | self.qkv = LinearProjection(dim,num_heads,dim//num_heads,bias=qkv_bias) 244 | 245 | self.token_projection = token_projection 246 | self.attn_drop = nn.Dropout(attn_drop) 247 | self.proj = nn.Linear(dim, dim) 248 | self.proj_drop = nn.Dropout(proj_drop) 249 | 250 | self.softmax = nn.Softmax(dim=-1) 251 | 252 | def forward(self, x, attn_kv=None, mask=None): 253 | B_, N, C = x.shape 254 | q, k, v = self.qkv(x,attn_kv) 255 | q = q * self.scale 256 | attn = (q @ k.transpose(-2, -1)) 257 | 258 | if mask is not None: 259 | nW = mask.shape[0] 260 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 261 | attn = attn.view(-1, self.num_heads, N, N) 262 | attn = self.softmax(attn) 263 | else: 264 | attn = self.softmax(attn) 265 | 266 | attn = self.attn_drop(attn) 267 | 268 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 269 | x = self.proj(x) 270 | x = self.proj_drop(x) 271 | return x 272 | 273 | def extra_repr(self) -> str: 274 | return f'dim={self.dim}, num_heads={self.num_heads}' 275 | 276 | class Mlp(nn.Module): 277 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 278 | super().__init__() 279 | out_features = out_features or in_features 280 | hidden_features = hidden_features or in_features 281 | self.fc1 = nn.Linear(in_features, hidden_features) 282 | self.act = act_layer() 283 | self.fc2 = nn.Linear(hidden_features, out_features) 284 | self.drop = nn.Dropout(drop) 285 | 286 | def forward(self, x): 287 | x = self.fc1(x) 288 | x = self.act(x) 289 | x = self.drop(x) 290 | x = self.fc2(x) 291 | x = self.drop(x) 292 | return x 293 | 294 | 295 | class LeFF(nn.Module): 296 | def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop=0.): 297 | super().__init__() 298 | self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim), 299 | act_layer()) 300 | self.dwconv = nn.Sequential( 301 | nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1), 302 | act_layer()) 303 | self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim)) 304 | self.drop = nn.Dropout(drop) 305 | 306 | def forward(self, x): 307 | # bs x hw x c 308 | bs, hw, c = x.size() 309 | hh = int(math.sqrt(hw)) 310 | 311 | x = self.linear1(x) 312 | 313 | # spatial restore 314 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h=hh, w=hh) 315 | # bs,hidden_dim,32x32 316 | 317 | x = self.dwconv(x) 318 | 319 | # flaten 320 | x = rearrange(x, ' b c h w -> b (h w) c', h=hh, w=hh) 321 | 322 | x = self.linear2(x) 323 | x = self.drop(x) 324 | 325 | return x 326 | 327 | def window_partition(x, win_size): 328 | B, H, W, C = x.shape 329 | x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) 330 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) 331 | return windows 332 | 333 | 334 | def window_reverse(windows, win_size, H, W): 335 | B = int(windows.shape[0] / (H * W / win_size / win_size)) 336 | x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) 337 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 338 | return x 339 | 340 | class Downsample(nn.Module): 341 | def __init__(self, in_channel, out_channel): 342 | super(Downsample, self).__init__() 343 | self.conv = nn.Sequential( 344 | nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1), 345 | 346 | ) 347 | 348 | def forward(self, x): 349 | B, L, C = x.shape 350 | H = int(math.sqrt(L)) 351 | W = int(math.sqrt(L)) 352 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 353 | out = self.conv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C 354 | return out 355 | 356 | class Upsample(nn.Module): 357 | def __init__(self, in_channel, out_channel): 358 | super(Upsample, self).__init__() 359 | self.deconv = nn.Sequential( 360 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 361 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1), 362 | ) 363 | 364 | def forward(self, x): 365 | B, L, C = x.shape 366 | H = int(math.sqrt(L)) 367 | W = int(math.sqrt(L)) 368 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 369 | out = self.deconv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C 370 | return out 371 | 372 | class InputProj(nn.Module): 373 | def __init__(self, in_channel=3, out_channel=64, kernel_size=3, stride=1, norm_layer=None, act_layer=nn.LeakyReLU): 374 | super().__init__() 375 | self.proj = nn.Sequential( 376 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size // 2), 377 | act_layer(inplace=True) 378 | ) 379 | if norm_layer is not None: 380 | self.norm = norm_layer(out_channel) 381 | else: 382 | self.norm = None 383 | 384 | def forward(self, x): 385 | B, C, H, W = x.shape 386 | x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C 387 | if self.norm is not None: 388 | x = self.norm(x) 389 | return x 390 | 391 | class OutputProj(nn.Module): 392 | def __init__(self, in_channel=64, out_channel=3, kernel_size=3, stride=1, task_classes=4, prompt=False): 393 | super().__init__() 394 | self.proj = nn.Sequential( 395 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size // 2), 396 | ) 397 | self.prompt = prompt 398 | if(self.prompt): 399 | self.prompt_task = nn.Parameter(torch.randn(1, task_classes, in_channel, 1, 1)) 400 | self.proj_prompt = nn.Sequential( 401 | nn.Conv2d(2*in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size // 2), 402 | ) 403 | 404 | def forward(self, x, prompt=None): 405 | B, L, C = x.shape 406 | H = int(math.sqrt(L)) 407 | W = int(math.sqrt(L)) 408 | x = x.transpose(1, 2).view(B, C, H, W) 409 | if(self.prompt): 410 | detask_prompts = prompt.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_task.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1) 411 | detask_prompts = torch.mean(detask_prompts, dim = 1) 412 | detask_prompts = F.interpolate(detask_prompts,(H,W),mode="bilinear") 413 | combo = torch.concat([x, detask_prompts], dim=1) 414 | combo = self.proj_prompt(combo) 415 | else: 416 | combo = self.proj(x) 417 | 418 | return combo 419 | 420 | 421 | class LeWinTransformerBlock(nn.Module): 422 | def __init__(self, dim, input_resolution, num_heads, win_size=8, shift_size=0, 423 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 424 | act_layer=nn.GELU, norm_layer=nn.LayerNorm,token_projection='linear',token_mlp='leff', 425 | prompt=False, task_classes=4): 426 | super().__init__() 427 | self.dim = dim 428 | self.input_resolution = input_resolution 429 | self.num_heads = num_heads 430 | self.win_size = win_size 431 | self.shift_size = shift_size 432 | self.mlp_ratio = mlp_ratio 433 | self.token_mlp = token_mlp 434 | if min(self.input_resolution) <= self.win_size: 435 | self.shift_size = 0 436 | self.win_size = min(self.input_resolution) 437 | assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size" 438 | 439 | if prompt: 440 | self.prompt_task = nn.Parameter(torch.randn(1, task_classes, 1, dim)) 441 | else: 442 | self.prompt_task = None 443 | 444 | self.norm1 = norm_layer(dim) 445 | self.attn = WindowAttention( 446 | dim, win_size=to_2tuple(self.win_size), num_heads=num_heads, 447 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, 448 | token_projection=token_projection) 449 | 450 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 451 | self.norm2 = norm_layer(dim) 452 | mlp_hidden_dim = int(dim * mlp_ratio) 453 | if token_mlp in ['ffn','mlp']: 454 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,act_layer=act_layer, drop=drop) 455 | elif token_mlp=='leff': 456 | self.mlp = LeFF(dim,mlp_hidden_dim,act_layer=act_layer, drop=drop) 457 | 458 | elif token_mlp=='fastleff': 459 | self.mlp = FastLeFF(dim,mlp_hidden_dim,act_layer=act_layer, drop=drop) 460 | else: 461 | raise Exception("FFN error!") 462 | 463 | self.chnl_reduce = nn.Conv2d(2*dim,dim,kernel_size=1,bias=False) 464 | 465 | 466 | def with_pos_embed(self, tensor, pos): 467 | return tensor if pos is None else tensor + pos 468 | 469 | def extra_repr(self) -> str: 470 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 471 | f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio},modulator={self.modulator}" 472 | 473 | def forward(self, x, point_cls=None, mask=None): 474 | B, L, C = x.shape 475 | H = int(math.sqrt(L)) 476 | W = int(math.sqrt(L)) 477 | 478 | ## input mask 479 | if mask != None: 480 | input_mask = F.interpolate(mask, size=(H,W)).permute(0,2,3,1) 481 | input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1 482 | attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size 483 | attn_mask = attn_mask.unsqueeze(2)*attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size 484 | attn_mask = attn_mask.masked_fill(attn_mask!=0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 485 | else: 486 | attn_mask = None 487 | 488 | ## shift mask 489 | if self.shift_size > 0: 490 | # calculate attention mask for SW-MSA 491 | shift_mask = torch.zeros((1, H, W, 1)).type_as(x) 492 | h_slices = (slice(0, -self.win_size), 493 | slice(-self.win_size, -self.shift_size), 494 | slice(-self.shift_size, None)) 495 | w_slices = (slice(0, -self.win_size), 496 | slice(-self.win_size, -self.shift_size), 497 | slice(-self.shift_size, None)) 498 | cnt = 0 499 | for h in h_slices: 500 | for w in w_slices: 501 | shift_mask[:, h, w, :] = cnt 502 | cnt += 1 503 | shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1 504 | shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size 505 | shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(2) # nW, win_size*win_size, win_size*win_size 506 | shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill(shift_attn_mask == 0, float(0.0)) 507 | attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask 508 | 509 | shortcut = x 510 | x = self.norm1(x) 511 | x = x.view(B, H, W, C) 512 | 513 | # cyclic shift 514 | if self.shift_size > 0: 515 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 516 | else: 517 | shifted_x = x 518 | 519 | # partition windows 520 | x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C 521 | x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C 522 | 523 | if self.prompt_task is not None: 524 | B_win, L_win, C_win = x_windows.shape 525 | M = B_win // B 526 | # (B, prompt_len, 1, 1) * (B, prompt_len, 1, prompt_dim) 527 | detask_prompts = point_cls.unsqueeze(-1).unsqueeze(-1) * self.prompt_task.unsqueeze(0).repeat(B,1,1,1,1).squeeze(1) 528 | # (B, 1, prompt_dim) 529 | detask_prompts = torch.mean(detask_prompts, dim=1) 530 | # (B, N, prompt_dim) 531 | detask_prompts = detask_prompts.unsqueeze(1).repeat(1, L_win, 1, 1).squeeze(2) 532 | # (B*M, N, prompt_dim) 533 | detask_prompts = detask_prompts.repeat(M, 1, 1) 534 | combo = torch.concat([x_windows, detask_prompts], dim=2) 535 | combo = combo.permute(0,2,1).view(-1, 2*C, self.win_size, self.win_size) 536 | combo = self.chnl_reduce(combo) 537 | wmsa_in = combo.view(-1, C, self.win_size * self.win_size).permute(0,2,1) 538 | 539 | else: 540 | wmsa_in = x_windows 541 | 542 | # W-MSA/SW-MSA 543 | attn_windows = self.attn(wmsa_in, mask=attn_mask) # nW*B, win_size*win_size, C 544 | 545 | # merge windows 546 | attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C) 547 | shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C 548 | 549 | # reverse cyclic shift 550 | if self.shift_size > 0: 551 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 552 | else: 553 | x = shifted_x 554 | x = x.view(B, H * W, C) 555 | 556 | # FFN 557 | x = shortcut + self.drop_path(x) 558 | x = x + self.drop_path(self.mlp(self.norm2(x))) 559 | del attn_mask 560 | return x 561 | 562 | class BasicUformerLayer(nn.Module): 563 | def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size, 564 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 565 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False, 566 | token_projection='linear',token_mlp='ffn', shift_flag=True, 567 | prompt=False, task_classes=4): 568 | 569 | super().__init__() 570 | self.dim = dim 571 | self.input_resolution = input_resolution 572 | self.depth = depth 573 | self.use_checkpoint = use_checkpoint 574 | if shift_flag: 575 | self.blocks = nn.ModuleList([ 576 | LeWinTransformerBlock(dim=dim, input_resolution=input_resolution, 577 | num_heads=num_heads, win_size=win_size, 578 | shift_size=0 if (i % 2 == 0) else win_size // 2, 579 | mlp_ratio=mlp_ratio, 580 | qkv_bias=qkv_bias, qk_scale=qk_scale, 581 | drop=drop, attn_drop=attn_drop, 582 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 583 | norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp, 584 | prompt=prompt, task_classes=task_classes) 585 | for i in range(depth)]) 586 | else: 587 | self.blocks = nn.ModuleList([ 588 | LeWinTransformerBlock(dim=dim, input_resolution=input_resolution, 589 | num_heads=num_heads, win_size=win_size, 590 | shift_size=0, 591 | mlp_ratio=mlp_ratio, 592 | qkv_bias=qkv_bias, qk_scale=qk_scale, 593 | drop=drop, attn_drop=attn_drop, 594 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 595 | norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp, 596 | prompt=prompt, task_classes=task_classes) 597 | for i in range(depth)]) 598 | 599 | def extra_repr(self) -> str: 600 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 601 | 602 | def forward(self, x, point_cls=None, mask=None): 603 | for blk in self.blocks: 604 | if self.use_checkpoint: 605 | x = checkpoint.checkpoint(blk, x, point_cls) 606 | else: 607 | x = blk(x, point_cls, mask) 608 | return x 609 | 610 | class MOWA(nn.Module): 611 | def __init__(self, img_size=256, in_chans=4, 612 | embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2], 613 | win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, 614 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_motion_rate=0., attn_drop_motion_rate=0., 615 | norm_layer=nn.LayerNorm, patch_norm=True, 616 | use_checkpoint=False, token_projection='linear', token_mlp='ffn', se_layer=False, 617 | dowsample=Downsample, upsample=Upsample, motion_net=MotionNet_Coord, 618 | shift_flag=True, prompt=False, task_classes=4, tps_points=[8, 10, 12, 14], 619 | correct_encoder=False, head_num=8, shared_head=False, **kwargs): 620 | super().__init__() 621 | 622 | self.num_enc_layers = len(depths) // 2 623 | self.num_dec_layers = len(depths) // 2 624 | self.embed_dim = embed_dim 625 | self.patch_norm = patch_norm 626 | self.mlp_ratio = mlp_ratio 627 | self.token_projection = token_projection 628 | self.mlp = token_mlp 629 | self.win_size = win_size 630 | self.mini_size = img_size // (2 ** 4) 631 | self.down_size = 2 ** 4 632 | self.tps_points = tps_points 633 | self.correct_encoder = correct_encoder 634 | self.prompt = prompt 635 | self.head_num = head_num 636 | self.shared_head = shared_head 637 | 638 | self.pos_drop = nn.Dropout(p=drop_rate) 639 | enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))] 640 | conv_dpr = [drop_path_rate] * depths[4] 641 | dec_dpr = enc_dpr[::-1] 642 | 643 | '''Input/Output Projection''' 644 | self.input_proj = InputProj(in_channel=in_chans, out_channel=embed_dim, kernel_size=3, stride=1, 645 | act_layer=nn.LeakyReLU) 646 | self.output_proj = OutputProj(in_channel=2 * embed_dim, out_channel=2, kernel_size=3, stride=1, task_classes=task_classes, prompt=False) 647 | 648 | '''Encoder''' 649 | self.encoderlayer_0 = BasicUformerLayer(dim=embed_dim, 650 | output_dim=embed_dim, 651 | input_resolution=(img_size, 652 | img_size), 653 | depth=depths[0], 654 | num_heads=num_heads[0], 655 | win_size=win_size, 656 | mlp_ratio=self.mlp_ratio, 657 | qkv_bias=qkv_bias, qk_scale=qk_scale, 658 | drop=drop_rate, attn_drop=attn_drop_rate, 659 | drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])], 660 | norm_layer=norm_layer, 661 | use_checkpoint=use_checkpoint, 662 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag) 663 | self.dowsample_0 = dowsample(embed_dim, embed_dim*2) 664 | 665 | self.encoderlayer_1 = BasicUformerLayer(dim=embed_dim*2, 666 | output_dim=embed_dim*2, 667 | input_resolution=(img_size // 2, 668 | img_size // 2), 669 | depth=depths[1], 670 | num_heads=num_heads[1], 671 | win_size=win_size, 672 | mlp_ratio=self.mlp_ratio, 673 | qkv_bias=qkv_bias, qk_scale=qk_scale, 674 | drop=drop_rate, attn_drop=attn_drop_rate, 675 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])], 676 | norm_layer=norm_layer, 677 | use_checkpoint=use_checkpoint, 678 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag) 679 | self.dowsample_1 = dowsample(embed_dim*2, embed_dim*4) 680 | 681 | self.encoderlayer_2 = BasicUformerLayer(dim=embed_dim*4, 682 | output_dim=embed_dim*4, 683 | input_resolution=(img_size // (2 ** 2), 684 | img_size // (2 ** 2)), 685 | depth=depths[2], 686 | num_heads=num_heads[2], 687 | win_size=win_size, 688 | mlp_ratio=self.mlp_ratio, 689 | qkv_bias=qkv_bias, qk_scale=qk_scale, 690 | drop=drop_rate, attn_drop=attn_drop_rate, 691 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])], 692 | norm_layer=norm_layer, 693 | use_checkpoint=use_checkpoint, 694 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag) 695 | self.dowsample_2 = dowsample(embed_dim*4, embed_dim*8) 696 | 697 | self.encoderlayer_3 = BasicUformerLayer(dim=embed_dim*8, 698 | output_dim=embed_dim*8, 699 | input_resolution=(img_size // (2 ** 3), 700 | img_size // (2 ** 3)), 701 | depth=depths[3], 702 | num_heads=num_heads[3], 703 | win_size=win_size, 704 | mlp_ratio=self.mlp_ratio, 705 | qkv_bias=qkv_bias, qk_scale=qk_scale, 706 | drop=drop_rate, attn_drop=attn_drop_rate, 707 | drop_path=enc_dpr[sum(depths[:3]):sum(depths[:4])], 708 | norm_layer=norm_layer, 709 | use_checkpoint=use_checkpoint, 710 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag) 711 | self.dowsample_3 = dowsample(embed_dim*8, embed_dim*16) 712 | 713 | 714 | '''Bottleneck''' 715 | self.conv = BasicUformerLayer(dim=embed_dim*16, 716 | output_dim=embed_dim*16, 717 | input_resolution=(img_size // (2 ** 4), 718 | img_size // (2 ** 4)), 719 | depth=depths[4], 720 | num_heads=num_heads[4], 721 | win_size=win_size, 722 | mlp_ratio=self.mlp_ratio, 723 | qkv_bias=qkv_bias, qk_scale=qk_scale, 724 | drop=drop_rate, attn_drop=attn_drop_rate, 725 | drop_path=conv_dpr, 726 | norm_layer=norm_layer, 727 | use_checkpoint=use_checkpoint, 728 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag) 729 | 730 | '''TPS regression heads''' 731 | if(self.shared_head): 732 | reg = BasicUformerLayer(dim=embed_dim*16, 733 | output_dim=embed_dim*16, 734 | input_resolution=(img_size // (2 ** 4), 735 | img_size // (2 ** 4)), 736 | depth=depths[4], 737 | num_heads=num_heads[4], 738 | win_size=win_size, 739 | mlp_ratio=self.mlp_ratio, 740 | qkv_bias=qkv_bias, qk_scale=qk_scale, 741 | drop=drop_motion_rate, attn_drop=attn_drop_motion_rate, 742 | drop_path=conv_dpr, 743 | norm_layer=norm_layer, 744 | use_checkpoint=use_checkpoint, 745 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag) 746 | self.reg = nn.ModuleList([reg for _ in range(self.head_num)]) 747 | else: 748 | self.reg = nn.ModuleList([BasicUformerLayer(dim=embed_dim*16, 749 | output_dim=embed_dim*16, 750 | input_resolution=(img_size // (2 ** 4), 751 | img_size // (2 ** 4)), 752 | depth=depths[4], 753 | num_heads=num_heads[4], 754 | win_size=win_size, 755 | mlp_ratio=self.mlp_ratio, 756 | qkv_bias=qkv_bias, qk_scale=qk_scale, 757 | drop=drop_motion_rate, attn_drop=attn_drop_motion_rate, 758 | drop_path=conv_dpr, 759 | norm_layer=norm_layer, 760 | use_checkpoint=use_checkpoint, 761 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag) for _ in range(self.head_num)]) 762 | 763 | 764 | self.motion = nn.ModuleList([motion_net(embed_dim*16, 2, tps_points[i]) for i in range(self.head_num)]) 765 | self.pointnet = PointNet(task_classes, tps_points[-1], tps_points[-1]) 766 | 767 | '''Decoder to Flow''' 768 | self.upsample_0 = upsample(embed_dim*16, embed_dim*8) 769 | self.decoderlayer_0 = BasicUformerLayer(dim=embed_dim*16, 770 | output_dim=embed_dim*16, 771 | input_resolution=(img_size // (2 ** 3), 772 | img_size // (2 ** 3)), 773 | depth=depths[5], 774 | num_heads=num_heads[5], 775 | win_size=win_size, 776 | mlp_ratio=self.mlp_ratio, 777 | qkv_bias=qkv_bias, qk_scale=qk_scale, 778 | drop=drop_rate, attn_drop=attn_drop_rate, 779 | drop_path=dec_dpr[:depths[5]], 780 | norm_layer=norm_layer, 781 | use_checkpoint=use_checkpoint, 782 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag, 783 | prompt=prompt, task_classes=task_classes) 784 | 785 | self.upsample_1 = upsample(embed_dim*16, embed_dim*4) 786 | self.decoderlayer_1 = BasicUformerLayer(dim=embed_dim*8, 787 | output_dim=embed_dim*8, 788 | input_resolution=(img_size // (2 ** 2), 789 | img_size // (2 ** 2)), 790 | depth=depths[6], 791 | num_heads=num_heads[6], 792 | win_size=win_size, 793 | mlp_ratio=self.mlp_ratio, 794 | qkv_bias=qkv_bias, qk_scale=qk_scale, 795 | drop=drop_rate, attn_drop=attn_drop_rate, 796 | drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])], 797 | norm_layer=norm_layer, 798 | use_checkpoint=use_checkpoint, 799 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag, 800 | prompt=prompt, task_classes=task_classes) 801 | 802 | self.upsample_2 = upsample(embed_dim*8, embed_dim*2) 803 | self.decoderlayer_2 = BasicUformerLayer(dim=embed_dim*4, 804 | output_dim=embed_dim*4, 805 | input_resolution=(img_size // 2, 806 | img_size // 2), 807 | depth=depths[7], 808 | num_heads=num_heads[7], 809 | win_size=win_size, 810 | mlp_ratio=self.mlp_ratio, 811 | qkv_bias=qkv_bias, qk_scale=qk_scale, 812 | drop=drop_rate, attn_drop=attn_drop_rate, 813 | drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])], 814 | norm_layer=norm_layer, 815 | use_checkpoint=use_checkpoint, 816 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag, 817 | prompt=prompt, task_classes=task_classes) 818 | 819 | self.upsample_3 = upsample(embed_dim*4, embed_dim) 820 | self.decoderlayer_3 = BasicUformerLayer(dim=embed_dim*2, 821 | output_dim=embed_dim*2, 822 | input_resolution=(img_size, 823 | img_size), 824 | depth=depths[8], 825 | num_heads=num_heads[8], 826 | win_size=win_size, 827 | mlp_ratio=self.mlp_ratio, 828 | qkv_bias=qkv_bias, qk_scale=qk_scale, 829 | drop=drop_rate, attn_drop=attn_drop_rate, 830 | drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])], 831 | norm_layer=norm_layer, 832 | use_checkpoint=use_checkpoint, 833 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag, 834 | prompt=prompt, task_classes=task_classes) 835 | 836 | self.apply(self._init_weights) 837 | 838 | def _init_weights(self, m): 839 | if isinstance(m, nn.Linear): 840 | trunc_normal_(m.weight, std=.02) 841 | if isinstance(m, nn.Linear) and m.bias is not None: 842 | nn.init.constant_(m.bias, 0) 843 | elif isinstance(m, nn.LayerNorm): 844 | nn.init.constant_(m.bias, 0) 845 | nn.init.constant_(m.weight, 1.0) 846 | 847 | @torch.jit.ignore 848 | def no_weight_decay(self): 849 | return {'absolute_pos_embed'} 850 | 851 | @torch.jit.ignore 852 | def no_weight_decay_keywords(self): 853 | return {'relative_position_bias_table'} 854 | 855 | def extra_repr(self) -> str: 856 | return f"embed_dim={self.embed_dim}, token_projection={self.token_projection}, token_mlp={self.mlp},win_size={self.win_size}" 857 | 858 | def forward(self, x, mask, feed_prompt=None): 859 | '''Input Projection''' 860 | x = torch.cat([x, mask], 1) 861 | y = self.input_proj(x) 862 | y = self.pos_drop(y) 863 | 864 | '''Encoder''' 865 | conv0 = self.encoderlayer_0(y) 866 | pool0 = self.dowsample_0(conv0) 867 | conv1 = self.encoderlayer_1(pool0) 868 | pool1 = self.dowsample_1(conv1) 869 | conv2 = self.encoderlayer_2(pool1) 870 | pool2 = self.dowsample_2(conv2) 871 | conv3 = self.encoderlayer_3(pool2) 872 | pool3 = self.dowsample_3(conv3) 873 | conv4 = self.conv(pool3) 874 | 875 | '''TPS regression heads''' 876 | tps = [] 877 | tps_up = [] 878 | fea = [conv4] 879 | for i in range(self.head_num): 880 | conv_fea = fea[-1] 881 | pre = self.reg[i](conv_fea) 882 | pre = self.motion[i](pre) 883 | 884 | warp = transform_tps_fea(pre/self.down_size, conv_fea, self.tps_points[i]-1, self.tps_points[i]-1, 885 | self.embed_dim*16, self.mini_size, self.mini_size) 886 | fea.append(warp) 887 | if i==0: 888 | tps.append(pre) 889 | else: 890 | tps.append(pre + tps_up[-1]) 891 | if(i < self.head_num - 1): 892 | tps_up.append(upsample_tps(tps[i], self.tps_points[i]-1, self.tps_points[i]-1, self.tps_points[i+1], self.tps_points[i+1])) 893 | 894 | '''Point classification''' 895 | point_cls = self.pointnet(tps[-1], conv4) 896 | prompt = F.softmax(point_cls, dim=1) 897 | 898 | '''Decoder to residual flow''' 899 | up0 = self.upsample_0(fea[-1]) 900 | if(self.correct_encoder): 901 | conv3 = transform_tps_fea(tps[-1]/(self.down_size//2), conv3, self.tps_points[-1]-1, self.tps_points[-1]-1, 902 | self.embed_dim*8, self.mini_size*2, self.mini_size*2) 903 | deconv0 = torch.cat([up0, conv3], -1) 904 | deconv0 = self.decoderlayer_0(deconv0, prompt) 905 | 906 | up1 = self.upsample_1(deconv0) 907 | if(self.correct_encoder): 908 | conv2 = transform_tps_fea(tps[-1]/(self.down_size//4), conv2, self.tps_points[-1]-1, self.tps_points[-1]-1, 909 | self.embed_dim*4, self.mini_size*4, self.mini_size*4) 910 | deconv1 = torch.cat([up1, conv2], -1) 911 | deconv1 = self.decoderlayer_1(deconv1, prompt) 912 | 913 | up2 = self.upsample_2(deconv1) 914 | if(self.correct_encoder): 915 | conv1 = transform_tps_fea(tps[-1]/(self.down_size//8), conv1, self.tps_points[-1]-1, self.tps_points[-1]-1, 916 | self.embed_dim*2, self.mini_size*8, self.mini_size*8) 917 | deconv2 = torch.cat([up2, conv1], -1) 918 | deconv2 = self.decoderlayer_2(deconv2, prompt) 919 | 920 | up3 = self.upsample_3(deconv2) 921 | if(self.correct_encoder): 922 | conv0 = transform_tps_fea(tps[-1], conv0, self.tps_points[-1]-1, self.tps_points[-1]-1, 923 | self.embed_dim*1, self.mini_size*16, self.mini_size*16) 924 | deconv3 = torch.cat([up3, conv0], -1) 925 | deconv3 = self.decoderlayer_3(deconv3, prompt) 926 | 927 | '''Output flow''' 928 | flow = self.output_proj(deconv3) 929 | 930 | return tps, flow, point_cls --------------------------------------------------------------------------------