├── LICENSE ├── README.md ├── img ├── model.jpg └── readme.md └── svf.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 zechao-li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SVF-pytorch 2 | 3 | This is the official pytorch implementation of [Singular Value Fine-tuning: Few-shot Segmentation requires Few-parameters Fine-tuning](https://arxiv.org/pdf/2206.06122.pdf) in NeurIPS 2022. 4 | 5 | Authors: Yanpeng Sun^, Qiang Chen^, Xiangyu He^, Jian Wang, Haocheng Feng, Junyu Han, Errui Ding, Jian Cheng, [Zechao Li](https://zechao-li.github.io/), Jingdong Wang 6 | 7 |
8 |
9 |
10 | 11 | 12 | ## Usage 13 | 14 | This tool can not only decompose and rebuild the model, but also decompose and rebuild a layer individually. 15 | 16 | ```python 17 | from . import svf 18 | import torchvision.models as models 19 | 20 | model = models.resnet18(pretrained=True) 21 | model = svf.resolver(model, 22 | global_low_rank_ratio=1.0, # no need to change 23 | skip_1x1=False, # we will decompose 1x1 conv layers 24 | skip_3x3=False # we will decompose 3x3 conv layers 25 | ) 26 | ``` 27 | 28 | 29 | ## Pipeline: 30 | 31 | We use a full-rank model as an input, then factorize the original model and return a low-rank model. 32 | 33 | - Previous Convolution Layer 34 | 35 | ```python 36 | conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 37 | ``` 38 | 39 | - Replaced by 40 | 41 | ```python 42 | class SVD_Conv2d(nn.Lyaer): 43 | """Kernel Number first SVD Conv2d 44 | """ 45 | 46 | def __init__(self, in_channels, out_channels, kernel_size, 47 | stride, padding, dilation, groups, bias, 48 | padding_mode='zeros', device=None, dtype=None, 49 | rank=1): 50 | super(SVD_Conv2d, self).__init__() 51 | factory_kwargs = {'device': device, 'dtype': dtype} 52 | self.conv_U = nn.Conv2d(rank, out_channels, (1, 1), (1, 1), 0, (1, 1), 1, bias) 53 | self.conv_V = nn.Conv2d(in_channels, rank, kernel_size, stride, padding, dilation, groups, False) 54 | self.vector_S = nn.Parameter(paddle.empty((1, rank, 1, 1), **factory_kwargs)) 55 | 56 | def forward(self, x): 57 | x = self.conv_V(x) 58 | x = x.mul(self.vector_S) 59 | output = self.conv_U(x) 60 | return output 61 | 62 | ``` 63 | ## Usage in FSS model: 64 | First, decompose and rebuild all layers in the backbone. 65 | 66 | ```python 67 | if args.svf: 68 | self.layer0 = svf.resolver(self.layer0, global_low_rank_ratio=1.0, skip_1x1=False, skip_3x3=False) 69 | self.layer1 = svf.resolver(self.layer1, global_low_rank_ratio=1.0, skip_1x1=False, skip_3x3=False) 70 | self.layer2 = svf.resolver(self.layer2, global_low_rank_ratio=1.0, skip_1x1=False, skip_3x3=False) 71 | self.layer3 = svf.resolver(self.layer3, global_low_rank_ratio=1.0, skip_1x1=False, skip_3x3=False) 72 | self.layer4 = svf.resolver(self.layer4, global_low_rank_ratio=1.0, skip_1x1=False, skip_3x3=False) 73 | ``` 74 | Then, set up the new model freezing strategy. 75 | ```python 76 | def svf_modules(self, model): 77 | for param in model.layer0.parameters(): 78 | param.requires_grad = False 79 | for param in model.layer1.parameters(): 80 | param.requires_grad = False 81 | for name, param in model.layer2.named_parameters(): 82 | param.requires_grad = False 83 | if 'vector_S' in name: 84 | param.requires_grad = True 85 | for name, param in model.layer3.named_parameters(): 86 | param.requires_grad = False 87 | if 'vector_S' in name: 88 | param.requires_grad = True 89 | for name, param in model.layer4.named_parameters(): 90 | param.requires_grad = False 91 | if 'vector_S' in name: 92 | param.requires_grad = True 93 | ``` 94 | -------------------------------------------------------------------------------- /img/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zechao-li/SVF-pytorch/2801557df10606e42653b19866f3361a1baf0ba6/img/model.jpg -------------------------------------------------------------------------------- /img/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /svf.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import inspect 3 | from math import floor 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def d_nsvd(matrix, rank=1): 10 | U, S, V = torch.svd(matrix) 11 | S = S[:rank] 12 | U = U[:, :rank] # * S.view(1, -1) 13 | V = V[:, :rank] # * S.view(1, -1) 14 | V = torch.transpose(V, 0, 1) 15 | return U, S, V 16 | 17 | 18 | class SVD_Conv2d(nn.Module): 19 | """Kernel Number first SVD Conv2d 20 | """ 21 | 22 | def __init__(self, in_channels, out_channels, kernel_size, 23 | stride, padding, dilation, groups, bias, 24 | padding_mode='zeros', device=None, dtype=None, 25 | rank=1): 26 | super(SVD_Conv2d, self).__init__() 27 | factory_kwargs = {'device': device, 'dtype': dtype} 28 | self.conv_U = nn.Conv2d(rank, out_channels, (1, 1), (1, 1), 0, (1, 1), 1, bias) 29 | self.conv_V = nn.Conv2d(in_channels, rank, kernel_size, stride, padding, dilation, groups, False) 30 | self.vector_S = nn.Parameter(torch.empty((1, rank, 1, 1), **factory_kwargs)) 31 | 32 | def forward(self, x): 33 | x = self.conv_V(x) 34 | x = x.mul(self.vector_S) 35 | output = self.conv_U(x) 36 | return output 37 | 38 | class SVD_Linear(nn.Module): 39 | 40 | def __init__(self, in_features, out_features, bias, device=None, dtype=None, rank=1): 41 | super(SVD_Linear, self).__init__() 42 | factory_kwargs = {'device': device, 'dtype': dtype} 43 | self.fc_V = nn.Linear(in_features, rank, False) 44 | self.vector_S = nn.Parameter(torch.empty((1, rank), **factory_kwargs)) 45 | self.fc_U = nn.Linear(rank, out_features, bias) 46 | 47 | def forward(self, x): 48 | x = self.fc_V(x) 49 | x = x.mul(self.vector_S) 50 | output = self.fc_U(x) 51 | return output 52 | 53 | 54 | full2low_mapping_n = { 55 | nn.Conv2d: SVD_Conv2d, 56 | nn.Linear: SVD_Linear 57 | } 58 | 59 | 60 | def replace_fullrank_with_lowrank(model, full2low_mapping={}, layer_rank={}, lowrank_param_dict={}, 61 | module_name=""): 62 | """Recursively replace original full-rank ops with low-rank ops. 63 | """ 64 | if len(full2low_mapping) == 0 or full2low_mapping is None: 65 | return model 66 | else: 67 | for sub_module_name in model._modules: 68 | current_module_name = sub_module_name if module_name == "" else \ 69 | module_name + "." + sub_module_name 70 | # has children 71 | if len(model._modules[sub_module_name]._modules) > 0: 72 | replace_fullrank_with_lowrank(model._modules[sub_module_name], 73 | full2low_mapping, 74 | layer_rank, 75 | lowrank_param_dict, 76 | current_module_name) 77 | else: 78 | if type(getattr(model, sub_module_name)) in full2low_mapping and \ 79 | current_module_name in layer_rank.keys(): 80 | _attr_dict = getattr(model, sub_module_name).__dict__ 81 | # use inspect.signature to know args and kwargs of __init__ 82 | _sig = inspect.signature( 83 | type(getattr(model, sub_module_name))) 84 | _kwargs = {} 85 | for param in _sig.parameters.values(): 86 | if param.name not in _attr_dict.keys(): 87 | if 'bias' in param.name: 88 | if getattr(model, sub_module_name).bias is not None: 89 | value = True 90 | else: 91 | value = False 92 | elif 'stride' in param.name: 93 | value = 1 94 | elif 'padding' in param.name: 95 | value = 0 96 | elif 'dilation' in param.name: 97 | value = 1 98 | elif 'groups' in param.name: 99 | value = 1 100 | elif 'padding_mode' in param.name: 101 | value = 'zeros' 102 | else: 103 | value = None 104 | _kwargs[param.name] = value 105 | else: 106 | _kwargs[param.name] = _attr_dict[param.name] 107 | _kwargs['rank'] = layer_rank[current_module_name] 108 | _layer_new = full2low_mapping[type( 109 | getattr(model, sub_module_name))](**_kwargs) 110 | old_module = getattr(model, sub_module_name) 111 | old_type = type(old_module) 112 | bias_tensor = None 113 | if _kwargs['bias'] == True: 114 | bias_tensor = old_module.bias.data 115 | setattr(model, sub_module_name, _layer_new) 116 | new_module = model._modules[sub_module_name] 117 | if old_type == nn.Conv2d: 118 | conv1 = new_module._modules["conv_V"] 119 | conv2 = new_module._modules["conv_U"] 120 | param_list = lowrank_param_dict[current_module_name] 121 | conv1.weight.data.copy_(param_list[1]) 122 | conv2.weight.data.copy_(param_list[0]) 123 | new_module.vector_S.data.copy_(param_list[2]) 124 | if bias_tensor is not None: 125 | conv2.bias.data.copy_(bias_tensor) 126 | return model 127 | 128 | 129 | class DatafreeSVD(object): 130 | 131 | def __init__(self, model, global_rank_ratio=1.0, 132 | excluded_layers=[], customized_layer_rank_ratio={}, skip_1x1=True, skip_3x3=True): 133 | # class-independent initialization 134 | super(DatafreeSVD, self).__init__() 135 | self.model = model 136 | self.layer_rank = {} 137 | model_dict_key = list(model.state_dict().keys())[0] 138 | model_data_parallel = True if str( 139 | model_dict_key).startswith('module') else False 140 | self.model_cpu = self.model.module.to( 141 | "cpu") if model_data_parallel else self.model.to("cpu") 142 | self.model_named_modules = self.model_cpu.named_modules() 143 | self.rank_base = 4 144 | self.global_rank_ratio = global_rank_ratio 145 | self.excluded_layers = excluded_layers 146 | self.customized_layer_rank_ratio = customized_layer_rank_ratio 147 | self.skip_1x1 = skip_1x1 148 | self.skip_3x3 = skip_3x3 149 | 150 | 151 | 152 | self.param_lowrank_decomp_dict = {} 153 | registered_param_op = [nn.Conv2d, nn.Linear] 154 | 155 | for m_name, m in self.model_named_modules: 156 | if type(m) in registered_param_op and m_name not in self.excluded_layers: 157 | weights_tensor = m.weight.data 158 | tensor_shape = weights_tensor.squeeze().shape 159 | param_1x1 = False 160 | param_3x3 = False 161 | depthwise_conv = False 162 | if len(tensor_shape) == 2: 163 | full_rank = min(tensor_shape[0], tensor_shape[1]) 164 | param_1x1 = True 165 | elif len(tensor_shape) == 4: 166 | full_rank = min( 167 | tensor_shape[0], tensor_shape[1] * tensor_shape[2] * tensor_shape[3]) 168 | if tensor_shape[2] == 1 and tensor_shape[3] == 1: 169 | param_1x1 = True 170 | else: 171 | param_3x3 = True 172 | else: 173 | full_rank = 1 174 | depthwise_conv = True 175 | 176 | if self.skip_1x1 and param_1x1: 177 | continue 178 | if self.skip_3x3 and param_3x3: 179 | continue 180 | if depthwise_conv: 181 | continue 182 | 183 | low_rank = round_to_nearest(full_rank, 184 | ratio=self.global_rank_ratio, 185 | base_number=self.rank_base, 186 | allow_rank_eq1=True) 187 | 188 | self.layer_rank[m_name] = low_rank 189 | 190 | def decompose_layers(self): 191 | self.model_named_modules = self.model_cpu.named_modules() 192 | for m_name, m in self.model_named_modules: 193 | if m_name in self.layer_rank.keys(): 194 | weights_tensor = m.weight.data 195 | tensor_shape = weights_tensor.shape 196 | if len(tensor_shape) == 1: 197 | self.layer_rank[m_name] = 1 198 | continue 199 | elif len(tensor_shape) == 2: 200 | weights_matrix = m.weight.data 201 | U, S, V = d_nsvd(weights_matrix, self.layer_rank[m_name]) 202 | self.param_lowrank_decomp_dict[m_name] = [ 203 | U, V, S.reshape(1, self.layer_rank[m_name])] 204 | elif len(tensor_shape) == 4: 205 | weights_matrix = m.weight.data.reshape(tensor_shape[0], -1) 206 | U, S, V = d_nsvd(weights_matrix, self.layer_rank[m_name]) 207 | self.param_lowrank_decomp_dict[m_name] = [ 208 | U.reshape(tensor_shape[0], 209 | self.layer_rank[m_name], 1, 1), 210 | V.reshape( 211 | self.layer_rank[m_name], tensor_shape[1], tensor_shape[2], tensor_shape[3]), 212 | S.reshape(1, self.layer_rank[m_name], 1, 1) 213 | ] 214 | 215 | def reconstruct_lowrank_network(self): 216 | self.low_rank_model_cpu = copy.deepcopy(self.model_cpu) 217 | self.low_rank_model_cpu = replace_fullrank_with_lowrank( 218 | self.low_rank_model_cpu, 219 | full2low_mapping=full2low_mapping_n, 220 | layer_rank=self.layer_rank, 221 | lowrank_param_dict=self.param_lowrank_decomp_dict, 222 | module_name="" 223 | ) 224 | return self.low_rank_model_cpu 225 | 226 | def round_to_nearest(n, ratio=1.0, base_number=4, allow_rank_eq1=False): 227 | rank = floor(floor(n * ratio) / base_number) * base_number 228 | rank = min(max(rank, 1), n) 229 | if rank == 1: 230 | rank = rank if allow_rank_eq1 else n 231 | return rank 232 | 233 | def resolver( 234 | model, 235 | global_low_rank_ratio=1.0, 236 | excluded_layers=[], 237 | customized_layers_low_rank_ratio={}, 238 | skip_1x1=False, 239 | skip_3x3=False 240 | ): 241 | lowrank_resolver = DatafreeSVD(model, 242 | global_rank_ratio=global_low_rank_ratio, 243 | excluded_layers=excluded_layers, 244 | customized_layer_rank_ratio=customized_layers_low_rank_ratio, 245 | skip_1x1=skip_1x1, 246 | skip_3x3=skip_3x3) 247 | lowrank_resolver.decompose_layers() 248 | lowrank_cpu_model = lowrank_resolver.reconstruct_lowrank_network() 249 | return lowrank_cpu_model 250 | 251 | 252 | if __name__ == "__main__": 253 | origin_model = FSS_model 254 | final_model = resolver(origin_model) --------------------------------------------------------------------------------