├── .editorconfig
├── .gitignore
├── LICENSE
├── MNIST-pytorch
├── data.py
├── graph.py
├── options.py
├── train.py
├── util.py
└── warp.py
├── MNIST-tensorflow
├── data.py
├── graph.py
├── options.py
├── train.py
├── util.py
└── warp.py
├── README.md
└── traffic-sign-tensorflow
├── data.py
├── graph.py
├── options.py
├── train.py
├── util.py
└── warp.py
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | end_of_line = lf
5 | insert_final_newline = true
6 | indent_style = tab
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 |
10 | [*.md]
11 | trim_trailing_whitespace = false
12 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Chen-Hsuan Lin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MNIST-pytorch/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.linalg
3 | import os,time
4 | import torch
5 | import torchvision
6 |
7 | import warp,util
8 |
9 | # load MNIST data
10 | def loadMNIST(opt,path):
11 | os.makedirs(path,exist_ok=True)
12 | trainDataset = torchvision.datasets.MNIST(path,train=True,download=True)
13 | testDataset = torchvision.datasets.MNIST(path,train=False,download=True)
14 | trainData,testData = {},{}
15 | trainData["image"] = torch.tensor([np.array(sample[0])/255.0 for sample in trainDataset],dtype=torch.float32)
16 | testData["image"] = torch.tensor([np.array(sample[0])/255.0 for sample in testDataset],dtype=torch.float32)
17 | trainData["label"] = torch.tensor([sample[1] for sample in trainDataset])
18 | testData["label"] = torch.tensor([sample[1] for sample in testDataset])
19 | return trainData,testData
20 |
21 | # generate training batch
22 | def genPerturbations(opt):
23 | X = np.tile(opt.canon4pts[:,0],[opt.batchSize,1])
24 | Y = np.tile(opt.canon4pts[:,1],[opt.batchSize,1])
25 | O = np.zeros([opt.batchSize,4],dtype=np.float32)
26 | I = np.ones([opt.batchSize,4],dtype=np.float32)
27 | dX = np.random.randn(opt.batchSize,4)*opt.pertScale \
28 | +np.random.randn(opt.batchSize,1)*opt.transScale
29 | dY = np.random.randn(opt.batchSize,4)*opt.pertScale \
30 | +np.random.randn(opt.batchSize,1)*opt.transScale
31 | dX,dY = dX.astype(np.float32),dY.astype(np.float32)
32 | # fit warp parameters to generated displacements
33 | if opt.warpType=="homography":
34 | A = np.concatenate([np.stack([X,Y,I,O,O,O,-X*(X+dX),-Y*(X+dX)],axis=-1),
35 | np.stack([O,O,O,X,Y,I,-X*(Y+dY),-Y*(Y+dY)],axis=-1)],axis=1)
36 | b = np.expand_dims(np.concatenate([X+dX,Y+dY],axis=1),axis=-1)
37 | pPert = np.matmul(np.linalg.inv(A),b).squeeze()
38 | pPert -= np.array([1,0,0,0,1,0,0,0])
39 | else:
40 | if opt.warpType=="translation":
41 | J = np.concatenate([np.stack([I,O],axis=-1),
42 | np.stack([O,I],axis=-1)],axis=1)
43 | if opt.warpType=="similarity":
44 | J = np.concatenate([np.stack([X,Y,I,O],axis=-1),
45 | np.stack([-Y,X,O,I],axis=-1)],axis=1)
46 | if opt.warpType=="affine":
47 | J = np.concatenate([np.stack([X,Y,I,O,O,O],axis=-1),
48 | np.stack([O,O,O,X,Y,I],axis=-1)],axis=1)
49 | dXY = np.expand_dims(np.concatenate([dX,dY],axis=1),axis=-1)
50 | Jtransp = np.transpose(J,axes=[0,2,1])
51 | pPert = np.matmul(np.linalg.inv(np.matmul(Jtransp,J)),np.matmul(Jtransp,dXY)).squeeze()
52 | pInit = torch.from_numpy(pPert).cuda()
53 | return pInit
54 |
55 | # make training batch
56 | def makeBatch(opt,data):
57 | N = len(data["image"])
58 | randIdx = np.random.randint(N,size=[opt.batchSize])
59 | batch = {
60 | "image": data["image"][randIdx].cuda(),
61 | "label": data["label"][randIdx].cuda(),
62 | }
63 | return batch
64 |
65 | # evaluation on test set
66 | def evalTest(opt,data,geometric,classifier):
67 | geometric.eval()
68 | classifier.eval()
69 | N = len(data["image"])
70 | batchN = int(np.ceil(N/opt.batchSize))
71 | warped = [{},{}]
72 | count = 0
73 | for b in range(batchN):
74 | # use some dummy data (0) as batch filler if necessary
75 | if b!=batchN-1:
76 | realIdx = np.arange(opt.batchSize*b,opt.batchSize*(b+1))
77 | else:
78 | realIdx = np.arange(opt.batchSize*b,N)
79 | idx = np.zeros([opt.batchSize],dtype=int)
80 | idx[:len(realIdx)] = realIdx
81 | # make training batch
82 | image = data["image"][idx].cuda()
83 | label = data["label"][idx].cuda()
84 | image.data.unsqueeze_(dim=1)
85 | # generate perturbation
86 | pInit = genPerturbations(opt)
87 | pInitMtrx = warp.vec2mtrx(opt,pInit)
88 | imagePert = warp.transformImage(opt,image,pInitMtrx)
89 | imageWarpAll = geometric(opt,image,pInit) if opt.netType=="IC-STN" else geometric(opt,imagePert)
90 | imageWarp = imageWarpAll[-1]
91 | output = classifier(opt,imageWarp)
92 | _,pred = output.max(dim=1)
93 | count += int((pred==label).sum().cpu().numpy())
94 | if opt.netType=="STN" or opt.netType=="IC-STN":
95 | imgPert = imagePert.detach().cpu().numpy()
96 | imgWarp = imageWarp.detach().cpu().numpy()
97 | for i in range(len(realIdx)):
98 | l = data["label"][idx[i]].item()
99 | if l not in warped[0]: warped[0][l] = []
100 | if l not in warped[1]: warped[1][l] = []
101 | warped[0][l].append(imgPert[i])
102 | warped[1][l].append(imgWarp[i])
103 | accuracy = float(count)/N
104 | if opt.netType=="STN" or opt.netType=="IC-STN":
105 | mean = [np.array([np.mean(warped[0][l],axis=0) for l in warped[0]]),
106 | np.array([np.mean(warped[1][l],axis=0) for l in warped[1]])]
107 | var = [np.array([np.var(warped[0][l],axis=0) for l in warped[0]]),
108 | np.array([np.var(warped[1][l],axis=0) for l in warped[1]])]
109 | else: mean,var = None,None
110 | geometric.train()
111 | classifier.train()
112 | return accuracy,mean,var
113 |
--------------------------------------------------------------------------------
/MNIST-pytorch/graph.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import time
4 | import data,warp,util
5 |
6 | # build classification network
7 | class FullCNN(torch.nn.Module):
8 | def __init__(self,opt):
9 | super(FullCNN,self).__init__()
10 | self.inDim = 1
11 | def conv2Layer(outDim):
12 | conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[3,3],stride=1,padding=0)
13 | self.inDim = outDim
14 | return conv
15 | def linearLayer(outDim):
16 | fc = torch.nn.Linear(self.inDim,outDim)
17 | self.inDim = outDim
18 | return fc
19 | def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2)
20 | self.conv2Layers = torch.nn.Sequential(
21 | conv2Layer(3),torch.nn.ReLU(True),
22 | conv2Layer(6),torch.nn.ReLU(True),maxpoolLayer(),
23 | conv2Layer(9),torch.nn.ReLU(True),
24 | conv2Layer(12),torch.nn.ReLU(True)
25 | )
26 | self.inDim *= 8**2
27 | self.linearLayers = torch.nn.Sequential(
28 | linearLayer(48),torch.nn.ReLU(True),
29 | linearLayer(opt.labelN)
30 | )
31 | initialize(opt,self,opt.stdC)
32 | def forward(self,opt,image):
33 | feat = image
34 | feat = self.conv2Layers(feat).reshape(opt.batchSize,-1)
35 | feat = self.linearLayers(feat)
36 | output = feat
37 | return output
38 |
39 | # build classification network
40 | class CNN(torch.nn.Module):
41 | def __init__(self,opt):
42 | super(CNN,self).__init__()
43 | self.inDim = 1
44 | def conv2Layer(outDim):
45 | conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[9,9],stride=1,padding=0)
46 | self.inDim = outDim
47 | return conv
48 | def linearLayer(outDim):
49 | fc = torch.nn.Linear(self.inDim,outDim)
50 | self.inDim = outDim
51 | return fc
52 | def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2)
53 | self.conv2Layers = torch.nn.Sequential(
54 | conv2Layer(3),torch.nn.ReLU(True)
55 | )
56 | self.inDim *= 20**2
57 | self.linearLayers = torch.nn.Sequential(
58 | linearLayer(opt.labelN)
59 | )
60 | initialize(opt,self,opt.stdC)
61 | def forward(self,opt,image):
62 | feat = image
63 | feat = self.conv2Layers(feat).reshape(opt.batchSize,-1)
64 | feat = self.linearLayers(feat)
65 | output = feat
66 | return output
67 |
68 | # an identity class to skip geometric predictors
69 | class Identity(torch.nn.Module):
70 | def __init__(self): super(Identity,self).__init__()
71 | def forward(self,opt,feat): return [feat]
72 |
73 | # build Spatial Transformer Network
74 | class STN(torch.nn.Module):
75 | def __init__(self,opt):
76 | super(STN,self).__init__()
77 | self.inDim = 1
78 | def conv2Layer(outDim):
79 | conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[7,7],stride=1,padding=0)
80 | self.inDim = outDim
81 | return conv
82 | def linearLayer(outDim):
83 | fc = torch.nn.Linear(self.inDim,outDim)
84 | self.inDim = outDim
85 | return fc
86 | def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2)
87 | self.conv2Layers = torch.nn.Sequential(
88 | conv2Layer(4),torch.nn.ReLU(True),
89 | conv2Layer(8),torch.nn.ReLU(True),maxpoolLayer()
90 | )
91 | self.inDim *= 8**2
92 | self.linearLayers = torch.nn.Sequential(
93 | linearLayer(48),torch.nn.ReLU(True),
94 | linearLayer(opt.warpDim)
95 | )
96 | initialize(opt,self,opt.stdGP,last0=True)
97 | def forward(self,opt,image):
98 | imageWarpAll = [image]
99 | feat = image
100 | feat = self.conv2Layers(feat).reshape(opt.batchSize,-1)
101 | feat = self.linearLayers(feat)
102 | p = feat
103 | pMtrx = warp.vec2mtrx(opt,p)
104 | imageWarp = warp.transformImage(opt,image,pMtrx)
105 | imageWarpAll.append(imageWarp)
106 | return imageWarpAll
107 |
108 | # build Inverse Compositional STN
109 | class ICSTN(torch.nn.Module):
110 | def __init__(self,opt):
111 | super(ICSTN,self).__init__()
112 | self.inDim = 1
113 | def conv2Layer(outDim):
114 | conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[7,7],stride=1,padding=0)
115 | self.inDim = outDim
116 | return conv
117 | def linearLayer(outDim):
118 | fc = torch.nn.Linear(self.inDim,outDim)
119 | self.inDim = outDim
120 | return fc
121 | def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2)
122 | self.conv2Layers = torch.nn.Sequential(
123 | conv2Layer(4),torch.nn.ReLU(True),
124 | conv2Layer(8),torch.nn.ReLU(True),maxpoolLayer()
125 | )
126 | self.inDim *= 8**2
127 | self.linearLayers = torch.nn.Sequential(
128 | linearLayer(48),torch.nn.ReLU(True),
129 | linearLayer(opt.warpDim)
130 | )
131 | initialize(opt,self,opt.stdGP,last0=True)
132 | def forward(self,opt,image,p):
133 | imageWarpAll = []
134 | for l in range(opt.warpN):
135 | pMtrx = warp.vec2mtrx(opt,p)
136 | imageWarp = warp.transformImage(opt,image,pMtrx)
137 | imageWarpAll.append(imageWarp)
138 | feat = imageWarp
139 | feat = self.conv2Layers(feat).reshape(opt.batchSize,-1)
140 | feat = self.linearLayers(feat)
141 | dp = feat
142 | p = warp.compose(opt,p,dp)
143 | pMtrx = warp.vec2mtrx(opt,p)
144 | imageWarp = warp.transformImage(opt,image,pMtrx)
145 | imageWarpAll.append(imageWarp)
146 | return imageWarpAll
147 |
148 | # initialize weights/biases
149 | def initialize(opt,model,stddev,last0=False):
150 | for m in model.conv2Layers:
151 | if isinstance(m,torch.nn.Conv2d):
152 | m.weight.data.normal_(0,stddev)
153 | m.bias.data.normal_(0,stddev)
154 | for m in model.linearLayers:
155 | if isinstance(m,torch.nn.Linear):
156 | if last0 and m is model.linearLayers[-1]:
157 | m.weight.data.zero_()
158 | m.bias.data.zero_()
159 | else:
160 | m.weight.data.normal_(0,stddev)
161 | m.bias.data.normal_(0,stddev)
162 |
--------------------------------------------------------------------------------
/MNIST-pytorch/options.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 | import warp
4 | import util
5 | import torch
6 |
7 | def set(training):
8 |
9 | # parse input arguments
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("netType", choices=["CNN","STN","IC-STN"], help="type of network")
12 | parser.add_argument("--group", default="0", help="name for group")
13 | parser.add_argument("--model", default="test", help="name for model instance")
14 | parser.add_argument("--size", default="28x28", help="image resolution")
15 | parser.add_argument("--warpType", default="homography", help="type of warp function on images",
16 | choices=["translation","similarity","affine","homography"])
17 | parser.add_argument("--warpN", type=int, default=4, help="number of recurrent transformations (for IC-STN)")
18 | parser.add_argument("--stdC", type=float, default=0.1, help="initialization stddev (classification network)")
19 | parser.add_argument("--stdGP", type=float, default=0.1, help="initialization stddev (geometric predictor)")
20 | parser.add_argument("--pertScale", type=float, default=0.25, help="initial perturbation scale")
21 | parser.add_argument("--transScale", type=float, default=0.25, help="initial translation scale")
22 | if training: # training
23 | parser.add_argument("--port", type=int, default=8097, help="port number for visdom visualization")
24 | parser.add_argument("--batchSize", type=int, default=100, help="batch size for SGD")
25 | parser.add_argument("--lrC", type=float, default=1e-2, help="learning rate (classification network)")
26 | parser.add_argument("--lrGP", type=float, default=None, help="learning rate (geometric predictor)")
27 | parser.add_argument("--lrDecay", type=float, default=1.0, help="learning rate decay")
28 | parser.add_argument("--lrStep", type=int, default=100000, help="learning rate decay step size")
29 | parser.add_argument("--fromIt", type=int, default=0, help="resume training from iteration number")
30 | parser.add_argument("--toIt", type=int, default=500000, help="run training to iteration number")
31 | else: # evaluation
32 | parser.add_argument("--batchSize", type=int, default=1, help="batch size for evaluation")
33 | opt = parser.parse_args()
34 |
35 | if opt.lrGP is None: opt.lrGP = 0 if opt.netType=="CNN" else \
36 | 1e-2 if opt.netType=="STN" else \
37 | 1e-4 if opt.netType=="IC-STN" else None
38 |
39 | # --- below are automatically set ---
40 | assert(torch.cuda.is_available()) # support only training on GPU for now
41 | torch.set_default_tensor_type("torch.cuda.FloatTensor")
42 | opt.training = training
43 | opt.H,opt.W = [int(x) for x in opt.size.split("x")]
44 | opt.visBlockSize = int(np.floor(np.sqrt(opt.batchSize)))
45 | opt.warpDim = 2 if opt.warpType == "translation" else \
46 | 4 if opt.warpType == "similarity" else \
47 | 6 if opt.warpType == "affine" else \
48 | 8 if opt.warpType == "homography" else None
49 | opt.labelN = 10
50 | opt.canon4pts = np.array([[-1,-1],[-1,1],[1,1],[1,-1]],dtype=np.float32)
51 | opt.image4pts = np.array([[0,0],[0,opt.H-1],[opt.W-1,opt.H-1],[opt.W-1,0]],dtype=np.float32)
52 | opt.refMtrx = np.eye(3).astype(np.float32)
53 | if opt.netType=="STN": opt.warpN = 1
54 |
55 | print("({0}) {1}".format(
56 | util.toGreen("{0}".format(opt.group)),
57 | util.toGreen("{0}".format(opt.model))))
58 | print("------------------------------------------")
59 | print("network type: {0}, recurrent warps: {1}".format(
60 | util.toYellow("{0}".format(opt.netType)),
61 | util.toYellow("{0}".format(opt.warpN if opt.netType=="IC-STN" else "X"))))
62 | print("batch size: {0}, image size: {1}x{2}".format(
63 | util.toYellow("{0}".format(opt.batchSize)),
64 | util.toYellow("{0}".format(opt.H)),
65 | util.toYellow("{0}".format(opt.W))))
66 | print("warpScale: (pert) {0} (trans) {1}".format(
67 | util.toYellow("{0}".format(opt.pertScale)),
68 | util.toYellow("{0}".format(opt.transScale))))
69 | if training:
70 | print("[geometric predictor] stddev={0}, lr={1}".format(
71 | util.toYellow("{0:.0e}".format(opt.stdGP)),
72 | util.toYellow("{0:.0e}".format(opt.lrGP))))
73 | print("[classification network] stddev={0}, lr={1}".format(
74 | util.toYellow("{0:.0e}".format(opt.stdC)),
75 | util.toYellow("{0:.0e}".format(opt.lrC))))
76 | print("------------------------------------------")
77 | if training:
78 | print(util.toMagenta("training model ({0}) {1}...".format(opt.group,opt.model)))
79 |
80 | return opt
81 |
--------------------------------------------------------------------------------
/MNIST-pytorch/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time,os,sys
3 | import argparse
4 | import util
5 |
6 | print(util.toYellow("======================================================="))
7 | print(util.toYellow("train.py (training on MNIST)"))
8 | print(util.toYellow("======================================================="))
9 |
10 | import torch
11 | import data,graph,warp,util
12 | import options
13 |
14 | print(util.toMagenta("setting configurations..."))
15 | opt = options.set(training=True)
16 |
17 | # create directories for model output
18 | util.mkdir("models_{0}".format(opt.group))
19 |
20 | print(util.toMagenta("building network..."))
21 | with torch.cuda.device(0):
22 | # ------ build network ------
23 | if opt.netType=="CNN":
24 | geometric = graph.Identity()
25 | classifier = graph.FullCNN(opt)
26 | elif opt.netType=="STN":
27 | geometric = graph.STN(opt)
28 | classifier = graph.CNN(opt)
29 | elif opt.netType=="IC-STN":
30 | geometric = graph.ICSTN(opt)
31 | classifier = graph.CNN(opt)
32 | # ------ define loss ------
33 | loss = torch.nn.CrossEntropyLoss()
34 | # ------ optimizer ------
35 | optimList = [{ "params": geometric.parameters(), "lr": opt.lrGP },
36 | { "params": classifier.parameters(), "lr": opt.lrC }]
37 | optim = torch.optim.SGD(optimList)
38 |
39 | # load data
40 | print(util.toMagenta("loading MNIST dataset..."))
41 | trainData,testData = data.loadMNIST(opt,"data")
42 |
43 | # visdom visualizer
44 | vis = util.Visdom(opt)
45 |
46 | print(util.toYellow("======= TRAINING START ======="))
47 | timeStart = time.time()
48 | # start session
49 | with torch.cuda.device(0):
50 | geometric.train()
51 | classifier.train()
52 | if opt.fromIt!=0:
53 | util.restoreModel(opt,geometric,classifier,opt.fromIt)
54 | print(util.toMagenta("resuming from iteration {0}...".format(opt.fromIt)))
55 | print(util.toMagenta("start training..."))
56 |
57 | # training loop
58 | for i in range(opt.fromIt,opt.toIt):
59 | lrGP = opt.lrGP*opt.lrDecay**(i//opt.lrStep)
60 | lrC = opt.lrC*opt.lrDecay**(i//opt.lrStep)
61 | # make training batch
62 | batch = data.makeBatch(opt,trainData)
63 | image = batch["image"].unsqueeze(dim=1)
64 | label = batch["label"]
65 | # generate perturbation
66 | pInit = data.genPerturbations(opt)
67 | pInitMtrx = warp.vec2mtrx(opt,pInit)
68 | # forward/backprop through network
69 | optim.zero_grad()
70 | imagePert = warp.transformImage(opt,image,pInitMtrx)
71 | imageWarpAll = geometric(opt,image,pInit) if opt.netType=="IC-STN" else geometric(opt,imagePert)
72 | imageWarp = imageWarpAll[-1]
73 | output = classifier(opt,imageWarp)
74 | train_loss = loss(output,label)
75 | train_loss.backward()
76 | # run one step
77 | optim.step()
78 | if (i+1)%100==0:
79 | print("it. {0}/{1} lr={3}(GP),{4}(C), loss={5}, time={2}"
80 | .format(util.toCyan("{0}".format(i+1)),
81 | opt.toIt,
82 | util.toGreen("{0:.2f}".format(time.time()-timeStart)),
83 | util.toYellow("{0:.0e}".format(lrGP)),
84 | util.toYellow("{0:.0e}".format(lrC)),
85 | util.toRed("{0:.4f}".format(train_loss))))
86 | if (i+1)%200==0: vis.trainLoss(opt,i+1,train_loss)
87 | if (i+1)%1000==0:
88 | # evaluate on test set
89 | testAcc,testMean,testVar = data.evalTest(opt,testData,geometric,classifier)
90 | testError = (1-testAcc)*100
91 | vis.testLoss(opt,i+1,testError)
92 | if opt.netType=="STN" or opt.netType=="IC-STN":
93 | vis.meanVar(opt,testMean,testVar)
94 | if (i+1)%10000==0:
95 | util.saveModel(opt,geometric,classifier,i+1)
96 | print(util.toGreen("model saved: {0}/{1}, it.{2}".format(opt.group,opt.model,i+1)))
97 |
98 | print(util.toYellow("======= TRAINING DONE ======="))
99 |
--------------------------------------------------------------------------------
/MNIST-pytorch/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.misc
3 | import torch
4 | import os
5 | import termcolor
6 | import visdom
7 |
8 | def mkdir(path):
9 | if not os.path.exists(path): os.mkdir(path)
10 | def imread(fname):
11 | return scipy.misc.imread(fname)/255.0
12 | def imsave(fname,array):
13 | scipy.misc.toimage(array,cmin=0.0,cmax=1.0).save(fname)
14 |
15 | # convert to colored strings
16 | def toRed(content): return termcolor.colored(content,"red",attrs=["bold"])
17 | def toGreen(content): return termcolor.colored(content,"green",attrs=["bold"])
18 | def toBlue(content): return termcolor.colored(content,"blue",attrs=["bold"])
19 | def toCyan(content): return termcolor.colored(content,"cyan",attrs=["bold"])
20 | def toYellow(content): return termcolor.colored(content,"yellow",attrs=["bold"])
21 | def toMagenta(content): return termcolor.colored(content,"magenta",attrs=["bold"])
22 |
23 | # restore model
24 | def restoreModel(opt,geometric,classifier,it):
25 | geometric.load_state_dict(torch.load("models_{0}/{1}_it{2}_GP.npy".format(opt.group,opt.model,it)))
26 | classifier.load_state_dict(torch.load("models_{0}/{1}_it{2}_C.npy".format(opt.group,opt.model,it)))
27 | # save model
28 | def saveModel(opt,geometric,classifier,it):
29 | torch.save(geometric.state_dict(),"models_{0}/{1}_it{2}_GP.npy".format(opt.group,opt.model,it))
30 | torch.save(classifier.state_dict(),"models_{0}/{1}_it{2}_C.npy".format(opt.group,opt.model,it))
31 |
32 | class Visdom():
33 | def __init__(self,opt):
34 | self.vis = visdom.Visdom(port=opt.port,use_incoming_socket=False)
35 | self.trainLossInit = True
36 | self.testLossInit = True
37 | self.meanVarInit = True
38 | def tileImages(self,opt,images,H,W,HN,WN):
39 | assert(len(images)==HN*WN)
40 | images = images.reshape([HN,WN,-1,H,W])
41 | images = [list(i) for i in images]
42 | imageBlocks = np.concatenate([np.concatenate(row,axis=2) for row in images],axis=1)
43 | return imageBlocks
44 | def trainLoss(self,opt,it,loss):
45 | loss = float(loss.detach().cpu().numpy())
46 | if self.trainLossInit:
47 | self.vis.line(Y=np.array([loss]),X=np.array([it]),win="{0}_trainloss".format(opt.model),
48 | opts={ "title": "{0} (TRAIN_loss)".format(opt.model) })
49 | self.trainLossInit = False
50 | else: self.vis.line(Y=np.array([loss]),X=np.array([it]),win=opt.model+"_trainloss",update="append")
51 | def testLoss(self,opt,it,loss):
52 | if self.testLossInit:
53 | self.vis.line(Y=np.array([loss]),X=np.array([it]),win="{0}_testloss".format(opt.model),
54 | opts={ "title": "{0} (TEST_error)".format(opt.model) })
55 | self.testLossInit = False
56 | else: self.vis.line(Y=np.array([loss]),X=np.array([it]),win=opt.model+"_testloss",update="append")
57 | def meanVar(self,opt,mean,var):
58 | mean = [self.tileImages(opt,m,opt.H,opt.W,1,10) for m in mean]
59 | var = [self.tileImages(opt,v,opt.H,opt.W,1,10)*3 for v in var]
60 | self.vis.image(mean[0].clip(0,1),win="{0}_meaninit".format(opt.model), opts={ "title": "{0} (TEST_mean_init)".format(opt.model) })
61 | self.vis.image(mean[1].clip(0,1),win="{0}_meanwarped".format(opt.model), opts={ "title": "{0} (TEST_mean_warped)".format(opt.model) })
62 | self.vis.image(var[0].clip(0,1),win="{0}_varinit".format(opt.model), opts={ "title": "{0} (TEST_var_init)".format(opt.model) })
63 | self.vis.image(var[1].clip(0,1),win="{0}_varwarped".format(opt.model), opts={ "title": "{0} (TEST_var_warped)".format(opt.model) })
64 |
65 |
--------------------------------------------------------------------------------
/MNIST-pytorch/warp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.linalg
3 | import torch
4 |
5 | import util
6 |
7 | # fit (affine) warp between two sets of points
8 | def fit(Xsrc,Xdst):
9 | ptsN = len(Xsrc)
10 | X,Y,U,V,O,I = Xsrc[:,0],Xsrc[:,1],Xdst[:,0],Xdst[:,1],np.zeros([ptsN]),np.ones([ptsN])
11 | A = np.concatenate((np.stack([X,Y,I,O,O,O],axis=1),
12 | np.stack([O,O,O,X,Y,I],axis=1)),axis=0)
13 | b = np.concatenate((U,V),axis=0)
14 | p1,p2,p3,p4,p5,p6 = scipy.linalg.lstsq(A,b)[0].squeeze()
15 | pMtrx = np.array([[p1,p2,p3],[p4,p5,p6],[0,0,1]],dtype=torch.float32)
16 | return pMtrx
17 |
18 | # compute composition of warp parameters
19 | def compose(opt,p,dp):
20 | pMtrx = vec2mtrx(opt,p)
21 | dpMtrx = vec2mtrx(opt,dp)
22 | pMtrxNew = dpMtrx.matmul(pMtrx)
23 | pMtrxNew = pMtrxNew/pMtrxNew[:,2:3,2:3]
24 | pNew = mtrx2vec(opt,pMtrxNew)
25 | return pNew
26 |
27 | # compute inverse of warp parameters
28 | def inverse(opt,p):
29 | pMtrx = vec2mtrx(opt,p)
30 | pInvMtrx = pMtrx.inverse()
31 | pInv = mtrx2vec(opt,pInvMtrx)
32 | return pInv
33 |
34 | # convert warp parameters to matrix
35 | def vec2mtrx(opt,p):
36 | O = torch.zeros(opt.batchSize,dtype=torch.float32).cuda()
37 | I = torch.ones(opt.batchSize,dtype=torch.float32).cuda()
38 | if opt.warpType=="translation":
39 | tx,ty = torch.unbind(p,dim=1)
40 | pMtrx = torch.stack([torch.stack([I,O,tx],dim=-1),
41 | torch.stack([O,I,ty],dim=-1),
42 | torch.stack([O,O,I],dim=-1)],dim=1)
43 | if opt.warpType=="similarity":
44 | pc,ps,tx,ty = torch.unbind(p,dim=1)
45 | pMtrx = torch.stack([torch.stack([I+pc,-ps,tx],dim=-1),
46 | torch.stack([ps,I+pc,ty],dim=-1),
47 | torch.stack([O,O,I],dim=-1)],dim=1)
48 | if opt.warpType=="affine":
49 | p1,p2,p3,p4,p5,p6 = torch.unbind(p,dim=1)
50 | pMtrx = torch.stack([torch.stack([I+p1,p2,p3],dim=-1),
51 | torch.stack([p4,I+p5,p6],dim=-1),
52 | torch.stack([O,O,I],dim=-1)],dim=1)
53 | if opt.warpType=="homography":
54 | p1,p2,p3,p4,p5,p6,p7,p8 = torch.unbind(p,dim=1)
55 | pMtrx = torch.stack([torch.stack([I+p1,p2,p3],dim=-1),
56 | torch.stack([p4,I+p5,p6],dim=-1),
57 | torch.stack([p7,p8,I],dim=-1)],dim=1)
58 | return pMtrx
59 |
60 | # convert warp matrix to parameters
61 | def mtrx2vec(opt,pMtrx):
62 | [row0,row1,row2] = torch.unbind(pMtrx,dim=1)
63 | [e00,e01,e02] = torch.unbind(row0,dim=1)
64 | [e10,e11,e12] = torch.unbind(row1,dim=1)
65 | [e20,e21,e22] = torch.unbind(row2,dim=1)
66 | if opt.warpType=="translation": p = torch.stack([e02,e12],dim=1)
67 | if opt.warpType=="similarity": p = torch.stack([e00-1,e10,e02,e12],dim=1)
68 | if opt.warpType=="affine": p = torch.stack([e00-1,e01,e02,e10,e11-1,e12],dim=1)
69 | if opt.warpType=="homography": p = torch.stack([e00-1,e01,e02,e10,e11-1,e12,e20,e21],dim=1)
70 | return p
71 |
72 | # warp the image
73 | def transformImage(opt,image,pMtrx):
74 | refMtrx = torch.from_numpy(opt.refMtrx).cuda()
75 | refMtrx = refMtrx.repeat(opt.batchSize,1,1)
76 | transMtrx = refMtrx.matmul(pMtrx)
77 | # warp the canonical coordinates
78 | X,Y = np.meshgrid(np.linspace(-1,1,opt.W),np.linspace(-1,1,opt.H))
79 | X,Y = X.flatten(),Y.flatten()
80 | XYhom = np.stack([X,Y,np.ones_like(X)],axis=1).T
81 | XYhom = np.tile(XYhom,[opt.batchSize,1,1]).astype(np.float32)
82 | XYhom = torch.from_numpy(XYhom).cuda()
83 | XYwarpHom = transMtrx.matmul(XYhom)
84 | XwarpHom,YwarpHom,ZwarpHom = torch.unbind(XYwarpHom,dim=1)
85 | Xwarp = (XwarpHom/(ZwarpHom+1e-8)).reshape(opt.batchSize,opt.H,opt.W)
86 | Ywarp = (YwarpHom/(ZwarpHom+1e-8)).reshape(opt.batchSize,opt.H,opt.W)
87 | grid = torch.stack([Xwarp,Ywarp],dim=-1)
88 | # sampling with bilinear interpolation
89 | imageWarp = torch.nn.functional.grid_sample(image,grid,mode="bilinear")
90 | return imageWarp
91 |
--------------------------------------------------------------------------------
/MNIST-tensorflow/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.linalg
3 | import os,time
4 | import tensorflow as tf
5 |
6 | import warp
7 |
8 | # load MNIST data
9 | def loadMNIST(fname):
10 | if not os.path.exists(fname):
11 | # download and preprocess MNIST dataset
12 | from tensorflow.examples.tutorials.mnist import input_data
13 | mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
14 | trainData,validData,testData = {},{},{}
15 | trainData["image"] = mnist.train.images.reshape([-1,28,28]).astype(np.float32)
16 | validData["image"] = mnist.validation.images.reshape([-1,28,28]).astype(np.float32)
17 | testData["image"] = mnist.test.images.reshape([-1,28,28]).astype(np.float32)
18 | trainData["label"] = np.argmax(mnist.train.labels.astype(np.float32),axis=1)
19 | validData["label"] = np.argmax(mnist.validation.labels.astype(np.float32),axis=1)
20 | testData["label"] = np.argmax(mnist.test.labels.astype(np.float32),axis=1)
21 | os.makedirs(os.path.dirname(fname))
22 | np.savez(fname,train=trainData,valid=validData,test=testData)
23 | os.system("rm -rf MNIST_data")
24 | MNIST = np.load(fname)
25 | trainData = MNIST["train"].item()
26 | validData = MNIST["valid"].item()
27 | testData = MNIST["test"].item()
28 | return trainData,validData,testData
29 |
30 | # generate training batch
31 | def genPerturbations(opt):
32 | with tf.name_scope("genPerturbations"):
33 | X = np.tile(opt.canon4pts[:,0],[opt.batchSize,1])
34 | Y = np.tile(opt.canon4pts[:,1],[opt.batchSize,1])
35 | dX = tf.random_normal([opt.batchSize,4])*opt.pertScale \
36 | +tf.random_normal([opt.batchSize,1])*opt.transScale
37 | dY = tf.random_normal([opt.batchSize,4])*opt.pertScale \
38 | +tf.random_normal([opt.batchSize,1])*opt.transScale
39 | O = np.zeros([opt.batchSize,4],dtype=np.float32)
40 | I = np.ones([opt.batchSize,4],dtype=np.float32)
41 | # fit warp parameters to generated displacements
42 | if opt.warpType=="homography":
43 | A = tf.concat([tf.stack([X,Y,I,O,O,O,-X*(X+dX),-Y*(X+dX)],axis=-1),
44 | tf.stack([O,O,O,X,Y,I,-X*(Y+dY),-Y*(Y+dY)],axis=-1)],1)
45 | b = tf.expand_dims(tf.concat([X+dX,Y+dY],1),-1)
46 | pPert = tf.matrix_solve(A,b)[:,:,0]
47 | pPert -= tf.to_float([[1,0,0,0,1,0,0,0]])
48 | else:
49 | if opt.warpType=="translation":
50 | J = np.concatenate([np.stack([I,O],axis=-1),
51 | np.stack([O,I],axis=-1)],axis=1)
52 | if opt.warpType=="similarity":
53 | J = np.concatenate([np.stack([X,Y,I,O],axis=-1),
54 | np.stack([-Y,X,O,I],axis=-1)],axis=1)
55 | if opt.warpType=="affine":
56 | J = np.concatenate([np.stack([X,Y,I,O,O,O],axis=-1),
57 | np.stack([O,O,O,X,Y,I],axis=-1)],axis=1)
58 | dXY = tf.expand_dims(tf.concat([dX,dY],1),-1)
59 | pPert = tf.matrix_solve_ls(J,dXY)[:,:,0]
60 | return pPert
61 |
62 | # make training batch
63 | def makeBatch(opt,data,PH):
64 | N = len(data["image"])
65 | randIdx = np.random.randint(N,size=[opt.batchSize])
66 | # put data in placeholders
67 | [image,label] = PH
68 | batch = {
69 | image: data["image"][randIdx],
70 | label: data["label"][randIdx],
71 | }
72 | return batch
73 |
74 | # evaluation on test set
75 | def evalTest(opt,sess,data,PH,prediction,imagesEval=[]):
76 | N = len(data["image"])
77 | # put data in placeholders
78 | [image,label] = PH
79 | batchN = int(np.ceil(N/opt.batchSize))
80 | warped = [{},{}]
81 | count = 0
82 | for b in range(batchN):
83 | # use some dummy data (0) as batch filler if necessary
84 | if b!=batchN-1:
85 | realIdx = np.arange(opt.batchSize*b,opt.batchSize*(b+1))
86 | else:
87 | realIdx = np.arange(opt.batchSize*b,N)
88 | idx = np.zeros([opt.batchSize],dtype=int)
89 | idx[:len(realIdx)] = realIdx
90 | batch = {
91 | image: data["image"][idx],
92 | label: data["label"][idx],
93 | }
94 | evalList = sess.run([prediction]+imagesEval,feed_dict=batch)
95 | pred = evalList[0]
96 | count += pred[:len(realIdx)].sum()
97 | if opt.netType=="STN" or opt.netType=="IC-STN":
98 | imgs = evalList[1:]
99 | for i in range(len(realIdx)):
100 | l = data["label"][idx[i]]
101 | if l not in warped[0]: warped[0][l] = []
102 | if l not in warped[1]: warped[1][l] = []
103 | warped[0][l].append(imgs[0][i])
104 | warped[1][l].append(imgs[1][i])
105 | accuracy = float(count)/N
106 | if opt.netType=="STN" or opt.netType=="IC-STN":
107 | mean = [np.array([np.mean(warped[0][l],axis=0) for l in warped[0]]),
108 | np.array([np.mean(warped[1][l],axis=0) for l in warped[1]])]
109 | var = [np.array([np.var(warped[0][l],axis=0) for l in warped[0]]),
110 | np.array([np.var(warped[1][l],axis=0) for l in warped[1]])]
111 | else: mean,var = None,None
112 | return accuracy,mean,var
113 |
--------------------------------------------------------------------------------
/MNIST-tensorflow/graph.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import time
4 | import data,warp,util
5 |
6 | # build classification network
7 | def fullCNN(opt,image):
8 | def conv2Layer(opt,feat,outDim):
9 | weight,bias = createVariable(opt,[3,3,int(feat.shape[-1]),outDim],stddev=opt.stdC)
10 | conv = tf.nn.conv2d(feat,weight,strides=[1,1,1,1],padding="VALID")+bias
11 | return conv
12 | def linearLayer(opt,feat,outDim):
13 | weight,bias = createVariable(opt,[int(feat.shape[-1]),outDim],stddev=opt.stdC)
14 | fc = tf.matmul(feat,weight)+bias
15 | return fc
16 | with tf.variable_scope("classifier"):
17 | feat = image
18 | with tf.variable_scope("conv1"):
19 | feat = conv2Layer(opt,feat,3)
20 | feat = tf.nn.relu(feat)
21 | with tf.variable_scope("conv2"):
22 | feat = conv2Layer(opt,feat,6)
23 | feat = tf.nn.relu(feat)
24 | feat = tf.nn.max_pool(feat,ksize=[1,2,2,1],strides=[1,2,2,1],padding="VALID")
25 | with tf.variable_scope("conv3"):
26 | feat = conv2Layer(opt,feat,9)
27 | feat = tf.nn.relu(feat)
28 | with tf.variable_scope("conv4"):
29 | feat = conv2Layer(opt,feat,12)
30 | feat = tf.nn.relu(feat)
31 | feat = tf.reshape(feat,[opt.batchSize,-1])
32 | with tf.variable_scope("fc5"):
33 | feat = linearLayer(opt,feat,48)
34 | feat = tf.nn.relu(feat)
35 | with tf.variable_scope("fc6"):
36 | feat = linearLayer(opt,feat,opt.labelN)
37 | output = feat
38 | return output
39 |
40 | # build classification network
41 | def CNN(opt,image):
42 | def conv2Layer(opt,feat,outDim):
43 | weight,bias = createVariable(opt,[9,9,int(feat.shape[-1]),outDim],stddev=opt.stdC)
44 | conv = tf.nn.conv2d(feat,weight,strides=[1,1,1,1],padding="VALID")+bias
45 | return conv
46 | def linearLayer(opt,feat,outDim):
47 | weight,bias = createVariable(opt,[int(feat.shape[-1]),outDim],stddev=opt.stdC)
48 | fc = tf.matmul(feat,weight)+bias
49 | return fc
50 | with tf.variable_scope("classifier"):
51 | feat = image
52 | with tf.variable_scope("conv1"):
53 | feat = conv2Layer(opt,feat,3)
54 | feat = tf.nn.relu(feat)
55 | feat = tf.reshape(feat,[opt.batchSize,-1])
56 | with tf.variable_scope("fc2"):
57 | feat = linearLayer(opt,feat,opt.labelN)
58 | output = feat
59 | return output
60 |
61 | # build Spatial Transformer Network
62 | def STN(opt,image):
63 | def conv2Layer(opt,feat,outDim):
64 | weight,bias = createVariable(opt,[7,7,int(feat.shape[-1]),outDim],stddev=opt.stdGP)
65 | conv = tf.nn.conv2d(feat,weight,strides=[1,1,1,1],padding="VALID")+bias
66 | return conv
67 | def linearLayer(opt,feat,outDim,final=False):
68 | weight,bias = createVariable(opt,[int(feat.shape[-1]),outDim],stddev=0.0 if final else opt.stdGP)
69 | fc = tf.matmul(feat,weight)+bias
70 | return fc
71 | imageWarpAll = [image]
72 | with tf.variable_scope("geometric"):
73 | feat = image
74 | with tf.variable_scope("conv1"):
75 | feat = conv2Layer(opt,feat,4)
76 | feat = tf.nn.relu(feat)
77 | with tf.variable_scope("conv2"):
78 | feat = conv2Layer(opt,feat,8)
79 | feat = tf.nn.relu(feat)
80 | feat = tf.nn.max_pool(feat,ksize=[1,2,2,1],strides=[1,2,2,1],padding="VALID")
81 | feat = tf.reshape(feat,[opt.batchSize,-1])
82 | with tf.variable_scope("fc3"):
83 | feat = linearLayer(opt,feat,48)
84 | feat = tf.nn.relu(feat)
85 | with tf.variable_scope("fc4"):
86 | feat = linearLayer(opt,feat,opt.warpDim,final=True)
87 | p = feat
88 | pMtrx = warp.vec2mtrx(opt,p)
89 | imageWarp = warp.transformImage(opt,image,pMtrx)
90 | imageWarpAll.append(imageWarp)
91 | return imageWarpAll
92 |
93 | # build Inverse Compositional STN
94 | def ICSTN(opt,image,p):
95 | def conv2Layer(opt,feat,outDim):
96 | weight,bias = createVariable(opt,[7,7,int(feat.shape[-1]),outDim],stddev=opt.stdGP)
97 | conv = tf.nn.conv2d(feat,weight,strides=[1,1,1,1],padding="VALID")+bias
98 | return conv
99 | def linearLayer(opt,feat,outDim,final=False):
100 | weight,bias = createVariable(opt,[int(feat.shape[-1]),outDim],stddev=0.0 if final else opt.stdGP)
101 | fc = tf.matmul(feat,weight)+bias
102 | return fc
103 | imageWarpAll = []
104 | for l in range(opt.warpN):
105 | with tf.variable_scope("geometric",reuse=l>0):
106 | pMtrx = warp.vec2mtrx(opt,p)
107 | imageWarp = warp.transformImage(opt,image,pMtrx)
108 | imageWarpAll.append(imageWarp)
109 | feat = imageWarp
110 | with tf.variable_scope("conv1"):
111 | feat = conv2Layer(opt,feat,4)
112 | feat = tf.nn.relu(feat)
113 | with tf.variable_scope("conv2"):
114 | feat = conv2Layer(opt,feat,8)
115 | feat = tf.nn.relu(feat)
116 | feat = tf.nn.max_pool(feat,ksize=[1,2,2,1],strides=[1,2,2,1],padding="VALID")
117 | feat = tf.reshape(feat,[opt.batchSize,-1])
118 | with tf.variable_scope("fc3"):
119 | feat = linearLayer(opt,feat,48)
120 | feat = tf.nn.relu(feat)
121 | with tf.variable_scope("fc4"):
122 | feat = linearLayer(opt,feat,opt.warpDim,final=True)
123 | dp = feat
124 | p = warp.compose(opt,p,dp)
125 | pMtrx = warp.vec2mtrx(opt,p)
126 | imageWarp = warp.transformImage(opt,image,pMtrx)
127 | imageWarpAll.append(imageWarp)
128 | return imageWarpAll
129 |
130 | # auxiliary function for creating weight and bias
131 | def createVariable(opt,weightShape,biasShape=None,stddev=None):
132 | if biasShape is None: biasShape = [weightShape[-1]]
133 | weight = tf.get_variable("weight",shape=weightShape,dtype=tf.float32,
134 | initializer=tf.random_normal_initializer(stddev=stddev))
135 | bias = tf.get_variable("bias",shape=biasShape,dtype=tf.float32,
136 | initializer=tf.random_normal_initializer(stddev=stddev))
137 | return weight,bias
138 |
--------------------------------------------------------------------------------
/MNIST-tensorflow/options.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 | import warp
4 | import util
5 |
6 | def set(training):
7 |
8 | # parse input arguments
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("netType", choices=["CNN","STN","IC-STN"], help="type of network")
11 | parser.add_argument("--group", default="0", help="name for group")
12 | parser.add_argument("--model", default="test", help="name for model instance")
13 | parser.add_argument("--size", default="28x28", help="image resolution")
14 | parser.add_argument("--warpType", default="homography", help="type of warp function on images",
15 | choices=["translation","similarity","affine","homography"])
16 | parser.add_argument("--warpN", type=int, default=4, help="number of recurrent transformations (for IC-STN)")
17 | parser.add_argument("--stdC", type=float, default=0.1, help="initialization stddev (classification network)")
18 | parser.add_argument("--stdGP", type=float, default=0.1, help="initialization stddev (geometric predictor)")
19 | parser.add_argument("--pertScale", type=float, default=0.25, help="initial perturbation scale")
20 | parser.add_argument("--transScale", type=float, default=0.25, help="initial translation scale")
21 | if training: # training
22 | parser.add_argument("--batchSize", type=int, default=100, help="batch size for SGD")
23 | parser.add_argument("--lrC", type=float, default=1e-2, help="learning rate (classification network)")
24 | parser.add_argument("--lrCdecay", type=float, default=1.0, help="learning rate decay (classification network)")
25 | parser.add_argument("--lrCstep", type=int, default=100000, help="learning rate decay step size (classification network)")
26 | parser.add_argument("--lrGP", type=float, default=None, help="learning rate (geometric predictor)")
27 | parser.add_argument("--lrGPdecay", type=float, default=1.0, help="learning rate decay (geometric predictor)")
28 | parser.add_argument("--lrGPstep", type=int, default=100000, help="learning rate decay step size (geometric predictor)")
29 | parser.add_argument("--fromIt", type=int, default=0, help="resume training from iteration number")
30 | parser.add_argument("--toIt", type=int, default=500000, help="run training to iteration number")
31 | else: # evaluation
32 | parser.add_argument("--batchSize", type=int, default=1, help="batch size for evaluation")
33 | opt = parser.parse_args()
34 |
35 | if opt.lrGP is None: opt.lrGP = 0 if opt.netType=="CNN" else \
36 | 1e-2 if opt.netType=="STN" else \
37 | 1e-4 if opt.netType=="IC-STN" else None
38 |
39 | # --- below are automatically set ---
40 | opt.training = training
41 | opt.H,opt.W = [int(x) for x in opt.size.split("x")]
42 | opt.visBlockSize = int(np.floor(np.sqrt(opt.batchSize)))
43 | opt.warpDim = 2 if opt.warpType == "translation" else \
44 | 4 if opt.warpType == "similarity" else \
45 | 6 if opt.warpType == "affine" else \
46 | 8 if opt.warpType == "homography" else None
47 | opt.labelN = 10
48 | opt.canon4pts = np.array([[-1,-1],[-1,1],[1,1],[1,-1]],dtype=np.float32)
49 | opt.image4pts = np.array([[0,0],[0,opt.H-1],[opt.W-1,opt.H-1],[opt.W-1,0]],dtype=np.float32)
50 | opt.refMtrx = warp.fit(Xsrc=opt.canon4pts,Xdst=opt.image4pts)
51 | if opt.netType=="STN": opt.warpN = 1
52 |
53 | print("({0}) {1}".format(
54 | util.toGreen("{0}".format(opt.group)),
55 | util.toGreen("{0}".format(opt.model))))
56 | print("------------------------------------------")
57 | print("network type: {0}, recurrent warps: {1}".format(
58 | util.toYellow("{0}".format(opt.netType)),
59 | util.toYellow("{0}".format(opt.warpN if opt.netType=="IC-STN" else "X"))))
60 | print("batch size: {0}, image size: {1}x{2}".format(
61 | util.toYellow("{0}".format(opt.batchSize)),
62 | util.toYellow("{0}".format(opt.H)),
63 | util.toYellow("{0}".format(opt.W))))
64 | print("warpScale: (pert) {0} (trans) {1}".format(
65 | util.toYellow("{0}".format(opt.pertScale)),
66 | util.toYellow("{0}".format(opt.transScale))))
67 | if training:
68 | print("[geometric predictor] stddev={0}, lr={1}".format(
69 | util.toYellow("{0:.0e}".format(opt.stdGP)),
70 | util.toYellow("{0:.0e}".format(opt.lrGP))))
71 | print("[classification network] stddev={0}, lr={1}".format(
72 | util.toYellow("{0:.0e}".format(opt.stdC)),
73 | util.toYellow("{0:.0e}".format(opt.lrC))))
74 | print("------------------------------------------")
75 | if training:
76 | print(util.toMagenta("training model ({0}) {1}...".format(opt.group,opt.model)))
77 |
78 | return opt
79 |
--------------------------------------------------------------------------------
/MNIST-tensorflow/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time,os,sys
3 | import argparse
4 | import util
5 |
6 | print(util.toYellow("======================================================="))
7 | print(util.toYellow("train.py (training on MNIST)"))
8 | print(util.toYellow("======================================================="))
9 |
10 | import tensorflow as tf
11 | import data,graph,warp,util
12 | import options
13 |
14 | print(util.toMagenta("setting configurations..."))
15 | opt = options.set(training=True)
16 |
17 | # create directories for model output
18 | util.mkdir("models_{0}".format(opt.group))
19 |
20 | print(util.toMagenta("building graph..."))
21 | tf.reset_default_graph()
22 | # build graph
23 | with tf.device("/gpu:0"):
24 | # ------ define input data ------
25 | image = tf.placeholder(tf.float32,shape=[opt.batchSize,opt.H,opt.W])
26 | label = tf.placeholder(tf.int64,shape=[opt.batchSize])
27 | PH = [image,label]
28 | # ------ generate perturbation ------
29 | pInit = data.genPerturbations(opt)
30 | pInitMtrx = warp.vec2mtrx(opt,pInit)
31 | # ------ build network ------
32 | image = tf.expand_dims(image,axis=-1)
33 | imagePert = warp.transformImage(opt,image,pInitMtrx)
34 | if opt.netType=="CNN":
35 | output = graph.fullCNN(opt,imagePert)
36 | elif opt.netType=="STN":
37 | imageWarpAll = graph.STN(opt,imagePert)
38 | imageWarp = imageWarpAll[-1]
39 | output = graph.CNN(opt,imageWarp)
40 | elif opt.netType=="IC-STN":
41 | imageWarpAll = graph.ICSTN(opt,image,pInit)
42 | imageWarp = imageWarpAll[-1]
43 | output = graph.CNN(opt,imageWarp)
44 | softmax = tf.nn.softmax(output)
45 | labelOnehot = tf.one_hot(label,opt.labelN)
46 | prediction = tf.equal(tf.argmax(softmax,1),label)
47 | # ------ define loss ------
48 | softmaxLoss = tf.nn.softmax_cross_entropy_with_logits(logits=output,labels=labelOnehot)
49 | loss = tf.reduce_mean(softmaxLoss)
50 | # ------ optimizer ------
51 | lrGP_PH,lrC_PH = tf.placeholder(tf.float32,shape=[]),tf.placeholder(tf.float32,shape=[])
52 | optim = util.setOptimizer(opt,loss,lrGP_PH,lrC_PH)
53 | # ------ generate summaries ------
54 | summaryImageTrain = []
55 | summaryImageTest = []
56 | if opt.netType=="STN" or opt.netType=="IC-STN":
57 | for l in range(opt.warpN+1):
58 | summaryImageTrain.append(util.imageSummary(opt,imageWarpAll[l],"TRAIN_warp{0}".format(l),opt.H,opt.W))
59 | summaryImageTest.append(util.imageSummary(opt,imageWarpAll[l],"TEST_warp{0}".format(l),opt.H,opt.W))
60 | summaryImageTrain = tf.summary.merge(summaryImageTrain)
61 | summaryImageTest = tf.summary.merge(summaryImageTest)
62 | summaryLossTrain = tf.summary.scalar("TRAIN_loss",loss)
63 | testErrorPH = tf.placeholder(tf.float32,shape=[])
64 | testImagePH = tf.placeholder(tf.float32,shape=[opt.labelN,opt.H,opt.W,1])
65 | summaryErrorTest = tf.summary.scalar("TEST_error",testErrorPH)
66 | if opt.netType=="STN" or opt.netType=="IC-STN":
67 | summaryMeanTest0 = util.imageSummaryMeanVar(opt,testImagePH,"TEST_mean_init",opt.H,opt.W)
68 | summaryMeanTest1 = util.imageSummaryMeanVar(opt,testImagePH,"TEST_mean_warped",opt.H,opt.W)
69 | summaryVarTest0 = util.imageSummaryMeanVar(opt,testImagePH*3,"TEST_var_init",opt.H,opt.W)
70 | summaryVarTest1 = util.imageSummaryMeanVar(opt,testImagePH*3,"TEST_var_warped",opt.H,opt.W)
71 |
72 | # load data
73 | print(util.toMagenta("loading MNIST dataset..."))
74 | trainData,validData,testData = data.loadMNIST("data/MNIST.npz")
75 |
76 | # prepare model saver/summary writer
77 | saver = tf.train.Saver(max_to_keep=20)
78 | summaryWriter = tf.summary.FileWriter("summary_{0}/{1}".format(opt.group,opt.model))
79 |
80 | print(util.toYellow("======= TRAINING START ======="))
81 | timeStart = time.time()
82 | # start session
83 | tfConfig = tf.ConfigProto(allow_soft_placement=True)
84 | tfConfig.gpu_options.allow_growth = True
85 | with tf.Session(config=tfConfig) as sess:
86 | sess.run(tf.global_variables_initializer())
87 | summaryWriter.add_graph(sess.graph)
88 | if opt.fromIt!=0:
89 | util.restoreModel(opt,sess,saver,opt.fromIt)
90 | print(util.toMagenta("resuming from iteration {0}...".format(opt.fromIt)))
91 | print(util.toMagenta("start training..."))
92 |
93 | # training loop
94 | for i in range(opt.fromIt,opt.toIt):
95 | lrGP = opt.lrGP*opt.lrGPdecay**(i//opt.lrGPstep)
96 | lrC = opt.lrC*opt.lrCdecay**(i//opt.lrCstep)
97 | # make training batch
98 | batch = data.makeBatch(opt,trainData,PH)
99 | batch[lrGP_PH] = lrGP
100 | batch[lrC_PH] = lrC
101 | # run one step
102 | _,l = sess.run([optim,loss],feed_dict=batch)
103 | if (i+1)%100==0:
104 | print("it. {0}/{1} lr={3}(GP),{4}(C), loss={5}, time={2}"
105 | .format(util.toCyan("{0}".format(i+1)),
106 | opt.toIt,
107 | util.toGreen("{0:.2f}".format(time.time()-timeStart)),
108 | util.toYellow("{0:.0e}".format(lrGP)),
109 | util.toYellow("{0:.0e}".format(lrC)),
110 | util.toRed("{0:.4f}".format(l))))
111 | if (i+1)%100==0:
112 | summaryWriter.add_summary(sess.run(summaryLossTrain,feed_dict=batch),i+1)
113 | if (i+1)%500==0 and (opt.netType=="STN" or opt.netType=="IC-STN"):
114 | summaryWriter.add_summary(sess.run(summaryImageTrain,feed_dict=batch),i+1)
115 | summaryWriter.add_summary(sess.run(summaryImageTest,feed_dict=batch),i+1)
116 | if (i+1)%1000==0:
117 | # evaluate on test set
118 | if opt.netType=="STN" or opt.netType=="IC-STN":
119 | testAcc,testMean,testVar = data.evalTest(opt,sess,testData,PH,prediction,imagesEval=[imagePert,imageWarp])
120 | else:
121 | testAcc,_,_ = data.evalTest(opt,sess,testData,PH,prediction)
122 | testError = (1-testAcc)*100
123 | summaryWriter.add_summary(sess.run(summaryErrorTest,feed_dict={testErrorPH:testError}),i+1)
124 | if opt.netType=="STN" or opt.netType=="IC-STN":
125 | summaryWriter.add_summary(sess.run(summaryMeanTest0,feed_dict={testImagePH:testMean[0]}),i+1)
126 | summaryWriter.add_summary(sess.run(summaryMeanTest1,feed_dict={testImagePH:testMean[1]}),i+1)
127 | summaryWriter.add_summary(sess.run(summaryVarTest0,feed_dict={testImagePH:testVar[0]}),i+1)
128 | summaryWriter.add_summary(sess.run(summaryVarTest1,feed_dict={testImagePH:testVar[1]}),i+1)
129 | if (i+1)%10000==0:
130 | util.saveModel(opt,sess,saver,i+1)
131 | print(util.toGreen("model saved: {0}/{1}, it.{2}".format(opt.group,opt.model,i+1)))
132 |
133 | print(util.toYellow("======= TRAINING DONE ======="))
134 |
--------------------------------------------------------------------------------
/MNIST-tensorflow/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.misc
3 | import tensorflow as tf
4 | import os
5 | import termcolor
6 |
7 | def mkdir(path):
8 | if not os.path.exists(path): os.mkdir(path)
9 | def imread(fname):
10 | return scipy.misc.imread(fname)/255.0
11 | def imsave(fname,array):
12 | scipy.misc.toimage(array,cmin=0.0,cmax=1.0).save(fname)
13 |
14 | # convert to colored strings
15 | def toRed(content): return termcolor.colored(content,"red",attrs=["bold"])
16 | def toGreen(content): return termcolor.colored(content,"green",attrs=["bold"])
17 | def toBlue(content): return termcolor.colored(content,"blue",attrs=["bold"])
18 | def toCyan(content): return termcolor.colored(content,"cyan",attrs=["bold"])
19 | def toYellow(content): return termcolor.colored(content,"yellow",attrs=["bold"])
20 | def toMagenta(content): return termcolor.colored(content,"magenta",attrs=["bold"])
21 |
22 | # make image summary from image batch
23 | def imageSummary(opt,image,tag,H,W):
24 | blockSize = opt.visBlockSize
25 | imageOne = tf.batch_to_space(image[:blockSize**2],crops=[[0,0],[0,0]],block_size=blockSize)
26 | imagePermute = tf.reshape(imageOne,[H,blockSize,W,blockSize,-1])
27 | imageTransp = tf.transpose(imagePermute,[1,0,3,2,4])
28 | imageBlocks = tf.reshape(imageTransp,[1,H*blockSize,W*blockSize,-1])
29 | imageBlocks = tf.cast(imageBlocks*255,tf.uint8)
30 | summary = tf.summary.image(tag,imageBlocks)
31 | return summary
32 |
33 | # make image summary from image batch (mean/variance)
34 | def imageSummaryMeanVar(opt,image,tag,H,W):
35 | imageOne = tf.batch_to_space_nd(image,crops=[[0,0],[0,0]],block_shape=[1,10])
36 | imagePermute = tf.reshape(imageOne,[H,1,W,10,-1])
37 | imageTransp = tf.transpose(imagePermute,[1,0,3,2,4])
38 | imageBlocks = tf.reshape(imageTransp,[1,H*1,W*10,-1])
39 | imageBlocks = tf.cast(imageBlocks*255,tf.uint8)
40 | summary = tf.summary.image(tag,imageBlocks)
41 | return summary
42 |
43 | # set optimizer for different learning rates
44 | def setOptimizer(opt,loss,lrGP,lrC):
45 | varsGP = [v for v in tf.global_variables() if "geometric" in v.name]
46 | varsC = [v for v in tf.global_variables() if "classifier" in v.name]
47 | gradC = tf.gradients(loss,varsC)
48 | optimC = tf.train.GradientDescentOptimizer(lrC).apply_gradients(zip(gradC,varsC))
49 | if len(varsGP)>0:
50 | gradGP = tf.gradients(loss,varsGP)
51 | optimGP = tf.train.GradientDescentOptimizer(lrGP).apply_gradients(zip(gradGP,varsGP))
52 | optim = tf.group(optimC,optimGP)
53 | else:
54 | optim = optimC
55 | return optim
56 |
57 | # restore model
58 | def restoreModel(opt,sess,saver,it):
59 | saver.restore(sess,"models_{0}/{1}_it{2}.ckpt".format(opt.group,opt.model,it,opt.warpN))
60 | # save model
61 | def saveModel(opt,sess,saver,it):
62 | saver.save(sess,"models_{0}/{1}_it{2}.ckpt".format(opt.group,opt.model,it,opt.warpN))
63 |
64 |
--------------------------------------------------------------------------------
/MNIST-tensorflow/warp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.linalg
3 | import tensorflow as tf
4 |
5 | # fit (affine) warp between two sets of points
6 | def fit(Xsrc,Xdst):
7 | ptsN = len(Xsrc)
8 | X,Y,U,V,O,I = Xsrc[:,0],Xsrc[:,1],Xdst[:,0],Xdst[:,1],np.zeros([ptsN]),np.ones([ptsN])
9 | A = np.concatenate((np.stack([X,Y,I,O,O,O],axis=1),
10 | np.stack([O,O,O,X,Y,I],axis=1)),axis=0)
11 | b = np.concatenate((U,V),axis=0)
12 | p1,p2,p3,p4,p5,p6 = scipy.linalg.lstsq(A,b)[0].squeeze()
13 | pMtrx = np.array([[p1,p2,p3],[p4,p5,p6],[0,0,1]],dtype=np.float32)
14 | return pMtrx
15 |
16 | # compute composition of warp parameters
17 | def compose(opt,p,dp):
18 | with tf.name_scope("compose"):
19 | pMtrx = vec2mtrx(opt,p)
20 | dpMtrx = vec2mtrx(opt,dp)
21 | pMtrxNew = tf.matmul(dpMtrx,pMtrx)
22 | pMtrxNew /= pMtrxNew[:,2:3,2:3]
23 | pNew = mtrx2vec(opt,pMtrxNew)
24 | return pNew
25 |
26 | # compute inverse of warp parameters
27 | def inverse(opt,p):
28 | with tf.name_scope("inverse"):
29 | pMtrx = vec2mtrx(opt,p)
30 | pInvMtrx = tf.matrix_inverse(pMtrx)
31 | pInv = mtrx2vec(opt,pInvMtrx)
32 | return pInv
33 |
34 | # convert warp parameters to matrix
35 | def vec2mtrx(opt,p):
36 | with tf.name_scope("vec2mtrx"):
37 | O = tf.zeros([opt.batchSize])
38 | I = tf.ones([opt.batchSize])
39 | if opt.warpType=="translation":
40 | tx,ty = tf.unstack(p,axis=1)
41 | pMtrx = tf.transpose(tf.stack([[I,O,tx],[O,I,ty],[O,O,I]]),perm=[2,0,1])
42 | if opt.warpType=="similarity":
43 | pc,ps,tx,ty = tf.unstack(p,axis=1)
44 | pMtrx = tf.transpose(tf.stack([[I+pc,-ps,tx],[ps,I+pc,ty],[O,O,I]]),perm=[2,0,1])
45 | if opt.warpType=="affine":
46 | p1,p2,p3,p4,p5,p6,p7,p8 = tf.unstack(p,axis=1)
47 | pMtrx = tf.transpose(tf.stack([[I+p1,p2,p3],[p4,I+p5,p6],[O,O,I]]),perm=[2,0,1])
48 | if opt.warpType=="homography":
49 | p1,p2,p3,p4,p5,p6,p7,p8 = tf.unstack(p,axis=1)
50 | pMtrx = tf.transpose(tf.stack([[I+p1,p2,p3],[p4,I+p5,p6],[p7,p8,I]]),perm=[2,0,1])
51 | return pMtrx
52 |
53 | # convert warp matrix to parameters
54 | def mtrx2vec(opt,pMtrx):
55 | with tf.name_scope("mtrx2vec"):
56 | [row0,row1,row2] = tf.unstack(pMtrx,axis=1)
57 | [e00,e01,e02] = tf.unstack(row0,axis=1)
58 | [e10,e11,e12] = tf.unstack(row1,axis=1)
59 | [e20,e21,e22] = tf.unstack(row2,axis=1)
60 | if opt.warpType=="translation": p = tf.stack([e02,e12],axis=1)
61 | if opt.warpType=="similarity": p = tf.stack([e00-1,e10,e02,e12],axis=1)
62 | if opt.warpType=="affine": p = tf.stack([e00-1,e01,e02,e10,e11-1,e12],axis=1)
63 | if opt.warpType=="homography": p = tf.stack([e00-1,e01,e02,e10,e11-1,e12,e20,e21],axis=1)
64 | return p
65 |
66 | # warp the image
67 | def transformImage(opt,image,pMtrx):
68 | with tf.name_scope("transformImage"):
69 | refMtrx = tf.tile(tf.expand_dims(opt.refMtrx,axis=0),[opt.batchSize,1,1])
70 | transMtrx = tf.matmul(refMtrx,pMtrx)
71 | # warp the canonical coordinates
72 | X,Y = np.meshgrid(np.linspace(-1,1,opt.W),np.linspace(-1,1,opt.H))
73 | X,Y = X.flatten(),Y.flatten()
74 | XYhom = np.stack([X,Y,np.ones_like(X)],axis=1).T
75 | XYhom = np.tile(XYhom,[opt.batchSize,1,1]).astype(np.float32)
76 | XYwarpHom = tf.matmul(transMtrx,XYhom)
77 | XwarpHom,YwarpHom,ZwarpHom = tf.unstack(XYwarpHom,axis=1)
78 | Xwarp = tf.reshape(XwarpHom/(ZwarpHom+1e-8),[opt.batchSize,opt.H,opt.W])
79 | Ywarp = tf.reshape(YwarpHom/(ZwarpHom+1e-8),[opt.batchSize,opt.H,opt.W])
80 | # get the integer sampling coordinates
81 | Xfloor,Xceil = tf.floor(Xwarp),tf.ceil(Xwarp)
82 | Yfloor,Yceil = tf.floor(Ywarp),tf.ceil(Ywarp)
83 | XfloorInt,XceilInt = tf.to_int32(Xfloor),tf.to_int32(Xceil)
84 | YfloorInt,YceilInt = tf.to_int32(Yfloor),tf.to_int32(Yceil)
85 | imageIdx = np.tile(np.arange(opt.batchSize).reshape([opt.batchSize,1,1]),[1,opt.H,opt.W])
86 | imageVec = tf.reshape(image,[-1,int(image.shape[-1])])
87 | imageVecOut = tf.concat([imageVec,tf.zeros([1,int(image.shape[-1])])],axis=0)
88 | idxUL = (imageIdx*opt.H+YfloorInt)*opt.W+XfloorInt
89 | idxUR = (imageIdx*opt.H+YfloorInt)*opt.W+XceilInt
90 | idxBL = (imageIdx*opt.H+YceilInt)*opt.W+XfloorInt
91 | idxBR = (imageIdx*opt.H+YceilInt)*opt.W+XceilInt
92 | idxOutside = tf.fill([opt.batchSize,opt.H,opt.W],opt.batchSize*opt.H*opt.W)
93 | def insideImage(Xint,Yint):
94 | return (Xint>=0)&(Xint