├── README.md ├── __init__.py ├── config.py ├── data ├── __init__.py ├── csv │ ├── pf-pascal │ │ ├── eval.csv │ │ ├── test.csv │ │ └── train.csv │ ├── pf-willow │ │ └── test.csv │ └── tss │ │ └── data.csv ├── pf_pascal.py ├── pf_willow.py └── tss.py ├── eval.py ├── eval.sh ├── geotnf ├── __init__.py ├── flow.py ├── point_tnf.py └── transformation.py ├── img └── teaser.png ├── model ├── __init__.py ├── loss.py └── network.py ├── parser ├── __init__.py └── parser.py ├── requirements.txt ├── train-coseg.py ├── train.py ├── train.sh ├── trained_models └── .gitignore └── util ├── __init__.py ├── dataloader.py ├── eval_util.py ├── normalize.py ├── py_util.py ├── torch_util.py ├── train_test_fn.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # Show, Match and Segment: Joint Weakly Supervised Learning of Semantic Matching and Object Co-segmentation 2 | 3 | This repository contains the source code for the paper Show, Match and Segment: Joint Weakly Supervised Learning of Semantic Matching and Object Co-segmentation. 4 | 5 | 6 | 7 | ## Abstract 8 | We present an approach for jointly matching and segmenting object instances of the same category within a collection of images. In contrast to existing algorithms that tackle the tasks of semantic matching and object co-segmentation in isolation, our method exploits the complementary nature of the two tasks. The key insights of our method are two-fold. First, the estimated dense correspondence fields from semantic matching provide supervision for object co-segmentation by enforcing consistency between the predicted masks from a pair of images. Second, the predicted object masks from object co-segmentation in turn allow us to reduce the adverse effects due to background clutters for improving semantic matching. Our model is end-to-end trainable and does not require supervision from manually annotated correspondences and object masks. We validate the efficacy of our approach on five benchmark datasets: TSS, Internet, PF-PASCAL, PF-WILLOW, and SPair-71k, and show that our algorithm performs favorably against the state-of-the-art methods on both semantic matching and object co-segmentation tasks. 9 | 10 | ## Citation 11 | If you find our code useful, please consider citing our work using the following bibtex: 12 | ``` 13 | @article{MaCoSNet, 14 | title={Show, Match and Segment: Joint Weakly Supervised Learning of Semantic Matching and Object Co-segmentation}, 15 | author={Chen, Yun-Chun and Lin, Yen-Yu and Yang, Ming-Hsuan and Huang, Jia-Bin}, 16 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence (PAMI)}, 17 | year={2020} 18 | } 19 | 20 | @inproceedings{WeakMatchNet, 21 | title={Deep Semantic Matching with Foreground Detection and Cycle-Consistency}, 22 | author={Chen, Yun-Chun and Huang, Po-Hsiang and Yu, Li-Yu and Huang, Jia-Bin and Yang, Ming-Hsuan and Lin, Yen-Yu}, 23 | booktitle={Asian Conference on Computer Vision (ACCV)}, 24 | year={2018} 25 | } 26 | ``` 27 | 28 | ## Environment 29 | - Install Anaconda Python3.7 30 | - This code is tested on NVIDIA V100 GPU with 16GB memory 31 | 32 | ``` 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ## Dataset 37 | - Please download the [PF-PASCAL](http://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset-PASCAL.zip), [PF-WILLOW](http://www.di.ens.fr/willow/research/proposalflow/dataset/PF-dataset.zip), [SPair-71k](http://cvlab.postech.ac.kr/research/SPair-71k/), [TSS](https://drive.google.com/file/d/0B-VxeI7PlJE1U3FyTGVpbUFtcjg/view?usp=sharing), and [Internet](http://people.csail.mit.edu/mrub/ObjectDiscovery/ObjectDiscovery-data.zip) datasets 38 | - Please modify the variable `DATASET_DIR` in `config.py` 39 | - Please modify the variable `CSV_DIR` in `config.py` 40 | 41 | 42 | 43 | ## Training 44 | - You may determine which dataset to be the `training set` by changing the $DATASET variable in train.sh 45 | - You may change the $BATCH_SIZE variable in `train.sh` to a suitable value based on the GPU memory 46 | - The trained model will be saved under the `trained_models` folder 47 | 48 | ``` 49 | sh train.sh 50 | ``` 51 | 52 | 53 | ## Evaluation 54 | - You may determine which dataset to be evaluated by changing the $DATASET variable in eval.sh 55 | - You may change the $BATCH_SIZE variable in `eval.sh` to a suitable value based on the GPU memory 56 | 57 | ``` 58 | sh eval.sh 59 | ``` 60 | 61 | ## Acknowledgement 62 | - This code is heavily borrowed from [Rocco et al.](https://github.com/ignacio-rocco/weakalign) 63 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/MaCoSNet-pytorch/697ea3a824c5bc9bbca6fac8d33367a8b9def4d1/__init__.py -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | THRESHOLD = 2 4 | 5 | NUM_OF_COORD = 100 6 | 7 | DATASET_DIR = '/work/cascades/ycchen918/PAMI/datasets' 8 | 9 | PF_PASCAL_DIR = os.path.join(DATASET_DIR, 'proposal-flow-pascal') 10 | 11 | PF_WILLOW_DIR = os.path.join(DATASET_DIR, 'proposal-flow-willow') 12 | 13 | TSS_DIR = os.path.join(DATASET_DIR, 'TSS_CVPR2016') 14 | 15 | INTERNET_DIR = os.path.join(DATASET_DIR, 'Internet') 16 | 17 | 18 | CSV_DIR = '/home/ycchen918/Project/MaCoSNet-pytorch/data/csv' 19 | 20 | PF_PASCAL_TRAIN_DATA = os.path.join(CSV_DIR, 'pf-pascal', 'train.csv') 21 | PF_PASCAL_EVAL_DATA = os.path.join(CSV_DIR, 'pf-pascal', 'eval.csv') 22 | PF_PASCAL_TEST_DATA = os.path.join(CSV_DIR, 'pf-pascal', 'test.csv') 23 | 24 | PF_WILLOW_TEST_DATA = os.path.join(CSV_DIR, 'pf-willow', 'test.csv') 25 | 26 | TSS_TRAIN_DATA = os.path.join(CSV_DIR, 'tss', 'data.csv') 27 | TSS_EVAL_DATA = os.path.join(CSV_DIR, 'tss', 'data.csv') 28 | 29 | INTERNET_TRAIN_DATA = os.path.join(CSV_DIR, 'internet', 'data.csv') 30 | INTERNET_EVAL_DATA = os.path.join(CSV_DIR, 'internet', 'data.csv') 31 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/MaCoSNet-pytorch/697ea3a824c5bc9bbca6fac8d33367a8b9def4d1/data/__init__.py -------------------------------------------------------------------------------- /data/pf_pascal.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from skimage import io 6 | import pandas as pd 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | from geotnf.transformation import GeometricTnf 10 | import cv2 11 | import config 12 | 13 | 14 | class PFPascal(Dataset): 15 | 16 | """ 17 | PF-PASCAL training dataset. 18 | 19 | Args: 20 | csv_file (string): Path to the csv file with image names and transformations. 21 | training_image_path (string): Directory with the images. 22 | output_size (2-tuple): Desired output size 23 | transform (callable): Transformation for post-processing the training pair (eg. image normalization) 24 | """ 25 | 26 | def __init__(self, 27 | csv_file=config.PF_PASCAL_TRAIN_DATA, 28 | training_image_path=config.PF_PASCAL_DIR, 29 | dataset_size=None, 30 | output_size=(240,240), 31 | transform=None, 32 | random_crop=False): 33 | 34 | self.random_crop=random_crop 35 | self.out_h, self.out_w = output_size 36 | self.train_data = pd.read_csv(csv_file) 37 | 38 | if dataset_size is not None: 39 | dataset_size = min((dataset_size,len(self.train_data))) 40 | self.train_data = self.train_data.iloc[0:dataset_size,:] 41 | 42 | self.img_A_names = self.train_data.iloc[:,0] 43 | self.img_B_names = self.train_data.iloc[:,1] 44 | self.set = self.train_data.iloc[:,2].as_matrix() 45 | self.flip = self.train_data.iloc[:, 3].as_matrix().astype('int') 46 | self.training_image_path = training_image_path 47 | self.transform = transform 48 | # no cuda as dataset is called from CPU threads in dataloader and produces confilct 49 | self.affineTnf = GeometricTnf(out_h=self.out_h, out_w=self.out_w, use_cuda = False) 50 | 51 | def __len__(self): 52 | return len(self.img_A_names) 53 | 54 | def __getitem__(self, idx): 55 | # get pre-processed images 56 | image_A,im_size_A = self.get_image(self.img_A_names,idx,self.flip[idx]) 57 | image_B,im_size_B = self.get_image(self.img_B_names,idx,self.flip[idx]) 58 | 59 | image_set = self.set[idx] 60 | 61 | sample = { 62 | 'source_image': image_A, 63 | 'target_image': image_B, 64 | 'source_im_size': im_size_A, 65 | 'target_im_size': im_size_B, 66 | 'set':image_set 67 | } 68 | 69 | if self.transform: 70 | sample = self.transform(sample) 71 | 72 | return sample 73 | 74 | def get_image(self,img_name_list,idx,flip): 75 | img_name = os.path.join(self.training_image_path, img_name_list.iloc[idx]) 76 | image = io.imread(img_name) 77 | 78 | # if grayscale convert to 3-channel image 79 | if image.ndim==2: 80 | image=np.repeat(np.expand_dims(image,2),axis=2,repeats=3) 81 | 82 | # do random crop 83 | if self.random_crop: 84 | h,w,c=image.shape 85 | top=np.random.randint(h/4) 86 | bottom=int(3*h/4+np.random.randint(h/4)) 87 | left=np.random.randint(w/4) 88 | right=int(3*w/4+np.random.randint(w/4)) 89 | image = image[top:bottom,left:right,:] 90 | 91 | # flip horizontally if needed 92 | if flip: 93 | image=np.flip(image,1) 94 | 95 | # get image size 96 | im_size = np.asarray(image.shape) 97 | 98 | # convert to torch Variable 99 | image = np.expand_dims(image.transpose((2,0,1)),0) 100 | image = torch.Tensor(image.astype(np.float32)) 101 | image_var = Variable(image,requires_grad=False) 102 | 103 | # Resize image using bilinear sampling with identity affine tnf 104 | image = self.affineTnf(image_var).data.squeeze(0) 105 | 106 | im_size = torch.Tensor(im_size.astype(np.float32)) 107 | 108 | return (image, im_size) 109 | 110 | 111 | 112 | class PFPascalVal(Dataset): 113 | 114 | """ 115 | 116 | PF-PASCAL eval/test dataset. 117 | 118 | Args: 119 | csv_file (string): Path to the csv file with image names and transformations. 120 | dataset_path (string): Directory with the images. 121 | output_size (2-tuple): Desired output size 122 | transform (callable): Transformation for post-processing the training pair (eg. image normalization) 123 | 124 | """ 125 | 126 | def __init__(self, 127 | csv_file=config.PF_PASCAL_EVAL_DATA, 128 | dataset_path=config.PF_PASCAL_DIR, 129 | output_size=(240,240), 130 | transform=None, 131 | category=None, 132 | mode='eval', 133 | pck_procedure='scnet'): 134 | 135 | if mode == 'test': 136 | csv_file = config.PF_PASCAL_TEST_DATA 137 | 138 | self.category_names = [ 'aeroplane', 'bicycle', 'bird', 'boat', 139 | 'bottle', 'bus', 'car', 'cat', 'chair', 140 | 'cow', 'diningtable', 'dog', 'horse', 141 | 'motorbike', 'person', 'pottedplant', 142 | 'sheep', 'sofa', 'train', 'tvmonitor' ] 143 | 144 | self.out_h, self.out_w = output_size 145 | self.pairs = pd.read_csv(csv_file) 146 | self.category = self.pairs.iloc[:,2].as_matrix().astype('int') 147 | if category is not None: 148 | cat_idx = np.nonzero(self.category==category)[0] 149 | self.category=self.category[cat_idx] 150 | self.pairs=self.pairs.iloc[cat_idx,:] 151 | self.img_A_names = self.pairs.iloc[:,0] 152 | self.img_B_names = self.pairs.iloc[:,1] 153 | self.point_A_coords = self.pairs.iloc[:, 3:5] 154 | self.point_B_coords = self.pairs.iloc[:, 5:] 155 | self.dataset_path = dataset_path 156 | self.transform = transform 157 | # no cuda as dataset is called from CPU threads in dataloader and produces confilct 158 | self.affineTnf = GeometricTnf(out_h=self.out_h, out_w=self.out_w, use_cuda = False) 159 | self.pck_procedure = pck_procedure 160 | 161 | def __len__(self): 162 | return len(self.pairs) 163 | 164 | def __getitem__(self, idx): 165 | # get pre-processed images 166 | image_A,im_size_A = self.get_image(self.img_A_names,idx) 167 | image_B,im_size_B = self.get_image(self.img_B_names,idx) 168 | 169 | # get pre-processed point coords 170 | point_A_coords = self.get_points(self.point_A_coords,idx) 171 | point_B_coords = self.get_points(self.point_B_coords,idx) 172 | 173 | # compute PCK reference length L_pck (equal to max bounding box side in image_A) 174 | N_pts = torch.sum(torch.ne(point_A_coords[0,:],-1)) 175 | 176 | if self.pck_procedure=='pf': 177 | L_pck = torch.FloatTensor([torch.max(point_A_coords[:,:N_pts].max(1)[0]-point_A_coords[:,:N_pts].min(1)[0])]) 178 | 179 | elif self.pck_procedure=='scnet': 180 | #modification to follow the evaluation procedure of SCNet 181 | point_A_coords[0,0:N_pts]=point_A_coords[0,0:N_pts]*224/im_size_A[1] 182 | point_A_coords[1,0:N_pts]=point_A_coords[1,0:N_pts]*224/im_size_A[0] 183 | 184 | point_B_coords[0,0:N_pts]=point_B_coords[0,0:N_pts]*224/im_size_B[1] 185 | point_B_coords[1,0:N_pts]=point_B_coords[1,0:N_pts]*224/im_size_B[0] 186 | 187 | im_size_A[0:2]=torch.FloatTensor([224,224]) 188 | im_size_B[0:2]=torch.FloatTensor([224,224]) 189 | 190 | L_pck = torch.FloatTensor([224.0]) 191 | 192 | sample = { 193 | 'source_image': image_A, 194 | 'target_image': image_B, 195 | 'source_im_size': im_size_A, 196 | 'target_im_size': im_size_B, 197 | 'source_points': point_A_coords, 198 | 'target_points': point_B_coords, 199 | 'L_pck': L_pck 200 | } 201 | 202 | if self.transform: 203 | sample = self.transform(sample) 204 | 205 | return sample 206 | 207 | def get_image(self, img_name_list, idx): 208 | img_name = os.path.join(self.dataset_path, img_name_list.iloc[idx]) 209 | image = io.imread(img_name) 210 | #image = cv2.imread(img_name) 211 | 212 | # get image size 213 | im_size = np.asarray(image.shape) 214 | 215 | # convert to torch Variable 216 | image = np.expand_dims(image.transpose((2,0,1)),0) 217 | image = torch.Tensor(image.astype(np.float32)) 218 | image_var = Variable(image,requires_grad=False) 219 | 220 | # Resize image using bilinear sampling with identity affine tnf 221 | image = self.affineTnf(image_var).data.squeeze(0) 222 | 223 | im_size = torch.Tensor(im_size.astype(np.float32)) 224 | 225 | return (image, im_size) 226 | 227 | def get_points(self,point_coords_list,idx): 228 | X=np.fromstring(point_coords_list.iloc[idx,0],sep=';') 229 | Y=np.fromstring(point_coords_list.iloc[idx,1],sep=';') 230 | Xpad = -np.ones(20); Xpad[:len(X)]=X 231 | Ypad = -np.ones(20); Ypad[:len(X)]=Y 232 | point_coords = np.concatenate((Xpad.reshape(1,20),Ypad.reshape(1,20)),axis=0) 233 | 234 | # make arrays float tensor for subsequent processing 235 | point_coords = torch.Tensor(point_coords.astype(np.float32)) 236 | return point_coords 237 | 238 | 239 | -------------------------------------------------------------------------------- /data/pf_willow.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from skimage import io 6 | import pandas as pd 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | from geotnf.transformation import GeometricTnf 10 | import config 11 | import cv2 12 | 13 | 14 | class PFWillow(Dataset): 15 | 16 | """ 17 | 18 | PF-WILLOW test dataset. 19 | 20 | Args: 21 | csv_file (string): Path to the csv file with image names and transformations. 22 | dataset_path (string): Directory with the images. 23 | output_size (2-tuple): Desired output size 24 | transform (callable): Transformation for post-processing the training pair (eg. image normalization) 25 | 26 | """ 27 | 28 | def __init__(self, 29 | csv_file=config.PF_WILLOW_TEST_DATA, 30 | dataset_path=config.PF_WILLOW_DIR, 31 | output_size=(240,240), 32 | transform=None, 33 | category=None, 34 | pck_procedure='scnet'): 35 | 36 | self.category_names = [ 'car(G)', 'car(M)', 'car(S)', 'duck(S)', 37 | 'motorbike(G)', 'motorbike(M)', 'motorbike(S)', 38 | 'winebottle(M)', 'winebottle(wC)', 'winebottle(woC)' ] 39 | 40 | self.out_h, self.out_w = output_size 41 | self.pairs = pd.read_csv(csv_file) 42 | self.category = self.pairs.iloc[:,2].as_matrix().astype('int') 43 | if category is not None: 44 | cat_idx = np.nonzero(self.category==category)[0] 45 | self.category=self.category[cat_idx] 46 | self.pairs=self.pairs.iloc[cat_idx,:] 47 | self.img_A_names = self.pairs.iloc[:,0] 48 | self.img_B_names = self.pairs.iloc[:,1] 49 | self.point_A_coords = self.pairs.iloc[:, 3:5] 50 | self.point_B_coords = self.pairs.iloc[:, 5:] 51 | self.dataset_path = dataset_path 52 | self.transform = transform 53 | # no cuda as dataset is called from CPU threads in dataloader and produces confilct 54 | self.affineTnf = GeometricTnf(out_h=self.out_h, out_w=self.out_w, use_cuda = False) 55 | self.pck_procedure = pck_procedure 56 | 57 | def __len__(self): 58 | return len(self.pairs) 59 | 60 | def __getitem__(self, idx): 61 | # get pre-processed images 62 | image_A,im_size_A = self.get_image(self.img_A_names,idx) 63 | image_B,im_size_B = self.get_image(self.img_B_names,idx) 64 | 65 | # get pre-processed point coords 66 | point_A_coords = self.get_points(self.point_A_coords,idx) 67 | point_B_coords = self.get_points(self.point_B_coords,idx) 68 | 69 | # compute PCK reference length L_pck (equal to max bounding box side in image_A) 70 | N_pts = torch.sum(torch.ne(point_A_coords[0,:],-1)) 71 | 72 | if self.pck_procedure=='pf': 73 | L_pck = torch.FloatTensor([torch.max(point_A_coords[:,:N_pts].max(1)[0]-point_A_coords[:,:N_pts].min(1)[0])]) 74 | 75 | elif self.pck_procedure=='scnet': 76 | #modification to follow the evaluation procedure of SCNet 77 | point_A_coords[0,0:N_pts]=point_A_coords[0,0:N_pts]*224/im_size_A[1] 78 | point_A_coords[1,0:N_pts]=point_A_coords[1,0:N_pts]*224/im_size_A[0] 79 | 80 | point_B_coords[0,0:N_pts]=point_B_coords[0,0:N_pts]*224/im_size_B[1] 81 | point_B_coords[1,0:N_pts]=point_B_coords[1,0:N_pts]*224/im_size_B[0] 82 | 83 | im_size_A[0:2]=torch.FloatTensor([224,224]) 84 | im_size_B[0:2]=torch.FloatTensor([224,224]) 85 | 86 | L_pck = torch.FloatTensor([224.0]) 87 | 88 | sample = { 89 | 'source_image': image_A, 90 | 'target_image': image_B, 91 | 'source_im_size': im_size_A, 92 | 'target_im_size': im_size_B, 93 | 'source_points': point_A_coords, 94 | 'target_points': point_B_coords, 95 | 'L_pck': L_pck 96 | } 97 | 98 | if self.transform: 99 | sample = self.transform(sample) 100 | 101 | return sample 102 | 103 | def get_image(self, img_name_list, idx): 104 | img_name = os.path.join(self.dataset_path, img_name_list.iloc[idx]) 105 | image = io.imread(img_name) 106 | #image = cv2.imread(img_name) 107 | 108 | # get image size 109 | im_size = np.asarray(image.shape) 110 | 111 | # convert to torch Variable 112 | image = np.expand_dims(image.transpose((2,0,1)),0) 113 | image = torch.Tensor(image.astype(np.float32)) 114 | image_var = Variable(image,requires_grad=False) 115 | 116 | # Resize image using bilinear sampling with identity affine tnf 117 | image = self.affineTnf(image_var).data.squeeze(0) 118 | 119 | im_size = torch.Tensor(im_size.astype(np.float32)) 120 | 121 | return (image, im_size) 122 | 123 | def get_points(self,point_coords_list,idx): 124 | X=np.fromstring(point_coords_list.iloc[idx,0],sep=';') 125 | Y=np.fromstring(point_coords_list.iloc[idx,1],sep=';') 126 | Xpad = -np.ones(20); Xpad[:len(X)]=X 127 | Ypad = -np.ones(20); Ypad[:len(X)]=Y 128 | point_coords = np.concatenate((Xpad.reshape(1,20),Ypad.reshape(1,20)),axis=0) 129 | 130 | # make arrays float tensor for subsequent processing 131 | point_coords = torch.Tensor(point_coords.astype(np.float32)) 132 | return point_coords 133 | 134 | 135 | -------------------------------------------------------------------------------- /data/tss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from skimage import io 6 | import pandas as pd 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | from geotnf.transformation import GeometricTnf 10 | from geotnf.flow import read_flo_file 11 | import cv2 12 | import config 13 | 14 | 15 | class TSS(Dataset): 16 | 17 | def __init__(self, 18 | csv_file=config.TSS_TRAIN_DATA, 19 | dataset_path=config.TSS_DIR, 20 | output_size=(240,240), 21 | transform=None, 22 | random_crop=False): 23 | 24 | self.random_crop = random_crop 25 | self.out_h, self.out_w = output_size 26 | self.train_data = pd.read_csv(csv_file) 27 | self.img_A_names = self.train_data.iloc[:,0] 28 | self.img_B_names = self.train_data.iloc[:,1] 29 | self.flip_img_A = self.train_data.iloc[:, 3].as_matrix().astype('int') 30 | self.pair_category = self.train_data.iloc[:, 4].as_matrix().astype('int') 31 | self.dataset_path = dataset_path 32 | self.transform = transform 33 | # no cuda as dataset is called from CPU threads in dataloader and produces confilct 34 | self.affineTnf = GeometricTnf(out_h=self.out_h, out_w=self.out_w, use_cuda = False) 35 | 36 | self.hash_table = self.set_dict() 37 | 38 | def set_dict(self): 39 | hash_table = dict() 40 | for idx in range(len(self.train_data)): 41 | class_label = self.pair_category[idx] 42 | if class_label not in hash_table: 43 | hash_table[class_label] = [] 44 | if self.img_A_names[idx] not in hash_table[class_label]: 45 | hash_table[class_label].append(self.img_A_names[idx]) 46 | if self.img_B_names[idx] not in hash_table[class_label]: 47 | hash_table[class_label].append(self.img_B_names[idx]) 48 | return hash_table 49 | 50 | def __len__(self): 51 | return len(self.train_data) 52 | 53 | def __getitem__(self, idx): 54 | class_label = self.pair_category[idx] 55 | sampled_idx = random.randint(0, len(self.hash_table[class_label])-1) 56 | 57 | # get pre-processed images 58 | image_A, im_size_A = self.get_image(self.img_A_names[idx], self.flip_img_A[idx]) 59 | image_B, im_size_B = self.get_image(self.img_B_names[idx]) 60 | image_C, im_size_C = self.get_image(self.hash_table[class_label][sampled_idx]) 61 | 62 | sample = { 63 | 'image_A': image_A, 64 | 'image_B': image_B, 65 | 'image_C': image_C, 66 | 'image_A_size': im_size_A, 67 | 'image_B_size': im_size_B, 68 | 'image_C_size': im_size_C, 69 | 'set': class_label 70 | } 71 | 72 | if self.transform: 73 | sample = self.transform(sample) 74 | 75 | return sample 76 | 77 | def get_image(self,img_name, flip=False): 78 | img_name = os.path.join(self.dataset_path, img_name) 79 | image = io.imread(img_name) 80 | 81 | # if grayscale convert to 3-channel image 82 | if image.ndim==2: 83 | image=np.repeat(np.expand_dims(image,2),axis=2,repeats=3) 84 | 85 | # do random crop 86 | if self.random_crop: 87 | h,w,c=image.shape 88 | top=np.random.randint(h/4) 89 | bottom=int(3*h/4+np.random.randint(h/4)) 90 | left=np.random.randint(w/4) 91 | right=int(3*w/4+np.random.randint(w/4)) 92 | image = image[top:bottom,left:right,:] 93 | 94 | # flip horizontally if needed 95 | if flip: 96 | image=np.flip(image,1) 97 | 98 | # get image size 99 | im_size = np.asarray(image.shape) 100 | 101 | # convert to torch Variable 102 | image = np.expand_dims(image.transpose((2,0,1)),0) 103 | image = torch.Tensor(image.astype(np.float32)) 104 | image_var = Variable(image,requires_grad=False) 105 | 106 | # Resize image using bilinear sampling with identity affine tnf 107 | image = self.affineTnf(image_var).data.squeeze(0) 108 | 109 | im_size = torch.Tensor(im_size.astype(np.float32)) 110 | 111 | return (image, im_size) 112 | 113 | 114 | class TSSVal(Dataset): 115 | 116 | def __init__(self, 117 | csv_file=config.TSS_EVAL_DATA, 118 | dataset_path=config.TSS_DIR, 119 | output_size=(240,240), 120 | transform=None): 121 | 122 | self.out_h, self.out_w = output_size 123 | self.pairs = pd.read_csv(csv_file) 124 | self.img_A_names = self.pairs.iloc[:,0] 125 | self.img_B_names = self.pairs.iloc[:,1] 126 | self.flow_direction = self.pairs.iloc[:, 2].as_matrix().astype('int') 127 | self.flip_img_A = self.pairs.iloc[:, 3].as_matrix().astype('int') 128 | self.pair_category = self.pairs.iloc[:, 4].as_matrix().astype('int') 129 | self.dataset_path = dataset_path 130 | self.transform = transform 131 | # no cuda as dataset is called from CPU threads in dataloader and produces confilct 132 | self.affineTnf = GeometricTnf(out_h=self.out_h, out_w=self.out_w, use_cuda = False) 133 | 134 | def __len__(self): 135 | return len(self.pairs) 136 | 137 | def __getitem__(self, idx): 138 | 139 | # get pre-processed images 140 | image_A, im_size_A = self.get_image(self.img_A_names[idx], self.flip_img_A[idx]) 141 | image_B, im_size_B = self.get_image(self.img_B_names[idx]) 142 | 143 | # get flow output path 144 | flow_path = self.get_GT_flow_relative_path(idx) 145 | 146 | sample = { 147 | 'image_A': image_A, 148 | 'image_B': image_B, 149 | 'image_A_size': im_size_A, 150 | 'image_B_size': im_size_B, 151 | 'flow_path': flow_path, 152 | 'set': class_label 153 | } 154 | 155 | if self.transform: 156 | sample = self.transform(sample) 157 | 158 | return sample 159 | 160 | def get_image(self,img_name, flip=False): 161 | img_name = os.path.join(self.dataset_path, img_name) 162 | image = io.imread(img_name) 163 | 164 | # if grayscale convert to 3-channel image 165 | if image.ndim == 2: 166 | image=np.repeat(np.expand_dims(image,2),axis=2,repeats=3) 167 | 168 | # flip horizontally if needed 169 | if flip: 170 | image = np.flip(image,1) 171 | 172 | # get image size 173 | im_size = np.asarray(image.shape) 174 | 175 | # convert to torch Variable 176 | image = np.expand_dims(image.transpose((2,0,1)),0) 177 | image = torch.Tensor(image.astype(np.float32)) 178 | image_var = Variable(image,requires_grad=False) 179 | 180 | # Resize image using bilinear sampling with identity affine tnf 181 | image = self.affineTnf(image_var).data.squeeze(0) 182 | 183 | im_size = torch.Tensor(im_size.astype(np.float32)) 184 | 185 | return (image, im_size) 186 | 187 | #def get_GT_flow(self,idx): 188 | # img_folder = os.path.dirname(self.img_A_names[idx]) 189 | # flow_dir = self.flow_direction[idx] 190 | # flow_file = 'flow'+str(flow_dir)+'.flo' 191 | # flow_file_path = os.path.join(self.dataset_path, img_folder , flow_file) 192 | # 193 | # flow = torch.FloatTensor(read_flo_file(flow_file_path)) 194 | # 195 | # return flow 196 | 197 | def get_GT_flow_relative_path(self,idx): 198 | img_folder = os.path.dirname(self.img_A_names[idx]) 199 | flow_dir = self.flow_direction[idx] 200 | flow_file = 'flow'+str(flow_dir)+'.flo' 201 | flow_file_path = os.path.join(img_folder , flow_file) 202 | 203 | return flow_file_path 204 | 205 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import sys 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from util.util import init_model 9 | from util.util import init_test_data 10 | 11 | from util.eval_util import compute_metric 12 | from util.torch_util import BatchTensorToVars 13 | 14 | from parser.parser import ArgumentParser 15 | 16 | 17 | args, arg_groups = ArgumentParser(mode='eval').parse() 18 | 19 | 20 | #torch.cuda.set_device(args.gpu) 21 | use_cuda = torch.cuda.is_available() 22 | 23 | 24 | """ Initialize model """ 25 | model = init_model(args, arg_groups, use_cuda, mode='eval') 26 | 27 | 28 | 29 | """ Initialize dataloader """ 30 | test_data, test_loader = init_test_data(args) 31 | 32 | batch_tnf = BatchTensorToVars(use_cuda=use_cuda) 33 | 34 | 35 | model.eval() 36 | 37 | stats = compute_metric(args.eval_metric, model, test_data, test_loader, batch_tnf, args) 38 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | NUM_WORKERS=10 2 | BATCH_SIZE=8 3 | 4 | EVAL_METRIC=pck 5 | PCK_ALPHA=0.1 6 | 7 | DATASET=pf-pascal 8 | 9 | MODEL_DIR=trained_models 10 | 11 | W_MATCH=1.0 12 | W_CYCLE=0.0 13 | W_TRANS=0.0 14 | W_COSEG=0.0 15 | W_TASK=0.0 16 | 17 | MODEL_PATH=$MODEL_DIR/best_match_${W_MATCH}_cycle_${W_CYCLE}_trans_${W_TRANS}_coseg_${W_COSEG}_task_${W_TASK}.pth.tar 18 | 19 | python eval.py \ 20 | --model $MODEL_PATH \ 21 | --num-workers $NUM_WORKERS \ 22 | --eval-dataset $DATASET \ 23 | --pck-alpha $PCK_ALPHA \ 24 | --eval-metric $EVAL_METRIC \ 25 | --batch-size $BATCH_SIZE \ 26 | -------------------------------------------------------------------------------- /geotnf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/MaCoSNet-pytorch/697ea3a824c5bc9bbca6fac8d33367a8b9def4d1/geotnf/__init__.py -------------------------------------------------------------------------------- /geotnf/flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from geotnf.point_tnf import normalize_axis, unnormalize_axis 6 | 7 | def read_flo_file(filename,verbose=False): 8 | """ 9 | Read from .flo optical flow file (Middlebury format) 10 | :param flow_file: name of the flow file 11 | :return: optical flow data in matrix 12 | 13 | adapted from https://github.com/liruoteng/OpticalFlowToolkit/ 14 | 15 | """ 16 | f = open(filename, 'rb') 17 | magic = np.fromfile(f, np.float32, count=1) 18 | data2d = None 19 | 20 | if 202021.25 != magic: 21 | raise TypeError('Magic number incorrect. Invalid .flo file') 22 | else: 23 | w = np.fromfile(f, np.int32, count=1) 24 | h = np.fromfile(f, np.int32, count=1) 25 | if verbose: 26 | print("Reading %d x %d flow file in .flo format" % (h, w)) 27 | data2d = np.fromfile(f, np.float32, count=int(2 * w * h)) 28 | # reshape data into 3D array (columns, rows, channels) 29 | data2d = np.resize(data2d, (h[0], w[0], 2)) 30 | f.close() 31 | return data2d 32 | 33 | def write_flo_file(flow, filename): 34 | """ 35 | Write optical flow in Middlebury .flo format 36 | 37 | :param flow: optical flow map 38 | :param filename: optical flow file path to be saved 39 | :return: None 40 | 41 | from https://github.com/liruoteng/OpticalFlowToolkit/ 42 | 43 | """ 44 | # forcing conversion to float32 precision 45 | flow = flow.astype(np.float32) 46 | f = open(filename, 'wb') 47 | magic = np.array([202021.25], dtype=np.float32) 48 | (height, width) = flow.shape[0:2] 49 | w = np.array([width], dtype=np.int32) 50 | h = np.array([height], dtype=np.int32) 51 | magic.tofile(f) 52 | w.tofile(f) 53 | h.tofile(f) 54 | flow.tofile(f) 55 | f.close() 56 | 57 | 58 | def warp_image(image, flow): 59 | """ 60 | Warp image (np.ndarray, shape=[h_src,w_src,3]) with flow (np.ndarray, shape=[h_tgt,w_tgt,2]) 61 | 62 | """ 63 | h_src,w_src=image.shape[0],image.shape[1] 64 | sampling_grid_torch = np_flow_to_th_sampling_grid(flow, h_src, w_src) 65 | image_torch = Variable(torch.FloatTensor(image.astype(np.float32)).transpose(1,2).transpose(0,1).unsqueeze(0)) 66 | warped_image_torch = F.grid_sample(image_torch, sampling_grid_torch) 67 | warped_image = warped_image_torch.data.squeeze(0).transpose(0,1).transpose(1,2).numpy().astype(np.uint8) 68 | return warped_image 69 | 70 | def np_flow_to_th_sampling_grid(flow,h_src,w_src,use_cuda=False): 71 | h_tgt,w_tgt=flow.shape[0],flow.shape[1] 72 | grid_x,grid_y = np.meshgrid(range(1,w_tgt+1),range(1,h_tgt+1)) 73 | disp_x=flow[:,:,0] 74 | disp_y=flow[:,:,1] 75 | source_x=grid_x+disp_x 76 | source_y=grid_y+disp_y 77 | source_x_norm=normalize_axis(source_x,w_src) 78 | source_y_norm=normalize_axis(source_y,h_src) 79 | sampling_grid=np.concatenate((np.expand_dims(source_x_norm,2), 80 | np.expand_dims(source_y_norm,2)),2) 81 | sampling_grid_torch = Variable(torch.FloatTensor(sampling_grid).unsqueeze(0)) 82 | if use_cuda: 83 | sampling_grid_torch = sampling_grid_torch.cuda() 84 | return sampling_grid_torch 85 | 86 | # def th_sampling_grid_to_np_flow(source_grid,h_src,w_src): 87 | # batch_size = source_grid.size(0) 88 | # h_tgt,w_tgt=source_grid.size(1),source_grid.size(2) 89 | # source_x_norm=source_grid[:,:,:,0] 90 | # source_y_norm=source_grid[:,:,:,1] 91 | # source_x=unnormalize_axis(source_x_norm,w_src) 92 | # source_y=unnormalize_axis(source_y_norm,h_src) 93 | # source_x=source_x.data.cpu().numpy() 94 | # source_y=source_y.data.cpu().numpy() 95 | # grid_x,grid_y = np.meshgrid(range(1,w_tgt+1),range(1,h_tgt+1)) 96 | # grid_x = np.repeat(grid_x,axis=0,repeats=batch_size) 97 | # grid_y = np.repeat(grid_y,axis=0,repeats=batch_size) 98 | # disp_x=source_x-grid_x 99 | # disp_y=source_y-grid_y 100 | # flow = np.concatenate((np.expand_dims(disp_x,3),np.expand_dims(disp_y,3)),3) 101 | # return flow 102 | 103 | def th_sampling_grid_to_np_flow(source_grid,h_src,w_src): 104 | # remove batch dimension 105 | source_grid = source_grid.squeeze(0) 106 | # get mask 107 | in_bound_mask=(source_grid.data[:,:,0]>-1) & (source_grid.data[:,:,0]<1) & (source_grid.data[:,:,1]>-1) & (source_grid.data[:,:,1]<1) 108 | in_bound_mask=in_bound_mask.cpu().numpy() 109 | # convert coords 110 | h_tgt,w_tgt=source_grid.size(0),source_grid.size(1) 111 | source_x_norm=source_grid[:,:,0] 112 | source_y_norm=source_grid[:,:,1] 113 | source_x=unnormalize_axis(source_x_norm,w_src) 114 | source_y=unnormalize_axis(source_y_norm,h_src) 115 | source_x=source_x.data.cpu().numpy() 116 | source_y=source_y.data.cpu().numpy() 117 | grid_x,grid_y = np.meshgrid(range(1,w_tgt+1),range(1,h_tgt+1)) 118 | disp_x=source_x-grid_x 119 | disp_y=source_y-grid_y 120 | # apply mask 121 | disp_x = disp_x*in_bound_mask+1e10*(1-in_bound_mask) 122 | disp_y = disp_y*in_bound_mask+1e10*(1-in_bound_mask) 123 | flow = np.concatenate((np.expand_dims(disp_x,2),np.expand_dims(disp_y,2)),2) 124 | return flow 125 | 126 | -------------------------------------------------------------------------------- /geotnf/point_tnf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | from geotnf.transformation import TpsGridGen 5 | 6 | def normalize_axis(x,L): 7 | return (x-1-(L-1)/2)*2/(L-1) 8 | 9 | def unnormalize_axis(x,L): 10 | return x*(L-1)/2+1+(L-1)/2 11 | 12 | class PointTnf(object): 13 | """ 14 | 15 | Class with functions for transforming a set of points with affine/tps transformations 16 | 17 | """ 18 | def __init__(self, tps_grid_size=3, tps_reg_factor=0, use_cuda=True): 19 | self.use_cuda=use_cuda 20 | self.tpsTnf = TpsGridGen(grid_size=tps_grid_size, 21 | reg_factor=tps_reg_factor, 22 | use_cuda=self.use_cuda) 23 | 24 | def tpsPointTnf(self,theta,points): 25 | # points are expected in [B,2,N], where first row is X and second row is Y 26 | # reshape points for applying Tps transformation 27 | points=points.unsqueeze(3).transpose(1,3) 28 | # apply transformation 29 | warped_points = self.tpsTnf.apply_transformation(theta,points) 30 | # undo reshaping 31 | warped_points=warped_points.transpose(3,1).squeeze(3) 32 | return warped_points 33 | 34 | def affPointTnf(self,theta,points): 35 | theta_mat = theta.view(-1,2,3) 36 | warped_points = torch.bmm(theta_mat[:,:,:2],points) 37 | warped_points += theta_mat[:,:,2].unsqueeze(2).expand_as(warped_points) 38 | return warped_points 39 | 40 | def PointsToUnitCoords(P,im_size): 41 | h,w = im_size[:,0],im_size[:,1] 42 | P_norm = P.clone() 43 | # normalize Y 44 | P_norm[:,0,:] = normalize_axis(P[:,0,:],w.unsqueeze(1).expand_as(P[:,0,:])) 45 | # normalize X 46 | P_norm[:,1,:] = normalize_axis(P[:,1,:],h.unsqueeze(1).expand_as(P[:,1,:])) 47 | return P_norm 48 | 49 | def PointsToPixelCoords(P,im_size): 50 | h,w = im_size[:,0],im_size[:,1] 51 | P_norm = P.clone() 52 | # normalize Y 53 | P_norm[:,0,:] = unnormalize_axis(P[:,0,:],w.unsqueeze(1).expand_as(P[:,0,:])) 54 | # normalize X 55 | P_norm[:,1,:] = unnormalize_axis(P[:,1,:],h.unsqueeze(1).expand_as(P[:,1,:])) 56 | return P_norm -------------------------------------------------------------------------------- /geotnf/transformation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import sys 4 | from skimage import io 5 | import pandas as pd 6 | import numpy as np 7 | import torch 8 | from torch.nn.modules.module import Module 9 | from torch.utils.data import Dataset 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | from util.torch_util import expand_dim 13 | 14 | class ComposedGeometricTnf(object): 15 | """ 16 | 17 | Composed geometric transfromation (affine+tps) 18 | 19 | """ 20 | def __init__(self, tps_grid_size=3, tps_reg_factor=0, out_h=240, out_w=240, 21 | offset_factor=1.0, 22 | padding_crop_factor=None, 23 | use_cuda=True): 24 | 25 | self.padding_crop_factor=padding_crop_factor 26 | 27 | self.affTnf = GeometricTnf(out_h=out_h,out_w=out_w, 28 | geometric_model='affine', 29 | offset_factor=offset_factor if padding_crop_factor is None else padding_crop_factor, 30 | use_cuda=use_cuda) 31 | 32 | self.tpsTnf = GeometricTnf(out_h=out_h,out_w=out_w, 33 | geometric_model='tps', 34 | tps_grid_size=tps_grid_size, 35 | tps_reg_factor=tps_reg_factor, 36 | offset_factor=offset_factor if padding_crop_factor is None else 1.0, 37 | use_cuda=use_cuda) 38 | 39 | def __call__(self, image_batch, theta_aff, theta_aff_tps, use_cuda=True): 40 | 41 | sampling_grid_aff = self.affTnf(image_batch=None, 42 | theta_batch=theta_aff.view(-1,2,3), 43 | return_sampling_grid=True, 44 | return_warped_image=False) 45 | 46 | sampling_grid_aff_tps = self.tpsTnf(image_batch=None, 47 | theta_batch=theta_aff_tps, 48 | return_sampling_grid=True, 49 | return_warped_image=False) 50 | 51 | if self.padding_crop_factor is not None: 52 | sampling_grid_aff_tps = sampling_grid_aff_tps*self.padding_crop_factor; 53 | 54 | # put 1e10 value in region out of bounds of sampling_grid_aff 55 | in_bound_mask_aff = ((sampling_grid_aff[:,:,:,0]>-1) * (sampling_grid_aff[:,:,:,0]<1) * (sampling_grid_aff[:,:,:,1]>-1) * (sampling_grid_aff[:,:,:,1]<1)).unsqueeze(3) 56 | in_bound_mask_aff = in_bound_mask_aff.expand_as(sampling_grid_aff) 57 | sampling_grid_aff = torch.mul(in_bound_mask_aff.float(),sampling_grid_aff) 58 | sampling_grid_aff = torch.add((in_bound_mask_aff.float()-1)*(1e10),sampling_grid_aff) 59 | 60 | # compose transformations 61 | sampling_grid_aff_tps_comp = F.grid_sample(sampling_grid_aff.transpose(2,3).transpose(1,2), sampling_grid_aff_tps).transpose(1,2).transpose(2,3) 62 | 63 | # put 1e10 value in region out of bounds of sampling_grid_aff_tps_comp 64 | in_bound_mask_aff_tps=((sampling_grid_aff_tps[:,:,:,0]>-1) * (sampling_grid_aff_tps[:,:,:,0]<1) * (sampling_grid_aff_tps[:,:,:,1]>-1) * (sampling_grid_aff_tps[:,:,:,1]<1)).unsqueeze(3) 65 | in_bound_mask_aff_tps=in_bound_mask_aff_tps.expand_as(sampling_grid_aff_tps_comp) 66 | sampling_grid_aff_tps_comp=torch.mul(in_bound_mask_aff_tps.float(),sampling_grid_aff_tps_comp) 67 | sampling_grid_aff_tps_comp = torch.add((in_bound_mask_aff_tps.float()-1)*(1e10),sampling_grid_aff_tps_comp) 68 | 69 | # sample transformed image 70 | warped_image_batch = F.grid_sample(image_batch, sampling_grid_aff_tps_comp) 71 | 72 | return warped_image_batch 73 | 74 | class GeometricTnf(object): 75 | """ 76 | 77 | Geometric transfromation to an image batch (wrapped in a PyTorch Variable) 78 | ( can be used with no transformation to perform bilinear resizing ) 79 | 80 | """ 81 | def __init__(self, geometric_model='affine', tps_grid_size=3, tps_reg_factor=0, out_h=240, out_w=240, offset_factor=None, use_cuda=True): 82 | self.out_h = out_h 83 | self.out_w = out_w 84 | self.geometric_model = geometric_model 85 | self.use_cuda = use_cuda 86 | self.offset_factor = offset_factor 87 | 88 | if geometric_model=='affine' and offset_factor is None: 89 | self.gridGen = AffineGridGen(out_h=out_h, out_w=out_w, use_cuda=use_cuda) 90 | elif geometric_model=='affine' and offset_factor is not None: 91 | self.gridGen = AffineGridGenV2(out_h=out_h, out_w=out_w, use_cuda=use_cuda) 92 | elif geometric_model=='tps': 93 | self.gridGen = TpsGridGen(out_h=out_h, out_w=out_w, grid_size=tps_grid_size, 94 | reg_factor=tps_reg_factor, use_cuda=use_cuda) 95 | if offset_factor is not None: 96 | self.gridGen.grid_X=self.gridGen.grid_X/offset_factor 97 | self.gridGen.grid_Y=self.gridGen.grid_Y/offset_factor 98 | 99 | self.theta_identity = torch.Tensor(np.expand_dims(np.array([[1,0,0],[0,1,0]]),0).astype(np.float32)) 100 | if use_cuda: 101 | self.theta_identity = self.theta_identity.cuda() 102 | 103 | def __call__(self, image_batch, theta_batch=None, out_h=None, out_w=None, return_warped_image=True, return_sampling_grid=False, padding_factor=1.0, crop_factor=1.0): 104 | if image_batch is None: 105 | b=1 106 | else: 107 | b=image_batch.size(0) 108 | if theta_batch is None: 109 | theta_batch = self.theta_identity 110 | theta_batch = theta_batch.expand(b,2,3).contiguous() 111 | theta_batch = Variable(theta_batch,requires_grad=False) 112 | 113 | # check if output dimensions have been specified at call time and have changed 114 | if (out_h is not None and out_w is not None) and (out_h!=self.out_h or out_w!=self.out_w): 115 | if self.geometric_model=='affine': 116 | gridGen = AffineGridGen(out_h, out_w) 117 | elif self.geometric_model=='tps': 118 | gridGen = TpsGridGen(out_h, out_w, use_cuda=self.use_cuda) 119 | else: 120 | gridGen = self.gridGen 121 | 122 | sampling_grid = gridGen(theta_batch) 123 | 124 | # rescale grid according to crop_factor and padding_factor 125 | if padding_factor != 1 or crop_factor !=1: 126 | sampling_grid = sampling_grid*(padding_factor*crop_factor) 127 | # rescale grid according to offset_factor 128 | if self.offset_factor is not None: 129 | sampling_grid = sampling_grid*self.offset_factor 130 | 131 | if return_sampling_grid and not return_warped_image: 132 | return sampling_grid 133 | 134 | # sample transformed image 135 | warped_image_batch = F.grid_sample(image_batch, sampling_grid) 136 | 137 | if return_sampling_grid and return_warped_image: 138 | return (warped_image_batch,sampling_grid) 139 | 140 | return warped_image_batch 141 | 142 | 143 | 144 | class SynthPairTnf(object): 145 | """ 146 | 147 | Generate a synthetically warped training pair using an affine transformation. 148 | 149 | """ 150 | def __init__(self, use_cuda=True, supervision='strong', geometric_model='affine', crop_factor=9/16, output_size=(240,240), padding_factor = 0.5): 151 | assert isinstance(use_cuda, (bool)) 152 | assert isinstance(crop_factor, (float)) 153 | assert isinstance(output_size, (tuple)) 154 | assert isinstance(padding_factor, (float)) 155 | self.supervision=supervision 156 | self.use_cuda=use_cuda 157 | self.crop_factor = crop_factor 158 | self.padding_factor = padding_factor 159 | self.out_h, self.out_w = output_size 160 | self.rescalingTnf = GeometricTnf('affine', out_h=self.out_h, out_w=self.out_w, 161 | use_cuda = self.use_cuda) 162 | self.geometricTnf = GeometricTnf(geometric_model, out_h=self.out_h, out_w=self.out_w, 163 | use_cuda = self.use_cuda) 164 | 165 | 166 | def __call__(self, batch): 167 | image_batch, theta_batch = batch['image'], batch['theta'] 168 | if self.use_cuda: 169 | image_batch = image_batch.cuda() 170 | theta_batch = theta_batch.cuda() 171 | 172 | b, c, h, w = image_batch.size() 173 | 174 | # generate symmetrically padded image for bigger sampling region 175 | image_batch = self.symmetricImagePad(image_batch,self.padding_factor) 176 | 177 | # convert to variables 178 | image_batch = Variable(image_batch,requires_grad=False) 179 | theta_batch = Variable(theta_batch,requires_grad=False) 180 | 181 | # get cropped image 182 | cropped_image_batch = self.rescalingTnf(image_batch=image_batch, 183 | theta_batch=None, 184 | padding_factor=self.padding_factor, 185 | crop_factor=self.crop_factor) # Identity is used as no theta given 186 | # get transformed image 187 | warped_image_batch = self.geometricTnf(image_batch=image_batch, 188 | theta_batch=theta_batch, 189 | padding_factor=self.padding_factor, 190 | crop_factor=self.crop_factor) # Identity is used as no theta given 191 | 192 | if self.supervision=='strong': 193 | return {'source_image': cropped_image_batch, 'target_image': warped_image_batch, 'theta_GT': theta_batch} 194 | 195 | elif self.supervision=='weak': 196 | pos_batch_idx = torch.LongTensor(range(int(b/2))) 197 | neg_batch_idx = torch.LongTensor(range(int(b/2),b)) 198 | if self.use_cuda: 199 | pos_batch_idx = pos_batch_idx.cuda() 200 | neg_batch_idx = neg_batch_idx.cuda() 201 | source_image = torch.cat((torch.index_select(cropped_image_batch,0,pos_batch_idx), 202 | torch.index_select(cropped_image_batch,0,pos_batch_idx)),0) 203 | target_image = torch.cat((torch.index_select(warped_image_batch,0,pos_batch_idx), 204 | torch.index_select(cropped_image_batch,0,neg_batch_idx)),0) 205 | return {'source_image': source_image, 'target_image': target_image, 'theta_GT': theta_batch} 206 | 207 | def symmetricImagePad(self, image_batch, padding_factor): 208 | b, c, h, w = image_batch.size() 209 | pad_h, pad_w = int(h*padding_factor), int(w*padding_factor) 210 | idx_pad_left = torch.LongTensor(range(pad_w-1,-1,-1)) 211 | idx_pad_right = torch.LongTensor(range(w-1,w-pad_w-1,-1)) 212 | idx_pad_top = torch.LongTensor(range(pad_h-1,-1,-1)) 213 | idx_pad_bottom = torch.LongTensor(range(h-1,h-pad_h-1,-1)) 214 | if self.use_cuda: 215 | idx_pad_left = idx_pad_left.cuda() 216 | idx_pad_right = idx_pad_right.cuda() 217 | idx_pad_top = idx_pad_top.cuda() 218 | idx_pad_bottom = idx_pad_bottom.cuda() 219 | image_batch = torch.cat((image_batch.index_select(3,idx_pad_left),image_batch, 220 | image_batch.index_select(3,idx_pad_right)),3) 221 | image_batch = torch.cat((image_batch.index_select(2,idx_pad_top),image_batch, 222 | image_batch.index_select(2,idx_pad_bottom)),2) 223 | return image_batch 224 | 225 | class SynthTwoStageTnf(SynthPairTnf): 226 | def __init__(self, use_cuda=True, crop_factor=9/16, output_size=(240,240), padding_factor = 0.5): 227 | super().__init__(use_cuda=use_cuda) 228 | # self.aff_reorder_idx=torch.LongTensor([3,2,5,1,0,4]) 229 | self.geometricTnf = ComposedGeometricTnf(padding_crop_factor=padding_factor*crop_factor,use_cuda=self.use_cuda) 230 | 231 | def __call__(self, batch): 232 | image_batch, theta_batch = batch['image'], batch['theta'] 233 | # theta_aff=torch.index_select(theta_batch[:,:6],1,self.aff_reorder_idx) 234 | theta_aff=theta_batch[:,:6].contiguous() 235 | theta_tps=theta_batch[:,6:] 236 | 237 | if self.use_cuda: 238 | image_batch = image_batch.cuda() 239 | theta_aff = theta_aff.cuda() 240 | theta_tps = theta_tps.cuda() 241 | 242 | b, c, h, w = image_batch.size() 243 | 244 | # generate symmetrically padded image for bigger sampling region 245 | image_batch = self.symmetricImagePad(image_batch,self.padding_factor) 246 | 247 | # convert to variables 248 | image_batch = Variable(image_batch,requires_grad=False) 249 | theta_aff = Variable(theta_aff,requires_grad=False) 250 | theta_tps = Variable(theta_tps,requires_grad=False) 251 | 252 | # get cropped image 253 | cropped_image_batch = self.rescalingTnf(image_batch=image_batch, 254 | theta_batch=None, 255 | padding_factor=self.padding_factor, 256 | crop_factor=self.crop_factor) # Identity is used as no theta given 257 | # get transformed image 258 | warped_image_batch = self.geometricTnf(image_batch=image_batch, 259 | theta_aff=theta_aff, 260 | theta_aff_tps=theta_tps) 261 | 262 | return {'source_image': cropped_image_batch, 'target_image': warped_image_batch, 'theta_GT_aff': theta_aff, 'theta_GT_tps': theta_tps} 263 | 264 | class SynthTwoStageTwoPairTnf(SynthPairTnf): 265 | def __init__(self, use_cuda=True, crop_factor=9/16, output_size=(240,240), padding_factor = 0.5): 266 | super().__init__(use_cuda=use_cuda) 267 | # self.aff_reorder_idx=torch.LongTensor([3,2,5,1,0,4]) 268 | self.geometricTnf = ComposedGeometricTnf(padding_crop_factor=padding_factor*crop_factor,use_cuda=self.use_cuda) 269 | self.affTnf = GeometricTnf(geometric_model='affine', out_h=self.out_h, out_w=self.out_w, 270 | use_cuda = self.use_cuda) 271 | self.tpsTnf = GeometricTnf(geometric_model='tps', out_h=self.out_h, out_w=self.out_w, 272 | use_cuda = self.use_cuda) 273 | 274 | def __call__(self, batch): 275 | image_batch, theta_batch = batch['image'], batch['theta'] 276 | theta_aff=theta_batch[:,:6].contiguous() 277 | # theta_aff=torch.index_select(theta_batch[:,:6],1,self.aff_reorder_idx) 278 | theta_tps=theta_batch[:,6:] 279 | 280 | if self.use_cuda: 281 | image_batch = image_batch.cuda() 282 | theta_aff = theta_aff.cuda() 283 | theta_tps = theta_tps.cuda() 284 | 285 | b, c, h, w = image_batch.size() 286 | 287 | # generate symmetrically padded image for bigger sampling region 288 | image_batch = self.symmetricImagePad(image_batch,self.padding_factor) 289 | 290 | # convert to variables 291 | image_batch = Variable(image_batch,requires_grad=False) 292 | theta_aff = Variable(theta_aff,requires_grad=False) 293 | theta_tps = Variable(theta_tps,requires_grad=False) 294 | 295 | # get cropped image 296 | cropped_image_batch = self.rescalingTnf(image_batch=image_batch, 297 | theta_batch=None, 298 | padding_factor=self.padding_factor, 299 | crop_factor=self.crop_factor) # Identity is used as no theta given 300 | # get transformed image 301 | target_image_tps = self.geometricTnf(image_batch=image_batch, 302 | theta_aff=theta_aff, 303 | theta_aff_tps=theta_tps) 304 | 305 | target_image_aff = self.affTnf(image_batch=image_batch, 306 | theta_batch=theta_aff, 307 | padding_factor=self.padding_factor, 308 | crop_factor=self.crop_factor) 309 | 310 | source_image_tps = self.affTnf(image_batch=cropped_image_batch, 311 | theta_batch=theta_aff, 312 | padding_factor=1.0, 313 | crop_factor=1.0) 314 | 315 | return {'source_image_aff': cropped_image_batch, 316 | 'target_image_aff': target_image_aff, 317 | 'source_image_tps': source_image_tps, 318 | 'target_image_tps': target_image_tps, 319 | 'theta_GT_aff': theta_aff, 320 | 'theta_GT_tps': theta_tps} 321 | 322 | class SynthTwoPairTnf(SynthPairTnf): 323 | def __init__(self, use_cuda=True, crop_factor=9/16, output_size=(240,240), padding_factor = 0.5): 324 | super().__init__(use_cuda=use_cuda) 325 | # self.aff_reorder_idx=torch.LongTensor([3,2,5,1,0,4]) 326 | self.affTnf = GeometricTnf(geometric_model='affine', out_h=self.out_h, out_w=self.out_w, 327 | use_cuda = self.use_cuda) 328 | self.tpsTnf = GeometricTnf(geometric_model='tps', out_h=self.out_h, out_w=self.out_w, 329 | use_cuda = self.use_cuda) 330 | 331 | def __call__(self, batch): 332 | image_batch, theta_batch = batch['image'], batch['theta'] 333 | # theta_aff=torch.index_select(theta_batch[:,:6],1,self.aff_reorder_idx) 334 | theta_aff=theta_batch[:,:6].contiguous() 335 | theta_tps=theta_batch[:,6:] 336 | 337 | if self.use_cuda: 338 | image_batch = image_batch.cuda() 339 | theta_aff = theta_aff.cuda() 340 | theta_tps = theta_tps.cuda() 341 | 342 | b, c, h, w = image_batch.size() 343 | 344 | # generate symmetrically padded image for bigger sampling region 345 | image_batch = self.symmetricImagePad(image_batch,self.padding_factor) 346 | 347 | # convert to variables 348 | image_batch = Variable(image_batch,requires_grad=False) 349 | theta_aff = Variable(theta_aff,requires_grad=False) 350 | theta_tps = Variable(theta_tps,requires_grad=False) 351 | 352 | # get cropped image 353 | cropped_image_batch = self.rescalingTnf(image_batch=image_batch, 354 | theta_batch=None, 355 | padding_factor=self.padding_factor, 356 | crop_factor=self.crop_factor) # Identity is used as no theta given 357 | # get transformed image 358 | warped_image_aff = self.affTnf(image_batch=image_batch, 359 | theta_batch=theta_aff, 360 | padding_factor=self.padding_factor, 361 | crop_factor=self.crop_factor) 362 | 363 | warped_image_tps = self.tpsTnf(image_batch=image_batch, 364 | theta_batch=theta_tps, 365 | padding_factor=self.padding_factor, 366 | crop_factor=self.crop_factor) 367 | 368 | return {'source_image': cropped_image_batch, 'target_image_aff': warped_image_aff, 'target_image_tps': warped_image_tps, 'theta_GT_aff': theta_aff, 'theta_GT_tps': theta_tps} 369 | 370 | 371 | class AffineGridGen(Module): 372 | def __init__(self, out_h=240, out_w=240, out_ch = 3, use_cuda=True): 373 | super(AffineGridGen, self).__init__() 374 | self.out_h = out_h 375 | self.out_w = out_w 376 | self.out_ch = out_ch 377 | 378 | def forward(self, theta): 379 | b=theta.size()[0] 380 | if not theta.size()==(b,2,3): 381 | theta = theta.view(-1,2,3) 382 | theta = theta.contiguous() 383 | batch_size = theta.size()[0] 384 | out_size = torch.Size((batch_size,self.out_ch,self.out_h,self.out_w)) 385 | return F.affine_grid(theta, out_size) 386 | 387 | class AffineGridGenV2(Module): 388 | def __init__(self, out_h=240, out_w=240, use_cuda=True): 389 | super(AffineGridGenV2, self).__init__() 390 | self.out_h, self.out_w = out_h, out_w 391 | self.use_cuda = use_cuda 392 | 393 | # create grid in numpy 394 | # self.grid = np.zeros( [self.out_h, self.out_w, 3], dtype=np.float32) 395 | # sampling grid with dim-0 coords (Y) 396 | self.grid_X,self.grid_Y = np.meshgrid(np.linspace(-1,1,out_w),np.linspace(-1,1,out_h)) 397 | # grid_X,grid_Y: size [1,H,W,1,1] 398 | self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3) 399 | self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3) 400 | self.grid_X = Variable(self.grid_X,requires_grad=False) 401 | self.grid_Y = Variable(self.grid_Y,requires_grad=False) 402 | if use_cuda: 403 | self.grid_X = self.grid_X.cuda() 404 | self.grid_Y = self.grid_Y.cuda() 405 | 406 | def forward(self, theta): 407 | b=theta.size(0) 408 | if not theta.size()==(b,6): 409 | theta = theta.view(b,6) 410 | theta = theta.contiguous() 411 | 412 | t0=theta[:,0].unsqueeze(1).unsqueeze(2).unsqueeze(3) 413 | t1=theta[:,1].unsqueeze(1).unsqueeze(2).unsqueeze(3) 414 | t2=theta[:,2].unsqueeze(1).unsqueeze(2).unsqueeze(3) 415 | t3=theta[:,3].unsqueeze(1).unsqueeze(2).unsqueeze(3) 416 | t4=theta[:,4].unsqueeze(1).unsqueeze(2).unsqueeze(3) 417 | t5=theta[:,5].unsqueeze(1).unsqueeze(2).unsqueeze(3) 418 | X = expand_dim(self.grid_X,0,b) 419 | Y = expand_dim(self.grid_Y,0,b) 420 | Xp = X*t0 + Y*t1 + t2 421 | Yp = X*t3 + Y*t4 + t5 422 | 423 | return torch.cat((Xp,Yp),3) 424 | 425 | class TpsGridGen(Module): 426 | def __init__(self, out_h=240, out_w=240, use_regular_grid=True, grid_size=3, reg_factor=0, use_cuda=True): 427 | super(TpsGridGen, self).__init__() 428 | self.out_h, self.out_w = out_h, out_w 429 | self.reg_factor = reg_factor 430 | self.use_cuda = use_cuda 431 | 432 | # create grid in numpy 433 | # self.grid = np.zeros( [self.out_h, self.out_w, 3], dtype=np.float32) 434 | # sampling grid with dim-0 coords (Y) 435 | self.grid_X,self.grid_Y = np.meshgrid(np.linspace(-1,1,out_w),np.linspace(-1,1,out_h)) 436 | # grid_X,grid_Y: size [1,H,W,1,1] 437 | self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3) 438 | self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3) 439 | self.grid_X = Variable(self.grid_X,requires_grad=False) 440 | self.grid_Y = Variable(self.grid_Y,requires_grad=False) 441 | if use_cuda: 442 | self.grid_X = self.grid_X.cuda() 443 | self.grid_Y = self.grid_Y.cuda() 444 | 445 | # initialize regular grid for control points P_i 446 | if use_regular_grid: 447 | axis_coords = np.linspace(-1,1,grid_size) 448 | self.N = grid_size*grid_size 449 | P_Y,P_X = np.meshgrid(axis_coords,axis_coords) 450 | P_X = np.reshape(P_X,(-1,1)) # size (N,1) 451 | P_Y = np.reshape(P_Y,(-1,1)) # size (N,1) 452 | P_X = torch.FloatTensor(P_X) 453 | P_Y = torch.FloatTensor(P_Y) 454 | self.Li = Variable(self.compute_L_inverse(P_X,P_Y).unsqueeze(0),requires_grad=False) 455 | self.P_X = P_X.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0,4) 456 | self.P_Y = P_Y.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0,4) 457 | self.P_X = Variable(self.P_X,requires_grad=False) 458 | self.P_Y = Variable(self.P_Y,requires_grad=False) 459 | if use_cuda: 460 | self.P_X = self.P_X.cuda() 461 | self.P_Y = self.P_Y.cuda() 462 | 463 | 464 | def forward(self, theta): 465 | warped_grid = self.apply_transformation(theta,torch.cat((self.grid_X,self.grid_Y),3)) 466 | 467 | return warped_grid 468 | 469 | def compute_L_inverse(self,X,Y): 470 | N = X.size()[0] # num of points (along dim 0) 471 | # construct matrix K 472 | Xmat = X.expand(N,N) 473 | Ymat = Y.expand(N,N) 474 | P_dist_squared = torch.pow(Xmat-Xmat.transpose(0,1),2)+torch.pow(Ymat-Ymat.transpose(0,1),2) 475 | P_dist_squared[P_dist_squared==0]=1 # make diagonal 1 to avoid NaN in log computation 476 | K = torch.mul(P_dist_squared,torch.log(P_dist_squared)) 477 | if self.reg_factor != 0: 478 | K+=torch.eye(K.size(0),K.size(1))*self.reg_factor 479 | # construct matrix L 480 | O = torch.FloatTensor(N,1).fill_(1) 481 | Z = torch.FloatTensor(3,3).fill_(0) 482 | P = torch.cat((O,X,Y),1) 483 | L = torch.cat((torch.cat((K,P),1),torch.cat((P.transpose(0,1),Z),1)),0) 484 | Li = torch.inverse(L) 485 | if self.use_cuda: 486 | Li = Li.cuda() 487 | return Li 488 | 489 | def apply_transformation(self,theta,points): 490 | if theta.dim()==2: 491 | theta = theta.unsqueeze(2).unsqueeze(3) 492 | # points should be in the [B,H,W,2] format, 493 | # where points[:,:,:,0] are the X coords 494 | # and points[:,:,:,1] are the Y coords 495 | 496 | # input are the corresponding control points P_i 497 | batch_size = theta.size()[0] 498 | # split theta into point coordinates 499 | Q_X=theta[:,:self.N,:,:].squeeze(3) 500 | Q_Y=theta[:,self.N:,:,:].squeeze(3) 501 | 502 | # get spatial dimensions of points 503 | points_b = points.size()[0] 504 | points_h = points.size()[1] 505 | points_w = points.size()[2] 506 | 507 | # repeat pre-defined control points along spatial dimensions of points to be transformed 508 | P_X = self.P_X.expand((1,points_h,points_w,1,self.N)) 509 | P_Y = self.P_Y.expand((1,points_h,points_w,1,self.N)) 510 | 511 | # compute weigths for non-linear part 512 | W_X = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_X) 513 | W_Y = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_Y) 514 | # reshape 515 | # W_X,W,Y: size [B,H,W,1,N] 516 | W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1) 517 | W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1) 518 | # compute weights for affine part 519 | A_X = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_X) 520 | A_Y = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_Y) 521 | # reshape 522 | # A_X,A,Y: size [B,H,W,1,3] 523 | A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1) 524 | A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1) 525 | 526 | # compute distance P_i - (grid_X,grid_Y) 527 | # grid is expanded in point dim 4, but not in batch dim 0, as points P_X,P_Y are fixed for all batch 528 | points_X_for_summation = points[:,:,:,0].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,0].size()+(1,self.N)) 529 | points_Y_for_summation = points[:,:,:,1].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,1].size()+(1,self.N)) 530 | 531 | if points_b==1: 532 | delta_X = points_X_for_summation-P_X 533 | delta_Y = points_Y_for_summation-P_Y 534 | else: 535 | # use expanded P_X,P_Y in batch dimension 536 | delta_X = points_X_for_summation-P_X.expand_as(points_X_for_summation) 537 | delta_Y = points_Y_for_summation-P_Y.expand_as(points_Y_for_summation) 538 | 539 | dist_squared = torch.pow(delta_X,2)+torch.pow(delta_Y,2) 540 | # U: size [1,H,W,1,N] 541 | dist_squared[dist_squared==0]=1 # avoid NaN in log computation 542 | U = torch.mul(dist_squared,torch.log(dist_squared)) 543 | 544 | # expand grid in batch dimension if necessary 545 | points_X_batch = points[:,:,:,0].unsqueeze(3) 546 | points_Y_batch = points[:,:,:,1].unsqueeze(3) 547 | if points_b==1: 548 | points_X_batch = points_X_batch.expand((batch_size,)+points_X_batch.size()[1:]) 549 | points_Y_batch = points_Y_batch.expand((batch_size,)+points_Y_batch.size()[1:]) 550 | 551 | points_X_prime = A_X[:,:,:,:,0]+ \ 552 | torch.mul(A_X[:,:,:,:,1],points_X_batch) + \ 553 | torch.mul(A_X[:,:,:,:,2],points_Y_batch) + \ 554 | torch.sum(torch.mul(W_X,U.expand_as(W_X)),4) 555 | 556 | points_Y_prime = A_Y[:,:,:,:,0]+ \ 557 | torch.mul(A_Y[:,:,:,:,1],points_X_batch) + \ 558 | torch.mul(A_Y[:,:,:,:,2],points_Y_batch) + \ 559 | torch.sum(torch.mul(W_Y,U.expand_as(W_Y)),4) 560 | 561 | return torch.cat((points_X_prime,points_Y_prime),3) 562 | -------------------------------------------------------------------------------- /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/MaCoSNet-pytorch/697ea3a824c5bc9bbca6fac8d33367a8b9def4d1/img/teaser.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/MaCoSNet-pytorch/697ea3a824c5bc9bbca6fac8d33367a8b9def4d1/model/__init__.py -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | from torch.autograd import Variable 7 | from geotnf.point_tnf import PointTnf, PointsToUnitCoords, PointsToPixelCoords 8 | from scipy.ndimage.morphology import binary_dilation, generate_binary_structure 9 | from util.torch_util import expand_dim 10 | from geotnf.transformation import GeometricTnf, ComposedGeometricTnf 11 | import torch.nn.functional as F 12 | import scipy.signal 13 | import config 14 | 15 | 16 | class CycleLoss(nn.Module): 17 | 18 | def __init__(self, 19 | image_size=240, 20 | transform='affine', 21 | use_cuda=True): 22 | 23 | super(CycleLoss, self).__init__() 24 | 25 | self.pointTnf = PointTnf(use_cuda=use_cuda) 26 | self.transform = transform 27 | 28 | self.coord = [] 29 | for i in range(config.NUM_OF_COORD): 30 | for j in range(config.NUM_OF_COORD): 31 | xx = [] 32 | xx.append(float(i) * image_size / config.NUM_OF_COORD) 33 | xx.append(float(j) * image_size / config.NUM_OF_COORD) 34 | self.coord.append(xx) 35 | self.coord = np.expand_dims(np.array(self.coord).transpose(), axis=0) 36 | self.coord = torch.from_numpy(self.coord).float() 37 | 38 | if use_cuda: 39 | self.coord = self.coord.cuda() 40 | 41 | def forward(self, theta_forward, theta_backward): 42 | batch = theta_forward.size()[0] 43 | b,h,w = self.coord.size() 44 | coord = Variable(self.coord.expand(batch, h, w)) 45 | 46 | img_size = Variable(torch.FloatTensor([[240, 240, 1]])).cuda() 47 | 48 | forward_norm = PointsToUnitCoords(coord, img_size) 49 | forward_norm = self.pointTnf.affPointTnf(theta_forward, forward_norm) 50 | forward_coord = PointsToPixelCoords(forward_norm, img_size) 51 | 52 | backward_norm = PointsToUnitCoords(forward_coord, img_size) 53 | backward_norm = self.pointTnf.affPointTnf(theta_backward, backward_norm) 54 | backward_coord = PointsToPixelCoords(backward_norm, img_size) 55 | 56 | loss = (torch.dist(coord, backward_coord, p=2) ** 2) / (config.NUM_OF_COORD * config.NUM_OF_COORD) / batch 57 | 58 | return loss 59 | 60 | 61 | class TransLoss(nn.Module): 62 | 63 | def __init__(self, 64 | transform='affine', 65 | use_cuda=True): 66 | 67 | super(TransLoss, self).__init__() 68 | 69 | self.pointTnf = PointTnf(use_cuda=use_cuda) 70 | self.transform = transform 71 | 72 | self.coord = [] 73 | for i in range(config.NUM_OF_COORD): 74 | for j in range(config.NUM_OF_COORD): 75 | xx = [] 76 | xx.append(float(i)) 77 | xx.append(float(j)) 78 | self.coord.append(xx) 79 | self.coord = np.expand_dims(np.array(self.coord).transpose(), axis=0) 80 | self.coord = torch.from_numpy(self.coord).float() 81 | 82 | if use_cuda: 83 | self.coord = self.coord.cuda() 84 | 85 | def forward(self, theta_A, theta_B, theta_C): 86 | batch = theta_A.size()[0] 87 | b,h,w = self.coord.size() 88 | self.coord = Variable(self.coord.expand(batch, h, w)) 89 | 90 | img_size = Variable(torch.FloatTensor([[240, 240, 1]])).cuda() 91 | 92 | A_norm = PointsToUnitCoords(self.coord, img_size) 93 | A_norm = self.pointTnf.affPointTnf(theta_A, A_norm) 94 | A_coord = PointsToPixelCoords(A_norm, img_size) 95 | 96 | B_norm = PointsToUnitCoords(A_coord, img_size) 97 | B_norm = self.pointTnf.affPointTnf(theta_B, B_norm) 98 | B_coord = PointsToPixelCoords(B_norm, img_size) 99 | 100 | C_norm = PointsToUnitCoords(B_coord, img_size) 101 | C_norm = self.pointTnf.affPointTnf(theta_C, C_norm) 102 | C_coord = PointsToPixelCoords(C_norm, img_size) 103 | 104 | loss = (torch.dist(self.coord, C_coord, p=2) ** 2) / (config.NUM_OF_COORD * config.NUM_OF_COORD) / batch 105 | 106 | return loss 107 | 108 | 109 | class CosegLoss(nn.Module): 110 | 111 | def __init__(self, 112 | threshold=config.THRESHOLD, 113 | use_cuda=True): 114 | 115 | super(CosegLoss, self).__init__() 116 | 117 | self.threshold = threshold 118 | 119 | self.affTnf = GeometricTnf(geometric_model='affine', 120 | out_h=240, 121 | out_w=240, 122 | use_cuda=use_cuda) 123 | 124 | self.extractor = models.resnet50(pretrained=True) 125 | self.extractor = nn.Sequential(*list(self.extractor.children())[:-1]) 126 | for name,param in self.extractor.named_parameters(): 127 | param.requires_grad = False 128 | 129 | if use_cuda: 130 | self.extractor = self.extractor.cuda() 131 | 132 | def forward(self, image_dict, mask_dict): 133 | 134 | image_A = image_dict['image_A'] 135 | image_B = image_dict['image_B'] 136 | 137 | mask_A = F.sigmoid(mask_dict['mask_A']) 138 | mask_B = F.sigmoid(mask_dict['mask_B']) 139 | 140 | obj_A = torch.squeeze(self.extractor(torch.mul(image_A, mask_A))) 141 | back_A = torch.squeeze(self.extractor(torch.mul(image_A, 1.0 - mask_A))) 142 | 143 | obj_B = torch.squeeze(self.extractor(torch.mul(image_B, mask_B))) 144 | back_B = torch.squeeze(self.extractor(torch.mul(image_B, 1.0 - mask_B))) 145 | 146 | batch, dim = obj_A.size() 147 | pos = (torch.dist(obj_A, obj_B, p=2) ** 2) / dim / batch 148 | neg = torch.max(0, self.threshold - ((torch.dist(obj_A, back_A, p=2) ** 2 + torch.dist(obj_B, back_B, p=2) ** 2) / dim / batch)) 149 | 150 | loss = pos + neg 151 | 152 | return loss 153 | 154 | 155 | class TaskLoss(nn.Module): 156 | def __init__(self, 157 | out_h=240, 158 | out_w=240, 159 | use_cuda=True): 160 | 161 | super(TaskLoss, self).__init__() 162 | 163 | self.affTnf = GeometricTnf(geometric_model='affine', 164 | out_h=out_h, 165 | out_w=out_w, 166 | use_cuda=use_cuda) 167 | 168 | def forward(self, theta, mask_dict): 169 | 170 | aff_AB = theta['aff_AB'] 171 | aff_BA = theta['aff_BA'] 172 | 173 | mask_A = F.sigmoid(mask_dict['mask_A']) 174 | mask_B = F.sigmoid(mask_dict['mask_B']) 175 | 176 | batch,c,h,w = mask_A.size() 177 | 178 | mask_Awrp_B = self.affTnf(mask_A, aff_AB) 179 | mask_Bwrp_A = self.affTnf(mask_B, aff_BA) 180 | 181 | loss_A = (F.binary_cross_entropy(mask_A, mask_Bwrp_A) + F.binary_cross_entropy(1.0 - mask_A, 1.0 - mask_Bwrp_A)) / (h * w) / batch 182 | loss_B = (F.binary_cross_entropy(mask_B, mask_Awrp_B) + F.binary_cross_entropy(1.0 - mask_B, 1.0 - mask_Awrp_B)) / (h * w) / batch 183 | 184 | loss = (loss_A + loss_B) / 2.0 185 | 186 | return loss 187 | 188 | 189 | class AffMatchScore(nn.Module): 190 | 191 | def __init__(self, 192 | tps_grid_size=3, 193 | tps_reg_factor=0, 194 | h_matches=15, 195 | w_matches=15, 196 | use_conv_filter=False, 197 | dilation_filter=None, 198 | use_cuda=True, 199 | seg_mask=False, 200 | normalize_inlier_count=False, 201 | offset_factor=227/210): 202 | 203 | super(AffMatchScore, self).__init__() 204 | 205 | self.normalize = normalize_inlier_count 206 | 207 | self.seg_mask = seg_mask 208 | 209 | self.geometricTnf = GeometricTnf(geometric_model='affine', 210 | tps_grid_size=tps_grid_size, 211 | tps_reg_factor=tps_reg_factor, 212 | out_h=h_matches, out_w=w_matches, 213 | offset_factor = offset_factor, 214 | use_cuda=use_cuda) 215 | # define dilation filter 216 | if dilation_filter is None: 217 | dilation_filter = generate_binary_structure(2, 2) 218 | 219 | # define identity mask tensor (w, h are switched and will be permuted back later) 220 | mask_id = np.zeros((w_matches, h_matches, w_matches*h_matches)) 221 | 222 | idx_list = list(range(0, mask_id.size, mask_id.shape[2] + 1)) 223 | mask_id.reshape((-1))[idx_list] = 1 224 | mask_id = mask_id.swapaxes(0,1) 225 | 226 | # perform 2D dilation to each channel 227 | if not use_conv_filter: 228 | if not (isinstance(dilation_filter, int) and dilation_filter == 0): 229 | for i in range(mask_id.shape[2]): 230 | mask_id[:,:,i] = binary_dilation(mask_id[:,:,i], structure=dilation_filter).astype(mask_id.dtype) 231 | else: 232 | for i in range(mask_id.shape[2]): 233 | flt = np.array([[1/16,1/8,1/16], 234 | [1/8, 1/4, 1/8], 235 | [1/16,1/8,1/16]]) 236 | mask_id[:,:,i] = scipy.signal.convolve2d(mask_id[:,:,i], flt, mode='same', boundary='fill', fillvalue=0) 237 | 238 | # convert to PyTorch variable 239 | mask_id = Variable(torch.FloatTensor(mask_id).transpose(1,2).transpose(0,1).unsqueeze(0), requires_grad=False) 240 | self.mask_id = mask_id 241 | if use_cuda: 242 | self.mask_id = self.mask_id.cuda(); 243 | 244 | def forward(self, theta, matches, seg_mask=None, return_outliers=False): 245 | 246 | if isinstance(theta, Variable): # handle normal batch transformations 247 | batch_size = theta.size()[0] 248 | theta = theta.clone() 249 | mask = self.geometricTnf(expand_dim(self.mask_id, 0, batch_size), theta) 250 | 251 | if return_outliers: 252 | mask_outliers = self.geometricTnf(expand_dim(1.0-self.mask_id,0,batch_size),theta) 253 | 254 | if self.normalize: 255 | epsilon = 1e-5 256 | mask = torch.div(mask, 257 | torch.sum(torch.sum(torch.sum(mask+epsilon,3),2),1).unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(mask)) 258 | if return_outliers: 259 | mask_outliers = torch.div(mask_outliers, 260 | torch.sum(torch.sum(torch.sum(mask_outliers+epsilon,3),2),1).unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(mask_outliers)) 261 | 262 | if self.seg_mask: 263 | score = torch.sum(torch.sum(torch.mul(torch.sum(torch.mul(mask,matches),1),seg_mask),2),1) 264 | else: 265 | score = torch.sum(torch.sum(torch.sum(torch.mul(mask,matches),3),2),1) 266 | 267 | if return_outliers: 268 | score_outliers = torch.sum(torch.sum(torch.sum(torch.mul(mask_outliers,matches),3),2),1) 269 | return (score,score_outliers) 270 | 271 | elif isinstance(theta, list): # handle multiple transformations per batch item, batch is in list format (used for RANSAC) 272 | batch_size = len(theta) 273 | score = [] 274 | for b in range(batch_size): 275 | sample_size=theta[b].size(0) 276 | s=self.forward(theta[b],expand_dim(matches[b,:,:,:].unsqueeze(0),0,sample_size)) 277 | score.append(s) 278 | 279 | return score 280 | 281 | 282 | class TpsMatchScore(AffMatchScore): 283 | 284 | def __init__(self, 285 | tps_grid_size=3, 286 | tps_reg_factor=0, 287 | h_matches=15, 288 | w_matches=15, 289 | use_conv_filter=False, 290 | dilation_filter=None, 291 | use_cuda=True, 292 | seg_mask=False, 293 | normalize_inlier_count=False, 294 | offset_factor=227/210): 295 | 296 | super(TpsMatchScore, self).__init__(h_matches=h_matches, 297 | w_matches=w_matches, 298 | use_conv_filter=use_conv_filter, 299 | dilation_filter=dilation_filter, 300 | use_cuda=use_cuda, 301 | seg_mask=seg_mask, 302 | normalize_inlier_count=normalize_inlier_count, 303 | offset_factor=offset_factor) 304 | 305 | self.compGeometricTnf = ComposedGeometricTnf(tps_grid_size=tps_grid_size, 306 | tps_reg_factor=tps_reg_factor, 307 | out_h=h_matches, 308 | out_w=w_matches, 309 | offset_factor=offset_factor, 310 | use_cuda=use_cuda) 311 | 312 | def forward(self, theta_aff, theta_aff_tps, matches, seg_mask=None, return_outliers=False): 313 | 314 | batch_size=theta_aff.size()[0] 315 | mask = self.compGeometricTnf(image_batch=expand_dim(self.mask_id,0,batch_size), 316 | theta_aff=theta_aff, 317 | theta_aff_tps=theta_aff_tps) 318 | if return_outliers: 319 | mask_outliers = self.compGeometricTnf(image_batch=expand_dim(1.0-self.mask_id,0,batch_size), 320 | theta_aff=theta_aff, 321 | theta_aff_tps=theta_aff_tps) 322 | if self.normalize: 323 | epsilon=1e-5 324 | mask = torch.div(mask, 325 | torch.sum(torch.sum(torch.sum(mask+epsilon,3),2),1).unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(mask)) 326 | if return_outliers: 327 | mask_outliers = torch.div(mask, 328 | torch.sum(torch.sum(torch.sum(mask_outliers+epsilon,3),2),1).unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(mask_outliers)) 329 | 330 | if self.seg_mask: 331 | score = torch.sum(torch.sum(torch.mul(torch.sum(torch.mul(mask,matches),1),seg_mask),2),1) 332 | else: 333 | score = torch.sum(torch.sum(torch.sum(torch.mul(mask,matches),3),2),1) 334 | 335 | if return_outliers: 336 | score_outliers = torch.sum(torch.sum(torch.sum(torch.mul(mask_outliers,matches),3),2),1) 337 | return (score,score_outliers) 338 | 339 | return score 340 | 341 | 342 | class GridLoss(nn.Module): 343 | 344 | def __init__(self, 345 | geometric_model='affine', 346 | use_cuda=True, 347 | grid_size=20): 348 | 349 | super(GridLoss, self).__init__() 350 | 351 | self.geometric_model = geometric_model 352 | 353 | # define virtual grid of points to be transformed 354 | axis_coords = np.linspace(-1,1,grid_size) 355 | 356 | self.N = grid_size * grid_size 357 | 358 | X,Y = np.meshgrid(axis_coords, axis_coords) 359 | X = np.reshape(X,(1,1,self.N)) 360 | Y = np.reshape(Y,(1,1,self.N)) 361 | P = np.concatenate((X,Y),1) 362 | self.P = Variable(torch.FloatTensor(P), requires_grad=False) 363 | 364 | self.pointTnf = PointTnf(use_cuda=use_cuda) 365 | 366 | if use_cuda: 367 | self.P = self.P.cuda(); 368 | 369 | def forward(self, theta, theta_GT): 370 | 371 | # expand grid according to batch size 372 | batch_size = theta.size()[0] 373 | P = self.P.expand(batch_size,2,self.N) 374 | 375 | # compute transformed grid points using estimated and GT tnfs 376 | if self.geometric_model == 'affine': 377 | P_prime = self.pointTnf.affPointTnf(theta, P) 378 | P_prime_GT = self.pointTnf.affPointTnf(theta_GT, P) 379 | 380 | elif self.geometric_model == 'tps': 381 | P_prime = self.pointTnf.tpsPointTnf(theta.unsqueeze(2).unsqueeze(3),P) 382 | P_prime_GT = self.pointTnf.tpsPointTnf(theta_GT, P) 383 | 384 | # compute MSE loss on transformed grid points 385 | loss = torch.sum(torch.pow(P_prime - P_prime_GT,2),1) 386 | loss = torch.mean(loss) 387 | 388 | return loss 389 | -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torchvision.models as models 6 | import numpy as np 7 | from geotnf.transformation import GeometricTnf 8 | 9 | 10 | def featureL2Norm(feature): 11 | epsilon = 1e-6 12 | norm = torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).unsqueeze(1).expand_as(feature) 13 | return torch.div(feature,norm) 14 | 15 | 16 | class FeatureExtraction(torch.nn.Module): 17 | 18 | def __init__(self, 19 | train_fe=False, 20 | feature_extraction_cnn='resnet-101', 21 | normalization=True, 22 | last_layer='', 23 | use_cuda=True): 24 | 25 | super(FeatureExtraction, self).__init__() 26 | 27 | self.normalization = normalization 28 | 29 | if feature_extraction_cnn == 'resnet101': 30 | self.model = models.resnet101(pretrained=True) 31 | resnet_feature_layers = ['conv1', 32 | 'bn1', 33 | 'relu', 34 | 'maxpool', 35 | 'layer1', 36 | 'layer2', 37 | 'layer3', 38 | 'layer4'] 39 | if last_layer=='': 40 | last_layer = 'layer3' 41 | last_layer_idx = resnet_feature_layers.index(last_layer) 42 | resnet_module_list = [self.model.conv1, 43 | self.model.bn1, 44 | self.model.relu, 45 | self.model.maxpool, 46 | self.model.layer1, 47 | self.model.layer2, 48 | self.model.layer3, 49 | self.model.layer4] 50 | 51 | self.model = nn.Sequential(*resnet_module_list[:last_layer_idx+1]) 52 | 53 | if feature_extraction_cnn == 'resnet101_v2': 54 | self.model = models.resnet101(pretrained=True) 55 | # keep feature extraction network up to pool4 (last layer - 7) 56 | self.model = nn.Sequential(*list(self.model.children())[:-3]) 57 | 58 | if not train_fe: 59 | # freeze parameters 60 | for param in self.model.parameters(): 61 | param.requires_grad = False 62 | 63 | if use_cuda: 64 | self.model = self.model.cuda() 65 | 66 | def forward(self, image_batch): 67 | features = self.model(image_batch) 68 | if self.normalization: 69 | features = featureL2Norm(features) 70 | return features 71 | 72 | 73 | class FeatureCorrelation(torch.nn.Module): 74 | 75 | def __init__(self, 76 | shape='3D', 77 | normalization=True): 78 | 79 | super(FeatureCorrelation, self).__init__() 80 | 81 | self.normalization = normalization 82 | self.shape=shape 83 | self.ReLU = nn.ReLU() 84 | 85 | 86 | def forward(self, feature_A, feature_B): 87 | 88 | b,c,h,w = feature_A.size() 89 | 90 | if self.shape=='3D': 91 | # reshape features for matrix multiplication 92 | feature_A = feature_A.transpose(2,3).contiguous().view(b,c,h*w) 93 | feature_B = feature_B.view(b,c,h*w).transpose(1,2) 94 | # perform matrix mult. 95 | feature_mul = torch.bmm(feature_B,feature_A) 96 | 97 | # batch x (h_A x w_A) x h_B x w_B 98 | correlation_tensor = feature_mul.view(b,h,w,h*w).transpose(2,3).transpose(1,2) 99 | 100 | 101 | elif self.shape=='4D': 102 | # reshape features for matrix multiplication 103 | feature_A = feature_A.view(b,c,h*w).transpose(1,2) # size [b,c,h*w] 104 | feature_B = feature_B.view(b,c,h*w) # size [b,c,h*w] 105 | # perform matrix mult. 106 | feature_mul = torch.bmm(feature_A,feature_B) 107 | # indexed [batch, row_A, col_A, row_B, col_B] 108 | correlation_tensor = feature_mul.view(b,h,w,h,w).unsqueeze(1) 109 | 110 | if self.normalization: 111 | correlation_tensor = featureL2Norm(self.ReLU(correlation_tensor)) 112 | 113 | return correlation_tensor 114 | 115 | 116 | class FeatureRegression(nn.Module): 117 | 118 | def __init__(self, 119 | output_dim=6, 120 | use_cuda=True, 121 | batch_normalization=True, 122 | kernel_sizes=[7,5], 123 | channels=[128,64] , 124 | feature_size=15): 125 | 126 | super(FeatureRegression, self).__init__() 127 | 128 | num_layers = len(kernel_sizes) 129 | nn_modules = list() 130 | for i in range(num_layers): 131 | 132 | if i==0: 133 | ch_in = feature_size*feature_size 134 | else: 135 | ch_in = channels[i-1] 136 | 137 | ch_out = channels[i] 138 | k_size = kernel_sizes[i] 139 | nn_modules.append(nn.Conv2d(ch_in, ch_out, kernel_size=k_size, padding=0)) 140 | 141 | if batch_normalization: 142 | nn_modules.append(nn.BatchNorm2d(ch_out)) 143 | nn_modules.append(nn.ReLU(inplace=True)) 144 | 145 | self.conv = nn.Sequential(*nn_modules) 146 | self.linear = nn.Linear(ch_out * k_size * k_size, output_dim) 147 | 148 | if use_cuda: 149 | self.conv.cuda() 150 | self.linear.cuda() 151 | 152 | def forward(self, x): 153 | x = self.conv(x) 154 | x = x.view(x.size(0), -1) 155 | x = self.linear(x) 156 | return x 157 | 158 | 159 | class CNNGeometric(nn.Module): 160 | 161 | def __init__(self, output_dim=6, 162 | feature_extraction_cnn='vgg', 163 | feature_extraction_last_layer='', 164 | return_correlation=False, 165 | fr_feature_size=15, 166 | fr_kernel_sizes=[7,5], 167 | fr_channels=[128,64], 168 | feature_self_matching=False, 169 | normalize_features=True, 170 | normalize_matches=True, 171 | batch_normalization=True, 172 | train_fe=False,use_cuda=True): 173 | 174 | super(CNNGeometric, self).__init__() 175 | 176 | self.use_cuda = use_cuda 177 | self.feature_self_matching = feature_self_matching 178 | self.normalize_features = normalize_features 179 | self.normalize_matches = normalize_matches 180 | self.return_correlation = return_correlation 181 | self.FeatureExtraction = FeatureExtraction(train_fe=train_fe, 182 | feature_extraction_cnn=feature_extraction_cnn, 183 | last_layer=feature_extraction_last_layer, 184 | normalization=normalize_features, 185 | use_cuda=self.use_cuda) 186 | 187 | self.FeatureCorrelation = FeatureCorrelation(shape='3D',normalization=normalize_matches) 188 | 189 | 190 | self.FeatureRegression = FeatureRegression(output_dim, 191 | use_cuda=self.use_cuda, 192 | feature_size=fr_feature_size, 193 | kernel_sizes=fr_kernel_sizes, 194 | channels=fr_channels, 195 | batch_normalization=batch_normalization) 196 | 197 | 198 | self.ReLU = nn.ReLU(inplace=True) 199 | 200 | # used only for foward pass at eval and for training with strong supervision 201 | def forward(self, tnf_batch): 202 | # feature extraction 203 | feature_A = self.FeatureExtraction(tnf_batch['image_A']) 204 | feature_B = self.FeatureExtraction(tnf_batch['image_B']) 205 | # feature correlation 206 | correlation = self.FeatureCorrelation(feature_A,feature_B) 207 | # regression to tnf parameters theta 208 | theta = self.FeatureRegression(correlation) 209 | 210 | if self.return_correlation: 211 | return (theta,correlation) 212 | else: 213 | return theta 214 | 215 | 216 | class WeakMatchNet(CNNGeometric): 217 | 218 | def __init__(self, 219 | fr_feature_size=15, 220 | fr_kernel_sizes=[7,5], 221 | fr_channels=[128,64], 222 | feature_extraction_cnn='vgg', 223 | feature_extraction_last_layer='', 224 | return_correlation=False, 225 | normalize_features=True, 226 | normalize_matches=True, 227 | batch_normalization=True, 228 | train_fe=False, 229 | use_cuda=True, 230 | s1_output_dim=6, 231 | s2_output_dim=18): 232 | 233 | super(WeakMatchNet, self).__init__(output_dim=s1_output_dim, 234 | fr_feature_size=fr_feature_size, 235 | fr_kernel_sizes=fr_kernel_sizes, 236 | fr_channels=fr_channels, 237 | feature_extraction_cnn=feature_extraction_cnn, 238 | feature_extraction_last_layer=feature_extraction_last_layer, 239 | return_correlation=return_correlation, 240 | normalize_features=normalize_features, 241 | normalize_matches=normalize_matches, 242 | batch_normalization=batch_normalization, 243 | train_fe=train_fe, 244 | use_cuda=use_cuda) 245 | 246 | if s1_output_dim==6: 247 | self.geoTnf = GeometricTnf(geometric_model='affine', 248 | use_cuda=use_cuda) 249 | else: 250 | tps_grid_size = np.sqrt(s2_output_dim/2) 251 | self.geoTnf = GeometricTnf(geometric_model='tps', 252 | tps_grid_size=tps_grid_size, 253 | use_cuda=use_cuda) 254 | 255 | self.FeatureRegression2 = FeatureRegression(output_dim=s2_output_dim, 256 | use_cuda=use_cuda, 257 | feature_size=fr_feature_size, 258 | kernel_sizes=fr_kernel_sizes, 259 | channels=fr_channels, 260 | batch_normalization=batch_normalization) 261 | 262 | def forward(self, batch, training=False): 263 | 264 | if not training: 265 | """ Affine """ 266 | img_A, img_B = batch['image_A'], batch['image_B'] 267 | f_A, f_B = self.get_feature2(img_A, img_B) 268 | corr_AB = self.get_corr1(f_A, f_B) 269 | aff_AB = self.get_affine1(corr_AB) 270 | 271 | """ Tps """ 272 | img_Awrp = self.warp_image(img_A, aff_AB) 273 | f_Awrp = self.get_feature1(img_Awrp) 274 | corr_Awrp_B = self.get_corr1(f_Awrp, f_B) 275 | tps_Awrp_B = self.get_tps1(corr_Awrp_B) 276 | 277 | return aff_AB, tps_Awrp_B 278 | 279 | else: 280 | """ Affine """ 281 | img_A, img_B, img_C = batch['image_A'], batch['image_B'], batch['image_C'] 282 | 283 | f_A, f_B, f_C = self.get_feature3(img_A, img_B, img_C) 284 | 285 | corr_AB, corr_BA = self.get_corr2(f_A, f_B) 286 | corr_BC, corr_CB = self.get_corr2(f_B, f_C) 287 | corr_CA, corr_AC = self.get_corr2(f_C, f_A) 288 | 289 | aff_AB, aff_BA = self.get_affine2(corr_AB, corr_BA) 290 | aff_BC, aff_CB = self.get_affine2(corr_BC, corr_CB) 291 | aff_CA, aff_AC = self.get_affine2(corr_CA, corr_AC) 292 | 293 | aff_dict = { 294 | 'aff_AB': aff_AB, 295 | 'aff_BA': aff_BA, 296 | 'aff_BC': aff_BC, 297 | 'aff_CB': aff_CB, 298 | 'aff_CA': aff_CA, 299 | 'aff_AC': aff_AC, 300 | } 301 | 302 | corr_dict = { 303 | 'corr_AB': corr_AB, 304 | 'corr_BA': corr_BA, 305 | 'corr_BC': corr_BC, 306 | 'corr_CB': corr_CB, 307 | 'corr_CA': corr_CA, 308 | 'corr_AC': corr_AC, 309 | } 310 | 311 | """ Tps """ 312 | img_Awrp_B = self.warp_image(img_A, aff_AB) # should better align img_B 313 | img_Bwrp_A = self.warp_image(img_B, aff_BA) # should better align img_A 314 | 315 | img_Bwrp_C = self.warp_image(img_B, aff_BC) # should better align img_C 316 | img_Cwrp_B = self.warp_image(img_C, aff_CB) # should better align img_B 317 | 318 | img_Cwrp_A = self.warp_image(img_C, aff_CA) # should better align img_A 319 | img_Awrp_C = self.warp_image(img_A, aff_AC) # should better align img_C 320 | 321 | f_Awrp_B, f_Bwrp_A = self.get_feature2(img_Awrp_B, img_Bwrp_A) 322 | f_Bwrp_C, f_Cwrp_B = self.get_feature2(img_Bwrp_C, img_Cwrp_B) 323 | f_Cwrp_A, f_Awrp_C = self.get_feature2(img_Cwrp_A, img_Awrp_C) 324 | 325 | corr_Awrp_B = self.get_corr1(f_Awrp_B, f_B) 326 | corr_Bwrp_A = self.get_corr1(f_Bwrp_A, f_A) 327 | 328 | corr_Bwrp_C = self.get_corr1(f_Bwrp_C, f_C) 329 | corr_Cwrp_B = self.get_corr1(f_Cwrp_B, f_B) 330 | 331 | corr_Awrp_C = self.get_corr1(f_Awrp_C, f_C) 332 | corr_Cwrp_A = self.get_corr1(f_Cwrp_A, f_A) 333 | 334 | tps_Awrp_B, tps_Bwrp_A = self.get_tps2(corr_Awrp_B, corr_Bwrp_A) 335 | tps_Bwrp_C, tps_Cwrp_B = self.get_tps2(corr_Bwrp_C, corr_Cwrp_B) 336 | tps_Cwrp_A, tps_Awrp_C = self.get_tps2(corr_Cwrp_A, corr_Awrp_C) 337 | 338 | tps_dict = { 339 | 'tps_Awrp_B': tps_Awrp_B, 340 | 'tps_Bwrp_A': tps_Bwrp_A, 341 | 'tps_Bwrp_C': tps_Bwrp_C, 342 | 'tps_Cwrp_B': tps_Cwrp_B, 343 | 'tps_Cwrp_A': tps_Cwrp_A, 344 | 'tps_Awrp_C': tps_Awrp_C, 345 | } 346 | 347 | return aff_dict, tps_dict, corr_dict 348 | 349 | def get_feature1(self, img_A): 350 | f_A = self.FeatureExtraction(img_A) 351 | return f_A 352 | 353 | def get_feature2(self, img_A, img_B): 354 | f_A = self.get_feature1(img_A) 355 | f_B = self.get_feature1(img_B) 356 | return f_A, f_B 357 | 358 | def get_feature3(self, img_A, img_B, img_C): 359 | f_A = self.get_feature1(img_A) 360 | f_B = self.get_feature1(img_B) 361 | f_C = self.get_feature1(img_C) 362 | return f_A, f_B, f_C 363 | 364 | def get_corr1(self, f_A, f_B): 365 | corr_AB = self.FeatureCorrelation(f_A, f_B) # (h_A x w_A) x h_B x w_B -> for T_AB 366 | return corr_AB 367 | 368 | def get_corr2(self, f_A, f_B): 369 | corr_AB = self.get_corr1(f_A, f_B) 370 | corr_BA = self.get_corr1(f_B, f_A) 371 | return corr_AB, corr_BA 372 | 373 | def get_affine1(self, corr_AB): 374 | aff_AB = self.FeatureRegression(corr_AB) # warp img_A to align img_B -> T_AB 375 | return aff_AB 376 | 377 | def get_affine2(self, corr_AB, corr_BA): 378 | aff_AB = self.get_affine1(corr_AB) 379 | aff_BA = self.get_affine1(corr_BA) 380 | return aff_AB, aff_BA 381 | 382 | def get_tps1(self, corr_AB): 383 | tps_AB = self.FeatureRegression2(corr_AB) 384 | return tps_AB 385 | 386 | def get_tps2(self, corr_AB, corr_BA): 387 | tps_AB = self.FeatureRegression2(corr_AB) 388 | tps_BA = self.FeatureRegression2(corr_BA) 389 | return tps_AB, tps_BA 390 | 391 | def warp_image(self, img, theta): 392 | warped_img = self.geoTnf(img, theta) 393 | return warped_img 394 | 395 | 396 | def conv(in_channel, 397 | out_channel, 398 | kernel_size=3, 399 | stride=1, 400 | dilation=1, 401 | bias=False, 402 | transposed=False): 403 | 404 | if transposed: 405 | layer = nn.ConvTranspose2d(in_channel, 406 | out_channel, 407 | kernel_size=kernel_size, 408 | stride=stride, 409 | padding=1, 410 | output_padding=1, 411 | dilation=dilation, 412 | bias=bias) 413 | w = torch.Tensor(kernel_size, kernel_size) 414 | center = kernel_size % 2 == 1 and stride - 1 or stride - 0.5 415 | for y in range(kernel_size): 416 | for x in range(kernel_size): 417 | w[y, x] = (1 - abs((x - center) / stride)) * (1 - abs((y - center) / stride)) 418 | layer.weight.data.copy_(w.div(in_channel).repeat(out_channel, in_channel, 1, 1)) 419 | else: 420 | padding = (kernel_size + 2 * (dilation - 1)) // 2 421 | layer = nn.Conv2d(in_channel, 422 | out_channel, 423 | kernel_size=kernel_size, 424 | stride=stride, 425 | padding=padding, 426 | dilation=dilation, 427 | bias=bias) 428 | if bias: 429 | nn.init.constant(layer.bias, 0) 430 | return layer 431 | 432 | 433 | def bn(channel): 434 | layer = nn.BatchNorm2d(channel) 435 | nn.init.constant(layer.weight, 1) 436 | nn.init.constant(layer.bias, 0) 437 | return layer 438 | 439 | 440 | class Decoder(nn.Module): 441 | def __init__(self): 442 | super(Decoder, self).__init__() 443 | self.deconv1 = conv(1249, 512, stride=2, transposed=True) 444 | self.bn1 = bn(512) 445 | self.relu1 = nn.ReLU() 446 | self.deconv2 = conv(512, 256, stride=2, transposed=True) 447 | self.bn2 = bn(256) 448 | self.relu2 = nn.ReLU() 449 | self.deconv3 = conv(256, 64, stride=2, transposed=True) 450 | self.bn3 = bn(64) 451 | self.relu3 = nn.ReLU() 452 | self.deconv4 = conv(64, 1, stride=2, transposed=True) 453 | self.sigmoid = nn.Sigmoid() 454 | 455 | def forward(self, data, features): 456 | _, x3, x2, x1 = features 457 | deconv1 = self.deconv1(data) 458 | bn1 = self.bn1(deconv1) 459 | relu1 = self.relu1(bn1) 460 | deconv2 = self.deconv2(relu1+x3) 461 | bn2 = self.bn2(deconv2) 462 | relu2 = self.relu2(bn2) 463 | deconv3 = self.deconv3(relu2+x2) 464 | bn3 = self.bn3(deconv3) 465 | relu3 = self.relu3(bn3) 466 | deconv4 = self.deconv4(relu3+x1) 467 | 468 | return deconv4 469 | 470 | 471 | class WeakCosegNet(CNNGeometric): 472 | 473 | def __init__(self, 474 | train_fe=False, 475 | normalize_features=True, 476 | normalize_matches=True, 477 | feature_extraction_cnn='vgg', 478 | feature_extraction_last_layer='', 479 | use_cuda=True): 480 | 481 | super(WeakCosegNet, self).__init__(use_cuda=use_cuda) 482 | 483 | self.FeatureExtraction = FeatureExtraction(train_fe=train_fe, 484 | feature_extraction_cnn=feature_extraction_cnn, 485 | last_layer=feature_extraction_last_layer, 486 | normalization=normalize_features, 487 | use_cuda=use_cuda) 488 | 489 | self.FeatureCorrelation = FeatureCorrelation(shape='3D', 490 | normalization=normalize_matches) 491 | 492 | self.Decoder = Decoder() 493 | 494 | if use_cuda: 495 | self.Decoder = self.Decoder.cuda() 496 | 497 | 498 | def forward(self, batch): 499 | 500 | img_A, img_B = batch['image_A'], batch['image_B'] 501 | 502 | features_A = self.get_features(img_A) 503 | festures_B = self.get_features(img_B) 504 | 505 | f_A = features_A[-1] 506 | f_B = features_B[-1] 507 | 508 | corr_AB, corr_BA = self.get_corr(f_A, f_B) 509 | 510 | C_A = torch.cat((f_A, corr_BA), dim=1) 511 | C_B = torch.cat((f_B, corr_AB), dim=1) 512 | 513 | mask_A = self.get_mask(C_A, features_A) 514 | mask_B = self.get_mask(C_B, features_B) 515 | 516 | mask_dict = { 517 | 'mask_A': mask_A, 518 | 'mask_B': mask_B, 519 | } 520 | 521 | return mask_dict 522 | 523 | def get_feature(self, data): 524 | x = data 525 | f = [] 526 | for idx, module in self.FeatureExtraction.model._modules.items(): 527 | x = module(x) 528 | if idx in ['2', '4', '5', '6']: 529 | print('size:', x.size()) 530 | f.append(x) 531 | return f 532 | 533 | def get_corr(self, f_A, f_B): 534 | corr_AB = self.FeatureCorrelation(f_A, f_B) # (h_A x w_A) x h_B x w_B 535 | corr_BA = self.FeatureCorrelation(f_B, f_A) # (h_B x w_B) x h_A x w_A 536 | return corr_AB, corr_BA 537 | 538 | def get_mask(self, C, features): 539 | mask = self.Decoder(C, features) 540 | return mask 541 | -------------------------------------------------------------------------------- /parser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/MaCoSNet-pytorch/697ea3a824c5bc9bbca6fac8d33367a8b9def4d1/parser/__init__.py -------------------------------------------------------------------------------- /parser/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from util.torch_util import str_to_bool 3 | 4 | 5 | class ArgumentParser(): 6 | 7 | def __init__(self, mode='train'): 8 | self.parser = argparse.ArgumentParser(description='PAMI implementation') 9 | self.add_base_parameters() 10 | self.add_cnn_model_parameters() 11 | 12 | if mode == 'train': 13 | self.add_train_parameters() 14 | self.add_loss_parameters() 15 | self.add_dataset_parameters() 16 | self.add_losses_parameters() 17 | 18 | elif mode == 'eval': 19 | self.add_eval_parameters() 20 | 21 | 22 | def add_base_parameters(self): 23 | base_params = self.parser.add_argument_group('base') 24 | # Image size 25 | base_params.add_argument('--image-size', type=int, default=240, help='image input size') 26 | # Pre-trained model file 27 | base_params.add_argument('--model', type=str, default='', help='Pre-trained model filename') 28 | base_params.add_argument('--model-aff', type=str, default='', help='Trained affine model filename') 29 | base_params.add_argument('--model-tps', type=str, default='', help='Trained TPS model filename') 30 | base_params.add_argument('--model-type', type=str, default='match', help='Which model') 31 | # GPU 32 | base_params.add_argument('--gpu', type=int, default=0, help='gpu id') 33 | base_params.add_argument('--num-workers', type=int, default=8, help='number of workers') 34 | 35 | 36 | def add_dataset_parameters(self): 37 | dataset_params = self.parser.add_argument_group('dataset') 38 | # Image pair dataset parameters for train/val 39 | dataset_params.add_argument('--categories', nargs='+', type=int, default=0, help='indices of categories for training/eval') 40 | # Eval dataset parameters for early stopping 41 | dataset_params.add_argument('--eval-dataset', type=str, default='pf-pascal', help='Validation dataset used for early stopping') 42 | dataset_params.add_argument('--pck-alpha', type=float, default=0.1, help='pck margin factor alpha') 43 | dataset_params.add_argument('--eval-metric', type=str, default='pck', help='pck/distance') 44 | # Random synth dataset parameters 45 | dataset_params.add_argument('--random-crop', type=str_to_bool, nargs='?', const=True, default=True, help='use random crop augmentation') 46 | 47 | 48 | def add_train_parameters(self): 49 | train_params = self.parser.add_argument_group('train') 50 | # Optimization parameters 51 | train_params.add_argument('--lr', type=float, default=0.001, help='learning rate') 52 | train_params.add_argument('--momentum', type=float, default=0.9, help='momentum constant') 53 | train_params.add_argument('--num-epochs', type=int, default=10, help='number of training epochs') 54 | train_params.add_argument('--batch-size', type=int, default=16, help='training batch size') 55 | train_params.add_argument('--weight-decay', type=float, default=0, help='weight decay constant') 56 | train_params.add_argument('--seed', type=int, default=1, help='Pseudo-RNG seed') 57 | train_params.add_argument('--geometric-model', type=str, default='affine', help='geometric model to be regressed at output: affine or tps') 58 | # Trained model parameters 59 | train_params.add_argument('--result-model-dir', type=str, default='trained_models', help='path to trained models folder') 60 | # Dataset name (used for loading defaults) 61 | train_params.add_argument('--training-dataset', type=str, default='pf-pascal', help='dataset to use for training') 62 | # Parts of model to train 63 | train_params.add_argument('--train-fe', type=str_to_bool, nargs='?', const=True, default=True, help='Train feature extraction') 64 | train_params.add_argument('--train-fr', type=str_to_bool, nargs='?', const=True, default=True, help='Train feature regressor') 65 | train_params.add_argument('--train-bn', type=str_to_bool, nargs='?', const=True, default=True, help='train batch-norm layers') 66 | train_params.add_argument('--fe-finetune-params', nargs='+', type=str, default=[''], help='String indicating the F.Ext params to finetune') 67 | train_params.add_argument('--update-bn-buffers', type=str_to_bool, nargs='?', const=True, default=False, help='Update batch norm running mean and std') 68 | train_params.add_argument('--self-correlation', type=str_to_bool, nargs='?', const=True, default=True, help='Compute self correlation') 69 | train_params.add_argument('--seg-mask', type=str_to_bool, nargs='?', const=True, default=True, help='Use segmentation mask') 70 | 71 | 72 | def add_loss_parameters(self): 73 | loss_params = self.parser.add_argument_group('loss') 74 | # Parameters of weak loss 75 | loss_params.add_argument('--tps-grid-size', type=int, default=3, help='tps grid size') 76 | loss_params.add_argument('--tps-reg-factor', type=float, default=0.2, help='tps regularization factor') 77 | loss_params.add_argument('--normalize-inlier-count', type=str_to_bool, nargs='?', const=True, default=True) 78 | loss_params.add_argument('--dilation-filter', type=int, default=0, help='type of dilation filter: 0:no filter;1:4-neighs;2:8-neighs') 79 | loss_params.add_argument('--use-conv-filter', type=str_to_bool, nargs='?', const=True, default=False, help='use conv filter instead of dilation') 80 | 81 | def add_losses_parameters(self): 82 | losses_params = self.parser.add_argument_group('loss-param') 83 | # Loss parameters 84 | losses_params.add_argument('--match-loss', type=str_to_bool, nargs='?', const=True, default=False, help='use foreground guided matching loss?') 85 | losses_params.add_argument('--w-match', type=float, default=0.0, help='weight for foreground guided matching loss') 86 | losses_params.add_argument('--cycle-loss', type=str_to_bool, nargs='?', const=True, default=False, help='use cycle consistency loss?') 87 | losses_params.add_argument('--w-cycle', type=float, default=0.0, help='weight for cycle consistency loss') 88 | losses_params.add_argument('--trans-loss', type=str_to_bool, nargs='?', const=True, default=False, help='use transitive consistency loss?') 89 | losses_params.add_argument('--w-trans', type=float, default=0.0, help='weight for transitive consistency loss') 90 | losses_params.add_argument('--coseg-loss', type=str_to_bool, nargs='?', const=True, default=False, help='use perceptual contrastive loss?') 91 | losses_params.add_argument('--w-coseg', type=float, default=0.0, help='weight for perceptual contrastive loss') 92 | losses_params.add_argument('--task-loss', type=str_to_bool, nargs='?', const=True, default=False, help='use task consistency loss?') 93 | losses_params.add_argument('--w-task', type=float, default=0.0, help='weight for task consistency loss') 94 | losses_params.add_argument('--grid-loss', type=str_to_bool, nargs='?', const=True, default=False, help='use transformed grid loss?') 95 | losses_params.add_argument('--w-grid', type=float, default=0.0, help='weight for transformed grid loss') 96 | 97 | 98 | def add_eval_parameters(self): 99 | eval_params = self.parser.add_argument_group('eval') 100 | # Evaluation parameters 101 | eval_params.add_argument('--eval-dataset', type=str, default='pf-pascal', help='pf/caltech/tss') 102 | eval_params.add_argument('--flow-output-dir', type=str, default='results/', help='flow output dir') 103 | eval_params.add_argument('--pck-alpha', type=float, default=0.1, help='pck margin factor alpha') 104 | eval_params.add_argument('--eval-metric', type=str, default='pck', help='pck/distance') 105 | eval_params.add_argument('--tps-reg-factor', type=float, default=0.0, help='regularisation factor for tps tnf') 106 | eval_params.add_argument('--batch-size', type=int, default=16, help='training batch size') 107 | eval_params.add_argument('--log-dir', type=str, default='', help='log directory') 108 | eval_params.add_argument('--self-correlation', type=str_to_bool, nargs='?', const=True, default=False, help='Compute self correlation') 109 | 110 | 111 | def add_cnn_model_parameters(self): 112 | model_params = self.parser.add_argument_group('model') 113 | # Model parameters 114 | model_params.add_argument('--feature-extraction-cnn', type=str, default='resnet101', help='feature extraction CNN model architecture: vgg/resnet101') 115 | model_params.add_argument('--feature-extraction-last-layer', type=str, default='', help='feature extraction CNN last layer') 116 | model_params.add_argument('--fr-feature-size', type=int, default=15, help='image input size') 117 | model_params.add_argument('--fr-kernel-sizes', nargs='+', type=int, default=[7,5], help='kernels sizes in feat.reg. conv layers') 118 | model_params.add_argument('--fr-channels', nargs='+', type=int, default=[128,64], help='channels in feat. reg. conv layers') 119 | 120 | 121 | def parse(self, arg_str=None): 122 | 123 | if arg_str is None: 124 | args = self.parser.parse_args() 125 | 126 | else: 127 | args = self.parser.parse_args(arg_str.split()) 128 | 129 | arg_groups = {} 130 | for group in self.parser._action_groups: 131 | group_dict = { a.dest: getattr(args, a.dest, None) for a in group._group_actions } 132 | arg_groups[group.title] = group_dict 133 | 134 | return (args, arg_groups) 135 | 136 | 137 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.1.0 2 | torchvision==0.3.0 3 | scikit-image 4 | pandas 5 | opencv-python 6 | requests 7 | click 8 | -------------------------------------------------------------------------------- /train-coseg.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from model.loss import CosegLoss 8 | 9 | from util.util import init_model, init_model_optim 10 | from util.util import init_train_data, init_eval_data 11 | from util.util import save_model 12 | 13 | from util.eval_util import compute_metric 14 | from util.torch_util import BatchTensorToVars 15 | from parser.parser import ArgumentParser 16 | import config 17 | 18 | 19 | args, arg_groups = ArgumentParser(mode='train').parse() 20 | 21 | if not os.path.exists(args.result_model_dir): 22 | os.makedirs(args.result_model_dir) 23 | 24 | torch.cuda.set_device(args.gpu) 25 | use_cuda = torch.cuda.is_available() 26 | 27 | torch.manual_seed(args.seed) 28 | if use_cuda: 29 | torch.cuda.manual_seed(args.seed) 30 | np.random.seed(args.seed) 31 | 32 | 33 | Coseg = CosegLoss(use_cuda=use_cuda) 34 | 35 | 36 | def loss_coseg(batch, mask_dict): 37 | coseg_loss = Coseg(batch, mask_dict) 38 | return coseg_loss 39 | 40 | 41 | def print_loss(epoch, idx, num, loss_dict): 42 | print_string = 'Epoch: {} [{}/{} ({:.0f}%)]'.format(epoch, idx, num, 100. * batch_idx / num) 43 | print_string += ' coseg: {:.6f}'.format(loss_dict['coseg']) 44 | print(print_string) 45 | return 46 | 47 | def process_epoch(epoch, model, model_opt, dataloader, batch_tnf, log_interval=100): 48 | 49 | for batch_idx, batch in enumerate(dataloader): 50 | 51 | batch = batch_tnf(batch) 52 | 53 | model_opt.zero_grad() 54 | 55 | loss_dict = { 56 | 'coseg': 0, 57 | } 58 | 59 | mask_dict = model(batch) 60 | 61 | loss = 0 62 | 63 | coseg_loss = loss_coseg(batch, mask_dict) 64 | loss_dict['coseg'] += coseg_loss.data.cpu().numpy() 65 | loss += args.w_coseg * coseg_loss 66 | 67 | loss.backward() 68 | model_opt.step() 69 | 70 | if batch_idx % log_interval == 0: 71 | print_loss(epoch, batch_idx, len(dataloader), loss_dict) 72 | return 73 | 74 | 75 | def main(): 76 | 77 | """ Initialize model """ 78 | model = init_model(args, arg_groups, use_cuda) 79 | 80 | 81 | """ Initialize dataloader """ 82 | train_data, train_loader = init_train_data(args) 83 | 84 | eval_data, eval_loader = init_eval_data(args) 85 | 86 | 87 | """ Initialize optimizer """ 88 | model_opt = init_model_optim(args, model) 89 | 90 | batch_tnf = BatchTensorToVars(use_cuda=use_cuda) 91 | 92 | """ Evaluate initial condition """ 93 | ''' 94 | eval_categories = np.array(range(20)) + 1 95 | eval_flag = np.zeros(len(eval_data)) 96 | for i in range(len(eval_data)): 97 | eval_flag[i] = sum(eval_categories == eval_data.category[i]) 98 | eval_idx = np.flatnonzero(eval_flag) 99 | 100 | model.eval() 101 | 102 | eval_stats = compute_metric(args.eval_metric, model, eval_data, eval_loader, batch_tnf, args) 103 | best_eval_pck = np.mean(eval_stats['aff_tps'][args.eval_metric][eval_idx]) 104 | ''' 105 | 106 | 107 | best_epoch = 1 108 | """ Start training """ 109 | for epoch in range(1, args.num_epochs+1): 110 | 111 | model.eval() 112 | 113 | process_epoch(epoch, model, model_opt, train_loader, batch_tnf) 114 | 115 | ''' 116 | model.eval() 117 | 118 | eval_stats = compute_metric(args.eval_metric, model, eval_data, eval_loader, batch_tnf, args) 119 | eval_pck = np.mean(eval_stats['aff_tps'][args.eval_metric][eval_idx]) 120 | 121 | is_best = eval_pck > best_eval_pck 122 | 123 | if eval_pck > best_eval_pck: 124 | best_eval_pck = eval_pck 125 | best_epoch = epoch 126 | 127 | print('eval: {:.3f}'.format(eval_pck), 128 | 'best eval: {:.3f}'.format(best_eval_pck), 129 | 'best epoch: {}'.format(best_epoch)) 130 | 131 | """ Early stopping """ 132 | if eval_pck < (best_eval_pck - 0.05): 133 | break 134 | ''' 135 | 136 | save_model(args, model, is_best) 137 | 138 | 139 | if __name__ == '__main__': 140 | 141 | main() 142 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from util.util import init_model, init_model_optim 8 | from util.util import init_train_data, init_eval_data 9 | from util.util import save_model 10 | 11 | from util.eval_util import compute_metric 12 | from util.torch_util import BatchTensorToVars 13 | from parser.parser import ArgumentParser 14 | import config 15 | 16 | 17 | args, arg_groups = ArgumentParser(mode='train').parse() 18 | 19 | if not os.path.exists(args.result_model_dir): 20 | os.makedirs(args.result_model_dir) 21 | 22 | torch.cuda.set_device(args.gpu) 23 | use_cuda = torch.cuda.is_available() 24 | 25 | torch.manual_seed(args.seed) 26 | if use_cuda: 27 | torch.cuda.manual_seed(args.seed) 28 | np.random.seed(args.seed) 29 | 30 | 31 | if args.match_loss: 32 | from model.loss import AffMatchScore, TpsMatchScore 33 | AffMatch = AffMatchScore(**arg_groups['loss'], seg_mask=args.seg_mask) 34 | TpsMatch = TpsMatchScore(use_cuda=use_cuda, **arg_groups['loss'], seg_mask=args.seg_mask) 35 | 36 | if args.cycle_loss: 37 | from model.loss import CycleLoss 38 | Cycle = CycleLoss(use_cuda=use_cuda, transform='affine') 39 | 40 | if args.trans_loss: 41 | from model.loss import TransLoss 42 | Trans = TransLoss(use_cuda=use_cuda, transform='affine') 43 | 44 | if args.coseg_loss: 45 | from model.loss import CosegLoss 46 | Coseg = CosegLoss(use_cuda=use_cuda, transform='affine') 47 | 48 | if args.task_loss: 49 | from model.loss import TaskLoss 50 | Task = TaskLoss(use_cuda=use_cuda, transform='affine') 51 | 52 | 53 | def gen_mask(corr_dict): 54 | 55 | mask_AB = torch.max(corr_dict['corr_AB'], dim=1, keepdim=True)[0] 56 | mask_BA = torch.max(corr_dict['corr_BA'], dim=1, keepdim=True)[0] 57 | 58 | mask_dict = { 59 | 'mask_AB': mask_AB, 60 | 'mask_BA': mask_BA, 61 | } 62 | 63 | return mask_dict 64 | 65 | 66 | def loss_match(aff_dict, tps_dict, corr_dict, seg_mask=False): 67 | 68 | mask_dict = { 69 | 'mask_AB': None, 70 | 'mask_BA': None, 71 | } 72 | 73 | if seg_mask: 74 | mask_dict = gen_mask(corr_dict) 75 | 76 | """ Affine matching score """ 77 | aff_AB = AffMatch(matches=corr_dict['corr_AB'], 78 | theta=aff_dict['aff_AB'], 79 | seg_mask=mask_dict['mask_AB']) 80 | aff_BA = AffMatch(matches=corr_dict['corr_BA'], 81 | theta=aff_dict['aff_BA'], 82 | seg_mask=mask_dict['mask_BA']) 83 | aff_match_score = (aff_AB + aff_BA) / 2.0 84 | 85 | """ TPS matching score """ 86 | tps_AB = TpsMatch(matches=corr_dict['corr_AB'], 87 | theta_aff=aff_dict['aff_AB'], 88 | theta_aff_tps=tps_dict['tps_Awrp_B'], 89 | seg_mask=mask_dict['mask_AB']) 90 | tps_BA = TpsMatch(matches=corr_dict['corr_BA'], 91 | theta_aff=aff_dict['aff_BA'], 92 | theta_aff_tps=tps_dict['tps_Bwrp_A'], 93 | seg_mask=mask_dict['mask_BA']) 94 | tps_match_score = (tps_AB + tps_BA) / 2.0 95 | 96 | match_score = aff_match_score + tps_match_score 97 | match_loss = torch.mean(-match_score) 98 | 99 | return match_loss 100 | 101 | 102 | def loss_cycle(aff_dict): 103 | cycle_AB = Cycle(aff_dict['aff_AB'], aff_dict['aff_BA']) 104 | cycle_BA = Cycle(aff_dict['aff_BA'], aff_dict['aff_AB']) 105 | cycle_loss = (cycle_AB + cycle_BA) / 2.0 106 | return cycle_loss 107 | 108 | 109 | def loss_trans(aff_dict): 110 | trans_ABCA = Trans(aff_dict['aff_AB'], aff_dict['aff_BC'], aff_dict['aff_CA']) 111 | trans_ACBA = Trans(aff_dict['aff_AC'], aff_dict['aff_CB'], aff_dict['aff_BA']) 112 | trans_BACB = Trans(aff_dict['aff_BA'], aff_dict['aff_AC'], aff_dict['aff_CB']) 113 | trans_BCAB = Trans(aff_dict['aff_BC'], aff_dict['aff_CA'], aff_dict['aff_AB']) 114 | trans_CABC = Trans(aff_dict['aff_CA'], aff_dict['aff_AB'], aff_dict['aff_BC']) 115 | trans_CBAC = Trans(aff_dict['aff_CB'], aff_dict['aff_BA'], aff_dict['aff_AC']) 116 | trans_loss = (trans_ABCA + trans_ACBA + trans_BACB + trans_BCAB + trans_CABC + trans_CBAC) / 6.0 117 | return trans_loss 118 | 119 | 120 | def loss_coseg(batch, mask_dict): 121 | coseg_loss = Coseg(batch, mask_dict) 122 | return coseg_loss 123 | 124 | 125 | def loss_task(aff_dict, mask_dict): 126 | task_loss = Task(aff_dict, mask_loss) 127 | return task_loss 128 | 129 | 130 | def print_loss(epoch, idx, num, loss_dict): 131 | print_string = 'Epoch: {} [{}/{} ({:.0f}%)]'.format(epoch, idx, num, 100. * batch_idx / num) 132 | if args.match_loss: 133 | print_string += ' match: {:.6f}'.format(loss_dict['match']) 134 | if args.cycle_loss: 135 | print_string += ' cycle: {:.6f}'.format(loss_dict['cycle']) 136 | if args.trans_loss: 137 | print_string += ' trans: {:.6f}'.format(loss_dict['trans']) 138 | if args.coseg_loss: 139 | print_string += ' coseg: {:.6f}'.format(loss_dict['coseg']) 140 | if args.task_loss: 141 | print_string += ' task: {:.6f}'.format(loss_dict['task']) 142 | print(print_string) 143 | return 144 | 145 | def process_epoch(epoch, model, model_opt, dataloader, batch_tnf, log_interval=100): 146 | 147 | for batch_idx, batch in enumerate(dataloader): 148 | 149 | batch = batch_tnf(batch) 150 | 151 | model_opt.zero_grad() 152 | 153 | loss_dict = { 154 | 'match': 0, 155 | 'cycle': 0, 156 | 'trans': 0, 157 | 'coseg': 0, 158 | 'task': 0, 159 | } 160 | 161 | aff_dict, tps_dict, corr_dict = model(batch) 162 | 163 | loss = 0 164 | 165 | if args.match_loss: 166 | match_loss = loss_match(aff_dict, tps_dict, corr_dict, seg_mask=args.seg_mask) 167 | loss_dict['match'] += match_loss.data.cpu().numpy() 168 | loss += args.w_match * match_loss 169 | 170 | if args.cycle_loss: 171 | cycle_loss = loss_cycle(aff_dict) 172 | loss_dict['cycle'] += cycle_loss.data.cpu().numpy() 173 | loss += args.w_cycle * cycle_loss 174 | 175 | if args.trans_loss: 176 | trans_loss = loss_trans(aff_dict) 177 | loss_dict['trans'] += trans_loss.data.cpu().numpy() 178 | loss += args.w_trans * trans_loss 179 | 180 | if args.coseg_loss: 181 | coseg_loss = loss_coseg(aff_dict) 182 | loss_dict['coseg'] += coseg_loss.data.cpu().numpy() 183 | loss += args.w_coseg * coseg_loss 184 | 185 | if args.task_loss: 186 | task_loss = loss_task(aff_dict) 187 | loss_dict['task'] += task_loss.data.cpu().numpy() 188 | loss += args.w_task * task_loss 189 | 190 | loss.backward() 191 | model_opt.step() 192 | 193 | if batch_idx % log_interval == 0: 194 | print_loss(epoch, batch_idx, len(dataloader), loss_dict) 195 | return 196 | 197 | 198 | def main(): 199 | 200 | """ Initialize model """ 201 | model = init_model(args, arg_groups, use_cuda) 202 | 203 | 204 | """ Initialize dataloader """ 205 | train_data, train_loader = init_train_data(args) 206 | 207 | eval_data, eval_loader = init_eval_data(args) 208 | 209 | 210 | """ Initialize optimizer """ 211 | model_opt = init_model_optim(args, model) 212 | 213 | batch_tnf = BatchTensorToVars(use_cuda=use_cuda) 214 | 215 | """ Evaluate initial condition """ 216 | eval_categories = np.array(range(20)) + 1 217 | eval_flag = np.zeros(len(eval_data)) 218 | for i in range(len(eval_data)): 219 | eval_flag[i] = sum(eval_categories == eval_data.category[i]) 220 | eval_idx = np.flatnonzero(eval_flag) 221 | 222 | model.eval() 223 | 224 | eval_stats = compute_metric(args.eval_metric, model, eval_data, eval_loader, batch_tnf, args) 225 | best_eval_pck = np.mean(eval_stats['aff_tps'][args.eval_metric][eval_idx]) 226 | 227 | 228 | best_epoch = 1 229 | """ Start training """ 230 | for epoch in range(1, args.num_epochs+1): 231 | 232 | model.eval() 233 | 234 | process_epoch(epoch, model, model_opt, train_loader, batch_tnf) 235 | 236 | model.eval() 237 | 238 | eval_stats = compute_metric(args.eval_metric, model, eval_data, eval_loader, batch_tnf, args) 239 | eval_pck = np.mean(eval_stats['aff_tps'][args.eval_metric][eval_idx]) 240 | 241 | is_best = eval_pck > best_eval_pck 242 | 243 | if eval_pck > best_eval_pck: 244 | best_eval_pck = eval_pck 245 | best_epoch = epoch 246 | 247 | print('eval: {:.3f}'.format(eval_pck), 248 | 'best eval: {:.3f}'.format(best_eval_pck), 249 | 'best epoch: {}'.format(best_epoch)) 250 | 251 | """ Early stopping """ 252 | if eval_pck < (best_eval_pck - 0.05): 253 | break 254 | 255 | save_model(args, model, is_best) 256 | 257 | 258 | if __name__ == '__main__': 259 | 260 | main() 261 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | MODEL_DIR=trained_models 2 | 3 | MODEL_PATH=$MODEL_DIR/weakalign.pth.tar 4 | 5 | GPU=0 6 | NUM_WORKERS=4 7 | BATCH_SIZE=4 8 | LEARNING_RATE=5e-8 9 | EPOCH=60 10 | 11 | DATASET=pf-pascal 12 | 13 | MATCH_LOSS=True 14 | CYCLE_LOSS=True 15 | TRANS_LOSS=False 16 | COSEG_LOSS=False 17 | TASK_LOSS=False 18 | 19 | W_MATCH=1.0 20 | W_CYCLE=1.0 21 | W_TRANS=0.0 22 | W_COSEG=0.0 23 | W_TASK=0.0 24 | 25 | python train.py \ 26 | --model $MODEL_PATH \ 27 | --training-dataset $DATASET \ 28 | --num-epochs $EPOCH \ 29 | --lr $LEARNING_RATE \ 30 | --gpu $GPU \ 31 | --num-workers $NUM_WORKERS \ 32 | --batch-size $BATCH_SIZE \ 33 | --result-model-dir $MODEL_DIR \ 34 | --match-loss $MATCH_LOSS \ 35 | --cycle-loss $CYCLE_LOSS \ 36 | --trans-loss $TRANS_LOSS \ 37 | --coseg-loss $COSEG_LOSS \ 38 | --task-loss $TASK_LOSS \ 39 | --w-match $W_MATCH \ 40 | --w-cycle $W_CYCLE \ 41 | --w-trans $W_TRANS \ 42 | --w-coseg $W_COSEG \ 43 | --w-task $W_TASK \ 44 | -------------------------------------------------------------------------------- /trained_models/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/MaCoSNet-pytorch/697ea3a824c5bc9bbca6fac8d33367a8b9def4d1/util/__init__.py -------------------------------------------------------------------------------- /util/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as multiprocessing 3 | from torch.utils.data.sampler import SequentialSampler, RandomSampler, BatchSampler 4 | import collections 5 | import sys 6 | import traceback 7 | import threading 8 | import numpy as np 9 | import numpy.random 10 | 11 | #from torch._six import string_classes 12 | PY2 = sys.version_info[0] == 2 13 | PY3 = sys.version_info[0] == 3 14 | 15 | if PY2: 16 | string_classes = basestring 17 | else: 18 | string_classes = (str, bytes) 19 | 20 | 21 | if sys.version_info[0] == 2: 22 | import Queue as queue 23 | else: 24 | import queue 25 | 26 | 27 | _use_shared_memory = False 28 | """Whether to use shared memory in default_collate""" 29 | 30 | 31 | class ExceptionWrapper(object): 32 | "Wraps an exception plus traceback to communicate across threads" 33 | 34 | def __init__(self, exc_info): 35 | self.exc_type = exc_info[0] 36 | self.exc_msg = "".join(traceback.format_exception(*exc_info)) 37 | 38 | 39 | def _worker_loop(dataset, index_queue, data_queue, collate_fn, rng_seed): 40 | global _use_shared_memory 41 | _use_shared_memory = True 42 | 43 | np.random.seed(rng_seed) 44 | torch.set_num_threads(1) 45 | while True: 46 | r = index_queue.get() 47 | if r is None: 48 | data_queue.put(None) 49 | break 50 | idx, batch_indices = r 51 | try: 52 | samples = collate_fn([dataset[i] for i in batch_indices]) 53 | except Exception: 54 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 55 | else: 56 | data_queue.put((idx, samples)) 57 | 58 | 59 | def _pin_memory_loop(in_queue, out_queue, done_event): 60 | while True: 61 | try: 62 | r = in_queue.get() 63 | except: 64 | if done_event.is_set(): 65 | return 66 | raise 67 | if r is None: 68 | break 69 | if isinstance(r[1], ExceptionWrapper): 70 | out_queue.put(r) 71 | continue 72 | idx, batch = r 73 | try: 74 | batch = pin_memory_batch(batch) 75 | except Exception: 76 | out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 77 | else: 78 | out_queue.put((idx, batch)) 79 | 80 | 81 | numpy_type_map = { 82 | 'float64': torch.DoubleTensor, 83 | 'float32': torch.FloatTensor, 84 | 'float16': torch.HalfTensor, 85 | 'int64': torch.LongTensor, 86 | 'int32': torch.IntTensor, 87 | 'int16': torch.ShortTensor, 88 | 'int8': torch.CharTensor, 89 | 'uint8': torch.ByteTensor, 90 | } 91 | 92 | 93 | def default_collate(batch): 94 | "Puts each data field into a tensor with outer dimension batch size" 95 | if torch.is_tensor(batch[0]): 96 | out = None 97 | if _use_shared_memory: 98 | # If we're in a background process, concatenate directly into a 99 | # shared memory tensor to avoid an extra copy 100 | numel = sum([x.numel() for x in batch]) 101 | storage = batch[0].storage()._new_shared(numel) 102 | out = batch[0].new(storage) 103 | return torch.stack(batch, 0, out=out) 104 | elif type(batch[0]).__module__ == 'numpy': 105 | elem = batch[0] 106 | if type(elem).__name__ == 'ndarray': 107 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 108 | if elem.shape == (): # scalars 109 | py_type = float if elem.dtype.name.startswith('float') else int 110 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 111 | elif isinstance(batch[0], int): 112 | return torch.LongTensor(batch) 113 | elif isinstance(batch[0], float): 114 | return torch.DoubleTensor(batch) 115 | elif isinstance(batch[0], string_classes): 116 | return batch 117 | elif isinstance(batch[0], collections.Mapping): 118 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 119 | elif isinstance(batch[0], collections.Sequence): 120 | transposed = zip(*batch) 121 | return [default_collate(samples) for samples in transposed] 122 | 123 | raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}" 124 | .format(type(batch[0])))) 125 | 126 | 127 | def pin_memory_batch(batch): 128 | if torch.is_tensor(batch): 129 | return batch.pin_memory() 130 | elif isinstance(batch, string_classes): 131 | return batch 132 | elif isinstance(batch, collections.Mapping): 133 | return {k: pin_memory_batch(sample) for k, sample in batch.items()} 134 | elif isinstance(batch, collections.Sequence): 135 | return [pin_memory_batch(sample) for sample in batch] 136 | else: 137 | return batch 138 | 139 | 140 | class DataLoaderIter(object): 141 | "Iterates once over the DataLoader's dataset, as specified by the sampler" 142 | 143 | def __init__(self, loader): 144 | self.dataset = loader.dataset 145 | self.collate_fn = loader.collate_fn 146 | self.batch_sampler = loader.batch_sampler 147 | self.num_workers = loader.num_workers 148 | self.pin_memory = loader.pin_memory 149 | self.done_event = threading.Event() 150 | 151 | self.sample_iter = iter(self.batch_sampler) 152 | 153 | if self.num_workers > 0: 154 | self.index_queue = multiprocessing.SimpleQueue() 155 | self.data_queue = multiprocessing.SimpleQueue() 156 | self.batches_outstanding = 0 157 | self.shutdown = False 158 | self.send_idx = 0 159 | self.rcvd_idx = 0 160 | self.reorder_dict = {} 161 | 162 | self.workers = [ 163 | multiprocessing.Process( 164 | target=_worker_loop, 165 | args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn, np.random.randint(0, 4294967296, dtype='uint32'))) 166 | for _ in range(self.num_workers)] 167 | 168 | for w in self.workers: 169 | w.daemon = True # ensure that the worker exits on process exit 170 | w.start() 171 | 172 | if self.pin_memory: 173 | in_data = self.data_queue 174 | self.data_queue = queue.Queue() 175 | self.pin_thread = threading.Thread( 176 | target=_pin_memory_loop, 177 | args=(in_data, self.data_queue, self.done_event)) 178 | self.pin_thread.daemon = True 179 | self.pin_thread.start() 180 | 181 | # prime the prefetch loop 182 | for _ in range(2 * self.num_workers): 183 | self._put_indices() 184 | 185 | def __len__(self): 186 | return len(self.batch_sampler) 187 | 188 | def __next__(self): 189 | if self.num_workers == 0: # same-process loading 190 | indices = next(self.sample_iter) # may raise StopIteration 191 | batch = self.collate_fn([self.dataset[i] for i in indices]) 192 | if self.pin_memory: 193 | batch = pin_memory_batch(batch) 194 | return batch 195 | 196 | # check if the next sample has already been generated 197 | if self.rcvd_idx in self.reorder_dict: 198 | batch = self.reorder_dict.pop(self.rcvd_idx) 199 | return self._process_next_batch(batch) 200 | 201 | if self.batches_outstanding == 0: 202 | self._shutdown_workers() 203 | raise StopIteration 204 | 205 | while True: 206 | assert (not self.shutdown and self.batches_outstanding > 0) 207 | idx, batch = self.data_queue.get() 208 | self.batches_outstanding -= 1 209 | if idx != self.rcvd_idx: 210 | # store out-of-order samples 211 | self.reorder_dict[idx] = batch 212 | continue 213 | return self._process_next_batch(batch) 214 | 215 | next = __next__ # Python 2 compatibility 216 | 217 | def __iter__(self): 218 | return self 219 | 220 | def _put_indices(self): 221 | assert self.batches_outstanding < 2 * self.num_workers 222 | indices = next(self.sample_iter, None) 223 | if indices is None: 224 | return 225 | self.index_queue.put((self.send_idx, indices)) 226 | self.batches_outstanding += 1 227 | self.send_idx += 1 228 | 229 | def _process_next_batch(self, batch): 230 | self.rcvd_idx += 1 231 | self._put_indices() 232 | if isinstance(batch, ExceptionWrapper): 233 | raise batch.exc_type(batch.exc_msg) 234 | return batch 235 | 236 | def __getstate__(self): 237 | # TODO: add limited pickling support for sharing an iterator 238 | # across multiple threads for HOGWILD. 239 | # Probably the best way to do this is by moving the sample pushing 240 | # to a separate thread and then just sharing the data queue 241 | # but signalling the end is tricky without a non-blocking API 242 | raise NotImplementedError("DataLoaderIterator cannot be pickled") 243 | 244 | def _shutdown_workers(self): 245 | if not self.shutdown: 246 | self.shutdown = True 247 | self.done_event.set() 248 | for _ in self.workers: 249 | self.index_queue.put(None) 250 | 251 | def __del__(self): 252 | if self.num_workers > 0: 253 | self._shutdown_workers() 254 | 255 | 256 | class DataLoader(object): 257 | """ 258 | Data loader. Combines a dataset and a sampler, and provides 259 | single- or multi-process iterators over the dataset. 260 | 261 | Arguments: 262 | dataset (Dataset): dataset from which to load the data. 263 | batch_size (int, optional): how many samples per batch to load 264 | (default: 1). 265 | shuffle (bool, optional): set to ``True`` to have the data reshuffled 266 | at every epoch (default: False). 267 | sampler (Sampler, optional): defines the strategy to draw samples from 268 | the dataset. If specified, ``shuffle`` must be False. 269 | batch_sampler (Sampler, optional): like sampler, but returns a batch of 270 | indices at a time. Mutually exclusive with batch_size, shuffle, 271 | sampler, and drop_last. 272 | num_workers (int, optional): how many subprocesses to use for data 273 | loading. 0 means that the data will be loaded in the main process 274 | (default: 0) 275 | collate_fn (callable, optional): merges a list of samples to form a mini-batch. 276 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors 277 | into CUDA pinned memory before returning them. 278 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 279 | if the dataset size is not divisible by the batch size. If False and 280 | the size of dataset is not divisible by the batch size, then the last batch 281 | will be smaller. (default: False) 282 | """ 283 | 284 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 285 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False): 286 | self.dataset = dataset 287 | self.batch_size = batch_size 288 | self.num_workers = num_workers 289 | self.collate_fn = collate_fn 290 | self.pin_memory = pin_memory 291 | self.drop_last = drop_last 292 | 293 | if batch_sampler is not None: 294 | if batch_size > 1 or shuffle or sampler is not None or drop_last: 295 | raise ValueError('batch_sampler is mutually exclusive with ' 296 | 'batch_size, shuffle, sampler, and drop_last') 297 | 298 | if sampler is not None and shuffle: 299 | raise ValueError('sampler is mutually exclusive with shuffle') 300 | 301 | if batch_sampler is None: 302 | if sampler is None: 303 | if shuffle: 304 | sampler = RandomSampler(dataset) 305 | else: 306 | sampler = SequentialSampler(dataset) 307 | batch_sampler = BatchSampler(sampler, batch_size, drop_last) 308 | 309 | self.sampler = sampler 310 | self.batch_sampler = batch_sampler 311 | 312 | def __iter__(self): 313 | return DataLoaderIter(self) 314 | 315 | def __len__(self): 316 | return len(self.batch_sampler) -------------------------------------------------------------------------------- /util/eval_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch 3 | import numpy as np 4 | import os 5 | from geotnf.transformation import GeometricTnf 6 | from geotnf.flow import th_sampling_grid_to_np_flow, write_flo_file 7 | import torch.nn.functional as F 8 | from data.pf_pascal import PFPascalVal 9 | from torch.autograd import Variable 10 | from geotnf.point_tnf import PointTnf, PointsToUnitCoords, PointsToPixelCoords 11 | 12 | try: 13 | from py_util import create_file_path 14 | except ImportError: 15 | from util.py_util import create_file_path 16 | 17 | from model.loss import TpsMatchScore 18 | 19 | 20 | def compute_metric(metric, model, dataset, dataloader, batch_tnf, args): 21 | N = len(dataset) 22 | stats = {} 23 | stats['aff_tps']={} 24 | 25 | if metric == 'pck': 26 | metrics = ['pck'] 27 | metric_fun = pck_metric 28 | 29 | elif metric == 'flow': 30 | metrics = ['flow'] 31 | metric_fun = flow_metrics 32 | 33 | for key in stats.keys(): 34 | for metric in metrics: 35 | stats[key][metric] = np.zeros((N,1)) 36 | 37 | for i, batch in enumerate(dataloader): 38 | batch = batch_tnf(batch) 39 | 40 | batch_start_idx = args.batch_size * i 41 | batch_end_idx = np.minimum(batch_start_idx + args.batch_size, N) 42 | 43 | model.eval() 44 | 45 | aff_dict, tps_dict, _ = model(batch) 46 | 47 | theta_aff = aff_dict['aff_AB'] 48 | theta_aff_tps = tps_dict['tps_Awrp_B'] 49 | 50 | stats = metric_fun(batch, batch_start_idx, theta_aff, theta_aff_tps, stats, args) 51 | 52 | #print('Batch: [{}/{} ({:.0f}%)]'.format(i, len(dataloader), 100. * i / len(dataloader))) 53 | 54 | 55 | if metric == 'flow': 56 | print('Flow files have been saved to '+ args.flow_output_dir) 57 | return stats 58 | 59 | print('\n') 60 | 61 | for key in stats.keys(): 62 | print('=== Results {} ==='.format(key)) 63 | for metric in metrics: 64 | if isinstance(dataset, PFPascalVal): 65 | N_cat = int(np.max(dataset.category)) 66 | for c in range(N_cat): 67 | cat_idx = np.nonzero(dataset.category == c+1)[0] 68 | print(dataset.category_names[c].ljust(15) + ': ', '{:.2%}'.format(np.mean(stats[key][metric][cat_idx]))) 69 | 70 | results = stats[key][metric] 71 | good_idx = np.flatnonzero((results!=-1) * ~np.isnan(results)) 72 | print('Total: '+ str(results.size)) 73 | print('Valid: '+ str(good_idx.size)) 74 | filtered_results = results[good_idx] 75 | print(metric + ':', '{:.2%}'.format(np.mean(filtered_results))) 76 | 77 | print('\n') 78 | 79 | return stats 80 | 81 | 82 | def pck(source_points, warped_points, L_pck, alpha=0.1): 83 | # compute precentage of correct keypoints 84 | batch_size = source_points.size(0) 85 | pck = torch.zeros((batch_size)) 86 | for i in range(batch_size): 87 | p_src = source_points[i,:] 88 | p_wrp = warped_points[i,:] 89 | N_pts = torch.sum(torch.ne(p_src[0,:],-1)*torch.ne(p_src[1,:],-1)) 90 | point_distance = torch.pow(torch.sum(torch.pow(p_src[:,:N_pts]-p_wrp[:,:N_pts],2),0),0.5) 91 | L_pck_mat = L_pck[i].expand_as(point_distance) 92 | correct_points = torch.le(point_distance,L_pck_mat*alpha) 93 | pck[i] = torch.mean(correct_points.float()) 94 | return pck 95 | 96 | 97 | def mean_dist(source_points,warped_points,L_pck): 98 | # compute precentage of correct keypoints 99 | batch_size=source_points.size(0) 100 | dist=torch.zeros((batch_size)) 101 | for i in range(batch_size): 102 | p_src = source_points[i,:] 103 | p_wrp = warped_points[i,:] 104 | N_pts = torch.sum(torch.ne(p_src[0,:],-1)*torch.ne(p_src[1,:],-1)) 105 | point_distance = torch.pow(torch.sum(torch.pow(p_src[:,:N_pts]-p_wrp[:,:N_pts],2),0),0.5) 106 | L_pck_mat = L_pck[i].expand_as(point_distance) 107 | dist[i]=torch.mean(torch.div(point_distance,L_pck_mat)) 108 | return dist 109 | 110 | def point_dist_metric(batch,batch_start_idx,theta_aff,theta_tps,theta_aff_tps,stats,args,use_cuda=True): 111 | do_aff = theta_aff is not None 112 | do_tps = theta_tps is not None 113 | do_aff_tps = theta_aff_tps is not None 114 | 115 | source_im_size = batch['source_im_size'] 116 | target_im_size = batch['target_im_size'] 117 | 118 | source_points = batch['source_points'] 119 | target_points = batch['target_points'] 120 | 121 | # Instantiate point transformer 122 | pt = PointTnf(use_cuda=use_cuda, 123 | tps_reg_factor=args.tps_reg_factor) 124 | 125 | # warp points with estimated transformations 126 | target_points_norm = PointsToUnitCoords(target_points,target_im_size) 127 | 128 | if do_aff: 129 | # do affine only 130 | warped_points_aff_norm = pt.affPointTnf(theta_aff,target_points_norm) 131 | warped_points_aff = PointsToPixelCoords(warped_points_aff_norm,source_im_size) 132 | 133 | if do_tps: 134 | # do tps only 135 | warped_points_tps_norm = pt.tpsPointTnf(theta_tps,target_points_norm) 136 | warped_points_tps = PointsToPixelCoords(warped_points_tps_norm,source_im_size) 137 | 138 | if do_aff_tps: 139 | # do tps+affine 140 | warped_points_aff_tps_norm = pt.tpsPointTnf(theta_aff_tps,target_points_norm) 141 | warped_points_aff_tps_norm = pt.affPointTnf(theta_aff,warped_points_aff_tps_norm) 142 | warped_points_aff_tps = PointsToPixelCoords(warped_points_aff_tps_norm,source_im_size) 143 | 144 | L_pck = batch['L_pck'].data 145 | 146 | current_batch_size=batch['source_im_size'].size(0) 147 | indices = range(batch_start_idx,batch_start_idx+current_batch_size) 148 | 149 | 150 | if do_aff: 151 | dist_aff = mean_dist(source_points.data, warped_points_aff.data, L_pck) 152 | 153 | if do_tps: 154 | dist_tps = mean_dist(source_points.data, warped_points_tps.data, L_pck) 155 | 156 | if do_aff_tps: 157 | dist_aff_tps = mean_dist(source_points.data, warped_points_aff_tps.data, L_pck) 158 | 159 | if do_aff: 160 | stats['aff']['dist'][indices] = dist_aff.unsqueeze(1).cpu().numpy() 161 | if do_tps: 162 | stats['tps']['dist'][indices] = dist_tps.unsqueeze(1).cpu().numpy() 163 | if do_aff_tps: 164 | stats['aff_tps']['dist'][indices] = dist_aff_tps.unsqueeze(1).cpu().numpy() 165 | 166 | return stats 167 | 168 | 169 | def pck_metric(batch, batch_start_idx, theta_aff, theta_aff_tps, stats, args, use_cuda=True): 170 | alpha = args.pck_alpha 171 | 172 | source_im_size = batch['source_im_size'] 173 | target_im_size = batch['target_im_size'] 174 | 175 | source_points = batch['source_points'] 176 | target_points = batch['target_points'] 177 | 178 | # Instantiate point transformer 179 | pt = PointTnf(use_cuda=use_cuda, 180 | tps_reg_factor=args.tps_reg_factor) 181 | 182 | # warp points with estimated transformations 183 | target_points_norm = PointsToUnitCoords(target_points,target_im_size) 184 | 185 | warped_points_aff_tps_norm = pt.tpsPointTnf(theta_aff_tps, target_points_norm) 186 | warped_points_aff_tps_norm = pt.affPointTnf(theta_aff, warped_points_aff_tps_norm) 187 | warped_points_aff_tps = PointsToPixelCoords(warped_points_aff_tps_norm, source_im_size) 188 | 189 | L_pck = batch['L_pck'].data 190 | 191 | current_batch_size=batch['source_im_size'].size(0) 192 | indices = range(batch_start_idx, batch_start_idx + current_batch_size) 193 | 194 | pck_aff_tps = pck(source_points.data, warped_points_aff_tps.data, L_pck, alpha) 195 | 196 | stats['aff_tps']['pck'][indices] = pck_aff_tps.unsqueeze(1).cpu().numpy() 197 | 198 | return stats 199 | 200 | 201 | def flow_metrics(batch,batch_start_idx,theta_aff,theta_tps,theta_aff_tps,stats,args,use_cuda=True): 202 | result_path=args.flow_output_dir 203 | 204 | do_aff = theta_aff is not None 205 | do_tps = theta_tps is not None 206 | do_aff_tps = theta_aff_tps is not None 207 | 208 | pt=PointTnf(use_cuda=use_cuda) 209 | 210 | batch_size=batch['source_im_size'].size(0) 211 | for b in range(batch_size): 212 | h_src = int(batch['source_im_size'][b,0].data.cpu().numpy()) 213 | w_src = int(batch['source_im_size'][b,1].data.cpu().numpy()) 214 | h_tgt = int(batch['target_im_size'][b,0].data.cpu().numpy()) 215 | w_tgt = int(batch['target_im_size'][b,1].data.cpu().numpy()) 216 | 217 | grid_X,grid_Y = np.meshgrid(np.linspace(-1,1,w_tgt),np.linspace(-1,1,h_tgt)) 218 | grid_X = torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3) 219 | grid_Y = torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3) 220 | grid_X = Variable(grid_X,requires_grad=False) 221 | grid_Y = Variable(grid_Y,requires_grad=False) 222 | if use_cuda: 223 | grid_X = grid_X.cuda() 224 | grid_Y = grid_Y.cuda() 225 | 226 | grid_X_vec = grid_X.view(1,1,-1) 227 | grid_Y_vec = grid_Y.view(1,1,-1) 228 | 229 | grid_XY_vec = torch.cat((grid_X_vec,grid_Y_vec),1) 230 | 231 | def pointsToGrid (x,h_tgt=h_tgt,w_tgt=w_tgt): return x.contiguous().view(1,2,h_tgt,w_tgt).transpose(1,2).transpose(2,3) 232 | 233 | idx = batch_start_idx+b 234 | 235 | if do_aff: 236 | grid_aff = pointsToGrid(pt.affPointTnf(theta_aff[b,:].unsqueeze(0),grid_XY_vec)) 237 | flow_aff = th_sampling_grid_to_np_flow(source_grid=grid_aff,h_src=h_src,w_src=w_src) 238 | flow_aff_path = os.path.join(result_path,'aff',batch['flow_path'][b]) 239 | create_file_path(flow_aff_path) 240 | write_flo_file(flow_aff,flow_aff_path) 241 | if do_tps: 242 | grid_tps = pointsToGrid(pt.tpsPointTnf(theta_tps[b,:].unsqueeze(0),grid_XY_vec)) 243 | flow_tps = th_sampling_grid_to_np_flow(source_grid=grid_tps,h_src=h_src,w_src=w_src) 244 | flow_tps_path = os.path.join(result_path,'tps',batch['flow_path'][b]) 245 | create_file_path(flow_tps_path) 246 | write_flo_file(flow_tps,flow_tps_path) 247 | if do_aff_tps: 248 | grid_aff_tps = pointsToGrid(pt.affPointTnf(theta_aff[b,:].unsqueeze(0),pt.tpsPointTnf(theta_aff_tps[b,:].unsqueeze(0),grid_XY_vec))) 249 | flow_aff_tps = th_sampling_grid_to_np_flow(source_grid=grid_aff_tps,h_src=h_src,w_src=w_src) 250 | flow_aff_tps_path = os.path.join(result_path,'aff_tps',batch['flow_path'][b]) 251 | create_file_path(flow_aff_tps_path) 252 | write_flo_file(flow_aff_tps,flow_aff_tps_path) 253 | 254 | idx = batch_start_idx+b 255 | return stats 256 | 257 | 258 | def intersection_over_union(warped_mask,target_mask): 259 | relative_part_weight = torch.sum(torch.sum(target_mask.data.gt(0.5).float(),2,True),3,True)/torch.sum(target_mask.data.gt(0.5).float()) 260 | part_iou = torch.sum(torch.sum((warped_mask.data.gt(0.5) & target_mask.data.gt(0.5)).float(),2,True),3,True)/torch.sum(torch.sum((warped_mask.data.gt(0.5) | target_mask.data.gt(0.5)).float(),2,True),3,True) 261 | weighted_iou = torch.sum(torch.mul(relative_part_weight,part_iou)) 262 | return weighted_iou 263 | 264 | 265 | def label_transfer_accuracy(warped_mask,target_mask): 266 | return torch.mean((warped_mask.data.gt(0.5) == target_mask.data.gt(0.5)).double()) 267 | 268 | 269 | def localization_error(source_mask_np, target_mask_np, flow_np): 270 | h_tgt, w_tgt = target_mask_np.shape[0],target_mask_np.shape[1] 271 | h_src, w_src = source_mask_np.shape[0],source_mask_np.shape[1] 272 | 273 | # initial pixel positions x1,y1 in target image 274 | x1, y1 = np.meshgrid(range(1,w_tgt+1), range(1,h_tgt+1)) 275 | # sampling pixel positions x2,y2 276 | x2 = x1 + flow_np[:,:,0] 277 | y2 = y1 + flow_np[:,:,1] 278 | 279 | # compute in-bound coords for each image 280 | in_bound = (x2 >= 1) & (x2 <= w_src) & (y2 >= 1) & (y2 <= h_src) 281 | row,col = np.where(in_bound) 282 | row_1=y1[row,col].flatten().astype(np.int)-1 283 | col_1=x1[row,col].flatten().astype(np.int)-1 284 | row_2=y2[row,col].flatten().astype(np.int)-1 285 | col_2=x2[row,col].flatten().astype(np.int)-1 286 | 287 | # compute relative positions 288 | target_loc_x,target_loc_y = obj_ptr(target_mask_np) 289 | source_loc_x,source_loc_y = obj_ptr(source_mask_np) 290 | x1_rel=target_loc_x[row_1,col_1] 291 | y1_rel=target_loc_y[row_1,col_1] 292 | x2_rel=source_loc_x[row_2,col_2] 293 | y2_rel=source_loc_y[row_2,col_2] 294 | 295 | # compute localization error 296 | loc_err = np.mean(np.abs(x1_rel-x2_rel)+np.abs(y1_rel-y2_rel)) 297 | 298 | return loc_err 299 | 300 | def obj_ptr(mask): 301 | # computes images of normalized coordinates around bounding box 302 | # kept function name from DSP code 303 | h,w = mask.shape[0],mask.shape[1] 304 | y, x = np.where(mask>0.5) 305 | left = np.min(x); 306 | right = np.max(x); 307 | top = np.min(y); 308 | bottom = np.max(y); 309 | fg_width = right-left + 1; 310 | fg_height = bottom-top + 1; 311 | x_image,y_image = np.meshgrid(range(1,w+1), range(1,h+1)); 312 | x_image = (x_image - left)/fg_width; 313 | y_image = (y_image - top)/fg_height; 314 | return (x_image,y_image) 315 | 316 | -------------------------------------------------------------------------------- /util/normalize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from torch.autograd import Variable 4 | 5 | class NormalizeImage(object): 6 | """ 7 | 8 | Normalizes Tensor images in dictionary 9 | 10 | Args: 11 | image_keys (list): dict. keys of the images to be normalized 12 | normalizeRange (bool): if True the image is divided by 255.0s 13 | 14 | """ 15 | 16 | def __init__(self,image_keys,normalizeRange=True): 17 | self.image_keys = image_keys 18 | self.normalizeRange=normalizeRange 19 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 20 | std=[0.229, 0.224, 0.225]) 21 | 22 | def __call__(self, sample): 23 | for key in self.image_keys: 24 | if self.normalizeRange: 25 | sample[key] /= 255.0 26 | sample[key] = self.normalize(sample[key]) 27 | return sample 28 | 29 | def normalize_image(image, forward=True, mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]): 30 | im_size = image.size() 31 | mean=torch.FloatTensor(mean).unsqueeze(1).unsqueeze(2) 32 | std=torch.FloatTensor(std).unsqueeze(1).unsqueeze(2) 33 | if image.is_cuda: 34 | mean = mean.cuda() 35 | std = std.cuda() 36 | if isinstance(image,torch.autograd.variable.Variable): 37 | mean = Variable(mean,requires_grad=False) 38 | std = Variable(std,requires_grad=False) 39 | if forward: 40 | if len(im_size)==3: 41 | result = image.sub(mean.expand(im_size)).div(std.expand(im_size)) 42 | elif len(im_size)==4: 43 | result = image.sub(mean.unsqueeze(0).expand(im_size)).div(std.unsqueeze(0).expand(im_size)) 44 | else: 45 | if len(im_size)==3: 46 | result = image.mul(std.expand(im_size)).add(mean.expand(im_size)) 47 | elif len(im_size)==4: 48 | result = image.mul(std.unsqueeze(0).expand(im_size)).add(mean.unsqueeze(0).expand(im_size)) 49 | 50 | return result 51 | -------------------------------------------------------------------------------- /util/py_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | 4 | def create_file_path(filename): 5 | if not os.path.exists(os.path.dirname(filename)): 6 | try: 7 | os.makedirs(os.path.dirname(filename)) 8 | except OSError as exc: # Guard against race condition 9 | if exc.errno != errno.EEXIST: 10 | raise -------------------------------------------------------------------------------- /util/torch_util.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import torch 3 | from torch.autograd import Variable 4 | from os import makedirs, remove 5 | from os.path import exists, join, basename, dirname 6 | import collections 7 | 8 | try: 9 | from dataloader import default_collate 10 | except ImportError: 11 | from util.dataloader import default_collate 12 | 13 | def collate_custom(batch): 14 | """ Custom collate function for the Dataset class 15 | * It doesn't convert numpy arrays to stacked-tensors, but rather combines them in a list 16 | * This is useful for processing annotations of different sizes 17 | """ 18 | # this case will occur in first pass, and will convert a 19 | # list of dictionaries (returned by the threads by sampling dataset[idx]) 20 | # to a unified dictionary of collated values 21 | if isinstance(batch[0], collections.Mapping): 22 | return {key: collate_custom([d[key] for d in batch]) for key in batch[0]} 23 | # these cases will occur in recursion 24 | elif torch.is_tensor(batch[0]): # for tensors, use standrard collating function 25 | return default_collate(batch) 26 | else: # for other types (i.e. lists), return as is 27 | return batch 28 | 29 | class BatchTensorToVars(object): 30 | """Convert tensors in dict batch to vars 31 | """ 32 | def __init__(self, use_cuda=True): 33 | self.use_cuda=use_cuda 34 | 35 | def __call__(self, batch): 36 | batch_var = {} 37 | for key,value in batch.items(): 38 | if isinstance(value,torch.Tensor) and not self.use_cuda: 39 | batch_var[key] = Variable(value,requires_grad=False) 40 | elif isinstance(value,torch.Tensor) and self.use_cuda: 41 | batch_var[key] = Variable(value,requires_grad=False).cuda() 42 | else: 43 | batch_var[key] = value 44 | return batch_var 45 | 46 | def Softmax1D(x,dim): 47 | x_k = torch.max(x,dim)[0].unsqueeze(dim) 48 | x -= x_k.expand_as(x) 49 | exp_x = torch.exp(x) 50 | return torch.div(exp_x,torch.sum(exp_x,dim).unsqueeze(dim).expand_as(x)) 51 | 52 | def save_checkpoint(state, is_best, file): 53 | model_dir = dirname(file) 54 | model_fn = basename(file) 55 | # make dir if needed (should be non-empty) 56 | if model_dir!='' and not exists(model_dir): 57 | makedirs(model_dir) 58 | torch.save(state, file) 59 | if is_best: 60 | shutil.copyfile(file, join(model_dir,'best_' + model_fn)) 61 | 62 | def str_to_bool(v): 63 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 64 | return True 65 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 66 | return False 67 | else: 68 | raise argparse.ArgumentTypeError('Boolean value expected.') 69 | 70 | def expand_dim(tensor,dim,desired_dim_len): 71 | sz = list(tensor.size()) 72 | sz[dim]=desired_dim_len 73 | return tensor.expand(tuple(sz)) 74 | 75 | -------------------------------------------------------------------------------- /util/train_test_fn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch 3 | from model.loss import TpsGridRegularityLoss 4 | 5 | def train_fun_strong(epoch,model,loss_fn,optimizer,dataloader,pair_generation_tnf,use_cuda=True,log_interval=50): 6 | model.train() 7 | train_loss = 0 8 | for batch_idx, batch in enumerate(dataloader): 9 | optimizer.zero_grad() 10 | tnf_batch = pair_generation_tnf(batch) 11 | theta = model(tnf_batch) 12 | loss = loss_fn(theta,tnf_batch['theta_GT']) 13 | loss.backward() 14 | optimizer.step() 15 | train_loss += loss.data.cpu().numpy()[0] 16 | if batch_idx % log_interval == 0: 17 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\t\tLoss: {:.6f}'.format( 18 | epoch, batch_idx , len(dataloader), 19 | 100. * batch_idx / len(dataloader), loss.data[0])) 20 | train_loss /= len(dataloader) 21 | print('Train set: Average loss: {:.4f}'.format(train_loss)) 22 | return train_loss 23 | 24 | def test_fun_strong(model,loss_fn,dataloader,pair_generation_tnf,use_cuda=True): 25 | model.eval() 26 | test_loss = 0 27 | for batch_idx, batch in enumerate(dataloader): 28 | tnf_batch = pair_generation_tnf(batch) 29 | theta = model(tnf_batch) 30 | loss = loss_fn(theta,tnf_batch['theta_GT']) 31 | test_loss += loss.data.cpu().numpy()[0] 32 | 33 | test_loss /= len(dataloader) 34 | print('Test set: Average loss: {:.4f}'.format(test_loss)) 35 | return test_loss 36 | 37 | 38 | def train_fun_weak(epoch,model,loss_fn,optimizer,dataloader,dataloader_neg,batch_tnf,use_cuda=True,log_interval=50,triplet=False,tps_grid_regularity_loss=0): 39 | tgrl = TpsGridRegularityLoss(use_cuda=use_cuda) 40 | model.train() 41 | train_loss = 0 42 | if dataloader_neg is not None: 43 | dataloader_neg_iter=iter(dataloader_neg) 44 | for batch_idx, batch in enumerate(dataloader): 45 | optimizer.zero_grad() 46 | batch = batch_tnf(batch) 47 | if dataloader_neg is not None and triplet==False: 48 | batch_neg = next(dataloader_neg_iter) 49 | batch_neg = batch_tnf(batch_neg) 50 | theta_pos,corr_pos,theta_neg,corr_neg = model(batch, batch_neg) 51 | inliers_pos = loss_fn(theta_pos,corr_pos) 52 | inliers_neg = loss_fn(theta_neg,corr_neg) 53 | loss = torch.sum(inliers_neg - inliers_pos) 54 | elif dataloader_neg is None and triplet==False: 55 | theta,corr = model(batch) 56 | loss = loss_fn(theta,corr) 57 | elif dataloader_neg is None and triplet==True: 58 | f_A = model.FeatureExtraction(batch['source_image']) 59 | f_B = model.FeatureExtraction(batch['source_image']) 60 | f_N = model.FeatureExtraction(batch['negative_image']) 61 | corr_pos = model.FeatureCorrelation(f_A,f_B) 62 | corr_neg = model.FeatureCorrelation(f_A,f_N) 63 | theta_pos = model.FeatureRegression(corr_pos) 64 | theta_neg = model.FeatureRegression(corr_neg) 65 | inliers_pos = loss_fn(theta_pos,corr_pos) 66 | inliers_neg = loss_fn(theta_neg,corr_neg) 67 | loss = torch.sum(inliers_neg - inliers_pos) 68 | if tps_grid_regularity_loss != 0: 69 | loss = loss + tps_grid_regularity_loss*tgrl(theta_pos) 70 | 71 | loss.backward() 72 | optimizer.step() 73 | train_loss += loss.data.cpu().numpy()[0] 74 | print_train_progress(log_interval,batch_idx,len(dataloader),epoch,loss.data[0]) 75 | train_loss /= len(dataloader) 76 | print('Train set: Average loss: {:.4f}'.format(train_loss)) 77 | return train_loss 78 | 79 | 80 | def test_fun_weak(model,loss_fn,dataloader,dataloader_neg,batch_tnf,use_cuda=True,triplet=False,tps_grid_regularity_loss=0): 81 | model.eval() 82 | test_loss = 0 83 | if dataloader_neg is not None: 84 | dataloader_neg_iter=iter(dataloader_neg) 85 | for batch_idx, batch in enumerate(dataloader): 86 | batch = batch_tnf(batch) 87 | if dataloader_neg is not None: 88 | batch_neg = next(dataloader_neg_iter) 89 | batch_neg = batch_tnf(batch_neg) 90 | theta_pos,corr_pos,theta_neg,corr_neg = model(batch, batch_neg) 91 | inliers_pos = loss_fn(theta_pos,corr_pos) 92 | inliers_neg = loss_fn(theta_neg,corr_neg) 93 | loss = torch.sum(inliers_neg - inliers_pos) 94 | elif dataloader_neg is None and triplet==False: 95 | theta,corr = model(batch) 96 | loss = loss_fn(theta,corr) 97 | elif dataloader_neg is None and triplet==True: 98 | theta_pos,corr_pos,theta_neg,corr_neg = model(batch, triplet=True) 99 | inliers_pos = loss_fn(theta_pos,corr_pos) 100 | inliers_neg = loss_fn(theta_neg,corr_neg) 101 | loss = torch.sum(inliers_neg - inliers_pos) 102 | test_loss += loss.data.cpu().numpy()[0] 103 | 104 | test_loss /= len(dataloader) 105 | print('Test set: Average loss: {:.4f}'.format(test_loss)) 106 | return test_loss 107 | 108 | def print_train_progress(log_interval,batch_idx,num_batches,epoch,loss_value): 109 | if batch_idx % log_interval == 0: 110 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\t\tLoss: {:.6f}'.format( 111 | epoch, batch_idx , num_batches, 112 | 100. * batch_idx / num_batches, loss_value)) -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | 5 | import config 6 | 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from torchvision import transforms 11 | from torch.autograd import Variable, grad 12 | 13 | from model.network import WeakMatchNet 14 | from model.network import WeakCosegNet 15 | #from model.network import MaCoSNet 16 | 17 | from data.pf_pascal import PFPascal 18 | from data.pf_pascal import PFPascalVal 19 | from data.pf_willow import PFWillow 20 | from data.tss import TSS 21 | from data.tss import TSSVal 22 | from data.internet import Internet 23 | from data.internet import InternetVal 24 | 25 | try: 26 | from dataloader import DataLoader 27 | except ImportError: 28 | from util.dataloader import DataLoader 29 | 30 | try: 31 | from normalize import NormalizeImage 32 | except ImportError: 33 | from util.normalize import NormalizeImage 34 | 35 | 36 | def init_model(args, arg_groups, use_cuda=True, mode='train'): 37 | if args.model_type == 'match': 38 | model = init_match_model(args, arg_groups, use_cuda, mode) 39 | elif args.model_type == 'coseg': 40 | model = init_coseg_model(args, arg_groups, use_cuda, mode) 41 | else: # joint 42 | model = init_joint_model(args, arg_groups, use_cuda, mode) 43 | return model 44 | 45 | 46 | def init_match_model(args, arg_groups, use_cuda=True, mode='train'): 47 | 48 | model = WeakMatchNet(use_cuda=use_cuda, 49 | **arg_groups['model']) 50 | 51 | if args.model: 52 | checkpoint = torch.load(args.model, map_location=lambda storage, loc: storage) 53 | for name, param in model.FeatureExtraction.state_dict().items(): 54 | try: 55 | model.FeatureExtraction.state_dict()[name].copy_(checkpoint['state_dict']['FeatureExtraction.' + name]) 56 | except KeyError: 57 | model.FeatureExtraction.state_dict()[name].copy_(checkpoint['FeatureExtraction.' + name]) 58 | for name, param in model.FeatureRegression.state_dict().items(): 59 | try: 60 | model.FeatureRegression.state_dict()[name].copy_(checkpoint['state_dict']['FeatureRegression.' + name]) 61 | except KeyError: 62 | model.FeatureRegression.state_dict()[name].copy_(checkpoint['FeatureRegression.' + name]) 63 | for name, param in model.FeatureRegression2.state_dict().items(): 64 | try: 65 | model.FeatureRegression2.state_dict()[name].copy_(checkpoint['state_dict']['FeatureRegression2.' + name]) 66 | except KeyError: 67 | model.FeatureRegression2.state_dict()[name].copy_(checkpoint['FeatureRegression2.' + name]) 68 | 69 | if mode == 'train': 70 | for name, param in model.FeatureExtraction.named_parameters(): 71 | param.requires_grad = False 72 | if args.train_fe and np.sum([name.find(x) != -1 for x in args.fe_finetune_params]): 73 | param.requires_grad = True 74 | if args.train_fe and name.find('bn') != -1 and np.sum([name.find(x) != -1 for x in args.fe_finetune_params]): 75 | param.requires_grad = args.train_bn 76 | 77 | for name, param in model.FeatureRegression.named_parameters(): 78 | param.requires_grad = args.train_fr 79 | if args.train_fr and name.find('bn') != -1: 80 | param.requires_grad = args.train_bn 81 | 82 | for name, param in model.FeatureRegression2.named_parameters(): 83 | param.requires_grad = args.train_fr 84 | if args.train_fr and name.find('bn') != -1: 85 | param.requires_grad = args.train_bn 86 | 87 | return model 88 | 89 | 90 | def init_coseg_model(args, arg_groups, use_cuda=True, mode='train'): 91 | 92 | model = WeakCosegNet(train_fe=args.train_fe, 93 | feature_extraction_cnn='resnet101', 94 | use_cuda=use_cuda) 95 | 96 | if args.model: 97 | checkpoint = torch.load(args.model, map_location=lambda storage, loc: storage) 98 | for name, param in model.FeatureExtraction.state_dict().items(): 99 | try: 100 | model.FeatureExtraction.state_dict()[name].copy_(checkpoint['state_dict']['FeatureExtraction.' + name]) 101 | except KeyError: 102 | model.FeatureExtraction.state_dict()[name].copy_(checkpoint['FeatureExtraction.' + name]) 103 | for name, param in model.Decoder.state_dict().items(): 104 | try: 105 | model.Decoder.state_dict()[name].copy_(checkpoint['state_dict']['Decoder.' + name]) 106 | except KeyError: 107 | model.Decoder.state_dict()[name].copy_(checkpoint['Decoder.' + name]) 108 | 109 | if mode == 'train': 110 | for name, param in model.FeatureExtraction.named_parameters(): 111 | param.requires_grad = False 112 | if args.train_fe and np.sum([name.find(x) != -1 for x in args.fe_finetune_params]): 113 | param.requires_grad = True 114 | if args.train_fe and name.find('bn') != -1 and np.sum([name.find(x) != -1 for x in args.fe_finetune_params]): 115 | param.requires_grad = args.train_bn 116 | for name, param in model.Decoder.named_parameters(): 117 | param.requires_grad = True 118 | 119 | return model 120 | 121 | 122 | def init_train_data(args): 123 | 124 | if args.training_dataset == 'pf-pascal': 125 | dataset = PFPascal(transform=NormalizeImage(['image_A', 'image_B', 'image_C']), random_crop=True) 126 | elif args.training_dataset == 'tss': 127 | dataset = TSS(transform=NormalizeImage(['image_A', 'image_B', 'image_C']), random_crop=True) 128 | else: # internet 129 | dataset = Internet(transform=NormalizeImage(['image_A', 'image_B', 'image_C']), random_crop=True) 130 | 131 | data_loader = DataLoader(dataset, 132 | batch_size=args.batch_size, 133 | shuffle=True, 134 | num_workers=args.num_workers, 135 | pin_memory=True) 136 | 137 | return dataset, data_loader 138 | 139 | 140 | def init_eval_data(args): 141 | 142 | if args.training_dataset == 'pf-pascal': 143 | dataset = PFPascalVal(transform=NormalizeImage(['image_A', 'image_B']), mode='eval') 144 | if args.training_dataset == 'tss': 145 | dataset = TSSVal(transform=NormalizeImage(['image_A', 'image_B'])) 146 | else: # Internet 147 | dataset = InternetVal(transform=NormalizeImage(['image_A', 'image_B'])) 148 | 149 | data_loader = DataLoader(dataset, 150 | batch_size=args.batch_size, 151 | shuffle=False, 152 | num_workers=args.num_workers, 153 | pin_memory=True) 154 | 155 | return dataset, data_loader 156 | 157 | 158 | def init_test_data(args): 159 | 160 | if args.eval_dataset == 'pf-pascal': 161 | dataset = PFPascalVal(transform=NormalizeImage(['image_A', 'image_B']), mode='test') 162 | elif args.eval_dataset == 'pf-willow': 163 | dataset = PFWillow(transform=NormalizeImage(['image_A', 'image_B'])) 164 | elif args.eval_dataset == 'tss': 165 | dataset = TSSVal(transform=NormalizeImage(['image_A', 'image_B'])) 166 | else: # Internet 167 | dataset = InternetVal(transform=NormalizeImage(['image_A', 'image_B'])) 168 | 169 | data_loader = DataLoader(dataset, 170 | batch_size=args.batch_size, 171 | shuffle=False, 172 | num_workers=args.num_workers, 173 | pin_memory=True) 174 | 175 | return dataset, data_loader 176 | 177 | 178 | def init_model_optim(args, model): 179 | 180 | model_opt = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 181 | lr=args.lr, 182 | weight_decay=args.weight_decay) 183 | model_opt.zero_grad() 184 | 185 | return model_opt 186 | 187 | 188 | def save_model(args, model, is_best): 189 | 190 | print('Saving model...') 191 | 192 | model_name = 'match_{}_cycle_{}_trans_{}_coseg_{}_task_{}.pth.tar'.format( 193 | args.w_match, args.w_cycle, args.w_trans, args.w_coseg, args.w_task) 194 | 195 | model_path = os.path.join(args.result_model_dir, model_name) 196 | 197 | torch.save(model.state_dict(), model_path) 198 | 199 | if is_best: 200 | best_model_path = os.path.join(args.result_model_dir, 'best_{}'.format(model_name)) 201 | shutil.copyfile(model_path, best_model_path) 202 | 203 | return 204 | --------------------------------------------------------------------------------