├── .gitignore ├── images ├── transform_test.png ├── basic_test_image_var.png ├── icstn_test_image_var.png ├── stn_test_image_mean.png ├── stn_test_image_var.png ├── basic_alignment_sample.png ├── basic_test_image_mean.png ├── icstn_test_image_mean.png ├── stn_alignment_samples.png ├── basic_alignment_samples.png └── icstn_alignment_samples.png ├── requirements.txt ├── experiments ├── base_stn_model │ └── params.json ├── base_icstn_model │ └── params.json └── base_basic_model │ └── params.json ├── data_loader.py ├── search_hyperparams.py ├── readme.md ├── utils.py ├── model.py ├── evaluate.py ├── train.py └── vision_transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | # exclude data 4 | data 5 | 6 | # virtual env 7 | .env 8 | -------------------------------------------------------------------------------- /images/transform_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/transform_test.png -------------------------------------------------------------------------------- /images/basic_test_image_var.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/basic_test_image_var.png -------------------------------------------------------------------------------- /images/icstn_test_image_var.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/icstn_test_image_var.png -------------------------------------------------------------------------------- /images/stn_test_image_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/stn_test_image_mean.png -------------------------------------------------------------------------------- /images/stn_test_image_var.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/stn_test_image_var.png -------------------------------------------------------------------------------- /images/basic_alignment_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/basic_alignment_sample.png -------------------------------------------------------------------------------- /images/basic_test_image_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/basic_test_image_mean.png -------------------------------------------------------------------------------- /images/icstn_test_image_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/icstn_test_image_mean.png -------------------------------------------------------------------------------- /images/stn_alignment_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/stn_alignment_samples.png -------------------------------------------------------------------------------- /images/basic_alignment_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/basic_alignment_samples.png -------------------------------------------------------------------------------- /images/icstn_alignment_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/icstn_alignment_samples.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | kiwisolver==1.0.1 3 | matplotlib==3.0.1 4 | numpy==1.15.3 5 | Pillow==5.3.0 6 | protobuf==3.6.1 7 | pyparsing==2.3.0 8 | python-dateutil==2.7.5 9 | pytz==2018.7 10 | six==1.11.0 11 | tensorboardX==1.4 12 | torch==0.4.1 13 | torchvision==0.2.1 14 | tqdm==4.28.1 15 | -------------------------------------------------------------------------------- /experiments/base_stn_model/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "stn_module": "STNModule", 3 | "icstn_steps": 4, 4 | "data_dir": "./data", 5 | "batch_size": 128, 6 | "transformer_lr": 1e-3, 7 | "clf_lr": 1e-3, 8 | "lr_step": 1, 9 | "lr_gamma": 1, 10 | "n_epochs": 10, 11 | "save_summary_steps": 10000, 12 | "mini_data": false 13 | } 14 | 15 | -------------------------------------------------------------------------------- /experiments/base_icstn_model/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "stn_module": "ICSTNModule", 3 | "icstn_steps": 4, 4 | "data_dir": "./data", 5 | "batch_size": 128, 6 | "transformer_lr": 5e-4, 7 | "clf_lr": 1e-3, 8 | "lr_step": 1, 9 | "lr_gamma": 1, 10 | "n_epochs": 10, 11 | "save_summary_steps": 10000, 12 | "mini_data": false 13 | } 14 | 15 | -------------------------------------------------------------------------------- /experiments/base_basic_model/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "stn_module": "BasicSTNModule", 3 | "icstn_steps": 4, 4 | "data_dir": "./data", 5 | "batch_size": 128, 6 | "transformer_lr": 1e-3, 7 | "clf_lr": 1e-3, 8 | "lr_step": 1, 9 | "lr_gamma": 1, 10 | "n_epochs": 10, 11 | "save_summary_steps": 10000, 12 | "mini_data": false 13 | } 14 | 15 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision.datasets import MNIST 4 | import torchvision.transforms as T 5 | 6 | 7 | def fetch_dataloader(params, train=True, mini_size=128): 8 | 9 | # load dataset and init in the dataloader 10 | transforms = T.Compose([T.ToTensor()]) 11 | dataset = MNIST(root=params.data_dir, train=train, download=False, transform=transforms) 12 | 13 | if params.dict.get('mini_data'): 14 | if train: 15 | dataset.train_data = dataset.train_data[:mini_size] 16 | dataset.train_labels = dataset.train_labels[:mini_size] 17 | else: 18 | dataset.test_data = dataset.test_data[:mini_size] 19 | dataset.test_labels = dataset.test_labels[:mini_size] 20 | 21 | if params.dict.get('mini_ones'): 22 | if train: 23 | labels = dataset.train_labels[:2000] 24 | mask = labels==1 25 | dataset.train_labels = labels[mask][:mini_size] 26 | dataset.train_data = dataset.train_data[:2000][mask][:mini_size] 27 | else: 28 | labels = dataset.test_labels[:2000] 29 | mask = labels==1 30 | dataset.test_labels = labels[mask][:mini_size] 31 | dataset.test_data = dataset.test_data[:2000][mask][:mini_size] 32 | 33 | kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() and params.device.type is 'cuda' else {} 34 | 35 | return DataLoader(dataset, batch_size=params.batch_size, shuffle=True, drop_last=True, **kwargs) 36 | 37 | 38 | -------------------------------------------------------------------------------- /search_hyperparams.py: -------------------------------------------------------------------------------- 1 | """ Perform hyperparameter search """ 2 | 3 | import os 4 | import sys 5 | import json 6 | import argparse 7 | from copy import deepcopy 8 | from subprocess import check_call 9 | 10 | import torch 11 | import utils 12 | 13 | 14 | PYTHON = sys.executable 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--parent_dir', default='experiments', help='Directory containing hyperparams.json to setup a model.') 18 | parser.add_argument('--data_dir', default='./data', help='Directory containing the dataset') 19 | parser.add_argument('--cuda', type=int, help='Which cuda device to use') 20 | 21 | 22 | def launch_training_job(parent_dir, data_dir, job_name, params): 23 | """ launch training of the model with a set of hyperparameters in parent_dir/job_name """ 24 | 25 | # create new filder in parent_dir with unique name 'job_name' 26 | output_dir = os.path.join(parent_dir, job_name) 27 | if not os.path.exists(output_dir): 28 | os.mkdir(output_dir) 29 | 30 | # write params in a json file 31 | json_path = os.path.join(output_dir, 'params.json') 32 | params.save(json_path) 33 | 34 | print('Launching training job with parameters:') 35 | print(params) 36 | 37 | # launch training with this config 38 | if params.device is 'cpu': 39 | cmd = '{python} train.py --output_dir={output_dir}'.format( 40 | python=PYTHON, output_dir=output_dir) 41 | else: 42 | cmd = '{python} train.py --output_dir={output_dir} --cuda={device}'.format( 43 | python=PYTHON, output_dir=output_dir, device=int(params.device.split(':')[1])) 44 | 45 | 46 | print(cmd) 47 | 48 | check_call(cmd, shell=True) 49 | 50 | 51 | if __name__ == '__main__': 52 | # load the references parameters from parent_dir json file 53 | args = parser.parse_args() 54 | 55 | json_path = os.path.join(args.parent_dir, 'hyperparams.json') 56 | assert os.path.isfile(json_path), 'No json configuration file found at {}'.format(json_path) 57 | hyperparams = utils.Params(json_path) 58 | 59 | json_path = os.path.join(args.parent_dir, 'base_params.json') 60 | assert os.path.isfile(json_path), 'No json configuration file found at {}'.format(json_path) 61 | base_params = utils.Params(json_path) 62 | 63 | # set the static parameters 64 | for param, values in hyperparams.dict.items(): 65 | if isinstance(values, list): 66 | continue 67 | base_params.dict[param] = values 68 | 69 | base_params.device = 'cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda else 'cpu' 70 | 71 | # loop through the hyperparameter lists 72 | for param, values in hyperparams.dict.items(): 73 | if isinstance(values, list): 74 | for v in values: 75 | params = deepcopy(base_params) 76 | # modify the parameter value to that in hyperparms 77 | params.dict[param] = v 78 | 79 | # launch job with unique name 80 | job_name = '{}_{}'.format(param, v) 81 | launch_training_job(args.parent_dir, args.data_dir, job_name, params) 82 | 83 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Spatial Transformer Networks 2 | 3 | Reimplementations of: 4 | * [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025) 5 | * [Inverse Compositional Spatial Transformer Networks](https://chenhsuanlin.bitbucket.io/inverse-compositional-STN/paper.pdf) 6 | 7 | Although implementations already exists, this focuses on simplicity and 8 | ease of understanding of the vision transforms and model. 9 | 10 | ## Results 11 | 12 | During training, random homography perturbations are applied to each image in the minibatch. The perturbations are composed by component transformation (rotation, translation, shear, projection), the parameters of each sampled from a uniform(-1,1) * 0.25 multiplicative factor. 13 | 14 | Example homography perturbation:
15 | example perturbation 16 | 17 | ### Test set accuracy: 18 | 19 | | Model | Accuracy | Training params | 20 | | ----- | -------- | ----- | 21 | | Basic affine STN | 91.59% | 10 epochs at learning rate 1e-3 (classifier and transformer)| 22 | | Homography STN | 93.30% | 10 epochs at learning rate 1e-3 (classifier and transformer) | 23 | | Homography ICSTN | 97.67% | 10 epochs at learning rate 1e-3 (classifier) and 5e-4 (transformer) | 24 | 25 | 26 | ### Sample alignment results: 27 | 28 | #### Basic affine STN 29 | 30 | | Image | Samples | 31 | | --- | --- | 32 | | original
perturbed
transformed | ![basic](images/basic_alignment_samples.png) | 33 | 34 | #### Homography STN 35 | 36 | | Image | Samples | 37 | | --- | --- | 38 | | original
perturbed
transformed | ![stn](images/stn_alignment_samples.png) | 39 | 40 | 41 | #### Homography ICSTN 42 | 43 | | Image | Samples | 44 | | --- | --- | 45 | | original
perturbed
transformed | ![icstn](images/icstn_alignment_samples.png) | 46 | 47 | 48 | ### Mean and variance of the aligned results (cf Lin ICSTN paper) 49 | 50 | #### Mean image 51 | | Image | Basic affine STN | Homography STN | Homography ICSTN | 52 | | --- | ---------------- | -------------- | ---------------- | 53 | | original
perturbed
transformed | ![basic](images/basic_test_image_mean.png) | ![stn](images/stn_test_image_mean.png) | ![icstn](images/icstn_test_image_mean.png) | 54 | 55 | #### Variance 56 | | Image | Basic affine STN | Homography STN | Homography ICSTN | 57 | | --- | ---------------- | -------------- | ---------------- | 58 | | original
perturbed
transformed | ![basic](images/basic_test_image_var.png) | ![stn](images/stn_test_image_var.png) | ![icstn](images/icstn_test_image_var.png) | 59 | 60 | 61 | ## Usage 62 | 63 | To train model: 64 | ``` 65 | python train.py --output_dir=[path to params.json] 66 | --restore_file=[path to .pt checkpoint if resuming training] 67 | --cuda=[cuda device id] 68 | ``` 69 | `params.json` provides training parameters and specifies which spatial transformer module to use: 70 | 1. `BasicSTNModule` -- affine transform localization network 71 | 2. `STNModule` -- homography transform localization network 72 | 3. `ICSTNModule` -- homography transform localization netwokr (cf Lin, 73 | ICSTN paper) 74 | 75 | To evaluate and visualize results: 76 | ``` 77 | python evaluate.py --output_dir=[path to params.json] 78 | --restore_file=[path to .pt checkpoint] 79 | --cuda=[cuda device id] 80 | ``` 81 | 82 | ## Dependencies 83 | * python 3.6 84 | * pytorch 0.4 85 | * torchvision 86 | * tensorboardX 87 | * numpy 88 | * matplotlib 89 | * tqdm 90 | 91 | 92 | 93 | ## Useful resources 94 | * https://github.com/chenhsuanlin/inverse-compositional-STN 95 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import json 4 | from datetime import datetime 5 | import torch 6 | 7 | from tensorboardX import SummaryWriter 8 | 9 | 10 | 11 | 12 | def set_writer(log_dir, comment=''): 13 | """ setup a tensorboardx summarywriter """ 14 | # current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 15 | # log_dir = os.path.join(log_path, current_time + comment) 16 | writer = SummaryWriter(log_dir=log_dir) 17 | return writer 18 | 19 | 20 | def save_checkpoint(state, is_best, checkpoint, quiet=False): 21 | """ saves model and training params at checkpoint + 'last.pt'; if is_best also saves checkpoint + 'best.pt' 22 | 23 | args 24 | state -- dict; with keys model_state_dict, optimizer_state_dict, epoch, scheduler_state_dict, etc 25 | is_best -- bool; true if best model seen so far 26 | checkpoint -- str; folder where params are to be saved 27 | """ 28 | 29 | filepath = os.path.join(checkpoint, 'state_checkpoint.pt') 30 | if not os.path.exists(checkpoint): 31 | if not quiet: 32 | print('Checkpoint directory does not exist Making directory {}'.format(checkpoint)) 33 | os.mkdir(checkpoint) 34 | 35 | torch.save(state, filepath) 36 | 37 | if is_best: 38 | shutil.copyfile(filepath, os.path.join(checkpoint, 'best_state_checkpoint.pt')) 39 | 40 | if not quiet: 41 | print('Checkpoint saved.') 42 | 43 | 44 | def load_checkpoint(checkpoint, model, optimizer=None, scheduler=None, best_metric=None): 45 | """ loads model state_dict from filepath; if optimizer and lr_scheduler provided also loads them 46 | 47 | args 48 | checkpoint -- string of filename 49 | model -- torch nn.Module model 50 | optimizer -- torch.optim instance to resume from checkpoint 51 | lr_scheduler -- torch.optim.lr_scheduler instance to resume from checkpoint 52 | """ 53 | 54 | if not os.path.exists(checkpoint): 55 | raise('File does not exist {}'.format(checkpoint)) 56 | 57 | checkpoint = torch.load(checkpoint) 58 | model.load_state_dict(checkpoint['model_state_dict']) 59 | 60 | if optimizer: 61 | try: 62 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 63 | except KeyError: 64 | print('No optimizer state dict in checkpoint file') 65 | 66 | if best_metric: 67 | try: 68 | best_metric = checkpoint['best_val_acc'] 69 | except KeyError: 70 | print('No best validation accuracy recorded in checkpoint file.') 71 | 72 | if scheduler: 73 | try: 74 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 75 | except KeyError: 76 | print('No lr scheduler state dict in checkpoint file') 77 | 78 | return checkpoint['epoch'] 79 | 80 | 81 | # -------------------- 82 | # Containers 83 | # -------------------- 84 | 85 | class RunningAverage: 86 | """ a class to maintain the running average of a quantity 87 | 88 | example: 89 | ``` 90 | loss_avg = RunningAverage() 91 | loss_avg.update(2) 92 | loss_avg.update(4) 93 | loss_avg() = 3 94 | ``` 95 | """ 96 | 97 | def __init__(self): 98 | self.steps = 0 99 | self.total = 0 100 | 101 | def __call__(self): 102 | return self.total/float(self.steps) 103 | 104 | def update(self, val): 105 | self.steps += 1 106 | self.total += val 107 | 108 | 109 | 110 | class Params: 111 | """ class that loads hyperparams from json file. 112 | 113 | example: 114 | ``` 115 | params = Params(json_path) 116 | print(params.learning_rate) 117 | params.learning_rate = 0.5 118 | ``` 119 | """ 120 | 121 | def __init__(self, json_path): 122 | with open(json_path, 'r') as f: 123 | params = json.load(f) 124 | self.__dict__.update(params) 125 | self.__dict__['output_dir'] = os.path.dirname(json_path) 126 | 127 | def save(self, json_path): 128 | with open(json_path, 'w') as f: 129 | json.dump(self.__dict__, f, indent=4) 130 | 131 | def update(self, json_path): 132 | """ loads params from json file """ 133 | with open(json_path, 'r') as f: 134 | params = json.load(f) 135 | self.__dict__.update(params) 136 | 137 | @property 138 | def dict(self): 139 | """ gives dict-like access to Params instances by `params.dict['learning_rate']` """ 140 | return self.__dict__ 141 | 142 | def __repr__(self): 143 | out = '' 144 | for k, v in self.__dict__.items(): 145 | out += k + ': ' + str(v) + '\n' 146 | return out 147 | 148 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from vision_transforms import apply_transform_to_batch, vec_to_perpective_matrix 6 | 7 | 8 | # -------------------- 9 | # Model helpers 10 | # -------------------- 11 | 12 | class Flatten(nn.Module): 13 | def forward(self, x): 14 | return x.view(x.shape[0],-1) 15 | 16 | def initialize(model, std=0.1): 17 | for p in model.parameters(): 18 | p.data.normal_(0,std) 19 | 20 | # init last linear layer of the transformer at 0 21 | model.transformer.net[-1].weight.data.zero_() 22 | model.transformer.net[-1].bias.data.copy_(torch.eye(3).flatten()[:model.transformer.net[-1].out_features]) 23 | # NOTE: this initialization the last layer of the transformer layer to identity here means the apply_tranform function should not 24 | # add an identity matrix when converting coordinates 25 | 26 | 27 | # -------------------- 28 | # Model components 29 | # -------------------- 30 | 31 | class BasicSTNModule(nn.Module): 32 | """ pytorch builtin affine transform """ 33 | def __init__(self, params, out_dim=6): 34 | super().__init__() 35 | self.net = nn.Sequential(nn.Conv2d(1, 4, kernel_size=7), # (N, 1, 28, 28) > (N, 4, 22, 22) 36 | nn.ReLU(True), 37 | nn.Conv2d(4, 8, kernel_size=7), # (N, 4, 20, 20) > (N, 8, 16, 16) 38 | nn.MaxPool2d(2, stride=2), # (N, 8, 18, 18) > (N, 8, 8, 8) 39 | nn.ReLU(True), 40 | Flatten(), 41 | nn.Linear(8**3, 48), 42 | nn.ReLU(True), 43 | nn.Linear(48, out_dim)) 44 | 45 | def forward(self, x, P_init): 46 | x = apply_transform_to_batch(x, P_init) 47 | theta = self.net(x).view(-1,2,3) 48 | grid = F.affine_grid(theta, x.size()) 49 | return F.grid_sample(x, grid) 50 | 51 | 52 | class STNModule(BasicSTNModule): 53 | """ homography stn """ 54 | def __init__(self, params, out_dim=8): 55 | super().__init__(params, out_dim) 56 | 57 | def forward(self, x, P_init): 58 | # apply the perturbation matrix to the minibatch of image tensors 59 | x = apply_transform_to_batch(x, P_init) 60 | # predict the transformation to approximate 61 | p = self.net(x) 62 | # convert to matrix 63 | P_net = vec_to_perpective_matrix(p) 64 | # apply to the original image 65 | return apply_transform_to_batch(x, P_net) 66 | 67 | 68 | class ICSTNModule(STNModule): 69 | """ inverse compositional stn cf Lin, Lucey ICSTN paper """ 70 | def __init__(self, params): 71 | super().__init__(params) 72 | self.icstn_steps = params.icstn_steps 73 | 74 | def forward(self, x, P_init): 75 | P = P_init 76 | # apply spatial transform recurrently for n_steps 77 | for i in range(self.icstn_steps): 78 | # apply the perturbation matrix to the minibatch of image tensors 79 | transformed_x = apply_transform_to_batch(x, P) 80 | # predict the trasnform 81 | p = self.net(transformed_x) 82 | # convert to matrix 83 | P_net = vec_to_perpective_matrix(p) 84 | # compose transform with previous 85 | P = P @ P_net # compose on the left; apply_transform_to_batch takes the composite transform and right multiplies by xy_hom 86 | # apply the final composite transform to the original image 87 | return apply_transform_to_batch(x, P) 88 | 89 | 90 | class ClassifierModule(nn.Module): 91 | def __init__(self, out_dim=10): 92 | super().__init__() 93 | self.net = nn.Sequential(nn.Conv2d(1, 3, kernel_size=9), # (N, 1, 28, 28) > (N, 3, 20, 20) 94 | nn.ReLU(True), 95 | Flatten(), 96 | nn.Linear(3*20*20, out_dim)) 97 | 98 | def forward(self, x): 99 | return self.net(x) 100 | 101 | 102 | # -------------------- 103 | # Model 104 | # -------------------- 105 | 106 | class STN(nn.Module): 107 | def __init__(self, transformer_module, params): 108 | super().__init__() 109 | self.transformer = transformer_module(params) 110 | self.clf = ClassifierModule() 111 | 112 | def forward(self, x, P): 113 | # take minibatch of image tensors x and geometric transform P 114 | x = self.transformer(x, P) 115 | # return the output of the transformer and the output of the classifier 116 | return x, self.clf(x) 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pprint 4 | from tqdm import tqdm 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torchvision.utils import make_grid, save_image 9 | 10 | import model 11 | from data_loader import fetch_dataloader 12 | from vision_transforms import gen_random_perspective_transform, apply_transform_to_batch 13 | import utils 14 | 15 | 16 | parser = argparse.ArgumentParser(description='Evaluate a model') 17 | parser.add_argument('--output_dir', help='Directory containing params.json and weights') 18 | parser.add_argument('--restore_file', help='Name of the file in containing weights to load') 19 | parser.add_argument('--cuda', type=int, help='Which cuda device to use') 20 | 21 | 22 | @torch.no_grad() 23 | def visualize_sample(model, dataset, writer, params, step, n_samples=20): 24 | model.eval() 25 | 26 | sample = torch.stack([dataset[i][0] for i in range(n_samples)], dim=0).to(params.device) 27 | 28 | P = gen_random_perspective_transform(params)[:n_samples] 29 | perturbed_sample = apply_transform_to_batch(sample, P) 30 | transformed_sample, scores = model(sample, P) 31 | 32 | perturbed_sample = perturbed_sample.view(n_samples, 1, 28, 28) 33 | transformed_sample = transformed_sample.view(n_samples, 1, 28, 28) 34 | 35 | sample = torch.cat([sample, perturbed_sample, transformed_sample], dim=0) 36 | sample = make_grid(sample.cpu(), nrow=n_samples, normalize=True, padding=1, pad_value=1) 37 | 38 | if writer: 39 | writer.add_image('sample', sample, step) 40 | 41 | save_image(sample, os.path.join(params.output_dir, 'samples__orig_perturbed_transformed' + (step!=None)*'_step_{}'.format(step) + '.png')) 42 | 43 | 44 | @torch.no_grad() 45 | def evaluate(model, dataloader, writer, params): 46 | model.eval() 47 | 48 | # init trackers 49 | accuracy = [] 50 | labels = [] 51 | original = [] 52 | perturbed = [] 53 | transformed = [] 54 | 55 | with tqdm(total=len(dataloader), desc='eval') as pbar: 56 | for i, (im_batch, labels_batch) in enumerate(dataloader): 57 | im_batch = im_batch.to(params.device) 58 | 59 | # get a random transformation and run through the batch 60 | P = gen_random_perspective_transform(params) 61 | 62 | transformed_batch, scores = model(im_batch, P) 63 | log_probs = F.log_softmax(scores, dim=1) 64 | 65 | # get predictions and calculate accuracy 66 | _, pred = torch.max(log_probs.cpu(), dim=1) 67 | accuracy.append(pred.eq(labels_batch.view_as(pred)).sum().item() / im_batch.shape[0]) 68 | 69 | 70 | # record to compute mean image with variance for original, perturbed, and transformed image (cf Lin, Lucey ICSTN paper) 71 | labels.append(labels_batch) 72 | original.append(im_batch) 73 | perturbed.append(apply_transform_to_batch(im_batch, P)) 74 | transformed.append(transformed_batch) 75 | 76 | avg_accuracy = sum(accuracy) / len(accuracy) 77 | pbar.set_postfix(accuracy='{:.5f}'.format(avg_accuracy)) 78 | pbar.update() 79 | 80 | labels = torch.cat(labels, dim=0) 81 | unique_labels = torch.unique(labels, sorted=True) 82 | original = torch.cat(original, dim=0) 83 | perturbed = torch.cat(perturbed, dim=0) 84 | transformed = torch.cat(transformed, dim=0) 85 | 86 | # compute mean image with variance for original, perturbed, and transformed image for each digit (cf Lin, Lucey ICSTN paper) 87 | image = torch.stack([original, perturbed, transformed], dim=0) # (3, len(data), C, H, W) 88 | mean_image = [make_grid(torch.mean(image[:, labels==i, ...], dim=1).cpu(), nrow=1) for i in unique_labels] 89 | var_image = [make_grid(torch.var(image[:, labels==i, ...], dim=1).cpu(), nrow=1) for i in unique_labels] 90 | var_image = make_grid(var_image, nrow=len(unique_labels)) 91 | 92 | # save mean and var image 93 | save_image(mean_image, os.path.join(params.output_dir, 'test_image_mean.png'), nrow=len(unique_labels)) 94 | save_image(var_image, os.path.join(params.output_dir, 'test_image_var.png'), nrow=len(unique_labels), normalize=True) 95 | 96 | # save accuracy 97 | with open(os.path.join(params.output_dir, 'eval_accuracy.txt'), 'w') as f: 98 | f.write('Mean evaluation accuracy {:.3f}'.format(avg_accuracy)) 99 | 100 | return avg_accuracy 101 | 102 | 103 | 104 | 105 | if __name__ == '__main__': 106 | args = parser.parse_args() 107 | 108 | # load params 109 | json_path = os.path.join(args.output_dir, 'params.json') 110 | assert os.path.isfile(json_path), 'No json configuration file found at {}'.format(json_path) 111 | params = utils.Params(json_path) 112 | 113 | # check output folder exist and if it is rel path 114 | if not os.path.isdir(params.output_dir): 115 | os.mkdir(params.output_dir) 116 | 117 | writer = utils.set_writer(params.output_dir) 118 | 119 | params.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda else 'cpu') 120 | 121 | # set random seed 122 | torch.manual_seed(11052018) 123 | if params.device.type is 'cuda': torch.cuda.manual_seed(11052018) 124 | 125 | # input 126 | dataloader = fetch_dataloader(params, train=False) 127 | 128 | # load model 129 | model = model.STN(getattr(model, params.stn_module), params).to(params.device) 130 | utils.load_checkpoint(args.restore_file, model) 131 | 132 | # run inference 133 | print('\nEvaluating with model:\n', model) 134 | print('\n.. and parameters:\n', pprint.pformat(params)) 135 | accuracy = evaluate(model, dataloader, writer, params) 136 | visualize_sample(model, dataloader.dataset, writer, params, None) 137 | print('Evaluation accuracy: {:.5f}'.format(accuracy)) 138 | 139 | writer.close() 140 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | from tqdm import tqdm 6 | import pprint 7 | 8 | import model 9 | from model import initialize 10 | from data_loader import fetch_dataloader 11 | from evaluate import evaluate, visualize_sample 12 | from vision_transforms import gen_random_perspective_transform, apply_transform_to_batch 13 | import utils 14 | 15 | 16 | parser = argparse.ArgumentParser(description='Train a model') 17 | parser.add_argument('--output_dir', help='Directory containing params.json and weights') 18 | parser.add_argument('--restore_file', help='Name of the file containing weights to load') 19 | parser.add_argument('--cuda', type=int, help='Which cuda device to use') 20 | 21 | 22 | def train_epoch(model, dataloader, loss_fn, optimizer, writer, params, epoch): 23 | model.train() 24 | 25 | loss_avg = utils.RunningAverage() 26 | loss_history = [] 27 | best_loss = float('inf') 28 | vis_counter = 0 29 | samples = {} 30 | lrs = [optimizer.param_groups[i]['lr'] for i in range(len(optimizer.param_groups))] 31 | 32 | with tqdm(total=len(dataloader), desc='epoch {} of {}. lr: [{:.0e}, {:.0e}]'.format(epoch + 1, params.n_epochs, *lrs)) as pbar: 33 | for i, (train_batch, labels_batch) in enumerate(dataloader): 34 | # move to gpu if available 35 | train_batch = train_batch.to(params.device) 36 | labels_batch = labels_batch.to(params.device) 37 | 38 | P = gen_random_perspective_transform(params) 39 | 40 | transformed_train_batch, scores = model(train_batch, P) 41 | 42 | loss = loss_fn(scores, labels_batch) 43 | 44 | optimizer.zero_grad() 45 | loss.backward() 46 | optimizer.step() 47 | 48 | 49 | # update trackers 50 | loss_avg.update(loss.item()) 51 | pbar.set_postfix(loss='{:.5f}'.format(loss_avg())) 52 | pbar.update() 53 | 54 | # write summary 55 | if i % params.save_summary_steps == 0: 56 | writer.add_scalar('loss', loss.item(), epoch*(i+1)) 57 | loss_history.append(loss.item()) 58 | 59 | return loss_history 60 | 61 | 62 | def train_and_evaluate(model, train_dataloader, val_dataloader, loss_fn, optimizer, scheduler, writer, params): 63 | 64 | best_loss = float('inf') 65 | start_epoch = 0 66 | 67 | if params.restore_file: 68 | print('Restoring parameters from {}'.format(params.restore_file)) 69 | start_epoch = utils.load_checkpoint(params.restore_file, model, optimizer, scheduler, best_loss) 70 | params.n_epochs += start_epoch - 1 71 | print('Resuming training from epoch {}'.format(start_epoch)) 72 | 73 | for epoch in range(start_epoch, params.n_epochs): 74 | scheduler.step() 75 | 76 | loss_history = train_epoch(model, train_dataloader, loss_fn, optimizer, writer, params, epoch) 77 | 78 | # snapshot at end of epoch 79 | is_best = sum(loss_history[:1000])/1000 < best_loss 80 | if is_best: best_loss = sum(loss_history[:1000])/1000 81 | utils.save_checkpoint({'epoch': epoch + 1, 82 | 'best_loss': best_loss, 83 | 'model_state_dict': model.state_dict(), 84 | 'optimizer_state_dict': optimizer.state_dict(), 85 | 'scheduler_state_dict': scheduler.state_dict()}, 86 | is_best=False, 87 | checkpoint=params.output_dir, 88 | quiet=True) 89 | 90 | # visualize 91 | visualize_sample(model, val_dataloader.dataset, writer, params, epoch+1) 92 | 93 | # evalutate and visualize 94 | val_accuracy = evaluate(model, val_dataloader, writer, params) 95 | 96 | # record val accuracy 97 | writer.add_scalar('val_accuracy', val_accuracy, epoch+1) 98 | 99 | 100 | if __name__ == '__main__': 101 | args = parser.parse_args() 102 | 103 | json_path = os.path.join(args.output_dir, 'params.json') 104 | assert os.path.isfile(json_path), 'No json configuration file found at {}'.format(json_path) 105 | params = utils.Params(json_path) 106 | 107 | params.restore_file = args.restore_file 108 | 109 | # check output folder exist and if it is rel path 110 | if not os.path.isdir(params.output_dir): 111 | os.mkdir(params.output_dir) 112 | 113 | writer = utils.set_writer(params.output_dir if args.restore_file is None else os.path.dirname(args.restore_file)) 114 | 115 | params.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda else 'cpu') 116 | 117 | # set random seed 118 | torch.manual_seed(11052018) 119 | if params.device.type is 'cuda': torch.cuda.manual_seed(11052018) 120 | 121 | # input 122 | train_dataloader = fetch_dataloader(params, train=True) 123 | val_dataloader = fetch_dataloader(params, train=False) 124 | 125 | # construct model 126 | # dims out (pytorch affine grid requires 2x3 matrix output; else perspective transform requires 8) 127 | model = model.STN(getattr(model, params.stn_module), params).to(params.device) 128 | # initialize 129 | initialize(model) 130 | capacity = sum(p.numel() for p in model.parameters()) 131 | 132 | loss_fn = torch.nn.CrossEntropyLoss().to(params.device) 133 | optimizer = torch.optim.Adam([ 134 | {'params': model.transformer.parameters(), 'lr': params.transformer_lr}, 135 | {'params': model.clf.parameters(), 'lr': params.clf_lr}]) 136 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, params.lr_step, params.lr_gamma) 137 | 138 | # train and eval 139 | print('\nStarting training with model (capacity {}):\n'.format(capacity), model) 140 | print('\nParameters:\n', pprint.pformat(params)) 141 | train_and_evaluate(model, train_dataloader, val_dataloader, loss_fn, optimizer, scheduler, writer, params) 142 | 143 | writer.close() 144 | 145 | 146 | -------------------------------------------------------------------------------- /vision_transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def vec_to_perpective_matrix(vec): 9 | # vec rep of the perspective transform has 8 dof; so add 1 for the bottom right of the perspective matrix; 10 | # note network is initialized to transformer layer bias = [1, 0, 0, 0, 1, 0] so no need to add an identity matrix here 11 | out = torch.cat((vec, torch.ones((vec.shape[0],1), dtype=vec.dtype, device=vec.device)), dim=1).reshape(vec.shape[0], -1) 12 | return out.view(-1,3,3) 13 | 14 | 15 | def gen_random_perspective_transform(params): 16 | """ generate a batch of 3x3 homography matrices by composing rotation, translation, shear, and projection matrices, 17 | where each samples components from a uniform(-1,1) * multiplicative_factor 18 | """ 19 | 20 | batch_size = params.batch_size 21 | 22 | # debugging 23 | if params.dict.get('identity_transform_only'): 24 | return torch.eye(3).repeat(batch_size, 1, 1).to(params.device) 25 | 26 | 27 | I = torch.eye(3).repeat(batch_size, 1, 1) 28 | uniform = torch.distributions.Uniform(-1,1) 29 | factor = 0.25 30 | c = copy.deepcopy 31 | 32 | # rotation component 33 | a = math.pi / 6 * uniform.sample((batch_size,)) 34 | R = c(I) 35 | R[:, 0, 0] = torch.cos(a) 36 | R[:, 0, 1] = - torch.sin(a) 37 | R[:, 1, 0] = torch.sin(a) 38 | R[:, 1, 1] = torch.cos(a) 39 | R.to(params.device) 40 | 41 | # translation component 42 | tx = factor * uniform.sample((batch_size,)) 43 | ty = factor * uniform.sample((batch_size,)) 44 | T = c(I) 45 | T[:, 0, 2] = tx 46 | T[:, 1, 2] = ty 47 | T.to(params.device) 48 | 49 | # shear component 50 | sx = factor * uniform.sample((batch_size,)) 51 | sy = factor * uniform.sample((batch_size,)) 52 | A = c(I) 53 | A[:, 0, 1] = sx 54 | A[:, 1, 0] = sy 55 | A.to(params.device) 56 | 57 | # projective component 58 | px = uniform.sample((batch_size,)) 59 | py = uniform.sample((batch_size,)) 60 | P = c(I) 61 | P[:, 2, 0] = px 62 | P[:, 2, 1] = py 63 | P.to(params.device) 64 | 65 | # compose the homography 66 | H = R @ T @ P @ A 67 | 68 | return H 69 | 70 | 71 | def apply_transform_to_batch(im_batch_tensor, transform_tensor): 72 | """ apply a geometric transform to a batch of image tensors 73 | args 74 | im_batch_tensor -- torch float tensor of shape (N, C, H, W) 75 | transform_tensor -- torch float tensor of shape (1, 3, 3) 76 | 77 | returns 78 | transformed_batch_tensor -- torch float tensor of shape (N, C, H, W) 79 | """ 80 | N, C, H, W = im_batch_tensor.shape 81 | device = im_batch_tensor.device 82 | 83 | # torch.nn.functional.grid_sample takes a grid in [-1,1] and interpolates; 84 | # construct grid in homogeneous coordinates 85 | x, y = torch.meshgrid([torch.linspace(-1, 1, H), torch.linspace(-1, 1, W)]) 86 | x, y = x.flatten(), y.flatten() 87 | xy_hom = torch.stack([x, y, torch.ones(x.shape[0])], dim=0).unsqueeze(0).to(device) 88 | 89 | # tansform the [-1,1] homogeneous coords 90 | xy_transformed = transform_tensor.matmul(xy_hom) # (N, 3, 3) matmul (N, 3, H*W) > (N, 3, H*W) 91 | # convert to inhomogeneous coords -- cf Szeliski eq. 2.21 92 | 93 | grid = xy_transformed[:,:2,:] / (xy_transformed[:,2,:].unsqueeze(1) + 1e-9) 94 | grid = grid.permute(0,2,1).reshape(-1, H, W, 2) # (N, H, W, 2); cf torch.functional.grid_sample 95 | grid = grid.expand(N, *grid.shape[1:]) # expand to minibatch 96 | 97 | transformed_batch = F.grid_sample(im_batch_tensor, grid, mode='bilinear') 98 | transformed_batch.transpose_(3,2) 99 | 100 | return transformed_batch 101 | 102 | 103 | 104 | 105 | # -------------------- 106 | # Test 107 | # -------------------- 108 | 109 | def test_get_random_perspective_transform(): 110 | import matplotlib 111 | matplotlib.use('TkAgg') 112 | import numpy as np 113 | import matplotlib.pyplot as plt 114 | from unittest.mock import Mock 115 | 116 | np.random.seed(6) 117 | 118 | im = np.zeros((30,30)) 119 | im[10:20,10:20] = 1 120 | im[20,20] = 1 121 | 122 | imt = np.array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 123 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 124 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 125 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 126 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 127 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 128 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 129 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 130 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 131 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 132 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 18, 133 | 18, 18, 126, 136, 175, 26, 166, 255, 247, 127, 0, 0, 0, 0], 134 | [ 0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253, 135 | 253, 253, 253, 253, 225, 172, 253, 242, 195, 64, 0, 0, 0, 0], 136 | [ 0, 0, 0, 0, 0, 0, 0, 49, 238, 253, 253, 253, 253, 253, 137 | 253, 253, 253, 251, 93, 82, 82, 56, 39, 0, 0, 0, 0, 0], 138 | [ 0, 0, 0, 0, 0, 0, 0, 18, 219, 253, 253, 253, 253, 253, 139 | 198, 182, 247, 241, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 140 | [ 0, 0, 0, 0, 0, 0, 0, 0, 80, 156, 107, 253, 253, 205, 141 | 11, 0, 43, 154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 142 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 1, 154, 253, 90, 143 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 144 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190, 145 | 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 146 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253, 147 | 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 148 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241, 149 | 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 150 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81, 151 | 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0], 152 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 153 | 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0], 154 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 155 | 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0], 156 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 157 | 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0], 158 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 159 | 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0], 160 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148, 161 | 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0], 162 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253, 163 | 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0], 164 | [ 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253, 165 | 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 166 | [ 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195, 167 | 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 168 | [ 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11, 169 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 170 | [ 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0, 171 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 172 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 173 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 174 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 175 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 176 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 177 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) 178 | 179 | 180 | 181 | # get transform 182 | params = Mock() 183 | params.batch_size = 1 184 | params.dict = {'identity_transform_only': False} 185 | params.device = torch.device('cpu') 186 | H = gen_random_perspective_transform(params) 187 | 188 | im = im[np.newaxis, np.newaxis, ...] 189 | im = torch.FloatTensor(im) 190 | im_transformed = apply_transform_to_batch(im, H) 191 | 192 | imt = imt[np.newaxis, np.newaxis, ...] 193 | imt = torch.FloatTensor(imt) 194 | imt_transformed = apply_transform_to_batch(imt, H) 195 | 196 | fig, axs = plt.subplots(2,2) 197 | 198 | axs[0,0].imshow(im.squeeze().numpy(), cmap='gray') 199 | axs[0,1].imshow(im_transformed.squeeze().numpy(), cmap='gray') 200 | 201 | axs[1,0].imshow(imt.squeeze().numpy(), cmap='gray') 202 | axs[1,1].imshow(imt_transformed.squeeze().numpy(), cmap='gray') 203 | 204 | for ax in plt.gcf().axes: 205 | ax.axis('off') 206 | plt.tight_layout() 207 | plt.savefig('images/transform_test.png') 208 | plt.close() 209 | 210 | 211 | if __name__ == '__main__': 212 | test_get_random_perspective_transform() 213 | 214 | 215 | --------------------------------------------------------------------------------