├── LICENSE ├── README.md ├── cal_bit.py ├── cal_bit_huffman.py ├── cifar10_models ├── __init__.py ├── data.py ├── densenet.py ├── googlenet.py ├── inception.py ├── mobilenetv2.py ├── resnet.py ├── resnet_orig.py └── vgg.py ├── clib ├── __init__.py ├── cos_stat │ ├── __init__.py │ ├── cos_stat.cpp │ ├── cos_stat.cu │ ├── cos_stat.py │ └── setup.py ├── lin_stat │ ├── __init__.py │ ├── lin_stat.cpp │ ├── lin_stat.cu │ ├── lin_stat.py │ └── setup.py └── tri_stat │ ├── __init__.py │ ├── setup.py │ ├── tri_stat.cpp │ ├── tri_stat.cu │ └── tri_stat.py ├── common ├── __init__.py └── tools.py ├── compress ├── __init__.py ├── huffman.py └── range_coder.py ├── data_new ├── __init__.py ├── data.py └── imagenet.py ├── model ├── __init__.py ├── mnasnet.py ├── mobilenetv2.py ├── regnet.py ├── resnet.py └── vision_transformer.py ├── quant ├── __init__.py ├── linspace_centric.py └── soft_quant.py ├── requirements.txt ├── run.sh ├── sparse └── soft_sparse.py ├── train.txt ├── train ├── __init__.py ├── logger.py ├── ops.py ├── test_baseline.py ├── train_final.py └── train_yolo.py └── transform ├── edge_scale.py ├── exp.py ├── log.py ├── ms.py └── scale.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # L^2 compression 2 | 3 | ## Introduction 4 | 5 | Deep neural networks have delivered remarkable performance and have been widely used in various visual tasks. However, their huge size causes significant inconvenience for transmission and storage. This work proposes a unified post-training model size compression method that combines lossy and lossless compression. 6 | 7 | ## Requirements 8 | 9 | * torch 10 | 11 | * torchvision 12 | 13 | * constriction 14 | 15 | * ninja 16 | 17 | * matplotlib 18 | 19 | * timm 20 | 21 | ````shell 22 | pip install -r requirements.txt 23 | ```` 24 | 25 | ## Usage 26 | 27 | ### train 28 | 29 | * change the model checkpoints paths in train/train_final 30 | 31 | ```` python 32 | state_path = { 33 | "resnet18": "", 34 | "resnet50": "", 35 | "mobilenetv2": "", 36 | "mnasnet": "", 37 | "regnetx_600m": "", 38 | "regnetx_3200m": "" 39 | } 40 | ```` 41 | 42 | * change dataset in train/train_final 43 | 44 | ````python 45 | train_dataset = ImageNetDataset(ROOTDIR + '/train/', 'train.txt', train_transform) 46 | test_dataset = ImageNetDataset(ROOTDIR + 'val/', METADIR + 'val.txt', val_transform) 47 | ```` 48 | 49 | * the calibration dataset we provide is train.txt in the repo 50 | 51 | * examples are in run.sh 52 | 53 | ````shell 54 | # example ResNet18 55 | python -u -m train.train_final \ 56 | --model_name resnet18 \ 57 | --lambda_r 1e-6 \ 58 | --lambda_kd 1.0 \ 59 | --weight_transform edgescale \ 60 | --bias_transform scale \ 61 | --transform_iter 300 \ 62 | --transform_lr 0.0001 \ 63 | --reconstruct_iter 1000 \ 64 | --reconstruct_lr 5e-06 \ 65 | --resolution 64 \ 66 | --diffkernel cos \ 67 | --log_path ./log \ 68 | --target_CR 10.0 \ 69 | --run_name resnet_test 70 | ```` 71 | 72 | ## encode and decode 73 | 74 | refer to cal_bit.py -------------------------------------------------------------------------------- /cal_bit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from compress.range_coder import compress_matrix_flatten, compress_matrix_flatten_new 4 | from common.tools import get_np_size 5 | import numpy as np 6 | import constriction 7 | import time 8 | import torch 9 | 10 | log_dirs = [""] 11 | 12 | def encode(pkl_path): 13 | pkl_file = pkl_path 14 | dict = pickle.load(open(pkl_file, "rb")) 15 | max_epoch = None 16 | 17 | bit_sum_epoch = None 18 | bit_quant_epoch = None 19 | bit_sym_epoch = None 20 | value_dict = dict["special"] 21 | res_str = "" 22 | time_cost = 0 23 | for name, state in dict.items(): 24 | # print(name) 25 | if name=="special": 26 | continue 27 | if not max_epoch: 28 | max_epoch = len(state['np_quant']) 29 | bit_sum_epoch = [0.0 for _ in range(max_epoch)] 30 | bit_quant_epoch = [0.0 for _ in range(max_epoch)] 31 | bit_sym_epoch = [0.0 for _ in range(max_epoch)] 32 | for i in range(max_epoch): 33 | quant_compressed, quant_symbol, probabilities = compress_matrix_flatten_new(state['np_quant'][i]) 34 | bit_sym_epoch[i] += get_np_size(quant_symbol) * 8 35 | bit_quant_epoch[i] += get_np_size(quant_compressed) * 8 36 | bit_sum_epoch[i] += get_np_size(quant_compressed) * 8 + 3 * get_np_size(quant_symbol) * 8 37 | quant_compressed.tofile("compressed_" + name + ".bin") 38 | quant_symbol.tofile("symbol_" + name + ".bin") 39 | probabilities.tofile("probabilities_" + name + ".bin") 40 | 41 | origin_bit = value_dict["origin_bit"] 42 | for i in range(max_epoch): 43 | tmp = "origin: {}KB, compress: {}KB, symbol: {}KB, CR: {}x \n".format(origin_bit / (8 * 1024), 44 | bit_quant_epoch[i] / (8 * 1024), 45 | 3 * bit_sym_epoch[i] / (8 * 1024), 46 | origin_bit / (bit_quant_epoch[i] + 3 * bit_sym_epoch[i])) 47 | res_str += tmp 48 | print(res_str) 49 | 50 | 51 | def decode(model_path): 52 | dict = torch.load(model_path) 53 | new_dict = {} 54 | for name, value in dict.items(): 55 | quant_compressed = np.fromfile("compressed_" + name + ".bin") 56 | quant_symbol = np.fromfile("symbol_" + name + ".bin") 57 | probabilities = np.fromfile("probabilities_" + name + ".bin") 58 | decoder = constriction.stream.queue.RangeDecoder(quant_compressed) 59 | probabilities_model = constriction.stream.model.Categorical(probabilities) 60 | decoded = decoder.decode(probabilities_model, value.size) 61 | decoded = quant_symbol[decoded] 62 | decoded = quant_symbol[decoded].reshape(value.shape) 63 | new_dict[name] = torch.from_numpy(decoded) 64 | 65 | 66 | -------------------------------------------------------------------------------- /cal_bit_huffman.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from compress.huffman import compress_matrix_flatten 4 | from common.tools import get_np_size 5 | 6 | log_dirs = [""] 7 | 8 | for log_dir in log_dirs: 9 | for file in os.listdir(log_dir): 10 | try: 11 | if file.endswith(".log") and not file.endswith("range.log"): 12 | # print("start {}".format(file)) 13 | dir_name = file[:-4] 14 | pkl_file = os.path.join(log_dir, dir_name, "test_state.pkl") 15 | dict = pickle.load(open(pkl_file, "rb")) 16 | max_epoch = None 17 | bit_sum_epoch = None 18 | bit_quant_epoch = None 19 | bit_sym_epoch = None 20 | value_dict = dict["special"] 21 | res_str = "" 22 | for name, state in dict.items(): 23 | # print(name) 24 | if name=="special": 25 | continue 26 | if not max_epoch: 27 | max_epoch = len(state['np_quant']) 28 | bit_sum_epoch = [0.0 for _ in range(max_epoch)] 29 | bit_quant_epoch = [0.0 for _ in range(max_epoch)] 30 | bit_sym_epoch = [0.0 for _ in range(max_epoch)] 31 | for i in range(max_epoch): 32 | bitrate, quant_symbol = compress_matrix_flatten(state['np_quant'][i]) 33 | bit_sym_epoch[i] += get_np_size(quant_symbol) * 8 34 | bit_quant_epoch[i] += bitrate 35 | bit_sum_epoch[i] += bitrate + get_np_size(quant_symbol) * 8 36 | 37 | origin_bit = value_dict["origin_bit"] 38 | for i in range(max_epoch): 39 | tmp = "origin: {}KB, compress: {}KB, symbol: {}KB, CR: {}x \n".format(origin_bit / (8 * 1024), 40 | bit_quant_epoch[i] / (8 * 1024), 41 | bit_sym_epoch[i] / (8 * 1024), 42 | origin_bit / ( 43 | bit_quant_epoch[i] + bit_sym_epoch[i])) 44 | res_str += tmp 45 | # print(dir_name, tmp) 46 | with open(os.path.join(log_dir, dir_name + "_huffman.log"), "w") as f: 47 | f.write(res_str) 48 | print("success", file) 49 | except: 50 | print("fail", file) 51 | continue -------------------------------------------------------------------------------- /cifar10_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/L2_Compression/4d28b141930399cfb4990a63a3fae85c23a1b1e5/cifar10_models/__init__.py -------------------------------------------------------------------------------- /cifar10_models/data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torchvision import transforms as T 3 | from torchvision.datasets import CIFAR10 4 | from tqdm import tqdm 5 | 6 | 7 | class CIFAR10Data(): 8 | def __init__(self): 9 | super().__init__() 10 | self.mean = (0.4914, 0.4822, 0.4465) 11 | self.std = (0.2471, 0.2435, 0.2616) 12 | 13 | def train_dataloader(self): 14 | transform = T.Compose( 15 | [ 16 | T.RandomCrop(32, padding=4), 17 | T.RandomHorizontalFlip(), 18 | T.ToTensor(), 19 | T.Normalize(self.mean, self.std), 20 | ] 21 | ) 22 | dataset = CIFAR10(root="./data", train=True, transform=transform) 23 | dataloader = DataLoader( 24 | dataset, 25 | batch_size=64, 26 | num_workers=4, 27 | shuffle=True, 28 | drop_last=True, 29 | pin_memory=True, 30 | ) 31 | return dataloader 32 | 33 | def val_dataloader(self): 34 | transform = T.Compose( 35 | [ 36 | T.ToTensor(), 37 | T.Normalize(self.mean, self.std), 38 | ] 39 | ) 40 | dataset = CIFAR10(root="./data", train=False, transform=transform) 41 | dataloader = DataLoader( 42 | dataset, 43 | batch_size=32, 44 | num_workers=2, 45 | drop_last=True, 46 | pin_memory=True, 47 | ) 48 | return dataloader 49 | 50 | def test_dataloader(self): 51 | return self.val_dataloader() -------------------------------------------------------------------------------- /cifar10_models/densenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | __all__ = ["DenseNet", "densenet121", "densenet169", "densenet161"] 9 | 10 | 11 | class _DenseLayer(nn.Sequential): 12 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 13 | super(_DenseLayer, self).__init__() 14 | self.add_module("norm1", nn.BatchNorm2d(num_input_features)), 15 | self.add_module("relu1", nn.ReLU(inplace=True)), 16 | self.add_module( 17 | "conv1", 18 | nn.Conv2d( 19 | num_input_features, 20 | bn_size * growth_rate, 21 | kernel_size=1, 22 | stride=1, 23 | bias=False, 24 | ), 25 | ), 26 | self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate)), 27 | self.add_module("relu2", nn.ReLU(inplace=True)), 28 | self.add_module( 29 | "conv2", 30 | nn.Conv2d( 31 | bn_size * growth_rate, 32 | growth_rate, 33 | kernel_size=3, 34 | stride=1, 35 | padding=1, 36 | bias=False, 37 | ), 38 | ), 39 | self.drop_rate = drop_rate 40 | 41 | def forward(self, x): 42 | new_features = super(_DenseLayer, self).forward(x) 43 | if self.drop_rate > 0: 44 | new_features = F.dropout( 45 | new_features, p=self.drop_rate, training=self.training 46 | ) 47 | return torch.cat([x, new_features], 1) 48 | 49 | 50 | class _DenseBlock(nn.Sequential): 51 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 52 | super(_DenseBlock, self).__init__() 53 | for i in range(num_layers): 54 | layer = _DenseLayer( 55 | num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate 56 | ) 57 | self.add_module("denselayer%d" % (i + 1), layer) 58 | 59 | 60 | class _Transition(nn.Sequential): 61 | def __init__(self, num_input_features, num_output_features): 62 | super(_Transition, self).__init__() 63 | self.add_module("norm", nn.BatchNorm2d(num_input_features)) 64 | self.add_module("relu", nn.ReLU(inplace=True)) 65 | self.add_module( 66 | "conv", 67 | nn.Conv2d( 68 | num_input_features, 69 | num_output_features, 70 | kernel_size=1, 71 | stride=1, 72 | bias=False, 73 | ), 74 | ) 75 | self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2)) 76 | 77 | 78 | class DenseNet(nn.Module): 79 | r"""Densenet-BC model class, based on 80 | `"Densely Connected Convolutional Networks" `_ 81 | 82 | Args: 83 | growth_rate (int) - how many filters to add each layer (`k` in paper) 84 | block_config (list of 4 ints) - how many layers in each pooling block 85 | num_init_features (int) - the number of filters to learn in the first convolution layer 86 | bn_size (int) - multiplicative factor for number of bottle neck layers 87 | (i.e. bn_size * k features in the bottleneck layer) 88 | drop_rate (float) - dropout rate after each dense layer 89 | num_classes (int) - number of classification classes 90 | """ 91 | 92 | def __init__( 93 | self, 94 | growth_rate=32, 95 | block_config=(6, 12, 24, 16), 96 | num_init_features=64, 97 | bn_size=4, 98 | drop_rate=0, 99 | num_classes=10, 100 | ): 101 | 102 | super(DenseNet, self).__init__() 103 | 104 | # First convolution 105 | 106 | # CIFAR-10: kernel_size 7 ->3, stride 2->1, padding 3->1 107 | self.features = nn.Sequential( 108 | OrderedDict( 109 | [ 110 | ( 111 | "conv0", 112 | nn.Conv2d( 113 | 3, 114 | num_init_features, 115 | kernel_size=3, 116 | stride=1, 117 | padding=1, 118 | bias=False, 119 | ), 120 | ), 121 | ("norm0", nn.BatchNorm2d(num_init_features)), 122 | ("relu0", nn.ReLU(inplace=True)), 123 | ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 124 | ] 125 | ) 126 | ) 127 | # END 128 | 129 | # Each denseblock 130 | num_features = num_init_features 131 | for i, num_layers in enumerate(block_config): 132 | block = _DenseBlock( 133 | num_layers=num_layers, 134 | num_input_features=num_features, 135 | bn_size=bn_size, 136 | growth_rate=growth_rate, 137 | drop_rate=drop_rate, 138 | ) 139 | self.features.add_module("denseblock%d" % (i + 1), block) 140 | num_features = num_features + num_layers * growth_rate 141 | if i != len(block_config) - 1: 142 | trans = _Transition( 143 | num_input_features=num_features, 144 | num_output_features=num_features // 2, 145 | ) 146 | self.features.add_module("transition%d" % (i + 1), trans) 147 | num_features = num_features // 2 148 | 149 | # Final batch norm 150 | self.features.add_module("norm5", nn.BatchNorm2d(num_features)) 151 | 152 | # Linear layer 153 | self.classifier = nn.Linear(num_features, num_classes) 154 | 155 | # Official init from torch repo. 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv2d): 158 | nn.init.kaiming_normal_(m.weight) 159 | elif isinstance(m, nn.BatchNorm2d): 160 | nn.init.constant_(m.weight, 1) 161 | nn.init.constant_(m.bias, 0) 162 | elif isinstance(m, nn.Linear): 163 | nn.init.constant_(m.bias, 0) 164 | 165 | def forward(self, x): 166 | features = self.features(x) 167 | out = F.relu(features, inplace=True) 168 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) 169 | out = self.classifier(out) 170 | return out 171 | 172 | 173 | def _densenet( 174 | arch, 175 | growth_rate, 176 | block_config, 177 | num_init_features, 178 | pretrained, 179 | progress, 180 | device, 181 | **kwargs 182 | ): 183 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 184 | if pretrained: 185 | script_dir = os.path.dirname(__file__) 186 | state_dict = torch.load( 187 | script_dir + "/state_dicts/" + arch + ".pt", map_location=device 188 | ) 189 | model.load_state_dict(state_dict) 190 | return model 191 | 192 | 193 | def densenet121(pretrained=False, progress=True, device="cpu", **kwargs): 194 | r"""Densenet-121 model from 195 | `"Densely Connected Convolutional Networks" `_ 196 | 197 | Args: 198 | pretrained (bool): If True, returns a model pre-trained on ImageNet 199 | progress (bool): If True, displays a progress bar of the download to stderr 200 | """ 201 | return _densenet( 202 | "densenet121", 32, (6, 12, 24, 16), 64, pretrained, progress, device, **kwargs 203 | ) 204 | 205 | 206 | def densenet161(pretrained=False, progress=True, device="cpu", **kwargs): 207 | r"""Densenet-161 model from 208 | `"Densely Connected Convolutional Networks" `_ 209 | 210 | Args: 211 | pretrained (bool): If True, returns a model pre-trained on ImageNet 212 | progress (bool): If True, displays a progress bar of the download to stderr 213 | """ 214 | return _densenet( 215 | "densenet161", 48, (6, 12, 36, 24), 96, pretrained, progress, device, **kwargs 216 | ) 217 | 218 | 219 | def densenet169(pretrained=False, progress=True, device="cpu", **kwargs): 220 | r"""Densenet-169 model from 221 | `"Densely Connected Convolutional Networks" `_ 222 | 223 | Args: 224 | pretrained (bool): If True, returns a model pre-trained on ImageNet 225 | progress (bool): If True, displays a progress bar of the download to stderr 226 | """ 227 | return _densenet( 228 | "densenet169", 32, (6, 12, 32, 32), 64, pretrained, progress, device, **kwargs 229 | ) 230 | -------------------------------------------------------------------------------- /cifar10_models/googlenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import namedtuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | __all__ = ["GoogLeNet", "googlenet"] 9 | 10 | 11 | _GoogLeNetOuputs = namedtuple( 12 | "GoogLeNetOuputs", ["logits", "aux_logits2", "aux_logits1"] 13 | ) 14 | 15 | 16 | def googlenet(pretrained=False, progress=True, device="cpu", **kwargs): 17 | r"""GoogLeNet (Inception v1) model architecture from 18 | `"Going Deeper with Convolutions" `_. 19 | 20 | Args: 21 | pretrained (bool): If True, returns a model pre-trained on ImageNet 22 | progress (bool): If True, displays a progress bar of the download to stderr 23 | aux_logits (bool): If True, adds two auxiliary branches that can improve training. 24 | Default: *False* when pretrained is True otherwise *True* 25 | transform_input (bool): If True, preprocesses the input according to the method with which it 26 | was trained on ImageNet. Default: *False* 27 | """ 28 | model = GoogLeNet() 29 | if pretrained: 30 | script_dir = os.path.dirname(__file__) 31 | state_dict = torch.load( 32 | script_dir + "/state_dicts/googlenet.pt", map_location=device 33 | ) 34 | model.load_state_dict(state_dict) 35 | return model 36 | 37 | 38 | class GoogLeNet(nn.Module): 39 | 40 | # CIFAR10: aux_logits True->False 41 | def __init__(self, num_classes=10, aux_logits=False, transform_input=False): 42 | super(GoogLeNet, self).__init__() 43 | self.aux_logits = aux_logits 44 | self.transform_input = transform_input 45 | 46 | # CIFAR10: out_channels 64->192, kernel_size 7->3, stride 2->1, padding 3->1 47 | self.conv1 = BasicConv2d(3, 192, kernel_size=3, stride=1, padding=1) 48 | # self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 49 | # self.conv2 = BasicConv2d(64, 64, kernel_size=1) 50 | # self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) 51 | # self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 52 | # END 53 | 54 | self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) 55 | self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) 56 | 57 | # CIFAR10: padding 0->1, ciel_model True->False 58 | self.maxpool3 = nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=False) 59 | # END 60 | 61 | self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) 62 | self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) 63 | self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) 64 | self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) 65 | self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) 66 | 67 | # CIFAR10: kernel_size 2->3, padding 0->1, ciel_model True->False 68 | self.maxpool4 = nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=False) 69 | # END 70 | 71 | self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) 72 | self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) 73 | 74 | if aux_logits: 75 | self.aux1 = InceptionAux(512, num_classes) 76 | self.aux2 = InceptionAux(528, num_classes) 77 | 78 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 79 | self.dropout = nn.Dropout(0.2) 80 | self.fc = nn.Linear(1024, num_classes) 81 | 82 | # if init_weights: 83 | # self._initialize_weights() 84 | 85 | # def _initialize_weights(self): 86 | # for m in self.modules(): 87 | # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 88 | # import scipy.stats as stats 89 | # X = stats.truncnorm(-2, 2, scale=0.01) 90 | # values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 91 | # values = values.view(m.weight.size()) 92 | # with torch.no_grad(): 93 | # m.weight.copy_(values) 94 | # elif isinstance(m, nn.BatchNorm2d): 95 | # nn.init.constant_(m.weight, 1) 96 | # nn.init.constant_(m.bias, 0) 97 | 98 | def forward(self, x): 99 | if self.transform_input: 100 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 101 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 102 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 103 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 104 | 105 | # N x 3 x 224 x 224 106 | x = self.conv1(x) 107 | 108 | # CIFAR10 109 | # N x 64 x 112 x 112 110 | # x = self.maxpool1(x) 111 | # N x 64 x 56 x 56 112 | # x = self.conv2(x) 113 | # N x 64 x 56 x 56 114 | # x = self.conv3(x) 115 | # N x 192 x 56 x 56 116 | # x = self.maxpool2(x) 117 | # END 118 | 119 | # N x 192 x 28 x 28 120 | x = self.inception3a(x) 121 | # N x 256 x 28 x 28 122 | x = self.inception3b(x) 123 | # N x 480 x 28 x 28 124 | x = self.maxpool3(x) 125 | # N x 480 x 14 x 14 126 | x = self.inception4a(x) 127 | # N x 512 x 14 x 14 128 | if self.training and self.aux_logits: 129 | aux1 = self.aux1(x) 130 | 131 | x = self.inception4b(x) 132 | # N x 512 x 14 x 14 133 | x = self.inception4c(x) 134 | # N x 512 x 14 x 14 135 | x = self.inception4d(x) 136 | # N x 528 x 14 x 14 137 | if self.training and self.aux_logits: 138 | aux2 = self.aux2(x) 139 | 140 | x = self.inception4e(x) 141 | # N x 832 x 14 x 14 142 | x = self.maxpool4(x) 143 | # N x 832 x 7 x 7 144 | x = self.inception5a(x) 145 | # N x 832 x 7 x 7 146 | x = self.inception5b(x) 147 | # N x 1024 x 7 x 7 148 | 149 | x = self.avgpool(x) 150 | # N x 1024 x 1 x 1 151 | x = x.view(x.size(0), -1) 152 | # N x 1024 153 | x = self.dropout(x) 154 | x = self.fc(x) 155 | # N x 1000 (num_classes) 156 | if self.training and self.aux_logits: 157 | return _GoogLeNetOuputs(x, aux2, aux1) 158 | return x 159 | 160 | 161 | class Inception(nn.Module): 162 | def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): 163 | super(Inception, self).__init__() 164 | 165 | self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) 166 | 167 | self.branch2 = nn.Sequential( 168 | BasicConv2d(in_channels, ch3x3red, kernel_size=1), 169 | BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1), 170 | ) 171 | 172 | self.branch3 = nn.Sequential( 173 | BasicConv2d(in_channels, ch5x5red, kernel_size=1), 174 | BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1), 175 | ) 176 | 177 | self.branch4 = nn.Sequential( 178 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 179 | BasicConv2d(in_channels, pool_proj, kernel_size=1), 180 | ) 181 | 182 | def forward(self, x): 183 | branch1 = self.branch1(x) 184 | branch2 = self.branch2(x) 185 | branch3 = self.branch3(x) 186 | branch4 = self.branch4(x) 187 | 188 | outputs = [branch1, branch2, branch3, branch4] 189 | return torch.cat(outputs, 1) 190 | 191 | 192 | class InceptionAux(nn.Module): 193 | def __init__(self, in_channels, num_classes): 194 | super(InceptionAux, self).__init__() 195 | self.conv = BasicConv2d(in_channels, 128, kernel_size=1) 196 | 197 | self.fc1 = nn.Linear(2048, 1024) 198 | self.fc2 = nn.Linear(1024, num_classes) 199 | 200 | def forward(self, x): 201 | # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 202 | x = F.adaptive_avg_pool2d(x, (4, 4)) 203 | # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 204 | x = self.conv(x) 205 | # N x 128 x 4 x 4 206 | x = x.view(x.size(0), -1) 207 | # N x 2048 208 | x = F.relu(self.fc1(x), inplace=True) 209 | # N x 2048 210 | x = F.dropout(x, 0.7, training=self.training) 211 | # N x 2048 212 | x = self.fc2(x) 213 | # N x 1024 214 | 215 | return x 216 | 217 | 218 | class BasicConv2d(nn.Module): 219 | def __init__(self, in_channels, out_channels, **kwargs): 220 | super(BasicConv2d, self).__init__() 221 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 222 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 223 | 224 | def forward(self, x): 225 | x = self.conv(x) 226 | x = self.bn(x) 227 | return F.relu(x, inplace=True) 228 | -------------------------------------------------------------------------------- /cifar10_models/inception.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import namedtuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | __all__ = ["Inception3", "inception_v3"] 9 | 10 | 11 | _InceptionOuputs = namedtuple("InceptionOuputs", ["logits", "aux_logits"]) 12 | 13 | 14 | def inception_v3(pretrained=False, progress=True, device="cpu", **kwargs): 15 | r"""Inception v3 model architecture from 16 | `"Rethinking the Inception Architecture for Computer Vision" `_. 17 | 18 | .. note:: 19 | **Important**: In contrast to the other models the inception_v3 expects tensors with a size of 20 | N x 3 x 299 x 299, so ensure your images are sized accordingly. 21 | 22 | Args: 23 | pretrained (bool): If True, returns a model pre-trained on ImageNet 24 | progress (bool): If True, displays a progress bar of the download to stderr 25 | aux_logits (bool): If True, add an auxiliary branch that can improve training. 26 | Default: *True* 27 | transform_input (bool): If True, preprocesses the input according to the method with which it 28 | was trained on ImageNet. Default: *False* 29 | """ 30 | model = Inception3() 31 | if pretrained: 32 | script_dir = os.path.dirname(__file__) 33 | state_dict = torch.load( 34 | script_dir + "/state_dicts/inception_v3.pt", map_location=device 35 | ) 36 | model.load_state_dict(state_dict) 37 | return model 38 | 39 | 40 | class Inception3(nn.Module): 41 | # CIFAR10: aux_logits True->False 42 | def __init__(self, num_classes=10, aux_logits=False, transform_input=False): 43 | super(Inception3, self).__init__() 44 | self.aux_logits = aux_logits 45 | self.transform_input = transform_input 46 | 47 | # CIFAR10: stride 2->1, padding 0 -> 1 48 | self.Conv2d_1a_3x3 = BasicConv2d(3, 192, kernel_size=3, stride=1, padding=1) 49 | # self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 50 | # self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 51 | # self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 52 | # self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 53 | self.Mixed_5b = InceptionA(192, pool_features=32) 54 | self.Mixed_5c = InceptionA(256, pool_features=64) 55 | self.Mixed_5d = InceptionA(288, pool_features=64) 56 | self.Mixed_6a = InceptionB(288) 57 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 58 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 59 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 60 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 61 | if aux_logits: 62 | self.AuxLogits = InceptionAux(768, num_classes) 63 | self.Mixed_7a = InceptionD(768) 64 | self.Mixed_7b = InceptionE(1280) 65 | self.Mixed_7c = InceptionE(2048) 66 | self.fc = nn.Linear(2048, num_classes) 67 | 68 | # for m in self.modules(): 69 | # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 70 | # import scipy.stats as stats 71 | # stddev = m.stddev if hasattr(m, 'stddev') else 0.1 72 | # X = stats.truncnorm(-2, 2, scale=stddev) 73 | # values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 74 | # values = values.view(m.weight.size()) 75 | # with torch.no_grad(): 76 | # m.weight.copy_(values) 77 | # elif isinstance(m, nn.BatchNorm2d): 78 | # nn.init.constant_(m.weight, 1) 79 | # nn.init.constant_(m.bias, 0) 80 | 81 | def forward(self, x): 82 | if self.transform_input: 83 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 84 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 85 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 86 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 87 | # N x 3 x 299 x 299 88 | x = self.Conv2d_1a_3x3(x) 89 | 90 | # CIFAR10 91 | # N x 32 x 149 x 149 92 | # x = self.Conv2d_2a_3x3(x) 93 | # N x 32 x 147 x 147 94 | # x = self.Conv2d_2b_3x3(x) 95 | # N x 64 x 147 x 147 96 | # x = F.max_pool2d(x, kernel_size=3, stride=2) 97 | # N x 64 x 73 x 73 98 | # x = self.Conv2d_3b_1x1(x) 99 | # N x 80 x 73 x 73 100 | # x = self.Conv2d_4a_3x3(x) 101 | # N x 192 x 71 x 71 102 | # x = F.max_pool2d(x, kernel_size=3, stride=2) 103 | # N x 192 x 35 x 35 104 | x = self.Mixed_5b(x) 105 | # N x 256 x 35 x 35 106 | x = self.Mixed_5c(x) 107 | # N x 288 x 35 x 35 108 | x = self.Mixed_5d(x) 109 | # N x 288 x 35 x 35 110 | x = self.Mixed_6a(x) 111 | # N x 768 x 17 x 17 112 | x = self.Mixed_6b(x) 113 | # N x 768 x 17 x 17 114 | x = self.Mixed_6c(x) 115 | # N x 768 x 17 x 17 116 | x = self.Mixed_6d(x) 117 | # N x 768 x 17 x 17 118 | x = self.Mixed_6e(x) 119 | # N x 768 x 17 x 17 120 | if self.training and self.aux_logits: 121 | aux = self.AuxLogits(x) 122 | # N x 768 x 17 x 17 123 | x = self.Mixed_7a(x) 124 | # N x 1280 x 8 x 8 125 | x = self.Mixed_7b(x) 126 | # N x 2048 x 8 x 8 127 | x = self.Mixed_7c(x) 128 | # N x 2048 x 8 x 8 129 | # Adaptive average pooling 130 | x = F.adaptive_avg_pool2d(x, (1, 1)) 131 | # N x 2048 x 1 x 1 132 | x = F.dropout(x, training=self.training) 133 | # N x 2048 x 1 x 1 134 | x = x.view(x.size(0), -1) 135 | # N x 2048 136 | x = self.fc(x) 137 | # N x 1000 (num_classes) 138 | if self.training and self.aux_logits: 139 | return _InceptionOuputs(x, aux) 140 | return x 141 | 142 | 143 | class InceptionA(nn.Module): 144 | def __init__(self, in_channels, pool_features): 145 | super(InceptionA, self).__init__() 146 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 147 | 148 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 149 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 150 | 151 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 152 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 153 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 154 | 155 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 156 | 157 | def forward(self, x): 158 | branch1x1 = self.branch1x1(x) 159 | 160 | branch5x5 = self.branch5x5_1(x) 161 | branch5x5 = self.branch5x5_2(branch5x5) 162 | 163 | branch3x3dbl = self.branch3x3dbl_1(x) 164 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 165 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 166 | 167 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 168 | branch_pool = self.branch_pool(branch_pool) 169 | 170 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 171 | return torch.cat(outputs, 1) 172 | 173 | 174 | class InceptionB(nn.Module): 175 | def __init__(self, in_channels): 176 | super(InceptionB, self).__init__() 177 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) 178 | 179 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 180 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 181 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) 182 | 183 | def forward(self, x): 184 | branch3x3 = self.branch3x3(x) 185 | 186 | branch3x3dbl = self.branch3x3dbl_1(x) 187 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 188 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 189 | 190 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 191 | 192 | outputs = [branch3x3, branch3x3dbl, branch_pool] 193 | return torch.cat(outputs, 1) 194 | 195 | 196 | class InceptionC(nn.Module): 197 | def __init__(self, in_channels, channels_7x7): 198 | super(InceptionC, self).__init__() 199 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 200 | 201 | c7 = channels_7x7 202 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 203 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 204 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 205 | 206 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 207 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 208 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 209 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 210 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 211 | 212 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 213 | 214 | def forward(self, x): 215 | branch1x1 = self.branch1x1(x) 216 | 217 | branch7x7 = self.branch7x7_1(x) 218 | branch7x7 = self.branch7x7_2(branch7x7) 219 | branch7x7 = self.branch7x7_3(branch7x7) 220 | 221 | branch7x7dbl = self.branch7x7dbl_1(x) 222 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 223 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 224 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 225 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 226 | 227 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 228 | branch_pool = self.branch_pool(branch_pool) 229 | 230 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 231 | return torch.cat(outputs, 1) 232 | 233 | 234 | class InceptionD(nn.Module): 235 | def __init__(self, in_channels): 236 | super(InceptionD, self).__init__() 237 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 238 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) 239 | 240 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 241 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) 242 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) 243 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) 244 | 245 | def forward(self, x): 246 | branch3x3 = self.branch3x3_1(x) 247 | branch3x3 = self.branch3x3_2(branch3x3) 248 | 249 | branch7x7x3 = self.branch7x7x3_1(x) 250 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 251 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 252 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 253 | 254 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 255 | outputs = [branch3x3, branch7x7x3, branch_pool] 256 | return torch.cat(outputs, 1) 257 | 258 | 259 | class InceptionE(nn.Module): 260 | def __init__(self, in_channels): 261 | super(InceptionE, self).__init__() 262 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 263 | 264 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 265 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 266 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 267 | 268 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 269 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 270 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 271 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 272 | 273 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 274 | 275 | def forward(self, x): 276 | branch1x1 = self.branch1x1(x) 277 | 278 | branch3x3 = self.branch3x3_1(x) 279 | branch3x3 = [ 280 | self.branch3x3_2a(branch3x3), 281 | self.branch3x3_2b(branch3x3), 282 | ] 283 | branch3x3 = torch.cat(branch3x3, 1) 284 | 285 | branch3x3dbl = self.branch3x3dbl_1(x) 286 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 287 | branch3x3dbl = [ 288 | self.branch3x3dbl_3a(branch3x3dbl), 289 | self.branch3x3dbl_3b(branch3x3dbl), 290 | ] 291 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 292 | 293 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 294 | branch_pool = self.branch_pool(branch_pool) 295 | 296 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 297 | return torch.cat(outputs, 1) 298 | 299 | 300 | class InceptionAux(nn.Module): 301 | def __init__(self, in_channels, num_classes): 302 | super(InceptionAux, self).__init__() 303 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) 304 | self.conv1 = BasicConv2d(128, 768, kernel_size=5) 305 | self.conv1.stddev = 0.01 306 | self.fc = nn.Linear(768, num_classes) 307 | self.fc.stddev = 0.001 308 | 309 | def forward(self, x): 310 | # N x 768 x 17 x 17 311 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 312 | # N x 768 x 5 x 5 313 | x = self.conv0(x) 314 | # N x 128 x 5 x 5 315 | x = self.conv1(x) 316 | # N x 768 x 1 x 1 317 | # Adaptive average pooling 318 | x = F.adaptive_avg_pool2d(x, (1, 1)) 319 | # N x 768 x 1 x 1 320 | x = x.view(x.size(0), -1) 321 | # N x 768 322 | x = self.fc(x) 323 | # N x 1000 324 | return x 325 | 326 | 327 | class BasicConv2d(nn.Module): 328 | def __init__(self, in_channels, out_channels, **kwargs): 329 | super(BasicConv2d, self).__init__() 330 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 331 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 332 | 333 | def forward(self, x): 334 | x = self.conv(x) 335 | x = self.bn(x) 336 | return F.relu(x, inplace=True) 337 | -------------------------------------------------------------------------------- /cifar10_models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | __all__ = ["MobileNetV2", "mobilenet_v2"] 7 | 8 | 9 | class ConvBNReLU(nn.Sequential): 10 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 11 | padding = (kernel_size - 1) // 2 12 | super(ConvBNReLU, self).__init__( 13 | nn.Conv2d( 14 | in_planes, 15 | out_planes, 16 | kernel_size, 17 | stride, 18 | padding, 19 | groups=groups, 20 | bias=False, 21 | ), 22 | nn.BatchNorm2d(out_planes), 23 | nn.ReLU6(inplace=True), 24 | ) 25 | 26 | 27 | class InvertedResidual(nn.Module): 28 | def __init__(self, inp, oup, stride, expand_ratio): 29 | super(InvertedResidual, self).__init__() 30 | self.stride = stride 31 | assert stride in [1, 2] 32 | 33 | hidden_dim = int(round(inp * expand_ratio)) 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | layers = [] 37 | if expand_ratio != 1: 38 | # pw 39 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 40 | layers.extend( 41 | [ 42 | # dw 43 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 44 | # pw-linear 45 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 46 | nn.BatchNorm2d(oup), 47 | ] 48 | ) 49 | self.conv = nn.Sequential(*layers) 50 | 51 | def forward(self, x): 52 | if self.use_res_connect: 53 | return x + self.conv(x) 54 | else: 55 | return self.conv(x) 56 | 57 | 58 | class MobileNetV2(nn.Module): 59 | def __init__(self, num_classes=10, width_mult=1.0): 60 | super(MobileNetV2, self).__init__() 61 | block = InvertedResidual 62 | input_channel = 32 63 | last_channel = 1280 64 | 65 | # CIFAR10 66 | inverted_residual_setting = [ 67 | # t, c, n, s 68 | [1, 16, 1, 1], 69 | [6, 24, 2, 1], # Stride 2 -> 1 for CIFAR-10 70 | [6, 32, 3, 2], 71 | [6, 64, 4, 2], 72 | [6, 96, 3, 1], 73 | [6, 160, 3, 2], 74 | [6, 320, 1, 1], 75 | ] 76 | # END 77 | 78 | # building first layer 79 | input_channel = int(input_channel * width_mult) 80 | self.last_channel = int(last_channel * max(1.0, width_mult)) 81 | 82 | # CIFAR10: stride 2 -> 1 83 | features = [ConvBNReLU(3, input_channel, stride=1)] 84 | # END 85 | 86 | # building inverted residual blocks 87 | for t, c, n, s in inverted_residual_setting: 88 | output_channel = int(c * width_mult) 89 | for i in range(n): 90 | stride = s if i == 0 else 1 91 | features.append( 92 | block(input_channel, output_channel, stride, expand_ratio=t) 93 | ) 94 | input_channel = output_channel 95 | # building last several layers 96 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 97 | # make it nn.Sequential 98 | self.features = nn.Sequential(*features) 99 | 100 | # building classifier 101 | self.classifier = nn.Sequential( 102 | nn.Dropout(0.2), 103 | nn.Linear(self.last_channel, num_classes), 104 | ) 105 | 106 | # weight initialization 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 110 | if m.bias is not None: 111 | nn.init.zeros_(m.bias) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | nn.init.ones_(m.weight) 114 | nn.init.zeros_(m.bias) 115 | elif isinstance(m, nn.Linear): 116 | nn.init.normal_(m.weight, 0, 0.01) 117 | nn.init.zeros_(m.bias) 118 | 119 | def forward(self, x): 120 | x = self.features(x) 121 | x = x.mean([2, 3]) 122 | x = self.classifier(x) 123 | return x 124 | 125 | 126 | def mobilenet_v2(pretrained=False, progress=True, device="cpu", **kwargs): 127 | """ 128 | Constructs a MobileNetV2 architecture from 129 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 130 | 131 | Args: 132 | pretrained (bool): If True, returns a model pre-trained on ImageNet 133 | progress (bool): If True, displays a progress bar of the download to stderr 134 | """ 135 | model = MobileNetV2(**kwargs) 136 | if pretrained: 137 | script_dir = os.path.dirname(__file__) 138 | state_dict = torch.load( 139 | script_dir + "/state_dicts/mobilenet_v2.pt", map_location=device 140 | ) 141 | model.load_state_dict(state_dict) 142 | return model 143 | -------------------------------------------------------------------------------- /cifar10_models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | __all__ = [ 6 | "ResNet", 7 | "resnet18", 8 | "resnet34", 9 | "resnet50", 10 | ] 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d( 16 | in_planes, 17 | out_planes, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=dilation, 21 | groups=groups, 22 | bias=False, 23 | dilation=dilation, 24 | ) 25 | 26 | 27 | def conv1x1(in_planes, out_planes, stride=1): 28 | """1x1 convolution""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | expansion = 1 34 | 35 | def __init__( 36 | self, 37 | inplanes, 38 | planes, 39 | stride=1, 40 | downsample=None, 41 | groups=1, 42 | base_width=64, 43 | dilation=1, 44 | norm_layer=None, 45 | ): 46 | super(BasicBlock, self).__init__() 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | if groups != 1 or base_width != 64: 50 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 51 | if dilation > 1: 52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 54 | self.conv1 = conv3x3(inplanes, planes, stride) 55 | self.bn1 = norm_layer(planes) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3(planes, planes) 58 | self.bn2 = norm_layer(planes) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__( 85 | self, 86 | inplanes, 87 | planes, 88 | stride=1, 89 | downsample=None, 90 | groups=1, 91 | base_width=64, 92 | dilation=1, 93 | norm_layer=None, 94 | ): 95 | super(Bottleneck, self).__init__() 96 | if norm_layer is None: 97 | norm_layer = nn.BatchNorm2d 98 | width = int(planes * (base_width / 64.0)) * groups 99 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 100 | self.conv1 = conv1x1(inplanes, width) 101 | self.bn1 = norm_layer(width) 102 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 103 | self.bn2 = norm_layer(width) 104 | self.conv3 = conv1x1(width, planes * self.expansion) 105 | self.bn3 = norm_layer(planes * self.expansion) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.downsample = downsample 108 | self.stride = stride 109 | 110 | def forward(self, x): 111 | identity = x 112 | 113 | out = self.conv1(x) 114 | out = self.bn1(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv2(out) 118 | out = self.bn2(out) 119 | out = self.relu(out) 120 | 121 | out = self.conv3(out) 122 | out = self.bn3(out) 123 | 124 | if self.downsample is not None: 125 | identity = self.downsample(x) 126 | 127 | out += identity 128 | out = self.relu(out) 129 | 130 | return out 131 | 132 | 133 | class ResNet(nn.Module): 134 | def __init__( 135 | self, 136 | block, 137 | layers, 138 | num_classes=10, 139 | zero_init_residual=False, 140 | groups=1, 141 | width_per_group=64, 142 | replace_stride_with_dilation=None, 143 | norm_layer=None, 144 | ): 145 | super(ResNet, self).__init__() 146 | if norm_layer is None: 147 | norm_layer = nn.BatchNorm2d 148 | self._norm_layer = norm_layer 149 | 150 | self.inplanes = 64 151 | self.dilation = 1 152 | if replace_stride_with_dilation is None: 153 | # each element in the tuple indicates if we should replace 154 | # the 2x2 stride with a dilated convolution instead 155 | replace_stride_with_dilation = [False, False, False] 156 | if len(replace_stride_with_dilation) != 3: 157 | raise ValueError( 158 | "replace_stride_with_dilation should be None " 159 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 160 | ) 161 | self.groups = groups 162 | self.base_width = width_per_group 163 | 164 | # CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1 165 | self.conv1 = nn.Conv2d( 166 | 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False 167 | ) 168 | # END 169 | 170 | self.bn1 = norm_layer(self.inplanes) 171 | self.relu = nn.ReLU(inplace=True) 172 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 173 | self.layer1 = self._make_layer(block, 64, layers[0]) 174 | self.layer2 = self._make_layer( 175 | block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 176 | ) 177 | self.layer3 = self._make_layer( 178 | block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 179 | ) 180 | self.layer4 = self._make_layer( 181 | block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] 182 | ) 183 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 184 | self.fc = nn.Linear(512 * block.expansion, num_classes) 185 | 186 | for m in self.modules(): 187 | if isinstance(m, nn.Conv2d): 188 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 189 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 190 | nn.init.constant_(m.weight, 1) 191 | nn.init.constant_(m.bias, 0) 192 | 193 | # Zero-initialize the last BN in each residual branch, 194 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 195 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 196 | if zero_init_residual: 197 | for m in self.modules(): 198 | if isinstance(m, Bottleneck): 199 | nn.init.constant_(m.bn3.weight, 0) 200 | elif isinstance(m, BasicBlock): 201 | nn.init.constant_(m.bn2.weight, 0) 202 | 203 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 204 | norm_layer = self._norm_layer 205 | downsample = None 206 | previous_dilation = self.dilation 207 | if dilate: 208 | self.dilation *= stride 209 | stride = 1 210 | if stride != 1 or self.inplanes != planes * block.expansion: 211 | downsample = nn.Sequential( 212 | conv1x1(self.inplanes, planes * block.expansion, stride), 213 | norm_layer(planes * block.expansion), 214 | ) 215 | 216 | layers = [] 217 | layers.append( 218 | block( 219 | self.inplanes, 220 | planes, 221 | stride, 222 | downsample, 223 | self.groups, 224 | self.base_width, 225 | previous_dilation, 226 | norm_layer, 227 | ) 228 | ) 229 | self.inplanes = planes * block.expansion 230 | for _ in range(1, blocks): 231 | layers.append( 232 | block( 233 | self.inplanes, 234 | planes, 235 | groups=self.groups, 236 | base_width=self.base_width, 237 | dilation=self.dilation, 238 | norm_layer=norm_layer, 239 | ) 240 | ) 241 | 242 | return nn.Sequential(*layers) 243 | 244 | def forward(self, x): 245 | x = self.conv1(x) 246 | x = self.bn1(x) 247 | x = self.relu(x) 248 | x = self.maxpool(x) 249 | 250 | x = self.layer1(x) 251 | x = self.layer2(x) 252 | x = self.layer3(x) 253 | x = self.layer4(x) 254 | 255 | x = self.avgpool(x) 256 | x = x.reshape(x.size(0), -1) 257 | x = self.fc(x) 258 | 259 | return x 260 | 261 | 262 | def _resnet(arch, block, layers, pretrained, progress, device, **kwargs): 263 | model = ResNet(block, layers, **kwargs) 264 | if pretrained: 265 | script_dir = os.path.dirname(__file__) 266 | state_dict = torch.load( 267 | script_dir + "/state_dicts/" + arch + ".pt", map_location=device 268 | ) 269 | model.load_state_dict(state_dict) 270 | return model 271 | 272 | 273 | def resnet18(pretrained=False, progress=True, device="cpu", **kwargs): 274 | """Constructs a ResNet-18 model. 275 | Args: 276 | pretrained (bool): If True, returns a model pre-trained on ImageNet 277 | progress (bool): If True, displays a progress bar of the download to stderr 278 | """ 279 | return _resnet( 280 | "resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, device, **kwargs 281 | ) 282 | 283 | 284 | def resnet34(pretrained=False, progress=True, device="cpu", **kwargs): 285 | """Constructs a ResNet-34 model. 286 | Args: 287 | pretrained (bool): If True, returns a model pre-trained on ImageNet 288 | progress (bool): If True, displays a progress bar of the download to stderr 289 | """ 290 | return _resnet( 291 | "resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, device, **kwargs 292 | ) 293 | 294 | 295 | def resnet50(pretrained=False, progress=True, device="cpu", **kwargs): 296 | """Constructs a ResNet-50 model. 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | progress (bool): If True, displays a progress bar of the download to stderr 300 | """ 301 | return _resnet( 302 | "resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, device, **kwargs 303 | ) 304 | -------------------------------------------------------------------------------- /cifar10_models/resnet_orig.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | 6 | # Credit to https://github.com/akamaster/pytorch_resnet_cifar10 7 | 8 | __all__ = ["resnet_orig"] 9 | 10 | 11 | class LambdaLayer(nn.Module): 12 | def __init__(self, lambd): 13 | super(LambdaLayer, self).__init__() 14 | self.lambd = lambd 15 | 16 | def forward(self, x): 17 | return self.lambd(x) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, in_planes, planes, stride=1, option="A"): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = nn.Conv2d( 26 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 27 | ) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.conv2 = nn.Conv2d( 30 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 31 | ) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | 34 | self.shortcut = nn.Sequential() 35 | if stride != 1 or in_planes != planes: 36 | if option == "A": 37 | """ 38 | For CIFAR10 ResNet paper uses option A. 39 | """ 40 | self.shortcut = LambdaLayer( 41 | lambda x: F.pad( 42 | x[:, :, ::2, ::2], 43 | (0, 0, 0, 0, planes // 4, planes // 4), 44 | "constant", 45 | 0, 46 | ) 47 | ) 48 | elif option == "B": 49 | self.shortcut = nn.Sequential( 50 | nn.Conv2d( 51 | in_planes, 52 | self.expansion * planes, 53 | kernel_size=1, 54 | stride=stride, 55 | bias=False, 56 | ), 57 | nn.BatchNorm2d(self.expansion * planes), 58 | ) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = self.bn2(self.conv2(out)) 63 | out += self.shortcut(x) 64 | out = F.relu(out) 65 | return out 66 | 67 | 68 | class ResNet(nn.Module): 69 | def __init__(self, block, num_blocks, num_classes=10): 70 | super(ResNet, self).__init__() 71 | self.in_planes = 16 72 | 73 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(16) 75 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 76 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 77 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 78 | self.linear = nn.Linear(64, num_classes) 79 | 80 | def _make_layer(self, block, planes, num_blocks, stride): 81 | strides = [stride] + [1] * (num_blocks - 1) 82 | layers = [] 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, stride)) 85 | self.in_planes = planes * block.expansion 86 | 87 | return nn.Sequential(*layers) 88 | 89 | def forward(self, x): 90 | out = F.relu(self.bn1(self.conv1(x))) 91 | out = self.layer1(out) 92 | out = self.layer2(out) 93 | out = self.layer3(out) 94 | out = F.avg_pool2d(out, out.size()[3]) 95 | out = out.view(out.size(0), -1) 96 | out = self.linear(out) 97 | return out 98 | 99 | 100 | def resnet_orig(pretrained=True, device="cpu"): 101 | net = ResNet(BasicBlock, [3, 3, 3]) 102 | if pretrained: 103 | script_dir = os.path.dirname(__file__) 104 | state_dict = torch.load( 105 | script_dir + "/state_dicts/resnet_orig.pt", map_location=device 106 | ) 107 | net.load_state_dict(state_dict) 108 | return net 109 | -------------------------------------------------------------------------------- /cifar10_models/vgg.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | __all__ = [ 7 | "VGG", 8 | "vgg11_bn", 9 | "vgg13_bn", 10 | "vgg16_bn", 11 | "vgg19_bn", 12 | ] 13 | 14 | 15 | class VGG(nn.Module): 16 | def __init__(self, features, num_classes=10, init_weights=True): 17 | super(VGG, self).__init__() 18 | self.features = features 19 | # CIFAR 10 (7, 7) to (1, 1) 20 | # self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 21 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 22 | 23 | self.classifier = nn.Sequential( 24 | nn.Linear(512 * 1 * 1, 4096), 25 | # nn.Linear(512 * 7 * 7, 4096), 26 | nn.ReLU(True), 27 | nn.Dropout(), 28 | nn.Linear(4096, 4096), 29 | nn.ReLU(True), 30 | nn.Dropout(), 31 | nn.Linear(4096, num_classes), 32 | ) 33 | if init_weights: 34 | self._initialize_weights() 35 | 36 | def forward(self, x): 37 | x = self.features(x) 38 | x = self.avgpool(x) 39 | x = x.view(x.size(0), -1) 40 | x = self.classifier(x) 41 | return x 42 | 43 | def _initialize_weights(self): 44 | for m in self.modules(): 45 | if isinstance(m, nn.Conv2d): 46 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 47 | if m.bias is not None: 48 | nn.init.constant_(m.bias, 0) 49 | elif isinstance(m, nn.BatchNorm2d): 50 | nn.init.constant_(m.weight, 1) 51 | nn.init.constant_(m.bias, 0) 52 | elif isinstance(m, nn.Linear): 53 | nn.init.normal_(m.weight, 0, 0.01) 54 | nn.init.constant_(m.bias, 0) 55 | 56 | 57 | def make_layers(cfg, batch_norm=False): 58 | layers = [] 59 | in_channels = 3 60 | for v in cfg: 61 | if v == "M": 62 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 63 | else: 64 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 65 | if batch_norm: 66 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 67 | else: 68 | layers += [conv2d, nn.ReLU(inplace=True)] 69 | in_channels = v 70 | return nn.Sequential(*layers) 71 | 72 | 73 | cfgs = { 74 | "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 75 | "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 76 | "D": [ 77 | 64, 78 | 64, 79 | "M", 80 | 128, 81 | 128, 82 | "M", 83 | 256, 84 | 256, 85 | 256, 86 | "M", 87 | 512, 88 | 512, 89 | 512, 90 | "M", 91 | 512, 92 | 512, 93 | 512, 94 | "M", 95 | ], 96 | "E": [ 97 | 64, 98 | 64, 99 | "M", 100 | 128, 101 | 128, 102 | "M", 103 | 256, 104 | 256, 105 | 256, 106 | 256, 107 | "M", 108 | 512, 109 | 512, 110 | 512, 111 | 512, 112 | "M", 113 | 512, 114 | 512, 115 | 512, 116 | 512, 117 | "M", 118 | ], 119 | } 120 | 121 | 122 | def _vgg(arch, cfg, batch_norm, pretrained, progress, device, **kwargs): 123 | if pretrained: 124 | kwargs["init_weights"] = False 125 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 126 | if pretrained: 127 | script_dir = os.path.dirname(__file__) 128 | state_dict = torch.load( 129 | script_dir + "/state_dicts/" + arch + ".pt", map_location=device 130 | ) 131 | model.load_state_dict(state_dict) 132 | return model 133 | 134 | 135 | def vgg11_bn(pretrained=False, progress=True, device="cpu", **kwargs): 136 | """VGG 11-layer model (configuration "A") with batch normalization 137 | 138 | Args: 139 | pretrained (bool): If True, returns a model pre-trained on ImageNet 140 | progress (bool): If True, displays a progress bar of the download to stderr 141 | """ 142 | return _vgg("vgg11_bn", "A", True, pretrained, progress, device, **kwargs) 143 | 144 | 145 | def vgg13_bn(pretrained=False, progress=True, device="cpu", **kwargs): 146 | """VGG 13-layer model (configuration "B") with batch normalization 147 | 148 | Args: 149 | pretrained (bool): If True, returns a model pre-trained on ImageNet 150 | progress (bool): If True, displays a progress bar of the download to stderr 151 | """ 152 | return _vgg("vgg13_bn", "B", True, pretrained, progress, device, **kwargs) 153 | 154 | 155 | def vgg16_bn(pretrained=False, progress=True, device="cpu", **kwargs): 156 | """VGG 16-layer model (configuration "D") with batch normalization 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | progress (bool): If True, displays a progress bar of the download to stderr 161 | """ 162 | return _vgg("vgg16_bn", "D", True, pretrained, progress, device, **kwargs) 163 | 164 | 165 | def vgg19_bn(pretrained=False, progress=True, device="cpu", **kwargs): 166 | """VGG 19-layer model (configuration 'E') with batch normalization 167 | 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | progress (bool): If True, displays a progress bar of the download to stderr 171 | """ 172 | return _vgg("vgg19_bn", "E", True, pretrained, progress, device, **kwargs) 173 | -------------------------------------------------------------------------------- /clib/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /clib/cos_stat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/L2_Compression/4d28b141930399cfb4990a63a3fae85c23a1b1e5/clib/cos_stat/__init__.py -------------------------------------------------------------------------------- /clib/cos_stat/cos_stat.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | void CosStatKernelLauncher(const int m , const int n, const float* sorted_matrix_data, const float* queries_data, const float delta, int* rescdf_i, float* rescdf_f, float* respdf_f); 5 | void CosStatGradKernelLauncher(const int m , const int n, const float* matrix_data, const float* queries_data, const float delta, const float* grad_cdf_f, float* res_grad_matrix_f); 6 | 7 | void cos_stat_forward_cuda( 8 | const at::Tensor sorted_matrix, 9 | const at::Tensor queries, 10 | const float delta, 11 | at::Tensor rescdf_i, 12 | at::Tensor rescdf_f, 13 | at::Tensor respdf_f) 14 | { 15 | CosStatKernelLauncher(sorted_matrix.size(0), queries.size(0), sorted_matrix.data_ptr(), queries.data_ptr(), delta, rescdf_i.data_ptr(), rescdf_f.data_ptr(), respdf_f.data_ptr()); 16 | } 17 | 18 | void cos_stat_backward_cuda( 19 | const at::Tensor matrix, 20 | const at::Tensor queries, 21 | const float delta, 22 | at::Tensor grad_cdf_f, 23 | at::Tensor res_grad_matrix_f) 24 | { 25 | CosStatGradKernelLauncher(matrix.size(0), queries.size(0), matrix.data_ptr(), queries.data_ptr(), delta, grad_cdf_f.data_ptr(), res_grad_matrix_f.data_ptr()); 26 | } 27 | 28 | int lower_bound(const float *array, int size, float key) 29 | { 30 | int first = 0, len = size; 31 | int half, middle; 32 | 33 | while(len > 0){ 34 | half = len >> 1; 35 | middle = first + half; 36 | if(array[middle] < key){ 37 | first = middle + 1; 38 | len = len - half - 1; 39 | } 40 | else{ 41 | len = half; 42 | } 43 | } 44 | return first; 45 | } 46 | 47 | void cos_stat(const int m , const int n, const float* sorted_matrix_data, const float* queries_data, const float delta, int* rescdf_i, float* rescdf_f, float* respdf_f) 48 | { 49 | for (int i = 0; i < n; i++) { 50 | int start = lower_bound(sorted_matrix_data, m, queries_data[i] - delta); 51 | rescdf_i[i] += start; 52 | for (int j = start; j < m; j++) { 53 | if (queries_data[i] > sorted_matrix_data[j] + delta) { 54 | rescdf_f[i] += 1; 55 | } 56 | else if (queries_data[i] < sorted_matrix_data[j] - delta){ 57 | break; 58 | } 59 | else{ 60 | rescdf_f[i] += -0.5 * cos(((queries_data[i] - sorted_matrix_data[j] + delta) * M_PI / delta) / 2) + 0.5; 61 | respdf_f[i] += ((M_PI/delta) / 4) * sin(((queries_data[i] - sorted_matrix_data[j] + delta) * M_PI / delta) / 2); 62 | } 63 | } 64 | } 65 | } 66 | 67 | void cos_stat_back(const int m , const int n, const float* matrix_data, const float* queries_data, const float delta, const float* grad_cdf_f, float* res_grad_matrix_f) 68 | { 69 | for (int j = 0; j < m; j++) { 70 | int start = lower_bound(queries_data, n, matrix_data[j] - delta); 71 | for (int i = start; i < n; i++) { 72 | if (matrix_data[j] > queries_data[i] + delta){ 73 | continue; 74 | } 75 | else if (matrix_data[j] < queries_data[i] - delta){ 76 | break; 77 | } 78 | else{ 79 | res_grad_matrix_f[j] += -grad_cdf_f[i] * ((M_PI/delta) / 4) * sin(((queries_data[i] - matrix_data[j] + delta) * M_PI / delta) / 2); 80 | 81 | } 82 | } 83 | } 84 | } 85 | 86 | 87 | 88 | 89 | void cos_stat_backward( 90 | const at::Tensor matrix, 91 | const at::Tensor queries, 92 | const float delta, 93 | at::Tensor grad_cdf_f, 94 | at::Tensor res_grad_matrix_f) 95 | { 96 | 97 | const int m = matrix.size(0); 98 | const int n = queries.size(0); 99 | 100 | const float* matrix_data = matrix.data_ptr(); 101 | const float* queries_data = queries.data_ptr(); 102 | float* grad_cdf_f_data = grad_cdf_f.data_ptr(); 103 | float* res_grad_matrix_f_data = res_grad_matrix_f.data_ptr(); 104 | 105 | cos_stat_back(m, n, matrix_data, queries_data, delta, grad_cdf_f_data, res_grad_matrix_f_data); 106 | } 107 | 108 | void cos_stat_forward( 109 | const at::Tensor sorted_matrix, 110 | const at::Tensor queries, 111 | const float delta, 112 | at::Tensor rescdf_i, 113 | at::Tensor rescdf_f, 114 | at::Tensor respdf_f) 115 | { 116 | 117 | const int m = sorted_matrix.size(0); 118 | const int n = queries.size(0); 119 | 120 | const float* sorted_matrix_data = sorted_matrix.data_ptr(); 121 | const float* queries_data = queries.data_ptr(); 122 | int* rescdf_i_data = rescdf_i.data_ptr(); 123 | float* rescdf_f_data = rescdf_f.data_ptr(); 124 | float* respdf_f_data = respdf_f.data_ptr(); 125 | 126 | cos_stat(m, n, sorted_matrix_data, queries_data, delta, rescdf_i_data, rescdf_f_data, respdf_f_data); 127 | } 128 | 129 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 130 | m.def("forward", &cos_stat_forward, "cos_stat forward"); 131 | m.def("backward", &cos_stat_backward, "cos_stat backward"); 132 | m.def("forward_cuda", &cos_stat_forward_cuda, "cos_stat forward (CUDA)"); 133 | m.def("backward_cuda", &cos_stat_backward_cuda, "cos_stat backward (CUDA)"); 134 | } 135 | -------------------------------------------------------------------------------- /clib/cos_stat/cos_stat.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | 6 | __device__ 7 | int lower_bound_cu(const float *array, int size, float key) 8 | { 9 | int first = 0, len = size; 10 | int half, middle; 11 | 12 | while(len > 0){ 13 | half = len >> 1; 14 | middle = first + half; 15 | if(array[middle] < key){ 16 | first = middle + 1; 17 | len = len - half - 1; 18 | } 19 | else{ 20 | len = half; 21 | } 22 | } 23 | return first; 24 | } 25 | 26 | __global__ 27 | void CosStatKernel(const int m , const int n, const float* sorted_matrix_data, const float* queries_data, const float delta, int* rescdf_i, float* rescdf_f, float* respdf_f) 28 | { 29 | const int batch=1024; 30 | __shared__ float buf[batch]; 31 | for (int k2=0;k2 buf[j] + delta) { 45 | rescdf_i_i += 1; 46 | } 47 | else if (query_data_i < buf[j] - delta){ 48 | break; 49 | } 50 | else{ 51 | rescdf_f_i += -0.5 * cos(((query_data_i - buf[j] + delta) * M_PI / delta) / 2) + 0.5; 52 | respdf_f_i += ((M_PI/delta) / 4) * sin(((query_data_i - buf[j] + delta) * M_PI / delta) / 2); 53 | } 54 | } 55 | rescdf_i[i] += rescdf_i_i; 56 | rescdf_f[i] += rescdf_f_i; 57 | respdf_f[i] += respdf_f_i; 58 | } 59 | __syncthreads(); 60 | 61 | } 62 | } 63 | 64 | __global__ 65 | void CosStatGradKernel(const int m , const int n, const float* matrix_data, const float* queries_data, const float delta, const float* grad_cdf_f, float* res_grad_matrix_f) 66 | { 67 | const int batch=512; 68 | __shared__ float buf1[batch]; 69 | __shared__ float buf2[batch]; 70 | for (int k2=0;k2 buf1[i] + delta){ 83 | continue; 84 | } 85 | else if (matrix_data_j < buf1[i] - delta){ 86 | break; 87 | } 88 | else{ 89 | res_grad_matrix_f_j += -buf2[i] * ((M_PI/delta) / 4) * sin(((buf1[i] - matrix_data_j + delta) * M_PI / delta) / 2); 90 | } 91 | } 92 | res_grad_matrix_f[j] += res_grad_matrix_f_j; 93 | } 94 | __syncthreads(); 95 | } 96 | 97 | } 98 | 99 | void CosStatKernelLauncher(const int m , const int n, const float* sorted_matrix_data, const float* queries_data, const float delta, int* rescdf_i, float* rescdf_f, float* respdf_f) 100 | { 101 | CosStatKernel<<>>(m, n, sorted_matrix_data, queries_data, delta, rescdf_i, rescdf_f, respdf_f); 102 | 103 | // cudaError_t err = cudaGetLastError(); 104 | // if (err != cudaSuccess) 105 | // printf("error in cosstat Output: %s\n", cudaGetErrorString(err)); 106 | } 107 | 108 | void CosStatGradKernelLauncher(const int m , const int n, const float* matrix_data, const float* queries_data, const float delta, const float* grad_cdf_f, float* res_grad_matrix_f) 109 | { 110 | CosStatGradKernel<<>>(m, n, matrix_data, queries_data, delta, grad_cdf_f, res_grad_matrix_f); 111 | 112 | // cudaError_t err = cudaGetLastError(); 113 | // if (err != cudaSuccess) 114 | // printf("error in cosstat Output: %s\n", cudaGetErrorString(err)); 115 | } 116 | 117 | -------------------------------------------------------------------------------- /clib/cos_stat/cos_stat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.cpp_extension import load 5 | 6 | 7 | stat = load(name="stat", 8 | sources=[os.path.join(os.path.split(os.path.abspath(__file__))[0], "cos_stat.cpp"), 9 | os.path.join(os.path.split(os.path.abspath(__file__))[0], "cos_stat.cu")]) 10 | print("stat module loaded") 11 | class CosearStatFun(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, matrix_flatten, queries_flatten, delta): 14 | matrix_sort_ind = torch.argsort(matrix_flatten) 15 | queries_sort_ind = torch.argsort(queries_flatten) 16 | # matrix_sort_ind = torch.arange(len(matrix_flatten), device=queries_flatten.device) 17 | # queries_sort_ind = torch.arange(len(matrix_flatten), device=queries_flatten.device) 18 | # batchsize, n, _ = xyz1.size() 19 | # _, m, _ = xyz2.size() 20 | sorted_matrix_flatten = matrix_flatten[matrix_sort_ind].contiguous() 21 | queries_flatten = queries_flatten.contiguous() 22 | # dist1 = torch.zeros(batchsize, n) 23 | # dist2 = torch.zeros(batchsize, m) 24 | # 25 | # idx1 = torch.zeros(batchsize, n, dtype=torch.int) 26 | # idx2 = torch.zeros(batchsize, m, dtype=torch.int) 27 | 28 | rescdf_i = torch.zeros(queries_flatten.shape[0], dtype=torch.int, device=queries_flatten.device) 29 | rescdf_f = torch.zeros(queries_flatten.shape[0], dtype=torch.float, device=queries_flatten.device) 30 | respdf_f = torch.zeros(queries_flatten.shape[0], dtype=torch.float, device=queries_flatten.device) 31 | if not queries_flatten.is_cuda: 32 | stat.forward(sorted_matrix_flatten, queries_flatten, delta if isinstance(delta, float) else delta.item(), rescdf_i, rescdf_f, respdf_f) 33 | else: 34 | stat.forward_cuda(sorted_matrix_flatten, queries_flatten, delta if isinstance(delta, float) else delta.item(), rescdf_i, rescdf_f, respdf_f) 35 | ctx.save_for_backward(matrix_flatten, queries_flatten, torch.tensor(delta, device=queries_flatten.device) if isinstance(delta, float) else delta, queries_sort_ind, respdf_f) 36 | return rescdf_i, rescdf_f 37 | 38 | @staticmethod 39 | def backward(ctx, grad_rescdf_i, grad_rescdf_f): 40 | matrix_flatten, queries_flatten, delta, queries_sort_ind, respdf_f = ctx.saved_tensors 41 | delta = delta.item() 42 | grad_matrix_f = torch.zeros(matrix_flatten.shape[0], dtype=torch.float, device=matrix_flatten.device) 43 | if not grad_rescdf_f.is_cuda: 44 | stat.backward(matrix_flatten.contiguous(), queries_flatten[queries_sort_ind].contiguous(), delta, grad_rescdf_f[queries_sort_ind].contiguous(), grad_matrix_f) 45 | else: 46 | stat.backward_cuda(matrix_flatten.contiguous(), queries_flatten[queries_sort_ind].contiguous(), delta, 47 | grad_rescdf_f[queries_sort_ind].contiguous(), grad_matrix_f) 48 | 49 | grad_queries = grad_rescdf_f.flatten() * respdf_f 50 | # print("+++++++++++++++++++++++") 51 | # print(grad_matrix_f.sum()) 52 | # print(grad_queries.sum()) 53 | # print("-----------------------") 54 | 55 | return grad_matrix_f, grad_queries, None 56 | 57 | 58 | class Cosear_Stat(torch.nn.Module): 59 | def __init__(self, delta=None, resolution=None): 60 | super().__init__() 61 | self.delta = delta 62 | self.resolution = resolution 63 | def forward(self, matrix, queries): 64 | if self.resolution: 65 | delta = ((matrix.max() - matrix.min()) / self.resolution).detach() 66 | else: 67 | delta = self.delta 68 | matrix_flatten = matrix.flatten() 69 | queries_flatten = queries.flatten() 70 | rescdf_i, rescdf_f = CosearStatFun.apply(matrix_flatten, queries_flatten, delta) 71 | return rescdf_i.view(queries.shape), rescdf_f.view(queries.shape) 72 | 73 | if __name__ == '__main__': 74 | import argparse 75 | 76 | tmp = (torch.arange(10, dtype=torch.float)/2).cuda().requires_grad_(True) 77 | test = (torch.tensor([5.1])/2).cuda().requires_grad_(True) 78 | rescdf_i, rescdf_f = CosearStatFun.apply(tmp, test, 1.0/2) 79 | rescdf_f.sum().backward() 80 | import pdb 81 | pdb.set_trace() 82 | -------------------------------------------------------------------------------- /clib/cos_stat/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='CosStat', 6 | ext_modules=[ 7 | CUDAExtension('CosStat', [ 8 | "/".join(__file__.split('/')[:-1] + ['cos_stat.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['cos_stat.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /clib/lin_stat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/L2_Compression/4d28b141930399cfb4990a63a3fae85c23a1b1e5/clib/lin_stat/__init__.py -------------------------------------------------------------------------------- /clib/lin_stat/lin_stat.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | void LinStatKernelLauncher(const int m , const int n, const float* sorted_matrix_data, const float* queries_data, const float delta, int* rescdf_i, float* rescdf_f, float* respdf_f); 5 | void LinStatGradKernelLauncher(const int m , const int n, const float* matrix_data, const float* queries_data, const float delta, const float* grad_cdf_f, float* res_grad_matrix_f); 6 | 7 | void lin_stat_forward_cuda( 8 | const at::Tensor sorted_matrix, 9 | const at::Tensor queries, 10 | const float delta, 11 | at::Tensor rescdf_i, 12 | at::Tensor rescdf_f, 13 | at::Tensor respdf_f) 14 | { 15 | LinStatKernelLauncher(sorted_matrix.size(0), queries.size(0), sorted_matrix.data_ptr(), queries.data_ptr(), delta, rescdf_i.data_ptr(), rescdf_f.data_ptr(), respdf_f.data_ptr()); 16 | } 17 | 18 | void lin_stat_backward_cuda( 19 | const at::Tensor matrix, 20 | const at::Tensor queries, 21 | const float delta, 22 | at::Tensor grad_cdf_f, 23 | at::Tensor res_grad_matrix_f) 24 | { 25 | LinStatGradKernelLauncher(matrix.size(0), queries.size(0), matrix.data_ptr(), queries.data_ptr(), delta, grad_cdf_f.data_ptr(), res_grad_matrix_f.data_ptr()); 26 | } 27 | 28 | int lower_bound(const float *array, int size, float key) 29 | { 30 | int first = 0, len = size; 31 | int half, middle; 32 | 33 | while(len > 0){ 34 | half = len >> 1; 35 | middle = first + half; 36 | if(array[middle] < key){ 37 | first = middle + 1; 38 | len = len - half - 1; 39 | } 40 | else{ 41 | len = half; 42 | } 43 | } 44 | return first; 45 | } 46 | 47 | void lin_stat(const int m , const int n, const float* sorted_matrix_data, const float* queries_data, const float delta, int* rescdf_i, float* rescdf_f, float* respdf_f) 48 | { 49 | for (int i = 0; i < n; i++) { 50 | int start = lower_bound(sorted_matrix_data, m, queries_data[i] - delta); 51 | rescdf_i[i] += start; 52 | for (int j = start; j < m; j++) { 53 | if (queries_data[i] > sorted_matrix_data[j] + delta) { 54 | rescdf_f[i] += 1; 55 | } 56 | else if (queries_data[i] < sorted_matrix_data[j] - delta){ 57 | break; 58 | } 59 | else{ 60 | rescdf_f[i] += ((queries_data[i] - sorted_matrix_data[j] + delta) / delta) / 2; 61 | respdf_f[i] += (1 / delta) / 2; 62 | } 63 | } 64 | } 65 | } 66 | 67 | void lin_stat_back(const int m , const int n, const float* matrix_data, const float* queries_data, const float delta, const float* grad_cdf_f, float* res_grad_matrix_f) 68 | { 69 | for (int j = 0; j < m; j++) { 70 | int start = lower_bound(queries_data, n, matrix_data[j] - delta); 71 | for (int i = start; i < n; i++) { 72 | if (matrix_data[j] > queries_data[i] + delta){ 73 | continue; 74 | } 75 | else if (matrix_data[j] < queries_data[i] - delta){ 76 | break; 77 | } 78 | else{ 79 | res_grad_matrix_f[j] += -(grad_cdf_f[i] / delta) / 2; 80 | } 81 | } 82 | } 83 | } 84 | 85 | 86 | 87 | 88 | void lin_stat_backward( 89 | const at::Tensor matrix, 90 | const at::Tensor queries, 91 | const float delta, 92 | at::Tensor grad_cdf_f, 93 | at::Tensor res_grad_matrix_f) 94 | { 95 | 96 | const int m = matrix.size(0); 97 | const int n = queries.size(0); 98 | 99 | const float* matrix_data = matrix.data_ptr(); 100 | const float* queries_data = queries.data_ptr(); 101 | float* grad_cdf_f_data = grad_cdf_f.data_ptr(); 102 | float* res_grad_matrix_f_data = res_grad_matrix_f.data_ptr(); 103 | 104 | lin_stat_back(m, n, matrix_data, queries_data, delta, grad_cdf_f_data, res_grad_matrix_f_data); 105 | } 106 | 107 | void lin_stat_forward( 108 | const at::Tensor sorted_matrix, 109 | const at::Tensor queries, 110 | const float delta, 111 | at::Tensor rescdf_i, 112 | at::Tensor rescdf_f, 113 | at::Tensor respdf_f) 114 | { 115 | 116 | const int m = sorted_matrix.size(0); 117 | const int n = queries.size(0); 118 | 119 | const float* sorted_matrix_data = sorted_matrix.data_ptr(); 120 | const float* queries_data = queries.data_ptr(); 121 | int* rescdf_i_data = rescdf_i.data_ptr(); 122 | float* rescdf_f_data = rescdf_f.data_ptr(); 123 | float* respdf_f_data = respdf_f.data_ptr(); 124 | 125 | lin_stat(m, n, sorted_matrix_data, queries_data, delta, rescdf_i_data, rescdf_f_data, respdf_f_data); 126 | } 127 | 128 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 129 | m.def("forward", &lin_stat_forward, "lin_stat forward"); 130 | m.def("backward", &lin_stat_backward, "lin_stat backward"); 131 | m.def("forward_cuda", &lin_stat_forward_cuda, "lin_stat forward (CUDA)"); 132 | m.def("backward_cuda", &lin_stat_backward_cuda, "lin_stat backward (CUDA)"); 133 | } 134 | -------------------------------------------------------------------------------- /clib/lin_stat/lin_stat.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | 6 | __device__ 7 | int lower_bound_cu(const float *array, int size, float key) 8 | { 9 | int first = 0, len = size; 10 | int half, middle; 11 | 12 | while(len > 0){ 13 | half = len >> 1; 14 | middle = first + half; 15 | if(array[middle] < key){ 16 | first = middle + 1; 17 | len = len - half - 1; 18 | } 19 | else{ 20 | len = half; 21 | } 22 | } 23 | return first; 24 | } 25 | 26 | __global__ 27 | void LinStatKernel(const int m , const int n, const float* sorted_matrix_data, const float* queries_data, const float delta, int* rescdf_i, float* rescdf_f, float* respdf_f) 28 | { 29 | const int batch=1024; 30 | __shared__ float buf[batch]; 31 | for (int k2=0;k2 buf[j] + delta) { 45 | rescdf_i_i += 1; 46 | } 47 | else if (query_data_i < buf[j] - delta){ 48 | break; 49 | } 50 | else{ 51 | rescdf_f_i += ((query_data_i - buf[j] + delta) / delta) / 2; 52 | respdf_f_i += (1 / delta) / 2; 53 | } 54 | } 55 | rescdf_i[i] += rescdf_i_i; 56 | rescdf_f[i] += rescdf_f_i; 57 | respdf_f[i] += respdf_f_i; 58 | } 59 | __syncthreads(); 60 | 61 | } 62 | } 63 | 64 | __global__ 65 | void LinStatGradKernel(const int m , const int n, const float* matrix_data, const float* queries_data, const float delta, const float* grad_cdf_f, float* res_grad_matrix_f) 66 | { 67 | const int batch=512; 68 | __shared__ float buf1[batch]; 69 | __shared__ float buf2[batch]; 70 | for (int k2=0;k2 buf1[i] + delta){ 83 | continue; 84 | } 85 | else if (matrix_data_j < buf1[i] - delta){ 86 | break; 87 | } 88 | else{ 89 | res_grad_matrix_f_j += -(buf2[i] / delta) / 2; 90 | } 91 | } 92 | res_grad_matrix_f[j] += res_grad_matrix_f_j; 93 | } 94 | __syncthreads(); 95 | } 96 | 97 | } 98 | 99 | void LinStatKernelLauncher(const int m , const int n, const float* sorted_matrix_data, const float* queries_data, const float delta, int* rescdf_i, float* rescdf_f, float* respdf_f) 100 | { 101 | LinStatKernel<<>>(m, n, sorted_matrix_data, queries_data, delta, rescdf_i, rescdf_f, respdf_f); 102 | 103 | // cudaError_t err = cudaGetLastError(); 104 | // if (err != cudaSuccess) 105 | // printf("error in linstat Output: %s\n", cudaGetErrorString(err)); 106 | } 107 | 108 | void LinStatGradKernelLauncher(const int m , const int n, const float* matrix_data, const float* queries_data, const float delta, const float* grad_cdf_f, float* res_grad_matrix_f) 109 | { 110 | LinStatGradKernel<<>>(m, n, matrix_data, queries_data, delta, grad_cdf_f, res_grad_matrix_f); 111 | 112 | // cudaError_t err = cudaGetLastError(); 113 | // if (err != cudaSuccess) 114 | // printf("error in linstat Output: %s\n", cudaGetErrorString(err)); 115 | } 116 | 117 | -------------------------------------------------------------------------------- /clib/lin_stat/lin_stat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.cpp_extension import load 5 | 6 | 7 | stat = load(name="stat", 8 | sources=[os.path.join(os.path.split(os.path.abspath(__file__))[0], "lin_stat.cpp"), 9 | os.path.join(os.path.split(os.path.abspath(__file__))[0], "lin_stat.cu")]) 10 | print("stat module loaded") 11 | class LinearStatFun(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, matrix_flatten, queries_flatten, delta): 14 | matrix_sort_ind = torch.argsort(matrix_flatten) 15 | queries_sort_ind = torch.argsort(queries_flatten) 16 | # matrix_sort_ind = torch.arange(len(matrix_flatten), device=queries_flatten.device) 17 | # queries_sort_ind = torch.arange(len(matrix_flatten), device=queries_flatten.device) 18 | # batchsize, n, _ = xyz1.size() 19 | # _, m, _ = xyz2.size() 20 | sorted_matrix_flatten = matrix_flatten[matrix_sort_ind].contiguous() 21 | queries_flatten = queries_flatten.contiguous() 22 | # dist1 = torch.zeros(batchsize, n) 23 | # dist2 = torch.zeros(batchsize, m) 24 | # 25 | # idx1 = torch.zeros(batchsize, n, dtype=torch.int) 26 | # idx2 = torch.zeros(batchsize, m, dtype=torch.int) 27 | 28 | rescdf_i = torch.zeros(queries_flatten.shape[0], dtype=torch.int, device=queries_flatten.device) 29 | rescdf_f = torch.zeros(queries_flatten.shape[0], dtype=torch.float, device=queries_flatten.device) 30 | respdf_f = torch.zeros(queries_flatten.shape[0], dtype=torch.float, device=queries_flatten.device) 31 | if not queries_flatten.is_cuda: 32 | stat.forward(sorted_matrix_flatten, queries_flatten, delta if isinstance(delta, float) else delta.item(), rescdf_i, rescdf_f, respdf_f) 33 | else: 34 | stat.forward_cuda(sorted_matrix_flatten, queries_flatten, delta if isinstance(delta, float) else delta.item(), rescdf_i, rescdf_f, respdf_f) 35 | ctx.save_for_backward(matrix_flatten, queries_flatten, torch.tensor(delta, device=queries_flatten.device) if isinstance(delta, float) else delta, queries_sort_ind, respdf_f) 36 | return rescdf_i, rescdf_f 37 | 38 | @staticmethod 39 | def backward(ctx, grad_rescdf_i, grad_rescdf_f): 40 | matrix_flatten, queries_flatten, delta, queries_sort_ind, respdf_f = ctx.saved_tensors 41 | delta = delta.item() 42 | grad_matrix_f = torch.zeros(matrix_flatten.shape[0], dtype=torch.float, device=matrix_flatten.device) 43 | if not grad_rescdf_f.is_cuda: 44 | stat.backward(matrix_flatten.contiguous(), queries_flatten[queries_sort_ind].contiguous(), delta, grad_rescdf_f[queries_sort_ind].contiguous(), grad_matrix_f) 45 | else: 46 | stat.backward_cuda(matrix_flatten.contiguous(), queries_flatten[queries_sort_ind].contiguous(), delta, 47 | grad_rescdf_f[queries_sort_ind].contiguous(), grad_matrix_f) 48 | 49 | grad_queries = grad_rescdf_f.flatten() * respdf_f 50 | # print("+++++++++++++++++++++++") 51 | # print(grad_matrix_f.sum()) 52 | # print(grad_queries.sum()) 53 | # print("-----------------------") 54 | 55 | return grad_matrix_f, grad_queries, None 56 | 57 | 58 | class Linear_Stat(torch.nn.Module): 59 | def __init__(self, delta=None, resolution=None): 60 | super().__init__() 61 | self.delta = delta 62 | self.resolution = resolution 63 | def forward(self, matrix, queries): 64 | if self.resolution: 65 | delta = ((matrix.max() - matrix.min()) / self.resolution).detach() 66 | else: 67 | delta = self.delta 68 | matrix_flatten = matrix.flatten() 69 | queries_flatten = queries.flatten() 70 | rescdf_i, rescdf_f = LinearStatFun.apply(matrix_flatten, queries_flatten, delta) 71 | return rescdf_i.view(queries.shape), rescdf_f.view(queries.shape) 72 | 73 | if __name__ == '__main__': 74 | import argparse 75 | 76 | tmp = (torch.arange(10, dtype=torch.float)/2).cuda().requires_grad_(True) 77 | test = (torch.tensor([5.1])/2).cuda().requires_grad_(True) 78 | rescdf_i, rescdf_f = LinearStatFun.apply(tmp, test, 1.0/2) 79 | rescdf_f.sum().backward() 80 | import pdb 81 | pdb.set_trace() 82 | -------------------------------------------------------------------------------- /clib/lin_stat/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='LinStat', 6 | ext_modules=[ 7 | CUDAExtension('LinStat', [ 8 | "/".join(__file__.split('/')[:-1] + ['lin_stat.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['lin_stat.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /clib/tri_stat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/L2_Compression/4d28b141930399cfb4990a63a3fae85c23a1b1e5/clib/tri_stat/__init__.py -------------------------------------------------------------------------------- /clib/tri_stat/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='TriStat', 6 | ext_modules=[ 7 | CUDAExtension('TriStat', [ 8 | "/".join(__file__.split('/')[:-1] + ['tri_stat.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['tri_stat.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /clib/tri_stat/tri_stat.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | void TriStatKernelLauncher(const int m , const int n, const float* sorted_matrix_data, const float* queries_data, const float delta, int* rescdf_i, float* rescdf_f, float* respdf_f); 5 | void TriStatGradKernelLauncher(const int m , const int n, const float* matrix_data, const float* queries_data, const float delta, const float* grad_cdf_f, float* res_grad_matrix_f); 6 | 7 | void tri_stat_forward_cuda( 8 | const at::Tensor sorted_matrix, 9 | const at::Tensor queries, 10 | const float delta, 11 | at::Tensor rescdf_i, 12 | at::Tensor rescdf_f, 13 | at::Tensor respdf_f) 14 | { 15 | TriStatKernelLauncher(sorted_matrix.size(0), queries.size(0), sorted_matrix.data_ptr(), queries.data_ptr(), delta, rescdf_i.data_ptr(), rescdf_f.data_ptr(), respdf_f.data_ptr()); 16 | } 17 | 18 | void tri_stat_backward_cuda( 19 | const at::Tensor matrix, 20 | const at::Tensor queries, 21 | const float delta, 22 | at::Tensor grad_cdf_f, 23 | at::Tensor res_grad_matrix_f) 24 | { 25 | TriStatGradKernelLauncher(matrix.size(0), queries.size(0), matrix.data_ptr(), queries.data_ptr(), delta, grad_cdf_f.data_ptr(), res_grad_matrix_f.data_ptr()); 26 | } 27 | 28 | int lower_bound(const float *array, int size, float key) 29 | { 30 | int first = 0, len = size; 31 | int half, middle; 32 | 33 | while(len > 0){ 34 | half = len >> 1; 35 | middle = first + half; 36 | if(array[middle] < key){ 37 | first = middle + 1; 38 | len = len - half - 1; 39 | } 40 | else{ 41 | len = half; 42 | } 43 | } 44 | return first; 45 | } 46 | 47 | void tri_stat(const int m , const int n, const float* sorted_matrix_data, const float* queries_data, const float delta, int* rescdf_i, float* rescdf_f, float* respdf_f) 48 | { 49 | for (int i = 0; i < n; i++) { 50 | int start = lower_bound(sorted_matrix_data, m, queries_data[i] - delta); 51 | rescdf_i[i] += start; 52 | for (int j = start; j < m; j++) { 53 | if (queries_data[i] > sorted_matrix_data[j] + delta) { 54 | rescdf_f[i] += 1; 55 | } 56 | else if (queries_data[i] < sorted_matrix_data[j] - delta){ 57 | break; 58 | } 59 | else{ 60 | if (queries_data[i] < sorted_matrix_data[j]){ 61 | rescdf_f[i] += ((queries_data[i] - sorted_matrix_data[j] + delta) * (queries_data[i] - sorted_matrix_data[j] + delta) / delta / delta) / 2; 62 | respdf_f[i] += (queries_data[i] - sorted_matrix_data[j] + delta) / delta / delta; 63 | } 64 | else{ 65 | rescdf_f[i] += 1 - ((sorted_matrix_data[j] - queries_data[i] + delta) * (sorted_matrix_data[j] - queries_data[i] + delta) / delta / delta) / 2; 66 | respdf_f[i] += (sorted_matrix_data[j] - queries_data[i] + delta) / delta / delta; 67 | } 68 | 69 | } 70 | } 71 | } 72 | } 73 | 74 | void tri_stat_back(const int m , const int n, const float* matrix_data, const float* queries_data, const float delta, const float* grad_cdf_f, float* res_grad_matrix_f) 75 | { 76 | for (int j = 0; j < m; j++) { 77 | int start = lower_bound(queries_data, n, matrix_data[j] - delta); 78 | for (int i = start; i < n; i++) { 79 | if (matrix_data[j] > queries_data[i] + delta){ 80 | continue; 81 | } 82 | else if (matrix_data[j] < queries_data[i] - delta){ 83 | break; 84 | } 85 | else{ 86 | if (queries_data[i] < matrix_data[j]){ 87 | res_grad_matrix_f[j] += -(grad_cdf_f[i]) * (queries_data[i] - matrix_data[j] + delta) / delta / delta; 88 | } 89 | else{ 90 | res_grad_matrix_f[j] += -(grad_cdf_f[i]) * (matrix_data[j] - queries_data[i] + delta) / delta / delta; 91 | } 92 | 93 | } 94 | } 95 | } 96 | } 97 | 98 | 99 | 100 | 101 | void tri_stat_backward( 102 | const at::Tensor matrix, 103 | const at::Tensor queries, 104 | const float delta, 105 | at::Tensor grad_cdf_f, 106 | at::Tensor res_grad_matrix_f) 107 | { 108 | 109 | const int m = matrix.size(0); 110 | const int n = queries.size(0); 111 | 112 | const float* matrix_data = matrix.data_ptr(); 113 | const float* queries_data = queries.data_ptr(); 114 | float* grad_cdf_f_data = grad_cdf_f.data_ptr(); 115 | float* res_grad_matrix_f_data = res_grad_matrix_f.data_ptr(); 116 | 117 | tri_stat_back(m, n, matrix_data, queries_data, delta, grad_cdf_f_data, res_grad_matrix_f_data); 118 | } 119 | 120 | void tri_stat_forward( 121 | const at::Tensor sorted_matrix, 122 | const at::Tensor queries, 123 | const float delta, 124 | at::Tensor rescdf_i, 125 | at::Tensor rescdf_f, 126 | at::Tensor respdf_f) 127 | { 128 | 129 | const int m = sorted_matrix.size(0); 130 | const int n = queries.size(0); 131 | 132 | const float* sorted_matrix_data = sorted_matrix.data_ptr(); 133 | const float* queries_data = queries.data_ptr(); 134 | int* rescdf_i_data = rescdf_i.data_ptr(); 135 | float* rescdf_f_data = rescdf_f.data_ptr(); 136 | float* respdf_f_data = respdf_f.data_ptr(); 137 | 138 | tri_stat(m, n, sorted_matrix_data, queries_data, delta, rescdf_i_data, rescdf_f_data, respdf_f_data); 139 | } 140 | 141 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 142 | m.def("forward", &tri_stat_forward, "tri_stat forward"); 143 | m.def("backward", &tri_stat_backward, "tri_stat backward"); 144 | m.def("forward_cuda", &tri_stat_forward_cuda, "tri_stat forward (CUDA)"); 145 | m.def("backward_cuda", &tri_stat_backward_cuda, "tri_stat backward (CUDA)"); 146 | } 147 | -------------------------------------------------------------------------------- /clib/tri_stat/tri_stat.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | 6 | __device__ 7 | int lower_bound_cu(const float *array, int size, float key) 8 | { 9 | int first = 0, len = size; 10 | int half, middle; 11 | 12 | while(len > 0){ 13 | half = len >> 1; 14 | middle = first + half; 15 | if(array[middle] < key){ 16 | first = middle + 1; 17 | len = len - half - 1; 18 | } 19 | else{ 20 | len = half; 21 | } 22 | } 23 | return first; 24 | } 25 | 26 | __global__ 27 | void TriStatKernel(const int m , const int n, const float* sorted_matrix_data, const float* queries_data, const float delta, int* rescdf_i, float* rescdf_f, float* respdf_f) 28 | { 29 | const int batch=1024; 30 | __shared__ float buf[batch]; 31 | for (int k2=0;k2 buf[j] + delta) { 45 | rescdf_i_i += 1; 46 | } 47 | else if (query_data_i < buf[j] - delta){ 48 | break; 49 | } 50 | else{ 51 | if (query_data_i < buf[j]){ 52 | rescdf_f_i += ((query_data_i - buf[j] + delta) * (query_data_i - buf[j] + delta) / delta / delta) / 2; 53 | respdf_f_i += (query_data_i - buf[j] + delta) / delta / delta; 54 | } 55 | else{ 56 | rescdf_f_i += 1 - ((buf[j] - query_data_i + delta) * (buf[j] - query_data_i + delta) / delta / delta) / 2; 57 | respdf_f_i += (buf[j] - query_data_i + delta) / delta / delta; 58 | } 59 | } 60 | } 61 | rescdf_i[i] += rescdf_i_i; 62 | rescdf_f[i] += rescdf_f_i; 63 | respdf_f[i] += respdf_f_i; 64 | } 65 | __syncthreads(); 66 | 67 | } 68 | } 69 | 70 | __global__ 71 | void TriStatGradKernel(const int m , const int n, const float* matrix_data, const float* queries_data, const float delta, const float* grad_cdf_f, float* res_grad_matrix_f) 72 | { 73 | const int batch=512; 74 | __shared__ float buf1[batch]; 75 | __shared__ float buf2[batch]; 76 | for (int k2=0;k2 buf1[i] + delta){ 89 | continue; 90 | } 91 | else if (matrix_data_j < buf1[i] - delta){ 92 | break; 93 | } 94 | else{ 95 | if (buf1[i] < matrix_data_j){ 96 | res_grad_matrix_f_j += -(buf2[i]) * (buf1[i] - matrix_data_j + delta) / delta / delta; 97 | } 98 | else{ 99 | res_grad_matrix_f_j += -(buf2[i]) * (matrix_data_j - buf1[i] + delta) / delta / delta; 100 | } 101 | } 102 | } 103 | res_grad_matrix_f[j] += res_grad_matrix_f_j; 104 | } 105 | __syncthreads(); 106 | } 107 | 108 | } 109 | 110 | void TriStatKernelLauncher(const int m , const int n, const float* sorted_matrix_data, const float* queries_data, const float delta, int* rescdf_i, float* rescdf_f, float* respdf_f) 111 | { 112 | TriStatKernel<<>>(m, n, sorted_matrix_data, queries_data, delta, rescdf_i, rescdf_f, respdf_f); 113 | 114 | // cudaError_t err = cudaGetLastError(); 115 | // if (err != cudaSuccess) 116 | // printf("error in tristat Output: %s\n", cudaGetErrorString(err)); 117 | } 118 | 119 | void TriStatGradKernelLauncher(const int m , const int n, const float* matrix_data, const float* queries_data, const float delta, const float* grad_cdf_f, float* res_grad_matrix_f) 120 | { 121 | TriStatGradKernel<<>>(m, n, matrix_data, queries_data, delta, grad_cdf_f, res_grad_matrix_f); 122 | 123 | // cudaError_t err = cudaGetLastError(); 124 | // if (err != cudaSuccess) 125 | // printf("error in tristat Output: %s\n", cudaGetErrorString(err)); 126 | } 127 | 128 | -------------------------------------------------------------------------------- /clib/tri_stat/tri_stat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.cpp_extension import load 5 | 6 | 7 | stat = load(name="stat", 8 | sources=[os.path.join(os.path.split(os.path.abspath(__file__))[0], "tri_stat.cpp"), 9 | os.path.join(os.path.split(os.path.abspath(__file__))[0], "tri_stat.cu")]) 10 | print("stat module loaded") 11 | class TriearStatFun(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, matrix_flatten, queries_flatten, delta): 14 | matrix_sort_ind = torch.argsort(matrix_flatten) 15 | queries_sort_ind = torch.argsort(queries_flatten) 16 | # matrix_sort_ind = torch.arange(len(matrix_flatten), device=queries_flatten.device) 17 | # queries_sort_ind = torch.arange(len(matrix_flatten), device=queries_flatten.device) 18 | # batchsize, n, _ = xyz1.size() 19 | # _, m, _ = xyz2.size() 20 | sorted_matrix_flatten = matrix_flatten[matrix_sort_ind].contiguous() 21 | queries_flatten = queries_flatten.contiguous() 22 | # dist1 = torch.zeros(batchsize, n) 23 | # dist2 = torch.zeros(batchsize, m) 24 | # 25 | # idx1 = torch.zeros(batchsize, n, dtype=torch.int) 26 | # idx2 = torch.zeros(batchsize, m, dtype=torch.int) 27 | 28 | rescdf_i = torch.zeros(queries_flatten.shape[0], dtype=torch.int, device=queries_flatten.device) 29 | rescdf_f = torch.zeros(queries_flatten.shape[0], dtype=torch.float, device=queries_flatten.device) 30 | respdf_f = torch.zeros(queries_flatten.shape[0], dtype=torch.float, device=queries_flatten.device) 31 | if not queries_flatten.is_cuda: 32 | stat.forward(sorted_matrix_flatten, queries_flatten, delta if isinstance(delta, float) else delta.item(), rescdf_i, rescdf_f, respdf_f) 33 | else: 34 | stat.forward_cuda(sorted_matrix_flatten, queries_flatten, delta if isinstance(delta, float) else delta.item(), rescdf_i, rescdf_f, respdf_f) 35 | ctx.save_for_backward(matrix_flatten, queries_flatten, torch.tensor(delta, device=queries_flatten.device) if isinstance(delta, float) else delta, queries_sort_ind, respdf_f) 36 | return rescdf_i, rescdf_f 37 | 38 | @staticmethod 39 | def backward(ctx, grad_rescdf_i, grad_rescdf_f): 40 | matrix_flatten, queries_flatten, delta, queries_sort_ind, respdf_f = ctx.saved_tensors 41 | delta = delta.item() 42 | grad_matrix_f = torch.zeros(matrix_flatten.shape[0], dtype=torch.float, device=matrix_flatten.device) 43 | if not grad_rescdf_f.is_cuda: 44 | stat.backward(matrix_flatten.contiguous(), queries_flatten[queries_sort_ind].contiguous(), delta, grad_rescdf_f[queries_sort_ind].contiguous(), grad_matrix_f) 45 | else: 46 | stat.backward_cuda(matrix_flatten.contiguous(), queries_flatten[queries_sort_ind].contiguous(), delta, 47 | grad_rescdf_f[queries_sort_ind].contiguous(), grad_matrix_f) 48 | 49 | grad_queries = grad_rescdf_f.flatten() * respdf_f 50 | # print("+++++++++++++++++++++++") 51 | # print(grad_matrix_f.sum()) 52 | # print(grad_queries.sum()) 53 | # print("-----------------------") 54 | 55 | return grad_matrix_f, grad_queries, None 56 | 57 | 58 | class Triear_Stat(torch.nn.Module): 59 | def __init__(self, delta=None, resolution=None): 60 | super().__init__() 61 | self.delta = delta 62 | self.resolution = resolution 63 | def forward(self, matrix, queries): 64 | if self.resolution: 65 | delta = ((matrix.max() - matrix.min()) / self.resolution).detach() 66 | else: 67 | delta = self.delta 68 | matrix_flatten = matrix.flatten() 69 | queries_flatten = queries.flatten() 70 | rescdf_i, rescdf_f = TriearStatFun.apply(matrix_flatten, queries_flatten, delta) 71 | return rescdf_i.view(queries.shape), rescdf_f.view(queries.shape) 72 | 73 | if __name__ == '__main__': 74 | import argparse 75 | 76 | tmp = (torch.arange(10, dtype=torch.float)).requires_grad_(True) 77 | test = (torch.tensor([5.1])).requires_grad_(True) 78 | rescdf_i, rescdf_f = TriearStatFun.apply(tmp, test, 1.0) 79 | rescdf_f.sum().backward() 80 | import pdb 81 | pdb.set_trace() 82 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/L2_Compression/4d28b141930399cfb4990a63a3fae85c23a1b1e5/common/__init__.py -------------------------------------------------------------------------------- /common/tools.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch 3 | 4 | def get_np_size(x): 5 | return x.size * x.itemsize 6 | 7 | def myabs(x): 8 | return torch.where(x==0, x, torch.abs(x)) 9 | 10 | def mysign(x): 11 | return torch.where(x == 0, torch.ones_like(x), torch.sign(x)) 12 | 13 | def to_np(x): 14 | return x.detach().cpu().numpy() -------------------------------------------------------------------------------- /compress/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/L2_Compression/4d28b141930399cfb4990a63a3fae85c23a1b1e5/compress/__init__.py -------------------------------------------------------------------------------- /compress/huffman.py: -------------------------------------------------------------------------------- 1 | import constriction 2 | import numpy as np 3 | from common.tools import get_np_size 4 | 5 | def compress_matrix_flatten(matrix): 6 | ''' 7 | :param matrix: np.array 8 | :return compressed, symtable 9 | ''' 10 | matrix = matrix.flatten() 11 | unique, unique_indices, unique_inverse, unique_counts = np.unique(matrix, return_index=True, return_inverse=True, return_counts=True, axis=None) 12 | message = unique_inverse.astype(np.int32) 13 | probabilities = unique_counts.astype(np.float64) / np.sum(unique_counts).astype(np.float64) 14 | encoder = constriction.symbol.QueueEncoder() 15 | encoder_codebook = constriction.symbol.huffman.EncoderHuffmanTree(probabilities) 16 | for symbol in message: 17 | encoder.encode_symbol(symbol, encoder_codebook) 18 | compressed, bitrate = encoder.get_compressed() 19 | # encoder.encode(message, probabilities_model) 20 | # compressed = encoder.get_compressed() 21 | return bitrate, unique 22 | 23 | if __name__ == '__main__': 24 | matrix = np.random.randint(low=0, high=255, size=(128, 128)).astype(np.float32) 25 | compressed, symtable = compress_matrix_flatten(matrix) 26 | print(get_np_size(compressed[0]), get_np_size(symtable)) 27 | -------------------------------------------------------------------------------- /compress/range_coder.py: -------------------------------------------------------------------------------- 1 | import constriction 2 | import numpy as np 3 | from common.tools import get_np_size 4 | 5 | def compress_matrix_flatten(matrix): 6 | ''' 7 | :param matrix: np.array 8 | :return compressed, symtable 9 | ''' 10 | matrix = matrix.flatten() 11 | unique, unique_indices, unique_inverse, unique_counts = np.unique(matrix, return_index=True, return_inverse=True, return_counts=True, axis=None) 12 | if unique.shape[0] == 1: 13 | return unique.astype(np.uint8), unique 14 | message = unique_inverse.astype(np.int32) 15 | probabilities = unique_counts.astype(np.float64) / np.sum(unique_counts).astype(np.float64) 16 | probabilities_model = constriction.stream.model.Categorical(probabilities) 17 | encoder = constriction.stream.queue.RangeEncoder() 18 | encoder.encode(message, probabilities_model) 19 | compressed = encoder.get_compressed() 20 | return compressed, unique 21 | 22 | def compress_matrix_flatten_new(matrix): 23 | ''' 24 | :param matrix: np.array 25 | :return compressed, symtable 26 | ''' 27 | matrix = matrix.flatten() 28 | unique, unique_indices, unique_inverse, unique_counts = np.unique(matrix, return_index=True, return_inverse=True, return_counts=True, axis=None) 29 | if unique.shape[0] == 1: 30 | return unique.astype(np.uint8), unique, 0 31 | message = unique_inverse.astype(np.int32) 32 | probabilities = unique_counts.astype(np.float64) / np.sum(unique_counts).astype(np.float64) 33 | probabilities_model = constriction.stream.model.Categorical(probabilities) 34 | encoder = constriction.stream.queue.RangeEncoder() 35 | encoder.encode(message, probabilities_model) 36 | compressed = encoder.get_compressed() 37 | return compressed, unique, probabilities 38 | 39 | if __name__ == '__main__': 40 | matrix = np.random.randint(low=0, high=255, size=(128, 128)).astype(np.float32) 41 | compressed, symtable = compress_matrix_flatten(matrix) 42 | print(get_np_size(compressed), get_np_size(symtable)) 43 | -------------------------------------------------------------------------------- /data_new/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/L2_Compression/4d28b141930399cfb4990a63a3fae85c23a1b1e5/data_new/__init__.py -------------------------------------------------------------------------------- /data_new/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.transforms as T 4 | from torchvision.datasets import ImageFolder, ImageNet 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class Imagenet(): 9 | def __init__(self) -> None: 10 | super().__init__() 11 | self.mean = (0.485, 0.456, 0.406) 12 | self.std = (0.229, 0.224, 0.225) 13 | 14 | def train_dataloader(self): 15 | transform = T.Compose( 16 | [ 17 | T.RandomResizedCrop(224), 18 | T.RandomHorizontalFlip(), 19 | T.ToTensor(), 20 | T.Normalize(self.mean, self.std), 21 | ] 22 | ) 23 | dataset = ImageFolder(root="", transform=transform) 24 | dataloader = DataLoader( 25 | dataset, 26 | batch_size=64, 27 | num_workers=2, 28 | shuffle=True, 29 | drop_last=True, 30 | pin_memory=True, 31 | ) 32 | return dataloader 33 | 34 | def val_dataloader(self): 35 | transform = T.Compose( 36 | [ 37 | T.Resize(256), 38 | T.CenterCrop(224), 39 | T.ToTensor(), 40 | T.Normalize(self.mean, self.std), 41 | ] 42 | ) 43 | dataset = ImageFolder(root="", transform=transform) 44 | dataloader = DataLoader( 45 | dataset, 46 | batch_size=1, 47 | num_workers=4, 48 | drop_last=False, 49 | pin_memory=False, 50 | ) 51 | return dataloader 52 | 53 | def test_dataloader(self): 54 | return self.val_dataloader() -------------------------------------------------------------------------------- /data_new/imagenet.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import torchvision.transforms as T 5 | # import cv2 6 | from torch.utils.data import Dataset 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | class ImageNetDataset(Dataset): 12 | """ 13 | ImageNet Dataset. 14 | """ 15 | def __init__(self, root_dir, meta_file, transform): 16 | 17 | self.root_dir = root_dir 18 | self.meta_file = meta_file 19 | self.transform = transform 20 | 21 | with open(meta_file) as f: 22 | lines = f.readlines() 23 | 24 | self.num = len(lines) 25 | self.metas = [] 26 | for line in lines: 27 | filename, label = line.rstrip().split() 28 | self.metas.append({'filename': filename, 'label': label}) 29 | 30 | 31 | def __len__(self): 32 | return self.num 33 | 34 | def _load_meta(self, idx): 35 | 36 | return self.metas[idx] 37 | 38 | def __getitem__(self, idx): 39 | curr_meta = self._load_meta(idx) 40 | filename = os.path.join(self.root_dir, curr_meta['filename']) 41 | label = int(curr_meta['label']) 42 | with Image.open(filename) as img: 43 | img = img.convert('RGB') 44 | if self.transform is not None: 45 | img = self.transform(img) 46 | return img, label 47 | # item = { 48 | # 'image': img, 49 | # 'label': label, 50 | # 'image_id': idx, 51 | # 'filename': filename 52 | # } 53 | # return item 54 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .resnet import BasicBlock, Bottleneck, resnet18, resnet50 # noqa: F401 4 | from .regnet import ResBottleneckBlock, regnetx_600m, regnetx_3200m # noqa: F401 5 | from .mobilenetv2 import InvertedResidual, mobilenetv2 # noqa: F401 6 | from .mnasnet import _InvertedResidual, mnasnet # noqa: F401 7 | from .vision_transformer import * # noqa: F403, F401 8 | from .vision_transformer import MultiHeadAttention, FeedForward, Encoder1DBlock, ViTEmbedding, ViTHead # noqa: F401 9 | # from qdrop.quantization.quantized_module import QuantizedLayer, QuantizedBlock, Quantizer # noqa: F401 10 | 11 | 12 | # class QuantBasicBlock(QuantizedBlock): 13 | # """ 14 | # Implementation of Quantized BasicBlock used in ResNet-18 and ResNet-34. 15 | # """ 16 | # def __init__(self, org_module: BasicBlock, w_qconfig, a_qconfig, qoutput=True): 17 | # super().__init__() 18 | # self.qoutput = qoutput 19 | # self.conv1_relu = QuantizedLayer(org_module.conv1, org_module.relu1, w_qconfig, a_qconfig) 20 | # self.conv2 = QuantizedLayer(org_module.conv2, None, w_qconfig, a_qconfig, qoutput=False) 21 | # if org_module.downsample is None: 22 | # self.downsample = None 23 | # else: 24 | # self.downsample = QuantizedLayer(org_module.downsample[0], None, w_qconfig, a_qconfig, qoutput=False) 25 | # self.activation = org_module.relu2 26 | # if self.qoutput: 27 | # self.block_post_act_fake_quantize = Quantizer(None, a_qconfig) 28 | 29 | # def forward(self, x): 30 | # residual = x if self.downsample is None else self.downsample(x) 31 | # out = self.conv1_relu(x) 32 | # out = self.conv2(out) 33 | # out += residual 34 | # out = self.activation(out) 35 | # if self.qoutput: 36 | # out = self.block_post_act_fake_quantize(out) 37 | # return out 38 | 39 | 40 | # class QuantBottleneck(QuantizedBlock): 41 | # """ 42 | # Implementation of Quantized Bottleneck Block used in ResNet-50, -101 and -152. 43 | # """ 44 | # def __init__(self, org_module: Bottleneck, w_qconfig, a_qconfig, qoutput=True): 45 | # super().__init__() 46 | # self.qoutput = qoutput 47 | # self.conv1_relu = QuantizedLayer(org_module.conv1, org_module.relu1, w_qconfig, a_qconfig) 48 | # self.conv2_relu = QuantizedLayer(org_module.conv2, org_module.relu2, w_qconfig, a_qconfig) 49 | # self.conv3 = QuantizedLayer(org_module.conv3, None, w_qconfig, a_qconfig, qoutput=False) 50 | 51 | # if org_module.downsample is None: 52 | # self.downsample = None 53 | # else: 54 | # self.downsample = QuantizedLayer(org_module.downsample[0], None, w_qconfig, a_qconfig, qoutput=False) 55 | # self.activation = org_module.relu3 56 | # if self.qoutput: 57 | # self.block_post_act_fake_quantize = Quantizer(None, a_qconfig) 58 | 59 | # def forward(self, x): 60 | # residual = x if self.downsample is None else self.downsample(x) 61 | # out = self.conv1_relu(x) 62 | # out = self.conv2_relu(out) 63 | # out = self.conv3(out) 64 | # out += residual 65 | # out = self.activation(out) 66 | # if self.qoutput: 67 | # out = self.block_post_act_fake_quantize(out) 68 | # return out 69 | 70 | 71 | # class QuantResBottleneckBlock(QuantizedBlock): 72 | # """ 73 | # Implementation of Quantized Bottleneck Blockused in RegNetX (no SE module). 74 | # """ 75 | # def __init__(self, org_module: ResBottleneckBlock, w_qconfig, a_qconfig, qoutput=True): 76 | # super().__init__() 77 | # self.qoutput = qoutput 78 | # self.conv1_relu = QuantizedLayer(org_module.f.a, org_module.f.a_relu, w_qconfig, a_qconfig) 79 | # self.conv2_relu = QuantizedLayer(org_module.f.b, org_module.f.b_relu, w_qconfig, a_qconfig) 80 | # self.conv3 = QuantizedLayer(org_module.f.c, None, w_qconfig, a_qconfig, qoutput=False) 81 | # if org_module.proj_block: 82 | # self.downsample = QuantizedLayer(org_module.proj, None, w_qconfig, a_qconfig, qoutput=False) 83 | # else: 84 | # self.downsample = None 85 | # self.activation = org_module.relu 86 | # if self.qoutput: 87 | # self.block_post_act_fake_quantize = Quantizer(None, a_qconfig) 88 | 89 | # def forward(self, x): 90 | # residual = x if self.downsample is None else self.downsample(x) 91 | # out = self.conv1_relu(x) 92 | # out = self.conv2_relu(out) 93 | # out = self.conv3(out) 94 | # out += residual 95 | # out = self.activation(out) 96 | # if self.qoutput: 97 | # out = self.block_post_act_fake_quantize(out) 98 | # return out 99 | 100 | 101 | # class QuantInvertedResidual(QuantizedBlock): 102 | # """ 103 | # Implementation of Quantized Inverted Residual Block used in MobileNetV2. 104 | # Inverted Residual does not have activation function. 105 | # """ 106 | # def __init__(self, org_module: InvertedResidual, w_qconfig, a_qconfig, qoutput=True): 107 | # super().__init__() 108 | # self.qoutput = qoutput 109 | # self.use_res_connect = org_module.use_res_connect 110 | # if org_module.expand_ratio == 1: 111 | # self.conv = nn.Sequential( 112 | # QuantizedLayer(org_module.conv[0], org_module.conv[2], w_qconfig, a_qconfig), 113 | # QuantizedLayer(org_module.conv[3], None, w_qconfig, a_qconfig, qoutput=False), 114 | # ) 115 | # else: 116 | # self.conv = nn.Sequential( 117 | # QuantizedLayer(org_module.conv[0], org_module.conv[2], w_qconfig, a_qconfig), 118 | # QuantizedLayer(org_module.conv[3], org_module.conv[5], w_qconfig, a_qconfig), 119 | # QuantizedLayer(org_module.conv[6], None, w_qconfig, a_qconfig, qoutput=False), 120 | # ) 121 | # if self.qoutput: 122 | # self.block_post_act_fake_quantize = Quantizer(None, a_qconfig) 123 | 124 | # def forward(self, x): 125 | # if self.use_res_connect: 126 | # out = x + self.conv(x) 127 | # else: 128 | # out = self.conv(x) 129 | # if self.qoutput: 130 | # out = self.block_post_act_fake_quantize(out) 131 | # return out 132 | 133 | 134 | # class _QuantInvertedResidual(QuantizedBlock): 135 | # # mnasnet 136 | # def __init__(self, org_module: _InvertedResidual, w_qconfig, a_qconfig, qoutput=True): 137 | # super().__init__() 138 | # self.qoutput = qoutput 139 | # self.apply_residual = org_module.apply_residual 140 | # self.conv = nn.Sequential( 141 | # QuantizedLayer(org_module.layers[0], org_module.layers[2], w_qconfig, a_qconfig), 142 | # QuantizedLayer(org_module.layers[3], org_module.layers[5], w_qconfig, a_qconfig), 143 | # QuantizedLayer(org_module.layers[6], None, w_qconfig, a_qconfig, qoutput=False), 144 | # ) 145 | # if self.qoutput: 146 | # self.block_post_act_fake_quantize = Quantizer(None, a_qconfig) 147 | 148 | # def forward(self, x): 149 | # if self.apply_residual: 150 | # out = x + self.conv(x) 151 | # else: 152 | # out = self.conv(x) 153 | # if self.qoutput: 154 | # out = self.block_post_act_fake_quantize(out) 155 | # return out 156 | 157 | 158 | # class QuantViTEmbedding(QuantizedBlock): 159 | # # just remove the output quantization, because we do it after add and layernorm 160 | # def __init__(self, org_module: ViTEmbedding, w_qconfig, a_qconfig, qoutput=False): 161 | # super().__init__() 162 | # self.patch_embedding = QuantizedLayer(org_module.patch_embedding, None, w_qconfig, a_qconfig, qoutput=False) 163 | # self.cls_type = org_module.cls_type 164 | # if self.cls_type == 'token': 165 | # self.pos_embedding = org_module.pos_embedding 166 | # self.cls_token = org_module.cls_token 167 | # elif self.cls_type == 'gap': 168 | # self.pos_embedding = org_module.pos_embedding 169 | 170 | # def forward(self, x): 171 | # x = self.patch_embedding(x) 172 | # x = x.flatten(2).transpose(1, 2) 173 | # # x shape: [B, N, K] 174 | 175 | # if self.cls_type == 'token': 176 | # cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) 177 | # x = torch.cat((cls_tokens, x), dim=1) 178 | # # x shape: [B, N+1, K] 179 | 180 | # x += self.pos_embedding 181 | # return x 182 | 183 | 184 | # class QuantViTHead(QuantizedBlock): 185 | # # just do the input quantization here, because there is layernorm and pool 186 | # def __init__(self, org_module: ViTEmbedding, w_qconfig, a_qconfig, qoutput=False): 187 | # super().__init__() 188 | # self.pool_post_act_fake_quantize = Quantizer(None, a_qconfig) 189 | # self.classifier = QuantizedLayer(org_module.classifier, None, w_qconfig, a_qconfig, qoutput=qoutput) 190 | # self.cls_type = org_module.cls_type 191 | 192 | # def forward(self, x): 193 | # if self.cls_type == 'token': 194 | # x = x[:, 0] 195 | # elif self.cls_type == 'gap': 196 | # x = torch.mean(x, dim=2, keepdim=False) 197 | # x = self.pool_post_act_fake_quantize(x) 198 | # return self.classifier(x) 199 | 200 | 201 | # class QuantViTMHA(QuantizedBlock): 202 | 203 | # def __init__(self, org_module: MultiHeadAttention, w_qconfig, a_qconfig, qoutput=True): 204 | # super().__init__() 205 | # self.heads = org_module.heads 206 | # self.scale = org_module.scale 207 | # self.to_qkv = QuantizedLayer(org_module.to_qkv, None, w_qconfig, a_qconfig, qoutput=False) 208 | # self.q_post_act_fake_quantize = Quantizer(None, a_qconfig) 209 | # self.k_post_act_fake_quantize = Quantizer(None, a_qconfig) 210 | # self.attn_post_act_fake_quantize = Quantizer(None, a_qconfig) 211 | # self.v_post_act_fake_quantize = Quantizer(None, a_qconfig) 212 | # self.context_post_act_fake_quantize = Quantizer(None, a_qconfig) 213 | # self.to_out = QuantizedLayer(org_module.to_out, None, w_qconfig, a_qconfig, qoutput=False) 214 | # self.qoutput = qoutput 215 | # if qoutput: 216 | # # in fact, not here, because there is also layernorm 217 | # self.attn_output_post_act_fake_quantize = Quantizer(None, a_qconfig) 218 | 219 | # def forward(self, x): 220 | # B, N, C = x.shape 221 | # qkv = self.to_qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 222 | # q, k, v = qkv[0], qkv[1], qkv[2] 223 | # q = self.q_post_act_fake_quantize(q) 224 | # k = self.k_post_act_fake_quantize(k) 225 | # v = self.v_post_act_fake_quantize(v) 226 | # attn = (q @ k.transpose(-2, -1)) * self.scale 227 | # attn = attn.softmax(dim=-1) 228 | # attn = self.attn_post_act_fake_quantize(attn) 229 | # x = (attn @ v).transpose(1, 2).reshape(B, N, C) 230 | # x = self.context_post_act_fake_quantize(x) 231 | # x = self.to_out(x) 232 | # if self.qoutput: 233 | # x = self.attn_output_post_act_fake_quantize(x) 234 | # return x 235 | 236 | 237 | # class QuantViTFFN(QuantizedBlock): 238 | # def __init__(self, org_module: FeedForward, w_qconfig, a_qconfig, qoutput=True): 239 | # super().__init__() 240 | # self.mlp1_act = QuantizedLayer(org_module.mlp1, org_module.act, w_qconfig, a_qconfig, True) 241 | # self.mlp2 = QuantizedLayer(org_module.mlp2, None, w_qconfig, a_qconfig, qoutput=False) 242 | # self.qoutput = qoutput 243 | # if qoutput: 244 | # self.mlp2_post_act_fake_quantize = Quantizer(None, a_qconfig) 245 | 246 | # def forward(self, x): 247 | # x = self.mlp1_act(x) 248 | # x = self.mlp2(x) 249 | # if self.qoutput: 250 | # x = self.mlp2_post_act_fake_quantize() 251 | # return x 252 | 253 | 254 | # class QuantEncoder1DBlock(QuantizedBlock): 255 | # def __init__(self, org_module: Encoder1DBlock, w_qconfig, a_qconfig, qoutput=False): 256 | # super().__init__() 257 | # self.norm1 = org_module.norm1 258 | # self.norm1_post_act_fake_quantize = Quantizer(None, a_qconfig) 259 | # self.attention = QuantViTMHA(org_module.attention, w_qconfig, a_qconfig, False) 260 | # self.norm2 = org_module.norm2 261 | # self.norm2_post_act_fake_quantize = Quantizer(None, a_qconfig) 262 | # self.feedforward = QuantViTFFN(org_module.feedforward, w_qconfig, a_qconfig, False) 263 | 264 | # def forward(self, x): 265 | # residual = x 266 | # x = self.norm1(x) 267 | # x = self.norm1_post_act_fake_quantize(x) 268 | # x = self.attention(x) 269 | # x = x + residual 270 | # y = self.norm2(x) 271 | # y = self.norm2_post_act_fake_quantize(y) 272 | # y = self.feedforward(y) 273 | # y += x 274 | # return y 275 | 276 | 277 | # specials = { 278 | # BasicBlock: QuantBasicBlock, 279 | # Bottleneck: QuantBottleneck, 280 | # ResBottleneckBlock: QuantResBottleneckBlock, 281 | # InvertedResidual: QuantInvertedResidual, 282 | # _InvertedResidual: _QuantInvertedResidual, 283 | # Encoder1DBlock: QuantEncoder1DBlock, 284 | # ViTHead: QuantViTHead, 285 | # ViTEmbedding: QuantViTEmbedding 286 | # } 287 | 288 | 289 | # def load_model(config): 290 | # config['kwargs'] = config.get('kwargs', dict()) 291 | # model = eval(config['type'])(**config['kwargs']) 292 | # checkpoint = torch.load(config.path, map_location='cpu') 293 | # if config.type == 'mobilenetv2': 294 | # checkpoint = checkpoint['model'] 295 | # model.load_state_dict(checkpoint) 296 | # return model 297 | -------------------------------------------------------------------------------- /model/mnasnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class _InvertedResidual(nn.Module): 6 | 7 | def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor): 8 | super(_InvertedResidual, self).__init__() 9 | assert stride in [1, 2] 10 | assert kernel_size in [3, 5] 11 | mid_ch = in_ch * expansion_factor 12 | self.apply_residual = (in_ch == out_ch and stride == 1) 13 | self.layers = nn.Sequential( 14 | # Pointwise 15 | nn.Conv2d(in_ch, mid_ch, 1, bias=False), 16 | BN(mid_ch), 17 | nn.ReLU(inplace=True), 18 | # Depthwise 19 | nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, 20 | stride=stride, groups=mid_ch, bias=False), 21 | BN(mid_ch), 22 | nn.ReLU(inplace=True), 23 | # Linear pointwise. Note that there's no activation. 24 | nn.Conv2d(mid_ch, out_ch, 1, bias=False), 25 | BN(out_ch)) 26 | 27 | def forward(self, input): 28 | if self.apply_residual: 29 | return self.layers(input) + input 30 | else: 31 | return self.layers(input) 32 | 33 | 34 | def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats): 35 | """ Creates a stack of inverted residuals. """ 36 | assert repeats >= 1 37 | # First one has no skip, because feature map size changes. 38 | first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor) 39 | remaining = [] 40 | for _ in range(1, repeats): 41 | remaining.append( 42 | _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor)) 43 | return nn.Sequential(first, *remaining) 44 | 45 | 46 | def _round_to_multiple_of(val, divisor, round_up_bias=0.9): 47 | """ Asymmetric rounding to make `val` divisible by `divisor`. With default 48 | bias, will round up, unless the number is no more than 10% greater than the 49 | smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ 50 | assert 0.0 < round_up_bias < 1.0 51 | new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) 52 | return new_val if new_val >= round_up_bias * val else new_val + divisor 53 | 54 | 55 | def _get_depths(scale): 56 | """ Scales tensor depths as in reference MobileNet code, prefers rouding up 57 | rather than down. """ 58 | depths = [32, 16, 24, 40, 80, 96, 192, 320] 59 | return [_round_to_multiple_of(depth * scale, 8) for depth in depths] 60 | 61 | 62 | class MNASNet(torch.nn.Module): 63 | # Version 2 adds depth scaling in the initial stages of the network. 64 | _version = 2 65 | 66 | def __init__(self, scale=2.0, num_classes=1000, dropout=0.0): 67 | super(MNASNet, self).__init__() 68 | 69 | global BN 70 | BN = nn.BatchNorm2d 71 | 72 | assert scale > 0.0 73 | self.scale = scale 74 | self.num_classes = num_classes 75 | depths = _get_depths(scale) 76 | layers = [ 77 | # First layer: regular conv. 78 | nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False), 79 | BN(depths[0]), 80 | nn.ReLU(inplace=True), 81 | # Depthwise separable, no skip. 82 | nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, 83 | groups=depths[0], bias=False), 84 | BN(depths[0]), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d(depths[0], depths[1], 1, 87 | padding=0, stride=1, bias=False), 88 | BN(depths[1]), 89 | # MNASNet blocks: stacks of inverted residuals. 90 | _stack(depths[1], depths[2], 3, 2, 3, 3), 91 | _stack(depths[2], depths[3], 5, 2, 3, 3), 92 | _stack(depths[3], depths[4], 5, 2, 6, 3), 93 | _stack(depths[4], depths[5], 3, 1, 6, 2), 94 | _stack(depths[5], depths[6], 5, 2, 6, 4), 95 | _stack(depths[6], depths[7], 3, 1, 6, 1), 96 | # Final mapping to classifier input. 97 | nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False), 98 | BN(1280), 99 | nn.ReLU(inplace=True), 100 | ] 101 | self.layers = nn.Sequential(*layers) 102 | self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), 103 | nn.Linear(1280, num_classes)) 104 | self._initialize_weights() 105 | 106 | def forward(self, x): 107 | x = self.layers(x) 108 | # Equivalent to global avgpool and removing H and W dimensions. 109 | x = x.mean([2, 3]) 110 | return self.classifier(x) 111 | 112 | def _initialize_weights(self): 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | nn.init.kaiming_normal_(m.weight, mode="fan_out", 116 | nonlinearity="relu") 117 | if m.bias is not None: 118 | nn.init.zeros_(m.bias) 119 | elif isinstance(m, nn.BatchNorm2d): 120 | nn.init.ones_(m.weight) 121 | nn.init.zeros_(m.bias) 122 | elif isinstance(m, nn.Linear): 123 | nn.init.kaiming_uniform_(m.weight, mode="fan_out", 124 | nonlinearity="sigmoid") 125 | nn.init.zeros_(m.bias) 126 | 127 | 128 | def mnasnet(**kwargs): 129 | model = MNASNet(**kwargs) 130 | return model 131 | -------------------------------------------------------------------------------- /model/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv_bn(inp, oup, stride): 6 | return nn.Sequential( 7 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 8 | nn.BatchNorm2d(oup), 9 | nn.ReLU6(inplace=True) 10 | ) 11 | 12 | 13 | def conv_1x1_bn(inp, oup): 14 | return nn.Sequential( 15 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 16 | nn.BatchNorm2d(oup), 17 | nn.ReLU6(inplace=True) 18 | ) 19 | 20 | 21 | class InvertedResidual(nn.Module): 22 | def __init__(self, inp, oup, stride, expand_ratio): 23 | super(InvertedResidual, self).__init__() 24 | self.stride = stride 25 | assert stride in [1, 2] 26 | 27 | hidden_dim = round(inp * expand_ratio) 28 | self.use_res_connect = self.stride == 1 and inp == oup 29 | self.expand_ratio = expand_ratio 30 | if expand_ratio == 1: 31 | self.conv = nn.Sequential( 32 | # dw 33 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 34 | nn.BatchNorm2d(hidden_dim), 35 | nn.ReLU6(inplace=True), 36 | # pw-linear 37 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 38 | nn.BatchNorm2d(oup), 39 | ) 40 | else: 41 | self.conv = nn.Sequential( 42 | # pw 43 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(hidden_dim), 45 | nn.ReLU6(inplace=True), 46 | # dw 47 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 48 | nn.BatchNorm2d(hidden_dim), 49 | nn.ReLU6(inplace=True), 50 | # pw-linear 51 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 52 | nn.BatchNorm2d(oup), 53 | ) 54 | 55 | def forward(self, x): 56 | if self.use_res_connect: 57 | return x + self.conv(x) 58 | else: 59 | return self.conv(x) 60 | 61 | 62 | class MobileNetV2(nn.Module): 63 | def __init__(self, num_classes=1000, input_size=224, width_mult=1., dropout=0.0): 64 | super(MobileNetV2, self).__init__() 65 | block = InvertedResidual 66 | input_channel = 32 67 | last_channel = 1280 68 | interverted_residual_setting = [ 69 | # t, c, n, s 70 | [1, 16, 1, 1], 71 | [6, 24, 2, 2], 72 | [6, 32, 3, 2], 73 | [6, 64, 4, 2], 74 | [6, 96, 3, 1], 75 | [6, 160, 3, 2], 76 | [6, 320, 1, 1], 77 | ] 78 | 79 | # building first layer 80 | assert input_size % 32 == 0 81 | input_channel = int(input_channel * width_mult) 82 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 83 | self.features = [conv_bn(3, input_channel, 2)] 84 | # building inverted residual blocks 85 | for t, c, n, s in interverted_residual_setting: 86 | output_channel = int(c * width_mult) 87 | for i in range(n): 88 | if i == 0: 89 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 90 | else: 91 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 92 | input_channel = output_channel 93 | # building last several layers 94 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 95 | # self.features.append(nn.AvgPool2d(input_size // 32)) 96 | # make it nn.Sequential 97 | self.features = nn.Sequential(*self.features) 98 | 99 | # building classifier 100 | self.classifier = nn.Sequential( 101 | nn.Dropout(dropout), 102 | nn.Linear(self.last_channel, num_classes), 103 | ) 104 | 105 | self._initialize_weights() 106 | 107 | def forward(self, x): 108 | x = self.features(x) 109 | x = x.mean([2, 3]) 110 | x = self.classifier(x) 111 | return x 112 | 113 | def _initialize_weights(self): 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | if m.bias is not None: 119 | m.bias.data.zero_() 120 | elif isinstance(m, nn.BatchNorm2d): 121 | m.weight.data.fill_(1) 122 | m.bias.data.zero_() 123 | elif isinstance(m, nn.Linear): 124 | n = m.weight.size(1) 125 | m.weight.data.normal_(0, 0.01) 126 | m.bias.data.zero_() 127 | 128 | 129 | def mobilenetv2(**kwargs): 130 | """ 131 | Constructs a MobileNetV2 model. 132 | """ 133 | model = MobileNetV2(**kwargs) 134 | return model 135 | -------------------------------------------------------------------------------- /model/regnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import math 4 | 5 | regnetX_200M_config = {'WA': 36.44, 'W0': 24, 'WM': 2.49, 'DEPTH': 13, 'GROUP_W': 8, 'SE_ON': False} 6 | regnetX_400M_config = {'WA': 24.48, 'W0': 24, 'WM': 2.54, 'DEPTH': 22, 'GROUP_W': 16, 'SE_ON': False} 7 | regnetX_600M_config = {'WA': 36.97, 'W0': 48, 'WM': 2.24, 'DEPTH': 16, 'GROUP_W': 24, 'SE_ON': False} 8 | regnetX_800M_config = {'WA': 35.73, 'W0': 56, 'WM': 2.28, 'DEPTH': 16, 'GROUP_W': 16, 'SE_ON': False} 9 | regnetX_1600M_config = {'WA': 34.01, 'W0': 80, 'WM': 2.25, 'DEPTH': 18, 'GROUP_W': 24, 'SE_ON': False} 10 | regnetX_3200M_config = {'WA': 26.31, 'W0': 88, 'WM': 2.25, 'DEPTH': 25, 'GROUP_W': 48, 'SE_ON': False} 11 | regnetX_4000M_config = {'WA': 38.65, 'W0': 96, 'WM': 2.43, 'DEPTH': 23, 'GROUP_W': 40, 'SE_ON': False} 12 | regnetX_6400M_config = {'WA': 60.83, 'W0': 184, 'WM': 2.07, 'DEPTH': 17, 'GROUP_W': 56, 'SE_ON': False} 13 | regnetY_200M_config = {'WA': 36.44, 'W0': 24, 'WM': 2.49, 'DEPTH': 13, 'GROUP_W': 8, 'SE_ON': True} 14 | regnetY_400M_config = {'WA': 27.89, 'W0': 48, 'WM': 2.09, 'DEPTH': 16, 'GROUP_W': 8, 'SE_ON': True} 15 | regnetY_600M_config = {'WA': 32.54, 'W0': 48, 'WM': 2.32, 'DEPTH': 15, 'GROUP_W': 16, 'SE_ON': True} 16 | regnetY_800M_config = {'WA': 38.84, 'W0': 56, 'WM': 2.4, 'DEPTH': 14, 'GROUP_W': 16, 'SE_ON': True} 17 | regnetY_1600M_config = {'WA': 20.71, 'W0': 48, 'WM': 2.65, 'DEPTH': 27, 'GROUP_W': 24, 'SE_ON': True} 18 | regnetY_3200M_config = {'WA': 42.63, 'W0': 80, 'WM': 2.66, 'DEPTH': 21, 'GROUP_W': 24, 'SE_ON': True} 19 | regnetY_4000M_config = {'WA': 31.41, 'W0': 96, 'WM': 2.24, 'DEPTH': 22, 'GROUP_W': 64, 'SE_ON': True} 20 | regnetY_6400M_config = {'WA': 33.22, 'W0': 112, 'WM': 2.27, 'DEPTH': 25, 'GROUP_W': 72, 'SE_ON': True} 21 | 22 | 23 | BN = nn.BatchNorm2d 24 | 25 | __all__ = ['regnetx_200m', 'regnetx_400m', 'regnetx_600m', 'regnetx_800m', 26 | 'regnetx_1600m', 'regnetx_3200m', 'regnetx_4000m', 'regnetx_6400m', 27 | 'regnety_200m', 'regnety_400m', 'regnety_600m', 'regnety_800m', 28 | 'regnety_1600m', 'regnety_3200m', 'regnety_4000m', 'regnety_6400m'] 29 | 30 | 31 | class SimpleStemIN(nn.Module): 32 | """Simple stem for ImageNet.""" 33 | 34 | def __init__(self, in_w, out_w): 35 | super(SimpleStemIN, self).__init__() 36 | self._construct(in_w, out_w) 37 | 38 | def _construct(self, in_w, out_w): 39 | # 3x3, BN, ReLU 40 | self.conv = nn.Conv2d( 41 | in_w, out_w, kernel_size=3, stride=2, padding=1, bias=False 42 | ) 43 | self.bn = BN(out_w) 44 | self.relu = nn.ReLU(True) 45 | 46 | def forward(self, x): 47 | for layer in self.children(): 48 | x = layer(x) 49 | return x 50 | 51 | 52 | class SE(nn.Module): 53 | """Squeeze-and-Excitation (SE) block""" 54 | 55 | def __init__(self, w_in, w_se): 56 | super(SE, self).__init__() 57 | self._construct(w_in, w_se) 58 | 59 | def _construct(self, w_in, w_se): 60 | # AvgPool 61 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 62 | # FC, Activation, FC, Sigmoid 63 | self.f_ex = nn.Sequential( 64 | nn.Conv2d(w_in, w_se, kernel_size=1, bias=True), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(w_se, w_in, kernel_size=1, bias=True), 67 | nn.Sigmoid(), 68 | ) 69 | 70 | def forward(self, x): 71 | return x * self.f_ex(self.avg_pool(x)) 72 | 73 | 74 | class BottleneckTransform(nn.Module): 75 | """Bottlenect transformation: 1x1, 3x3, 1x1""" 76 | 77 | def __init__(self, w_in, w_out, stride, bm, gw, se_r): 78 | super(BottleneckTransform, self).__init__() 79 | self._construct(w_in, w_out, stride, bm, gw, se_r) 80 | 81 | def _construct(self, w_in, w_out, stride, bm, gw, se_r): 82 | # Compute the bottleneck width 83 | w_b = int(round(w_out * bm)) 84 | # Compute the number of groups 85 | num_gs = w_b // gw 86 | # 1x1, BN, ReLU 87 | self.a = nn.Conv2d(w_in, w_b, kernel_size=1, stride=1, padding=0, bias=False) 88 | self.a_bn = BN(w_b) 89 | self.a_relu = nn.ReLU(True) 90 | # 3x3, BN, ReLU 91 | self.b = nn.Conv2d( 92 | w_b, w_b, kernel_size=3, stride=stride, padding=1, groups=num_gs, bias=False 93 | ) 94 | self.b_bn = BN(w_b) 95 | self.b_relu = nn.ReLU(True) 96 | # Squeeze-and-Excitation (SE) 97 | if se_r: 98 | w_se = int(round(w_in * se_r)) 99 | self.se = SE(w_b, w_se) 100 | # 1x1, BN 101 | self.c = nn.Conv2d(w_b, w_out, kernel_size=1, stride=1, padding=0, bias=False) 102 | self.c_bn = BN(w_out) 103 | self.c_bn.final_bn = True 104 | 105 | def forward(self, x): 106 | for layer in self.children(): 107 | x = layer(x) 108 | return x 109 | 110 | 111 | class ResBottleneckBlock(nn.Module): 112 | """Residual bottleneck block: x + F(x), F = bottleneck transform""" 113 | 114 | def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None): 115 | super(ResBottleneckBlock, self).__init__() 116 | self._construct(w_in, w_out, stride, bm, gw, se_r) 117 | 118 | def _add_skip_proj(self, w_in, w_out, stride): 119 | self.proj = nn.Conv2d( 120 | w_in, w_out, kernel_size=1, stride=stride, padding=0, bias=False 121 | ) 122 | self.bn = BN(w_out) 123 | 124 | def _construct(self, w_in, w_out, stride, bm, gw, se_r): 125 | # Use skip connection with projection if shape changes 126 | self.proj_block = (w_in != w_out) or (stride != 1) 127 | if self.proj_block: 128 | self._add_skip_proj(w_in, w_out, stride) 129 | self.f = BottleneckTransform(w_in, w_out, stride, bm, gw, se_r) 130 | self.relu = nn.ReLU(True) 131 | 132 | def forward(self, x): 133 | if self.proj_block: 134 | x = self.bn(self.proj(x)) + self.f(x) 135 | else: 136 | x = x + self.f(x) 137 | x = self.relu(x) 138 | return x 139 | 140 | 141 | class AnyHead(nn.Module): 142 | """AnyNet head.""" 143 | 144 | def __init__(self, w_in, nc): 145 | super(AnyHead, self).__init__() 146 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 147 | self.fc = nn.Linear(w_in, nc, bias=True) 148 | 149 | def forward(self, x): 150 | x = self.avg_pool(x) 151 | x = x.view(x.size(0), -1) 152 | x = self.fc(x) 153 | return x 154 | 155 | 156 | class AnyStage(nn.Module): 157 | """AnyNet stage (sequence of blocks w/ the same output shape).""" 158 | 159 | def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r): 160 | super(AnyStage, self).__init__() 161 | self._construct(w_in, w_out, stride, d, block_fun, bm, gw, se_r) 162 | 163 | def _construct(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r): 164 | # Construct the blocks 165 | for i in range(d): 166 | # Stride and w_in apply to the first block of the stage 167 | b_stride = stride if i == 0 else 1 168 | b_w_in = w_in if i == 0 else w_out 169 | # Construct the block 170 | self.add_module( 171 | "b{}".format(i + 1), block_fun(b_w_in, w_out, b_stride, bm, gw, se_r) 172 | ) 173 | 174 | def forward(self, x): 175 | for block in self.children(): 176 | x = block(x) 177 | return x 178 | 179 | 180 | class AnyNet(nn.Module): 181 | """AnyNet model.""" 182 | 183 | def __init__(self, **kwargs): 184 | super(AnyNet, self).__init__() 185 | if kwargs: 186 | self._construct( 187 | stem_w=kwargs["stem_w"], 188 | ds=kwargs["ds"], 189 | ws=kwargs["ws"], 190 | ss=kwargs["ss"], 191 | bms=kwargs["bms"], 192 | gws=kwargs["gws"], 193 | se_r=kwargs["se_r"], 194 | nc=kwargs["nc"], 195 | ) 196 | for m in self.modules(): 197 | if isinstance(m, nn.Conv2d): 198 | # Note that there is no bias due to BN 199 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 200 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) 201 | elif isinstance(m, nn.BatchNorm2d): 202 | m.weight.data.fill_(1) 203 | m.bias.data.zero_() 204 | elif isinstance(m, nn.Linear): 205 | n = m.weight.size(1) 206 | m.weight.data.normal_(0, 1.0 / float(n)) 207 | m.bias.data.zero_() 208 | 209 | def _construct(self, stem_w, ds, ws, ss, bms, gws, se_r, nc): 210 | # self.logger.info("Constructing AnyNet: ds={}, ws={}".format(ds, ws)) 211 | # Generate dummy bot muls and gs for models that do not use them 212 | bms = bms if bms else [1.0 for _d in ds] 213 | gws = gws if gws else [1 for _d in ds] 214 | # Group params by stage 215 | stage_params = list(zip(ds, ws, ss, bms, gws)) 216 | # Construct the stem 217 | self.stem = SimpleStemIN(3, stem_w) 218 | # Construct the stages 219 | block_fun = ResBottleneckBlock 220 | prev_w = stem_w 221 | for i, (d, w, s, bm, gw) in enumerate(stage_params): 222 | self.add_module( 223 | "s{}".format(i + 1), AnyStage(prev_w, w, s, d, block_fun, bm, gw, se_r) 224 | ) 225 | prev_w = w 226 | # Construct the head 227 | self.head = AnyHead(w_in=prev_w, nc=nc) 228 | 229 | def forward(self, x): 230 | for module in self.children(): 231 | x = module(x) 232 | return x 233 | 234 | 235 | def quantize_float(f, q): 236 | """Converts a float to closest non-zero int divisible by q.""" 237 | return int(round(f / q) * q) 238 | 239 | 240 | def adjust_ws_gs_comp(ws, bms, gs): 241 | """Adjusts the compatibility of widths and groups.""" 242 | ws_bot = [int(w * b) for w, b in zip(ws, bms)] 243 | gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)] 244 | ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)] 245 | ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)] 246 | return ws, gs 247 | 248 | 249 | def get_stages_from_blocks(ws, rs): 250 | """Gets ws/ds of network at each stage from per block values.""" 251 | ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs) 252 | ts = [w != wp or r != rp for w, wp, r, rp in ts_temp] 253 | s_ws = [w for w, t in zip(ws, ts[:-1]) if t] 254 | s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist() 255 | return s_ws, s_ds 256 | 257 | 258 | def generate_regnet(w_a, w_0, w_m, d, q=8): 259 | """Generates per block ws from RegNet parameters. 260 | 261 | args: 262 | w_a(float): slope 263 | w_0(int): initial width 264 | w_m(float): an additional parameter that controls quantization 265 | d(int): number of depth 266 | q(int): the coefficient of division 267 | 268 | procedure: 269 | 1. generate a linear parameterization for block widths. Eql(2) 270 | 2. compute corresponding stage for each block $log_{w_m}^{w_j/w_0}$. Eql(3) 271 | 3. compute per-block width via $w_0*w_m^(s_j)$ and qunatize them that can be divided by q. Eql(4) 272 | 273 | return: 274 | ws(list of quantized float): quantized width list for blocks in different stages 275 | num_stages(int): total number of stages 276 | max_stage(float): the maximal index of stage 277 | ws_cont(list of float): original width list for blocks in different stages 278 | """ 279 | assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0 280 | ws_cont = np.arange(d) * w_a + w_0 281 | ks = np.round(np.log(ws_cont / w_0) / np.log(w_m)) 282 | ws = w_0 * np.power(w_m, ks) 283 | ws = np.round(np.divide(ws, q)) * q 284 | num_stages, max_stage = len(np.unique(ws)), ks.max() + 1 285 | ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist() 286 | return ws, num_stages, max_stage, ws_cont 287 | 288 | 289 | class RegNet(AnyNet): 290 | """RegNet model class, based on 291 | `"Designing Network Design Spaces" `_ 292 | """ 293 | 294 | def __init__(self, cfg, bn=None): 295 | # Generate RegNet ws per block 296 | b_ws, num_s, _, _ = generate_regnet( 297 | cfg['WA'], cfg['W0'], cfg['WM'], cfg['DEPTH'] 298 | ) 299 | # Convert to per stage format 300 | ws, ds = get_stages_from_blocks(b_ws, b_ws) 301 | # Generate group widths and bot muls 302 | gws = [cfg['GROUP_W'] for _ in range(num_s)] 303 | bms = [1 for _ in range(num_s)] 304 | # Adjust the compatibility of ws and gws 305 | ws, gws = adjust_ws_gs_comp(ws, bms, gws) 306 | # Use the same stride for each stage, stride set to 2 307 | ss = [2 for _ in range(num_s)] 308 | # Use SE for RegNetY 309 | se_r = 0.25 if cfg['SE_ON'] else None 310 | # Construct the model 311 | STEM_W = 32 312 | 313 | global BN 314 | 315 | kwargs = { 316 | "stem_w": STEM_W, 317 | "ss": ss, 318 | "ds": ds, 319 | "ws": ws, 320 | "bms": bms, 321 | "gws": gws, 322 | "se_r": se_r, 323 | "nc": 1000, 324 | } 325 | super(RegNet, self).__init__(**kwargs) 326 | 327 | 328 | def regnetx_200m(**kwargs): 329 | """ 330 | Constructs a RegNet-X model under 200M FLOPs. 331 | """ 332 | model = RegNet(regnetX_200M_config, **kwargs) 333 | return model 334 | 335 | 336 | def regnetx_400m(**kwargs): 337 | """ 338 | Constructs a RegNet-X model under 400M FLOPs. 339 | """ 340 | model = RegNet(regnetX_400M_config, **kwargs) 341 | return model 342 | 343 | 344 | def regnetx_600m(**kwargs): 345 | """ 346 | Constructs a RegNet-X model under 600M FLOPs. 347 | """ 348 | model = RegNet(regnetX_600M_config, **kwargs) 349 | return model 350 | 351 | 352 | def regnetx_800m(**kwargs): 353 | """ 354 | Constructs a RegNet-X model under 800M FLOPs. 355 | """ 356 | model = RegNet(regnetX_800M_config, **kwargs) 357 | return model 358 | 359 | 360 | def regnetx_1600m(**kwargs): 361 | """ 362 | Constructs a RegNet-X model under 1600M FLOPs. 363 | """ 364 | model = RegNet(regnetX_1600M_config, **kwargs) 365 | return model 366 | 367 | 368 | def regnetx_3200m(**kwargs): 369 | """ 370 | Constructs a RegNet-X model under 3200M FLOPs. 371 | """ 372 | model = RegNet(regnetX_3200M_config, **kwargs) 373 | return model 374 | 375 | 376 | def regnetx_4000m(**kwargs): 377 | """ 378 | Constructs a RegNet-X model under 4000M FLOPs. 379 | """ 380 | model = RegNet(regnetX_4000M_config, **kwargs) 381 | return model 382 | 383 | 384 | def regnetx_6400m(**kwargs): 385 | """ 386 | Constructs a RegNet-X model under 6400M FLOPs. 387 | """ 388 | model = RegNet(regnetX_6400M_config, **kwargs) 389 | return model 390 | 391 | 392 | def regnety_200m(**kwargs): 393 | """ 394 | Constructs a RegNet-Y model under 200M FLOPs. 395 | """ 396 | model = RegNet(regnetY_200M_config, **kwargs) 397 | return model 398 | 399 | 400 | def regnety_400m(**kwargs): 401 | """ 402 | Constructs a RegNet-Y model under 400M FLOPs. 403 | """ 404 | model = RegNet(regnetY_400M_config, **kwargs) 405 | return model 406 | 407 | 408 | def regnety_600m(**kwargs): 409 | """ 410 | Constructs a RegNet-Y model under 600M FLOPs. 411 | """ 412 | model = RegNet(regnetY_600M_config, **kwargs) 413 | return model 414 | 415 | 416 | def regnety_800m(**kwargs): 417 | """ 418 | Constructs a RegNet-Y model under 800M FLOPs. 419 | """ 420 | model = RegNet(regnetY_800M_config, **kwargs) 421 | return model 422 | 423 | 424 | def regnety_1600m(**kwargs): 425 | """ 426 | Constructs a RegNet-Y model under 1600M FLOPs. 427 | """ 428 | model = RegNet(regnetY_1600M_config, **kwargs) 429 | return model 430 | 431 | 432 | def regnety_3200m(**kwargs): 433 | """ 434 | Constructs a RegNet-Y model under 3200M FLOPs. 435 | """ 436 | model = RegNet(regnetY_3200M_config, **kwargs) 437 | return model 438 | 439 | 440 | def regnety_4000m(**kwargs): 441 | """ 442 | Constructs a RegNet-Y model under 4000M FLOPs. 443 | """ 444 | model = RegNet(regnetY_4000M_config, **kwargs) 445 | return model 446 | 447 | 448 | def regnety_6400m(**kwargs): 449 | """ 450 | Constructs a RegNet-Y model under 6400M FLOPs. 451 | """ 452 | model = RegNet(regnetY_6400M_config, **kwargs) 453 | return model 454 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=dilation, groups=groups, bias=False, dilation=dilation) 9 | 10 | 11 | def conv1x1(in_planes, out_planes, stride=1): 12 | """1x1 convolution""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | __constants__ = ['downsample'] 19 | 20 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 21 | base_width=64, dilation=1, norm_layer=None): 22 | super(BasicBlock, self).__init__() 23 | if norm_layer is None: 24 | norm_layer = BN 25 | if groups != 1 or base_width != 64: 26 | raise ValueError( 27 | 'BasicBlock only supports groups=1 and base_width=64') 28 | if dilation > 1: 29 | raise NotImplementedError( 30 | "Dilation > 1 not supported in BasicBlock") 31 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = norm_layer(planes) 34 | self.relu1 = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = norm_layer(planes) 37 | self.downsample = downsample 38 | self.relu2 = nn.ReLU(inplace=True) 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | identity = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu1(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | identity = self.downsample(x) 53 | 54 | out += identity 55 | out = self.relu2(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | __constants__ = ['downsample'] 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 65 | base_width=64, dilation=1, norm_layer=None): 66 | super(Bottleneck, self).__init__() 67 | if norm_layer is None: 68 | norm_layer = BN 69 | width = int(planes * (base_width / 64.)) * groups 70 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 71 | self.conv1 = conv1x1(inplanes, width) 72 | self.bn1 = norm_layer(width) 73 | self.relu1 = nn.ReLU(inplace=True) 74 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 75 | self.bn2 = norm_layer(width) 76 | self.relu2 = nn.ReLU(inplace=True) 77 | self.conv3 = conv1x1(width, planes * self.expansion) 78 | self.bn3 = norm_layer(planes * self.expansion) 79 | self.relu3 = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | identity = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu1(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu2(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | identity = self.downsample(x) 99 | 100 | out += identity 101 | out = self.relu3(out) 102 | 103 | return out 104 | 105 | 106 | class ResNet(nn.Module): 107 | 108 | def __init__(self, 109 | block, 110 | layers, 111 | num_classes=1000, 112 | zero_init_residual=False, 113 | groups=1, 114 | width_per_group=64, 115 | replace_stride_with_dilation=None, 116 | deep_stem=False, 117 | avg_down=False): 118 | 119 | super(ResNet, self).__init__() 120 | 121 | global BN 122 | 123 | BN = torch.nn.BatchNorm2d 124 | norm_layer = BN 125 | 126 | self._norm_layer = norm_layer 127 | 128 | self.inplanes = 64 129 | self.dilation = 1 130 | self.deep_stem = deep_stem 131 | self.avg_down = avg_down 132 | 133 | if replace_stride_with_dilation is None: 134 | # each element in the tuple indicates if we should replace 135 | # the 2x2 stride with a dilated convolution instead 136 | replace_stride_with_dilation = [False, False, False] 137 | if len(replace_stride_with_dilation) != 3: 138 | raise ValueError("replace_stride_with_dilation should be None " 139 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 140 | self.groups = groups 141 | self.base_width = width_per_group 142 | 143 | if self.deep_stem: 144 | self.conv1 = nn.Sequential( 145 | nn.Conv2d(3, 32, kernel_size=3, stride=2, 146 | padding=1, bias=False), 147 | norm_layer(32), 148 | nn.ReLU(inplace=True), 149 | nn.Conv2d(32, 32, kernel_size=3, stride=1, 150 | padding=1, bias=False), 151 | norm_layer(32), 152 | nn.ReLU(inplace=True), 153 | nn.Conv2d(32, 64, kernel_size=3, stride=1, 154 | padding=1, bias=False), 155 | ) 156 | else: 157 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, 158 | stride=2, padding=3, bias=False) 159 | 160 | self.bn1 = norm_layer(self.inplanes) 161 | self.relu = nn.ReLU(inplace=True) 162 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 163 | self.layer1 = self._make_layer(block, 64, layers[0]) 164 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 165 | dilate=replace_stride_with_dilation[0]) 166 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 167 | dilate=replace_stride_with_dilation[1]) 168 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 169 | dilate=replace_stride_with_dilation[2]) 170 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 171 | self.fc = nn.Linear(512 * block.expansion, num_classes) 172 | 173 | for m in self.modules(): 174 | if isinstance(m, nn.Conv2d): 175 | nn.init.kaiming_normal_( 176 | m.weight, mode='fan_out', nonlinearity='relu') 177 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 178 | nn.init.constant_(m.weight, 1) 179 | nn.init.constant_(m.bias, 0) 180 | 181 | # Zero-initialize the last BN in each residual branch, 182 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 183 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 184 | if zero_init_residual: 185 | for m in self.modules(): 186 | if isinstance(m, Bottleneck): 187 | nn.init.constant_(m.bn3.weight, 0) 188 | elif isinstance(m, BasicBlock): 189 | nn.init.constant_(m.bn2.weight, 0) 190 | 191 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 192 | norm_layer = self._norm_layer 193 | downsample = None 194 | previous_dilation = self.dilation 195 | if dilate: 196 | self.dilation *= stride 197 | stride = 1 198 | if stride != 1 or self.inplanes != planes * block.expansion: 199 | if self.avg_down: 200 | downsample = nn.Sequential( 201 | nn.AvgPool2d(stride, stride=stride, 202 | ceil_mode=True, count_include_pad=False), 203 | conv1x1(self.inplanes, planes * block.expansion), 204 | norm_layer(planes * block.expansion), 205 | ) 206 | else: 207 | downsample = nn.Sequential( 208 | conv1x1(self.inplanes, planes * block.expansion, stride), 209 | norm_layer(planes * block.expansion), 210 | ) 211 | 212 | layers = [] 213 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 214 | self.base_width, previous_dilation, norm_layer)) 215 | self.inplanes = planes * block.expansion 216 | for _ in range(1, blocks): 217 | layers.append(block(self.inplanes, planes, groups=self.groups, 218 | base_width=self.base_width, dilation=self.dilation, 219 | norm_layer=norm_layer)) 220 | 221 | return nn.Sequential(*layers) 222 | 223 | def _forward_impl(self, x): 224 | # See note [TorchScript super()] 225 | x = self.conv1(x) 226 | x = self.bn1(x) 227 | x = self.relu(x) 228 | x = self.maxpool(x) 229 | x = self.layer1(x) 230 | x = self.layer2(x) 231 | x = self.layer3(x) 232 | x = self.layer4(x) 233 | 234 | x = self.avgpool(x) 235 | x = torch.flatten(x, 1) 236 | x = self.fc(x) 237 | 238 | return x 239 | 240 | def forward(self, x): 241 | return self._forward_impl(x) 242 | 243 | 244 | def resnet18(**kwargs): 245 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 246 | return model 247 | 248 | 249 | def resnet34(**kwargs): 250 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 251 | return model 252 | 253 | 254 | def resnet50(**kwargs): 255 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 256 | return model 257 | 258 | 259 | def resnet101(**kwargs): 260 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 261 | return model 262 | 263 | 264 | def resnet152(**kwargs): 265 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 266 | return model 267 | 268 | 269 | def resnext50_32x4d(**kwargs): 270 | kwargs['groups'] = 32 271 | kwargs['width_per_group'] = 4 272 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 273 | return model 274 | 275 | 276 | def resnext101_32x8d(**kwargs): 277 | kwargs['groups'] = 32 278 | kwargs['width_per_group'] = 8 279 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 280 | return model 281 | 282 | 283 | def wide_resnet50_2(**kwargs): 284 | kwargs['width_per_group'] = 64 * 2 285 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 286 | return model 287 | 288 | 289 | def wide_resnet101_2(**kwargs): 290 | kwargs['width_per_group'] = 64 * 2 291 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 292 | return model 293 | -------------------------------------------------------------------------------- /quant/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/L2_Compression/4d28b141930399cfb4990a63a3fae85c23a1b1e5/quant/__init__.py -------------------------------------------------------------------------------- /quant/linspace_centric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def cal_centric(scale, clip_min, clip_max): 4 | ''' 5 | infinite centric 6 | ''' 7 | centric = torch.arange(-1, ((clip_max - clip_min)/scale).item() + 2, 1, device=scale.device) * scale + clip_min 8 | return centric 9 | 10 | -------------------------------------------------------------------------------- /quant/soft_quant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from clib.lin_stat.lin_stat import Linear_Stat 4 | 5 | linstat = Linear_Stat(resolution=10) 6 | # from pytorch3d.ops import knn_points 7 | 8 | 9 | def distance_metric(matrix, centric): 10 | return torch.abs(matrix[:, None] - centric[None, :]) 11 | 12 | 13 | def get_cdf(matrix, x, resolution): 14 | # sorted_matrix = torch.sort(matrix.flatten()) 15 | interval = (matrix.max() - matrix.min())/ resolution 16 | quantized_matrix = torch.round(matrix / interval).flatten().long() 17 | # unique, inverse_indices, counts = torch.unique(quantized_matrix, return_inverse=True, return_counts=True) 18 | unique_offset = quantized_matrix.min() 19 | _ind = quantized_matrix - unique_offset 20 | zeros = torch.zeros(_ind.max()+2, device=matrix.device, dtype=torch.float32) 21 | spline_pmf = torch.scatter_add(zeros, 0, _ind+1, torch.ones_like(_ind).float()).long() 22 | # spline_pmf = spline_pmf 23 | # print("1") 24 | # print(_ind.max()+1) 25 | spline_x = (torch.arange(start=-1, end=_ind.max()+1, device=matrix.device, dtype=torch.float32) + unique_offset) * interval + 0.5 * interval 26 | # print("2") 27 | spline_cdf = torch.cumsum(spline_pmf, dim=0) 28 | 29 | query_x_idx = torch.searchsorted(spline_x, x) 30 | 31 | res_cdf_int = torch.ones_like(x).long() 32 | 33 | mask1 = query_x_idx - 1 < 0 34 | mask2 = (query_x_idx >= spline_cdf.shape[0]) 35 | inter_mask = ~(mask1 | mask2) 36 | corr = spline_pmf.sum() 37 | res_cdf_int[mask1] = 0 38 | res_cdf_int[mask2] = corr 39 | 40 | barycentric = (x[inter_mask] - spline_x[query_x_idx[inter_mask] - 1]) / interval 41 | res_cdf_int[inter_mask] = spline_cdf[query_x_idx[inter_mask]-1] 42 | 43 | res_cdf_float = torch.zeros_like(x) 44 | res_cdf_float[inter_mask] = barycentric * spline_pmf[query_x_idx[inter_mask]].float() 45 | 46 | return res_cdf_int.view(x.shape), res_cdf_float.view(x.shape), corr 47 | 48 | 49 | def get_cdf_new(matrix, x, resolution): 50 | linstat.resolution = resolution 51 | rescdf_i, rescdf_f = linstat(matrix, x) 52 | return rescdf_i, rescdf_f, torch.tensor(matrix.numel(), device=matrix.device) 53 | 54 | 55 | def aun(matrix, delta): 56 | x_tilde = matrix + (torch.rand_like(matrix) - 0.5) * delta 57 | return x_tilde 58 | 59 | 60 | def ste(matrix, scale): 61 | if torch.all(matrix==0): 62 | return torch.zeros_like(matrix) 63 | matrix = matrix / scale 64 | x_tile = torch.round(matrix) - matrix.detach() + matrix 65 | x_tile = x_tile * scale 66 | return x_tile 67 | 68 | 69 | def quant(matrix, scale): 70 | if torch.all(matrix==0): 71 | return torch.zeros_like(matrix) 72 | return torch.round(matrix / scale) * scale 73 | 74 | 75 | def get_bitrate_quant(matrix, x_tilde, delta, resolution): 76 | if torch.all(x_tilde==0): 77 | return torch.zeros_like(x_tilde) 78 | original_num = matrix.numel() 79 | if matrix.numel() > 128 * 128: 80 | sample_num = matrix.numel() // 256 81 | sample_idx = torch.randperm(matrix.numel())[:sample_num] 82 | matrix = matrix.flatten()[sample_idx] 83 | x_tilde = x_tilde.flatten()[sample_idx] 84 | probi1, probf1, corr1 = get_cdf(matrix, x_tilde+0.5*delta, resolution=resolution) 85 | probi2, probf2, corr2 = get_cdf(matrix, x_tilde-0.5*delta, resolution=resolution) 86 | prob = (probi1 - probi2).float() + (probf1 - probf2) 87 | if torch.sum(prob == 0): 88 | print("except", torch.sum(prob == 0).item()) 89 | # x_tilde[prob == 0] 90 | # [0.0329, 0.0430] 91 | # 0.0367 92 | raise ValueError 93 | import pdb 94 | pdb.set_trace() 95 | # # get_cdf(matrix, torch.tensor(0.0573)[None].to(matrix), 10) 96 | # prob = get_cdf(matrix, x_tilde + 0.5 * delta, 10) - get_cdf(matrix, x_tilde - 0.5 * delta, 10) 97 | return (-(torch.log2(prob) - torch.log2(corr1.float()))).mean() * original_num 98 | 99 | 100 | def get_bitrate_quant_new(matrix, x_tilde, delta, resolution): 101 | if torch.all(x_tilde==0): 102 | return torch.zeros_like(x_tilde) 103 | original_num = matrix.numel() 104 | if matrix.numel() > 128 * 128: 105 | sample_num = matrix.numel() // 128 106 | sample_idx = torch.randperm(matrix.numel())[:sample_num] 107 | matrix = matrix.flatten()[sample_idx] 108 | x_tilde = x_tilde.flatten()[sample_idx] 109 | probi1, probf1, corr1 = get_cdf_new(matrix, x_tilde+0.5*delta, resolution=resolution) 110 | probi2, probf2, corr2 = get_cdf_new(matrix, x_tilde-0.5*delta, resolution=resolution) 111 | prob = (probi1 - probi2).float() + (probf1 - probf2) 112 | if torch.sum(~torch.log2(prob).isfinite()): 113 | print("except", torch.sum(prob == 0).item()) 114 | # x_tilde[prob == 0] 115 | # [0.0329, 0.0430] 116 | # 0.0367 117 | # raise ValueError 118 | import pdb 119 | pdb.set_trace() 120 | # # get_cdf(matrix, torch.tensor(0.0573)[None].to(matrix), 10) 121 | # prob = get_cdf(matrix, x_tilde + 0.5 * delta, 10) - get_cdf(matrix, x_tilde - 0.5 * delta, 10) 122 | return (-(torch.log2(prob) - torch.log2(corr1.float()))).mean() * original_num 123 | 124 | 125 | def knn_soft_quant(matrix, centric, temprature=1, K=5): 126 | K = min(K, centric.shape[0]) 127 | get_cdf(matrix, centric, resolution=10) 128 | # dists, idx, _ = knn_points(matrix.flatten()[None, :, None], centric[None, :, None], K=K) 129 | # dists = dists[0] 130 | # idx = idx[0] 131 | logits = -torch.abs(matrix.flatten()[:, None] - centric[idx]) 132 | soft_one_hot = F.softmax(logits, dim=-1) 133 | # soft_one_hot = F.gumbel_softmax(logits, tau=temprature, hard=False, eps=1e-10, dim=-1) 134 | 135 | scatter_ind = idx.flatten() 136 | scatter_value = soft_one_hot.flatten() 137 | zero_soft_cnt = torch.zeros_like(centric) 138 | prob_per_centric = torch.scatter_add(zero_soft_cnt, 0, scatter_ind, scatter_value) 139 | prob_per_centric = prob_per_centric / prob_per_centric.sum() 140 | bitrate_per_centric = -(torch.log2(prob_per_centric)) 141 | # scatter_value[scatter_ind==1].sum() 142 | quant_res = torch.sum(soft_one_hot * centric[idx], dim=1, keepdim=False).view(matrix.shape) 143 | bitrate = torch.sum(soft_one_hot * bitrate_per_centric[idx], dim=1, keepdim=False).view(matrix.shape) 144 | return quant_res, bitrate 145 | 146 | 147 | def gumbel_soft_quant(matrix, centric, temprature=1): 148 | ''' 149 | 150 | :param matrix: any shape 151 | :param centric: 1D tensor 152 | :return: quantization, entropy 153 | ''' 154 | print(matrix.shape, centric.shape) 155 | logits = -distance_metric(matrix.flatten(), centric) 156 | soft_one_hot = F.gumbel_softmax(logits, tau=temprature, hard=False, eps=1e-10, dim=-1) 157 | prob_per_centric = torch.sum(soft_one_hot, dim=0, keepdim=True) 158 | prob_per_centric = prob_per_centric / prob_per_centric.sum() 159 | bitrate_per_centric = -(torch.log2(prob_per_centric)) 160 | quant_res = torch.sum(soft_one_hot * centric[None, :], dim=1, keepdim=False).view(matrix.shape) 161 | return quant_res, soft_one_hot * bitrate_per_centric 162 | 163 | 164 | 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | constriction 4 | ninja 5 | matplotlib 6 | timm -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | !/bin/bash 2 | 3 | # example ViT 4 | # need to install timm 5 | # high timm version may conflict 6 | python -u -m train.train_final \ 7 | --model_name ViT \ 8 | --lambda_r 1e-6 \ 9 | --lambda_kd 1.0 \ 10 | --weight_transform edgescale \ 11 | --bias_transform scale \ 12 | --transform_iter 300 \ 13 | --transform_lr 0.0001 \ 14 | --reconstruct_iter 1000 \ 15 | --reconstruct_lr 5e-06 \ 16 | --resolution 64 \ 17 | --diffkernel cos \ 18 | --log_path ./log \ 19 | --target_CR 10.0 \ 20 | --run_name vit_test \ 21 | 22 | # example ResNet18 23 | python -u -m train.train_final \ 24 | --model_name resnet18 \ 25 | --lambda_r 1e-6 \ 26 | --lambda_kd 1.0 \ 27 | --weight_transform edgescale \ 28 | --bias_transform scale \ 29 | --transform_iter 300 \ 30 | --transform_lr 0.0001 \ 31 | --reconstruct_iter 1000 \ 32 | --reconstruct_lr 5e-06 \ 33 | --resolution 64 \ 34 | --diffkernel cos \ 35 | --log_path ./log \ 36 | --target_CR 10.0 \ 37 | --run_name resnet_test -------------------------------------------------------------------------------- /sparse/soft_sparse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import numpy as np 5 | from quant.soft_quant import get_cdf, get_cdf_new 6 | 7 | from torch.autograd import Function 8 | 9 | def get_sigma(iter, std): 10 | MAX_ITER = 100 11 | MIN_ITER = 100 12 | range = (1e-1, 1e-3) 13 | if iter < MIN_ITER: 14 | sigma = range[0] 15 | else: 16 | iter -= MIN_ITER 17 | if iter < MAX_ITER: 18 | sigma = (1-iter/MAX_ITER) * range[1] + iter/MAX_ITER * range[0] 19 | # elif iter < 40: 20 | # sigma = 1e-4 21 | # elif iter < 60: 22 | # sigma = 1e-5 23 | else: 24 | sigma = range[0] 25 | 26 | return sigma * std 27 | 28 | 29 | def annealed_temperature(t, r, ub, lb=1e-8, backend=np, scheme='exp', **kwargs): 30 | """ 31 | Return the temperature at time step t, based on a chosen annealing schedule. 32 | :param t: step/iteration number 33 | :param r: decay strength 34 | :param ub: maximum/init temperature 35 | :param lb: small const like 1e-8 to prevent numerical issue when temperature gets too close to 0 36 | :param backend: np or tf 37 | :param scheme: 38 | :param kwargs: 39 | :return: 40 | """ 41 | default_t0 = 700 42 | if scheme == 'exp': 43 | tau = backend.exp(-r * t) 44 | elif scheme == 'exp0': 45 | # Modified version of above that fixes temperature at ub for initial t0 iterations 46 | t0 = kwargs.get('t0', default_t0) 47 | tau = ub * backend.exp(-r * (t - t0)) 48 | elif scheme == 'linear': 49 | # Cool temperature linearly from ub after the initial t0 iterations 50 | t0 = kwargs.get('t0', default_t0) 51 | tau = -r * (t - t0) + ub 52 | else: 53 | raise NotImplementedError 54 | 55 | if backend is None: 56 | return min(max(tau, lb), ub) 57 | else: 58 | return backend.minimum(backend.maximum(tau, lb), ub) 59 | 60 | 61 | def soft_sparse1(matrix, edge, iter, std): 62 | sigma = get_sigma(iter, std) 63 | matrix_diff = torch.abs(matrix) - torch.abs(edge) 64 | matrix_dis = matrix_diff * matrix_diff 65 | # logits_one = torch.sign(matrix_diff) * matrix_dis / sigma 66 | logits_one = matrix_diff / sigma 67 | 68 | 69 | logits_zero = torch.zeros_like(logits_one) 70 | logits = torch.stack([logits_zero, logits_one], dim=-1) 71 | 72 | # # TDOO: input outside 73 | # annealing_scheme = 'linear' 74 | # annealing_rate = 1e-3 # default annealing_rate = 1e-3 75 | # # annealing_rate = 5e-2 # default annealing_rate = 1e-3 76 | # t0 = 0 # default t0 = 700 for 2000 iters 77 | # # t0 = 100 # default t0 = 700 for 2000 iters 78 | # T_ub = 1.0 79 | # 80 | # 81 | # temprature = annealed_temperature(iter, r=annealing_rate, ub=T_ub, scheme=annealing_scheme, t0=t0) 82 | # soft_one_hot = F.gumbel_softmax(logits, tau=temprature, hard=False, eps=1e-10, dim=-1) 83 | # 84 | # mask = soft_one_hot[..., 1] 85 | 86 | mask = torch.sigmoid(logits_one) 87 | return mask * matrix, mask 88 | 89 | 90 | def ste(matrix): 91 | x_tile = torch.round(matrix) - matrix.detach() + matrix 92 | return x_tile 93 | 94 | 95 | class BinaryQuantize_m(Function): 96 | @staticmethod 97 | def forward(ctx, input): 98 | out = (torch.sign(input) + 1) / 2 99 | ctx.save_for_backward(out) 100 | 101 | return out 102 | 103 | @staticmethod 104 | def backward(ctx, grad_output): 105 | out = ctx.saved_tensors[0].long() 106 | grad_input = grad_output.clone() 107 | 108 | grad_input[out == 0] = -torch.abs(grad_input[out == 0]) 109 | grad_input[out == 1] = 0 #TODO 110 | return grad_input 111 | 112 | 113 | class BinaryQuantize(Function): 114 | @staticmethod 115 | def forward(ctx, input): 116 | out = (torch.sign(input) + 1) / 2 117 | ctx.save_for_backward(out) 118 | 119 | return out 120 | 121 | @staticmethod 122 | def backward(ctx, grad_output): 123 | out = ctx.saved_tensors[0].long() 124 | grad_input = grad_output.clone() 125 | return grad_input 126 | 127 | 128 | # def soft_sparse(matrix, edge): 129 | # mask1 = BinaryQuantize_m.apply(torch.abs(matrix) - torch.abs(edge)) 130 | # mask2 = BinaryQuantize.apply(torch.abs(matrix) - torch.abs(edge)) 131 | # return mask1 * matrix, mask2 132 | 133 | 134 | class FLIPGrad(Function): 135 | @staticmethod 136 | def forward(ctx, input, original): 137 | ctx.save_for_backward(original) 138 | 139 | return input 140 | 141 | @staticmethod 142 | def backward(ctx, grad_output): 143 | original = ctx.saved_tensors[0] 144 | grad_input = grad_output.clone() 145 | grad_input = -torch.sign(original) * torch.abs(grad_input) 146 | return grad_input, None 147 | 148 | 149 | def soft_sparse(matrix, edge): 150 | x = matrix / (2 * edge) 151 | x_mask = torch.abs(x) > 0.5 152 | x = -x.detach() + x 153 | x_tile = x * (2 * edge) 154 | # x_tile = FLIPGrad.apply(x_tile, matrix) 155 | return torch.where(x_mask, matrix, x_tile), x_mask.float() 156 | # mask1 = BinaryQuantize_m.apply(torch.abs(matrix) - torch.abs(edge)) 157 | # mask2 = BinaryQuantize.apply(torch.abs(matrix) - torch.abs(edge)) 158 | # return mask1 * matrix, mask2 159 | 160 | 161 | def sparse(matrix, edge): 162 | # matrix_diff = torch.abs(matrix) - torch.abs(edge) 163 | # matrix_dis = matrix_diff * matrix_diff 164 | # mask = torch.sigmoid(torch.sign(matrix_diff) * matrix_dis / 1e-7) 165 | mask_new = torch.abs(matrix) > torch.abs(edge) 166 | return matrix * mask_new, mask_new 167 | 168 | 169 | # def get_bitrate(matrix, edge, resolution): 170 | # # x = torch.round(matrix * 512) / 512 171 | # # x[matrix.abs() <= edge.abs()] = 0 172 | # # num1 = matrix.numel() - x.count_nonzero() 173 | # # x = x.nonzero() 174 | # # x = torch.round(x * 512) / 512 175 | # # x[x.abs() <= edge.abs()] = 0 176 | # # num2 = x.numel() - x.count_nonzero() 177 | # # x = x.nonzero() 178 | 179 | # x = torch.round(matrix * 256) / 256 180 | # x[x.abs() <= edge.abs()] = 0 181 | # x = x.flatten() 182 | # idx = x.nonzero().flatten() 183 | # x = x.index_select(dim=0, index=idx) 184 | 185 | # if x.shape[0] != 0: 186 | # x_min = x.abs().min() 187 | # x_min_ = x.min() 188 | # probi1, probf1, corr1 = get_cdf(matrix, edge.abs(), resolution=resolution) 189 | # probi2, probf2, corr2 = get_cdf(matrix, -edge.abs(), resolution=resolution) 190 | 191 | # probi3, probf3, corr3 = get_cdf(matrix, x + 0.5/256, resolution=resolution) 192 | # probi4, probf4, corr4 = get_cdf(matrix, x - 0.5/256, resolution=resolution) 193 | 194 | # probi5, probf5, corr5 = get_cdf(matrix, x.abs().min(), resolution=resolution) 195 | # probi6, probf6, corr6 = get_cdf(matrix, edge.abs(), resolution=resolution) 196 | 197 | # probi7, probf7, corr7 = get_cdf(matrix, -edge.abs(), resolution=resolution) 198 | # probi8, probf8, corr8 = get_cdf(matrix, -x.abs().min(), resolution=resolution) 199 | 200 | # prob1 = (probi1 - probi2).float() + (probf1 - probf2) 201 | # prob2 = (probi3 - probi4).float() + (probf3 - probf4) 202 | # prob3 = (probi5 - probi6).float() + (probf5 - probf6) 203 | # prob4 = (probi7 - probi8).float() + (probf7 - probf8) 204 | 205 | # prob = (-(torch.log2(prob2[prob2 > 0]) - torch.log2(corr1.float()))).sum()[None] 206 | # prob += -((torch.log2(prob1[prob1 > 0]) - torch.log2(corr1.float()))) * (prob1 / corr1.float()) * matrix.numel() 207 | # if prob3.item() > 0: 208 | # prob += -((torch.log2(prob3[prob3 > 0]) - torch.log2(corr1.float()))) * (prob3 / corr1.float()) * matrix.numel() 209 | # if prob4.item() > 0: 210 | # prob += -((torch.log2(prob4[prob4 > 0]) - torch.log2(corr1.float()))) * (prob4 / corr1.float()) * matrix.numel() 211 | # else: 212 | # probi1, probf1, corr1 = get_cdf(matrix, edge.abs(), resolution=resolution) 213 | # probi2, probf2, corr2 = get_cdf(matrix, -edge.abs(), resolution=resolution) 214 | 215 | # prob1 = (probi1 - probi2).float() + (probf1 - probf2) 216 | # prob = -((torch.log2(prob1[prob1 > 0]) - torch.log2(corr1.float()))) * (prob1 / corr1.float()) * matrix.numel() 217 | 218 | # return prob 219 | 220 | 221 | def get_bitrate(matrix, quant, scale, edge, resolution): 222 | x = quant.detach() 223 | zero_cnt = x.numel() - x.count_nonzero() 224 | x = x.flatten() 225 | idx = x.nonzero().flatten() 226 | x = x.index_select(dim=0, index=idx) 227 | # 找到离edge最近的整数值,并且去掉 228 | x_min = x.abs().min() 229 | x[x == x_min] = 0 230 | plus_cnt = x.numel() - x.count_nonzero() 231 | x[x == -x_min] = 0 232 | minus_cnt = x.numel() - x.count_nonzero() - plus_cnt 233 | idx = x.nonzero().flatten() 234 | x = x.index_select(dim=0, index=idx) 235 | 236 | if x.shape[0] != 0: 237 | # 0的cdf 238 | probi1, probf1, corr1 = get_cdf(matrix, edge.abs(), resolution=resolution) 239 | probi2, probf2, corr2 = get_cdf(matrix, -edge.abs(), resolution=resolution) 240 | 241 | # 非0非最小量化值的cdf 242 | probi3, probf3, corr3 = get_cdf(matrix, x + 0.5 * scale, resolution=resolution) 243 | probi4, probf4, corr4 = get_cdf(matrix, x - 0.5 * scale, resolution=resolution) 244 | 245 | # edge到最小量化值 246 | probi5, probf5, corr5 = get_cdf(matrix, x_min + 0.5 * scale, resolution=resolution) 247 | probi6, probf6, corr6 = get_cdf(matrix, edge.abs().detach(), resolution=resolution) 248 | 249 | probi7, probf7, corr7 = get_cdf(matrix, -edge.abs().detach(), resolution=resolution) 250 | probi8, probf8, corr8 = get_cdf(matrix, -x_min - 0.5 * scale, resolution=resolution) 251 | 252 | prob1 = (probi1 - probi2).float() + (probf1 - probf2) 253 | prob2 = (probi3 - probi4).float() + (probf3 - probf4) 254 | prob3 = (probi5 - probi6).float() + (probf5 - probf6) 255 | prob4 = (probi7 - probi8).float() + (probf7 - probf8) 256 | 257 | prob = -((torch.log2(prob1[prob1 > 0]) - torch.log2(corr1.float()))) * zero_cnt 258 | prob += (-(torch.log2(prob2[prob2 > 0]) - torch.log2(corr1.float()))).sum()[None] 259 | if prob3.item() > 0: 260 | prob += -((torch.log2(prob3[prob3 > 0]) - torch.log2(corr1.float()))) * plus_cnt 261 | if prob4.item() > 0: 262 | prob += -((torch.log2(prob4[prob4 > 0]) - torch.log2(corr1.float()))) * minus_cnt 263 | # print("zero_cnt {}, plus_cnt {}, minus cnt {}, matrix num {}".format(zero_cnt, plus_cnt, minus_cnt, matrix.numel())) 264 | # print("zero_prob {}, plus_prob {}, minus prob {}, matrix num{}".format(prob1, prob3, prob4, corr1)) 265 | else: 266 | probi1, probf1, corr1 = get_cdf(matrix, edge.abs(), resolution=resolution) 267 | probi2, probf2, corr2 = get_cdf(matrix, -edge.abs(), resolution=resolution) 268 | 269 | prob1 = (probi1 - probi2).float() + (probf1 - probf2) 270 | prob = -((torch.log2(prob1[prob1 > 0]) - torch.log2(corr1.float()))) * (prob1 / corr1.float()) * matrix.numel() 271 | 272 | return prob 273 | 274 | 275 | def get_bitrate_new(matrix, quant, scale, edge, resolution): 276 | if torch.all(quant==0): 277 | return torch.zeros_like(quant) 278 | 279 | sum = quant.numel() 280 | x_tilde = quant 281 | if matrix.numel() > 128 * 128: 282 | sample_num = matrix.numel() // 256 283 | sample_idx = torch.randperm(matrix.numel())[:sample_num] 284 | matrix = matrix.flatten()[sample_idx] 285 | x_tilde = quant.flatten()[sample_idx] 286 | 287 | x = quant.clone() 288 | zero_cnt = x.numel() - x.count_nonzero() 289 | x = x.flatten() 290 | idx = x.nonzero().flatten() 291 | x = x.index_select(dim=0, index=idx) 292 | # 找到离edge最近的整数值,并且去掉 293 | x_min_tmp = torch.abs(x.clone()) 294 | x_min = x_min_tmp.min() 295 | x[x == x_min] = 0 296 | plus_cnt = x.numel() - x.count_nonzero() 297 | x[x == -x_min] = 0 298 | minus_cnt = x.numel() - x.count_nonzero() - plus_cnt 299 | # idx = x.nonzero().flatten() 300 | # x = x.index_select(dim=0, index=idx) 301 | 302 | x_tilde[x_tilde == x_min] = 0 303 | x_tilde[x_tilde == -x_min] = 0 304 | idx = x_tilde.nonzero().flatten() 305 | x = x_tilde.index_select(dim=0, index=idx) 306 | 307 | if x.shape[0] != 0: 308 | # 0的cdf 309 | probi1, probf1, corr1 = get_cdf_new(matrix, edge.abs(), resolution=resolution) 310 | probi2, probf2, corr2 = get_cdf_new(matrix, -edge.abs(), resolution=resolution) 311 | 312 | # 非0非最小量化值的cdf 313 | probi3, probf3, corr3 = get_cdf_new(matrix, x + 0.5 * scale, resolution=resolution) 314 | probi4, probf4, corr4 = get_cdf_new(matrix, x - 0.5 * scale, resolution=resolution) 315 | 316 | # edge到最小量化值 317 | probi5, probf5, corr5 = get_cdf_new(matrix, x_min + 0.5 * scale, resolution=resolution) 318 | probi6, probf6, corr6 = get_cdf_new(matrix, edge.abs().detach(), resolution=resolution) 319 | 320 | probi7, probf7, corr7 = get_cdf_new(matrix, -edge.abs().detach(), resolution=resolution) 321 | probi8, probf8, corr8 = get_cdf_new(matrix, -x_min - 0.5 * scale, resolution=resolution) 322 | 323 | prob1 = (probi1 - probi2).float() + (probf1 - probf2) 324 | prob2 = (probi3 - probi4).float() + (probf3 - probf4) 325 | prob3 = (probi5 - probi6).float() + (probf5 - probf6) 326 | prob4 = (probi7 - probi8).float() + (probf7 - probf8) 327 | 328 | prob = -((torch.log2(prob1[prob1 > 0]) - torch.log2(corr1.float()))) * zero_cnt 329 | # prob += (-(torch.log2(prob2[prob2 > 0]) - torch.log2(corr1.float()))).sum()[None] 330 | prob += (-(torch.log2(prob2[prob2 > 0]) - torch.log2(corr1.float()))).mean() * (sum - zero_cnt - plus_cnt -minus_cnt) 331 | if prob3.item() > 0: 332 | prob += -((torch.log2(prob3[prob3 > 0]) - torch.log2(corr1.float()))) * plus_cnt 333 | if prob4.item() > 0: 334 | prob += -((torch.log2(prob4[prob4 > 0]) - torch.log2(corr1.float()))) * minus_cnt 335 | # print("zero_cnt {}, plus_cnt {}, minus cnt {}, matrix num {}".format(zero_cnt, plus_cnt, minus_cnt, matrix.numel())) 336 | # print("zero_prob {}, plus_prob {}, minus prob {}, matrix num{}".format(prob1, prob3, prob4, corr1)) 337 | else: 338 | probi1, probf1, corr1 = get_cdf_new(matrix, edge.abs(), resolution=resolution) 339 | probi2, probf2, corr2 = get_cdf_new(matrix, -edge.abs(), resolution=resolution) 340 | 341 | prob1 = (probi1 - probi2).float() + (probf1 - probf2) 342 | prob = -((torch.log2(prob1[prob1 > 0]) - torch.log2(corr1.float()))) * (prob1 / corr1.float()) * matrix.numel() 343 | 344 | return prob 345 | 346 | 347 | def get_bitrate_sparse(matrix, edge, resolution): 348 | # x = torch.round(matrix * 512) / 512 349 | # x[matrix.abs() <= edge.abs()] = 0 350 | # num1 = matrix.numel() - x.count_nonzero() 351 | # x = x.nonzero() 352 | # x = torch.round(x * 512) / 512 353 | # x[x.abs() <= edge.abs()] = 0 354 | # num2 = x.numel() - x.count_nonzero() 355 | # x = x.nonzero() 356 | 357 | # x = torch.round(matrix * 256) / 256 358 | # x[x.abs() <= edge.abs()] = 0 359 | # x = x.flatten() 360 | # idx = x.nonzero().flatten() 361 | # x = x.index_select(dim=0, index=idx) 362 | 363 | probi1, probf1, corr1 = get_cdf(matrix, edge.abs(), resolution=resolution) 364 | probi2, probf2, corr2 = get_cdf(matrix, -edge.abs(), resolution=resolution) 365 | 366 | prob1 = (probi1 - probi2).float() + (probf1 - probf2) 367 | prob = -((torch.log2(prob1[prob1 > 0]) - torch.log2(corr1.float()))) 368 | 369 | return prob -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/L2_Compression/4d28b141930399cfb4990a63a3fae85c23a1b1e5/train/__init__.py -------------------------------------------------------------------------------- /train/logger.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import matplotlib 6 | import os 7 | from matplotlib import pyplot 8 | import shutil 9 | from common.tools import to_np 10 | import pickle 11 | 12 | import time 13 | 14 | 15 | class Logger(): 16 | def __init__(self, code_path, root, run_name=None): 17 | super().__init__() 18 | if run_name == None: 19 | run_name = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) 20 | else: 21 | run_name = run_name 22 | self.root = os.path.join(root, run_name) 23 | if not os.path.exists(self.root): 24 | os.makedirs(self.root, exist_ok=True) 25 | # self.log_file = open(os.path.join(self.root, "log.txt"), "a") 26 | self.edge_dict = {} 27 | self.scale_dict = {} 28 | self.loss_dict = {} 29 | self.value_dict = {} 30 | self.history_state_train = {} 31 | self.history_state_test = {"special":{}} 32 | shutil.copy(code_path, self.root) 33 | 34 | def vis_hist(self, tensor_dict, title): 35 | ''' 36 | 37 | Args: 38 | tensor_list: dict of (name: Tensors) 39 | ''' 40 | np_list = [(t[0], t[1].detach().cpu().numpy().flatten()) for t in tensor_dict.items()] 41 | 42 | r_max = -99999999.0 43 | r_min = 99999999.0 44 | for name, ndarray in np_list: 45 | r_max = max(r_max, ndarray.max()) 46 | r_min = min(r_min, ndarray.min()) 47 | bins = np.linspace(r_min, r_max, 100) 48 | pyplot.cla() 49 | for name, ndarray in np_list: 50 | pyplot.hist(ndarray, bins, alpha=0.5, label=name) 51 | pyplot.legend(loc='upper right') 52 | pyplot.title(title) 53 | save_dir = os.path.join(self.root, "hist") 54 | if not os.path.exists(save_dir): 55 | os.makedirs(save_dir, exist_ok=True) 56 | pyplot.savefig(os.path.join(save_dir, "{}.png".format(title))) 57 | 58 | def record_edge(self, value, name): 59 | if name not in self.edge_dict: 60 | self.edge_dict[name] = [] 61 | self.edge_dict[name].append(value.item()) 62 | 63 | def record_scale(self, value, name): 64 | if name not in self.scale_dict: 65 | self.scale_dict[name] = [] 66 | self.scale_dict[name].append(value.item()) 67 | 68 | def record_value(self, value, name): 69 | if value is None: 70 | return 71 | if name not in self.value_dict: 72 | self.value_dict[name] = [] 73 | self.value_dict[name].append(value.item()) 74 | 75 | def record_loss(self, value, name): 76 | if name not in self.loss_dict: 77 | self.loss_dict[name] = [] 78 | self.loss_dict[name].append(value.item()) 79 | 80 | def log_curve_edge(self): 81 | save_dir = os.path.join(self.root, "curve") 82 | if not os.path.exists(save_dir): 83 | os.makedirs(save_dir, exist_ok=True) 84 | for name, data in self.edge_dict.items(): 85 | pyplot.cla() 86 | x = np.arange(1, len(data) + 1) 87 | pyplot.plot(x, np.array(data), 'o-') 88 | pyplot.title(name) 89 | pyplot.savefig(os.path.join(save_dir, "{}_edge.png".format(name))) 90 | 91 | def log_curve_scale(self): 92 | save_dir = os.path.join(self.root, "curve") 93 | if not os.path.exists(save_dir): 94 | os.makedirs(save_dir, exist_ok=True) 95 | for name, data in self.scale_dict.items(): 96 | pyplot.cla() 97 | x = np.arange(1, len(data) + 1) 98 | pyplot.plot(x, np.array(data), 'o-') 99 | pyplot.title(name) 100 | pyplot.savefig(os.path.join(save_dir, "{}_scale.png".format(name))) 101 | 102 | def log_curve_loss(self): 103 | save_dir = os.path.join(self.root, "curve") 104 | if not os.path.exists(save_dir): 105 | os.makedirs(save_dir, exist_ok=True) 106 | for name, data in self.loss_dict.items(): 107 | pyplot.cla() 108 | x = np.arange(1, len(data) + 1) 109 | pyplot.plot(x, np.array(data), 'o-') 110 | pyplot.title(name) 111 | pyplot.savefig(os.path.join(save_dir, "{}_loss.png".format(name))) 112 | 113 | def log_curve_value(self): 114 | save_dir = os.path.join(self.root, "curve") 115 | if not os.path.exists(save_dir): 116 | os.makedirs(save_dir, exist_ok=True) 117 | for name, data in self.value_dict.items(): 118 | pyplot.cla() 119 | x = np.arange(1, len(data) + 1) 120 | pyplot.plot(x, np.array(data), 'o-') 121 | pyplot.title(name) 122 | pyplot.savefig(os.path.join(save_dir, "{}_value.png".format(name))) 123 | 124 | def log(self, string): 125 | # self.log_file.write(string+"\n") 126 | # self.log_file.flush() 127 | print(string) 128 | 129 | def record_train_state(self, state, name): 130 | for key, substate in state.items(): 131 | outkey = "{}_{}".format(name, key) 132 | if outkey not in self.history_state_train: 133 | self.history_state_train[outkey] = {} 134 | if "trans_param" not in self.history_state_train[outkey]: 135 | self.history_state_train[outkey]["trans_param"] = {} 136 | for tk, tv in substate["trans_param"].items(): 137 | if tk not in self.history_state_train[outkey]["trans_param"]: 138 | self.history_state_train[outkey]["trans_param"][tk] = [] 139 | self.history_state_train[outkey]["trans_param"][tk].append(to_np(tv)) 140 | 141 | def save_train_state(self): 142 | pickle.dump(self.history_state_train, open(os.path.join(self.root, "train_state.pkl"), "wb")) 143 | 144 | def save_test_state(self, ): 145 | pickle.dump(self.history_state_test, open(os.path.join(self.root, "test_state.pkl"), "wb")) 146 | 147 | def record_test_value(self, name, value): 148 | self.history_state_test["special"][name] = value 149 | 150 | def record_test_state(self, state, name): 151 | for key, substate in state.items(): 152 | outkey = "{}_{}".format(name, key) 153 | if outkey not in self.history_state_test: 154 | self.history_state_test[outkey] = {} 155 | if "trans_param" not in self.history_state_test[outkey]: 156 | self.history_state_test[outkey]["trans_param"] = {} 157 | for tk, tv in substate["trans_param"].items(): 158 | if tk not in self.history_state_test[outkey]["trans_param"]: 159 | self.history_state_test[outkey]["trans_param"][tk] = [] 160 | np_tv = to_np(tv) 161 | self.history_state_test[outkey]["trans_param"][tk].append(np_tv) 162 | self.log("[{}][{}]: {}".format(outkey, tk, np_tv)) 163 | if "np_quant" not in self.history_state_test[outkey]: 164 | self.history_state_test[outkey]["np_quant"] = [] 165 | self.history_state_test[outkey]["np_quant"].append(substate["np_quant"]) 166 | self.log("[{}] symbols: {}, equal_bit: {}".format(outkey, substate["test_cnt_symbol"], math.log2(substate["test_cnt_symbol"]))) 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /train/ops.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, FloatTensor 2 | from torch import nn 3 | from torch.nn.common_types import _size_2_t 4 | 5 | 6 | class QConv2d(nn.Module): 7 | def __init__( 8 | self, 9 | conv: nn.Module 10 | ): 11 | self.scale = nn.Parameter(FloatTensor(1), requires_grad=True) 12 | 13 | max = conv.weight.max 14 | self.scale.data.fill_() 15 | 16 | def forward(self, input: Tensor) -> Tensor: 17 | sparse_weight = self.weight_fake_sparse(self.weight) 18 | return self._conv_forward(input, sparse_weight, self.bias) 19 | -------------------------------------------------------------------------------- /transform/edge_scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from common.tools import mysign, myabs 3 | 4 | class EdgeScale_T(): 5 | @staticmethod 6 | def encode_param(param, trans_param): 7 | edge = trans_param["edge"] 8 | scale = trans_param["scale"] 9 | param_sign = torch.sign(param) 10 | # res = torch.zeros_like(param) 11 | reserve_mask = torch.abs(param) > torch.abs(edge) 12 | sparse = (param / (2 * torch.abs(edge))) 13 | reserve = (param_sign * (0.5 + (torch.abs(param) - torch.abs(edge)) / torch.abs(scale))) 14 | # param = soft_sparse(param, edge)[0] 15 | # res = param / scale 16 | return torch.where(reserve_mask, reserve, sparse) 17 | @staticmethod 18 | def decode_param(code, trans_param): 19 | edge = trans_param["edge"] 20 | scale = trans_param["scale"] 21 | code_sign = torch.sign(code) 22 | # res = torch.zeros_like(code) 23 | reserve_mask = torch.abs(code) > 0.5 24 | sparse = (code * (2 * torch.abs(edge))) 25 | reserve = (code_sign * (torch.abs(edge) + (torch.abs(code) - 0.5) * torch.abs(scale))) 26 | return torch.where(reserve_mask, reserve, sparse) 27 | 28 | @staticmethod 29 | def get_init_trans_param(param): 30 | trans_param = {} 31 | device = param.device 32 | trans_param["edge"] = torch.tensor((param.max() - param.min()) / 256, device=device).requires_grad_(True) 33 | trans_param["scale"] = torch.tensor((param.max() - param.min()) / 256, device=device).requires_grad_(True) 34 | return trans_param 35 | 36 | @staticmethod 37 | def get_trainable_list(trans_param): 38 | return [trans_param["edge"], trans_param["scale"]] -------------------------------------------------------------------------------- /transform/exp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from common.tools import mysign, myabs 3 | 4 | class Exp_T(): 5 | @staticmethod 6 | def encode_param(param, trans_param): 7 | shift = trans_param["shift"] 8 | scale = trans_param["scale"] 9 | inner_scale = trans_param["inner_scale"] 10 | return mysign(param) * (torch.exp(myabs(param) / inner_scale) + shift) / scale 11 | @staticmethod 12 | def decode_param(code, trans_param): 13 | shift = trans_param["shift"] 14 | scale = trans_param["scale"] 15 | inner_scale = trans_param["inner_scale"] 16 | return mysign(code) * torch.log(myabs(code) * scale - shift) * inner_scale 17 | 18 | @staticmethod 19 | def get_init_trans_param(param): 20 | trans_param = {} 21 | device = param.device 22 | trans_param["shift"] = torch.tensor(-1.0, device=device).requires_grad_(True) 23 | trans_param["inner_scale"] = torch.tensor((param.abs().max() / 0.69314718056), device=device).requires_grad_(True) 24 | trans_param["scale"] = torch.tensor(1.0 / 64, device=device).requires_grad_(True) 25 | return trans_param 26 | 27 | @staticmethod 28 | def get_trainable_list(trans_param): 29 | return [trans_param["inner_scale"], trans_param["scale"]] -------------------------------------------------------------------------------- /transform/log.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from common.tools import mysign, myabs 3 | 4 | class Log_T(): 5 | @staticmethod 6 | def encode_param(param, trans_param): 7 | shift = trans_param["shift"] 8 | scale = trans_param["scale"] 9 | inner_scale = trans_param["inner_scale"] 10 | return mysign(param) * torch.log(shift + myabs(param) / inner_scale) / scale 11 | @staticmethod 12 | def decode_param(code, trans_param): 13 | shift = trans_param["shift"] 14 | scale = trans_param["scale"] 15 | inner_scale = trans_param["inner_scale"] 16 | return mysign(code) * (torch.exp(myabs(code) * scale) - shift) * inner_scale 17 | 18 | @staticmethod 19 | def get_init_trans_param(param): 20 | trans_param = {} 21 | device = param.device 22 | trans_param["shift"] = torch.tensor(1.0, device=device).requires_grad_(True) 23 | trans_param["inner_scale"] = torch.tensor((param.abs().max() / 1.718281828459045), device=device).requires_grad_(True) 24 | trans_param["scale"] = torch.tensor(1.0 / 64, device=device).requires_grad_(True) 25 | return trans_param 26 | 27 | @staticmethod 28 | def get_trainable_list(trans_param): 29 | return [trans_param["inner_scale"], trans_param["scale"]] -------------------------------------------------------------------------------- /transform/ms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from common.tools import mysign, myabs 3 | 4 | class MS_T(): 5 | @staticmethod 6 | def encode_param(param, trans_param): 7 | param_range = trans_param["param_range"] 8 | scale = trans_param["scale"] 9 | 10 | assert param_range.shape[0] + 1 == scale.shape[0] 11 | param_sign = mysign(param) 12 | res = torch.zeros_like(param) 13 | filled = torch.zeros_like(param).bool() 14 | base_last = 0 15 | range_last = 0 16 | for i in range(len(param_range)): 17 | mask = (myabs(param) < param_range[i]) & (~filled) 18 | res[mask] = (base_last + (myabs(param) - range_last) / myabs(scale[i]))[mask] 19 | filled = filled | mask 20 | base_last += ((param_range[i] - range_last) / myabs(scale[i])) 21 | range_last = param_range[i] 22 | res[~filled] = (base_last + (myabs(param) - range_last) / myabs(scale[-1]))[~filled] 23 | return res * param_sign 24 | 25 | @staticmethod 26 | def decode_param(code, trans_param): 27 | param_range = trans_param["param_range"] 28 | scale = trans_param["scale"] 29 | 30 | assert param_range.shape[0] + 1 == scale.shape[0] 31 | code_sign = mysign(code) 32 | res = torch.zeros_like(code) 33 | filled = torch.zeros_like(code).bool() 34 | base_last = 0 35 | range_last = 0 36 | for i in range(len(param_range)): 37 | base_now = (base_last + (param_range[i] - range_last) / scale[i]) 38 | mask = (myabs(code) < base_now) & (~filled) 39 | res[mask] = (range_last + (myabs(code) - base_last) * scale[i])[mask] 40 | filled = filled | mask 41 | base_last = base_now 42 | range_last = param_range[i] 43 | 44 | res[~filled] = (range_last + (myabs(code) - base_last) * scale[-1])[~filled] 45 | return res * code_sign 46 | 47 | @staticmethod 48 | def get_init_trans_param(param): 49 | trans_param = {} 50 | device = param.device 51 | NUM_LIN = 5 52 | trans_param["scale"] = torch.full((NUM_LIN,), (param.max() - param.min()) / 256, device=device).requires_grad_(True) 53 | trans_param["param_range"] = (torch.arange(1, NUM_LIN, device=device, dtype=torch.float32) * (param.abs().max() / NUM_LIN)).detach().requires_grad_(False) 54 | return trans_param 55 | 56 | @staticmethod 57 | def get_trainable_list(trans_param): 58 | return [trans_param["scale"],] -------------------------------------------------------------------------------- /transform/scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from common.tools import mysign, myabs 3 | 4 | class Scale_T(): 5 | @staticmethod 6 | def encode_param(param, trans_param): 7 | scale = trans_param["scale"] 8 | return param / scale 9 | 10 | @staticmethod 11 | def decode_param(code, trans_param): 12 | scale = trans_param["scale"] 13 | return code * scale 14 | 15 | @staticmethod 16 | def get_init_trans_param(param): 17 | trans_param = {} 18 | device = param.device 19 | trans_param["scale"] = torch.tensor((param.max() - param.min()) / 256, device=device).requires_grad_(True) 20 | return trans_param 21 | 22 | @staticmethod 23 | def get_trainable_list(trans_param): 24 | return [trans_param["scale"],] --------------------------------------------------------------------------------