├── __init__.py ├── models ├── __init__.py ├── mvcnn.py └── resnet.py ├── LICENSE ├── util.py ├── README.md ├── .gitignore ├── custom_dataset.py ├── logger.py └── controller.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 René Birkeland 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | def logEpoch(logger, model, epoch, loss, accuracy): 5 | # 1. Log scalar values (scalar summary) 6 | info = {'loss': loss.item(), 'accuracy': accuracy.item()} 7 | 8 | for tag, value in info.items(): 9 | logger.scalar_summary(tag, value, epoch) 10 | 11 | # 2. Log values and gradients of the parameters (histogram summary) 12 | for tag, value in model.named_parameters(): 13 | tag = tag.replace('.', '/') 14 | logger.histo_summary(tag, value.data.cpu().numpy(), epoch) 15 | logger.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), epoch) 16 | 17 | # 3. Log training images (image summary) 18 | #info = {'images': images.view(-1, 28, 28)[:10].cpu().numpy()} 19 | 20 | #for tag, images in info.items(): 21 | #logger.image_summary(tag, images, epoch) 22 | 23 | def save_checkpoint(state, model, resnet=None, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 24 | if resnet: 25 | filepath = os.path.join(checkpoint, model + str(resnet) + '_' + filename) 26 | else: 27 | filepath = os.path.join(checkpoint, model + '_' + filename) 28 | torch.save(state, filepath) 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MVCNN-PyTorch 2 | ## Multi-View CNN built on ResNet/AlexNet to classify 3D objects 3 | A PyTorch implementation of MVCNN using ResNet, inspired by the paper by [Hang Su](http://vis-www.cs.umass.edu/mvcnn/docs/su15mvcnn.pdf). 4 | MVCNN uses multiple 2D images of 3D objects to classify them. You can use the provided dataset or create your own. 5 | 6 | Also check out my [RotationNet](https://github.com/RBirkeland/RotationNet) implementation whitch outperforms MVCNN (Under construction). 7 | 8 | ![MVCNN](https://preview.ibb.co/eKcJHy/687474703a2f2f7669732d7777772e63732e756d6173732e6564752f6d76636e6e2f696d616765732f6d76636e6e2e706e67.png) 9 | 10 | ### Dependencies 11 | * torch 12 | * torchvision 13 | * numpy 14 | * tensorflow (for logging) 15 | 16 | ### Dataset 17 | ModelNet40 12-view PNG dataset can be downloaded from [Google Drive](https://drive.google.com/file/d/0B4v2jR3WsindMUE3N2xiLVpyLW8/view). 18 | 19 | You can also create your own 2D dataset from 3D objects (.obj, .stl, and .off), using [BlenderPhong](https://github.com/WeiTang114/BlenderPhong) 20 | 21 | ### Setup 22 | ```bash 23 | mkdir checkpoint 24 | mkdir logs 25 | ``` 26 | 27 | ### Train 28 | To start training, simply point to the path of the downloaded dataset. All the other settings are optional. 29 | 30 | ``` 31 | python controller.py [--depth N] [--model MODEL] [--epochs N] [-b N] 32 | [--lr LR] [--momentum M] [--lr-decay-freq W] 33 | [--lr-decay W] [--print-freq N] [-r PATH] [--pretrained] 34 | ``` 35 | 36 | To resume from a checkpoint, use the -r tag together with the path to the checkpoint file. 37 | 38 | ### Tensorboard 39 | To view training logs 40 | ``` 41 | tensorboard --logdir='logs' --port=6006 42 | ``` 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | *.png 3 | *.tar 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /custom_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | import os 3 | from PIL import Image 4 | 5 | class MultiViewDataSet(Dataset): 6 | 7 | def find_classes(self, dir): 8 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 9 | classes.sort() 10 | class_to_idx = {classes[i]: i for i in range(len(classes))} 11 | 12 | return classes, class_to_idx 13 | 14 | def __init__(self, root, data_type, transform=None, target_transform=None): 15 | self.x = [] 16 | self.y = [] 17 | self.root = root 18 | 19 | self.classes, self.class_to_idx = self.find_classes(root) 20 | 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | 24 | # root /