├── .gitignore ├── README.md ├── lbcnn_model.py ├── main.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | *.pt 4 | *.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Pytorch implementation of CVPR'17 - Local Binary Convolutional Neural Networks 2 | 3 | Juefei-Xu, F., Naresh Boddeti, V., & Savvides, M. (2017). Local binary convolutional neural networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 19-28). 4 | 5 | * paper link: http://xujuefei.com/lbcnn.html 6 | * original Torch (Lua) repository: https://github.com/juefeix/lbcnn.torch 7 | 8 | Training even MNIST with parameters, stated in the original repository, is incredibly slow. Here is an example of training a toy model -- "2 x {BatchNorm2d(8) -> ConvLBP(8, 16, 3) -> Conv(16, 8, 1)} -> FC(200) -> FC(50) -> FC(10)" -- on MNIST: 9 | 10 | ``` 11 | Epoch 0/5: 100%|██████████| 235/235 [00:06<00:00, 37.74it/s] 12 | Epoch 0 train accuracy: 0.948 13 | Epoch 1/5: 100%|██████████| 235/235 [00:05<00:00, 41.98it/s] 14 | Epoch 1 train accuracy: 0.962 15 | Epoch 2/5: 100%|██████████| 235/235 [00:05<00:00, 42.01it/s] 16 | Epoch 2 train accuracy: 0.969 17 | Epoch 3/5: 100%|██████████| 235/235 [00:05<00:00, 42.04it/s] 18 | Epoch 3 train accuracy: 0.971 19 | Epoch 4/5: 100%|██████████| 235/235 [00:05<00:00, 41.84it/s] 20 | Epoch 4 train accuracy: 0.971 21 | Finished Training. Total training time: 41 sec 22 | Full forward pass: 100%|██████████| 40/40 [00:00<00:00, 100.42it/s] 23 | MNIST test accuracy: 0.974 24 | ``` 25 | -------------------------------------------------------------------------------- /lbcnn_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvLBP(nn.Conv2d): 7 | def __init__(self, in_channels, out_channels, kernel_size=3, sparsity=0.5): 8 | super().__init__(in_channels, out_channels, kernel_size, padding=1, bias=False) 9 | weights = next(self.parameters()) 10 | matrix_proba = torch.FloatTensor(weights.data.shape).fill_(0.5) 11 | binary_weights = torch.bernoulli(matrix_proba) * 2 - 1 12 | mask_inactive = torch.rand(matrix_proba.shape) > sparsity 13 | binary_weights.masked_fill_(mask_inactive, 0) 14 | weights.data = binary_weights 15 | weights.requires_grad_(False) 16 | 17 | 18 | class BlockLBP(nn.Module): 19 | 20 | def __init__(self, numChannels, numWeights, sparsity=0.5): 21 | super().__init__() 22 | self.batch_norm = nn.BatchNorm2d(numChannels) 23 | self.conv_lbp = ConvLBP(numChannels, numWeights, kernel_size=3, sparsity=sparsity) 24 | self.conv_1x1 = nn.Conv2d(numWeights, numChannels, kernel_size=1) 25 | 26 | def forward(self, x): 27 | residual = x 28 | x = self.batch_norm(x) 29 | x = F.relu(self.conv_lbp(x)) 30 | x = self.conv_1x1(x) 31 | x.add_(residual) 32 | return x 33 | 34 | 35 | class Lbcnn(nn.Module): 36 | def __init__(self, nInputPlane=1, numChannels=8, numWeights=16, full=50, depth=2, sparsity=0.5): 37 | super().__init__() 38 | 39 | self.preprocess_block = nn.Sequential( 40 | nn.Conv2d(nInputPlane, numChannels, kernel_size=3, padding=1), 41 | nn.BatchNorm2d(numChannels), 42 | nn.ReLU(inplace=True) 43 | ) 44 | 45 | chain = [BlockLBP(numChannels, numWeights, sparsity) for i in range(depth)] 46 | self.chained_blocks = nn.Sequential(*chain) 47 | self.pool = nn.AvgPool2d(kernel_size=5, stride=5) 48 | 49 | self.dropout = nn.Dropout(0.5) 50 | self.fc1 = nn.Linear(numChannels * 5 * 5, full) 51 | self.fc2 = nn.Linear(full, 10) 52 | 53 | def forward(self, x): 54 | x = self.preprocess_block(x) 55 | x = self.chained_blocks(x) 56 | x = self.pool(x) 57 | x = x.view(x.shape[0], -1) 58 | x = self.fc1(self.dropout(x)) 59 | x = F.relu(x) 60 | x = self.fc2(self.dropout(x)) 61 | return x 62 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.optim.lr_scheduler 8 | import torch.utils.data 9 | from tqdm import tqdm 10 | 11 | from lbcnn_model import Lbcnn 12 | from utils import calc_accuracy, get_mnist_loader 13 | 14 | MODEL_PATH = os.path.join(os.path.dirname(__file__), 'models', 'lbcnn_best.pt') 15 | 16 | 17 | def test(model=None): 18 | if model is None: 19 | assert os.path.exists(MODEL_PATH), "Train a model first" 20 | lbcnn_depth, state_dict = torch.load(MODEL_PATH) 21 | model = Lbcnn(depth=lbcnn_depth) 22 | model.load_state_dict(state_dict) 23 | loader = get_mnist_loader(train=False) 24 | accuracy = calc_accuracy(model, loader=loader, verbose=True) 25 | print("MNIST test accuracy: {:.3f}".format(accuracy)) 26 | 27 | 28 | def train(n_epochs=50, lbcnn_depth=2, learning_rate=1e-2, momentum=0.9, weight_decay=1e-4, lr_scheduler_step=5): 29 | start = time.time() 30 | models_dir = os.path.dirname(MODEL_PATH) 31 | if not os.path.exists(models_dir): 32 | os.makedirs(models_dir) 33 | 34 | train_loader = get_mnist_loader(train=True) 35 | test_loader = get_mnist_loader(train=False) 36 | model = Lbcnn(depth=lbcnn_depth) 37 | use_cuda = torch.cuda.is_available() 38 | if use_cuda: 39 | model = model.cuda() 40 | best_accuracy = 0. 41 | criterion = nn.CrossEntropyLoss() 42 | optimizer = optim.SGD(filter(lambda param: param.requires_grad, model.parameters()), lr=learning_rate, 43 | momentum=momentum, weight_decay=weight_decay, nesterov=True) 44 | 45 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_scheduler_step) 46 | 47 | for epoch in range(n_epochs): 48 | for batch_id, (inputs, labels) in enumerate( 49 | tqdm(train_loader, desc="Epoch {}/{}".format(epoch, n_epochs))): 50 | if use_cuda: 51 | inputs = inputs.cuda() 52 | labels = labels.cuda() 53 | optimizer.zero_grad() 54 | outputs = model(inputs) 55 | loss = criterion(outputs, labels) 56 | loss.backward() 57 | optimizer.step() 58 | accuracy_train = calc_accuracy(model, loader=train_loader) 59 | accuracy_test = calc_accuracy(model, loader=test_loader) 60 | print("Epoch {} accuracy: train={:.3f}, test={:.3f}".format(epoch, accuracy_train, accuracy_test)) 61 | if accuracy_train > best_accuracy: 62 | best_accuracy = accuracy_train 63 | torch.save((lbcnn_depth, model.state_dict()), MODEL_PATH) 64 | scheduler.step(epoch=epoch) 65 | train_duration_sec = int(time.time() - start) 66 | print('Finished Training. Total training time: {} sec'.format(train_duration_sec)) 67 | 68 | 69 | if __name__ == '__main__': 70 | # train includes test phase at each epoch 71 | train(n_epochs=5) 72 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch==1.4.0 2 | torchvision==0.5.0 3 | tqdm -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from tqdm import tqdm 6 | 7 | 8 | def get_mnist_loader(train=True, batch_size=256): 9 | """ 10 | :param train: train or test fold? 11 | :param batch_size: batch size, int 12 | :return: MNIST loader 13 | """ 14 | transform = transforms.Compose( 15 | [transforms.ToTensor(), 16 | transforms.Normalize(mean=(0.1307,), std=(0.3081,))]) 17 | data_set = torchvision.datasets.MNIST(root='./data', train=train, 18 | download=True, transform=transform) 19 | loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, 20 | shuffle=train, num_workers=4) 21 | return loader 22 | 23 | 24 | def calc_accuracy(model, loader, verbose=False): 25 | """ 26 | :param model: model network 27 | :param loader: torch.utils.data.DataLoader 28 | :param verbose: show progress bar, bool 29 | :return accuracy, float 30 | """ 31 | mode_saved = model.training 32 | model.train(False) 33 | use_cuda = torch.cuda.is_available() 34 | if use_cuda: 35 | model.cuda() 36 | outputs_full = [] 37 | labels_full = [] 38 | for inputs, labels in tqdm(iter(loader), desc="Full forward pass", total=len(loader), disable=not verbose): 39 | if use_cuda: 40 | inputs = inputs.cuda() 41 | labels = labels.cuda() 42 | with torch.no_grad(): 43 | outputs_batch = model(inputs) 44 | outputs_full.append(outputs_batch) 45 | labels_full.append(labels) 46 | model.train(mode_saved) 47 | outputs_full = torch.cat(outputs_full, dim=0) 48 | labels_full = torch.cat(labels_full, dim=0) 49 | _, labels_predicted = torch.max(outputs_full.data, dim=1) 50 | accuracy = torch.sum(labels_full == labels_predicted).item() / float(len(labels_full)) 51 | return accuracy 52 | --------------------------------------------------------------------------------