├── .gitignore ├── README.md ├── __init__.py ├── doxy_mypycaffe.config ├── matconvnet_caffe.py ├── my_examples.py ├── my_exp_config.py ├── my_exp_config_old.py ├── my_netspec.py ├── my_pycaffe.py ├── my_pycaffe_io.py ├── my_pycaffe_tests.py ├── my_pycaffe_utils.py ├── other_utils.py ├── pycaffe_config.py ├── quick_things.py ├── rot_utils.py ├── test_bench ├── __init__.py ├── test_mysolver.py └── test_sql.py └── vis_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swo 3 | *.swp 4 | debug-files/* 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository provides utilities for conveniently defining and running deep learning experiments using Caffe. Functions in this repository are especially useful for performing parameter sweeps, visualizing and recording results or debugging the training process of nets. 2 | 3 | Note: This README is being constantly updated and currently covers only a few functions provided as part of pycaffe-utils. 4 | 5 | Dependencies 6 | ------------- 7 | This is not an exhaustive list. 8 | ``` 9 | git clone https://github.com/pulkitag/pyhelper_fns.git 10 | sudo apt-get install liblmdb-dev 11 | sudo pip install lmdb 12 | ``` 13 | 14 | 15 | Setting up a Caffe Experiment 16 | ---------------------------------- 17 | 18 | There are three main classes of parameters needed to define an experiment: 19 | - What data is to be used (i.e. images/labels) (called dPrms or data parameters) 20 | - What should be the structure of the network (called nPrms or network parameters) 21 | - How should the learning proceed (called sPrms or solver parameters) 22 | 23 | Details of different experiments (specified by different parameters) are stored in SQL database. The SQL database stores an automatically generated hash string for each parameter setting and that is used to automatically generate and name files that are used to run 24 | the experiment. 25 | 26 | #### Specifying dPrms 27 | type: EasyDict 28 | 29 | The minimal definition of dPrms is below: 30 |

 31 | from easydict import EasyDict as edict
 32 | dPrms     =   edict()
 33 | dPrms['expStr'] = 'demo-experiment' #The name of the experiment
 34 | dPrms.paths     = edict() #The paths that will be used
 35 | dPrms.paths.exp    = edict() #Paths for storing experiment files
 36 | dPrms.paths.exp.dr = '/directory/for/storing/experiment/files'
 37 | dPrms.paths.snapshot    = edict()
 38 | dPrms.paths.snapshot.dr = '/directory/for/storing/snapshots'
 39 | 
40 | 41 | #### Specifying nPrms 42 | type: EasyDict 43 | 44 | The minimal definition is defined in module my_exp_config in function get_default_net_prms. 45 | 46 | Custom nPrms should be defined as following: 47 | (To be updated soon). 48 | 49 | #### Specifying sPrms 50 | type: EasyDict 51 | 52 | To be updated soon. 53 | 54 | 55 | 56 | 57 | 58 | 59 | Debugging a Caffe Experiment 60 | ------------------------------------------------------------------------- 61 | 62 | If a deep network is not training, it is instructive to look at how the parameters, gradients and feature values of different layers change with iterations. It is easy to log, 63 | - The parameter values 64 | - The parameter update values (i.e. gradients) 65 | - The feature values 66 | 67 | of all the blobs in the net using the following code snippet 68 | 69 | ```python 70 | import my_pycaffe as mp 71 | #Define the solver using caffe style solver prototxt 72 | sol = mp.MySolver.from_file(solver_prototxt) 73 | #Number of iterations after which parameters should be saved to log file 74 | saveIter = 1000 75 | maxIter = 200000 76 | #Name of the log file 77 | logFile = 'log.pkl' 78 | for r in range(0, maxIter, saveIter): 79 | #Train for saveIter iterations 80 | sol.solve(saveIter) 81 | #Save the log file 82 | sol.dump_to_file(logFile) 83 | ``` 84 | The logged values can be easily plotted using, `sol.plot()` 85 | 86 | To restore from a solver state 87 | ```python 88 | #fName : the .solverstate file name 89 | #restoreIter: the iteration from which the log should be restored. 90 | sol.restore(fName, restoreIter) 91 | ``` 92 | 93 | 94 | Creating a Siamese prototxt file for Caffe 95 | ------------------------- 96 |

 97 | import my_pycaffe_utils as mpu
 98 | fName = 'deploy.prototxt'
 99 | pDef  = mpu.ProtoDef(fName)
100 | #Make a siamese protodef by duplicating layers between 'conv1' and 'conv5', leave
101 | #other layers as such.
102 | siameseDef = pDef.get_siamese('conv1', 'conv5')
103 | #Save the siamese file
104 | siameseDef.write('siamese.prototxt')
105 | 
106 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from pyhelper_fns import my_sqlite 2 | -------------------------------------------------------------------------------- /matconvnet_caffe.py: -------------------------------------------------------------------------------- 1 | import h5py as h5 2 | import numpy as np 3 | import other_utils as ou 4 | import my_pycaffe_utils as mpu 5 | import caffe 6 | import collections as co 7 | import pickle 8 | 9 | class MatConvNetModel: 10 | def __init__(self, inFile): 11 | self.dat_ = h5.File(inFile, 'r') 12 | 13 | def ref_to_str(self, ref): 14 | secRefs = self.dat_['#refs#'][ref][:] 15 | ch = [] 16 | #print secRefs 17 | for sec in secRefs: 18 | if not sec == 0: 19 | #If the reference is not empty 20 | ch.append(ou.ints_to_str(self.dat_['#refs#'][sec[0]])) 21 | return ch 22 | 23 | def make_caffe_layer(self, lNum): 24 | #Get the name 25 | nameRef = self.dat_['net']['layers']['name'][lNum][0] 26 | name = ou.ints_to_str(self.dat_['#refs#'][nameRef][:]) 27 | #Get the inputs 28 | ipRef = self.dat_['net']['layers']['inputs'][lNum][0] 29 | ipNames = self.ref_to_str(ipRef) 30 | #Get the parameter names 31 | pmRef = self.dat_['net']['layers']['params'][lNum][0] 32 | pmNames = self.ref_to_str(pmRef) 33 | #Get the output names 34 | opRef = self.dat_['net']['layers']['outputs'][lNum][0] 35 | opNames = self.ref_to_str(opRef) 36 | #Get the layer type 37 | tpRef = self.dat_['net']['layers']['type'][lNum][0] 38 | lType = ou.ints_to_str(self.dat_['#refs#'][tpRef][:]) 39 | #Get the layer params 40 | lpRef = self.dat_['net']['layers']['block'][lNum][0] 41 | lParam = self.dat_['#refs#'][lpRef] 42 | assert (lType[0:5] == 'dagnn') 43 | lType = lType[6:] 44 | 45 | if lType == 'Conv': 46 | paramW = {'name': pmNames[0]} 47 | paramB = {'name': pmNames[1]} 48 | pDupKey = mpu.make_key('param', ['param']) 49 | lDef = mpu.get_layerdef_for_proto('Convolution', name, ipNames[0], 50 | **{'num_output': int(lParam['size'][3][0]), 51 | 'param': paramW, pDupKey: paramB, 52 | 'kernel_size': int(lParam['size'][0][0]), 53 | 'stride': int(lParam['stride'][0][0]), 54 | 'pad': int(lParam['pad'][0][0])}) 55 | 56 | elif lType == 'ReLU': 57 | lDef = mpu.get_layerdef_for_proto(lType, name, ipNames[0], 58 | **{'top': opNames[0]}) 59 | 60 | elif lType == 'Pooling': 61 | poolType = lParam['method'][0] 62 | if poolType == 'max': 63 | poolType = 'MAX' 64 | elif poolType == 'avg': 65 | poolType = 'AVE' 66 | lDef = mpu.get_layerdef_for_proto(lType, name, ipNames[0], 67 | **{'top': opNames[0], 'kernel_size': int(lParam['poolSize'][0][0]), 68 | 'stride': int(lParam['stride'][0][0]), 'pad': int(lParam['pad'][0][0]), 69 | 'pool': poolType}) 70 | 71 | elif lType == 'LRN': 72 | N, kappa, alpha, beta = lParam['param'][0][0], lParam['param'][1][0],\ 73 | lParam['param'][2][0], lParam['param'][3][0] 74 | lDef = mpu.get_layerdef_for_proto(lType, name, ipNames[0], 75 | **{'top': opNames[0], 76 | 'local_size': int(N), 77 | 'alpha': N * alpha, 78 | 'beta' : beta, 79 | 'k' : kappa}) 80 | 81 | elif lType == 'Concat': 82 | lDef = mpu.get_layerdef_for_proto(lType, name, ipNames[0], 83 | **{'bottom2': ipNames[1:], 84 | 'concat_dim': 1, 85 | 'top': opNames[0]}) 86 | 87 | elif lType == 'Loss': 88 | lossType = ou.ints_to_str(lParam['loss']) 89 | if lossType == 'pdist': 90 | p = lParam['p'][0][0] 91 | if p == 2: 92 | lossName = 'EuclideanLoss' 93 | else: 94 | raise Exception('Loss type %s not recognized' % lossType) 95 | else: 96 | raise Exception('Loss type %s not recognized' % lossType) 97 | lDef = mpu.get_layerdef_for_proto(lossName, name, ipNames[0], 98 | **{'bottom2': ipNames[1]}) 99 | 100 | elif lType == 'gaussRender': 101 | lDef = mpu.get_layerdef_for_proto(lType, name, ipNames[0], 102 | **{'top': opNames[0], 103 | 'K': lParam['K'][0][0], 'T': lParam['T'][0][0], 104 | 'sigma': lParam['sigma'][0][0], 'imgSz': int(lParam['img_size'][0][0])}) 105 | 106 | else: 107 | raise Exception('Layer Type %s not recognized, %d' % (lType, lNum)) 108 | return lDef 109 | 110 | #Convert the model to Caffe 111 | def to_caffe(self, ipLayers=[], layerOrder=[]): 112 | ''' 113 | Caffe doesnot support DAGs but MatConvNet does. layerOrder allows some matconvnet 114 | nets to expressed as caffe nets by moving the order of layers so as to allow caffe 115 | to read the generated prototxt file. 116 | ''' 117 | pDef = mpu.ProtoDef() 118 | caffeLayers = co.OrderedDict() 119 | for lNum in range(len(self.dat_['net']['layers']['name'])): 120 | cl = self.make_caffe_layer(lNum) 121 | caffeLayers[cl['name'][1:-1]] = cl 122 | #Add input layers if needed 123 | for ipl in ipLayers: 124 | pDef.add_layer(ipl['name'][1:-1], ipl) 125 | #Add the ordered layers first 126 | for l in layerOrder: 127 | pDef.add_layer(l, caffeLayers[l]) 128 | del caffeLayers[l] 129 | for key, cl in caffeLayers.iteritems(): 130 | pDef.add_layer(key, cl) 131 | return pDef 132 | 133 | ## 134 | def save_caffe_model(self, 135 | outName='/work4/pulkitag-code/code/ief/IEF/models/ief-googlenet-dec2015', **kwargs): 136 | #caffe prototxt 137 | defFile = outName + '.prototxt' 138 | #caffe model 139 | modelFile = outName + '.caffemodel' 140 | #the meta data 141 | metaFile = outName + '-meta.pkl' 142 | 143 | #obtain prototxt from matconvnet and write to disk 144 | pDef = self.to_caffe(**kwargs) 145 | pDef.write(defFile) 146 | 147 | #Store th weights 148 | net = caffe.Net(defFile, caffe.TEST) 149 | #List the parameter names of all the matconvnet params 150 | matPrmNames = [] 151 | for p in range(len(self.dat_['net']['params']['name'])): 152 | prmRef = self.dat_['net']['params']['name'][p][0] 153 | matPrmNames.append(ou.ints_to_str(self.dat_['#refs#'][prmRef][:])) 154 | #Name of caffe params 155 | paramKeys = net.params.keys() 156 | for k in paramKeys: 157 | for i in range(2): 158 | prm = pDef.get_layer_property(k, 'param', propNum=i) 159 | prmName = prm['name'][1:-1] 160 | idx = matPrmNames.index(prmName) 161 | valRef = self.dat_['net']['params']['value'][idx][0] 162 | vals = np.array(self.dat_['#refs#'][valRef]) 163 | if i==0: 164 | vals = vals.transpose((0,1,3,2)) 165 | print (k, i, net.params[k][i].data.shape, vals.shape) 166 | net.params[k][i].data[...] = vals.reshape(net.params[k][i].data.shape) 167 | net.save(modelFile) 168 | 169 | #Store meta information 170 | seedPose = np.array(self.dat_['params']['seed_pose']) 171 | mxStpNrm = np.array(self.dat_['params']['MAX_STEP_NORM'])[0][0] 172 | pickle.dump({'seedPose': seedPose, 'mxStepNorm': mxStpNrm}, 173 | open(metaFile, 'w')) 174 | 175 | ## 176 | # Convert matconvnet network into a caffemodel 177 | def matconvnet_dag_to_caffemodel(inFile, outFile): 178 | dat = h5.File(inFile, 'r') 179 | 180 | 181 | ## test the conversion 182 | def test_convert(): 183 | fName = '/work4/pulkitag-code/code/ief/IEF/models/new_models/models/new-model.mat' 184 | outName = 'try.prototxt' 185 | model = MatConvNetModel(fName) 186 | imgLayer = mpu.get_layerdef_for_proto('DeployData', 'image', None, 187 | **{'ipDims': [1, 3, 224, 224]}) 188 | kpLayer = mpu.get_layerdef_for_proto('DeployData', 'kp_pos', None, 189 | **{'ipDims': [1, 17, 2, 1]}) 190 | lbLayer = mpu.get_layerdef_for_proto('DeployData', 'label', None, 191 | **{'ipDims': [1, 16, 2, 1]}) 192 | pdef = model.save_caffe_model(ipLayers=[imgLayer, kpLayer, lbLayer], layerOrder=['render1', 'concat1']) 193 | #pdef = model.to_caffe(ipLayers=[imgLayer, kpLayer, lbLayer], layerOrder=['render1', 'concat1']) 194 | #pdef.write(outName) 195 | #net = caffe.Net(outName, caffe.TEST) 196 | #return pdef, net 197 | 198 | -------------------------------------------------------------------------------- /my_examples.py: -------------------------------------------------------------------------------- 1 | import my_exp_config as mec 2 | from easydict import EasyDict as edict 3 | from os import path as osp 4 | 5 | ####### EXAMPLE 1 - CONFIGURING AN MNIST EXPERIMENT ######### 6 | ## 7 | #Define the experiment, snapshot and other required paths 8 | def get_mnist_paths(): 9 | paths = edict() 10 | #Path to store experiment details 11 | paths.exp = edict() 12 | paths.exp.dr = './test_data/mnist/exp' 13 | #Paths to store snapshot details 14 | paths.exp.snapshot = edict() 15 | paths.exp.snapshot.dr = './test_data/mnist/snapshots' 16 | return paths 17 | 18 | ## 19 | #Define any parameters that may influence the experiment details 20 | def get_mnist_prms(): 21 | prms = edict() 22 | prms['expStr'] = 'mnist' 23 | prms.paths = get_mnist_paths() 24 | return prms 25 | 26 | ## 27 | #Setup a scratch experiment 28 | def setup_experiment(): 29 | prms = get_mnist_prms() 30 | nwPrms = {'netName': 'MyNet', 31 | 'baseNetDefProto': 'trainval.prototxt'} 32 | cPrms = mec.get_caffe_prms(mec.get_default_net_prms, nwPrms, 33 | mec.get_default_solver_prms, 34 | baseDefDir='./test_data/mnist') 35 | exp = mec.CaffeSolverExperiment(prms, cPrms) 36 | exp.make() 37 | return exp 38 | 39 | ####### END OF EXAMPLE 1 ################### 40 | 41 | ####### Example 2 - FINETUNING MNIST EXPERIMENT ########### 42 | def setup_experiment_finetune(): 43 | prms = get_mnist_prms() 44 | preTrainNet = './test_data/mnist/mnist-test_iter_4000.caffemodel' 45 | #preTrainNet = None 46 | baseDefDir ='./test_data/mnist' 47 | nwPrms = {'netName': 'MyNet', 48 | 'baseNetDefProto': osp.join(baseDefDir, 'trainval.prototxt'), 49 | 'preTrainNet': preTrainNet} 50 | cPrms = mec.get_caffe_prms(mec.get_default_net_prms, nwPrms, 51 | mec.get_default_solver_prms) 52 | exp = mec.CaffeSolverExperiment(prms, cPrms) 53 | exp.make() 54 | return exp 55 | 56 | ####### END OF EXAMPLE 2 ################### 57 | -------------------------------------------------------------------------------- /my_exp_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import my_pycaffe as mp 5 | import my_pycaffe_utils as mpu 6 | from easydict import EasyDict as edict 7 | import copy 8 | import other_utils as ou 9 | import pickle 10 | import my_sqlite as msq 11 | 12 | REAL_PATH = os.path.dirname(os.path.realpath(__file__)) 13 | DEF_DB = osp.join(REAL_PATH, 'test_data/default-exp-db.sqlite') 14 | 15 | def get_sql_id(dbFile, dArgs, ignoreKeys=[]): 16 | sql = msq.SqDb(dbFile) 17 | #try: 18 | #print ('Ignore KEYS: ', ignoreKeys) 19 | sql.fetch(dArgs, ignoreKeys=ignoreKeys) 20 | idName = sql.get_id(dArgs, ignoreKeys=ignoreKeys) 21 | #except: 22 | # sql.close() 23 | # raise Exception('Error in fetching a name from database') 24 | sql.close() 25 | return idName 26 | 27 | 28 | def get_default_net_prms(dbFile=DEF_DB, **kwargs): 29 | dArgs = edict() 30 | #Name of the net which will be constructed 31 | dArgs.netName = 'alexnet' 32 | #For layers below lrAbove, learning rate is set to 0 33 | dArgs.lrAbove = None 34 | #If weights from a pretrained net are to be used 35 | dArgs.preTrainNet = None 36 | #The base proto from which net will be constructed 37 | dArgs.baseNetDefProto = None 38 | #Batch size 39 | dArgs.batchSize = None 40 | #runNum 41 | dArgs.runNum = 0 42 | dArgs = mpu.get_defaults(kwargs, dArgs, False) 43 | dArgs.expStr = get_sql_id(dbFile, dArgs) 44 | return dArgs 45 | 46 | 47 | def get_siamese_net_prms(dbFile=DEF_DB, **kwargs): 48 | dArgs = get_default_net_prms(dbFile) 49 | del dArgs['expStr'] 50 | #Layers at which the nets are to be concatenated 51 | dArgs.concatLayer = 'fc6' 52 | #If dropouts should be used in the concatenation layer 53 | dArgs.concatDrop = False 54 | #Number of filters in concatenation layer 55 | dArgs.concatSz = None 56 | #If an extra FC layer needs to be added 57 | dArgs.extraFc = None 58 | dArgs = mpu.get_defaults(kwargs, dArgs, False) 59 | dArgs.expStr = get_sql_id(dbFile, dArgs) 60 | return dArgs 61 | 62 | 63 | def get_siamese_window_net_prms(dbFile=DEF_DB, **kwargs): 64 | dArgs = get_siamese_net_prms(dbFile) 65 | del dArgs['expStr'] 66 | #Size of input image 67 | dArgs.imSz = 227 68 | #If random cropping is to be used 69 | dArgs.randCrop = False 70 | #If gray scale images need to be used 71 | dArgs.isGray = False 72 | dArgs = mpu.get_defaults(kwargs, dArgs, False) 73 | dArgs.expStr = get_sql_id(dbFile, dArgs) 74 | return dArgs 75 | 76 | ''' 77 | Defining get_custom_net_prms() 78 | def get_custom_net_prms(**kwargs): 79 | dArgs = get_your_favorite_prms() 80 | del dArgs['expStr'] 81 | ##DEFINE NEW PROPERTIES## 82 | dArgs.myNew = value 83 | ################ 84 | dArgs = mpu.get_defaults(kwargs, dArgs, False) 85 | dArgs.expStr = get_sql_id(dbFile, dArgs) 86 | return dArgs 87 | ''' 88 | 89 | ## 90 | # Parameters that specify the learning 91 | def get_default_solver_prms(dbFile=DEF_DB, **kwargs): 92 | ''' 93 | Refer to caffe.proto for a description of the 94 | variables. 95 | ''' 96 | dArgs = edict() 97 | dArgs.baseSolDefFile = None 98 | dArgs.iter_size = 1 99 | dArgs.max_iter = 250000 100 | dArgs.base_lr = 0.001 101 | dArgs.lr_policy = 'step' 102 | dArgs.stepsize = 20000 103 | dArgs.gamma = 0.5 104 | dArgs.weight_decay = 0.0005 105 | dArgs.clip_gradients = -1 106 | #Momentum 107 | dArgs.momentum = 0.9 108 | #Other 109 | dArgs.regularization_type = 'L2' 110 | dArgs.random_seed = -1 111 | #Testing info 112 | dArgs.test_iter = 100 113 | dArgs.test_interval = 1000 114 | dArgs.snapshot = 2000 115 | dArgs.display = 20 116 | #Update parameters 117 | dArgs = mpu.get_defaults(kwargs, dArgs, False) 118 | dArgs.expStr = 'solprms' + get_sql_id(dbFile, dArgs, 119 | ignoreKeys=['test_iter', 'test_interval', 120 | 'snapshot', 'display']) 121 | return dArgs 122 | 123 | ## 124 | #Make Solver 125 | def get_solver_def(solPrms): 126 | solPrms = copy.deepcopy(solPrms) 127 | del solPrms['expStr'] 128 | if solPrms['baseSolDefFile'] is not None: 129 | solDef = mpu.Solver.from_file(solPrms) 130 | else: 131 | del solPrms['baseSolDefFile'] 132 | solDef = mpu.make_solver(**solPrms) 133 | return solDef 134 | 135 | ## 136 | #Make Solver 137 | def get_net_def(dPrms, nwPrms): 138 | ''' 139 | dPrms : data parameters 140 | nwPrms: parameters that define the net 141 | ''' 142 | if nwPrms.baseNetDefProto is None: 143 | return None 144 | else: 145 | netDef = mpu.ProtoDef(nwPrms.baseNetDefProto) 146 | return netDef 147 | 148 | 149 | def get_caffe_prms(nwFn=None, nwPrms={}, solFn=None, 150 | solPrms={}, resumeIter=None): 151 | ''' 152 | nwFn, nwPrms : nwPrms are passed as input to nwFn to generate the 153 | complete set of network parameters. 154 | solFn, solPrms: solPrms are passes as input ro solFn to generate the 155 | complete set of solver parameters. 156 | resumeIter : if the experiment needs to be resumed from a previously 157 | stored number of iterations. 158 | ''' 159 | if nwFn is None: 160 | nwFn = get_default_net_prms 161 | if solFn is None: 162 | solFn = get_default_solver_prms 163 | nwPrms = nwFn(**nwPrms) 164 | solPrms = solFn(**solPrms) 165 | cPrms = edict() 166 | cPrms.nwPrms = copy.deepcopy(nwPrms) 167 | cPrms.lrPrms = copy.deepcopy(solPrms) 168 | cPrms.resumeIter = resumeIter 169 | expStr = osp.join(cPrms.nwPrms.expStr, cPrms.lrPrms.expStr) + '/' 170 | cPrms.expStr = expStr 171 | return cPrms 172 | 173 | ## 174 | # Programatically make a Caffe Experiment. 175 | class CaffeSolverExperiment: 176 | def __init__(self, dPrms, cPrms, 177 | netDefFn=get_net_def, solverDefFn=get_solver_def, 178 | isLog=True, addFiles=None): 179 | ''' 180 | dPrms: dict containing key 'expStr' and 'paths' 181 | contains dataset specific parameters 182 | cPrms: dict contraining 'expStr', 'resumeIter', 'nwPrms' 183 | contains net and sovler specific parameters 184 | isLog: if logging data should be recorded 185 | addFiles: if additional files need to be stored - not implemented yet 186 | ''' 187 | self.isLog_ = isLog 188 | dataExpName = dPrms['expStr'] 189 | caffeExpName = cPrms['expStr'] 190 | expDirPrefix = dPrms.paths.exp.dr 191 | snapDirPrefix = dPrms.paths.exp.snapshot.dr 192 | #Relevant directories. 193 | self.dPrms_ = copy.deepcopy(dPrms) 194 | self.cPrms_ = copy.deepcopy(cPrms) 195 | self.dirs_ = {} 196 | self.dirs_['exp'] = osp.join(expDirPrefix, dataExpName) 197 | self.dirs_['snap'] = osp.join(snapDirPrefix, dataExpName) 198 | self.resumeIter_ = cPrms.resumeIter 199 | self.runNum_ = cPrms.nwPrms.runNum 200 | self.preTrainNet_ = cPrms.nwPrms.preTrainNet 201 | 202 | solverFile = caffeExpName + '_solver.prototxt' 203 | defFile = caffeExpName + '_netdef.prototxt' 204 | defDeployFile = caffeExpName + '_netdef_deploy.prototxt' 205 | defRecFile = caffeExpName + '_netdef_reconstruct.prototxt' 206 | logFile = caffeExpName + '_log.pkl' 207 | snapPrefix = caffeExpName + '_caffenet_run%d' % self.runNum_ 208 | 209 | self.files_ = {} 210 | self.files_['solver'] = osp.join(self.dirs_['exp'], solverFile) 211 | self.files_['netdef'] = osp.join(self.dirs_['exp'], defFile) 212 | self.files_['netdefDeploy'] = osp.join(self.dirs_['exp'], defDeployFile) 213 | self.files_['netdefRec'] = osp.join(self.dirs_['exp'], defRecFile) 214 | self.files_['log'] = osp.join(self.dirs_['exp'], logFile) 215 | #snapshot 216 | self.files_['snap'] = osp.join(snapDirPrefix, dataExpName, 217 | snapPrefix + '_iter_%d.caffemodel') 218 | self.snapPrefix_ = '"%s"' % osp.join(snapDirPrefix, dataExpName, snapPrefix) 219 | self.snapPrefix_ = ou.chunk_filename(self.snapPrefix_, maxLen=242) 220 | #Chunk all the filnames if needed 221 | for key in self.files_.keys(): 222 | self.files_[key] = ou.chunk_filename(self.files_[key]) 223 | 224 | #Store the solver and the net definition files 225 | self.expFile_ = edict() 226 | self.expFile_.solDef_ = solverDefFn(cPrms.lrPrms) 227 | self.setup_solver() 228 | self.expFile_.netDef_ = netDefFn(dPrms, cPrms.nwPrms) 229 | #Other class parameters 230 | self.solver_ = None 231 | self.expMake_ = False 232 | 233 | def setup_solver(self): 234 | self.expFile_.solDef_.add_property('device_id', 0) 235 | self.expFile_.solDef_.set_property('net', '"%s"' % self.files_['netdef']) 236 | self.expFile_.solDef_.set_property('snapshot_prefix', self.snapPrefix_) 237 | 238 | ## 239 | def del_layer(self, layerName): 240 | self.expFile_.netDef_.del_layer(layerName) 241 | 242 | ## 243 | def del_all_layers_above(self, layerName): 244 | self.expFile_.netDef_.del_all_layers_above(layerName) 245 | 246 | ## Get layer property 247 | def get_layer_property(self, layerName, propName, **kwargs): 248 | return self.expFile_.netDef_.get_layer_property(layerName, propName, **kwargs) 249 | 250 | ## Set the property. 251 | def set_layer_property(self, layerName, propName, value, **kwargs): 252 | self.expFile_.netDef_.set_layer_property(layerName, propName, value, **kwargs) 253 | 254 | ## 255 | def add_layer(self, layerName, layer, phase): 256 | self.expFile_.netDef_.add_layer(layerName, layer, phase) 257 | ## 258 | # Get the layerNames from the type of the layer. 259 | def get_layernames_from_type(self, layerType, phase='TRAIN'): 260 | return self.expFile_.netDef_.get_layernames_from_type(layerType, phase=phase) 261 | 262 | ## 263 | def get_snapshot_name(self, numIter=10000, getSolverFile=False): 264 | ''' 265 | Find the name with which models are being stored. 266 | ''' 267 | assert self.expFile_.solDef_ is not None, 'Solver has not been formed' 268 | snapshot = self.expFile_.solDef_.get_property('snapshot_prefix') 269 | #_iter_%d.caffemodel is added by caffe while snapshotting. 270 | snapshotName = snapshot[1:-1] + '_iter_%d.caffemodel' % numIter 271 | snapshotName = ou.chunk_filename(snapshotName) 272 | #solver file 273 | solverName = snapshot[1:-1] + '_iter_%d.solverstate' % numIter 274 | solverName = ou.chunk_filename(solverName) 275 | if getSolverFile: 276 | return solverName 277 | else: 278 | return snapshotName 279 | 280 | 281 | ## Only finetune the layers that are above ( and including) layerName 282 | def finetune_above(self, layerName): 283 | self.expFile_.netDef_.set_no_learning_until(layerName) 284 | 285 | ## All layernames 286 | def get_all_layernames(self, phase='TRAIN'): 287 | return self.expFile_.netDef_.get_all_layernames(phase=phase) 288 | 289 | ## Get the top name of the last layer 290 | def get_last_top_name(self): 291 | return self.expFile_.netDef_.get_last_top_name() 292 | 293 | ##Setup the network 294 | def setup_net(self, **kwargs): 295 | if self.net_ is None: 296 | self.make(**kwargs) 297 | snapName = self.get_snapshot_name(kwargs['modelIter']) 298 | self.net_ = mp.MyNet(self.files_['netdef'], snapName) 299 | 300 | ##Get weights from a layer. 301 | def get_weights(self, layerName, **kwargs): 302 | self.setup_net(**kwargs) 303 | return self.net_.net.params[layerName][0].data 304 | 305 | 306 | # Make the experiment. 307 | def make(self, deviceId=0, dumpLogFreq=1000): 308 | ''' 309 | deviceId: the gpu on which to run 310 | ''' 311 | assert self.expFile_.solDef_ is not None, 'SolverDef has not been formed' 312 | assert self.expFile_.netDef_ is not None, 'NetDef has not been formed' 313 | self.expFile_.solDef_.set_property('device_id', deviceId) 314 | #Write the solver and the netdef file 315 | if not osp.exists(self.dirs_['exp']): 316 | os.makedirs(self.dirs_['exp']) 317 | if not osp.exists(self.dirs_['snap']): 318 | os.makedirs(self.dirs_['snap']) 319 | self.expFile_.netDef_.write(self.files_['netdef']) 320 | self.expFile_.solDef_.write(self.files_['solver']) 321 | #Create the solver 322 | self.solver_ = mp.MySolver.from_file(self.files_['solver'], 323 | dumpLogFreq=dumpLogFreq, logFile=self.files_['log'], 324 | isLog=self.isLog_) 325 | 326 | if self.resumeIter_ is not None: 327 | solverStateFile = self.get_snapshot_name(self.resumeIter_, getSolverFile=True) 328 | assert osp.exists(solverStateFile), '%s not present' % solverStateFile 329 | self.solver_.restore(solverStateFile) 330 | elif self.preTrainNet_ is not None: 331 | assert (self.resumeIter_ is None) 332 | self.solver_.copy_weights(self.preTrainNet_) 333 | self.expMake_ = True 334 | 335 | ## Make the deploy file. 336 | def make_deploy(self, dataLayerNames, imSz, **kwargs): 337 | self.deployProto_ = mpu.ProtoDef.deploy_from_proto(self.expFile_.netDef_, 338 | dataLayerNames=dataLayerNames, imSz=imSz, **kwargs) 339 | self.deployProto_.write(self.files_['netdefDeploy']) 340 | 341 | ## Get the deploy proto 342 | def get_deploy_proto(self): 343 | if not(osp.exists(self.files_['netdefDeploy'])): 344 | self.make_deploy, 345 | return self.deployProto_ 346 | 347 | ## Get deploy file 348 | def get_deploy_file(self): 349 | return self.files_['netdefDeploy'] 350 | 351 | ## 352 | # Run the experiment 353 | def run(self, recFreq=20): 354 | if not self.expMake_: 355 | print ('Make the experiment using exp.make(), before running, returning') 356 | return 357 | self.solver_.solve() 358 | 359 | ## 360 | #Visualize the log 361 | def vis_log(self): 362 | self.solver_.read_log_from_file() 363 | self.solver_.plot() 364 | 365 | def get_test_accuracy(self): 366 | print ('NOT IMPLEMENTED YET') 367 | 368 | 369 | def get_default_caffe_prms(deviceId=1): 370 | nwPrms = get_nw_prms() 371 | lrPrms = get_lr_prms() 372 | cPrms = get_caffe_prms(nwPrms, lrPrms, deviceId=deviceId) 373 | return cPrms 374 | 375 | def get_experiment_object(prms, cPrms): 376 | #Legacy support 377 | if prms['paths'].has_key('exp'): 378 | expDir = prms.paths.exp.dr 379 | else: 380 | expDir = prms['paths']['expDir'] 381 | caffeExp = mpu.CaffeExperiment(prms['expName'], cPrms['expStr'], 382 | expDir, prms.paths.exp.snapshot.dr, 383 | deviceId=cPrms['deviceId']) 384 | return caffeExp 385 | 386 | 387 | 388 | -------------------------------------------------------------------------------- /my_exp_config_old.py: -------------------------------------------------------------------------------- 1 | def get_caffe_prms_old(nwPrms, lrPrms, finePrms=None, 2 | isScratch=True, deviceId=1, 3 | runNum=0, resumeIter=0): 4 | caffePrms = edict() 5 | caffePrms.deviceId = deviceId 6 | caffePrms.isScratch = isScratch 7 | caffePrms.nwPrms = copy.deepcopy(nwPrms) 8 | caffePrms.lrPrms = copy.deepcopy(lrPrms) 9 | caffePrms.finePrms = copy.deepcopy(finePrms) 10 | caffePrms.resumeIter = resumeIter 11 | 12 | expStr = nwPrms.expStr + '/' + lrPrms.expStr 13 | if finePrms is not None: 14 | expStr = expStr + '/' + finePrms.expStr 15 | if runNum > 0: 16 | expStr = expStr + '_run%d' % runNum 17 | caffePrms['expStr'] = expStr 18 | caffePrms['solver'] = lrPrms.solver 19 | return caffePrms 20 | 21 | 22 | 23 | ## 24 | # Parameters required to specify the n/w architecture 25 | def get_nw_prms(isHashStr=False, **kwargs): 26 | dArgs = edict() 27 | dArgs.netName = 'alexnet' 28 | dArgs.concatLayer = 'fc6' 29 | dArgs.concatDrop = False 30 | dArgs.contextPad = 0 31 | dArgs.imSz = 227 32 | dArgs.imgntMean = True 33 | dArgs.maxJitter = 0 34 | dArgs.randCrop = False 35 | dArgs.lossWeight = 1.0 36 | dArgs.multiLossProto = None 37 | dArgs.ptchStreamNum = 256 38 | dArgs.poseStreamNum = 256 39 | dArgs.isGray = False 40 | dArgs.isPythonLayer = False 41 | dArgs.extraFc = None 42 | dArgs.numFc5 = None 43 | dArgs.numConv4 = None 44 | dArgs.numCommonFc = None 45 | dArgs.lrAbove = None 46 | dArgs = mpu.get_defaults(kwargs, dArgs) 47 | if dArgs.numFc5 is not None: 48 | assert(dArgs.concatLayer=='fc5') 49 | expStr = 'net-%s_cnct-%s_cnctDrp%d_contPad%d_imSz%d_imgntMean%d_jit%d'\ 50 | %(dArgs.netName, dArgs.concatLayer, dArgs.concatDrop, 51 | dArgs.contextPad, 52 | dArgs.imSz, dArgs.imgntMean, dArgs.maxJitter) 53 | if dArgs.numFc5 is not None: 54 | expStr = '%s_numFc5-%d' % (expStr, dArgs.numFc5) 55 | if dArgs.numConv4 is not None: 56 | expStr = '%s_numConv4-%d' % (expStr, dArgs.numConv4) 57 | if dArgs.numCommonFc is not None: 58 | expStr = '%s_numCommonFc-%d' % (expStr, dArgs.numCommonFc) 59 | if dArgs.randCrop: 60 | expStr = '%s_randCrp%d' % (expStr, dArgs.randCrop) 61 | if not(dArgs.lossWeight==1.0): 62 | if type(dArgs.lossWeight)== list: 63 | lStr = '' 64 | for i,l in enumerate(dArgs.lossWeight): 65 | lStr = lStr + 'lw%d-%.1f_' % (i,l) 66 | lStr = lStr[0:-1] 67 | print lStr 68 | expStr = '%s_%s' % (expStr, lStr) 69 | else: 70 | assert isinstance(dArgs.lossWeight, (int, long, float)) 71 | expStr = '%s_lw%.1f' % (expStr, dArgs.lossWeight) 72 | if dArgs.multiLossProto is not None: 73 | expStr = '%s_mlpr%s-posn%d-ptsn%d' % (expStr, 74 | dArgs.multiLossProto, dArgs.poseStreamNum, dArgs.ptchStreamNum) 75 | if dArgs.isGray: 76 | expStr = '%s_grayIm' % expStr 77 | if dArgs.isPythonLayer: 78 | expStr = '%s_pylayers' % expStr 79 | if dArgs.extraFc is not None: 80 | expStr = '%s_extraFc%d' % (expStr, dArgs.extraFc) 81 | if dArgs.lrAbove is not None: 82 | expStr = '%s_lrAbove-%s' % (expStr, dArgs.lrAbove) 83 | if not isHashStr: 84 | dArgs.expStr = expStr 85 | else: 86 | dArgs.expStr = 'nwPrms-%s' % ou.hash_dict_str(dArgs) 87 | return dArgs 88 | 89 | ## 90 | # Parameters that specify the learning 91 | def get_lr_prms(isHashStr=False, **kwargs): 92 | dArgs = edict() 93 | dArgs.batchsize = 128 94 | dArgs.stepsize = 20000 95 | dArgs.base_lr = 0.001 96 | dArgs.max_iter = 250000 97 | dArgs.gamma = 0.5 98 | dArgs.weight_decay = 0.0005 99 | dArgs.clip_gradients = -1 100 | dArgs.debug_info = False 101 | dArgs = mpu.get_defaults(kwargs, dArgs) 102 | #Make the solver 103 | debugStr = '%s' % dArgs.debug_info 104 | debugStr = debugStr.lower() 105 | del dArgs['debug_info'] 106 | solArgs = edict({'test_iter': 100, 'test_interval': 1000, 107 | 'snapshot': 2000, 108 | 'debug_info': debugStr}) 109 | #print dArgs.keys() 110 | expStr = 'batchSz%d_stepSz%.0e_blr%.5f_mxItr%.1e_gamma%.2f_wdecay%.6f'\ 111 | % (dArgs.batchsize, dArgs.stepsize, dArgs.base_lr, 112 | dArgs.max_iter, dArgs.gamma, dArgs.weight_decay) 113 | if not(dArgs.clip_gradients==-1): 114 | expStr = '%s_gradClip%.1f' % (expStr, dArgs.clip_gradients) 115 | if not isHashStr: 116 | dArgs.expStr = expStr 117 | else: 118 | dArgs.expStr = 'lrPrms-%s' % ou.hash_dict_str(dArgs) 119 | for k in dArgs.keys(): 120 | if k in ['batchsize', 'expStr']: 121 | continue 122 | solArgs[k] = copy.deepcopy(dArgs[k]) 123 | 124 | dArgs.solver = mpu.make_solver(**solArgs) 125 | return dArgs 126 | 127 | ## 128 | # Parameters for fine-tuning 129 | def get_finetune_prms(isHashStr=False, **kwargs): 130 | ''' 131 | sourceModelIter: The number of model iterations of the source model to consider 132 | fine_max_iter : The maximum iterations to which the target model should be trained. 133 | lrAbove : If learning is to be performed some layer. 134 | fine_base_lr : The base learning rate for finetuning. 135 | fineRunNum : The run num for the finetuning. 136 | fineNumData : The amount of data to be used for the finetuning. 137 | fineMaxLayer : The maximum layer of the source n/w that should be considered. 138 | ''' 139 | dArgs = edict() 140 | dArgs.base_lr = 0.001 141 | dArgs.runNum = 1 142 | dArgs.maxLayer = None 143 | dArgs.lrAbove = None 144 | dArgs.dataset = 'sun' 145 | dArgs.maxIter = 40000 146 | dArgs.extraFc = False 147 | dArgs.extraFcDrop = False 148 | dArgs.sourceModelIter = 60000 149 | dArgs = mpu.get_defaults(kwargs, dArgs) 150 | return dArgs 151 | 152 | 153 | def get_caffe_prms(nwPrms, lrPrms, finePrms=None, 154 | isScratch=True, deviceId=1, 155 | runNum=0, resumeIter=0): 156 | caffePrms = edict() 157 | caffePrms.deviceId = deviceId 158 | caffePrms.isScratch = isScratch 159 | caffePrms.nwPrms = copy.deepcopy(nwPrms) 160 | caffePrms.lrPrms = copy.deepcopy(lrPrms) 161 | caffePrms.finePrms = copy.deepcopy(finePrms) 162 | caffePrms.resumeIter = resumeIter 163 | 164 | expStr = nwPrms.expStr + '/' + lrPrms.expStr 165 | if finePrms is not None: 166 | expStr = expStr + '/' + finePrms.expStr 167 | if runNum > 0: 168 | expStr = expStr + '_run%d' % runNum 169 | caffePrms['expStr'] = expStr 170 | caffePrms['solver'] = lrPrms.solver 171 | return caffePrms 172 | 173 | -------------------------------------------------------------------------------- /my_netspec.py: -------------------------------------------------------------------------------- 1 | ## @package my_netspec 2 | # This was developed indepedently of caffe's netspec in 2014. 3 | # This maybe phased out in favor of caffe's netspec in the near 4 | # future 5 | # 6 | 7 | import my_pycaffe_utils as mpu 8 | import os 9 | 10 | 11 | def make_def_proto(nw, isSiamese=True, 12 | baseFileStr='split_im.prototxt', getStreamTopNames=False): 13 | ''' 14 | If is siamese then wait for the Concat layers - and make all layers until then siamese. 15 | ''' 16 | baseFile = os.path.join(baseFileStr) 17 | protoDef = mpu.ProtoDef(baseFile) 18 | 19 | #if baseFileStr in ['split_im.prototxt', 'normal.prototxt']: 20 | lastTop = 'data' 21 | 22 | siameseFlag = isSiamese 23 | stream1, stream2 = [], [] 24 | mainStream = [] 25 | 26 | nameGen = mpu.LayerNameGenerator() 27 | for l in nw: 28 | lType, lParam = l 29 | lName = nameGen.next_name(lType) 30 | #To account for layers that should not copied while finetuning 31 | # Such layers need to named differently. 32 | if lParam.has_key('nameDiff'): 33 | lName = lName + '-%s' % lParam['nameDiff'] 34 | if lType == 'Concat': 35 | siameseFlag = False 36 | if not lParam.has_key('bottom2'): 37 | lParam['bottom2'] = lastTop + '_p' 38 | 39 | if siameseFlag: 40 | lDef, lsDef = mpu.get_siamese_layerdef_for_proto(lType, lName, lastTop, **lParam) 41 | stream1.append(lDef) 42 | stream2.append(lsDef) 43 | else: 44 | lDef = mpu.get_layerdef_for_proto(lType, lName, lastTop, **lParam) 45 | mainStream.append(lDef) 46 | 47 | if lParam.has_key('shareBottomWithNext'): 48 | assert lParam['shareBottomWithNext'] 49 | pass 50 | else: 51 | lastTop = lName 52 | 53 | #Add layers 54 | mainStream = stream1 + stream2 + mainStream 55 | for l in mainStream: 56 | protoDef.add_layer(l['name'][1:-1], l) 57 | 58 | if getStreamTopNames: 59 | if isSiamese: 60 | top1Name = stream1[-1]['name'][1:-1] 61 | top2Name = stream2[-1]['name'][1:-1] 62 | else: 63 | top1Name, top2Name = None, None 64 | return protoDef, top1Name, top2Name 65 | else: 66 | return protoDef 67 | 68 | 69 | ## 70 | # Generates a string to represent the n/w name 71 | def nw2name(nw, getLayerNames=False): 72 | nameGen = mpu.LayerNameGenerator() 73 | nwName = [] 74 | allNames = [] 75 | for l in nw: 76 | lType, lParam = l 77 | lName = nameGen.next_name(lType) 78 | if lParam.has_key('nameDiff'): 79 | allNames.append(lName + '-%s' % lParam['nameDiff']) 80 | else: 81 | allNames.append(lName) 82 | if lType in ['InnerProduct', 'Convolution']: 83 | lName = lName + '-%d' % lParam['num_output'] 84 | if lType == 'Convolution': 85 | lName = lName + 'sz%d-st%d' % (lParam['kernel_size'], lParam['stride']) 86 | nwName.append(lName) 87 | elif lType in ['Pooling']: 88 | lName = lName + '-sz%d-st%d' % (lParam['kernel_size'], 89 | lParam['stride']) 90 | nwName.append(lName) 91 | elif lType in ['Concat', 'Dropout', 'Sigmoid']: 92 | nwName.append(lName) 93 | elif lType in ['RandomNoise']: 94 | if lParam.has_key('adaptive_sigma'): 95 | lName = lName + '-asig%.2f' % lParam['adaptive_factor'] 96 | else: 97 | lName = lName + '-sig%.2f' % lParam['sigma'] 98 | nwName.append(lName) 99 | else: 100 | pass 101 | nwName = ''.join(s + '_' for s in nwName) 102 | nwName = nwName[:-1] 103 | if getLayerNames: 104 | return nwName, allNames 105 | else: 106 | return nwName 107 | 108 | ## 109 | # This is highly hand engineered to suit my current needs for ICCV submission. 110 | def nw2name_small(nw, isLatexName=False): 111 | nameGen = mpu.LayerNameGenerator() 112 | nwName = [] 113 | latexName = [] 114 | for l in nw: 115 | lType, lParam = l 116 | lName = '' 117 | latName = '' 118 | if lType in ['Convolution']: 119 | lName = 'C%d_k%d' % (lParam['num_output'], lParam['kernel_size']) 120 | latName = 'C%d' % (lParam['num_output']) 121 | nwName.append(lName) 122 | latexName.append(latName) 123 | elif lType in ['InnerProduct']: 124 | lName = 'F%d' % lParam['num_output'] 125 | nwName.append(lName) 126 | latexName.append(latName) 127 | elif lType in ['Pooling']: 128 | lName = lName + 'P' 129 | nwName.append(lName) 130 | latexName.append(lName) 131 | elif lType in ['Sigmoid']: 132 | lName = lName + 'S' 133 | nwName.append(lName) 134 | latexName.append(lName) 135 | elif lType in ['Concat']: 136 | break 137 | else: 138 | pass 139 | nwName = ''.join(s + '-' for s in nwName) 140 | nwName = nwName[:-1] 141 | 142 | latexName = ''.join(s + '-' for s in latexName) 143 | latexName = latexName[:-1] 144 | 145 | if isLatexName: 146 | return nwName, latexName 147 | else: 148 | return nwName 149 | 150 | 151 | -------------------------------------------------------------------------------- /my_pycaffe.py: -------------------------------------------------------------------------------- 1 | ## @package my_pycaffe 2 | # Major Wrappers. 3 | # 4 | 5 | import numpy as np 6 | import caffe 7 | import pdb 8 | import matplotlib.pyplot as plt 9 | import os 10 | from os import path as osp 11 | from six import string_types 12 | import copy 13 | from easydict import EasyDict as edict 14 | import my_pycaffe_utils as mpu 15 | import my_pycaffe_io as mpio 16 | import pickle 17 | import collections as co 18 | import time 19 | try: 20 | import h5py 21 | except: 22 | print ('WARNING: h5py not found, some functions may not work') 23 | 24 | class layerSz: 25 | def __init__(self, stride, filterSz): 26 | self.imSzPrev = [] #We will assume square images for now 27 | self.stride = stride #Stride with which filters are applied 28 | self.filterSz = filterSz #Size of filters. 29 | self.stridePixPrev = [] #Stride in image pixels of the filters in the previous layers. 30 | self.pixelSzPrev = [] #Size of the filters in the previous layers in the image space 31 | #To be computed 32 | self.pixelSz = [] #the receptive field size of the filter in the original image. 33 | self.stridePix = [] #Stride of units in the image pixel domain. 34 | 35 | def prev_prms(self, prevLayer): 36 | self.set_prev_prms(prevLayer.stridePix, prevLayer.pixelSz) 37 | 38 | def set_prev_prms(self, stridePixPrev, pixelSzPrev): 39 | self.stridePixPrev = stridePixPrev 40 | self.pixelSzPrev = pixelSzPrev 41 | 42 | def compute(self): 43 | self.pixelSz = self.pixelSzPrev + (self.filterSz-1)*self.stridePixPrev 44 | self.stridePix = self.stride * self.stridePixPrev 45 | 46 | 47 | ## 48 | # Calculate the receptive field size and the stride of the Alex-Net 49 | # Something 50 | def calculate_size(): 51 | conv1 = layerSz(4,11) 52 | conv1.set_prev_prms(1,1) 53 | conv1.compute() 54 | pool1 = layerSz(2,3) 55 | pool1.prev_prms(conv1) 56 | pool1.compute() 57 | 58 | conv2 = layerSz(1,5) 59 | conv2.prev_prms(pool1) 60 | conv2.compute() 61 | pool2 = layerSz(2,3) 62 | pool2.prev_prms(conv2) 63 | pool2.compute() 64 | 65 | conv3 = layerSz(1,3) 66 | conv3.prev_prms(pool2) 67 | conv3.compute() 68 | 69 | conv4 = layerSz(1,3) 70 | conv4.prev_prms(conv3) 71 | conv4.compute() 72 | 73 | conv5 = layerSz(1,3) 74 | conv5.prev_prms(conv4) 75 | conv5.compute() 76 | pool5 = layerSz(2,3) 77 | pool5.prev_prms(conv5) 78 | pool5.compute() 79 | 80 | print 'Pool1: Receptive: %d, Stride: %d ' % (pool1.pixelSz, pool1.stridePix) 81 | print 'Pool2: Receptive: %d, Stride: %d ' % (pool2.pixelSz, pool2.stridePix) 82 | print 'Conv3: Receptive: %d, Stride: %d ' % (conv3.pixelSz, conv3.stridePix) 83 | print 'Conv4: Receptive: %d, Stride: %d ' % (conv4.pixelSz, conv4.stridePix) 84 | print 'Pool5: Receptive: %d, Stride: %d ' % (pool5.pixelSz, pool5.stridePix) 85 | 86 | ## 87 | # Find the Layer Type 88 | def find_layer(lines): 89 | layerName = [] 90 | for l in lines: 91 | if 'type' in l: 92 | _,layerName = l.split() 93 | return layerName 94 | 95 | ## 96 | # Find the Layer Name 97 | def find_layer_name(lines): 98 | layerName = None 99 | topName = None 100 | flagCount = 0 101 | firstL = lines[0] 102 | assert firstL.split()[1] is '{', 'Something is wrong' 103 | brackCount = 1 104 | for l in lines[1:]: 105 | if '{' in l: 106 | brackCount += 1 107 | if '}' in l: 108 | brackCount -= 1 109 | if brackCount ==0: 110 | break 111 | if 'name' in l and brackCount==1: 112 | flagCount += 1 113 | _,layerName = l.split() 114 | layerName = layerName[1:-1] 115 | if 'top' in l and brackCount==1: 116 | flagCount += 1 117 | _,topName = l.split() 118 | topName = topName[1:-1] 119 | 120 | assert layerName is not None, 'no name of a layer found' 121 | return layerName, topName 122 | 123 | 124 | ## 125 | # Converts definition file of a network into siamese network. 126 | def netdef2siamese(defFile, outFile): 127 | outFid = open(outFile,'w') 128 | stream1, stream2 = [],[] 129 | with open(defFile,'r') as fid: 130 | lines = fid.readlines() 131 | layerFlag = 0 132 | for (i,l) in enumerate(lines): 133 | #Indicates if the line has been added or not 134 | addFlag = False 135 | if 'layers' in l: 136 | layerName = find_layer(lines[i:]) 137 | print layerName 138 | #Go in the state that this a useful layer to model. 139 | if layerName in ['CONVOLUTION', 'INNER_PRODUCT']: 140 | layerFlag = 1 141 | 142 | #Manage the top, bottom and name for the two streams in case of a layer with params. 143 | if 'bottom' in l or 'top' in l or 'name' in l: 144 | stream1.append(l) 145 | pre, suf = l.split() 146 | suf = suf[0:-1] + '_p"' 147 | newL = pre + ' ' + suf + '\n' 148 | stream2.append(newL) 149 | addFlag = True 150 | 151 | #Store the name of the parameters 152 | if layerFlag > 0 and 'name' in l: 153 | _,paramName = l.split() 154 | paramName = paramName[1:-1] 155 | 156 | #Dont want to overcount '{' multiple times for the line 'layers {' 157 | if (layerFlag > 0) and ('{' in l) and ('layers' not in l): 158 | layerFlag += 1 159 | 160 | if '}' in l: 161 | print layerFlag 162 | layerFlag = layerFlag - 1 163 | #Before the ending the layer definition inlucde the param 164 | if layerFlag == 0: 165 | stream1.append('\t param: "%s" \n' % (paramName + '_w')) 166 | stream1.append('\t param: "%s" \n' % (paramName + '_b')) 167 | stream2.append('\t param: "%s" \n' % (paramName + '_w')) 168 | stream2.append('\t param: "%s" \n' % (paramName + '_b')) 169 | 170 | if not addFlag: 171 | stream1.append(l) 172 | stream2.append(l) 173 | 174 | #Write the first stream of the siamese net. 175 | for l in stream1: 176 | outFid.write('%s' % l) 177 | 178 | #Write the second stream of the siamese net. 179 | skipFlag = False 180 | layerFlag = 0 181 | for (i,l) in enumerate(stream2): 182 | if 'layers' in l: 183 | layerName = find_layer(stream2[i:]) 184 | #Skip writing the data layers in stream 2 185 | if layerName in ['DATA']: 186 | skipFlag = True 187 | layerFlag = 1 188 | 189 | #Dont want to overcount '{' multiple times for the line 'layers {' 190 | if layerFlag > 1 and '{' in l: 191 | layerFlag += 1 192 | 193 | #Write to the out File 194 | if not skipFlag: 195 | outFid.write('%s' % l) 196 | 197 | if '}' in l: 198 | layerFlag = layerFlag - 1 199 | if layerFlag == 0: 200 | skipFlag = False 201 | outFid.close() 202 | 203 | 204 | ## 205 | # Get Model and Mean file for a mdoel. 206 | # 207 | # the model file - the .caffemodel with the weights 208 | # the mean file of the imagenet data 209 | def get_model_mean_file(netName='vgg'): 210 | modelDir = '/data1/pulkitag/caffe_models/' 211 | bvlcDir = modelDir + 'bvlc_reference/' 212 | if netName in ['alex', 'alex_deploy']: 213 | modelFile = modelDir + 'caffe_imagenet_train_iter_310000' 214 | imMeanFile = modelDir + 'ilsvrc2012_mean.binaryproto' 215 | elif netName == 'bvlcAlexNet': 216 | modelFile = bvlcDir + 'bvlc_reference_caffenet.caffemodel' 217 | imMeanFile = bvlcDir + 'imagenet_mean.binaryproto' 218 | elif netName == 'vgg': 219 | modelFile = '/data1/pulkitag/caffe_models/VGG_ILSVRC_19_layers.caffemodel' 220 | imMeanFile = '/data1/pulkitag/caffe_models/ilsvrc2012_mean.binaryproto' 221 | elif netName == 'lenet': 222 | modelFile = '/data1/pulkitag/mnist/snapshots/lenet_iter_20000.caffemodel' 223 | imMeanFile = None 224 | else: 225 | print 'netName not recognized' 226 | return 227 | 228 | return modelFile, imMeanFile 229 | 230 | 231 | ## 232 | # Get the architecture definition file. 233 | def get_layer_def_files(netName='vgg', layerName='pool4'): 234 | ''' 235 | Returns 236 | the architecture definition file of the network uptil layer layerName 237 | ''' 238 | modelDir = '/data1/pulkitag/caffe_models/' 239 | bvlcDir = modelDir + 'bvlc_reference/' 240 | if netName=='vgg': 241 | defFile = modelDir + 'layer_def_files/vgg_19_%s.prototxt' % layerName 242 | elif netName == 'alex_deploy': 243 | if layerName is not None: 244 | defFile = bvlcDir + 'caffenet_deploy_%s.prototxt' % layerName 245 | else: 246 | defFile = bvlcDir + 'caffenet_deploy.prototxt' 247 | else: 248 | print 'Cannont get files for networks other than VGG' 249 | return defFile 250 | 251 | 252 | ## 253 | # Get the shape of input blob from the defFile 254 | def get_input_blob_shape(defFile): 255 | blobShape = [] 256 | with open(defFile,'r') as f: 257 | lines = f.readlines() 258 | ipMode = False 259 | for l in lines: 260 | if 'input:' in l: 261 | ipMode = True 262 | if ipMode and 'input_dim:' in l: 263 | ips = l.split() 264 | blobShape.append(int(ips[1])) 265 | return blobShape 266 | 267 | 268 | 269 | 270 | class MyNet: 271 | def __init__(self, defFile, modelFile=None, isGPU=True, testMode=True, deviceId=None): 272 | self.defFile_ = defFile 273 | self.modelFile_ = modelFile 274 | self.testMode_ = testMode 275 | self.set_mode(isGPU, deviceId=deviceId) 276 | self.setup_network() 277 | self.transformer = {} 278 | 279 | 280 | def setup_network(self): 281 | if self.testMode_: 282 | if not self.modelFile_ is None: 283 | self.net = caffe.Net(self.defFile_, self.modelFile_, caffe.TEST) 284 | else: 285 | self.net = caffe.Net(self.defFile_, caffe.TEST) 286 | else: 287 | if not self.modelFile_ is None: 288 | self.net = caffe.Net(self.defFile_, self.modelFile_, caffe.TRAIN) 289 | else: 290 | self.net = caffe.Net(self.defFile_, caffe.TRAIN) 291 | self.batchSz = self.get_batchsz() 292 | 293 | 294 | def set_mode(self, isGPU=True, deviceId=None): 295 | if isGPU: 296 | caffe.set_mode_gpu() 297 | else: 298 | caffe.set_mode_cpu() 299 | if deviceId is not None: 300 | caffe.set_device(deviceId) 301 | 302 | 303 | def get_batchsz(self): 304 | if len(self.net.inputs) > 0: 305 | return self.net.blobs[self.net.inputs[0]].num 306 | else: 307 | return None 308 | 309 | def get_blob_shape(self, blobName): 310 | assert blobName in self.net.blobs.keys(), 'Blob Name is not present in the net' 311 | blob = self.net.blobs[blobName] 312 | return blob.num, blob.channels, blob.height, blob.width 313 | 314 | 315 | def set_preprocess(self, ipName='data',chSwap=(2,1,0), meanDat=None, 316 | imageDims=None, isBlobFormat=False, rawScale=None, cropDims=None, 317 | noTransform=False, numCh=3): 318 | ''' 319 | isBlobFormat: if the images are already coming in blobFormat or not. 320 | ipName : the blob for which the pre-processing parameters need to be set. 321 | meanDat : the mean which needs to subtracted 322 | imageDims : the size of the images as H * W * K where K is the number of channels 323 | cropDims : the size to which the image needs to be cropped. 324 | if None - then it is automatically determined 325 | this behavior is undesirable for some deploy prototxts 326 | noTransform: if no transform needs to be applied 327 | numCh : number of channels 328 | ''' 329 | if chSwap is not None: 330 | assert len(chSwap) == numCh, 'Number of channels mismatch (%d, %d)'\ 331 | %(len(chSwap), numCh) 332 | if noTransform: 333 | self.transformer[ipName] = None 334 | return 335 | self.transformer[ipName] = caffe.io.Transformer({ipName: self.net.blobs[ipName].data.shape}) 336 | #Note blobFormat will be so used that finally the image will need to be flipped. 337 | self.transformer[ipName].set_transpose(ipName, (2,0,1)) 338 | 339 | if isBlobFormat: 340 | assert chSwap is None, 'With Blob format chSwap should be none' 341 | 342 | if chSwap is not None: 343 | #Required for eg RGB to BGR conversion. 344 | print (ipName, chSwap) 345 | self.transformer[ipName].set_channel_swap(ipName, chSwap) 346 | 347 | if rawScale is not None: 348 | self.transformer[ipName].set_raw_scale(ipName, rawScale) 349 | 350 | #Crop Dimensions 351 | ipDims = np.array(self.net.blobs[ipName].data.shape) 352 | if cropDims is not None: 353 | assert len(cropDims)==2, 'Length of cropDims needs to be corrected' 354 | self.cropDims = np.array(cropDims) 355 | else: 356 | self.cropDims = ipDims[2:] 357 | self.isBlobFormat = isBlobFormat 358 | if imageDims is None: 359 | imageDims = np.array([ipDims[2], ipDims[3], ipDims[1]]) 360 | else: 361 | assert len(imageDims)==3 362 | imageDims = np.array([imageDims[0], imageDims[1], imageDims[2]]) 363 | self.imageDims = imageDims 364 | self.get_crop_dims() 365 | 366 | #Mean Subtraction 367 | if not meanDat is None: 368 | isTuple = False 369 | if isinstance(meanDat, string_types): 370 | meanDat = mpio.read_mean(meanDat) 371 | elif type(meanDat) == tuple: 372 | meanDat = np.array(meanDat).reshape(numCh,1,1) 373 | meanDat = meanDat * (np.ones((numCh, self.crop[2] - self.crop[0],\ 374 | self.crop[3]-self.crop[1])).astype(np.float32)) 375 | isTuple = True 376 | _,h,w = meanDat.shape 377 | assert self.imageDims[0]<=h and self.imageDims[1]<=w,\ 378 | 'imageDims must match mean Image size, (h,w), (imH, imW): (%d, %d), (%d,%d)'\ 379 | % (h,w,self.imageDims[0],self.imageDims[1]) 380 | if not isTuple: 381 | meanDat = meanDat[:, self.crop[0]:self.crop[2], self.crop[1]:self.crop[3]] 382 | self.transformer[ipName].set_mean(ipName, meanDat) 383 | 384 | 385 | def get_crop_dims(self): 386 | # Take center crop. 387 | center = np.array(self.imageDims[0:2]) / 2.0 388 | crop = np.tile(center, (1, 2))[0] + np.concatenate([ 389 | -self.cropDims / 2.0, 390 | self.cropDims / 2.0 391 | ]) 392 | self.crop = crop 393 | 394 | 395 | def resize_batch(self, ims): 396 | if ims.shape[0] > self.batchSz: 397 | assert False, "More input images than the batch sz" 398 | if ims.shape[0] == self.batchSz: 399 | return ims 400 | print "Adding Zero Images to fix the batch, Size: %d" % ims.shape[0] 401 | N,ch,h,w = ims.shape 402 | imZ = np.zeros((self.batchSz - N, ch, h, w)) 403 | ims = np.concatenate((ims, imZ)) 404 | return ims 405 | 406 | 407 | def preprocess_batch(self, ims, ipName='data'): 408 | ''' 409 | ims: iterator over H * W * K sized images (K - number of channels) or K * H * W format. 410 | ''' 411 | #The image necessary needs to be float - otherwise caffe.io.resize fucks up. 412 | assert ipName in self.transformer.keys() 413 | ims = ims.astype(np.float32) 414 | if self.transformer[ipName] is None: 415 | ims = self.resize_batch(ims) 416 | return ims.astype(np.float32) 417 | 418 | if np.max(ims)<=1.0: 419 | print "There maybe issues with image scaling. The maximum pixel value is 1.0 and not 255.0" 420 | 421 | im_ = np.zeros((len(ims), 422 | self.imageDims[0], self.imageDims[1], self.imageDims[2]), 423 | dtype=np.float32) 424 | #Convert to normal image format if required. 425 | if self.isBlobFormat: 426 | ims = np.transpose(ims, (0,2,3,1)) 427 | 428 | #Resize the images 429 | h, w = ims.shape[1], ims.shape[2] 430 | for ix, in_ in enumerate(ims): 431 | if h==self.imageDims[0] and w==self.imageDims[1]: 432 | im_[ix] = np.copy(in_) 433 | else: 434 | #print (in_.shape, self.imageDims) 435 | im_[ix] = caffe.io.resize_image(in_, self.imageDims[0:2]) 436 | 437 | #Required cropping 438 | im_ = im_[:,self.crop[0]:self.crop[2], self.crop[1]:self.crop[3],:] 439 | #Applying the preprocessing 440 | caffe_in = np.zeros(np.array(im_.shape)[[0,3,1,2]], dtype=np.float32) 441 | for ix, in_ in enumerate(im_): 442 | caffe_in[ix] = self.transformer[ipName].preprocess(ipName, in_) 443 | 444 | #Resize the batch appropriately 445 | caffe_in = self.resize_batch(caffe_in) 446 | return caffe_in 447 | 448 | 449 | def deprocess_batch(self, caffeIn, ipName='data'): 450 | ''' 451 | ims: iterator over H * W * K sized images (K - number of channels) 452 | ''' 453 | #Applying the deprocessing 454 | im_ = np.zeros(np.array(caffeIn.shape)[[0,2,3,1]], dtype=np.float32) 455 | for ix, in_ in enumerate(caffeIn): 456 | im_[ix] = self.transformer[ipName].deprocess(ipName, in_) 457 | 458 | ims = np.zeros((len(im_), 459 | self.imageDims[0], self.imageDims[1], im_[0].shape[2]), 460 | dtype=np.float32) 461 | #Resize the images 462 | for ix, in_ in enumerate(im_): 463 | ims[ix] = caffe.io.resize_image(in_, self.imageDims) 464 | 465 | return ims 466 | 467 | 468 | ## 469 | #Core function for running forward and backward passes. 470 | def _run_forward_backward_all(self, runType, blobs=None, noInputs=False, diffs=None, **kwargs): 471 | ''' 472 | runType : 'forward_all' 473 | 'forward_backward_all' 474 | blobs : The blobs to extract in the forward_all pass 475 | noInputs: Set to true when there are no input blobs. 476 | diffs : the blobs for which the gradient needs to be extracted. 477 | kwargs : A dictionary where each input blob has associated data 478 | ''' 479 | if not noInputs: 480 | if kwargs: 481 | if (set(kwargs.keys()) != set(self.transformer.keys())): 482 | raise Exception('Data Transformer has not been set for all input blobs') 483 | #Just pass all the inputs 484 | procData = {} 485 | N = self.batchSz 486 | for in_, data in kwargs.iteritems(): 487 | N = data.shape[0] #The first dimension must be equivalent of batchSz 488 | procData[in_] = self.preprocess_batch(data, ipName=in_) 489 | 490 | if runType == 'forward_all': 491 | ops = self.net.forward_all(blobs=blobs, **procData) 492 | elif runType == 'forward': 493 | ops = self.net.forward(blobs=blobs, **procData) 494 | elif runType == 'backward': 495 | ops = self.net.backward(diff=diff, **procData) 496 | elif runType == 'forward_backward_all': 497 | ops, opDiff = self.net.forward_backward_all(blobs=blobs, diffs=diffs, **procData) 498 | #Resize diffs in the right size 499 | for opd_, data in ops.iteritems(): 500 | opDiff[opd_] = data[0:N] 501 | else: 502 | raise Exception('runType %s not recognized' % runType) 503 | #Resize data in the right size 504 | for op_, data in ops.iteritems(): 505 | if data.ndim==0: 506 | continue 507 | #print (op_, data.shape) 508 | ops[op_] = data[0:N] 509 | else: 510 | raise Exception('No Input data specified.') 511 | else: 512 | if runType in ['forward_all', 'forward']: 513 | ops = self.net.forward(blobs=blobs) 514 | elif runType in ['backward']: 515 | ops = self.net.backward(diffs=diffs) 516 | elif runType in ['forward_backward_all']: 517 | ops, opDiff = self.net.forward_backward_all(blobs=blobs, diffs=diffs) 518 | else: 519 | raise Exception('runType %s not recognized' % runType) 520 | 521 | if runType in ['forward', 'forward_all', 'backward']: 522 | return copy.deepcopy(ops) 523 | else: 524 | return copy.deepcopy(ops), copy.deepcopy(opDiff) 525 | 526 | 527 | def forward_all(self, blobs=None, noInputs=False, **kwargs): 528 | ''' 529 | See _run_forward_backward_all 530 | ''' 531 | return self._run_forward_backward_all(runType='forward_all', blobs=blobs, 532 | noInputs=noInputs, **kwargs) 533 | 534 | 535 | def forward_backward_all(self, blobs=None, noInputs=False, **kwargs): 536 | return self._run_forward_backward_all(runType='forward_backward_all', blobs=blobs, 537 | noInputs=noInputs, **kwargs) 538 | 539 | 540 | def forward(self, blobs=None, noInputs=False, **kwargs): 541 | return self._run_forward_backward_all(runType='forward', blobs=blobs, 542 | noInputs=noInputs, **kwargs) 543 | 544 | 545 | def backward(self, diffs=None, noInputs=False, **kwargs): 546 | return self._run_forward_backward_all(runType='backward', diffs=diffs, 547 | noInputs=noInputs, **kwargs) 548 | 549 | 550 | def vis_weights(self, blobName, blobNum=0, ax=None, titleName=None, isFc=False, 551 | h=None, w=None, returnData=False, chSt=0, chEn=3): 552 | assert blobName in self.net.params, 'BlobName not found' 553 | dat = copy.deepcopy(self.net.params[blobName][blobNum].data) 554 | if isFc: 555 | dat = dat.transpose((2,3,0,1)) 556 | print dat.shape 557 | assert dat.shape[2]==1 and dat.shape[3]==1 558 | ch = dat.shape[1] 559 | assert np.sqrt(ch)*np.sqrt(ch)==ch, 'Cannot transform to filter' 560 | h,w = int(np.sqrt(ch)), int(np.sqrt(ch)) 561 | dat = np.reshape(dat,(dat.shape[0],h,w,1)) 562 | print dat.shape 563 | weights = vis_square(dat, ax=ax, titleName=titleName, returnData=returnData) 564 | else: 565 | if h is None and w is None: 566 | weights = vis_square(dat.transpose(0,2,3,1), ax=ax, titleName=titleName, 567 | returnData=returnData, chSt=chSt, chEn=chEn) 568 | else: 569 | weights = vis_rect(dat.transpose(0,2,3,1), h, w, ax=ax, titleName=titleName, returnData=returnData) 570 | 571 | if returnData: 572 | return weights 573 | 574 | 575 | class CaffeNetLogger(object): 576 | def __init__(self, logFile='default_log.pkl', phases=['train', 'test']): 577 | self.logFile_ = logFile 578 | self.phase_ = phases 579 | self.featVals = co.OrderedDict() 580 | self.paramVals = co.OrderedDict() 581 | self.paramUpdate = co.OrderedDict() 582 | self.blobNames_ = co.OrderedDict() 583 | self.paramNames_ = co.OrderedDict() 584 | #number of parameters for instance with weight, bias = 2 585 | self.numParam_ = co.OrderedDict() 586 | self.layerNames_ = co.OrderedDict() 587 | for ph in self.phase_: 588 | self.featVals[ph] = edict() 589 | self.paramVals[ph] = [edict(), edict()] 590 | self.paramUpdate[ph] = [edict(), edict()] 591 | self.plotSetup_ = False 592 | #For recording the iterations at which data was recorded 593 | self.recIter_ = co.OrderedDict() 594 | for ph in self.phase_: 595 | self.recIter_[ph] = [] 596 | self.isRead_ = False 597 | 598 | @classmethod 599 | def from_net(cls, net, logFile='default_log.pkl', phases=['train', 'test']): 600 | self = cls(logFile, phases) 601 | for ph in self.phase_: 602 | self.layerNames_[ph] = [l for l in net[ph]._layer_names] 603 | self.paramNames_[ph] = net[ph].params.keys() 604 | self.blobNames_[ph] = net[ph].blobs.keys() 605 | #Blobs 606 | for i,b in enumerate(self.blobNames_[ph]): 607 | self.featVals[ph][b] = [] 608 | #Params 609 | for p in self.paramNames_[ph]: 610 | self.numParam_[p] = len(net[ph].params[p]) 611 | print (p, self.numParam_[p]) 612 | for i in range(self.numParam_[p]): 613 | self.paramVals[ph][i][p] = [] 614 | self.paramUpdate[ph][i][p] = [] 615 | return self 616 | 617 | #Read the logging data from file 618 | def read(self, fName=None, maxIter=None): 619 | ''' 620 | fName : File from which values need to be read 621 | maxIter: read until maxIter 622 | ''' 623 | self.isRead_ = True 624 | if fName is None and osp.exists(self.logFile_): 625 | fName = self.logFile_ 626 | else: 627 | print ('%s doesnot exist, please specify a log file name' % self.logFile_) 628 | return 629 | print ('Loading log from: %s' % fName) 630 | data = pickle.load(open(fName, 'r')) 631 | self.recIter_ = data['recIter'] 632 | for ph in self.phase_: 633 | if maxIter is not None: 634 | idx = np.where(self.recIter_[ph] <= maxIter)[0][-1] 635 | idx = min(idx + 1, len(self.recIter_[ph])) 636 | else: 637 | idx = len(self.recIter_[ph]) 638 | self.recIter_[ph] = self.recIter_[ph][0:idx] 639 | self.blobNames_[ph] = data[ph]['blobs'].keys() 640 | for k, b in enumerate(data[ph]['blobs'].keys()): 641 | self.featVals[ph][b] = data[ph]['blobs'][b][0:idx] 642 | self.paramNames_[ph] = data[ph]['params'].keys() 643 | for k, p in enumerate(data[ph]['params'].keys()): 644 | #for i in range(self.numParam_[p]): 645 | for i in range(1): 646 | self.paramVals[ph][i][p] = data[ph]['params'][p][i][0:idx] 647 | self.paramUpdate[ph][i][p] = data[ph]['paramsUpdate'][p][i][0:idx] 648 | 649 | #Internal funciton for defining axes 650 | def _get_axes(self, titleNames, figTitle): 651 | numSub = np.ceil(np.sqrt(self.maxPerFigure_)) 652 | N = len(titleNames) 653 | allAx = [] 654 | count = 0 655 | for fn in range(int(np.ceil(float(N)/self.maxPerFigure_))): 656 | #Given a figure 657 | fig = plt.figure() 658 | fig.suptitle(figTitle) 659 | ax = [] 660 | en = min(N, count + self.maxPerFigure_) 661 | for i,tit in enumerate(titleNames[count:en]): 662 | ax.append((fig, fig.add_subplot(numSub, numSub, i+1))) 663 | ax[i][1].set_title(tit) 664 | count += self.maxPerFigure_ 665 | allAx = allAx + ax 666 | return allAx 667 | 668 | #Setup for plotting 669 | def setup_plots(self): 670 | plt.close('all') 671 | plt.ion() 672 | self.maxPerFigure_ = 16 673 | self.axBlobs_ = co.OrderedDict() 674 | self.axParamValW_ = co.OrderedDict() 675 | self.axParamDeltaW_ = co.OrderedDict() 676 | for ph in self.phase_: 677 | self.axBlobs_[ph] = self._get_axes(self.blobNames_[ph], 678 | '%s-Feature Values' % ph) 679 | self.axParamValW_[ph] = self._get_axes(self.paramNames_[ph], 680 | '%s-Parameter Values' % ph) 681 | self.axParamDeltaW_[ph] = self._get_axes(self.paramNames_[ph], 682 | '%s-Parameter Updates' % ph) 683 | self.plotSetup_ = True 684 | 685 | #Plot the log 686 | def plot(self): 687 | if not self.isRead_: 688 | self.read() 689 | if not self.plotSetup_: 690 | self.setup_plots() 691 | plt.ion() 692 | for ph in self.phase_: 693 | for i,bn in enumerate(self.blobNames_[ph]): 694 | fig, ax = self.axBlobs_[ph][i] 695 | plt.figure(fig.number) 696 | print (ph, bn, len(self.recIter_[ph]), len(self.featVals[ph][bn])) 697 | ax.plot(np.array(self.recIter_[ph]), self.featVals[ph][bn]) 698 | plt.draw() 699 | plt.show() 700 | for i,pn in enumerate(self.paramNames_[ph]): 701 | #The parameters 702 | fig, ax = self.axParamValW_[ph][i] 703 | plt.figure(fig.number) 704 | ax.plot(np.array(self.recIter_[ph]), self.paramVals[ph][0][pn]) 705 | #The delta in parameters 706 | fig, ax = self.axParamDeltaW_[ph][i] 707 | plt.figure(fig.number) 708 | ax.plot(np.array(self.recIter_[ph]), self.paramUpdate[ph][0][pn]) 709 | plt.draw() 710 | plt.show() 711 | 712 | #Only plot the losses 713 | def plot_loss(self, ax=None, layerNames=[], ylim=None): 714 | if not self.isRead_: 715 | self.read() 716 | plt.ion() 717 | if ax is None: 718 | fig = plt.figure() 719 | ax = fig.add_subplot(111) 720 | colors = ['r', 'b'] 721 | for p, ph in enumerate(self.phase_): 722 | for i, bn in enumerate(self.blobNames_[ph]): 723 | if len(layerNames) == 0 and 'loss' in bn: 724 | ax.plot(np.array(self.recIter_[ph][1:]), self.featVals[ph][bn][1:], colors[p]) 725 | ax.set_title(bn) 726 | elif len(layerNames) >0 and layerNames[0] in bn: 727 | ax.plot(np.array(self.recIter_[ph]), self.featVals[ph][bn], colors[p]) 728 | ax.set_title(bn) 729 | if ylim is not None: 730 | ax.set_ylim(ylim) 731 | plt.show() 732 | plt.draw() 733 | return ax 734 | 735 | class MySolver(object): 736 | def __init__(self): 737 | self.solver_ = None 738 | self.phase_ = ['train', 'test'] 739 | self.isLog_ = True 740 | 741 | def __del__(self): 742 | del self.solver_ 743 | del self.net_ 744 | 745 | @classmethod 746 | def from_file(cls, solFile, recFreq=20, dumpLogFreq=None, 747 | logFile='default_log.pkl', isLog=True): 748 | ''' 749 | solFile : solver prototxt from which to load the net 750 | recFreq : the frequency of recording 751 | dumpLogFreq: the frequency with which dumpLog should be noted 752 | ''' 753 | self = cls() 754 | self.solFile_ = solFile 755 | self.logFile_ = logFile 756 | self.recFreq_ = recFreq 757 | self.dumpLogFreq_= dumpLogFreq 758 | self.isLog_ = isLog 759 | self.setup_solver() 760 | self.plotSetup_ = False 761 | return self 762 | 763 | ## 764 | #setup the solver 765 | def setup_solver(self): 766 | self.solDef_ = mpu.SolverDef.from_file(self.solFile_) 767 | self.maxIter_ = int(self.solDef_.get_property('max_iter')) 768 | self.testInterval_ = int(self.solDef_.get_property('test_interval')) 769 | if self.solDef_.has_property('solver_mode'): 770 | solverMode = self.solDef_.get_property('solver_mode') 771 | else: 772 | solverMode = 'GPU' 773 | if solverMode == 'GPU': 774 | if self.solDef_.has_property('device_id'): 775 | device = int(self.solDef_.get_property('device_id')) 776 | else: 777 | device = 0 778 | print ('GPU Mode, setting device %d' % device) 779 | caffe.set_device(device) 780 | caffe.set_mode_gpu() 781 | else: 782 | print ('CPU Mode') 783 | caffe.set_mode_cpu() 784 | 785 | self.solver_ = caffe.SGDSolver(self.solFile_) 786 | self.net_ = co.OrderedDict() 787 | self.net_[self.phase_[0]] = self.solver_.net 788 | self.net_[self.phase_[1]] = self.solver_.test_nets[0] 789 | if len(self.solver_.test_nets) > 1: 790 | print (' ##### WARNING - THERE ARE MORE THAN ONE TEST-NETS, FEATURE VALS\ 791 | FOR TEST NETS > 1 WILL NOT BE RECORDED #################') 792 | ip = raw_input('ARE YOU SURE YOU WANT TO CONTINUE(y/n)?') 793 | if ip == 'n': 794 | raise Exception('Quitting') 795 | self.log_ = CaffeNetLogger.from_net(self.net_, self.logFile_, self.phase_) 796 | 797 | ## 798 | #Restore the solver from a previous state 799 | def restore(self, fName, restoreIter=None): 800 | ''' 801 | fName: the name of the file from which solver needs to be resumed. 802 | ''' 803 | self.solver_.restore(fName) 804 | #if osp.exists(self.logFile_): 805 | #self.log_.read(self.logFile_, maxIter=restoreIter) 806 | self.log_.read(maxIter=restoreIter) 807 | 808 | ## 809 | #Copy weights from a net file 810 | def copy_weights(self, fName): 811 | ''' 812 | fName: the name of the file from which weights need to be copied. 813 | ''' 814 | self.solver_.copy_trained_layers_from_netfile(fName) 815 | 816 | ## 817 | # Solve 818 | def solve(self, numSteps=None): 819 | if numSteps is None: 820 | numSteps = self.maxIter_ 821 | for i in range(numSteps): 822 | if self.isLog_: 823 | if np.mod(self.solver_.iter, self.recFreq_)==0: 824 | self.record_feats_params(phases=['train']) 825 | self.log_.recIter_['train'].append(self.solver_.iter) 826 | if np.mod(self.solver_.iter, self.testInterval_)==0: 827 | self.record_feats_params(phases=['test']) 828 | self.log_.recIter_['test'].append(self.solver_.iter) 829 | self.solver_.step(1) 830 | if self.isLog_: 831 | if np.mod(self.solver_.iter, self.dumpLogFreq_)==0: 832 | self.dump_to_file() 833 | ## 834 | #Record the data 835 | def record_feats_params(self, phases=None): 836 | t1 = time.time() 837 | print ('RECORDING FEATURE STATS') 838 | if phases is None: 839 | phases = self.phase_ 840 | for ph in phases: 841 | for b in self.log_.blobNames_[ph]: 842 | self.log_.featVals[ph][b].append(np.mean(np.abs(self.net_[ph].blobs[b].data))) 843 | for p in self.log_.paramNames_[ph]: 844 | for i in range(self.log_.numParam_[p]): 845 | #print (ph, p, i) 846 | dat = np.mean(np.abs(self.net_[ph].params[p][i].data)) 847 | dif = np.mean(np.abs(self.net_[ph].params[p][i].diff)) 848 | self.log_.paramVals[ph][i][p].append(dat) 849 | self.log_.paramUpdate[ph][i][p].append(dif) 850 | t = time.time() - t1 851 | print ('$$$$$$$$$$$$ TIME TO RECORD %f' % t) 852 | 853 | ## 854 | #Dump the data to the file 855 | def dump_to_file(self): 856 | t1 = time.time() 857 | print ('SOLVER DUMPING TO FILE') 858 | data = co.OrderedDict() 859 | for ph in self.phase_: 860 | data[ph] = co.OrderedDict() 861 | data[ph]['blobs'] = co.OrderedDict() 862 | for b in self.log_.blobNames_[ph]: 863 | data[ph]['blobs'][b] = self.log_.featVals[ph][b] 864 | data[ph]['params'] = co.OrderedDict() 865 | data[ph]['paramsUpdate'] = co.OrderedDict() 866 | for p in self.log_.paramNames_[ph]: 867 | data[ph]['params'][p] = [] 868 | data[ph]['paramsUpdate'][p] = [] 869 | for i in range(self.log_.numParam_[p]): 870 | data[ph]['params'][p].append(self.log_.paramVals[ph][i][p]) 871 | data[ph]['paramsUpdate'][p].append(self.log_.paramVals[ph][i][p]) 872 | data['recFreq'] = self.recFreq_ 873 | data['recIter'] = self.log_.recIter_ 874 | data['numParam'] = self.log_.numParam_ 875 | pickle.dump(data, open(self.logFile_, 'w')) 876 | t = time.time() - t1 877 | print ('$$$$$$$$$$$$ TIME TO DUMP %f' % t) 878 | 879 | # Return pointer to layer 880 | def get_layer_pointer(self, layerName, ph='train'): 881 | assert layerName in self.log_.layerNames_[ph], 'layer not found' 882 | index = self.log_.layerNames_[ph].index(layerName) 883 | return self.net_[ph].layers[index] 884 | 885 | 886 | #Read from log file 887 | def read_log_from_file(self, **kwargs): 888 | self.log_.read(**kwargs) 889 | 890 | #Plot the log 891 | def plot(self): 892 | self.log_.plot() 893 | 894 | ## 895 | # Visualize filters 896 | def vis_square(data, padsize=1, padval=0, ax=None, titleName=None, returnData=False, 897 | chSt=0, chEn=3): 898 | ''' 899 | data is numFitlers * height * width or numFilters * height * width * channels 900 | ''' 901 | if data.ndim == 4: 902 | data = data[:,:,:,chSt:chEn] 903 | 904 | data -= data.min() 905 | data /= data.max() 906 | 907 | # force the number of filters to be square 908 | n = int(np.ceil(np.sqrt(data.shape[0]))) 909 | padding = ((0, n ** 2 - data.shape[0]), (0, padsize), (0, padsize)) + ((0, 0),) * (data.ndim - 3) 910 | data = np.pad(data, padding, mode='constant', constant_values=(padval, padval)) 911 | 912 | # tile the filters into an image 913 | data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1))) 914 | data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:]) 915 | 916 | if titleName is None: 917 | titleName = '' 918 | 919 | data = data.squeeze() 920 | if ax is not None: 921 | ax.imshow(data, interpolation='none') 922 | ax.set_title(titleName) 923 | else: 924 | plt.imshow(data, interpolation='none') 925 | plt.title(titleName) 926 | 927 | if returnData: 928 | return data 929 | 930 | 931 | #Make rectangular filters 932 | def vis_rect(data, h, w, padsize=1, padval=0, ax=None, titleName=None, returnData=False): 933 | ''' 934 | data is numFitlers * height * width or numFilters * height * width * channels 935 | ''' 936 | data -= data.min() 937 | data /= data.max() 938 | 939 | padding = ((0, h * w - data.shape[0]), (0, padsize), (0, padsize)) + ((0, 0),) * (data.ndim - 3) 940 | data = np.pad(data, padding, mode='constant', constant_values=(padval, padval)) 941 | 942 | # tile the filters into an image 943 | data = data.reshape((h, w) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1))) 944 | data = data.reshape((h * data.shape[1], w * data.shape[3]) + data.shape[4:]) 945 | 946 | if titleName is None: 947 | titleName = '' 948 | data = data.squeeze() 949 | if ax is not None: 950 | ax.imshow(data, interpolation='none') 951 | ax.set_title(titleName) 952 | else: 953 | plt.imshow(data) 954 | plt.title(titleName, interpolation='none') 955 | 956 | if returnData: 957 | return data 958 | 959 | 960 | 961 | def setup_prototypical_network(netName='vgg', layerName='pool4'): 962 | ''' 963 | Sets up a network in a configuration in which I commonly use it. 964 | ''' 965 | modelFile, meanFile = get_model_mean_file(netName) 966 | defFile = get_layer_def_files(netName, layerName=layerName) 967 | meanDat = mpio.read_mean(meanFile) 968 | net = MyNet(defFile, modelFile) 969 | net.set_preprocess(ipName='data', meanDat=meanDat, imageDims=(256,256,3)) 970 | return net 971 | 972 | 973 | ''' 974 | def get_features(net, im, layerName=None, ipLayerName='data'): 975 | dataBlob = net.blobs['data'] 976 | batchSz = dataBlob.num 977 | assert im.ndim == 4 978 | N,nc,h,w = im.shape 979 | assert h == dataBlob.height and w==dataBlob.width 980 | 981 | if not layerName==None: 982 | assert layerName in net.blobs.keys() 983 | layerName = [layerName] 984 | outName = layerName[0] 985 | else: 986 | outName = net.outputs[0] 987 | layerName = [] 988 | 989 | print layerName 990 | imBatch = np.zeros((batchSz,nc,h,w)) 991 | outFeats = {} 992 | outBlob = net.blobs[outName] 993 | outFeats = np.zeros((N, outBlob.channels, outBlob.height, outBlob.width)) 994 | 995 | for i in range(0,N,batchSz): 996 | st = i 997 | en = min(N, st + batchSz) 998 | l = en - st 999 | imBatch[0:l,:,:,:] = np.copy(im[st:en]) 1000 | dataLayer = {ipLayerName:imBatch} 1001 | feats = net.forward(blobs=layerName, start=None, end=None, **dataLayer) 1002 | outFeats[st:en] = feats[outName][0:l] 1003 | 1004 | return outFeats 1005 | 1006 | ''' 1007 | 1008 | 1009 | def compute_error(gtLabels, prLabels, errType='classify'): 1010 | N, lblSz = gtLabels.shape 1011 | res = [] 1012 | assert prLabels.shape[0] == N and prLabels.shape[1] == lblSz 1013 | if errType == 'classify': 1014 | assert lblSz == 1 1015 | cls = np.unique(gtLabels) 1016 | cls = np.sort(cls) 1017 | nCl = cls.shape[0] 1018 | confMat = np.zeros((nCl, nCl)) 1019 | for i in range(nCl): 1020 | for j in range(nCl): 1021 | confMat[i,j] = float(np.sum(np.bitwise_and((gtLabels == cls[i]),(prLabels == cls[j]))))/(np.sum(gtLabels == cls[i])) 1022 | res = confMat 1023 | else: 1024 | print "Error type not recognized" 1025 | raise 1026 | return res 1027 | 1028 | 1029 | def feats_2_labels(feats, lblType, maskLastLabel=False): 1030 | #feats are assumed to be numEx * featDims 1031 | labels = [] 1032 | if lblType in ['uniform20', 'kmedoids30_20']: 1033 | r,c = feats.shape 1034 | if maskLastLabel: 1035 | feats = feats[0:r,0:c-1] 1036 | labels = np.argmax(feats, axis=1) 1037 | labels = labels.reshape((r,1)) 1038 | else: 1039 | print "UNrecognized lblType" 1040 | raise 1041 | return labels 1042 | 1043 | 1044 | def save_images(ims, gtLb, pdLb, svFileStr, stCount=0, isSiamese=False): 1045 | ''' 1046 | Saves the images 1047 | ims: N * nCh * H * W 1048 | gtLb: Ground Truth Label 1049 | pdLb: Predicted Label 1050 | svFileStr: Path should contain (%s, %d) - which will be filled in by correct/incorrect and count 1051 | ''' 1052 | N = ims.shape[0] 1053 | ims = ims.transpose((0,2,3,1)) 1054 | fig = plt.figure() 1055 | for i in range(N): 1056 | im = ims[i] 1057 | plt.title('Gt-Label: %d, Predicted-Label: %d' %(gtLb[i], pdLb[i])) 1058 | gl, pl = gtLb[i], pdLb[i] 1059 | if gl==pl: 1060 | fStr = 'correct' 1061 | else: 1062 | fStr = 'mistake' 1063 | if isSiamese: 1064 | im1 = im[:,:,0:3] 1065 | im2 = im[:,:,3:] 1066 | im1 = im1[:,:,[2,1,0]] 1067 | im2 = im2[:,:,[2,1,0]] 1068 | plt.subplot(1,2,1) 1069 | plt.imshow(im1) 1070 | plt.subplot(1,2,2) 1071 | plt.imshow(im2) 1072 | fName = svFileStr % (fStr, i + stCount) 1073 | if not os.path.exists(os.path.dirname(fName)): 1074 | os.makedirs(os.path.dirname(fName)) 1075 | print fName 1076 | plt.savefig(fName) 1077 | 1078 | ''' 1079 | def test_network_siamese_h5(imH5File=[], lbH5File=[], netFile=[], defFile=[], imSz=128, cropSz=112, nCh=3, outLblSz=1, meanFile=[], ipLayerName='data', lblType='uniform20',outFeatSz=20, maskLastLabel=False, db=None, svImg=False, svImFileStr=None, deviceId=None): 1080 | #defFile: Architecture prototxt 1081 | #netFile : The model weights 1082 | #maskLastLabel: In some cases it is we may need to compute the error bt ignoring the last label 1083 | # for example in det - where the last class might be the backgroud class 1084 | #db: instead of h5File, provide a dbReader 1085 | 1086 | isBlobFormat = True 1087 | if db is None: 1088 | isBlobFormat = False 1089 | print imH5File, lbH5File 1090 | imFid = h5py.File(imH5File,'r') 1091 | lbFid = h5py.File(lbH5File,'r') 1092 | ims1 = imFid['images1/'] 1093 | ims2 = imFid['images2/'] 1094 | lbls = lbFid['labels/'] 1095 | 1096 | #Get Sizes 1097 | imSzSq = imSz * imSz 1098 | assert(ims1.shape[0] % imSzSq == 0 and ims2.shape[0] % imSzSq ==0) 1099 | N = ims1.shape[0]/(imSzSq * nCh) 1100 | assert(lbls.shape[0] % N == 0) 1101 | lblSz = outLblSz 1102 | 1103 | #Get the mean 1104 | imMean = [] 1105 | if not meanFile == []: 1106 | imMean = mpio.read_mean(meanFile) 1107 | 1108 | #Initialize network 1109 | net = MyNet(defFile, netFile, deviceId=deviceId) 1110 | net.set_preprocess(chSwap=None, meanDat=imMean,imageDims=(imSz, imSz, 2*nCh), isBlobFormat=isBlobFormat, ipName='data') 1111 | 1112 | #Initialize variables 1113 | batchSz = net.get_batchsz() 1114 | ims = np.zeros((batchSz, 2 * nCh, imSz, imSz)) 1115 | count = 0 1116 | imCount = 0 1117 | 1118 | if db is None: 1119 | labels = np.zeros((N, lblSz)) 1120 | gtLabels = np.zeros((N, lblSz)) 1121 | #Loop through the images 1122 | for i in np.arange(0,N,batchSz): 1123 | st = i * nCh * imSzSq 1124 | en = min(N, i + batchSz) * nCh * imSzSq 1125 | numIm = min(N, i + batchSz) - i 1126 | ims[0:batchSz] = 0 1127 | ims[0:numIm,0:nCh,:,:] = ims1[st:en].reshape((numIm,nCh,imSz,imSz)) 1128 | ims[0:numIm,nCh:2*nCh,:,:] = ims2[st:en].reshape((numIm,nCh,imSz,imSz)) 1129 | imsPrep = prepare_image(ims, cropSz, imMean) 1130 | predFeat = get_features(net, imsPrep, ipLayerName=ipLayerName) 1131 | predFeat = predFeat[0:numIm] 1132 | print numIm 1133 | try: 1134 | labels[i : i + numIm, :] = feats_2_labels(predFeat.reshape((numIm,outFeatSz)), lblType, maskLastLabel=maskLastLabel)[0:numIm] 1135 | gtLabels[i : i + numIm, : ] = (lbls[i * lblSz : (i+numIm) * lblSz]).reshape(numIm, lblSz) 1136 | except ValueError: 1137 | print "Value Error found" 1138 | pdb.set_trace() 1139 | else: 1140 | labels, gtLabels = [], [] 1141 | runFlag = True 1142 | while runFlag: 1143 | count = count + 1 1144 | print "Processing Batch: ", count 1145 | dat, lbl = db.read_batch(batchSz) 1146 | N = dat.shape[0] 1147 | print N 1148 | if N < batchSz: 1149 | runFlag = False 1150 | batchDat = net.preprocess_batch(dat, ipName='data') 1151 | dataLayer = {} 1152 | dataLayer[ipLayerName] = batchDat 1153 | feats = net.net.forward(**dataLayer) 1154 | feats = feats[feats.keys()[0]][0:N] 1155 | gtLabels.append(lbl) 1156 | predLabels = feats_2_labels(feats.reshape((N,outFeatSz)), lblType) 1157 | labels.append(predLabels) 1158 | if svImg: 1159 | save_images(dat, lbl, predLabels, svImFileStr, stCount=imCount, isSiamese=True) 1160 | imCount = imCount + N 1161 | labels = np.concatenate(labels) 1162 | gtLabels = np.concatenate(gtLabels) 1163 | 1164 | confMat = compute_error(gtLabels, labels, 'classify') 1165 | return confMat, labels, gtLabels 1166 | 1167 | ''' 1168 | 1169 | def read_mean_txt(fileName): 1170 | with open(fileName,'r') as f: 1171 | l = f.readlines() 1172 | mn = [float(i) for i in l] 1173 | mn = np.array(mn) 1174 | return mn 1175 | -------------------------------------------------------------------------------- /my_pycaffe_io.py: -------------------------------------------------------------------------------- 1 | ## @package my_pycaffe_io 2 | # IO operations. 3 | # 4 | 5 | try: 6 | import h5py as h5 7 | except: 8 | print ('WARNING: h5py not found, some functions may not work') 9 | import numpy as np 10 | import my_pycaffe as mp 11 | import caffe 12 | import pdb 13 | import os 14 | import lmdb 15 | import shutil 16 | import scipy.misc as scm 17 | import scipy.io as sio 18 | import copy 19 | from pycaffe_config import cfg 20 | from os import path as osp 21 | import other_utils as ou 22 | 23 | if not cfg.IS_EC2: 24 | #import matlab.engine as men 25 | MATLAB_PATH = '/work4/pulkitag-code/pkgs/caffe-v2-2/matlab/caffe' 26 | else: 27 | MATLAB_PATH = '' 28 | 29 | 30 | def read_mean(protoFileName): 31 | ''' 32 | Reads mean from the protoFile 33 | ''' 34 | with open(protoFileName,'r') as fid: 35 | ss = fid.read() 36 | vec = caffe.io.caffe_pb2.BlobProto() 37 | vec.ParseFromString(ss) 38 | mn = caffe.io.blobproto_to_array(vec) 39 | mn = np.squeeze(mn) 40 | return mn 41 | 42 | 43 | ## 44 | # Write array as a proto 45 | def write_proto(arr, outFile): 46 | ''' 47 | Writes the array as a protofile 48 | ''' 49 | blobProto = caffe.io.array_to_blobproto(arr) 50 | ss = blobProto.SerializeToString() 51 | fid = open(outFile,'w') 52 | fid.write(ss) 53 | fid.close() 54 | 55 | 56 | ## 57 | # Convert the mean to be useful for siamese network. 58 | def mean2siamese_mean(inFile, outFile, isGray=False): 59 | mn = mp.read_mean(inFile) 60 | if isGray: 61 | mn = mn.reshape((1,mn.shape[0],mn.shape[1])) 62 | mn = np.concatenate((mn, mn)) 63 | dType = mn.dtype 64 | mn = mn.reshape((1, mn.shape[0], mn.shape[1], mn.shape[2])) 65 | print "New mean shape: ", mn.shape, dType 66 | write_proto(mn, outFile) 67 | 68 | ## 69 | # Convert the siamese mean to be the mean 70 | def siamese_mean2mean(inFile, outFile): 71 | assert not os.path.exists(outFile), '%s already exists' % outFile 72 | mn = mp.read_mean(inFile) 73 | ch = mn.shape[0] 74 | assert np.mod(ch,2)==0 75 | ch = ch / 2 76 | print "New number of channels: %d" % ch 77 | newMn = mn[0:ch].reshape(1,ch,mn.shape[1],mn.shape[2]) 78 | write_proto(newMn.astype(mn.dtype), outFile) 79 | 80 | ## 81 | # Convert to grayscale, mimics the matlab function 82 | def rgb2gray(rgb): 83 | return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140]) 84 | 85 | ## 86 | # Convert the mean grayscale mean 87 | def mean2graymean(inFile, outFile): 88 | assert not os.path.exists(outFile), '%s already exists' % outFile 89 | mn = mp.read_mean(inFile) 90 | dType = mn.dtype 91 | ch = mn.shape[0] 92 | assert ch==3 93 | mn = rgb2gray(mn.transpose((1,2,0))).reshape((1,1,mn.shape[1],mn.shape[2])) 94 | print "New mean shape: ", mn.shape, dType 95 | write_proto(mn.astype(dType), outFile) 96 | 97 | 98 | ## 99 | # Resize the mean to a different size 100 | def resize_mean(inFile, outFile, imSz): 101 | mn = mp.read_mean(inFile) 102 | dType = mn.dtype 103 | ch, rows, cols = mn.shape 104 | mn = mn.transpose((1,2,0)) 105 | mn = scm.imresize(mn, (imSz, imSz)).transpose((2,0,1)).reshape((1,ch,imSz,imSz)) 106 | write_proto(mn.astype(dType), outFile) 107 | 108 | ''' 109 | def ims2hdf5(im, labels, batchSz, batchPath, isColor=True, batchStNum=1, isUInt8=True, scale=None, newLabels=False): 110 | #Converts an image dataset into hdf5 111 | h5SrcFile = os.path.join(batchPath, 'h5source.txt') 112 | strFid = open(h5SrcFile, 'w') 113 | 114 | dType = im.dtype 115 | if isUInt8: 116 | assert im.dtype==np.uint8, 'Images should be in uint8' 117 | h5DType = 'u1' 118 | else: 119 | assert im.dtype==np.float32, 'Images can either be uint8 or float32' 120 | h5DType = 'f' 121 | 122 | if scale is not None: 123 | im = im * scale 124 | 125 | if isColor: 126 | assert im.ndim ==4 127 | N,ch,h,w = im.shape 128 | assert ch==3, 'Color images must have 3 channels' 129 | else: 130 | assert im.ndim ==3 131 | N,h,w = im.shape 132 | im = np.reshape(im,(N,1,h,w)) 133 | ch = 1 134 | 135 | count = batchStNum 136 | for i in range(0,N,batchSz): 137 | st = i 138 | en = min(N, st + batchSz) 139 | if st + batchSz > N: 140 | break 141 | h5File = os.path.join(batchPath, 'batch%d.h5' % count) 142 | h5Fid = h5.File(h5File, 'w') 143 | imBatch = np.zeros((N, ch, h, w), dType) 144 | imH5 = h5Fid.create_dataset('/data',(batchSz, ch, h, w), dtype=h5DType) 145 | imH5[0:batchSz] = im[st:en] 146 | if newLabels: 147 | lbH5 = h5Fid.create_dataset('/label', (batchSz,), dtype='f') 148 | lbH5[0:batchSz] = labels[st:en].reshape((batchSz,)) 149 | else: 150 | lbH5 = h5Fid.create_dataset('/label', (batchSz,1,1,1), dtype='f') 151 | lbH5[0:batchSz] = labels[st:en].reshape((batchSz,1,1,1)) 152 | h5Fid.close() 153 | strFid.write('%s \n' % h5File) 154 | count += 1 155 | strFid.close() 156 | ''' 157 | 158 | class DbSaver: 159 | def __init__(self, dbName, isLMDB=True): 160 | if os.path.exists(dbName): 161 | print "%s already existed, but not anymore ..removing.." % dbName 162 | shutil.rmtree(dbName) 163 | self.db = lmdb.open(dbName, map_size=int(1e12)) 164 | self.count = 0 165 | 166 | def __del__(self): 167 | self.db.close() 168 | 169 | def add_batch(self, ims, labels=None, imAsFloat=False, svIdx=None): 170 | ''' 171 | Assumes ims are numEx * ch * h * w 172 | svIdx: Allows one to store the images randomly. 173 | ''' 174 | self.txn = self.db.begin(write=True) 175 | if labels is not None: 176 | assert labels.dtype == np.int or labels.dtype==np.long 177 | else: 178 | N = ims.shape[0] 179 | labels = np.zeros((N,)).astype(np.int) 180 | 181 | if svIdx is not None: 182 | itrtr = zip(svIdx, ims, labels) 183 | else: 184 | itrtr = zip(range(self.count, self.count + ims.shape[0]), ims, labels) 185 | 186 | #print svIdx.shape, ims.shape, labels.shape 187 | for idx, im, lb in itrtr: 188 | if not imAsFloat: 189 | im = im.astype(np.uint8) 190 | imDat = caffe.io.array_to_datum(im, label=lb) 191 | aa = imDat.SerializeToString() 192 | self.txn.put('{:0>10d}'.format(idx), imDat.SerializeToString()) 193 | self.txn.commit() 194 | self.count = self.count + ims.shape[0] 195 | 196 | def close(self): 197 | self.db.close() 198 | 199 | 200 | 201 | class DoubleDbSaver: 202 | ''' 203 | Useful for example when storing images and labels in two different dbs 204 | ''' 205 | def __init__(self, dbName1, dbName2, isLMDB=True): 206 | self.dbs_ = [] 207 | self.dbs_.append(DbSaver(dbName1, isLMDB=isLMDB)) 208 | self.dbs_.append(DbSaver(dbName2, isLMDB=isLMDB)) 209 | 210 | def __del__(self): 211 | for db in self.dbs_: 212 | db.__del__() 213 | 214 | def close(self): 215 | for db in self.dbs_: 216 | db.close() 217 | 218 | def add_batch(self, ims, labels=(None,None), imAsFloat=(False,False), svIdx=(None,None)): 219 | for (i,db) in enumerate(self.dbs_): 220 | im = ims[i] 221 | db.add_batch(ims[i], labels[i], imAsFloat=imAsFloat[i], svIdx=svIdx[i]) 222 | 223 | 224 | class DbReader: 225 | def __init__(self, dbName, isLMDB=True, readahead=True, wrapAround=False): 226 | ''' 227 | wrapAround: False - return None, None if end of file is reached 228 | True - move to the first element 229 | ''' 230 | #For large LMDB set readahead to be False 231 | self.db_ = lmdb.open(dbName, readonly=True, readahead=readahead) 232 | self.txn_ = self.db_.begin(write=False) 233 | self.cursor_ = self.txn_.cursor() 234 | self.nextValid_ = True 235 | self.wrap_ = wrapAround 236 | self.cursor_.first() 237 | 238 | def __del__(self): 239 | #self.txn_.commit() 240 | self.db_.close() 241 | 242 | #Maintain the appropriate variables 243 | def _maintain(self): 244 | if self.wrap_: 245 | if not self.nextValid_: 246 | print ('Going to first element of lmdb') 247 | self.cursor_.first() 248 | self.nextValid_ = True 249 | 250 | #Get the current key 251 | def get_key(self): 252 | if not self.nextValid_: 253 | return self.cursor_.key() 254 | else: 255 | return None 256 | 257 | #Get all keys 258 | def get_key_all(self): 259 | keys = [] 260 | self.cursor_.first() 261 | isNext = True 262 | while isNext: 263 | key = self.cursor_.key() 264 | isNext = self.cursor_.next() 265 | keys.append(key) 266 | self.cursor_.first() 267 | return keys 268 | 269 | def read_key(self, key): 270 | dat = self.cursor_.get(key) 271 | datum = caffe.io.caffe_pb2.Datum() 272 | datStr = datum.FromString(dat) 273 | data = caffe.io.datum_to_array(datStr) 274 | label = datStr.label 275 | return data, label 276 | 277 | 278 | def read_next(self): 279 | if not self.nextValid_: 280 | return None, None 281 | else: 282 | key, dat = self.cursor_.item() 283 | datum = caffe.io.caffe_pb2.Datum() 284 | datStr = datum.FromString(dat) 285 | data = caffe.io.datum_to_array(datStr) 286 | label = datStr.label 287 | self.nextValid_ = self.cursor_.next() 288 | self._maintain() 289 | return data, label 290 | 291 | #Read a batch of elements 292 | def read_batch(self, batchSz): 293 | data, label = [], [] 294 | count = 0 295 | for b in range(batchSz): 296 | dat, lb = self.read_next() 297 | if dat is None: 298 | break 299 | else: 300 | count += 1 301 | ch, h, w = dat.shape 302 | dat = np.reshape(dat,(1,ch,h,w)) 303 | data.append(dat) 304 | label.append(lb) 305 | if count > 0: 306 | data = np.concatenate(data[:]) 307 | label = np.array(label) 308 | label = label.reshape((len(label),1)) 309 | else: 310 | data, label = None, None 311 | return data, label 312 | 313 | def get_label_stats(self, maxLabels): 314 | countArr = np.zeros((maxLabels,)) 315 | countFlag = True 316 | while countFlag: 317 | _,lb = self.read_next() 318 | if lb is not None: 319 | countArr[lb] += 1 320 | else: 321 | countFlag = False 322 | return countArr 323 | 324 | #Get number of elements 325 | def get_count(self): 326 | return int(self.db_.stat()['entries']) 327 | 328 | #Skip one element 329 | def skip(self): 330 | isNext = self.cursor_.next() 331 | if not isNext: 332 | self.cursor_.first() 333 | self._maintain() 334 | 335 | #Skip in reverse 336 | def skip_reverse(self): 337 | isPrev = self.cursor_.prev() 338 | #Prev skip will not be possible if we are the first element 339 | if not isPrev: 340 | self.cursor_.last() 341 | self._maintain() 342 | 343 | #Compute the mean of the data 344 | def compute_mean(self): 345 | self.cursor_.first() 346 | im, _ = self.read_next() 347 | mu = np.zeros(im.shape) 348 | mu[...] = im[...] 349 | wrap = self.wrap_ 350 | self.wrap_ = False 351 | N = 1 352 | while True: 353 | im, _ = self.read_next() 354 | if im is None: 355 | break 356 | mu += im 357 | N += 1 358 | if np.mod(N,1000)==1: 359 | print ('Processed %d images' % N) 360 | mu = mu / float(N) 361 | self.wrap_ = wrap 362 | return mu 363 | 364 | #close 365 | def close(self): 366 | self.txn_.commit() 367 | self.db_.close() 368 | 369 | 370 | class SiameseDbReader(DbReader): 371 | def get_next_pair(self, flipColor=True): 372 | imDat,label = self.read_next() 373 | ch,h,w = imDat.shape 374 | assert np.mod(ch,2)==0 375 | ch = ch / 2 376 | imDat = np.transpose(imDat,(1,2,0)) 377 | im1 = imDat[:,:,0:ch] 378 | im2 = imDat[:,:,ch:2*ch] 379 | if flipColor: 380 | im1 = im1[:,:,[2,1,0]] 381 | im2 = im2[:,:,[2,1,0]] 382 | return im1, im2, label 383 | 384 | ## 385 | # Read two LMDBs simultaneosuly 386 | class DoubleDbReader(object): 387 | def __init__(self, dbNames, isLMDB=True, readahead=True, 388 | wrapAround=False, isMulti=False): 389 | ''' 390 | wrapAround: False - return None, None if end of file is reached 391 | True - move to the first element 392 | isMulti : False - read only two dbs v(flag for backward compatibility) 393 | True - read from arbitrary number of dbs 394 | ''' 395 | #For large LMDB set readahead to be False 396 | self.dbs_ = [] 397 | self.isMulti_ = isMulti 398 | for d in dbNames: 399 | self.dbs_.append(DbReader(d, isLMDB=isLMDB, readahead=readahead, 400 | wrapAround=wrapAround)) 401 | 402 | def __del__(self): 403 | for db in self.dbs_: 404 | db.__del__() 405 | 406 | #Check that all the DBs have the exact same set of keys 407 | def check_key_consistency(self, softOnLength=False): 408 | ''' 409 | checks that all the LMDBs have the same set of keys 410 | softOnLength: True - if the LMDBs have different lenth 411 | chose the length to be smalles across all 412 | sequences. 413 | ''' 414 | isConsistent = True 415 | keyList = [] 416 | for db in self.dbs_: 417 | keyList.append(db.get_key_all()) 418 | numDb = len(self.dbs_) 419 | #verify the length of all keys is the same 420 | numKeys = [len(keyList[i]) for i in range(numDb)] 421 | for i in range(numDb): 422 | isConsistent = isConsistent and keyList[0] == len(keyList[i]) 423 | #If hard consistency is enforced return 424 | if not isConsistent and not softOnLength: 425 | return False, None 426 | #if inconsistency on lengths is allowed 427 | if not isConsistent: 428 | nK = min(numKeys) 429 | isConsistent = True 430 | else: 431 | nK = numKeys[0] 432 | #If the number of keys are the same check that 433 | #keys have the exact the same value 434 | for n in range(numDb): 435 | for i in range(nK): 436 | key = keyList[0][i] 437 | isConsistent = isConsistent and (key == keyList[n][i]) 438 | return isConsistent, nK 439 | 440 | def read_key(self, keys): 441 | data = [] 442 | for db, key in zip(self.dbs_, keys): 443 | dat, _ = db.read_key(key) 444 | data.append(dat) 445 | return data 446 | 447 | #Read a common key from all the dbs 448 | def read_common_key(self, key): 449 | data = [] 450 | for db in self.dbs_: 451 | dat, _ = db.read_key(key) 452 | data.append(dat) 453 | return data 454 | 455 | def read_next(self): 456 | data = [] 457 | for db in self.dbs_: 458 | dat,_ = db.read_next() 459 | data.append(dat) 460 | if self.isMulti_: 461 | return data 462 | else: 463 | return data[0], data[1] 464 | 465 | def read_batch(self, batchSz): 466 | data = [] 467 | for db in self.dbs_: 468 | dat,_ = db.read_batch(batchSz) 469 | data.append(dat) 470 | return data[0], data[1] 471 | 472 | def read_batch_data_label(self, batchSz): 473 | data, label = [], [] 474 | for db in self.dbs_: 475 | dat,lb = db.read_batch(batchSz) 476 | data.append(dat) 477 | label.append(lb) 478 | if self.isMulti_: 479 | return data, label 480 | else: 481 | return data[0], data[1], label[0], label[1] 482 | 483 | def close(self): 484 | for db in self.dbs_: 485 | db.close() 486 | 487 | ## 488 | # Read multiple LMDBs simultaneosuly 489 | class MultiDbReader(DoubleDbReader): 490 | def __init__(self, dbNames, isLMDB=True, readahead=True, 491 | wrapAround=False): 492 | DoubleDbReader.__init__(self, dbNames, isLMDB=isLMDB, 493 | readahead=readahead, wrapAround=wrapAround, isMulti=True) 494 | 495 | ## 496 | # For reading generic window reader. 497 | class GenericWindowReader: 498 | def __init__(self, fileName): 499 | self.fid_ = open(fileName,'r') 500 | line = self.fid_.readline() 501 | assert(line.split()[1] == 'GenericDataLayer') 502 | self.num_ = int(self.fid_.readline()) 503 | self.numIm_ = int(self.fid_.readline()) 504 | self.lblSz_ = int(self.fid_.readline()) 505 | self.count_ = 0 506 | self.open_ = True 507 | 508 | def read_next(self): 509 | if self.count_ == self.num_: 510 | print "All lines already read" 511 | return None, None 512 | count = int(self.fid_.readline().split()[1]) 513 | assert count == self.count_ 514 | self.count_ += 1 515 | imDat = [] 516 | for n in range(self.numIm_): 517 | imDat.append(self.fid_.readline()) 518 | lbls = self.fid_.readline().split() 519 | lbls = np.array([float(l) for l in lbls]).reshape(1,self.lblSz_) 520 | return imDat, lbls 521 | 522 | #Get the processed images and labels 523 | def read_next_processed(self, rootFolder, returnName=False): 524 | imDat, lbls = self.read_next() 525 | ims = [] 526 | imNames, outNames = [], [] 527 | for l in imDat: 528 | imName, ch, h, w, x1, y1, x2, y2 = l.strip().split() 529 | imName = osp.join(rootFolder, imName) 530 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) 531 | im = scm.imread(imName) 532 | ims.append(im[y1:y2, x1:x2,:]) 533 | imNames.append(imName) 534 | #Generate an outprefix that maybe used to save the images 535 | _, fName = osp.split(imName) 536 | ext = fName[-4:] 537 | outNames.append(fName[:-4] + '-%d-%d-%d-%d%s' % (x1,y1,x2,y2,ext)) 538 | if returnName: 539 | return ims, lbls[0], imNames, outNames 540 | else: 541 | return ims, lbls[0] 542 | 543 | def get_all_labels(self): 544 | readFlag = True 545 | lbls = [] 546 | while readFlag: 547 | _, lbl = self.read_next() 548 | if lbl is None: 549 | readFlag = False 550 | continue 551 | else: 552 | lbls.append(lbl) 553 | lbls = np.concatenate(lbls) 554 | return lbls 555 | 556 | def is_open(self): 557 | return self.open_ 558 | 559 | def is_eof(self): 560 | return self.count_ >= self.num_ 561 | 562 | def close(self): 563 | self.fid_.close() 564 | self.open_ = False 565 | 566 | #Save image crops 567 | def save_crops(self, rootFolder, tgtDir, numIm=None): 568 | ''' 569 | rootFolder: the root folder for the window file 570 | tgtDir : the directory where the images should be saved 571 | ''' 572 | count = 0 573 | readFlag = True 574 | ou.mkdir(tgtDir) 575 | while readFlag: 576 | ims, _, imNames, oNames = self.read_next_processed(rootFolder, 577 | returnName=True) 578 | for im, name, oName in zip(ims, imNames, oNames): 579 | svName = osp.join(tgtDir, oName) 580 | scm.imsave(svName, im) 581 | if self.is_eof(): 582 | readFlag = False 583 | count += 1 584 | if numIm is not None and count >= numIm: 585 | readFlag = False 586 | 587 | 588 | 589 | ## 590 | # For writing generic window file layers. 591 | class GenericWindowWriter: 592 | def __init__(self, fileName, numEx, numImgPerEx, lblSz): 593 | ''' 594 | fileName : the file to write to. 595 | numEx : the number of examples 596 | numImgPerEx: the number of images per example 597 | lblSz : the size of the labels 598 | ''' 599 | self.file_ = fileName 600 | self.num_ = numEx 601 | self.numIm_ = numImgPerEx 602 | self.lblSz_ = lblSz 603 | self.count_ = 0 #The number of examples written. 604 | 605 | dirName = os.path.dirname(fileName) 606 | if len(dirName) >0 and not os.path.exists(dirName): 607 | os.makedirs(dirName) 608 | self.initWrite_ = False 609 | 610 | #If image and labels are being stacked 611 | self.imStack_ = [] 612 | self.lbStack_ = [] 613 | 614 | #Start writing 615 | def init_write(self): 616 | if self.initWrite_: 617 | return 618 | self.fid_ = open(self.file_, 'w') 619 | self.fid_.write('# GenericDataLayer\n') 620 | self.fid_.write('%d\n' % self.num_) #Num Examples. 621 | self.fid_.write('%d\n' % self.numIm_) #Num Images per Example. 622 | self.fid_.write('%d\n' % self.lblSz_) #Num Labels 623 | self.initWrite_ = True 624 | 625 | ## 626 | # Private Helper function for writing the images for the WindowFile 627 | def write_image_line_(self, imgName, imgSz, bbox): 628 | ''' 629 | imgSz: channels * height * width 630 | bbox : x1, y1, x2, y2 631 | ''' 632 | ch, h, w = imgSz 633 | x1,y1,x2,y2 = bbox 634 | x1 = max(0, x1) 635 | y1 = max(0, y1) 636 | x2 = min(x2, w-1) 637 | y2 = min(y2, h-1) 638 | self.fid_.write('%s %d %d %d %d %d %d %d\n' % (imgName, 639 | ch, h, w, x1, y1, x2, y2)) 640 | 641 | ## 642 | def write(self, lbl, *args): 643 | assert len(args)==self.numIm_,\ 644 | 'Wrong input arguments: (%d v/s %d)' % (len(args),self.numIm_) 645 | #Make sure the writing has been intialized 646 | if not self.initWrite_: 647 | self.init_write() 648 | #Start writing the current stuff 649 | self.fid_.write('# %d\n' % self.count_) 650 | #Write the images 651 | for arg in args: 652 | if type(arg)==str: 653 | #Assuming arg is the imageline read from another window-file 654 | #and the last character in the str is \n 655 | self.fid_.write(arg) 656 | else: 657 | #print (len(arg), arg) 658 | imName, imSz, bbox = arg 659 | self.write_image_line_(imName, imSz, bbox) 660 | 661 | #Write the label 662 | lbStr = ['%f '] * self.lblSz_ 663 | lbStr = ''.join(lbS % lb for (lb, lbS) in zip(lbl, lbStr)) 664 | lbStr = lbStr[:-1] + '\n' 665 | self.fid_.write(lbStr) 666 | self.count_ += 1 667 | 668 | if self.count_ == self.num_: 669 | self.close() 670 | 671 | ## 672 | #Instead of writing, just stack 673 | def push_to_stack(self, lbl, *args): 674 | assert len(args)==self.numIm_,\ 675 | 'Wrong input arguments: (%d v/s %d)' % (len(args),self.numIm_) 676 | self.imStack_.append(args) 677 | self.lbStack_.append(lbl) 678 | 679 | ## 680 | #Write the stack 681 | def write_stack(self, rndState=None, rndSeed=None): 682 | if rndSeed is not None: 683 | rndState = np.random.RandomState(rndSeed) 684 | N = len(self.imStack_) 685 | assert N == len(self.lbStack_) 686 | 687 | if rndState is None: 688 | perm = range(N) 689 | else: 690 | perm = rndState.permutation(N) 691 | ims = [self.imStack_[p] for p in perm] 692 | lbs = [self.lbStack_[p] for p in perm] 693 | self.num_ = N 694 | for n in range(N): 695 | self.write(lbs[n], *(ims[n][0])) 696 | self.close() 697 | 698 | ## 699 | def close(self): 700 | self.fid_.close() 701 | 702 | 703 | ## 704 | # For writing sqbox window file layers. 705 | class SqBoxWindowWriter: 706 | def __init__(self, fileName, numEx): 707 | ''' 708 | fileName : the file to write to. 709 | numEx : the number of examples 710 | The format 711 | # ExNum 712 | IMG_NAME IMG_SZ 713 | NUM_OBJ 714 | OBJ1_X OBJ1_Y BBOX1_X BBOX1_Y BBOX1_SZ 715 | .. 716 | . 717 | # ExNum 718 | .. 719 | . 720 | - x1, y1 for object position 721 | - xc, yc for the center of desired bbox 722 | - sqSz the length of the desired bbox 723 | express sqSz as the ratio of the largest imgSz / sqSz 724 | ''' 725 | self.file_ = fileName 726 | self.num_ = numEx 727 | self.count_ = 0 #The number of examples written. 728 | 729 | dirName = os.path.dirname(fileName) 730 | if not os.path.exists(dirName): 731 | os.makedirs(dirName) 732 | 733 | self.fid_ = open(self.file_, 'w') 734 | self.fid_.write('# SqBoxWindowDataLayer\n') 735 | self.fid_.write('%d\n' % self.num_) #Num Examples. 736 | 737 | ## 738 | # Private Helper function for writing the images for the WindowFile 739 | def write_image_line_(self, imgName, imgSz, numObj): 740 | ''' 741 | imgSz : channels * height * width 742 | numObj: number of objects in the image 743 | ''' 744 | ch, h, w = imgSz 745 | self.fid_.write('%d\n' % numObj) 746 | self.fid_.write('%s %d %d %d\n' % (imgName, 747 | ch, h, w)) 748 | 749 | ## 750 | def write(self, *args): 751 | self.fid_.write('# %d\n' % self.count_) 752 | #Write the images 753 | imName, imSz, objPos, bboxPos, bboxSz = args 754 | numObj = len(objPos) 755 | self.write_image_line_(imName, imSz, numObj) 756 | for i in range(numObj): 757 | xObjPos, yObjPos = objPos[i] 758 | xBbxPos, yBbxPos = bboxPos[i] 759 | self.fid_.write('%d %d %d %d %d\n' % (xObjPos, yObjPos, xBbxPos, yBbxPos, bboxSz[i])) 760 | self.count_ += 1 761 | if self.count_ == self.num_: 762 | self.close() 763 | 764 | ## 765 | def close(self): 766 | self.fid_.close() 767 | 768 | ## 769 | # For reading generic window reader. 770 | class SqBoxWindowReader: 771 | def __init__(self, fileName): 772 | self.fid_ = open(fileName,'r') 773 | line = self.fid_.readline() 774 | assert(line.split()[1] == 'SqBoxWindowDataLayer') 775 | self.num_ = int(self.fid_.readline()) 776 | self.count_ = 0 777 | 778 | def read_next(self): 779 | if self.count_ == self.num_: 780 | print "All lines already read" 781 | return None, None 782 | count = int(self.fid_.readline().split()[1]) 783 | assert count == self.count_ 784 | self.count_ += 1 785 | #The number of boxes in the image 786 | numBox = int(self.fid_.readline()) 787 | imName.append(self.fid_.readline()) 788 | #Read all the boxes 789 | objPos, bbxPos, bbxSz = [], [], [] 790 | for n in range(numBox): 791 | lbls = self.fid_.readline().split() 792 | lbls = [int(l) for l in lbls] 793 | ox, oy, bx, by, bs = lbls 794 | objPos.append([ox, oy]) 795 | bbxPos.append([bx, by]) 796 | bbxSz.append(bs) 797 | return imName, objPos, bbxPos, bbxSz 798 | 799 | def get_all_labels(self): 800 | readFlag = True 801 | lbls = [] 802 | while readFlag: 803 | _, lbl = self.read_next() 804 | if lbl is None: 805 | readFlag = False 806 | continue 807 | else: 808 | lbls.append(lbl) 809 | lbls = np.concatenate(lbls) 810 | return lbls 811 | 812 | def close(self): 813 | self.fid_.close() 814 | 815 | def red_col_sel(col): 816 | if col[0] < 0.6: 817 | return False 818 | else: 819 | return True 820 | 821 | #READ PCD Files 822 | class PCDReader(object): 823 | def __init__(self, fName=None, keepNaN=False, subsample=None, colsel=None, fromFile=True): 824 | self.fName = fName 825 | self.ax_ = None 826 | if fromFile: 827 | self.read(keepNaN=keepNaN,subsample=subsample, colsel=colsel) 828 | 829 | @classmethod 830 | def from_pts(cls, pts, subsample=0.2): 831 | self = cls(None, fromFile=False) 832 | N = pts.shape[0] 833 | pts = pts.copy() 834 | if subsample is not None: 835 | perm = np.random.permutation(N) 836 | perm = perm[0:int(subsample*N)] 837 | pts = pts[perm] 838 | self.x_ = pts[:,0] 839 | self.y_ = pts[:,1] 840 | self.z_ = pts[:,2] 841 | self.c_ = pts[:,3:6] 842 | return self 843 | 844 | @classmethod 845 | def from_db(cls, dbPath): 846 | self = cls(None, fromFile=False) 847 | self.db_ = DbReader(dbPath) 848 | return self 849 | 850 | def read_next(self, subSample=None): 851 | dat,_ = self.db_.read_next() 852 | if dat is None: 853 | return None 854 | dat = dat.transpose((1,0,2)) 855 | nr, nc, _ = dat.shape 856 | if subSample is None: 857 | subSample=1 858 | rows = range(0, nr, subSample) 859 | cols = range(0, nc, subSample) 860 | xIdx, yIdx = np.meshgrid(cols, rows) 861 | dat = dat[yIdx, xIdx] 862 | self.x_ = dat[:,:,0] 863 | self.y_ = dat[:,:,1] 864 | self.z_ = dat[:,:,2] 865 | self.c_ = dat[:,:,3:6] 866 | return True 867 | 868 | def get_mask(self): 869 | nanMask = np.isnan(self.z_) 870 | d = self.z_.copy() 871 | d[nanMask] = -10 872 | d = d + 0.255 873 | self.mask_ = d > 0 874 | 875 | def to_rgbd(self): 876 | self.get_mask() 877 | im = (self.c_.copy() * 255).astype(np.uint8) 878 | d = self.z_.copy() 879 | d[~self.mask_] = 0.0 880 | assert np.max(d) < 0.1, np.max(d) 881 | d[self.mask_] = 255 * (d[self.mask_]/0.1) 882 | d = d.astype(np.uint8) 883 | print (np.min(d), np.max(d)) 884 | return im, d 885 | 886 | def get_masked_pts(self): 887 | self.get_mask() 888 | N = np.sum(self.mask_) 889 | pts = np.zeros((N,6), np.float32) 890 | pts[:,0] = 10*self.x_[self.mask_].reshape(N,) 891 | pts[:,1] = 10*self.y_[self.mask_].reshape(N,) 892 | pts[:,2] = 10*self.z_[self.mask_].reshape(N,) 893 | pts[:,3:6] = self.c_[self.mask_,0:3].reshape(N,3) 894 | return pts 895 | 896 | def save_rgbd(self, dirName=''): 897 | count = 0 898 | imName = osp.join(dirName, 'im%06d.png') 899 | dpName = osp.join(dirName, 'dp%06d.png') 900 | while True: 901 | isExist = self.read_next() 902 | if isExist is None: 903 | break 904 | im, d = self.to_rgbd() 905 | print im.shape, d.shape 906 | ou.mkdir(osp.dirname(imName)) 907 | scm.imsave(imName % (count+1), im) 908 | scm.imsave(dpName % (count+1), d) 909 | count += 1 910 | 911 | def plot_next_rgbd(self): 912 | import matplotlib.pyplot as plt 913 | self.read_next() 914 | im, d = self.to_rgbd() 915 | if self.ax_ is None: 916 | self.ax_ = [] 917 | plt.ion() 918 | fig = plt.figure() 919 | self.ax_.append(fig.add_subplot(121)) 920 | self.ax_.append(fig.add_subplot(122)) 921 | self.ax_[0].imshow(im) 922 | self.ax_[1].imshow(d) 923 | plt.draw() 924 | plt.show() 925 | 926 | def read(self, keepNaN=False, subsample=None, colsel=None): 927 | with open(self.fName, 'r') as fid: 928 | lines = fid.readlines() 929 | for i, l in enumerate(lines): 930 | if i<=8: 931 | continue 932 | if i==9: 933 | N = int(l.strip().split()[1]) 934 | break 935 | lines = lines[11:] 936 | assert len(lines)==N 937 | self.x_ = np.zeros((N,), np.float32) 938 | self.y_ = np.zeros((N,), np.float32) 939 | self.z_ = np.zeros((N,), np.float32) 940 | self.c_ = np.zeros((N,3), np.float32) 941 | count = 0 942 | for i,l in enumerate(lines): 943 | if subsample is not None: 944 | if not np.mod(i, subsample) == 0: 945 | continue 946 | x, y, z, rgb = l.strip().split() 947 | nanVal = False 948 | if x=='nan' or y=='nan' or z=='nan': 949 | nanVal = True 950 | if not keepNaN and nanVal: 951 | continue 952 | #col = np.array([np.float32(rgb)]) 953 | col = np.array([np.float64(rgb)], np.float32) 954 | col = (col.view(np.uint8)[0:3])/256.0 955 | #to rgb 956 | col = np.array((col[2], col[1], col[0])) 957 | if nanVal: 958 | col = np.array((0,0,1.0)) 959 | if colsel is not None: 960 | isValid = colsel(col) 961 | if not isValid: 962 | continue 963 | self.x_[count] = float(x) 964 | self.y_[count] = float(y) 965 | self.z_[count] = float(z) 966 | self.c_[count] = col 967 | count += 1 968 | self.x_ = self.x_[0:count] 969 | self.y_ = self.y_[0:count] 970 | self.z_ = self.z_[0:count] 971 | self.c_ = self.c_[0:count] 972 | self.N_ = count 973 | 974 | def get_rgb_im(self): 975 | im = np.zeros((480, 640, 3)).astype(np.uint8) 976 | count = 0 977 | for r in range(480): 978 | for c in range(640): 979 | im[r, c] = (255 * self.c_[count].reshape((1,1,3))).astype(np.uint8) 980 | count += 1 981 | return im 982 | 983 | def matplot(self, ax=None): 984 | import matplotlib.pyplot as plt 985 | if ax is None: 986 | from mpl_toolkits.mplot3d import Axes3D 987 | plt.ion() 988 | fig = plt.figure() 989 | ax = fig.add_subplot(111, projection='3d') 990 | #perm = np.random.permutation(self.x_.shape[0]) 991 | #perm = perm[0:5000] 992 | perm = np.array(range(self.x_.shape[0])) 993 | for i in range(1): 994 | #ax.scatter(self.x_, self.y_, self.z_, 995 | # c=tuple(np.random.rand(self.N_,3))) 996 | ax.scatter(self.x_[perm], self.y_[perm], self.z_[perm], 997 | c=tuple(self.c_[perm])) 998 | plt.draw() 999 | plt.show() 1000 | 1001 | 1002 | def save_lmdb_images(ims, dbFileName, labels=None, asFloat=False): 1003 | ''' 1004 | Assumes ims are numEx * ch * h * w 1005 | ''' 1006 | N,_,_,_ = ims.shape 1007 | if labels is not None: 1008 | assert labels.dtype == np.int or labels.dtype==np.long 1009 | else: 1010 | labels = np.zeros((N,)).astype(np.int) 1011 | 1012 | db = lmdb.open(dbFileName, map_size=int(1e12)) 1013 | with db.begin(write=True) as txn: 1014 | for (idx, im) in enumerate(ims): 1015 | if not asFloat: 1016 | im = im.astype(np.uint8) 1017 | imDat = caffe.io.array_to_datum(im, label=labels[idx]) 1018 | txn.put('{:0>10d}'.format(idx), imDat.SerializeToString()) 1019 | db.close() 1020 | 1021 | ## 1022 | # Save the weights in a form that will be used by matlab function 1023 | # swap_weights to generate files useful for matconvnet. 1024 | def save_weights_for_matconvnet(net, outName, matlabRefFile=None): 1025 | ''' 1026 | net : Instance of my_pycaffe.MyNet 1027 | outName: The matlab file which needs to store parameters. 1028 | ''' 1029 | params = {} 1030 | for (count,key) in enumerate(net.net.params.keys()): 1031 | blob = net.net.params[key] 1032 | wKey = key + '_w' 1033 | bKey = key + '_b' 1034 | params[wKey] = copy.deepcopy(blob[0].data) 1035 | params[bKey] = copy.deepcopy(blob[1].data) 1036 | N,ch,h,w = params[wKey].shape 1037 | print params[wKey].shape, params[bKey].shape, N 1038 | num = N * ch * h * w 1039 | if count==0: 1040 | print 'Converting BGR filters to RGB filters' 1041 | assert ch==3, 'The code is hacked as MatConvNet works with RGB format instead of BGR' 1042 | params[wKey] = params[wKey][:,[2,1,0],:,:] 1043 | params[wKey] = params[wKey].transpose((2,3,1,0)).reshape(1,num,order='F') 1044 | if N==1 and ch==1: 1045 | #Hacky way of finding a FC layer 1046 | N = h 1047 | params[bKey] = params[bKey].reshape((1,N)) 1048 | if matlabRefFile is not None: 1049 | params['refFile'] = matlabRefFile 1050 | else: 1051 | params['refFile'] = '' 1052 | sio.savemat(outName, params) 1053 | _mat_to_matconvnet(matlabRefFile, outName, outName) 1054 | 1055 | 1056 | ## 1057 | # Convert matconvnet network into a caffemodel 1058 | def matconvnet_to_caffemodel(inFile, outFile): 1059 | ''' 1060 | Relies on a matlab helper function which converts a matconvnet model 1061 | into an approrpriate format. 1062 | ''' 1063 | #Right now the code is hacked to work with the BVLC reference model definition. 1064 | #defFile = '/data1/pulkitag/caffe_models/bvlc_reference/caffenet_deploy.prototxt' 1065 | #defFile = '/work4/pulkitag-code/code/ief/models/vgg_16_base.prototxt' 1066 | defFile = '/work4/pulkitag-code/code/ief/models/vgg_s.prototxt' 1067 | net = caffe.Net(defFile, caffe.TEST) 1068 | 1069 | #Load the weights 1070 | dat = sio.loadmat(inFile, squeeze_me=True) 1071 | w = dat['weights'] 1072 | b = dat['biases'] 1073 | names = dat['names'] 1074 | 1075 | #Hack the names 1076 | #names[5] = 'fc6' 1077 | #names[6] = 'fc7' 1078 | #names[7] = 'fc8' 1079 | 1080 | count = 0 1081 | for n,weight,bias in zip(names, w, b): 1082 | print n 1083 | if 'conv' in n: 1084 | weight = weight.transpose((3,2,0,1)) 1085 | elif 'fc' in n: 1086 | print weight.shape 1087 | if weight.ndim==4: 1088 | weight = weight.transpose((3,2,0,1)) 1089 | print weight.shape 1090 | num,ch,h,w = weight.shape 1091 | weight = weight.reshape((1,1,num,ch*h*w)) 1092 | elif weight.ndim==1: 1093 | #This can happen because adding a singleton dimension in matlab in the end 1094 | #is not possible 1095 | print type(bias) 1096 | assert type(bias)==type(1.0) 1097 | weight = weight.reshape((1,len(weight))) 1098 | print weight.shape 1099 | else: 1100 | weight = weight.transpose((1,0)) 1101 | 1102 | if type(bias)==type(1.0): 1103 | #i.e. 1-D bias 1104 | bias = bias * np.ones((1,1,1,1)) 1105 | else: 1106 | bias = bias.reshape((1,1,1,len(bias))) 1107 | if count == 0: 1108 | #RGB to BGR flip for the first layer channels 1109 | weight[:,0:3,:,:] = weight[:,[2,1,0],:,:] 1110 | net.params[n][0].data[...] = weight 1111 | net.params[n][1].data[...] = bias 1112 | count+=1 1113 | #Save the network 1114 | print outFile 1115 | net.save(outFile) 1116 | 1117 | 1118 | 1119 | #Converts the weights stored in a .mat file into format 1120 | #for matconvnet. 1121 | def _mat_to_matconvnet(srcFile, targetFile, outFile): 1122 | ''' 1123 | srcFile : Provided the matconvnet format 1124 | targetFile: Provides the n/w weights in .mat format 1125 | obtained from first part of save_weights_for_matconvnet() 1126 | outFile: Where should the weights be saved. 1127 | ''' 1128 | #meng = men.start_matlab() 1129 | _ = meng.addpath(MATLAB_PATH, nargout=1) 1130 | meng.swap_weights_matconvnet(srcFile, targetFile, outFile, nargout=0) 1131 | meng.exit() 1132 | 1133 | 1134 | ## 1135 | # Test the conversion of matconvnet into caffemodel 1136 | def test_convert(): 1137 | inFile = '/data1/pulkitag/others/tmp.mat' 1138 | outFile = '/data1/pulkitag/others/tmp.caffemodel' 1139 | #inFile = '/data1/pulkitag/others/alex-matconvnet.mat' 1140 | #outFile = '/data1/pulkitag/others/alex-matconvnet.caffemodel' 1141 | matconvnet_to_caffemodel(inFile, outFile) 1142 | 1143 | def vis_convert(): 1144 | defFile = '/data1/pulkitag/caffe_models/bvlc_reference/caffenet_deploy.prototxt' 1145 | modelFile = '/data1/pulkitag/others/alex-matconvnet.caffemodel' 1146 | net = mp.MyNet(defFile, modelFile) 1147 | net.vis_weights('conv1') 1148 | 1149 | 1150 | def test_convert_features(): 1151 | 1152 | defFile = '/data1/pulkitag/caffe_models/bvlc_reference/caffenet_deploy.prototxt' 1153 | netFile = '/data1/pulkitag/others/tmp.caffemodel' 1154 | #netFile = '/data1/pulkitag/others/alex-matconvnet.caffemodel' 1155 | net = mp.MyNet(defFile, netFile) 1156 | 1157 | net.set_preprocess(isBlobFormat=True, chSwap=None) 1158 | imData = sio.loadmat('/data1/pulkitag/others/ref_imdata.mat',squeeze_me=True) 1159 | imData = imData['im'] 1160 | imData = imData.transpose((3,2,0,1)) 1161 | imData = imData[:,[2,1,0],:,:] 1162 | imData = imData[0:10] 1163 | 1164 | op = net.forward_all(blobs=['fc8','conv1','data','conv2','conv5','fc6','fc7'],**{'data': imData}) 1165 | pdb.set_trace() 1166 | 1167 | -------------------------------------------------------------------------------- /my_pycaffe_tests.py: -------------------------------------------------------------------------------- 1 | ## @package my_pycaffe_tests 2 | # Unit Testing functions. 3 | # 4 | 5 | import my_pycaffe as mp 6 | import my_pycaffe_utils as mpu 7 | import numpy as np 8 | import pdb 9 | import os 10 | try: 11 | import h5py 12 | except: 13 | print ('WARNING: h5py not found, some functions may not work') 14 | 15 | ## 16 | # Test code for Zeiler-Fergus Saliency. 17 | def test_zf_saliency(dataSet='mnist', stride=2, patchSz=5): 18 | 19 | if dataSet=='mnist': 20 | defFile = '/work4/pulkitag-code/pkgs/caffe-v2-2/modelFiles/mnist/hdf5_test/lenet.prototxt' 21 | modelFile,_ = mp.get_model_mean_file('lenet') 22 | net = mp.MyNet(defFile, modelFile, isGPU=False) 23 | N = net.get_batchsz() 24 | net.set_preprocess(chSwap=None, imageDims=(28,28,1), isBlobFormat=True) 25 | 26 | h5File = '/data1/pulkitag/mnist/h5store/test/batch1.h5' 27 | fid = h5py.File(h5File,'r') 28 | data = fid['data'] 29 | data = data[0:N] 30 | 31 | #Do the saliency 32 | imSal, score = mpu.zf_saliency(net, data, 10, 'ip2', patchSz=patchSz, stride=stride) 33 | gtLabels = fid['label'] 34 | else: 35 | netName = 'bvlcAlexNet' 36 | opLayer = 'fc8' 37 | defFile = mp.get_layer_def_files(netName, layerName=opLayer) 38 | modelFile, meanFile = mp.get_model_mean_file(netName) 39 | net = mp.MyNet(defFile, modelFile) 40 | net.set_preprocess(imageDims=(256,256,3), meanDat=meanFile, rawScale=255, isBlobFormat=True) 41 | 42 | ilDat = mpu.ILSVRC12Reader() 43 | ilDat.set_count(2) 44 | data,gtLabels,syn,words = ilDat.read() 45 | data = data.reshape((1,data.shape[0],data.shape[1],data.shape[2])) 46 | data = data.transpose((0,3,1,2)) 47 | print data.shape 48 | imSal, score = mpu.zf_saliency(net, data, 1000, 'fc8', patchSz=patchSz, stride=stride) 49 | 50 | pdLabels = np.argmax(score.squeeze(), axis=1) 51 | return data, imSal, pdLabels, gtLabels 52 | 53 | 54 | ## 55 | # Test Reading the protoFile 56 | def test_get_proto_param(): 57 | paths = mpu.get_caffe_paths() 58 | testFile = os.path.join(paths['pythonTest'], 'test_conv_param.txt') 59 | fid = open(testFile, 'r') 60 | lines = fid.readlines() 61 | fid.close() 62 | params = mpu.get_proto_param(lines) 63 | return params 64 | -------------------------------------------------------------------------------- /other_utils.py: -------------------------------------------------------------------------------- 1 | ## @package other_utils 2 | # Miscellaneous Util Functions 3 | # 4 | import numpy as np 5 | #import scipy.misc as scm 6 | import matplotlib.pyplot as plt 7 | import copy 8 | import os 9 | from os import path as osp 10 | import collections as co 11 | import pdb 12 | from easydict import EasyDict as edict 13 | 14 | ## 15 | # Get the defaults 16 | def get_defaults(setArgs, defArgs, defOnly=True): 17 | for key in setArgs.keys(): 18 | if defOnly: 19 | assert defArgs.has_key(key), 'Key not found: %s' % key 20 | if key in defArgs.keys(): 21 | defArgs[key] = copy.deepcopy(setArgs[key]) 22 | return defArgs 23 | 24 | 25 | 26 | ## 27 | # Verify if all the keys are present recursively in the dict 28 | def verify_recursive_key(data, keyNames, verifyOnly=False): 29 | ''' 30 | data : dict like data['a']['b']['c']...['l'] 31 | keyNames: list of keys 32 | verifyOnly: if TRUE then dont raise exceptions - just return the truth value 33 | ''' 34 | assert isinstance(keyNames, list), 'keyNames is required to be a list' 35 | #print data, keyNames 36 | if verifyOnly: 37 | if not data.has_key(keyNames[0]): 38 | return False 39 | else: 40 | assert data.has_key(keyNames[0]), '%s not present' % keyNames[0] 41 | for i in range(1,len(keyNames)): 42 | dat = reduce(lambda dat, key: dat[key], keyNames[0:i], data) 43 | assert isinstance(dat, dict), 'Wrong Keys' 44 | if verifyOnly: 45 | if not dat.has_key(keyNames[i]): 46 | return False 47 | else: 48 | assert dat.has_key(keyNames[i]), '%s key not present' % keyNames[i] 49 | return True 50 | 51 | ## 52 | # Set the value of a recursive key. 53 | def set_recursive_key(data, keyNames, val): 54 | if verify_recursive_key(data, keyNames): 55 | dat = reduce(lambda dat, key: dat[key], keyNames[:-1], data) 56 | dat[keyNames[-1]] = val 57 | else: 58 | raise Exception('Keys not present') 59 | 60 | 61 | def add_recursive_key(data, keyNames, val): 62 | #isKey = verify_recursive_key(data, keyNames) 63 | #assert not isKey, 'key is already present' 64 | dat = reduce(lambda dat, key: dat[key], keyNames[:-1], data) 65 | dat[keyNames[-1]] = val 66 | 67 | 68 | ## 69 | # Delete the recursive key 70 | def del_recursive_key(data, keyNames): 71 | if verify_recursive_key(data, keyNames): 72 | dat = reduce(lambda dat, key: dat[key], keyNames[:-1], data) 73 | del dat[keyNames[-1]] 74 | else: 75 | raise Exception('Keys not present') 76 | 77 | 78 | 79 | ## 80 | # Get the item from a recursive key 81 | def get_item_recursive_key(data, keyNames, verifyOnly=False): 82 | if verify_recursive_key(data, keyNames, verifyOnly=verifyOnly): 83 | dat = reduce(lambda dat, key: dat[key], keyNames[:-1], data) 84 | return dat[keyNames[-1]] 85 | else: 86 | print "Not found:", keyNames 87 | return None 88 | 89 | ## 90 | # Find the path to the key in a recursive dictionary. 91 | def find_path_key(data, keyName): 92 | ''' 93 | Returns path to the first key of name keyName that is found. 94 | if keyName is a list - [k1, k2 ..kp] then find in data[k1][k2]...[kp-1] the key kp 95 | ''' 96 | path = [] 97 | prevKey = [] 98 | if not isinstance(data, dict): 99 | return path 100 | #Find if all keys except the last one exist or not. 101 | if isinstance(keyName, list): 102 | data = copy.deepcopy(data) 103 | for key in keyName[0:-1]: 104 | if key not in data: 105 | return [] 106 | else: 107 | data = data[key] 108 | prevKey = keyName[0:-1] 109 | keyName = keyName[-1] 110 | 111 | if data.has_key(keyName): 112 | return [keyName] 113 | else: 114 | for key in data.keys(): 115 | pathFound = find_path_key(data[key], keyName) 116 | if len(pathFound) > 0: 117 | return prevKey + [key] + pathFound 118 | return path 119 | 120 | ## 121 | # Find an item in dict 122 | # keyName should be a string or an list of a single name. 123 | def get_item_dict(data, keyName): 124 | keyPath = find_path_key(data, keyName) 125 | #print keyPath 126 | if len(keyPath)==0: 127 | return None 128 | else: 129 | return get_item_recursive_key(data, keyPath) 130 | 131 | ## 132 | #Find the key to the item in a dict 133 | def find_keyofitem(data, item): 134 | keyName = None 135 | for k in data.keys(): 136 | tp = type(data[k]) 137 | if tp == dict or tp ==edict: 138 | keyPath = find_keyofitem(data[k], item) 139 | if keyPath is None: 140 | continue 141 | else: 142 | keyName = [k] + keyPath 143 | else: 144 | if data[k] == item: 145 | keyName = [k] 146 | return keyName 147 | 148 | ## 149 | # Read the image 150 | def read_image(imName, color=True, isBGR=False, imSz=None): 151 | ''' 152 | color: True - if a gray scale image is encountered convert into color 153 | ''' 154 | im = plt.imread(imName) 155 | if color: 156 | if im.ndim==2: 157 | print "Converting grayscale image into color image" 158 | im = np.tile(im.reshape(im.shape[0], im.shape[1],1),(1,1,3)) 159 | if isBGR: 160 | im = im[:,:,[2,1,0]] 161 | #Resize if needed 162 | if imSz is not None: 163 | assert isinstance(imSz,int) 164 | im = scm.imresize(im, (imSz,imSz)) 165 | return im 166 | 167 | 168 | ## 169 | # Crop the image 170 | def crop_im(im, bbox, **kwargs): 171 | ''' 172 | The bounding box is assumed to be in the form (xmin, ymin, xmax, ymax) 173 | kwargs: 174 | imSz: Size of the image required 175 | ''' 176 | cropType = kwargs['cropType'] 177 | imSz = kwargs['imSz'] 178 | x1,y1,x2,y2 = bbox 179 | x1 = max(0, x1) 180 | y1 = max(0, y1) 181 | x2 = min(im.shape[1], x2) 182 | y2 = min(im.shape[0], y2) 183 | if cropType=='resize': 184 | imBox = im[y1:y2, x1:x2] 185 | imBox = scm.imresize(imBox, (imSz, imSz)) 186 | if cropType=='contPad': 187 | contPad = kwargs['contPad'] 188 | x1 = max(0, x1 - contPad) 189 | y1 = max(0, y1 - contPad) 190 | x2 = min(im.shape[1], x2 + contPad) 191 | y2 = min(im.shape[0], y2 + contPad) 192 | imBox = im[y1:y2, x1:x2] 193 | imBox = scm.imresize(imBox, (imSz, imSz)) 194 | else: 195 | raise Exception('Unrecognized crop type') 196 | return imBox 197 | 198 | ## 199 | # Read and crop the image. 200 | def read_crop_im(imName, bbox, **kwargs): 201 | if kwargs.has_key('color'): 202 | im = read_image(imName, color=kwargs['color']) 203 | else: 204 | im = read_image(imName) 205 | return crop_im(im, bbox, **kwargs) 206 | 207 | 208 | ## 209 | # Makes a table from dict 210 | def make_table(keyOrder=None, colWidth=15, sep=None, **kwargs): 211 | ''' 212 | kwargs should contains keys and lists as the values. 213 | Each dictionaty will be plotted as a column. 214 | ''' 215 | if keyOrder is None: 216 | keyOrder = kwargs.keys() 217 | if sep is None: 218 | sepStr = '' 219 | elif sep == 'csv': 220 | sepStr = ',' 221 | elif sep == 'tab': 222 | sepStr == '\t' 223 | 224 | for i,key in enumerate(keyOrder): 225 | if i==0: 226 | L = len(kwargs[key]) 227 | else: 228 | assert L == len(kwargs[key]), 'Wrong length for %s' % key 229 | 230 | N = len(keyOrder) 231 | formatStr = ("{:<%d} " % colWidth) + sepStr 232 | lines = [] 233 | lines.append(''.join(formatStr.format(k) for k in keyOrder) + '\n') 234 | if sepStr is None: 235 | lines.append('-' * 15 * N + '\n') 236 | 237 | for i in range(L): 238 | line = '' 239 | for key in keyOrder: 240 | if isinstance(kwargs[key][i], int): 241 | fStr = '%d' + sepStr 242 | elif type(kwargs[key][i]) in [float, np.float32, np.float64]: 243 | fStr = '%.1f' + sepStr 244 | elif isinstance(kwargs[key][i], str): 245 | fStr = '%s' + sepStr 246 | else: 247 | fStr = '%s' + sepStr 248 | line = line + formatStr.format(fStr % kwargs[key][i]) 249 | line = line + '\n' 250 | lines.append(line) 251 | 252 | for l in lines: 253 | print l 254 | 255 | 256 | #I will make the rows. 257 | def make_table_rows(**kwargs): 258 | #Find the maximum length of the key. 259 | maxKeyLen = 0 260 | for key,val in kwargs.iteritems(): 261 | maxKeyLen = max(maxKeyLen, len(key)) 262 | keyLen = maxKeyLen + 15 263 | keyStr = "{:<%d} " % keyLen 264 | formatStr = "{:<15} " 265 | #Lets start printing 266 | lines = [] 267 | count = 0 268 | for key,val in kwargs.iteritems(): 269 | line = '' 270 | line = line + keyStr.format('%s' % key) 271 | for v in val: 272 | if isinstance(v, int): 273 | fStr = '%d' 274 | elif isinstance(v, np.float32) or isinstance(v, np.float64): 275 | fStr = '%.3f' 276 | elif isinstance(v, str): 277 | fStr = '%s' 278 | else: 279 | fStr = '%s' 280 | line = line + formatStr.format(fStr % v) 281 | line = line + '\n' 282 | lines.append(line) 283 | if count == 0: 284 | lines.append('-' * 100 + '\n') 285 | count += 1 286 | 287 | for l in lines: 288 | print l 289 | 290 | ## 291 | # In a recursive dictionary - subselect a few fields 292 | # while vary others. For eg d['vr1']['a1']['b1'], d['vr2']['a1']['b2'], d['vr3']['a2']['b3'] 293 | # Now I might be interested only in values such that the second field is fixed to 'a1' 294 | # So that I get the output as d['vr1']['b1'], d['vr2']['b2'] 295 | def conditional_select(data, fields, reduceFn=None): 296 | ''' 297 | data : dict 298 | fields : fields (a list) 299 | [None,'a1',None] means that keep the second field fixed to 'a1', 300 | but consider all values of other fields. 301 | reduceFn : Typically the dict would store an array 302 | reductionFn can be any function to reduce this array to 303 | a quantity of interset like mean etc. 304 | ''' 305 | newData = co.OrderedDict() 306 | for key in data.keys(): 307 | if fields[0] is None: 308 | #Chose all the keys 309 | newData[key] = conditional_select(data[key], fields[1:], reduceFn=reduceFn) 310 | else: 311 | if key == fields[0]: 312 | if len(fields) > 1: 313 | newData = conditional_select(data[key], fields[1:], reduceFn=reduceFn) 314 | else: 315 | if reduceFn is None: 316 | newData = copy.deepcopy(data[key]) 317 | else: 318 | newData = copy.deepcopy(reduceFn(data[key])) 319 | return newData 320 | return newData 321 | 322 | ## 323 | # Count the things. 324 | def count_unique(arr, maxVal=None): 325 | if maxVal is None: 326 | elms = np.unique(arr) 327 | else: 328 | elms = range(maxVal+1) 329 | count = np.zeros((len(elms),)) 330 | for i,e in enumerate(elms): 331 | count[i] = np.sum(arr==e) 332 | 333 | return count 334 | 335 | ## 336 | # Create dir 337 | def create_dir(dirName): 338 | if not os.path.exists(dirName): 339 | os.makedirs(dirName) 340 | 341 | ## 342 | #Private function for chunking a path 343 | def _chunk_path(fName, N): 344 | assert '/' not in fName 345 | L = len(fName) 346 | if L <= N: 347 | return fName 348 | else: 349 | slices=[] 350 | for i in range(0,L,N): 351 | slices.append(fName[i:min(L, i+N)]) 352 | newName = ''.join('%s/' % s for s in slices) 353 | newName = newName[0:-1] 354 | return newName 355 | 356 | ## 357 | # chunk filenames 358 | def chunk_filename(fName, maxLen=255): 359 | ''' 360 | if any of the names is larger than 256 then 361 | the file cannot be stored so some chunking needs 362 | to be done 363 | ''' 364 | splitNames = fName.split('/') 365 | newSplits = [] 366 | for s in splitNames: 367 | if len(s)>=maxLen: 368 | newSplits.append(_chunk_path(s, maxLen-1)) 369 | else: 370 | newSplits.append(s) 371 | newName = ''.join('%s/' % s for s in newSplits) 372 | newName = newName[0:-1] 373 | dirName = os.path.dirname(newName) 374 | create_dir(dirName) 375 | return newName 376 | 377 | ## 378 | # Hash a dictonary into string 379 | def hash_dict_str(d, ignoreKeys=[]): 380 | d = copy.deepcopy(d) 381 | oKeys = [] 382 | for k in ignoreKeys: 383 | if k in d.keys(): 384 | del d[k] 385 | for k,v in d.iteritems(): 386 | if type(v) in [bool, int, float, str, type(None)]: 387 | continue 388 | else: 389 | assert type(v) in [dict, edict, co.OrderedDict],\ 390 | 'Type not recognized %s, for this type different results for different runs of\ 391 | hashing can be obtained, therefore the exception' % v 392 | oKeys.append(k) 393 | hStr = [] 394 | for k in oKeys: 395 | hStr.append('-%s' % hash_dict_str({k: d[k]})) 396 | del d[k] 397 | hStr = ''.join('%s' % s for s in hStr) 398 | return '%d%s' % (hash(frozenset(d.items())), hStr) 399 | 400 | ## 401 | # 402 | def mkdir(fName): 403 | if not osp.exists(fName): 404 | os.makedirs(fName) 405 | 406 | ## 407 | #Make parameter string for python layers 408 | def make_python_param_str(params, ignoreKeys=['expStr']): 409 | paramStr = '' 410 | for k,v in params.iteritems(): 411 | if k in ignoreKeys: 412 | continue 413 | if type(v) == bool: 414 | if v: 415 | paramStr = paramStr + ' --%s' % k 416 | else: 417 | paramStr = paramStr + ' --no-%s' % k 418 | else: 419 | paramStr = paramStr + ' --%s %s' % (k,v) 420 | return paramStr 421 | 422 | ## 423 | #Convert a list of ints into a string 424 | def ints_to_str(ints): 425 | ch = ''.join(str(unichr(i)) for i in ints) 426 | return ch 427 | -------------------------------------------------------------------------------- /pycaffe_config.py: -------------------------------------------------------------------------------- 1 | ## @package pycaffe_config 2 | # Specify the configurations 3 | # 4 | 5 | import socket 6 | from easydict import EasyDict as edict 7 | from os import path as osp 8 | 9 | cfg = edict() 10 | cfg.HOSTNAME = socket.gethostname() 11 | if cfg.HOSTNAME in ['anakin', 'vader', 'spock', 'poseidon']: 12 | cfg.IS_EC2 = False 13 | cfg.CAFFE_PATH = '/work4/pulkitag-code/pkgs/caffe-v2-3' 14 | #cfg.CAFFE_PATH = '/work4/pulkitag-code/pkgs/caffe-jeff-dec15' 15 | cfg.STREETVIEW_CODE_PATH = '/work4/pulkitag-code/code/projStreetView' 16 | cfg.STREETVIEW_DATA_MAIN = '/data0' 17 | cfg.STREETVIEW_DATA_READ_IM = cfg.STREETVIEW_DATA_MAIN 18 | #Billiards Path 19 | cfg.BILLIARDS_DATA_MAIN = '/data1' 20 | cfg.DATA0 = '/data0' 21 | cfg.DATA1 = '/data1' 22 | cfg.BILLIARDS_CODE_PATH = '/work4/pulkitag-code' 23 | #Caffe Model Path 24 | cfg.CAFFE_MODEL_PATH = '/data1/pulkitag/caffe_models/' 25 | else: 26 | cfg.IS_EC2 = True 27 | if osp.exists('/home-2/pagrawal'): 28 | cfg.STREETVIEW_CODE_PATH = '/home-2/pagrawal/code/streetview' 29 | cfg.CAFFE_PATH = '/home-2/pagrawal/pkgs/caffe-v2-3' 30 | else: 31 | cfg.STREETVIEW_CODE_PATH = '/home/ubuntu/code/streetview' 32 | cfg.CAFFE_PATH = '/home/ubuntu/caffe-v2-3' 33 | 34 | if osp.exists('/data0'): 35 | cfg.STREETVIEW_DATA_MAIN = '/data0' 36 | cfg.STREETVIEW_DATA_READ_IM = cfg.STREETVIEW_DATA_MAIN 37 | #BILLIARDS PATH 38 | cfg.BILLIARDS_DATA_MAIN = '/data1' 39 | cfg.BILLIARDS_CODE_PATH = '/work4/pulkitag-code' 40 | cfg.DATA0 = '/data0' 41 | cfg.DATA1 = '/data1' 42 | else: 43 | cfg.STREETVIEW_DATA_MAIN = '/puresan/shared' 44 | cfg.STREETVIEW_DATA_READ_IM = '/dev/shm' 45 | #BILLIARDA PATH 46 | cfg.BILLIARDS_DATA_MAIN = '/puresan/shared' 47 | cfg.BILLIARDS_CODE_PATH = '/home-2/pagrawal' 48 | cfg.DATA0 = '/dev/shm' 49 | cfg.DATA1 = '/dev/shm' 50 | #Caffe Model Path 51 | cfg.CAFFE_MODEL_PATH = osp.join(cfg.DATA0, 'pulkitag/caffe_models/') 52 | -------------------------------------------------------------------------------- /quick_things.py: -------------------------------------------------------------------------------- 1 | ## @package my_pycaffe 2 | # Some quick and dirty functions 3 | # 4 | 5 | import my_pycaffe as mp 6 | import my_pycaffe_utils as mpu 7 | from os import path as osp 8 | import caffe 9 | 10 | ## 11 | #Save alexnet weights stored uptil various levels 12 | def save_alexnet_levels(): 13 | maxLayer = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc6'] 14 | modelDir = '/data1/pulkitag/caffe_models/bvlc_reference' 15 | defFile = osp.join(modelDir, 'caffenet_deploy.prototxt') 16 | oDefFile = osp.join(modelDir, 'alexnet_levels', 'caffenet_deploy_%s.prototxt') 17 | modelFile = osp.join(modelDir, 'bvlc_reference_caffenet.caffemodel') 18 | oModelFile = osp.join(modelDir, 'alexnet_levels', 19 | 'bvlc_reference_caffenet_%s.caffemodel') 20 | for l in maxLayer: 21 | print (l) 22 | dFile = mpu.ProtoDef(defFile=defFile) 23 | dFile.del_all_layers_above(l) 24 | dFile.write(oDefFile % l) 25 | net = caffe.Net((oDefFile % l), modelFile, caffe.TEST) 26 | net.save(oModelFile % l) 27 | -------------------------------------------------------------------------------- /rot_utils.py: -------------------------------------------------------------------------------- 1 | ## @package rot_utils 2 | # Util functions for dealing with rotations 3 | # 4 | 5 | import scipy.io as sio 6 | import numpy as np 7 | from scipy import linalg as linalg 8 | import sys, os 9 | import pdb 10 | import math 11 | from mpl_toolkits.mplot3d import Axes3D 12 | 13 | _FLOAT_EPS_4 = np.finfo(float).eps * 4.0 14 | 15 | ## 16 | #Convert degrees to radians 17 | def deg2rad(dg): 18 | dg = np.mod(dg, 360) 19 | if dg > 180: 20 | dg = -(360 - dg) 21 | return ((np.pi)/180.) * dg 22 | 23 | def get_rot_angle(view1, view2): 24 | try: 25 | viewDiff = linalg.logm(np.dot(view2, np.transpose(view1))) 26 | except: 27 | print "Error Encountered" 28 | pdb.set_trace() 29 | 30 | viewDiff = linalg.norm(viewDiff, ord='fro') 31 | assert not any(np.isnan(viewDiff.flatten())) 32 | assert not any(np.isinf(viewDiff.flatten())) 33 | angle = viewDiff/np.sqrt(2) 34 | return angle 35 | 36 | 37 | def get_cluster_assignments(x, centers): 38 | N = x.shape[0] 39 | nCl = centers.shape[0] 40 | distMat = np.inf * np.ones((nCl,N)) 41 | 42 | for c in range(nCl): 43 | for i in range(N): 44 | distMat[c,i] = get_rot_angle(centers[c], x[i]) 45 | 46 | assert not any(np.isinf(distMat.flatten())) 47 | assert not any(np.isnan(distMat.flatten())) 48 | 49 | assgn = np.argmin(distMat, axis=0) 50 | minDist = np.amin(distMat, axis=0) 51 | meanDist = np.mean(minDist) 52 | assert all(minDist.flatten()>=0) 53 | return assgn, meanDist 54 | 55 | 56 | def karcher_mean(x, tol=0.01): 57 | ''' 58 | Determined the Karcher mean of rotations 59 | Implementation from Algorithm 1, Rotation Averaging, Hartley et al, IJCV 2013 60 | ''' 61 | R = x[0] 62 | N = x.shape[0] 63 | normDeltaR = np.inf 64 | itr = 0 65 | while True: 66 | #Estimate the delta rotation between the current center and all points 67 | deltaR = np.zeros((3,3)) 68 | oldNorm = normDeltaR 69 | for i in range(N): 70 | deltaR += linalg.logm(np.dot(np.transpose(R),x[i])) 71 | deltaR = deltaR / N 72 | normDeltaR = linalg.norm(deltaR, ord='fro')/np.sqrt(2) 73 | 74 | if oldNorm - normDeltaR < tol: 75 | break 76 | 77 | R = np.dot(R, linalg.expm(deltaR)) 78 | #print itr 79 | itr += 1 80 | 81 | return R 82 | 83 | 84 | def estimate_clusters(x, assgn, nCl): 85 | clusters = np.zeros((nCl,3,3)) 86 | for c in range(nCl): 87 | pointSet = x[assgn==c] 88 | clusters[c] = karcher_mean(pointSet) 89 | 90 | return clusters 91 | 92 | 93 | def cluster_rotmats(x,nCl=2,tol=0.01): 94 | ''' 95 | x : numMats * 3 * 3 96 | nCl: number of clusters 97 | tol: tolerance when to stop, it is basically if the reduction in mean error goes below this point 98 | ''' 99 | assert x.shape[1]==x.shape[2]==3 100 | N = x.shape[0] 101 | 102 | #Randomly chose some points as initial cluster centers 103 | perm = np.random.permutation(N) 104 | centers = x[perm[0:nCl]] 105 | assgn, dist = get_cluster_assignments(x, centers) 106 | print "Initial Mean Distance is: %f" % dist 107 | 108 | itr = 0 109 | clusterFlag = True 110 | while clusterFlag: 111 | itr += 1 112 | prevAssgn = np.copy(assgn) 113 | prevDist = dist 114 | #Find the new centers 115 | centers = estimate_clusters(x, assgn, nCl) 116 | #Find the new assgn 117 | assgn,dist = get_cluster_assignments(x, centers) 118 | 119 | print "iteration: %d, mean distance: %f" % (itr,dist) 120 | 121 | if prevDist - dist < tol: 122 | print "Desired tolerance achieved" 123 | clusterFlag = False 124 | 125 | if all(assgn==prevAssgn): 126 | print "Assignments didnot change in this iteration, hence converged" 127 | clusterFlag = False 128 | 129 | return assgn, centers 130 | 131 | 132 | def axis_to_skewsym(v): 133 | ''' 134 | Converts an axis into a skew symmetric matrix format. 135 | ''' 136 | v = v/np.linalg.norm(v) 137 | vHat = np.zeros((3,3)) 138 | vHat[0,1], vHat[0,2] = -v[2],v[1] 139 | vHat[1,0], vHat[1,2] = v[2],-v[0] 140 | vHat[2,0], vHat[2,1] = -v[1],v[0] 141 | 142 | return vHat 143 | 144 | 145 | def angle_axis_to_rotmat(theta, v): 146 | ''' 147 | Given the axis v, and a rotation theta - convert it into rotation matrix 148 | theta needs to be in radian 149 | ''' 150 | assert theta>=0 and theta0: 169 | v = v/theta 170 | return theta, v 171 | 172 | ## 173 | # Convert Euler matrices into a rotation matrix. 174 | def euler2mat(z=0, y=0, x=0, isRadian=True): 175 | ''' Return matrix for rotations around z, y and x axes 176 | 177 | Uses the z, then y, then x convention above 178 | 179 | Parameters 180 | ---------- 181 | z : scalar 182 | Rotation angle in radians around z-axis (performed first) 183 | y : scalar 184 | Rotation angle in radians around y-axis 185 | x : scalar 186 | Rotation angle in radians around x-axis (performed last) 187 | 188 | Returns 189 | ------- 190 | M : array shape (3,3) 191 | Rotation matrix giving same rotation as for given angles 192 | 193 | Examples 194 | -------- 195 | >>> zrot = 1.3 # radians 196 | >>> yrot = -0.1 197 | >>> xrot = 0.2 198 | >>> M = euler2mat(zrot, yrot, xrot) 199 | >>> M.shape == (3, 3) 200 | True 201 | 202 | The output rotation matrix is equal to the composition of the 203 | individual rotations 204 | 205 | >>> M1 = euler2mat(zrot) 206 | >>> M2 = euler2mat(0, yrot) 207 | >>> M3 = euler2mat(0, 0, xrot) 208 | >>> composed_M = np.dot(M3, np.dot(M2, M1)) 209 | >>> np.allclose(M, composed_M) 210 | True 211 | 212 | You can specify rotations by named arguments 213 | 214 | >>> np.all(M3 == euler2mat(x=xrot)) 215 | True 216 | 217 | When applying M to a vector, the vector should column vector to the 218 | right of M. If the right hand side is a 2D array rather than a 219 | vector, then each column of the 2D array represents a vector. 220 | 221 | >>> vec = np.array([1, 0, 0]).reshape((3,1)) 222 | >>> v2 = np.dot(M, vec) 223 | >>> vecs = np.array([[1, 0, 0],[0, 1, 0]]).T # giving 3x2 array 224 | >>> vecs2 = np.dot(M, vecs) 225 | 226 | Rotations are counter-clockwise. 227 | 228 | >>> zred = np.dot(euler2mat(z=np.pi/2), np.eye(3)) 229 | >>> np.allclose(zred, [[0, -1, 0],[1, 0, 0], [0, 0, 1]]) 230 | True 231 | >>> yred = np.dot(euler2mat(y=np.pi/2), np.eye(3)) 232 | >>> np.allclose(yred, [[0, 0, 1],[0, 1, 0], [-1, 0, 0]]) 233 | True 234 | >>> xred = np.dot(euler2mat(x=np.pi/2), np.eye(3)) 235 | >>> np.allclose(xred, [[1, 0, 0],[0, 0, -1], [0, 1, 0]]) 236 | True 237 | 238 | Notes 239 | ----- 240 | The direction of rotation is given by the right-hand rule (orient 241 | the thumb of the right hand along the axis around which the rotation 242 | occurs, with the end of the thumb at the positive end of the axis; 243 | curl your fingers; the direction your fingers curl is the direction 244 | of rotation). Therefore, the rotations are counterclockwise if 245 | looking along the axis of rotation from positive to negative. 246 | ''' 247 | 248 | if not isRadian: 249 | z = ((np.pi)/180.) * z 250 | y = ((np.pi)/180.) * y 251 | x = ((np.pi)/180.) * x 252 | assert z>=(-np.pi) and z < np.pi, 'Inapprorpriate z: %f' % z 253 | assert y>=(-np.pi) and y < np.pi, 'Inapprorpriate y: %f' % y 254 | assert x>=(-np.pi) and x < np.pi, 'Inapprorpriate x: %f' % x 255 | 256 | Ms = [] 257 | if z: 258 | cosz = math.cos(z) 259 | sinz = math.sin(z) 260 | Ms.append(np.array( 261 | [[cosz, -sinz, 0], 262 | [sinz, cosz, 0], 263 | [0, 0, 1]])) 264 | if y: 265 | cosy = math.cos(y) 266 | siny = math.sin(y) 267 | Ms.append(np.array( 268 | [[cosy, 0, siny], 269 | [0, 1, 0], 270 | [-siny, 0, cosy]])) 271 | if x: 272 | cosx = math.cos(x) 273 | sinx = math.sin(x) 274 | Ms.append(np.array( 275 | [[1, 0, 0], 276 | [0, cosx, -sinx], 277 | [0, sinx, cosx]])) 278 | if Ms: 279 | return reduce(np.dot, Ms[::-1]) 280 | return np.eye(3) 281 | 282 | 283 | def mat2euler(M, cy_thresh=None, seq='zyx'): 284 | ''' 285 | Taken Forom: http://afni.nimh.nih.gov/pub/dist/src/pkundu/meica.libs/nibabel/eulerangles.py 286 | Discover Euler angle vector from 3x3 matrix 287 | 288 | Uses the conventions above. 289 | 290 | Parameters 291 | ---------- 292 | M : array-like, shape (3,3) 293 | cy_thresh : None or scalar, optional 294 | threshold below which to give up on straightforward arctan for 295 | estimating x rotation. If None (default), estimate from 296 | precision of input. 297 | 298 | Returns 299 | ------- 300 | z : scalar 301 | y : scalar 302 | x : scalar 303 | Rotations in radians around z, y, x axes, respectively 304 | 305 | Notes 306 | ----- 307 | If there was no numerical error, the routine could be derived using 308 | Sympy expression for z then y then x rotation matrix, which is:: 309 | 310 | [ cos(y)*cos(z), -cos(y)*sin(z), sin(y)], 311 | [cos(x)*sin(z) + cos(z)*sin(x)*sin(y), cos(x)*cos(z) - sin(x)*sin(y)*sin(z), -cos(y)*sin(x)], 312 | [sin(x)*sin(z) - cos(x)*cos(z)*sin(y), cos(z)*sin(x) + cos(x)*sin(y)*sin(z), cos(x)*cos(y)] 313 | 314 | with the obvious derivations for z, y, and x 315 | 316 | z = atan2(-r12, r11) 317 | y = asin(r13) 318 | x = atan2(-r23, r33) 319 | 320 | for x,y,z order 321 | y = asin(-r31) 322 | x = atan2(r32, r33) 323 | z = atan2(r21, r11) 324 | 325 | 326 | Problems arise when cos(y) is close to zero, because both of:: 327 | 328 | z = atan2(cos(y)*sin(z), cos(y)*cos(z)) 329 | x = atan2(cos(y)*sin(x), cos(x)*cos(y)) 330 | 331 | will be close to atan2(0, 0), and highly unstable. 332 | 333 | The ``cy`` fix for numerical instability below is from: *Graphics 334 | Gems IV*, Paul Heckbert (editor), Academic Press, 1994, ISBN: 335 | 0123361559. Specifically it comes from EulerAngles.c by Ken 336 | Shoemake, and deals with the case where cos(y) is close to zero: 337 | 338 | See: http://www.graphicsgems.org/ 339 | 340 | The code appears to be licensed (from the website) as "can be used 341 | without restrictions". 342 | ''' 343 | M = np.asarray(M) 344 | if cy_thresh is None: 345 | try: 346 | cy_thresh = np.finfo(M.dtype).eps * 4 347 | except ValueError: 348 | cy_thresh = _FLOAT_EPS_4 349 | r11, r12, r13, r21, r22, r23, r31, r32, r33 = M.flat 350 | # cy: sqrt((cos(y)*cos(z))**2 + (cos(x)*cos(y))**2) 351 | cy = math.sqrt(r33*r33 + r23*r23) 352 | if seq=='zyx': 353 | if cy > cy_thresh: # cos(y) not close to zero, standard form 354 | z = math.atan2(-r12, r11) # atan2(cos(y)*sin(z), cos(y)*cos(z)) 355 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 356 | x = math.atan2(-r23, r33) # atan2(cos(y)*sin(x), cos(x)*cos(y)) 357 | else: # cos(y) (close to) zero, so x -> 0.0 (see above) 358 | # so r21 -> sin(z), r22 -> cos(z) and 359 | z = math.atan2(r21, r22) 360 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 361 | x = 0.0 362 | elif seq=='xyz': 363 | if cy > cy_thresh: 364 | y = math.atan2(-r31, cy) 365 | x = math.atan2(r32, r33) 366 | z = math.atan2(r21, r11) 367 | else: 368 | z = 0.0 369 | if r31 < 0: 370 | y = np.pi/2 371 | x = atan2(r12, r13) 372 | else: 373 | y = -np.pi/2 374 | #x = 375 | else: 376 | raise Exception('Sequence not recognized') 377 | return z, y, x 378 | 379 | 380 | def euler2quat(z=0, y=0, x=0, isRadian=True): 381 | ''' Return quaternion corresponding to these Euler angles 382 | 383 | Uses the z, then y, then x convention above 384 | 385 | Parameters 386 | ---------- 387 | z : scalar 388 | Rotation angle in radians around z-axis (performed first) 389 | y : scalar 390 | Rotation angle in radians around y-axis 391 | x : scalar 392 | Rotation angle in radians around x-axis (performed last) 393 | 394 | Returns 395 | ------- 396 | quat : array shape (4,) 397 | Quaternion in w, x, y z (real, then vector) format 398 | 399 | Notes 400 | ----- 401 | We can derive this formula in Sympy using: 402 | 403 | 1. Formula giving quaternion corresponding to rotation of theta radians 404 | about arbitrary axis: 405 | http://mathworld.wolfram.com/EulerParameters.html 406 | 2. Generated formulae from 1.) for quaternions corresponding to 407 | theta radians rotations about ``x, y, z`` axes 408 | 3. Apply quaternion multiplication formula - 409 | http://en.wikipedia.org/wiki/Quaternions#Hamilton_product - to 410 | formulae from 2.) to give formula for combined rotations. 411 | ''' 412 | 413 | if not isRadian: 414 | z = ((np.pi)/180.) * z 415 | y = ((np.pi)/180.) * y 416 | x = ((np.pi)/180.) * x 417 | z = z/2.0 418 | y = y/2.0 419 | x = x/2.0 420 | cz = math.cos(z) 421 | sz = math.sin(z) 422 | cy = math.cos(y) 423 | sy = math.sin(y) 424 | cx = math.cos(x) 425 | sx = math.sin(x) 426 | return np.array([ 427 | cx*cy*cz - sx*sy*sz, 428 | cx*sy*sz + cy*cz*sx, 429 | cx*cz*sy - sx*cy*sz, 430 | cx*cy*sz + sx*cz*sy]) 431 | 432 | 433 | def quat2euler(q): 434 | ''' Return Euler angles corresponding to quaternion `q` 435 | 436 | Parameters 437 | ---------- 438 | q : 4 element sequence 439 | w, x, y, z of quaternion 440 | 441 | Returns 442 | ------- 443 | z : scalar 444 | Rotation angle in radians around z-axis (performed first) 445 | y : scalar 446 | Rotation angle in radians around y-axis 447 | x : scalar 448 | Rotation angle in radians around x-axis (performed last) 449 | 450 | Notes 451 | ----- 452 | It's possible to reduce the amount of calculation a little, by 453 | combining parts of the ``quat2mat`` and ``mat2euler`` functions, but 454 | the reduction in computation is small, and the code repetition is 455 | large. 456 | ''' 457 | # delayed import to avoid cyclic dependencies 458 | import nibabel.quaternions as nq 459 | return mat2euler(nq.quat2mat(q)) 460 | 461 | def plot_rotmats(rotMats, isInteractive=True): 462 | if isInteractive: 463 | import matplotlib 464 | matplotlib.use('tkagg') 465 | import matplotlib.pyplot as plt 466 | else: 467 | import matplotlib 468 | matplotlib.use('Agg') 469 | import matplotlib.pyplot as plt 470 | 471 | N = rotMats.shape[0] 472 | plt.ion() 473 | fig = plt.figure() 474 | ax = fig.add_subplot(111, projection='3d') 475 | 476 | xpos, ypos, zpos = np.zeros((N,1)), np.zeros((N,1)), np.zeros((N,1)) 477 | vx,vy,vz = [],[],[] 478 | 479 | for i in range(N): 480 | theta,v = rotmat_to_angle_axis(rotMats[i]) 481 | v = theta * v 482 | vx.append(v[0]) 483 | vy.append(v[1]) 484 | vz.append(v[2]) 485 | 486 | ax.quiver(xpos,ypos,zpos,vx,vy,vz) 487 | plt.show() 488 | ax.set_xlim(-1,1) 489 | ax.set_ylim(-1,1) 490 | ax.set_zlim(-1,1) 491 | 492 | 493 | def generate_random_rotmats(numMat = 100, thetaRange=np.pi/4, thetaFixed=False): 494 | rotMats = np.zeros((numMat,3,3)) 495 | 496 | if not thetaFixed: 497 | #Randomly generate an axis for rotation matrix 498 | v = np.random.random(3) 499 | for i in range(numMat): 500 | theta = thetaRange * np.random.random() 501 | rotMats[i] = angle_axis_to_rotmat(theta, v) 502 | else: 503 | for i in range(numMat): 504 | v = np.random.randn(3) 505 | v = v/linalg.norm(v) 506 | theta = thetaRange * np.random.random() 507 | rotMats[i] = angle_axis_to_rotmat(theta, v) 508 | 509 | return rotMats 510 | 511 | 512 | def test_clustering(): 513 | ''' 514 | For testing clustering: 515 | Randomly generate soem data, cluster it and save it .mat file 516 | Using matlab I will then visualize it. Visualizing in python is being a pain. 517 | ''' 518 | N = 1000 519 | nCl = 3 520 | 521 | #Generate the data using nCl different axes. 522 | dat = np.zeros((N,3,3)) 523 | idx = np.linspace(0,N,nCl+1).astype('int') 524 | for i in range(nCl): 525 | dat[idx[i]:idx[i+1]] = generate_random_rotmats(idx[i+1]-idx[i],thetaFixed=True) 526 | 527 | assgn, centersMat = cluster_rotmats(dat,nCl) 528 | 529 | points = np.zeros((N,3)) 530 | for i in range(N): 531 | theta,points[i] = rotmat_to_angle_axis(dat[i]) 532 | points[i] = theta*points[i] 533 | 534 | centers = np.zeros((nCl,3)) 535 | for i in range(nCl): 536 | theta,centers[i] = rotmat_to_angle_axis(centersMat[i]) 537 | centers[i] = theta*centers[i] 538 | 539 | sio.savemat('test_clustering.mat',{'assgn':assgn,'centers':centers,'points':points}) 540 | 541 | 542 | -------------------------------------------------------------------------------- /test_bench/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pulkitag/pycaffe-utils/102578eb28d6ce96a431d65858102e062d05ca6f/test_bench/__init__.py -------------------------------------------------------------------------------- /test_bench/test_mysolver.py: -------------------------------------------------------------------------------- 1 | import my_pycaffe as mp 2 | import my_exp_config as mec 3 | 4 | REAL_PATH = cfg.REAL_PATH 5 | DEF_DB = cfg.DEF_DB % ('default', '%s') 6 | 7 | #See if the solver is able to load a pretrained net automatically 8 | def test_solver_load_pretrain(): 9 | #Get net prms 10 | nPrms = mec.get_default_net_prms(dbFile, **kwargs) 11 | del nPrms['expStr'] 12 | nPrms.baseNetDefProto = 'doublefc-v1_window_fc6' 13 | nPrms = mpu.get_defaults(kwargs, nPrms, False) 14 | nPrms['expStr'] = mec.get_sql_id(dbFile, dArgs, ignoreKeys=['ncpu']) 15 | 16 | dPrms = get_data_prms() 17 | nwFn = process_net_prms 18 | ncpu = 0 19 | nwArgs = {'ncpu': ncpu, 'lrAbove': None, 'preTrainNet':None} 20 | solFn = mec.get_default_solver_prms 21 | solArgs = {'dbFile': DEF_DB % 'sol', 'clip_gradients': 10} 22 | cPrms = mec.get_caffe_prms(nwFn=nwFn, nwPrms=nwArgs, 23 | solFn=solFn, solPrms=solArgs) 24 | exp = mec.CaffeSolverExperiment(dPrms, cPrms, 25 | netDefFn=make_net_def, isLog=True) 26 | if isRun: 27 | exp.make() 28 | exp.run() 29 | return exp 30 | 31 | 32 | -------------------------------------------------------------------------------- /test_bench/test_sql.py: -------------------------------------------------------------------------------- 1 | import my_sqlite as msq 2 | import sys 3 | import subprocess 4 | from easydict import EasyDict as edict 5 | 6 | def prms1(): 7 | p = edict() 8 | p.a = None 9 | p.b = 1 10 | return p 11 | 12 | def prms2(): 13 | p = prms1() 14 | p.c = None 15 | p.d = 'check' 16 | return p 17 | 18 | def prms3(): 19 | p = prms2() 20 | p.e = 34.56 21 | p.f = None 22 | return p 23 | 24 | def test1(): 25 | dbFile = 'test_data/test-sql.sqlite' 26 | db = msq.SqDb(dbFile) 27 | p1, p2, p3 = prms1(), prms2(), prms3() 28 | db.fetch(p3) 29 | idx = db.get_id(p1) 30 | assert idx is None 31 | idx = db.get_id(p2) 32 | assert idx is None 33 | p3Idx = db.get_id(p3) 34 | assert p3Idx is not None 35 | db.fetch(p1) 36 | idx = db.get_id(p3) 37 | assert p3Idx==idx 38 | idx = db.get_id(p2) 39 | assert idx is None 40 | idx = db.get_id(p1) 41 | assert idx is not None 42 | subprocess.check_call(['rm %s' % dbFile],shell=True) 43 | 44 | def test2(): 45 | dbFile = 'test_data/test-sql.sqlite' 46 | db = msq.SqDb(dbFile) 47 | p1, p2, p3 = prms1(), prms2(), prms3() 48 | db.fetch(p2) 49 | idx = db.get_id(p1) 50 | assert idx is None 51 | p2Idx = db.get_id(p2) 52 | assert p2Idx is not None 53 | p3Idx = db.get_id(p3) 54 | assert p3Idx is None 55 | db.close() 56 | del db 57 | db = msq.SqDb(dbFile) 58 | db.fetch(p2) 59 | db.fetch(p1) 60 | idx = db.get_id(p2) 61 | assert idx == p2Idx 62 | idx = db.get_id(p3) 63 | assert idx is None 64 | idx = db.get_id(p1) 65 | assert idx is not None 66 | db.fetch(p1) 67 | print (db._get({})) 68 | subprocess.check_call(['rm %s' % dbFile],shell=True) 69 | -------------------------------------------------------------------------------- /vis_utils.py: -------------------------------------------------------------------------------- 1 | ## @package vis_utils 2 | # Miscellaneous Functions for visualizations 3 | # 4 | import numpy as np 5 | import scipy.misc as scm 6 | import matplotlib.pyplot as plt 7 | import copy 8 | import os 9 | import pdb 10 | try: 11 | import caffe 12 | import my_pycaffe_utils as mpu 13 | import my_pycaffe as mp 14 | import my_pycaffe_io as mpio 15 | except: 16 | print ('CAFFE PACKAGES CANNOT BE LOADED') 17 | import scipy 18 | from matplotlib import gridspec 19 | 20 | TMP_DATA_DIR = '/data1/pulkitag/others/caffe_tmp_data/' 21 | 22 | ## 23 | #Plot n images 24 | def plot_n_ims(ims, fig=None, titleStr='', figTitle='', 25 | axTitles = None, subPlotShape=None, 26 | isBlobFormat=False, chSwap=None, trOrder=None, 27 | showType=None): 28 | ''' 29 | ims: list of images 30 | isBlobFormat: Caffe stores images as ch x h x w 31 | True - convert the images into h x w x ch format 32 | trOrder : If certain transpose order of channels is to be used 33 | overrides isBlobFormat 34 | showType : imshow or matshow (by default imshow) 35 | ''' 36 | ims = copy.deepcopy(ims) 37 | if trOrder is not None: 38 | for i, im in enumerate(ims): 39 | ims[i] = im.transpose(trOrder) 40 | if trOrder is None and isBlobFormat: 41 | for i, im in enumerate(ims): 42 | ims[i] = im.transpose((1,2,0)) 43 | if chSwap is not None: 44 | for i, im in enumerate(ims): 45 | ims[i] = im[:,:,chSwap] 46 | plt.ion() 47 | if fig is None: 48 | fig = plt.figure() 49 | plt.figure(fig.number) 50 | plt.clf() 51 | if subPlotShape is None: 52 | N = np.ceil(np.sqrt(len(ims))) 53 | subPlotShape = (N,N) 54 | #gs = gridspec.GridSpec(N, N) 55 | ax = [] 56 | for i in range(len(ims)): 57 | shp = subPlotShape + (i+1,) 58 | aa = fig.add_subplot(*shp) 59 | aa.autoscale(False) 60 | ax.append(aa) 61 | #ax.append(plt.subplot(gs[i])) 62 | 63 | if showType is None: 64 | showType = ['imshow'] * len(ims) 65 | else: 66 | assert len(showType) == len(ims) 67 | 68 | for i, im in enumerate(ims): 69 | ax[i].set_ylim(im.shape[0], 0) 70 | ax[i].set_xlim(0, im.shape[1]) 71 | if showType[i] == 'imshow': 72 | ax[i].imshow(im.astype(np.uint8)) 73 | elif showType[i] == 'matshow': 74 | res = ax[i].matshow(im) 75 | plt.colorbar(res, ax=ax[i]) 76 | ax[i].axis('off') 77 | if axTitles is not None: 78 | ax[i].set_title(axTitles[i]) 79 | if len(figTitle) > 0: 80 | fig.suptitle(figTitle) 81 | plt.show() 82 | return ax 83 | 84 | 85 | def plot_pairs(im1, im2, **kwargs): 86 | ims = [] 87 | ims.append(im1) 88 | ims.append(im2) 89 | return plot_n_ims(ims, subPlotShape=(1,2), **kwargs) 90 | 91 | 92 | ## 93 | #Plot pairs of images from an iterator_fun 94 | def plot_pairs_iterfun(ifun, **kwargs): 95 | ''' 96 | ifun : iteration function 97 | kwargs: look at input arguments for plot_pairs 98 | ''' 99 | plt.ion() 100 | fig = plt.figure() 101 | pltFlag = True 102 | while pltFlag: 103 | im1, im2 = ifun() 104 | plot_pairs(im1, im2, fig=fig, **kwargs) 105 | ip = raw_input('Press Enter for next pair') 106 | if ip == 'q': 107 | pltFlag = False 108 | 109 | 110 | class MyAnimation(object): 111 | def __init__(self, vis_func, frames=100, fps=20, height=200, width=200, fargs=[]): 112 | self.frames = frames 113 | self.vis_func = vis_func 114 | self.vis_func_args = fargs 115 | self.fps = fps 116 | self.fig, self.ax = plt.subplots(1,1) 117 | plt.show(block=False) 118 | self.bg = self.fig.canvas.copy_from_bbox(self.ax.bbox) 119 | im = np.zeros((height, width, 3)).astype(np.uint8) 120 | self.image_obj = self.ax.imshow(im) 121 | self.fig.canvas.draw() 122 | 123 | def __del__(self): 124 | plt.close(self.fig) 125 | 126 | def run(self, fargs=[]): 127 | if len(fargs)==0: 128 | func_args = self.vis_func_args 129 | else: 130 | func_args = fargs 131 | time_diff = float(1.0)/self.fps 132 | for i in range(self.frames): 133 | op = self.vis_func(i, *func_args) 134 | if type(op) == tuple: 135 | im, is_stop = op 136 | else: 137 | im = op 138 | is_stop = False 139 | self._display(im) 140 | time.sleep(time_diff) 141 | if is_stop: 142 | break 143 | 144 | def _display(self, pixels): 145 | self.image_obj.set_data(pixels) 146 | self.fig.canvas.restore_region(self.bg) 147 | self.ax.draw_artist(self.image_obj) 148 | self.fig.canvas.blit(self.ax.bbox) 149 | 150 | 151 | def draw_square_on_im(im, sq, width=4, col='w'): 152 | x1, y1, x2, y2 = sq 153 | h = im.shape[0] 154 | w = im.shape[1] 155 | if col == 'w': 156 | col = (255 * np.ones((1,1,3))).astype(np.uint8) 157 | elif col == 'r': 158 | col = np.zeros((1,1,3)) 159 | col[0,0,0] = 255 160 | col = col.astype(np.uint8) 161 | #Top Line 162 | im[max(0,int(y1-width/2)):min(h, y1+int(width/2)),x1:x2,:] = col 163 | #Bottom line 164 | im[max(0,int(y2-width/2)):min(h, y2+int(width/2)),x1:x2,:] = col 165 | #Left line 166 | im[y1:y2, max(0,int(x1-width/2)):min(h, x1+int(width/2))] = col 167 | #Right line 168 | im[y1:y2, max(0,int(x2-width/2)):min(h, x2+int(width/2))] = col 169 | return im 170 | 171 | 172 | ## 173 | # Visualize GenericWindowDataLayer file 174 | def vis_generic_window_data(protoDef, numLabels, layerName='window_data', phase='TEST', 175 | maxVis=100): 176 | ''' 177 | layerName: The name of the generic_window_data layer 178 | numLabels: The number of labels. 179 | ''' 180 | #Just write the data part of the file. 181 | if not isinstance(protoDef, mpu.ProtoDef): 182 | protoDef = mpu.ProtoDef(protoDef) 183 | protoDef.del_all_layers_above(layerName) 184 | randInt = np.random.randint(1e+10) 185 | outProto = os.path.join(TMP_DATA_DIR, 'gn_window_%d.prototxt' % randInt) 186 | protoDef.write(outProto) 187 | #Extract the name of the data and the label blobs. 188 | dataName = protoDef.get_layer_property(layerName, 'top', propNum=0)[1:-1] 189 | labelName = protoDef.get_layer_property(layerName, 'top', propNum=1)[1:-1] 190 | crpSize = int(protoDef.get_layer_property(layerName, ['crop_size'])) 191 | mnFile = protoDef.get_layer_property(layerName, ['mean_file'])[1:-1] 192 | mnDat = mpio.read_mean(mnFile) 193 | ch,nr,nc = mnDat.shape 194 | xMn = int((nr - crpSize)/2) 195 | mnDat = mnDat[:,xMn:xMn+crpSize,xMn:xMn+crpSize] 196 | print mnDat.shape 197 | 198 | #Create a network 199 | if phase=='TRAIN': 200 | net = caffe.Net(outProto, caffe.TRAIN) 201 | else: 202 | net = caffe.Net(outProto, caffe.TEST) 203 | 204 | lblStr = ''.join('lb-%d: %s, ' % (i,'%.2f') for i in range(numLabels)) 205 | figDt = plt.figure() 206 | plt.ion() 207 | for i in range(maxVis): 208 | allDat = net.forward([dataName,labelName]) 209 | imData = allDat[dataName] + mnDat 210 | lblDat = allDat[labelName] 211 | batchSz = imData.shape[0] 212 | for b in range(batchSz): 213 | #Plot network data. 214 | im1 = imData[b,0:3].transpose((1,2,0)) 215 | im2 = imData[b,3:6].transpose((1,2,0)) 216 | im1 = im1[:,:,[2,1,0]] 217 | im2 = im2[:,:,[2,1,0]] 218 | lb = lblDat[b].squeeze() 219 | lbStr = lblStr % tuple(lb) 220 | plot_pairs(im1, im2, figDt, lbStr) 221 | raw_input() 222 | 223 | 224 | def rec_fun_grad(x, myNet, blobDat, blobLbl, shp, lamda): 225 | ''' 226 | Consider one batch a time. 227 | ''' 228 | #print x.shape, blobDat.shape, blobLbl.shape, lamda 229 | #Put the data 230 | myNet.net_.net.set_input_arrays(blobDat, blobLbl) 231 | print shp 232 | #Get the Error 233 | feats, diffs = myNet.net_.forward_backward_all(blobs=['loss'], 234 | diffs=['data'],data=x.reshape(shp)) 235 | grad = diffs['data'] + lamda * x.reshape(shp) 236 | batchLoss = feats['loss'][0] + 0.5 * lamda * np.dot(x,x) 237 | 238 | grad = grad.flatten() 239 | return batchLoss, grad 240 | 241 | 242 | def reconstruct_optimal_input(exp, modelIter, im, recLayer='conv1', 243 | imH=101, imW=101, cropH=101, cropW=101, channels=3, 244 | meanFile=None, lamda=1e-8, batchSz=1, **kwargs): 245 | exp = copy.deepcopy(exp) 246 | kwargs['delAbove'] = recLayer 247 | 248 | #Setup the original network 249 | origNet = mpu.CaffeTest.from_caffe_exp(exp) 250 | origNet.setup_network(opNames=recLayer, imH=imH, imW=imW, cropH=cropH, cropW=cropW, 251 | modelIterations=modelIter, batchSz=batchSz, 252 | isAccuracyTest=False, meanFile=meanFile, **kwargs) 253 | 254 | #Get the size of the features in the layer that needs to be reconstructed 255 | #Shape of the layer to be reconstructed 256 | blob = origNet.net_.net.blobs[recLayer] 257 | initBlobDat = np.zeros((blob.num, blob.channels, blob.height, blob.width)).astype('float32') 258 | blobLbl = np.zeros((blob.num, 1, 1, 1)).astype('float32') 259 | recShape = (blob.num, blob.channels, blob.height, blob.width) 260 | 261 | #Get the initial layer features 262 | print "Extracting Initial Features" 263 | blobDat = np.zeros((blob.num, blob.channels, blob.height, blob.width)).astype('float32') 264 | if im.ndim == 3: 265 | imIp = im.reshape((batchSz,) + im.shape) 266 | else: 267 | imIp = im 268 | feats = origNet.net_.forward_all(blobs=[recLayer], data=imIp) 269 | blobDat = feats[recLayer] 270 | #imDat = np.asarray(imDat) 271 | 272 | #Get the net for reconstruvtions 273 | #print (exp.expFile_.netDef_.get_all_layernames()) 274 | recProto = mpu.ProtoDef.recproto_from_proto(exp.expFile_.netDef_, featDims=recShape, 275 | imSz=[[channels, cropH, cropW]], batchSz=batchSz, **kwargs) 276 | recProto.write(exp.files_['netdefRec']) 277 | recModel = exp.get_snapshot_name(modelIter) 278 | recNet = edict() 279 | recNet.net_ = mp.MyNet(exp.files_['netdefRec'], recModel, caffe.TRAIN) 280 | #recNet = mpu.CaffeTest.from_model(exp.files_['netdefRec'], recModel) 281 | #kwargs['dataLayerNames'] = ['data'] 282 | #kwargs['newDataLayerNames'] = None 283 | #recNet.setup_network(opNames=[recLayer], imH=imH, imW=imW, cropH=cropH, cropW=cropW, 284 | # modelIterations=modelIter, isAccuracyTest=False, meanFile=meanFile, 285 | # testMode=False, **kwargs) 286 | recNet.net_.net.set_force_backward(recLayer) 287 | #Start the reconstruction 288 | ch,h,w = imIp.shape[3], imIp.shape[1], imIp.shape[2] 289 | imRec = 255*np.random.random((batchSz,ch,h,w)).astype('float32') 290 | print imRec.shape, blobDat.shape, blobLbl.shape 291 | sol = scipy.optimize.fmin_l_bfgs_b(rec_fun_grad, imRec.flatten(),args=[recNet, blobDat, blobLbl, imRec.shape, lamda], maxfun=1000, factr=1e+7,pgtol=1e-07, iprint=0, disp=1) 292 | 293 | imRec = np.reshape(sol[0],((batchSz,ch,h,w))) 294 | #imRec = im2visim(np.copy(imRec)) 295 | #imGt = im2visim(np.copy(imDat)) 296 | return imRec, imGt 297 | --------------------------------------------------------------------------------