├── mhbn.png ├── __pycache__ ├── mhbn.cpython-36.pyc ├── mvcnn.cpython-36.pyc ├── util.cpython-36.pyc └── custom_dataset.cpython-36.pyc ├── util.py ├── README.md ├── custom_dataset.py ├── mhbn.py ├── main.py ├── resnet.py └── main.ipynb /mhbn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyuan24/MHBNN-PyTorch/HEAD/mhbn.png -------------------------------------------------------------------------------- /__pycache__/mhbn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyuan24/MHBNN-PyTorch/HEAD/__pycache__/mhbn.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/mvcnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyuan24/MHBNN-PyTorch/HEAD/__pycache__/mvcnn.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyuan24/MHBNN-PyTorch/HEAD/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/custom_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyuan24/MHBNN-PyTorch/HEAD/__pycache__/custom_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | def logEpoch(logger, model, epoch, loss, accuracy): 5 | # 1. Log scalar values (scalar summary) 6 | info = {'loss': loss.item(), 'accuracy': accuracy.item()} 7 | 8 | for tag, value in info.items(): 9 | logger.scalar_summary(tag, value, epoch) 10 | 11 | # 2. Log values and gradients of the parameters (histogram summary) 12 | for tag, value in model.named_parameters(): 13 | tag = tag.replace('.', '/') 14 | logger.histo_summary(tag, value.data.cpu().numpy(), epoch) 15 | logger.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), epoch) 16 | 17 | # 3. Log training images (image summary) 18 | #info = {'images': images.view(-1, 28, 28)[:10].cpu().numpy()} 19 | 20 | #for tag, images in info.items(): 21 | #logger.image_summary(tag, images, epoch) 22 | 23 | def save_checkpoint(state, model, resnet=None, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 24 | if resnet: 25 | filepath = os.path.join(checkpoint, model + str(resnet) + '_' + filename) 26 | else: 27 | filepath = os.path.join(checkpoint, model + '_' + filename) 28 | torch.save(state, filepath) 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MHBNN-PyTorch 2 | 3 | A Pytorch implementation of Multi-view Harmonized Bilinear Netowrk for 3D Object Detection(MHBN) inpsired by [Tan Yu et al](http://openaccess.thecvf.com/content_cvpr_2018/html/Yu_Multi-View_Harmonized_Bilinear_CVPR_2018_paper.html). 4 | 5 | In this paper, the 3D object recognition problem is converted to multi-view 2D image classification problem. For each 3D object, there are multiple images taken from different views. 6 | 7 | ![](https://github.com/LiyuanLacfo/MHBNN-PyTorch/blob/master/mhbn.png) 8 | 9 | ### Dependecies 10 | 11 | * torch 0.4.1 12 | * torchvision 13 | * numpy 14 | 15 | ### Dataset 16 | 17 | * ModelNet CAD data can be found at [Princeton](http://modelnet.cs.princeton.edu/) 18 | * ModelNet40 12-view png images can be downloaded at [google drive](https://drive.google.com/file/d/0B4v2jR3WsindMUE3N2xiLVpyLW8/view?usp=sharing) 19 | * You can also create your own png dataset with [Blend](https://github.com/WeiTang114/BlenderPhong) 20 | 21 | ### Train the model 22 | 23 | ``` 24 | python main.py --data 25 | ``` 26 | 27 | ### Special Thanks 28 | 29 | I refered to [RBirkeland](https://github.com/RBirkeland/MVCNN-PyTorch) for some code. 30 | 31 | -------------------------------------------------------------------------------- /custom_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | import os 3 | from PIL import Image 4 | 5 | class MultiViewDataSet(Dataset): 6 | 7 | def find_classes(self, dir): 8 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 9 | classes.sort() 10 | class_to_idx = {classes[i]: i for i in range(len(classes))} 11 | 12 | return classes, class_to_idx 13 | 14 | def __init__(self, root, data_type, transform=None, target_transform=None): 15 | self.x = [] 16 | self.y = [] 17 | self.root = root 18 | 19 | self.classes, self.class_to_idx = self.find_classes(root) 20 | 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | # root /