├── model
├── __init__.py
├── loss.py
├── builder.py
└── network.py
├── utils
├── __init__.py
├── torch_tps_upsample.py
├── utils_op.py
├── utils_transform.py
├── flow_viz.py
├── utils_module.py
└── torch_tps_transform.py
├── assets
└── teaser.jpg
├── scripts
├── train.sh
├── test.sh
└── test_portrait.sh
├── warmup_scheduler
├── __init__.py
├── run.py
└── scheduler.py
├── requirements.txt
├── LICENSE
├── Datasets
└── readme.md
├── README.md
├── dataset_loaders.py
├── test.py
├── test_portrait.py
└── train.py
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KangLiao929/MOWA/HEAD/assets/teaser.jpg
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | python train.py -gpu '0, 1, 2, 3, 4, 5, 6, 7' -b 8 -m 'mowa' [--train_path=...]
--------------------------------------------------------------------------------
/warmup_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from warmup_scheduler.scheduler import GradualWarmupScheduler
3 |
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | python test.py --gpu 0 --batch_size 1 --model_path '/checkpoint/' --method 'mowa' [--test_path=...]
--------------------------------------------------------------------------------
/scripts/test_portrait.sh:
--------------------------------------------------------------------------------
1 | python test_portrait.py --gpu 0 --batch_size 1 --model_path '/checkpoint/' --method 'mowa' [--test_path=...]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | h5py==3.11.0
2 | imageio==2.33.1
3 | imgaug==0.4.0
4 | lpips==0.1.4
5 | Markdown==3.5.1
6 | matplotlib==3.7.4
7 | networkx==3.1
8 | numpy==1.24.4
9 | opencv-python==4.9.0.80
10 | opencv-python-headless==4.9.0.80
11 | packaging==23.2
12 | pandas==2.0.3
13 | pillow==10.2.0
14 | pytorch-fid==0.3.0
15 | pytorch-lightning==2.2.1
16 | PyYAML==6.0.1
17 | scikit-image==0.21.0
18 | scikit-learn==1.3.2
19 | scipy==1.10.1
20 | seaborn==0.13.1
21 | six==1.16.0
22 | tensorboard==2.14.0
23 | tensorboardX==2.6.2.2
24 | timm==0.9.12
25 | torch==2.1.2
26 | torchmetrics==1.3.2
27 | torchvision==0.16.2
28 | tqdm==4.66.1
29 |
--------------------------------------------------------------------------------
/warmup_scheduler/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR
3 | from torch.optim.sgd import SGD
4 |
5 | from warmup_scheduler import GradualWarmupScheduler
6 |
7 |
8 | if __name__ == '__main__':
9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
10 | optim = SGD(model, 0.1)
11 |
12 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
13 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)
14 |
15 | optim.zero_grad()
16 | optim.step()
17 |
18 | for epoch in range(1, 20):
19 | scheduler_warmup.step(epoch)
20 | print(epoch, optim.param_groups[0]['lr'])
21 |
22 | optim.step()
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | S-Lab License 1.0
2 |
3 | Copyright 2025 S-Lab
4 |
5 | Redistribution and use for non-commercial purpose in source and
6 | binary forms, with or without modification, are permitted provided
7 | that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright
10 | notice, this list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright
13 | notice, this list of conditions and the following disclaimer in
14 | the documentation and/or other materials provided with the
15 | distribution.
16 |
17 | 3. Neither the name of the copyright holder nor the names of its
18 | contributors may be used to endorse or promote products derived
19 | from this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32 |
33 | In the event that redistribution and/or use for commercial purpose in
34 | source or binary forms, with or without modification is required,
35 | please contact the contributor(s) of the work.
36 |
--------------------------------------------------------------------------------
/warmup_scheduler/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 |
5 | class GradualWarmupScheduler(_LRScheduler):
6 | """ Gradually warm-up(increasing) learning rate in optimizer.
7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
8 |
9 | Args:
10 | optimizer (Optimizer): Wrapped optimizer.
11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
12 | total_epoch: target learning rate is reached at total_epoch, gradually
13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
14 | """
15 |
16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
17 | self.multiplier = multiplier
18 | if self.multiplier < 1.:
19 | raise ValueError('multiplier should be greater thant or equal to 1.')
20 | self.total_epoch = total_epoch
21 | self.after_scheduler = after_scheduler
22 | self.finished = False
23 | super(GradualWarmupScheduler, self).__init__(optimizer)
24 |
25 | def get_lr(self):
26 | if self.last_epoch > self.total_epoch:
27 | if self.after_scheduler:
28 | if not self.finished:
29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
30 | self.finished = True
31 | return self.after_scheduler.get_lr()
32 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
33 |
34 | if self.multiplier == 1.0:
35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
36 | else:
37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
38 |
39 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
40 | if epoch is None:
41 | epoch = self.last_epoch + 1
42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
43 | if self.last_epoch <= self.total_epoch:
44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
46 | param_group['lr'] = lr
47 | else:
48 | if epoch is None:
49 | self.after_scheduler.step(metrics, None)
50 | else:
51 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
52 |
53 | def step(self, epoch=None, metrics=None):
54 | if type(self.after_scheduler) != ReduceLROnPlateau:
55 | if self.finished and self.after_scheduler:
56 | if epoch is None:
57 | self.after_scheduler.step(None)
58 | else:
59 | self.after_scheduler.step(epoch - self.total_epoch)
60 | else:
61 | return super(GradualWarmupScheduler, self).step(epoch)
62 | else:
63 | self.step_ReduceLROnPlateau(metrics, epoch)
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def get_vgg19_FeatureMap(vgg_model, input_tensor, layer_index):
5 | mean = torch.tensor([0.485, 0.456, 0.406]).reshape((1,3,1,1))
6 | std = torch.tensor([0.229, 0.224, 0.225]).reshape((1,3,1,1))
7 | if torch.cuda.is_available():
8 | mean = mean.cuda()
9 | std = std.cuda()
10 |
11 | vgg_input = (input_tensor - mean) / std
12 | for i in range(0, layer_index+1):
13 | if i == 0:
14 | x = vgg_model.module.features[0](vgg_input)
15 | else:
16 | x = vgg_model.module.features[i](x)
17 | return x
18 |
19 | def l_num_loss(img1, img2, l_num=1):
20 | return torch.mean(torch.abs((img1 - img2)**l_num))
21 |
22 | def mask_flow_loss(flow, gt, task_ids, target_id=5):
23 | batch_size = flow.size(0)
24 | target_ids = torch.full((batch_size, 1), target_id, dtype=task_ids.dtype, device=task_ids.device)
25 | mask = task_ids == target_ids.squeeze(-1)
26 | flow_mask = flow * mask.view(batch_size, 1, 1, 1)
27 | gt_mask = gt * mask.view(batch_size, 1, 1, 1)
28 | return l_num_loss(flow_mask, gt_mask, 1)
29 |
30 | def cal_appearance_loss_sum(warp_list, gt, weights, eps=1e-7):
31 | num = len(warp_list)
32 | loss = []
33 | for i in range(num):
34 | warp = warp_list[i]
35 | loss.append(l_num_loss(warp, gt, 1) + eps)
36 |
37 | return torch.sum(torch.stack([w * v for w, v in zip(weights, loss)]))
38 |
39 | def cal_perception_loss_sum(vgg_model, warp_list, gt, weights):
40 | num = len(warp_list)
41 | loss = []
42 | for i in range(num):
43 | warp = warp_list[i]
44 | warp_feature = get_vgg19_FeatureMap(vgg_model, warp, 24)
45 | gt_feature = get_vgg19_FeatureMap(vgg_model, gt, 24)
46 |
47 | loss.append(l_num_loss(warp_feature, gt_feature, 2))
48 |
49 | return torch.sum(torch.stack([w * v for w, v in zip(weights, loss)]))
50 |
51 | def cal_point_loss(pre, gt):
52 | criterion = nn.CrossEntropyLoss()
53 | return criterion(pre, gt.long().cuda())
54 |
55 | def cal_inter_grid_loss_sum(mesh_list, tps_points, weights, eps=1e-8):
56 | num = len(mesh_list)
57 | loss = []
58 | for i in range(num):
59 | mesh = mesh_list[i]
60 | grid_w = tps_points[i]-1
61 | grid_h = tps_points[i]-1
62 | w_edges = mesh[:,:,0:grid_w,:] - mesh[:,:,1:grid_w+1,:]
63 |
64 | w_norm1 = torch.sqrt(torch.sum(w_edges[:,:,0:grid_w-1,:] * w_edges[:,:,0:grid_w-1,:], 3) + eps)
65 | w_norm2 = torch.sqrt(torch.sum(w_edges[:,:,1:grid_w,:] * w_edges[:,:,1:grid_w,:], 3) + eps)
66 | cos_w = torch.sum(w_edges[:,:,0:grid_w-1,:] * w_edges[:,:,1:grid_w,:], 3) / (w_norm1 * w_norm2)
67 | delta_w_angle = 1 - cos_w
68 |
69 | h_edges = mesh[:,0:grid_h,:,:] - mesh[:,1:grid_h+1,:,:]
70 |
71 | h_norm1 = torch.sqrt(torch.sum(h_edges[:,0:grid_h-1,:,:] * h_edges[:,0:grid_h-1,:,:], 3) + eps)
72 | h_norm2 = torch.sqrt(torch.sum(h_edges[:,1:grid_h,:,:] * h_edges[:,1:grid_h,:,:], 3) + eps)
73 | cos_h = torch.sum(h_edges[:,0:grid_h-1,:,:] * h_edges[:,1:grid_h,:,:], 3) / (h_norm1 * h_norm2)
74 | delta_h_angle = 1 - cos_h
75 |
76 | loss.append(torch.mean(delta_w_angle) + torch.mean(delta_h_angle))
77 |
78 | return torch.sum(torch.stack([w * v for w, v in zip(weights, loss)]))
79 |
--------------------------------------------------------------------------------
/utils/torch_tps_upsample.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def transformer(source, target, out_size):
5 | """
6 | Thin Plate Spline Spatial Transformer Layer
7 | convert the TPS deformation into optical flows
8 | TPS control points are arranged in arbitrary positions given by `source`.
9 | source : float Tensor [num_batch, num_point, 2]
10 | The source position of the control points.
11 | target : float Tensor [num_batch, num_point, 2]
12 | The target position of the control points.
13 | out_size: tuple of two integers [height, width]
14 | The size of the output of the network (height, width)
15 | """
16 |
17 | def _meshgrid(height, width, source):
18 |
19 | x_t = torch.matmul(torch.ones([height, 1]), torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 0))
20 | y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1), torch.ones([1, width]))
21 | if torch.cuda.is_available():
22 | x_t = x_t.cuda()
23 | y_t = y_t.cuda()
24 |
25 | x_t_flat = x_t.reshape([1, 1, -1])
26 | y_t_flat = y_t.reshape([1, 1, -1])
27 |
28 | num_batch = source.size()[0]
29 | px = torch.unsqueeze(source[:,:,0], 2)
30 | py = torch.unsqueeze(source[:,:,1], 2)
31 | if torch.cuda.is_available():
32 | px = px.cuda()
33 | py = py.cuda()
34 | d2 = torch.square(x_t_flat - px) + torch.square(y_t_flat - py)
35 | r = d2 * torch.log(d2 + 1e-9)
36 | x_t_flat_g = x_t_flat.expand(num_batch, -1, -1)
37 | y_t_flat_g = y_t_flat.expand(num_batch, -1, -1)
38 | ones = torch.ones_like(x_t_flat_g)
39 | if torch.cuda.is_available():
40 | ones = ones.cuda()
41 |
42 | grid = torch.cat((ones, x_t_flat_g, y_t_flat_g, r), 1)
43 |
44 | return grid
45 |
46 | def _transform(T, source, out_size):
47 | num_batch, *_ = T.size()
48 |
49 | out_height, out_width = out_size[0], out_size[1]
50 | grid = _meshgrid(out_height, out_width, source)
51 |
52 | T_g = torch.matmul(T, grid)
53 | x_s = T_g[:,0,:]
54 | y_s = T_g[:,1,:]
55 |
56 | flow_x = (x_s - grid[:,1,:])*(out_width/2)
57 | flow_y = (y_s - grid[:,2,:])*(out_height/2)
58 |
59 | flow = torch.stack([flow_x, flow_y], 1)
60 | flow = flow.reshape([num_batch, 2, out_height, out_width])
61 |
62 | return flow
63 |
64 |
65 | def _solve_system(source, target):
66 | num_batch = source.size()[0]
67 | num_point = source.size()[1]
68 |
69 | np.set_printoptions(precision=8)
70 |
71 | ones = torch.ones(num_batch, num_point, 1).float()
72 | if torch.cuda.is_available():
73 | ones = ones.cuda()
74 | p = torch.cat([ones, source], 2)
75 |
76 | p_1 = p.reshape([num_batch, -1, 1, 3])
77 | p_2 = p.reshape([num_batch, 1, -1, 3])
78 | d2 = torch.sum(torch.square(p_1-p_2), 3)
79 | r = d2 * torch.log(d2 + 1e-9)
80 |
81 | zeros = torch.zeros(num_batch, 3, 3).float()
82 | if torch.cuda.is_available():
83 | zeros = zeros.cuda()
84 | W_0 = torch.cat((p, r), 2)
85 | W_1 = torch.cat((zeros, p.permute(0,2,1)), 2)
86 | W = torch.cat((W_0, W_1), 1)
87 | W_inv = torch.inverse(W.type(torch.float64))
88 |
89 | zeros2 = torch.zeros(num_batch, 3, 2)
90 | if torch.cuda.is_available():
91 | zeros2 = zeros2.cuda()
92 | tp = torch.cat((target, zeros2), 1)
93 | T = torch.matmul(W_inv, tp.type(torch.float64))
94 | T = T.permute(0, 2, 1)
95 |
96 | return T.type(torch.float32)
97 |
98 | T = _solve_system(source, target)
99 | output = _transform(T, source, out_size)
100 |
101 | return output
--------------------------------------------------------------------------------
/Datasets/readme.md:
--------------------------------------------------------------------------------
1 | # Guidance for the Practical Image Warping Datasets
2 | For the convenience of training/testing various image warping tasks in one project, we cleaned and arranged the mainstream warping datasets with unified structures and more visual assistance. Note that the ground truth of warping flow is only available in training dataset of the portrait correction task. For other tasks, we provide the pseudo flows (obtained by [RAFT](https://github.com/princeton-vl/RAFT)) for visualization. The download link and source of each task are shown as follows.
3 |
4 |
5 | | # | Type | Training Set | Testing Set | Source |
6 | |:--:|:---------------------:|:------------------------------------------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------:|
7 | | 1 | Stitched Image | [Google Drive](https://drive.google.com/file/d/1vWaD3bbUd6TZ_wQlehBxVoYlwBxZdAwu/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1eiMwbJMCERbgZUuf6hQOhQ?pwd=fpa8) | [Google Drive](https://drive.google.com/file/d/1Ldqn4Q-mrrRibGNdO9t8cUVpH0YZWK7i/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1MOCOUUJGfqBss76nw-dDiw?pwd=d4du) | [Website](https://github.com/nie-lang/DeepRectangling) |
8 | | 2 | Rectified Wide-Angle Image | [Google Drive](https://drive.google.com/file/d/1Cxv97NybP5t8aBekm7XNvBP4i2CAhxUz/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1J7LeYenzbG4fKxMHePGNpg?pwd=cmc1) | [Google Drive](https://drive.google.com/file/d/1WzXSFQoLuqeAlAXR4gO5eBZhUJgz-aN_/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1dFlPT-iB-F-382ZSQx5z-Q?pwd=wdu8) | [Website](https://github.com/KangLiao929/RecRecNet) |
9 | | 3 | Unrolling Shutter Image | [Google Drive](https://drive.google.com/file/d/1r3B3BAmZjmy5PSzQZllIxE8mM6bSPFNj/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1dHxI8ILMg84dUE6icP3pdg?pwd=cqu3) | [Google Drive](https://drive.google.com/file/d/18Qtj_sp2cDWM3OZv0UYjf-5qpOBzv4QN/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1KnNs-S0BOpwvMer3epnndg?pwd=yvxw) | Self-constructed |
10 | | 4 | Rotated Image | [Google Drive](https://drive.google.com/file/d/1WJAvf5sG3JyLeVOi9ZrSBOcOqdnoURGG/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1ambt8n0Bc9tkMD6jdbagSg?pwd=2nkw) | [Google Drive](https://drive.google.com/file/d/1HdQdK6enbOebASL4Wie3kzkiQ9SUNmBH/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1Tr3SvOnKXZ4JfE9KYVmNfw?pwd=3ea8) | [Website](https://github.com/nie-lang/RotationCorrection) |
11 | | 5 | Fisheye Image | [Google Drive](https://drive.google.com/file/d/1yn_hlVyFRIt3yTBPsDNKKfx9U9j7vY01/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1tVbeSywPyceyeCIdhbGSAQ?pwd=1v6a) | [Google Drive](https://drive.google.com/file/d/1X1851eEB0gvzEJrOOw7uQdKDJ2MQZAaG/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1TwURM0jjLcnjm5v4zdtjCg?pwd=7w56) | [Website](https://github.com/uof1745-cmd/PCN) |
12 | | 6 | Portrait Photo | [Google Drive](https://drive.google.com/file/d/1Ng_dx2Y4v8Qjv4xVtWA6f2il3SklBIu3/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1AJi1X3wkgFiVrO7GhvSrNw?pwd=yidy) | [Google Drive](https://drive.google.com/file/d/18yY7O7Yi6TygnRG73eCTxkVtgMk5qY9k/view?usp=sharing), [Baidu Netdisk](https://pan.baidu.com/s/1FVCK57a7y8PoHfMDXxVmmA?pwd=ej81) | [Website](https://github.com/megvii-research/Portraits_Correction) |
13 |
--------------------------------------------------------------------------------
/utils/utils_op.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | from torch.nn.parallel import DistributedDataParallel as DDP
5 | import cv2
6 | from imgaug import augmenters as iaa
7 | import os
8 |
9 | def set_gpu(args, distributed=False, rank=0):
10 | """ set parameter to gpu or ddp """
11 | if args is None:
12 | return None
13 | if distributed and isinstance(args, torch.nn.Module):
14 | return DDP(args.cuda(), device_ids=[rank], output_device=rank, broadcast_buffers=True, find_unused_parameters=True)
15 | else:
16 | return args.cuda()
17 |
18 | def set_device(args, distributed=False, rank=0):
19 | """ set parameter to gpu or cpu """
20 | if torch.cuda.is_available():
21 | if isinstance(args, list):
22 | return (set_gpu(item, distributed, rank) for item in args)
23 | elif isinstance(args, dict):
24 | return {key:set_gpu(args[key], distributed, rank) for key in args}
25 | else:
26 | args = set_gpu(args, distributed, rank)
27 | return args
28 |
29 | def count_files(directory_path):
30 | for subdir, _, _ in os.walk(directory_path):
31 | if subdir != directory_path:
32 | return len([file for _, _, files in os.walk(subdir) for file in files])
33 | return 0
34 |
35 | def draw_mesh_on_warp(warp, f_local, grid_h, grid_w):
36 | height = warp.shape[0]
37 | width = warp.shape[1]
38 |
39 | min_w = np.minimum(np.min(f_local[:,:,0]), 0).astype(np.int32)
40 | max_w = np.maximum(np.max(f_local[:,:,0]), width).astype(np.int32)
41 | min_h = np.minimum(np.min(f_local[:,:,1]), 0).astype(np.int32)
42 | max_h = np.maximum(np.max(f_local[:,:,1]), height).astype(np.int32)
43 | cw = max_w - min_w
44 | ch = max_h - min_h
45 |
46 | pic = np.ones([ch+10, cw+10, 3], np.int32)*255
47 | pic[0-min_h+5:0-min_h+height+5, 0-min_w+5:0-min_w+width+5, :] = warp
48 |
49 | warp = pic
50 | f_local[:,:,0] = f_local[:,:,0] - min_w+5
51 | f_local[:,:,1] = f_local[:,:,1] - min_h+5
52 |
53 | point_color = (0, 255, 0)
54 | thickness = 2
55 | lineType = 8
56 | num = 1
57 | for i in range(grid_h+1):
58 | for j in range(grid_w+1):
59 | num = num + 1
60 | if j == grid_w and i == grid_h:
61 | continue
62 | elif j == grid_w:
63 | cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)
64 | elif i == grid_h:
65 | cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)
66 | else :
67 | cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i+1,j,0]), int(f_local[i+1,j,1])), point_color, thickness, lineType)
68 | cv2.line(warp, (int(f_local[i,j,0]), int(f_local[i,j,1])), (int(f_local[i,j+1,0]), int(f_local[i,j+1,1])), point_color, thickness, lineType)
69 |
70 | return warp
71 |
72 | def data_aug(img, gt):
73 | oplist = []
74 | if random.random() > 0.5:
75 | oplist.append(iaa.GaussianBlur(sigma=(0.0, 1.0)))
76 | elif random.random() > 0.5:
77 | oplist.append(iaa.WithChannels(0, iaa.Add((1, 15))))
78 | elif random.random() > 0.5:
79 | oplist.append(iaa.WithChannels(1, iaa.Add((1, 15))))
80 | elif random.random() > 0.5:
81 | oplist.append(iaa.WithChannels(2, iaa.Add((1, 15))))
82 | elif random.random() > 0.5:
83 | oplist.append(iaa.AdditiveGaussianNoise(scale=(0, 10)))
84 | elif random.random() > 0.5:
85 | oplist.append(iaa.Sharpen(alpha=0.15))
86 |
87 | seq = iaa.Sequential(oplist)
88 | images_aug = seq.augment_images([img])
89 | gt_aug = seq.augment_images([gt])
90 | return images_aug[0], gt_aug[0]
91 |
92 | def get_weight_mask(mask, gt, pred, weight=10):
93 | mask = (mask * (weight - 1)) + 1
94 | gt = gt.mul(mask)
95 | pred = pred.mul(mask)
96 | return gt, pred
97 |
98 | def adjust_weight(epoch, total_epoch, weight):
99 | return (1 - 0.9 * (epoch / total_epoch)) * weight
100 |
101 | def flow2list(flow):
102 | h, w, c = flow.shape
103 | dirs_grid = []
104 | for i in range(h):
105 | dirs_row = []
106 | for j in range(w):
107 | dx, dy = flow[i, j, :]
108 | dx = np.round(dx)
109 | dy = np.round(dy)
110 | dirs_row.append([-int(dx), -int(dy)])
111 | dirs_grid.append(dirs_row)
112 | return dirs_grid
--------------------------------------------------------------------------------
/model/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import utils.torch_tps_upsample as torch_tps_upsample
3 | from utils.utils_transform import *
4 | import torch.nn.functional as F
5 |
6 | def build_model(net, input_tensor1, input_tensor2, mask_tensor, tps_points):
7 | """
8 | input_tensor1: source image with original resolution
9 | input_tensor2: resized image with fixed resolution (256x256)
10 | .. input_tensor1 == input_tensor2 in training
11 | """
12 | batch_size, _, img_h, img_w = input_tensor1.size()
13 | batch_size, _, input_size, input_size = input_tensor2.size()
14 |
15 | offset, flow, point_cls = net(input_tensor2, mask_tensor)
16 | head_num = len(offset)
17 | norm_rigid_mesh_list = []
18 | norm_ori_mesh_list = []
19 | output_tps_list = []
20 | ori_mesh_list = []
21 | tps2flow_list = []
22 |
23 | for i in range(head_num):
24 | mesh_motion = offset[i].reshape(-1, tps_points[i], tps_points[i], 2)
25 | rigid_mesh = get_rigid_mesh(batch_size, input_size, input_size, tps_points[i]-1, tps_points[i]-1)
26 | ori_mesh = rigid_mesh + mesh_motion
27 | clamped_x = torch.clamp(ori_mesh[..., 0], min=0, max=input_size - 1)
28 | clamped_y = torch.clamp(ori_mesh[..., 1], min=0, max=input_size - 1)
29 | ori_mesh = torch.stack((clamped_x, clamped_y), dim=-1)
30 |
31 | norm_rigid_mesh = get_norm_mesh(rigid_mesh, input_size, input_size)
32 | norm_ori_mesh = get_norm_mesh(ori_mesh, input_size, input_size)
33 | tps2flow = torch_tps_upsample.transformer(norm_rigid_mesh, norm_ori_mesh, (img_h, img_w))
34 | output_tps = resample_image(input_tensor1, tps2flow)
35 |
36 | norm_rigid_mesh_list.append(norm_rigid_mesh)
37 | norm_ori_mesh_list.append(norm_ori_mesh)
38 | output_tps_list.append(output_tps)
39 | tps2flow_list.append(tps2flow)
40 | ori_mesh_list.append(ori_mesh)
41 |
42 | tps_flow = tps2flow_list[-1]
43 | final_flow = flow + tps_flow
44 | output_flow = resample_image(output_tps_list[-1], flow)
45 |
46 | out_dict = {}
47 | out_dict.update(warp_tps=output_tps_list, warp_flow=output_flow, mesh=ori_mesh_list,
48 | flow1=flow, flow2=tps_flow, flow3=final_flow, point_cls=point_cls)
49 | return out_dict
50 |
51 |
52 | def build_model_test(net, input_tensor1, input_tensor2, mask_tensor, tps_points, resize_flow=False):
53 | """
54 | input_tensor1: source image with original resolution
55 | input_tensor2: resized image with fixed resolution (256x256)
56 | .. input_tensor1 = input_tensor2 in training
57 | """
58 | batch_size, _, img_h, img_w = input_tensor1.size()
59 | batch_size, _, input_size, input_size = input_tensor2.size()
60 |
61 | offset, flow, point_cls = net(input_tensor2, mask_tensor)
62 | head_num = len(offset)
63 | norm_rigid_mesh_list = []
64 | norm_ori_mesh_list = []
65 | output_tps_list = []
66 | ori_mesh_list = []
67 | tps2flow_list = []
68 | for i in range(head_num):
69 | mesh_motion = offset[i].reshape(-1, tps_points[i], tps_points[i], 2)
70 | rigid_mesh = get_rigid_mesh(batch_size, input_size, input_size, tps_points[i]-1, tps_points[i]-1)
71 | ori_mesh = rigid_mesh + mesh_motion
72 | clamped_x = torch.clamp(ori_mesh[..., 0], min=0, max=input_size - 1)
73 | clamped_y = torch.clamp(ori_mesh[..., 1], min=0, max=input_size - 1)
74 | ori_mesh = torch.stack((clamped_x, clamped_y), dim=-1)
75 |
76 | norm_rigid_mesh = get_norm_mesh(rigid_mesh, input_size, input_size)
77 | norm_ori_mesh = get_norm_mesh(ori_mesh, input_size, input_size)
78 | tps2flow = torch_tps_upsample.transformer(norm_rigid_mesh, norm_ori_mesh, (img_h, img_w))
79 | output_tps = resample_image_xy(input_tensor1, tps2flow)
80 | norm_rigid_mesh_list.append(norm_rigid_mesh)
81 | norm_ori_mesh_list.append(norm_ori_mesh)
82 | output_tps_list.append(output_tps)
83 | ori_mesh_list.append(ori_mesh)
84 | tps2flow_list.append(tps2flow)
85 |
86 | tps_flow = tps2flow_list[-1]
87 | if(resize_flow):
88 | flow = F.interpolate(flow, size=(img_h, img_w), mode='bilinear', align_corners=True)
89 | scale_H, scale_W = img_h / input_size, img_w / input_size
90 | flow[:, 0, :, :] *= scale_W
91 | flow[:, 1, :, :] *= scale_H
92 |
93 | final_flow = flow + tps_flow
94 | output_flow = resample_image_xy(output_tps_list[-1], flow)
95 | out_dict = {}
96 | out_dict.update(warp_tps=output_tps_list, warp_flow=output_flow, mesh=ori_mesh_list,
97 | flow1=flow, flow2=tps_flow, flow3=final_flow, point_cls=point_cls)
98 | return out_dict
99 |
--------------------------------------------------------------------------------
/utils/utils_transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import utils.torch_tps_transform as torch_tps_transform
5 | import utils.torch_tps_upsample as torch_tps_upsample
6 |
7 | def get_rigid_mesh(batch_size, height, width, grid_w, grid_h):
8 | ww = torch.matmul(torch.ones([grid_h+1, 1]), torch.unsqueeze(torch.linspace(0., float(width), grid_w+1), 0))
9 | hh = torch.matmul(torch.unsqueeze(torch.linspace(0.0, float(height), grid_h+1), 1), torch.ones([1, grid_w+1]))
10 | if torch.cuda.is_available():
11 | ww = ww.cuda()
12 | hh = hh.cuda()
13 |
14 | ori_pt = torch.cat((ww.unsqueeze(2), hh.unsqueeze(2)),2)
15 | ori_pt = ori_pt.unsqueeze(0).expand(batch_size, -1, -1, -1)
16 |
17 | return ori_pt
18 |
19 | def get_norm_mesh(mesh, height, width):
20 | batch_size = mesh.size()[0]
21 | mesh_w = mesh[...,0]*2./float(width) - 1.
22 | mesh_h = mesh[...,1]*2./float(height) - 1.
23 | norm_mesh = torch.stack([mesh_w, mesh_h], 3)
24 |
25 | return norm_mesh.reshape([batch_size, -1, 2])
26 |
27 | def transform_tps_fea(offset, input_tensor, grid_w, grid_h, dim, h, w):
28 | input_tensor = input_tensor.permute(0,2,1).view(-1, dim, h, w)
29 | batch_size, _, img_h, img_w = input_tensor.size()
30 |
31 | mesh_motion = offset.reshape(-1, grid_h+1, grid_w+1, 2)
32 |
33 | rigid_mesh = get_rigid_mesh(batch_size, img_h, img_w, grid_w, grid_h)
34 | ori_mesh = rigid_mesh + mesh_motion
35 |
36 | clamped_x = torch.clamp(ori_mesh[..., 0], min=0, max=img_h - 1)
37 | clamped_y = torch.clamp(ori_mesh[..., 1], min=0, max=img_w - 1)
38 | ori_mesh = torch.stack((clamped_x, clamped_y), dim=-1)
39 |
40 | norm_rigid_mesh = get_norm_mesh(rigid_mesh, img_h, img_w)
41 | norm_ori_mesh = get_norm_mesh(ori_mesh, img_h, img_w)
42 |
43 | output_tps = torch_tps_transform.transformer(input_tensor, norm_rigid_mesh, norm_ori_mesh, (img_h, img_w))
44 | output_tps = output_tps.view(-1, dim, h*w).permute(0,2,1)
45 |
46 | return output_tps
47 |
48 | def upsample_tps(offset, grid_w, grid_h, out_h, out_w):
49 | if(grid_w+1 == out_w):
50 | return offset
51 |
52 | else:
53 | batch_size, *_ = offset.size()
54 | mesh_motion = offset.reshape(-1, grid_h+1, grid_w+1, 2)
55 |
56 | rigid_mesh = get_rigid_mesh(batch_size, out_h, out_w, grid_w, grid_h)
57 | ori_mesh = rigid_mesh + mesh_motion
58 |
59 | norm_rigid_mesh = get_norm_mesh(rigid_mesh, out_h, out_w)
60 | norm_ori_mesh = get_norm_mesh(ori_mesh, out_h, out_w)
61 |
62 | up_points = torch_tps_upsample.transformer(norm_rigid_mesh, norm_ori_mesh, (out_h, out_w))
63 | out = up_points.permute(0, 2, 3, 1).view(-1, out_h*out_w, 2)
64 |
65 | return out
66 |
67 | def get_coordinate(shape, det_uv):
68 | b, _, w, h = shape
69 | uv_d = np.zeros([w, h, 2], np.float32)
70 |
71 | for i in range(0, w):
72 | for j in range(0, h):
73 | uv_d[i, j, 0] = j
74 | uv_d[i, j, 1] = i
75 |
76 | uv_d = np.expand_dims(uv_d.swapaxes(2, 1).swapaxes(1, 0), 0)
77 | uv_d = torch.from_numpy(uv_d).cuda()
78 | uv_d = uv_d.repeat(b, 1, 1, 1)
79 |
80 | det_uv = uv_d + det_uv
81 | return det_uv
82 |
83 | def uniform(shape, img_uv):
84 | b, _, w, h = shape
85 | x0 = (w - 1) / 2.
86 |
87 | img_nor = (img_uv - x0)/x0
88 | img_nor = img_nor.permute(0, 2, 3, 1)
89 | return img_nor
90 |
91 | def resample_image(feature, flow):
92 | img_uv = get_coordinate(feature.shape, flow)
93 | grid = uniform(feature.shape, img_uv)
94 | target_image = F.grid_sample(feature, grid)
95 | return target_image
96 |
97 | def get_coordinate_xy(shape, det_uv):
98 | b, _, h, w = shape
99 | uv_d = np.zeros([h, w, 2], np.float32)
100 |
101 | for j in range(0, h):
102 | for i in range(0, w):
103 | uv_d[j, i, 0] = i
104 | uv_d[j, i, 1] = j
105 |
106 | uv_d = np.expand_dims(uv_d.swapaxes(2, 1).swapaxes(1, 0), 0)
107 | uv_d = torch.from_numpy(uv_d).cuda()
108 | uv_d = uv_d.repeat(b, 1, 1, 1)
109 | det_uv = uv_d + det_uv
110 | return det_uv
111 |
112 | def uniform_xy(shape, uv):
113 | b, _, h, w = shape
114 | y0 = (h - 1) / 2.
115 | x0 = (w - 1) / 2.
116 |
117 | nor = uv.clone()
118 | nor[:, 0, :, :] = (uv[:, 0, :, :] - x0) / x0
119 | nor[:, 1, :, :] = (uv[:, 1, :, :] - y0) / y0
120 | nor = nor.permute(0, 2, 3, 1) # b w h 2
121 |
122 | return nor
123 |
124 | def resample_image_xy(feature, flow):
125 | uv = get_coordinate_xy(feature.shape, flow)
126 | grid = uniform_xy(feature.shape, uv)
127 | target_image = F.grid_sample(feature, grid)
128 | return target_image
--------------------------------------------------------------------------------
/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/utils/utils_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | class AddCoords(nn.Module):
7 | def __init__(self):
8 | super(AddCoords, self).__init__()
9 |
10 | def forward(self, input_tensor):
11 | batch_size, _, height, width = input_tensor.size()
12 |
13 | xx_channel = torch.arange(width).repeat(1, height, 1)
14 | yy_channel = torch.arange(height).repeat(1, width, 1).transpose(1, 2)
15 |
16 | xx_channel = xx_channel.float() / (width - 1)
17 | yy_channel = yy_channel.float() / (height - 1)
18 |
19 | xx_channel = xx_channel * 2 - 1
20 | yy_channel = yy_channel * 2 - 1
21 |
22 | xx_channel = xx_channel.repeat(batch_size, 1, 1, 1)
23 | yy_channel = yy_channel.repeat(batch_size, 1, 1, 1)
24 |
25 | input_tensor = torch.cat([
26 | input_tensor,
27 | xx_channel.type_as(input_tensor),
28 | yy_channel.type_as(input_tensor)], dim=1)
29 |
30 | return input_tensor
31 |
32 | class CoordConv(nn.Module):
33 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
34 | super(CoordConv, self).__init__()
35 | self.addcoords = AddCoords()
36 | self.conv = nn.Conv2d(in_channels+2, out_channels, kernel_size, stride, padding)
37 |
38 | def forward(self, x):
39 | x = self.addcoords(x)
40 | x = self.conv(x)
41 | return x
42 |
43 | class MotionNet_Coord(nn.Module):
44 | def __init__(self, in_channel, out_channel, num):
45 | super(MotionNet_Coord, self).__init__()
46 | if(num==8):
47 | self.conv = nn.Sequential(
48 | CoordConv(in_channel, out_channel, kernel_size=3, stride=2, padding=1),
49 | )
50 | if(num==10):
51 | self.conv = nn.Sequential(
52 | CoordConv(in_channel, 64, kernel_size=5, stride=1, padding=0),
53 | CoordConv(64, out_channel, kernel_size=3, stride=1, padding=0),
54 | )
55 | if(num==12):
56 | self.conv = nn.Sequential(
57 | CoordConv(in_channel, out_channel, kernel_size=5, stride=1, padding=0),
58 | )
59 | if(num==14):
60 | self.conv = nn.Sequential(
61 | CoordConv(in_channel, out_channel, kernel_size=3, stride=1, padding=0),
62 | )
63 | if(num==16):
64 | self.conv = nn.Sequential(
65 | CoordConv(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
66 | )
67 |
68 | def forward(self, x):
69 | B, L, C = x.shape
70 | H = int(math.sqrt(L))
71 | W = int(math.sqrt(L))
72 | x = x.transpose(1, 2).contiguous().view(B, C, H, W)
73 | out = self.conv(x).flatten(2).transpose(1, 2).contiguous()
74 | return out
75 |
76 | class PointNet(nn.Module):
77 | def __init__(self, num_classes=4, grid_h=12, grid_w=12):
78 | super(PointNet, self).__init__()
79 | self.conv1 = nn.Conv1d(8, 256, 1)
80 | self.conv2 = nn.Conv1d(256, 256, 1)
81 | self.conv3 = nn.Conv1d(256, 512, 1)
82 | self.fc1 = nn.Linear(512, 512)
83 | self.fc2 = nn.Linear(512, 256)
84 | self.fc3 = nn.Linear(256, num_classes)
85 | self.dropout1 = nn.Dropout(p=0.3)
86 | self.dropout2 = nn.Dropout(p=0.3)
87 | self.h = grid_h
88 | self.w = grid_w
89 | self.fc_fea = nn.Linear(16*16, 6)
90 |
91 | def forward(self, pre, fea):
92 | '''pre-processing the predicted points'''
93 | x = pre.reshape(-1, 2, (self.h*self.w))
94 |
95 | '''pre-processing the prompt features and form the superpoint'''
96 | fea = nn.MaxPool1d(fea.size(-1))(fea).squeeze(-1)
97 | fea = F.relu(self.fc_fea(fea))
98 | fea = fea.unsqueeze(-1).repeat(1, 1, self.h*self.w)
99 | superpoint = torch.cat((x, fea), dim=1)
100 |
101 | '''learn the superpoints' features'''
102 | x = F.relu(self.conv1(superpoint))
103 | x = F.relu(self.conv2(x))
104 | x = F.relu(self.conv3(x))
105 | x = nn.MaxPool1d(x.size(-1))(x)
106 | x = x.view(-1, 512)
107 |
108 | '''classification'''
109 | x = F.relu(self.fc1(x))
110 | x = self.dropout1(x)
111 | x = F.relu(self.fc2(x))
112 | x = self.dropout2(x)
113 | x = self.fc3(x)
114 | return x
115 |
116 | class PromptGenBlock(nn.Module):
117 | def __init__(self,prompt_dim=128,prompt_len=5,prompt_size = 96,lin_dim = 192):
118 | super(PromptGenBlock,self).__init__()
119 | self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))
120 | self.linear_layer = nn.Linear(lin_dim,prompt_len)
121 | self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)
122 |
123 | def forward(self,x):
124 | B, C, H, W = x.shape
125 | emb = x.mean(dim=(-2,-1))
126 | prompt_weights = F.softmax(self.linear_layer(emb),dim=1)
127 | prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
128 | prompt = torch.sum(prompt,dim=1)
129 | prompt = F.interpolate(prompt,(H,W),mode="bilinear")
130 | prompt = self.conv3x3(prompt)
131 | return prompt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MOWA: Multiple-in-One Image Warping Model
2 |
3 | ## Introduction
4 | This is the official implementation for [MOWA](https://arxiv.org/abs/2404.10716) (TPAMI 2025).
5 |
6 | [Kang Liao](https://kangliao929.github.io/), [Zongsheng Yue](https://zsyoaoa.github.io/), [Zhonghua Wu](https://wu-zhonghua.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
7 |
8 | S-Lab, Nanyang Technological University
9 |
10 |
11 |
12 |

13 |
14 |
15 | > ### Why MOWA?
16 | > MOWA is a practical multiple-in-one image warping framework, particularly in computational photography, where six distinct tasks are considered. Compared to previous works tailored to specific tasks, our method can solve various warping tasks from different camera models or manipulation spaces in a single framework. It also demonstrates an ability to generalize to novel scenarios, as evidenced in both cross-domain and zero-shot evaluations.
17 | > ### Features
18 | > * The first practical multiple-in-one image warping framework especially in the field of computational photography.
19 | > * We propose to mitigate the difficulty of multi-task learning by decoupling the motion estimation in both the region level and pixel level.
20 | > * A prompt learning module, guided by a lightweight point-based classifier, is designed to facilitate task-aware image warpings.
21 | > * We show that through multi-task learning, our framework develops a robust generalized warping strategy that gains improved performance across various tasks and even generalizes to unseen tasks.
22 |
23 | Check out more visual results and interactions [here](https://kangliao929.github.io/projects/mowa/).
24 |
25 | ## 📢 News
26 | - Our recent work **Puffin** can unify the camera-centric understanding (camera calibration, pose estimation) and generation (camera-controllable T2I and I2I generation) within a cohesive multimodal framework. It enables more precise understanding and generation performance by our proposed *thinking with camera*, and provides insights on the meaningful mutual effect among multimodal tasks. If you are interested in the camera-related 3D vision, photography, embodied AI, and spatial intelligence, please check out more details [here](https://kangliao929.github.io/projects/puffin/).
27 | - MOWA has been included in [AI Art Weekly #80](https://aiartweekly.com/issues/80).
28 |
29 | ## 📝 Changelog
30 |
31 | - [x] 2024.04.16: The paper of the arXiv version is online.
32 | - [x] 2024.06.30: Release the code and pre-trained model.
33 | - [x] 2025.05.01: MOWA has been accepted to IEEE TPAMI.
34 | - [ ] Release a demo for users to try MOWA online.
35 | - [ ] Release an interactive interface to drag the control points and perform customized warpings.
36 |
37 | ## Installation
38 | Using the virtual environment (conda) to run the code is recommended.
39 | ```
40 | conda create -n mowa python=3.8.13
41 | conda activate mowa
42 | pip install -r requirements.txt
43 | ```
44 |
45 | ## Dataset
46 | We mainly explored six representative image warping tasks in this work. The datasets are derived/constructed from previous works. For the convenience of training and testing in one project, we cleaned and arranged these six types of datasets with unified structures and more visual assistance. Please refer to the category and download links in [Datasets](https://github.com/KangLiao929/MOWA/tree/main/Datasets).
47 |
48 | ## Pretrained Model
49 | Download the pretrained model from [Google Drive](https://drive.google.com/file/d/1fxQbD1TLoRnW8lG2a8KMinmD6Jlol8EX/view?usp=drive_link) or [Baidu Netdisk](https://pan.baidu.com/s/1swMZTkTm1iSYDGVsdepdBA?pwd=hcvy), and put it into the ```.\checkpoint``` folder.
50 |
51 | ## Testing
52 | ### Unified Warping and Evaluation on Public Benchmark
53 | Customize the paths of the checkpoint and test set, and run:
54 | ```
55 | sh scripts/test.sh
56 | ```
57 | The warped images and the intermediate results such as the control points and warping flow can be found in the ```.\results``` folder. The evaluated metrics such as PSNR and SSIM are also shown with the task ID.
58 |
59 | ### Specific Evaluation on Portrait Correction
60 | In the portrait correction task, the ground truth of warped image and flow is unavailable and thus the image quality metrics cannot be evaluated. Instead, the specific metric (ShapeAcc) regarding this task's purpose, i.e., correcting the face distortion, was presented. To reproduce the warping performance on portrait photos, customize the paths of checkpoint and test set, and run:
61 | ```
62 | sh scripts/test_portrait.sh
63 | ```
64 | The warped images can also be found in the test path.
65 |
66 | ## Training
67 | Customize the paths of all warping training datasets in a list, and run:
68 | ```
69 | sh scripts/train.sh
70 | ```
71 | ## Projects that use MOWA
72 | * MOWA-onnxrun (ONNX Runtime Implementation): [https://github.com/hpc203/MOWA-onnxrun](https://github.com/hpc203/MOWA-onnxrun)
73 |
74 | ## Demo
75 | TBD
76 |
77 | ## Acknowledgment
78 | The current version of **MOWA** is inspired by previous specific image warping works such as [RectanglingPano](https://people.csail.mit.edu/kaiming/publications/sig13pano.pdf), [DeepRectangling](https://github.com/nie-lang/DeepRectangling), [RecRecNet](https://github.com/KangLiao929/RecRecNet), [PCN](https://github.com/uof1745-cmd/PCN), [Deep_RS-HM](https://github.com/DavidYan2001/Deep_RS-HM), [SSPC](https://github.com/megvii-research/Portraits_Correction).
79 |
80 | ## Citation
81 |
82 | ```bibtex
83 | @article{liao2024mowa,
84 | title={MOWA: Multiple-in-One Image Warping Model},
85 | author={Liao, Kang and Yue, Zongsheng and Wu, Zhonghua and Loy, Chen Change},
86 | journal={arXiv preprint arXiv:2404.10716},
87 | year={2024}
88 | }
89 | ```
90 |
91 | ## Contact
92 | For any questions, feel free to email `kang.liao@ntu.edu.sg`.
93 |
94 | ## License
95 | This project is licensed under [NTU S-Lab License 1.0](LICENSE).
96 |
--------------------------------------------------------------------------------
/dataset_loaders.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import numpy as np
3 | import cv2, torch
4 | import os
5 | import glob
6 | from collections import OrderedDict
7 | from utils.utils_op import data_aug
8 |
9 | class TrainDataset(Dataset):
10 | def __init__(self, paths):
11 |
12 | self.width = 256
13 | self.height = 256
14 | self.prob = 0.5
15 | self.input_images = []
16 | self.gt_images = []
17 | self.masks = []
18 | self.flows = []
19 | self.task_id = []
20 | for index, path in enumerate(paths):
21 | inputs = glob.glob(os.path.join(path, 'input/', '*.*'))
22 | gts = glob.glob(os.path.join(path, 'gt/', '*.*'))
23 | masks = glob.glob(os.path.join(path, 'mask/', '*.*'))
24 | flows = glob.glob(os.path.join(path, 'flow_npy/', '*.*'))
25 | inputs.sort()
26 | gts.sort()
27 | masks.sort()
28 | flows.sort()
29 |
30 | lens = len(inputs)
31 | index_array = [index] * lens
32 | self.task_id.extend(index_array)
33 | self.input_images.extend(inputs)
34 | self.gt_images.extend(gts)
35 | self.masks.extend(masks)
36 | self.flows.extend(flows)
37 |
38 | print("total dataset num: ", len(self.input_images))
39 |
40 | def __getitem__(self, index):
41 |
42 | '''load images'''
43 | task_id = self.task_id[index]
44 | input_src = cv2.imread(self.input_images[index])
45 | input_resized = cv2.resize(input_src, (self.width, self.height))
46 | gt = cv2.imread(self.gt_images[index])
47 | gt = cv2.resize(gt, (self.width, self.height))
48 | input_resized, gt = data_aug(input_resized, gt)
49 |
50 | input_resized = input_resized.astype(dtype=np.float32)
51 | input_resized = input_resized / 255.0
52 | input_resized = np.transpose(input_resized, [2, 0, 1])
53 | gt = gt.astype(dtype=np.float32)
54 | gt = gt / 255.0
55 | gt = np.transpose(gt, [2, 0, 1])
56 |
57 | '''load mask'''
58 | mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE)
59 | mask = cv2.resize(mask, (self.width, self.height))
60 | mask = np.expand_dims(mask, axis=-1)
61 | mask = mask.astype(dtype=np.float32)
62 | mask = mask / 255.0
63 | mask = np.transpose(mask, [2, 0, 1])
64 | mask_tensor = torch.tensor(mask)
65 |
66 | input_tensor = torch.tensor(input_resized)
67 | gt_tensor = torch.tensor(gt)
68 | task_id_tensor = torch.tensor(task_id, dtype=torch.int64)
69 |
70 | '''load flow and face mask for the portrait task'''
71 | if(task_id == 5):
72 | flow = np.load(self.flows[index])
73 | flow = flow.astype(dtype=np.float32)
74 | flow = np.transpose(flow, [2, 0, 1])
75 | flow_tensor = torch.tensor(flow)
76 |
77 | face_mask_path = self.input_images[index].replace('/input/', '/mask_face/')
78 | facemask = cv2.imread(face_mask_path, 0)
79 | facemask = facemask.astype(dtype=np.float32)
80 | facemask = (facemask / 255.0)
81 | facemask = np.expand_dims(facemask, axis=-1)
82 | facemask = np.transpose(facemask, [2, 0, 1])
83 | face_mask = torch.tensor(facemask)
84 | mask_sum = torch.sum(face_mask)
85 | weight = self.width * self.height / mask_sum - 1
86 | weight = torch.max(weight / 3, torch.ones(1))
87 | face_weight = weight.unsqueeze(-1).unsqueeze(-1)
88 | else:
89 | flow_tensor = torch.zeros_like(input_tensor[0:2, :, :])
90 | face_mask = torch.zeros_like(mask_tensor)
91 | face_weight = torch.mean(face_mask).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
92 |
93 | return (input_tensor, gt_tensor, mask_tensor, task_id_tensor, flow_tensor, face_mask, face_weight)
94 |
95 | def __len__(self):
96 | return len(self.input_images)
97 |
98 | class TestDatasetMask(Dataset):
99 | def __init__(self, data_path, task_id):
100 | self.width = 256
101 | self.height = 256
102 | self.test_path = data_path
103 | self.datas = OrderedDict()
104 | self.task_id = task_id
105 |
106 | datas = glob.glob(os.path.join(self.test_path, '*'))
107 | for data in sorted(datas):
108 | data_name = data.split('/')[-1]
109 | if data_name == 'input' or data_name == 'gt' or data_name == 'mask':
110 | self.datas[data_name] = {}
111 | self.datas[data_name]['path'] = data
112 | self.datas[data_name]['image'] = glob.glob(os.path.join(data, '*.*'))
113 | self.datas[data_name]['image'].sort()
114 |
115 | def __getitem__(self, index):
116 |
117 | input = cv2.imread(self.datas['input']['image'][index])
118 | input1 = input.astype(dtype=np.float32)
119 | input1 = input1 / 255.0
120 | input1 = np.transpose(input1, [2, 0, 1])
121 |
122 | input2 = cv2.resize(input, (self.width, self.height))
123 | input2 = input2.astype(dtype=np.float32)
124 | input2 = input2 / 255.0
125 | input2 = np.transpose(input2, [2, 0, 1])
126 |
127 | gt = cv2.imread(self.datas['gt']['image'][index])
128 | gt1 = gt.astype(dtype=np.float32)
129 | gt1 = gt1 / 255.0
130 | gt1 = np.transpose(gt1, [2, 0, 1])
131 |
132 | gt2 = cv2.resize(gt, (self.width, self.height))
133 | gt2 = gt2.astype(dtype=np.float32)
134 | gt2 = gt2 / 255.0
135 | gt2 = np.transpose(gt2, [2, 0, 1])
136 |
137 | mask = cv2.imread(self.datas['mask']['image'][index], cv2.IMREAD_GRAYSCALE)
138 | mask = cv2.resize(mask, (self.width, self.height))
139 | mask = np.expand_dims(mask, axis=-1)
140 | mask = mask.astype(dtype=np.float32)
141 | mask = mask / 255.0
142 | mask = np.transpose(mask, [2, 0, 1])
143 |
144 | input1_tensor = torch.tensor(input1)
145 | input2_tensor = torch.tensor(input2)
146 | gt1_tensor = torch.tensor(gt1)
147 | gt2_tensor = torch.tensor(gt2)
148 | mask_tensor = torch.tensor(mask)
149 | task_id_tensor = torch.tensor(self.task_id, dtype=torch.int64)
150 | file_name = os.path.basename(self.datas['input']['image'][index])
151 | file_name, _ = os.path.splitext(file_name)
152 |
153 | out_dict = {}
154 | out_dict.update(input1_tensor=input1_tensor, input2_tensor=input2_tensor, gt1_tensor=gt1_tensor,
155 | gt2_tensor=gt2_tensor, mask_tensor=mask_tensor, task_id_tensor=task_id_tensor, file_name=file_name)
156 |
157 | return out_dict
158 |
159 | def __len__(self):
160 | return len(self.datas['input']['image'])
--------------------------------------------------------------------------------
/utils/torch_tps_transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 | import torch.nn.functional as F
5 |
6 | def transformer(U, source, target, out_size):
7 | """
8 | Thin Plate Spline Spatial Transformer Layer
9 | TPS control points are arranged in arbitrary positions given by `source`.
10 | U : float Tensor [num_batch, height, width, num_channels].
11 | Input Tensor.
12 | source : float Tensor [num_batch, num_point, 2]
13 | The source position of the control points.
14 | target : float Tensor [num_batch, num_point, 2]
15 | The target position of the control points.
16 | out_size: tuple of two integers [height, width]
17 | The size of the output of the network (height, width)
18 | """
19 |
20 | def _repeat(x, n_repeats):
21 | rep = torch.ones([n_repeats, ]).unsqueeze(0)
22 | rep = rep.int()
23 | x = x.int()
24 |
25 | x = torch.matmul(x.reshape([-1,1]), rep)
26 | return x.reshape([-1])
27 |
28 | def _interpolate(im, x, y, out_size):
29 | num_batch, num_channels , height, width = im.size()
30 |
31 | height_f = height
32 | width_f = width
33 | out_height, out_width = out_size[0], out_size[1]
34 |
35 | zero = 0
36 | max_y = height - 1
37 | max_x = width - 1
38 |
39 | x = (x + 1.0)*(width_f) / 2.0
40 | y = (y + 1.0) * (height_f) / 2.0
41 |
42 | # sampling
43 | x0 = torch.floor(x).int()
44 | x1 = x0 + 1
45 | y0 = torch.floor(y).int()
46 | y1 = y0 + 1
47 |
48 | x0 = torch.clamp(x0, zero, max_x)
49 | x1 = torch.clamp(x1, zero, max_x)
50 | y0 = torch.clamp(y0, zero, max_y)
51 | y1 = torch.clamp(y1, zero, max_y)
52 | dim2 = torch.from_numpy( np.array(width) )
53 | dim1 = torch.from_numpy( np.array(width * height) )
54 |
55 | base = _repeat(torch.arange(0,num_batch) * dim1, out_height * out_width)
56 | if torch.cuda.is_available():
57 | dim2 = dim2.cuda()
58 | dim1 = dim1.cuda()
59 | y0 = y0.cuda()
60 | y1 = y1.cuda()
61 | x0 = x0.cuda()
62 | x1 = x1.cuda()
63 | base = base.cuda()
64 | base_y0 = base + y0 * dim2
65 | base_y1 = base + y1 * dim2
66 | idx_a = base_y0 + x0
67 | idx_b = base_y1 + x0
68 | idx_c = base_y0 + x1
69 | idx_d = base_y1 + x1
70 |
71 | # channels dim
72 | im = im.permute(0,2,3,1)
73 | im_flat = im.reshape([-1, num_channels]).float()
74 |
75 |
76 | idx_a = idx_a.unsqueeze(-1).long()
77 | idx_a = idx_a.expand(out_height * out_width * num_batch,num_channels)
78 | Ia = torch.gather(im_flat, 0, idx_a)
79 |
80 | idx_b = idx_b.unsqueeze(-1).long()
81 | idx_b = idx_b.expand(out_height * out_width * num_batch, num_channels)
82 | Ib = torch.gather(im_flat, 0, idx_b)
83 |
84 | idx_c = idx_c.unsqueeze(-1).long()
85 | idx_c = idx_c.expand(out_height * out_width * num_batch, num_channels)
86 | Ic = torch.gather(im_flat, 0, idx_c)
87 |
88 | idx_d = idx_d.unsqueeze(-1).long()
89 | idx_d = idx_d.expand(out_height * out_width * num_batch, num_channels)
90 | Id = torch.gather(im_flat, 0, idx_d)
91 |
92 | x0_f = x0.float()
93 | x1_f = x1.float()
94 | y0_f = y0.float()
95 | y1_f = y1.float()
96 |
97 | wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)
98 | wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)
99 | wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)
100 | wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)
101 | output = wa*Ia+wb*Ib+wc*Ic+wd*Id
102 |
103 | return output
104 |
105 | def _meshgrid(height, width, source):
106 |
107 | x_t = torch.matmul(torch.ones([height, 1]), torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 0))
108 | y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1), torch.ones([1, width]))
109 | if torch.cuda.is_available():
110 | x_t = x_t.cuda()
111 | y_t = y_t.cuda()
112 |
113 | x_t_flat = x_t.reshape([1, 1, -1])
114 | y_t_flat = y_t.reshape([1, 1, -1])
115 |
116 | num_batch = source.size()[0]
117 | px = torch.unsqueeze(source[:,:,0], 2)
118 | py = torch.unsqueeze(source[:,:,1], 2)
119 | if torch.cuda.is_available():
120 | px = px.cuda()
121 | py = py.cuda()
122 | d2 = torch.square(x_t_flat - px) + torch.square(y_t_flat - py)
123 | r = d2 * torch.log(d2 + 1e-6)
124 | x_t_flat_g = x_t_flat.expand(num_batch, -1, -1)
125 | y_t_flat_g = y_t_flat.expand(num_batch, -1, -1)
126 | ones = torch.ones_like(x_t_flat_g)
127 | if torch.cuda.is_available():
128 | ones = ones.cuda()
129 |
130 | grid = torch.cat((ones, x_t_flat_g, y_t_flat_g, r), 1)
131 |
132 | return grid
133 |
134 | def _transform(T, source, input_dim, out_size):
135 | num_batch, num_channels, height, width = input_dim.size()
136 |
137 | out_height, out_width = out_size[0], out_size[1]
138 | grid = _meshgrid(out_height, out_width, source)
139 |
140 | # transform A x (1, x_t, y_t, r1, r2, ..., rn) -> (x_s, y_s)
141 | # [bn, 2, pn+3] x [bn, pn+3, h*w] -> [bn, 2, h*w]
142 | T_g = torch.matmul(T, grid)
143 | x_s = T_g[:,0,:]
144 | y_s = T_g[:,1,:]
145 | x_s_flat = x_s.reshape([-1])
146 | y_s_flat = y_s.reshape([-1])
147 |
148 | input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat,out_size)
149 |
150 | output = input_transformed.reshape([num_batch, out_height, out_width, num_channels])
151 |
152 | output = output.permute(0,3,1,2)
153 | return output
154 |
155 | def _solve_system(source, target):
156 | num_batch = source.size()[0]
157 | num_point = source.size()[1]
158 |
159 | ones = torch.ones(num_batch, num_point, 1).float()
160 | if torch.cuda.is_available():
161 | ones = ones.cuda()
162 | p = torch.cat([ones, source], 2)
163 |
164 | p_1 = p.reshape([num_batch, -1, 1, 3])
165 | p_2 = p.reshape([num_batch, 1, -1, 3])
166 | d2 = torch.sum(torch.square(p_1-p_2), 3)
167 |
168 | r = d2 * torch.log(d2 + 1e-6)
169 |
170 | zeros = torch.zeros(num_batch, 3, 3).float()
171 | if torch.cuda.is_available():
172 | zeros = zeros.cuda()
173 | W_0 = torch.cat((p, r), 2)
174 | W_1 = torch.cat((zeros, p.permute(0,2,1)), 2)
175 | W = torch.cat((W_0, W_1), 1)
176 |
177 | W_inv = torch.inverse(W.type(torch.float64))
178 |
179 | zeros2 = torch.zeros(num_batch, 3, 2)
180 | if torch.cuda.is_available():
181 | zeros2 = zeros2.cuda()
182 | tp = torch.cat((target, zeros2), 1)
183 | T = torch.matmul(W_inv, tp.type(torch.float64))
184 | T = T.permute(0, 2, 1)
185 |
186 | return T.type(torch.float32)
187 |
188 | T = _solve_system(source, target)
189 | output = _transform(T, source, U, out_size)
190 |
191 | return output
192 |
193 | def interpolate_grid(points, scr_w, scr_h, dest_w, dest_h):
194 | sparse_control_points = points.reshape(-1, 2, scr_w, scr_h)
195 | dense_control_points = F.interpolate(sparse_control_points, size=(dest_h, dest_w), mode='bilinear', align_corners=True)
196 | interpolated_points = dense_control_points.reshape(-1, 2*dest_w*dest_h)
197 | return interpolated_points
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from torch.utils.data import DataLoader
4 | from model.builder import *
5 | from dataset_loaders import TestDatasetMask
6 | import os
7 | import glob
8 | import numpy as np
9 | import cv2
10 | from skimage.metrics import peak_signal_noise_ratio as psnr
11 | from skimage.metrics import structural_similarity as ssim
12 | from model.network import MOWA
13 | from collections import OrderedDict
14 | from utils import flow_viz
15 | from utils.utils_op import *
16 |
17 | def test(args):
18 |
19 | path = os.path.dirname(os.path.abspath(__file__))
20 | IMG_DIR = os.path.join(path, 'results/', args.method, 'img/')
21 | MESH_DIR = os.path.join(path, 'results/', args.method, 'mesh/')
22 | RES_DIR = os.path.join(path, 'results/', args.method, 'res/')
23 | IMG_DIR_LIST = [os.path.join(IMG_DIR, str(i)) for i in range(len(args.test_path))]
24 | MESH_DIR_LIST = [os.path.join(MESH_DIR, str(i)) for i in range(len(args.test_path))]
25 | RES_DIR_LIST = [os.path.join(RES_DIR, str(i)) for i in range(len(args.test_path))]
26 |
27 | for i in range(len(IMG_DIR_LIST)):
28 | path = IMG_DIR_LIST[i]
29 | if not os.path.exists(path):
30 | os.makedirs(path)
31 | for i in range(len(MESH_DIR_LIST)):
32 | path = MESH_DIR_LIST[i]
33 | if not os.path.exists(path):
34 | os.makedirs(path)
35 | for i in range(len(RES_DIR_LIST)):
36 | path = RES_DIR_LIST[i]
37 | if not os.path.exists(path):
38 | os.makedirs(path)
39 |
40 | test_loader_list = [DataLoader(dataset=TestDatasetMask(test_path, i), batch_size=1, num_workers=4, shuffle=False, drop_last=False) \
41 | for i, test_path in enumerate(args.test_path)]
42 |
43 | '''define the network'''
44 | net = MOWA(img_size=args.input_size, tps_points=args.tps_points, embed_dim=args.embed_dim, win_size=args.win_size,
45 | token_projection=args.token_projection, token_mlp=args.token_mlp, depths=args.depths,
46 | prompt=args.prompt, task_classes=args.task_classes, head_num=args.head_num, shared_head=args.shared_head)
47 |
48 | if torch.cuda.is_available():
49 | torch.cuda.set_device(args.gpu)
50 | device = torch.device('cuda:{}'.format(args.gpu))
51 | net = net.to(device)
52 |
53 | '''load the existing models'''
54 | MODEL_DIR = args.model_path
55 | ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
56 | ckpt_list.sort()
57 | if len(ckpt_list) != 0:
58 | model_path = ckpt_list[-1]
59 | checkpoint = torch.load(model_path)
60 | state_dict = checkpoint["model"]
61 | new_state_dict = OrderedDict()
62 | for k, v in state_dict.items():
63 | name = k[7:] if 'module.' in k else k
64 | new_state_dict[name] = v
65 |
66 | net.load_state_dict(new_state_dict)
67 | print('load model from {}!'.format(model_path))
68 | else:
69 | raise FileNotFoundError(f'No checkpoint found in directory {MODEL_DIR}!')
70 |
71 | print("##################start testing#######################")
72 | net.eval()
73 | test_num = len(test_loader_list)
74 |
75 | for index in range(test_num):
76 | print("Task ID:", index)
77 | NUM = count_files(args.test_path[index])
78 | acc_temp = 0
79 | psnr_img = 0
80 | ssim_img = 0
81 | path_img = str(IMG_DIR_LIST[index]) + '/'
82 | path_grid = str(MESH_DIR_LIST[index]) + '/'
83 | path_res = str(RES_DIR_LIST[index]) + '/'
84 |
85 | with torch.no_grad():
86 | net.eval()
87 | test_loader = test_loader_list[index]
88 | for i, outputs in enumerate(test_loader):
89 |
90 | input1_tensor = outputs['input1_tensor'].float()
91 | input2_tensor = outputs['input2_tensor'].float()
92 | gt1_tensor = outputs['gt1_tensor'].float()
93 | gt2_tensor = outputs['gt2_tensor'].float()
94 | mask_tensor = outputs['mask_tensor'].float()
95 | task_id_tensor = outputs['task_id_tensor'].float()
96 | file_name = outputs['file_name'][0]
97 |
98 | if torch.cuda.is_available():
99 | input1_tensor = input1_tensor.cuda()
100 | input2_tensor = input2_tensor.cuda()
101 | gt1_tensor = gt1_tensor.cuda()
102 | gt2_tensor = gt2_tensor.cuda()
103 | mask_tensor = mask_tensor.cuda()
104 | task_id_tensor = task_id_tensor.cuda()
105 |
106 | ''' parsing the output '''
107 | batch_out = build_model_test(net, input1_tensor, input2_tensor, mask_tensor, args.tps_points, resize_flow=True)
108 | warp_tps, warp_flow, mesh, flow1, flow2, flow3, point_cls = \
109 | [batch_out[key] for key in ['warp_tps', 'warp_flow', 'mesh', 'flow1', 'flow2', 'flow3', 'point_cls']]
110 |
111 | ''' tensor to numpy and post-processing '''
112 | _, c, ori_h, ori_w = input1_tensor.shape
113 | input_np2 = ((input2_tensor[0])*255.0).cpu().detach().numpy().transpose(1,2,0).astype(np.uint8)
114 | gt_np1 = ((gt1_tensor[0])*255.0).cpu().detach().numpy().transpose(1,2,0).astype(np.uint8)
115 | gt_np2 = ((gt2_tensor[0])*255.0).cpu().detach().numpy().transpose(1,2,0).astype(np.uint8)
116 | gt_np2 = cv2.resize(gt_np2, (ori_w, ori_h))
117 |
118 | warp_flow_np = ((warp_flow[0])*255.0).cpu().detach().numpy().transpose(1,2,0).astype(np.uint8)
119 | flow1 = (flow1[0]).cpu().detach().numpy().transpose(1,2,0)
120 | flow2 = (flow2[0]).cpu().detach().numpy().transpose(1,2,0)
121 | flow3 = (flow3[0]).cpu().detach().numpy().transpose(1,2,0)
122 |
123 | flow1 = flow_viz.flow_to_image(flow1)
124 | flow2 = flow_viz.flow_to_image(flow2)
125 | flow3 = flow_viz.flow_to_image(flow3)
126 |
127 | _, point_cls = torch.max(point_cls[0], 0)
128 | acc_temp += (point_cls == task_id_tensor[0]).float().mean().item()
129 | warp_tps_np = ((warp_tps[-1][0])*255.0).cpu().detach().numpy().transpose(1,2,0).astype(np.uint8)
130 | mesh_np = mesh[-1][0].cpu().detach().numpy()
131 | cv2.imwrite(path_img + file_name + "_mesh" + ".jpg", warp_tps_np)
132 | input_with_mesh = draw_mesh_on_warp(input_np2, mesh_np, args.tps_points[-1]-1, args.tps_points[-1]-1)
133 | cv2.imwrite(path_grid + file_name + "_mesh" + ".jpg", input_with_mesh)
134 |
135 | ''' calculate metrics '''
136 | psnr_img += psnr(warp_flow_np, gt_np1, data_range=255)
137 | ssim_img += ssim(warp_flow_np, gt_np1, data_range=255, channel_axis=2)
138 | cv2.imwrite(path_img + file_name + "_flow1.jpg", flow1)
139 | cv2.imwrite(path_img + file_name + "_flow2.jpg", flow2)
140 | cv2.imwrite(path_img + file_name + "_flow3.jpg", flow3)
141 | cv2.imwrite(path_res + file_name + ".jpg", warp_flow_np)
142 | cv2.imwrite(path_img + file_name + "_flow.jpg", warp_flow_np)
143 |
144 | print(f"Validation PSNR: {round(psnr_img / NUM, 4)}, Validation SSIM: {round(ssim_img / NUM, 4)}, Validation Acc: {round(acc_temp / NUM, 4)}")
145 |
146 |
147 |
148 | if __name__=="__main__":
149 |
150 | parser = argparse.ArgumentParser()
151 |
152 | '''Implementation details'''
153 | parser.add_argument('--gpu', type=int, default=0)
154 | parser.add_argument('--batch_size', type=int, default=1)
155 | parser.add_argument('--model_path', type=str, default='model/')
156 | parser.add_argument('--method', type=str, default='method')
157 |
158 | '''Network details'''
159 | parser.add_argument('--input_size', type=int, default=256)
160 | parser.add_argument('--depths', nargs='+', type=int, default=[2, 2, 2, 2, 2, 2, 2, 2, 2], help='depths for transformer layers')
161 | parser.add_argument('--embed_dim', type=int, default=32)
162 | parser.add_argument('--win_size', type=int, default=8)
163 | parser.add_argument('--token_projection', type=str, default='linear')
164 | parser.add_argument('--token_mlp', type=str, default='leff')
165 | parser.add_argument('--prompt', type=bool, default=True)
166 | parser.add_argument('--task_classes', type=int, default=6)
167 | parser.add_argument('--tps_points', nargs='+', type=int, default=[10, 12, 14, 16], help='tps points for regression heads')
168 | parser.add_argument('--head_num', type=int, default=4)
169 | parser.add_argument('--shared_head', type=bool, default=False)
170 |
171 | '''Dataset settings'''
172 | parser.add_argument('--test_path', type=str, default=['/stitch/test/', '/wide-angle/test/', '/RS_Rec/test/', '/Rotation/test/', '/fisheye/test/', '/portrait/test/'])
173 |
174 | print('<==================== Testing ===================>\n')
175 |
176 | args = parser.parse_args()
177 | print(args)
178 | test(args)
--------------------------------------------------------------------------------
/test_portrait.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import torch
4 | from model.builder import *
5 | import os
6 | import glob
7 | import numpy as np
8 | import cv2
9 | from model.network import MOWA
10 | from collections import OrderedDict
11 | from utils.utils_op import *
12 | from tqdm import tqdm
13 |
14 | eps = 1e-6
15 |
16 | def estimation_flowmap(model, img, device, args):
17 | model.eval()
18 | img = cv2.resize(img, (256, 256))
19 | img = img.astype(dtype=np.float32)
20 | img = img / 255.0
21 | img = np.transpose(img, [2, 0, 1])
22 | img = torch.tensor(img)
23 | img = img.unsqueeze(0)
24 | mask = np.ones((256, 256), dtype=np.uint8) * 255
25 | mask = np.expand_dims(mask, axis=-1)
26 | mask = mask.astype(dtype=np.float32)
27 | mask = mask / 255.0
28 | mask = np.transpose(mask, [2, 0, 1])
29 | mask = torch.tensor(mask)
30 | mask = mask.unsqueeze(0)
31 | with torch.no_grad():
32 | img = img.to(device)
33 | mask = mask.to(device)
34 | batch_out = build_model_test(model, img, img, mask, args.tps_points)
35 | output = batch_out['flow3']
36 | output = output.detach().cpu().squeeze(0).numpy()
37 | return output
38 |
39 |
40 | # ----------------------The computation process of face metric ---------------------------------------
41 | def compute_cosin_similarity(preds, gts):
42 | people_num = gts.shape[0]
43 | points_num = gts.shape[1]
44 | similarity_list = []
45 | preds = preds.astype(np.float32)
46 | gts = gts.astype(np.float32)
47 | for people_index in range(people_num):
48 | # the index 63 of lmk is the center point of the face, that is, the tip of the nose
49 | pred_center = preds[people_index, 63, :]
50 | pred = preds[people_index, :, :]
51 | pred = pred - pred_center[None, :]
52 | gt_center = gts[people_index, 63, :]
53 | gt = gts[people_index, :, :]
54 | gt = gt - gt_center[None, :]
55 |
56 | dot = np.sum((pred * gt), axis=1)
57 | pred = np.sqrt(np.sum(pred * pred, axis=1))
58 | gt = np.sqrt(np.sum(gt * gt, axis=1))
59 |
60 | similarity_list_tmp = []
61 | for i in range(points_num):
62 | if i != 63:
63 | similarity = (dot[i] / (pred[i] * gt[i] + eps))
64 | similarity_list_tmp.append(similarity)
65 |
66 | similarity_list.append(np.mean(similarity_list_tmp))
67 |
68 | return np.mean(similarity_list)
69 |
70 |
71 | # --------------------The normalization function -----------------------------------------------------
72 | def normalization(x):
73 | return [(float(i) - min(x)) / float(max(x) - min(x) + eps) for i in x]
74 |
75 |
76 | # -------------------The computation process of line metric-------------------------------------------
77 | def compute_line_slope_difference(pred_line, gt_k):
78 | scores = []
79 | for i in range(pred_line.shape[0] - 1):
80 | pk = (pred_line[i + 1, 1] - pred_line[i, 1]) / (pred_line[i + 1, 0] - pred_line[i, 0] + eps)
81 | score = np.abs(pk - gt_k)
82 | scores.append(score)
83 | scores_norm = normalization(scores)
84 | score = np.mean(scores_norm)
85 | score = 1 - score
86 | return score
87 |
88 |
89 | # -------------------------------Compute the out put flow map -------------------------------------------------
90 | def compute_ori2shape_face_line_metric(model, oriimg_paths, device, args):
91 | line_all_sum_pred = []
92 | face_all_sum_pred = []
93 |
94 | for oriimg_path in tqdm(oriimg_paths):
95 | # Get the [Source image]
96 | ori_img = cv2.imread(oriimg_path) # Read the oriinal image
97 | ori_height, ori_width, _ = ori_img.shape # get the size of the oriinal image
98 | input = ori_img.copy() # get the image as the input of our model
99 |
100 | # Get the [flow map]"""
101 | pred = estimation_flowmap(model, input, device, args)
102 | pflow = pred.transpose(1, 2, 0)
103 | predflow_x, predflow_y = pflow[:, :, 0], pflow[:, :, 1]
104 |
105 | scale_x = ori_width / predflow_x.shape[1]
106 | scale_y = ori_height / predflow_x.shape[0]
107 | predflow_x = cv2.resize(predflow_x, (ori_width, ori_height)) * scale_x
108 | predflow_y = cv2.resize(predflow_y, (ori_width, ori_height)) * scale_y
109 |
110 | # Get the [predicted image]"""
111 | ys, xs = np.mgrid[:ori_height, :ori_width]
112 | mesh_x = predflow_x.astype("float32") + xs.astype("float32")
113 | mesh_y = predflow_y.astype("float32") + ys.astype("float32")
114 | pred_out = cv2.remap(input, mesh_x, mesh_y, cv2.INTER_LINEAR)
115 | cv2.imwrite(oriimg_path.replace(".jpg", "_pred.jpg"), pred_out)
116 |
117 | # Get the landmarks from the [gt image]
118 | stereo_lmk_file = open(oriimg_path.replace(".jpg", "_stereo_landmark.json"))
119 | stereo_lmk = np.array(json.load(stereo_lmk_file), dtype="float32")
120 |
121 | # Get the landmarks from the [source image]
122 | ori_lmk_file = open(oriimg_path.replace(".jpg", "_landmark.json"))
123 | ori_lmk = np.array(json.load(ori_lmk_file), dtype="float32")
124 |
125 | # Get the landmarks from the the pred out
126 | out_lmk = np.zeros_like(ori_lmk)
127 | for i in range(ori_lmk.shape[0]):
128 | for j in range(ori_lmk.shape[1]):
129 | x = ori_lmk[i, j, 0]
130 | y = ori_lmk[i, j, 1]
131 | if y < predflow_y.shape[0] and x < predflow_y.shape[1]:
132 | out_lmk[i, j, 0] = x - predflow_x[int(y), int(x)]
133 | out_lmk[i, j, 1] = y - predflow_y[int(y), int(x)]
134 | else:
135 | out_lmk[i, j, 0] = x
136 | out_lmk[i, j, 1] = y
137 |
138 | # Compute the face metric
139 | face_pred_sim = compute_cosin_similarity(out_lmk, stereo_lmk)
140 | face_all_sum_pred.append(face_pred_sim)
141 | stereo_lmk_file.close()
142 | ori_lmk_file.close()
143 |
144 | # Get the line from the [gt image]
145 | gt_line_file = oriimg_path.replace(".jpg", "_line_lines.json")
146 | lines = json.load(open(gt_line_file))
147 |
148 | # Get the line from the [source image]
149 | ori_line_file = oriimg_path.replace(".jpg", "_lines.json")
150 | ori_lines = json.load(open(ori_line_file))
151 |
152 | # Get the line from the pred out
153 | pred_ori2shape_lines = []
154 | for index, ori_line in enumerate(ori_lines):
155 | ori_line = np.array(ori_line, dtype="float32")
156 | pred_ori2shape = np.zeros_like(ori_line)
157 | for i in range(ori_line.shape[0]):
158 | x = ori_line[i, 0]
159 | y = ori_line[i, 1]
160 | pred_ori2shape[i, 0] = x - predflow_x[int(y), int(x)]
161 | pred_ori2shape[i, 1] = y - predflow_y[int(y), int(x)]
162 | pred_ori2shape = pred_ori2shape.tolist()
163 | pred_ori2shape_lines.append(pred_ori2shape)
164 |
165 | # Compute the lines score
166 | line_pred_ori2shape_sum = []
167 | for index, line in enumerate(lines):
168 | gt_line = np.array(line, dtype="float32")
169 | pred_ori2shape = np.array(pred_ori2shape_lines[index], dtype="float32")
170 | gt_k = (gt_line[1, 1] - gt_line[0, 1]) / (gt_line[1, 0] - gt_line[0, 0] + eps)
171 | pred_ori2shape_score = compute_line_slope_difference(pred_ori2shape, gt_k)
172 | line_pred_ori2shape_sum.append(pred_ori2shape_score)
173 | line_all_sum_pred.append(np.mean(line_pred_ori2shape_sum))
174 |
175 | return np.mean(line_all_sum_pred) * 100, np.mean(face_all_sum_pred) * 100
176 |
177 | def test(args):
178 |
179 | net = MOWA(img_size=args.input_size, tps_points=args.tps_points, embed_dim=args.embed_dim, win_size=args.win_size,
180 | token_projection=args.token_projection, token_mlp=args.token_mlp, depths=args.depths,
181 | prompt=args.prompt, task_classes=args.task_classes, head_num=args.head_num, shared_head=args.shared_head)
182 |
183 | if torch.cuda.is_available():
184 | torch.cuda.set_device(args.gpu)
185 | device = torch.device('cuda:{}'.format(args.gpu))
186 | net = net.to(device)
187 |
188 | MODEL_DIR = args.model_path
189 | ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
190 | ckpt_list.sort()
191 | if len(ckpt_list) != 0:
192 | model_path = ckpt_list[-1]
193 | print(model_path)
194 | checkpoint = torch.load(model_path)
195 | state_dict = checkpoint["model"]
196 | new_state_dict = OrderedDict()
197 | for k, v in state_dict.items():
198 | name = k[7:] if 'module.' in k else k
199 | new_state_dict[name] = v
200 |
201 | net.load_state_dict(new_state_dict)
202 | print('load model from {}!'.format(model_path))
203 | else:
204 | raise FileNotFoundError(f'No checkpoint found in directory {MODEL_DIR}!')
205 |
206 | print("##################start testing#######################")
207 | net.eval()
208 |
209 | oriimg_paths = []
210 | for root, _, files in os.walk(args.test_path):
211 | for file_name in files:
212 | if file_name.endswith(".jpg"):
213 | if "line" not in file_name and "stereo" not in file_name and "pred" not in file_name:
214 | oriimg_paths.append(os.path.join(root, file_name))
215 |
216 | print("The number of images: :", len(oriimg_paths))
217 |
218 | line_score, face_score = compute_ori2shape_face_line_metric(net, oriimg_paths, device, args)
219 | print("Line_score = {:.3f}, Face_score = {:.3f} ".format(line_score, face_score))
220 |
221 |
222 | if __name__=="__main__":
223 |
224 | parser = argparse.ArgumentParser()
225 |
226 | '''Implementation details'''
227 | parser.add_argument('--gpu', type=int, default=0)
228 | parser.add_argument('--batch_size', type=int, default=1)
229 | parser.add_argument('--model_path', type=str, default='model/')
230 | parser.add_argument('--method', type=str, default='method')
231 |
232 | '''Network details'''
233 | parser.add_argument('--input_size', type=int, default=256)
234 | parser.add_argument('--depths', nargs='+', type=int, default=[2, 2, 2, 2, 2, 2, 2, 2, 2], help='depths for transformer layers')
235 | parser.add_argument('--embed_dim', type=int, default=32)
236 | parser.add_argument('--win_size', type=int, default=8)
237 | parser.add_argument('--token_projection', type=str, default='linear')
238 | parser.add_argument('--token_mlp', type=str, default='leff')
239 | parser.add_argument('--prompt', type=bool, default=True)
240 | parser.add_argument('--task_classes', type=int, default=6)
241 | parser.add_argument('--tps_points', nargs='+', type=int, default=[10, 12, 14, 16], help='tps points for regression heads')
242 | parser.add_argument('--head_num', type=int, default=4)
243 | parser.add_argument('--shared_head', type=bool, default=False)
244 |
245 | '''Dataset settings'''
246 | parser.add_argument('--test_path', type=str, default="/Dataset/FaceRec/test_4_3_all/")
247 |
248 | print('<==================== Testing ===================>\n')
249 |
250 | args = parser.parse_args()
251 | print(args)
252 | test(args)
253 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from torch.utils.data import DataLoader
4 | import os
5 | import torch.optim as optim
6 | from torch.optim.lr_scheduler import StepLR
7 | from torch.utils.tensorboard import SummaryWriter
8 | from model.builder import *
9 | from model.network import MOWA
10 | from dataset_loaders import TrainDataset
11 | import glob
12 | from model.loss import *
13 | import torchvision.models as models
14 | import torch.multiprocessing as mp
15 | from torch.utils.data.distributed import DistributedSampler
16 | from torch.cuda.amp import autocast
17 | import torch.distributed as dist
18 | from warmup_scheduler import GradualWarmupScheduler
19 | from utils.utils_op import *
20 | torch.backends.cudnn.enabled = True
21 |
22 | def train(gpu, ngpus_per_node, args):
23 |
24 | """ threads running on each GPU """
25 | if args.distributed:
26 | torch.cuda.set_device(int(gpu))
27 | print('using GPU {} for training'.format(int(gpu)))
28 | torch.distributed.init_process_group(backend = 'nccl',
29 | init_method = 'tcp://127.0.0.1:' + args.port,
30 | world_size = ngpus_per_node,
31 | rank = gpu,
32 | group_name='mtorch'
33 | )
34 |
35 | ''' folder settings'''
36 | path = os.path.dirname(os.path.abspath(__file__))
37 | MODEL_DIR = os.path.join(path, 'checkpoint/', args.method)
38 | SUMMARY_DIR = os.path.join(path, 'summary/', args.method)
39 | if dist.get_rank() == 0:
40 | writer = SummaryWriter(log_dir=SUMMARY_DIR)
41 | if not os.path.exists(MODEL_DIR):
42 | os.makedirs(MODEL_DIR)
43 | if not os.path.exists(SUMMARY_DIR):
44 | os.makedirs(SUMMARY_DIR)
45 |
46 | ''' dataloader settings '''
47 | train_dataset = TrainDataset(args.train_path)
48 | train_sampler = DistributedSampler(train_dataset, num_replicas=ngpus_per_node, rank=gpu)
49 | train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, num_workers=4, sampler=train_sampler, drop_last=True)
50 |
51 | ''' define the network and training scheduler'''
52 | net = MOWA(img_size=args.input_size, tps_points=args.tps_points, embed_dim=args.embed_dim, win_size=args.win_size,
53 | token_projection=args.token_projection, token_mlp=args.token_mlp, depths=args.depths,
54 | prompt=args.prompt, task_classes=args.task_classes, head_num=args.head_num, shared_head=args.shared_head)
55 |
56 | vgg_model = models.vgg19(pretrained=True)
57 | net = set_device(net, distributed=args.distributed)
58 | vgg_model = set_device(vgg_model, distributed=args.distributed)
59 | vgg_model.eval()
60 |
61 | optimizer = optim.AdamW(net.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay)
62 | scaler = torch.cuda.amp.GradScaler()
63 | if args.warmup:
64 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_epoch-args.warmup_epochs, eta_min=args.eta_min)
65 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=args.warmup_epochs, after_scheduler=scheduler_cosine)
66 | scheduler.step()
67 | else:
68 | step = 50
69 | print("Using StepLR,step={}!".format(step))
70 | scheduler = StepLR(optimizer, step_size=step, gamma=0.5)
71 | scheduler.step()
72 |
73 | ''' resume training or not'''
74 | ckpt_list = glob.glob(MODEL_DIR + "/*.pth")
75 | ckpt_list.sort()
76 | if len(ckpt_list) != 0:
77 | model_path = ckpt_list[-1]
78 | checkpoint = torch.load(model_path)
79 | net.load_state_dict(checkpoint['model'])
80 | optimizer.load_state_dict(checkpoint['optimizer'])
81 | start_epoch = checkpoint['epoch']
82 | glob_iter = checkpoint['glob_iter']
83 | scheduler.last_epoch = start_epoch
84 | print('load model from {}!'.format(model_path))
85 | else:
86 | start_epoch = 0
87 | glob_iter = 0
88 | print('training from stratch!')
89 |
90 | ''' network training'''
91 | for epoch in range(start_epoch, args.max_epoch):
92 | train_sampler.set_epoch(epoch)
93 | net.train()
94 | total_loss_sigma = 0.
95 | appearance_loss_sigma = 0.
96 | perception_loss_sigma = 0.
97 | inter_grid_loss_sigma = 0.
98 | point_loss_sigma = 0.
99 | flow_loss_sigma = 0.
100 |
101 | for i, batch_value in enumerate(train_loader):
102 | input_tesnor = batch_value[0].float()
103 | gt_tesnor = batch_value[1].float()
104 | mask_tensor = batch_value[2].float()
105 | task_id_tensor = batch_value[3].float()
106 | flow_tensor = batch_value[4].float()
107 | face_mask = batch_value[5].float()
108 | face_weight = batch_value[6].float()
109 |
110 | if torch.cuda.is_available():
111 | input_tesnor = set_device(input_tesnor, distributed=args.distributed)
112 | gt_tesnor = set_device(gt_tesnor, distributed=args.distributed)
113 | mask_tensor = set_device(mask_tensor, distributed=args.distributed)
114 | task_id_tensor = set_device(task_id_tensor, distributed=args.distributed)
115 | flow_tensor = set_device(flow_tensor, distributed=args.distributed)
116 | face_mask = set_device(face_mask, distributed=args.distributed)
117 | face_weight = set_device(face_weight, distributed=args.distributed)
118 |
119 | optimizer.zero_grad()
120 | with autocast():
121 | batch_out = build_model(net.module, input_tesnor, input_tesnor, mask_tensor, args.tps_points)
122 | warp_tps, mesh, warp_flow, flow, point_cls = \
123 | [batch_out[key] for key in ['warp_tps', 'mesh', 'warp_flow', 'flow3', 'point_cls']]
124 |
125 | ''' calculate losses'''
126 | inter_grid_loss = cal_inter_grid_loss_sum(mesh, args.tps_points, [1.0/args.head_num for _ in range(args.head_num)])
127 | point_loss = cal_point_loss(point_cls, task_id_tensor) * 0.1
128 | face_weight_ad = adjust_weight(epoch, args.max_epoch, face_weight)
129 | flow_tensor, flow = get_weight_mask(face_mask, flow_tensor, flow, weight=face_weight_ad)
130 | flow_loss = mask_flow_loss(flow, flow_tensor, task_id_tensor) * 0.1
131 |
132 | if(epoch <= 10):
133 | appearance_loss = cal_appearance_loss_sum(warp_tps, gt_tesnor, args.img_weight1)
134 | perception_loss = cal_perception_loss_sum(vgg_model, warp_tps, gt_tesnor, args.img_weight1)
135 | total_loss = appearance_loss + perception_loss + inter_grid_loss + point_loss
136 | else:
137 | appearance_loss = cal_appearance_loss_sum(warp_tps+[warp_flow], gt_tesnor, args.img_weight1+args.img_weight2)
138 | perception_loss = cal_perception_loss_sum(vgg_model, warp_tps+[warp_flow], gt_tesnor, args.img_weight1+args.img_weight2)
139 | total_loss = appearance_loss + perception_loss + inter_grid_loss + flow_loss + point_loss
140 |
141 | scaler.scale(total_loss).backward()
142 | scaler.step(optimizer)
143 | scaler.update()
144 |
145 | total_loss_sigma += total_loss.item()
146 | appearance_loss_sigma += appearance_loss.item()
147 | perception_loss_sigma += perception_loss.item()
148 | inter_grid_loss_sigma += inter_grid_loss.item()
149 | point_loss_sigma += point_loss.item()
150 | flow_loss_sigma += flow_loss.item()
151 |
152 | ''' writting training logs '''
153 | if i % args.print_interval == 0 and i != 0:
154 | if dist.get_rank() == 0:
155 | total_loss_average = total_loss_sigma / args.print_interval
156 | appearance_loss_average = appearance_loss_sigma/ args.print_interval
157 | perception_loss_average = perception_loss_sigma/ args.print_interval
158 | inter_grid_loss_average = inter_grid_loss_sigma/ args.print_interval
159 | point_loss_average = point_loss_sigma/ args.print_interval
160 | flow_loss_average = flow_loss_sigma/ args.print_interval
161 |
162 | total_loss_sigma = 0.
163 | appearance_loss_sigma = 0.
164 | perception_loss_sigma = 0.
165 | inter_grid_loss_sigma = 0.
166 | point_loss_sigma = 0.
167 | flow_loss_sigma = 0.
168 |
169 | print(f"Training: Epoch[{epoch + 1:0>3}/{args.max_epoch:0>3}] "
170 | f"Iteration[{i + 1:0>3}/{len(train_loader):0>3}] "
171 | f"Total Loss: {total_loss_average:.4f} "
172 | f"Appearance Loss: {appearance_loss_average:.4f} "
173 | f"Perception Loss: {perception_loss_average:.4f} "
174 | f"Point Loss: {point_loss_average:.4f} "
175 | f"Flow Loss: {flow_loss_average:.4f} "
176 | f"Inter-Grid Loss: {inter_grid_loss_average:.4f} "
177 | f"lr={optimizer.state_dict()['param_groups'][0]['lr']:.8f}")
178 |
179 | writer.add_image("input", (input_tesnor[0]), glob_iter)
180 | writer.add_image("rectangling", (warp_flow[0]), glob_iter)
181 | writer.add_image("gt", (gt_tesnor[0]), glob_iter)
182 | writer.add_scalar('lr', optimizer.state_dict()['param_groups'][0]['lr'], glob_iter)
183 | writer.add_scalar('total loss', total_loss_average, glob_iter)
184 | writer.add_scalar('appearance loss', appearance_loss_average, glob_iter)
185 | writer.add_scalar('perception loss', perception_loss_average, glob_iter)
186 | writer.add_scalar('inter-grid loss', inter_grid_loss_average, glob_iter)
187 | writer.add_scalar('point loss', point_loss_average, glob_iter)
188 | writer.add_scalar('flow loss', flow_loss_average, glob_iter)
189 |
190 | glob_iter += 1
191 |
192 | ''' save models '''
193 | if ((epoch+1) % args.eva_interval == 0 or (epoch+1)==args.max_epoch):
194 | if dist.get_rank() == 0:
195 | filename ='epoch' + str(epoch+1).zfill(3) + '_model.pth'
196 | model_save_path = os.path.join(MODEL_DIR, filename)
197 | state = {'model': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch+1, "glob_iter": glob_iter}
198 | torch.save(state, model_save_path)
199 |
200 | scheduler.step()
201 |
202 |
203 | if __name__=="__main__":
204 |
205 | print('<==================== setting arguments ===================>\n')
206 |
207 | parser = argparse.ArgumentParser()
208 | '''Implementation details'''
209 | parser.add_argument('-gpu', '--gpu_ids', type=str, default='0')
210 | parser.add_argument('-b', '--batch_size', type=int, default=8)
211 | parser.add_argument('--max_epoch', type=int, default=300)
212 | parser.add_argument('-m', '--method', type=str, default='mowa')
213 | parser.add_argument('-P', '--port', default='21016', type=str)
214 | parser.add_argument('-d', '--distributed', type=bool, default=True)
215 | parser.add_argument('-w', '--warmup', type=bool, default=True)
216 | parser.add_argument('--warmup_epochs', type=int, default=3, help='epochs for warmup')
217 | parser.add_argument('--print_interval', type=int, default=160)
218 | parser.add_argument('--eva_interval', type=int, default=10)
219 | parser.add_argument("--lr", type=float, default=2e-4, help="start learning rate")
220 | parser.add_argument("--eta_min", type=float, default=1e-7, help="final learning rate")
221 | parser.add_argument("--weight_decay", type=float, default=0, help="weight decay of the optimizer")
222 | parser.add_argument('--img_weight1', nargs='+', type=float, default=[0.1, 0.1, 0.2, 0.5], help='weights for img loss (stage1)')
223 | parser.add_argument('--img_weight2', nargs='+', type=float, default=[0.5], help='weights for img loss (stage2)')
224 |
225 | '''Network details'''
226 | parser.add_argument('--input_size', type=int, default=256)
227 | parser.add_argument('--depths', nargs='+', type=int, default=[2, 2, 2, 2, 2, 2, 2, 2, 2], help='depths for transformer layers')
228 | parser.add_argument('--tps_points', nargs='+', type=int, default=[10, 12, 14, 16], help='tps points for regression heads')
229 | parser.add_argument('--embed_dim', type=int, default=32)
230 | parser.add_argument('--win_size', type=int, default=8)
231 | parser.add_argument('--token_projection', type=str, default='linear')
232 | parser.add_argument('--token_mlp', type=str, default='leff')
233 | parser.add_argument('--prompt', type=bool, default=True)
234 | parser.add_argument('--task_classes', type=int, default=6)
235 | parser.add_argument('--head_num', type=int, default=4)
236 | parser.add_argument('--shared_head', type=bool, default=False)
237 |
238 | '''Dataset settings'''
239 | parser.add_argument('--train_path', type=str, default=['/Dataset/pano-rectangling/train/', '/Dataset/wide-angle_rectangling/train/',
240 | '/Dataset/RS_Rec/RS_Rec/train/', '/Dataset/Rotation/train/',
241 | '/Dataset/fisheye/train/', '/Dataset/FaceRec/train/'])
242 |
243 | args = parser.parse_args()
244 | print(args)
245 |
246 | gpu_str = args.gpu_ids
247 | gpu_ids = [int(id) for id in args.gpu_ids.split(',')]
248 | num_gpus = len(gpu_ids)
249 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_str
250 | print('export CUDA_VISIBLE_DEVICES={}'.format(gpu_str))
251 | opt=0
252 | print('<==================== start training ===================>\n')
253 | mp.spawn(train, nprocs=num_gpus, args=(num_gpus, args))
254 |
255 | print("################## end training #######################")
--------------------------------------------------------------------------------
/model/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.checkpoint as checkpoint
4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5 | import torch.nn.functional as F
6 | from einops import rearrange, repeat
7 | import math
8 | from utils.utils_module import *
9 | from utils.utils_transform import *
10 |
11 | class FastLeFF(nn.Module):
12 |
13 | def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop = 0.):
14 | super().__init__()
15 |
16 | from torch_dwconv import DepthwiseConv2d
17 |
18 | self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim),
19 | act_layer())
20 | self.dwconv = nn.Sequential(DepthwiseConv2d(hidden_dim, hidden_dim, kernel_size=3,stride=1,padding=1),
21 | act_layer())
22 | self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim))
23 | self.dim = dim
24 | self.hidden_dim = hidden_dim
25 |
26 | def forward(self, x):
27 | # bs x hw x c
28 | bs, hw, c = x.size()
29 | hh = int(math.sqrt(hw))
30 |
31 | x = self.linear1(x)
32 |
33 | # spatial restore
34 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h = hh, w = hh)
35 | # bs,hidden_dim,32x32
36 |
37 | x = self.dwconv(x)
38 |
39 | # flaten
40 | x = rearrange(x, ' b c h w -> b (h w) c', h = hh, w = hh)
41 |
42 | x = self.linear2(x)
43 |
44 | return x
45 |
46 | class SELayer(nn.Module):
47 | def __init__(self, channel, reduction=16):
48 | super(SELayer, self).__init__()
49 | self.avg_pool = nn.AdaptiveAvgPool1d(1)
50 | self.fc = nn.Sequential(
51 | nn.Linear(channel, channel // reduction, bias=False),
52 | nn.ReLU(inplace=True),
53 | nn.Linear(channel // reduction, channel, bias=False),
54 | nn.Sigmoid()
55 | )
56 |
57 | def forward(self, x): # x: [B, N, C]
58 | x = torch.transpose(x, 1, 2) # [B, C, N]
59 | b, c, _ = x.size()
60 | y = self.avg_pool(x).view(b, c)
61 | y = self.fc(y).view(b, c, 1)
62 | x = x * y.expand_as(x)
63 | x = torch.transpose(x, 1, 2) # [B, N, C]
64 | return x
65 |
66 | class SepConv2d(torch.nn.Module):
67 | def __init__(self,
68 | in_channels,
69 | out_channels,
70 | kernel_size,
71 | stride=1,
72 | padding=0,
73 | dilation=1, act_layer=nn.ReLU):
74 | super(SepConv2d, self).__init__()
75 | self.depthwise = torch.nn.Conv2d(in_channels,
76 | in_channels,
77 | kernel_size=kernel_size,
78 | stride=stride,
79 | padding=padding,
80 | dilation=dilation,
81 | groups=in_channels)
82 | self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1)
83 | self.act_layer = act_layer() if act_layer is not None else nn.Identity()
84 |
85 | def forward(self, x):
86 | x = self.depthwise(x)
87 | x = self.act_layer(x)
88 | x = self.pointwise(x)
89 | return x
90 |
91 | class ConvProjection(nn.Module):
92 | def __init__(self, dim, heads=8, dim_head=64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, dropout=0.,
93 | last_stage=False, bias=True):
94 | super().__init__()
95 |
96 | inner_dim = dim_head * heads
97 | self.heads = heads
98 | pad = (kernel_size - q_stride) // 2
99 | self.to_q = SepConv2d(dim, inner_dim, kernel_size, q_stride, pad, bias)
100 | self.to_k = SepConv2d(dim, inner_dim, kernel_size, k_stride, pad, bias)
101 | self.to_v = SepConv2d(dim, inner_dim, kernel_size, v_stride, pad, bias)
102 |
103 | def forward(self, x, attn_kv=None):
104 | b, n, c, h = *x.shape, self.heads
105 | l = int(math.sqrt(n))
106 | w = int(math.sqrt(n))
107 |
108 | attn_kv = x if attn_kv is None else attn_kv
109 | x = rearrange(x, 'b (l w) c -> b c l w', l=l, w=w)
110 | attn_kv = rearrange(attn_kv, 'b (l w) c -> b c l w', l=l, w=w)
111 | # print(attn_kv)
112 | q = self.to_q(x)
113 | q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h)
114 |
115 | k = self.to_k(attn_kv)
116 | v = self.to_v(attn_kv)
117 | k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h)
118 | v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h)
119 | return q, k, v
120 |
121 | class LinearProjection(nn.Module):
122 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., bias=True):
123 | super().__init__()
124 | inner_dim = dim_head * heads
125 | self.heads = heads
126 | self.to_q = nn.Linear(dim, inner_dim, bias=bias)
127 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias)
128 |
129 | def forward(self, x, attn_kv=None):
130 | B_, N, C = x.shape
131 | attn_kv = x if attn_kv is None else attn_kv
132 | q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
133 | kv = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
134 | q = q[0]
135 | k, v = kv[0], kv[1]
136 | return q, k, v
137 |
138 | class LinearProjection_Concat_kv(nn.Module):
139 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., bias=True):
140 | super().__init__()
141 | inner_dim = dim_head * heads
142 | self.heads = heads
143 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=bias)
144 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias)
145 |
146 | def forward(self, x, attn_kv=None):
147 | B_, N, C = x.shape
148 | attn_kv = x if attn_kv is None else attn_kv
149 | qkv_dec = self.to_qkv(x).reshape(B_, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
150 | kv_enc = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
151 | q, k_d, v_d = qkv_dec[0], qkv_dec[1], qkv_dec[2] # make torchscript happy (cannot use tensor as tuple)
152 | k_e, v_e = kv_enc[0], kv_enc[1]
153 | k = torch.cat((k_d, k_e), dim=2)
154 | v = torch.cat((v_d, v_e), dim=2)
155 | return q, k, v
156 |
157 | class WindowAttention(nn.Module):
158 | def __init__(self, dim, win_size, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0.,
159 | proj_drop=0., se_layer=False):
160 |
161 | super().__init__()
162 | self.dim = dim
163 | self.win_size = win_size # Wh, Ww
164 | self.num_heads = num_heads
165 | head_dim = dim // num_heads
166 | self.scale = qk_scale or head_dim ** -0.5
167 |
168 | # define a parameter table of relative position bias
169 | self.relative_position_bias_table = nn.Parameter(
170 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
171 |
172 | # get pair-wise relative position index for each token inside the window
173 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1]
174 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1]
175 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
176 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
177 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
178 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
179 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0
180 | relative_coords[:, :, 1] += self.win_size[1] - 1
181 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
182 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
183 | self.register_buffer("relative_position_index", relative_position_index)
184 |
185 | if token_projection == 'conv':
186 | self.qkv = ConvProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)
187 | elif token_projection == 'linear_concat':
188 | self.qkv = LinearProjection_Concat_kv(dim, num_heads, dim // num_heads, bias=qkv_bias)
189 | else:
190 | self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)
191 |
192 | self.attn_drop = nn.Dropout(attn_drop)
193 | self.proj = nn.Linear(dim, dim)
194 | self.se_layer = SELayer(dim) if se_layer else nn.Identity()
195 | self.proj_drop = nn.Dropout(proj_drop)
196 |
197 | trunc_normal_(self.relative_position_bias_table, std=.02)
198 | self.softmax = nn.Softmax(dim=-1)
199 |
200 | def forward(self, x, attn_kv=None, mask=None):
201 | B_, N, C = x.shape
202 | q, k, v = self.qkv(x, attn_kv)
203 | q = q * self.scale
204 | attn = (q @ k.transpose(-2, -1))
205 |
206 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
207 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH
208 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
209 | ratio = attn.size(-1) // relative_position_bias.size(-1)
210 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio)
211 |
212 | attn = attn + relative_position_bias.unsqueeze(0)
213 |
214 | if mask is not None:
215 | nW = mask.shape[0]
216 | mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio)
217 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0)
218 | attn = attn.view(-1, self.num_heads, N, N * ratio)
219 | attn = self.softmax(attn)
220 | else:
221 | attn = self.softmax(attn)
222 |
223 | attn = self.attn_drop(attn)
224 |
225 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
226 | x = self.proj(x)
227 | x = self.se_layer(x)
228 | x = self.proj_drop(x)
229 | return x
230 |
231 | def extra_repr(self) -> str:
232 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}'
233 |
234 | class Attention(nn.Module):
235 | def __init__(self, dim,num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
236 |
237 | super().__init__()
238 | self.dim = dim
239 | self.num_heads = num_heads
240 | head_dim = dim // num_heads
241 | self.scale = qk_scale or head_dim ** -0.5
242 |
243 | self.qkv = LinearProjection(dim,num_heads,dim//num_heads,bias=qkv_bias)
244 |
245 | self.token_projection = token_projection
246 | self.attn_drop = nn.Dropout(attn_drop)
247 | self.proj = nn.Linear(dim, dim)
248 | self.proj_drop = nn.Dropout(proj_drop)
249 |
250 | self.softmax = nn.Softmax(dim=-1)
251 |
252 | def forward(self, x, attn_kv=None, mask=None):
253 | B_, N, C = x.shape
254 | q, k, v = self.qkv(x,attn_kv)
255 | q = q * self.scale
256 | attn = (q @ k.transpose(-2, -1))
257 |
258 | if mask is not None:
259 | nW = mask.shape[0]
260 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
261 | attn = attn.view(-1, self.num_heads, N, N)
262 | attn = self.softmax(attn)
263 | else:
264 | attn = self.softmax(attn)
265 |
266 | attn = self.attn_drop(attn)
267 |
268 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
269 | x = self.proj(x)
270 | x = self.proj_drop(x)
271 | return x
272 |
273 | def extra_repr(self) -> str:
274 | return f'dim={self.dim}, num_heads={self.num_heads}'
275 |
276 | class Mlp(nn.Module):
277 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
278 | super().__init__()
279 | out_features = out_features or in_features
280 | hidden_features = hidden_features or in_features
281 | self.fc1 = nn.Linear(in_features, hidden_features)
282 | self.act = act_layer()
283 | self.fc2 = nn.Linear(hidden_features, out_features)
284 | self.drop = nn.Dropout(drop)
285 |
286 | def forward(self, x):
287 | x = self.fc1(x)
288 | x = self.act(x)
289 | x = self.drop(x)
290 | x = self.fc2(x)
291 | x = self.drop(x)
292 | return x
293 |
294 |
295 | class LeFF(nn.Module):
296 | def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop=0.):
297 | super().__init__()
298 | self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim),
299 | act_layer())
300 | self.dwconv = nn.Sequential(
301 | nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1),
302 | act_layer())
303 | self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim))
304 | self.drop = nn.Dropout(drop)
305 |
306 | def forward(self, x):
307 | # bs x hw x c
308 | bs, hw, c = x.size()
309 | hh = int(math.sqrt(hw))
310 |
311 | x = self.linear1(x)
312 |
313 | # spatial restore
314 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h=hh, w=hh)
315 | # bs,hidden_dim,32x32
316 |
317 | x = self.dwconv(x)
318 |
319 | # flaten
320 | x = rearrange(x, ' b c h w -> b (h w) c', h=hh, w=hh)
321 |
322 | x = self.linear2(x)
323 | x = self.drop(x)
324 |
325 | return x
326 |
327 | def window_partition(x, win_size):
328 | B, H, W, C = x.shape
329 | x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
330 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C)
331 | return windows
332 |
333 |
334 | def window_reverse(windows, win_size, H, W):
335 | B = int(windows.shape[0] / (H * W / win_size / win_size))
336 | x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
337 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
338 | return x
339 |
340 | class Downsample(nn.Module):
341 | def __init__(self, in_channel, out_channel):
342 | super(Downsample, self).__init__()
343 | self.conv = nn.Sequential(
344 | nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1),
345 |
346 | )
347 |
348 | def forward(self, x):
349 | B, L, C = x.shape
350 | H = int(math.sqrt(L))
351 | W = int(math.sqrt(L))
352 | x = x.transpose(1, 2).contiguous().view(B, C, H, W)
353 | out = self.conv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
354 | return out
355 |
356 | class Upsample(nn.Module):
357 | def __init__(self, in_channel, out_channel):
358 | super(Upsample, self).__init__()
359 | self.deconv = nn.Sequential(
360 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
361 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
362 | )
363 |
364 | def forward(self, x):
365 | B, L, C = x.shape
366 | H = int(math.sqrt(L))
367 | W = int(math.sqrt(L))
368 | x = x.transpose(1, 2).contiguous().view(B, C, H, W)
369 | out = self.deconv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
370 | return out
371 |
372 | class InputProj(nn.Module):
373 | def __init__(self, in_channel=3, out_channel=64, kernel_size=3, stride=1, norm_layer=None, act_layer=nn.LeakyReLU):
374 | super().__init__()
375 | self.proj = nn.Sequential(
376 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size // 2),
377 | act_layer(inplace=True)
378 | )
379 | if norm_layer is not None:
380 | self.norm = norm_layer(out_channel)
381 | else:
382 | self.norm = None
383 |
384 | def forward(self, x):
385 | B, C, H, W = x.shape
386 | x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
387 | if self.norm is not None:
388 | x = self.norm(x)
389 | return x
390 |
391 | class OutputProj(nn.Module):
392 | def __init__(self, in_channel=64, out_channel=3, kernel_size=3, stride=1, task_classes=4, prompt=False):
393 | super().__init__()
394 | self.proj = nn.Sequential(
395 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size // 2),
396 | )
397 | self.prompt = prompt
398 | if(self.prompt):
399 | self.prompt_task = nn.Parameter(torch.randn(1, task_classes, in_channel, 1, 1))
400 | self.proj_prompt = nn.Sequential(
401 | nn.Conv2d(2*in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size // 2),
402 | )
403 |
404 | def forward(self, x, prompt=None):
405 | B, L, C = x.shape
406 | H = int(math.sqrt(L))
407 | W = int(math.sqrt(L))
408 | x = x.transpose(1, 2).view(B, C, H, W)
409 | if(self.prompt):
410 | detask_prompts = prompt.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_task.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
411 | detask_prompts = torch.mean(detask_prompts, dim = 1)
412 | detask_prompts = F.interpolate(detask_prompts,(H,W),mode="bilinear")
413 | combo = torch.concat([x, detask_prompts], dim=1)
414 | combo = self.proj_prompt(combo)
415 | else:
416 | combo = self.proj(x)
417 |
418 | return combo
419 |
420 |
421 | class LeWinTransformerBlock(nn.Module):
422 | def __init__(self, dim, input_resolution, num_heads, win_size=8, shift_size=0,
423 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
424 | act_layer=nn.GELU, norm_layer=nn.LayerNorm,token_projection='linear',token_mlp='leff',
425 | prompt=False, task_classes=4):
426 | super().__init__()
427 | self.dim = dim
428 | self.input_resolution = input_resolution
429 | self.num_heads = num_heads
430 | self.win_size = win_size
431 | self.shift_size = shift_size
432 | self.mlp_ratio = mlp_ratio
433 | self.token_mlp = token_mlp
434 | if min(self.input_resolution) <= self.win_size:
435 | self.shift_size = 0
436 | self.win_size = min(self.input_resolution)
437 | assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size"
438 |
439 | if prompt:
440 | self.prompt_task = nn.Parameter(torch.randn(1, task_classes, 1, dim))
441 | else:
442 | self.prompt_task = None
443 |
444 | self.norm1 = norm_layer(dim)
445 | self.attn = WindowAttention(
446 | dim, win_size=to_2tuple(self.win_size), num_heads=num_heads,
447 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
448 | token_projection=token_projection)
449 |
450 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
451 | self.norm2 = norm_layer(dim)
452 | mlp_hidden_dim = int(dim * mlp_ratio)
453 | if token_mlp in ['ffn','mlp']:
454 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,act_layer=act_layer, drop=drop)
455 | elif token_mlp=='leff':
456 | self.mlp = LeFF(dim,mlp_hidden_dim,act_layer=act_layer, drop=drop)
457 |
458 | elif token_mlp=='fastleff':
459 | self.mlp = FastLeFF(dim,mlp_hidden_dim,act_layer=act_layer, drop=drop)
460 | else:
461 | raise Exception("FFN error!")
462 |
463 | self.chnl_reduce = nn.Conv2d(2*dim,dim,kernel_size=1,bias=False)
464 |
465 |
466 | def with_pos_embed(self, tensor, pos):
467 | return tensor if pos is None else tensor + pos
468 |
469 | def extra_repr(self) -> str:
470 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
471 | f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio},modulator={self.modulator}"
472 |
473 | def forward(self, x, point_cls=None, mask=None):
474 | B, L, C = x.shape
475 | H = int(math.sqrt(L))
476 | W = int(math.sqrt(L))
477 |
478 | ## input mask
479 | if mask != None:
480 | input_mask = F.interpolate(mask, size=(H,W)).permute(0,2,3,1)
481 | input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1
482 | attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
483 | attn_mask = attn_mask.unsqueeze(2)*attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size
484 | attn_mask = attn_mask.masked_fill(attn_mask!=0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
485 | else:
486 | attn_mask = None
487 |
488 | ## shift mask
489 | if self.shift_size > 0:
490 | # calculate attention mask for SW-MSA
491 | shift_mask = torch.zeros((1, H, W, 1)).type_as(x)
492 | h_slices = (slice(0, -self.win_size),
493 | slice(-self.win_size, -self.shift_size),
494 | slice(-self.shift_size, None))
495 | w_slices = (slice(0, -self.win_size),
496 | slice(-self.win_size, -self.shift_size),
497 | slice(-self.shift_size, None))
498 | cnt = 0
499 | for h in h_slices:
500 | for w in w_slices:
501 | shift_mask[:, h, w, :] = cnt
502 | cnt += 1
503 | shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1
504 | shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
505 | shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(2) # nW, win_size*win_size, win_size*win_size
506 | shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill(shift_attn_mask == 0, float(0.0))
507 | attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask
508 |
509 | shortcut = x
510 | x = self.norm1(x)
511 | x = x.view(B, H, W, C)
512 |
513 | # cyclic shift
514 | if self.shift_size > 0:
515 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
516 | else:
517 | shifted_x = x
518 |
519 | # partition windows
520 | x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C
521 | x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C
522 |
523 | if self.prompt_task is not None:
524 | B_win, L_win, C_win = x_windows.shape
525 | M = B_win // B
526 | # (B, prompt_len, 1, 1) * (B, prompt_len, 1, prompt_dim)
527 | detask_prompts = point_cls.unsqueeze(-1).unsqueeze(-1) * self.prompt_task.unsqueeze(0).repeat(B,1,1,1,1).squeeze(1)
528 | # (B, 1, prompt_dim)
529 | detask_prompts = torch.mean(detask_prompts, dim=1)
530 | # (B, N, prompt_dim)
531 | detask_prompts = detask_prompts.unsqueeze(1).repeat(1, L_win, 1, 1).squeeze(2)
532 | # (B*M, N, prompt_dim)
533 | detask_prompts = detask_prompts.repeat(M, 1, 1)
534 | combo = torch.concat([x_windows, detask_prompts], dim=2)
535 | combo = combo.permute(0,2,1).view(-1, 2*C, self.win_size, self.win_size)
536 | combo = self.chnl_reduce(combo)
537 | wmsa_in = combo.view(-1, C, self.win_size * self.win_size).permute(0,2,1)
538 |
539 | else:
540 | wmsa_in = x_windows
541 |
542 | # W-MSA/SW-MSA
543 | attn_windows = self.attn(wmsa_in, mask=attn_mask) # nW*B, win_size*win_size, C
544 |
545 | # merge windows
546 | attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C)
547 | shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C
548 |
549 | # reverse cyclic shift
550 | if self.shift_size > 0:
551 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
552 | else:
553 | x = shifted_x
554 | x = x.view(B, H * W, C)
555 |
556 | # FFN
557 | x = shortcut + self.drop_path(x)
558 | x = x + self.drop_path(self.mlp(self.norm2(x)))
559 | del attn_mask
560 | return x
561 |
562 | class BasicUformerLayer(nn.Module):
563 | def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size,
564 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
565 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False,
566 | token_projection='linear',token_mlp='ffn', shift_flag=True,
567 | prompt=False, task_classes=4):
568 |
569 | super().__init__()
570 | self.dim = dim
571 | self.input_resolution = input_resolution
572 | self.depth = depth
573 | self.use_checkpoint = use_checkpoint
574 | if shift_flag:
575 | self.blocks = nn.ModuleList([
576 | LeWinTransformerBlock(dim=dim, input_resolution=input_resolution,
577 | num_heads=num_heads, win_size=win_size,
578 | shift_size=0 if (i % 2 == 0) else win_size // 2,
579 | mlp_ratio=mlp_ratio,
580 | qkv_bias=qkv_bias, qk_scale=qk_scale,
581 | drop=drop, attn_drop=attn_drop,
582 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
583 | norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp,
584 | prompt=prompt, task_classes=task_classes)
585 | for i in range(depth)])
586 | else:
587 | self.blocks = nn.ModuleList([
588 | LeWinTransformerBlock(dim=dim, input_resolution=input_resolution,
589 | num_heads=num_heads, win_size=win_size,
590 | shift_size=0,
591 | mlp_ratio=mlp_ratio,
592 | qkv_bias=qkv_bias, qk_scale=qk_scale,
593 | drop=drop, attn_drop=attn_drop,
594 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
595 | norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp,
596 | prompt=prompt, task_classes=task_classes)
597 | for i in range(depth)])
598 |
599 | def extra_repr(self) -> str:
600 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
601 |
602 | def forward(self, x, point_cls=None, mask=None):
603 | for blk in self.blocks:
604 | if self.use_checkpoint:
605 | x = checkpoint.checkpoint(blk, x, point_cls)
606 | else:
607 | x = blk(x, point_cls, mask)
608 | return x
609 |
610 | class MOWA(nn.Module):
611 | def __init__(self, img_size=256, in_chans=4,
612 | embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2],
613 | win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
614 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_motion_rate=0., attn_drop_motion_rate=0.,
615 | norm_layer=nn.LayerNorm, patch_norm=True,
616 | use_checkpoint=False, token_projection='linear', token_mlp='ffn', se_layer=False,
617 | dowsample=Downsample, upsample=Upsample, motion_net=MotionNet_Coord,
618 | shift_flag=True, prompt=False, task_classes=4, tps_points=[8, 10, 12, 14],
619 | correct_encoder=False, head_num=8, shared_head=False, **kwargs):
620 | super().__init__()
621 |
622 | self.num_enc_layers = len(depths) // 2
623 | self.num_dec_layers = len(depths) // 2
624 | self.embed_dim = embed_dim
625 | self.patch_norm = patch_norm
626 | self.mlp_ratio = mlp_ratio
627 | self.token_projection = token_projection
628 | self.mlp = token_mlp
629 | self.win_size = win_size
630 | self.mini_size = img_size // (2 ** 4)
631 | self.down_size = 2 ** 4
632 | self.tps_points = tps_points
633 | self.correct_encoder = correct_encoder
634 | self.prompt = prompt
635 | self.head_num = head_num
636 | self.shared_head = shared_head
637 |
638 | self.pos_drop = nn.Dropout(p=drop_rate)
639 | enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))]
640 | conv_dpr = [drop_path_rate] * depths[4]
641 | dec_dpr = enc_dpr[::-1]
642 |
643 | '''Input/Output Projection'''
644 | self.input_proj = InputProj(in_channel=in_chans, out_channel=embed_dim, kernel_size=3, stride=1,
645 | act_layer=nn.LeakyReLU)
646 | self.output_proj = OutputProj(in_channel=2 * embed_dim, out_channel=2, kernel_size=3, stride=1, task_classes=task_classes, prompt=False)
647 |
648 | '''Encoder'''
649 | self.encoderlayer_0 = BasicUformerLayer(dim=embed_dim,
650 | output_dim=embed_dim,
651 | input_resolution=(img_size,
652 | img_size),
653 | depth=depths[0],
654 | num_heads=num_heads[0],
655 | win_size=win_size,
656 | mlp_ratio=self.mlp_ratio,
657 | qkv_bias=qkv_bias, qk_scale=qk_scale,
658 | drop=drop_rate, attn_drop=attn_drop_rate,
659 | drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])],
660 | norm_layer=norm_layer,
661 | use_checkpoint=use_checkpoint,
662 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag)
663 | self.dowsample_0 = dowsample(embed_dim, embed_dim*2)
664 |
665 | self.encoderlayer_1 = BasicUformerLayer(dim=embed_dim*2,
666 | output_dim=embed_dim*2,
667 | input_resolution=(img_size // 2,
668 | img_size // 2),
669 | depth=depths[1],
670 | num_heads=num_heads[1],
671 | win_size=win_size,
672 | mlp_ratio=self.mlp_ratio,
673 | qkv_bias=qkv_bias, qk_scale=qk_scale,
674 | drop=drop_rate, attn_drop=attn_drop_rate,
675 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
676 | norm_layer=norm_layer,
677 | use_checkpoint=use_checkpoint,
678 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag)
679 | self.dowsample_1 = dowsample(embed_dim*2, embed_dim*4)
680 |
681 | self.encoderlayer_2 = BasicUformerLayer(dim=embed_dim*4,
682 | output_dim=embed_dim*4,
683 | input_resolution=(img_size // (2 ** 2),
684 | img_size // (2 ** 2)),
685 | depth=depths[2],
686 | num_heads=num_heads[2],
687 | win_size=win_size,
688 | mlp_ratio=self.mlp_ratio,
689 | qkv_bias=qkv_bias, qk_scale=qk_scale,
690 | drop=drop_rate, attn_drop=attn_drop_rate,
691 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
692 | norm_layer=norm_layer,
693 | use_checkpoint=use_checkpoint,
694 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag)
695 | self.dowsample_2 = dowsample(embed_dim*4, embed_dim*8)
696 |
697 | self.encoderlayer_3 = BasicUformerLayer(dim=embed_dim*8,
698 | output_dim=embed_dim*8,
699 | input_resolution=(img_size // (2 ** 3),
700 | img_size // (2 ** 3)),
701 | depth=depths[3],
702 | num_heads=num_heads[3],
703 | win_size=win_size,
704 | mlp_ratio=self.mlp_ratio,
705 | qkv_bias=qkv_bias, qk_scale=qk_scale,
706 | drop=drop_rate, attn_drop=attn_drop_rate,
707 | drop_path=enc_dpr[sum(depths[:3]):sum(depths[:4])],
708 | norm_layer=norm_layer,
709 | use_checkpoint=use_checkpoint,
710 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag)
711 | self.dowsample_3 = dowsample(embed_dim*8, embed_dim*16)
712 |
713 |
714 | '''Bottleneck'''
715 | self.conv = BasicUformerLayer(dim=embed_dim*16,
716 | output_dim=embed_dim*16,
717 | input_resolution=(img_size // (2 ** 4),
718 | img_size // (2 ** 4)),
719 | depth=depths[4],
720 | num_heads=num_heads[4],
721 | win_size=win_size,
722 | mlp_ratio=self.mlp_ratio,
723 | qkv_bias=qkv_bias, qk_scale=qk_scale,
724 | drop=drop_rate, attn_drop=attn_drop_rate,
725 | drop_path=conv_dpr,
726 | norm_layer=norm_layer,
727 | use_checkpoint=use_checkpoint,
728 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag)
729 |
730 | '''TPS regression heads'''
731 | if(self.shared_head):
732 | reg = BasicUformerLayer(dim=embed_dim*16,
733 | output_dim=embed_dim*16,
734 | input_resolution=(img_size // (2 ** 4),
735 | img_size // (2 ** 4)),
736 | depth=depths[4],
737 | num_heads=num_heads[4],
738 | win_size=win_size,
739 | mlp_ratio=self.mlp_ratio,
740 | qkv_bias=qkv_bias, qk_scale=qk_scale,
741 | drop=drop_motion_rate, attn_drop=attn_drop_motion_rate,
742 | drop_path=conv_dpr,
743 | norm_layer=norm_layer,
744 | use_checkpoint=use_checkpoint,
745 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag)
746 | self.reg = nn.ModuleList([reg for _ in range(self.head_num)])
747 | else:
748 | self.reg = nn.ModuleList([BasicUformerLayer(dim=embed_dim*16,
749 | output_dim=embed_dim*16,
750 | input_resolution=(img_size // (2 ** 4),
751 | img_size // (2 ** 4)),
752 | depth=depths[4],
753 | num_heads=num_heads[4],
754 | win_size=win_size,
755 | mlp_ratio=self.mlp_ratio,
756 | qkv_bias=qkv_bias, qk_scale=qk_scale,
757 | drop=drop_motion_rate, attn_drop=attn_drop_motion_rate,
758 | drop_path=conv_dpr,
759 | norm_layer=norm_layer,
760 | use_checkpoint=use_checkpoint,
761 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag) for _ in range(self.head_num)])
762 |
763 |
764 | self.motion = nn.ModuleList([motion_net(embed_dim*16, 2, tps_points[i]) for i in range(self.head_num)])
765 | self.pointnet = PointNet(task_classes, tps_points[-1], tps_points[-1])
766 |
767 | '''Decoder to Flow'''
768 | self.upsample_0 = upsample(embed_dim*16, embed_dim*8)
769 | self.decoderlayer_0 = BasicUformerLayer(dim=embed_dim*16,
770 | output_dim=embed_dim*16,
771 | input_resolution=(img_size // (2 ** 3),
772 | img_size // (2 ** 3)),
773 | depth=depths[5],
774 | num_heads=num_heads[5],
775 | win_size=win_size,
776 | mlp_ratio=self.mlp_ratio,
777 | qkv_bias=qkv_bias, qk_scale=qk_scale,
778 | drop=drop_rate, attn_drop=attn_drop_rate,
779 | drop_path=dec_dpr[:depths[5]],
780 | norm_layer=norm_layer,
781 | use_checkpoint=use_checkpoint,
782 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,
783 | prompt=prompt, task_classes=task_classes)
784 |
785 | self.upsample_1 = upsample(embed_dim*16, embed_dim*4)
786 | self.decoderlayer_1 = BasicUformerLayer(dim=embed_dim*8,
787 | output_dim=embed_dim*8,
788 | input_resolution=(img_size // (2 ** 2),
789 | img_size // (2 ** 2)),
790 | depth=depths[6],
791 | num_heads=num_heads[6],
792 | win_size=win_size,
793 | mlp_ratio=self.mlp_ratio,
794 | qkv_bias=qkv_bias, qk_scale=qk_scale,
795 | drop=drop_rate, attn_drop=attn_drop_rate,
796 | drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])],
797 | norm_layer=norm_layer,
798 | use_checkpoint=use_checkpoint,
799 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,
800 | prompt=prompt, task_classes=task_classes)
801 |
802 | self.upsample_2 = upsample(embed_dim*8, embed_dim*2)
803 | self.decoderlayer_2 = BasicUformerLayer(dim=embed_dim*4,
804 | output_dim=embed_dim*4,
805 | input_resolution=(img_size // 2,
806 | img_size // 2),
807 | depth=depths[7],
808 | num_heads=num_heads[7],
809 | win_size=win_size,
810 | mlp_ratio=self.mlp_ratio,
811 | qkv_bias=qkv_bias, qk_scale=qk_scale,
812 | drop=drop_rate, attn_drop=attn_drop_rate,
813 | drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])],
814 | norm_layer=norm_layer,
815 | use_checkpoint=use_checkpoint,
816 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,
817 | prompt=prompt, task_classes=task_classes)
818 |
819 | self.upsample_3 = upsample(embed_dim*4, embed_dim)
820 | self.decoderlayer_3 = BasicUformerLayer(dim=embed_dim*2,
821 | output_dim=embed_dim*2,
822 | input_resolution=(img_size,
823 | img_size),
824 | depth=depths[8],
825 | num_heads=num_heads[8],
826 | win_size=win_size,
827 | mlp_ratio=self.mlp_ratio,
828 | qkv_bias=qkv_bias, qk_scale=qk_scale,
829 | drop=drop_rate, attn_drop=attn_drop_rate,
830 | drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])],
831 | norm_layer=norm_layer,
832 | use_checkpoint=use_checkpoint,
833 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,
834 | prompt=prompt, task_classes=task_classes)
835 |
836 | self.apply(self._init_weights)
837 |
838 | def _init_weights(self, m):
839 | if isinstance(m, nn.Linear):
840 | trunc_normal_(m.weight, std=.02)
841 | if isinstance(m, nn.Linear) and m.bias is not None:
842 | nn.init.constant_(m.bias, 0)
843 | elif isinstance(m, nn.LayerNorm):
844 | nn.init.constant_(m.bias, 0)
845 | nn.init.constant_(m.weight, 1.0)
846 |
847 | @torch.jit.ignore
848 | def no_weight_decay(self):
849 | return {'absolute_pos_embed'}
850 |
851 | @torch.jit.ignore
852 | def no_weight_decay_keywords(self):
853 | return {'relative_position_bias_table'}
854 |
855 | def extra_repr(self) -> str:
856 | return f"embed_dim={self.embed_dim}, token_projection={self.token_projection}, token_mlp={self.mlp},win_size={self.win_size}"
857 |
858 | def forward(self, x, mask, feed_prompt=None):
859 | '''Input Projection'''
860 | x = torch.cat([x, mask], 1)
861 | y = self.input_proj(x)
862 | y = self.pos_drop(y)
863 |
864 | '''Encoder'''
865 | conv0 = self.encoderlayer_0(y)
866 | pool0 = self.dowsample_0(conv0)
867 | conv1 = self.encoderlayer_1(pool0)
868 | pool1 = self.dowsample_1(conv1)
869 | conv2 = self.encoderlayer_2(pool1)
870 | pool2 = self.dowsample_2(conv2)
871 | conv3 = self.encoderlayer_3(pool2)
872 | pool3 = self.dowsample_3(conv3)
873 | conv4 = self.conv(pool3)
874 |
875 | '''TPS regression heads'''
876 | tps = []
877 | tps_up = []
878 | fea = [conv4]
879 | for i in range(self.head_num):
880 | conv_fea = fea[-1]
881 | pre = self.reg[i](conv_fea)
882 | pre = self.motion[i](pre)
883 |
884 | warp = transform_tps_fea(pre/self.down_size, conv_fea, self.tps_points[i]-1, self.tps_points[i]-1,
885 | self.embed_dim*16, self.mini_size, self.mini_size)
886 | fea.append(warp)
887 | if i==0:
888 | tps.append(pre)
889 | else:
890 | tps.append(pre + tps_up[-1])
891 | if(i < self.head_num - 1):
892 | tps_up.append(upsample_tps(tps[i], self.tps_points[i]-1, self.tps_points[i]-1, self.tps_points[i+1], self.tps_points[i+1]))
893 |
894 | '''Point classification'''
895 | point_cls = self.pointnet(tps[-1], conv4)
896 | prompt = F.softmax(point_cls, dim=1)
897 |
898 | '''Decoder to residual flow'''
899 | up0 = self.upsample_0(fea[-1])
900 | if(self.correct_encoder):
901 | conv3 = transform_tps_fea(tps[-1]/(self.down_size//2), conv3, self.tps_points[-1]-1, self.tps_points[-1]-1,
902 | self.embed_dim*8, self.mini_size*2, self.mini_size*2)
903 | deconv0 = torch.cat([up0, conv3], -1)
904 | deconv0 = self.decoderlayer_0(deconv0, prompt)
905 |
906 | up1 = self.upsample_1(deconv0)
907 | if(self.correct_encoder):
908 | conv2 = transform_tps_fea(tps[-1]/(self.down_size//4), conv2, self.tps_points[-1]-1, self.tps_points[-1]-1,
909 | self.embed_dim*4, self.mini_size*4, self.mini_size*4)
910 | deconv1 = torch.cat([up1, conv2], -1)
911 | deconv1 = self.decoderlayer_1(deconv1, prompt)
912 |
913 | up2 = self.upsample_2(deconv1)
914 | if(self.correct_encoder):
915 | conv1 = transform_tps_fea(tps[-1]/(self.down_size//8), conv1, self.tps_points[-1]-1, self.tps_points[-1]-1,
916 | self.embed_dim*2, self.mini_size*8, self.mini_size*8)
917 | deconv2 = torch.cat([up2, conv1], -1)
918 | deconv2 = self.decoderlayer_2(deconv2, prompt)
919 |
920 | up3 = self.upsample_3(deconv2)
921 | if(self.correct_encoder):
922 | conv0 = transform_tps_fea(tps[-1], conv0, self.tps_points[-1]-1, self.tps_points[-1]-1,
923 | self.embed_dim*1, self.mini_size*16, self.mini_size*16)
924 | deconv3 = torch.cat([up3, conv0], -1)
925 | deconv3 = self.decoderlayer_3(deconv3, prompt)
926 |
927 | '''Output flow'''
928 | flow = self.output_proj(deconv3)
929 |
930 | return tps, flow, point_cls
--------------------------------------------------------------------------------