├── .gitignore
├── README.md
├── data
├── __init__.py
├── data_train_gym_lidar_aerial
│ └── ground_truth_gym_lidar.csv
├── data_train_qsdjt_lidar_aerial
│ └── ground_truth_qsdjt_lidar.csv
├── data_train_qsdjt_lidar_sat
│ └── ground_truth_qsdjt_lidar.csv
├── data_train_qsdjt_stereo_aerial
│ └── ground_truth_qsdjt_lidar.csv
├── data_train_qsdjt_stereo_sat
│ └── ground_truth_qsdjt_lidar.csv
├── data_utils.py
├── dataset.py
├── dataset_DPCN.py
├── ortho
│ └── ground_truth_gym_lidar.csv
└── simulation.py
├── detect.py
├── fft
├── __init__.py
├── dft_test.py
├── fft_demo.py
├── imreg_test.py
└── test.py
├── images_for_readme
├── Result1.png
├── Result2.png
└── simulation.png
├── ipynb
├── pytorch_fcn.ipynb
├── pytorch_resnet18_unet.ipynb
└── pytorch_unet.ipynb
├── log_polar
├── __init__.py
├── log_polar.py
├── polar.py
├── polarizeLayer.py
└── tf_test.py
├── phase_correlation
├── __init__.py
├── phase_corr.py
└── tmp.dot
├── requirements.txt
├── trainDPCN.py
├── unet
├── LICENSE
├── README.md
├── __init__.py
├── helper.py
├── loss.py
└── pytorch_DPCN.py
├── utils
├── __init__.py
├── detect_utils.py
├── train_utils.py
├── utils.py
└── validate_utils.py
└── validate.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.jpg
2 | *.pth
3 | *.zip
4 | *.pt
5 | *.mav-lab
6 | *.png
7 | *.DS_Store
8 | *.pyc
9 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://arxiv.org/abs/2008.09474)
2 | [](https://www.youtube.com/watch?v=_xPoHFf_8yI)
3 | [](https://opensource.org/licenses/MIT)
4 | # DPCN: Deep Phase Correlation for End-to-End Heterogeneous Sensor Measurements Matching
5 |
6 | This is the official repository for DPCN, with paper "[Deep Phase Correlation for End-to-End Heterogeneous Sensor Measurements Matching](https://proceedings.mlr.press/v155/chen21g.html)" accpepted and presented in CoRL2020.
7 |
8 |
9 |
10 | ## Dependencies
11 |
12 | There are a few dependencies required to run the code. They are listed below:
13 |
14 | ### System Environments:
15 |
16 | `Python >= 3.5`
17 |
18 | `CUDA 10.1`
19 |
20 | `CuDnn`
21 |
22 | ### Pip dependencies:
23 |
24 | `Pytorch 1.5.0`
25 |
26 | `Torchvision 0.6.0`
27 |
28 | `Kornia 0.3.1`
29 |
30 | `Graphviz`
31 |
32 | `Opencv`
33 |
34 | `Scipy`
35 |
36 | `Matplotlib`
37 |
38 | `Pandas`
39 |
40 | `TensorboardX`
41 |
42 |
43 |
44 | You can install these dependencies by changing your current directory to this DPCN directory and running:
45 |
46 | `pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html` (if using pip),
47 |
48 | or manually install these dependencies if it makes you feel more comfortable and assured.
49 |
50 | We would suggest you to setup the whole environment in Conda or Virtualenv so that your original system will not be affected.
51 |
52 | Also, we are now working on the upgrade process to be competible with the lastes version of packages (e.g. torch==1.7.1, cuda11, and kornia0.4.1), coming soon!
53 |
54 |
55 |
56 | ### Pre-trained Models:
57 |
58 | ***Pre-trained Models can be downloaded from this link:*** https://drive.google.com/file/d/1GZ8hz3cfaBP7F7KEdQK7M_WcwOM1GF6z/view?usp=sharing
59 |
60 |
61 |
62 |
63 |
64 | ## How to Run The Code
65 |
66 | ### Step 0
67 |
68 | Before you start to train or validate on the simulation dataset, you might want to choose one specific dataset type of the following three: Homogeneous, Heterogeneous, and Dynamic Obstacles. Demonstrations are shown as follows:
69 |
70 | 
71 |
72 | **By default, the code is running on Heterogeneous dataset** and you can change this by modifying following lines in "./data/simulation.py".
73 |
74 | 1. To decide whether they are heterogeneous, modify line 78~80:
75 |
76 | `# for heterogeneous image, comment out if you want them to be homogeneous
77 | rot = cv2.GaussianBlur(rot, (9,9), 13)
78 | kernel = np.array([[-1,-1,-1],[-1,9,-1],[-1,-1,-1]])
79 | rot = cv2.filter2D(rot, -1, kernel)`
80 |
81 | 2. To decide whether there are dynamic obstacles, modify line 83~85
82 |
83 | `# for dynamic obstacles, comment out if you dont want any dynamic obstacles
84 | arr[0,] = add_plus(arr[0], *plus_location4)
85 | arr = np.reshape(arr, (1, height, width)).astype(np.float32)`
86 |
87 |
88 |
89 | ### Training
90 |
91 | If you want to train the DPCN network from scratch on simulation dataset then simply run:
92 |
93 | `python trainDPCN.py --simulation`.
94 |
95 | By default, this will train the network on the heterogeneous dataset with the batch size of 2, and will run on GPU. There are several settings you can change by adding arguments below:
96 |
97 | | Arguments | What it will trigger | Default |
98 | | ------------------- | ------------------------------------------------------------ | --------------------------- |
99 | | --save_path | The path to save the checkpoint of every epoch | ./checkpoints/ |
100 | | --simulation | The training will be applied on a randomly generated simulation dataset | False |
101 | | --cpu | The Program will use cpu for the training | False |
102 | | --load_pretrained | Choose whether to use a pretrained model to fine tune | Fasle |
103 | | --load_path | The path to load a pretrained checkpoint | ./checkpoints/checkpoint.pt |
104 | | --load_optimizer | When using a pretrained model, options of loading it's optimizer | False |
105 | | --pretrained_mode | Three options:
'all' for loading rotation and translation;
'rot' for loading only rotation;
'trans' for loading only translation | All |
106 | | --use_dsnt | When enabled, the loss will be calculated via DSNT and MSELoss, or it will use a CELoss | False |
107 | | --batch_size_train | The batch size of training | 2 |
108 | | --batch_size_val | The batch size of training | 2 |
109 | | --train_writer_path | Where to write the Log of training | ./checkpoints/log/train/ |
110 | | --val_writer_path | Where to write the Log of validation | ./checkpoints/log/val/ |
111 |
112 |
113 |
114 |
115 |
116 | ### Validating
117 |
118 | If you are only interested in validating on the randomly generated simulation dataset, then you can simply run following lines based on the specific dataset type you chose in **Step 0**.
119 |
120 | For **Homogeneous** sets:
121 |
122 | `python validate.py --simulation --only_valid --load_path=./checkpoints/checkpoint_simulation_homo.pt`
123 |
124 | For **Heterogeneous** sets:
125 |
126 | `python validate.py --simulation --only_valid --load_path=./checkpoints/checkpoint_simulation_hetero.pt`
127 |
128 | For **Dynamic** Obstacle sets:
129 |
130 | `python validate.py --simulation --only_valid --load_path=./checkpoints/checkpoint_simulation_dynamic.pt`
131 |
132 |
133 |
134 | ***Again, Pre-trained Models can be downloaded from this link:*** https://drive.google.com/file/d/1GZ8hz3cfaBP7F7KEdQK7M_WcwOM1GF6z/view?usp=sharing
135 |
136 |
137 |
138 | Similarly, there are several options that you can choose when running validation, shown as follows:
139 |
140 | | Arguments | What it will trigger | Default |
141 | | ----------------- | ------------------------------------------------------------ | --------------------------- |
142 | | --only_valid | You have to use this command if you run validation alone | False |
143 | | --simulation | The training will be applied on a randomly generated simulation dataset | False |
144 | | --cpu | The Program will use cpu for the training | False |
145 | | --load_path | The path to load a pretrained checkpoint | ./checkpoints/checkpoint.pt |
146 | | --use_dsnt | When enabled, the loss will be calculated via DSNT and MSELoss, or it will use a CELoss | False |
147 | | --batch_size_val | The batch size of training | 2 |
148 | | --val_writer_path | Where to write the Log of validation | ./checkpoints/log/val/ |
149 |
150 |
151 |
152 | ### Detecting and Infering:
153 |
154 | This repository also provided a single pair detecting script so that you can see the result of DPCN directly. A few demos are given in the "./demo" directory including images pairs from both simulation dataset and Aero-Ground Dataset. You could customize the script `detect.py` to test the chosen images pair with relative pre-trained model given, and run this code below:
155 |
156 | `python detect.py`
157 |
158 | The results should be something like this:
159 |
160 |
161 |
162 | 
163 |
164 | 
165 |
166 | ### Citation
167 | If our source code could help you in your project, please cite the following:
168 | ```bibtex
169 | @InProceedings{chen2020deep,
170 | title = {Deep Phase Correlation for End-to-End Heterogeneous Sensor Measurements Matching},
171 | author = {Chen, Zexi and Xu, Xuecheng and Wang, Yue and Xiong, Rong},
172 | booktitle = {Proceedings of the 2020 Conference on Robot Learning},
173 | pages = {2359--2375},
174 | year = {2021},
175 | volume = {155},
176 | series = {Proceedings of Machine Learning Research},
177 | month = {16--18 Nov},
178 | publisher = {PMLR},
179 | pdf = {https://proceedings.mlr.press/v155/chen21g/chen21g.pdf},
180 | url = {https://proceedings.mlr.press/v155/chen21g.html},
181 | }
182 | ```
183 |
184 |
185 |
186 |
187 |
188 | ## HAVE FUN with the CODE!!!!
189 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of useful helper functions."""
2 |
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | from torchvision import transforms, utils
3 | import pandas as pd
4 | from pandas import Series,DataFrame
5 | import numpy as np
6 | import torch
7 | import os
8 | import cv2
9 | from PIL import Image
10 | import matplotlib.pyplot as plt
11 | import sys
12 | import os
13 | sys.path.append(os.path.abspath(".."))
14 | from utils.utils import *
15 |
16 | def default_loader(path, resize_shape, change_scale = False):
17 | trans = transforms.Compose([
18 | transforms.ToTensor(),
19 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet
20 | ])
21 | image = cv2.imread(path)
22 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
23 | (h_original, w_original) = image.shape
24 | image = cv2.resize(image, dsize=(resize_shape,resize_shape), interpolation=cv2.INTER_CUBIC)
25 | angle = (np.random.rand()-0.5) * 0.
26 | angle += 180.
27 | angle %= 360
28 | angle -= 180.
29 | (h, w) = image.shape
30 | (cX, cY) = (w//2, h//2)
31 |
32 | t_x = np.random.rand() * 0.
33 | t_y = np.random.rand() * 0.
34 | translation = np.array((t_y, t_x))
35 |
36 | # arr = arr[0,]
37 | # rot = ndii.rotate(arr, angle)
38 |
39 | # N = np.float32([[1,0,t_x],[0,1,t_y]])
40 | # image = cv2.warpAffine(image, N, (w, h))
41 | # M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
42 | # image = cv2.warpAffine(image, M, (w, h))
43 | # image = cv2.resize(image, (h, w), interpolation=cv2.INTER_CUBIC)
44 |
45 | np_image_data = np.asarray(image)
46 | image_tensor = trans(np_image_data)
47 | scaling_factor = 1
48 | if change_scale:
49 | center = torch.ones(1,2)
50 | center[:, 0] = h // 2
51 | center[:, 1] = w // 2
52 | scaling_factor = torch.tensor(np.random.rand()*0.2+1)
53 | angle_source = torch.ones(1) * 0.
54 | scale_source = torch.ones(1) * scaling_factor
55 | image_tensor = image_tensor.unsqueeze(0)
56 | rot_mat = kornia.get_rotation_matrix2d(center, angle_source, scale_source)
57 | image_tensor = kornia.warp_affine(image_tensor, rot_mat, dsize=(h, w))
58 | image_tensor = image_tensor.squeeze(0)
59 | # image = Image.open(path)
60 | # image = image.convert("1")
61 | # # image.show()
62 | # image = image.resize((128,128))
63 | # image_tensor = trans(image)
64 | return image_tensor, angle, translation, scaling_factor, h_original, w_original
65 |
66 | def get_gt_tensor(this_gt, size):
67 | this_gt = this_gt +180
68 | gt_tensor_self = torch.zeros(size,size)
69 | angle_convert = this_gt*size/360
70 | angle_index = angle_convert//1 + (angle_convert%1+0.5)//1
71 | if angle_index.long() == size:
72 | angle_index = size-1
73 | gt_tensor_self[angle_index,0] = 1
74 | else:
75 | gt_tensor_self[angle_index.long(),0] = 1
76 | # print("angle_index", angle_index)
77 |
78 | return gt_tensor_self
79 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | from torchvision import transforms, datasets, models
3 | import numpy as np
4 | import torch
5 | import sys
6 | import data.simulation as simulation
7 | from data.data_utils import *
8 |
9 |
10 | class SimDataset(Dataset):
11 | def __init__(self, count, transform=None):
12 | self.input_images, self.rotate_image, self.gt_rot, self.gt_trans = simulation.generate_random_data(256, 256, count=count)
13 | self.transform = transform
14 |
15 | def __len__(self):
16 | return len(self.input_images)
17 |
18 | def __getitem__(self, idx):
19 | image = self.input_images[idx]
20 | rot = self.rotate_image[idx]
21 | gt_rot = self.gt_rot[idx]
22 | gt_trans = self.gt_trans[idx]
23 | gt_scale = torch.tensor(1.)
24 |
25 | if self.transform:
26 | image = self.transform(image)
27 | rot = self.transform(rot)
28 | gt_rots = torch.tensor(gt_rot)
29 | gt_scale = torch.tensor(1.)
30 | gt_trans = torch.tensor(gt_trans)
31 | # print("gt = ", gt)
32 | # print("gt tensor = ", gt_tensor)
33 |
34 | return [image, rot, gt_rots, gt_scale, gt_trans]
35 |
36 | def generate_dataloader(batch_size):
37 | # use the same transformations for train/val in this example
38 | trans = transforms.Compose([
39 | transforms.ToTensor(),
40 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet
41 | ])
42 |
43 | train_set = SimDataset(2000, transform = trans)
44 | val_set = SimDataset(1000, transform = trans)
45 |
46 | image_datasets = {
47 | 'train': train_set, 'val': val_set
48 | }
49 |
50 | dataloaders = {
51 | 'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
52 | 'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
53 | }
54 | return dataloaders
--------------------------------------------------------------------------------
/data/dataset_DPCN.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | from torchvision import transforms, utils
3 | import pandas as pd
4 | from pandas import Series,DataFrame
5 | import numpy as np
6 | import torch
7 | import os
8 | import cv2
9 | from PIL import Image
10 | import matplotlib.pyplot as plt
11 | import sys
12 | import os
13 |
14 | sys.path.append(os.path.abspath(".."))
15 | from utils.utils import *
16 | from data.data_utils import *
17 |
18 | # datacsv = pd.read_csv("./data_train/ground_truth.csv")
19 |
20 | # template_list = datacsv["name"].values
21 | # template_list = [i+".jpg" for i in template_list]
22 | # template_list = [os.path.join("./data_train/ground/",i) for i in template_list ]
23 | # template_train_list = template_list[:5000]
24 | # template_val_list = template_list[:5000]
25 |
26 | # source_list = datacsv["name"].values
27 | # source_list = [i+".jpg" for i in source_list]
28 | # source_list = [os.path.join("./data_train/aerial/",i) for i in source_list ]
29 | # source_train_list = source_list[:5000]
30 | # source_val_list = source_list[:5000]
31 |
32 | # ground_truth_list = datacsv["rotation"].values
33 | # ground_truth_train_list = ground_truth_list[5000:]
34 | # ground_truth_val_list = ground_truth_list[:5000]
35 | # ground_truth_train_list = torch.from_numpy(ground_truth_train_list)
36 | # ground_truth_val_list = torch.from_numpy(ground_truth_val_list)
37 |
38 | # trans = transforms.Compose([
39 | # transforms.ToTensor(),
40 | # # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet
41 | # ])
42 |
43 |
44 | class AeroGroundDataset_train(Dataset):
45 | def __init__(self, template_train_list, source_train_list, gt_rot_train_list, gt_scale_train_list, gt_x_train_list, gt_y_train_list, loader=default_loader):
46 | self.template_path_list = template_train_list #this is a list of the path to the template image
47 | self.source_path_list = source_train_list
48 | self.gt_rot_list = gt_rot_train_list
49 | self.gt_scale_list = gt_scale_train_list
50 | self.loader = loader
51 | self.gt_trans_x = gt_x_train_list
52 | self.gt_trans_y = gt_y_train_list
53 |
54 | # add x y theta
55 | def __len__(self):
56 | return len(self.template_path_list)
57 |
58 | def __getitem__(self, index):
59 | # print(np.shape(this_source))
60 | this_template_path = self.template_path_list[index]
61 | this_source_path = self.source_path_list[index]
62 | rot_gt = self.gt_rot_list[index]
63 | scale_gt = self.gt_scale_list[index]
64 | trans_x = self.gt_trans_x[index]
65 | trans_y = self.gt_trans_y[index]
66 |
67 | this_template, _, _, _, _,_ = self.loader(this_template_path, resize_shape=256)
68 | this_source, _, _, scaling_factor, h_original, w_original = self.loader(this_source_path, resize_shape=256, change_scale=True)
69 | # print("this gt =", rot_gt)
70 | # gt_tensor = get_gt_tensor(rot_gt, this_template.size(1))
71 | # print("gt_tensor =", gt_tensor)
72 | # rot_gt += angle_t + angle_s
73 | # rot_gt += 180.
74 | # rot_gt %= 360
75 | # rot_gt -= 180.
76 | rot_gt = torch.tensor(rot_gt)
77 | scale_gt = torch.tensor(scale_gt) * scaling_factor
78 |
79 | trans = np.array((trans_y/(h_original/this_template.size(1)), trans_x/(w_original/this_template.size(1))))
80 |
81 | gt_trans = torch.tensor(trans)
82 | # add x y theta
83 | return [this_template, this_source, rot_gt, scale_gt, gt_trans]
84 |
85 | class AeroGroundDataset_val(Dataset):
86 | def __init__(self, template_val_list, source_val_list, gt_rot_val_list, gt_scale_val_list, gt_x_val_list, gt_y_val_list, loader=default_loader):
87 | self.template_path_list = template_val_list #this is a list of the path to the template image
88 | self.source_path_list = source_val_list
89 | self.gt_rot_list = gt_rot_val_list
90 | self.gt_scale_list = gt_scale_val_list
91 | self.loader = loader
92 | self.gt_trans_x = gt_x_val_list
93 | self.gt_trans_y = gt_y_val_list
94 |
95 | def __len__(self):
96 | return len(self.template_path_list)
97 |
98 | def __getitem__(self, index):
99 |
100 | this_template_path = self.template_path_list[index]
101 | this_source_path = self.source_path_list[index]
102 | rot_gt = self.gt_rot_list[index]
103 | scale_gt = self.gt_scale_list[index]
104 | trans_x = self.gt_trans_x[index]
105 | trans_y = self.gt_trans_y[index]
106 | this_template, _, _, _, _, _ = self.loader(this_template_path, resize_shape=256)
107 | this_source, _, _, scaling_factor, h_original, w_original = self.loader(this_source_path, resize_shape=256, change_scale=True)
108 |
109 | # gt_tensor = get_gt_tensor(rot_gt, this_template.size(1))
110 |
111 | # rot_gt += angle_t + angle_s
112 | # rot_gt += 180.
113 | # rot_gt %= 360
114 | # rot_gt -= 180.
115 | rot_gt = torch.tensor(rot_gt)
116 | scale_gt = torch.tensor(scale_gt) * scaling_factor
117 |
118 | # trans_x = (torch.sign(-trans_x) + 1) / 2 * 256 + trans_x
119 | # trans_y = (torch.sign(-trans_y) + 1) / 2 * 256 + trans_y
120 | trans = np.array((trans_y/(h_original/this_template.size(1)), trans_x/(w_original/this_template.size(1))))
121 | gt_trans = torch.tensor(trans)
122 |
123 | return [this_template, this_source, rot_gt, scale_gt, gt_trans]
124 |
125 |
126 | def DPCNdataloader(batch_size):
127 | # use the same transformations for train/val in this example
128 | path = "./data"
129 | datacsv = pd.read_csv(path + "/data_train_qsdjt_stereo_sat/ground_truth_qsdjt_lidar.csv")
130 | train_upper = 6000
131 | val_num = 2000
132 | val_upper = train_upper+val_num
133 |
134 | template_list = datacsv["name"].values
135 | template_list = [i+".jpg" for i in template_list]
136 | template_list = [os.path.join(path + "/data_train_qsdjt_stereo_sat/ground/",i) for i in template_list ]
137 | template_train_list = template_list[:train_upper]
138 | template_val_list = template_list[train_upper:val_upper]
139 |
140 | source_list = datacsv["name"].values
141 | source_list = [i+".jpg" for i in source_list]
142 | source_list = [os.path.join(path + "/data_train_qsdjt_stereo_sat/aerial/",i) for i in source_list ]
143 | source_train_list = source_list[:train_upper]
144 | source_val_list = source_list[train_upper:val_upper]
145 |
146 | gt_rot_list = datacsv["rotation"].values
147 | gt_rot_train_list = gt_rot_list[:train_upper]
148 | gt_rot_val_list = gt_rot_list[train_upper:val_upper]
149 | gt_rot_train_list = torch.from_numpy(gt_rot_train_list)
150 | gt_rot_val_list = torch.from_numpy(gt_rot_val_list)
151 |
152 | gt_scale_list = datacsv["rotation"].values * 0 +1.0
153 | gt_scale_train_list = gt_scale_list[:train_upper]
154 | gt_scale_val_list = gt_scale_list[train_upper:val_upper]
155 | gt_scale_train_list = torch.from_numpy(gt_scale_train_list)
156 | gt_scale_val_list = torch.from_numpy(gt_scale_val_list)
157 |
158 | gt_x_list = datacsv["shift_x"].values
159 | gt_x_train_list = gt_x_list[:train_upper]
160 | gt_x_val_list = gt_x_list[train_upper:val_upper]
161 | gt_x_train_list = torch.from_numpy(gt_x_train_list)
162 | gt_x_val_list = torch.from_numpy(gt_x_val_list)
163 |
164 | gt_y_list = datacsv["shift_y"].values
165 | gt_y_train_list = gt_y_list[:train_upper]
166 | gt_y_val_list = gt_y_list[train_upper:val_upper]
167 | gt_y_train_list = torch.from_numpy(gt_y_train_list)
168 | gt_y_val_list = torch.from_numpy(gt_y_val_list)
169 |
170 | train_set = AeroGroundDataset_train(template_train_list, source_train_list, gt_rot_train_list, gt_scale_train_list, gt_x_train_list, gt_y_train_list)
171 | val_set = AeroGroundDataset_val(template_val_list, source_val_list, gt_rot_val_list, gt_scale_val_list, gt_x_val_list, gt_y_val_list)
172 |
173 | image_datasets = {
174 | 'train': train_set, 'val': val_set
175 | }
176 |
177 |
178 | dataloaders = {
179 | 'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
180 | 'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
181 | }
182 | return dataloaders
183 |
--------------------------------------------------------------------------------
/data/ortho/ground_truth_gym_lidar.csv:
--------------------------------------------------------------------------------
1 | name,rotation,shift_x,shift_y
2 | 0_,172.4768822,37,-28
3 | 28_,172.4768822,24,-43
4 |
--------------------------------------------------------------------------------
/data/simulation.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import random
4 | try:
5 | import scipy.ndimage.interpolation as ndii
6 | except ImportError:
7 | import ndimage.interpolation as ndii
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | def generate_random_data(height, width, count):
12 | x, y, gt, trans = zip(*[generate_img_and_rot_img(height, width) for i in range(0, count)])
13 |
14 | X = np.asarray(x) * 255
15 | X = X.repeat(1, axis=1).transpose([0, 2, 3, 1]).astype(np.uint8)
16 | Y = np.asarray(y) * 100
17 | Y = Y.repeat(1, axis=1).transpose([0, 2, 3, 1]).astype(np.uint8)
18 |
19 | return X, Y, gt, trans
20 |
21 | def generate_img_and_rot_img(height, width):
22 | shape = (height, width)
23 |
24 | triangle_location = get_random_location(*shape)
25 | triangle_location1 = get_random_location(*shape)
26 | triangle_location2 = get_random_location(*shape)
27 | circle_location1 = get_random_location(*shape, zoom=0.7)
28 | circle_location2 = get_random_location(*shape, zoom=0.5)
29 | circle_location3 = get_random_location(*shape, zoom=0.9)
30 | mesh_location = get_random_location(*shape)
31 | square_location = get_random_location(*shape, zoom=0.8)
32 | plus_location = get_random_location(*shape, zoom=1.2)
33 | plus_location1 = get_random_location(*shape, zoom=1.2)
34 | plus_location2 = get_random_location(*shape, zoom=1.2)
35 | plus_location3 = get_random_location(*shape, zoom=1.2)
36 | plus_location4 = get_random_location(*shape, zoom=1.2)
37 |
38 | # Create input image
39 | arr = np.zeros(shape, dtype=bool)
40 | arr = add_triangle(arr, *triangle_location)
41 | arr = add_triangle(arr, *triangle_location1)
42 | arr = add_triangle(arr, *triangle_location2)
43 | arr = add_circle(arr, *circle_location1)
44 | arr = add_circle(arr, *circle_location2, fill=True)
45 | arr = add_circle(arr, *circle_location3)
46 | arr = add_mesh_square(arr, *mesh_location)
47 | arr = add_filled_square(arr, *square_location)
48 | arr = add_plus(arr, *plus_location)
49 | arr = add_plus(arr, *plus_location1)
50 | arr = add_plus(arr, *plus_location2)
51 | arr = add_plus(arr, *plus_location3)
52 | arr = np.reshape(arr, (1, height, width)).astype(np.float32)
53 |
54 | angle = np.random.rand() * 180
55 | t_x = (np.random.rand()-0.5) * 50.
56 | t_y = (np.random.rand()-0.5) * 50.
57 | trans = np.array((t_y, t_x))
58 |
59 | if angle < -180.0:
60 | angle = angle + 360.0
61 | elif angle > 180.0:
62 | angle = angle - 360.0
63 |
64 | # arr = arr[0,]
65 | # rot = ndii.rotate(arr, angle)
66 |
67 | (_, h, w) = arr.shape
68 | (cX, cY) = (w//2, h//2)
69 | rot = arr[0,]
70 |
71 | N = np.float32([[1,0,t_x],[0,1,t_y]])
72 | rot = cv2.warpAffine(rot, N, (w, h))
73 |
74 | M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
75 | rot = cv2.warpAffine(rot, M, (w, h))
76 | rot = cv2.resize(rot, (h, w), interpolation=cv2.INTER_CUBIC)
77 | # for heterogeneous image, comment out if you want them to be homogeneous
78 | rot = cv2.GaussianBlur(rot, (9,9), 13)
79 | kernel = np.array([[-1,-1,-1],[-1,9,-1],[-1,-1,-1]])
80 | rot = cv2.filter2D(rot, -1, kernel)
81 |
82 | rot = rot[np.newaxis, :]
83 | # for dynamic obstacles, comment out if you dont want any dynamic obstacles
84 | # arr[0,] = add_plus(arr[0], *plus_location4)
85 | # arr = np.reshape(arr, (1, height, width)).astype(np.float32)
86 |
87 | return arr, rot, angle, trans
88 |
89 | def add_square(arr, x, y, size):
90 | s = int(size / 2)
91 | arr[x-s,y-s:y+s] = True
92 | arr[x+s,y-s:y+s] = True
93 | arr[x-s:x+s,y-s] = True
94 | arr[x-s:x+s,y+s] = True
95 |
96 | return arr
97 |
98 | def add_filled_square(arr, x, y, size):
99 | s = int(size / 2)
100 |
101 | xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
102 |
103 | return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, yy > y - s, yy < y + s]))
104 |
105 | def logical_and(arrays):
106 | new_array = np.ones(arrays[0].shape, dtype=bool)
107 | for a in arrays:
108 | new_array = np.logical_and(new_array, a)
109 |
110 | return new_array
111 |
112 | def add_mesh_square(arr, x, y, size):
113 | s = int(size / 2)
114 |
115 | xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
116 |
117 | return np.logical_or(arr, logical_and([xx > x - s, xx < x + s, xx % 2 == 1, yy > y - s, yy < y + s, yy % 2 == 1]))
118 |
119 | def add_triangle(arr, x, y, size):
120 | s = int(size / 2)
121 |
122 | triangle = np.tril(np.ones((size, size), dtype=bool))
123 |
124 | arr[x-s:x-s+triangle.shape[0],y-s:y-s+triangle.shape[1]] = triangle
125 |
126 | return arr
127 |
128 | def add_circle(arr, x, y, size, fill=False):
129 | xx, yy = np.mgrid[:arr.shape[0], :arr.shape[1]]
130 | circle = np.sqrt((xx - x) ** 2 + (yy - y) ** 2)
131 | new_arr = np.logical_or(arr, np.logical_and(circle < size, circle >= size * 0.7 if not fill else True))
132 |
133 | return new_arr
134 |
135 | def add_plus(arr, x, y, size):
136 | s = int(size / 2)
137 | arr[x-1:x+1,y-s:y+s] = True
138 | arr[x-s:x+s,y-1:y+1] = True
139 |
140 | return arr
141 |
142 | def get_random_location(width, height, zoom=1.0):
143 | x = int(width * random.uniform(0.22, 0.78))
144 | y = int(height * random.uniform(0.22, 0.78))
145 |
146 | size = int(min(width, height) * random.uniform(0.06, 0.12) * zoom)
147 |
148 | return (x, y, size)
--------------------------------------------------------------------------------
/detect.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import torch.nn.functional as F
3 | import torch
4 | import torch.optim as optim
5 | import torch.nn as nn
6 | from torch.optim import lr_scheduler
7 | import time
8 | import copy
9 | from unet.pytorch_DPCN import FFT2, UNet, LogPolar, PhaseCorr, Corr2Softmax
10 | from data.dataset_DPCN import *
11 | import numpy as np
12 | import shutil
13 | from utils.utils import *
14 | import kornia
15 | from data.dataset import *
16 | from utils.detect_utils import *
17 |
18 |
19 | def detect_model(template_path, source_path, model_template, model_source, model_corr2softmax,\
20 | model_trans_template, model_trans_source, model_trans_corr2softmax):
21 | batch_size_inner = 1
22 |
23 | since = time.time()
24 |
25 | # Each epoch has a training and validation phase
26 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
27 | # device = torch.device("cpu")
28 | phase = "val"
29 |
30 | model_template.eval() # Set model to evaluate mode
31 | model_source.eval()
32 | model_corr2softmax.eval()
33 | model_trans_template.eval()
34 | model_trans_source.eval()
35 | model_trans_corr2softmax.eval()
36 | with torch.no_grad():
37 |
38 | iters = 0
39 | acc = 0.
40 | template, _, _, _, _, _ = default_loader(template_path, 256)
41 | source, _, _, _, _, _ = default_loader(source_path, 256)
42 | template = template.to(device)
43 | source = source.to(device)
44 | template = template.unsqueeze(0)
45 | template = template.permute(1,0,2,3)
46 | source = source.unsqueeze(0)
47 | source = source.permute(1,0,2,3)
48 |
49 | iters += 1
50 | since = time.time()
51 | rotation_cal, scale_cal = detect_rot_scale(template, source,\
52 | model_template, model_source, model_corr2softmax, device )
53 | tranformation_y, tranformation_x, image_aligned, source_rotated = detect_translation(template, source, rotation_cal, scale_cal, \
54 | model_trans_template, model_trans_source, model_trans_corr2softmax, device)
55 | time_elapsed = time.time() - since
56 | # print('in detection time {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
57 | print("in detection time", time_elapsed)
58 | plot_and_save_result(template[0,:,:], source[0,:,:], source_rotated[0,:,:], image_aligned)
59 |
60 |
61 |
62 |
63 |
64 | checkpoint_path = "./checkpoints/checkpoint_simulation_hetero.pt"
65 | template_path = "./demo/temp_1.png"
66 | source_path = "./demo/src_1.png"
67 |
68 | load_pretrained =True
69 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
70 | print("The devices that the code is running on:", device)
71 | # device = torch.device("cpu")
72 | batch_size = 1
73 | num_class = 1
74 | start_epoch = 0
75 | model_template = UNet(num_class).to(device)
76 | model_source = UNet(num_class).to(device)
77 | model_corr2softmax = Corr2Softmax(200., 0.).to(device)
78 | model_trans_template = UNet(num_class).to(device)
79 | model_trans_source = UNet(num_class).to(device)
80 | model_trans_corr2softmax = Corr2Softmax(200., 0.).to(device)
81 |
82 | optimizer_ft_temp = optim.Adam(filter(lambda p: p.requires_grad, model_template.parameters()), lr=2e-4)
83 | optimizer_ft_src = optim.Adam(filter(lambda p: p.requires_grad, model_source.parameters()), lr=2e-4)
84 | optimizer_c2s = optim.Adam(filter(lambda p: p.requires_grad, model_corr2softmax.parameters()), lr=1e-1)
85 | optimizer_trans_ft_temp = optim.Adam(filter(lambda p: p.requires_grad, model_template.parameters()), lr=2e-4)
86 | optimizer_trans_ft_src = optim.Adam(filter(lambda p: p.requires_grad, model_source.parameters()), lr=2e-4)
87 | optimizer_trans_c2s = optim.Adam(filter(lambda p: p.requires_grad, model_corr2softmax.parameters()), lr=1e-1)
88 |
89 | if load_pretrained:
90 | model_template, model_source, model_corr2softmax, model_trans_template, model_trans_source, model_trans_corr2softmax,\
91 | _, _, _, _, _, _,\
92 | start_epoch = load_checkpoint(\
93 | checkpoint_path, model_template, model_source, model_corr2softmax, model_trans_template, model_trans_source, model_trans_corr2softmax,\
94 | optimizer_ft_temp, optimizer_ft_src, optimizer_c2s, optimizer_trans_ft_temp, optimizer_trans_ft_src, optimizer_trans_c2s, device)
95 |
96 | detect_model(template_path, source_path, model_template, model_source, model_corr2softmax, model_trans_template, model_trans_source, model_trans_corr2softmax)
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
--------------------------------------------------------------------------------
/fft/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of useful helper functions."""
2 |
--------------------------------------------------------------------------------
/fft/dft_test.py:
--------------------------------------------------------------------------------
1 | """
2 | FFT based image registration. --- main functions
3 | """
4 |
5 | from __future__ import division, print_function
6 |
7 | import math
8 |
9 | import numpy as np
10 | try:
11 | import pyfftw.interfaces.numpy_fft as fft
12 | except ImportError:
13 | import numpy.fft as fft
14 | import scipy.ndimage.interpolation as ndii
15 |
16 | import imreg_dft.utils as utils
17 |
18 |
19 | def _logpolar_filter(shape):
20 | """
21 | Make a radial cosine filter for the logpolar transform.
22 | This filter suppresses low frequencies and completely removes
23 | the zero freq.
24 | """
25 | yy = np.linspace(- np.pi / 2., np.pi / 2., shape[0])[:, np.newaxis]
26 | xx = np.linspace(- np.pi / 2., np.pi / 2., shape[1])[np.newaxis, :]
27 | # Supressing low spatial frequencies is a must when using log-polar
28 | # transform. The scale stuff is poorly reflected with low freqs.
29 | rads = np.sqrt(yy ** 2 + xx ** 2)
30 | filt = 1.0 - np.cos(rads) ** 2
31 | # vvv This doesn't really matter, very high freqs are not too usable anyway
32 | filt[np.abs(rads) > np.pi / 2] = 1
33 | return filt
34 |
35 |
36 | def _get_pcorr_shape(shape):
37 | ret = (int(max(shape) * 1.0),) * 2
38 | return ret
39 |
40 |
41 | def _get_ang_scale(ims, bgval, exponent='inf', constraints=None, reports=None):
42 | """
43 | Given two images, return their scale and angle difference.
44 | Args:
45 | ims (2-tuple-like of 2D ndarrays): The images
46 | bgval: We also pad here in the :func:`map_coordinates`
47 | exponent (float or 'inf'): The exponent stuff, see :func:`similarity`
48 | constraints (dict, optional)
49 | reports (optional)
50 | Returns:
51 | tuple: Scale, angle. Describes the relationship of
52 | the subject image to the first one.
53 | """
54 | assert len(ims) == 2, \
55 | "Only two images are supported as input"
56 | shape = ims[0].shape
57 |
58 | ims_apod = [utils._apodize(im) for im in ims]
59 | dfts = [fft.fftshift(abs(fft.fft2(im/255))) for im in ims]
60 | # imshow(dfts[0]*1000,dfts[0]*1000,dfts[0]*1000)
61 | # imshow(ims_apod[0], ims_apod[0], ims_apod[0])
62 |
63 | filt = _logpolar_filter(shape)
64 | dfts = [dft * filt for dft in dfts]
65 | # imshow(dfts[0]*1000,dfts[0]*1000,dfts[0]*1000)
66 |
67 | # High-pass filtering used to be here, but we have moved it to a higher
68 | # level interface
69 |
70 | pcorr_shape = _get_pcorr_shape(shape)
71 | log_base = _get_log_base(shape, pcorr_shape[1])
72 |
73 | stuffs = [_logpolar(np.abs(dft), pcorr_shape, log_base)
74 | for dft in dfts]
75 | # print(stuffs[0])
76 | (arg_ang, arg_rad), success = _phase_correlation(
77 | stuffs[0], stuffs[1],
78 | utils.argmax_angscale, log_base, exponent, constraints, reports)
79 |
80 | angle = -np.pi * arg_ang / float(pcorr_shape[0])
81 | angle = np.rad2deg(angle)
82 | angle = utils.wrap_angle(angle, 360)
83 | scale = log_base ** arg_rad
84 |
85 | angle = - angle
86 | scale = 1.0 / scale
87 |
88 | if reports is not None:
89 | reports["shape"] = filt.shape
90 | reports["base"] = log_base
91 |
92 | if reports.show("spectra"):
93 | reports["dfts_filt"] = dfts
94 | if reports.show("inputs"):
95 | reports["ims_filt"] = [fft.ifft2(np.fft.ifftshift(dft))
96 | for dft in dfts]
97 | if reports.show("logpolar"):
98 | reports["logpolars"] = stuffs
99 |
100 | if reports.show("scale_angle"):
101 | reports["amas-result-raw"] = (arg_ang, arg_rad)
102 | reports["amas-result"] = (scale, angle)
103 | reports["amas-success"] = success
104 | extent_el = pcorr_shape[1] / 2.0
105 | reports["amas-extent"] = (
106 | log_base ** (-extent_el), log_base ** extent_el,
107 | -90, 90
108 | )
109 |
110 | if not 0.5 < scale < 2:
111 | raise ValueError(
112 | "Images are not compatible. Scale change %g too big to be true."
113 | % scale)
114 |
115 | return scale, angle
116 |
117 |
118 | def translation(im0, im1, filter_pcorr=0, odds=1, constraints=None,
119 | reports=None):
120 | """
121 | Return translation vector to register images.
122 | It tells how to translate the im1 to get im0.
123 | Args:
124 | im0 (2D numpy array): The first (template) image
125 | im1 (2D numpy array): The second (subject) image
126 | filter_pcorr (int): Radius of the minimum spectrum filter
127 | for translation detection, use the filter when detection fails.
128 | Values > 3 are likely not useful.
129 | constraints (dict or None): Specify preference of seeked values.
130 | For more detailed documentation, refer to :func:`similarity`.
131 | The only difference is that here, only keys ``tx`` and/or ``ty``
132 | (i.e. both or any of them or none of them) are used.
133 | odds (float): The greater the odds are, the higher is the preferrence
134 | of the angle + 180 over the original angle. Odds of -1 are the same
135 | as inifinity.
136 | The value 1 is neutral, the converse of 2 is 1 / 2 etc.
137 | Returns:
138 | dict: Contains following keys: ``angle``, ``tvec`` (Y, X),
139 | and ``success``.
140 | """
141 | angle = 0
142 | report_one = report_two = None
143 | if reports is not None and reports.show("translation"):
144 | report_one = reports.copy_empty()
145 | report_two = reports.copy_empty()
146 |
147 | # We estimate translation for the original image...
148 | tvec, succ = _translation(im0, im1, filter_pcorr, constraints, report_one)
149 | # ... and for the 180-degrees rotated image (the rotation estimation
150 | # doesn't distinguish rotation of x vs x + 180deg).
151 | tvec2, succ2 = _translation(im0, utils.rot180(im1), filter_pcorr,
152 | constraints, report_two)
153 |
154 | pick_rotated = False
155 | if succ2 * odds > succ or odds == -1:
156 | pick_rotated = True
157 |
158 | if reports is not None and reports.show("translation"):
159 | reports["t0-orig"] = report_one["amt-orig"]
160 | reports["t0-postproc"] = report_one["amt-postproc"]
161 | reports["t0-success"] = succ
162 | reports["t0-tvec"] = tuple(tvec)
163 |
164 | reports["t1-orig"] = report_two["amt-orig"]
165 | reports["t1-postproc"] = report_two["amt-postproc"]
166 | reports["t1-success"] = succ2
167 | reports["t1-tvec"] = tuple(tvec2)
168 |
169 | if reports is not None and reports.show("transformed"):
170 | toapp = [
171 | transform_img(utils.rot180(im1), tvec=tvec2, mode="wrap", order=3),
172 | transform_img(im1, tvec=tvec, mode="wrap", order=3),
173 | ]
174 | if pick_rotated:
175 | toapp = toapp[::-1]
176 | reports["after_tform"].extend(toapp)
177 |
178 | if pick_rotated:
179 | tvec = tvec2
180 | succ = succ2
181 | angle = angle + 180
182 |
183 | ret = dict(tvec=tvec, success=succ, angle=angle)
184 | return ret
185 |
186 |
187 | def _get_precision(shape, scale=1):
188 | """
189 | Given the parameters of the log-polar transform, get width of the interval
190 | where the correct values are.
191 | Args:
192 | shape (tuple): Shape of images
193 | scale (float): The scale difference (precision varies)
194 | """
195 | pcorr_shape = _get_pcorr_shape(shape)
196 | log_base = _get_log_base(shape, pcorr_shape[1])
197 | # * 0.5 <= max deviation is half of the step
198 | # * 0.25 <= we got subpixel precision now and 0.5 / 2 == 0.25
199 | # sccale: Scale deviation depends on the scale value
200 | Dscale = scale * (log_base - 1) * 0.25
201 | # angle: Angle deviation is constant
202 | Dangle = 180.0 / pcorr_shape[0] * 0.25
203 | return Dangle, Dscale
204 |
205 |
206 | def _similarity(im0, im1, numiter=1, order=3, constraints=None,
207 | filter_pcorr=0, exponent='inf', bgval=None, reports=None):
208 | """
209 | This function takes some input and returns mutual rotation, scale
210 | and translation.
211 | It does these things during the process:
212 | * Handles correct constraints handling (defaults etc.).
213 | * Performs angle-scale determination iteratively.
214 | This involves keeping constraints in sync.
215 | * Performs translation determination.
216 | * Calculates precision.
217 | Returns:
218 | Dictionary with results.
219 | """
220 | if bgval is None:
221 | bgval = utils.get_borderval(im1, 5)
222 |
223 | shape = im0.shape
224 | if shape != im1.shape:
225 | raise ValueError("Images must have same shapes.")
226 | elif im0.ndim != 2:
227 | raise ValueError("Images must be 2-dimensional.")
228 |
229 | # We are going to iterate and precise scale and angle estimates
230 | scale = 1.0
231 | angle = 0.0
232 | im2 = im1
233 |
234 | constraints_default = dict(angle=[0, None], scale=[1, None])
235 | if constraints is None:
236 | constraints = constraints_default
237 |
238 | # We guard against case when caller passes only one constraint key.
239 | # Now, the provided ones just replace defaults.
240 | constraints_default.update(constraints)
241 | constraints = constraints_default
242 |
243 | # During iterations, we have to work with constraints too.
244 | # So we make the copy in order to leave the original intact
245 | constraints_dynamic = constraints.copy()
246 | constraints_dynamic["scale"] = list(constraints["scale"])
247 | constraints_dynamic["angle"] = list(constraints["angle"])
248 |
249 | if reports is not None and reports.show("transformed"):
250 | reports["after_tform"] = [im2.copy()]
251 |
252 | for ii in range(numiter):
253 | newscale, newangle = _get_ang_scale([im0, im2], bgval, exponent,
254 | constraints_dynamic, reports)
255 | print("angle, scale",newangle, newscale)
256 | scale = scale*newscale
257 | angle = angle+newangle
258 |
259 | constraints_dynamic["scale"][0] = constraints_dynamic["scale"][0]/newscale
260 | constraints_dynamic["angle"][0] = constraints_dynamic["angle"][0]-newangle
261 | im2 = transform_img(im1, scale, angle, bgval=bgval, order=order)
262 |
263 | if reports is not None and reports.show("transformed"):
264 | reports["after_tform"].append(im2.copy())
265 |
266 | # Here we look how is the turn-180
267 | target, stdev = constraints.get("angle", (0, None))
268 | odds = _get_odds(angle, target, stdev)
269 | # now we can use pcorr to guess the translation
270 | res = translation(im0, im2, filter_pcorr, odds,
271 | constraints, reports)
272 | print("odd")
273 |
274 | # print("newangle, newscale",angle, scale)
275 | # The log-polar transform may have got the angle wrong by 180 degrees.
276 | # The phase correlation can help us to correct that
277 | angle = angle + res["angle"]
278 | res["angle"] = utils.wrap_angle(angle, 360)
279 |
280 | # don't know what it does, but it alters the scale a little bit
281 | # scale = (im1.shape[1] - 1) / (int(im1.shape[1] / scale) - 1)
282 |
283 | Dangle, Dscale = _get_precision(shape, scale)
284 |
285 | res["scale"] = scale
286 | res["Dscale"] = Dscale
287 | res["Dangle"] = Dangle
288 | # 0.25 because we go subpixel now
289 | res["Dt"] = 0.25
290 |
291 | return res
292 |
293 |
294 | def similarity(im0, im1, numiter=1, order=3, constraints=None,
295 | filter_pcorr=0, exponent='inf', reports=None):
296 | """
297 | Return similarity transformed image im1 and transformation parameters.
298 | Transformation parameters are: isotropic scale factor, rotation angle (in
299 | degrees), and translation vector.
300 | A similarity transformation is an affine transformation with isotropic
301 | scale and without shear.
302 | Args:
303 | im0 (2D numpy array): The first (template) image
304 | im1 (2D numpy array): The second (subject) image
305 | numiter (int): How many times to iterate when determining scale and
306 | rotation
307 | order (int): Order of approximation (when doing transformations). 1 =
308 | linear, 3 = cubic etc.
309 | filter_pcorr (int): Radius of a spectrum filter for translation
310 | detection
311 | exponent (float or 'inf'): The exponent value used during processing.
312 | Refer to the docs for a thorough explanation. Generally, pass "inf"
313 | when feeling conservative. Otherwise, experiment, values below 5
314 | are not even supposed to work.
315 | constraints (dict or None): Specify preference of seeked values.
316 | Pass None (default) for no constraints, otherwise pass a dict with
317 | keys ``angle``, ``scale``, ``tx`` and/or ``ty`` (i.e. you can pass
318 | all, some of them or none of them, all is fine). The value of a key
319 | is supposed to be a mutable 2-tuple (e.g. a list), where the first
320 | value is related to the constraint center and the second one to
321 | softness of the constraint (the higher is the number,
322 | the more soft a constraint is).
323 | More specifically, constraints may be regarded as weights
324 | in form of a shifted Gaussian curve.
325 | However, for precise meaning of keys and values,
326 | see the documentation section :ref:`constraints`.
327 | Names of dictionary keys map to names of command-line arguments.
328 | Returns:
329 | dict: Contains following keys: ``scale``, ``angle``, ``tvec`` (Y, X),
330 | ``success`` and ``timg`` (the transformed subject image)
331 | .. note:: There are limitations
332 | * Scale change must be less than 2.
333 | * No subpixel precision (but you can use *resampling* to get
334 | around this).
335 | """
336 | bgval = utils.get_borderval(im1, 5)
337 |
338 | res = _similarity(im0, im1, numiter, order, constraints,
339 | filter_pcorr, exponent, bgval, reports)
340 |
341 | im2 = transform_img_dict(im1, res, bgval, order)
342 | # Order of mask should be always 1 - higher values produce strange results.
343 | imask = transform_img_dict(np.ones_like(im1), res, 0, 1)
344 | # This removes some weird artifacts
345 | imask[imask > 0.8] = 1.0
346 |
347 | # Framing here = just blending the im2 with its BG according to the mask
348 | im3 = utils.frame_img(im2, imask, 10)
349 |
350 | res["timg"] = im3
351 | return res
352 |
353 |
354 | def _get_odds(angle, target, stdev):
355 | """
356 | Determine whether we are more likely to choose the angle, or angle + 180
357 | Args:
358 | angle (float, degrees): The base angle.
359 | target (float, degrees): The angle we think is the right one.
360 | Typically, we take this from constraints.
361 | stdev (float, degrees): The relevance of the target value.
362 | Also typically taken from constraints.
363 | Return:
364 | float: The greater the odds are, the higher is the preferrence
365 | of the angle + 180 over the original angle. Odds of -1 are the same
366 | as inifinity.
367 | """
368 | ret = 1
369 | if stdev is not None:
370 | diffs = [abs(utils.wrap_angle(ang, 360))
371 | for ang in (target - angle, target - angle + 180)]
372 | odds0, odds1 = 0, 0
373 | if stdev > 0:
374 | odds0, odds1 = [np.exp(- diff ** 2 / stdev ** 2) for diff in diffs]
375 | if odds0 == 0 and odds1 > 0:
376 | # -1 is treated as infinity in _translation
377 | ret = -1
378 | elif stdev == 0 or (odds0 == 0 and odds1 == 0):
379 | ret = -1
380 | if diffs[0] < diffs[1]:
381 | ret = 0
382 | else:
383 | ret = odds1 / odds0
384 | return ret
385 |
386 |
387 | def _translation(im0, im1, filter_pcorr=0, constraints=None, reports=None):
388 | """
389 | The plain wrapper for translation phase correlation, no big deal.
390 | """
391 | # Apodization and pcorr don't play along
392 | # im0, im1 = [utils._apodize(im, ratio=1) for im in (im0, im1)]
393 | ret, succ = _phase_correlation(
394 | im0, im1,
395 | utils.argmax_translation, filter_pcorr, constraints, reports)
396 | return ret, succ
397 |
398 |
399 | def _phase_correlation(im0, im1, callback=None, *args):
400 | """
401 | Computes phase correlation between im0 and im1
402 | Args:
403 | im0
404 | im1
405 | callback (function): Process the cross-power spectrum (i.e. choose
406 | coordinates of the best element, usually of the highest one).
407 | Defaults to :func:`imreg_dft.utils.argmax2D`
408 | Returns:
409 | tuple: The translation vector (Y, X). Translation vector of (0, 0)
410 | means that the two images match.
411 | """
412 | if callback is None:
413 | callback = utils._argmax2D
414 |
415 | # TODO: Implement some form of high-pass filtering of PHASE correlation
416 | f0, f1 = [fft.fft2(arr) for arr in (im0, im1)]
417 | print(im0)
418 | # spectrum can be filtered (already),
419 | # so we have to take precaution against dividing by 0
420 | eps = abs(f1).max() * 1e-15
421 |
422 | # cps == cross-power spectrum of im0 and im1
423 | cps = abs(fft.ifft2((f0 * f1.conjugate()) / (abs(f0) * abs(f1) + eps)))
424 |
425 | #scps = shifted cps
426 | scps = fft.fftshift(cps)
427 | # imshow(scps*1000,scps*1000,scps*1000)
428 | imshow(scps,scps, scps)
429 | #scps = cps
430 |
431 | (t0, t1), success = callback(scps, *args)
432 | ret = np.array((t0, t1))
433 |
434 | # _compensate_fftshift is not appropriate here, this is OK.
435 | # t0 -= f0.shape[0] // 2
436 | # t1 -= f0.shape[1] // 2
437 | print("ret_oir",ret)
438 |
439 | ret = ret-np.array(f0.shape, int) // 2
440 | print("ret",ret)
441 | return ret, success
442 |
443 |
444 | def transform_img_dict(img, tdict, bgval=None, order=1, invert=False):
445 | """
446 | Wrapper of :func:`transform_img`, works well with the :func:`similarity`
447 | output.
448 | Args:
449 | img
450 | tdict (dictionary): Transformation dictionary --- supposed to contain
451 | keys "scale", "angle" and "tvec"
452 | bgval
453 | order
454 | invert (bool): Whether to perform inverse transformation --- doesn't
455 | work very well with the translation.
456 | Returns:
457 | np.ndarray: .. seealso:: :func:`transform_img`
458 | """
459 | scale = tdict["scale"]
460 | angle = tdict["angle"]
461 | tvec = np.array(tdict["tvec"])
462 | if invert:
463 | scale = 1.0 / scale
464 | angle = -angle
465 | tvec = -tvec
466 | res = transform_img(img, scale, angle, tvec, bgval=bgval, order=order)
467 | return res
468 |
469 |
470 | def transform_img(img, scale=1.0, angle=0.0, tvec=(0, 0),
471 | mode="constant", bgval=None, order=1):
472 | """
473 | Return translation vector to register images.
474 | Args:
475 | img (2D or 3D numpy array): What will be transformed.
476 | If a 3D array is passed, it is treated in a manner in which RGB
477 | images are supposed to be handled - i.e. assume that coordinates
478 | are (Y, X, channels).
479 | Complex images are handled in a way that treats separately
480 | the real and imaginary parts.
481 | scale (float): The scale factor (scale > 1.0 means zooming in)
482 | angle (float): Degrees of rotation (clock-wise)
483 | tvec (2-tuple): Pixel translation vector, Y and X component.
484 | mode (string): The transformation mode (refer to e.g.
485 | :func:`scipy.ndimage.shift` and its kwarg ``mode``).
486 | bgval (float): Shade of the background (filling during transformations)
487 | If None is passed, :func:`imreg_dft.utils.get_borderval` with
488 | radius of 5 is used to get it.
489 | order (int): Order of approximation (when doing transformations). 1 =
490 | linear, 3 = cubic etc. Linear works surprisingly well.
491 | Returns:
492 | np.ndarray: The transformed img, may have another
493 | i.e. (bigger) shape than the source.
494 | """
495 | if img.ndim == 3:
496 | # A bloody painful special case of RGB images
497 | ret = np.empty_like(img)
498 | for idx in range(img.shape[2]):
499 | sli = (slice(None), slice(None), idx)
500 | ret[sli] = transform_img(img[sli], scale, angle, tvec,
501 | mode, bgval, order)
502 | return ret
503 | elif np.iscomplexobj(img):
504 | decomposed = np.empty(img.shape + (2,), float)
505 | decomposed[:, :, 0] = img.real
506 | decomposed[:, :, 1] = img.imag
507 | # The bgval makes little sense now, as we decompose the image
508 | res = transform_img(decomposed, scale, angle, tvec, mode, None, order)
509 | ret = res[:, :, 0] + 1j * res[:, :, 1]
510 | return ret
511 |
512 | if bgval is None:
513 | bgval = utils.get_borderval(img)
514 |
515 | bigshape = np.round(np.array(img.shape) * 1.2).astype(int)
516 | bg = np.zeros(bigshape, img.dtype) + bgval
517 |
518 | dest0 = utils.embed_to(bg, img.copy())
519 | # TODO: We have problems with complex numbers
520 | # that are not supported by zoom(), rotate() or shift()
521 | if scale != 1.0:
522 | dest0 = ndii.zoom(dest0, scale, order=order, mode=mode, cval=bgval)
523 | if angle != 0.0:
524 | dest0 = ndii.rotate(dest0, angle, order=order, mode=mode, cval=bgval)
525 |
526 | if tvec[0] != 0 or tvec[1] != 0:
527 | dest0 = ndii.shift(dest0, tvec, order=order, mode=mode, cval=bgval)
528 |
529 | bg = np.zeros_like(img) + bgval
530 | dest = utils.embed_to(bg, dest0)
531 | return dest
532 |
533 |
534 | def similarity_matrix(scale, angle, vector):
535 | """
536 | Return homogeneous transformation matrix from similarity parameters.
537 | Transformation parameters are: isotropic scale factor, rotation angle (in
538 | degrees), and translation vector (of size 2).
539 | The order of transformations is: scale, rotate, translate.
540 | """
541 | raise NotImplementedError("We have no idea what this is supposed to do")
542 | m_scale = np.diag([scale, scale, 1.0])
543 | m_rot = np.identity(3)
544 | angle = math.radians(angle)
545 | m_rot[0, 0] = math.cos(angle)
546 | m_rot[1, 1] = math.cos(angle)
547 | m_rot[0, 1] = -math.sin(angle)
548 | m_rot[1, 0] = math.sin(angle)
549 | m_transl = np.identity(3)
550 | m_transl[:2, 2] = vector
551 | return np.dot(m_transl, np.dot(m_rot, m_scale))
552 |
553 |
554 | EXCESS_CONST = 1.1
555 |
556 |
557 | def _get_log_base(shape, new_r):
558 | """
559 | Basically common functionality of :func:`_logpolar`
560 | and :func:`_get_ang_scale`
561 | This value can be considered fixed, if you want to mess with the logpolar
562 | transform, mess with the shape.
563 | Args:
564 | shape: Shape of the original image.
565 | new_r (float): The r-size of the log-polar transform array dimension.
566 | Returns:
567 | float: Base of the log-polar transform.
568 | The following holds:
569 | :math:`log\_base = \exp( \ln [ \mathit{spectrum\_dim} ] / \mathit{loglpolar\_scale\_dim} )`,
570 | or the equivalent :math:`log\_base^{\mathit{loglpolar\_scale\_dim}} = \mathit{spectrum\_dim}`.
571 | """
572 | # The highest radius we have to accomodate is 'old_r',
573 | # However, we cut some parts out as only a thin part of the spectra has
574 | # these high frequencies
575 | old_r = shape[0] * EXCESS_CONST
576 | # We are radius, so we divide the diameter by two.
577 | old_r /= 2.0
578 | # we have at most 'new_r' of space.
579 | log_base = np.exp(np.log(old_r) / new_r)
580 | return log_base
581 |
582 |
583 | def _logpolar(image, shape, log_base, bgval=None):
584 | """
585 | Return log-polar transformed image
586 | Takes into account anisotropicity of the freq spectrum
587 | of rectangular images
588 | Args:
589 | image: The image to be transformed
590 | shape: Shape of the transformed image
591 | log_base: Parameter of the transformation, get it via
592 | :func:`_get_log_base`
593 | bgval: The backround value. If None, use minimum of the image.
594 | Returns:
595 | The transformed image
596 | """
597 | if bgval is None:
598 | bgval = np.percentile(image, 1)
599 | imshape = np.array(image.shape)
600 | center = imshape[0] / 2.0, imshape[1] / 2.0
601 | # 0 .. pi = only half of the spectrum is used
602 | theta = utils._get_angles(shape)
603 | radius_x = utils._get_lograd(shape, log_base)
604 | radius_y = radius_x.copy()
605 | ellipse_coef = imshape[0] / float(imshape[1])
606 | # We have to acknowledge that the frequency spectrum can be deformed
607 | # if the image aspect ratio is not 1.0
608 | # The image is x-thin, so we acknowledge that the frequency spectra
609 | # scale in x is shrunk.
610 | # radius_x /= ellipse_coef
611 |
612 | y = radius_y * np.sin(theta) + center[0]
613 | x = radius_x * np.cos(theta) + center[1]
614 | output = np.empty_like(y)
615 | ndii.map_coordinates(image, [y, x], output=output, order=3,
616 | mode="constant", cval=bgval)
617 | # print(radius_x)
618 | # imshow(image, image, image)
619 | # imshow(output, output, output)
620 | # print(output)
621 | return output
622 |
623 |
624 | def imshow(im0, im1, im2, cmap=None, fig=None, **kwargs):
625 | """
626 | Plot images using matplotlib.
627 | Opens a new figure with four subplots:
628 | ::
629 | +----------------------+---------------------+
630 | | | |
631 | | | |
632 | | | |
633 | +----------------------+---------------------+
634 | | |
636 | | transformed subject> | |
637 | +----------------------+---------------------+
638 | Args:
639 | im0 (np.ndarray): The template image
640 | im1 (np.ndarray): The subject image
641 | im2: The transformed subject --- it is supposed to match the template
642 | cmap (optional): colormap
643 | fig (optional): The figure you would like to have this plotted on
644 | Returns:
645 | matplotlib figure: The figure with subplots
646 | """
647 | from matplotlib import pyplot
648 |
649 | if fig is None:
650 | fig = pyplot.figure()
651 | if cmap is None:
652 | cmap = 'coolwarm'
653 | # We do the difference between the template and the result now
654 | # To increase the contrast of the difference, we norm images according
655 | # to their near-maximums
656 | norm = np.percentile(np.abs(im2), 99.5) / np.percentile(np.abs(im0), 99.5)
657 | # Divide by zero is OK here
658 | phase_norm = np.median(np.angle(im2 / im0) % (2 * np.pi))
659 | if phase_norm != 0:
660 | norm *= np.exp(1j * phase_norm)
661 | im3 = abs(im2 - im0 * norm)
662 | pl0 = fig.add_subplot(221)
663 | pl0.imshow(im0.real, cmap, **kwargs)
664 | pl0.grid()
665 | share = dict(sharex=pl0, sharey=pl0)
666 | pl = fig.add_subplot(222, **share)
667 | pl.imshow(im1.real, cmap, **kwargs)
668 | pl.grid()
669 | pl = fig.add_subplot(223, **share)
670 | pl.imshow(im3, cmap, **kwargs)
671 | pl.grid()
672 | pl = fig.add_subplot(224, **share)
673 | pl.imshow(im2.real, cmap, **kwargs)
674 | pl.grid()
675 | return fig
676 |
--------------------------------------------------------------------------------
/fft/fft_demo.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | # from imreg_test import *
5 | #from imreg_dft import imreg
6 | #import imreg
7 | from dft_test import _logpolar, similarity, imshow, _get_log_base
8 |
9 | template = cv2.imread("./1.jpg",0)
10 | source = cv2.imread("./1.jpg",0)
11 | # template = cv2.cvtColor(template, cv2.COLOR_RGB2GRAY)
12 | # source = cv2.cvtColor(source, cv2.COLOR_RGB2GRAY)
13 |
14 | col, row = template.shape
15 | center = (col // 2, row // 2)
16 |
17 | t_x = -50.
18 | t_y = -50.
19 |
20 |
21 | M = cv2.getRotationMatrix2D(center, -160, 0.8)
22 | source = cv2.warpAffine(source, M, (col, row))
23 |
24 | N = np.float32([[1,0,t_x],[0,1,t_y]])
25 | source = cv2.warpAffine(source, N, (col, row))
26 | source = cv2.resize(source, (col, row), interpolation=cv2.INTER_CUBIC)
27 | #logbase = _get_log_base(template.shape, template.shape)
28 | ##print(logbase)
29 |
30 | # im2, scale, angle, (t0, t1) = similarity(template, source)
31 | # print(scale, angle)
32 | # imshow(template, source, im2)
33 |
34 | # im2 = _logpolar(template, template.shape, 1.00878256)
35 | result = similarity(source, template)
36 | # im2 = template[0:300,0:1000]
37 | # cv2.imshow("1", im2)
38 | # cv2.waitKey(0)
39 | # imshow(im2, im2, im2)
40 |
41 | print(result['angle'], result['scale'], result)
42 | # imshow(source, template, result['timg'])
43 | # plt.show()
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/fft/imreg_test.py:
--------------------------------------------------------------------------------
1 | # imreg.py
2 |
3 | # Copyright (c) 2011-2020, Christoph Gohlke
4 | # All rights reserved.
5 | #
6 | # Redistribution and use in source and binary forms, with or without
7 | # modification, are permitted provided that the following conditions are met:
8 | #
9 | # 1. Redistributions of source code must retain the above copyright notice,
10 | # this list of conditions and the following disclaimer.
11 | #
12 | # 2. Redistributions in binary form must reproduce the above copyright notice,
13 | # this list of conditions and the following disclaimer in the documentation
14 | # and/or other materials provided with the distribution.
15 | #
16 | # 3. Neither the name of the copyright holder nor the names of its
17 | # contributors may be used to endorse or promote products derived from
18 | # this software without specific prior written permission.
19 | #
20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
24 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
25 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
26 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
27 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
28 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
29 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
30 | # POSSIBILITY OF SUCH DAMAGE.
31 |
32 | """FFT based image registration.
33 | Imreg is a Python library that implements an FFT-based technique for
34 | translation, rotation and scale-invariant image registration [1].
35 | :Author:
36 | `Christoph Gohlke `_
37 | :Organization:
38 | Laboratory for Fluorescence Dynamics, University of California, Irvine
39 | :License: BSD 3-Clause
40 | :Version: 2020.1.1
41 | Requirements
42 | ------------
43 | * `CPython >= 3.6 `_
44 | * `Numpy 1.14 `_
45 | * `Scipy 1.3 `_
46 | * `Matplotlib 3.1 `_ (optional for plotting)
47 | Notes
48 | -----
49 | Imreg is no longer being actively developed.
50 | This implementation is mainly for educational purposes.
51 | An improved version is being developed at https://github.com/matejak/imreg_dft.
52 | References
53 | ----------
54 | 1. An FFT-based technique for translation, rotation and scale-invariant
55 | image registration. BS Reddy, BN Chatterji.
56 | IEEE Transactions on Image Processing, 5, 1266-1271, 1996
57 | 2. An IDL/ENVI implementation of the FFT-based algorithm for automatic
58 | image registration. H Xiea, N Hicksa, GR Kellera, H Huangb, V Kreinovich.
59 | Computers & Geosciences, 29, 1045-1055, 2003.
60 | 3. Image Registration Using Adaptive Polar Transform. R Matungka, YF Zheng,
61 | RL Ewing. IEEE Transactions on Image Processing, 18(10), 2009.
62 | Examples
63 | --------
64 | >>> im0 = imread('t400')
65 | >>> im1 = imread('Tr19s1.3')
66 | >>> im2, scale, angle, (t0, t1) = similarity(im0, im1)
67 | >>> imshow(im0, im1, im2)
68 | >>> im0 = imread('t350380ori')
69 | >>> im1 = imread('t350380shf')
70 | >>> t0, t1 = translation(im0, im1)
71 | >>> t0, t1
72 | (20, 50)
73 | """
74 |
75 | __version__ = '2020.1.1'
76 |
77 | __all__ = (
78 | 'translation', 'similarity', 'similarity_matrix', 'logpolar', 'highpass',
79 | 'imread', 'imshow', 'ndii'
80 | )
81 |
82 | import math
83 | import cv2
84 | import numpy
85 | from numpy.fft import fft2, ifft2, fftshift
86 |
87 | try:
88 | import scipy.ndimage.interpolation as ndii
89 | except ImportError:
90 | import ndimage.interpolation as ndii
91 |
92 |
93 | def translation(im0, im1):
94 | """Return translation vector to register images."""
95 | shape = im0.shape
96 | f0 = fft2(im0)
97 | f1 = fft2(im1)
98 | ir = abs(ifft2((f0 * f1.conjugate()) / (abs(f0) * abs(f1))))
99 | t0, t1 = numpy.unravel_index(numpy.argmax(ir), shape)
100 | if t0 > shape[0] // 2:
101 | t0 = t0 - shape[0]
102 | if t1 > shape[1] // 2:
103 | t1 = t0 - shape[1]
104 | return [t0, t1]
105 |
106 | def similarity(im0, im1):
107 | """Return similarity transformed image im1 and transformation parameters.
108 | Transformation parameters are: isotropic scale factor, rotation angle (in
109 | degrees), and translation vector.
110 | A similarity transformation is an affine transformation with isotropic
111 | scale and without shear.
112 | Limitations:
113 | Image shapes must be equal and square.
114 | All image areas must have same scale, rotation, and shift.
115 | Scale change must be less than 1.8.
116 | No subpixel precision.
117 | """
118 | if im0.shape != im1.shape:
119 | raise ValueError('images must have same shapes')
120 | if len(im0.shape) != 2:
121 | raise ValueError('images must be 2 dimensional')
122 |
123 | f0 = fftshift(abs(fft2(im0)))
124 | f1 = fftshift(abs(fft2(im1)))
125 |
126 | h = highpass(f0.shape)
127 | f0 *= h
128 | f1 *= h
129 | del h
130 |
131 | f0, log_base = logpolar(f0)
132 | f1, log_base = logpolar(f1)
133 |
134 | f0 = fft2(f0)
135 | f1 = fft2(f1)
136 | eps=1e-10
137 | r0 = abs(f0) * abs(f1)
138 | ir = abs(ifft2((f0 * f1.conjugate()) / (r0 + eps)))
139 | ir = fftshift(ir)
140 |
141 | i0, i1 = numpy.unravel_index(numpy.argmax(ir), ir.shape)
142 | # i0 -= f0.shape[0] // 2
143 | # i1 -= f0.shape[1] // 2
144 | print(i0, i1)
145 | angle = -180.0 * i0 / ir.shape[0]
146 | scale = log_base ** i1
147 | print(angle, scale)
148 | if scale > 1.8:
149 | ir = abs(ifft2((f1 * f0.conjugate()) / (r0 + eps)))
150 | ir = fftshift(ir)
151 | print("***********************")
152 | i0, i1 = numpy.unravel_index(numpy.argmax(ir), ir.shape)
153 | i0 = i0-ir.shape[0] // 2
154 | i1 = i1-ir.shape[0] // 2
155 | # imshow(ir*10000,ir*10000,ir*10000)
156 | print(i0, i1)
157 |
158 | angle = 180.0 * i0 / ir.shape[0]
159 | scale = 1.0 / (log_base ** i1)
160 | if scale > 1.8:
161 | raise ValueError('images are not compatible. Scale change > 1.8')
162 |
163 | # if angle < -90.0:
164 | # angle += 180.0
165 | # elif angle > 90.0:
166 | # angle -= 180.0
167 |
168 | print(angle, scale)
169 |
170 | im2 = ndii.zoom(im1, 1.0/scale)
171 | im2 = ndii.rotate(im2, -angle)
172 | if im2.shape < im0.shape:
173 | t = numpy.zeros_like(im0)
174 | t[:im2.shape[0], :im2.shape[1]] = im2
175 | im2 = t
176 | elif im2.shape > im0.shape:
177 | im2 = im2[:im0.shape[0], :im0.shape[1]]
178 |
179 | f0 = fft2(im0)
180 | f1 = fft2(im2)
181 | ir = abs(ifft2((f0 * f1.conjugate()) / (abs(f0) * abs(f1))))
182 | t0, t1 = numpy.unravel_index(numpy.argmax(ir), ir.shape)
183 |
184 | f2_rot = numpy.rot90(f1,2)
185 | f2_rot = f2_rot[:im0.shape[0], :im0.shape[1]]
186 | ir_rot = abs(ifft2((f0 * f2_rot.conjugate()) / (abs(f0) * abs(f2_rot))))
187 | t0_rot, t1_rot = numpy.unravel_index(numpy.argmax(ir_rot), ir_rot.shape)
188 |
189 | print("compare",ir[t0,t1],ir_rot[t0_rot,t1_rot])
190 | if(ir[t0,t1] < ir_rot[t0_rot,t1_rot]):
191 | angle = angle + 180
192 | im2 = numpy.rot90(im2, -180)
193 |
194 | if t0 > f0.shape[0] // 2:
195 | t0 = t0-f0.shape[0]
196 | if t1 > f0.shape[1] // 2:
197 | t1 = t1-f0.shape[1]
198 |
199 | im2 = ndii.shift(im2, [t0, t1])
200 |
201 | # correct parameters for ndimage's internal processing
202 | if angle > 0.0:
203 | d = int(int(im1.shape[1] / scale) * math.sin(math.radians(angle)))
204 | t0, t1 = t1, d+t0
205 | elif angle < 0.0:
206 | d = int(int(im1.shape[0] / scale) * math.sin(math.radians(angle)))
207 | t0, t1 = d+t1, d+t0
208 | scale = (im1.shape[1] - 1) / (int(im1.shape[1] / scale) - 1)
209 |
210 | if angle < -180.0:
211 | angle = angle+360.0
212 | elif angle > 180.0:
213 | angle = angle-360.0
214 |
215 | return im2, scale, angle, [-t0, -t1]
216 |
217 |
218 | def similarity_matrix(scale, angle, vector):
219 | """Return homogeneous transformation matrix from similarity parameters.
220 | Transformation parameters are: isotropic scale factor, rotation angle
221 | (in degrees), and translation vector (of size 2).
222 | The order of transformations is: scale, rotate, translate.
223 | """
224 | S = numpy.diag([scale, scale, 1.0])
225 | R = numpy.identity(3)
226 | angle = math.radians(angle)
227 | R[0, 0] = math.cos(angle)
228 | R[1, 1] = math.cos(angle)
229 | R[0, 1] = -math.sin(angle)
230 | R[1, 0] = math.sin(angle)
231 | T = numpy.identity(3)
232 | T[:2, 2] = vector
233 | return numpy.dot(T, numpy.dot(R, S))
234 |
235 |
236 | def logpolar(image, angles=None, radii=None):
237 | """Return log-polar transformed image and log base."""
238 | shape = image.shape
239 | center = shape[0] / 2, shape[1] / 2
240 | if angles is None:
241 | angles = shape[0]
242 | if radii is None:
243 | radii = shape[1]
244 | theta = numpy.empty((angles, radii), dtype='float64')
245 | theta.T[:] = numpy.linspace(0, numpy.pi, angles, endpoint=False) * -1.0
246 | # d = radii
247 | d = numpy.hypot(shape[0] - center[0], shape[1] - center[1])
248 | log_base = 10.0 ** (math.log10(d) / (radii))
249 | radius = numpy.empty_like(theta)
250 | radius[:] = numpy.power(log_base,
251 | numpy.arange(radii, dtype='float64')) - 1.0
252 | x = radius * numpy.sin(theta) + center[0]
253 | y = radius * numpy.cos(theta) + center[1]
254 | output = numpy.empty_like(x)
255 | ndii.map_coordinates(image, [x, y], output=output)
256 | return output, log_base
257 |
258 |
259 | def highpass(shape):
260 | """Return highpass filter to be multiplied with fourier transform."""
261 | x = numpy.outer(
262 | numpy.cos(numpy.linspace(-math.pi/2.0, math.pi/2.0, shape[0])),
263 | numpy.cos(numpy.linspace(-math.pi/2.0, math.pi/2.0, shape[1])))
264 | return (1.0 - x) * (2.0 - x)
265 |
266 |
267 | def imread(fname, norm=True):
268 | """Return image data from img&hdr uint8 files."""
269 | with open(fname + '.hdr', 'r') as fh:
270 | hdr = fh.readlines()
271 | img = numpy.fromfile(fname + '.img', numpy.uint8, -1)
272 | img.shape = int(hdr[4].split()[-1]), int(hdr[3].split()[-1])
273 | if norm:
274 | img = img.astype('float64')
275 | img = img/255.0
276 | return img
277 |
278 |
279 | def imshow(im0, im1, im2, im3=None, cmap=None, **kwargs):
280 | """Plot images using matplotlib."""
281 | from matplotlib import pyplot
282 |
283 | if im3 is None:
284 | im3 = abs(im2 - im0)
285 | pyplot.subplot(221)
286 | pyplot.imshow(im0, cmap, **kwargs)
287 | pyplot.subplot(222)
288 | pyplot.imshow(im1, cmap, **kwargs)
289 | pyplot.subplot(223)
290 | pyplot.imshow(im3, cmap, **kwargs)
291 | pyplot.subplot(224)
292 | pyplot.imshow(im2, cmap, **kwargs)
293 | pyplot.show()
294 |
295 |
296 | if __name__ == '__main__':
297 | import os
298 | import doctest
299 |
300 | try:
301 | os.chdir('data')
302 | except Exception:
303 | pass
304 | doctest.testmod()
--------------------------------------------------------------------------------
/fft/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import misc
3 | import cv2
4 | from argparse import ArgumentParser
5 |
6 | # a and b are numpy arrays
7 | def phase_correlation(a, b):
8 | G_a = np.fft.fft2(a)
9 | G_b = np.fft.fft2(b)
10 | conj_b = np.ma.conjugate(G_b)
11 | R = G_a*conj_b
12 | R /= np.absolute(R)
13 | r = np.fft.ifft2(R).real
14 | return r
15 |
16 | def main():
17 |
18 | parser = ArgumentParser(description="Set parameters phase correlation calculation")
19 |
20 | parser.add_argument("infile1", metavar="in1", help="input image 1")
21 | parser.add_argument("infile2", metavar="in2", help="input image 2")
22 | parser.add_argument("outfile", metavar="out", help="output image file name")
23 |
24 | args = parser.parse_args()
25 |
26 | infile1 = open(args.infile1)
27 | infile2 = open(args.infile2)
28 | outfile = args.outfile
29 | newfile = open(outfile, 'w')
30 |
31 | road1 = cv2.imread("./4.png")
32 | road2 = cv2.imread("./5.png")
33 | result = phase_correlation(road1, road2)
34 | cv2.imshow("s",result*1000)
35 | cv2.waitKey(0)
36 | infile1.close()
37 | infile2.close()
38 | newfile.close()
39 |
40 | if __name__=="__main__":
41 | main()
42 |
--------------------------------------------------------------------------------
/images_for_readme/Result1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZJU-Robotics-Lab/DPCN/c3e94beba3d3cf97f9ddfa24b77f309f5517d8b9/images_for_readme/Result1.png
--------------------------------------------------------------------------------
/images_for_readme/Result2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZJU-Robotics-Lab/DPCN/c3e94beba3d3cf97f9ddfa24b77f309f5517d8b9/images_for_readme/Result2.png
--------------------------------------------------------------------------------
/images_for_readme/simulation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZJU-Robotics-Lab/DPCN/c3e94beba3d3cf97f9ddfa24b77f309f5517d8b9/images_for_readme/simulation.png
--------------------------------------------------------------------------------
/log_polar/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of useful helper functions."""
2 |
--------------------------------------------------------------------------------
/log_polar/log_polar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from torch.autograd import Variable
5 | import cv2
6 | import math
7 | from torch.autograd.gradcheck import gradcheck
8 | from torchvision import transforms, utils
9 | import matplotlib.pyplot as plt
10 |
11 | def polar_transformer(U, out_size, device, log=True, radius_factor=0.707):
12 | """Polar Transformer Layer
13 |
14 | Based on https://github.com/tensorflow/models/blob/master/transformer/spatial_transformer.py.
15 | _repeat(), _interpolate() are exactly the same;
16 | the polar transform implementation is in _transform()
17 |
18 | Args:
19 | U, theta, out_size, name: same as spatial_transformer.py
20 | log (bool): log-polar if True; else linear polar
21 | radius_factor (float): 2maxR / Width
22 | """
23 | def _repeat(x, n_repeats):
24 | rep = torch.ones(n_repeats)
25 | rep.unsqueeze(0)
26 | x = torch.reshape(x, (-1, 1))
27 | x = x * rep
28 | return torch.reshape(x, [-1])
29 |
30 | def _interpolate(im, x, y, out_size): # im [B,H,W,C]
31 | # constants
32 | x = x.to(device)
33 | y = y.to(device)
34 | num_batch = im.shape[0]
35 | height = im.shape[1]
36 | width = im.shape[2]
37 | channels = im.shape[3]
38 | height_f = height
39 | width_f = width
40 |
41 | x = x.double()
42 | y = y.double()
43 | out_height = out_size[0]
44 | out_width = out_size[1]
45 | zero = torch.zeros([])
46 | max_y = im.shape[1] - 1
47 | max_x = im.shape[2] - 1
48 |
49 | # do sampling
50 | x0 = torch.floor(x).long()
51 | x1 = x0 + 1
52 | y0 = torch.floor(y).long()
53 | y1 = y0 + 1
54 |
55 | x0 = torch.clamp(x0, zero, max_x)
56 | x1 = torch.clamp(x1, zero, max_x)
57 | y0 = torch.clamp(y0, zero, max_y)
58 | y1 = torch.clamp(y1, zero, max_y)
59 | dim2 = width
60 | dim1 = width*height
61 |
62 | base = _repeat(torch.range(0, num_batch-1, dtype=int)*dim1, out_height*out_width)
63 | base = base.long()
64 | base = base.to(device)
65 | base_y0 = base + y0*dim2
66 | base_y1 = base + y1*dim2
67 | idx_a = base_y0 + x0
68 | idx_b = base_y1 + x0
69 | idx_c = base_y0 + x1
70 | idx_d = base_y1 + x1
71 |
72 |
73 | # use indices to lookup pixels in the flat image and restore
74 | # channels dim
75 | im_flat = torch.reshape(im, [-1, channels])
76 | im_flat = im_flat.clone().float().to(device)
77 |
78 | Ia = im_flat.gather(0, idx_a.unsqueeze(1))
79 | Ib = im_flat.gather(0, idx_b.unsqueeze(1))
80 | Ic = im_flat.gather(0, idx_c.unsqueeze(1))
81 | Id = im_flat.gather(0, idx_d.unsqueeze(1))
82 |
83 | # Ia = im_flat[idx_a].to(device)
84 | # Ib = im_flat[idx_b].to(device)
85 | # Ic = im_flat[idx_c].to(device)
86 | # Id = im_flat[idx_d].to(device)
87 |
88 | # and finally calculate interpolated values
89 | x0_f = x0.double()
90 | x1_f = x1.double()
91 | y0_f = y0.double()
92 | y1_f = y1.double()
93 | # print(((x1_f-x) * (y1_f-y)).shape)
94 | # print("-------------")
95 | wa = ((x1_f-x) * (y1_f-y)).unsqueeze(1)
96 | wb = ((x1_f-x) * (y-y0_f)).unsqueeze(1)
97 | wc = ((x-x0_f) * (y1_f-y)).unsqueeze(1)
98 | wd = ((x-x0_f) * (y-y0_f)).unsqueeze(1)
99 |
100 | # output = Ia + Ib + Ic + Id
101 | output = wa*Ia + wb*Ib + wc*Ic + wd*Id
102 | return output
103 |
104 | def _meshgrid(height, width):
105 | x_t = torch.ones([height, 1]) * torch.linspace(0.0, 1.0 * width-1, width).unsqueeze(1).permute(1, 0)
106 | y_t = torch.linspace(0.0, 1.0, height).unsqueeze(1) * torch.ones([1, width])
107 |
108 | x_t_flat = torch.reshape(x_t, (1, -1))
109 | y_t_flat = torch.reshape(y_t, (1, -1))
110 | grid = torch.cat((x_t_flat, y_t_flat), 0)
111 |
112 | return grid
113 |
114 | def _transform(input_dim, out_size):
115 | # radius_factor = torch.sqrt(torch.tensor(2.))/2.
116 | num_batch = input_dim.shape[0] # input [B,H,W,C]
117 | num_channels = input_dim.shape[3]
118 |
119 | out_height = out_size[0]
120 | out_width = out_size[1]
121 | grid = _meshgrid(out_height, out_width) # (2, WxH)
122 | grid = grid.unsqueeze(0)
123 | grid = torch.reshape(grid, [-1])
124 | grid = grid.repeat(num_batch)
125 | grid = torch.reshape(grid, [num_batch, 2, -1]) # (B,2,WxH)
126 |
127 | ## here we do the polar/log-polar transform
128 | W = torch.tensor(input_dim.shape[1], dtype = torch.double)
129 | # W = input_dim.shape[1].float()
130 | maxR = W*radius_factor
131 |
132 | # if radius is from 1 to W/2; log R is from 0 to log(W/2)
133 | # we map the -1 to +1 grid to log R
134 | # then remap to 0 to 1
135 | EXCESS_CONST = 1.1
136 |
137 | logbase = torch.exp(torch.log(W*EXCESS_CONST/2) / W) #10. ** (torch.log10(maxR) / W)
138 | #torch.exp(torch.log(W*EXCESS_CONST/2) / W) #
139 | # get radius in pix
140 | if log:
141 | # min=1, max=maxR
142 | r_s = torch.pow(logbase, grid[:, 0, :])
143 | else:
144 | # min=1, max=maxR
145 | r_s = 1 + (grid[:, 0, :] + 1)/2*(maxR-1)
146 |
147 | # y is from -1 to 1; theta is from 0 to 2pi
148 | theta = np.linspace(0., np.pi, input_dim.shape[1], endpoint=False) * -1.0
149 | t_s = torch.from_numpy(theta).unsqueeze(1) * torch.ones([1, out_width])
150 | t_s = torch.reshape(t_s, (1, -1))
151 |
152 | # use + theta[:, 0] to deal with origin
153 | x_s = r_s*torch.cos(t_s) + (W /2)
154 | y_s = r_s*torch.sin(t_s) + (W /2)
155 |
156 | x_s_flat = torch.reshape(x_s, [-1])
157 | y_s_flat = torch.reshape(y_s, [-1])
158 |
159 | input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat, out_size)
160 | output = torch.reshape(input_transformed, [num_batch, out_height, out_width, num_channels]).to(device)
161 | return output, logbase
162 |
163 | output, logbase = _transform(U, out_size)
164 | return [output, logbase]
165 |
166 | #### Debug
167 | # def meshgrid(height, width):
168 | # x_t = torch.ones([height, 1]) * torch.linspace(0.0, 1.0 * width-1, width).unsqueeze(1).permute(1, 0)
169 | # y_t = torch.linspace(0.0, 1.0, height).unsqueeze(1) * torch.ones([1, width])
170 |
171 | # x_t_flat = torch.reshape(x_t, (1, -1))
172 | # y_t_flat = torch.reshape(y_t, (1, -1))
173 | # grid = torch.cat((x_t_flat, y_t_flat), 0)
174 |
175 | # return grid
176 |
177 | # def transform(input_dim, out_size):
178 | # radius_factor = torch.sqrt(torch.tensor(2.))/2.
179 | # log = True
180 | # num_batch = input_dim.shape[0] # input [B,W,H,C]
181 | # num_channels = input_dim.shape[3]
182 | # # theta = torch.reshape(theta, (-1, 2))
183 | # # theta = theta.float()
184 |
185 | # out_height = out_size[0]
186 | # out_width = out_size[1]
187 | # grid = meshgrid(out_height, out_width) # (2, WxH)
188 | # grid = grid.unsqueeze(0)
189 | # grid = torch.reshape(grid, [-1])
190 | # grid = grid.repeat(num_batch)
191 | # grid = torch.reshape(grid, [num_batch, 2, -1]) # (B,2,WxH)
192 |
193 | # ## here we do the polar/log-polar transform
194 | # W = torch.tensor(input_dim.shape[1], dtype = torch.double)
195 | # # W = input_dim.shape[1].float()
196 | # maxR = W*radius_factor
197 |
198 | # # if radius is from 1 to W/2; log R is from 0 to log(W/2)
199 | # # we map the -1 to +1 grid to log R
200 | # # then remap to 0 to 1
201 |
202 | # EXCESS_CONST = 1.1
203 |
204 | # logbase = torch.exp(torch.log(W*EXCESS_CONST/2) / W) #10. ** (torch.log10(maxR) / W)
205 |
206 | # # get radius in pix
207 | # if log:
208 | # # min=1, max=maxR
209 | # r_s = torch.pow(logbase, grid[:, 0, :])-1
210 | # else:
211 | # # min=1, max=maxR
212 | # r_s = 1 + (grid[:, 0, :] + 1)/2*(maxR-1)
213 | # # convert it to [0, 2maxR/W]
214 | # # r_s = (r_s - 1) / (maxR - 1) * 2 * maxR / W
215 | # # y is from -1 to 1; theta is from 0 to 2pi
216 | # theta = np.linspace(0, np.pi, W, endpoint=False) * -1.0
217 | # t_s = torch.from_numpy(theta).unsqueeze(1) * torch.ones([1, out_width])
218 | # t_s = torch.reshape(t_s, (1, -1))
219 |
220 | # # use + theta[:, 0] to deal with origin
221 | # x_s = r_s*torch.cos(t_s) + (W /2)#+ theta[:, 0, np.newaxis] # x
222 | # y_s = r_s*torch.sin(t_s) + (W /2) #+ theta[:, 1, np.newaxis]
223 |
224 | # x_s_flat = torch.reshape(x_s, [-1])
225 | # y_s_flat = torch.reshape(y_s, [-1])
226 |
227 | # input_transformed = interpolate(input_dim, x_s_flat, y_s_flat, out_size)
228 | # output = torch.reshape(input_transformed, [num_batch, out_height, out_width, num_channels])
229 | # return output
230 |
231 | # def repeat(x, n_repeats):
232 | # rep = torch.ones(n_repeats)
233 | # rep.unsqueeze(0)
234 | # x = torch.reshape(x, (-1, 1))
235 | # x = x * rep
236 | # return torch.reshape(x, [-1])
237 |
238 | # def interpolate(im, x, y, out_size): # im [B,H,W,C]
239 | # # constants
240 | # num_batch = im.shape[0]
241 | # height = im.shape[1]
242 | # width = im.shape[2]
243 | # channels = im.shape[3]
244 | # height_f = torch.DoubleTensor(height)
245 | # width_f = torch.DoubleTensor(width)
246 |
247 | # x = x.double()
248 | # y = y.double()
249 | # out_height = out_size[0]
250 | # out_width = out_size[1]
251 | # zero = torch.zeros([])
252 | # max_y = im.shape[1] - 1
253 | # max_x = im.shape[2] - 1
254 |
255 | # # do sampling
256 | # x0 = torch.floor(x).long()
257 | # x1 = x0 + 1
258 | # y0 = torch.floor(y).long()
259 | # y1 = y0 + 1
260 | # x0 = torch.clamp(x0, zero, max_x)
261 | # x1 = torch.clamp(x1, zero, max_x)
262 | # y0 = torch.clamp(y0, zero, max_y)
263 | # y1 = torch.clamp(y1, zero, max_y)
264 | # dim2 = width
265 | # dim1 = width*height
266 | # base = repeat(torch.range(0, num_batch-1, dtype=int)*dim1, out_height*out_width)
267 | # base = base.long()
268 |
269 | # base_y0 = base + y0*dim2
270 | # base_y1 = base + y1*dim2
271 | # idx_a = base_y0 + x0
272 | # idx_b = base_y1 + x0
273 | # idx_c = base_y0 + x1
274 | # idx_d = base_y1 + x1
275 |
276 | # # use indices to lookup pixels in the flat image and restore
277 | # # channels dim
278 | # im_flat = torch.reshape(im, [-1, channels])
279 | # im_flat = im_flat.clone().double()
280 |
281 | # Ia = im_flat.gather(0, idx_a.unsqueeze(1))
282 | # # Ia = im_flat[idx_a-1].to(device)
283 | # Ib = im_flat.gather(0, idx_b.unsqueeze(1))
284 | # # Ib = im_flat[idx_b-1].to(device)
285 | # Ic = im_flat.gather(0, idx_c.unsqueeze(1))
286 | # # Ic = im_flat[idx_c-1].to(device)
287 | # Id = im_flat.gather(0, idx_d.unsqueeze(1))
288 | # # Id = im_flat[idx_d-1].to(device)
289 |
290 | # # and finally calculate interpolated values
291 | # x0_f = x0.double()
292 | # x1_f = x1.double()
293 | # y0_f = y0.double()
294 | # y1_f = y1.double()
295 | # wa = ((x1_f-x) * (y1_f-y)).unsqueeze(1)
296 | # wb = ((x1_f-x) * (y-y0_f)).unsqueeze(1)
297 | # wc = ((x-x0_f) * (y1_f-y)).unsqueeze(1)
298 | # wd = ((x-x0_f) * (y-y0_f)).unsqueeze(1)
299 | # # output = torch.add([wa*Ia, wb*Ib, wc*Ic, wd*Id])
300 | # output = wa*Ia + wb*Ib + wc*Ic + wd*Id
301 | # return output
302 |
303 | # def imshow(tensor, title=None):
304 | # unloader = transforms.ToPILImage()
305 | # image = tensor.cpu().clone() # we clone the tensor to not do changes on it
306 | # image = image.squeeze(0) # remove the fake batch dimension
307 | # image = unloader(image)
308 | # plt.imshow(image)
309 | # plt.show()
310 |
311 | # trans = transforms.Compose([
312 | # transforms.ToTensor(),
313 | # # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet
314 | # ])
315 |
316 | # input = cv2.imread("./1.jpg", 0)
317 | # # # input = polar(input)
318 | # # # cv2.namedWindow("Image")
319 | # # # cv2.imshow("Image",input)
320 | # # # cv2.waitKey (0)
321 |
322 | # input = trans(input)
323 | # # imshow(input)
324 | # input = input.unsqueeze(0)
325 | # input = input.permute(0,2,3,1)
326 |
327 | # # print(np.shape(input))
328 | # # input = torch.randn((1,128,128,1), dtype=torch.double, requires_grad=True)
329 | # # input = torch.ones(10,128,128,3).requires_grad
330 |
331 | # # test = gradcheck(transform, input, eps=1e-6, atol=1e-4)
332 | # # print("Are the gradients correct: ", test)
333 |
334 | # output = transform(input, [677,677])
335 |
336 | # output_show = output.permute(0,3,1,2)
337 | # output_show = output_show.squeeze(0)
338 | # # print(output_show.shape)
339 | # imshow(output_show.float())
340 | # # output_show = output_show.squeeze(0)
341 |
342 | # # print(output_show)
343 |
--------------------------------------------------------------------------------
/log_polar/polar.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def get_pixel_value(img, x, y):
4 | """
5 | Utility function to get pixel value for coordinate
6 | vectors x and y from a 4D tensor image.
7 | Input
8 | -----
9 | - img: tensor of shape (B, H, W, C)
10 | - x: flattened tensor of shape (B*H*W,)
11 | - y: flattened tensor of shape (B*H*W,)
12 | Returns
13 | -------
14 | - output: tensor of shape (B, H, W, C)
15 | """
16 | shape = tf.shape(x)
17 | batch_size = shape[0]
18 | height = shape[1]
19 | width = shape[2]
20 |
21 | batch_idx = tf.range(0, batch_size)
22 | batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
23 | b = tf.tile(batch_idx, (1, height, width))
24 |
25 | indices = tf.stack([b, y, x], 3)
26 |
27 | return tf.gather_nd(img, indices)
28 |
29 |
30 | def affine_grid_generator(height, width, theta):
31 | """
32 | This function returns a sampling grid, which when
33 | used with the bilinear sampler on the input feature
34 | map, will create an output feature map that is an
35 | affine transformation [1] of the input feature map.
36 | Input
37 | -----
38 | - height: desired height of grid/output. Used
39 | to downsample or upsample.
40 | - width: desired width of grid/output. Used
41 | to downsample or upsample.
42 | - theta: affine transform matrices of shape (num_batch, 2, 3).
43 | For each image in the batch, we have 6 theta parameters of
44 | the form (2x3) that define the affine transformation T.
45 | Returns
46 | -------
47 | - normalized grid (-1, 1) of shape (num_batch, 2, H, W).
48 | The 2nd dimension has 2 components: (x, y) which are the
49 | sampling points of the original image for each point in the
50 | target image.
51 | Note
52 | ----
53 | [1]: the affine transformation allows cropping, translation,
54 | and isotropic scaling.
55 | """
56 | num_batch = tf.shape(theta)[0]
57 |
58 | # create normalized 2D grid
59 | x = tf.linspace(-1.0, 1.0, width)
60 | y = tf.linspace(-1.0, 1.0, height)
61 | x_t, y_t = tf.meshgrid(x, y)
62 |
63 | # flatten
64 | x_t_flat = tf.reshape(x_t, [-1])
65 | y_t_flat = tf.reshape(y_t, [-1])
66 |
67 | # reshape to [x_t, y_t , 1] - (homogeneous form)
68 | ones = tf.ones_like(x_t_flat)
69 | sampling_grid = tf.stack([x_t_flat, y_t_flat, ones])
70 |
71 | # repeat grid num_batch times
72 | sampling_grid = tf.expand_dims(sampling_grid, axis=0)
73 | sampling_grid = tf.tile(sampling_grid, tf.stack([num_batch, 1, 1]))
74 |
75 | # cast to float32 (required for matmul)
76 | theta = tf.cast(theta, 'float32')
77 | sampling_grid = tf.cast(sampling_grid, 'float32')
78 |
79 | # transform the sampling grid - batch multiply
80 | batch_grids = tf.matmul(theta, sampling_grid)
81 | # batch grid has shape (num_batch, 2, H*W)
82 |
83 | # reshape to (num_batch, H, W, 2)
84 | batch_grids = tf.reshape(batch_grids, [num_batch, 2, height, width])
85 |
86 | return batch_grids
87 |
88 |
89 | def bilinear_sampler(img, x, y):
90 | """
91 | Performs bilinear sampling of the input images according to the
92 | normalized coordinates provided by the sampling grid. Note that
93 | the sampling is done identically for each channel of the input.
94 | To test if the function works properly, output image should be
95 | identical to input image when theta is initialized to identity
96 | transform.
97 | Input
98 | -----
99 | - img: batch of images in (B, H, W, C) layout.
100 | - grid: x, y which is the output of affine_grid_generator.
101 | Returns
102 | -------
103 | - out: interpolated images according to grids. Same size as grid.
104 | """
105 | H = tf.shape(img)[1]
106 | W = tf.shape(img)[2]
107 | max_y = tf.cast(H - 1, 'int32')
108 | max_x = tf.cast(W - 1, 'int32')
109 | zero = tf.zeros([], dtype='int32')
110 |
111 | # rescale x and y to [0, W-1/H-1]
112 | x = tf.cast(x, 'float32')
113 | y = tf.cast(y, 'float32')
114 | x = 0.5 * ((x + 1.0) * tf.cast(max_x-1, 'float32'))
115 | y = 0.5 * ((y + 1.0) * tf.cast(max_y-1, 'float32'))
116 |
117 | # grab 4 nearest corner points for each (x_i, y_i)
118 | x0 = tf.cast(tf.floor(x), 'int32')
119 | x1 = x0 + 1
120 | y0 = tf.cast(tf.floor(y), 'int32')
121 | y1 = y0 + 1
122 |
123 | # clip to range [0, H-1/W-1] to not violate img boundaries
124 | x0 = tf.clip_by_value(x0, zero, max_x)
125 | x1 = tf.clip_by_value(x1, zero, max_x)
126 | y0 = tf.clip_by_value(y0, zero, max_y)
127 | y1 = tf.clip_by_value(y1, zero, max_y)
128 |
129 | # get pixel value at corner coords
130 | Ia = get_pixel_value(img, x0, y0)
131 | Ib = get_pixel_value(img, x0, y1)
132 | Ic = get_pixel_value(img, x1, y0)
133 | Id = get_pixel_value(img, x1, y1)
134 |
135 | # recast as float for delta calculation
136 | x0 = tf.cast(x0, 'float32')
137 | x1 = tf.cast(x1, 'float32')
138 | y0 = tf.cast(y0, 'float32')
139 | y1 = tf.cast(y1, 'float32')
140 |
141 | # calculate deltas
142 | wa = (x1-x) * (y1-y)
143 | wb = (x1-x) * (y-y0)
144 | wc = (x-x0) * (y1-y)
145 | wd = (x-x0) * (y-y0)
146 |
147 | # add dimension for addition
148 | wa = tf.expand_dims(wa, axis=3)
149 | wb = tf.expand_dims(wb, axis=3)
150 | wc = tf.expand_dims(wc, axis=3)
151 | wd = tf.expand_dims(wd, axis=3)
152 |
153 | # compute output
154 | out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
155 |
156 | return out
157 |
158 |
159 | # import torch
160 | # import cv2
161 | # import torch.nn.functional as F
162 | # import matplotlib.pyplot as plt
163 | # import numpy as np
164 | # import torchvision.models as models
165 | # import torchvision.transforms as transforms
166 | # import torch.nn as nn
167 | # import torch
168 |
169 | # import matplotlib
170 | # import matplotlib.pyplot as plt
171 |
172 | # theta = torch.Tensor([[0.707,0.707,0],[-0.707,0.707,0]]).unsqueeze(dim=0)
173 | # img = cv2.imread('1.jpg',cv2.IMREAD_GRAYSCALE)
174 | # plt.subplot(2,1,1)
175 | # plt.imshow(img,cmap='gray')
176 | # plt.axis('off')
177 | # img = torch.Tensor(img).unsqueeze(0).unsqueeze(0)
178 | # grid = F.affine_grid(theta,size=img.shape)
179 | # print(np.shape(grid))
180 | # print(grid)
181 | # new_img_PIL = transforms.ToPILImage()(grid).convert('RGB')
182 | # new_img_PIL.show()
183 |
184 | # output = F.grid_sample(img,grid)[0].numpy().transpose(1,2,0).squeeze()
185 | # plt.subplot(2,1,2)
186 | # plt.imshow(output,cmap='gray')
187 | # plt.axis('off')
188 | # plt.show()
189 |
--------------------------------------------------------------------------------
/log_polar/polarizeLayer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | image = torch.randn(4,1,28,28)
3 | polar = torch.randn(4,3,28,28)
4 | center = torch.FloatTensor([image.shape[3]/2,image.shape[2]/2])
5 |
6 | for batch in range(image.shape[0]):
7 | for y in range(image.shape[2]):
8 | for x in range(image.shape[3]):
9 | Cart_coord = torch.FloatTensor([x - center[0], -(y - center[1])])
10 | rho = torch.mul(Cart_coord,Cart_coord).sum(-1).sqrt().log()
11 | theta = torch.atan2(Cart_coord[1],Cart_coord[0])
12 | polar[batch,0,y,x] = image[batch,0,y,x]
13 | polar[batch,1,y,x] = rho
14 | polar[batch,2,y,x] = theta
--------------------------------------------------------------------------------
/log_polar/tf_test.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tflearn.layers.normalization import batch_normalization
4 | import tflearn
5 | from tflearn.layers.conv import conv_2d
6 |
7 | import numpy as np
8 |
9 |
10 | def _interpolate(im, x, y, out_size):
11 | with tf.variable_scope('_interpolate'):
12 | # constants
13 | num_batch = tf.shape(im)[0]
14 | height = tf.shape(im)[1]
15 | width = tf.shape(im)[2]
16 | channels = tf.shape(im)[3]
17 |
18 | x = tf.cast(x, 'float32')
19 | y = tf.cast(y, 'float32')
20 | height_f = tf.cast(height, 'float32')
21 | width_f = tf.cast(width, 'float32')
22 | out_height = out_size[0]
23 | out_width = out_size[1]
24 | zero = tf.zeros([], dtype='int32')
25 | max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
26 | max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
27 |
28 | # scale indices from [-1, 1] to [0, width/height]
29 | x = (x + 1.0)*(width_f) / 2.0
30 | y = (y + 1.0)*(height_f) / 2.0
31 |
32 | # do sampling
33 | x0 = tf.cast(tf.floor(x), 'int32')
34 | x1 = x0 + 1
35 | y0 = tf.cast(tf.floor(y), 'int32')
36 | y1 = y0 + 1
37 |
38 | x0 = tf.clip_by_value(x0, zero, max_x)
39 | x1 = tf.clip_by_value(x1, zero, max_x)
40 | y0 = tf.clip_by_value(y0, zero, max_y)
41 | y1 = tf.clip_by_value(y1, zero, max_y)
42 | dim2 = width
43 | dim1 = width*height
44 | base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
45 | base_y0 = base + y0*dim2
46 | base_y1 = base + y1*dim2
47 | idx_a = base_y0 + x0
48 | idx_b = base_y1 + x0
49 | idx_c = base_y0 + x1
50 | idx_d = base_y1 + x1
51 |
52 | # use indices to lookup pixels in the flat image and restore
53 | # channels dim
54 | im_flat = tf.reshape(im, tf.stack([-1, channels]))
55 | im_flat = tf.cast(im_flat, 'float32')
56 | Ia = tf.gather(im_flat, idx_a)
57 | Ib = tf.gather(im_flat, idx_b)
58 | Ic = tf.gather(im_flat, idx_c)
59 | Id = tf.gather(im_flat, idx_d)
60 |
61 | # and finally calculate interpolated values
62 | x0_f = tf.cast(x0, 'float32')
63 | x1_f = tf.cast(x1, 'float32')
64 | y0_f = tf.cast(y0, 'float32')
65 | y1_f = tf.cast(y1, 'float32')
66 | wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1)
67 | wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1)
68 | wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1)
69 | wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1)
70 | print(wa.shape)
71 | print(Ia.shape)
72 | output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
73 | return output
74 |
75 | def _meshgrid(height, width):
76 | x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])),
77 | tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0]))
78 | y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1),
79 | tf.ones(shape=tf.stack([1, width])))
80 |
81 | x_t_flat = tf.reshape(x_t, (1, -1))
82 | y_t_flat = tf.reshape(y_t, (1, -1))
83 | grid = tf.concat(axis=0, values=[x_t_flat, y_t_flat])
84 |
85 | return grid
86 |
87 | grid = _interpolate(600, 400)
88 | #print(np.shape(grid))
89 | #print(grid)
90 |
91 | grid = tf.expand_dims(grid, 0)
92 | grid = tf.reshape(grid, [-1])
93 | grid = tf.tile(grid, tf.stack([10]))
94 | grid = tf.reshape(grid, tf.stack([10, 2, -1]))
95 |
96 | sess = tf.Session()
97 | #print(sess.run(grid))
98 |
--------------------------------------------------------------------------------
/phase_correlation/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of useful helper functions."""
2 |
--------------------------------------------------------------------------------
/phase_correlation/phase_corr.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.abspath(".."))
4 |
5 | import cv2
6 | import math
7 | import torch
8 | import kornia
9 | import numpy as np
10 | import torch.nn as nn
11 | from numpy.fft import fft2, ifft2, fftshift
12 | import torch.nn.functional as F
13 | from torch.autograd import Variable
14 | from torchvision import transforms, utils
15 | import matplotlib.pyplot as plt
16 | from utils.utils import *
17 | from log_polar.log_polar import *
18 |
19 | def phase_corr(a, b, device, logbase, modelc2s, trans=False):
20 | # a: template; b: source
21 | # imshow(a.squeeze(0).float())
22 | G_a = torch.rfft(a, 2, onesided=False)
23 | G_b = torch.rfft(b, 2, onesided=False)
24 | eps=1e-15
25 |
26 | real_a = G_a[:, :, :, 0]
27 | real_b = G_b[:, :, :, 0]
28 | imag_a = G_a[:, :, :, 1]
29 | imag_b = G_b[:, :, :, 1]
30 |
31 | # compute a * b.conjugate; shape=[B,H,W,C]
32 | R = torch.FloatTensor(G_a.shape[0], G_a.shape[1], G_a.shape[2], 2).to(device)
33 | R[:, :, :, 0] = real_a * real_b + imag_a * imag_b
34 | R[:, :, :, 1] = real_a * imag_b - real_b * imag_a
35 |
36 | r0 = torch.sqrt(real_a ** 2 + imag_a ** 2 + eps) * torch.sqrt(real_b ** 2 + imag_b ** 2 + eps)
37 | R[:, :, :, 0] = R[:, :, :, 0].clone()/(r0 + eps).to(device)
38 | R[:, :, :, 1] = R[:, :, :, 1].clone()/(r0 + eps).to(device)
39 |
40 | r = torch.ifft(R, 2)
41 | r_real = r[:, :, :, 0]
42 | r_imag = r[:, :, :, 1]
43 | r = torch.sqrt(r_real ** 2 + r_imag ** 2 + eps)
44 | r = fftshift2d(r)
45 | if trans:
46 | r[:,0:60,:]=0.
47 | r[:,G_a.shape[1]-60:G_a.shape[1], :] = 0.
48 | r[:,:, 0:60]=0.
49 | r[:, :, G_a.shape[1]-60:G_a.shape[1]] = 0.
50 | # imshow(r[0,:,:])
51 | # plt.show()
52 | # feed the result of phase correlation to the NET
53 | softargmax_input = modelc2s(r.clone())
54 | # suppress the output to angle and scale
55 | angle_resize_out_tensor = torch.sum(softargmax_input.clone(), 2, keepdim=False)
56 | scale_reszie_out_tensor = torch.sum(softargmax_input.clone(), 1, keepdim=False)
57 | # get the argmax of the angle and the scale
58 | angle_out_tensor = torch.argmax(angle_resize_out_tensor.clone().detach(), dim=-1)
59 | scale_out_tensor = torch.argmax(scale_reszie_out_tensor.clone().detach(), dim=-1)
60 |
61 | #calculate angle
62 | angle = angle_out_tensor*180.00/r.shape[1]
63 | for batch_num in range(angle.shape[0]):
64 | if angle[batch_num].item() > 90:
65 | angle[batch_num] -= 90.00
66 | else:
67 | angle[batch_num] += 90.00
68 | # compute the softmax in case any needs
69 | softmax_result = softmax2d(softargmax_input, device)
70 | # imshow(softmax_result[0,:,:])
71 | # plt.show()
72 | # calculate scale
73 | logbase = logbase.to(device)
74 |
75 | sca_f = scale_out_tensor.clone()-r.shape[2] // 2
76 | scale = 1 / torch.pow(logbase, sca_f.float())#logbase ** sca_f
77 |
78 | return [angle, scale, softmax_result,r]
79 |
80 | def highpass(shape):
81 | """Return highpass filter to be multiplied with fourier transform."""
82 | i1 = torch.cos(torch.linspace(-np.pi/2.0, np.pi/2.0, shape[0]))
83 | i2 = torch.cos(torch.linspace(-np.pi/2.0, np.pi/2.0, shape[1]))
84 | x = torch.einsum('i,j->ij', i1, i2)
85 | return (1.0 - x) * (1.0 - x)
86 |
87 | def logpolar_filter(shape):
88 | """
89 | Make a radial cosine filter for the logpolar transform.
90 | This filter suppresses low frequencies and completely removes
91 | the zero freq.
92 | """
93 | yy = np.linspace(- np.pi / 2., np.pi / 2., shape[0])[:, np.newaxis]
94 | xx = np.linspace(- np.pi / 2., np.pi / 2., shape[1])[np.newaxis, :]
95 | # Supressing low spatial frequencies is a must when using log-polar
96 | # transform. The scale stuff is poorly reflected with low freqs.
97 | rads = np.sqrt(yy ** 2 + xx ** 2)
98 | filt = 1.0 - np.cos(rads) ** 2
99 | # vvv This doesn't really matter, very high freqs are not too usable anyway
100 | filt[np.abs(rads) > np.pi / 2] = 1
101 | filt = torch.from_numpy(filt)
102 | return filt
103 |
104 | class LogPolar(nn.Module):
105 | def __init__(self, out_size, device):
106 | super(LogPolar, self).__init__()
107 | self.out_size = out_size
108 | self.device = device
109 |
110 | def forward(self, input):
111 | return polar_transformer(input, self.out_size, self.device)
112 |
113 |
114 | class PhaseCorr(nn.Module):
115 | def __init__(self, device, logbase, trans=False):
116 | super(PhaseCorr, self).__init__()
117 | self.device = device
118 | self.logbase = logbase
119 | self.trans = trans
120 |
121 | def forward(self, template, source):
122 | return phase_corr(template, source, self.device, self.logbase, trans=self.trans)
123 |
124 | ##############################################
125 | # grad check
126 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
127 |
128 | # x = torch.randn((1, 4, 3, 1), requires_grad=True)
129 | # y = torch.ones((1, 4, 3, 1),requires_grad=True)
130 |
131 | # logpolar_layer = LogPolar((x.shape[1], x.shape[2]), device)
132 | # x_logpolar, log_base = logpolar_layer(x)
133 | # y_logpolar, log_base = logpolar_layer(y)
134 |
135 | # x_logpolar = x_logpolar.squeeze(-1)
136 | # y_logpolar = y_logpolar.squeeze(-1)
137 | # y_logpolar.retain_grad()
138 |
139 | # phase_corr_layer_rs = PhaseCorr(device, log_base)
140 | # angle, scale, _ = phase_corr_layer_rs(x_logpolar, y_logpolar)
141 |
142 | # rx = torch.ifft(x, 2)
143 | # r_real = rx[:, :, :, 0]
144 | # r_real.retain_grad()
145 | # r_imag = rx[:, :, :, 1]
146 | # r = torch.sqrt(r_real ** 2 + r_imag ** 2)
147 | # r = fftshift2d(r)
148 | # r.sum().backward()
149 | # print(r_real.grad)
150 |
151 |
152 | # loss = nn.L1Loss()
153 | # loss = loss(x_logpolar)
154 | # print(x.device)
155 | # x_logpolar.sum().backward()
156 | # angle.backward()
157 | # print(y_logpolar.grad)
158 |
159 | # ###############################################
160 | # # overall check
161 |
162 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
163 |
164 | # trans = transforms.Compose([
165 | # transforms.ToTensor(),
166 | # ])
167 |
168 | # # read images
169 | # template = cv2.imread("./1_-70.jpg", 0)
170 | # source = cv2.imread("./7_33.jpg", 0)
171 | # t_x = 0.
172 | # t_y = 0.
173 | # # gt rot and scale
174 | # col, row = template.shape
175 | # center = (col // 2, row // 2)
176 |
177 | # N = np.float32([[1,0,t_x],[0,1,t_y]])
178 | # source = cv2.warpAffine(source, N, (col, row))
179 | # # M = cv2.getRotationMatrix2D(center, 0., 1.)
180 | # # source = cv2.warpAffine(source, M, (col, row))
181 |
182 |
183 | # template = trans(template)
184 | # source = trans(source)
185 | # # template = template.unsqueeze(0)
186 | # # source = source.unsqueeze(0)
187 | # # imshow(template)
188 | # # plt.show()
189 | # # imshow(source)
190 | # # plt.show()
191 |
192 | # # [B,H,W,C]
193 | # template = template.unsqueeze(0)
194 | # template = template.permute(0,2,3,1)
195 | # source = source.unsqueeze(0)
196 | # source = source.permute(0,2,3,1)
197 | # template_img = template.squeeze(-1)
198 | # source_img = source.squeeze(-1)
199 |
200 | # # fft
201 | # template = torch.rfft(template_img, 2, onesided=False)
202 | # source = torch.rfft(source_img, 2, onesided=False)
203 |
204 | # # fftshift
205 | # template_r = template[:, :, :, 0]
206 | # template_i = template[:, :, :, 1]
207 | # template = torch.sqrt(template_r ** 2 + template_i ** 2)
208 |
209 | # source_r = source[:, :, :, 0]
210 | # source_i = source[:, :, :, 1]
211 | # source = torch.sqrt(source_r ** 2 + source_i ** 2)
212 |
213 | # template = fftshift2d(template)
214 | # source = fftshift2d(source) # [B,H,W]
215 |
216 | # # highpass filter
217 | # h = logpolar_filter((source.shape[1],source.shape[2]))#highpass((source.shape[1],source.shape[2])) # [H,W]
218 | # template = template.squeeze(0) * h
219 | # source = source.squeeze(0) * h
220 |
221 | # # print(template)
222 | # # imshow(template.squeeze(0))
223 | # # change size
224 | # template = template.unsqueeze(-1)
225 | # source = source.unsqueeze(-1)
226 | # template = template.unsqueeze(0)
227 | # source = source.unsqueeze(0)
228 |
229 | # # log_polar
230 | # template, logbase = polar_transformer(template, (source.shape[1], source.shape[2]), device)
231 | # source, logbase = polar_transformer(source, (source.shape[1], source.shape[2]), device)
232 |
233 | # source = source.squeeze(-1)
234 | # template = template.squeeze(-1)
235 | # # imshow(template.squeeze(0))
236 |
237 | # # phase corr
238 | # rot, scale, _,_,_ = phase_corr(template, source, device, logbase, trans=False)
239 |
240 | # # angle = -rot * math.pi/180
241 | # center = torch.ones(1,2).to(device)
242 | # center[:, 0] = col // 2
243 | # center[:, 1] = row // 2
244 | # rot = torch.zeros(1).to(device)
245 | # rot_mat = kornia.get_rotation_matrix2d(center, -rot, 1/scale)
246 | # _, h, w = source_img.shape
247 | # new_source_img = kornia.warp_affine(source_img.unsqueeze(1).to(device), rot_mat, dsize=(h, w))
248 |
249 | # # theta = torch.tensor([
250 | # # [math.cos(angle), math.sin(-angle), 0],
251 | # # [math.sin(angle), math.cos(angle), 0]
252 | # # ], dtype=torch.float)
253 | # # grid = F.affine_grid(theta.unsqueeze(0), source_img.unsqueeze(0).size())
254 | # # output = F.grid_sample(source_img.unsqueeze(0), grid)
255 |
256 | # # theta = torch.tensor([
257 | # # [scale, 0., 0],
258 | # # [0., scale, 0]
259 | # # ], dtype=torch.float)
260 | # # grid = F.affine_grid(theta.unsqueeze(0), source_img.unsqueeze(0).size())
261 | # # output = F.grid_sample(output[0].unsqueeze(0), grid)
262 | # # new_source_img = output[0].to(device)
263 | # # imshow(source_img)
264 |
265 | # # imshow(new_source_img.squeeze(1))
266 | # # imshow(template_img)
267 |
268 | # t0, t1, success, _, r = phase_corr(template_img.to(device), new_source_img.squeeze(1), device, logbase, trans=True)
269 | # imshow(r)
270 | # plt.show()
271 | # print("success", success)
272 | # # if success == 1:
273 | # # rot += 180
274 | # # print("rot+= 180")
275 |
276 | # rot += success.squeeze(0)*180
277 |
278 | # if rot < -180.0:
279 | # rot += 360.0
280 | # elif rot > 180.0:
281 | # rot -= 360.0
282 |
283 | # print(rot)
284 | # print(scale)
285 | # print(t0,t1)
--------------------------------------------------------------------------------
/phase_correlation/tmp.dot:
--------------------------------------------------------------------------------
1 | digraph {
2 | graph [size="12,12"]
3 | node [align=left fontsize=12 height=0.2 ranksep=0.1 shape=box style=filled]
4 | 139620185182544 [label=CopySlices fillcolor=white]
5 | 139620268556624 -> 139620185182544
6 | 139620227597648 -> 139620185182544
7 | 139620227597648 [label=ViewBackward fillcolor=white]
8 | 139620268560144 -> 139620227597648
9 | 139620268560144 [label=SelectBackward fillcolor=white]
10 | 139620268556624 -> 139620268560144
11 | 139620268556624 [label=DivBackward0 fillcolor=red]
12 | 139619904498512 -> 139620268556624
13 | 139619904498512 [label=MulBackward0 fillcolor=red]
14 | 139619904498640 -> 139619904498512
15 | 139619904498640 [label=SelectBackward fillcolor=white]
16 | 139619904498768 -> 139619904498640
17 | 139619904498768 [label=SliceBackward fillcolor=white]
18 | 139619904498896 -> 139619904498768
19 | 139619904498896 [label=StackBackward fillcolor=white]
20 | 139619904499088 -> 139619904498896
21 | 139619904499408 -> 139619904498896
22 | 139619904499408 [label=SumBackward1 fillcolor=white]
23 | 139619904499280 -> 139619904499408
24 | 139619904499280 [label=MulBackward0 fillcolor=red]
25 | 139619904499536 -> 139619904499280
26 | 139619904499536 [label=MulBackward0 fillcolor=red]
27 | 139619904540752 -> 139619904499536
28 | 139619904540752 [label=SoftmaxBackward fillcolor=white]
29 | 139619904540944 -> 139619904540752
30 | 139619904540944 [label=MulBackward0 fillcolor=red]
31 | 139619904541136 -> 139619904540944
32 | 139619904541136 [label=ViewBackward fillcolor=white]
33 | 139619904541328 -> 139619904541136
34 | 139619904541328 [label=CatBackward fillcolor=white]
35 | 139619904541520 -> 139619904541328
36 | 139619904541840 -> 139619904541328
37 | 139619904541840 [label=SliceBackward fillcolor=white]
38 | 139619904541712 -> 139619904541840
39 | 139619904541712 [label=SliceBackward fillcolor=white]
40 | 139619904541968 -> 139619904541712
41 | 139619904541968 [label=SliceBackward fillcolor=white]
42 | 139619904542096 -> 139619904541968
43 | 139619904542096 [label=CatBackward fillcolor=white]
44 | 139619904542288 -> 139619904542096
45 | 139619904542608 -> 139619904542096
46 | 139619904542608 [label=SliceBackward fillcolor=white]
47 | 139619904542480 -> 139619904542608
48 | 139619904542480 [label=SliceBackward fillcolor=white]
49 | 139619904542800 -> 139619904542480
50 | 139619904542800 [label=SliceBackward fillcolor=white]
51 | 139619904542992 -> 139619904542800
52 | 139619904542992 [label=SqrtBackward fillcolor=red]
53 | 139619904543120 -> 139619904542992
54 | 139619904543120 [label=AddBackward0 fillcolor=red]
55 | 139620227950416 -> 139619904543120
56 | 139619904543312 -> 139619904543120
57 | 139619904543312 [label=PowBackward0 fillcolor=red]
58 | 139619904543504 -> 139619904543312
59 | 139619904543504 [label=SelectBackward fillcolor=red]
60 | 139619904543696 -> 139619904543504
61 | 139619904543696 [label=SliceBackward fillcolor=red]
62 | 139619904543888 -> 139619904543696
63 | 139619904543888 [label=SliceBackward fillcolor=red]
64 | 139619904544080 -> 139619904543888
65 | 139619904544080 [label=SliceBackward fillcolor=red]
66 | 139619904544272 -> 139619904544080
67 | 139619904544272 [label=FftWithSizeBackward fillcolor=red]
68 | 139619904544464 -> 139619904544272
69 | 139619904544464 [label=CopySlices fillcolor=red]
70 | 139619904544656 -> 139619904544464
71 | 139619904557328 -> 139619904544464
72 | 139619904557328 [label=ExpandBackward fillcolor=red]
73 | 139619904557200 -> 139619904557328
74 | 139619904557200 [label=ViewBackward fillcolor=red]
75 | 139619904557456 -> 139619904557200
76 | 139619904557456 [label=AsStridedBackward fillcolor=red]
77 | 139619904544656 -> 139619904557456
78 | 139619904544656 [label=CopySlices fillcolor=red]
79 | 139619904557776 -> 139619904544656
80 | 139619904558032 -> 139619904544656
81 | 139619904558032 [label=AddBackward0 fillcolor=red]
82 | 139619904557904 -> 139619904558032
83 | 139619904557904 [label=MulBackward0 fillcolor=red]
84 | 139619904558160 -> 139619904557904
85 | 139619904558416 -> 139619904557904
86 | 139619904558416 [label=SqrtBackward fillcolor=red]
87 | 139619904558288 -> 139619904558416
88 | 139619904558288 [label=AddBackward0 fillcolor=red]
89 | 139619904558480 -> 139619904558288
90 | 139619904558800 -> 139619904558288
91 | 139619904558800 [label=PowBackward0 fillcolor=red]
92 | 139619904558672 -> 139619904558800
93 | 139619904558672 [label=SelectBackward fillcolor=red]
94 | 139619904558928 -> 139619904558672
95 | 139619904558928 [label=SliceBackward fillcolor=red]
96 | 139619904559120 -> 139619904558928
97 | 139619904559120 [label=SliceBackward fillcolor=red]
98 | 139620187330960 -> 139619904559120
99 | 139620187330960 [label=SliceBackward fillcolor=red]
100 | 139619904559504 -> 139620187330960
101 | 139619904559504 [label=FftWithSizeBackward fillcolor=red]
102 | 139619904559568 -> 139619904559504
103 | 139619904559568 [label=SqueezeBackward1 fillcolor=red]
104 | 139619904559760 -> 139619904559568
105 | 139619904559760 [label=ViewBackward fillcolor=red]
106 | 139619904559952 -> 139619904559760
107 | 139619904559952 [label=AddBackward0 fillcolor=red]
108 | 139619904560144 -> 139619904559952
109 | 139619904560464 -> 139619904559952
110 | 139619904560464 [label=MulBackward0 fillcolor=red]
111 | 139619904560336 -> 139619904560464
112 | 139619904560336 [label=CopyBackwards fillcolor=red]
113 | 139619904560592 -> 139619904560336
114 | 139619904560592 [label=IndexBackward fillcolor=red]
115 | 139619904560784 -> 139619904560592
116 | 139619904560784 [label=CopyBackwards fillcolor=red]
117 | 139619904560976 -> 139619904560784
118 | 139619904560976 [label=ViewBackward fillcolor=red]
119 | 139619904573520 -> 139619904560976
120 | 139619904573520 [label=UnsqueezeBackward0 fillcolor=red]
121 | 139619904573712 -> 139619904573520
122 | 139619904573712 [label=UnsqueezeBackward0 fillcolor=red]
123 | 139619904573840 -> 139619904573712
124 | 139619904573840 [label=MulBackward0 fillcolor=red]
125 | 139619904574032 -> 139619904573840
126 | 139619904574032 [label=SqueezeBackward1 fillcolor=red]
127 | 139619904574224 -> 139619904574032
128 | 139619904574224 [label=CatBackward fillcolor=red]
129 | 139619904574416 -> 139619904574224
130 | 139619904574736 -> 139619904574224
131 | 139619904574736 [label=SliceBackward fillcolor=red]
132 | 139619904574608 -> 139619904574736
133 | 139619904574608 [label=SliceBackward fillcolor=red]
134 | 139619904574864 -> 139619904574608
135 | 139619904574864 [label=SliceBackward fillcolor=red]
136 | 139619904575056 -> 139619904574864
137 | 139619904575056 [label=CatBackward fillcolor=red]
138 | 139619904575248 -> 139619904575056
139 | 139619904575440 -> 139619904575056
140 | 139619904575440 [label=SliceBackward fillcolor=red]
141 | 139619904575376 -> 139619904575440
142 | 139619904575376 [label=SliceBackward fillcolor=red]
143 | 139619904575568 -> 139619904575376
144 | 139619904575568 [label=SliceBackward fillcolor=red]
145 | 139619904575760 -> 139619904575568
146 | 139619904575760 [label=SqrtBackward fillcolor=red]
147 | 139619904575952 -> 139619904575760
148 | 139619904575952 [label=AddBackward0 fillcolor=red]
149 | 139619904576144 -> 139619904575952
150 | 139619904576400 -> 139619904575952
151 | 139619904576400 [label=PowBackward0 fillcolor=red]
152 | 139619904576272 -> 139619904576400
153 | 139619904576272 [label=SelectBackward fillcolor=red]
154 | 139619904576528 -> 139619904576272
155 | 139619904576528 [label=SliceBackward fillcolor=red]
156 | 139619904576720 -> 139619904576528
157 | 139619904576720 [label=SliceBackward fillcolor=red]
158 | 139619904576912 -> 139619904576720
159 | 139619904576912 [label=SliceBackward fillcolor=red]
160 | 139619904577104 -> 139619904576912
161 | 139619904577104 [label=FftWithSizeBackward fillcolor=red]
162 | 139619904577296 -> 139619904577104
163 | 139619904577296 [label=SqueezeBackward1 fillcolor=red]
164 | 139619904577488 -> 139619904577296
165 | 139619904577488 [label=PermuteBackward fillcolor=red]
166 | 139619904594128 -> 139619904577488
167 | 139619904594128 [label=UnsqueezeBackward0 fillcolor=red]
168 | 139619904203824 -> 139619904594128
169 | 139619904203824 [label="Variable
170 | (1, 4, 3)" fillcolor=lightblue]
171 | 139619904576144 [label=PowBackward0 fillcolor=red]
172 | 139619904594576 -> 139619904576144
173 | 139619904594576 [label=SelectBackward fillcolor=red]
174 | 139619904594640 -> 139619904594576
175 | 139619904594640 [label=SliceBackward fillcolor=red]
176 | 139619904594832 -> 139619904594640
177 | 139619904594832 [label=SliceBackward fillcolor=red]
178 | 139619904595024 -> 139619904594832
179 | 139619904595024 [label=SliceBackward fillcolor=red]
180 | 139619904577104 -> 139619904595024
181 | 139619904575248 [label=SliceBackward fillcolor=red]
182 | 139619904595408 -> 139619904575248
183 | 139619904595408 [label=SliceBackward fillcolor=red]
184 | 139619904595536 -> 139619904595408
185 | 139619904595536 [label=SliceBackward fillcolor=red]
186 | 139619904575760 -> 139619904595536
187 | 139619904574416 [label=SliceBackward fillcolor=white]
188 | 139619904595920 -> 139619904574416
189 | 139619904595920 [label=SliceBackward fillcolor=white]
190 | 139619904596048 -> 139619904595920
191 | 139619904596048 [label=SliceBackward fillcolor=white]
192 | 139619904575056 -> 139619904596048
193 | 139619904560144 [label=AddBackward0 fillcolor=red]
194 | 139619904596432 -> 139619904560144
195 | 139619904596688 -> 139619904560144
196 | 139619904596688 [label=MulBackward0 fillcolor=red]
197 | 139619904596560 -> 139619904596688
198 | 139619904596560 [label=CopyBackwards fillcolor=red]
199 | 139619904596816 -> 139619904596560
200 | 139619904596816 [label=IndexBackward fillcolor=red]
201 | 139619904560784 -> 139619904596816
202 | 139619904596432 [label=AddBackward0 fillcolor=red]
203 | 139619904597136 -> 139619904596432
204 | 139619904597392 -> 139619904596432
205 | 139619904597392 [label=MulBackward0 fillcolor=red]
206 | 139619904597264 -> 139619904597392
207 | 139619904597264 [label=CopyBackwards fillcolor=red]
208 | 139619904597520 -> 139619904597264
209 | 139619904597520 [label=IndexBackward fillcolor=red]
210 | 139619904560784 -> 139619904597520
211 | 139619904597136 [label=MulBackward0 fillcolor=red]
212 | 139619904597904 -> 139619904597136
213 | 139619904597904 [label=CopyBackwards fillcolor=red]
214 | 139619904610384 -> 139619904597904
215 | 139619904610384 [label=IndexBackward fillcolor=red]
216 | 139619904560784 -> 139619904610384
217 | 139619904558480 [label=PowBackward0 fillcolor=red]
218 | 139619904610768 -> 139619904558480
219 | 139619904610768 [label=SelectBackward fillcolor=red]
220 | 139619904610896 -> 139619904610768
221 | 139619904610896 [label=SliceBackward fillcolor=red]
222 | 139619904611088 -> 139619904610896
223 | 139619904611088 [label=SliceBackward fillcolor=red]
224 | 139619904611216 -> 139619904611088
225 | 139619904611216 [label=SliceBackward fillcolor=red]
226 | 139619904559504 -> 139619904611216
227 | 139619904558160 [label=SqrtBackward fillcolor=red]
228 | 139619904611600 -> 139619904558160
229 | 139619904611600 [label=AddBackward0 fillcolor=red]
230 | 139619904611728 -> 139619904611600
231 | 139619904611984 -> 139619904611600
232 | 139619904611984 [label=PowBackward0 fillcolor=red]
233 | 139619904611856 -> 139619904611984
234 | 139619904611856 [label=SelectBackward fillcolor=red]
235 | 139619904612112 -> 139619904611856
236 | 139619904612112 [label=SliceBackward fillcolor=red]
237 | 139619904612304 -> 139619904612112
238 | 139619904612304 [label=SliceBackward fillcolor=red]
239 | 139619904612496 -> 139619904612304
240 | 139619904612496 [label=SliceBackward fillcolor=red]
241 | 139619904612688 -> 139619904612496
242 | 139619904612688 [label=FftWithSizeBackward fillcolor=red]
243 | 139619904612880 -> 139619904612688
244 | 139619904612880 [label=SqueezeBackward1 fillcolor=red]
245 | 139619904613008 -> 139619904612880
246 | 139619904613008 [label=ViewBackward fillcolor=red]
247 | 139619904613200 -> 139619904613008
248 | 139619904613200 [label=AddBackward0 fillcolor=red]
249 | 139619904613392 -> 139619904613200
250 | 139619904613648 -> 139619904613200
251 | 139619904613648 [label=MulBackward0 fillcolor=red]
252 | 139619904613520 -> 139619904613648
253 | 139619904613520 [label=CopyBackwards fillcolor=red]
254 | 139619904613776 -> 139619904613520
255 | 139619904613776 [label=IndexBackward fillcolor=red]
256 | 139619904613968 -> 139619904613776
257 | 139619904613968 [label=CopyBackwards fillcolor=red]
258 | 139619904614160 -> 139619904613968
259 | 139619904614160 [label=ViewBackward fillcolor=red]
260 | 139619904614352 -> 139619904614160
261 | 139619904614352 [label=UnsqueezeBackward0 fillcolor=red]
262 | 139619904626896 -> 139619904614352
263 | 139619904626896 [label=UnsqueezeBackward0 fillcolor=red]
264 | 139619904627088 -> 139619904626896
265 | 139619904627088 [label=MulBackward0 fillcolor=red]
266 | 139619904627280 -> 139619904627088
267 | 139619904627280 [label=SqueezeBackward1 fillcolor=red]
268 | 139619904627408 -> 139619904627280
269 | 139619904627408 [label=CatBackward fillcolor=red]
270 | 139619904627600 -> 139619904627408
271 | 139619904627920 -> 139619904627408
272 | 139619904627920 [label=SliceBackward fillcolor=red]
273 | 139619904627792 -> 139619904627920
274 | 139619904627792 [label=SliceBackward fillcolor=red]
275 | 139619904628048 -> 139619904627792
276 | 139619904628048 [label=SliceBackward fillcolor=red]
277 | 139619904628240 -> 139619904628048
278 | 139619904628240 [label=CatBackward fillcolor=red]
279 | 139619904628432 -> 139619904628240
280 | 139619904628688 -> 139619904628240
281 | 139619904628688 [label=SliceBackward fillcolor=red]
282 | 139619904628560 -> 139619904628688
283 | 139619904628560 [label=SliceBackward fillcolor=red]
284 | 139619904628816 -> 139619904628560
285 | 139619904628816 [label=SliceBackward fillcolor=red]
286 | 139619904629008 -> 139619904628816
287 | 139619904629008 [label=SqrtBackward fillcolor=red]
288 | 139619904629200 -> 139619904629008
289 | 139619904629200 [label=AddBackward0 fillcolor=red]
290 | 139619904629392 -> 139619904629200
291 | 139619904629712 -> 139619904629200
292 | 139619904629712 [label=PowBackward0 fillcolor=red]
293 | 139619904629584 -> 139619904629712
294 | 139619904629584 [label=SelectBackward fillcolor=red]
295 | 139619904629840 -> 139619904629584
296 | 139619904629840 [label=SliceBackward fillcolor=red]
297 | 139619904630032 -> 139619904629840
298 | 139619904630032 [label=SliceBackward fillcolor=red]
299 | 139619904630224 -> 139619904630032
300 | 139619904630224 [label=SliceBackward fillcolor=red]
301 | 139619904630416 -> 139619904630224
302 | 139619904630416 [label=FftWithSizeBackward fillcolor=red]
303 | 139619904630608 -> 139619904630416
304 | 139619904630608 [label=SqueezeBackward1 fillcolor=red]
305 | 139619904639056 -> 139619904630608
306 | 139619904639056 [label=PermuteBackward fillcolor=red]
307 | 139619904639248 -> 139619904639056
308 | 139619904639248 [label=UnsqueezeBackward0 fillcolor=red]
309 | 139619904204064 -> 139619904639248
310 | 139619904204064 [label="Variable
311 | (1, 4, 3)" fillcolor=lightblue]
312 | 139619904629392 [label=PowBackward0 fillcolor=red]
313 | 139619904639824 -> 139619904629392
314 | 139619904639824 [label=SelectBackward fillcolor=red]
315 | 139619904639952 -> 139619904639824
316 | 139619904639952 [label=SliceBackward fillcolor=red]
317 | 139619904640144 -> 139619904639952
318 | 139619904640144 [label=SliceBackward fillcolor=red]
319 | 139619904640272 -> 139619904640144
320 | 139619904640272 [label=SliceBackward fillcolor=red]
321 | 139619904630416 -> 139619904640272
322 | 139619904628432 [label=SliceBackward fillcolor=red]
323 | 139619904640656 -> 139619904628432
324 | 139619904640656 [label=SliceBackward fillcolor=red]
325 | 139619904640784 -> 139619904640656
326 | 139619904640784 [label=SliceBackward fillcolor=red]
327 | 139619904629008 -> 139619904640784
328 | 139619904627600 [label=SliceBackward fillcolor=white]
329 | 139619904641168 -> 139619904627600
330 | 139619904641168 [label=SliceBackward fillcolor=white]
331 | 139619904641296 -> 139619904641168
332 | 139619904641296 [label=SliceBackward fillcolor=white]
333 | 139619904628240 -> 139619904641296
334 | 139619904613392 [label=AddBackward0 fillcolor=red]
335 | 139620186772880 -> 139619904613392
336 | 139619904641616 -> 139619904613392
337 | 139619904641616 [label=MulBackward0 fillcolor=red]
338 | 139619904641744 -> 139619904641616
339 | 139619904641744 [label=CopyBackwards fillcolor=red]
340 | 139619904641872 -> 139619904641744
341 | 139619904641872 [label=IndexBackward fillcolor=red]
342 | 139619904613968 -> 139619904641872
343 | 139620186772880 [label=AddBackward0 fillcolor=red]
344 | 139619904642192 -> 139620186772880
345 | 139619904642384 -> 139620186772880
346 | 139619904642384 [label=MulBackward0 fillcolor=red]
347 | 139619904642320 -> 139619904642384
348 | 139619904642320 [label=CopyBackwards fillcolor=red]
349 | 139619904642512 -> 139619904642320
350 | 139619904642512 [label=IndexBackward fillcolor=red]
351 | 139619904613968 -> 139619904642512
352 | 139619904642192 [label=MulBackward0 fillcolor=red]
353 | 139619904642896 -> 139619904642192
354 | 139619904642896 [label=CopyBackwards fillcolor=red]
355 | 139619904643024 -> 139619904642896
356 | 139619904643024 [label=IndexBackward fillcolor=red]
357 | 139619904613968 -> 139619904643024
358 | 139619904611728 [label=PowBackward0 fillcolor=red]
359 | 139619904659856 -> 139619904611728
360 | 139619904659856 [label=SelectBackward fillcolor=red]
361 | 139619904659920 -> 139619904659856
362 | 139619904659920 [label=SliceBackward fillcolor=red]
363 | 139619904660112 -> 139619904659920
364 | 139619904660112 [label=SliceBackward fillcolor=red]
365 | 139619904660304 -> 139619904660112
366 | 139619904660304 [label=SliceBackward fillcolor=red]
367 | 139619904612688 -> 139619904660304
368 | 139619904557776 [label=CopySlices fillcolor=red]
369 | 139619904660688 -> 139619904557776
370 | 139619904660944 -> 139619904557776
371 | 139619904660944 [label=ExpandBackward fillcolor=red]
372 | 139619904660816 -> 139619904660944
373 | 139619904660816 [label=ViewBackward fillcolor=red]
374 | 139619904661072 -> 139619904660816
375 | 139619904661072 [label=AsStridedBackward fillcolor=red]
376 | 139619904660688 -> 139619904661072
377 | 139619904660688 [label=CopySlices fillcolor=red]
378 | 139619904661456 -> 139619904660688
379 | 139619904661712 -> 139619904660688
380 | 139619904661712 [label=AddBackward0 fillcolor=red]
381 | 139619904557904 -> 139619904661712
382 | 139619904661456 [label=CopySlices fillcolor=red]
383 | 139619904661840 -> 139619904661456
384 | 139619904662032 -> 139619904661456
385 | 139619904662032 [label=ExpandBackward fillcolor=red]
386 | 139619904661968 -> 139619904662032
387 | 139619904661968 [label=ViewBackward fillcolor=red]
388 | 139619904662160 -> 139619904661968
389 | 139619904662160 [label=SubBackward0 fillcolor=red]
390 | 139619904662288 -> 139619904662160
391 | 139619904662608 -> 139619904662160
392 | 139619904662608 [label=MulBackward0 fillcolor=red]
393 | 139619904659856 -> 139619904662608
394 | 139619904558672 -> 139619904662608
395 | 139619904662288 [label=MulBackward0 fillcolor=red]
396 | 139619904610768 -> 139619904662288
397 | 139619904611856 -> 139619904662288
398 | 139619904661840 [label=CopySlices fillcolor=red]
399 | 139619904662800 -> 139619904661840
400 | 139619904662800 [label=ExpandBackward fillcolor=red]
401 | 139619904662928 -> 139619904662800
402 | 139619904662928 [label=ViewBackward fillcolor=red]
403 | 139619904663120 -> 139619904662928
404 | 139619904663120 [label=AddBackward0 fillcolor=red]
405 | 139619904663312 -> 139619904663120
406 | 139619904675984 -> 139619904663120
407 | 139619904675984 [label=MulBackward0 fillcolor=red]
408 | 139619904611856 -> 139619904675984
409 | 139619904558672 -> 139619904675984
410 | 139619904663312 [label=MulBackward0 fillcolor=red]
411 | 139619904659856 -> 139619904663312
412 | 139619904610768 -> 139619904663312
413 | 139620227950416 [label=PowBackward0 fillcolor=red]
414 | 139619904676496 -> 139620227950416
415 | 139619904676496 [label=SelectBackward fillcolor=red]
416 | 139619904676368 -> 139619904676496
417 | 139619904676368 [label=SliceBackward fillcolor=red]
418 | 139619904676624 -> 139619904676368
419 | 139619904676624 [label=SliceBackward fillcolor=red]
420 | 139619904676816 -> 139619904676624
421 | 139619904676816 [label=SliceBackward fillcolor=red]
422 | 139619904544272 -> 139619904676816
423 | 139619904542288 [label=SliceBackward fillcolor=white]
424 | 139619904676240 -> 139619904542288
425 | 139619904676240 [label=SliceBackward fillcolor=white]
426 | 139619904677264 -> 139619904676240
427 | 139619904677264 [label=SliceBackward fillcolor=white]
428 | 139619904542992 -> 139619904677264
429 | 139619904541520 [label=SliceBackward fillcolor=white]
430 | 139619904677584 -> 139619904541520
431 | 139619904677584 [label=SliceBackward fillcolor=white]
432 | 139619904677712 -> 139619904677584
433 | 139619904677712 [label=SliceBackward fillcolor=white]
434 | 139619904542096 -> 139619904677712
435 | 139619904499088 [label=SumBackward1 fillcolor=white]
436 | 139619904678096 -> 139619904499088
437 | 139619904678096 [label=MulBackward0 fillcolor=red]
438 | 139619904678224 -> 139619904678096
439 | 139619904678224 [label=MulBackward0 fillcolor=red]
440 | 139619904540752 -> 139619904678224
441 | }
442 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.5.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
2 | torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
3 | kornia
4 | graphviz
5 | opencv-python
6 | scipy
7 | matplotlib
8 | pandas
9 | tensorboardX
10 |
--------------------------------------------------------------------------------
/trainDPCN.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import kornia
4 | import time
5 | import copy
6 | import shutil
7 | import numpy as np
8 | import torch.nn as nn
9 | from graphviz import Digraph
10 | from torch.optim import lr_scheduler
11 | from collections import defaultdict
12 | import torch.nn.functional as F
13 | from unet.loss import dice_loss
14 | import torch.optim as optim
15 | from data.dataset import *
16 | from unet.pytorch_DPCN import FFT2, UNet, LogPolar, PhaseCorr, Corr2Softmax
17 | from data.dataset_DPCN import *
18 | from tensorboardX import SummaryWriter
19 | from utils.utils import *
20 | from utils.train_utils import *
21 | from validate import val_model
22 | import argparse
23 |
24 |
25 | # adding a bunch of parameters for an easy access
26 | parser = argparse.ArgumentParser(description="DPCN Network Training")
27 |
28 | parser.add_argument('--cpu', action='store_true', default=False, help="The Program will use cpu for the training")
29 | parser.add_argument('--save_path', type=str, default="./checkpoints/", help="The path to save the checkpoint of every epoch")
30 | parser.add_argument('--simulation', action='store_true', default=False, help="The training will be applied on a randomly generated simulation dataset")
31 | parser.add_argument('--load_pretrained', action='store_true', default=False, help="Choose whether to use a pretrained model to fine tune")
32 | parser.add_argument('--load_path', type=str, default="./checkpoints/checkpoint.pt", help="The path to load a pretrained checkpoint")
33 | parser.add_argument('--load_optimizer', action='store_true', default=False, help="When using a pretrained model, options of loading it's optimizer")
34 | parser.add_argument('--pretrained_mode', type=str, default="all", help="Three options: 'all' for loading rotation and translation; 'rot' for loading only rotation; 'trans' for loading only translation")
35 | parser.add_argument('--use_dsnt', action='store_true', default=False, help="When enabled, the loss will be calculated via DSNT and MSELoss, or it will use a CELoss")
36 | parser.add_argument('--batch_size_train', type=int, default=2, help="The batch size of training")
37 | parser.add_argument('--batch_size_val', type=int, default=2, help="The batch size of validation")
38 | parser.add_argument('--train_writer_path', type=str, default="./checkpoints/log/train/", help="Where to write the Log of training")
39 | parser.add_argument('--val_writer_path', type=str, default="./checkpoints/log/val/", help="Where to write the Log of validation")
40 | args = parser.parse_args()
41 |
42 | writer = SummaryWriter(log_dir=args.train_writer_path)
43 | writer_val = SummaryWriter(log_dir=args.val_writer_path)
44 | np.set_printoptions(threshold=np.inf)
45 |
46 |
47 | def train_model(model_template, model_source, model_corr2softmax, model_trans_template, model_trans_source, model_trans_corr2softmax,\
48 | optimizer_temp, optimizer_src, optimizer_c2s, optimizer_trans_temp, optimizer_trans_src, optimizer_trans_c2s,\
49 | scheduler_temp, scheduler_src, scheduler_trans_temp, scheduler_trans_src,\
50 | save_path, start_epoch, num_epochs=25):
51 | best_loss = 1e10
52 | iters = 0
53 |
54 | for epoch in range(start_epoch , start_epoch + num_epochs):
55 | print('Epoch {}/{}'.format(epoch, num_epochs - 1))
56 | print('-' * 10)
57 |
58 | since = time.time()
59 |
60 | # Each epoch has a training and validation phase
61 | for phase in ['train', 'val']:
62 | if phase == 'train':
63 |
64 | for param_group in optimizer_temp.param_groups:
65 | print("LR", param_group['lr'])
66 |
67 | model_template.train() # Set model to training mode
68 | model_source.train()
69 | model_corr2softmax.train()
70 | model_trans_template.train()
71 | model_trans_source.train()
72 | model_trans_corr2softmax.train()
73 | else:
74 | model_template.eval() # Set model to evaluate mode
75 | model_source.eval()
76 | model_corr2softmax.eval()
77 | model_trans_template.eval()
78 | model_trans_source.eval()
79 | model_trans_corr2softmax.eval()
80 |
81 | metrics = defaultdict(float)
82 | epoch_samples = 0
83 |
84 | if phase == 'train':
85 | for template, source, groundTruth_number, scale_gt, gt_trans in dataloader(batch_size)[phase]:
86 | iters = iters + 1
87 | template = template.to(device)
88 | source = source.to(device)
89 | torch.autograd.set_detect_anomaly(True)
90 |
91 | # zero the parameter gradients
92 | optimizer_temp.zero_grad()
93 | optimizer_src.zero_grad()
94 | optimizer_c2s.zero_grad()
95 | optimizer_trans_temp.zero_grad()
96 | optimizer_trans_src.zero_grad()
97 | optimizer_trans_c2s.zero_grad()
98 |
99 | # forward
100 | loss_rot, loss_scale, loss_l1_rot, loss_mse_rot, loss_l1_scale, loss_mse_scale, template_visual_rot, source_visual_rot \
101 | = train_rot_scale(template, source, groundTruth_number.clone(), scale_gt.clone(),\
102 | model_template, model_source, model_corr2softmax, phase, device )
103 | loss_y, loss_x, total_loss, loss_l1_x,loss_l1_y,loss_mse_x, loss_mse_y, template_visual_trans, source_visual_trans \
104 | = train_translation(template, source, groundTruth_number.clone(), scale_gt.clone(), gt_trans, \
105 | model_trans_template, model_trans_source, model_trans_corr2softmax, phase, dsnt, device)
106 |
107 |
108 | # backward + optimize only if in training phase:
109 | if phase == 'train':
110 | # print(iters)
111 | with torch.autograd.detect_anomaly():
112 | total_loss.backward(retain_graph=False)
113 | loss_rot.backward(retain_graph=True)
114 | # loss_l1_rot.backward(retain_graph=False)
115 | # loss_scale.backward(retain_graph=True)
116 | # loss_x.backward(retain_graph=True)
117 | # loss_y.backward(retain_graph=True)
118 | optimizer_temp.step()
119 | optimizer_src.step()
120 | optimizer_c2s.step()
121 | optimizer_trans_temp.step()
122 | optimizer_trans_src.step()
123 | optimizer_trans_c2s.step()
124 | writer.add_scalar('LOSS ROTATION', loss_rot.detach().cpu().numpy(), iters)
125 | writer.add_scalar('LOSS SCALE', loss_scale.detach().cpu().numpy(), iters)
126 | writer.add_scalar('LOSS X', loss_x.detach().cpu().numpy(), iters)
127 | writer.add_scalar('LOSS Y', loss_y.detach().cpu().numpy(), iters)
128 |
129 | writer.add_scalar('LOSS ROTATION L1', loss_l1_rot.item(), iters)
130 | writer.add_scalar('LOSS ROTATION MSE', loss_mse_rot.item(), iters)
131 | writer.add_scalar('LOSS SCALE L1', loss_l1_scale.item(), iters)
132 | writer.add_scalar('LOSS SCALE MSE', loss_mse_scale.item(), iters)
133 |
134 | writer.add_scalar('LOSS X L1', loss_l1_x.item(), iters)
135 | writer.add_scalar('LOSS X MSE', loss_mse_x.item(), iters)
136 | writer.add_scalar('LOSS Y L1', loss_l1_y.item(), iters)
137 | writer.add_scalar('LOSS Y MSE', loss_mse_y.item(), iters)
138 |
139 | writer.add_image("temp_input", template[0,:,:].cpu(), iters)
140 | writer.add_image("src_input", source[0,:,:].cpu(), iters)
141 | writer.add_image("unet_temp_rot", template_visual_rot[0,:,:].cpu(), iters)
142 | writer.add_image("unet_src_rot", source_visual_rot[0,:,:].cpu(), iters)
143 | writer.add_image("unet_temp_trans", template_visual_trans[0,:,:].cpu(), iters)
144 | writer.add_image("unet_src_trans", source_visual_trans[0,:,:].cpu(), iters)
145 | # writer.add_image("fft_temp", template_fft_visual[0,:,:].detach().cpu(), iters)
146 | # writer.add_image("fft_src", source_fft_visual[0,:,:].detach().cpu(), iters)
147 | # writer.add_image("logpolar_temp", template_logpolar_visual[0,:,:].cpu(), iters)
148 | # writer.add_image("logpolar_src", source_logpolar_visual[0,:,:].cpu(), iters)
149 | # writer.add_image("new", new_source_img[0,:,:].cpu())
150 |
151 | # statistics
152 | epoch_samples = epoch_samples + template.size(0)
153 |
154 |
155 | checkpoint = {'epoch': epoch + 1,
156 | 'state_dict_temp': model_template.state_dict(),
157 | 'optimizer_temp': optimizer_temp.state_dict(),
158 | 'state_dict_src': model_source.state_dict(),
159 | 'optimizer_src': optimizer_src.state_dict(),
160 | 'state_dict_c2s': model_corr2softmax.state_dict(),
161 | 'optimizer_c2s': optimizer_c2s.state_dict(),
162 | 'state_dict_trans_temp': model_trans_template.state_dict(),
163 | 'optimizer_trans_temp': optimizer_trans_temp.state_dict(),
164 | 'state_dict_trans_src': model_trans_source.state_dict(),
165 | 'optimizer_trans_src': optimizer_trans_src.state_dict(),
166 | 'state_dict_trans_c2s': model_trans_corr2softmax.state_dict(),
167 | 'optimizer_trans_c2s': optimizer_trans_c2s.state_dict()}
168 |
169 | if phase == 'val':
170 | print("in val")
171 | loss_list = val_model(model_template, model_source, model_corr2softmax,\
172 | model_trans_template, model_trans_source, model_trans_corr2softmax,\
173 | writer_val, iters, dsnt, dataloader, batch_size_val, device, epoch)
174 | epoch_loss = np.mean(loss_list)
175 | print("epoch_loss", epoch_loss)
176 | print("best_loss", best_loss)
177 | # print("accuracy = ", acc)
178 | if epoch_loss < best_loss:
179 | is_best = True
180 | best_loss = epoch_loss
181 | else:
182 | is_best = False
183 | save_checkpoint(checkpoint, is_best, save_path)
184 |
185 |
186 | time_elapsed = time.time() - since
187 | print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
188 |
189 |
190 | scheduler_temp.step()
191 | scheduler_src.step()
192 | scheduler_trans_temp.step()
193 | scheduler_trans_src.step()
194 |
195 | print('Best val loss: {:4f}'.format(best_loss))
196 |
197 | return model_template, model_source
198 |
199 |
200 |
201 | save_path = args.save_path
202 | checkpoint_path = args.load_path
203 | load_pretrained = args.load_pretrained
204 | load_optimizer = args.load_optimizer
205 | simulation = args.simulation
206 | dsnt = args.use_dsnt
207 | load_pretrained_mode = args.pretrained_mode
208 | batch_size = args.batch_size_train
209 | batch_size_val = args.batch_size_val
210 | dataloader = generate_dataloader if simulation else DPCNdataloader
211 | device = torch.device("cuda:0" if not args.cpu else "cpu")
212 | print("The devices that the code is running on:", device)
213 | print("batch size is ",batch_size)
214 |
215 |
216 | # to create models for rotations and translations for source images and template images
217 | num_class = 1
218 | start_epoch = 0
219 | model_template = UNet(num_class).to(device)
220 | model_source = UNet(num_class).to(device)
221 | model_corr2softmax = Corr2Softmax(200., 0.).to(device)
222 | model_trans_template = UNet(num_class).to(device)
223 | model_trans_source = UNet(num_class).to(device)
224 | model_trans_corr2softmax = Corr2Softmax(11.72, 0.).to(device)
225 |
226 |
227 | optimizer_ft_temp = optim.Adam(filter(lambda p: p.requires_grad, model_template.parameters()), lr=4e-3)
228 | optimizer_ft_src = optim.Adam(filter(lambda p: p.requires_grad, model_source.parameters()), lr=4e-3)
229 | optimizer_c2s = optim.Adam(filter(lambda p: p.requires_grad, model_corr2softmax.parameters()), lr=1e-1)
230 | optimizer_trans_ft_temp = optim.AdamW(filter(lambda p: p.requires_grad, model_trans_template.parameters()), lr=4e-3)
231 | optimizer_trans_ft_src = optim.AdamW(filter(lambda p: p.requires_grad, model_trans_source.parameters()), lr=4e-3)
232 | optimizer_trans_c2s = optim.AdamW(filter(lambda p: p.requires_grad, model_trans_corr2softmax.parameters()), lr=5e-2)
233 |
234 | exp_lr_scheduler_temp = lr_scheduler.StepLR(optimizer_ft_temp, step_size=1, gamma=0.8)
235 | exp_lr_scheduler_src = lr_scheduler.StepLR(optimizer_ft_src, step_size=1, gamma=0.8)
236 | exp_lr_scheduler_trans_temp = lr_scheduler.StepLR(optimizer_trans_ft_temp, step_size=1, gamma=0.8)
237 | exp_lr_scheduler_trans_src = lr_scheduler.StepLR(optimizer_trans_ft_src, step_size=1, gamma=0.8)
238 |
239 |
240 | # load pretrained model based on the input pretrained mode
241 | if load_pretrained:
242 | if load_pretrained_mode == 'all':
243 | if load_optimizer:
244 | model_template, model_source, model_corr2softmax, model_trans_template, model_trans_source, model_trans_corr2softmax,\
245 | optimizer_ft_temp, optimizer_ft_src, optimizer_c2s, optimizer_trans_ft_temp, optimizer_trans_ft_src, optimizer_trans_c2s,\
246 | start_epoch = load_checkpoint(\
247 | checkpoint_path, model_template, model_source, model_corr2softmax, model_trans_template, model_trans_source, model_trans_corr2softmax,\
248 | optimizer_ft_temp, optimizer_ft_src, optimizer_c2s, optimizer_trans_ft_temp, optimizer_trans_ft_src, optimizer_trans_c2s, device)
249 | else:
250 | model_template, model_source, model_corr2softmax, model_trans_template, model_trans_source, model_trans_corr2softmax,\
251 | _, _, _, _, _, _,\
252 | start_epoch = load_checkpoint(\
253 | checkpoint_path, model_template, model_source, model_corr2softmax, model_trans_template, model_trans_source, model_trans_corr2softmax,\
254 | optimizer_ft_temp, optimizer_ft_src, optimizer_c2s, optimizer_trans_ft_temp, optimizer_trans_ft_src, optimizer_trans_c2s, device)
255 |
256 | if load_pretrained_mode == 'trans':
257 | model_trans_template, model_trans_source,\
258 | start_epoch = load_trans_checkpoint(\
259 | checkpoint_path, model_trans_template, model_trans_source,\
260 | device)
261 | if load_pretrained_mode == 'rot':
262 | model_template, model_source, model_corr2softmax,\
263 | optimizer_ft_temp, optimizer_ft_src, optimizer_c2s = load_rot_checkpoint(\
264 | checkpoint_path, model_template, model_source, model_corr2softmax,\
265 | optimizer_ft_temp, optimizer_ft_src, optimizer_c2s, device)
266 |
267 | model_template, model_source = train_model(model_template, model_source, model_corr2softmax, model_trans_template, model_trans_source, model_trans_corr2softmax,\
268 | optimizer_ft_temp, optimizer_ft_src, optimizer_c2s, optimizer_trans_ft_temp, optimizer_trans_ft_src, optimizer_trans_c2s,\
269 | exp_lr_scheduler_temp, exp_lr_scheduler_src, exp_lr_scheduler_trans_temp, exp_lr_scheduler_trans_src,\
270 | save_path, start_epoch, num_epochs=700)
271 |
--------------------------------------------------------------------------------
/unet/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Naoto Usuyama
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/unet/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of useful helper functions."""
2 |
--------------------------------------------------------------------------------
/unet/helper.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from functools import reduce
3 | import numpy as np
4 | import itertools
5 |
6 | def plot_img_array(img_array, ncol=3):
7 | nrow = len(img_array) // ncol
8 |
9 | f, plots = plt.subplots(nrow, ncol, sharex='all', sharey='all', figsize=(ncol * 4, nrow * 4))
10 |
11 | for i in range(len(img_array)):
12 | plots[i // ncol, i % ncol]
13 | plots[i // ncol, i % ncol].imshow(img_array[i])
14 |
15 | def plot_side_by_side(img_arrays):
16 | flatten_list = reduce(lambda x,y: x+y, zip(*img_arrays))
17 |
18 | plot_img_array(np.array(flatten_list), ncol=len(img_arrays))
19 |
20 | def plot_errors(results_dict, title):
21 | markers = itertools.cycle(('+', 'x', 'o'))
22 |
23 | plt.title('{}'.format(title))
24 |
25 | for label, result in sorted(results_dict.items()):
26 | plt.plot(result, marker=next(markers), label=label)
27 | plt.ylabel('dice_coef')
28 | plt.xlabel('epoch')
29 | plt.legend(loc=3, bbox_to_anchor=(1, 0))
30 |
31 | plt.show()
32 |
33 | def masks_to_colorimg(masks):
34 | colors = np.asarray([(201, 58, 64), (242, 207, 1), (0, 152, 75), (101, 172, 228),(56, 34, 132), (160, 194, 56)])
35 |
36 | colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255
37 | channels, height, width = masks.shape
38 |
39 | for y in range(height):
40 | for x in range(width):
41 | selected_colors = colors[masks[:,y,x] > 0.5]
42 |
43 | if len(selected_colors) > 0:
44 | colorimg[y,x,:] = np.mean(selected_colors, axis=0)
45 |
46 | return colorimg.astype(np.uint8)
--------------------------------------------------------------------------------
/unet/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def dice_loss(pred, target, smooth = 1.):
5 | pred = pred.contiguous()
6 | target = target.contiguous()
7 |
8 | intersection = (pred * target).sum(dim=2).sum(dim=2)
9 |
10 | loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
11 |
12 | return loss.mean()
13 |
--------------------------------------------------------------------------------
/unet/pytorch_DPCN.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.abspath(".."))
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | from phase_correlation.phase_corr import phase_corr
8 | from log_polar.log_polar import polar_transformer
9 | from utils.utils import *
10 | print("sys path", sys.path)
11 | from utils.utils import *
12 |
13 | class LogPolar(nn.Module):
14 | def __init__(self, out_size, device):
15 | super(LogPolar, self).__init__()
16 | self.out_size = out_size
17 | self.device = device
18 |
19 | def forward(self, input):
20 | return polar_transformer(input, self.out_size, self.device)
21 |
22 |
23 | class PhaseCorr(nn.Module):
24 | def __init__(self, device, logbase, modelc2s, trans=False):
25 | super(PhaseCorr, self).__init__()
26 | self.device = device
27 | self.logbase = logbase
28 | self.trans = trans
29 | self.modelc2s = modelc2s
30 |
31 | def forward(self, template, source):
32 | return phase_corr(template, source, self.device, self.logbase, self.modelc2s, trans=self.trans)
33 |
34 | class FFT2(nn.Module):
35 | def __init__(self, device):
36 | super(FFT2, self).__init__()
37 | self.device = device
38 |
39 | def forward(self, input):
40 | median_output = torch.rfft(input, 2, onesided=False)
41 | median_output_r = median_output[:, :, :, 0]
42 | median_output_i = median_output[:, :, :, 1]
43 | # print("median_output r", median_output_r)
44 | # print("median_output i", median_output_i)
45 | output = torch.sqrt(median_output_r ** 2 + median_output_i ** 2 + 1e-15)
46 | # output = median_outputW_r
47 | output = fftshift2d(output)
48 | # h = logpolar_filter((output.shape[1],output.shape[2]), self.device)
49 | # output = output.squeeze(0) * h
50 | # output = output.unsqueeze(-1)
51 | output = output.unsqueeze(-1)
52 | return output
53 |
54 | def double_conv(in_channels, out_channels):
55 | return nn.Sequential(
56 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
57 | nn.BatchNorm2d(out_channels),
58 | nn.ReLU(inplace=True),
59 | nn.Conv2d(out_channels, out_channels, 3, padding=1),
60 | nn.BatchNorm2d(out_channels),
61 | nn.ReLU(inplace=True)
62 | )
63 |
64 |
65 | class UNet(nn.Module):
66 |
67 | def __init__(self, n_class):
68 | super().__init__()
69 |
70 | self.dconv_down1 = double_conv(1, 64)
71 | self.dconv_down2 = double_conv(64, 128)
72 | self.dconv_down3 = double_conv(128, 256)
73 | self.dconv_down4 = double_conv(256, 512)
74 |
75 | self.maxpool = nn.MaxPool2d(2)
76 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
77 |
78 |
79 | self.dconv_up3 = double_conv(256 + 512, 256)
80 | self.dconv_up2 = double_conv(128 + 256, 128)
81 | self.dconv_up1 = double_conv(128 + 64, 64)
82 |
83 | self.conv_last = nn.Conv2d(64, n_class, 1)
84 |
85 |
86 | def forward(self, x):
87 | conv1 = self.dconv_down1(x)
88 | x = self.maxpool(conv1)
89 |
90 | conv2 = self.dconv_down2(x)
91 | x = self.maxpool(conv2)
92 |
93 | conv3 = self.dconv_down3(x)
94 | x = self.maxpool(conv3)
95 |
96 | x = self.dconv_down4(x)
97 |
98 | x = self.upsample(x)
99 | x = torch.cat([x, conv3], dim=1)
100 |
101 | x = self.dconv_up3(x)
102 | x = self.upsample(x)
103 | x = torch.cat([x, conv2], dim=1)
104 |
105 | x = self.dconv_up2(x)
106 | x = self.upsample(x)
107 | x = torch.cat([x, conv1], dim=1)
108 |
109 | x = self.dconv_up1(x)
110 |
111 | out = self.conv_last(x)
112 |
113 | return out
114 |
115 | class Corr2Softmax(nn.Module):
116 |
117 | def __init__(self, weight, bias):
118 |
119 | super(Corr2Softmax, self).__init__()
120 | softmax_w = torch.tensor((weight), requires_grad=True)
121 | softmax_b = torch.tensor((bias), requires_grad=True)
122 | self.softmax_w = torch.nn.Parameter(softmax_w)
123 | self.softmax_b = torch.nn.Parameter(softmax_b)
124 | self.register_parameter("softmax_w",self.softmax_w)
125 | self.register_parameter("softmax_b",self.softmax_b)
126 | def forward(self, x):
127 | x1 = self.softmax_w*x + self.softmax_b
128 | # print("w = ",self.softmax_w, "b = ",self.softmax_b)
129 | # x1 = 1000. * x
130 | return x1
131 |
132 |
133 |
134 |
135 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of useful helper functions."""
2 |
--------------------------------------------------------------------------------
/utils/detect_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import kornia
3 | import time
4 | import copy
5 | import shutil
6 | import numpy as np
7 | import torch.nn as nn
8 | from graphviz import Digraph
9 | from torch.optim import lr_scheduler
10 | from collections import defaultdict
11 | import torch.nn.functional as F
12 | from unet.loss import dice_loss
13 | import torch.optim as optim
14 | from data.dataset import *
15 | from unet.pytorch_DPCN import FFT2, UNet, LogPolar, PhaseCorr, Corr2Softmax
16 | from data.dataset_DPCN import *
17 | from tensorboardX import SummaryWriter
18 | from utils.utils import *
19 | def detect_rot_scale(template_rot, source_rot, model_template_rot, model_source_rot, model_corr2softmax_rot, device):
20 | print(" ")
21 | print(" DETETCTING ROTATION AND SCALE")
22 | print(" ")
23 | template_unet_rot = model_template_rot(template_rot)
24 | source_unet_rot = model_source_rot(source_rot)
25 |
26 | # for tensorboard visualize
27 | template_visual_rot = template_unet_rot
28 | source_visual_rot = source_unet_rot
29 |
30 | # print(np.shape(template_unet_rot))
31 | # imshow(template_unet_rot)
32 | # convert to [B,H,W,C]
33 | template_unet_rot = template_unet_rot.permute(0,2,3,1)
34 | source_unet_rot = source_unet_rot.permute(0,2,3,1)
35 |
36 | template_unet_rot = template_unet_rot.squeeze(-1)
37 | source_unet_rot = source_unet_rot.squeeze(-1)
38 |
39 | fft_layer = FFT2(device)
40 | template_fft = fft_layer(template_unet_rot)
41 | source_fft = fft_layer(source_unet_rot) # [B,H,W,1]
42 |
43 | h = logpolar_filter((source_fft.shape[1],source_fft.shape[2]), device)#highpass((source.shape[1],source.shape[2])) # [H,W]
44 | template_fft = template_fft.squeeze(-1) * h
45 | source_fft = source_fft.squeeze(-1) * h
46 |
47 | template_fft = template_fft.unsqueeze(-1)
48 | source_fft = source_fft.unsqueeze(-1)
49 |
50 | # for tensorboard visualize
51 | template_fft_visual = template_fft.permute(0,3,1,2)
52 | source_fft_visual = source_fft.permute(0,3,1,2)
53 |
54 | logpolar_layer = LogPolar((template_fft.shape[1], template_fft.shape[2]), device)
55 | template_logpolar, logbase_rot = logpolar_layer(template_fft)
56 | source_logpolar, logbase_rot = logpolar_layer(source_fft)
57 |
58 | # for tensorboard visualize
59 | template_logpolar_visual = template_logpolar.permute(0,3,1,2)
60 | source_logpolar_visual = source_logpolar.permute(0,3,1,2)
61 |
62 | template_logpolar = template_logpolar.squeeze(-1)
63 | source_logpolar = source_logpolar.squeeze(-1)
64 | phase_corr_layer_rs = PhaseCorr(device, logbase_rot, model_corr2softmax_rot)
65 | rotation_cal, scale_cal, softmax_result_rot, corr_result_rot = phase_corr_layer_rs(template_logpolar, source_logpolar)
66 |
67 |
68 |
69 | # use phasecorr result
70 |
71 |
72 | print("rotation =", rotation_cal)
73 | print("scale =", scale_cal)
74 | # print("gt_angle ", gt_angle)
75 |
76 | # # flatten the tensor:
77 | # b_loss,h_loss,w_loss = groundTruth.shape
78 | # groundTruth = groundTruth.reshape(b_loss,h_loss*w_loss)
79 | # softmax_final = softmax_final.reshape(b_loss,h_loss*w_loss)
80 |
81 |
82 | # set the loss function:
83 | # compute_loss = torch.nn.KLDivLoss(reduction="sum").to(device)
84 | # compute_loss = torch.nn.BCEWithLogitsLoss(reduction="sum").to(device)
85 | compute_loss_rot = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
86 | # compute_loss = torch.nn.MSELoss()
87 | # compute_loss=torch.nn.L1Loss()
88 |
89 | return rotation_cal, scale_cal
90 |
91 | def detect_translation(template_trans, source_trans, rotation, scale, model_template_trans, model_source_trans, model_corr2softmax_trans, device ):
92 | print(" ")
93 | print(" DETECTING TRANSLATION")
94 | print(" ")
95 |
96 |
97 | # for AGDatase
98 | b, c, h, w = source_trans.shape
99 | center = torch.ones(b,2).to(device)
100 | center[:, 0] = h // 2
101 | center[:, 1] = w // 2
102 | angle_rot = torch.ones(b).to(device) * (-rotation.to(device))
103 | scale_rot = torch.ones(b).to(device) * (1/scale.to(device))
104 | rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot)
105 | source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w))
106 | # imshow(template_trans[0,:,:])
107 | # time.sleep(2)
108 | # imshow(source_trans[0,:,:])
109 | # time.sleep(2)
110 |
111 | # imshow(template,"temp")
112 | # imshow(source, "src")
113 |
114 | template_unet_trans = model_template_trans(template_trans)
115 | source_unet_trans = model_source_trans(source_trans)
116 |
117 | # for tensorboard visualize
118 | template_visual_trans = template_unet_trans
119 | source_visual_trans = source_unet_trans
120 |
121 | template_unet_trans = template_unet_trans.permute(0,2,3,1)
122 | source_unet_trans = source_unet_trans.permute(0,2,3,1)
123 |
124 | template_unet_trans = template_unet_trans.squeeze(-1)
125 | source_unet_trans = source_unet_trans.squeeze(-1)
126 |
127 | (b, h, w) = template_unet_trans.shape
128 | logbase_trans = torch.tensor(1.)
129 | phase_corr_layer_xy = PhaseCorr(device, logbase_trans, model_corr2softmax_trans)
130 | t0, t1, softmax_result_trans, corr_result_trans = phase_corr_layer_xy(template_unet_trans.to(device), source_unet_trans.to(device))
131 |
132 | # use phasecorr result
133 |
134 | corr_final_trans = corr_result_trans.clone()
135 | # corr_visual = corr_final_trans.unsqueeze(-1)
136 | # corr_visual = corr_visual.permute(0,3,1,2)
137 | corr_y = torch.sum(corr_final_trans.clone(), 2, keepdim=False)
138 | # corr_2d = corr_final_trans.clone().reshape(b, h*w)
139 | # corr_2d = model_corr2softmax(corr_2d)
140 | corr_y = model_corr2softmax_trans(corr_y)
141 | input_c = nn.functional.softmax(corr_y.clone(), dim=-1)
142 | indices_c = np.linspace(0, 1, 256)
143 | indices_c = torch.tensor(np.reshape(indices_c, (-1, 256))).to(device)
144 | transformation_y = torch.sum((256 - 1) * input_c * indices_c, dim=-1)
145 | # transformation_y = torch.argmax(corr_y, dim=-1)
146 |
147 | corr_x = torch.sum(corr_final_trans.clone(), 1, keepdim=False)
148 | # corr_final_trans = corr_final_trans.reshape(b, h*w)
149 | corr_x = model_corr2softmax_trans(corr_x)
150 | input_r = nn.functional.softmax(corr_x.clone(), dim=-1)
151 | indices_r = np.linspace(0, 1, 256)
152 | indices_r = torch.tensor(np.reshape(indices_r, (-1, 256))).to(device)
153 | # transformation_x = torch.argmax(corr_x, dim=-1)
154 | transformation_x = torch.sum((256 - 1) * input_r * indices_r, dim=-1)
155 |
156 | print("trans x", transformation_x)
157 | print("trans y", transformation_y)
158 |
159 | trans_mat_affine = torch.Tensor([[[1.0,0.0,transformation_x-128.0],[0.0,1.0,transformation_y-128.0]]]).to(device)
160 | template_trans = kornia.warp_affine(template_trans.to(device), trans_mat_affine, dsize=(h, w))
161 | image_aligned = align_image(template_trans[0,:,:], source_trans[0,:,:])
162 | # imshow(template_trans[0,:,:])
163 | # time.sleep(2)
164 | # imshow(source_trans[0,:,:])
165 | # time.sleep(2)
166 |
167 | return transformation_y, transformation_x, image_aligned, source_trans
168 |
169 |
170 |
171 |
--------------------------------------------------------------------------------
/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import kornia
3 | import time
4 | import copy
5 | import shutil
6 | import numpy as np
7 | import torch.nn as nn
8 | from graphviz import Digraph
9 | from torch.optim import lr_scheduler
10 | from collections import defaultdict
11 | import torch.nn.functional as F
12 | from unet.loss import dice_loss
13 | import torch.optim as optim
14 | from data.dataset import *
15 | from unet.pytorch_DPCN import FFT2, UNet, LogPolar, PhaseCorr, Corr2Softmax
16 | from data.dataset_DPCN import *
17 | from tensorboardX import SummaryWriter
18 | from utils.utils import *
19 | from validate import val_model
20 | def train_rot_scale(template_rot, source_rot, groundTruth_rot, groundTruth_scale, model_template_rot, model_source_rot, model_corr2softmax_rot, phase, device):
21 | print(" ")
22 | print(" TRAINING ROTATION AND SCALE")
23 | print(" ")
24 |
25 | with torch.set_grad_enabled(phase == 'train'):
26 | template_unet_rot = model_template_rot(template_rot)
27 | source_unet_rot = model_source_rot(source_rot)
28 |
29 | # for tensorboard visualize
30 | template_visual_rot = template_unet_rot
31 | source_visual_rot = source_unet_rot
32 |
33 | # print(np.shape(template_unet_rot))
34 | # imshow(template_unet_rot)
35 | # convert to [B,H,W,C]
36 | template_unet_rot = template_unet_rot.permute(0,2,3,1)
37 | source_unet_rot = source_unet_rot.permute(0,2,3,1)
38 |
39 | template_unet_rot = template_unet_rot.squeeze(-1)
40 | source_unet_rot = source_unet_rot.squeeze(-1)
41 |
42 | fft_layer = FFT2(device)
43 | template_fft = fft_layer(template_unet_rot)
44 | source_fft = fft_layer(source_unet_rot) # [B,H,W,1]
45 |
46 | h = logpolar_filter((source_fft.shape[1],source_fft.shape[2]), device)#highpass((source.shape[1],source.shape[2])) # [H,W]
47 | template_fft = template_fft.squeeze(-1) * h
48 | source_fft = source_fft.squeeze(-1) * h
49 |
50 | template_fft = template_fft.unsqueeze(-1)
51 | source_fft = source_fft.unsqueeze(-1)
52 |
53 | # for tensorboard visualize
54 | template_fft_visual = template_fft.permute(0,3,1,2)
55 | source_fft_visual = source_fft.permute(0,3,1,2)
56 |
57 | logpolar_layer = LogPolar((template_fft.shape[1], template_fft.shape[2]), device)
58 | template_logpolar, logbase_rot = logpolar_layer(template_fft)
59 | source_logpolar, logbase_rot = logpolar_layer(source_fft)
60 |
61 | # for tensorboard visualize
62 | template_logpolar_visual = template_logpolar.permute(0,3,1,2)
63 | source_logpolar_visual = source_logpolar.permute(0,3,1,2)
64 |
65 | template_logpolar = template_logpolar.squeeze(-1)
66 | source_logpolar = source_logpolar.squeeze(-1)
67 | source_logpolar.retain_grad()
68 | phase_corr_layer_rs = PhaseCorr(device, logbase_rot, model_corr2softmax_rot)
69 | rotation_cal, scale_cal, softmax_result_rot, corr_result_rot = phase_corr_layer_rs(template_logpolar, source_logpolar)
70 | # print("logbaseeeeeeee", logbase_rot)
71 | # use softmax
72 |
73 | # softmax_final = softmax_result.clone()
74 |
75 | # for batch_num in range(softmax_result.shape[0]):
76 | # lower, upper = softmax_result[batch_num].clone().chunk(2,0)
77 | # if success[batch_num] == 1:
78 | # softmax_final[batch_num] = torch.cat((upper,lower),0)
79 | # else:
80 | # softmax_final[batch_num] = softmax_result[batch_num].clone()
81 |
82 | # softmax_final = torch.sum(softmax_final, 2, keepdim=False)
83 |
84 |
85 |
86 | # use phasecorr result
87 |
88 | corr_final_rot = corr_result_rot.clone()
89 | corr_visual_rot = corr_final_rot.unsqueeze(-1)
90 | corr_visual_rot = corr_visual_rot.permute(0,3,1,2)
91 |
92 | corr_final_rot = torch.sum(corr_final_rot, 2, keepdim=False)
93 |
94 | corr_final_rot = model_corr2softmax_rot(corr_final_rot)
95 | input_rot = nn.functional.softmax(corr_final_rot.clone(), dim=-1)
96 | indice_rot = np.linspace(0, 1, 256)
97 | indice_rot = torch.tensor(np.reshape(indice_rot, (-1, 256))).to(device)
98 | rot_exp = torch.sum((256 - 1) * input_rot * indice_rot, dim=-1)
99 |
100 | # groundTruth = groundTruth.to(device)
101 |
102 |
103 |
104 |
105 | corr_final_scale = corr_result_rot.clone()
106 | corr_final_scale = torch.sum(corr_final_scale,1,keepdim=False)
107 | corr_final_scale = model_corr2softmax_rot(corr_final_scale)
108 |
109 | # consider angle and scale as the loss
110 |
111 | groundTruth_rot = groundTruth_rot.to(device)
112 | gt_number = groundTruth_rot.clone()
113 | gt_angle = GT_angle_convert(gt_number,256)
114 | gt_angle = gt_angle.to(device)
115 |
116 | groundTruth_scale = groundTruth_scale.to(device)
117 | gt_scale = GT_scale_convert(groundTruth_scale.clone(), logbase_rot, 256)
118 | gt_scale = gt_scale.to(device)
119 |
120 |
121 |
122 |
123 |
124 | ACC_rot = (1-(rotation_cal-groundTruth_rot).abs()/(groundTruth_rot+0.00000000000000001)).mean()
125 | if ACC_rot <= 0:
126 | ACC_rot = torch.Tensor([0.5])
127 | ACC_scale = (1-(scale_cal-groundTruth_scale).abs()/(groundTruth_scale+0.00000000000000001)).mean()
128 | if ACC_scale <= 0:
129 | ACC_scale = torch.Tensor([0.5])
130 |
131 |
132 |
133 |
134 | print("rotation =", rotation_cal)
135 |
136 | print("gt_rot =", groundTruth_rot, "\n")
137 | print("scale =", scale_cal)
138 | print("gt_scale =", groundTruth_scale, "\n")
139 |
140 | # print("gt_angle ", gt_angle)
141 | print("ACC_rot = ",ACC_rot.item()*100,"%")
142 | print("ACC_scale = ",ACC_scale.item()*100,"%")
143 |
144 |
145 | # # flatten the tensor:
146 | # b_loss,h_loss,w_loss = groundTruth.shape
147 | # groundTruth = groundTruth.reshape(b_loss,h_loss*w_loss)
148 | # softmax_final = softmax_final.reshape(b_loss,h_loss*w_loss)
149 |
150 |
151 | # set the loss function:
152 | # compute_loss = torch.nn.KLDivLoss(reduction="sum").to(device)
153 | # compute_loss = torch.nn.BCEWithLogitsLoss(reduction="sum").to(device)
154 | compute_loss_rot = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
155 | compute_loss_scale = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
156 | compute_mse = torch.nn.MSELoss()
157 | compute_l1=torch.nn.L1Loss().to(device)
158 | # compute_loss = torch.nn.MSELoss()
159 | # compute_loss=torch.nn.L1Loss()
160 |
161 | loss_rot = compute_loss_rot(corr_final_rot,gt_angle)
162 | loss_scale = compute_loss_scale(corr_final_scale,gt_scale)
163 | loss_l1_rot = compute_l1(rot_exp, groundTruth_rot)
164 | loss_l1_scale = compute_l1(scale_cal, groundTruth_scale)
165 | loss_mse_rot = compute_mse(rotation_cal, groundTruth_rot)
166 | loss_mse_scale = compute_mse(scale_cal, groundTruth_scale)
167 | print("loss rot ==", loss_rot)
168 | print("loss scale ==", loss_scale)
169 | return loss_rot, loss_scale, loss_l1_rot, loss_mse_rot, loss_l1_scale, loss_mse_scale, template_visual_rot, source_visual_rot
170 |
171 | def train_translation(template_trans, source_trans, groundTruth_number, scale_gt, gt_trans, model_template_trans, model_source_trans, model_corr2softmax_trans, phase, dsnt, device ):
172 | print(" ")
173 | print(" TRAINING TRANSLATION")
174 | print(" ")
175 | with torch.set_grad_enabled(phase == 'train'):
176 | # # for toy dataset
177 | # b, c, h, w = source_trans.shape
178 | # center = torch.ones(b,2).to(device)
179 | # center[:, 0] = h // 2
180 | # center[:, 1] = w // 2
181 | # angle_rot = torch.ones(b).to(device) * (-groundTruth_number.to(device))
182 | # scale_rot = torch.ones(b).to(device)
183 | # rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot)
184 | # source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w))
185 |
186 | # for AGDatase
187 | b, c, h, w = source_trans.shape
188 | center = torch.ones(b,2).to(device)
189 | center[:, 0] = h // 2
190 | center[:, 1] = w // 2
191 | angle_rot = torch.ones(b).to(device) * (-groundTruth_number.to(device))
192 | scale_rot = torch.ones(b).to(device) / scale_gt.to(device)
193 | rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot)
194 | source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w))
195 |
196 |
197 | # imshow(template,"temp")
198 | # imshow(source, "src")
199 |
200 | template_unet_trans = model_template_trans(template_trans)
201 | source_unet_trans = model_source_trans(source_trans)
202 |
203 | # for tensorboard visualize
204 | template_visual_trans = template_unet_trans
205 | source_visual_trans = source_unet_trans
206 |
207 | template_unet_trans = template_unet_trans.permute(0,2,3,1)
208 | source_unet_trans = source_unet_trans.permute(0,2,3,1)
209 |
210 | template_unet_trans = template_unet_trans.squeeze(-1)
211 | source_unet_trans = source_unet_trans.squeeze(-1)
212 |
213 | (b, h, w) = template_unet_trans.shape
214 | logbase_trans = torch.tensor(1.)
215 | phase_corr_layer_xy = PhaseCorr(device, logbase_trans, model_corr2softmax_trans, trans=True)
216 | t0, t1, softmax_result_trans, corr_result_trans = phase_corr_layer_xy(template_unet_trans.to(device), source_unet_trans.to(device))
217 |
218 | # use phasecorr result
219 | if not dsnt:
220 |
221 | corr_final_trans = corr_result_trans.clone()
222 | # corr_visual = corr_final_trans.unsqueeze(-1)
223 | # corr_visual = corr_visual.permute(0,3,1,2)
224 | corr_y = torch.sum(corr_final_trans.clone(), 2, keepdim=False)
225 | # corr_2d = corr_final_trans.clone().reshape(b, h*w)
226 | # corr_2d = model_corr2softmax(corr_2d)
227 | corr_y = model_corr2softmax_trans(corr_y)
228 | input_c = nn.functional.softmax(corr_y.clone(), dim=-1)
229 | indices_c = np.linspace(0, 1, 256)
230 | indices_c = torch.tensor(np.reshape(indices_c, (-1, 256))).to(device)
231 | tranformation_y = torch.sum((256 - 1) * input_c * indices_c, dim=-1)
232 | # tranformation_y = torch.argmax(corr_y, dim=-1)
233 |
234 | corr_x = torch.sum(corr_final_trans.clone(), 1, keepdim=False)
235 | # corr_final_trans = corr_final_trans.reshape(b, h*w)
236 | corr_x = model_corr2softmax_trans(corr_x)
237 | input_r = nn.functional.softmax(corr_x.clone(), dim=-1)
238 | indices_r = np.linspace(0, 1, 256)
239 | indices_r = torch.tensor(np.reshape(indices_r, (-1, 256))).to(device)
240 | # tranformation_x = torch.argmax(corr_x, dim=-1)
241 | tranformation_x = torch.sum((256 - 1) * input_r * indices_r, dim=-1)
242 |
243 | # only consider angle as the los
244 | # softmax_result = torch.sum(corr_result.clone(), 2, keepdim=False)
245 | # softmax_final = softmax_result.clone()
246 | # # softmax_visual = softmax_final.unsqueeze(-1)
247 | # # softmax_visual = softmax_visual.permute(0,3,1,2)
248 |
249 | # softmax_final = softmax_final.reshape(b, h*w)
250 | # softmax_final = model_corr2softmax(softmax_final.clone())
251 | gt_trans_orig = gt_trans.clone().to(device)
252 |
253 | # print("err_true = ",err_true.item()*100,"%")
254 |
255 | gt_trans_convert = GT_trans_convert(gt_trans_orig, [256, 256])
256 | gt_trans_convert_y = gt_trans_convert[:,0]
257 | gt_trans_convert_x = gt_trans_convert[:,1]
258 |
259 | print("trans x", tranformation_x)
260 | print("gt_convert x", gt_trans_convert_x, "\n")
261 |
262 |
263 | print("trans y", tranformation_y)
264 | print("gt_convert y", gt_trans_convert_y,"\n")
265 |
266 |
267 |
268 | # set the loss function:
269 | # compute_loss = torch.nn.KLDivLoss(reduction="sum").to(device)
270 | # compute_loss = torch.nn.BCEWithLogitsLoss(reduction="sum").to(device)
271 | compute_loss_y = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
272 | compute_loss_x = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
273 | # mse_loss = torch.nn.MSELoss(reduce=True)
274 | # compute_loss = torch.nn.NLLLoss()
275 | # compute_l1loss_a = torch.nn.L1Loss()
276 | # compute_l1loss_x = torch.nn.L1Loss()
277 | compute_mse = torch.nn.MSELoss()
278 | compute_l1=torch.nn.L1Loss()
279 |
280 | # mse_loss = mse_loss(rotation_cal.float(), groundTruth_number.float())
281 | loss_l1_x = compute_l1(tranformation_x, gt_trans_convert_x)
282 | loss_l1_y = compute_l1(tranformation_y, gt_trans_convert_y)
283 | loss_mse_x = compute_mse(tranformation_x, gt_trans_convert_x)
284 | loss_mse_y = compute_mse(tranformation_y, gt_trans_convert_y)
285 | loss_x = compute_loss_x(corr_x, gt_trans_convert_x)
286 | loss_y = compute_loss_y(corr_y, gt_trans_convert_y)
287 | total_loss = loss_x + loss_y + loss_l1_x + loss_l1_y
288 | return loss_y, loss_x, total_loss, loss_l1_x,loss_l1_y,loss_mse_x, loss_mse_y, template_visual_trans, source_visual_trans
289 | else:
290 | corr_final_trans = corr_result_trans.clone()
291 | # corr_visual = corr_final_trans.unsqueeze(-1)
292 | # corr_visual = corr_visual.permute(0,3,1,2)
293 | corr_y = torch.sum(corr_final_trans.clone(), 2, keepdim=False)
294 | # corr_2d = corr_final_trans.clone().reshape(b, h*w)
295 | # corr_2d = model_corr2softmax(corr_2d)
296 | corr_y = model_corr2softmax_trans(corr_y)
297 |
298 | corr_x = torch.sum(corr_final_trans.clone(), 1, keepdim=False)
299 | # corr_final_trans = corr_final_trans.reshape(b, h*w)
300 | corr_x = model_corr2softmax_trans(corr_x)
301 |
302 | corr_mat_dsnt_trans = corr_result_trans.clone().unsqueeze(-1)
303 | corr_mat_dsnt_trans_final = model_corr2softmax_trans(corr_mat_dsnt_trans)
304 | corr_mat_dsnt_trans_final = kornia.spatial_softmax2d(corr_mat_dsnt_trans_final)
305 | coors_trans = kornia.spatial_expectation2d(corr_mat_dsnt_trans_final,False)
306 | tranformation_x = coors_trans[:,0,0]
307 | tranformation_y = coors_trans[:,0,1]
308 |
309 | # only consider angle as the los
310 | # softmax_result = torch.sum(corr_result.clone(), 2, keepdim=False)
311 | # softmax_final = softmax_result.clone()
312 | # # softmax_visual = softmax_final.unsqueeze(-1)
313 | # # softmax_visual = softmax_visual.permute(0,3,1,2)
314 |
315 | # softmax_final = softmax_final.reshape(b, h*w)
316 | # softmax_final = model_corr2softmax(softmax_final.clone())
317 | gt_trans_orig = gt_trans.clone().to(device)
318 |
319 | # print("err_true = ",err_true.item()*100,"%")
320 |
321 | gt_trans_convert = GT_trans_convert(gt_trans_orig, [256, 256])
322 | gt_trans_convert_y = gt_trans_convert[:,0]
323 | gt_trans_convert_x = gt_trans_convert[:,1]
324 |
325 | print("trans x", tranformation_x)
326 | print("gt_convert x", gt_trans_convert_x, "\n")
327 |
328 |
329 | print("trans y", tranformation_y)
330 | print("gt_convert y", gt_trans_convert_y,"\n")
331 |
332 |
333 |
334 | # set the loss function:
335 | # compute_loss = torch.nn.KLDivLoss(reduction="sum").to(device)
336 | # compute_loss = torch.nn.BCEWithLogitsLoss(reduction="sum").to(device)
337 | compute_loss_y = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
338 | compute_loss_x = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
339 | # mse_loss = torch.nn.MSELoss(reduce=True)
340 | # compute_loss = torch.nn.NLLLoss()
341 | # compute_l1loss_a = torch.nn.L1Loss()
342 | # compute_l1loss_x = torch.nn.L1Loss()
343 | compute_mse = torch.nn.MSELoss()
344 | compute_l1=torch.nn.L1Loss()
345 |
346 | # mse_loss = mse_loss(rotation_cal.float(), groundTruth_number.float())
347 | loss_l1_x = compute_l1(tranformation_x, gt_trans_convert_x)
348 | loss_l1_y = compute_l1(tranformation_y, gt_trans_convert_y)
349 | loss_mse_x = compute_mse(tranformation_x, gt_trans_convert_x.type(torch.FloatTensor).to(device))
350 | loss_mse_y = compute_mse(tranformation_y, gt_trans_convert_y.type(torch.FloatTensor).to(device))
351 | loss_x = compute_loss_x(corr_x, gt_trans_convert_x)
352 | loss_y = compute_loss_y(corr_y, gt_trans_convert_y)
353 | total_loss = 0.001*(loss_mse_x + loss_mse_y)
354 | return loss_y, loss_x, total_loss, loss_l1_x,loss_l1_y,loss_mse_x, loss_mse_y, template_visual_trans, source_visual_trans
355 |
356 |
357 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.abspath("../unet"))
4 | from torchvision import transforms, utils
5 | import matplotlib.pyplot as plt
6 | from collections import defaultdict
7 | import torch.nn.functional as F
8 | from unet.loss import dice_loss
9 | import torch
10 | import torch.optim as optim
11 | import torch.nn as nn
12 | from torch.autograd import Variable
13 | from torch.optim import lr_scheduler
14 | import time
15 | import copy
16 | import numpy as np
17 | import shutil
18 | import math
19 | from PIL import Image
20 | import kornia
21 | import cv2
22 |
23 |
24 | def logpolar_filter(shape, device):
25 | """
26 | Make a radial cosine filter for the logpolar transform.
27 | This filter suppresses low frequencies and completely removes
28 | the zero freq.
29 | """
30 | yy = np.linspace(- np.pi / 2., np.pi / 2., shape[0])[:, np.newaxis]
31 | xx = np.linspace(- np.pi / 2., np.pi / 2., shape[1])[np.newaxis, :]
32 | # Supressing low spatial frequencies is a must when using log-polar
33 | # transform. The scale stuff is poorly reflected with low freqs.
34 | rads = np.sqrt(yy ** 2 + xx ** 2)
35 | filt = 1.0 - np.cos(rads) ** 2
36 | # vvv This doesn't really matter, very high freqs are not too usable anyway
37 | filt[np.abs(rads) > np.pi / 2] = 1
38 | filt = torch.from_numpy(filt).to(device)
39 | return filt
40 |
41 | def roll_n(X, axis, n):
42 | # print("x")
43 | # print(X)
44 |
45 | f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim()))
46 | b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim()))
47 | front = X[f_idx]
48 | back = X[b_idx]
49 | return torch.cat([back, front], axis)
50 |
51 | def fftshift2d(x):
52 | for dim in range(1, len(x.size())):
53 | n_shift = x.size(dim)//2
54 | if x.size(dim) % 2 != 0:
55 | n_shift = n_shift + 1 # for odd-sized images
56 | x = roll_n(x, axis=dim, n=n_shift)
57 | return x # last dim=2 (real&imag)
58 |
59 | def batch_fftshift2d(x):
60 | real, imag = torch.unbind(x, -1)
61 | for dim in range(1, len(real.size())):
62 | n_shift = real.size(dim)//2
63 | if real.size(dim) % 2 != 0:
64 | n_shift = n_shift+1 # for odd-sized images
65 | real = roll_n(real, axis=dim, n=n_shift)
66 | imag = roll_n(imag, axis=dim, n=n_shift)
67 | return torch.stack((real, imag), -1) # last dim=2 (real&imag)
68 |
69 | def softargmax(input, device, beta=100):
70 | *_, h, w = input.shape
71 |
72 | # input = input.reshape(*_, h * w)
73 | input = input.squeeze(0)
74 | input = input.reshape(1, h * w)
75 | input = input * 6000
76 | result = torch.sum(torch.exp(input).to(device) / torch.sum(torch.exp(input).to(device)).to(device) * torch.arange(0,h*w).to(device)).to(device)
77 | col = result // h
78 | row = result % col
79 |
80 | return result
81 |
82 | def softargmax2d(input, device, beta=10000):
83 | *_, h, w = input.shape
84 |
85 | input_orig = input.reshape(*_, h * w)
86 | # print(torch.max(input_orig))
87 | # print(torch.min(input_orig))
88 | beta_t = 100000. / torch.max(input_orig).to(device)
89 | # print(input_orig)
90 | # print(beta * input_orig)
91 | input_d = nn.functional.softmax(beta_t * input_orig, dim=-1)
92 | input_orig.retain_grad()
93 | # print(torch.argmax(input))
94 | # print(torch.max(input_d))
95 | # print(torch.min(input_d))
96 | # print(torch.sum(input))
97 |
98 | indices_c, indices_r = np.meshgrid(
99 | np.linspace(0, 1, w),
100 | np.linspace(0, 1, h),
101 | indexing='xy'
102 | )
103 |
104 | indices_r = torch.tensor(np.reshape(indices_r, (-1, h * w))).to(device)
105 | indices_c = torch.tensor(np.reshape(indices_c, (-1, h * w))).to(device)
106 |
107 | result_r = torch.sum((h - 1) * input_d * indices_r, dim=-1)
108 | result_c = torch.sum((w - 1) * input_d * indices_c, dim=-1)
109 |
110 | result = torch.stack([result_r, result_c], dim=-1)
111 | # result.sum().backward(retain_graph=True)
112 | # print(input_orig.grad)
113 |
114 | return result
115 |
116 | def softmax2d(input, device, beta=10000):
117 | *_, h, w = input.shape
118 |
119 | input_orig = input.reshape(*_, h * w)
120 | beta_t = 100. / torch.max(input_orig).to(device)
121 | input_d = nn.functional.softmax(1000 * input_orig, dim=-1)
122 | soft_r = input_d.reshape(*_,h,w)
123 | # soft_r.retain_grad()
124 | # print("softmax grad =======", soft_r.grad)
125 | return soft_r
126 |
127 | def GT_angle_convert(this_gt,size):
128 | for batch_num in range(this_gt.shape[0]):
129 | if this_gt[batch_num] > 90:
130 | this_gt[batch_num] = this_gt[batch_num].clone() - 90
131 | else:
132 | this_gt[batch_num] = this_gt[batch_num].clone() + 90
133 | this_gt[batch_num] = this_gt[batch_num].clone()*size/180
134 | this_gt[batch_num] = this_gt[batch_num].clone()//1 + (this_gt[batch_num].clone()%1+0.5)//1
135 | if this_gt[batch_num].long() == size:
136 | this_gt[batch_num] = this_gt[batch_num] - 1
137 | return this_gt.long()
138 | def GT_scale_convert(scale_gt,logbase,size):
139 | for batch_num in range(scale_gt.shape[0]):
140 | scale_gt[batch_num] = torch.log10(1/scale_gt[batch_num].clone())/torch.log10(logbase)+128.
141 | return scale_gt.long()
142 |
143 | def GT_trans_convert(this_trans, size):
144 | this_trans = (this_trans.clone() + size[0] // 2)
145 | # gt_converted = this_trans[:,1] * size[0] + this_trans[:,0]
146 | # gt_converted = this_trans[:,0]
147 | gt_converted = this_trans
148 | # # create a gt for kldivloss
149 | # kldiv_gt = torch.zeros(this_trans.clone().shape[0],size[0],size[1])
150 | # gauss_blur = kornia.filters.GaussianBlur2d((5, 5), (5, 5))
151 | # for batch_num in range(this_trans.clone().shape[0]):
152 | # kldiv_gt[batch_num, this_trans.clone()[batch_num,0].long(), this_trans.clone()[batch_num,1].long()] = 1
153 |
154 | # kldiv_gt = torch.unsqueeze(kldiv_gt.clone(), dim = 0)
155 | # kldiv_gt = kldiv_gt.permute(1,0,2,3)
156 | # kldiv_gt = gauss_blur(kldiv_gt.clone())
157 | # kldiv_gt = kldiv_gt.permute(1,0,2,3)
158 | # kldiv_gt = torch.squeeze(kldiv_gt.clone(), dim = 0)
159 | # (b, h, w) = kldiv_gt.shape
160 | # kldiv_gt = kldiv_gt.clone().reshape(b, h*w)
161 | # # Create GT for Pooling data
162 | # gt_pooling = torch.floor(this_trans.clone()/4)
163 | return gt_converted.long()
164 |
165 | def calc_loss(pred, target, metrics, bce_weight=0.5):
166 | bce = F.binary_cross_entropy_with_logits(pred, target)
167 |
168 | pred = F.sigmoid(pred)
169 | dice = dice_loss(pred, target)
170 |
171 | loss = bce * bce_weight + dice * (1 - bce_weight)
172 |
173 | metrics['bce'] = metrics['bce']+bce.data.cpu().numpy() * target.size(0)
174 | metrics['dice'] = metrics['dice']+dice.data.cpu().numpy() * target.size(0)
175 | metrics['loss'] = metrics['loss']+loss.data.cpu().numpy() * target.size(0)
176 |
177 | return loss
178 |
179 | def print_metrics(metrics, epoch_samples, phase):
180 | outputs = []
181 | for k in metrics.keys():
182 | outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
183 |
184 | print("{}: {}".format(phase, ", ".join(outputs)))
185 | def imshow(tensor, title=None):
186 | image = tensor.cpu().detach().numpy() # we clone the tensor to not do changes on it
187 | image = image.squeeze(0) # remove the fake batch dimension
188 | plt.imshow(image, cmap="gray")
189 | if title is not None:
190 | plt.title(title)
191 | plt.pause(0.001) # pause a bit so that plots are updated
192 | def heatmap_imshow(tensor, title=None):
193 | image = tensor.cpu().detach().numpy() # we clone the tensor to not do changes on it
194 | image = gaussian_filter(image, sigma = 5, mode = 'nearest')
195 | plt.imshow(image, cmap="jet", interpolation="hamming")
196 | plt.colorbar()
197 | if title is not None:
198 | plt.title(title)
199 | plt.pause(0.001) # pause a bit so that plots are updated
200 | def align_image(template, source):
201 | template = template.cpu().detach().numpy()
202 | source = source.cpu().detach().numpy()
203 | template = template.squeeze(0) # remove the fake batch dimension
204 | source = source.squeeze(0) # remove the fake batch dimension
205 | dst = cv2.addWeighted(template, 1, source, 0.6, 0)
206 | # plt.imshow(dst, cmap="gray")
207 | # plt.show() # pause a bit so that plots are updated
208 | return dst
209 |
210 | def plot_and_save_result(template, source, rotated, dst):
211 | template = template.cpu().detach().numpy()
212 | source = source.cpu().detach().numpy()
213 | template = template.squeeze(0) # remove the fake batch dimension
214 | source = source.squeeze(0) # remove the fake batch dimension
215 | rotated = rotated.cpu().detach().numpy()
216 | rotated = rotated.squeeze(0)
217 |
218 |
219 | result = plt.figure()
220 | result_t = result.add_subplot(1,4,1)
221 | result_t.set_title("Template")
222 | result_t.imshow(template, cmap="gray").axes.get_xaxis().set_visible(False)
223 | result_t.imshow(template, cmap="gray").axes.get_yaxis().set_visible(False)
224 |
225 | result_s = result.add_subplot(1,4,2)
226 | result_s.set_title("Source")
227 | result_s.imshow(source, cmap="gray").axes.get_xaxis().set_visible(False)
228 | result_s.imshow(source, cmap="gray").axes.get_yaxis().set_visible(False)
229 |
230 | result_r = result.add_subplot(1,4,3)
231 | result_r.set_title("Rotated Source")
232 | result_r.imshow(rotated, cmap="gray").axes.get_xaxis().set_visible(False)
233 | result_r.imshow(rotated, cmap="gray").axes.get_yaxis().set_visible(False)
234 |
235 | result_d = result.add_subplot(1,4,4)
236 | result_d.set_title("Destination")
237 | result_d.imshow(dst, cmap="gray").axes.get_xaxis().set_visible(False)
238 | result_d.imshow(dst, cmap="gray").axes.get_yaxis().set_visible(False)
239 | plt.savefig("Result.png")
240 | plt.show()
241 |
242 |
243 |
244 | def save_checkpoint(state, is_best, checkpoint_dir):
245 | file_path = checkpoint_dir + 'checkpoint.pt'
246 | torch.save(state, file_path)
247 | if is_best:
248 | best_fpath = checkpoint_dir + 'best_model.pt'
249 | shutil.copyfile(file_path, best_fpath)
250 |
251 |
252 | def load_checkpoint(checkpoint_fpath, model_template, model_source, model_c2s, model_trans_template, model_trans_source, model_trans_c2s,\
253 | optimizer_temp, optimizer_src, optimizer_c2s, optimizer_trans_temp, optimizer_trans_src, optimizer_trans_c2s, device):
254 |
255 | if (device == torch.device('cpu')):
256 | print("using cpu")
257 | checkpoint = torch.load(checkpoint_fpath, map_location=torch.device('cpu'))
258 | else:
259 | checkpoint = torch.load(checkpoint_fpath, map_location=device)
260 |
261 |
262 | model_template.load_state_dict(checkpoint['state_dict_temp'])
263 | model_source.load_state_dict(checkpoint['state_dict_src'])
264 | model_c2s.load_state_dict(checkpoint['state_dict_c2s'])
265 | optimizer_temp.load_state_dict(checkpoint['optimizer_temp'])
266 | optimizer_src.load_state_dict(checkpoint['optimizer_src'])
267 | optimizer_c2s.load_state_dict(checkpoint['optimizer_c2s'])
268 |
269 | model_trans_template.load_state_dict(checkpoint['state_dict_trans_temp'])
270 | model_trans_source.load_state_dict(checkpoint['state_dict_trans_src'])
271 | model_trans_c2s.load_state_dict(checkpoint['state_dict_trans_c2s'])
272 | optimizer_trans_temp.load_state_dict(checkpoint['optimizer_trans_temp'])
273 | optimizer_trans_src.load_state_dict(checkpoint['optimizer_trans_src'])
274 | optimizer_trans_c2s.load_state_dict(checkpoint['optimizer_trans_c2s'])
275 |
276 | return model_template, model_source, model_c2s, model_trans_template, model_trans_source, model_trans_c2s,\
277 | optimizer_temp, optimizer_src, optimizer_c2s, optimizer_trans_temp, optimizer_trans_src, optimizer_trans_c2s, \
278 | checkpoint['epoch']
279 |
280 | def load_trans_checkpoint(checkpoint_fpath, model_trans_template, model_trans_source,\
281 | device):
282 |
283 | if (device == torch.device('cpu')):
284 | print("using cpu")
285 | checkpoint = torch.load(checkpoint_fpath, map_location=torch.device('cpu'))
286 | else:
287 | checkpoint = torch.load(checkpoint_fpath)
288 |
289 | model_trans_template.load_state_dict(checkpoint['state_dict_t_src'])
290 | model_trans_source.load_state_dict(checkpoint['state_dict_t_temp'])
291 |
292 | return model_trans_template, model_trans_source,\
293 | checkpoint['epoch']
294 |
295 | def load_rot_checkpoint(checkpoint_rpath, model_template, model_source, model_c2s,\
296 | optimizer_temp, optimizer_src, optimizer_c2s, device):
297 |
298 | if (device == torch.device('cpu')):
299 | print("using cpu")
300 | checkpoint = torch.load(checkpoint_rpath, map_location=torch.device('cpu'))
301 | else:
302 | checkpoint = torch.load(checkpoint_rpath)
303 | model_template.load_state_dict(checkpoint['state_dict_temp'])
304 | model_source.load_state_dict(checkpoint['state_dict_src'])
305 | model_c2s.load_state_dict(checkpoint['state_dict_c2s'])
306 | optimizer_temp.load_state_dict(checkpoint['optimizer_temp'])
307 | optimizer_src.load_state_dict(checkpoint['optimizer_src'])
308 | optimizer_c2s.load_state_dict(checkpoint['optimizer_c2s'])
309 |
310 | return model_template, model_source, model_c2s, optimizer_temp, optimizer_src, optimizer_c2s
--------------------------------------------------------------------------------
/utils/validate_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import kornia
3 | import time
4 | import copy
5 | import shutil
6 | import numpy as np
7 | import torch.nn as nn
8 | from graphviz import Digraph
9 | from torch.optim import lr_scheduler
10 | from collections import defaultdict
11 | import torch.nn.functional as F
12 | from unet.loss import dice_loss
13 | import torch.optim as optim
14 | from data.dataset import *
15 | from unet.pytorch_DPCN import FFT2, UNet, LogPolar, PhaseCorr, Corr2Softmax
16 | from data.dataset_DPCN import *
17 | from tensorboardX import SummaryWriter
18 | from utils.utils import *
19 | import matplotlib.pyplot as plt
20 |
21 |
22 | def validate_rot_scale(template_rot, source_rot, groundTruth_rot, groundTruth_scale, model_template_rot, model_source_rot, model_corr2softmax_rot, device):
23 | print(" ")
24 | print(" VALIDATING ROTATION AND SCALE")
25 | print(" ")
26 | # imshow(template_rot[0,:,:,:])
27 | # plt.show()
28 | # imshow(source_rot[0,:,:,:])
29 | # plt.show()
30 | template_unet_rot = model_template_rot(template_rot)
31 | source_unet_rot = model_source_rot(source_rot)
32 | # imshow(template_unet_rot[0,:,:,:])
33 | # plt.show()
34 | # imshow(source_unet_rot[0,:,:,:])
35 | # plt.show()
36 | # for tensorboard visualize
37 | template_visual_rot = template_unet_rot
38 | source_visual_rot = source_unet_rot
39 |
40 | # print(np.shape(template_unet_rot))
41 | # imshow(template_unet_rot)
42 | # convert to [B,H,W,C]
43 | template_unet_rot = template_unet_rot.permute(0,2,3,1)
44 | source_unet_rot = source_unet_rot.permute(0,2,3,1)
45 |
46 | template_unet_rot = template_unet_rot.squeeze(-1)
47 | source_unet_rot = source_unet_rot.squeeze(-1)
48 |
49 | fft_layer = FFT2(device)
50 | template_fft = fft_layer(template_unet_rot)
51 | source_fft = fft_layer(source_unet_rot) # [B,H,W,1]
52 |
53 | h = logpolar_filter((source_fft.shape[1],source_fft.shape[2]), device)#highpass((source.shape[1],source.shape[2])) # [H,W]
54 | template_fft = template_fft.squeeze(-1) * h
55 | source_fft = source_fft.squeeze(-1) * h
56 |
57 | template_fft = template_fft.unsqueeze(-1)
58 | source_fft = source_fft.unsqueeze(-1)
59 |
60 | # for tensorboard visualize
61 | template_fft_visual = template_fft.permute(0,3,1,2)
62 | source_fft_visual = source_fft.permute(0,3,1,2)
63 |
64 | logpolar_layer = LogPolar((template_fft.shape[1], template_fft.shape[2]), device)
65 | template_logpolar, logbase_rot = logpolar_layer(template_fft)
66 | source_logpolar, logbase_rot = logpolar_layer(source_fft)
67 |
68 | # for tensorboard visualize
69 | template_logpolar_visual = template_logpolar.permute(0,3,1,2)
70 | source_logpolar_visual = source_logpolar.permute(0,3,1,2)
71 |
72 | template_logpolar = template_logpolar.squeeze(-1)
73 | source_logpolar = source_logpolar.squeeze(-1)
74 | print(template_logpolar_visual.shape)
75 | print(source_logpolar_visual.shape)
76 | # imshow(template_logpolar_visual.int()[0,:,:,:])
77 | # plt.show()
78 | # imshow(source_logpolar_visual.int()[0,:,:,:])
79 | # plt.show()
80 | # source_logpolar.retain_grad()
81 | phase_corr_layer_rs = PhaseCorr(device, logbase_rot, model_corr2softmax_rot)
82 | rotation_cal, scale_cal, softmax_result_rot, corr_result_rot = phase_corr_layer_rs(template_logpolar, source_logpolar)
83 |
84 | # use softmax
85 |
86 | # softmax_final = softmax_result.clone()
87 |
88 | # for batch_num in range(softmax_result.shape[0]):
89 | # lower, upper = softmax_result[batch_num].clone().chunk(2,0)
90 | # if success[batch_num] == 1:
91 | # softmax_final[batch_num] = torch.cat((upper,lower),0)
92 | # else:
93 | # softmax_final[batch_num] = softmax_result[batch_num].clone()
94 |
95 | # softmax_final = torch.sum(softmax_final, 2, keepdim=False)
96 |
97 |
98 |
99 | # use phasecorr result
100 |
101 | corr_final_rot = corr_result_rot.clone()
102 | corr_visual_rot = corr_final_rot.unsqueeze(-1)
103 | corr_visual_rot = corr_visual_rot.permute(0,3,1,2)
104 |
105 | corr_final_rot = torch.sum(corr_final_rot, 2, keepdim=False)
106 |
107 | corr_final_rot = model_corr2softmax_rot(corr_final_rot)
108 |
109 | corr_final_scale = corr_result_rot.clone()
110 | corr_final_scale = torch.sum(corr_final_scale,1,keepdim=False)
111 | corr_final_scale = model_corr2softmax_rot(corr_final_scale)
112 |
113 | # groundTruth = groundTruth.to(device)
114 |
115 |
116 |
117 | # consider angle and scale as the loss
118 |
119 | # groundTruth = torch.sum(groundTruth, 2, keepdim=False)
120 |
121 | groundTruth_rot = groundTruth_rot.to(device)
122 | gt_number = groundTruth_rot.clone()
123 | gt_angle = GT_angle_convert(gt_number,256)
124 | gt_angle = gt_angle.to(device)
125 |
126 | groundTruth_scale = groundTruth_scale.to(device)
127 | gt_scale = GT_scale_convert(groundTruth_scale.clone(), logbase_rot, 256)
128 | gt_scale = gt_scale.to(device)
129 |
130 |
131 | # ACC_rot = (1-(rotation_cal-groundTruth_rot).abs()/(groundTruth_rot+0.00000000000000001)).mean()
132 | # if ACC_rot <= 0:
133 | # ACC_rot = torch.Tensor([0.5])
134 | # ACC_scale = (1-(scale_cal-groundTruth_scale).abs()/(groundTruth_scale+0.00000000000000001)).mean()
135 | # if ACC_scale <= 0:
136 | # ACC_scale = torch.Tensor([0.5])
137 |
138 | print("rotation =", rotation_cal)
139 | print("rotation_gt =", groundTruth_rot, "\n")
140 | print("scale =", scale_cal)
141 | print("scale_gt =", groundTruth_scale,"\n")
142 | # print("gt_angle ", gt_angle)
143 | # print("ACC_rot = ",ACC_rot.item()*100,"%")
144 | # print("ACC_scale = ",ACC_scale.item()*100,"%")
145 |
146 |
147 | # # flatten the tensor:
148 | # b_loss,h_loss,w_loss = groundTruth.shape
149 | # groundTruth = groundTruth.reshape(b_loss,h_loss*w_loss)
150 | # softmax_final = softmax_final.reshape(b_loss,h_loss*w_loss)
151 |
152 |
153 | # set the loss function:
154 | # compute_loss = torch.nn.KLDivLoss(reduction="sum").to(device)
155 | # compute_loss = torch.nn.BCEWithLogitsLoss(reduction="sum").to(device)
156 | compute_loss_rot = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
157 | compute_loss_scale = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
158 | compute_mse = torch.nn.MSELoss()
159 | compute_l1=torch.nn.L1Loss()
160 |
161 | loss_rot = compute_loss_rot(corr_final_rot,gt_angle)
162 | loss_scale = compute_loss_scale(corr_final_scale,gt_scale)
163 | loss_l1_rot = compute_l1(rotation_cal, groundTruth_rot)
164 | loss_l1_scale = compute_l1(scale_cal, groundTruth_scale)
165 | loss_mse_rot = compute_mse(rotation_cal, groundTruth_rot)
166 | loss_mse_scale = compute_mse(scale_cal, groundTruth_scale)
167 |
168 | print("Rotation L1", loss_l1_rot)
169 | print("Rotation mse", loss_mse_rot,"\n")
170 | print("Scale L1", loss_l1_scale)
171 | print("Scale mse", loss_mse_scale,"\n")
172 |
173 | print("loss rot ==", loss_rot)
174 | return loss_rot, loss_scale, scale_cal, loss_l1_rot, loss_mse_rot, loss_l1_scale, loss_mse_scale
175 |
176 | def validate_translation(template_trans, source_trans, groundTruth_number, scale_gt, gt_trans, model_template_trans, model_source_trans, model_corr2softmax_trans, acc_x, acc_y, dsnt, device ):
177 | print(" ")
178 | print(" VALIDATING TRANSLATION")
179 | print(" ")
180 | # # for toy dataset
181 | # b, c, h, w = source_trans.shape
182 | # center = torch.ones(b,2).to(device)
183 | # center[:, 0] = h // 2
184 | # center[:, 1] = w // 2
185 | # angle_rot = torch.ones(b).to(device) * (-groundTruth_number.to(device))
186 | # scale_rot = torch.ones(b).to(device)
187 | # rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot)
188 | # source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w))
189 |
190 | # for AGDatase
191 | since = time.time()
192 | b, c, h, w = source_trans.shape
193 | center = torch.ones(b,2).to(device)
194 | center[:, 0] = h // 2
195 | center[:, 1] = w // 2
196 | angle_rot = torch.ones(b).to(device) * (-groundTruth_number.to(device))
197 | scale_rot = torch.ones(b).to(device) / scale_gt.to(device)
198 | # scale_rot = torch.ones(b).to(device) / torch.Tensor([1.1]).to(device)
199 | rot_mat = kornia.get_rotation_matrix2d(center, angle_rot, scale_rot)
200 | source_trans = kornia.warp_affine(source_trans.to(device), rot_mat, dsize=(h, w))
201 | # imshow(template_trans[0,:,:,:])
202 | # plt.show()
203 | # imshow(source_trans[0,:,:,:])
204 | # plt.show()
205 |
206 | # imshow(template,"temp")
207 | # imshow(source, "src")
208 |
209 | template_unet_trans = model_template_trans(template_trans)
210 | source_unet_trans = model_source_trans(source_trans)
211 | # imshow(template_unet_trans[0,:,:,:])
212 | # plt.show()
213 | # imshow(source_unet_trans[0,:,:,:])
214 | # plt.show()
215 |
216 | # for tensorboard visualize
217 | template_visual_trans = template_unet_trans
218 | source_visual_trans = source_unet_trans
219 |
220 | template_unet_trans = template_unet_trans.permute(0,2,3,1)
221 | source_unet_trans = source_unet_trans.permute(0,2,3,1)
222 |
223 | template_unet_trans = template_unet_trans.squeeze(-1)
224 | source_unet_trans = source_unet_trans.squeeze(-1)
225 |
226 | (b, h, w) = template_unet_trans.shape
227 | logbase_trans = torch.tensor(1.)
228 | phase_corr_layer_xy = PhaseCorr(device, logbase_trans, model_corr2softmax_trans, trans=True)
229 | t0, t1, softmax_result_trans, corr_result_trans = phase_corr_layer_xy(template_unet_trans.to(device), source_unet_trans.to(device))
230 |
231 |
232 | if not dsnt:
233 | # use phasecorr result
234 |
235 | corr_final_trans = corr_result_trans.clone()
236 | # corr_visual = corr_final_trans.unsqueeze(-1)
237 | # corr_visual = corr_visual.permute(0,3,1,2)
238 | corr_y = torch.sum(corr_final_trans.clone(), 2, keepdim=False)
239 | # corr_2d = corr_final_trans.clone().reshape(b, h*w)
240 | # corr_2d = model_corr2softmax(corr_2d)
241 | corr_y = model_corr2softmax_trans(corr_y)
242 | input_c = nn.functional.softmax(corr_y.clone(), dim=-1)
243 | indices_c = np.linspace(0, 1, 256)
244 | indices_c = torch.tensor(np.reshape(indices_c, (-1, 256))).to(device)
245 | tranformation_y = torch.sum((256 - 1) * input_c * indices_c, dim=-1)
246 | tranformation_y_show = torch.argmax(corr_y, dim=-1)
247 |
248 | corr_x = torch.sum(corr_final_trans.clone(), 1, keepdim=False)
249 | # corr_final_trans = corr_final_trans.reshape(b, h*w)
250 | corr_x = model_corr2softmax_trans(corr_x)
251 | input_r = nn.functional.softmax(corr_x.clone(), dim=-1)
252 | indices_r = np.linspace(0, 1, 256)
253 | indices_r = torch.tensor(np.reshape(indices_r, (-1, 256))).to(device)
254 | tranformation_x_show = torch.argmax(corr_x, dim=-1)
255 | tranformation_x = torch.sum((256 - 1) * input_r * indices_r, dim=-1)
256 | time_elapsed = time.time() - since
257 | print("time elapsed", time_elapsed)
258 | print('in val time {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
259 |
260 | # only consider angle as the los
261 | # softmax_result = torch.sum(corr_result.clone(), 2, keepdim=False)
262 | # softmax_final = softmax_result.clone()
263 | # # softmax_visual = softmax_final.unsqueeze(-1)
264 | # # softmax_visual = softmax_visual.permute(0,3,1,2)
265 |
266 | # softmax_final = softmax_final.reshape(b, h*w)
267 | # softmax_final = model_corr2softmax(softmax_final.clone())
268 | gt_trans_orig = gt_trans.clone().to(device)
269 |
270 | # err_true = (1-((t0-gt_trans_orig).abs()/(gt_trans_orig+1e-15))).mean()
271 | # if err_true <= 0:
272 | # err_true = torch.Tensor([0.5])
273 | # print("err_true = ",err_true.item()*100,"%")
274 |
275 | gt_trans_convert = GT_trans_convert(gt_trans_orig, [256, 256])
276 | gt_trans_convert_y = gt_trans_convert[:,0]
277 | gt_trans_convert_x = gt_trans_convert[:,1]
278 |
279 | print("trans x", tranformation_x_show)
280 | print("trans y", tranformation_y_show)
281 | print("gt_convert x", gt_trans_convert_x)
282 | print("gt_convert y", gt_trans_convert_y)
283 | # err_y = (1-(abs(tranformation_y_show-gt_trans_convert_y))/(gt_trans_convert_y+1e-10)).mean()
284 | # if err_y <= 0:
285 | # err_y = torch.Tensor([0.5])
286 | # print("err y", err_y.item()*100,"%")
287 |
288 |
289 | # err_x = (1-(abs(tranformation_x_show-gt_trans_convert_x))/(gt_trans_convert_x+1e-10)).mean()
290 | # if err_x <= 0:
291 | # err_x = torch.Tensor([0.5])
292 | # print("err x", err_x.item()*100,"%")
293 |
294 | arg_x = []
295 | corr_top5 = corr_x.clone().detach().cpu()
296 | for i in range(1):
297 | max_ind = torch.argmax(corr_top5, dim=-1)
298 | arg_x.append(max_ind)
299 | for batch_num in range(b):
300 | corr_top5[batch_num, max_ind[batch_num]] = 0.
301 |
302 |
303 | for batch_num in range(b):
304 | min_x = 100.
305 | for i in range(1):
306 | if abs(gt_trans_convert_x[batch_num].float() - torch.round(arg_x[i][batch_num].float())) < min_x:
307 | min_x = abs(gt_trans_convert_x[batch_num].float() - torch.round(arg_x[i][batch_num].float()))
308 |
309 | for class_id in range(20):
310 | for batch_num in range(b):
311 | if min_x <= class_id:
312 | acc_x[class_id] += 1
313 |
314 |
315 | arg_y = []
316 | corr_top5 = corr_y.clone().detach().cpu()
317 | for i in range(1):
318 | max_ind = torch.argmax(corr_top5, dim=-1)
319 | arg_y.append(max_ind)
320 | for batch_num in range(b):
321 | corr_top5[batch_num, max_ind[batch_num]] = 0.
322 |
323 |
324 | for batch_num in range(b):
325 | min_y = 100.
326 | for i in range(1):
327 | if abs(gt_trans_convert_y[batch_num] - torch.round(arg_y[i][batch_num].float())) < min_y:
328 | min_y = abs(gt_trans_convert_y[batch_num] - torch.round(arg_y[i][batch_num].float()))
329 |
330 | for class_id in range(20):
331 | for batch_num in range(b):
332 | if min_y <= class_id:
333 | acc_y[class_id] += 1
334 |
335 |
336 | # set the loss function:
337 | # compute_loss = torch.nn.KLDivLoss(reduction="sum").to(device)
338 | # compute_loss = torch.nn.BCEWithLogitsLoss(reduction="sum").to(device)
339 | compute_loss_y = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
340 | compute_loss_x = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
341 | # mse_loss = torch.nn.MSELoss(reduce=True)
342 | # compute_loss = torch.nn.NLLLoss()
343 | # compute_l1loss_a = torch.nn.L1Loss()
344 | # compute_l1loss_x = torch.nn.L1Loss()
345 | # compute_l1loss_y = torch.nn.L1Loss()
346 | compute_mse = torch.nn.MSELoss()
347 | compute_l1 = torch.nn.L1Loss()
348 |
349 | # mse_loss = mse_loss(rotation_cal.float(), groundTruth_number.float())
350 | # l1loss = compute_l1loss(rotation_cal, groundTruth_number)
351 | loss_l1_x = compute_l1(tranformation_x, gt_trans_convert_x)
352 | loss_l1_y = compute_l1(tranformation_y, gt_trans_convert_y)
353 | loss_mse_x = compute_mse(tranformation_x, gt_trans_convert_x)
354 | loss_mse_y = compute_mse(tranformation_y, gt_trans_convert_y)
355 | loss_x = compute_loss_x(corr_x, gt_trans_convert_x)
356 | loss_y = compute_loss_y(corr_y, gt_trans_convert_y)
357 | total_loss = loss_x + loss_y + loss_l1_x +loss_l1_y
358 | return loss_y, loss_x, total_loss, loss_l1_x,loss_l1_y,loss_mse_x, loss_mse_y
359 |
360 | else:
361 | corr_final_trans = corr_result_trans.clone()
362 | # corr_visual = corr_final_trans.unsqueeze(-1)
363 | # corr_visual = corr_visual.permute(0,3,1,2)
364 | corr_y = torch.sum(corr_final_trans.clone(), 2, keepdim=False)
365 | # corr_2d = corr_final_trans.clone().reshape(b, h*w)
366 | # corr_2d = model_corr2softmax(corr_2d)
367 | corr_y = model_corr2softmax_trans(corr_y)
368 |
369 | corr_x = torch.sum(corr_final_trans.clone(), 1, keepdim=False)
370 | corr_mat_dsnt_trans = corr_result_trans.clone().unsqueeze(-1)
371 | corr_mat_dsnt_trans_final = model_corr2softmax_trans(corr_mat_dsnt_trans)
372 | corr_mat_dsnt_trans_final = kornia.spatial_softmax2d(corr_mat_dsnt_trans_final)
373 | coors_trans = kornia.spatial_expectation2d(corr_mat_dsnt_trans_final,False)
374 | tranformation_x = coors_trans[:,0,0]
375 | tranformation_y = coors_trans[:,0,1]
376 |
377 | # only consider angle as the los
378 | # softmax_result = torch.sum(corr_result.clone(), 2, keepdim=False)
379 | # softmax_final = softmax_result.clone()
380 | # # softmax_visual = softmax_final.unsqueeze(-1)
381 | # # softmax_visual = softmax_visual.permute(0,3,1,2)
382 |
383 | # softmax_final = softmax_final.reshape(b, h*w)
384 | # softmax_final = model_corr2softmax(softmax_final.clone())
385 | gt_trans_orig = gt_trans.clone().to(device)
386 |
387 | # err_true = (1-((t0-gt_trans_orig).abs()/(gt_trans_orig+1e-15))).mean()
388 | # if err_true <= 0:
389 | # err_true = torch.Tensor([0.5])
390 | # print("err_true = ",err_true.item()*100,"%")
391 |
392 | gt_trans_convert = GT_trans_convert(gt_trans_orig, [256, 256])
393 | gt_trans_convert_y = gt_trans_convert[:,0]
394 | gt_trans_convert_x = gt_trans_convert[:,1]
395 |
396 | print("trans x", tranformation_x)
397 | print("trans y", tranformation_y)
398 | print("gt_convert x", gt_trans_convert_x)
399 | print("gt_convert y", gt_trans_convert_y)
400 | # err_y = (1-(abs(tranformation_y_show-gt_trans_convert_y))/(gt_trans_convert_y+1e-10)).mean()
401 | # if err_y <= 0:
402 | # err_y = torch.Tensor([0.5])
403 | # print("err y", err_y.item()*100,"%")
404 |
405 |
406 | # err_x = (1-(abs(tranformation_x_show-gt_trans_convert_x))/(gt_trans_convert_x+1e-10)).mean()
407 | # if err_x <= 0:
408 | # err_x = torch.Tensor([0.5])
409 | # print("err x", err_x.item()*100,"%")
410 |
411 | arg_x = []
412 | corr_top5 = corr_x.clone().detach().cpu()
413 | for i in range(1):
414 | max_ind = torch.argmax(corr_top5, dim=-1)
415 | arg_x.append(max_ind)
416 | for batch_num in range(b):
417 | corr_top5[batch_num, max_ind[batch_num]] = 0.
418 |
419 |
420 | for batch_num in range(b):
421 | min_x = 100.
422 | for i in range(1):
423 | if abs(gt_trans_convert_x[batch_num].float() - torch.round(arg_x[i][batch_num].float())) < min_x:
424 | min_x = abs(gt_trans_convert_x[batch_num].float() - torch.round(arg_x[i][batch_num].float()))
425 |
426 | for class_id in range(20):
427 | for batch_num in range(b):
428 | if min_x <= class_id:
429 | acc_x[class_id] += 1
430 |
431 |
432 | arg_y = []
433 | corr_top5 = corr_y.clone().detach().cpu()
434 | for i in range(1):
435 | max_ind = torch.argmax(corr_top5, dim=-1)
436 | arg_y.append(max_ind)
437 | for batch_num in range(b):
438 | corr_top5[batch_num, max_ind[batch_num]] = 0.
439 |
440 |
441 | for batch_num in range(b):
442 | min_y = 100.
443 | for i in range(1):
444 | if abs(gt_trans_convert_y[batch_num] - torch.round(arg_y[i][batch_num].float())) < min_y:
445 | min_y = abs(gt_trans_convert_y[batch_num] - torch.round(arg_y[i][batch_num].float()))
446 |
447 | for class_id in range(20):
448 | for batch_num in range(b):
449 | if min_y <= class_id:
450 | acc_y[class_id] += 1
451 |
452 |
453 | # set the loss function:
454 | # compute_loss = torch.nn.KLDivLoss(reduction="sum").to(device)
455 | # compute_loss = torch.nn.BCEWithLogitsLoss(reduction="sum").to(device)
456 | compute_loss_y = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
457 | compute_loss_x = torch.nn.CrossEntropyLoss(reduction="sum").to(device)
458 | # mse_loss = torch.nn.MSELoss(reduce=True)
459 | # compute_loss = torch.nn.NLLLoss()
460 | # compute_l1loss_a = torch.nn.L1Loss()
461 | # compute_l1loss_x = torch.nn.L1Loss()
462 | # compute_l1loss_y = torch.nn.L1Loss()
463 | compute_mse = torch.nn.MSELoss()
464 | compute_l1 = torch.nn.L1Loss()
465 |
466 | # mse_loss = mse_loss(rotation_cal.float(), groundTruth_number.float())
467 | # l1loss = compute_l1loss(rotation_cal, groundTruth_number)
468 | loss_l1_x = compute_l1(tranformation_x, gt_trans_convert_x)
469 | loss_l1_y = compute_l1(tranformation_y, gt_trans_convert_y)
470 | loss_mse_x = compute_mse(tranformation_x, gt_trans_convert_x.type(torch.FloatTensor).to(device))
471 | loss_mse_y = compute_mse(tranformation_y, gt_trans_convert_y.type(torch.FloatTensor).to(device))
472 | loss_x = compute_loss_x(corr_x, gt_trans_convert_x)
473 | loss_y = compute_loss_y(corr_y, gt_trans_convert_y)
474 | total_loss = loss_mse_x + loss_mse_y
475 | return loss_y, loss_x, total_loss, loss_l1_x,loss_l1_y,loss_mse_x, loss_mse_y
476 |
477 |
478 |
479 |
--------------------------------------------------------------------------------
/validate.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import torch.nn.functional as F
3 | import torch
4 | import torch.optim as optim
5 | import torch.nn as nn
6 | from torch.optim import lr_scheduler
7 | import time
8 | import copy
9 | from unet.pytorch_DPCN import FFT2, UNet, LogPolar, PhaseCorr, Corr2Softmax
10 | from data.dataset_DPCN import *
11 | import numpy as np
12 | import shutil
13 | from utils.utils import *
14 | import kornia
15 | from data.dataset import *
16 | from utils.validate_utils import *
17 | import argparse
18 |
19 |
20 | def val_model(model_template, model_source, model_corr2softmax,\
21 | model_trans_template, model_trans_source, model_trans_corr2softmax, \
22 | writer_val, iters, dsnt, dataloader, batch_size_val, device, epoch):
23 |
24 | # for the use of visualizing the validation properly on the tensorboard
25 | iters -= 500
26 | phase = "val"
27 | loss_list = []
28 | rot_list = []
29 | model_template.eval() # Set model to evaluate mode
30 | model_source.eval()
31 | model_corr2softmax.eval()
32 | model_trans_template.eval()
33 | model_trans_source.eval()
34 | model_trans_corr2softmax.eval()
35 | acc_x = np.zeros(20)
36 | acc_y = np.zeros(20)
37 | acc = 0.
38 |
39 | with torch.no_grad():
40 |
41 | for template, source, groundTruth_number, gt_scale, gt_trans in dataloader(batch_size_val)[phase]:
42 | template = template.to(device)
43 | source = source.to(device)
44 | iters += 1
45 | # imshow(template[0,:,:])
46 | # plt.show()
47 | # imshow(source[0,:,:])
48 | # plt.show()
49 | # print("gtSCALE~~~~",gt_scale)
50 | loss_rot, loss_scale, scale_cal, loss_l1_rot, loss_mse_rot, loss_l1_scale, loss_mse_scale \
51 | = validate_rot_scale(template.clone(), source.clone(), groundTruth_number.clone(), gt_scale.clone(),\
52 | model_template, model_source, model_corr2softmax, device )
53 | loss_y, loss_x, total_loss, loss_l1_x,loss_l1_y,loss_mse_x, loss_mse_y \
54 | = validate_translation(template.clone(), source.clone(), groundTruth_number.clone(), gt_scale.clone(), gt_trans.clone(), \
55 | model_trans_template, model_trans_source, model_trans_corr2softmax,acc_x, acc_y, dsnt, device)
56 |
57 |
58 | # loss = compute_loss(corr_final, gt_angle)
59 | total_rs_loss = loss_rot + loss_scale
60 | loss_list.append(total_rs_loss.tolist())
61 | writer_val.add_scalar('LOSS ROTATION', loss_rot.detach().cpu().numpy(), iters)
62 | writer_val.add_scalar('LOSS SCALE', loss_scale.detach().cpu().numpy(), iters)
63 | writer_val.add_scalar('LOSS X', loss_x.detach().cpu().numpy(), iters)
64 | writer_val.add_scalar('LOSS Y', loss_y.detach().cpu().numpy(), iters)
65 |
66 | writer_val.add_scalar('LOSS ROTATION L1', loss_l1_rot.item(), iters)
67 | writer_val.add_scalar('LOSS ROTATION MSE', loss_mse_rot.item(), iters)
68 | writer_val.add_scalar('LOSS SCALE L1', loss_l1_scale.item(), iters)
69 | writer_val.add_scalar('LOSS SCALE MSE', loss_mse_scale.item(), iters)
70 |
71 | writer_val.add_scalar('LOSS X L1', loss_l1_x.item(), iters)
72 | writer_val.add_scalar('LOSS X MSE', loss_mse_x.item(), iters)
73 | writer_val.add_scalar('LOSS Y L1', loss_l1_y.item(), iters)
74 | writer_val.add_scalar('LOSS Y MSE', loss_mse_y.item(), iters)
75 |
76 | X = np.linspace(0, 19, 20)
77 | fig = plt.figure()
78 | plt.bar(X,acc_x/1000)
79 | plt.xlabel("X-axis")
80 | plt.ylabel("Y-axis")
81 |
82 | plt.savefig("./checkpoints/barChart/x/"+ str(epoch) + "_toy_barChartX_top1.jpg")
83 |
84 | Y = np.linspace(0, 19, 20)
85 | fig = plt.figure()
86 | plt.bar(Y,acc_y/1000)
87 | plt.xlabel("X-axis")
88 | plt.ylabel("Y-axis")
89 |
90 | plt.savefig("./checkpoints/barChart/y/"+ str(epoch) + "_toy_barChartY_top1.jpg")
91 | return loss_list
92 |
93 |
94 | # Passing a bunch of parameters
95 | parser_val = argparse.ArgumentParser(description="DPCN Network Validation")
96 | parser_val.add_argument('--only_valid', action='store_true', default=False)
97 | parser_val.add_argument('--cpu', action='store_true', default=False)
98 | parser_val.add_argument('--load_path', type=str, default="./checkpoints/checkpoint.pt")
99 | parser_val.add_argument('--simulation', action='store_true', default=False)
100 | parser_val.add_argument('--use_dsnt', action='store_true', default=False)
101 | parser_val.add_argument('--batch_size_val', type=int, default=2)
102 | parser_val.add_argument('--val_writer_path', type=str, default="./checkpoints/log/val/")
103 | args_val = parser_val.parse_args()
104 |
105 | if args_val.only_valid:
106 | epoch = 1
107 | checkpoint_path = args_val.load_path
108 | device = torch.device("cuda:0" if not args_val.cpu else "cpu")
109 | print("The devices that the code is running on:", device)
110 | writer_val = SummaryWriter(log_dir=args_val.val_writer_path)
111 | batch_size_val = args_val.batch_size_val
112 | dataloader = generate_dataloader if args_val.simulation else DPCNdataloader
113 | dsnt = args_val.use_dsnt
114 |
115 |
116 | num_class = 1
117 | start_epoch = 0
118 | iters = 0
119 |
120 |
121 | # create a shell model for checkpoint loader to load into
122 | model_template = UNet(num_class).to(device)
123 | model_source = UNet(num_class).to(device)
124 | model_corr2softmax = Corr2Softmax(200., 0.).to(device)
125 | model_trans_template = UNet(num_class).to(device)
126 | model_trans_source = UNet(num_class).to(device)
127 | model_trans_corr2softmax = Corr2Softmax(11.72, 0.).to(device)
128 |
129 | optimizer_ft_temp = optim.Adam(filter(lambda p: p.requires_grad, model_template.parameters()), lr=2e-4)
130 | optimizer_ft_src = optim.Adam(filter(lambda p: p.requires_grad, model_source.parameters()), lr=2e-4)
131 | optimizer_c2s = optim.Adam(filter(lambda p: p.requires_grad, model_corr2softmax.parameters()), lr=1e-1)
132 | optimizer_trans_ft_temp = optim.Adam(filter(lambda p: p.requires_grad, model_template.parameters()), lr=2e-4)
133 | optimizer_trans_ft_src = optim.Adam(filter(lambda p: p.requires_grad, model_source.parameters()), lr=2e-4)
134 | optimizer_trans_c2s = optim.Adam(filter(lambda p: p.requires_grad, model_corr2softmax.parameters()), lr=1e-1)
135 |
136 |
137 | # load checkpoint
138 | model_template, model_source, model_corr2softmax, model_trans_template, model_trans_source, model_trans_corr2softmax,\
139 | optimizer_ft_temp, optimizer_ft_src, optimizer_c2s, optimizer_trans_ft_temp, optimizer_trans_ft_src, optimizer_trans_c2s,\
140 | start_epoch = load_checkpoint(\
141 | checkpoint_path, model_template, model_source, model_corr2softmax, model_trans_template, model_trans_source, model_trans_corr2softmax,\
142 | optimizer_ft_temp, optimizer_ft_src, optimizer_c2s, optimizer_trans_ft_temp, optimizer_trans_ft_src, optimizer_trans_c2s, device)
143 |
144 | # Entering the mean loop of Validation
145 | loss_list = val_model(model_template, model_source, model_corr2softmax, \
146 | model_trans_template, model_trans_source, model_trans_corr2softmax, \
147 | writer_val, iters, dsnt, dataloader, batch_size_val, device, epoch)
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
--------------------------------------------------------------------------------