├── .gitignore
├── images
├── transform_test.png
├── basic_test_image_var.png
├── icstn_test_image_var.png
├── stn_test_image_mean.png
├── stn_test_image_var.png
├── basic_alignment_sample.png
├── basic_test_image_mean.png
├── icstn_test_image_mean.png
├── stn_alignment_samples.png
├── basic_alignment_samples.png
└── icstn_alignment_samples.png
├── requirements.txt
├── experiments
├── base_stn_model
│ └── params.json
├── base_icstn_model
│ └── params.json
└── base_basic_model
│ └── params.json
├── data_loader.py
├── search_hyperparams.py
├── readme.md
├── utils.py
├── model.py
├── evaluate.py
├── train.py
└── vision_transforms.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
3 | # exclude data
4 | data
5 |
6 | # virtual env
7 | .env
8 |
--------------------------------------------------------------------------------
/images/transform_test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/transform_test.png
--------------------------------------------------------------------------------
/images/basic_test_image_var.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/basic_test_image_var.png
--------------------------------------------------------------------------------
/images/icstn_test_image_var.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/icstn_test_image_var.png
--------------------------------------------------------------------------------
/images/stn_test_image_mean.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/stn_test_image_mean.png
--------------------------------------------------------------------------------
/images/stn_test_image_var.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/stn_test_image_var.png
--------------------------------------------------------------------------------
/images/basic_alignment_sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/basic_alignment_sample.png
--------------------------------------------------------------------------------
/images/basic_test_image_mean.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/basic_test_image_mean.png
--------------------------------------------------------------------------------
/images/icstn_test_image_mean.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/icstn_test_image_mean.png
--------------------------------------------------------------------------------
/images/stn_alignment_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/stn_alignment_samples.png
--------------------------------------------------------------------------------
/images/basic_alignment_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/basic_alignment_samples.png
--------------------------------------------------------------------------------
/images/icstn_alignment_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/spatial_transformer/HEAD/images/icstn_alignment_samples.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cycler==0.10.0
2 | kiwisolver==1.0.1
3 | matplotlib==3.0.1
4 | numpy==1.15.3
5 | Pillow==5.3.0
6 | protobuf==3.6.1
7 | pyparsing==2.3.0
8 | python-dateutil==2.7.5
9 | pytz==2018.7
10 | six==1.11.0
11 | tensorboardX==1.4
12 | torch==0.4.1
13 | torchvision==0.2.1
14 | tqdm==4.28.1
15 |
--------------------------------------------------------------------------------
/experiments/base_stn_model/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "stn_module": "STNModule",
3 | "icstn_steps": 4,
4 | "data_dir": "./data",
5 | "batch_size": 128,
6 | "transformer_lr": 1e-3,
7 | "clf_lr": 1e-3,
8 | "lr_step": 1,
9 | "lr_gamma": 1,
10 | "n_epochs": 10,
11 | "save_summary_steps": 10000,
12 | "mini_data": false
13 | }
14 |
15 |
--------------------------------------------------------------------------------
/experiments/base_icstn_model/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "stn_module": "ICSTNModule",
3 | "icstn_steps": 4,
4 | "data_dir": "./data",
5 | "batch_size": 128,
6 | "transformer_lr": 5e-4,
7 | "clf_lr": 1e-3,
8 | "lr_step": 1,
9 | "lr_gamma": 1,
10 | "n_epochs": 10,
11 | "save_summary_steps": 10000,
12 | "mini_data": false
13 | }
14 |
15 |
--------------------------------------------------------------------------------
/experiments/base_basic_model/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "stn_module": "BasicSTNModule",
3 | "icstn_steps": 4,
4 | "data_dir": "./data",
5 | "batch_size": 128,
6 | "transformer_lr": 1e-3,
7 | "clf_lr": 1e-3,
8 | "lr_step": 1,
9 | "lr_gamma": 1,
10 | "n_epochs": 10,
11 | "save_summary_steps": 10000,
12 | "mini_data": false
13 | }
14 |
15 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torchvision.datasets import MNIST
4 | import torchvision.transforms as T
5 |
6 |
7 | def fetch_dataloader(params, train=True, mini_size=128):
8 |
9 | # load dataset and init in the dataloader
10 | transforms = T.Compose([T.ToTensor()])
11 | dataset = MNIST(root=params.data_dir, train=train, download=False, transform=transforms)
12 |
13 | if params.dict.get('mini_data'):
14 | if train:
15 | dataset.train_data = dataset.train_data[:mini_size]
16 | dataset.train_labels = dataset.train_labels[:mini_size]
17 | else:
18 | dataset.test_data = dataset.test_data[:mini_size]
19 | dataset.test_labels = dataset.test_labels[:mini_size]
20 |
21 | if params.dict.get('mini_ones'):
22 | if train:
23 | labels = dataset.train_labels[:2000]
24 | mask = labels==1
25 | dataset.train_labels = labels[mask][:mini_size]
26 | dataset.train_data = dataset.train_data[:2000][mask][:mini_size]
27 | else:
28 | labels = dataset.test_labels[:2000]
29 | mask = labels==1
30 | dataset.test_labels = labels[mask][:mini_size]
31 | dataset.test_data = dataset.test_data[:2000][mask][:mini_size]
32 |
33 | kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() and params.device.type is 'cuda' else {}
34 |
35 | return DataLoader(dataset, batch_size=params.batch_size, shuffle=True, drop_last=True, **kwargs)
36 |
37 |
38 |
--------------------------------------------------------------------------------
/search_hyperparams.py:
--------------------------------------------------------------------------------
1 | """ Perform hyperparameter search """
2 |
3 | import os
4 | import sys
5 | import json
6 | import argparse
7 | from copy import deepcopy
8 | from subprocess import check_call
9 |
10 | import torch
11 | import utils
12 |
13 |
14 | PYTHON = sys.executable
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--parent_dir', default='experiments', help='Directory containing hyperparams.json to setup a model.')
18 | parser.add_argument('--data_dir', default='./data', help='Directory containing the dataset')
19 | parser.add_argument('--cuda', type=int, help='Which cuda device to use')
20 |
21 |
22 | def launch_training_job(parent_dir, data_dir, job_name, params):
23 | """ launch training of the model with a set of hyperparameters in parent_dir/job_name """
24 |
25 | # create new filder in parent_dir with unique name 'job_name'
26 | output_dir = os.path.join(parent_dir, job_name)
27 | if not os.path.exists(output_dir):
28 | os.mkdir(output_dir)
29 |
30 | # write params in a json file
31 | json_path = os.path.join(output_dir, 'params.json')
32 | params.save(json_path)
33 |
34 | print('Launching training job with parameters:')
35 | print(params)
36 |
37 | # launch training with this config
38 | if params.device is 'cpu':
39 | cmd = '{python} train.py --output_dir={output_dir}'.format(
40 | python=PYTHON, output_dir=output_dir)
41 | else:
42 | cmd = '{python} train.py --output_dir={output_dir} --cuda={device}'.format(
43 | python=PYTHON, output_dir=output_dir, device=int(params.device.split(':')[1]))
44 |
45 |
46 | print(cmd)
47 |
48 | check_call(cmd, shell=True)
49 |
50 |
51 | if __name__ == '__main__':
52 | # load the references parameters from parent_dir json file
53 | args = parser.parse_args()
54 |
55 | json_path = os.path.join(args.parent_dir, 'hyperparams.json')
56 | assert os.path.isfile(json_path), 'No json configuration file found at {}'.format(json_path)
57 | hyperparams = utils.Params(json_path)
58 |
59 | json_path = os.path.join(args.parent_dir, 'base_params.json')
60 | assert os.path.isfile(json_path), 'No json configuration file found at {}'.format(json_path)
61 | base_params = utils.Params(json_path)
62 |
63 | # set the static parameters
64 | for param, values in hyperparams.dict.items():
65 | if isinstance(values, list):
66 | continue
67 | base_params.dict[param] = values
68 |
69 | base_params.device = 'cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda else 'cpu'
70 |
71 | # loop through the hyperparameter lists
72 | for param, values in hyperparams.dict.items():
73 | if isinstance(values, list):
74 | for v in values:
75 | params = deepcopy(base_params)
76 | # modify the parameter value to that in hyperparms
77 | params.dict[param] = v
78 |
79 | # launch job with unique name
80 | job_name = '{}_{}'.format(param, v)
81 | launch_training_job(args.parent_dir, args.data_dir, job_name, params)
82 |
83 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Spatial Transformer Networks
2 |
3 | Reimplementations of:
4 | * [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025)
5 | * [Inverse Compositional Spatial Transformer Networks](https://chenhsuanlin.bitbucket.io/inverse-compositional-STN/paper.pdf)
6 |
7 | Although implementations already exists, this focuses on simplicity and
8 | ease of understanding of the vision transforms and model.
9 |
10 | ## Results
11 |
12 | During training, random homography perturbations are applied to each image in the minibatch. The perturbations are composed by component transformation (rotation, translation, shear, projection), the parameters of each sampled from a uniform(-1,1) * 0.25 multiplicative factor.
13 |
14 | Example homography perturbation:
15 |
16 |
17 | ### Test set accuracy:
18 |
19 | | Model | Accuracy | Training params |
20 | | ----- | -------- | ----- |
21 | | Basic affine STN | 91.59% | 10 epochs at learning rate 1e-3 (classifier and transformer)|
22 | | Homography STN | 93.30% | 10 epochs at learning rate 1e-3 (classifier and transformer) |
23 | | Homography ICSTN | 97.67% | 10 epochs at learning rate 1e-3 (classifier) and 5e-4 (transformer) |
24 |
25 |
26 | ### Sample alignment results:
27 |
28 | #### Basic affine STN
29 |
30 | | Image | Samples |
31 | | --- | --- |
32 | | original
perturbed
transformed |  |
33 |
34 | #### Homography STN
35 |
36 | | Image | Samples |
37 | | --- | --- |
38 | | original
perturbed
transformed |  |
39 |
40 |
41 | #### Homography ICSTN
42 |
43 | | Image | Samples |
44 | | --- | --- |
45 | | original
perturbed
transformed |  |
46 |
47 |
48 | ### Mean and variance of the aligned results (cf Lin ICSTN paper)
49 |
50 | #### Mean image
51 | | Image | Basic affine STN | Homography STN | Homography ICSTN |
52 | | --- | ---------------- | -------------- | ---------------- |
53 | | original
perturbed
transformed |  |  |  |
54 |
55 | #### Variance
56 | | Image | Basic affine STN | Homography STN | Homography ICSTN |
57 | | --- | ---------------- | -------------- | ---------------- |
58 | | original
perturbed
transformed |  |  |  |
59 |
60 |
61 | ## Usage
62 |
63 | To train model:
64 | ```
65 | python train.py --output_dir=[path to params.json]
66 | --restore_file=[path to .pt checkpoint if resuming training]
67 | --cuda=[cuda device id]
68 | ```
69 | `params.json` provides training parameters and specifies which spatial transformer module to use:
70 | 1. `BasicSTNModule` -- affine transform localization network
71 | 2. `STNModule` -- homography transform localization network
72 | 3. `ICSTNModule` -- homography transform localization netwokr (cf Lin,
73 | ICSTN paper)
74 |
75 | To evaluate and visualize results:
76 | ```
77 | python evaluate.py --output_dir=[path to params.json]
78 | --restore_file=[path to .pt checkpoint]
79 | --cuda=[cuda device id]
80 | ```
81 |
82 | ## Dependencies
83 | * python 3.6
84 | * pytorch 0.4
85 | * torchvision
86 | * tensorboardX
87 | * numpy
88 | * matplotlib
89 | * tqdm
90 |
91 |
92 |
93 | ## Useful resources
94 | * https://github.com/chenhsuanlin/inverse-compositional-STN
95 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import json
4 | from datetime import datetime
5 | import torch
6 |
7 | from tensorboardX import SummaryWriter
8 |
9 |
10 |
11 |
12 | def set_writer(log_dir, comment=''):
13 | """ setup a tensorboardx summarywriter """
14 | # current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
15 | # log_dir = os.path.join(log_path, current_time + comment)
16 | writer = SummaryWriter(log_dir=log_dir)
17 | return writer
18 |
19 |
20 | def save_checkpoint(state, is_best, checkpoint, quiet=False):
21 | """ saves model and training params at checkpoint + 'last.pt'; if is_best also saves checkpoint + 'best.pt'
22 |
23 | args
24 | state -- dict; with keys model_state_dict, optimizer_state_dict, epoch, scheduler_state_dict, etc
25 | is_best -- bool; true if best model seen so far
26 | checkpoint -- str; folder where params are to be saved
27 | """
28 |
29 | filepath = os.path.join(checkpoint, 'state_checkpoint.pt')
30 | if not os.path.exists(checkpoint):
31 | if not quiet:
32 | print('Checkpoint directory does not exist Making directory {}'.format(checkpoint))
33 | os.mkdir(checkpoint)
34 |
35 | torch.save(state, filepath)
36 |
37 | if is_best:
38 | shutil.copyfile(filepath, os.path.join(checkpoint, 'best_state_checkpoint.pt'))
39 |
40 | if not quiet:
41 | print('Checkpoint saved.')
42 |
43 |
44 | def load_checkpoint(checkpoint, model, optimizer=None, scheduler=None, best_metric=None):
45 | """ loads model state_dict from filepath; if optimizer and lr_scheduler provided also loads them
46 |
47 | args
48 | checkpoint -- string of filename
49 | model -- torch nn.Module model
50 | optimizer -- torch.optim instance to resume from checkpoint
51 | lr_scheduler -- torch.optim.lr_scheduler instance to resume from checkpoint
52 | """
53 |
54 | if not os.path.exists(checkpoint):
55 | raise('File does not exist {}'.format(checkpoint))
56 |
57 | checkpoint = torch.load(checkpoint)
58 | model.load_state_dict(checkpoint['model_state_dict'])
59 |
60 | if optimizer:
61 | try:
62 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
63 | except KeyError:
64 | print('No optimizer state dict in checkpoint file')
65 |
66 | if best_metric:
67 | try:
68 | best_metric = checkpoint['best_val_acc']
69 | except KeyError:
70 | print('No best validation accuracy recorded in checkpoint file.')
71 |
72 | if scheduler:
73 | try:
74 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
75 | except KeyError:
76 | print('No lr scheduler state dict in checkpoint file')
77 |
78 | return checkpoint['epoch']
79 |
80 |
81 | # --------------------
82 | # Containers
83 | # --------------------
84 |
85 | class RunningAverage:
86 | """ a class to maintain the running average of a quantity
87 |
88 | example:
89 | ```
90 | loss_avg = RunningAverage()
91 | loss_avg.update(2)
92 | loss_avg.update(4)
93 | loss_avg() = 3
94 | ```
95 | """
96 |
97 | def __init__(self):
98 | self.steps = 0
99 | self.total = 0
100 |
101 | def __call__(self):
102 | return self.total/float(self.steps)
103 |
104 | def update(self, val):
105 | self.steps += 1
106 | self.total += val
107 |
108 |
109 |
110 | class Params:
111 | """ class that loads hyperparams from json file.
112 |
113 | example:
114 | ```
115 | params = Params(json_path)
116 | print(params.learning_rate)
117 | params.learning_rate = 0.5
118 | ```
119 | """
120 |
121 | def __init__(self, json_path):
122 | with open(json_path, 'r') as f:
123 | params = json.load(f)
124 | self.__dict__.update(params)
125 | self.__dict__['output_dir'] = os.path.dirname(json_path)
126 |
127 | def save(self, json_path):
128 | with open(json_path, 'w') as f:
129 | json.dump(self.__dict__, f, indent=4)
130 |
131 | def update(self, json_path):
132 | """ loads params from json file """
133 | with open(json_path, 'r') as f:
134 | params = json.load(f)
135 | self.__dict__.update(params)
136 |
137 | @property
138 | def dict(self):
139 | """ gives dict-like access to Params instances by `params.dict['learning_rate']` """
140 | return self.__dict__
141 |
142 | def __repr__(self):
143 | out = ''
144 | for k, v in self.__dict__.items():
145 | out += k + ': ' + str(v) + '\n'
146 | return out
147 |
148 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from vision_transforms import apply_transform_to_batch, vec_to_perpective_matrix
6 |
7 |
8 | # --------------------
9 | # Model helpers
10 | # --------------------
11 |
12 | class Flatten(nn.Module):
13 | def forward(self, x):
14 | return x.view(x.shape[0],-1)
15 |
16 | def initialize(model, std=0.1):
17 | for p in model.parameters():
18 | p.data.normal_(0,std)
19 |
20 | # init last linear layer of the transformer at 0
21 | model.transformer.net[-1].weight.data.zero_()
22 | model.transformer.net[-1].bias.data.copy_(torch.eye(3).flatten()[:model.transformer.net[-1].out_features])
23 | # NOTE: this initialization the last layer of the transformer layer to identity here means the apply_tranform function should not
24 | # add an identity matrix when converting coordinates
25 |
26 |
27 | # --------------------
28 | # Model components
29 | # --------------------
30 |
31 | class BasicSTNModule(nn.Module):
32 | """ pytorch builtin affine transform """
33 | def __init__(self, params, out_dim=6):
34 | super().__init__()
35 | self.net = nn.Sequential(nn.Conv2d(1, 4, kernel_size=7), # (N, 1, 28, 28) > (N, 4, 22, 22)
36 | nn.ReLU(True),
37 | nn.Conv2d(4, 8, kernel_size=7), # (N, 4, 20, 20) > (N, 8, 16, 16)
38 | nn.MaxPool2d(2, stride=2), # (N, 8, 18, 18) > (N, 8, 8, 8)
39 | nn.ReLU(True),
40 | Flatten(),
41 | nn.Linear(8**3, 48),
42 | nn.ReLU(True),
43 | nn.Linear(48, out_dim))
44 |
45 | def forward(self, x, P_init):
46 | x = apply_transform_to_batch(x, P_init)
47 | theta = self.net(x).view(-1,2,3)
48 | grid = F.affine_grid(theta, x.size())
49 | return F.grid_sample(x, grid)
50 |
51 |
52 | class STNModule(BasicSTNModule):
53 | """ homography stn """
54 | def __init__(self, params, out_dim=8):
55 | super().__init__(params, out_dim)
56 |
57 | def forward(self, x, P_init):
58 | # apply the perturbation matrix to the minibatch of image tensors
59 | x = apply_transform_to_batch(x, P_init)
60 | # predict the transformation to approximate
61 | p = self.net(x)
62 | # convert to matrix
63 | P_net = vec_to_perpective_matrix(p)
64 | # apply to the original image
65 | return apply_transform_to_batch(x, P_net)
66 |
67 |
68 | class ICSTNModule(STNModule):
69 | """ inverse compositional stn cf Lin, Lucey ICSTN paper """
70 | def __init__(self, params):
71 | super().__init__(params)
72 | self.icstn_steps = params.icstn_steps
73 |
74 | def forward(self, x, P_init):
75 | P = P_init
76 | # apply spatial transform recurrently for n_steps
77 | for i in range(self.icstn_steps):
78 | # apply the perturbation matrix to the minibatch of image tensors
79 | transformed_x = apply_transform_to_batch(x, P)
80 | # predict the trasnform
81 | p = self.net(transformed_x)
82 | # convert to matrix
83 | P_net = vec_to_perpective_matrix(p)
84 | # compose transform with previous
85 | P = P @ P_net # compose on the left; apply_transform_to_batch takes the composite transform and right multiplies by xy_hom
86 | # apply the final composite transform to the original image
87 | return apply_transform_to_batch(x, P)
88 |
89 |
90 | class ClassifierModule(nn.Module):
91 | def __init__(self, out_dim=10):
92 | super().__init__()
93 | self.net = nn.Sequential(nn.Conv2d(1, 3, kernel_size=9), # (N, 1, 28, 28) > (N, 3, 20, 20)
94 | nn.ReLU(True),
95 | Flatten(),
96 | nn.Linear(3*20*20, out_dim))
97 |
98 | def forward(self, x):
99 | return self.net(x)
100 |
101 |
102 | # --------------------
103 | # Model
104 | # --------------------
105 |
106 | class STN(nn.Module):
107 | def __init__(self, transformer_module, params):
108 | super().__init__()
109 | self.transformer = transformer_module(params)
110 | self.clf = ClassifierModule()
111 |
112 | def forward(self, x, P):
113 | # take minibatch of image tensors x and geometric transform P
114 | x = self.transformer(x, P)
115 | # return the output of the transformer and the output of the classifier
116 | return x, self.clf(x)
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pprint
4 | from tqdm import tqdm
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torchvision.utils import make_grid, save_image
9 |
10 | import model
11 | from data_loader import fetch_dataloader
12 | from vision_transforms import gen_random_perspective_transform, apply_transform_to_batch
13 | import utils
14 |
15 |
16 | parser = argparse.ArgumentParser(description='Evaluate a model')
17 | parser.add_argument('--output_dir', help='Directory containing params.json and weights')
18 | parser.add_argument('--restore_file', help='Name of the file in containing weights to load')
19 | parser.add_argument('--cuda', type=int, help='Which cuda device to use')
20 |
21 |
22 | @torch.no_grad()
23 | def visualize_sample(model, dataset, writer, params, step, n_samples=20):
24 | model.eval()
25 |
26 | sample = torch.stack([dataset[i][0] for i in range(n_samples)], dim=0).to(params.device)
27 |
28 | P = gen_random_perspective_transform(params)[:n_samples]
29 | perturbed_sample = apply_transform_to_batch(sample, P)
30 | transformed_sample, scores = model(sample, P)
31 |
32 | perturbed_sample = perturbed_sample.view(n_samples, 1, 28, 28)
33 | transformed_sample = transformed_sample.view(n_samples, 1, 28, 28)
34 |
35 | sample = torch.cat([sample, perturbed_sample, transformed_sample], dim=0)
36 | sample = make_grid(sample.cpu(), nrow=n_samples, normalize=True, padding=1, pad_value=1)
37 |
38 | if writer:
39 | writer.add_image('sample', sample, step)
40 |
41 | save_image(sample, os.path.join(params.output_dir, 'samples__orig_perturbed_transformed' + (step!=None)*'_step_{}'.format(step) + '.png'))
42 |
43 |
44 | @torch.no_grad()
45 | def evaluate(model, dataloader, writer, params):
46 | model.eval()
47 |
48 | # init trackers
49 | accuracy = []
50 | labels = []
51 | original = []
52 | perturbed = []
53 | transformed = []
54 |
55 | with tqdm(total=len(dataloader), desc='eval') as pbar:
56 | for i, (im_batch, labels_batch) in enumerate(dataloader):
57 | im_batch = im_batch.to(params.device)
58 |
59 | # get a random transformation and run through the batch
60 | P = gen_random_perspective_transform(params)
61 |
62 | transformed_batch, scores = model(im_batch, P)
63 | log_probs = F.log_softmax(scores, dim=1)
64 |
65 | # get predictions and calculate accuracy
66 | _, pred = torch.max(log_probs.cpu(), dim=1)
67 | accuracy.append(pred.eq(labels_batch.view_as(pred)).sum().item() / im_batch.shape[0])
68 |
69 |
70 | # record to compute mean image with variance for original, perturbed, and transformed image (cf Lin, Lucey ICSTN paper)
71 | labels.append(labels_batch)
72 | original.append(im_batch)
73 | perturbed.append(apply_transform_to_batch(im_batch, P))
74 | transformed.append(transformed_batch)
75 |
76 | avg_accuracy = sum(accuracy) / len(accuracy)
77 | pbar.set_postfix(accuracy='{:.5f}'.format(avg_accuracy))
78 | pbar.update()
79 |
80 | labels = torch.cat(labels, dim=0)
81 | unique_labels = torch.unique(labels, sorted=True)
82 | original = torch.cat(original, dim=0)
83 | perturbed = torch.cat(perturbed, dim=0)
84 | transformed = torch.cat(transformed, dim=0)
85 |
86 | # compute mean image with variance for original, perturbed, and transformed image for each digit (cf Lin, Lucey ICSTN paper)
87 | image = torch.stack([original, perturbed, transformed], dim=0) # (3, len(data), C, H, W)
88 | mean_image = [make_grid(torch.mean(image[:, labels==i, ...], dim=1).cpu(), nrow=1) for i in unique_labels]
89 | var_image = [make_grid(torch.var(image[:, labels==i, ...], dim=1).cpu(), nrow=1) for i in unique_labels]
90 | var_image = make_grid(var_image, nrow=len(unique_labels))
91 |
92 | # save mean and var image
93 | save_image(mean_image, os.path.join(params.output_dir, 'test_image_mean.png'), nrow=len(unique_labels))
94 | save_image(var_image, os.path.join(params.output_dir, 'test_image_var.png'), nrow=len(unique_labels), normalize=True)
95 |
96 | # save accuracy
97 | with open(os.path.join(params.output_dir, 'eval_accuracy.txt'), 'w') as f:
98 | f.write('Mean evaluation accuracy {:.3f}'.format(avg_accuracy))
99 |
100 | return avg_accuracy
101 |
102 |
103 |
104 |
105 | if __name__ == '__main__':
106 | args = parser.parse_args()
107 |
108 | # load params
109 | json_path = os.path.join(args.output_dir, 'params.json')
110 | assert os.path.isfile(json_path), 'No json configuration file found at {}'.format(json_path)
111 | params = utils.Params(json_path)
112 |
113 | # check output folder exist and if it is rel path
114 | if not os.path.isdir(params.output_dir):
115 | os.mkdir(params.output_dir)
116 |
117 | writer = utils.set_writer(params.output_dir)
118 |
119 | params.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda else 'cpu')
120 |
121 | # set random seed
122 | torch.manual_seed(11052018)
123 | if params.device.type is 'cuda': torch.cuda.manual_seed(11052018)
124 |
125 | # input
126 | dataloader = fetch_dataloader(params, train=False)
127 |
128 | # load model
129 | model = model.STN(getattr(model, params.stn_module), params).to(params.device)
130 | utils.load_checkpoint(args.restore_file, model)
131 |
132 | # run inference
133 | print('\nEvaluating with model:\n', model)
134 | print('\n.. and parameters:\n', pprint.pformat(params))
135 | accuracy = evaluate(model, dataloader, writer, params)
136 | visualize_sample(model, dataloader.dataset, writer, params, None)
137 | print('Evaluation accuracy: {:.5f}'.format(accuracy))
138 |
139 | writer.close()
140 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 | from tqdm import tqdm
6 | import pprint
7 |
8 | import model
9 | from model import initialize
10 | from data_loader import fetch_dataloader
11 | from evaluate import evaluate, visualize_sample
12 | from vision_transforms import gen_random_perspective_transform, apply_transform_to_batch
13 | import utils
14 |
15 |
16 | parser = argparse.ArgumentParser(description='Train a model')
17 | parser.add_argument('--output_dir', help='Directory containing params.json and weights')
18 | parser.add_argument('--restore_file', help='Name of the file containing weights to load')
19 | parser.add_argument('--cuda', type=int, help='Which cuda device to use')
20 |
21 |
22 | def train_epoch(model, dataloader, loss_fn, optimizer, writer, params, epoch):
23 | model.train()
24 |
25 | loss_avg = utils.RunningAverage()
26 | loss_history = []
27 | best_loss = float('inf')
28 | vis_counter = 0
29 | samples = {}
30 | lrs = [optimizer.param_groups[i]['lr'] for i in range(len(optimizer.param_groups))]
31 |
32 | with tqdm(total=len(dataloader), desc='epoch {} of {}. lr: [{:.0e}, {:.0e}]'.format(epoch + 1, params.n_epochs, *lrs)) as pbar:
33 | for i, (train_batch, labels_batch) in enumerate(dataloader):
34 | # move to gpu if available
35 | train_batch = train_batch.to(params.device)
36 | labels_batch = labels_batch.to(params.device)
37 |
38 | P = gen_random_perspective_transform(params)
39 |
40 | transformed_train_batch, scores = model(train_batch, P)
41 |
42 | loss = loss_fn(scores, labels_batch)
43 |
44 | optimizer.zero_grad()
45 | loss.backward()
46 | optimizer.step()
47 |
48 |
49 | # update trackers
50 | loss_avg.update(loss.item())
51 | pbar.set_postfix(loss='{:.5f}'.format(loss_avg()))
52 | pbar.update()
53 |
54 | # write summary
55 | if i % params.save_summary_steps == 0:
56 | writer.add_scalar('loss', loss.item(), epoch*(i+1))
57 | loss_history.append(loss.item())
58 |
59 | return loss_history
60 |
61 |
62 | def train_and_evaluate(model, train_dataloader, val_dataloader, loss_fn, optimizer, scheduler, writer, params):
63 |
64 | best_loss = float('inf')
65 | start_epoch = 0
66 |
67 | if params.restore_file:
68 | print('Restoring parameters from {}'.format(params.restore_file))
69 | start_epoch = utils.load_checkpoint(params.restore_file, model, optimizer, scheduler, best_loss)
70 | params.n_epochs += start_epoch - 1
71 | print('Resuming training from epoch {}'.format(start_epoch))
72 |
73 | for epoch in range(start_epoch, params.n_epochs):
74 | scheduler.step()
75 |
76 | loss_history = train_epoch(model, train_dataloader, loss_fn, optimizer, writer, params, epoch)
77 |
78 | # snapshot at end of epoch
79 | is_best = sum(loss_history[:1000])/1000 < best_loss
80 | if is_best: best_loss = sum(loss_history[:1000])/1000
81 | utils.save_checkpoint({'epoch': epoch + 1,
82 | 'best_loss': best_loss,
83 | 'model_state_dict': model.state_dict(),
84 | 'optimizer_state_dict': optimizer.state_dict(),
85 | 'scheduler_state_dict': scheduler.state_dict()},
86 | is_best=False,
87 | checkpoint=params.output_dir,
88 | quiet=True)
89 |
90 | # visualize
91 | visualize_sample(model, val_dataloader.dataset, writer, params, epoch+1)
92 |
93 | # evalutate and visualize
94 | val_accuracy = evaluate(model, val_dataloader, writer, params)
95 |
96 | # record val accuracy
97 | writer.add_scalar('val_accuracy', val_accuracy, epoch+1)
98 |
99 |
100 | if __name__ == '__main__':
101 | args = parser.parse_args()
102 |
103 | json_path = os.path.join(args.output_dir, 'params.json')
104 | assert os.path.isfile(json_path), 'No json configuration file found at {}'.format(json_path)
105 | params = utils.Params(json_path)
106 |
107 | params.restore_file = args.restore_file
108 |
109 | # check output folder exist and if it is rel path
110 | if not os.path.isdir(params.output_dir):
111 | os.mkdir(params.output_dir)
112 |
113 | writer = utils.set_writer(params.output_dir if args.restore_file is None else os.path.dirname(args.restore_file))
114 |
115 | params.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda else 'cpu')
116 |
117 | # set random seed
118 | torch.manual_seed(11052018)
119 | if params.device.type is 'cuda': torch.cuda.manual_seed(11052018)
120 |
121 | # input
122 | train_dataloader = fetch_dataloader(params, train=True)
123 | val_dataloader = fetch_dataloader(params, train=False)
124 |
125 | # construct model
126 | # dims out (pytorch affine grid requires 2x3 matrix output; else perspective transform requires 8)
127 | model = model.STN(getattr(model, params.stn_module), params).to(params.device)
128 | # initialize
129 | initialize(model)
130 | capacity = sum(p.numel() for p in model.parameters())
131 |
132 | loss_fn = torch.nn.CrossEntropyLoss().to(params.device)
133 | optimizer = torch.optim.Adam([
134 | {'params': model.transformer.parameters(), 'lr': params.transformer_lr},
135 | {'params': model.clf.parameters(), 'lr': params.clf_lr}])
136 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, params.lr_step, params.lr_gamma)
137 |
138 | # train and eval
139 | print('\nStarting training with model (capacity {}):\n'.format(capacity), model)
140 | print('\nParameters:\n', pprint.pformat(params))
141 | train_and_evaluate(model, train_dataloader, val_dataloader, loss_fn, optimizer, scheduler, writer, params)
142 |
143 | writer.close()
144 |
145 |
146 |
--------------------------------------------------------------------------------
/vision_transforms.py:
--------------------------------------------------------------------------------
1 | import math
2 | import copy
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | def vec_to_perpective_matrix(vec):
9 | # vec rep of the perspective transform has 8 dof; so add 1 for the bottom right of the perspective matrix;
10 | # note network is initialized to transformer layer bias = [1, 0, 0, 0, 1, 0] so no need to add an identity matrix here
11 | out = torch.cat((vec, torch.ones((vec.shape[0],1), dtype=vec.dtype, device=vec.device)), dim=1).reshape(vec.shape[0], -1)
12 | return out.view(-1,3,3)
13 |
14 |
15 | def gen_random_perspective_transform(params):
16 | """ generate a batch of 3x3 homography matrices by composing rotation, translation, shear, and projection matrices,
17 | where each samples components from a uniform(-1,1) * multiplicative_factor
18 | """
19 |
20 | batch_size = params.batch_size
21 |
22 | # debugging
23 | if params.dict.get('identity_transform_only'):
24 | return torch.eye(3).repeat(batch_size, 1, 1).to(params.device)
25 |
26 |
27 | I = torch.eye(3).repeat(batch_size, 1, 1)
28 | uniform = torch.distributions.Uniform(-1,1)
29 | factor = 0.25
30 | c = copy.deepcopy
31 |
32 | # rotation component
33 | a = math.pi / 6 * uniform.sample((batch_size,))
34 | R = c(I)
35 | R[:, 0, 0] = torch.cos(a)
36 | R[:, 0, 1] = - torch.sin(a)
37 | R[:, 1, 0] = torch.sin(a)
38 | R[:, 1, 1] = torch.cos(a)
39 | R.to(params.device)
40 |
41 | # translation component
42 | tx = factor * uniform.sample((batch_size,))
43 | ty = factor * uniform.sample((batch_size,))
44 | T = c(I)
45 | T[:, 0, 2] = tx
46 | T[:, 1, 2] = ty
47 | T.to(params.device)
48 |
49 | # shear component
50 | sx = factor * uniform.sample((batch_size,))
51 | sy = factor * uniform.sample((batch_size,))
52 | A = c(I)
53 | A[:, 0, 1] = sx
54 | A[:, 1, 0] = sy
55 | A.to(params.device)
56 |
57 | # projective component
58 | px = uniform.sample((batch_size,))
59 | py = uniform.sample((batch_size,))
60 | P = c(I)
61 | P[:, 2, 0] = px
62 | P[:, 2, 1] = py
63 | P.to(params.device)
64 |
65 | # compose the homography
66 | H = R @ T @ P @ A
67 |
68 | return H
69 |
70 |
71 | def apply_transform_to_batch(im_batch_tensor, transform_tensor):
72 | """ apply a geometric transform to a batch of image tensors
73 | args
74 | im_batch_tensor -- torch float tensor of shape (N, C, H, W)
75 | transform_tensor -- torch float tensor of shape (1, 3, 3)
76 |
77 | returns
78 | transformed_batch_tensor -- torch float tensor of shape (N, C, H, W)
79 | """
80 | N, C, H, W = im_batch_tensor.shape
81 | device = im_batch_tensor.device
82 |
83 | # torch.nn.functional.grid_sample takes a grid in [-1,1] and interpolates;
84 | # construct grid in homogeneous coordinates
85 | x, y = torch.meshgrid([torch.linspace(-1, 1, H), torch.linspace(-1, 1, W)])
86 | x, y = x.flatten(), y.flatten()
87 | xy_hom = torch.stack([x, y, torch.ones(x.shape[0])], dim=0).unsqueeze(0).to(device)
88 |
89 | # tansform the [-1,1] homogeneous coords
90 | xy_transformed = transform_tensor.matmul(xy_hom) # (N, 3, 3) matmul (N, 3, H*W) > (N, 3, H*W)
91 | # convert to inhomogeneous coords -- cf Szeliski eq. 2.21
92 |
93 | grid = xy_transformed[:,:2,:] / (xy_transformed[:,2,:].unsqueeze(1) + 1e-9)
94 | grid = grid.permute(0,2,1).reshape(-1, H, W, 2) # (N, H, W, 2); cf torch.functional.grid_sample
95 | grid = grid.expand(N, *grid.shape[1:]) # expand to minibatch
96 |
97 | transformed_batch = F.grid_sample(im_batch_tensor, grid, mode='bilinear')
98 | transformed_batch.transpose_(3,2)
99 |
100 | return transformed_batch
101 |
102 |
103 |
104 |
105 | # --------------------
106 | # Test
107 | # --------------------
108 |
109 | def test_get_random_perspective_transform():
110 | import matplotlib
111 | matplotlib.use('TkAgg')
112 | import numpy as np
113 | import matplotlib.pyplot as plt
114 | from unittest.mock import Mock
115 |
116 | np.random.seed(6)
117 |
118 | im = np.zeros((30,30))
119 | im[10:20,10:20] = 1
120 | im[20,20] = 1
121 |
122 | imt = np.array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
123 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
124 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
125 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
126 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
127 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
128 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
129 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
130 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
131 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
132 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 18,
133 | 18, 18, 126, 136, 175, 26, 166, 255, 247, 127, 0, 0, 0, 0],
134 | [ 0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253,
135 | 253, 253, 253, 253, 225, 172, 253, 242, 195, 64, 0, 0, 0, 0],
136 | [ 0, 0, 0, 0, 0, 0, 0, 49, 238, 253, 253, 253, 253, 253,
137 | 253, 253, 253, 251, 93, 82, 82, 56, 39, 0, 0, 0, 0, 0],
138 | [ 0, 0, 0, 0, 0, 0, 0, 18, 219, 253, 253, 253, 253, 253,
139 | 198, 182, 247, 241, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
140 | [ 0, 0, 0, 0, 0, 0, 0, 0, 80, 156, 107, 253, 253, 205,
141 | 11, 0, 43, 154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
142 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 1, 154, 253, 90,
143 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
144 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190,
145 | 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
146 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253,
147 | 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
148 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241,
149 | 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
150 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81,
151 | 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0],
152 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
153 | 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0],
154 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
155 | 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0],
156 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
157 | 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0],
158 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
159 | 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0],
160 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148,
161 | 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0],
162 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253,
163 | 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0],
164 | [ 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253,
165 | 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
166 | [ 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195,
167 | 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
168 | [ 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11,
169 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
170 | [ 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0,
171 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
172 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
173 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
174 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
175 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
176 | [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
177 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
178 |
179 |
180 |
181 | # get transform
182 | params = Mock()
183 | params.batch_size = 1
184 | params.dict = {'identity_transform_only': False}
185 | params.device = torch.device('cpu')
186 | H = gen_random_perspective_transform(params)
187 |
188 | im = im[np.newaxis, np.newaxis, ...]
189 | im = torch.FloatTensor(im)
190 | im_transformed = apply_transform_to_batch(im, H)
191 |
192 | imt = imt[np.newaxis, np.newaxis, ...]
193 | imt = torch.FloatTensor(imt)
194 | imt_transformed = apply_transform_to_batch(imt, H)
195 |
196 | fig, axs = plt.subplots(2,2)
197 |
198 | axs[0,0].imshow(im.squeeze().numpy(), cmap='gray')
199 | axs[0,1].imshow(im_transformed.squeeze().numpy(), cmap='gray')
200 |
201 | axs[1,0].imshow(imt.squeeze().numpy(), cmap='gray')
202 | axs[1,1].imshow(imt_transformed.squeeze().numpy(), cmap='gray')
203 |
204 | for ax in plt.gcf().axes:
205 | ax.axis('off')
206 | plt.tight_layout()
207 | plt.savefig('images/transform_test.png')
208 | plt.close()
209 |
210 |
211 | if __name__ == '__main__':
212 | test_get_random_perspective_transform()
213 |
214 |
215 |
--------------------------------------------------------------------------------