├── CODE_LICENSE.txt ├── LSUV.py ├── README.md ├── example.py └── imgs ├── cafe.png ├── cat.png ├── dum.png ├── face.png ├── fox.png ├── girl.png ├── grand.png ├── mag.png ├── pkk.png └── vin.png /CODE_LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (C) 2017, Dmytro Mishkin 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 1. Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright 10 | notice, this list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the 12 | distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 15 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 16 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 17 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 18 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 19 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 20 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 21 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /LSUV.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import torch 4 | import torch.nn.init 5 | import torch.nn as nn 6 | 7 | gg = {} 8 | gg['hook_position'] = 0 9 | gg['total_fc_conv_layers'] = 0 10 | gg['done_counter'] = -1 11 | gg['hook'] = None 12 | gg['act_dict'] = {} 13 | gg['counter_to_apply_correction'] = 0 14 | gg['correction_needed'] = False 15 | gg['current_coef'] = 1.0 16 | 17 | # Orthonorm init code is taked from Lasagne 18 | # https://github.com/Lasagne/Lasagne/blob/master/lasagne/init.py 19 | def svd_orthonormal(w): 20 | shape = w.shape 21 | if len(shape) < 2: 22 | raise RuntimeError("Only shapes of length 2 or more are supported.") 23 | flat_shape = (shape[0], np.prod(shape[1:])) 24 | a = np.random.normal(0.0, 1.0, flat_shape)#w; 25 | u, _, v = np.linalg.svd(a, full_matrices=False) 26 | q = u if u.shape == flat_shape else v 27 | print (shape, flat_shape) 28 | q = q.reshape(shape) 29 | return q.astype(np.float32) 30 | 31 | def store_activations(self, input, output): 32 | gg['act_dict'] = output.data.cpu().numpy(); 33 | #print('act shape = ', gg['act_dict'].shape) 34 | return 35 | 36 | 37 | def add_current_hook(m): 38 | if gg['hook'] is not None: 39 | return 40 | if (isinstance(m, nn.Conv2d)) or (isinstance(m, nn.Linear)): 41 | #print 'trying to hook to', m, gg['hook_position'], gg['done_counter'] 42 | if gg['hook_position'] > gg['done_counter']: 43 | gg['hook'] = m.register_forward_hook(store_activations) 44 | #print ' hooking layer = ', gg['hook_position'], m 45 | else: 46 | #print m, 'already done, skipping' 47 | gg['hook_position'] += 1 48 | return 49 | 50 | def count_conv_fc_layers(m): 51 | if (isinstance(m, nn.Conv2d)) or (isinstance(m, nn.Linear)): 52 | gg['total_fc_conv_layers'] +=1 53 | return 54 | 55 | def remove_hooks(hooks): 56 | for h in hooks: 57 | h.remove() 58 | return 59 | def orthogonal_weights_init(m): 60 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 61 | if hasattr(m, 'weight'): 62 | w_ortho = svd_orthonormal(m.weight.data.cpu().numpy()) 63 | m.weight.data = torch.from_numpy(w_ortho) 64 | try: 65 | nn.init.constant(m.bias, 0) 66 | except: 67 | pass 68 | else: 69 | #nn.init.orthogonal(m.weight) 70 | w_ortho = svd_orthonormal(m.weight.data.cpu().numpy()) 71 | #print w_ortho 72 | #m.weight.data.copy_(torch.from_numpy(w_ortho)) 73 | m.weight.data = torch.from_numpy(w_ortho) 74 | try: 75 | nn.init.constant(m.bias, 0) 76 | except: 77 | pass 78 | return 79 | 80 | def apply_weights_correction(m): 81 | if gg['hook'] is None: 82 | return 83 | if not gg['correction_needed']: 84 | return 85 | if (isinstance(m, nn.Conv2d)) or (isinstance(m, nn.Linear)): 86 | if gg['counter_to_apply_correction'] < gg['hook_position']: 87 | gg['counter_to_apply_correction'] += 1 88 | else: 89 | if hasattr(m, 'weight'): 90 | m.weight.data *= float(gg['current_coef']) 91 | gg['correction_needed'] = False 92 | if hasattr(m, 'bias'): 93 | if m.bias is not None: 94 | m.bias.data += float(gg['current_bias']) 95 | return 96 | return 97 | 98 | def LSUVinit(model,data, needed_std = 1.0, std_tol = 0.1, max_attempts = 10, do_orthonorm = True,needed_mean = 0., cuda = False, verbose = True): 99 | cuda = data.is_cuda 100 | gg['total_fc_conv_layers']=0 101 | gg['done_counter']= 0 102 | gg['hook_position'] = 0 103 | gg['hook'] = None 104 | model.eval(); 105 | if cuda: 106 | model = model.cuda() 107 | data = data.cuda() 108 | else: 109 | model = model.cpu() 110 | data = data.cpu() 111 | if verbose: print( 'Starting LSUV') 112 | model.apply(count_conv_fc_layers) 113 | if verbose: print ('Total layers to process:', gg['total_fc_conv_layers']) 114 | with torch.no_grad(): 115 | if do_orthonorm: 116 | model.apply(orthogonal_weights_init) 117 | if verbose: print ('Orthonorm done') 118 | if cuda: 119 | model = model.cuda() 120 | for layer_idx in range(gg['total_fc_conv_layers']): 121 | if verbose: print (layer_idx) 122 | model.apply(add_current_hook) 123 | out = model(data) 124 | current_std = gg['act_dict'].std() 125 | current_mean = gg['act_dict'].mean() 126 | if verbose: print ('std at layer ',layer_idx, ' = ', current_std) 127 | #print gg['act_dict'].shape 128 | attempts = 0 129 | while (np.abs(current_std - needed_std) > std_tol): 130 | gg['current_coef'] = needed_std / (current_std + 1e-8); 131 | gg['current_bias'] = needed_mean - current_mean * gg['current_coef']; 132 | gg['correction_needed'] = True 133 | model.apply(apply_weights_correction) 134 | if cuda: 135 | model = model.cuda() 136 | out = model(data) 137 | current_std = gg['act_dict'].std() 138 | current_mean = gg['act_dict'].mean() 139 | if verbose: print ('std at layer ',layer_idx, ' = ', current_std, 'mean = ', current_mean) 140 | attempts+=1 141 | if attempts > max_attempts: 142 | if verbose: print ('Cannot converge in ', max_attempts, 'iterations') 143 | break 144 | if gg['hook'] is not None: 145 | gg['hook'].remove() 146 | gg['done_counter']+=1 147 | gg['counter_to_apply_correction'] = 0 148 | gg['hook_position'] = 0 149 | gg['hook'] = None 150 | if verbose: print ('finish at layer',layer_idx ) 151 | if verbose: print ('LSUV init done!') 152 | if not cuda: 153 | model = model.cpu() 154 | return model 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Layer-sequential unit-variance (LSUV) initialization for PyTorch 2 | 3 | # NEW repo: [ducha-aiki/lsuv](https://github.com/ducha-aiki/lsuv) 4 | 5 | ``` 6 | pip install lsuv 7 | ``` 8 | 9 | This is sample code for LSUV and initializations, implemented in python script within PyTorch framework. 10 | 11 | Usage: 12 | 13 | from LSUV import LSUVinit 14 | ... 15 | model = LSUVinit(model,data) 16 | 17 | See detailed example in [example.py](example.py) 18 | 19 | LSUV initialization is described in: 20 | 21 | Mishkin, D. and Matas, J.,(2015). All you need is a good init. ICLR 2016 [arXiv:1511.06422](http://arxiv.org/abs/1511.06422). 22 | 23 | Original Caffe implementation [https://github.com/ducha-aiki/LSUVinit](https://github.com/ducha-aiki/LSUVinit) 24 | 25 | Torch re-implementation [https://github.com/yobibyte/torch-lsuv](https://github.com/yobibyte/torch-lsuv) 26 | 27 | Keras implementation: [https://github.com/ducha-aiki/LSUV-keras](https://github.com/ducha-aiki/LSUV-keras) 28 | 29 | **New!** Thinc re-implementation [LSUV-thinc](https://github.com/explosion/thinc/blob/e653dd3dfe91f8572e2001c8943dbd9b9401768b/thinc/neural/_lsuv.py) 30 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | import numpy as np 4 | from LSUV import LSUVinit 5 | import sys 6 | import os 7 | sys.path.insert(0, '/home/ubuntu/dev/opencv-3.1/build/lib') 8 | import cv2 9 | from torch.autograd import Variable 10 | images_to_process = [] 11 | for img_fname in os.listdir('imgs'): 12 | img = cv2.imread('imgs/' + img_fname) 13 | print (img.shape) 14 | if img is not None: 15 | images_to_process.append(np.transpose(cv2.resize(img, (224,224)), (2,0,1) )) 16 | 17 | data = np.array(images_to_process).astype(np.float32) 18 | data = torch.from_numpy(data) 19 | alexnet = models.densenet121(pretrained=False) 20 | alexnet = LSUVinit(alexnet,data, needed_std = 1.0, std_tol = 0.1, max_attempts = 10, needed_mean = 0., do_orthonorm = False) 21 | -------------------------------------------------------------------------------- /imgs/cafe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducha-aiki/LSUV-pytorch/60e9fee5f15e0d4edd1f4d9f0587c7149c1974ed/imgs/cafe.png -------------------------------------------------------------------------------- /imgs/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducha-aiki/LSUV-pytorch/60e9fee5f15e0d4edd1f4d9f0587c7149c1974ed/imgs/cat.png -------------------------------------------------------------------------------- /imgs/dum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducha-aiki/LSUV-pytorch/60e9fee5f15e0d4edd1f4d9f0587c7149c1974ed/imgs/dum.png -------------------------------------------------------------------------------- /imgs/face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducha-aiki/LSUV-pytorch/60e9fee5f15e0d4edd1f4d9f0587c7149c1974ed/imgs/face.png -------------------------------------------------------------------------------- /imgs/fox.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducha-aiki/LSUV-pytorch/60e9fee5f15e0d4edd1f4d9f0587c7149c1974ed/imgs/fox.png -------------------------------------------------------------------------------- /imgs/girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducha-aiki/LSUV-pytorch/60e9fee5f15e0d4edd1f4d9f0587c7149c1974ed/imgs/girl.png -------------------------------------------------------------------------------- /imgs/grand.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducha-aiki/LSUV-pytorch/60e9fee5f15e0d4edd1f4d9f0587c7149c1974ed/imgs/grand.png -------------------------------------------------------------------------------- /imgs/mag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducha-aiki/LSUV-pytorch/60e9fee5f15e0d4edd1f4d9f0587c7149c1974ed/imgs/mag.png -------------------------------------------------------------------------------- /imgs/pkk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducha-aiki/LSUV-pytorch/60e9fee5f15e0d4edd1f4d9f0587c7149c1974ed/imgs/pkk.png -------------------------------------------------------------------------------- /imgs/vin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ducha-aiki/LSUV-pytorch/60e9fee5f15e0d4edd1f4d9f0587c7149c1974ed/imgs/vin.png --------------------------------------------------------------------------------