├── .editorconfig ├── .gitignore ├── LICENSE ├── MNIST-pytorch ├── data.py ├── graph.py ├── options.py ├── train.py ├── util.py └── warp.py ├── MNIST-tensorflow ├── data.py ├── graph.py ├── options.py ├── train.py ├── util.py └── warp.py ├── README.md └── traffic-sign-tensorflow ├── data.py ├── graph.py ├── options.py ├── train.py ├── util.py └── warp.py /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | insert_final_newline = true 6 | indent_style = tab 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | 10 | [*.md] 11 | trim_trailing_whitespace = false 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Chen-Hsuan Lin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MNIST-pytorch/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg 3 | import os,time 4 | import torch 5 | import torchvision 6 | 7 | import warp,util 8 | 9 | # load MNIST data 10 | def loadMNIST(opt,path): 11 | os.makedirs(path,exist_ok=True) 12 | trainDataset = torchvision.datasets.MNIST(path,train=True,download=True) 13 | testDataset = torchvision.datasets.MNIST(path,train=False,download=True) 14 | trainData,testData = {},{} 15 | trainData["image"] = torch.tensor([np.array(sample[0])/255.0 for sample in trainDataset],dtype=torch.float32) 16 | testData["image"] = torch.tensor([np.array(sample[0])/255.0 for sample in testDataset],dtype=torch.float32) 17 | trainData["label"] = torch.tensor([sample[1] for sample in trainDataset]) 18 | testData["label"] = torch.tensor([sample[1] for sample in testDataset]) 19 | return trainData,testData 20 | 21 | # generate training batch 22 | def genPerturbations(opt): 23 | X = np.tile(opt.canon4pts[:,0],[opt.batchSize,1]) 24 | Y = np.tile(opt.canon4pts[:,1],[opt.batchSize,1]) 25 | O = np.zeros([opt.batchSize,4],dtype=np.float32) 26 | I = np.ones([opt.batchSize,4],dtype=np.float32) 27 | dX = np.random.randn(opt.batchSize,4)*opt.pertScale \ 28 | +np.random.randn(opt.batchSize,1)*opt.transScale 29 | dY = np.random.randn(opt.batchSize,4)*opt.pertScale \ 30 | +np.random.randn(opt.batchSize,1)*opt.transScale 31 | dX,dY = dX.astype(np.float32),dY.astype(np.float32) 32 | # fit warp parameters to generated displacements 33 | if opt.warpType=="homography": 34 | A = np.concatenate([np.stack([X,Y,I,O,O,O,-X*(X+dX),-Y*(X+dX)],axis=-1), 35 | np.stack([O,O,O,X,Y,I,-X*(Y+dY),-Y*(Y+dY)],axis=-1)],axis=1) 36 | b = np.expand_dims(np.concatenate([X+dX,Y+dY],axis=1),axis=-1) 37 | pPert = np.matmul(np.linalg.inv(A),b).squeeze() 38 | pPert -= np.array([1,0,0,0,1,0,0,0]) 39 | else: 40 | if opt.warpType=="translation": 41 | J = np.concatenate([np.stack([I,O],axis=-1), 42 | np.stack([O,I],axis=-1)],axis=1) 43 | if opt.warpType=="similarity": 44 | J = np.concatenate([np.stack([X,Y,I,O],axis=-1), 45 | np.stack([-Y,X,O,I],axis=-1)],axis=1) 46 | if opt.warpType=="affine": 47 | J = np.concatenate([np.stack([X,Y,I,O,O,O],axis=-1), 48 | np.stack([O,O,O,X,Y,I],axis=-1)],axis=1) 49 | dXY = np.expand_dims(np.concatenate([dX,dY],axis=1),axis=-1) 50 | Jtransp = np.transpose(J,axes=[0,2,1]) 51 | pPert = np.matmul(np.linalg.inv(np.matmul(Jtransp,J)),np.matmul(Jtransp,dXY)).squeeze() 52 | pInit = torch.from_numpy(pPert).cuda() 53 | return pInit 54 | 55 | # make training batch 56 | def makeBatch(opt,data): 57 | N = len(data["image"]) 58 | randIdx = np.random.randint(N,size=[opt.batchSize]) 59 | batch = { 60 | "image": data["image"][randIdx].cuda(), 61 | "label": data["label"][randIdx].cuda(), 62 | } 63 | return batch 64 | 65 | # evaluation on test set 66 | def evalTest(opt,data,geometric,classifier): 67 | geometric.eval() 68 | classifier.eval() 69 | N = len(data["image"]) 70 | batchN = int(np.ceil(N/opt.batchSize)) 71 | warped = [{},{}] 72 | count = 0 73 | for b in range(batchN): 74 | # use some dummy data (0) as batch filler if necessary 75 | if b!=batchN-1: 76 | realIdx = np.arange(opt.batchSize*b,opt.batchSize*(b+1)) 77 | else: 78 | realIdx = np.arange(opt.batchSize*b,N) 79 | idx = np.zeros([opt.batchSize],dtype=int) 80 | idx[:len(realIdx)] = realIdx 81 | # make training batch 82 | image = data["image"][idx].cuda() 83 | label = data["label"][idx].cuda() 84 | image.data.unsqueeze_(dim=1) 85 | # generate perturbation 86 | pInit = genPerturbations(opt) 87 | pInitMtrx = warp.vec2mtrx(opt,pInit) 88 | imagePert = warp.transformImage(opt,image,pInitMtrx) 89 | imageWarpAll = geometric(opt,image,pInit) if opt.netType=="IC-STN" else geometric(opt,imagePert) 90 | imageWarp = imageWarpAll[-1] 91 | output = classifier(opt,imageWarp) 92 | _,pred = output.max(dim=1) 93 | count += int((pred==label).sum().cpu().numpy()) 94 | if opt.netType=="STN" or opt.netType=="IC-STN": 95 | imgPert = imagePert.detach().cpu().numpy() 96 | imgWarp = imageWarp.detach().cpu().numpy() 97 | for i in range(len(realIdx)): 98 | l = data["label"][idx[i]].item() 99 | if l not in warped[0]: warped[0][l] = [] 100 | if l not in warped[1]: warped[1][l] = [] 101 | warped[0][l].append(imgPert[i]) 102 | warped[1][l].append(imgWarp[i]) 103 | accuracy = float(count)/N 104 | if opt.netType=="STN" or opt.netType=="IC-STN": 105 | mean = [np.array([np.mean(warped[0][l],axis=0) for l in warped[0]]), 106 | np.array([np.mean(warped[1][l],axis=0) for l in warped[1]])] 107 | var = [np.array([np.var(warped[0][l],axis=0) for l in warped[0]]), 108 | np.array([np.var(warped[1][l],axis=0) for l in warped[1]])] 109 | else: mean,var = None,None 110 | geometric.train() 111 | classifier.train() 112 | return accuracy,mean,var 113 | -------------------------------------------------------------------------------- /MNIST-pytorch/graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import time 4 | import data,warp,util 5 | 6 | # build classification network 7 | class FullCNN(torch.nn.Module): 8 | def __init__(self,opt): 9 | super(FullCNN,self).__init__() 10 | self.inDim = 1 11 | def conv2Layer(outDim): 12 | conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[3,3],stride=1,padding=0) 13 | self.inDim = outDim 14 | return conv 15 | def linearLayer(outDim): 16 | fc = torch.nn.Linear(self.inDim,outDim) 17 | self.inDim = outDim 18 | return fc 19 | def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2) 20 | self.conv2Layers = torch.nn.Sequential( 21 | conv2Layer(3),torch.nn.ReLU(True), 22 | conv2Layer(6),torch.nn.ReLU(True),maxpoolLayer(), 23 | conv2Layer(9),torch.nn.ReLU(True), 24 | conv2Layer(12),torch.nn.ReLU(True) 25 | ) 26 | self.inDim *= 8**2 27 | self.linearLayers = torch.nn.Sequential( 28 | linearLayer(48),torch.nn.ReLU(True), 29 | linearLayer(opt.labelN) 30 | ) 31 | initialize(opt,self,opt.stdC) 32 | def forward(self,opt,image): 33 | feat = image 34 | feat = self.conv2Layers(feat).reshape(opt.batchSize,-1) 35 | feat = self.linearLayers(feat) 36 | output = feat 37 | return output 38 | 39 | # build classification network 40 | class CNN(torch.nn.Module): 41 | def __init__(self,opt): 42 | super(CNN,self).__init__() 43 | self.inDim = 1 44 | def conv2Layer(outDim): 45 | conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[9,9],stride=1,padding=0) 46 | self.inDim = outDim 47 | return conv 48 | def linearLayer(outDim): 49 | fc = torch.nn.Linear(self.inDim,outDim) 50 | self.inDim = outDim 51 | return fc 52 | def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2) 53 | self.conv2Layers = torch.nn.Sequential( 54 | conv2Layer(3),torch.nn.ReLU(True) 55 | ) 56 | self.inDim *= 20**2 57 | self.linearLayers = torch.nn.Sequential( 58 | linearLayer(opt.labelN) 59 | ) 60 | initialize(opt,self,opt.stdC) 61 | def forward(self,opt,image): 62 | feat = image 63 | feat = self.conv2Layers(feat).reshape(opt.batchSize,-1) 64 | feat = self.linearLayers(feat) 65 | output = feat 66 | return output 67 | 68 | # an identity class to skip geometric predictors 69 | class Identity(torch.nn.Module): 70 | def __init__(self): super(Identity,self).__init__() 71 | def forward(self,opt,feat): return [feat] 72 | 73 | # build Spatial Transformer Network 74 | class STN(torch.nn.Module): 75 | def __init__(self,opt): 76 | super(STN,self).__init__() 77 | self.inDim = 1 78 | def conv2Layer(outDim): 79 | conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[7,7],stride=1,padding=0) 80 | self.inDim = outDim 81 | return conv 82 | def linearLayer(outDim): 83 | fc = torch.nn.Linear(self.inDim,outDim) 84 | self.inDim = outDim 85 | return fc 86 | def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2) 87 | self.conv2Layers = torch.nn.Sequential( 88 | conv2Layer(4),torch.nn.ReLU(True), 89 | conv2Layer(8),torch.nn.ReLU(True),maxpoolLayer() 90 | ) 91 | self.inDim *= 8**2 92 | self.linearLayers = torch.nn.Sequential( 93 | linearLayer(48),torch.nn.ReLU(True), 94 | linearLayer(opt.warpDim) 95 | ) 96 | initialize(opt,self,opt.stdGP,last0=True) 97 | def forward(self,opt,image): 98 | imageWarpAll = [image] 99 | feat = image 100 | feat = self.conv2Layers(feat).reshape(opt.batchSize,-1) 101 | feat = self.linearLayers(feat) 102 | p = feat 103 | pMtrx = warp.vec2mtrx(opt,p) 104 | imageWarp = warp.transformImage(opt,image,pMtrx) 105 | imageWarpAll.append(imageWarp) 106 | return imageWarpAll 107 | 108 | # build Inverse Compositional STN 109 | class ICSTN(torch.nn.Module): 110 | def __init__(self,opt): 111 | super(ICSTN,self).__init__() 112 | self.inDim = 1 113 | def conv2Layer(outDim): 114 | conv = torch.nn.Conv2d(self.inDim,outDim,kernel_size=[7,7],stride=1,padding=0) 115 | self.inDim = outDim 116 | return conv 117 | def linearLayer(outDim): 118 | fc = torch.nn.Linear(self.inDim,outDim) 119 | self.inDim = outDim 120 | return fc 121 | def maxpoolLayer(): return torch.nn.MaxPool2d([2,2],stride=2) 122 | self.conv2Layers = torch.nn.Sequential( 123 | conv2Layer(4),torch.nn.ReLU(True), 124 | conv2Layer(8),torch.nn.ReLU(True),maxpoolLayer() 125 | ) 126 | self.inDim *= 8**2 127 | self.linearLayers = torch.nn.Sequential( 128 | linearLayer(48),torch.nn.ReLU(True), 129 | linearLayer(opt.warpDim) 130 | ) 131 | initialize(opt,self,opt.stdGP,last0=True) 132 | def forward(self,opt,image,p): 133 | imageWarpAll = [] 134 | for l in range(opt.warpN): 135 | pMtrx = warp.vec2mtrx(opt,p) 136 | imageWarp = warp.transformImage(opt,image,pMtrx) 137 | imageWarpAll.append(imageWarp) 138 | feat = imageWarp 139 | feat = self.conv2Layers(feat).reshape(opt.batchSize,-1) 140 | feat = self.linearLayers(feat) 141 | dp = feat 142 | p = warp.compose(opt,p,dp) 143 | pMtrx = warp.vec2mtrx(opt,p) 144 | imageWarp = warp.transformImage(opt,image,pMtrx) 145 | imageWarpAll.append(imageWarp) 146 | return imageWarpAll 147 | 148 | # initialize weights/biases 149 | def initialize(opt,model,stddev,last0=False): 150 | for m in model.conv2Layers: 151 | if isinstance(m,torch.nn.Conv2d): 152 | m.weight.data.normal_(0,stddev) 153 | m.bias.data.normal_(0,stddev) 154 | for m in model.linearLayers: 155 | if isinstance(m,torch.nn.Linear): 156 | if last0 and m is model.linearLayers[-1]: 157 | m.weight.data.zero_() 158 | m.bias.data.zero_() 159 | else: 160 | m.weight.data.normal_(0,stddev) 161 | m.bias.data.normal_(0,stddev) 162 | -------------------------------------------------------------------------------- /MNIST-pytorch/options.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import warp 4 | import util 5 | import torch 6 | 7 | def set(training): 8 | 9 | # parse input arguments 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("netType", choices=["CNN","STN","IC-STN"], help="type of network") 12 | parser.add_argument("--group", default="0", help="name for group") 13 | parser.add_argument("--model", default="test", help="name for model instance") 14 | parser.add_argument("--size", default="28x28", help="image resolution") 15 | parser.add_argument("--warpType", default="homography", help="type of warp function on images", 16 | choices=["translation","similarity","affine","homography"]) 17 | parser.add_argument("--warpN", type=int, default=4, help="number of recurrent transformations (for IC-STN)") 18 | parser.add_argument("--stdC", type=float, default=0.1, help="initialization stddev (classification network)") 19 | parser.add_argument("--stdGP", type=float, default=0.1, help="initialization stddev (geometric predictor)") 20 | parser.add_argument("--pertScale", type=float, default=0.25, help="initial perturbation scale") 21 | parser.add_argument("--transScale", type=float, default=0.25, help="initial translation scale") 22 | if training: # training 23 | parser.add_argument("--port", type=int, default=8097, help="port number for visdom visualization") 24 | parser.add_argument("--batchSize", type=int, default=100, help="batch size for SGD") 25 | parser.add_argument("--lrC", type=float, default=1e-2, help="learning rate (classification network)") 26 | parser.add_argument("--lrGP", type=float, default=None, help="learning rate (geometric predictor)") 27 | parser.add_argument("--lrDecay", type=float, default=1.0, help="learning rate decay") 28 | parser.add_argument("--lrStep", type=int, default=100000, help="learning rate decay step size") 29 | parser.add_argument("--fromIt", type=int, default=0, help="resume training from iteration number") 30 | parser.add_argument("--toIt", type=int, default=500000, help="run training to iteration number") 31 | else: # evaluation 32 | parser.add_argument("--batchSize", type=int, default=1, help="batch size for evaluation") 33 | opt = parser.parse_args() 34 | 35 | if opt.lrGP is None: opt.lrGP = 0 if opt.netType=="CNN" else \ 36 | 1e-2 if opt.netType=="STN" else \ 37 | 1e-4 if opt.netType=="IC-STN" else None 38 | 39 | # --- below are automatically set --- 40 | assert(torch.cuda.is_available()) # support only training on GPU for now 41 | torch.set_default_tensor_type("torch.cuda.FloatTensor") 42 | opt.training = training 43 | opt.H,opt.W = [int(x) for x in opt.size.split("x")] 44 | opt.visBlockSize = int(np.floor(np.sqrt(opt.batchSize))) 45 | opt.warpDim = 2 if opt.warpType == "translation" else \ 46 | 4 if opt.warpType == "similarity" else \ 47 | 6 if opt.warpType == "affine" else \ 48 | 8 if opt.warpType == "homography" else None 49 | opt.labelN = 10 50 | opt.canon4pts = np.array([[-1,-1],[-1,1],[1,1],[1,-1]],dtype=np.float32) 51 | opt.image4pts = np.array([[0,0],[0,opt.H-1],[opt.W-1,opt.H-1],[opt.W-1,0]],dtype=np.float32) 52 | opt.refMtrx = np.eye(3).astype(np.float32) 53 | if opt.netType=="STN": opt.warpN = 1 54 | 55 | print("({0}) {1}".format( 56 | util.toGreen("{0}".format(opt.group)), 57 | util.toGreen("{0}".format(opt.model)))) 58 | print("------------------------------------------") 59 | print("network type: {0}, recurrent warps: {1}".format( 60 | util.toYellow("{0}".format(opt.netType)), 61 | util.toYellow("{0}".format(opt.warpN if opt.netType=="IC-STN" else "X")))) 62 | print("batch size: {0}, image size: {1}x{2}".format( 63 | util.toYellow("{0}".format(opt.batchSize)), 64 | util.toYellow("{0}".format(opt.H)), 65 | util.toYellow("{0}".format(opt.W)))) 66 | print("warpScale: (pert) {0} (trans) {1}".format( 67 | util.toYellow("{0}".format(opt.pertScale)), 68 | util.toYellow("{0}".format(opt.transScale)))) 69 | if training: 70 | print("[geometric predictor] stddev={0}, lr={1}".format( 71 | util.toYellow("{0:.0e}".format(opt.stdGP)), 72 | util.toYellow("{0:.0e}".format(opt.lrGP)))) 73 | print("[classification network] stddev={0}, lr={1}".format( 74 | util.toYellow("{0:.0e}".format(opt.stdC)), 75 | util.toYellow("{0:.0e}".format(opt.lrC)))) 76 | print("------------------------------------------") 77 | if training: 78 | print(util.toMagenta("training model ({0}) {1}...".format(opt.group,opt.model))) 79 | 80 | return opt 81 | -------------------------------------------------------------------------------- /MNIST-pytorch/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time,os,sys 3 | import argparse 4 | import util 5 | 6 | print(util.toYellow("=======================================================")) 7 | print(util.toYellow("train.py (training on MNIST)")) 8 | print(util.toYellow("=======================================================")) 9 | 10 | import torch 11 | import data,graph,warp,util 12 | import options 13 | 14 | print(util.toMagenta("setting configurations...")) 15 | opt = options.set(training=True) 16 | 17 | # create directories for model output 18 | util.mkdir("models_{0}".format(opt.group)) 19 | 20 | print(util.toMagenta("building network...")) 21 | with torch.cuda.device(0): 22 | # ------ build network ------ 23 | if opt.netType=="CNN": 24 | geometric = graph.Identity() 25 | classifier = graph.FullCNN(opt) 26 | elif opt.netType=="STN": 27 | geometric = graph.STN(opt) 28 | classifier = graph.CNN(opt) 29 | elif opt.netType=="IC-STN": 30 | geometric = graph.ICSTN(opt) 31 | classifier = graph.CNN(opt) 32 | # ------ define loss ------ 33 | loss = torch.nn.CrossEntropyLoss() 34 | # ------ optimizer ------ 35 | optimList = [{ "params": geometric.parameters(), "lr": opt.lrGP }, 36 | { "params": classifier.parameters(), "lr": opt.lrC }] 37 | optim = torch.optim.SGD(optimList) 38 | 39 | # load data 40 | print(util.toMagenta("loading MNIST dataset...")) 41 | trainData,testData = data.loadMNIST(opt,"data") 42 | 43 | # visdom visualizer 44 | vis = util.Visdom(opt) 45 | 46 | print(util.toYellow("======= TRAINING START =======")) 47 | timeStart = time.time() 48 | # start session 49 | with torch.cuda.device(0): 50 | geometric.train() 51 | classifier.train() 52 | if opt.fromIt!=0: 53 | util.restoreModel(opt,geometric,classifier,opt.fromIt) 54 | print(util.toMagenta("resuming from iteration {0}...".format(opt.fromIt))) 55 | print(util.toMagenta("start training...")) 56 | 57 | # training loop 58 | for i in range(opt.fromIt,opt.toIt): 59 | lrGP = opt.lrGP*opt.lrDecay**(i//opt.lrStep) 60 | lrC = opt.lrC*opt.lrDecay**(i//opt.lrStep) 61 | # make training batch 62 | batch = data.makeBatch(opt,trainData) 63 | image = batch["image"].unsqueeze(dim=1) 64 | label = batch["label"] 65 | # generate perturbation 66 | pInit = data.genPerturbations(opt) 67 | pInitMtrx = warp.vec2mtrx(opt,pInit) 68 | # forward/backprop through network 69 | optim.zero_grad() 70 | imagePert = warp.transformImage(opt,image,pInitMtrx) 71 | imageWarpAll = geometric(opt,image,pInit) if opt.netType=="IC-STN" else geometric(opt,imagePert) 72 | imageWarp = imageWarpAll[-1] 73 | output = classifier(opt,imageWarp) 74 | train_loss = loss(output,label) 75 | train_loss.backward() 76 | # run one step 77 | optim.step() 78 | if (i+1)%100==0: 79 | print("it. {0}/{1} lr={3}(GP),{4}(C), loss={5}, time={2}" 80 | .format(util.toCyan("{0}".format(i+1)), 81 | opt.toIt, 82 | util.toGreen("{0:.2f}".format(time.time()-timeStart)), 83 | util.toYellow("{0:.0e}".format(lrGP)), 84 | util.toYellow("{0:.0e}".format(lrC)), 85 | util.toRed("{0:.4f}".format(train_loss)))) 86 | if (i+1)%200==0: vis.trainLoss(opt,i+1,train_loss) 87 | if (i+1)%1000==0: 88 | # evaluate on test set 89 | testAcc,testMean,testVar = data.evalTest(opt,testData,geometric,classifier) 90 | testError = (1-testAcc)*100 91 | vis.testLoss(opt,i+1,testError) 92 | if opt.netType=="STN" or opt.netType=="IC-STN": 93 | vis.meanVar(opt,testMean,testVar) 94 | if (i+1)%10000==0: 95 | util.saveModel(opt,geometric,classifier,i+1) 96 | print(util.toGreen("model saved: {0}/{1}, it.{2}".format(opt.group,opt.model,i+1))) 97 | 98 | print(util.toYellow("======= TRAINING DONE =======")) 99 | -------------------------------------------------------------------------------- /MNIST-pytorch/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import torch 4 | import os 5 | import termcolor 6 | import visdom 7 | 8 | def mkdir(path): 9 | if not os.path.exists(path): os.mkdir(path) 10 | def imread(fname): 11 | return scipy.misc.imread(fname)/255.0 12 | def imsave(fname,array): 13 | scipy.misc.toimage(array,cmin=0.0,cmax=1.0).save(fname) 14 | 15 | # convert to colored strings 16 | def toRed(content): return termcolor.colored(content,"red",attrs=["bold"]) 17 | def toGreen(content): return termcolor.colored(content,"green",attrs=["bold"]) 18 | def toBlue(content): return termcolor.colored(content,"blue",attrs=["bold"]) 19 | def toCyan(content): return termcolor.colored(content,"cyan",attrs=["bold"]) 20 | def toYellow(content): return termcolor.colored(content,"yellow",attrs=["bold"]) 21 | def toMagenta(content): return termcolor.colored(content,"magenta",attrs=["bold"]) 22 | 23 | # restore model 24 | def restoreModel(opt,geometric,classifier,it): 25 | geometric.load_state_dict(torch.load("models_{0}/{1}_it{2}_GP.npy".format(opt.group,opt.model,it))) 26 | classifier.load_state_dict(torch.load("models_{0}/{1}_it{2}_C.npy".format(opt.group,opt.model,it))) 27 | # save model 28 | def saveModel(opt,geometric,classifier,it): 29 | torch.save(geometric.state_dict(),"models_{0}/{1}_it{2}_GP.npy".format(opt.group,opt.model,it)) 30 | torch.save(classifier.state_dict(),"models_{0}/{1}_it{2}_C.npy".format(opt.group,opt.model,it)) 31 | 32 | class Visdom(): 33 | def __init__(self,opt): 34 | self.vis = visdom.Visdom(port=opt.port,use_incoming_socket=False) 35 | self.trainLossInit = True 36 | self.testLossInit = True 37 | self.meanVarInit = True 38 | def tileImages(self,opt,images,H,W,HN,WN): 39 | assert(len(images)==HN*WN) 40 | images = images.reshape([HN,WN,-1,H,W]) 41 | images = [list(i) for i in images] 42 | imageBlocks = np.concatenate([np.concatenate(row,axis=2) for row in images],axis=1) 43 | return imageBlocks 44 | def trainLoss(self,opt,it,loss): 45 | loss = float(loss.detach().cpu().numpy()) 46 | if self.trainLossInit: 47 | self.vis.line(Y=np.array([loss]),X=np.array([it]),win="{0}_trainloss".format(opt.model), 48 | opts={ "title": "{0} (TRAIN_loss)".format(opt.model) }) 49 | self.trainLossInit = False 50 | else: self.vis.line(Y=np.array([loss]),X=np.array([it]),win=opt.model+"_trainloss",update="append") 51 | def testLoss(self,opt,it,loss): 52 | if self.testLossInit: 53 | self.vis.line(Y=np.array([loss]),X=np.array([it]),win="{0}_testloss".format(opt.model), 54 | opts={ "title": "{0} (TEST_error)".format(opt.model) }) 55 | self.testLossInit = False 56 | else: self.vis.line(Y=np.array([loss]),X=np.array([it]),win=opt.model+"_testloss",update="append") 57 | def meanVar(self,opt,mean,var): 58 | mean = [self.tileImages(opt,m,opt.H,opt.W,1,10) for m in mean] 59 | var = [self.tileImages(opt,v,opt.H,opt.W,1,10)*3 for v in var] 60 | self.vis.image(mean[0].clip(0,1),win="{0}_meaninit".format(opt.model), opts={ "title": "{0} (TEST_mean_init)".format(opt.model) }) 61 | self.vis.image(mean[1].clip(0,1),win="{0}_meanwarped".format(opt.model), opts={ "title": "{0} (TEST_mean_warped)".format(opt.model) }) 62 | self.vis.image(var[0].clip(0,1),win="{0}_varinit".format(opt.model), opts={ "title": "{0} (TEST_var_init)".format(opt.model) }) 63 | self.vis.image(var[1].clip(0,1),win="{0}_varwarped".format(opt.model), opts={ "title": "{0} (TEST_var_warped)".format(opt.model) }) 64 | 65 | -------------------------------------------------------------------------------- /MNIST-pytorch/warp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg 3 | import torch 4 | 5 | import util 6 | 7 | # fit (affine) warp between two sets of points 8 | def fit(Xsrc,Xdst): 9 | ptsN = len(Xsrc) 10 | X,Y,U,V,O,I = Xsrc[:,0],Xsrc[:,1],Xdst[:,0],Xdst[:,1],np.zeros([ptsN]),np.ones([ptsN]) 11 | A = np.concatenate((np.stack([X,Y,I,O,O,O],axis=1), 12 | np.stack([O,O,O,X,Y,I],axis=1)),axis=0) 13 | b = np.concatenate((U,V),axis=0) 14 | p1,p2,p3,p4,p5,p6 = scipy.linalg.lstsq(A,b)[0].squeeze() 15 | pMtrx = np.array([[p1,p2,p3],[p4,p5,p6],[0,0,1]],dtype=torch.float32) 16 | return pMtrx 17 | 18 | # compute composition of warp parameters 19 | def compose(opt,p,dp): 20 | pMtrx = vec2mtrx(opt,p) 21 | dpMtrx = vec2mtrx(opt,dp) 22 | pMtrxNew = dpMtrx.matmul(pMtrx) 23 | pMtrxNew = pMtrxNew/pMtrxNew[:,2:3,2:3] 24 | pNew = mtrx2vec(opt,pMtrxNew) 25 | return pNew 26 | 27 | # compute inverse of warp parameters 28 | def inverse(opt,p): 29 | pMtrx = vec2mtrx(opt,p) 30 | pInvMtrx = pMtrx.inverse() 31 | pInv = mtrx2vec(opt,pInvMtrx) 32 | return pInv 33 | 34 | # convert warp parameters to matrix 35 | def vec2mtrx(opt,p): 36 | O = torch.zeros(opt.batchSize,dtype=torch.float32).cuda() 37 | I = torch.ones(opt.batchSize,dtype=torch.float32).cuda() 38 | if opt.warpType=="translation": 39 | tx,ty = torch.unbind(p,dim=1) 40 | pMtrx = torch.stack([torch.stack([I,O,tx],dim=-1), 41 | torch.stack([O,I,ty],dim=-1), 42 | torch.stack([O,O,I],dim=-1)],dim=1) 43 | if opt.warpType=="similarity": 44 | pc,ps,tx,ty = torch.unbind(p,dim=1) 45 | pMtrx = torch.stack([torch.stack([I+pc,-ps,tx],dim=-1), 46 | torch.stack([ps,I+pc,ty],dim=-1), 47 | torch.stack([O,O,I],dim=-1)],dim=1) 48 | if opt.warpType=="affine": 49 | p1,p2,p3,p4,p5,p6 = torch.unbind(p,dim=1) 50 | pMtrx = torch.stack([torch.stack([I+p1,p2,p3],dim=-1), 51 | torch.stack([p4,I+p5,p6],dim=-1), 52 | torch.stack([O,O,I],dim=-1)],dim=1) 53 | if opt.warpType=="homography": 54 | p1,p2,p3,p4,p5,p6,p7,p8 = torch.unbind(p,dim=1) 55 | pMtrx = torch.stack([torch.stack([I+p1,p2,p3],dim=-1), 56 | torch.stack([p4,I+p5,p6],dim=-1), 57 | torch.stack([p7,p8,I],dim=-1)],dim=1) 58 | return pMtrx 59 | 60 | # convert warp matrix to parameters 61 | def mtrx2vec(opt,pMtrx): 62 | [row0,row1,row2] = torch.unbind(pMtrx,dim=1) 63 | [e00,e01,e02] = torch.unbind(row0,dim=1) 64 | [e10,e11,e12] = torch.unbind(row1,dim=1) 65 | [e20,e21,e22] = torch.unbind(row2,dim=1) 66 | if opt.warpType=="translation": p = torch.stack([e02,e12],dim=1) 67 | if opt.warpType=="similarity": p = torch.stack([e00-1,e10,e02,e12],dim=1) 68 | if opt.warpType=="affine": p = torch.stack([e00-1,e01,e02,e10,e11-1,e12],dim=1) 69 | if opt.warpType=="homography": p = torch.stack([e00-1,e01,e02,e10,e11-1,e12,e20,e21],dim=1) 70 | return p 71 | 72 | # warp the image 73 | def transformImage(opt,image,pMtrx): 74 | refMtrx = torch.from_numpy(opt.refMtrx).cuda() 75 | refMtrx = refMtrx.repeat(opt.batchSize,1,1) 76 | transMtrx = refMtrx.matmul(pMtrx) 77 | # warp the canonical coordinates 78 | X,Y = np.meshgrid(np.linspace(-1,1,opt.W),np.linspace(-1,1,opt.H)) 79 | X,Y = X.flatten(),Y.flatten() 80 | XYhom = np.stack([X,Y,np.ones_like(X)],axis=1).T 81 | XYhom = np.tile(XYhom,[opt.batchSize,1,1]).astype(np.float32) 82 | XYhom = torch.from_numpy(XYhom).cuda() 83 | XYwarpHom = transMtrx.matmul(XYhom) 84 | XwarpHom,YwarpHom,ZwarpHom = torch.unbind(XYwarpHom,dim=1) 85 | Xwarp = (XwarpHom/(ZwarpHom+1e-8)).reshape(opt.batchSize,opt.H,opt.W) 86 | Ywarp = (YwarpHom/(ZwarpHom+1e-8)).reshape(opt.batchSize,opt.H,opt.W) 87 | grid = torch.stack([Xwarp,Ywarp],dim=-1) 88 | # sampling with bilinear interpolation 89 | imageWarp = torch.nn.functional.grid_sample(image,grid,mode="bilinear") 90 | return imageWarp 91 | -------------------------------------------------------------------------------- /MNIST-tensorflow/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg 3 | import os,time 4 | import tensorflow as tf 5 | 6 | import warp 7 | 8 | # load MNIST data 9 | def loadMNIST(fname): 10 | if not os.path.exists(fname): 11 | # download and preprocess MNIST dataset 12 | from tensorflow.examples.tutorials.mnist import input_data 13 | mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) 14 | trainData,validData,testData = {},{},{} 15 | trainData["image"] = mnist.train.images.reshape([-1,28,28]).astype(np.float32) 16 | validData["image"] = mnist.validation.images.reshape([-1,28,28]).astype(np.float32) 17 | testData["image"] = mnist.test.images.reshape([-1,28,28]).astype(np.float32) 18 | trainData["label"] = np.argmax(mnist.train.labels.astype(np.float32),axis=1) 19 | validData["label"] = np.argmax(mnist.validation.labels.astype(np.float32),axis=1) 20 | testData["label"] = np.argmax(mnist.test.labels.astype(np.float32),axis=1) 21 | os.makedirs(os.path.dirname(fname)) 22 | np.savez(fname,train=trainData,valid=validData,test=testData) 23 | os.system("rm -rf MNIST_data") 24 | MNIST = np.load(fname) 25 | trainData = MNIST["train"].item() 26 | validData = MNIST["valid"].item() 27 | testData = MNIST["test"].item() 28 | return trainData,validData,testData 29 | 30 | # generate training batch 31 | def genPerturbations(opt): 32 | with tf.name_scope("genPerturbations"): 33 | X = np.tile(opt.canon4pts[:,0],[opt.batchSize,1]) 34 | Y = np.tile(opt.canon4pts[:,1],[opt.batchSize,1]) 35 | dX = tf.random_normal([opt.batchSize,4])*opt.pertScale \ 36 | +tf.random_normal([opt.batchSize,1])*opt.transScale 37 | dY = tf.random_normal([opt.batchSize,4])*opt.pertScale \ 38 | +tf.random_normal([opt.batchSize,1])*opt.transScale 39 | O = np.zeros([opt.batchSize,4],dtype=np.float32) 40 | I = np.ones([opt.batchSize,4],dtype=np.float32) 41 | # fit warp parameters to generated displacements 42 | if opt.warpType=="homography": 43 | A = tf.concat([tf.stack([X,Y,I,O,O,O,-X*(X+dX),-Y*(X+dX)],axis=-1), 44 | tf.stack([O,O,O,X,Y,I,-X*(Y+dY),-Y*(Y+dY)],axis=-1)],1) 45 | b = tf.expand_dims(tf.concat([X+dX,Y+dY],1),-1) 46 | pPert = tf.matrix_solve(A,b)[:,:,0] 47 | pPert -= tf.to_float([[1,0,0,0,1,0,0,0]]) 48 | else: 49 | if opt.warpType=="translation": 50 | J = np.concatenate([np.stack([I,O],axis=-1), 51 | np.stack([O,I],axis=-1)],axis=1) 52 | if opt.warpType=="similarity": 53 | J = np.concatenate([np.stack([X,Y,I,O],axis=-1), 54 | np.stack([-Y,X,O,I],axis=-1)],axis=1) 55 | if opt.warpType=="affine": 56 | J = np.concatenate([np.stack([X,Y,I,O,O,O],axis=-1), 57 | np.stack([O,O,O,X,Y,I],axis=-1)],axis=1) 58 | dXY = tf.expand_dims(tf.concat([dX,dY],1),-1) 59 | pPert = tf.matrix_solve_ls(J,dXY)[:,:,0] 60 | return pPert 61 | 62 | # make training batch 63 | def makeBatch(opt,data,PH): 64 | N = len(data["image"]) 65 | randIdx = np.random.randint(N,size=[opt.batchSize]) 66 | # put data in placeholders 67 | [image,label] = PH 68 | batch = { 69 | image: data["image"][randIdx], 70 | label: data["label"][randIdx], 71 | } 72 | return batch 73 | 74 | # evaluation on test set 75 | def evalTest(opt,sess,data,PH,prediction,imagesEval=[]): 76 | N = len(data["image"]) 77 | # put data in placeholders 78 | [image,label] = PH 79 | batchN = int(np.ceil(N/opt.batchSize)) 80 | warped = [{},{}] 81 | count = 0 82 | for b in range(batchN): 83 | # use some dummy data (0) as batch filler if necessary 84 | if b!=batchN-1: 85 | realIdx = np.arange(opt.batchSize*b,opt.batchSize*(b+1)) 86 | else: 87 | realIdx = np.arange(opt.batchSize*b,N) 88 | idx = np.zeros([opt.batchSize],dtype=int) 89 | idx[:len(realIdx)] = realIdx 90 | batch = { 91 | image: data["image"][idx], 92 | label: data["label"][idx], 93 | } 94 | evalList = sess.run([prediction]+imagesEval,feed_dict=batch) 95 | pred = evalList[0] 96 | count += pred[:len(realIdx)].sum() 97 | if opt.netType=="STN" or opt.netType=="IC-STN": 98 | imgs = evalList[1:] 99 | for i in range(len(realIdx)): 100 | l = data["label"][idx[i]] 101 | if l not in warped[0]: warped[0][l] = [] 102 | if l not in warped[1]: warped[1][l] = [] 103 | warped[0][l].append(imgs[0][i]) 104 | warped[1][l].append(imgs[1][i]) 105 | accuracy = float(count)/N 106 | if opt.netType=="STN" or opt.netType=="IC-STN": 107 | mean = [np.array([np.mean(warped[0][l],axis=0) for l in warped[0]]), 108 | np.array([np.mean(warped[1][l],axis=0) for l in warped[1]])] 109 | var = [np.array([np.var(warped[0][l],axis=0) for l in warped[0]]), 110 | np.array([np.var(warped[1][l],axis=0) for l in warped[1]])] 111 | else: mean,var = None,None 112 | return accuracy,mean,var 113 | -------------------------------------------------------------------------------- /MNIST-tensorflow/graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import time 4 | import data,warp,util 5 | 6 | # build classification network 7 | def fullCNN(opt,image): 8 | def conv2Layer(opt,feat,outDim): 9 | weight,bias = createVariable(opt,[3,3,int(feat.shape[-1]),outDim],stddev=opt.stdC) 10 | conv = tf.nn.conv2d(feat,weight,strides=[1,1,1,1],padding="VALID")+bias 11 | return conv 12 | def linearLayer(opt,feat,outDim): 13 | weight,bias = createVariable(opt,[int(feat.shape[-1]),outDim],stddev=opt.stdC) 14 | fc = tf.matmul(feat,weight)+bias 15 | return fc 16 | with tf.variable_scope("classifier"): 17 | feat = image 18 | with tf.variable_scope("conv1"): 19 | feat = conv2Layer(opt,feat,3) 20 | feat = tf.nn.relu(feat) 21 | with tf.variable_scope("conv2"): 22 | feat = conv2Layer(opt,feat,6) 23 | feat = tf.nn.relu(feat) 24 | feat = tf.nn.max_pool(feat,ksize=[1,2,2,1],strides=[1,2,2,1],padding="VALID") 25 | with tf.variable_scope("conv3"): 26 | feat = conv2Layer(opt,feat,9) 27 | feat = tf.nn.relu(feat) 28 | with tf.variable_scope("conv4"): 29 | feat = conv2Layer(opt,feat,12) 30 | feat = tf.nn.relu(feat) 31 | feat = tf.reshape(feat,[opt.batchSize,-1]) 32 | with tf.variable_scope("fc5"): 33 | feat = linearLayer(opt,feat,48) 34 | feat = tf.nn.relu(feat) 35 | with tf.variable_scope("fc6"): 36 | feat = linearLayer(opt,feat,opt.labelN) 37 | output = feat 38 | return output 39 | 40 | # build classification network 41 | def CNN(opt,image): 42 | def conv2Layer(opt,feat,outDim): 43 | weight,bias = createVariable(opt,[9,9,int(feat.shape[-1]),outDim],stddev=opt.stdC) 44 | conv = tf.nn.conv2d(feat,weight,strides=[1,1,1,1],padding="VALID")+bias 45 | return conv 46 | def linearLayer(opt,feat,outDim): 47 | weight,bias = createVariable(opt,[int(feat.shape[-1]),outDim],stddev=opt.stdC) 48 | fc = tf.matmul(feat,weight)+bias 49 | return fc 50 | with tf.variable_scope("classifier"): 51 | feat = image 52 | with tf.variable_scope("conv1"): 53 | feat = conv2Layer(opt,feat,3) 54 | feat = tf.nn.relu(feat) 55 | feat = tf.reshape(feat,[opt.batchSize,-1]) 56 | with tf.variable_scope("fc2"): 57 | feat = linearLayer(opt,feat,opt.labelN) 58 | output = feat 59 | return output 60 | 61 | # build Spatial Transformer Network 62 | def STN(opt,image): 63 | def conv2Layer(opt,feat,outDim): 64 | weight,bias = createVariable(opt,[7,7,int(feat.shape[-1]),outDim],stddev=opt.stdGP) 65 | conv = tf.nn.conv2d(feat,weight,strides=[1,1,1,1],padding="VALID")+bias 66 | return conv 67 | def linearLayer(opt,feat,outDim,final=False): 68 | weight,bias = createVariable(opt,[int(feat.shape[-1]),outDim],stddev=0.0 if final else opt.stdGP) 69 | fc = tf.matmul(feat,weight)+bias 70 | return fc 71 | imageWarpAll = [image] 72 | with tf.variable_scope("geometric"): 73 | feat = image 74 | with tf.variable_scope("conv1"): 75 | feat = conv2Layer(opt,feat,4) 76 | feat = tf.nn.relu(feat) 77 | with tf.variable_scope("conv2"): 78 | feat = conv2Layer(opt,feat,8) 79 | feat = tf.nn.relu(feat) 80 | feat = tf.nn.max_pool(feat,ksize=[1,2,2,1],strides=[1,2,2,1],padding="VALID") 81 | feat = tf.reshape(feat,[opt.batchSize,-1]) 82 | with tf.variable_scope("fc3"): 83 | feat = linearLayer(opt,feat,48) 84 | feat = tf.nn.relu(feat) 85 | with tf.variable_scope("fc4"): 86 | feat = linearLayer(opt,feat,opt.warpDim,final=True) 87 | p = feat 88 | pMtrx = warp.vec2mtrx(opt,p) 89 | imageWarp = warp.transformImage(opt,image,pMtrx) 90 | imageWarpAll.append(imageWarp) 91 | return imageWarpAll 92 | 93 | # build Inverse Compositional STN 94 | def ICSTN(opt,image,p): 95 | def conv2Layer(opt,feat,outDim): 96 | weight,bias = createVariable(opt,[7,7,int(feat.shape[-1]),outDim],stddev=opt.stdGP) 97 | conv = tf.nn.conv2d(feat,weight,strides=[1,1,1,1],padding="VALID")+bias 98 | return conv 99 | def linearLayer(opt,feat,outDim,final=False): 100 | weight,bias = createVariable(opt,[int(feat.shape[-1]),outDim],stddev=0.0 if final else opt.stdGP) 101 | fc = tf.matmul(feat,weight)+bias 102 | return fc 103 | imageWarpAll = [] 104 | for l in range(opt.warpN): 105 | with tf.variable_scope("geometric",reuse=l>0): 106 | pMtrx = warp.vec2mtrx(opt,p) 107 | imageWarp = warp.transformImage(opt,image,pMtrx) 108 | imageWarpAll.append(imageWarp) 109 | feat = imageWarp 110 | with tf.variable_scope("conv1"): 111 | feat = conv2Layer(opt,feat,4) 112 | feat = tf.nn.relu(feat) 113 | with tf.variable_scope("conv2"): 114 | feat = conv2Layer(opt,feat,8) 115 | feat = tf.nn.relu(feat) 116 | feat = tf.nn.max_pool(feat,ksize=[1,2,2,1],strides=[1,2,2,1],padding="VALID") 117 | feat = tf.reshape(feat,[opt.batchSize,-1]) 118 | with tf.variable_scope("fc3"): 119 | feat = linearLayer(opt,feat,48) 120 | feat = tf.nn.relu(feat) 121 | with tf.variable_scope("fc4"): 122 | feat = linearLayer(opt,feat,opt.warpDim,final=True) 123 | dp = feat 124 | p = warp.compose(opt,p,dp) 125 | pMtrx = warp.vec2mtrx(opt,p) 126 | imageWarp = warp.transformImage(opt,image,pMtrx) 127 | imageWarpAll.append(imageWarp) 128 | return imageWarpAll 129 | 130 | # auxiliary function for creating weight and bias 131 | def createVariable(opt,weightShape,biasShape=None,stddev=None): 132 | if biasShape is None: biasShape = [weightShape[-1]] 133 | weight = tf.get_variable("weight",shape=weightShape,dtype=tf.float32, 134 | initializer=tf.random_normal_initializer(stddev=stddev)) 135 | bias = tf.get_variable("bias",shape=biasShape,dtype=tf.float32, 136 | initializer=tf.random_normal_initializer(stddev=stddev)) 137 | return weight,bias 138 | -------------------------------------------------------------------------------- /MNIST-tensorflow/options.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import warp 4 | import util 5 | 6 | def set(training): 7 | 8 | # parse input arguments 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("netType", choices=["CNN","STN","IC-STN"], help="type of network") 11 | parser.add_argument("--group", default="0", help="name for group") 12 | parser.add_argument("--model", default="test", help="name for model instance") 13 | parser.add_argument("--size", default="28x28", help="image resolution") 14 | parser.add_argument("--warpType", default="homography", help="type of warp function on images", 15 | choices=["translation","similarity","affine","homography"]) 16 | parser.add_argument("--warpN", type=int, default=4, help="number of recurrent transformations (for IC-STN)") 17 | parser.add_argument("--stdC", type=float, default=0.1, help="initialization stddev (classification network)") 18 | parser.add_argument("--stdGP", type=float, default=0.1, help="initialization stddev (geometric predictor)") 19 | parser.add_argument("--pertScale", type=float, default=0.25, help="initial perturbation scale") 20 | parser.add_argument("--transScale", type=float, default=0.25, help="initial translation scale") 21 | if training: # training 22 | parser.add_argument("--batchSize", type=int, default=100, help="batch size for SGD") 23 | parser.add_argument("--lrC", type=float, default=1e-2, help="learning rate (classification network)") 24 | parser.add_argument("--lrCdecay", type=float, default=1.0, help="learning rate decay (classification network)") 25 | parser.add_argument("--lrCstep", type=int, default=100000, help="learning rate decay step size (classification network)") 26 | parser.add_argument("--lrGP", type=float, default=None, help="learning rate (geometric predictor)") 27 | parser.add_argument("--lrGPdecay", type=float, default=1.0, help="learning rate decay (geometric predictor)") 28 | parser.add_argument("--lrGPstep", type=int, default=100000, help="learning rate decay step size (geometric predictor)") 29 | parser.add_argument("--fromIt", type=int, default=0, help="resume training from iteration number") 30 | parser.add_argument("--toIt", type=int, default=500000, help="run training to iteration number") 31 | else: # evaluation 32 | parser.add_argument("--batchSize", type=int, default=1, help="batch size for evaluation") 33 | opt = parser.parse_args() 34 | 35 | if opt.lrGP is None: opt.lrGP = 0 if opt.netType=="CNN" else \ 36 | 1e-2 if opt.netType=="STN" else \ 37 | 1e-4 if opt.netType=="IC-STN" else None 38 | 39 | # --- below are automatically set --- 40 | opt.training = training 41 | opt.H,opt.W = [int(x) for x in opt.size.split("x")] 42 | opt.visBlockSize = int(np.floor(np.sqrt(opt.batchSize))) 43 | opt.warpDim = 2 if opt.warpType == "translation" else \ 44 | 4 if opt.warpType == "similarity" else \ 45 | 6 if opt.warpType == "affine" else \ 46 | 8 if opt.warpType == "homography" else None 47 | opt.labelN = 10 48 | opt.canon4pts = np.array([[-1,-1],[-1,1],[1,1],[1,-1]],dtype=np.float32) 49 | opt.image4pts = np.array([[0,0],[0,opt.H-1],[opt.W-1,opt.H-1],[opt.W-1,0]],dtype=np.float32) 50 | opt.refMtrx = warp.fit(Xsrc=opt.canon4pts,Xdst=opt.image4pts) 51 | if opt.netType=="STN": opt.warpN = 1 52 | 53 | print("({0}) {1}".format( 54 | util.toGreen("{0}".format(opt.group)), 55 | util.toGreen("{0}".format(opt.model)))) 56 | print("------------------------------------------") 57 | print("network type: {0}, recurrent warps: {1}".format( 58 | util.toYellow("{0}".format(opt.netType)), 59 | util.toYellow("{0}".format(opt.warpN if opt.netType=="IC-STN" else "X")))) 60 | print("batch size: {0}, image size: {1}x{2}".format( 61 | util.toYellow("{0}".format(opt.batchSize)), 62 | util.toYellow("{0}".format(opt.H)), 63 | util.toYellow("{0}".format(opt.W)))) 64 | print("warpScale: (pert) {0} (trans) {1}".format( 65 | util.toYellow("{0}".format(opt.pertScale)), 66 | util.toYellow("{0}".format(opt.transScale)))) 67 | if training: 68 | print("[geometric predictor] stddev={0}, lr={1}".format( 69 | util.toYellow("{0:.0e}".format(opt.stdGP)), 70 | util.toYellow("{0:.0e}".format(opt.lrGP)))) 71 | print("[classification network] stddev={0}, lr={1}".format( 72 | util.toYellow("{0:.0e}".format(opt.stdC)), 73 | util.toYellow("{0:.0e}".format(opt.lrC)))) 74 | print("------------------------------------------") 75 | if training: 76 | print(util.toMagenta("training model ({0}) {1}...".format(opt.group,opt.model))) 77 | 78 | return opt 79 | -------------------------------------------------------------------------------- /MNIST-tensorflow/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time,os,sys 3 | import argparse 4 | import util 5 | 6 | print(util.toYellow("=======================================================")) 7 | print(util.toYellow("train.py (training on MNIST)")) 8 | print(util.toYellow("=======================================================")) 9 | 10 | import tensorflow as tf 11 | import data,graph,warp,util 12 | import options 13 | 14 | print(util.toMagenta("setting configurations...")) 15 | opt = options.set(training=True) 16 | 17 | # create directories for model output 18 | util.mkdir("models_{0}".format(opt.group)) 19 | 20 | print(util.toMagenta("building graph...")) 21 | tf.reset_default_graph() 22 | # build graph 23 | with tf.device("/gpu:0"): 24 | # ------ define input data ------ 25 | image = tf.placeholder(tf.float32,shape=[opt.batchSize,opt.H,opt.W]) 26 | label = tf.placeholder(tf.int64,shape=[opt.batchSize]) 27 | PH = [image,label] 28 | # ------ generate perturbation ------ 29 | pInit = data.genPerturbations(opt) 30 | pInitMtrx = warp.vec2mtrx(opt,pInit) 31 | # ------ build network ------ 32 | image = tf.expand_dims(image,axis=-1) 33 | imagePert = warp.transformImage(opt,image,pInitMtrx) 34 | if opt.netType=="CNN": 35 | output = graph.fullCNN(opt,imagePert) 36 | elif opt.netType=="STN": 37 | imageWarpAll = graph.STN(opt,imagePert) 38 | imageWarp = imageWarpAll[-1] 39 | output = graph.CNN(opt,imageWarp) 40 | elif opt.netType=="IC-STN": 41 | imageWarpAll = graph.ICSTN(opt,image,pInit) 42 | imageWarp = imageWarpAll[-1] 43 | output = graph.CNN(opt,imageWarp) 44 | softmax = tf.nn.softmax(output) 45 | labelOnehot = tf.one_hot(label,opt.labelN) 46 | prediction = tf.equal(tf.argmax(softmax,1),label) 47 | # ------ define loss ------ 48 | softmaxLoss = tf.nn.softmax_cross_entropy_with_logits(logits=output,labels=labelOnehot) 49 | loss = tf.reduce_mean(softmaxLoss) 50 | # ------ optimizer ------ 51 | lrGP_PH,lrC_PH = tf.placeholder(tf.float32,shape=[]),tf.placeholder(tf.float32,shape=[]) 52 | optim = util.setOptimizer(opt,loss,lrGP_PH,lrC_PH) 53 | # ------ generate summaries ------ 54 | summaryImageTrain = [] 55 | summaryImageTest = [] 56 | if opt.netType=="STN" or opt.netType=="IC-STN": 57 | for l in range(opt.warpN+1): 58 | summaryImageTrain.append(util.imageSummary(opt,imageWarpAll[l],"TRAIN_warp{0}".format(l),opt.H,opt.W)) 59 | summaryImageTest.append(util.imageSummary(opt,imageWarpAll[l],"TEST_warp{0}".format(l),opt.H,opt.W)) 60 | summaryImageTrain = tf.summary.merge(summaryImageTrain) 61 | summaryImageTest = tf.summary.merge(summaryImageTest) 62 | summaryLossTrain = tf.summary.scalar("TRAIN_loss",loss) 63 | testErrorPH = tf.placeholder(tf.float32,shape=[]) 64 | testImagePH = tf.placeholder(tf.float32,shape=[opt.labelN,opt.H,opt.W,1]) 65 | summaryErrorTest = tf.summary.scalar("TEST_error",testErrorPH) 66 | if opt.netType=="STN" or opt.netType=="IC-STN": 67 | summaryMeanTest0 = util.imageSummaryMeanVar(opt,testImagePH,"TEST_mean_init",opt.H,opt.W) 68 | summaryMeanTest1 = util.imageSummaryMeanVar(opt,testImagePH,"TEST_mean_warped",opt.H,opt.W) 69 | summaryVarTest0 = util.imageSummaryMeanVar(opt,testImagePH*3,"TEST_var_init",opt.H,opt.W) 70 | summaryVarTest1 = util.imageSummaryMeanVar(opt,testImagePH*3,"TEST_var_warped",opt.H,opt.W) 71 | 72 | # load data 73 | print(util.toMagenta("loading MNIST dataset...")) 74 | trainData,validData,testData = data.loadMNIST("data/MNIST.npz") 75 | 76 | # prepare model saver/summary writer 77 | saver = tf.train.Saver(max_to_keep=20) 78 | summaryWriter = tf.summary.FileWriter("summary_{0}/{1}".format(opt.group,opt.model)) 79 | 80 | print(util.toYellow("======= TRAINING START =======")) 81 | timeStart = time.time() 82 | # start session 83 | tfConfig = tf.ConfigProto(allow_soft_placement=True) 84 | tfConfig.gpu_options.allow_growth = True 85 | with tf.Session(config=tfConfig) as sess: 86 | sess.run(tf.global_variables_initializer()) 87 | summaryWriter.add_graph(sess.graph) 88 | if opt.fromIt!=0: 89 | util.restoreModel(opt,sess,saver,opt.fromIt) 90 | print(util.toMagenta("resuming from iteration {0}...".format(opt.fromIt))) 91 | print(util.toMagenta("start training...")) 92 | 93 | # training loop 94 | for i in range(opt.fromIt,opt.toIt): 95 | lrGP = opt.lrGP*opt.lrGPdecay**(i//opt.lrGPstep) 96 | lrC = opt.lrC*opt.lrCdecay**(i//opt.lrCstep) 97 | # make training batch 98 | batch = data.makeBatch(opt,trainData,PH) 99 | batch[lrGP_PH] = lrGP 100 | batch[lrC_PH] = lrC 101 | # run one step 102 | _,l = sess.run([optim,loss],feed_dict=batch) 103 | if (i+1)%100==0: 104 | print("it. {0}/{1} lr={3}(GP),{4}(C), loss={5}, time={2}" 105 | .format(util.toCyan("{0}".format(i+1)), 106 | opt.toIt, 107 | util.toGreen("{0:.2f}".format(time.time()-timeStart)), 108 | util.toYellow("{0:.0e}".format(lrGP)), 109 | util.toYellow("{0:.0e}".format(lrC)), 110 | util.toRed("{0:.4f}".format(l)))) 111 | if (i+1)%100==0: 112 | summaryWriter.add_summary(sess.run(summaryLossTrain,feed_dict=batch),i+1) 113 | if (i+1)%500==0 and (opt.netType=="STN" or opt.netType=="IC-STN"): 114 | summaryWriter.add_summary(sess.run(summaryImageTrain,feed_dict=batch),i+1) 115 | summaryWriter.add_summary(sess.run(summaryImageTest,feed_dict=batch),i+1) 116 | if (i+1)%1000==0: 117 | # evaluate on test set 118 | if opt.netType=="STN" or opt.netType=="IC-STN": 119 | testAcc,testMean,testVar = data.evalTest(opt,sess,testData,PH,prediction,imagesEval=[imagePert,imageWarp]) 120 | else: 121 | testAcc,_,_ = data.evalTest(opt,sess,testData,PH,prediction) 122 | testError = (1-testAcc)*100 123 | summaryWriter.add_summary(sess.run(summaryErrorTest,feed_dict={testErrorPH:testError}),i+1) 124 | if opt.netType=="STN" or opt.netType=="IC-STN": 125 | summaryWriter.add_summary(sess.run(summaryMeanTest0,feed_dict={testImagePH:testMean[0]}),i+1) 126 | summaryWriter.add_summary(sess.run(summaryMeanTest1,feed_dict={testImagePH:testMean[1]}),i+1) 127 | summaryWriter.add_summary(sess.run(summaryVarTest0,feed_dict={testImagePH:testVar[0]}),i+1) 128 | summaryWriter.add_summary(sess.run(summaryVarTest1,feed_dict={testImagePH:testVar[1]}),i+1) 129 | if (i+1)%10000==0: 130 | util.saveModel(opt,sess,saver,i+1) 131 | print(util.toGreen("model saved: {0}/{1}, it.{2}".format(opt.group,opt.model,i+1))) 132 | 133 | print(util.toYellow("======= TRAINING DONE =======")) 134 | -------------------------------------------------------------------------------- /MNIST-tensorflow/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import tensorflow as tf 4 | import os 5 | import termcolor 6 | 7 | def mkdir(path): 8 | if not os.path.exists(path): os.mkdir(path) 9 | def imread(fname): 10 | return scipy.misc.imread(fname)/255.0 11 | def imsave(fname,array): 12 | scipy.misc.toimage(array,cmin=0.0,cmax=1.0).save(fname) 13 | 14 | # convert to colored strings 15 | def toRed(content): return termcolor.colored(content,"red",attrs=["bold"]) 16 | def toGreen(content): return termcolor.colored(content,"green",attrs=["bold"]) 17 | def toBlue(content): return termcolor.colored(content,"blue",attrs=["bold"]) 18 | def toCyan(content): return termcolor.colored(content,"cyan",attrs=["bold"]) 19 | def toYellow(content): return termcolor.colored(content,"yellow",attrs=["bold"]) 20 | def toMagenta(content): return termcolor.colored(content,"magenta",attrs=["bold"]) 21 | 22 | # make image summary from image batch 23 | def imageSummary(opt,image,tag,H,W): 24 | blockSize = opt.visBlockSize 25 | imageOne = tf.batch_to_space(image[:blockSize**2],crops=[[0,0],[0,0]],block_size=blockSize) 26 | imagePermute = tf.reshape(imageOne,[H,blockSize,W,blockSize,-1]) 27 | imageTransp = tf.transpose(imagePermute,[1,0,3,2,4]) 28 | imageBlocks = tf.reshape(imageTransp,[1,H*blockSize,W*blockSize,-1]) 29 | imageBlocks = tf.cast(imageBlocks*255,tf.uint8) 30 | summary = tf.summary.image(tag,imageBlocks) 31 | return summary 32 | 33 | # make image summary from image batch (mean/variance) 34 | def imageSummaryMeanVar(opt,image,tag,H,W): 35 | imageOne = tf.batch_to_space_nd(image,crops=[[0,0],[0,0]],block_shape=[1,10]) 36 | imagePermute = tf.reshape(imageOne,[H,1,W,10,-1]) 37 | imageTransp = tf.transpose(imagePermute,[1,0,3,2,4]) 38 | imageBlocks = tf.reshape(imageTransp,[1,H*1,W*10,-1]) 39 | imageBlocks = tf.cast(imageBlocks*255,tf.uint8) 40 | summary = tf.summary.image(tag,imageBlocks) 41 | return summary 42 | 43 | # set optimizer for different learning rates 44 | def setOptimizer(opt,loss,lrGP,lrC): 45 | varsGP = [v for v in tf.global_variables() if "geometric" in v.name] 46 | varsC = [v for v in tf.global_variables() if "classifier" in v.name] 47 | gradC = tf.gradients(loss,varsC) 48 | optimC = tf.train.GradientDescentOptimizer(lrC).apply_gradients(zip(gradC,varsC)) 49 | if len(varsGP)>0: 50 | gradGP = tf.gradients(loss,varsGP) 51 | optimGP = tf.train.GradientDescentOptimizer(lrGP).apply_gradients(zip(gradGP,varsGP)) 52 | optim = tf.group(optimC,optimGP) 53 | else: 54 | optim = optimC 55 | return optim 56 | 57 | # restore model 58 | def restoreModel(opt,sess,saver,it): 59 | saver.restore(sess,"models_{0}/{1}_it{2}.ckpt".format(opt.group,opt.model,it,opt.warpN)) 60 | # save model 61 | def saveModel(opt,sess,saver,it): 62 | saver.save(sess,"models_{0}/{1}_it{2}.ckpt".format(opt.group,opt.model,it,opt.warpN)) 63 | 64 | -------------------------------------------------------------------------------- /MNIST-tensorflow/warp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg 3 | import tensorflow as tf 4 | 5 | # fit (affine) warp between two sets of points 6 | def fit(Xsrc,Xdst): 7 | ptsN = len(Xsrc) 8 | X,Y,U,V,O,I = Xsrc[:,0],Xsrc[:,1],Xdst[:,0],Xdst[:,1],np.zeros([ptsN]),np.ones([ptsN]) 9 | A = np.concatenate((np.stack([X,Y,I,O,O,O],axis=1), 10 | np.stack([O,O,O,X,Y,I],axis=1)),axis=0) 11 | b = np.concatenate((U,V),axis=0) 12 | p1,p2,p3,p4,p5,p6 = scipy.linalg.lstsq(A,b)[0].squeeze() 13 | pMtrx = np.array([[p1,p2,p3],[p4,p5,p6],[0,0,1]],dtype=np.float32) 14 | return pMtrx 15 | 16 | # compute composition of warp parameters 17 | def compose(opt,p,dp): 18 | with tf.name_scope("compose"): 19 | pMtrx = vec2mtrx(opt,p) 20 | dpMtrx = vec2mtrx(opt,dp) 21 | pMtrxNew = tf.matmul(dpMtrx,pMtrx) 22 | pMtrxNew /= pMtrxNew[:,2:3,2:3] 23 | pNew = mtrx2vec(opt,pMtrxNew) 24 | return pNew 25 | 26 | # compute inverse of warp parameters 27 | def inverse(opt,p): 28 | with tf.name_scope("inverse"): 29 | pMtrx = vec2mtrx(opt,p) 30 | pInvMtrx = tf.matrix_inverse(pMtrx) 31 | pInv = mtrx2vec(opt,pInvMtrx) 32 | return pInv 33 | 34 | # convert warp parameters to matrix 35 | def vec2mtrx(opt,p): 36 | with tf.name_scope("vec2mtrx"): 37 | O = tf.zeros([opt.batchSize]) 38 | I = tf.ones([opt.batchSize]) 39 | if opt.warpType=="translation": 40 | tx,ty = tf.unstack(p,axis=1) 41 | pMtrx = tf.transpose(tf.stack([[I,O,tx],[O,I,ty],[O,O,I]]),perm=[2,0,1]) 42 | if opt.warpType=="similarity": 43 | pc,ps,tx,ty = tf.unstack(p,axis=1) 44 | pMtrx = tf.transpose(tf.stack([[I+pc,-ps,tx],[ps,I+pc,ty],[O,O,I]]),perm=[2,0,1]) 45 | if opt.warpType=="affine": 46 | p1,p2,p3,p4,p5,p6,p7,p8 = tf.unstack(p,axis=1) 47 | pMtrx = tf.transpose(tf.stack([[I+p1,p2,p3],[p4,I+p5,p6],[O,O,I]]),perm=[2,0,1]) 48 | if opt.warpType=="homography": 49 | p1,p2,p3,p4,p5,p6,p7,p8 = tf.unstack(p,axis=1) 50 | pMtrx = tf.transpose(tf.stack([[I+p1,p2,p3],[p4,I+p5,p6],[p7,p8,I]]),perm=[2,0,1]) 51 | return pMtrx 52 | 53 | # convert warp matrix to parameters 54 | def mtrx2vec(opt,pMtrx): 55 | with tf.name_scope("mtrx2vec"): 56 | [row0,row1,row2] = tf.unstack(pMtrx,axis=1) 57 | [e00,e01,e02] = tf.unstack(row0,axis=1) 58 | [e10,e11,e12] = tf.unstack(row1,axis=1) 59 | [e20,e21,e22] = tf.unstack(row2,axis=1) 60 | if opt.warpType=="translation": p = tf.stack([e02,e12],axis=1) 61 | if opt.warpType=="similarity": p = tf.stack([e00-1,e10,e02,e12],axis=1) 62 | if opt.warpType=="affine": p = tf.stack([e00-1,e01,e02,e10,e11-1,e12],axis=1) 63 | if opt.warpType=="homography": p = tf.stack([e00-1,e01,e02,e10,e11-1,e12,e20,e21],axis=1) 64 | return p 65 | 66 | # warp the image 67 | def transformImage(opt,image,pMtrx): 68 | with tf.name_scope("transformImage"): 69 | refMtrx = tf.tile(tf.expand_dims(opt.refMtrx,axis=0),[opt.batchSize,1,1]) 70 | transMtrx = tf.matmul(refMtrx,pMtrx) 71 | # warp the canonical coordinates 72 | X,Y = np.meshgrid(np.linspace(-1,1,opt.W),np.linspace(-1,1,opt.H)) 73 | X,Y = X.flatten(),Y.flatten() 74 | XYhom = np.stack([X,Y,np.ones_like(X)],axis=1).T 75 | XYhom = np.tile(XYhom,[opt.batchSize,1,1]).astype(np.float32) 76 | XYwarpHom = tf.matmul(transMtrx,XYhom) 77 | XwarpHom,YwarpHom,ZwarpHom = tf.unstack(XYwarpHom,axis=1) 78 | Xwarp = tf.reshape(XwarpHom/(ZwarpHom+1e-8),[opt.batchSize,opt.H,opt.W]) 79 | Ywarp = tf.reshape(YwarpHom/(ZwarpHom+1e-8),[opt.batchSize,opt.H,opt.W]) 80 | # get the integer sampling coordinates 81 | Xfloor,Xceil = tf.floor(Xwarp),tf.ceil(Xwarp) 82 | Yfloor,Yceil = tf.floor(Ywarp),tf.ceil(Ywarp) 83 | XfloorInt,XceilInt = tf.to_int32(Xfloor),tf.to_int32(Xceil) 84 | YfloorInt,YceilInt = tf.to_int32(Yfloor),tf.to_int32(Yceil) 85 | imageIdx = np.tile(np.arange(opt.batchSize).reshape([opt.batchSize,1,1]),[1,opt.H,opt.W]) 86 | imageVec = tf.reshape(image,[-1,int(image.shape[-1])]) 87 | imageVecOut = tf.concat([imageVec,tf.zeros([1,int(image.shape[-1])])],axis=0) 88 | idxUL = (imageIdx*opt.H+YfloorInt)*opt.W+XfloorInt 89 | idxUR = (imageIdx*opt.H+YfloorInt)*opt.W+XceilInt 90 | idxBL = (imageIdx*opt.H+YceilInt)*opt.W+XfloorInt 91 | idxBR = (imageIdx*opt.H+YceilInt)*opt.W+XceilInt 92 | idxOutside = tf.fill([opt.batchSize,opt.H,opt.W],opt.batchSize*opt.H*opt.W) 93 | def insideImage(Xint,Yint): 94 | return (Xint>=0)&(Xint=0)&(Yint

12 | 13 | We provide TensorFlow code for the following experiments: 14 | - MNIST classification 15 | - traffic sign classification 16 | 17 | **[NEW!]** The PyTorch implementation of the MNIST experiment is now up! 18 | 19 | -------------------------------------- 20 | 21 | ## TensorFlow 22 | 23 | ### Prerequisites 24 | This code is developed with Python3 (`python3`) but it is also compatible with Python2.7 (`python`). TensorFlow r1.0+ is required. The dependencies can install by running 25 | ``` 26 | pip3 install --upgrade numpy scipy termcolor matplotlib tensorflow-gpu 27 | ``` 28 | If you're using Python2.7, use `pip2` instead; if you don't have sudo access, add the `--user` flag. 29 | 30 | ### Running the code 31 | The training code can be executed via the command 32 | ``` 33 | python3 train.py [(options)] 34 | ``` 35 | `` should be one of the following: 36 | 1. `CNN` - standard convolutional neural network 37 | 2. `STN` - Spatial Transformer Network (STN) 38 | 3. `IC-STN` - Inverse Compositional Spatial Transformer Network (IC-STN) 39 | 40 | The list of optional arguments can be found by executing `python3 train.py --help`. 41 | The default training settings in this released code is slightly different from that in the paper; it is stabler and optimizes the networks better. 42 | 43 | When the code is run for the first time, the datasets will be automatically downloaded and preprocessed. 44 | The checkpoints are saved in the automatically created directory `model_GROUP`; summaries are saved in `summary_GROUP`. 45 | 46 | ### Visualizing the results 47 | We've included code to visualize the training over TensorBoard. To execute, run 48 | ``` 49 | tensorboard --logdir=summary_GROUP --port=6006 50 | ``` 51 | 52 | We provide three types of data visualization: 53 | 1. **SCALARS**: training/test error over iterations 54 | 2. **IMAGES**: alignment results and mean/variance appearances 55 | 3. **GRAPH**: network architecture 56 | 57 | -------------------------------------- 58 | 59 | ## PyTorch 60 | 61 | The PyTorch version of the code is stil under active development. The training speed is currently slower than the TensorFlow version. Suggestions on improvements are welcome! :) 62 | 63 | ### Prerequisites 64 | This code is developed with Python3 (`python3`). It has not been tested with Python2.7 yet. PyTorch 0.2.0+ is required. Please see http://pytorch.org/ for installation instructions. 65 | Visdom is also required; it can be installed by running 66 | ``` 67 | pip3 install --upgrade visdom 68 | ``` 69 | If you don't have sudo access, add the `--user` flag. 70 | 71 | ### Running the code 72 | First, start a Visdom server by running 73 | ``` 74 | python3 -m visdom.server -port=7000 75 | ``` 76 | The training code can be executed via the command (using the same port number) 77 | ``` 78 | python3 train.py --port=7000 [(options)] 79 | ``` 80 | `` should be one of the following: 81 | 1. `CNN` - standard convolutional neural network 82 | 2. `STN` - Spatial Transformer Network (STN) 83 | 3. `IC-STN` - Inverse Compositional Spatial Transformer Network (IC-STN) 84 | 85 | The list of optional arguments can be found by executing `python3 train.py --help`. 86 | The default training settings in this released code is slightly different from that in the paper; it is stabler and optimizes the networks better. 87 | 88 | When the code is run for the first time, the datasets will be automatically downloaded and preprocessed. 89 | The checkpoints are saved in the automatically created directory `model_GROUP`; summaries are saved in `summary_GROUP`. 90 | 91 | ### Visualizing the results 92 | We provide three types of data visualization on Visdom: 93 | 1. Training/test error over iterations 94 | 2. Alignment results and mean/variance appearances 95 | 96 | -------------------------------------- 97 | 98 | If you find our code useful for your research, please cite 99 | ``` 100 | @inproceedings{lin2017inverse, 101 | title={Inverse Compositional Spatial Transformer Networks}, 102 | author={Lin, Chen-Hsuan and Lucey, Simon}, 103 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition ({CVPR})}, 104 | year={2017} 105 | } 106 | ``` 107 | 108 | Please contact me (chlin@cmu.edu) if you have any questions! 109 | 110 | 111 | -------------------------------------------------------------------------------- /traffic-sign-tensorflow/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg,scipy.misc 3 | import os,time 4 | import tensorflow as tf 5 | import matplotlib.pyplot as plt 6 | import csv 7 | 8 | import warp 9 | 10 | # load GTSRB data 11 | def loadGTSRB(opt,fname): 12 | if not os.path.exists(fname): 13 | # download and preprocess GTSRB dataset 14 | os.makedirs(os.path.dirname(fname)) 15 | os.system("wget -O data/GTSRB_Final_Training_Images.zip http://benchmark.ini.rub.de/Dataset/GTSRB_Final_Training_Images.zip") 16 | os.system("wget -O data/GTSRB_Final_Test_Images.zip http://benchmark.ini.rub.de/Dataset/GTSRB_Final_Test_Images.zip") 17 | os.system("wget -O data/GTSRB_Final_Test_GT.zip http://benchmark.ini.rub.de/Dataset/GTSRB_Final_Test_GT.zip") 18 | os.system("cd data && unzip GTSRB_Final_Training_Images.zip") 19 | os.system("cd data && unzip GTSRB_Final_Test_Images.zip") 20 | os.system("cd data && unzip GTSRB_Final_Test_GT.zip") 21 | # training data 22 | print("preparing training data...") 23 | images,bboxes,labels = [],[],[] 24 | for c in range(43): 25 | prefix = "data/GTSRB/Final_Training/Images/{0:05d}".format(c) 26 | with open("{0}/GT-{1:05d}.csv".format(prefix,c)) as file: 27 | reader = csv.reader(file,delimiter=";") 28 | next(reader) 29 | for line in reader: 30 | img = plt.imread(prefix+"/"+line[0]) 31 | rawH,rawW = img.shape[0],img.shape[1] 32 | scaleH,scaleW = float(opt.fullH)/rawH,float(opt.fullW)/rawW 33 | imgResize = scipy.misc.imresize(img,(opt.fullH,opt.fullW,3)) 34 | images.append(imgResize) 35 | bboxes.append([float(line[3])*scaleW,float(line[4])*scaleH, 36 | float(line[5])*scaleW,float(line[6])*scaleH]) 37 | labels.append(int(line[7])) 38 | trainData = { 39 | "image": np.array(images), 40 | "bbox": np.array(bboxes), 41 | "label": np.array(labels) 42 | } 43 | # test data 44 | print("preparing test data...") 45 | images,bboxes,labels = [],[],[] 46 | prefix = "data/GTSRB/Final_Test/Images/" 47 | with open("data/GT-final_test.csv") as file: 48 | reader = csv.reader(file,delimiter=";") 49 | next(reader) 50 | for line in reader: 51 | img = plt.imread(prefix+"/"+line[0]) 52 | rawH,rawW = img.shape[0],img.shape[1] 53 | scaleH,scaleW = float(opt.fullH)/rawH,float(opt.fullW)/rawW 54 | imgResize = scipy.misc.imresize(img,(opt.fullH,opt.fullW,3)) 55 | images.append(imgResize) 56 | bboxes.append([float(line[3])*scaleW,float(line[4])*scaleH, 57 | float(line[5])*scaleW,float(line[6])*scaleH]) 58 | labels.append(int(line[7])) 59 | testData = { 60 | "image": np.array(images), 61 | "bbox": np.array(bboxes), 62 | "label": np.array(labels) 63 | } 64 | np.savez(fname,train=trainData,test=testData) 65 | os.system("rm -rf data/*.zip") 66 | GTSRB = np.load(fname) 67 | trainData = GTSRB["train"].item() 68 | testData = GTSRB["test"].item() 69 | return trainData,testData 70 | 71 | # generate training batch 72 | def genPerturbations(opt): 73 | with tf.name_scope("genPerturbations"): 74 | X = np.tile(opt.canon4pts[:,0],[opt.batchSize,1]) 75 | Y = np.tile(opt.canon4pts[:,1],[opt.batchSize,1]) 76 | dX = tf.random_normal([opt.batchSize,4])*opt.pertScale \ 77 | +tf.random_normal([opt.batchSize,1])*opt.transScale 78 | dY = tf.random_normal([opt.batchSize,4])*opt.pertScale \ 79 | +tf.random_normal([opt.batchSize,1])*opt.transScale 80 | O = np.zeros([opt.batchSize,4],dtype=np.float32) 81 | I = np.ones([opt.batchSize,4],dtype=np.float32) 82 | # fit warp parameters to generated displacements 83 | if opt.warpType=="homography": 84 | A = tf.concat([tf.stack([X,Y,I,O,O,O,-X*(X+dX),-Y*(X+dX)],axis=-1), 85 | tf.stack([O,O,O,X,Y,I,-X*(Y+dY),-Y*(Y+dY)],axis=-1)],1) 86 | b = tf.expand_dims(tf.concat([X+dX,Y+dY],1),-1) 87 | pPert = tf.matrix_solve(A,b)[:,:,0] 88 | pPert -= tf.to_float([[1,0,0,0,1,0,0,0]]) 89 | else: 90 | if opt.warpType=="translation": 91 | J = np.concatenate([np.stack([I,O],axis=-1), 92 | np.stack([O,I],axis=-1)],axis=1) 93 | if opt.warpType=="similarity": 94 | J = np.concatenate([np.stack([X,Y,I,O],axis=-1), 95 | np.stack([-Y,X,O,I],axis=-1)],axis=1) 96 | if opt.warpType=="affine": 97 | J = np.concatenate([np.stack([X,Y,I,O,O,O],axis=-1), 98 | np.stack([O,O,O,X,Y,I],axis=-1)],axis=1) 99 | dXY = tf.expand_dims(tf.concat([dX,dY],1),-1) 100 | pPert = tf.matrix_solve_ls(J,dXY)[:,:,0] 101 | return pPert 102 | 103 | # make training batch 104 | def makeBatch(opt,data,PH): 105 | N = len(data["image"]) 106 | randIdx = np.random.randint(N,size=[opt.batchSize]) 107 | # put data in placeholders 108 | [image,label] = PH 109 | batch = { 110 | image: data["image"][randIdx]/255.0, 111 | label: data["label"][randIdx], 112 | } 113 | return batch 114 | 115 | # evaluation on test set 116 | def evalTest(opt,sess,data,PH,prediction,imagesEval=[]): 117 | N = len(data["image"]) 118 | # put data in placeholders 119 | [image,label] = PH 120 | batchN = int(np.ceil(N/opt.batchSize)) 121 | warped = [{},{}] 122 | count = 0 123 | for b in range(batchN): 124 | # use some dummy data (0) as batch filler if necessary 125 | if b!=batchN-1: 126 | realIdx = np.arange(opt.batchSize*b,opt.batchSize*(b+1)) 127 | else: 128 | realIdx = np.arange(opt.batchSize*b,N) 129 | idx = np.zeros([opt.batchSize],dtype=int) 130 | idx[:len(realIdx)] = realIdx 131 | batch = { 132 | image: data["image"][idx]/255.0, 133 | label: data["label"][idx], 134 | } 135 | evalList = sess.run([prediction]+imagesEval,feed_dict=batch) 136 | pred = evalList[0] 137 | count += pred[:len(realIdx)].sum() 138 | if len(imagesEval)>0: 139 | imgs = evalList[1:] 140 | for i in range(len(realIdx)): 141 | if data["label"][idx[i]] not in warped[0]: warped[0][data["label"][idx[i]]] = [] 142 | if data["label"][idx[i]] not in warped[1]: warped[1][data["label"][idx[i]]] = [] 143 | warped[0][data["label"][idx[i]]].append(imgs[0][i]) 144 | warped[1][data["label"][idx[i]]].append(imgs[1][i]) 145 | accuracy = float(count)/N 146 | if len(imagesEval)>0: 147 | mean = [np.array([np.mean(warped[0][l],axis=0) for l in warped[0]]), 148 | np.array([np.mean(warped[1][l],axis=0) for l in warped[1]])] 149 | var = [np.array([np.var(warped[0][l],axis=0) for l in warped[0]]), 150 | np.array([np.var(warped[1][l],axis=0) for l in warped[1]])] 151 | else: mean,var = None,None 152 | return accuracy,mean,var 153 | -------------------------------------------------------------------------------- /traffic-sign-tensorflow/graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import time 4 | import data,warp,util 5 | 6 | # build classification network 7 | def fullCNN(opt,image): 8 | def conv2Layer(opt,feat,outDim): 9 | weight,bias = createVariable(opt,[7,7,int(feat.shape[-1]),outDim],stddev=opt.stdC) 10 | conv = tf.nn.conv2d(feat,weight,strides=[1,1,1,1],padding="VALID")+bias 11 | return conv 12 | def linearLayer(opt,feat,outDim): 13 | weight,bias = createVariable(opt,[int(feat.shape[-1]),outDim],stddev=opt.stdC) 14 | fc = tf.matmul(feat,weight)+bias 15 | return fc 16 | with tf.variable_scope("classifier"): 17 | feat = image 18 | with tf.variable_scope("conv1"): 19 | feat = conv2Layer(opt,feat,6) 20 | feat = tf.nn.relu(feat) 21 | with tf.variable_scope("conv2"): 22 | feat = conv2Layer(opt,feat,12) 23 | feat = tf.nn.relu(feat) 24 | feat = tf.nn.max_pool(feat,ksize=[1,2,2,1],strides=[1,2,2,1],padding="VALID") 25 | with tf.variable_scope("conv3"): 26 | feat = conv2Layer(opt,feat,24) 27 | feat = tf.nn.relu(feat) 28 | feat = tf.reshape(feat,[opt.batchSize,-1]) 29 | with tf.variable_scope("fc4"): 30 | feat = linearLayer(opt,feat,200) 31 | feat = tf.nn.relu(feat) 32 | with tf.variable_scope("fc5"): 33 | feat = linearLayer(opt,feat,opt.labelN) 34 | output = feat 35 | return output 36 | 37 | # build classification network 38 | def CNN(opt,image): 39 | def conv2Layer(opt,feat,outDim): 40 | weight,bias = createVariable(opt,[7,7,int(feat.shape[-1]),outDim],stddev=opt.stdC) 41 | conv = tf.nn.conv2d(feat,weight,strides=[1,1,1,1],padding="VALID")+bias 42 | return conv 43 | def linearLayer(opt,feat,outDim): 44 | weight,bias = createVariable(opt,[int(feat.shape[-1]),outDim],stddev=opt.stdC) 45 | fc = tf.matmul(feat,weight)+bias 46 | return fc 47 | with tf.variable_scope("classifier"): 48 | feat = image 49 | with tf.variable_scope("conv1"): 50 | feat = conv2Layer(opt,feat,6) 51 | feat = tf.nn.relu(feat) 52 | with tf.variable_scope("conv2"): 53 | feat = conv2Layer(opt,feat,12) 54 | feat = tf.nn.relu(feat) 55 | feat = tf.nn.max_pool(feat,ksize=[1,2,2,1],strides=[1,2,2,1],padding="VALID") 56 | feat = tf.reshape(feat,[opt.batchSize,-1]) 57 | with tf.variable_scope("fc3"): 58 | feat = linearLayer(opt,feat,opt.labelN) 59 | output = feat 60 | return output 61 | 62 | # build Spatial Transformer Network 63 | def STN(opt,image): 64 | def conv2Layer(opt,feat,outDim): 65 | weight,bias = createVariable(opt,[7,7,int(feat.shape[-1]),outDim],stddev=opt.stdGP) 66 | conv = tf.nn.conv2d(feat,weight,strides=[1,1,1,1],padding="VALID")+bias 67 | return conv 68 | def linearLayer(opt,feat,outDim): 69 | weight,bias = createVariable(opt,[int(feat.shape[-1]),outDim],stddev=opt.stdGP) 70 | fc = tf.matmul(feat,weight)+bias 71 | return fc 72 | imageWarpAll = [image] 73 | with tf.variable_scope("geometric"): 74 | feat = image 75 | with tf.variable_scope("conv1"): 76 | feat = conv2Layer(opt,feat,6) 77 | feat = tf.nn.relu(feat) 78 | with tf.variable_scope("conv2"): 79 | feat = conv2Layer(opt,feat,24) 80 | feat = tf.nn.relu(feat) 81 | feat = tf.reshape(feat,[opt.batchSize,-1]) 82 | with tf.variable_scope("fc3"): 83 | feat = linearLayer(opt,feat,opt.warpDim) 84 | p = feat 85 | pMtrx = warp.vec2mtrx(opt,p) 86 | imageWarp = warp.transformImage(opt,image,pMtrx) 87 | imageWarpAll.append(imageWarp) 88 | return imageWarpAll 89 | 90 | # build Inverse Compositional STN 91 | def ICSTN(opt,imageFull,p): 92 | def conv2Layer(opt,feat,outDim): 93 | weight,bias = createVariable(opt,[7,7,int(feat.shape[-1]),outDim],stddev=opt.stdGP) 94 | conv = tf.nn.conv2d(feat,weight,strides=[1,1,1,1],padding="VALID")+bias 95 | return conv 96 | def linearLayer(opt,feat,outDim): 97 | weight,bias = createVariable(opt,[int(feat.shape[-1]),outDim],stddev=opt.stdGP) 98 | fc = tf.matmul(feat,weight)+bias 99 | return fc 100 | imageWarpAll = [] 101 | for l in range(opt.warpN): 102 | with tf.variable_scope("geometric",reuse=l>0): 103 | pMtrx = warp.vec2mtrx(opt,p) 104 | imageWarp = warp.transformCropImage(opt,imageFull,pMtrx) 105 | imageWarpAll.append(imageWarp) 106 | feat = imageWarp 107 | with tf.variable_scope("conv1"): 108 | feat = conv2Layer(opt,feat,6) 109 | feat = tf.nn.relu(feat) 110 | with tf.variable_scope("conv2"): 111 | feat = conv2Layer(opt,feat,24) 112 | feat = tf.nn.relu(feat) 113 | feat = tf.reshape(feat,[opt.batchSize,-1]) 114 | with tf.variable_scope("fc3"): 115 | feat = linearLayer(opt,feat,opt.warpDim) 116 | dp = feat 117 | p = warp.compose(opt,p,dp) 118 | pMtrx = warp.vec2mtrx(opt,p) 119 | imageWarp = warp.transformCropImage(opt,imageFull,pMtrx) 120 | imageWarpAll.append(imageWarp) 121 | return imageWarpAll 122 | 123 | # auxiliary function for creating weight and bias 124 | def createVariable(opt,weightShape,biasShape=None,stddev=None): 125 | if biasShape is None: biasShape = [weightShape[-1]] 126 | weight = tf.get_variable("weight",shape=weightShape,dtype=tf.float32, 127 | initializer=tf.random_normal_initializer(stddev=stddev)) 128 | bias = tf.get_variable("bias",shape=biasShape,dtype=tf.float32, 129 | initializer=tf.random_normal_initializer(stddev=stddev)) 130 | return weight,bias 131 | -------------------------------------------------------------------------------- /traffic-sign-tensorflow/options.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import warp 4 | import util 5 | 6 | def set(training): 7 | 8 | # parse input arguments 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("netType", choices=["CNN","STN","IC-STN"], help="type of network") 11 | parser.add_argument("--group", default="0", help="name for group") 12 | parser.add_argument("--model", default="test", help="name for model instance") 13 | parser.add_argument("--size", default="36x36", help="image resolution") 14 | parser.add_argument("--sizeFull", default="50x50", help="full image resolution") 15 | parser.add_argument("--warpType", default="homography", help="type of warp function on images", 16 | choices=["translation","similarity","affine","homography"]) 17 | parser.add_argument("--warpN", type=int, default=4, help="number of recurrent transformations (for IC-STN)") 18 | parser.add_argument("--stdC", type=float, default=0.01, help="initialization stddev (classification network)") 19 | parser.add_argument("--stdGP", type=float, default=0.001, help="initialization stddev (geometric predictor)") 20 | parser.add_argument("--pertScale", type=float, default=0.25, help="initial perturbation scale") 21 | parser.add_argument("--transScale", type=float, default=0.25, help="initial translation scale") 22 | if training: # training 23 | parser.add_argument("--batchSize", type=int, default=100, help="batch size for SGD") 24 | parser.add_argument("--lrC", type=float, default=1e-2, help="learning rate (classification network)") 25 | parser.add_argument("--lrCdecay", type=float, default=0.1, help="learning rate decay (classification network)") 26 | parser.add_argument("--lrCstep", type=int, default=500000, help="learning rate decay step size (classification network)") 27 | parser.add_argument("--lrGP", type=float, default=None, help="learning rate (geometric predictor)") 28 | parser.add_argument("--lrGPdecay", type=float, default=0.1, help="learning rate decay (geometric predictor)") 29 | parser.add_argument("--lrGPstep", type=int, default=500000, help="learning rate decay step size (geometric predictor)") 30 | parser.add_argument("--fromIt", type=int, default=0, help="resume training from iteration number") 31 | parser.add_argument("--toIt", type=int, default=1000000,help="run training to iteration number") 32 | else: # evaluation 33 | parser.add_argument("--batchSize", type=int, default=1, help="batch size for evaluation") 34 | opt = parser.parse_args() 35 | 36 | if opt.lrGP is None: opt.lrGP = 0 if opt.netType=="CNN" else \ 37 | 1e-3 if opt.netType=="STN" else \ 38 | 3e-5 if opt.netType=="IC-STN" else None 39 | 40 | # --- below are automatically set --- 41 | opt.training = training 42 | opt.H,opt.W = [int(x) for x in opt.size.split("x")] 43 | opt.fullH,opt.fullW = [int(x) for x in opt.sizeFull.split("x")] 44 | opt.visBlockSize = int(np.floor(np.sqrt(opt.batchSize))) 45 | opt.warpDim = 2 if opt.warpType == "translation" else \ 46 | 4 if opt.warpType == "similarity" else \ 47 | 6 if opt.warpType == "affine" else \ 48 | 8 if opt.warpType == "homography" else None 49 | opt.labelN = 43 50 | opt.canon4pts = np.array([[-1,-1],[-1,1],[1,1],[1,-1]],dtype=np.float32) 51 | opt.image4pts = np.array([[0,0],[0,opt.H-1],[opt.W-1,opt.H-1],[opt.W-1,0]],dtype=np.float32) 52 | opt.bbox = [int(opt.fullW/2-opt.W/2),int(opt.fullH/2-opt.H/2),int(opt.fullW/2+opt.W/2),int(opt.fullH/2+opt.H/2)] 53 | opt.bbox4pts = np.array([[opt.bbox[0],opt.bbox[1]],[opt.bbox[0],opt.bbox[3]], 54 | [opt.bbox[2],opt.bbox[3]],[opt.bbox[2],opt.bbox[1]]],dtype=np.float32) 55 | opt.refMtrx = warp.fit(Xsrc=opt.canon4pts,Xdst=opt.image4pts) 56 | opt.bboxRefMtrx = warp.fit(Xsrc=opt.canon4pts,Xdst=opt.bbox4pts) 57 | if opt.netType=="STN": opt.warpN = 1 58 | 59 | print("({0}) {1}".format( 60 | util.toGreen("{0}".format(opt.group)), 61 | util.toGreen("{0}".format(opt.model)))) 62 | print("------------------------------------------") 63 | print("network type: {0}, recurrent warps: {1}".format( 64 | util.toYellow("{0}".format(opt.netType)), 65 | util.toYellow("{0}".format(opt.warpN if opt.netType=="IC-STN" else "X")))) 66 | print("batch size: {0}, image size: {1}x{2}".format( 67 | util.toYellow("{0}".format(opt.batchSize)), 68 | util.toYellow("{0}".format(opt.H)), 69 | util.toYellow("{0}".format(opt.W)))) 70 | print("warpScale: (pert) {0} (trans) {1}".format( 71 | util.toYellow("{0}".format(opt.pertScale)), 72 | util.toYellow("{0}".format(opt.transScale)))) 73 | if training: 74 | print("[geometric predictor] stddev={0}, lr={1}".format( 75 | util.toYellow("{0:.0e}".format(opt.stdGP)), 76 | util.toYellow("{0:.0e}".format(opt.lrGP)))) 77 | print("[classification network] stddev={0}, lr={1}".format( 78 | util.toYellow("{0:.0e}".format(opt.stdC)), 79 | util.toYellow("{0:.0e}".format(opt.lrC)))) 80 | print("------------------------------------------") 81 | if training: 82 | print(util.toMagenta("training model ({0}) {1}...".format(opt.group,opt.model))) 83 | 84 | return opt 85 | -------------------------------------------------------------------------------- /traffic-sign-tensorflow/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time,os,sys 3 | import argparse 4 | import util 5 | 6 | print(util.toYellow("=======================================================")) 7 | print(util.toYellow("train.py (training on MNIST)")) 8 | print(util.toYellow("=======================================================")) 9 | 10 | import tensorflow as tf 11 | import data,graph,warp,util 12 | import options 13 | 14 | print(util.toMagenta("setting configurations...")) 15 | opt = options.set(training=True) 16 | 17 | # create directories for model output 18 | util.mkdir("models_{0}".format(opt.group)) 19 | 20 | print(util.toMagenta("building graph...")) 21 | tf.reset_default_graph() 22 | # build graph 23 | with tf.device("/gpu:0"): 24 | # ------ define input data ------ 25 | imageFull = tf.placeholder(tf.float32,shape=[opt.batchSize,opt.fullH,opt.fullW,3]) 26 | imageMean,imageVar = tf.nn.moments(imageFull,axes=[1,2],keep_dims=True) 27 | imageFullNormalize = (imageFull-imageMean)/tf.sqrt(imageVar) 28 | label = tf.placeholder(tf.int64,shape=[opt.batchSize]) 29 | PH = [imageFull,label] 30 | # ------ generate perturbation ------ 31 | pInit = data.genPerturbations(opt) 32 | pInitMtrx = warp.vec2mtrx(opt,pInit) 33 | # ------ build network ------ 34 | imagePert = warp.transformCropImage(opt,imageFullNormalize,pInitMtrx) 35 | imagePertRescale = imagePert*tf.sqrt(imageVar)+imageMean 36 | if opt.netType=="CNN": 37 | output = graph.fullCNN(opt,imagePert) 38 | elif opt.netType=="STN": 39 | imageWarpAll = graph.STN(opt,imagePert) 40 | imageWarp = imageWarpAll[-1] 41 | output = graph.CNN(opt,imageWarp) 42 | imageWarpRescale = imageWarp*tf.sqrt(imageVar)+imageMean 43 | elif opt.netType=="IC-STN": 44 | imageWarpAll = graph.ICSTN(opt,imageFullNormalize,pInit) 45 | imageWarp = imageWarpAll[-1] 46 | output = graph.CNN(opt,imageWarp) 47 | imageWarpRescale = imageWarp*tf.sqrt(imageVar)+imageMean 48 | softmax = tf.nn.softmax(output) 49 | labelOnehot = tf.one_hot(label,opt.labelN) 50 | prediction = tf.equal(tf.argmax(softmax,1),label) 51 | # ------ define loss ------ 52 | softmaxLoss = tf.nn.softmax_cross_entropy_with_logits(logits=output,labels=labelOnehot) 53 | loss = tf.reduce_mean(softmaxLoss) 54 | # ------ optimizer ------ 55 | lrGP_PH,lrC_PH = tf.placeholder(tf.float32,shape=[]),tf.placeholder(tf.float32,shape=[]) 56 | optim = util.setOptimizer(opt,loss,lrGP_PH,lrC_PH) 57 | # ------ generate summaries ------ 58 | summaryImageTrain = [] 59 | summaryImageTest = [] 60 | if opt.netType=="STN" or opt.netType=="IC-STN": 61 | for l in range(opt.warpN+1): 62 | summaryImageTrain.append(util.imageSummary(opt,imageWarpAll[l]*tf.sqrt(imageVar)+imageMean,"TRAIN_warp{0}".format(l),opt.H,opt.W)) 63 | summaryImageTest.append(util.imageSummary(opt,imageWarpAll[l]*tf.sqrt(imageVar)+imageMean,"TEST_warp{0}".format(l),opt.H,opt.W)) 64 | summaryImageTrain = tf.summary.merge(summaryImageTrain) 65 | summaryImageTest = tf.summary.merge(summaryImageTest) 66 | summaryLossTrain = tf.summary.scalar("TRAIN_loss",loss) 67 | testErrorPH = tf.placeholder(tf.float32,shape=[]) 68 | testImagePH = tf.placeholder(tf.float32,shape=[opt.labelN,opt.H,opt.W,3]) 69 | summaryErrorTest = tf.summary.scalar("TEST_error",testErrorPH) 70 | if opt.netType=="STN" or opt.netType=="IC-STN": 71 | summaryMeanTest0 = util.imageSummaryMeanVar(opt,testImagePH,"TEST_mean_init",opt.H,opt.W) 72 | summaryMeanTest1 = util.imageSummaryMeanVar(opt,testImagePH,"TEST_mean_warped",opt.H,opt.W) 73 | summaryVarTest0 = util.imageSummaryMeanVar(opt,testImagePH,"TEST_var_init",opt.H,opt.W) 74 | summaryVarTest1 = util.imageSummaryMeanVar(opt,testImagePH,"TEST_var_warped",opt.H,opt.W) 75 | 76 | # load data 77 | print(util.toMagenta("loading GTSRB dataset...")) 78 | trainData,testData = data.loadGTSRB(opt,"data/GTSRB.npz") 79 | 80 | # prepare model saver/summary writer 81 | saver = tf.train.Saver(max_to_keep=20) 82 | summaryWriter = tf.summary.FileWriter("summary_{0}/{1}".format(opt.group,opt.model)) 83 | 84 | print(util.toYellow("======= TRAINING START =======")) 85 | timeStart = time.time() 86 | # start session 87 | tfConfig = tf.ConfigProto(allow_soft_placement=True) 88 | tfConfig.gpu_options.allow_growth = True 89 | with tf.Session(config=tfConfig) as sess: 90 | sess.run(tf.global_variables_initializer()) 91 | summaryWriter.add_graph(sess.graph) 92 | if opt.fromIt!=0: 93 | util.restoreModel(opt,sess,saver,opt.fromIt) 94 | print(util.toMagenta("resuming from iteration {0}...".format(opt.fromIt))) 95 | print(util.toMagenta("start training...")) 96 | 97 | # training loop 98 | for i in range(opt.fromIt,opt.toIt): 99 | lrGP = opt.lrGP*opt.lrGPdecay**(i//opt.lrGPstep) 100 | lrC = opt.lrC*opt.lrCdecay**(i//opt.lrCstep) 101 | # make training batch 102 | batch = data.makeBatch(opt,trainData,PH) 103 | batch[lrGP_PH] = lrGP 104 | batch[lrC_PH] = lrC 105 | # run one step 106 | _,l = sess.run([optim,loss],feed_dict=batch) 107 | if (i+1)%100==0: 108 | print("it. {0}/{1} lr={3}(GP),{4}(C), loss={5}, time={2}" 109 | .format(util.toCyan("{0}".format(i+1)), 110 | opt.toIt, 111 | util.toGreen("{0:.2f}".format(time.time()-timeStart)), 112 | util.toYellow("{0:.0e}".format(lrGP)), 113 | util.toYellow("{0:.0e}".format(lrC)), 114 | util.toRed("{0:.4f}".format(l)))) 115 | if (i+1)%100==0: 116 | summaryWriter.add_summary(sess.run(summaryLossTrain,feed_dict=batch),i+1) 117 | if (i+1)%500==0 and (opt.netType=="STN" or opt.netType=="IC-STN"): 118 | summaryWriter.add_summary(sess.run(summaryImageTrain,feed_dict=batch),i+1) 119 | summaryWriter.add_summary(sess.run(summaryImageTest,feed_dict=batch),i+1) 120 | if (i+1)%1000==0: 121 | # evaluate on test set 122 | if opt.netType=="STN" or opt.netType=="IC-STN": 123 | testAcc,testMean,testVar = data.evalTest(opt,sess,testData,PH,prediction,imagesEval=[imagePert,imageWarp]) 124 | else: 125 | testAcc,_,_ = data.evalTest(opt,sess,testData,PH,prediction) 126 | testError = (1-testAcc)*100 127 | summaryWriter.add_summary(sess.run(summaryErrorTest,feed_dict={testErrorPH:testError}),i+1) 128 | if opt.netType=="STN" or opt.netType=="IC-STN": 129 | summaryWriter.add_summary(sess.run(summaryMeanTest0,feed_dict={testImagePH:testMean[0]}),i+1) 130 | summaryWriter.add_summary(sess.run(summaryMeanTest1,feed_dict={testImagePH:testMean[1]}),i+1) 131 | summaryWriter.add_summary(sess.run(summaryVarTest0,feed_dict={testImagePH:testVar[0]}),i+1) 132 | summaryWriter.add_summary(sess.run(summaryVarTest1,feed_dict={testImagePH:testVar[1]}),i+1) 133 | if (i+1)%10000==0: 134 | util.saveModel(opt,sess,saver,i+1) 135 | print(util.toGreen("model saved: {0}/{1}, it.{2}".format(opt.group,opt.model,i+1))) 136 | 137 | print(util.toYellow("======= TRAINING DONE =======")) 138 | -------------------------------------------------------------------------------- /traffic-sign-tensorflow/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import tensorflow as tf 4 | import os 5 | import termcolor 6 | 7 | def mkdir(path): 8 | if not os.path.exists(path): os.mkdir(path) 9 | def imread(fname): 10 | return scipy.misc.imread(fname)/255.0 11 | def imsave(fname,array): 12 | scipy.misc.toimage(array,cmin=0.0,cmax=1.0).save(fname) 13 | 14 | # convert to colored strings 15 | def toRed(content): return termcolor.colored(content,"red",attrs=["bold"]) 16 | def toGreen(content): return termcolor.colored(content,"green",attrs=["bold"]) 17 | def toBlue(content): return termcolor.colored(content,"blue",attrs=["bold"]) 18 | def toCyan(content): return termcolor.colored(content,"cyan",attrs=["bold"]) 19 | def toYellow(content): return termcolor.colored(content,"yellow",attrs=["bold"]) 20 | def toMagenta(content): return termcolor.colored(content,"magenta",attrs=["bold"]) 21 | 22 | # make image summary from image batch 23 | def imageSummary(opt,image,tag,H,W): 24 | blockSize = opt.visBlockSize 25 | imageOne = tf.batch_to_space(image[:blockSize**2],crops=[[0,0],[0,0]],block_size=blockSize) 26 | imagePermute = tf.reshape(imageOne,[H,blockSize,W,blockSize,-1]) 27 | imageTransp = tf.transpose(imagePermute,[1,0,3,2,4]) 28 | imageBlocks = tf.reshape(imageTransp,[1,H*blockSize,W*blockSize,-1]) 29 | imageBlocks = tf.cast(imageBlocks*255,tf.uint8) 30 | summary = tf.summary.image(tag,imageBlocks) 31 | return summary 32 | 33 | # make image summary from image batch (mean/variance) 34 | def imageSummaryMeanVar(opt,image,tag,H,W): 35 | image = tf.concat([image,np.zeros([2,H,W,3])],axis=0) 36 | imageOne = tf.batch_to_space_nd(image,crops=[[0,0],[0,0]],block_shape=[5,9]) 37 | imagePermute = tf.reshape(imageOne,[H,5,W,9,-1]) 38 | imageTransp = tf.transpose(imagePermute,[1,0,3,2,4]) 39 | imageBlocks = tf.reshape(imageTransp,[1,H*5,W*9,-1]) 40 | # imageBlocks = tf.cast(imageBlocks*255,tf.uint8) 41 | summary = tf.summary.image(tag,imageBlocks) 42 | return summary 43 | 44 | # set optimizer for different learning rates 45 | def setOptimizer(opt,loss,lrGP,lrC): 46 | varsGP = [v for v in tf.global_variables() if "geometric" in v.name] 47 | varsC = [v for v in tf.global_variables() if "classifier" in v.name] 48 | gradC = tf.gradients(loss,varsC) 49 | optimC = tf.train.GradientDescentOptimizer(lrC).apply_gradients(zip(gradC,varsC)) 50 | if len(varsGP)>0: 51 | gradGP = tf.gradients(loss,varsGP) 52 | optimGP = tf.train.GradientDescentOptimizer(lrGP).apply_gradients(zip(gradGP,varsGP)) 53 | optim = tf.group(optimC,optimGP) 54 | else: 55 | optim = optimC 56 | return optim 57 | 58 | # restore model 59 | def restoreModel(opt,sess,saver,it): 60 | saver.restore(sess,"models_{0}/{1}_it{2}.ckpt".format(opt.group,opt.model,it,opt.warpN)) 61 | # save model 62 | def saveModel(opt,sess,saver,it): 63 | saver.save(sess,"models_{0}/{1}_it{2}.ckpt".format(opt.group,opt.model,it,opt.warpN)) 64 | 65 | -------------------------------------------------------------------------------- /traffic-sign-tensorflow/warp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg 3 | import tensorflow as tf 4 | 5 | # fit (affine) warp between two sets of points 6 | def fit(Xsrc,Xdst): 7 | ptsN = len(Xsrc) 8 | X,Y,U,V,O,I = Xsrc[:,0],Xsrc[:,1],Xdst[:,0],Xdst[:,1],np.zeros([ptsN]),np.ones([ptsN]) 9 | A = np.concatenate((np.stack([X,Y,I,O,O,O],axis=1), 10 | np.stack([O,O,O,X,Y,I],axis=1)),axis=0) 11 | b = np.concatenate((U,V),axis=0) 12 | p1,p2,p3,p4,p5,p6 = scipy.linalg.lstsq(A,b)[0].squeeze() 13 | pMtrx = np.array([[p1,p2,p3],[p4,p5,p6],[0,0,1]],dtype=np.float32) 14 | return pMtrx 15 | 16 | # compute composition of warp parameters 17 | def compose(opt,p,dp): 18 | with tf.name_scope("compose"): 19 | pMtrx = vec2mtrx(opt,p) 20 | dpMtrx = vec2mtrx(opt,dp) 21 | pMtrxNew = tf.matmul(dpMtrx,pMtrx) 22 | pMtrxNew /= pMtrxNew[:,2:3,2:3] 23 | pNew = mtrx2vec(opt,pMtrxNew) 24 | return pNew 25 | 26 | # compute inverse of warp parameters 27 | def inverse(opt,p): 28 | with tf.name_scope("inverse"): 29 | pMtrx = vec2mtrx(opt,p) 30 | pInvMtrx = tf.matrix_inverse(pMtrx) 31 | pInv = mtrx2vec(opt,pInvMtrx) 32 | return pInv 33 | 34 | # convert warp parameters to matrix 35 | def vec2mtrx(opt,p): 36 | with tf.name_scope("vec2mtrx"): 37 | O = tf.zeros([opt.batchSize]) 38 | I = tf.ones([opt.batchSize]) 39 | if opt.warpType=="translation": 40 | tx,ty = tf.unstack(p,axis=1) 41 | pMtrx = tf.transpose(tf.stack([[I,O,tx],[O,I,ty],[O,O,I]]),perm=[2,0,1]) 42 | if opt.warpType=="similarity": 43 | pc,ps,tx,ty = tf.unstack(p,axis=1) 44 | pMtrx = tf.transpose(tf.stack([[I+pc,-ps,tx],[ps,I+pc,ty],[O,O,I]]),perm=[2,0,1]) 45 | if opt.warpType=="affine": 46 | p1,p2,p3,p4,p5,p6,p7,p8 = tf.unstack(p,axis=1) 47 | pMtrx = tf.transpose(tf.stack([[I+p1,p2,p3],[p4,I+p5,p6],[O,O,I]]),perm=[2,0,1]) 48 | if opt.warpType=="homography": 49 | p1,p2,p3,p4,p5,p6,p7,p8 = tf.unstack(p,axis=1) 50 | pMtrx = tf.transpose(tf.stack([[I+p1,p2,p3],[p4,I+p5,p6],[p7,p8,I]]),perm=[2,0,1]) 51 | return pMtrx 52 | 53 | # convert warp matrix to parameters 54 | def mtrx2vec(opt,pMtrx): 55 | with tf.name_scope("mtrx2vec"): 56 | [row0,row1,row2] = tf.unstack(pMtrx,axis=1) 57 | [e00,e01,e02] = tf.unstack(row0,axis=1) 58 | [e10,e11,e12] = tf.unstack(row1,axis=1) 59 | [e20,e21,e22] = tf.unstack(row2,axis=1) 60 | if opt.warpType=="translation": p = tf.stack([e02,e12],axis=1) 61 | if opt.warpType=="similarity": p = tf.stack([e00-1,e10,e02,e12],axis=1) 62 | if opt.warpType=="affine": p = tf.stack([e00-1,e01,e02,e10,e11-1,e12],axis=1) 63 | if opt.warpType=="homography": p = tf.stack([e00-1,e01,e02,e10,e11-1,e12,e20,e21],axis=1) 64 | return p 65 | 66 | # warp the image 67 | def transformImage(opt,image,pMtrx): 68 | with tf.name_scope("transformImage"): 69 | refMtrx = tf.tile(tf.expand_dims(opt.refMtrx,axis=0),[opt.batchSize,1,1]) 70 | transMtrx = tf.matmul(refMtrx,pMtrx) 71 | # warp the canonical coordinates 72 | X,Y = np.meshgrid(np.linspace(-1,1,opt.W),np.linspace(-1,1,opt.H)) 73 | X,Y = X.flatten(),Y.flatten() 74 | XYhom = np.stack([X,Y,np.ones_like(X)],axis=1).T 75 | XYhom = np.tile(XYhom,[opt.batchSize,1,1]).astype(np.float32) 76 | XYwarpHom = tf.matmul(transMtrx,XYhom) 77 | XwarpHom,YwarpHom,ZwarpHom = tf.unstack(XYwarpHom,axis=1) 78 | Xwarp = tf.reshape(XwarpHom/(ZwarpHom+1e-8),[opt.batchSize,opt.H,opt.W]) 79 | Ywarp = tf.reshape(YwarpHom/(ZwarpHom+1e-8),[opt.batchSize,opt.H,opt.W]) 80 | # get the integer sampling coordinates 81 | Xfloor,Xceil = tf.floor(Xwarp),tf.ceil(Xwarp) 82 | Yfloor,Yceil = tf.floor(Ywarp),tf.ceil(Ywarp) 83 | XfloorInt,XceilInt = tf.to_int32(Xfloor),tf.to_int32(Xceil) 84 | YfloorInt,YceilInt = tf.to_int32(Yfloor),tf.to_int32(Yceil) 85 | imageIdx = np.tile(np.arange(opt.batchSize).reshape([opt.batchSize,1,1]),[1,opt.H,opt.W]) 86 | imageVec = tf.reshape(image,[-1,int(image.shape[-1])]) 87 | imageVecOut = tf.concat([imageVec,tf.zeros([1,int(image.shape[-1])])],axis=0) 88 | idxUL = (imageIdx*opt.H+YfloorInt)*opt.W+XfloorInt 89 | idxUR = (imageIdx*opt.H+YfloorInt)*opt.W+XceilInt 90 | idxBL = (imageIdx*opt.H+YceilInt)*opt.W+XfloorInt 91 | idxBR = (imageIdx*opt.H+YceilInt)*opt.W+XceilInt 92 | idxOutside = tf.fill([opt.batchSize,opt.H,opt.W],opt.batchSize*opt.H*opt.W) 93 | def insideImage(Xint,Yint): 94 | return (Xint>=0)&(Xint=0)&(Yint=0)&(Xint=0)&(Yint