├── .gitignore ├── README.md ├── caffemodel_to_t7.lua └── convert_torch.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convert caffe to pytorch 2 | # Convert caffe to torch 3 | # Convert torch to pytorch 4 | + [x] Convert caffe model to pytorch model 5 | + [x] Convert caffe model to torch model 6 | + [x] Convert torch model to pytorch model 7 | 8 | * I have tested on vgg16, it behaves well on classification tasks. But I can't guarantee it performs well on other tasks(such as object detection and semantic segmentation). You can try it and modify the code according the bug info. If there are new components in your caffe model, you should add corresponding parts in the code。 9 | 10 | 11 | ## [Install torch](http://torch.ch/docs/getting-started.html#_) 12 | 13 | ## [install loadcaffe](https://github.com/szagoruyko/loadcaffe) 14 | 15 | ## Convert caffe to torch 16 | * Change the path to your own path. 17 | 18 | * Put the `.prototxt` and `.caffemodel` file in the same folder. 19 | 20 | * You will get the `vgg16_torch.t7` file. 21 | 22 | ``` 23 | th caffemodel_to_t7.lua 24 | ``` 25 | 26 | ## Convert torch to pytorch 27 | 28 | ```bash 29 | python convert_torch.py -m vgg16_torch.t7 30 | ``` 31 | Two file will be created ```vgg16_torch.py``` ```vgg16_torch.pth``` 32 | 33 | 34 | ## Load the .pth model in python 35 | * Make sure the ```vgg16_torch.py``` and ```vgg16_torch.pth``` files in the same folder with the python workspace. 36 | * The ```import vgg16_torch``` means importing the model structure from the ```vgg16_torch.py```. 37 | * The ```model.load_state_dict``` means loading weights from ```vgg16_torch.pth``` into the model structure. 38 | ```python 39 | import vgg16_torch 40 | 41 | model = vgg16_torch.vgg16_torch 42 | model.load_state_dict(torch.load('vgg16_torch.pth')) 43 | model.eval() 44 | ... 45 | ``` 46 | 47 | # Acknowledgement 48 | * The caffe to torch code is modified from [https://github.com/jcjohnson/pytorch-vgg](https://github.com/jcjohnson/pytorch-vgg) 49 | 50 | * The torch to pytorch code is borrowed from [https://github.com/clcarwin/convert_torch_to_pytorch](https://github.com/clcarwin/convert_torch_to_pytorch) 51 | 52 | -------------------------------------------------------------------------------- /caffemodel_to_t7.lua: -------------------------------------------------------------------------------- 1 | require 'loadcaffe' 2 | require 'xlua' 3 | require 'optim' 4 | 5 | —- modify the path 6 | 7 | prototxt = '/home/fanq15/convert_caffe_to_pytorch/deploy_vgg16.prototxt' 8 | binary = '/home/fanq15/pconvert_caffe_to_pytorch/vgg16.caffemodel' 9 | 10 | net = loadcaffe.load(prototxt, binary, 'cudnn') 11 | net = net:float() —- essential reference https://github.com/clcarwin/convert_torch_to_pytorch/issues/8 12 | print(net) 13 | 14 | torch.save('/home/fanq15/convert_caffe_to_pytorch/vgg16_torch.t7', net) 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /convert_torch.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | from torch.utils.serialization import load_lua 9 | 10 | import numpy as np 11 | import os 12 | import math 13 | from functools import reduce 14 | 15 | class LambdaBase(nn.Sequential): 16 | def __init__(self, fn, *args): 17 | super(LambdaBase, self).__init__(*args) 18 | self.lambda_func = fn 19 | 20 | def forward_prepare(self, input): 21 | output = [] 22 | for module in self._modules.values(): 23 | output.append(module(input)) 24 | return output if output else input 25 | 26 | class Lambda(LambdaBase): 27 | def forward(self, input): 28 | return self.lambda_func(self.forward_prepare(input)) 29 | 30 | class LambdaMap(LambdaBase): 31 | def forward(self, input): 32 | # result is Variables list [Variable1, Variable2, ...] 33 | return list(map(self.lambda_func,self.forward_prepare(input))) 34 | 35 | class LambdaReduce(LambdaBase): 36 | def forward(self, input): 37 | # result is a Variable 38 | return reduce(self.lambda_func,self.forward_prepare(input)) 39 | 40 | 41 | def copy_param(m,n): 42 | if m.weight is not None: n.weight.data.copy_(m.weight) 43 | if m.bias is not None: n.bias.data.copy_(m.bias) 44 | if hasattr(n,'running_mean'): n.running_mean.copy_(m.running_mean) 45 | if hasattr(n,'running_var'): n.running_var.copy_(m.running_var) 46 | 47 | def add_submodule(seq, *args): 48 | for n in args: 49 | seq.add_module(str(len(seq._modules)),n) 50 | 51 | def lua_recursive_model(module,seq): 52 | for m in module.modules: 53 | name = type(m).__name__ 54 | real = m 55 | if name == 'TorchObject': 56 | name = m._typename.replace('cudnn.','') 57 | m = m._obj 58 | 59 | if name == 'SpatialConvolution': 60 | if not hasattr(m,'groups'): m.groups=1 61 | n = nn.Conv2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,bias=(m.bias is not None)) 62 | copy_param(m,n) 63 | add_submodule(seq,n) 64 | elif name == 'SpatialBatchNormalization': 65 | n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum, m.affine) 66 | copy_param(m,n) 67 | add_submodule(seq,n) 68 | elif name == 'ReLU': 69 | n = nn.ReLU() 70 | add_submodule(seq,n) 71 | elif name == 'SpatialMaxPooling': 72 | n = nn.MaxPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode) 73 | add_submodule(seq,n) 74 | elif name == 'SpatialAveragePooling': 75 | n = nn.AvgPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode) 76 | add_submodule(seq,n) 77 | elif name == 'SpatialUpSamplingNearest': 78 | n = nn.UpsamplingNearest2d(scale_factor=m.scale_factor) 79 | add_submodule(seq,n) 80 | elif name == 'View': 81 | n = Lambda(lambda x: x.view(x.size(0),-1)) 82 | add_submodule(seq,n) 83 | elif name == 'Linear': 84 | # Linear in pytorch only accept 2D input 85 | n1 = Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ) 86 | n2 = nn.Linear(m.weight.size(1),m.weight.size(0),bias=(m.bias is not None)) 87 | copy_param(m,n2) 88 | n = nn.Sequential(n1,n2) 89 | add_submodule(seq,n) 90 | elif name == 'Dropout': 91 | m.inplace = False 92 | n = nn.Dropout(m.p) 93 | add_submodule(seq,n) 94 | elif name == 'SoftMax': 95 | n = nn.Softmax() 96 | add_submodule(seq,n) 97 | elif name == 'Identity': 98 | n = Lambda(lambda x: x) # do nothing 99 | add_submodule(seq,n) 100 | elif name == 'SpatialFullConvolution': 101 | n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH)) 102 | add_submodule(seq,n) 103 | elif name == 'SpatialReplicationPadding': 104 | n = nn.ReplicationPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b)) 105 | add_submodule(seq,n) 106 | elif name == 'SpatialReflectionPadding': 107 | n = nn.ReflectionPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b)) 108 | add_submodule(seq,n) 109 | elif name == 'Copy': 110 | n = Lambda(lambda x: x) # do nothing 111 | add_submodule(seq,n) 112 | elif name == 'Narrow': 113 | n = Lambda(lambda x,a=(m.dimension,m.index,m.length): x.narrow(*a)) 114 | add_submodule(seq,n) 115 | elif name == 'SpatialCrossMapLRN': 116 | lrn = torch.legacy.nn.SpatialCrossMapLRN(m.size,m.alpha,m.beta,m.k) 117 | n = Lambda(lambda x,lrn=lrn: Variable(lrn.forward(x.data))) 118 | add_submodule(seq,n) 119 | elif name == 'Sequential': 120 | n = nn.Sequential() 121 | lua_recursive_model(m,n) 122 | add_submodule(seq,n) 123 | elif name == 'ConcatTable': # output is list 124 | n = LambdaMap(lambda x: x) 125 | lua_recursive_model(m,n) 126 | add_submodule(seq,n) 127 | elif name == 'CAddTable': # input is list 128 | n = LambdaReduce(lambda x,y: x+y) 129 | add_submodule(seq,n) 130 | elif name == 'Concat': 131 | dim = m.dimension 132 | n = LambdaReduce(lambda x,y,dim=dim: torch.cat((x,y),dim)) 133 | lua_recursive_model(m,n) 134 | add_submodule(seq,n) 135 | elif name == 'TorchObject': 136 | print('Not Implement',name,real._typename) 137 | else: 138 | print('Not Implement',name) 139 | 140 | 141 | def lua_recursive_source(module): 142 | s = [] 143 | for m in module.modules: 144 | name = type(m).__name__ 145 | real = m 146 | if name == 'TorchObject': 147 | name = m._typename.replace('cudnn.','') 148 | m = m._obj 149 | 150 | if name == 'SpatialConvolution': 151 | if not hasattr(m,'groups'): m.groups=1 152 | s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane, 153 | m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,m.bias is not None)] 154 | elif name == 'SpatialBatchNormalization': 155 | s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)] 156 | elif name == 'ReLU': 157 | s += ['nn.ReLU()'] 158 | elif name == 'SpatialMaxPooling': 159 | s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)] 160 | elif name == 'SpatialAveragePooling': 161 | s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)] 162 | elif name == 'SpatialUpSamplingNearest': 163 | s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(m.scale_factor)] 164 | elif name == 'View': 165 | s += ['Lambda(lambda x: x.view(x.size(0),-1)), # View'] 166 | elif name == 'Linear': 167 | s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )' 168 | s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1),m.weight.size(0),(m.bias is not None)) 169 | s += ['nn.Sequential({},{}),#Linear'.format(s1,s2)] 170 | elif name == 'Dropout': 171 | s += ['nn.Dropout({})'.format(m.p)] 172 | elif name == 'SoftMax': 173 | s += ['nn.Softmax()'] 174 | elif name == 'Identity': 175 | s += ['Lambda(lambda x: x), # Identity'] 176 | elif name == 'SpatialFullConvolution': 177 | s += ['nn.ConvTranspose2d({},{},{},{},{})'.format(m.nInputPlane, 178 | m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH))] 179 | elif name == 'SpatialReplicationPadding': 180 | s += ['nn.ReplicationPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))] 181 | elif name == 'SpatialReflectionPadding': 182 | s += ['nn.ReflectionPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))] 183 | elif name == 'Copy': 184 | s += ['Lambda(lambda x: x), # Copy'] 185 | elif name == 'Narrow': 186 | s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format((m.dimension,m.index,m.length))] 187 | elif name == 'SpatialCrossMapLRN': 188 | lrn = 'torch.legacy.nn.SpatialCrossMapLRN(*{})'.format((m.size,m.alpha,m.beta,m.k)) 189 | s += ['Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))'.format(lrn)] 190 | 191 | elif name == 'Sequential': 192 | s += ['nn.Sequential( # Sequential'] 193 | s += lua_recursive_source(m) 194 | s += [')'] 195 | elif name == 'ConcatTable': 196 | s += ['LambdaMap(lambda x: x, # ConcatTable'] 197 | s += lua_recursive_source(m) 198 | s += [')'] 199 | elif name == 'CAddTable': 200 | s += ['LambdaReduce(lambda x,y: x+y), # CAddTable'] 201 | elif name == 'Concat': 202 | dim = m.dimension 203 | s += ['LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format(m.dimension)] 204 | s += lua_recursive_source(m) 205 | s += [')'] 206 | else: 207 | s += '# ' + name + ' Not Implement,\n' 208 | s = map(lambda x: '\t{}'.format(x),s) 209 | return s 210 | 211 | def simplify_source(s): 212 | s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d',')'),s) 213 | s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d',')'),s) 214 | s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d',')'),s) 215 | s = map(lambda x: x.replace(',bias=True),#Conv2d',')'),s) 216 | s = map(lambda x: x.replace('),#Conv2d',')'),s) 217 | s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d',')'),s) 218 | s = map(lambda x: x.replace('),#BatchNorm2d',')'),s) 219 | s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d',')'),s) 220 | s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d',')'),s) 221 | s = map(lambda x: x.replace('),#MaxPool2d',')'),s) 222 | s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d',')'),s) 223 | s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d',')'),s) 224 | s = map(lambda x: x.replace(',bias=True)),#Linear',')), # Linear'),s) 225 | s = map(lambda x: x.replace(')),#Linear',')), # Linear'),s) 226 | 227 | s = map(lambda x: '{},\n'.format(x),s) 228 | s = map(lambda x: x[1:],s) 229 | s = reduce(lambda x,y: x+y, s) 230 | return s 231 | 232 | def torch_to_pytorch(t7_filename,outputname=None): 233 | model = load_lua(t7_filename,unknown_classes=True) 234 | if type(model).__name__=='hashable_uniq_dict': model=model.model 235 | model.gradInput = None 236 | slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model)) 237 | s = simplify_source(slist) 238 | header = ''' 239 | import torch 240 | import torch.nn as nn 241 | from torch.autograd import Variable 242 | from functools import reduce 243 | 244 | class LambdaBase(nn.Sequential): 245 | def __init__(self, fn, *args): 246 | super(LambdaBase, self).__init__(*args) 247 | self.lambda_func = fn 248 | 249 | def forward_prepare(self, input): 250 | output = [] 251 | for module in self._modules.values(): 252 | output.append(module(input)) 253 | return output if output else input 254 | 255 | class Lambda(LambdaBase): 256 | def forward(self, input): 257 | return self.lambda_func(self.forward_prepare(input)) 258 | 259 | class LambdaMap(LambdaBase): 260 | def forward(self, input): 261 | return list(map(self.lambda_func,self.forward_prepare(input))) 262 | 263 | class LambdaReduce(LambdaBase): 264 | def forward(self, input): 265 | return reduce(self.lambda_func,self.forward_prepare(input)) 266 | ''' 267 | varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_') 268 | s = '{}\n\n{} = {}'.format(header,varname,s[:-2]) 269 | 270 | if outputname is None: outputname=varname 271 | with open(outputname+'.py', "w") as pyfile: 272 | pyfile.write(s) 273 | 274 | n = nn.Sequential() 275 | lua_recursive_model(model,n) 276 | torch.save(n.state_dict(),outputname+'.pth') 277 | 278 | 279 | parser = argparse.ArgumentParser(description='Convert torch t7 model to pytorch') 280 | parser.add_argument('--model','-m', type=str, required=True, 281 | help='torch model file in t7 format') 282 | parser.add_argument('--output', '-o', type=str, default=None, 283 | help='output file name prefix, xxx.py xxx.pth') 284 | args = parser.parse_args() 285 | 286 | torch_to_pytorch(args.model,args.output) 287 | --------------------------------------------------------------------------------