├── LICENSE
├── README.md
├── img
├── model.jpg
└── readme.md
└── svf.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 zechao-li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SVF-pytorch
2 |
3 | This is the official pytorch implementation of [Singular Value Fine-tuning: Few-shot Segmentation requires Few-parameters Fine-tuning](https://arxiv.org/pdf/2206.06122.pdf) in NeurIPS 2022.
4 |
5 | Authors: Yanpeng Sun^, Qiang Chen^, Xiangyu He^, Jian Wang, Haocheng Feng, Junyu Han, Errui Ding, Jian Cheng, [Zechao Li](https://zechao-li.github.io/), Jingdong Wang
6 |
7 |
8 |

9 |
10 |
11 |
12 | ## Usage
13 |
14 | This tool can not only decompose and rebuild the model, but also decompose and rebuild a layer individually.
15 |
16 | ```python
17 | from . import svf
18 | import torchvision.models as models
19 |
20 | model = models.resnet18(pretrained=True)
21 | model = svf.resolver(model,
22 | global_low_rank_ratio=1.0, # no need to change
23 | skip_1x1=False, # we will decompose 1x1 conv layers
24 | skip_3x3=False # we will decompose 3x3 conv layers
25 | )
26 | ```
27 |
28 |
29 | ## Pipeline:
30 |
31 | We use a full-rank model as an input, then factorize the original model and return a low-rank model.
32 |
33 | - Previous Convolution Layer
34 |
35 | ```python
36 | conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
37 | ```
38 |
39 | - Replaced by
40 |
41 | ```python
42 | class SVD_Conv2d(nn.Lyaer):
43 | """Kernel Number first SVD Conv2d
44 | """
45 |
46 | def __init__(self, in_channels, out_channels, kernel_size,
47 | stride, padding, dilation, groups, bias,
48 | padding_mode='zeros', device=None, dtype=None,
49 | rank=1):
50 | super(SVD_Conv2d, self).__init__()
51 | factory_kwargs = {'device': device, 'dtype': dtype}
52 | self.conv_U = nn.Conv2d(rank, out_channels, (1, 1), (1, 1), 0, (1, 1), 1, bias)
53 | self.conv_V = nn.Conv2d(in_channels, rank, kernel_size, stride, padding, dilation, groups, False)
54 | self.vector_S = nn.Parameter(paddle.empty((1, rank, 1, 1), **factory_kwargs))
55 |
56 | def forward(self, x):
57 | x = self.conv_V(x)
58 | x = x.mul(self.vector_S)
59 | output = self.conv_U(x)
60 | return output
61 |
62 | ```
63 | ## Usage in FSS model:
64 | First, decompose and rebuild all layers in the backbone.
65 |
66 | ```python
67 | if args.svf:
68 | self.layer0 = svf.resolver(self.layer0, global_low_rank_ratio=1.0, skip_1x1=False, skip_3x3=False)
69 | self.layer1 = svf.resolver(self.layer1, global_low_rank_ratio=1.0, skip_1x1=False, skip_3x3=False)
70 | self.layer2 = svf.resolver(self.layer2, global_low_rank_ratio=1.0, skip_1x1=False, skip_3x3=False)
71 | self.layer3 = svf.resolver(self.layer3, global_low_rank_ratio=1.0, skip_1x1=False, skip_3x3=False)
72 | self.layer4 = svf.resolver(self.layer4, global_low_rank_ratio=1.0, skip_1x1=False, skip_3x3=False)
73 | ```
74 | Then, set up the new model freezing strategy.
75 | ```python
76 | def svf_modules(self, model):
77 | for param in model.layer0.parameters():
78 | param.requires_grad = False
79 | for param in model.layer1.parameters():
80 | param.requires_grad = False
81 | for name, param in model.layer2.named_parameters():
82 | param.requires_grad = False
83 | if 'vector_S' in name:
84 | param.requires_grad = True
85 | for name, param in model.layer3.named_parameters():
86 | param.requires_grad = False
87 | if 'vector_S' in name:
88 | param.requires_grad = True
89 | for name, param in model.layer4.named_parameters():
90 | param.requires_grad = False
91 | if 'vector_S' in name:
92 | param.requires_grad = True
93 | ```
94 |
--------------------------------------------------------------------------------
/img/model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zechao-li/SVF-pytorch/2801557df10606e42653b19866f3361a1baf0ba6/img/model.jpg
--------------------------------------------------------------------------------
/img/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/svf.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import inspect
3 | from math import floor
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | def d_nsvd(matrix, rank=1):
10 | U, S, V = torch.svd(matrix)
11 | S = S[:rank]
12 | U = U[:, :rank] # * S.view(1, -1)
13 | V = V[:, :rank] # * S.view(1, -1)
14 | V = torch.transpose(V, 0, 1)
15 | return U, S, V
16 |
17 |
18 | class SVD_Conv2d(nn.Module):
19 | """Kernel Number first SVD Conv2d
20 | """
21 |
22 | def __init__(self, in_channels, out_channels, kernel_size,
23 | stride, padding, dilation, groups, bias,
24 | padding_mode='zeros', device=None, dtype=None,
25 | rank=1):
26 | super(SVD_Conv2d, self).__init__()
27 | factory_kwargs = {'device': device, 'dtype': dtype}
28 | self.conv_U = nn.Conv2d(rank, out_channels, (1, 1), (1, 1), 0, (1, 1), 1, bias)
29 | self.conv_V = nn.Conv2d(in_channels, rank, kernel_size, stride, padding, dilation, groups, False)
30 | self.vector_S = nn.Parameter(torch.empty((1, rank, 1, 1), **factory_kwargs))
31 |
32 | def forward(self, x):
33 | x = self.conv_V(x)
34 | x = x.mul(self.vector_S)
35 | output = self.conv_U(x)
36 | return output
37 |
38 | class SVD_Linear(nn.Module):
39 |
40 | def __init__(self, in_features, out_features, bias, device=None, dtype=None, rank=1):
41 | super(SVD_Linear, self).__init__()
42 | factory_kwargs = {'device': device, 'dtype': dtype}
43 | self.fc_V = nn.Linear(in_features, rank, False)
44 | self.vector_S = nn.Parameter(torch.empty((1, rank), **factory_kwargs))
45 | self.fc_U = nn.Linear(rank, out_features, bias)
46 |
47 | def forward(self, x):
48 | x = self.fc_V(x)
49 | x = x.mul(self.vector_S)
50 | output = self.fc_U(x)
51 | return output
52 |
53 |
54 | full2low_mapping_n = {
55 | nn.Conv2d: SVD_Conv2d,
56 | nn.Linear: SVD_Linear
57 | }
58 |
59 |
60 | def replace_fullrank_with_lowrank(model, full2low_mapping={}, layer_rank={}, lowrank_param_dict={},
61 | module_name=""):
62 | """Recursively replace original full-rank ops with low-rank ops.
63 | """
64 | if len(full2low_mapping) == 0 or full2low_mapping is None:
65 | return model
66 | else:
67 | for sub_module_name in model._modules:
68 | current_module_name = sub_module_name if module_name == "" else \
69 | module_name + "." + sub_module_name
70 | # has children
71 | if len(model._modules[sub_module_name]._modules) > 0:
72 | replace_fullrank_with_lowrank(model._modules[sub_module_name],
73 | full2low_mapping,
74 | layer_rank,
75 | lowrank_param_dict,
76 | current_module_name)
77 | else:
78 | if type(getattr(model, sub_module_name)) in full2low_mapping and \
79 | current_module_name in layer_rank.keys():
80 | _attr_dict = getattr(model, sub_module_name).__dict__
81 | # use inspect.signature to know args and kwargs of __init__
82 | _sig = inspect.signature(
83 | type(getattr(model, sub_module_name)))
84 | _kwargs = {}
85 | for param in _sig.parameters.values():
86 | if param.name not in _attr_dict.keys():
87 | if 'bias' in param.name:
88 | if getattr(model, sub_module_name).bias is not None:
89 | value = True
90 | else:
91 | value = False
92 | elif 'stride' in param.name:
93 | value = 1
94 | elif 'padding' in param.name:
95 | value = 0
96 | elif 'dilation' in param.name:
97 | value = 1
98 | elif 'groups' in param.name:
99 | value = 1
100 | elif 'padding_mode' in param.name:
101 | value = 'zeros'
102 | else:
103 | value = None
104 | _kwargs[param.name] = value
105 | else:
106 | _kwargs[param.name] = _attr_dict[param.name]
107 | _kwargs['rank'] = layer_rank[current_module_name]
108 | _layer_new = full2low_mapping[type(
109 | getattr(model, sub_module_name))](**_kwargs)
110 | old_module = getattr(model, sub_module_name)
111 | old_type = type(old_module)
112 | bias_tensor = None
113 | if _kwargs['bias'] == True:
114 | bias_tensor = old_module.bias.data
115 | setattr(model, sub_module_name, _layer_new)
116 | new_module = model._modules[sub_module_name]
117 | if old_type == nn.Conv2d:
118 | conv1 = new_module._modules["conv_V"]
119 | conv2 = new_module._modules["conv_U"]
120 | param_list = lowrank_param_dict[current_module_name]
121 | conv1.weight.data.copy_(param_list[1])
122 | conv2.weight.data.copy_(param_list[0])
123 | new_module.vector_S.data.copy_(param_list[2])
124 | if bias_tensor is not None:
125 | conv2.bias.data.copy_(bias_tensor)
126 | return model
127 |
128 |
129 | class DatafreeSVD(object):
130 |
131 | def __init__(self, model, global_rank_ratio=1.0,
132 | excluded_layers=[], customized_layer_rank_ratio={}, skip_1x1=True, skip_3x3=True):
133 | # class-independent initialization
134 | super(DatafreeSVD, self).__init__()
135 | self.model = model
136 | self.layer_rank = {}
137 | model_dict_key = list(model.state_dict().keys())[0]
138 | model_data_parallel = True if str(
139 | model_dict_key).startswith('module') else False
140 | self.model_cpu = self.model.module.to(
141 | "cpu") if model_data_parallel else self.model.to("cpu")
142 | self.model_named_modules = self.model_cpu.named_modules()
143 | self.rank_base = 4
144 | self.global_rank_ratio = global_rank_ratio
145 | self.excluded_layers = excluded_layers
146 | self.customized_layer_rank_ratio = customized_layer_rank_ratio
147 | self.skip_1x1 = skip_1x1
148 | self.skip_3x3 = skip_3x3
149 |
150 |
151 |
152 | self.param_lowrank_decomp_dict = {}
153 | registered_param_op = [nn.Conv2d, nn.Linear]
154 |
155 | for m_name, m in self.model_named_modules:
156 | if type(m) in registered_param_op and m_name not in self.excluded_layers:
157 | weights_tensor = m.weight.data
158 | tensor_shape = weights_tensor.squeeze().shape
159 | param_1x1 = False
160 | param_3x3 = False
161 | depthwise_conv = False
162 | if len(tensor_shape) == 2:
163 | full_rank = min(tensor_shape[0], tensor_shape[1])
164 | param_1x1 = True
165 | elif len(tensor_shape) == 4:
166 | full_rank = min(
167 | tensor_shape[0], tensor_shape[1] * tensor_shape[2] * tensor_shape[3])
168 | if tensor_shape[2] == 1 and tensor_shape[3] == 1:
169 | param_1x1 = True
170 | else:
171 | param_3x3 = True
172 | else:
173 | full_rank = 1
174 | depthwise_conv = True
175 |
176 | if self.skip_1x1 and param_1x1:
177 | continue
178 | if self.skip_3x3 and param_3x3:
179 | continue
180 | if depthwise_conv:
181 | continue
182 |
183 | low_rank = round_to_nearest(full_rank,
184 | ratio=self.global_rank_ratio,
185 | base_number=self.rank_base,
186 | allow_rank_eq1=True)
187 |
188 | self.layer_rank[m_name] = low_rank
189 |
190 | def decompose_layers(self):
191 | self.model_named_modules = self.model_cpu.named_modules()
192 | for m_name, m in self.model_named_modules:
193 | if m_name in self.layer_rank.keys():
194 | weights_tensor = m.weight.data
195 | tensor_shape = weights_tensor.shape
196 | if len(tensor_shape) == 1:
197 | self.layer_rank[m_name] = 1
198 | continue
199 | elif len(tensor_shape) == 2:
200 | weights_matrix = m.weight.data
201 | U, S, V = d_nsvd(weights_matrix, self.layer_rank[m_name])
202 | self.param_lowrank_decomp_dict[m_name] = [
203 | U, V, S.reshape(1, self.layer_rank[m_name])]
204 | elif len(tensor_shape) == 4:
205 | weights_matrix = m.weight.data.reshape(tensor_shape[0], -1)
206 | U, S, V = d_nsvd(weights_matrix, self.layer_rank[m_name])
207 | self.param_lowrank_decomp_dict[m_name] = [
208 | U.reshape(tensor_shape[0],
209 | self.layer_rank[m_name], 1, 1),
210 | V.reshape(
211 | self.layer_rank[m_name], tensor_shape[1], tensor_shape[2], tensor_shape[3]),
212 | S.reshape(1, self.layer_rank[m_name], 1, 1)
213 | ]
214 |
215 | def reconstruct_lowrank_network(self):
216 | self.low_rank_model_cpu = copy.deepcopy(self.model_cpu)
217 | self.low_rank_model_cpu = replace_fullrank_with_lowrank(
218 | self.low_rank_model_cpu,
219 | full2low_mapping=full2low_mapping_n,
220 | layer_rank=self.layer_rank,
221 | lowrank_param_dict=self.param_lowrank_decomp_dict,
222 | module_name=""
223 | )
224 | return self.low_rank_model_cpu
225 |
226 | def round_to_nearest(n, ratio=1.0, base_number=4, allow_rank_eq1=False):
227 | rank = floor(floor(n * ratio) / base_number) * base_number
228 | rank = min(max(rank, 1), n)
229 | if rank == 1:
230 | rank = rank if allow_rank_eq1 else n
231 | return rank
232 |
233 | def resolver(
234 | model,
235 | global_low_rank_ratio=1.0,
236 | excluded_layers=[],
237 | customized_layers_low_rank_ratio={},
238 | skip_1x1=False,
239 | skip_3x3=False
240 | ):
241 | lowrank_resolver = DatafreeSVD(model,
242 | global_rank_ratio=global_low_rank_ratio,
243 | excluded_layers=excluded_layers,
244 | customized_layer_rank_ratio=customized_layers_low_rank_ratio,
245 | skip_1x1=skip_1x1,
246 | skip_3x3=skip_3x3)
247 | lowrank_resolver.decompose_layers()
248 | lowrank_cpu_model = lowrank_resolver.reconstruct_lowrank_network()
249 | return lowrank_cpu_model
250 |
251 |
252 | if __name__ == "__main__":
253 | origin_model = FSS_model
254 | final_model = resolver(origin_model)
--------------------------------------------------------------------------------