├── tests ├── __init__.py ├── unit │ ├── __init__.py │ └── test_matrix_block_slices.py ├── integration │ ├── __init__.py │ ├── test_LinearRegression.py │ ├── test_FFNN.py │ ├── test_resnet.py │ ├── test_FeatherNet.py │ └── test_resnet_main.py ├── performance │ ├── __init__.py │ ├── test_mm_lazy_eval.py │ └── test_latency.py └── exploratory │ ├── linear_hook.py │ ├── bin_packing.py │ ├── grad_fns.py │ ├── logger_test.py │ ├── LinearReg.py │ └── hashable_params.py ├── feathermap ├── __init__.py ├── models │ ├── __init__.py │ ├── credit.txt │ ├── lenet.py │ ├── vgg.py │ ├── mobilenet.py │ ├── mobilenetv2.py │ ├── googlenet.py │ ├── resnext.py │ ├── dpn.py │ ├── densenet.py │ ├── shufflenet.py │ ├── senet.py │ ├── preact_resnet.py │ ├── pnasnet.py │ ├── resnet.py │ ├── regnet.py │ ├── shufflenetv2.py │ └── efficientnet.py ├── dataloader.py ├── utils.py ├── train.py └── feathernet.py ├── references ├── smh_1.png ├── smh_2.png ├── smh_3.png ├── resnet34_acc_latency.png └── resnet34_acc_latency_og.png ├── environment.yaml ├── setup.py ├── Dockerfile ├── LICENSE ├── TODO.md ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /feathermap/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /feathermap/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/performance/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /feathermap/models/credit.txt: -------------------------------------------------------------------------------- 1 | Models attributed to https://github.com/kuangliu/pytorch-cifar ! 2 | -------------------------------------------------------------------------------- /references/smh_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phelps-matthew/FeatherMap/HEAD/references/smh_1.png -------------------------------------------------------------------------------- /references/smh_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phelps-matthew/FeatherMap/HEAD/references/smh_2.png -------------------------------------------------------------------------------- /references/smh_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phelps-matthew/FeatherMap/HEAD/references/smh_3.png -------------------------------------------------------------------------------- /references/resnet34_acc_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phelps-matthew/FeatherMap/HEAD/references/resnet34_acc_latency.png -------------------------------------------------------------------------------- /references/resnet34_acc_latency_og.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phelps-matthew/FeatherMap/HEAD/references/resnet34_acc_latency_og.png -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: /home/mgp/Insight/project/FeatherMap/env 2 | channels: 3 | - defaults 4 | dependencies: 5 | - numpy 6 | - pandas 7 | - packaging 8 | - pytorch 9 | - torchvision 10 | - matplotlib 11 | prefix: /home/mgp/Insight/project/FeatherMap/env 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="feathermap", 5 | version="1.0", 6 | install_requires=[ 7 | "torch", 8 | "torchvision", 9 | "packaging", 10 | "numpy", 11 | "pandas", 12 | "matplotlib", 13 | ], 14 | ) 15 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # set base image 2 | FROM python:3.8 3 | 4 | # copy FeatherMap package into container 5 | COPY setup.py . 6 | COPY feathermap ./feathermap 7 | 8 | # install feathermap package 9 | RUN pip3 install -e . 10 | 11 | # command to run on container start 12 | ENTRYPOINT ["python", "feathermap/train.py"] 13 | -------------------------------------------------------------------------------- /tests/exploratory/linear_hook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | model = nn.Sequential(nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()) 6 | x = torch.randn(1, 1, 32, 32) 7 | 8 | 9 | def prehook(module, inputs): 10 | print("Prehook: ", module) 11 | 12 | 13 | def posthook(module, inputs, outputs): 14 | print("Posthook: ", module) 15 | 16 | 17 | for module in model.modules(): 18 | module.register_forward_pre_hook(prehook) 19 | module.register_forward_hook(posthook) 20 | 21 | print(x, model(x)) 22 | -------------------------------------------------------------------------------- /feathermap/models/lenet.py: -------------------------------------------------------------------------------- 1 | '''LeNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, 5) 9 | self.conv2 = nn.Conv2d(6, 16, 5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, 10) 13 | 14 | def forward(self, x): 15 | out = F.relu(self.conv1(x)) 16 | out = F.max_pool2d(out, 2) 17 | out = F.relu(self.conv2(out)) 18 | out = F.max_pool2d(out, 2) 19 | out = out.view(out.size(0), -1) 20 | out = F.relu(self.fc1(out)) 21 | out = F.relu(self.fc2(out)) 22 | out = self.fc3(out) 23 | return out 24 | -------------------------------------------------------------------------------- /tests/exploratory/bin_packing.py: -------------------------------------------------------------------------------- 1 | from rectpack import newPacker 2 | from feathermap.models.resnet import ResNet, ResidualBlock 3 | from feathermap.models.feathernet import FeatherNet 4 | import torch.nn as nn 5 | 6 | 7 | base_model = ResNet(ResidualBlock, [2, 2, 2]) 8 | model = FeatherNet(base_model, exclude=(nn.BatchNorm2d)) 9 | 10 | 11 | packer = newPacker() 12 | 13 | for name, tensor in model.get_WandB(): 14 | if tensor.dim() > 2: 15 | tensor = tensor.flatten(end_dim=-2) 16 | if tensor.dim() == 1: 17 | tensor = tensor.view(-1, 1) 18 | rect_obj = (tensor.size(1), tensor.size(0), name) 19 | packer.add_rect(*rect_obj) 20 | print(rect_obj) 21 | 22 | bin_factor = 10 23 | packer.add_bin(bin_factor * model.size_n, bin_factor * model.size_n) 24 | packer.pack() 25 | 26 | for r in packer.rect_list(): 27 | print(r) 28 | -------------------------------------------------------------------------------- /tests/exploratory/grad_fns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | 6 | V1 = Parameter(torch.randn(3, 3, requires_grad=True)) 7 | V2 = Parameter(torch.randn(3, 3, requires_grad=True)) 8 | W = torch.randn(2, 2) 9 | bias = torch.zeros(2) 10 | 11 | 12 | def update(V, W): 13 | V = torch.matmul(V1, V2.transpose(0, 1)) 14 | i = 0 15 | V = V.view(-1, 1) 16 | W = W.view(-1, 1) 17 | for j in range(len(W)): 18 | W[j] = V[i] 19 | i += 1 20 | 21 | 22 | def forward(x, W, bias): 23 | return F.linear(x, W, bias) 24 | 25 | 26 | print("V {}".format(V)) 27 | print("W {}".format(W)) 28 | update(V, W) 29 | print("V {}".format(V)) 30 | print("W {}".format(W)) 31 | 32 | 33 | x = torch.randn(2) 34 | g = torch.ones(2) 35 | print(x) 36 | print(forward(x, W, bias).norm) 37 | y = forward(x, W, bias) 38 | print(y) 39 | print(y.reshape(-1,1)) 40 | loss_fn = F.cross_entropy(y.reshape(1, -1), torch.ones(1, 2)) 41 | print(loss_fn) 42 | 43 | forward(x, W, bias).backward(g) 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Matthew Phelps 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/exploratory/logger_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from mod1 import mod1fn 4 | 5 | 6 | def set_logger(filepath): 7 | root = logging.basicConfig( 8 | filename=str(filepath), 9 | filemode="w", # will rewrite on each run 10 | level=logging.DEBUG, 11 | format="[%(asctime)s] %(levelname)s - %(message)s", 12 | ) 13 | handler = logging.StreamHandler(sys.stdout) 14 | handler.setLevel(logging.DEBUG) 15 | formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s") 16 | handler.setFormatter(formatter) 17 | return root 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | logger.setLevel(logging.DEBUG) 22 | logging.basicConfig( 23 | filename=str("testlog2.log"), 24 | filemode="w", # will rewrite on each run 25 | format="[%(asctime)s] %(levelname)s - %(message)s", 26 | ) 27 | handler = logging.StreamHandler(sys.stdout) 28 | handler.setLevel(logging.DEBUG) 29 | formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s") 30 | handler.setFormatter(formatter) 31 | logger.addHandler(handler) 32 | # logging.setFormatter(formatter) 33 | logger.info("adsf") 34 | print("from main..") 35 | 36 | mod1fn() 37 | 38 | # root = logging.getLogger() 39 | # root.setLevel(logging.DEBUG) 40 | 41 | 42 | # root.info("asdf") 43 | -------------------------------------------------------------------------------- /tests/performance/test_mm_lazy_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from feathermap.utils import timed 3 | from math import sqrt 4 | 5 | 6 | dim_in = 2 ** 14 7 | dim_out = 2 ** 4 8 | A = torch.randn(dim_in, dim_out) 9 | B = torch.randn(dim_out, dim_in) 10 | C = torch.rand(dim_in, dim_in) 11 | D = torch.rand(dim_in, dim_in) 12 | E = torch.rand(1, dim_out) 13 | F = torch.rand(dim_out, dim_in) 14 | G = torch.rand(int(sqrt(dim_in)), int(sqrt(dim_in))) 15 | H = torch.rand(int(sqrt(dim_in)), int(sqrt(dim_in))) 16 | 17 | 18 | @timed 19 | def mam(a, b): 20 | for _ in range(10000): 21 | out = torch.mm(a, b) 22 | return out 23 | 24 | 25 | def loop(a, b): 26 | for i in range(a.size(0)): 27 | for j in range(b.size(1)): 28 | yield a[i, :] @ b[:, j] 29 | 30 | 31 | def loop2(a, b): 32 | for i in range(a.size(0)): 33 | for j in range(b.size(1)): 34 | yield 1 35 | 36 | 37 | def tmm(a, b): 38 | c = torch.mm(a, b).view(-1, 1) 39 | return iter(c) 40 | 41 | 42 | @timed 43 | def run(c, dim_in): 44 | d = torch.empty(dim_in ** 2) 45 | for i in range(d.numel()): 46 | d[i] = next(c) 47 | 48 | 49 | mam(E, F) # about 23% faster 50 | mam(G, H) 51 | 52 | 53 | # run(loop(A, B), dim_in) # 739 54 | # run(loop2(A, B), dim_in) # 254 55 | # run(tmm(A, B), dim_in) # 289 56 | -------------------------------------------------------------------------------- /feathermap/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | class VGG(nn.Module): 15 | def __init__(self, vgg_name): 16 | super(VGG, self).__init__() 17 | self.features = self._make_layers(cfg[vgg_name]) 18 | self.classifier = nn.Linear(512, 10) 19 | 20 | def forward(self, x): 21 | out = self.features(x) 22 | out = out.view(out.size(0), -1) 23 | out = self.classifier(out) 24 | return out 25 | 26 | def _make_layers(self, cfg): 27 | layers = [] 28 | in_channels = 3 29 | for x in cfg: 30 | if x == 'M': 31 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 32 | else: 33 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(x), 35 | nn.ReLU(inplace=True)] 36 | in_channels = x 37 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 38 | return nn.Sequential(*layers) 39 | 40 | 41 | def test(): 42 | net = VGG('VGG11') 43 | x = torch.randn(2,3,32,32) 44 | y = net(x) 45 | print(y.size()) 46 | 47 | # test() 48 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | # To Do 2 | - [X] Create iterator over params with nn.Module specific exclusion capability 3 | - [X] In `unregister_params`, need to initialize to proper tensor size 4 | - [X] Find global weight matrix dimensions 5 | - [X] Establish initialization method 6 | - [X] Test FFNN 7 | - [X] Git branching 8 | - [X] Recheck initialization method 9 | - [X] Handle BatchNorm2d `fan_in` 10 | - [X] Handle train() and eval() appropriately 11 | - [X] argparse for jobs 12 | - [X] Test ResNet 13 | - [X] Docker containerization locally 14 | - [X] Add logging for training times and accuracy 15 | - [X] Add validation dataloader 16 | - [X] Compute weights 'on the fly' from variable pool 17 | - [X] forward() prehook and posthook 18 | - [X] global method 19 | - [X] create deploy() mode (method) 20 | - [ ] test on the fly calcs, ensuring same accuracy results 21 | - [X] obtain minimum allowable compression 22 | - [ ] deploy() on GPU 23 | - [ ] ~~REST or GraphQL API (or microservice)~~ 24 | - [X] GPU EC2 P2 check 25 | - [X] Find best `num_workers`; `num_workers=1` suggested 26 | - [X] GPU EC2 P3 check 27 | - [X] Hyperparams optimization based on param size 28 | - [X] Validation Set 29 | - [X] Early stopping w/ checkpoints 30 | - [X] Learning rate scheduler? 31 | - [ ] ~~MLFlow~~ 32 | - [ ] Toggle logging and printing or allow both 33 | - [X] Verbose mode 34 | - [ ] Change verbose to debug mode 35 | - [X] Docker on EC2 (no GPU) 36 | - [ ] Docker on EC2 (GPU) 37 | - [X] Compare timings of training and evaluation vs compression 38 | - [ ] ~~S3 ?~~ 39 | - [ ] Setup Travis CI 40 | - [X] Rename papers in ./references 41 | - [ ] Analyze residual blocks 42 | - [ ] Reorganize structure 43 | - [ ] Add docstrings 44 | - [ ] Public/private methods and attributes 45 | - [X] Create Readme 46 | - [ ] API usage 47 | - [ ] Sphinx documentation 48 | - [ ] Pypl publish 49 | - [ ] Conda publish 50 | -------------------------------------------------------------------------------- /tests/exploratory/LinearReg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | 6 | 7 | # Hyper-parameters 8 | input_size = 1 9 | output_size = 1 10 | num_epochs = 60 11 | learning_rate = 0.001 12 | 13 | # Toy dataset 14 | x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168], 15 | [9.779], [6.182], [7.59], [2.167], [7.042], 16 | [10.791], [5.313], [7.997], [3.1]], dtype=np.float32) 17 | 18 | y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573], 19 | [3.366], [2.596], [2.53], [1.221], [2.827], 20 | [3.465], [1.65], [2.904], [1.3]], dtype=np.float32) 21 | 22 | # Linear regression model 23 | model = nn.Linear(input_size, output_size) 24 | 25 | inputs = torch.from_numpy(x_train) 26 | print(model.weight) 27 | print(inputs) 28 | 29 | # Loss and optimizer 30 | criterion = nn.MSELoss() 31 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) 32 | 33 | # Train the model 34 | for epoch in range(num_epochs): 35 | # Convert numpy arrays to torch tensors 36 | inputs = torch.from_numpy(x_train) 37 | targets = torch.from_numpy(y_train) 38 | 39 | # Forward pass 40 | outputs = model(inputs) 41 | loss = criterion(outputs, targets) 42 | 43 | # Backward and optimize 44 | optimizer.zero_grad() 45 | loss.backward() 46 | optimizer.step() 47 | 48 | if (epoch+1) % 5 == 0: 49 | print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item())) 50 | 51 | # Plot the graph 52 | predicted = model(torch.from_numpy(x_train)).detach().numpy() 53 | plt.plot(x_train, y_train, 'ro', label='Original data') 54 | plt.plot(x_train, predicted, label='Fitted line') 55 | plt.legend() 56 | plt.show() 57 | 58 | # Save the model checkpoint 59 | torch.save(model.state_dict(), 'model.ckpt') 60 | -------------------------------------------------------------------------------- /feathermap/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''Depthwise conv + Pointwise conv''' 13 | def __init__(self, in_planes, out_planes, stride=1): 14 | super(Block, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 18 | self.bn2 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | out = F.relu(self.bn1(self.conv1(x))) 22 | out = F.relu(self.bn2(self.conv2(out))) 23 | return out 24 | 25 | 26 | class MobileNet(nn.Module): 27 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 28 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 29 | 30 | def __init__(self, num_classes=10): 31 | super(MobileNet, self).__init__() 32 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(32) 34 | self.layers = self._make_layers(in_planes=32) 35 | self.linear = nn.Linear(1024, num_classes) 36 | 37 | def _make_layers(self, in_planes): 38 | layers = [] 39 | for x in self.cfg: 40 | out_planes = x if isinstance(x, int) else x[0] 41 | stride = 1 if isinstance(x, int) else x[1] 42 | layers.append(Block(in_planes, out_planes, stride)) 43 | in_planes = out_planes 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | out = F.relu(self.bn1(self.conv1(x))) 48 | out = self.layers(out) 49 | out = F.avg_pool2d(out, 2) 50 | out = out.view(out.size(0), -1) 51 | out = self.linear(out) 52 | return out 53 | 54 | 55 | def test(): 56 | net = MobileNet() 57 | x = torch.randn(1,3,32,32) 58 | y = net(x) 59 | print(y.size()) 60 | 61 | # test() 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # MP 2 | /envs 3 | envs/ 4 | data/ 5 | FeatherMap/feathermap/data/ 6 | feathermap/data/ 7 | feathermap/train/data/ 8 | feathermap/train/checkpoint/ 9 | FeatherMap/tests/analysis/ 10 | FeatherMap/tests/test_streamlit.py 11 | FeatherMap/tests/generator_test.py 12 | ec2setup* 13 | *.ckpt 14 | 15 | #*.log 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | pip-wheel-metadata/ 39 | share/python-wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .nox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | *.py,cover 66 | .hypothesis/ 67 | .pytest_cache/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | #*.log 75 | local_settings.py 76 | db.sqlite3 77 | db.sqlite3-journal 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | .python-version 101 | 102 | # pipenv 103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 106 | # install all needed dependencies. 107 | #Pipfile.lock 108 | 109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 110 | __pypackages__/ 111 | 112 | # Celery stuff 113 | celerybeat-schedule 114 | celerybeat.pid 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | -------------------------------------------------------------------------------- /tests/integration/test_LinearRegression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch import optim 5 | from math import ceil, sqrt 6 | from torch.nn import Parameter 7 | import torch.nn.functional as F 8 | 9 | 10 | class LinReg(nn.Module): 11 | def __init__(self, input_size, output_size): 12 | super().__init__() 13 | self.weight = Parameter(torch.Tensor(output_size, input_size)) 14 | self.bias = Parameter(torch.Tensor(output_size)) 15 | 16 | def forward(self, x): 17 | out = F.linear(x, self.weight, self.bias) 18 | return out 19 | 20 | 21 | class LinRegHash(nn.Module): 22 | def __init__(self, input_size, output_size, compress=1.0): 23 | super().__init__() 24 | self.weight = torch.Tensor(output_size, input_size) 25 | self.bias = torch.Tensor(output_size) 26 | self.compress = compress 27 | self.size_n = ceil(sqrt(input_size * output_size + output_size)) 28 | self.size_m = ceil((self.compress * self.size_n) / 2) 29 | self.V1 = Parameter(torch.Tensor(self.size_n, self.size_m)) 30 | self.V2 = Parameter(torch.Tensor(self.size_m, self.size_n)) 31 | 32 | self.norm_V() 33 | 34 | def norm_V(self): 35 | k = sqrt(12) / 2 * self.size_m ** (-1 / 4) 36 | torch.nn.init.uniform_(self.V1, -k, k) 37 | torch.nn.init.uniform_(self.V2, -k, k) 38 | 39 | def WtoV(self): 40 | self.V = torch.matmul(self.V1, self.V2) 41 | V = self.V.view(-1, 1) 42 | i = 0 43 | for kind in ("weight", "bias"): 44 | v = getattr(self, kind) 45 | j = v.numel() 46 | w = V[i : i + j].reshape(v.size()) 47 | setattr(self, kind, w) 48 | i += j 49 | # i = 0 50 | # V = self.V.view(-1, 1) 51 | # v = self.weight.view(-1, 1) 52 | # for j in range(len(v)): 53 | # v[j] = V[i] 54 | # i += 1 55 | 56 | def forward(self, x): 57 | self.WtoV() 58 | out = F.linear(x, self.weight, self.bias) 59 | return out 60 | 61 | 62 | def train(model, criterion, optimizer, x, y): 63 | # Forward 64 | loss = criterion(model(x), y) 65 | 66 | # Backward 67 | optimizer.zero_grad() 68 | loss.backward() 69 | optimizer.step() 70 | 71 | return loss.item() 72 | 73 | 74 | def main(): 75 | torch.manual_seed(42) 76 | X = torch.Tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) 77 | Y = 3.0 * X + torch.randn(X.size()) * 0.33 78 | 79 | model = LinRegHash(3, 3, compress=0.1) 80 | # model = LinReg(3, 3) 81 | loss = torch.nn.MSELoss() 82 | optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 83 | batch_size = 1 84 | epochs = 20 85 | 86 | for i in range(epochs): 87 | cost = 0.0 88 | num_batches = len(X) // batch_size 89 | for k in range(num_batches): 90 | start, end = k * batch_size, (k + 1) * batch_size 91 | cost += train(model, loss, optimizer, X[start:end], Y[start:end]) 92 | print("Epoch = %d, cost = %s" % (i + 1, cost / num_batches)) 93 | 94 | print(list(model.named_parameters())) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /feathermap/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | 3 | See the paper "Inverted Residuals and Linear Bottlenecks: 4 | Mobile Networks for Classification, Detection and Segmentation" for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''expand + depthwise + pointwise''' 13 | def __init__(self, in_planes, out_planes, expansion, stride): 14 | super(Block, self).__init__() 15 | self.stride = stride 16 | 17 | planes = expansion * in_planes 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 23 | self.bn3 = nn.BatchNorm2d(out_planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride == 1 and in_planes != out_planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.BatchNorm2d(out_planes), 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = F.relu(self.bn2(self.conv2(out))) 35 | out = self.bn3(self.conv3(out)) 36 | out = out + self.shortcut(x) if self.stride==1 else out 37 | return out 38 | 39 | 40 | class MobileNetV2(nn.Module): 41 | # (expansion, out_planes, num_blocks, stride) 42 | cfg = [(1, 16, 1, 1), 43 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 44 | (6, 32, 3, 2), 45 | (6, 64, 4, 2), 46 | (6, 96, 3, 1), 47 | (6, 160, 3, 2), 48 | (6, 320, 1, 1)] 49 | 50 | def __init__(self, num_classes=10): 51 | super(MobileNetV2, self).__init__() 52 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 53 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(32) 55 | self.layers = self._make_layers(in_planes=32) 56 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 57 | self.bn2 = nn.BatchNorm2d(1280) 58 | self.linear = nn.Linear(1280, num_classes) 59 | 60 | def _make_layers(self, in_planes): 61 | layers = [] 62 | for expansion, out_planes, num_blocks, stride in self.cfg: 63 | strides = [stride] + [1]*(num_blocks-1) 64 | for stride in strides: 65 | layers.append(Block(in_planes, out_planes, expansion, stride)) 66 | in_planes = out_planes 67 | return nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | out = F.relu(self.bn1(self.conv1(x))) 71 | out = self.layers(out) 72 | out = F.relu(self.bn2(self.conv2(out))) 73 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 74 | out = F.avg_pool2d(out, 4) 75 | out = out.view(out.size(0), -1) 76 | out = self.linear(out) 77 | return out 78 | 79 | 80 | def test(): 81 | net = MobileNetV2() 82 | x = torch.randn(2,3,32,32) 83 | y = net(x) 84 | print(y.size()) 85 | 86 | # test() 87 | -------------------------------------------------------------------------------- /feathermap/models/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Inception(nn.Module): 8 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 9 | super(Inception, self).__init__() 10 | # 1x1 conv branch 11 | self.b1 = nn.Sequential( 12 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 13 | nn.BatchNorm2d(n1x1), 14 | nn.ReLU(True), 15 | ) 16 | 17 | # 1x1 conv -> 3x3 conv branch 18 | self.b2 = nn.Sequential( 19 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 20 | nn.BatchNorm2d(n3x3red), 21 | nn.ReLU(True), 22 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(n3x3), 24 | nn.ReLU(True), 25 | ) 26 | 27 | # 1x1 conv -> 5x5 conv branch 28 | self.b3 = nn.Sequential( 29 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 30 | nn.BatchNorm2d(n5x5red), 31 | nn.ReLU(True), 32 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(n5x5), 34 | nn.ReLU(True), 35 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(n5x5), 37 | nn.ReLU(True), 38 | ) 39 | 40 | # 3x3 pool -> 1x1 conv branch 41 | self.b4 = nn.Sequential( 42 | nn.MaxPool2d(3, stride=1, padding=1), 43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 44 | nn.BatchNorm2d(pool_planes), 45 | nn.ReLU(True), 46 | ) 47 | 48 | def forward(self, x): 49 | y1 = self.b1(x) 50 | y2 = self.b2(x) 51 | y3 = self.b3(x) 52 | y4 = self.b4(x) 53 | return torch.cat([y1,y2,y3,y4], 1) 54 | 55 | 56 | class GoogLeNet(nn.Module): 57 | def __init__(self): 58 | super(GoogLeNet, self).__init__() 59 | self.pre_layers = nn.Sequential( 60 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(192), 62 | nn.ReLU(True), 63 | ) 64 | 65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 67 | 68 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 69 | 70 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 71 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 72 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 73 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 74 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 75 | 76 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 77 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 78 | 79 | self.avgpool = nn.AvgPool2d(8, stride=1) 80 | self.linear = nn.Linear(1024, 10) 81 | 82 | def forward(self, x): 83 | out = self.pre_layers(x) 84 | out = self.a3(out) 85 | out = self.b3(out) 86 | out = self.maxpool(out) 87 | out = self.a4(out) 88 | out = self.b4(out) 89 | out = self.c4(out) 90 | out = self.d4(out) 91 | out = self.e4(out) 92 | out = self.maxpool(out) 93 | out = self.a5(out) 94 | out = self.b5(out) 95 | out = self.avgpool(out) 96 | out = out.view(out.size(0), -1) 97 | out = self.linear(out) 98 | return out 99 | 100 | 101 | def test(): 102 | net = GoogLeNet() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y.size()) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /tests/integration/test_FFNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from feathermap.feathernet import FeatherNet 6 | 7 | 8 | # Device configuration 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | # Hyper-parameters 12 | input_size = 784 13 | hidden_size = 500 14 | num_classes = 10 15 | num_epochs = 5 16 | batch_size = 100 17 | learning_rate = 0.001 18 | 19 | # MNIST dataset 20 | train_dataset = torchvision.datasets.MNIST(root='../../data', 21 | train=True, 22 | transform=transforms.ToTensor(), 23 | download=True) 24 | 25 | test_dataset = torchvision.datasets.MNIST(root='../../data', 26 | train=False, 27 | transform=transforms.ToTensor()) 28 | 29 | # Data loader 30 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 31 | batch_size=batch_size, 32 | shuffle=True) 33 | 34 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 35 | batch_size=batch_size, 36 | shuffle=False) 37 | 38 | # Fully connected neural network with one hidden layer 39 | class NeuralNet(nn.Module): 40 | def __init__(self, input_size, hidden_size, num_classes): 41 | super(NeuralNet, self).__init__() 42 | self.fc1 = nn.Linear(input_size, hidden_size) 43 | self.relu = nn.ReLU() 44 | self.fc2 = nn.Linear(hidden_size, num_classes) 45 | 46 | def forward(self, x): 47 | out = self.fc1(x) 48 | out = self.relu(out) 49 | out = self.fc2(out) 50 | return out 51 | 52 | 53 | base_model = NeuralNet(input_size, hidden_size, num_classes).to(device) 54 | #model = FeatherNet(base_model, compress=0.5) 55 | model = base_model 56 | #a = [print(name) for name, v in model.get_WandB()] 57 | 58 | # Loss and optimizer 59 | criterion = nn.CrossEntropyLoss() 60 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 61 | 62 | # Train the model 63 | total_step = len(train_loader) 64 | for epoch in range(num_epochs): 65 | for i, (images, labels) in enumerate(train_loader): 66 | # Move tensors to the configured device 67 | images = images.reshape(-1, 28*28).to(device) 68 | labels = labels.to(device) 69 | 70 | # Forward pass 71 | outputs = model(images) 72 | loss = criterion(outputs, labels) 73 | 74 | # Backward and optimize 75 | optimizer.zero_grad() 76 | #loss.backward(retain_graph=True) 77 | loss.backward() 78 | optimizer.step() 79 | 80 | if (i+1) % 100 == 0: 81 | print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 82 | .format(epoch+1, num_epochs, i+1, total_step, loss.item())) 83 | 84 | # Test the model 85 | # In test phase, we don't need to compute gradients (for memory efficiency) 86 | with torch.no_grad(): 87 | correct = 0 88 | total = 0 89 | for images, labels in test_loader: 90 | images = images.reshape(-1, 28*28).to(device) 91 | labels = labels.to(device) 92 | outputs = model(images) 93 | _, predicted = torch.max(outputs.data, 1) 94 | total += labels.size(0) 95 | correct += (predicted == labels).sum().item() 96 | 97 | print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total)) 98 | 99 | # Save the model checkpoint 100 | torch.save(model.state_dict(), 'model.ckpt') 101 | -------------------------------------------------------------------------------- /tests/integration/test_resnet.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- # 2 | # An implementation of https://arxiv.org/pdf/1512.03385.pdf # 3 | # See section 4.2 for the model architecture on CIFAR-10 # 4 | # Some part of the code was referenced from # 5 | # https://github.com/yunjey/pytorch-tutorial and # 6 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py # 7 | # ---------------------------------------------------------------------------- # 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | from feathermap.feathernet import FeatherNet 14 | 15 | # Device configuration 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | # 3x3 convolution 19 | def conv3x3(in_channels, out_channels, stride=1): 20 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, 21 | stride=stride, padding=1, bias=False) 22 | 23 | # Residual block 24 | class ResidualBlock(nn.Module): 25 | def __init__(self, in_channels, out_channels, stride=1, downsample=None): 26 | super(ResidualBlock, self).__init__() 27 | self.conv1 = conv3x3(in_channels, out_channels, stride) 28 | self.bn1 = nn.BatchNorm2d(out_channels) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(out_channels, out_channels) 31 | self.bn2 = nn.BatchNorm2d(out_channels) 32 | self.downsample = downsample 33 | 34 | def forward(self, x): 35 | residual = x 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | if self.downsample: 42 | residual = self.downsample(x) 43 | out += residual 44 | out = self.relu(out) 45 | return out 46 | 47 | # ResNet 48 | class ResNet(nn.Module): 49 | def __init__(self, block, layers, num_classes=10): 50 | super(ResNet, self).__init__() 51 | self.in_channels = 16 52 | self.conv = conv3x3(3, 16) 53 | self.bn = nn.BatchNorm2d(16) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.layer1 = self.make_layer(block, 16, layers[0]) 56 | self.layer2 = self.make_layer(block, 32, layers[1], 2) 57 | self.layer3 = self.make_layer(block, 64, layers[2], 2) 58 | self.avg_pool = nn.AvgPool2d(8) 59 | self.fc = nn.Linear(64, num_classes) 60 | 61 | def make_layer(self, block, out_channels, blocks, stride=1): 62 | downsample = None 63 | if (stride != 1) or (self.in_channels != out_channels): 64 | downsample = nn.Sequential( 65 | conv3x3(self.in_channels, out_channels, stride=stride), 66 | nn.BatchNorm2d(out_channels)) 67 | layers = [] 68 | layers.append(block(self.in_channels, out_channels, stride, downsample)) 69 | self.in_channels = out_channels 70 | for i in range(1, blocks): 71 | layers.append(block(out_channels, out_channels)) 72 | return nn.Sequential(*layers) 73 | 74 | def forward(self, x): 75 | out = self.conv(x) 76 | out = self.bn(out) 77 | out = self.relu(out) 78 | out = self.layer1(out) 79 | out = self.layer2(out) 80 | out = self.layer3(out) 81 | out = self.avg_pool(out) 82 | out = out.view(out.size(0), -1) 83 | out = self.fc(out) 84 | return out 85 | 86 | 87 | base_model = ResNet(ResidualBlock, [2, 2, 2]).to(device) 88 | model = FeatherNet(base_model, exclude=(nn.BatchNorm2d), compress=0.25) 89 | a = [print(name +'.'+kind, type(module)) for name, module, kind in model.get_WandB_modules()] 90 | -------------------------------------------------------------------------------- /feathermap/models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Block(nn.Module): 11 | '''Grouped convolution block.''' 12 | expansion = 2 13 | 14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 15 | super(Block, self).__init__() 16 | group_width = cardinality * bottleneck_width 17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(group_width) 19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 20 | self.bn2 = nn.BatchNorm2d(group_width) 21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*group_width: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*group_width) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | out = self.bn3(self.conv3(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ResNeXt(nn.Module): 41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 42 | super(ResNeXt, self).__init__() 43 | self.cardinality = cardinality 44 | self.bottleneck_width = bottleneck_width 45 | self.in_planes = 64 46 | 47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(64) 49 | self.layer1 = self._make_layer(num_blocks[0], 1) 50 | self.layer2 = self._make_layer(num_blocks[1], 2) 51 | self.layer3 = self._make_layer(num_blocks[2], 2) 52 | # self.layer4 = self._make_layer(num_blocks[3], 2) 53 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 54 | 55 | def _make_layer(self, num_blocks, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for stride in strides: 59 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 60 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 61 | # Increase bottleneck_width by 2 after each stage. 62 | self.bottleneck_width *= 2 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.layer1(out) 68 | out = self.layer2(out) 69 | out = self.layer3(out) 70 | # out = self.layer4(out) 71 | out = F.avg_pool2d(out, 8) 72 | out = out.view(out.size(0), -1) 73 | out = self.linear(out) 74 | return out 75 | 76 | 77 | def ResNeXt29_2x64d(): 78 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64) 79 | 80 | def ResNeXt29_4x64d(): 81 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64) 82 | 83 | def ResNeXt29_8x64d(): 84 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64) 85 | 86 | def ResNeXt29_32x4d(): 87 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4) 88 | 89 | def test_resnext(): 90 | net = ResNeXt29_2x64d() 91 | x = torch.randn(1,3,32,32) 92 | y = net(x) 93 | print(y.size()) 94 | 95 | # test_resnext() 96 | -------------------------------------------------------------------------------- /feathermap/models/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Bottleneck(nn.Module): 8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 9 | super(Bottleneck, self).__init__() 10 | self.out_planes = out_planes 11 | self.dense_depth = dense_depth 12 | 13 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 16 | self.bn2 = nn.BatchNorm2d(in_planes) 17 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 18 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 19 | 20 | self.shortcut = nn.Sequential() 21 | if first_layer: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(out_planes+dense_depth) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = F.relu(self.bn2(self.conv2(out))) 30 | out = self.bn3(self.conv3(out)) 31 | x = self.shortcut(x) 32 | d = self.out_planes 33 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 34 | out = F.relu(out) 35 | return out 36 | 37 | 38 | class DPN(nn.Module): 39 | def __init__(self, cfg): 40 | super(DPN, self).__init__() 41 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 42 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 43 | 44 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(64) 46 | self.last_planes = 64 47 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 48 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 49 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 50 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 51 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10) 52 | 53 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 54 | strides = [stride] + [1]*(num_blocks-1) 55 | layers = [] 56 | for i,stride in enumerate(strides): 57 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 58 | self.last_planes = out_planes + (i+2) * dense_depth 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = self.layer1(out) 64 | out = self.layer2(out) 65 | out = self.layer3(out) 66 | out = self.layer4(out) 67 | out = F.avg_pool2d(out, 4) 68 | out = out.view(out.size(0), -1) 69 | out = self.linear(out) 70 | return out 71 | 72 | 73 | def DPN26(): 74 | cfg = { 75 | 'in_planes': (96,192,384,768), 76 | 'out_planes': (256,512,1024,2048), 77 | 'num_blocks': (2,2,2,2), 78 | 'dense_depth': (16,32,24,128) 79 | } 80 | return DPN(cfg) 81 | 82 | def DPN92(): 83 | cfg = { 84 | 'in_planes': (96,192,384,768), 85 | 'out_planes': (256,512,1024,2048), 86 | 'num_blocks': (3,4,20,3), 87 | 'dense_depth': (16,32,24,128) 88 | } 89 | return DPN(cfg) 90 | 91 | 92 | def test(): 93 | net = DPN92() 94 | x = torch.randn(1,3,32,32) 95 | y = net(x) 96 | print(y) 97 | 98 | # test() 99 | -------------------------------------------------------------------------------- /feathermap/models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 38 | super(DenseNet, self).__init__() 39 | self.growth_rate = growth_rate 40 | 41 | num_planes = 2*growth_rate 42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 43 | 44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 45 | num_planes += nblocks[0]*growth_rate 46 | out_planes = int(math.floor(num_planes*reduction)) 47 | self.trans1 = Transition(num_planes, out_planes) 48 | num_planes = out_planes 49 | 50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 51 | num_planes += nblocks[1]*growth_rate 52 | out_planes = int(math.floor(num_planes*reduction)) 53 | self.trans2 = Transition(num_planes, out_planes) 54 | num_planes = out_planes 55 | 56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 57 | num_planes += nblocks[2]*growth_rate 58 | out_planes = int(math.floor(num_planes*reduction)) 59 | self.trans3 = Transition(num_planes, out_planes) 60 | num_planes = out_planes 61 | 62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 63 | num_planes += nblocks[3]*growth_rate 64 | 65 | self.bn = nn.BatchNorm2d(num_planes) 66 | self.linear = nn.Linear(num_planes, num_classes) 67 | 68 | def _make_dense_layers(self, block, in_planes, nblock): 69 | layers = [] 70 | for i in range(nblock): 71 | layers.append(block(in_planes, self.growth_rate)) 72 | in_planes += self.growth_rate 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.trans1(self.dense1(out)) 78 | out = self.trans2(self.dense2(out)) 79 | out = self.trans3(self.dense3(out)) 80 | out = self.dense4(out) 81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 82 | out = out.view(out.size(0), -1) 83 | out = self.linear(out) 84 | return out 85 | 86 | def DenseNet121(): 87 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 88 | 89 | def DenseNet169(): 90 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 91 | 92 | def DenseNet201(): 93 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 94 | 95 | def DenseNet161(): 96 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 97 | 98 | def densenet_cifar(): 99 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 100 | 101 | def test(): 102 | net = densenet_cifar() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /feathermap/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | 3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N,C,H,W = x.size() 18 | g = self.groups 19 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) 20 | 21 | 22 | class Bottleneck(nn.Module): 23 | def __init__(self, in_planes, out_planes, stride, groups): 24 | super(Bottleneck, self).__init__() 25 | self.stride = stride 26 | 27 | mid_planes = out_planes/4 28 | g = 1 if in_planes==24 else groups 29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 30 | self.bn1 = nn.BatchNorm2d(mid_planes) 31 | self.shuffle1 = ShuffleBlock(groups=g) 32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 33 | self.bn2 = nn.BatchNorm2d(mid_planes) 34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_planes) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride == 2: 39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.shuffle1(out) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | res = self.shortcut(x) 47 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res) 48 | return out 49 | 50 | 51 | class ShuffleNet(nn.Module): 52 | def __init__(self, cfg): 53 | super(ShuffleNet, self).__init__() 54 | out_planes = cfg['out_planes'] 55 | num_blocks = cfg['num_blocks'] 56 | groups = cfg['groups'] 57 | 58 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(24) 60 | self.in_planes = 24 61 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 62 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 63 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 64 | self.linear = nn.Linear(out_planes[2], 10) 65 | 66 | def _make_layer(self, out_planes, num_blocks, groups): 67 | layers = [] 68 | for i in range(num_blocks): 69 | stride = 2 if i == 0 else 1 70 | cat_planes = self.in_planes if i == 0 else 0 71 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 72 | self.in_planes = out_planes 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = self.layer1(out) 78 | out = self.layer2(out) 79 | out = self.layer3(out) 80 | out = F.avg_pool2d(out, 4) 81 | out = out.view(out.size(0), -1) 82 | out = self.linear(out) 83 | return out 84 | 85 | 86 | def ShuffleNetG2(): 87 | cfg = { 88 | 'out_planes': [200,400,800], 89 | 'num_blocks': [4,8,4], 90 | 'groups': 2 91 | } 92 | return ShuffleNet(cfg) 93 | 94 | def ShuffleNetG3(): 95 | cfg = { 96 | 'out_planes': [240,480,960], 97 | 'num_blocks': [4,8,4], 98 | 'groups': 3 99 | } 100 | return ShuffleNet(cfg) 101 | 102 | 103 | def test(): 104 | net = ShuffleNetG2() 105 | x = torch.randn(1,3,32,32) 106 | y = net(x) 107 | print(y) 108 | 109 | # test() 110 | -------------------------------------------------------------------------------- /tests/performance/test_latency.py: -------------------------------------------------------------------------------- 1 | """Train CIFAR10 with PyTorch. 2 | Deploy mode only supported in CPU 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.backends.cudnn as cudnn 7 | import argparse 8 | from feathermap.train.models.resnet import ResNet34 9 | from feathermap.models.feathernet import FeatherNet 10 | from feathermap.data_loader import get_test_loader 11 | from timeit import default_timer as timer 12 | 13 | parser = argparse.ArgumentParser( 14 | description="PyTorch CIFAR10 Training", 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 16 | ) 17 | parser.add_argument( 18 | "--epochs", 19 | type=int, 20 | default=1, 21 | help="Number of epochs to evaluate over test set", 22 | metavar="", 23 | ) 24 | parser.add_argument( 25 | "--compress", 26 | type=float, 27 | default=0, 28 | help="Compression rate. Set to zero for base model", 29 | metavar="", 30 | ) 31 | parser.add_argument( 32 | "--constrain", 33 | action="store_true", 34 | default=False, 35 | help="Constrain to per layer caching", 36 | ) 37 | parser.add_argument( 38 | "--num-workers", 39 | type=int, 40 | default=2, 41 | help="Number of dataloader processing threads. Try adjusting for faster training", 42 | metavar="", 43 | ) 44 | parser.add_argument( 45 | "--data-dir", 46 | type=str, 47 | default="./data/", 48 | help="Path to store CIFAR10 data", 49 | metavar="", 50 | ) 51 | parser.add_argument( 52 | "--pin-memory", 53 | type=bool, 54 | default=False, 55 | help="Pin GPU memory", 56 | metavar="", 57 | ) 58 | parser.add_argument( 59 | "--cudabench", 60 | type=bool, 61 | default=False, 62 | help="Set cudann.benchmark to true for static model and input", 63 | metavar="", 64 | ) 65 | parser.add_argument( 66 | "--cpu", 67 | action="store_true", 68 | default=False, 69 | help="Use CPU", 70 | ) 71 | parser.add_argument( 72 | "--deploy", 73 | action="store_true", 74 | default=False, 75 | help="Calculate weights on the fly in eval mode", 76 | ) 77 | parser.add_argument( 78 | "--v", 79 | action="store_true", 80 | default=False, 81 | help="Verbose", 82 | ) 83 | args = parser.parse_args() 84 | 85 | 86 | # Build Model 87 | print("==> Building model..") 88 | base_model = ResNet34() 89 | if args.compress: 90 | model = FeatherNet( 91 | base_model, 92 | exclude=(nn.BatchNorm2d), 93 | compress=args.compress, 94 | constrain=args.constrain, 95 | verbose=args.v 96 | ) 97 | else: 98 | model = base_model 99 | 100 | # Enable GPU support 101 | print("==> Preparing device..") 102 | if torch.cuda.is_available() and not args.cpu: 103 | print("Utilizing", torch.cuda.device_count(), "GPU(s)!") 104 | if torch.cuda.device_count() > 1: 105 | model = nn.DataParallel(model) 106 | DEV = torch.device("cuda:0") 107 | cuda_kwargs = {"num_workers": args.num_workers, "pin_memory": args.pin_memory} 108 | cudnn.benchmark = args.cudabench 109 | else: 110 | print("Utilizing CPU!") 111 | DEV = torch.device("cpu") 112 | cuda_kwargs = {} 113 | model.to(DEV) 114 | if args.deploy: 115 | model.deploy() 116 | else: 117 | model.eval() 118 | 119 | # Create dataloaders 120 | print("==> Preparing data..") 121 | test_loader = get_test_loader(data_dir=args.data_dir, batch_size=100, **cuda_kwargs) 122 | 123 | # Benchmark latency 124 | print("==> Evaluating test set..") 125 | 126 | 127 | def test(epoch): 128 | start = timer() 129 | with torch.no_grad(): 130 | for batch_idx, (inputs, targets) in enumerate(test_loader): 131 | inputs = inputs.to(DEV) 132 | outputs = model(inputs) 133 | end = timer() 134 | return end - start 135 | 136 | 137 | dt = [] 138 | for epoch in range(0, args.epochs): 139 | dt.append(test(epoch)) 140 | 141 | # 10k images in test set; result is fps 142 | avg = args.epochs * 10000 / sum(dt) 143 | print("Average fps: {:.4f}".format(avg)) 144 | print("Average latency: {:.4f}".format(1 / avg)) 145 | -------------------------------------------------------------------------------- /feathermap/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create train, valid, test iterators for CIFAR-10. 3 | Some code implemented from 4 | https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb 5 | """ 6 | 7 | import torch 8 | import numpy as np 9 | from torchvision import datasets 10 | from torchvision import transforms 11 | from torch.utils.data.sampler import SubsetRandomSampler 12 | 13 | 14 | def get_train_valid_loader( 15 | data_dir, 16 | batch_size=128, 17 | augment=True, 18 | random_seed=42, 19 | valid_size=0.1, 20 | shuffle=True, 21 | num_workers=2, 22 | pin_memory=False, 23 | ): 24 | """ 25 | Utility function for loading and returning train and valid 26 | multi-process iterators over the CIFAR-10 dataset 27 | """ 28 | error_msg = "[!] valid_size should be in the range [0, 1]." 29 | assert (valid_size >= 0) and (valid_size <= 1), error_msg 30 | 31 | normalize = transforms.Normalize( 32 | mean=[0.4914, 0.4822, 0.4465], 33 | std=[0.2023, 0.1994, 0.2010], 34 | ) 35 | 36 | # define transforms 37 | valid_transform = transforms.Compose( 38 | [ 39 | transforms.ToTensor(), 40 | normalize, 41 | ] 42 | ) 43 | if augment: 44 | train_transform = transforms.Compose( 45 | [ 46 | transforms.RandomCrop(32, padding=4), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.ToTensor(), 49 | normalize, 50 | ] 51 | ) 52 | else: 53 | train_transform = transforms.Compose( 54 | [ 55 | transforms.ToTensor(), 56 | normalize, 57 | ] 58 | ) 59 | 60 | # load the dataset 61 | train_dataset = datasets.CIFAR10( 62 | root=data_dir, 63 | train=True, 64 | download=True, 65 | transform=train_transform, 66 | ) 67 | 68 | valid_dataset = datasets.CIFAR10( 69 | root=data_dir, 70 | train=True, 71 | download=True, 72 | transform=valid_transform, 73 | ) 74 | 75 | num_train = len(train_dataset) 76 | indices = list(range(num_train)) 77 | split = int(np.floor(valid_size * num_train)) 78 | 79 | if shuffle: 80 | np.random.seed(random_seed) 81 | np.random.shuffle(indices) 82 | 83 | train_idx, valid_idx = indices[split:], indices[:split] 84 | train_sampler = SubsetRandomSampler(train_idx) 85 | valid_sampler = SubsetRandomSampler(valid_idx) 86 | 87 | train_loader = torch.utils.data.DataLoader( 88 | train_dataset, 89 | batch_size=batch_size, 90 | sampler=train_sampler, 91 | num_workers=num_workers, 92 | pin_memory=pin_memory, 93 | ) 94 | valid_loader = torch.utils.data.DataLoader( 95 | valid_dataset, 96 | batch_size=batch_size, 97 | sampler=valid_sampler, 98 | num_workers=num_workers, 99 | pin_memory=pin_memory, 100 | ) 101 | 102 | return (train_loader, valid_loader) 103 | 104 | 105 | def get_test_loader( 106 | data_dir, batch_size=100, shuffle=False, num_workers=2, pin_memory=False 107 | ): 108 | """ 109 | Utility function for loading and returning a multi-process 110 | test iterator over the CIFAR-10 dataset. 111 | """ 112 | normalize = transforms.Normalize( 113 | mean=[0.4914, 0.4822, 0.4465], 114 | std=[0.2023, 0.1994, 0.2010], 115 | ) 116 | 117 | # define transform 118 | transform = transforms.Compose( 119 | [ 120 | transforms.ToTensor(), 121 | normalize, 122 | ] 123 | ) 124 | 125 | dataset = datasets.CIFAR10( 126 | root=data_dir, 127 | train=False, 128 | download=True, 129 | transform=transform, 130 | ) 131 | 132 | data_loader = torch.utils.data.DataLoader( 133 | dataset, 134 | batch_size=batch_size, 135 | shuffle=shuffle, 136 | num_workers=num_workers, 137 | pin_memory=pin_memory, 138 | ) 139 | 140 | return data_loader 141 | 142 | 143 | # For reference 144 | label_names = [ 145 | "airplane", 146 | "automobile", 147 | "bird", 148 | "cat", 149 | "deer", 150 | "dog", 151 | "frog", 152 | "horse", 153 | "ship", 154 | "truck", 155 | ] 156 | -------------------------------------------------------------------------------- /tests/integration/test_FeatherNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | from feathermap.resnet import ResNet, ResidualBlock 5 | from torch import Tensor 6 | from typing import Iterator, Tuple 7 | from math import ceil, sqrt 8 | 9 | 10 | class FeatherNet(nn.Module): 11 | """Implementation of structured multihashing for model compression""" 12 | 13 | def __init__( 14 | self, module: nn.Module, compress: float = 1, exclude: tuple = () 15 | ) -> None: 16 | super().__init__() 17 | self.module = module 18 | self.compress = compress 19 | self.exclude = exclude 20 | self.unregister_params() 21 | self.size_n = ceil(sqrt(self.num_WandB())) 22 | self.size_m = ceil((self.compress * self.size_n) / 2) 23 | self.V1 = Parameter(torch.randn(self.size_n, self.size_m)) 24 | self.V2 = Parameter(torch.randn(self.size_n, self.size_m)) 25 | self.V = torch.matmul(self.V1, self.V2.transpose(0, 1)) 26 | 27 | def WandBtoV(self): 28 | i, j = 0, 0 29 | V = self.V.view(-1, 1) 30 | for name, v in self.get_WandB(): 31 | v = v.view(-1, 1) 32 | for j in range(len(v)): 33 | v[j] = V[i] 34 | i += 1 35 | 36 | def num_WandB(self) -> int: 37 | """Return total number of weights and biases""" 38 | return sum(v.numel() for name, v in self.get_WandB()) 39 | 40 | def get_WandB(self) -> Iterator[Tuple[str, Tensor]]: 41 | for name, module, kind in self.get_WandB_modules(): 42 | yield name + "." + kind, getattr(module, kind) 43 | 44 | def get_WandB_modules(self) -> Iterator[Tuple[str, nn.Module, str]]: 45 | """Helper function to return weight and bias modules in order""" 46 | for name, module in self.named_modules(): 47 | try: 48 | if isinstance(module, self.exclude): 49 | continue 50 | if getattr(module, "weight") is not None: 51 | yield name, module, "weight" 52 | if getattr(module, "bias") is not None: 53 | yield name, module, "bias" 54 | except nn.modules.module.ModuleAttributeError: 55 | pass 56 | 57 | def unregister_params(self) -> None: 58 | """Delete params, set attributes as Tensors of prior data""" 59 | for name, module, kind in self.get_WandB_modules(): 60 | try: 61 | data = module._parameters[kind].data 62 | del module._parameters[kind] 63 | print( 64 | "Parameter unregistered, assigned to type Tensor: {}".format( 65 | name + "." + kind 66 | ) 67 | ) 68 | setattr(module, kind, data) 69 | except KeyError: 70 | print( 71 | "{} is already registered as {}".format( 72 | name + "." + kind, type(getattr(module, kind)) 73 | ) 74 | ) 75 | 76 | def forward(self, x: Tensor) -> Tensor: 77 | self.WandBtoV() 78 | return self.module(x) 79 | 80 | 81 | def main(): 82 | # Device configuration 83 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 84 | 85 | def linear_test(): 86 | lmodel = nn.Linear(2, 4).to(device) 87 | flmodel = FeatherNet(lmodel, compress=0.5) 88 | print(flmodel.num_WandB(), flmodel.size_n, flmodel.size_m) 89 | print("V1: {}".format(flmodel.V1)) 90 | print("V2: {}".format(flmodel.V2)) 91 | print("V: {}".format(flmodel.V)) 92 | flmodel.WandBtoV() 93 | [print(name, v) for name, v in flmodel.named_parameters()] 94 | 95 | def res_test(): 96 | rmodel = ResNet(ResidualBlock, [2, 2, 2]).to(device) 97 | frmodel = FeatherNet(rmodel).to(device) 98 | print(frmodel.num_WandB(), frmodel.size_n, frmodel.size_m) 99 | print("V1: {}".format(frmodel.V1)) 100 | print("V2: {}".format(frmodel.V2)) 101 | print("V: {}".format(frmodel.V)) 102 | frmodel.WandBtoV() 103 | [print(name, v) for name, v in frmodel.named_parameters()] 104 | 105 | linear_test() 106 | 107 | 108 | if __name__ == "__main__": 109 | try: 110 | main() 111 | except KeyboardInterrupt: 112 | exit() 113 | -------------------------------------------------------------------------------- /feathermap/models/senet.py: -------------------------------------------------------------------------------- 1 | '''SENet in PyTorch. 2 | 3 | SENet is the winner of ImageNet-2017. The paper is not released yet. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_planes != planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm2d(planes) 23 | ) 24 | 25 | # SE layers 26 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear 27 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | 33 | # Squeeze 34 | w = F.avg_pool2d(out, out.size(2)) 35 | w = F.relu(self.fc1(w)) 36 | w = F.sigmoid(self.fc2(w)) 37 | # Excitation 38 | out = out * w # New broadcasting feature from v0.2! 39 | 40 | out += self.shortcut(x) 41 | out = F.relu(out) 42 | return out 43 | 44 | 45 | class PreActBlock(nn.Module): 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(PreActBlock, self).__init__() 48 | self.bn1 = nn.BatchNorm2d(in_planes) 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 52 | 53 | if stride != 1 or in_planes != planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 56 | ) 57 | 58 | # SE layers 59 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 60 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(x)) 64 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 65 | out = self.conv1(out) 66 | out = self.conv2(F.relu(self.bn2(out))) 67 | 68 | # Squeeze 69 | w = F.avg_pool2d(out, out.size(2)) 70 | w = F.relu(self.fc1(w)) 71 | w = F.sigmoid(self.fc2(w)) 72 | # Excitation 73 | out = out * w 74 | 75 | out += shortcut 76 | return out 77 | 78 | 79 | class SENet(nn.Module): 80 | def __init__(self, block, num_blocks, num_classes=10): 81 | super(SENet, self).__init__() 82 | self.in_planes = 64 83 | 84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(64) 86 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 87 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 88 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 89 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 90 | self.linear = nn.Linear(512, num_classes) 91 | 92 | def _make_layer(self, block, planes, num_blocks, stride): 93 | strides = [stride] + [1]*(num_blocks-1) 94 | layers = [] 95 | for stride in strides: 96 | layers.append(block(self.in_planes, planes, stride)) 97 | self.in_planes = planes 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = F.avg_pool2d(out, 4) 107 | out = out.view(out.size(0), -1) 108 | out = self.linear(out) 109 | return out 110 | 111 | 112 | def SENet18(): 113 | return SENet(PreActBlock, [2,2,2,2]) 114 | 115 | 116 | def test(): 117 | net = SENet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /feathermap/models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class PreActBlock(nn.Module): 13 | '''Pre-activation version of the BasicBlock.''' 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | '''Pre-activation version of the original Bottleneck module.''' 39 | expansion = 4 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBottleneck, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(x)) 57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 58 | out = self.conv1(out) 59 | out = self.conv2(F.relu(self.bn2(out))) 60 | out = self.conv3(F.relu(self.bn3(out))) 61 | out += shortcut 62 | return out 63 | 64 | 65 | class PreActResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(PreActResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.layer1(out) 88 | out = self.layer2(out) 89 | out = self.layer3(out) 90 | out = self.layer4(out) 91 | out = F.avg_pool2d(out, 4) 92 | out = out.view(out.size(0), -1) 93 | out = self.linear(out) 94 | return out 95 | 96 | 97 | def PreActResNet18(): 98 | return PreActResNet(PreActBlock, [2,2,2,2]) 99 | 100 | def PreActResNet34(): 101 | return PreActResNet(PreActBlock, [3,4,6,3]) 102 | 103 | def PreActResNet50(): 104 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 105 | 106 | def PreActResNet101(): 107 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 108 | 109 | def PreActResNet152(): 110 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 111 | 112 | 113 | def test(): 114 | net = PreActResNet18() 115 | y = net((torch.randn(1,3,32,32))) 116 | print(y.size()) 117 | 118 | # test() 119 | -------------------------------------------------------------------------------- /feathermap/models/pnasnet.py: -------------------------------------------------------------------------------- 1 | '''PNASNet in PyTorch. 2 | 3 | Paper: Progressive Neural Architecture Search 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class SepConv(nn.Module): 11 | '''Separable Convolution.''' 12 | def __init__(self, in_planes, out_planes, kernel_size, stride): 13 | super(SepConv, self).__init__() 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, 15 | kernel_size, stride, 16 | padding=(kernel_size-1)//2, 17 | bias=False, groups=in_planes) 18 | self.bn1 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | return self.bn1(self.conv1(x)) 22 | 23 | 24 | class CellA(nn.Module): 25 | def __init__(self, in_planes, out_planes, stride=1): 26 | super(CellA, self).__init__() 27 | self.stride = stride 28 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 29 | if stride==2: 30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn1 = nn.BatchNorm2d(out_planes) 32 | 33 | def forward(self, x): 34 | y1 = self.sep_conv1(x) 35 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 36 | if self.stride==2: 37 | y2 = self.bn1(self.conv1(y2)) 38 | return F.relu(y1+y2) 39 | 40 | class CellB(nn.Module): 41 | def __init__(self, in_planes, out_planes, stride=1): 42 | super(CellB, self).__init__() 43 | self.stride = stride 44 | # Left branch 45 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 46 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride) 47 | # Right branch 48 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride) 49 | if stride==2: 50 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 51 | self.bn1 = nn.BatchNorm2d(out_planes) 52 | # Reduce channels 53 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 54 | self.bn2 = nn.BatchNorm2d(out_planes) 55 | 56 | def forward(self, x): 57 | # Left branch 58 | y1 = self.sep_conv1(x) 59 | y2 = self.sep_conv2(x) 60 | # Right branch 61 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 62 | if self.stride==2: 63 | y3 = self.bn1(self.conv1(y3)) 64 | y4 = self.sep_conv3(x) 65 | # Concat & reduce channels 66 | b1 = F.relu(y1+y2) 67 | b2 = F.relu(y3+y4) 68 | y = torch.cat([b1,b2], 1) 69 | return F.relu(self.bn2(self.conv2(y))) 70 | 71 | class PNASNet(nn.Module): 72 | def __init__(self, cell_type, num_cells, num_planes): 73 | super(PNASNet, self).__init__() 74 | self.in_planes = num_planes 75 | self.cell_type = cell_type 76 | 77 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(num_planes) 79 | 80 | self.layer1 = self._make_layer(num_planes, num_cells=6) 81 | self.layer2 = self._downsample(num_planes*2) 82 | self.layer3 = self._make_layer(num_planes*2, num_cells=6) 83 | self.layer4 = self._downsample(num_planes*4) 84 | self.layer5 = self._make_layer(num_planes*4, num_cells=6) 85 | 86 | self.linear = nn.Linear(num_planes*4, 10) 87 | 88 | def _make_layer(self, planes, num_cells): 89 | layers = [] 90 | for _ in range(num_cells): 91 | layers.append(self.cell_type(self.in_planes, planes, stride=1)) 92 | self.in_planes = planes 93 | return nn.Sequential(*layers) 94 | 95 | def _downsample(self, planes): 96 | layer = self.cell_type(self.in_planes, planes, stride=2) 97 | self.in_planes = planes 98 | return layer 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = self.layer5(out) 107 | out = F.avg_pool2d(out, 8) 108 | out = self.linear(out.view(out.size(0), -1)) 109 | return out 110 | 111 | 112 | def PNASNetA(): 113 | return PNASNet(CellA, num_cells=6, num_planes=44) 114 | 115 | def PNASNetB(): 116 | return PNASNet(CellB, num_cells=6, num_planes=32) 117 | 118 | 119 | def test(): 120 | net = PNASNetB() 121 | x = torch.randn(1,3,32,32) 122 | y = net(x) 123 | print(y) 124 | 125 | # test() 126 | -------------------------------------------------------------------------------- /feathermap/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d( 20 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 23 | stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, 30 | kernel_size=1, stride=stride, bias=False), 31 | nn.BatchNorm2d(self.expansion*planes) 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 50 | stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * 53 | planes, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion*planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_planes, self.expansion*planes, 60 | kernel_size=1, stride=stride, bias=False), 61 | nn.BatchNorm2d(self.expansion*planes) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(self.conv1(x))) 66 | out = F.relu(self.bn2(self.conv2(out))) 67 | out = self.bn3(self.conv3(out)) 68 | out += self.shortcut(x) 69 | out = F.relu(out) 70 | return out 71 | 72 | 73 | class ResNet(nn.Module): 74 | def __init__(self, block, num_blocks, num_classes=10): 75 | super(ResNet, self).__init__() 76 | self.in_planes = 64 77 | 78 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 79 | stride=1, padding=1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(64) 81 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 82 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 83 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 84 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 85 | self.linear = nn.Linear(512*block.expansion, num_classes) 86 | 87 | def _make_layer(self, block, planes, num_blocks, stride): 88 | strides = [stride] + [1]*(num_blocks-1) 89 | layers = [] 90 | for stride in strides: 91 | layers.append(block(self.in_planes, planes, stride)) 92 | self.in_planes = planes * block.expansion 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | out = F.relu(self.bn1(self.conv1(x))) 97 | out = self.layer1(out) 98 | out = self.layer2(out) 99 | out = self.layer3(out) 100 | out = self.layer4(out) 101 | out = F.avg_pool2d(out, 4) 102 | out = out.view(out.size(0), -1) 103 | out = self.linear(out) 104 | return out 105 | 106 | 107 | def ResNet18(): 108 | return ResNet(BasicBlock, [2, 2, 2, 2]) 109 | 110 | 111 | def ResNet34(): 112 | return ResNet(BasicBlock, [3, 4, 6, 3]) 113 | 114 | 115 | def ResNet50(): 116 | return ResNet(Bottleneck, [3, 4, 6, 3]) 117 | 118 | 119 | def ResNet101(): 120 | return ResNet(Bottleneck, [3, 4, 23, 3]) 121 | 122 | 123 | def ResNet152(): 124 | return ResNet(Bottleneck, [3, 8, 36, 3]) 125 | 126 | 127 | def test(): 128 | net = ResNet18() 129 | y = net(torch.randn(1, 3, 32, 32)) 130 | print(y.size()) 131 | 132 | # test() 133 | -------------------------------------------------------------------------------- /feathermap/utils.py: -------------------------------------------------------------------------------- 1 | """Some helper functions for FeatherMap, including: 2 | - progress_bar: mimics xlua.progress. 3 | - timed: decorator for timing functions 4 | - get_block_rows: Get complete rows from range within matrix 5 | """ 6 | from timeit import default_timer as timer 7 | import os 8 | import sys 9 | import time 10 | from math import ceil 11 | from typing import List 12 | 13 | 14 | _, term_width = os.popen("stty size", "r").read().split() 15 | term_width = int(term_width) 16 | 17 | TOTAL_BAR_LENGTH = 40.0 18 | last_time = time.time() 19 | begin_time = last_time 20 | 21 | 22 | # current = batch idx, total = len(dataloader) 23 | def progress_bar(current, total, msg=None): 24 | global last_time, begin_time 25 | if current == 0: 26 | begin_time = time.time() # Reset for new bar. 27 | 28 | cur_len = int(TOTAL_BAR_LENGTH * current / total) 29 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 30 | 31 | sys.stdout.write(" [") 32 | for i in range(cur_len): 33 | sys.stdout.write("=") 34 | sys.stdout.write(">") 35 | for i in range(rest_len): 36 | sys.stdout.write(".") 37 | sys.stdout.write("]") 38 | 39 | cur_time = time.time() 40 | step_time = cur_time - last_time 41 | last_time = cur_time 42 | tot_time = cur_time - begin_time 43 | 44 | L = [] 45 | L.append(" Step: {:<4}".format(format_time(step_time))) 46 | L.append(" | Tot: {:<8}".format(format_time(tot_time))) 47 | if msg: 48 | L.append(" | " + msg) 49 | 50 | msg = "".join(L) 51 | sys.stdout.write(msg) 52 | for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3): 53 | sys.stdout.write(" ") 54 | 55 | # Go back to the center of the bar. 56 | for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2): 57 | sys.stdout.write("\b") 58 | sys.stdout.write(" {:>3}/{:<3} ".format(current + 1, total)) 59 | 60 | if current < total - 1: 61 | sys.stdout.write("\r") 62 | else: 63 | sys.stdout.write("\n") 64 | sys.stdout.flush() 65 | 66 | 67 | def format_time(seconds): 68 | days = int(seconds / 3600 / 24) 69 | seconds = seconds - days * 3600 * 24 70 | hours = int(seconds / 3600) 71 | seconds = seconds - hours * 3600 72 | minutes = int(seconds / 60) 73 | seconds = seconds - minutes * 60 74 | secondsf = int(seconds) 75 | seconds = seconds - secondsf 76 | millis = int(seconds * 1000) 77 | 78 | f = "" 79 | i = 1 80 | if days > 0: 81 | f += str(days) + "D" 82 | i += 1 83 | if hours > 0 and i <= 2: 84 | f += str(hours) + "h" 85 | i += 1 86 | if minutes > 0 and i <= 2: 87 | f += str(minutes) + "m" 88 | i += 1 89 | if secondsf > 0 and i <= 2: 90 | f += str(secondsf) + "s" 91 | i += 1 92 | if millis > 0 and i <= 2: 93 | f += str(millis) + "ms" 94 | i += 1 95 | if f == "": 96 | f = "0ms" 97 | return f 98 | 99 | 100 | def timed(method): 101 | def time_me(*args, **kw): 102 | start = timer() 103 | result = method(*args, **kw) 104 | end = timer() 105 | print("{!r} duration (secs): {:.4f}".format(method.__name__, end - start)) 106 | return result 107 | 108 | return time_me 109 | 110 | 111 | def get_block_rows(i1: int, j1: int, i2: int, j2: int, n: int) -> List[int]: 112 | """Return range of full (complete) rows from an (n x n) matrix, starting from row, col 113 | [i1, j1] and ending at [i2, j2]. E.g. 114 | 115 | | _ x x x | | x x x | 116 | | x x x x | ------> + | x x x x | 117 | | x x x x | | x x x x | + 118 | | x x _ _ | | x x | 119 | 120 | Necessary to make the most use of vectorized matrix multiplication. Sequentially 121 | calculating V[i, j] = V1[i, :] @ V2[:, j] leads to large latency. 122 | """ 123 | row = [] 124 | # Handle all cases for j1=0 125 | if j1 == 0: 126 | # All row(s) complete 127 | if j2 == n - 1: 128 | row.extend([i1, i2 + 1]) 129 | return row 130 | # First row complete, last row incomplete 131 | elif i2 > i1: 132 | row.extend([i1, i2]) 133 | return row 134 | # First row incomplete (from right), no additional rows 135 | else: 136 | return row 137 | # First row incomplete (from left), last row complete 138 | if j2 == n - 1 and i2 > i1: 139 | row.extend([i1 + 1, i2 + 1]) 140 | return row 141 | # First row incomplete, last row incomplete; has at least one full row 142 | if i2 - i1 > 1: 143 | row.extend([i1 + 1, i2]) 144 | return row 145 | # First row incomplete, second row incomplete 146 | return row 147 | -------------------------------------------------------------------------------- /feathermap/models/regnet.py: -------------------------------------------------------------------------------- 1 | '''RegNet in PyTorch. 2 | 3 | Paper: "Designing Network Design Spaces". 4 | 5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class SE(nn.Module): 13 | '''Squeeze-and-Excitation block.''' 14 | 15 | def __init__(self, in_planes, se_planes): 16 | super(SE, self).__init__() 17 | self.se1 = nn.Conv2d(in_planes, se_planes, kernel_size=1, bias=True) 18 | self.se2 = nn.Conv2d(se_planes, in_planes, kernel_size=1, bias=True) 19 | 20 | def forward(self, x): 21 | out = F.adaptive_avg_pool2d(x, (1, 1)) 22 | out = F.relu(self.se1(out)) 23 | out = self.se2(out).sigmoid() 24 | out = x * out 25 | return out 26 | 27 | 28 | class Block(nn.Module): 29 | def __init__(self, w_in, w_out, stride, group_width, bottleneck_ratio, se_ratio): 30 | super(Block, self).__init__() 31 | # 1x1 32 | w_b = int(round(w_out * bottleneck_ratio)) 33 | self.conv1 = nn.Conv2d(w_in, w_b, kernel_size=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(w_b) 35 | # 3x3 36 | num_groups = w_b // group_width 37 | self.conv2 = nn.Conv2d(w_b, w_b, kernel_size=3, 38 | stride=stride, padding=1, groups=num_groups, bias=False) 39 | self.bn2 = nn.BatchNorm2d(w_b) 40 | # se 41 | self.with_se = se_ratio > 0 42 | if self.with_se: 43 | w_se = int(round(w_in * se_ratio)) 44 | self.se = SE(w_b, w_se) 45 | # 1x1 46 | self.conv3 = nn.Conv2d(w_b, w_out, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(w_out) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or w_in != w_out: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(w_in, w_out, 53 | kernel_size=1, stride=stride, bias=False), 54 | nn.BatchNorm2d(w_out) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(self.conv1(x))) 59 | out = F.relu(self.bn2(self.conv2(out))) 60 | if self.with_se: 61 | out = self.se(out) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | out = F.relu(out) 65 | return out 66 | 67 | 68 | class RegNet(nn.Module): 69 | def __init__(self, cfg, num_classes=10): 70 | super(RegNet, self).__init__() 71 | self.cfg = cfg 72 | self.in_planes = 64 73 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 74 | stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(64) 76 | self.layer1 = self._make_layer(0) 77 | self.layer2 = self._make_layer(1) 78 | self.layer3 = self._make_layer(2) 79 | self.layer4 = self._make_layer(3) 80 | self.linear = nn.Linear(self.cfg['widths'][-1], num_classes) 81 | 82 | def _make_layer(self, idx): 83 | depth = self.cfg['depths'][idx] 84 | width = self.cfg['widths'][idx] 85 | stride = self.cfg['strides'][idx] 86 | group_width = self.cfg['group_width'] 87 | bottleneck_ratio = self.cfg['bottleneck_ratio'] 88 | se_ratio = self.cfg['se_ratio'] 89 | 90 | layers = [] 91 | for i in range(depth): 92 | s = stride if i == 0 else 1 93 | layers.append(Block(self.in_planes, width, 94 | s, group_width, bottleneck_ratio, se_ratio)) 95 | self.in_planes = width 96 | return nn.Sequential(*layers) 97 | 98 | def forward(self, x): 99 | out = F.relu(self.bn1(self.conv1(x))) 100 | out = self.layer1(out) 101 | out = self.layer2(out) 102 | out = self.layer3(out) 103 | out = self.layer4(out) 104 | out = F.adaptive_avg_pool2d(out, (1, 1)) 105 | out = out.view(out.size(0), -1) 106 | out = self.linear(out) 107 | return out 108 | 109 | 110 | def RegNetX_200MF(): 111 | cfg = { 112 | 'depths': [1, 1, 4, 7], 113 | 'widths': [24, 56, 152, 368], 114 | 'strides': [1, 1, 2, 2], 115 | 'group_width': 8, 116 | 'bottleneck_ratio': 1, 117 | 'se_ratio': 0, 118 | } 119 | return RegNet(cfg) 120 | 121 | 122 | def RegNetX_400MF(): 123 | cfg = { 124 | 'depths': [1, 2, 7, 12], 125 | 'widths': [32, 64, 160, 384], 126 | 'strides': [1, 1, 2, 2], 127 | 'group_width': 16, 128 | 'bottleneck_ratio': 1, 129 | 'se_ratio': 0, 130 | } 131 | return RegNet(cfg) 132 | 133 | 134 | def RegNetY_400MF(): 135 | cfg = { 136 | 'depths': [1, 2, 7, 12], 137 | 'widths': [32, 64, 160, 384], 138 | 'strides': [1, 1, 2, 2], 139 | 'group_width': 16, 140 | 'bottleneck_ratio': 1, 141 | 'se_ratio': 0.25, 142 | } 143 | return RegNet(cfg) 144 | 145 | 146 | def test(): 147 | net = RegNetX_200MF() 148 | print(net) 149 | x = torch.randn(2, 3, 32, 32) 150 | y = net(x) 151 | print(y.shape) 152 | 153 | 154 | if __name__ == '__main__': 155 | test() 156 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🕊 FeatherMap 2 | 3 | ## What is FeatherMap? 4 | FeatherMap is a tool that compresses deep neural networks. Centered around computer vision models, it implements the Google Research paper [Structured Multi-Hashing for Model Compression (CVPR 2020)](https://openaccess.thecvf.com/content_CVPR_2020/papers/Eban_Structured_Multi-Hashing_for_Model_Compression_CVPR_2020_paper.pdf). Taking the form of a Python package, the tool takes a user-defined PyTorch model and compresses it to a desired factor without modification to the underlying architecture. Using its simple API, FeatherMap can easily be applied across a broad array of models. 5 | 6 | ## Table of Contents 7 | * [Installation](#installation) 8 | * [Usage](#usage) 9 | + [General Usage](#general-usage) 10 | + [Training](#training) 11 | + [Deployment](#deployment) 12 | * [Results](#results) 13 | * [What is Structured Multi-Hashing?](#what-is-structured-multi-hashing) 14 | 15 | ## Installation 16 | * Clone into directory 17 | ``` 18 | git clone https://github.com/phelps-matthew/FeatherMap.git 19 | cd FeatherMap 20 | ``` 21 | * Pip Install 22 | ``` 23 | pip install -e . 24 | ``` 25 | * Conda Install 26 | ``` 27 | conda develop . 28 | ``` 29 | ## Usage 30 | ``` 31 | ├── FeatherMap 32 | │   ├── feathermap 33 | │   │   ├── dataloader.py # CIFAR10 train, valid, and test data loading 34 | │   │   ├── feathernet.py # Module for implementing compression 35 | │   │   ├── train.py # Training script (argparse) 36 | │   │   ├── utils.py 37 | │   │   ├── models # Sample computer vision models to compress 38 | │   │   │   ├── densenet.py 39 | │   │   │   ├── efficientnet.py 40 | │   │   │   ├── mobilenetv2.py 41 | │   │   │   ├── resnet.py 42 | │   │   │   └── ... 43 | ``` 44 | ### General Usage 45 | To compress a model such as Resnet-34, import the model from `feathermap/models/` and simply wrap the model with the `FeatherNet` module, initializing with the desired compression. One can then proceed with forward and backward passes as normal, as well as `state_dict` loading and saving. 46 | ```python 47 | from feathermap.models.resnet import ResNet34 48 | from feathermap.feathernet import FeatherNet 49 | 50 | base_model = ResNet34() 51 | model = FeatherNet(base_model, compress=0.10) 52 | 53 | # Forward pass ... 54 | y = model(x) 55 | loss = criterion(y, target) 56 | ... 57 | 58 | # Backward and optimize ... 59 | loss.backward() 60 | optimizer.step() 61 | ... 62 | ``` 63 | See `feathermap/models/` for a zoo of available CV models to compress. 64 | ### Training 65 | Models are trained on CIFAR-10 using `feathermap/train.py` (defaults to training ResNet-34). See the argument options by using the help flag `--help`. 66 | ```bash 67 | python train.py --compress 0.1 68 | ``` 69 | 70 | ### Deployment 71 | Upon defining your `FeatherNet` model, switch to deploy mode to calculate weights on the fly (see [What is Structured Multi-Hashing?](#what-is-structured-multi-hashing)). 72 | ```python 73 | base_model = ResNet34() 74 | model = FeatherNet(base_model, compress=0.10) 75 | model.deploy() 76 | ``` 77 | 78 | ## Results 79 | Below are results as applied to a ResNet-34 architecture, trained and tested on CIFAR-10. Latency benchmarked on CPU (AWS c5a.8xlarge) iterating over 30k images with batch size of 100. To add some context, one can compress ResNet-34 to 2% of its original size while still achieving over 90% accuracy (a 5% accuracy drop compared to the base model), while incurring only a 4% increase in inference time. 80 |

81 | 82 | 83 | 84 | ## What is Structured Multi-Hashing? 85 | There are two main concepts behind structured multi-hashing. The first concept is to take the weights of each *layer*, flatten them, and tile them into a single square matrix. This *global weight matrix* represents the weights of the entire network. 86 |

87 | The second concept is purely linear algebra and it is the understanding that if we take a pair of columns and matrix-multiply them by a pair of rows, we obtain a square matrix. 88 |

89 | Putting these two ideas together, we can implement structured multi-hashing! Here's how it works: 90 | 91 | 1. Let the total number of tunable parameters describing the entire network be the set of two rows (2 x n) and two columns (n x 2) 92 | 2. Matrix multiply the columns and rows to obtain a square matrix of size (n x n) 93 | 3. Map each element of the matrix above to each element in the *global weight matrix* 94 | 95 | Putting it all together, we have this process. 96 |

97 | 98 | What we have effectively done with this mapping is a reduction of the number of *tunable parameters* from n^2 to 4n, thus achieving the desired compression! 99 | 100 | Additional Remarks: 101 | - To obtain a target compression factor, generalize the respective dimension of the rows and columns from 2 to m, to thus begin with a total of 2mn tunable parameters. The compression factor will then be 2mn/n^2 = 2m/n. By varying m, one can achieve varying levels of compression. 102 | - For practical deployment, in order to constrain RAM consumption, each weight must be calculated 'on the fly' during the foward pass. Such additional calculations will induce latency overhead; however, the 'structured' nature of this multi-hashing approach embraces memory locality and I have found that for small compression factors the overhead is minimal (see [Results](#results)). 103 | - An earlier reference to compression via hashing: [Compressing Neural Networks with the Hashing Trick](https://arxiv.org/abs/1504.04788) 104 | 105 | 106 | -------------------------------------------------------------------------------- /feathermap/models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNetV2 in PyTorch. 2 | 3 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups=2): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N, C, H, W = x.size() 18 | g = self.groups 19 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 20 | 21 | 22 | class SplitBlock(nn.Module): 23 | def __init__(self, ratio): 24 | super(SplitBlock, self).__init__() 25 | self.ratio = ratio 26 | 27 | def forward(self, x): 28 | c = int(x.size(1) * self.ratio) 29 | return x[:, :c, :, :], x[:, c:, :, :] 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | def __init__(self, in_channels, split_ratio=0.5): 34 | super(BasicBlock, self).__init__() 35 | self.split = SplitBlock(split_ratio) 36 | in_channels = int(in_channels * split_ratio) 37 | self.conv1 = nn.Conv2d(in_channels, in_channels, 38 | kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(in_channels) 40 | self.conv2 = nn.Conv2d(in_channels, in_channels, 41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) 42 | self.bn2 = nn.BatchNorm2d(in_channels) 43 | self.conv3 = nn.Conv2d(in_channels, in_channels, 44 | kernel_size=1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(in_channels) 46 | self.shuffle = ShuffleBlock() 47 | 48 | def forward(self, x): 49 | x1, x2 = self.split(x) 50 | out = F.relu(self.bn1(self.conv1(x2))) 51 | out = self.bn2(self.conv2(out)) 52 | out = F.relu(self.bn3(self.conv3(out))) 53 | out = torch.cat([x1, out], 1) 54 | out = self.shuffle(out) 55 | return out 56 | 57 | 58 | class DownBlock(nn.Module): 59 | def __init__(self, in_channels, out_channels): 60 | super(DownBlock, self).__init__() 61 | mid_channels = out_channels // 2 62 | # left 63 | self.conv1 = nn.Conv2d(in_channels, in_channels, 64 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) 65 | self.bn1 = nn.BatchNorm2d(in_channels) 66 | self.conv2 = nn.Conv2d(in_channels, mid_channels, 67 | kernel_size=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(mid_channels) 69 | # right 70 | self.conv3 = nn.Conv2d(in_channels, mid_channels, 71 | kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(mid_channels) 73 | self.conv4 = nn.Conv2d(mid_channels, mid_channels, 74 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) 75 | self.bn4 = nn.BatchNorm2d(mid_channels) 76 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, 77 | kernel_size=1, bias=False) 78 | self.bn5 = nn.BatchNorm2d(mid_channels) 79 | 80 | self.shuffle = ShuffleBlock() 81 | 82 | def forward(self, x): 83 | # left 84 | out1 = self.bn1(self.conv1(x)) 85 | out1 = F.relu(self.bn2(self.conv2(out1))) 86 | # right 87 | out2 = F.relu(self.bn3(self.conv3(x))) 88 | out2 = self.bn4(self.conv4(out2)) 89 | out2 = F.relu(self.bn5(self.conv5(out2))) 90 | # concat 91 | out = torch.cat([out1, out2], 1) 92 | out = self.shuffle(out) 93 | return out 94 | 95 | 96 | class ShuffleNetV2(nn.Module): 97 | def __init__(self, net_size): 98 | super(ShuffleNetV2, self).__init__() 99 | out_channels = configs[net_size]['out_channels'] 100 | num_blocks = configs[net_size]['num_blocks'] 101 | 102 | self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 103 | stride=1, padding=1, bias=False) 104 | self.bn1 = nn.BatchNorm2d(24) 105 | self.in_channels = 24 106 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 107 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 108 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 109 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], 110 | kernel_size=1, stride=1, padding=0, bias=False) 111 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 112 | self.linear = nn.Linear(out_channels[3], 10) 113 | 114 | def _make_layer(self, out_channels, num_blocks): 115 | layers = [DownBlock(self.in_channels, out_channels)] 116 | for i in range(num_blocks): 117 | layers.append(BasicBlock(out_channels)) 118 | self.in_channels = out_channels 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | out = F.relu(self.bn1(self.conv1(x))) 123 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 124 | out = self.layer1(out) 125 | out = self.layer2(out) 126 | out = self.layer3(out) 127 | out = F.relu(self.bn2(self.conv2(out))) 128 | out = F.avg_pool2d(out, 4) 129 | out = out.view(out.size(0), -1) 130 | out = self.linear(out) 131 | return out 132 | 133 | 134 | configs = { 135 | 0.5: { 136 | 'out_channels': (48, 96, 192, 1024), 137 | 'num_blocks': (3, 7, 3) 138 | }, 139 | 140 | 1: { 141 | 'out_channels': (116, 232, 464, 1024), 142 | 'num_blocks': (3, 7, 3) 143 | }, 144 | 1.5: { 145 | 'out_channels': (176, 352, 704, 1024), 146 | 'num_blocks': (3, 7, 3) 147 | }, 148 | 2: { 149 | 'out_channels': (224, 488, 976, 2048), 150 | 'num_blocks': (3, 7, 3) 151 | } 152 | } 153 | 154 | 155 | def test(): 156 | net = ShuffleNetV2(net_size=0.5) 157 | x = torch.randn(3, 3, 32, 32) 158 | y = net(x) 159 | print(y.shape) 160 | 161 | 162 | # test() 163 | -------------------------------------------------------------------------------- /tests/unit/test_matrix_block_slices.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from feathermap.utils import get_block_rows 4 | from feathermap.feathernet import LoadLayer 5 | 6 | 7 | class TestBlockRows(unittest.TestCase): 8 | """ 9 | Cover all possible matrix slice scenarios and test 10 | `feathernet.utils.get_block_rows` and `feathernet.LoadLayer._get_operands` 11 | """ 12 | 13 | def setUp(self): 14 | """ Set up fixture covering index ranges and expected operand returns """ 15 | m = 1 16 | self.n = 3342 17 | self.V1 = torch.Tensor( 18 | [list(range(self.n * q, self.n * (q + 1))) for q in range(m)] 19 | ).reshape(self.n, m) 20 | self.V2 = torch.Tensor( 21 | [list(range(m * r, m * (r + 1))) for r in range(self.n)] 22 | ).reshape(m, self.n) 23 | self.V = torch.Tensor( 24 | [list(range(self.n * q, self.n * (q + 1))) for q in range(self.n)] 25 | ) 26 | 27 | # row1, col1, row2, col2 28 | self.idxs = [ 29 | # All one row complete 30 | (42, 0, 42, self.n - 1), 31 | # All two rows complete 32 | (652, 0, 653, self.n - 1), 33 | # All 10 rows complete 34 | (652, 0, 662, self.n - 1), 35 | # First row complete, second incomplete 36 | (0, 0, 1, self.n - 100), 37 | # First row compete, last row incomplete 38 | (652, 0, 829, 105), 39 | # First row incomplete (from right), no additional rows 40 | (1, 0, 1, 105), 41 | # First row incomplete (from left), no additional rows 42 | (123, 123, 123, self.n - 1), 43 | # First row incomplete (from left), second row incomplete 44 | (222, 1816, 223, 2018), 45 | # First row] incomplete (from left), last row incomplete 46 | (652, 1816, 829, 105), 47 | # First row incomplete (from left), second row complete 48 | (3000, 1816, 3001, self.n - 1), 49 | # First row incomplete (from left), last row complete 50 | (3000, 1816, 3091, self.n - 1), 51 | ] 52 | 53 | # expected operands from covered index ranges above 54 | self.operands = [ 55 | {"top": (self.V1[42, :], self.V2[:, 0 : 3341 + 1])}, 56 | {"block": (self.V1[range(*[652, 654]), :], self.V2)}, 57 | {"block": (self.V1[range(*[652, 663]), :], self.V2)}, 58 | { 59 | "block": (self.V1[range(*[0, 1]), :], self.V2), 60 | "bottom": (self.V1[1, :], self.V2[:, : 3242 + 1]), 61 | }, 62 | { 63 | "block": (self.V1[range(*[652, 829]), :], self.V2), 64 | "bottom": (self.V1[829, :], self.V2[:, : 105 + 1]), 65 | }, 66 | {"top": (self.V1[1, :], self.V2[:, 0 : 105 + 1])}, 67 | {"top": (self.V1[123, :], self.V2[:, 123 : 3341 + 1])}, 68 | { 69 | "top": (self.V1[222, :], self.V2[:, 1816:]), 70 | "bottom": (self.V1[223, :], self.V2[:, : 2018 + 1]), 71 | }, 72 | { 73 | "block": (self.V1[range(*[653, 829]), :], self.V2), 74 | "top": (self.V1[652, :], self.V2[:, 1816:]), 75 | "bottom": (self.V1[829, :], self.V2[:, : 105 + 1]), 76 | }, 77 | { 78 | "block": (self.V1[range(*[3001, 3002]), :], self.V2), 79 | "top": (self.V1[3000, :], self.V2[:, 1816:]), 80 | }, 81 | { 82 | "block": (self.V1[range(*[3001, 3092]), :], self.V2), 83 | "top": (self.V1[3000, :], self.V2[:, 1816:]), 84 | }, 85 | ] 86 | 87 | def test_get_block_rows(self): 88 | """ Test block row range """ 89 | # Correct block row range 90 | res = [ 91 | [42, 43], 92 | [652, 654], 93 | [652, 663], 94 | [0, 1], 95 | [652, 829], 96 | [], 97 | [], 98 | [], 99 | [653, 829], 100 | [3001, 3002], 101 | [3001, 3092], 102 | ] 103 | 104 | for i, idx in enumerate(self.idxs): 105 | with self.subTest(): 106 | self.assertEqual(get_block_rows(*idx, self.n), res[i]) 107 | 108 | @unittest.skip("operand keys") 109 | def test_get_operands_keys(self): 110 | """ Test operand keys (as sequence) """ 111 | for i, idx in enumerate(self.idxs): 112 | with self.subTest(name=i): 113 | load_layer_operands = LoadLayer._get_operands( 114 | self.V1, self.V2, *idx, self.n 115 | ) 116 | self.assertSequenceEqual( 117 | load_layer_operands.keys(), self.operands[i].keys() 118 | ) 119 | 120 | @unittest.skip("operand value lengths") 121 | def test_get_operands_value_len(self): 122 | """ Test length of operand tensors """ 123 | for i, idx in enumerate(self.idxs): 124 | load_layer_operands = LoadLayer._get_operands( 125 | self.V1, self.V2, *idx, self.n 126 | ) 127 | for key in load_layer_operands: 128 | for tensor_idx in (0, 1): 129 | with self.subTest(name=(i, key, tensor_idx)): 130 | self.assertEqual( 131 | len(load_layer_operands[key][tensor_idx]), 132 | len(self.operands[i][key][tensor_idx]), 133 | ) 134 | 135 | def test_get_operands_values(self): 136 | """ Test values of operand tensors """ 137 | for i, idx in enumerate(self.idxs): 138 | load_layer_operands = LoadLayer._get_operands( 139 | self.V1, self.V2, *idx, self.n 140 | ) 141 | for key in load_layer_operands: 142 | for tensor_idx in (0, 1): 143 | with self.subTest(name=(i, key, tensor_idx)): 144 | self.assertTrue( 145 | torch.equal( 146 | load_layer_operands[key][tensor_idx], 147 | self.operands[i][key][tensor_idx], 148 | ) 149 | ) 150 | 151 | 152 | if __name__ == "__main__": 153 | unittest.main() 154 | -------------------------------------------------------------------------------- /feathermap/models/efficientnet.py: -------------------------------------------------------------------------------- 1 | '''EfficientNet in PyTorch. 2 | 3 | Paper: "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks". 4 | 5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | def swish(x): 13 | return x * x.sigmoid() 14 | 15 | 16 | def drop_connect(x, drop_ratio): 17 | keep_ratio = 1.0 - drop_ratio 18 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) 19 | mask.bernoulli_(keep_ratio) 20 | x.div_(keep_ratio) 21 | x.mul_(mask) 22 | return x 23 | 24 | 25 | class SE(nn.Module): 26 | '''Squeeze-and-Excitation block with Swish.''' 27 | 28 | def __init__(self, in_channels, se_channels): 29 | super(SE, self).__init__() 30 | self.se1 = nn.Conv2d(in_channels, se_channels, 31 | kernel_size=1, bias=True) 32 | self.se2 = nn.Conv2d(se_channels, in_channels, 33 | kernel_size=1, bias=True) 34 | 35 | def forward(self, x): 36 | out = F.adaptive_avg_pool2d(x, (1, 1)) 37 | out = swish(self.se1(out)) 38 | out = self.se2(out).sigmoid() 39 | out = x * out 40 | return out 41 | 42 | 43 | class Block(nn.Module): 44 | '''expansion + depthwise + pointwise + squeeze-excitation''' 45 | 46 | def __init__(self, 47 | in_channels, 48 | out_channels, 49 | kernel_size, 50 | stride, 51 | expand_ratio=1, 52 | se_ratio=0., 53 | drop_rate=0.): 54 | super(Block, self).__init__() 55 | self.stride = stride 56 | self.drop_rate = drop_rate 57 | self.expand_ratio = expand_ratio 58 | 59 | # Expansion 60 | channels = expand_ratio * in_channels 61 | self.conv1 = nn.Conv2d(in_channels, 62 | channels, 63 | kernel_size=1, 64 | stride=1, 65 | padding=0, 66 | bias=False) 67 | self.bn1 = nn.BatchNorm2d(channels) 68 | 69 | # Depthwise conv 70 | self.conv2 = nn.Conv2d(channels, 71 | channels, 72 | kernel_size=kernel_size, 73 | stride=stride, 74 | padding=(1 if kernel_size == 3 else 2), 75 | groups=channels, 76 | bias=False) 77 | self.bn2 = nn.BatchNorm2d(channels) 78 | 79 | # SE layers 80 | se_channels = int(in_channels * se_ratio) 81 | self.se = SE(channels, se_channels) 82 | 83 | # Output 84 | self.conv3 = nn.Conv2d(channels, 85 | out_channels, 86 | kernel_size=1, 87 | stride=1, 88 | padding=0, 89 | bias=False) 90 | self.bn3 = nn.BatchNorm2d(out_channels) 91 | 92 | # Skip connection if in and out shapes are the same (MV-V2 style) 93 | self.has_skip = (stride == 1) and (in_channels == out_channels) 94 | 95 | def forward(self, x): 96 | out = x if self.expand_ratio == 1 else swish(self.bn1(self.conv1(x))) 97 | out = swish(self.bn2(self.conv2(out))) 98 | out = self.se(out) 99 | out = self.bn3(self.conv3(out)) 100 | if self.has_skip: 101 | if self.training and self.drop_rate > 0: 102 | out = drop_connect(out, self.drop_rate) 103 | out = out + x 104 | return out 105 | 106 | 107 | class EfficientNet(nn.Module): 108 | def __init__(self, cfg, num_classes=10): 109 | super(EfficientNet, self).__init__() 110 | self.cfg = cfg 111 | self.conv1 = nn.Conv2d(3, 112 | 32, 113 | kernel_size=3, 114 | stride=1, 115 | padding=1, 116 | bias=False) 117 | self.bn1 = nn.BatchNorm2d(32) 118 | self.layers = self._make_layers(in_channels=32) 119 | self.linear = nn.Linear(cfg['out_channels'][-1], num_classes) 120 | 121 | def _make_layers(self, in_channels): 122 | layers = [] 123 | cfg = [self.cfg[k] for k in ['expansion', 'out_channels', 'num_blocks', 'kernel_size', 124 | 'stride']] 125 | b = 0 126 | blocks = sum(self.cfg['num_blocks']) 127 | for expansion, out_channels, num_blocks, kernel_size, stride in zip(*cfg): 128 | strides = [stride] + [1] * (num_blocks - 1) 129 | for stride in strides: 130 | drop_rate = self.cfg['drop_connect_rate'] * b / blocks 131 | layers.append( 132 | Block(in_channels, 133 | out_channels, 134 | kernel_size, 135 | stride, 136 | expansion, 137 | se_ratio=0.25, 138 | drop_rate=drop_rate)) 139 | in_channels = out_channels 140 | return nn.Sequential(*layers) 141 | 142 | def forward(self, x): 143 | out = swish(self.bn1(self.conv1(x))) 144 | out = self.layers(out) 145 | out = F.adaptive_avg_pool2d(out, 1) 146 | out = out.view(out.size(0), -1) 147 | dropout_rate = self.cfg['dropout_rate'] 148 | if self.training and dropout_rate > 0: 149 | out = F.dropout(out, p=dropout_rate) 150 | out = self.linear(out) 151 | return out 152 | 153 | 154 | def EfficientNetB0(): 155 | cfg = { 156 | 'num_blocks': [1, 2, 2, 3, 3, 4, 1], 157 | 'expansion': [1, 6, 6, 6, 6, 6, 6], 158 | 'out_channels': [16, 24, 40, 80, 112, 192, 320], 159 | 'kernel_size': [3, 3, 5, 3, 5, 5, 3], 160 | 'stride': [1, 2, 2, 2, 1, 2, 1], 161 | 'dropout_rate': 0.2, 162 | 'drop_connect_rate': 0.2, 163 | } 164 | return EfficientNet(cfg) 165 | 166 | 167 | def test(): 168 | net = EfficientNetB0() 169 | x = torch.randn(2, 3, 32, 32) 170 | y = net(x) 171 | print(y.shape) 172 | 173 | 174 | if __name__ == '__main__': 175 | test() 176 | -------------------------------------------------------------------------------- /tests/integration/test_resnet_main.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------------------------- # 2 | # An implementation of https://arxiv.org/pdf/1512.03385.pdf # 3 | # See section 4.2 for the model architecture on CIFAR-10 # 4 | # Some part of the code was referenced from # 5 | # https://github.com/yunjey/pytorch-tutorial and # 6 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py # 7 | # ---------------------------------------------------------------------------- # 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | from feathermap.feathernet import FeatherNet 14 | 15 | 16 | # Device configuration 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | # Hyper-parameters 20 | num_epochs = 1 21 | batch_size = 100 22 | learning_rate = 0.001 23 | 24 | # Image preprocessing modules 25 | transform = transforms.Compose([ 26 | transforms.Pad(4), 27 | transforms.RandomHorizontalFlip(), 28 | transforms.RandomCrop(32), 29 | transforms.ToTensor()]) 30 | 31 | # CIFAR-10 dataset 32 | train_dataset = torchvision.datasets.CIFAR10(root='../../data/', 33 | train=True, 34 | transform=transform, 35 | download=True) 36 | 37 | test_dataset = torchvision.datasets.CIFAR10(root='../../data/', 38 | train=False, 39 | transform=transforms.ToTensor()) 40 | 41 | # Data loader 42 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 43 | batch_size=batch_size, 44 | shuffle=True) 45 | 46 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 47 | batch_size=batch_size, 48 | shuffle=False) 49 | 50 | # 3x3 convolution 51 | def conv3x3(in_channels, out_channels, stride=1): 52 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, 53 | stride=stride, padding=1, bias=False) 54 | 55 | # Residual block 56 | class ResidualBlock(nn.Module): 57 | def __init__(self, in_channels, out_channels, stride=1, downsample=None): 58 | super(ResidualBlock, self).__init__() 59 | self.conv1 = conv3x3(in_channels, out_channels, stride) 60 | self.bn1 = nn.BatchNorm2d(out_channels) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.conv2 = conv3x3(out_channels, out_channels) 63 | self.bn2 = nn.BatchNorm2d(out_channels) 64 | self.downsample = downsample 65 | 66 | def forward(self, x): 67 | residual = x 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | if self.downsample: 74 | residual = self.downsample(x) 75 | out += residual 76 | out = self.relu(out) 77 | return out 78 | 79 | # ResNet 80 | class ResNet(nn.Module): 81 | def __init__(self, block, layers, num_classes=10): 82 | super(ResNet, self).__init__() 83 | self.in_channels = 16 84 | self.conv = conv3x3(3, 16) 85 | self.bn = nn.BatchNorm2d(16) 86 | self.relu = nn.ReLU(inplace=True) 87 | self.layer1 = self.make_layer(block, 16, layers[0]) 88 | self.layer2 = self.make_layer(block, 32, layers[1], 2) 89 | self.layer3 = self.make_layer(block, 64, layers[2], 2) 90 | self.avg_pool = nn.AvgPool2d(8) 91 | self.fc = nn.Linear(64, num_classes) 92 | 93 | def make_layer(self, block, out_channels, blocks, stride=1): 94 | downsample = None 95 | if (stride != 1) or (self.in_channels != out_channels): 96 | downsample = nn.Sequential( 97 | conv3x3(self.in_channels, out_channels, stride=stride), 98 | nn.BatchNorm2d(out_channels)) 99 | layers = [] 100 | layers.append(block(self.in_channels, out_channels, stride, downsample)) 101 | self.in_channels = out_channels 102 | for i in range(1, blocks): 103 | layers.append(block(out_channels, out_channels)) 104 | return nn.Sequential(*layers) 105 | 106 | def forward(self, x): 107 | out = self.conv(x) 108 | out = self.bn(out) 109 | out = self.relu(out) 110 | out = self.layer1(out) 111 | out = self.layer2(out) 112 | out = self.layer3(out) 113 | out = self.avg_pool(out) 114 | out = out.view(out.size(0), -1) 115 | out = self.fc(out) 116 | return out 117 | 118 | base_model = ResNet(ResidualBlock, [2, 2, 2]) 119 | model = FeatherNet(base_model, exclude=(nn.BatchNorm2d), compress=1).to(device) 120 | 121 | #model = base_model 122 | 123 | # Loss and optimizer 124 | criterion = nn.CrossEntropyLoss() 125 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 126 | 127 | # For updating learning rate 128 | def update_lr(optimizer, lr): 129 | for param_group in optimizer.param_groups: 130 | param_group['lr'] = lr 131 | 132 | # Train the model 133 | total_step = len(train_loader) 134 | curr_lr = learning_rate 135 | for epoch in range(num_epochs): 136 | for i, (images, labels) in enumerate(train_loader): 137 | images = images.to(device) 138 | labels = labels.to(device) 139 | 140 | # Forward pass 141 | outputs = model(images) 142 | loss = criterion(outputs, labels) 143 | 144 | # Backward and optimize 145 | optimizer.zero_grad() 146 | loss.backward() 147 | optimizer.step() 148 | 149 | if (i+1) % 100 == 0: 150 | print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}" 151 | .format(epoch+1, num_epochs, i+1, total_step, loss.item())) 152 | 153 | # Decay learning rate 154 | if (epoch+1) % 20 == 0: 155 | curr_lr /= 3 156 | update_lr(optimizer, curr_lr) 157 | 158 | # Test the model 159 | model.eval() 160 | with torch.no_grad(): 161 | correct = 0 162 | total = 0 163 | for images, labels in test_loader: 164 | images = images.to(device) 165 | labels = labels.to(device) 166 | outputs = model(images) 167 | _, predicted = torch.max(outputs.data, 1) 168 | total += labels.size(0) 169 | correct += (predicted == labels).sum().item() 170 | 171 | print('Accuracy of the model on the test images: {} %'.format(100 * correct / total)) 172 | 173 | # Save the model checkpoint 174 | torch.save(model.state_dict(), 'resnet.ckpt') 175 | -------------------------------------------------------------------------------- /tests/exploratory/hashable_params.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation testing for hashable parameters function. 3 | """ 4 | # Many imports derived from PyTorch source 5 | import torch 6 | from torch.nn.modules import Module 7 | import torch.nn as nn 8 | from torch.nn import Parameter 9 | from src.resnet import ResNet, ResidualBlock 10 | from torch import Tensor, device, dtype 11 | from collections import OrderedDict, namedtuple 12 | from typing import ( 13 | Union, 14 | Tuple, 15 | Any, 16 | Callable, 17 | Iterator, 18 | Set, 19 | Optional, 20 | overload, 21 | TypeVar, 22 | Mapping, 23 | Dict, 24 | ) 25 | 26 | 27 | # Device configuration 28 | my_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | 30 | model = ResNet(ResidualBlock, [2, 2, 2]).to(my_device) 31 | 32 | # print(type(model.parameters())) 33 | # print(model.parameters().__next__()) 34 | # print(next(model.named_parameters())) 35 | 36 | test_params = list(model.parameters()) 37 | test_named_params = dict(model.named_parameters()) 38 | 39 | # print(*test_named_params.keys(), sep="\n") 40 | # print(*model.named_modules(), sep='\n') 41 | l = nn.Linear(2, 2) 42 | net1 = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) 43 | net = nn.Sequential(l, l, net1) 44 | # print(*dict(net.parameters()),sep='\n') 45 | 46 | for name, module in model.named_modules(): 47 | # print(name, type(module)) 48 | continue 49 | 50 | # print(l._parameters.items()) 51 | 52 | 53 | def _named_members_subset( 54 | self, 55 | get_members_fn, 56 | prefix="", 57 | exclude: tuple = (), 58 | recurse=True, 59 | ): 60 | r"""Helper method for yielding various names + members of modules.""" 61 | memo = set() 62 | modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] 63 | for module_prefix, module in modules: 64 | if isinstance(module, exclude): 65 | continue 66 | # members from _parameters.items() are odict_items (possibly empty) 67 | members = get_members_fn(module) 68 | for k, v in members: 69 | if v is None or v in memo: 70 | continue 71 | memo.add(v) 72 | name = module_prefix + ("." if module_prefix else "") + k 73 | yield name, v 74 | 75 | 76 | def named_parameters_subset( 77 | self, 78 | prefix: str = "", 79 | exclude: tuple = (nn.BatchNorm2d), 80 | recurse: bool = True, 81 | ) -> Iterator[Tuple[str, Tensor]]: 82 | r"""Returns an iterator over module parameters, yielding both the 83 | name of the parameter as well as the parameter itself. 84 | Args: 85 | prefix (str): prefix to prepend to all parameter names. 86 | recurse (bool): if True, then yields parameters of this module 87 | and all submodules. Otherwise, yields only parameters that 88 | are direct members of this module. 89 | Yields: 90 | (string, Parameter): Tuple containing the name and parameter 91 | Example:: 92 | >>> for name, param in self.named_parameters(): 93 | >>> if name in ['bias']: 94 | >>> print(param.size()) 95 | """ 96 | gen = _named_members_subset( 97 | self, 98 | lambda module: module._parameters.items(), 99 | prefix=prefix, 100 | exclude=exclude, 101 | recurse=recurse, 102 | ) 103 | for elem in gen: 104 | yield elem 105 | 106 | 107 | def parameters_subset( 108 | self, exclude: tuple = (nn.BatchNorm2d), recurse: bool = True 109 | ) -> Iterator[Parameter]: 110 | r"""Returns an iterator over module parameters. 111 | This is typically passed to an optimizer. 112 | Args: 113 | recurse (bool): if True, then yields parameters of this module 114 | and all submodules. Otherwise, yields only parameters that 115 | are direct members of this module. 116 | Yields: 117 | Parameter: module parameter 118 | Example:: 119 | >>> for param in model.parameters(): 120 | >>> print(type(param), param.size()) 121 | (20L,) 122 | (20L, 1L, 5L, 5L) 123 | """ 124 | for name, param in self.named_parameters_subset( 125 | exclude=exclude, recurse=recurse 126 | ): 127 | yield param 128 | 129 | 130 | [ 131 | print(name) 132 | for name, v in my_named_parameters(model, exclude=(nn.BatchNorm2d)) 133 | ] 134 | print("-" * 20) 135 | print(*dict(model.named_parameters()).keys(), sep="\n") 136 | 137 | 138 | # ---------------------------------------------------------------------------- # 139 | # PyTorch methods # 140 | # ---------------------------------------------------------------------------- # 141 | 142 | # parameters(named_parameters(_named_members(named_modules(named_modules)))) 143 | # net._parameters only exists for explicitly defined layers 144 | # (i.e. empty for nn.Sequential or nn.ResNet) 145 | # for a given module, seeks __getattr__; 146 | 147 | 148 | def parameters(self, recurse: bool = True) -> Iterator[Parameter]: 149 | for name, param in self.named_parameters(recurse=recurse): 150 | yield param 151 | 152 | 153 | def named_parameters( 154 | self, prefix: str = "", recurse: bool = True 155 | ) -> Iterator[Tuple[str, Tensor]]: 156 | r"""Returns an iterator over module parameters, yielding both the 157 | name of the parameter as well as the parameter itself. 158 | Args: 159 | prefix (str): prefix to prepend to all parameter names. 160 | """ 161 | gen = self._named_members( 162 | lambda module: module._parameters.items(), 163 | prefix=prefix, 164 | recurse=recurse, 165 | ) 166 | for elem in gen: 167 | yield elem 168 | 169 | 170 | def _named_members(self, get_members_fn, prefix="", recurse=True): 171 | r"""Helper method for yielding various names + members of modules.""" 172 | memo = set() 173 | # iterator 174 | modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] 175 | for module_prefix, module in modules: 176 | # members from _parameters.items() are odict_items (possibly empty) 177 | members = get_members_fn(module) 178 | for k, v in members: 179 | if v is None or v in memo: 180 | continue 181 | memo.add(v) 182 | name = module_prefix + ("." if module_prefix else "") + k 183 | yield name, v 184 | 185 | 186 | def named_modules(self, memo: Optional[Set["Module"]] = None, prefix: str = ""): 187 | r"""Returns an iterator over all modules in the network, yielding 188 | both the name of the module as well as the module itself. 189 | Yields: 190 | (string, Module): Tuple of name and module 191 | Note: 192 | Duplicate modules are returned only once. In the following 193 | example, ``l`` will be returned only once. 194 | Example:: 195 | >>> l = nn.Linear(2, 2) 196 | >>> net = nn.Sequential(l, l) 197 | >>> for idx, m in enumerate(net.named_modules()): 198 | print(idx, '->', m) 199 | 0 -> ('', Sequential( 200 | (0): Linear(in_features=2, out_features=2, bias=True) 201 | (1): Linear(in_features=2, out_features=2, bias=True) 202 | )) 203 | 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) 204 | """ 205 | # non-duplication should not be issue; weights are same in ex. 206 | 207 | if memo is None: 208 | memo = set() 209 | if self not in memo: 210 | memo.add(self) 211 | yield prefix, self 212 | for name, module in self._modules.items(): 213 | if module is None: 214 | continue 215 | submodule_prefix = prefix + ("." if prefix else "") + name 216 | for m in module.named_modules(memo, submodule_prefix): 217 | yield m 218 | -------------------------------------------------------------------------------- /feathermap/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train compressed feathermap/models on CIFAR10. 3 | - Progress bar inspried by https://github.com/kuangliu/pytorch-cifar 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.optim.lr_scheduler import MultiStepLR 10 | from packaging import version 11 | import os 12 | import argparse 13 | from feathermap.utils import progress_bar 14 | from feathermap.models.resnet import ResNet34 15 | from feathermap.feathernet import FeatherNet 16 | from feathermap.dataloader import get_train_valid_loader, get_test_loader 17 | 18 | 19 | def main(): 20 | """Perform training, validation, and testing, with checkpoint loading and saving""" 21 | 22 | # Build Model 23 | print("==> Building model..") 24 | base_model = ResNet34() 25 | if args.compress: 26 | model = FeatherNet( 27 | base_model, 28 | compress=args.compress, 29 | ) 30 | else: 31 | if args.lr != 0.1: 32 | print("Warning: Suggest setting base-model learning rate to 0.1") 33 | model = base_model 34 | 35 | # Enable GPU support 36 | print("==> Setting up device..") 37 | if torch.cuda.is_available(): 38 | print("Utilizing", torch.cuda.device_count(), "GPU(s)!") 39 | if torch.cuda.device_count() > 1: 40 | model = nn.DataParallel(model) 41 | DEV = torch.device("cuda:0") 42 | cuda_kwargs = {"num_workers": args.num_workers, "pin_memory": True} 43 | cudnn.benchmark = True 44 | else: 45 | print("Utilizing CPU!") 46 | DEV = torch.device("cpu") 47 | cuda_kwargs = {} 48 | model.to(DEV) 49 | 50 | # Create dataloaders 51 | print("==> Preparing data..") 52 | train_loader, valid_loader = get_train_valid_loader( 53 | data_dir=args.data_dir, 54 | batch_size=args.batch_size, 55 | valid_size=args.valid_size, 56 | **cuda_kwargs 57 | ) 58 | test_loader = get_test_loader(data_dir=args.data_dir, **cuda_kwargs) 59 | 60 | best_acc = 0 # best validation accuracy 61 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 62 | save_display = False 63 | 64 | # Load checkpoint 65 | if args.resume: 66 | print("==> Resuming from checkpoint..") 67 | assert os.path.isdir("checkpoint"), "Error: no checkpoint directory found!" 68 | checkpoint = torch.load("./checkpoint/" + args.ckpt_name) 69 | model.load_state_dict(checkpoint["model"]) 70 | best_acc = checkpoint["acc"] 71 | start_epoch = checkpoint["epoch"] 72 | 73 | # Initialize optimizers and loss fn 74 | criterion = nn.CrossEntropyLoss() 75 | optimizer = optim.SGD( 76 | model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4 77 | ) 78 | scheduler = MultiStepLR(optimizer, milestones=[100, 200], gamma=0.1) 79 | 80 | def train(epoch: int) -> None: 81 | """Train on CIFAR10 per epoch""" 82 | # maintain backward compatibility; get_last_lr requires PyTorch >= 1.4 83 | last_lr = ( 84 | scheduler.get_last_lr()[0] 85 | if version.parse(torch.__version__) >= version.parse("1.4") 86 | else scheduler.get_lr()[0] 87 | ) 88 | print( 89 | "\nEpoch: {} | Compression: {:.2f} | lr: {:<6}".format( 90 | epoch, args.compress, last_lr 91 | ) 92 | ) 93 | model.train() 94 | train_loss = 0 95 | correct = 0 96 | total = 0 97 | for batch_idx, (inputs, targets) in enumerate(train_loader): 98 | inputs, targets = inputs.to(DEV), targets.to(DEV) 99 | optimizer.zero_grad() 100 | outputs = model(inputs) 101 | loss = criterion(outputs, targets) 102 | loss.backward() 103 | optimizer.step() 104 | 105 | train_loss += loss.item() 106 | _, predicted = outputs.max(1) 107 | total += targets.size(0) 108 | correct += predicted.eq(targets).sum().item() 109 | 110 | progress_bar( 111 | batch_idx, 112 | len(train_loader), 113 | "Loss: {:.3f} | Acc: {:.3f}% ({}/{})".format( 114 | train_loss / (batch_idx + 1), 115 | 100.0 * correct / total, 116 | correct, 117 | total, 118 | ), 119 | ) 120 | 121 | # Validation 122 | def validate(epoch: int) -> None: 123 | """Validate on CIFAR10 per epoch. Save best accuracy for checkpoint storing""" 124 | nonlocal best_acc 125 | nonlocal save_display 126 | model.eval() 127 | valid_loss = 0 128 | correct = 0 129 | total = 0 130 | with torch.no_grad(): 131 | for batch_idx, (inputs, targets) in enumerate(valid_loader): 132 | inputs, targets = inputs.to(DEV), targets.to(DEV) 133 | outputs = model(inputs) 134 | loss = criterion(outputs, targets) 135 | 136 | valid_loss += loss.item() 137 | _, predicted = outputs.max(1) 138 | total += targets.size(0) 139 | correct += predicted.eq(targets).sum().item() 140 | 141 | progress_bar( 142 | batch_idx, 143 | len(valid_loader), 144 | "Loss: {:.3f} | Acc: {:.3f}% ({}/{})".format( 145 | valid_loss / (batch_idx + 1), 146 | 100.0 * correct / total, 147 | correct, 148 | total, 149 | ), 150 | ) 151 | 152 | # Save checkpoint. 153 | acc = 100.0 * correct / total 154 | save_display = acc > best_acc 155 | if acc > best_acc: 156 | state = { 157 | "model": model.state_dict(), 158 | "acc": acc, 159 | "epoch": epoch, 160 | } 161 | if not os.path.isdir("checkpoint"): 162 | os.mkdir("checkpoint") 163 | torch.save(state, "./checkpoint/" + args.ckpt_name) 164 | best_acc = acc 165 | 166 | # Testing 167 | def test(epoch: int) -> None: 168 | """Test on CIFAR10 per epoch.""" 169 | model.eval() 170 | test_loss = 0 171 | correct = 0 172 | total = 0 173 | with torch.no_grad(): 174 | for batch_idx, (inputs, targets) in enumerate(test_loader): 175 | inputs, targets = inputs.to(DEV), targets.to(DEV) 176 | outputs = model(inputs) 177 | loss = criterion(outputs, targets) 178 | 179 | test_loss += loss.item() 180 | _, predicted = outputs.max(1) 181 | total += targets.size(0) 182 | correct += predicted.eq(targets).sum().item() 183 | 184 | progress_bar( 185 | batch_idx, 186 | len(test_loader), 187 | "Loss: {:.3f} | Acc: {:.3f}% ({}/{})".format( 188 | test_loss / (batch_idx + 1), 189 | 100.0 * correct / total, 190 | correct, 191 | total, 192 | ), 193 | ) 194 | 195 | # Train up to 300 epochs 196 | # *Displays* concurent performance on validation and test set while training, 197 | # but strictly uses validation set to determine early stopping 198 | print("==> Initiate Training..") 199 | for epoch in range(start_epoch, 300): 200 | train(epoch) 201 | validate(epoch) 202 | test(epoch) 203 | if save_display: 204 | print("Saving..") 205 | scheduler.step() 206 | 207 | 208 | if __name__ == "__main__": 209 | try: 210 | parser = argparse.ArgumentParser( 211 | description="PyTorch CIFAR10 training with Structured Multi-Hashing compression", 212 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 213 | ) 214 | parser.add_argument( 215 | "--compress", 216 | type=float, 217 | default=0.5, 218 | help="Compression rate. Set to zero for base model", 219 | metavar="", 220 | ) 221 | parser.add_argument( 222 | "--resume", "-r", action="store_true", help="Resume from checkpoint" 223 | ) 224 | parser.add_argument( 225 | "--ckpt-name", 226 | type=str, 227 | default="ckpt.pth", 228 | help="Name of checkpoint", 229 | metavar="", 230 | ) 231 | parser.add_argument( 232 | "--lr", 233 | default=0.01, 234 | type=float, 235 | help="Learning rate. Set to 0.1 for base model (uncompressed) training.", 236 | metavar="", 237 | ) 238 | parser.add_argument( 239 | "--batch-size", type=int, default=128, help="Mini-batch size", metavar="" 240 | ) 241 | parser.add_argument( 242 | "--valid-size", 243 | type=float, 244 | default=0.1, 245 | help="Validation set size as fraction of train", 246 | metavar="", 247 | ) 248 | parser.add_argument( 249 | "--num-workers", 250 | type=int, 251 | default=2, 252 | help="Number of dataloader processing threads. Try adjusting for faster training", 253 | metavar="", 254 | ) 255 | parser.add_argument( 256 | "--data-dir", 257 | type=str, 258 | default="./data/", 259 | help="Path to store CIFAR10 data", 260 | metavar="", 261 | ) 262 | args = parser.parse_args() 263 | main() 264 | except KeyboardInterrupt: 265 | exit() 266 | -------------------------------------------------------------------------------- /feathermap/feathernet.py: -------------------------------------------------------------------------------- 1 | """Provides compression wrapper for user-defined PyTorch model. 2 | - Computes weights based on structured multi-hashing 3 | - Activates forward pre and post hooks for weight layer caching 4 | - Overloads appropriate nn.Module methods 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import Parameter 9 | from torch import Tensor 10 | from typing import Iterator, Tuple, List, Dict, Callable 11 | from math import ceil, sqrt 12 | from feathermap.utils import get_block_rows 13 | import copy 14 | from torch.utils.hooks import RemovableHandle 15 | 16 | 17 | class LoadLayer: 18 | """Forward prehook callable for inner layers. Load weights and biases from V1 and V2, 19 | calculating on the fly. Must be as optimized as possible for low latency.""" 20 | 21 | def __init__( 22 | self, 23 | name: str, 24 | module: nn, 25 | V1: Tensor, 26 | V2: Tensor, 27 | size_n: int, 28 | offset: int, 29 | verbose: bool = False, 30 | ): 31 | self._module = module 32 | self._name = name 33 | self._verbose = verbose 34 | self._V1 = V1 35 | self._V2 = V2 36 | self._offset = offset 37 | self._size_n = size_n 38 | self._w_size = module.weight.size() 39 | self._w_num = module.weight.numel() 40 | self._w_p = module.weight_p # weight scaler 41 | self._w_ops = list( 42 | self._get_operands( 43 | self._V1, self._V2, *self.__get_index_range(), self._size_n 44 | ).values() 45 | ) 46 | self._bias = module.bias is not None 47 | self._b_ops = None 48 | if self._bias: 49 | self._b_size = module.bias.size() 50 | self._b_num = module.bias.numel() 51 | self._b_p = module.bias_p # bias scaler 52 | self._b_ops = list( 53 | self._get_operands( 54 | self._V1, self._V2, *self.__get_index_range(bias=True), self._size_n 55 | ).values() 56 | ) 57 | 58 | def __get_index_range(self, bias: bool = False) -> Tuple[int]: 59 | """Return global weight or bias index range associated with given layer""" 60 | i1, j1 = divmod(self._offset + 1, self._size_n) 61 | if bias: 62 | i2, j2 = divmod(self._offset + self._b_num, self._size_n) 63 | else: 64 | i2, j2 = divmod(self._offset + self._w_num, self._size_n) 65 | return (i1, j1, i2, j2) 66 | 67 | @staticmethod 68 | def _get_operands( 69 | V1: Tensor, V2: Tensor, i1: int, j1: int, i2: int, j2: int, n: int 70 | ) -> dict: 71 | """Return dictionary of operands representing complete slices of V1.dot(V2) from 72 | range [i1, j1] to [i2, j2]. Matrix product maps to underlying matrix of size 73 | (n x n)""" 74 | ops = {} 75 | block_rows = get_block_rows(i1, j1, i2, j2, n) 76 | # Only one row, return whether complete or incomplete 77 | if i2 - i1 == 0: 78 | ops["top"] = (V1[i1, :], V2[:, j1 : j2 + 1]) 79 | return ops 80 | # Has block rows 81 | if block_rows: 82 | ops["block"] = (V1[range(*block_rows), :], V2) 83 | # First row incomplete (from left) 84 | if i1 < block_rows[0]: 85 | ops["top"] = (V1[i1, :], V2[:, j1:]) 86 | # Last row incomplete 87 | if i2 > (block_rows[1] - 1): 88 | ops["bottom"] = (V1[i2, :], V2[:, : j2 + 1]) 89 | return ops 90 | # Two rows, no blocks 91 | else: 92 | ops["top"] = (V1[i1, :], V2[:, j1:]) 93 | ops["bottom"] = (V1[i2, :], V2[:, : j2 + 1]) 94 | return ops 95 | 96 | def __mm_map(self, matrices: List[Tensor]) -> Tensor: 97 | """Helper function for matrix multiplication, including scale parameter""" 98 | return self._w_p * torch.matmul(*matrices).view(-1, 1) 99 | 100 | def __call__(self, module: nn.Module, inputs: Tensor): 101 | if self._verbose: 102 | print("prehook activated: {} {}".format(self._name, self._module)) 103 | 104 | # Load weights 105 | if len(self._w_ops) == 1: 106 | module.weight = self.__mm_map(*self._w_ops).reshape(self._w_size) 107 | else: 108 | m_products = tuple(map(self.__mm_map, self._w_ops)) 109 | module.weight = torch.cat(m_products).reshape(self._w_size) 110 | 111 | # Load biases 112 | if self._bias: 113 | if len(self._b_ops) == 1: 114 | module.bias = self.__mm_map(*self._b_ops).reshape(self._b_size) 115 | else: 116 | m_products = tuple(map(self.__mm_map, self._b_ops)) 117 | module.bias = torch.cat(m_products).reshape(self._b_size) 118 | 119 | 120 | class UnloadLayer: 121 | """Forward posthook callable class with verbose switch""" 122 | 123 | verbose = False 124 | 125 | @classmethod 126 | def __call__(cls, module: nn.Module, inputs: Tensor, outputs: Tensor): 127 | if UnloadLayer.verbose: 128 | print("posthook activated: {}".format(module)) 129 | # Unload weights and biases 130 | module.weight = None 131 | module.bias = None 132 | 133 | 134 | class FeatherNet(nn.Module): 135 | """ 136 | Compresses user-defined PyTorch models based on structured multi-hashing. 137 | 138 | Calculates matrix product V1 * V2 = V, and maps each element of V to global weight 139 | matrix. The size of V1 and V2 are determined based on compression. See README.md for 140 | an overview of structured mutli-hashing. 141 | """ 142 | 143 | 144 | def __init__( 145 | self, 146 | module: nn.Module, 147 | compress: float = 0.5, 148 | exclude: tuple = (nn.BatchNorm2d), 149 | clone: bool = True, 150 | verbose: bool = False, 151 | ) -> None: 152 | super().__init__() 153 | self.module = copy.deepcopy(module) if clone else module 154 | self._verbose = verbose 155 | self._exclude = exclude 156 | self._prehooks = None 157 | self._posthooks = None 158 | self._prehook_callables = None 159 | self._posthook_callable = None 160 | 161 | # Find max compression range 162 | self._max_compress = self.get_max_compression() 163 | self.compress = compress 164 | 165 | # Unregister module Parameters, create scaler attributes, set weights 166 | # as tensors of prior data 167 | self.__unregister_params() 168 | 169 | self._size_n = ceil(sqrt(self.get_num_WandB())) 170 | self._size_m = ceil((self.compress * self._size_n) / 2) 171 | self._V1 = Parameter(torch.Tensor(self._size_n, self._size_m)) 172 | self._V2 = Parameter(torch.Tensor(self._size_m, self._size_n)) 173 | self._V = None 174 | 175 | # Normalize V1 and V2 176 | self.__norm_V() 177 | 178 | def __register_hooks(self) -> None: 179 | """Register forward pre and post hooks""" 180 | prehooks, posthooks, prehook_callables = [], [], [] 181 | offset = -1 182 | for name, module in self._get_WorB_modules(): 183 | # Create callable prehook object; see LoadLayer; update running V.view(-1, 1) index 184 | prehook_callable = LoadLayer( 185 | name, module, self._V1, self._V2, self._size_n, offset, self._verbose 186 | ) 187 | offset += prehook_callable._w_num 188 | if getattr(module, "bias", None) is not None: 189 | offset += prehook_callable._b_num 190 | 191 | # Create callable posthook object 192 | posthook_callable = UnloadLayer() 193 | UnloadLayer.verbose = self._verbose 194 | 195 | # Register hooks 196 | prehook_handle = module.register_forward_pre_hook(prehook_callable) 197 | posthook_handle = module.register_forward_hook(posthook_callable) 198 | 199 | # Collect removable handles 200 | prehooks.append(prehook_handle) 201 | posthooks.append(posthook_handle) 202 | prehook_callables.append(prehook_callable) 203 | 204 | # Pass handles into attributes 205 | self._prehooks = prehooks 206 | self._posthooks = posthooks 207 | self._prehook_callables = prehook_callables 208 | self._posthook_callable = posthook_callable 209 | 210 | def __unregister_hooks(self, hooks: RemovableHandle) -> None: 211 | """Unregister forward pre and post hooks""" 212 | # Remove hooks 213 | if hooks is not None: 214 | for hook in hooks: 215 | hook.remove() 216 | 217 | def __unregister_params(self) -> None: 218 | """Delete params, set attributes as Tensors of prior data, 219 | register new params to scale weights and biases""" 220 | # fan_in will fail on BatchNorm2d.weight 221 | for name, module, kind in self._get_WandB_modules(): 222 | try: 223 | data = module._parameters[kind].data 224 | if kind == "weight": 225 | fan_in = torch.nn.init._calculate_correct_fan(data, "fan_in") 226 | # get bias fan_in from corresponding weight 227 | else: 228 | fan_in = torch.nn.init._calculate_correct_fan( 229 | getattr(module, "weight"), "fan_in" 230 | ) 231 | # Delete from parameter list to avoid loading into state dict 232 | del module._parameters[kind] 233 | scaler = 1 / sqrt(3 * fan_in) 234 | setattr(module, kind, data) 235 | # Add scale parameter to each weight or bias 236 | module.register_parameter( 237 | kind + "_p", Parameter(torch.Tensor([scaler])) 238 | ) 239 | if self._verbose: 240 | print( 241 | "Parameter unregistered, assigned to type Tensor: {}".format( 242 | name + "." + kind 243 | ) 244 | ) 245 | except KeyError: 246 | print( 247 | "{} is already registered as {}".format( 248 | name + "." + kind, type(getattr(module, kind)) 249 | ) 250 | ) 251 | except ValueError: 252 | print( 253 | "Check module exclusion list. Note, cannot calculate fan_in\ 254 | for BatchNorm2d layers." 255 | ) 256 | 257 | def __map_V_to_WandB(self) -> None: 258 | """Calculate V = V1*V2 and allocate to all weights and biases""" 259 | self._V = torch.matmul(self._V1, self._V2) 260 | V = self._V.view(-1, 1) # V.is_contiguous() = True 261 | i = 0 262 | for name, module, kind in self._get_WandB_modules(): 263 | v = getattr(module, kind) 264 | j = v.numel() # elements in weight or bias 265 | v_new = V[i : i + j].reshape(v.size()) # confirmed contiguous 266 | 267 | # Scaler Parameter, e.g. nn.Linear.weight_p 268 | scaler = getattr(module, kind + "_p") 269 | # Update weights and biases, point to elems in V 270 | setattr(module, kind, scaler * v_new) 271 | i += j 272 | 273 | def __clear_WandB(self) -> None: 274 | """Set weights and biases as empty tensors""" 275 | for name, module, kind in self._get_WandB_modules(): 276 | tensor_size = getattr(module, kind).size() 277 | setattr(module, kind, torch.empty(tensor_size)) 278 | 279 | def __norm_V(self) -> None: 280 | """Normalize global weight matrix. Currently implemented only for uniform 281 | intializations""" 282 | # sigma = M**(-1/4); bound follows from uniform dist. 283 | bound = sqrt(12) / 2 * (self._size_m ** (-1 / 4)) 284 | torch.nn.init.uniform_(self._V1, -bound, bound) 285 | torch.nn.init.uniform_(self._V2, -bound, bound) 286 | 287 | def _get_WandB(self) -> Iterator[Tuple[str, Tensor]]: 288 | """Helper function to return weight AND bias attributes in order""" 289 | for name, module, kind in self._get_WandB_modules(): 290 | yield name + "." + kind, getattr(module, kind) 291 | 292 | def _get_WandB_modules(self) -> Iterator[Tuple[str, nn.Module, str]]: 293 | """Helper function to return weight AND bias modules in order. 294 | Adheres to `self.exclusion` list""" 295 | for name, module in self.named_modules(): 296 | if isinstance(module, self._exclude): 297 | continue 298 | if getattr(module, "weight", None) is not None: 299 | yield name, module, "weight" 300 | if getattr(module, "bias", None) is not None: 301 | yield name, module, "bias" 302 | 303 | def _get_WorB_modules(self) -> Iterator[Tuple[str, nn.Module]]: 304 | """Helper function to return weight OR bias modules in order 305 | Adheres to `self.exclusion` list""" 306 | for name, module in self.named_modules(): 307 | if isinstance(module, (self._exclude, FeatherNet)): 308 | continue 309 | if getattr(module, "weight", None) is not None: 310 | yield name, module 311 | 312 | def get_max_compression(self) -> float: 313 | """Calculate maximum compression rate based on largest layer size""" 314 | max_layer_size, _ = self.get_max_num_WandB() 315 | return max_layer_size / self.get_num_WandB() 316 | 317 | def get_max_num_WandB(self) -> Tuple[int, nn.Module]: 318 | """Return size of largest layer's weights + biases""" 319 | w, b, size = 0, 0, 0 320 | layer = None 321 | for name, module in self._get_WorB_modules(): 322 | if module.bias is not None: 323 | b = module.bias.numel() 324 | else: 325 | b = 0 326 | if module.weight is not None: 327 | w = module.weight.numel() 328 | else: 329 | w = 0 330 | if w + b > size: 331 | size = w + b 332 | layer = module 333 | return size, layer 334 | 335 | def get_num_WandB(self) -> int: 336 | """Return total number of weights and biases""" 337 | return sum(v.numel() for name, v in self._get_WandB()) 338 | 339 | def load_state_dict(self, *args, **kwargs) -> Dict: 340 | """Update weights and biases from stored V1, V2 values""" 341 | out = nn.Module.load_state_dict(self, *args, *kwargs) 342 | self.__map_V_to_WandB() 343 | return out 344 | 345 | def train(self, mode: bool = True) -> None: 346 | """Remove forward hooks, load weights and biases. 347 | Note: `self.eval()` calls `self.train(False)`""" 348 | self.training = mode 349 | self.__map_V_to_WandB() 350 | return nn.Module.train(self.module, mode) 351 | 352 | def deploy(self, mode: bool = True) -> None: 353 | """Whether in train or eval mode, activate deploy mode, i.e. weight layer 354 | caching. Note: `self.eval()` calls `self.train(False)`""" 355 | if mode: 356 | nn.Module.train(self.module, mode=False) 357 | self.training = False 358 | # Clear V weight matrix 359 | self._V = None 360 | # Add forward hooks 361 | self.__register_hooks() 362 | # Clear weights and biases 363 | self.__clear_WandB() 364 | 365 | # Remove forward hooks 366 | else: 367 | self.__unregister_hooks(self.prehooks) 368 | self.__unregister_hooks(self.posthooks) 369 | self.__unregister_hooks(self.prehook_outer) 370 | self.__map_V_to_WandB() 371 | 372 | def forward(self, x: Tensor) -> Tensor: 373 | if self.training: 374 | self.__map_V_to_WandB() 375 | output = self.module(x) 376 | if self._verbose: 377 | print("\tIn Model: input size", x.size(), "output size", output.size()) 378 | return output 379 | 380 | 381 | def tests(): 382 | from feathermap.models.resnet import ResNet34, ResNet18 383 | 384 | # Device configuration 385 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 386 | 387 | def linear_test(): 388 | lmodel = nn.Linear(2, 4).to(device) 389 | flmodel = FeatherNet(lmodel, compress=0.5) 390 | flmodel.__VtoWandB() 391 | print(flmodel.get_num_WandB(), flmodel._size_n, flmodel._size_m) 392 | print("V: {}".format(flmodel.V)) 393 | flmodel.__WandBtoV() 394 | [print(name, v) for name, v in flmodel.named_parameters()] 395 | 396 | def res_test(): 397 | def pic_gen(): 398 | for i in range(2): 399 | yield torch.randn([1, 3, 32, 32]) 400 | base_model = ResNet34().to(device) 401 | model = FeatherNet(base_model, compress=1.0, verbose=True).to(device) 402 | for name, module, kind in model._get_WandB_modules(): 403 | p = getattr(module, kind) 404 | print(name, kind, p.size(), p.numel()) 405 | with torch.no_grad(): 406 | for x in pic_gen(): 407 | model(x) 408 | model.deploy() 409 | with torch.no_grad(): 410 | for x in pic_gen(): 411 | model(x) 412 | 413 | res_test() 414 | 415 | 416 | if __name__ == "__main__": 417 | try: 418 | tests() 419 | except KeyboardInterrupt: 420 | exit() 421 | --------------------------------------------------------------------------------