├── .gitignore ├── LICENSE ├── README.md ├── autoencoder ├── options.py └── train.py ├── data ├── augmentation.py ├── build_som │ ├── save_som.ipynb │ └── util │ │ ├── potential_field.py │ │ └── som.py ├── modelnet_shrec_loader.py ├── sampler_matlab │ ├── RemoveWhiteSpace.m │ ├── pc_generator.m │ ├── pc_generator_test.m │ ├── read_obj.m │ ├── sampler.m │ └── visualization.m └── shapenet_loader.py ├── modelnet ├── options.py └── train.py ├── models ├── autoencoder.py ├── classifier.py ├── index_max_ext │ ├── index_max.cpp │ ├── index_max_cuda.cu │ └── setup.py ├── layers.py ├── losses.py ├── networks.py ├── operations.py └── segmenter.py ├── part-seg ├── options.py └── train.py ├── shrec16 ├── options.py ├── test.py └── train.py └── util ├── __init__.py ├── html.py ├── potential_field.py ├── som.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | __pycache__/ 3 | .idea/ 4 | *.pyc 5 | *.pth 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Carson Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # SO-Net 3 | **SO-Net: Self-Organizing Network for Point Cloud Analysis.** CVPR 2018, Salt Lake City, USA 4 | Jiaxin Li, Ben M. Chen, Gim Hee Lee, National University of Singapore 5 | 6 | 7 | 8 | ## Introduction 9 | SO-Net is a deep network architecture that processes 2D/3D point clouds. It enables various applications including but not limited to classification, shape retrieval, segmentation, reconstruction. The arXiv version of SO-Net can be found [here](https://arxiv.org/abs/1803.04249). 10 | ``` 11 | @article{li2018sonet, 12 | title={SO-Net: Self-Organizing Network for Point Cloud Analysis}, 13 | author={Li, Jiaxin and Chen, Ben M and Lee, Gim Hee}, 14 | journal={arXiv preprint arXiv:1803.04249}, 15 | year={2018} 16 | } 17 | ``` 18 | Inspired by Self-Organizing Network (SOM), SO-Net performs dimensional reduction on point clouds and extracts features based on the SOM nodes, with theoretical guarantee of invariance to point order. SO-Net explicitly models the spatial distribution of points and provides precise control of the receptive field overlap. 19 | 20 | This repository releases codes of 4 applications: 21 | * Classification - ModelNet 40/10, MNIST dataset 22 | * Shape Retrieval - SHREC 2016 dataset 23 | * Part Segmentation - ShapeNetPart dataset 24 | * Auto-encoder - ModelNet 40/10, SHREC 2016, ShapeNetPart 25 | 26 | 27 | ## Installation 28 | Requirements: 29 | - Python 3 30 | - [PyTorch 0.4 or higher](http://pytorch.org/) 31 | - [Faiss](https://github.com/facebookresearch/faiss) 32 | - [visdom](https://github.com/facebookresearch/visdom) 33 | - Compile customized cuda code: 34 | ``` 35 | cd models/index_max_ext 36 | python3 setup.py install 37 | ``` 38 | 39 | Optional dependency: 40 | - Faiss [GPU support](https://github.com/facebookresearch/faiss/blob/master/INSTALL.md) - required by auto-encoder 41 | 42 | ## Dataset 43 | For [ModelNet40/10](https://1drv.ms/u/s!ApbTjxa06z9CgQfKl99yUDHL_wHs) and [ShapeNetPart](https://1drv.ms/u/s!ApbTjxa06z9CgQnl-Qm6KI3Ywbe1), we use the pre-processed dataset provided by [PointNet++](https://github.com/charlesq34/pointnet2) of Charles R. Qi. For SHREC2016, we sampled points uniformly from the original `*.obj` files. Matlab codes that perform sampling is provided in `data/`. 44 | 45 | In SO-Net, we can decouple the SOM training as data pre-processing. So we further process the datasets by generating a SOM for each point cloud. The codes for batch-SOM training can be found in `data/`. 46 | 47 | In addition, our prepared datasets can be found in [Google Drive](https://drive.google.com/open?id=184MbflF_RbDX9MyML3hid7OxsYJ8oQQ7): MNIST, ModelNet, ShapeNetPart, SHREC2016. 48 | 49 | ## Usage 50 | ### Configuration 51 | The 4 applications share the same SO-Net architecture, which is implemented in `models/`. Typically each task has its own folder like `modelnet/`, `part-seg/` that contains its own configuration `options.py`, training script `train.py` and testing script `test.py`. 52 | 53 | To run these tasks, you may need to set the dataset type and path in `options.py`, by changing the default value of `--dataset`, `--dataroot`. 54 | ### Visualization 55 | We use visdom for visualization. Various loss values and the reconstructed point clouds (in auto-encoder) are plotted in real-time. Please start the visdom server before training, otherwise there will be warnings/errors, though the warnings/errors won't affect the training process. 56 | ``` 57 | python3 -m visdom.server 58 | ``` 59 | The visualization results can be viewed in browser with the address of: 60 | ``` 61 | http://localhost:8097 62 | ``` 63 | ### Application - Classification 64 | Point cloud classification can be done on ModelNet40/10 and SHREC2016 dataset. Besides setting `--dataset` and `--dataroot`, `--classes` should be set to the desired class number, i.e, 55 for SHREC2016, 40 for ModelNet40 and 10 for ModelNet10. 65 | ``` 66 | cd modelnet/ 67 | python3 train.py 68 | ``` 69 | ### Application - Shape Retrieval 70 | The training of shape retrieval is the same as classification, while at testing phase, the score vector (length 55 for SHREC2016) is regarded as the feature vector. We calculate the L2 feature distance between each shape in the test set and all shapes in the same predicted category from the test set (including itself). The corresponding retrieval list is constructed by sorting these shapes according to the feature distances. 71 | ``` 72 | cd shrec16/ 73 | python3 train.py 74 | ``` 75 | ### Application - Part Segmentation 76 | Segmentation is formulated as a per-point classification problem. 77 | ``` 78 | cd part-seg/ 79 | python3 train.py 80 | ``` 81 | ### Application - Auto-encoder 82 | An input point cloud is compressed into a feature vector, based on which a point cloud is reconstructed to minimize the Chamfer loss. Supports ModelNet, ShapeNetPart, SHREC2016. 83 | ``` 84 | cd autoencoder/ 85 | python3 train.py 86 | ``` 87 | 88 | ## License 89 | This repository is released under MIT License (see LICENSE file for details). -------------------------------------------------------------------------------- /autoencoder/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | 7 | class Options(): 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser() 10 | self.initialized = False 11 | 12 | def initialize(self): 13 | self.parser.add_argument('--gpu_id', type=int, default=0, help='gpu ids: e.g. 0, 1. -1 is no GPU') 14 | 15 | self.parser.add_argument('--dataset', type=str, default='shapenet', help='modelnet / shrec') 16 | self.parser.add_argument('--dataroot', default='/ssd/dataset/shapenetcore_partanno_segmentation_benchmark_v0_normal//', help='path to images & laser point clouds') 17 | self.parser.add_argument('--classes', type=int, default=40, help='ModelNet40 or ModelNet10') 18 | self.parser.add_argument('--name', type=str, default='train', help='name of the experiment. It decides where to store samples and models') 19 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 20 | 21 | self.parser.add_argument('--batch_size', type=int, default=8, help='input batch size') 22 | self.parser.add_argument('--input_pc_num', type=int, default=1024, help='# of input points') 23 | self.parser.add_argument('--surface_normal', type=bool, default=True, help='use surface normal in the pc input') 24 | self.parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data') 25 | 26 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 27 | self.parser.add_argument('--display_id', type=int, default=200, help='window id of the web display') 28 | 29 | self.parser.add_argument('--output_pc_num', type=int, default=1280, help='# of output points') 30 | self.parser.add_argument('--output_fc_pc_num', type=int, default=256, help='# of fc decoder output points') 31 | self.parser.add_argument('--output_conv_pc_num', type=int, default=1024, help='# of conv decoder output points') 32 | 33 | self.parser.add_argument('--feature_num', type=int, default=1024, help='length of encoded feature') 34 | self.parser.add_argument('--activation', type=str, default='relu', help='activation function: relu, elu') 35 | self.parser.add_argument('--normalization', type=str, default='batch', help='normalization function: batch, instance') 36 | 37 | self.parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 38 | self.parser.add_argument('--dropout', type=float, default=0.5, help='learning rate') 39 | self.parser.add_argument('--node_num', type=int, default=64, help='som node number') 40 | self.parser.add_argument('--k', type=int, default=3, help='knn search') 41 | 42 | self.parser.add_argument('--som_k', type=int, default=9, help='k nearest neighbor of SOM nodes searching on SOM nodes') 43 | self.parser.add_argument('--som_k_type', type=str, default='avg', help='avg / center') 44 | 45 | self.parser.add_argument('--random_pc_dropout_lower_limit', type=float, default=1, help='keep ratio lower limit') 46 | self.parser.add_argument('--bn_momentum', type=float, default=0.1, help='normalization momentum, typically 0.1. Equal to (1-m) in TF') 47 | self.parser.add_argument('--bn_momentum_decay_step', type=int, default=None, help='BN momentum decay step. e.g, 0.5->0.01.') 48 | self.parser.add_argument('--bn_momentum_decay', type=float, default=0.6, help='BN momentum decay step. e.g, 0.5->0.01.') 49 | 50 | self.initialized = True 51 | 52 | def parse(self): 53 | if not self.initialized: 54 | self.initialize() 55 | self.opt = self.parser.parse_args() 56 | 57 | self.opt.device = torch.device("cuda:%d" % (self.opt.gpu_id) if torch.cuda.is_available() else "cpu") 58 | # torch.cuda.set_device(self.opt.gpu_id) 59 | 60 | args = vars(self.opt) 61 | 62 | print('------------ Options -------------') 63 | for k, v in sorted(args.items()): 64 | print('%s: %s' % (str(k), str(v))) 65 | print('-------------- End ----------------') 66 | 67 | # save to the disk 68 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 69 | util.mkdirs(expr_dir) 70 | file_name = os.path.join(expr_dir, 'opt.txt') 71 | with open(file_name, 'wt') as opt_file: 72 | opt_file.write('------------ Options -------------\n') 73 | for k, v in sorted(args.items()): 74 | opt_file.write('%s: %s\n' % (str(k), str(v))) 75 | opt_file.write('-------------- End ----------------\n') 76 | return self.opt 77 | -------------------------------------------------------------------------------- /autoencoder/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | import numpy as np 4 | import math 5 | 6 | from options import Options 7 | opt = Options().parse() # set CUDA_VISIBLE_DEVICES before import torch 8 | 9 | import torch 10 | import torchvision 11 | from torch.autograd import Variable 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | import random 16 | import numpy as np 17 | 18 | from models.autoencoder import Model 19 | from data.modelnet_shrec_loader import ModelNet_Shrec_Loader 20 | from data.shapenet_loader import ShapeNetLoader 21 | from util.visualizer import Visualizer 22 | 23 | 24 | if __name__=='__main__': 25 | if opt.dataset=='modelnet' or opt.dataset=='shrec': 26 | trainset = ModelNet_Shrec_Loader(opt.dataroot, 'train', opt) 27 | dataset_size = len(trainset) 28 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.nThreads) 29 | print('#training point clouds = %d' % len(trainset)) 30 | 31 | testset = ModelNet_Shrec_Loader(opt.dataroot, 'test', opt) 32 | testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.nThreads) 33 | elif opt.dataset=='shapenet': 34 | trainset = ShapeNetLoader(opt.dataroot, 'train', opt) 35 | dataset_size = len(trainset) 36 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.nThreads) 37 | print('#training point clouds = %d' % len(trainset)) 38 | 39 | tesetset = ShapeNetLoader(opt.dataroot, 'test', opt) 40 | testloader = torch.utils.data.DataLoader(tesetset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.nThreads) 41 | else: 42 | raise Exception('Dataset error.') 43 | 44 | model = Model(opt) 45 | 46 | visualizer = Visualizer(opt) 47 | 48 | best_loss = 99 49 | for epoch in range(601): 50 | 51 | epoch_iter = 0 52 | for i, data in enumerate(trainloader): 53 | iter_start_time = time.time() 54 | epoch_iter += opt.batch_size 55 | 56 | if opt.dataset=='modelnet' or opt.dataset=='shrec': 57 | input_pc, input_sn, input_label, input_node, input_node_knn_I = data 58 | model.set_input(input_pc, input_sn, input_label, input_node, input_node_knn_I) 59 | elif opt.dataset=='shapenet': 60 | input_pc, input_sn, input_label, input_seg, input_node, input_node_knn_I = data 61 | model.set_input(input_pc, input_sn, input_label, input_node, input_node_knn_I) 62 | 63 | model.optimize() 64 | 65 | if i % 100 == 0: 66 | # print/plot errors 67 | t = (time.time() - iter_start_time) / opt.batch_size 68 | 69 | errors = model.get_current_errors() 70 | 71 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 72 | visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors) 73 | 74 | # print(model.autoencoder.encoder.feature) 75 | visuals = model.get_current_visuals() 76 | visualizer.display_current_results(visuals, epoch, i) 77 | 78 | # test network 79 | if epoch >= 0 and epoch%1==0: 80 | batch_amount = 0 81 | model.test_loss.data.zero_() 82 | for i, data in enumerate(testloader): 83 | if opt.dataset == 'modelnet' or opt.dataset=='shrec': 84 | input_pc, input_sn, input_label, input_node, input_node_knn_I = data 85 | model.set_input(input_pc, input_sn, input_label, input_node, input_node_knn_I) 86 | elif opt.dataset == 'shapenet': 87 | input_pc, input_sn, input_label, input_seg, input_node, input_node_knn_I = data 88 | model.set_input(input_pc, input_sn, input_label, input_node, input_node_knn_I) 89 | model.test_model() 90 | 91 | batch_amount += input_label.size()[0] 92 | 93 | # # accumulate loss 94 | model.test_loss += model.loss_chamfer.detach() * input_label.size()[0] 95 | 96 | model.test_loss /= batch_amount 97 | if model.test_loss.item() < best_loss: 98 | best_loss = model.test_loss.item() 99 | print('Tested network. So far lowest loss: %f' % best_loss ) 100 | 101 | # learning rate decay 102 | if epoch%20==0 and epoch>0: 103 | model.update_learning_rate(0.5) 104 | 105 | # save network 106 | if epoch%1==0 and epoch>0: 107 | print("Saving network...") 108 | model.save_network(model.encoder, 'encoder', '%d_%f' % (epoch, model.test_loss.item()), opt.gpu_id) 109 | model.save_network(model.decoder, 'decoder', '%d_%f' % (epoch, model.test_loss.item()), opt.gpu_id) 110 | 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /data/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numbers 3 | import os 4 | import os.path 5 | import numpy as np 6 | import struct 7 | import math 8 | 9 | import torch 10 | import torchvision 11 | import matplotlib.pyplot as plt 12 | import h5py 13 | import faiss 14 | 15 | 16 | def rotate_point_cloud_90(data): 17 | """ Randomly rotate the point clouds to augument the dataset 18 | rotation is per shape based along up direction 19 | Input: 20 | Nx3 array, original point clouds 21 | Return: 22 | Nx3 array, rotated point clouds 23 | """ 24 | rotated_data = np.zeros(data.shape, dtype=np.float32) 25 | 26 | rotation_angle = np.random.randint(low=0, high=4) * (np.pi/2.0) 27 | cosval = np.cos(rotation_angle) 28 | sinval = np.sin(rotation_angle) 29 | rotation_matrix = np.array([[cosval, 0, sinval], 30 | [0, 1, 0], 31 | [-sinval, 0, cosval]]) 32 | rotated_data = np.dot(data.reshape((-1, 3)), rotation_matrix) 33 | 34 | return rotated_data 35 | 36 | 37 | def rotate_point_cloud(data): 38 | """ Randomly rotate the point clouds to augument the dataset 39 | rotation is per shape based along up direction 40 | Input: 41 | Nx3 array, original point clouds 42 | Return: 43 | Nx3 array, rotated point clouds 44 | """ 45 | rotated_data = np.zeros(data.shape, dtype=np.float32) 46 | 47 | rotation_angle = np.random.uniform() * 2 * np.pi 48 | cosval = np.cos(rotation_angle) 49 | sinval = np.sin(rotation_angle) 50 | rotation_matrix = np.array([[cosval, 0, sinval], 51 | [0, 1, 0], 52 | [-sinval, 0, cosval]]) 53 | rotated_data = np.dot(data.reshape((-1, 3)), rotation_matrix) 54 | 55 | return rotated_data 56 | 57 | 58 | def rotate_point_cloud_with_normal_som(pc, surface_normal, som): 59 | """ Randomly rotate the point clouds to augument the dataset 60 | rotation is per shape based along up direction 61 | Input: 62 | Nx3 array, original point clouds 63 | Return: 64 | Nx3 array, rotated point clouds 65 | """ 66 | 67 | rotation_angle = np.random.uniform() * 2 * np.pi 68 | # rotation_angle = np.random.randint(low=0, high=12) * (2*np.pi / 12.0) 69 | cosval = np.cos(rotation_angle) 70 | sinval = np.sin(rotation_angle) 71 | rotation_matrix = np.array([[cosval, 0, sinval], 72 | [0, 1, 0], 73 | [-sinval, 0, cosval]]) 74 | 75 | rotated_pc = np.dot(pc, rotation_matrix) 76 | rotated_surface_normal = np.dot(surface_normal, rotation_matrix) 77 | rotated_som = np.dot(som, rotation_matrix) 78 | 79 | return rotated_pc, rotated_surface_normal, rotated_som 80 | 81 | 82 | def rotate_perturbation_point_cloud(data, angle_sigma=0.06, angle_clip=0.18): 83 | """ Randomly perturb the point clouds by small rotations 84 | Input: 85 | Nx3 array, original point clouds 86 | Return: 87 | Nx3 array, rotated point clouds 88 | """ 89 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 90 | Rx = np.array([[1,0,0], 91 | [0,np.cos(angles[0]),-np.sin(angles[0])], 92 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 93 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 94 | [0,1,0], 95 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 96 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 97 | [np.sin(angles[2]),np.cos(angles[2]),0], 98 | [0,0,1]]) 99 | R = np.dot(Rz, np.dot(Ry,Rx)) 100 | 101 | rotated_data = np.dot(data, R) 102 | 103 | return rotated_data 104 | 105 | 106 | def rotate_perturbation_point_cloud_with_normal_som(pc, surface_normal, som, angle_sigma=0.06, angle_clip=0.18): 107 | """ Randomly perturb the point clouds by small rotations 108 | Input: 109 | Nx3 array, original point clouds 110 | Return: 111 | Nx3 array, rotated point clouds 112 | """ 113 | 114 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 115 | Rx = np.array([[1,0,0], 116 | [0,np.cos(angles[0]),-np.sin(angles[0])], 117 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 118 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 119 | [0,1,0], 120 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 121 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 122 | [np.sin(angles[2]),np.cos(angles[2]),0], 123 | [0,0,1]]) 124 | R = np.dot(Rz, np.dot(Ry,Rx)) 125 | 126 | rotated_pc = np.dot(pc, R) 127 | rotated_surface_normal = np.dot(surface_normal, R) 128 | rotated_som = np.dot(som, R) 129 | 130 | return rotated_pc, rotated_surface_normal, rotated_som 131 | 132 | 133 | def jitter_point_cloud(data, sigma=0.01, clip=0.05): 134 | """ Randomly jitter points. jittering is per point. 135 | Input: 136 | Nx3 array, original point clouds 137 | Return: 138 | Nx3 array, jittered point clouds 139 | """ 140 | N, C = data.shape 141 | assert(clip > 0) 142 | jittered_data = np.clip(sigma * np.random.randn(N, C), -1*clip, clip) 143 | jittered_data += data 144 | return jittered_data 145 | -------------------------------------------------------------------------------- /data/build_som/save_som.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import random\n", 12 | "import numbers\n", 13 | "import os\n", 14 | "import os.path\n", 15 | "import numpy as np\n", 16 | "import struct\n", 17 | "import math\n", 18 | "\n", 19 | "import torch\n", 20 | "import torchvision\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "import h5py\n", 23 | "import json\n", 24 | "\n", 25 | "from util import som\n", 26 | "\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "from mpl_toolkits.mplot3d import Axes3D\n", 29 | "%matplotlib qt5 \n" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "collapsed": true 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "# for train and val\n", 41 | "def som_saver_shrec2016(root, rows, cols, gpu_ids, output_root):\n", 42 | " som_builder = som.SOM(rows, cols, 3, gpu_ids)\n", 43 | " \n", 44 | " folder_list = os.listdir(root)\n", 45 | " for i, folder in enumerate(folder_list):\n", 46 | " file_list = os.listdir(os.path.join(root, folder))\n", 47 | " for j, file in enumerate(file_list):\n", 48 | " if file[-3:] == 'txt':\n", 49 | " data = np.loadtxt(os.path.join(root, folder, file))\n", 50 | " pc_np = data[:, 0:3]\n", 51 | " sn_np = data[:, 3:6]\n", 52 | " \n", 53 | " pc_np_sampled = pc_np[np.random.choice(pc_np.shape[0], 4096, replace=False), :]\n", 54 | " pc = torch.from_numpy(pc_np_sampled.transpose().astype(np.float32)).cuda() # 3xN tensor\n", 55 | " som_builder.optimize(pc)\n", 56 | " som_node_np = som_builder.node.cpu().numpy().transpose().astype(np.float32) # node_numx3\n", 57 | "\n", 58 | " npz_file = os.path.join(output_root, file[0:-4]+'.npz')\n", 59 | " np.savez(npz_file, pc=pc_np, sn=sn_np, som_node=som_node_np)\n", 60 | "\n", 61 | " if j%100==0:\n", 62 | " print('%s, %s' % (folder, file))\n", 63 | "\n", 64 | "# print(pc_np.shape)\n", 65 | "# print(som_node_np.shape)\n", 66 | "\n", 67 | "# x_np = pc_np\n", 68 | "# node_np = som_node_np\n", 69 | "# fig = plt.figure()\n", 70 | "# ax = Axes3D(fig)\n", 71 | "# ax.scatter(x_np[:,0].tolist(), x_np[:,1].tolist(), x_np[:,2].tolist(), s=1)\n", 72 | "# ax.scatter(node_np[:,0].tolist(), node_np[:,1].tolist(), node_np[:,2].tolist(), s=6, c='r')\n", 73 | "# plt.show()\n", 74 | "\n", 75 | "# if j>10:\n", 76 | "# break\n", 77 | "# break" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": { 84 | "collapsed": true, 85 | "scrolled": true 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "rows, cols = 8, 8\n", 90 | "som_saver_shrec2016('/ssd/dataset/SHREC2016/obj_txt/train', rows, cols, True, '/ssd/dataset/SHREC2016/%dx%d/train'%(rows, cols))\n", 91 | "som_saver_shrec2016('/ssd/dataset/SHREC2016/obj_txt/val', rows, cols, True, '/ssd/dataset/SHREC2016/%dx%d/val'%(rows, cols))" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": { 98 | "collapsed": true 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "# for test set\n", 103 | "def som_saver_shrec2016(root, rows, cols, gpu_ids, output_root):\n", 104 | " som_builder = som.SOM(rows, cols, 3, gpu_ids)\n", 105 | "\n", 106 | " file_list = os.listdir(root)\n", 107 | " for j, file in enumerate(file_list):\n", 108 | " if file[-3:] == 'txt':\n", 109 | " data = np.loadtxt(os.path.join(root, file))\n", 110 | " pc_np = data[:, 0:3]\n", 111 | " sn_np = data[:, 3:6]\n", 112 | "\n", 113 | " pc_np_sampled = pc_np[np.random.choice(pc_np.shape[0], 4096, replace=False), :]\n", 114 | " pc = torch.from_numpy(pc_np_sampled.transpose().astype(np.float32)).cuda() # 3xN tensor\n", 115 | " som_builder.optimize(pc)\n", 116 | " som_node_np = som_builder.node.cpu().numpy().transpose().astype(np.float32) # node_numx3\n", 117 | "\n", 118 | " npz_file = os.path.join(output_root, file[0:-4]+'.npz')\n", 119 | " np.savez(npz_file, pc=pc_np, sn=sn_np, som_node=som_node_np)\n", 120 | "\n", 121 | " if j%100==0:\n", 122 | " print('%s' % (file))\n", 123 | "\n", 124 | "# print(pc_np.shape)\n", 125 | "# print(som_node_np.shape)\n", 126 | "\n", 127 | "# x_np = pc_np\n", 128 | "# node_np = som_node_np\n", 129 | "# fig = plt.figure()\n", 130 | "# ax = Axes3D(fig)\n", 131 | "# ax.scatter(x_np[:,0].tolist(), x_np[:,1].tolist(), x_np[:,2].tolist(), s=1)\n", 132 | "# ax.scatter(node_np[:,0].tolist(), node_np[:,1].tolist(), node_np[:,2].tolist(), s=6, c='r')\n", 133 | "# plt.show()\n", 134 | "\n", 135 | "# if j>10:\n", 136 | "# break\n", 137 | "# break" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": { 144 | "collapsed": true 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "rows, cols = 8, 8\n", 149 | "som_saver_shrec2016('/ssd/dataset/SHREC2016/obj_txt/test_allinone', rows, cols, True, '/ssd/dataset/SHREC2016/%dx%d/test'%(rows, cols))" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": { 156 | "collapsed": true 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "file = '/ssd/dataset/SHREC2016/8x8/train/model_013435.npz'\n", 161 | "data = np.load(file)\n", 162 | "pc_np = data['pc']\n", 163 | "sn_np = data['sn']\n", 164 | "som_node_np = data['som_node']\n", 165 | "\n", 166 | "print(pc_np)\n", 167 | "print(sn_np)\n", 168 | "print(som_node_np)\n", 169 | "\n", 170 | "x_np = pc_np\n", 171 | "node_np = som_node_np\n", 172 | "fig = plt.figure()\n", 173 | "ax = Axes3D(fig)\n", 174 | "ax.scatter(x_np[:,0].tolist(), x_np[:,1].tolist(), x_np[:,2].tolist(), s=1)\n", 175 | "ax.scatter(node_np[:,0].tolist(), node_np[:,1].tolist(), node_np[:,2].tolist(), s=6, c='r')\n", 176 | "plt.show()" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": { 183 | "collapsed": true 184 | }, 185 | "outputs": [], 186 | "source": [ 187 | "file_list = os.listdir('/ssd/dataset/SHREC2016/8x8/test')\n", 188 | "file_list.sort()\n", 189 | "f = open('test.txt', 'w')\n", 190 | "for file in file_list:\n", 191 | " f.write('%s\\n' % file[6:-4])\n", 192 | "f.close()" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": { 199 | "collapsed": true 200 | }, 201 | "outputs": [], 202 | "source": [] 203 | } 204 | ], 205 | "metadata": { 206 | "kernelspec": { 207 | "display_name": "Python 3", 208 | "language": "python", 209 | "name": "python3" 210 | }, 211 | "language_info": { 212 | "codemirror_mode": { 213 | "name": "ipython", 214 | "version": 3 215 | }, 216 | "file_extension": ".py", 217 | "mimetype": "text/x-python", 218 | "name": "python", 219 | "nbconvert_exporter": "python", 220 | "pygments_lexer": "ipython3", 221 | "version": "3.5.2" 222 | } 223 | }, 224 | "nbformat": 4, 225 | "nbformat_minor": 2 226 | } 227 | -------------------------------------------------------------------------------- /data/build_som/util/potential_field.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numbers 3 | import os 4 | import os.path 5 | import numpy as np 6 | import struct 7 | import math 8 | import time 9 | 10 | 11 | class PotentialField: 12 | def __init__(self, node_num, dim): 13 | self.node_num = node_num 14 | self.dim = dim 15 | np.random.seed(2017) 16 | self.node = np.random.rand(self.node_num, self.dim) * 2 - 1 17 | np.random.seed() 18 | 19 | self.learning_rate = 0.01 20 | 21 | def node_force(self, src, dst): 22 | # return the force from src to dst 23 | f = dst - src 24 | f_norm = np.linalg.norm(f) + 0.00001 25 | f = f / f_norm / f_norm ** 2 26 | return f 27 | 28 | def wall_force(self, dst): 29 | f = np.zeros(self.dim) 30 | for i in range(self.dim): 31 | x = dst[i] 32 | # no force if far away 33 | if math.fabs(x) < 0.01: 34 | continue 35 | 36 | f_tmp = np.zeros(self.dim) 37 | f_tmp[i] = -1 * x * self.node_num/1.5 38 | f = f + f_tmp 39 | return f 40 | 41 | def get_total_node_force(self): 42 | force = np.zeros((self.node_num, self.dim)) 43 | for j in range(self.node_num): 44 | dst = self.node[j] 45 | for k in range(self.node_num): 46 | force[j] += self.node_force(self.node[k], dst) 47 | return force 48 | 49 | def get_total_wall_force(self): 50 | force = np.zeros((self.node_num, self.dim)) 51 | for j in range(self.node_num): 52 | dst = self.node[j] 53 | force[j] += self.wall_force(dst) 54 | return force 55 | 56 | def optimize(self): 57 | for i in range(100): 58 | learning_rate = self.learning_rate 59 | 60 | # cumulate the force 61 | force = np.zeros((self.node_num, self.dim)) 62 | for j in range(self.node_num): 63 | dst = self.node[j] 64 | force[j] += self.wall_force(dst) 65 | 66 | for k in range(self.node_num): 67 | force[j] += self.node_force(self.node[k], dst) 68 | 69 | # apply the force 70 | self.node += force * learning_rate 71 | 72 | self.reorder() 73 | 74 | def reorder(self): 75 | node_ordered = self.node[self.node[:, 0].argsort()] 76 | 77 | rows = int(math.sqrt(self.node_num)) 78 | cols = rows 79 | node_ordered = node_ordered.reshape((rows, cols, self.dim)) 80 | for i in range(rows): 81 | node_row = node_ordered[i] 82 | node_row = node_row[node_row[:, 1].argsort()] 83 | node_ordered[i] = node_row 84 | node_ordered = node_ordered.reshape((self.node_num, self.dim)) 85 | 86 | self.node = node_ordered 87 | 88 | -------------------------------------------------------------------------------- /data/build_som/util/som.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numbers 3 | import os 4 | import os.path 5 | import numpy as np 6 | import struct 7 | import math 8 | import time 9 | import gc 10 | 11 | import torch 12 | import torchvision 13 | 14 | from . import potential_field 15 | 16 | 17 | class SOM(): 18 | def __init__(self, rows=4, cols=4, dim=3, gpu_id=-1): 19 | ''' 20 | Can't put into dataloader, because dataloader keeps only 1 class instance. So this should be used offline, 21 | to save som result into numpy array. 22 | :param rows: 23 | :param cols: 24 | :param dim: 25 | :param gpu_id: 26 | ''' 27 | self.rows = rows 28 | self.cols = cols 29 | self.dim = dim 30 | self.node_num = rows * cols 31 | 32 | self.sigma = 0.4 33 | self.learning_rate = 0.5 34 | self.max_iteration = 60 35 | 36 | self.gpu_id = gpu_id 37 | self.device = torch.device("cuda:%d" % gpu_id) 38 | 39 | # node: Cx(rowsxcols), tensor 40 | self.node = torch.FloatTensor(self.dim, self.rows * self.cols).zero_() 41 | self.node_idx_list = torch.from_numpy(np.arange(self.rows * self.cols).astype(np.float32)) 42 | self.init_weighting_matrix = torch.FloatTensor(self.node_num, self.rows, self.cols) # node_numxrowsxcols 43 | if self.gpu_id >= 0: 44 | self.node = self.node.to(self.device) 45 | self.node_idx_list = self.node_idx_list.to(self.device) 46 | self.init_weighting_matrix = self.init_weighting_matrix.to(self.device) 47 | 48 | self.get_init_weighting_matrix() 49 | 50 | # initialize the node by potential field 51 | pf = potential_field.PotentialField(self.node_num, self.dim) 52 | pf.optimize() 53 | self.node_init_value = torch.from_numpy(pf.node.transpose().astype(np.float32)) 54 | 55 | def node_init(self): 56 | self.node.copy_(self.node_init_value) 57 | 58 | def get_init_weighting_matrix(self): 59 | ''' 60 | get the initial weighting matrix, later the weighting matrix wil base on the init. 61 | ''' 62 | for idx in range(self.rows * self.cols): 63 | (i, j) = self.idx2multi(idx) 64 | self.init_weighting_matrix[idx, :] = self.gaussian((i, j), self.sigma) 65 | if self.gpu_id >= 0: 66 | self.init_weighting_matrix = self.init_weighting_matrix.to(self.device) 67 | 68 | def get_weighting_matrix(self, sigma): 69 | scale = 1.0 / ((sigma / self.sigma) ** 2) 70 | weighting_matrix = torch.exp(torch.log(self.init_weighting_matrix) * scale) 71 | return weighting_matrix 72 | 73 | def gaussian(self, c, sigma): 74 | """Returns a Gaussian centered in c""" 75 | d = 2 * np.pi * sigma * sigma 76 | ax = np.exp(-np.power(np.arange(self.rows) - c[0], 2) / d) 77 | ay = np.exp(-np.power(np.arange(self.cols) - c[1], 2) / d) 78 | return torch.from_numpy(np.outer(ax, ay).astype(np.float32)) 79 | 80 | def idx2multi(self, i): 81 | return (i // self.cols, i % self.cols) 82 | 83 | def query(self, x): 84 | ''' 85 | :param x: input data CxN tensor 86 | :return: mask: Nxnode_num 87 | ''' 88 | # expand as CxNxnode_num 89 | node = self.node.unsqueeze(1).expand(x.size(0), x.size(1), self.rows * self.cols) 90 | x_expanded = x.unsqueeze(2).expand_as(node) 91 | 92 | # calcuate difference between x and each node 93 | diff = x_expanded - node # CxNxnode_num 94 | diff_norm = (diff ** 2).sum(dim=0) # Nxnode_num 95 | 96 | # find the nearest neighbor 97 | _, min_idx = torch.min(diff_norm, dim=1) # N 98 | min_idx_expanded = min_idx.unsqueeze(1).expand(min_idx.size()[0], self.rows * self.cols).float() # Nxnode_num 99 | 100 | node_idx_list = self.node_idx_list.unsqueeze(0).expand_as(min_idx_expanded) # Nxnode_num 101 | mask = torch.eq(min_idx_expanded, node_idx_list).float() # Nxnode_num 102 | mask_row_max, _ = torch.max(mask, dim=0) # node_num, this indicates whether the node has nearby x 103 | 104 | return mask, mask_row_max 105 | 106 | def batch_update(self, x, iteration): 107 | # x is CxN tensor, C==self.dim, W=1 108 | assert (x.size()[0] == self.dim) 109 | 110 | # get learning_rate and sigma 111 | learning_rate = self.learning_rate / (1 + 2 * iteration / self.max_iteration) 112 | sigma = self.sigma / (1 + 2 * iteration / self.max_iteration) 113 | 114 | # expand as CxNxnode_num 115 | node = self.node.unsqueeze(1).expand(x.size(0), x.size(1), self.rows * self.cols) 116 | x_expanded = x.unsqueeze(2).expand_as(node) 117 | 118 | # calcuate difference between x and each node 119 | diff = x_expanded - node # CxNxnode_num 120 | diff_norm = (diff ** 2).sum(dim=0) # Nxnode_num 121 | 122 | # find the nearest neighbor 123 | _, min_idx = torch.min(diff_norm, dim=1) # N 124 | min_idx_expanded = min_idx.unsqueeze(1).expand(min_idx.size()[0], self.rows * self.cols).float() # Nxnode_num 125 | 126 | node_idx_list = self.node_idx_list.unsqueeze(0).expand_as(min_idx_expanded) # Nxnode_num 127 | mask = torch.eq(min_idx_expanded, node_idx_list).float() # Nxnode_num 128 | mask_row_sum = torch.sum(mask, dim=0) + 0.00001 # node_num 129 | mask_row_max, _ = torch.max(mask, dim=0) # node_num, this indicates whether the node has nearby x 130 | 131 | # calculate the mean x for each node 132 | x_expanded_masked = x_expanded * mask.unsqueeze(0).expand_as(x_expanded) # CxNxnode_num 133 | x_expanded_masked_sum = torch.sum(x_expanded_masked, dim=1) # Cxnode_num 134 | x_expanded_mask_mean = x_expanded_masked_sum / mask_row_sum.unsqueeze(0).expand_as( 135 | x_expanded_masked_sum) # Cxnode_num 136 | 137 | # each x_expanded_mask_mean (in total node_num vectors) will calculate its diff with all nodes 138 | # multiply the mask_row_max, so that the isolated node won't be pulled to the center 139 | x_expanded_mask_mean_expanded = x_expanded_mask_mean.unsqueeze(2).expand(self.dim, self.rows * self.cols, 140 | self.rows * self.cols) # Cxnode_numxnode_num 141 | node_expanded_transposed = self.node.unsqueeze(1).expand_as(x_expanded_mask_mean_expanded) # .transpose(1,2) 142 | diff_masked_mean = x_expanded_mask_mean_expanded - node_expanded_transposed # Cxnode_numxnode_num 143 | diff_masked_mean = diff_masked_mean * mask_row_max.unsqueeze(1).unsqueeze(0).expand_as(diff_masked_mean) 144 | 145 | # compute the neighrbor weighting 146 | # weighting_matrix = torch.FloatTensor(self.rows*self.cols, self.rows, self.cols) # node_numxrowsxcols 147 | # for idx in range(self.rows*self.cols): 148 | # (i,j) = self.idx2multi(idx) 149 | # weighting_matrix[idx,:] = self.gaussian((i,j), sigma) 150 | # if self.gpu_id >= 0: 151 | # weighting_matrix = weighting_matrix.to(self.device) 152 | # compute the neighrbor weighting using pre-computed matrix 153 | weighting_matrix = self.get_weighting_matrix(sigma) # node_numxrowsxcols 154 | 155 | # compute the update 156 | weighting_matrix = weighting_matrix.unsqueeze(0).expand(self.dim, self.node_num, self.rows, 157 | self.cols) # Cxnode_numxrowsxcols 158 | diff_masked_mean_matrix_view = diff_masked_mean.view(self.dim, self.node_num, self.rows, self.cols) 159 | delta = diff_masked_mean_matrix_view * weighting_matrix * learning_rate # Cxnode_numxrowsxcols 160 | delta = delta.sum(dim=1) 161 | 162 | # apply the update 163 | node_matrix_view = self.node.view(self.dim, self.rows, self.cols) # Cxrowsxcols 164 | node_matrix_view += delta 165 | 166 | # print(self.node) 167 | 168 | def optimize(self, x): 169 | self.node_init() 170 | for iter in range(int(self.max_iteration / 3)): 171 | self.batch_update(x, 0) 172 | for iter in range(self.max_iteration): 173 | self.batch_update(x, iter) 174 | 175 | 176 | class BatchSOM(): 177 | def __init__(self, rows=4, cols=4, dim=3, gpu_id=0, batch_size=10): 178 | self.rows = rows 179 | self.cols = cols 180 | self.dim = dim 181 | self.node_num = rows * cols 182 | 183 | self.sigma = 0.4 184 | self.learning_rate = 0.5 185 | self.max_iteration = 30 186 | 187 | self.gpu_id = gpu_id 188 | self.device = torch.device("cuda:%d" % gpu_id) 189 | self.batch_size = batch_size 190 | 191 | # node: BxCx(rowsxcols), tensor 192 | self.node = torch.FloatTensor(self.batch_size, self.dim, self.rows * self.cols).zero_() 193 | self.node_idx_list = torch.from_numpy(np.arange(self.node_num).astype(np.int64)) # node_num LongTensor 194 | self.init_weighting_matrix = torch.FloatTensor(self.node_num, self.rows, self.cols) # node_numxrowsxcols 195 | if self.gpu_id >= 0: 196 | self.node = self.node.to(self.device) 197 | self.node_idx_list = self.node_idx_list.to(self.device) 198 | 199 | # get initial weighting matrix 200 | self.get_init_weighting_matrix() 201 | 202 | # initialize the node by potential field 203 | pf = potential_field.PotentialField(self.node_num, self.dim) 204 | pf.optimize() 205 | self.node_init_value = torch.from_numpy(pf.node.transpose().astype(np.float32)) 206 | 207 | def node_init(self, batch_size): 208 | self.batch_size = batch_size 209 | self.node.resize_(self.batch_size, self.dim, self.node_num) 210 | self.node.copy_(torch.unsqueeze(self.node_init_value, dim=0).expand_as(self.node)) 211 | 212 | def gaussian(self, c, sigma): 213 | """Returns a Gaussian centered in c""" 214 | d = 2 * np.pi * sigma * sigma 215 | ax = np.exp(-np.power(np.arange(self.rows) - c[0], 2) / d) 216 | ay = np.exp(-np.power(np.arange(self.cols) - c[1], 2) / d) 217 | return torch.from_numpy(np.outer(ax, ay).astype(np.float32)) 218 | 219 | def get_init_weighting_matrix(self): 220 | ''' 221 | get the initial weighting matrix, later the weighting matrix wil base on the init. 222 | ''' 223 | for idx in range(self.rows * self.cols): 224 | (i, j) = self.idx2multi(idx) 225 | self.init_weighting_matrix[idx, :] = self.gaussian((i, j), self.sigma) 226 | if self.gpu_id >= 0: 227 | self.init_weighting_matrix = self.init_weighting_matrix.to(self.device) 228 | 229 | def get_weighting_matrix(self, sigma): 230 | scale = 1.0 / ((sigma / self.sigma) ** 2) 231 | weighting_matrix = torch.exp(torch.log(self.init_weighting_matrix) * scale) 232 | return weighting_matrix 233 | 234 | def idx2multi(self, i): 235 | return (i // self.cols, i % self.cols) 236 | 237 | def query_topk(self, x, k): 238 | ''' 239 | :param x: input data CxN tensor 240 | :param k: topk 241 | :return: mask: Nxnode_num 242 | ''' 243 | 244 | # expand as BxCxNxnode_num 245 | node = self.node.unsqueeze(2).expand(x.size(0), x.size(1), x.size(2), self.rows * self.cols) 246 | x_expanded = x.unsqueeze(3).expand_as(node) 247 | 248 | # calcuate difference between x and each node 249 | diff = x_expanded - node # BxCxNxnode_num 250 | diff_norm = (diff ** 2).sum(dim=1) # BxNxnode_num 251 | 252 | # find the nearest neighbor 253 | _, min_idx = torch.topk(diff_norm, k=k, dim=2, largest=False, sorted=False) # BxNxk 254 | min_idx_expanded = min_idx.unsqueeze(2).expand(min_idx.size()[0], min_idx.size()[1], self.rows * self.cols, 255 | k) # BxNxnode_numxk 256 | 257 | node_idx_list = self.node_idx_list.unsqueeze(0).unsqueeze(0).unsqueeze(3).expand_as( 258 | min_idx_expanded).long() # BxNxnode_numxk 259 | mask = torch.eq(min_idx_expanded, node_idx_list).float() # BxNxnode_numxk 260 | # mask = torch.sum(mask, dim=3) # BxNxnode_num 261 | if k == 2: 262 | mask = torch.cat((mask[..., 0], mask[..., 1]), dim=1) # BxkNxnode_num 263 | min_idx = torch.cat((min_idx[..., 0], min_idx[..., 1]), dim=1) # BxkN 264 | elif k == 3: 265 | mask = torch.cat((mask[..., 0], mask[..., 1], mask[..., 2]), dim=1) # BxkNxnode_num 266 | min_idx = torch.cat((min_idx[..., 0], min_idx[..., 1], min_idx[..., 2]), dim=1) # BxkN 267 | mask_row_max, _ = torch.max(mask, dim=1) # Bxnode_num, this indicates whether the node has nearby x 268 | 269 | return mask, mask_row_max, min_idx 270 | 271 | def query(self, x): 272 | ''' 273 | :param x: input data CxN tensor 274 | :return: mask: Nxnode_num 275 | ''' 276 | # expand as BxCxNxnode_num 277 | node = self.node.unsqueeze(2).expand(x.size(0), x.size(1), x.size(2), self.rows * self.cols) 278 | x_expanded = x.unsqueeze(3).expand_as(node) 279 | 280 | # calcuate difference between x and each node 281 | diff = x_expanded - node # BxCxNxnode_num 282 | diff_norm = (diff ** 2).sum(dim=1) # BxNxnode_num 283 | 284 | # find the nearest neighbor 285 | _, min_idx = torch.min(diff_norm, dim=2) # BxN 286 | min_idx_expanded = min_idx.unsqueeze(2).expand(min_idx.size()[0], min_idx.size()[1], 287 | self.rows * self.cols) # BxNxnode_num 288 | 289 | node_idx_list = self.node_idx_list.unsqueeze(0).unsqueeze(0).expand_as(min_idx_expanded).long() # BxNxnode_num 290 | mask = torch.eq(min_idx_expanded, node_idx_list).float() # BxNxnode_num 291 | mask_row_max, _ = torch.max(mask, dim=1) # Bxnode_num, this indicates whether the node has nearby x 292 | 293 | return mask, mask_row_max 294 | 295 | def batch_update(self, x, learning_rate, sigma): 296 | # x is BxCxN tensor, C==self.dim, W=1 297 | assert (x.size()[1] == self.dim) 298 | assert (x.size()[0] == self.batch_size) 299 | 300 | # expand as BxCxNxnode_num 301 | node = self.node.unsqueeze(2).expand(x.size(0), x.size(1), x.size(2), self.rows * self.cols) 302 | x_expanded = x.unsqueeze(3).expand_as(node) 303 | 304 | # calcuate difference between x and each node 305 | diff = x_expanded - node # BxCxNxnode_num 306 | diff_norm = (diff ** 2).sum(dim=1) # BxNxnode_num 307 | 308 | # find the nearest neighbor 309 | _, min_idx = torch.min(diff_norm, dim=2) # BxN 310 | min_idx_expanded = min_idx.unsqueeze(2).expand(min_idx.size()[0], min_idx.size()[1], 311 | self.rows * self.cols) # BxNxnode_num 312 | 313 | node_idx_list = self.node_idx_list.unsqueeze(0).unsqueeze(0).expand_as(min_idx_expanded).long() # BxNxnode_num 314 | mask = torch.eq(min_idx_expanded, node_idx_list).float() # BxNxnode_num 315 | mask_row_sum = torch.sum(mask, dim=1) + 0.00001 # Bxnode_num 316 | mask_row_max, _ = torch.max(mask, dim=1) # Bxnode_num, this indicates whether the node has nearby x 317 | 318 | # calculate the mean x for each node 319 | x_expanded_masked = x_expanded * mask.unsqueeze(1).expand_as(x_expanded) # BxCxNxnode_num 320 | x_expanded_masked_sum = torch.sum(x_expanded_masked, dim=2) # BxCxnode_num 321 | x_expanded_mask_mean = x_expanded_masked_sum / mask_row_sum.unsqueeze(1).expand_as( 322 | x_expanded_masked_sum) # Cxnode_num 323 | 324 | # each x_expanded_mask_mean (in total node_num vectors) will calculate its diff with all nodes 325 | # multiply the mask_row_max, so that the isolated node won't be pulled to the center 326 | x_expanded_mask_mean_expanded = x_expanded_mask_mean.unsqueeze(3).expand(self.batch_size, self.dim, 327 | self.rows * self.cols, 328 | self.rows * self.cols) # BxCxnode_numxnode_num 329 | node_expanded_transposed = self.node.unsqueeze(2).expand_as(x_expanded_mask_mean_expanded) # .transpose(1,2) 330 | diff_masked_mean = x_expanded_mask_mean_expanded - node_expanded_transposed # BxCxnode_numxnode_num 331 | diff_masked_mean = diff_masked_mean * mask_row_max.unsqueeze(2).unsqueeze(1).expand_as(diff_masked_mean) 332 | 333 | # compute the neighrbor weighting using pre-computed matrix 334 | weighting_matrix = self.get_weighting_matrix(sigma) 335 | 336 | # expand weighting_matrix to be batch, Bxnode_numxrowsxcols 337 | weighting_matrix = weighting_matrix.unsqueeze(0).expand(self.batch_size, self.rows * self.cols, self.rows, 338 | self.cols) 339 | 340 | # compute the update 341 | weighting_matrix = weighting_matrix.unsqueeze(1).expand(self.batch_size, self.dim, self.rows * self.cols, 342 | self.rows, self.cols) # BxCxnode_numxrowsxcols 343 | diff_masked_mean_matrix_view = diff_masked_mean.view(self.batch_size, self.dim, self.rows * self.cols, 344 | self.rows, self.cols) 345 | delta = diff_masked_mean_matrix_view * weighting_matrix * learning_rate # BxCxnode_numxrowsxcols 346 | delta = delta.sum(dim=2) 347 | 348 | # apply the update 349 | node_matrix_view = self.node.view(self.batch_size, self.dim, self.rows, self.cols) # BxCxrowsxcols 350 | node_matrix_view += delta 351 | 352 | # print(self.node) 353 | # print(delta.max()) 354 | 355 | def optimize(self, x): 356 | self.node_init(x.size()[0]) 357 | for iter in range(int(self.max_iteration / 3)): 358 | # get learning_rate and sigma 359 | learning_rate = self.learning_rate 360 | sigma = self.sigma 361 | self.batch_update(x, learning_rate, sigma) 362 | for iter in range(self.max_iteration): 363 | # get learning_rate and sigma 364 | learning_rate = self.learning_rate / (1 + iter / self.max_iteration) 365 | sigma = self.sigma / (1 + iter / self.max_iteration) 366 | self.batch_update(x, iter, learning_rate, sigma) 367 | 368 | 369 | -------------------------------------------------------------------------------- /data/modelnet_shrec_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | import random 4 | import numbers 5 | import os 6 | import os.path 7 | import numpy as np 8 | import struct 9 | import math 10 | 11 | import torch 12 | import torchvision 13 | import matplotlib.pyplot as plt 14 | import h5py 15 | import faiss 16 | 17 | from .augmentation import * 18 | 19 | 20 | # Read numpy array data and label from h5_filename 21 | def load_h5(h5_filename): 22 | f = h5py.File(h5_filename) 23 | data = f['data'][:] 24 | label = f['label'][:] 25 | return (data, label) 26 | 27 | 28 | def make_dataset_modelnet40_10k(root, mode, opt): 29 | dataset = [] 30 | rows = round(math.sqrt(opt.node_num)) 31 | cols = rows 32 | 33 | f = open(os.path.join(root, 'modelnet%d_shape_names.txt' % opt.classes)) 34 | shape_list = [str.rstrip() for str in f.readlines()] 35 | f.close() 36 | 37 | if 'train' == mode: 38 | f = open(os.path.join(root, 'modelnet%d_train.txt' % opt.classes), 'r') 39 | lines = [str.rstrip() for str in f.readlines()] 40 | f.close() 41 | elif 'test' == mode: 42 | f = open(os.path.join(root, 'modelnet%d_test.txt' % opt.classes), 'r') 43 | lines = [str.rstrip() for str in f.readlines()] 44 | f.close() 45 | else: 46 | raise Exception('Network mode error.') 47 | 48 | for i, name in enumerate(lines): 49 | # locate the folder name 50 | folder = name[0:-5] 51 | file_name = name 52 | 53 | # get the label 54 | label = shape_list.index(folder) 55 | 56 | # som node locations 57 | som_nodes_folder = '%dx%d_som_nodes' % (rows, cols) 58 | 59 | item = (os.path.join(root, folder, file_name + '.npy'), 60 | label, 61 | os.path.join(root, som_nodes_folder, folder, file_name + '.npy')) 62 | dataset.append(item) 63 | 64 | return dataset 65 | 66 | 67 | def make_dataset_shrec2016(root, mode, opt): 68 | rows = round(math.sqrt(opt.node_num)) 69 | cols = rows 70 | dataset = [] 71 | 72 | # load category txt 73 | f = open(os.path.join(root, 'category.txt'), 'r') 74 | category_list = [str.rstrip() for str in f.readlines()] 75 | f.close() 76 | 77 | if 'train'==mode: 78 | f = open(os.path.join(root, 'train.txt'), 'r') 79 | lines = [str.rstrip() for str in f.readlines()] 80 | f.close() 81 | elif 'val'==mode: 82 | f = open(os.path.join(root, 'val.txt'), 'r') 83 | lines = [str.rstrip() for str in f.readlines()] 84 | f.close() 85 | elif 'test'==mode: 86 | f = open(os.path.join(root, 'test.txt'), 'r') 87 | lines = [str.rstrip() for str in f.readlines()] 88 | f.close() 89 | else: 90 | raise Exception('Network mode error.') 91 | 92 | if 'train' == mode or 'val' == mode: 93 | for line in lines: 94 | line_split = [x.strip() for x in line.split(',')] 95 | name, category = line_split[0], line_split[1] 96 | 97 | npz_file = os.path.join(root, '%dx%d'%(rows,cols), mode, 'model_'+name+'.npz') 98 | try: 99 | category = category_list.index(category) 100 | except ValueError: 101 | continue 102 | 103 | item = (npz_file, category) 104 | dataset.append(item) 105 | elif 'test' == mode: 106 | for line in lines: 107 | name, category = line, int(line) % 55 108 | npz_file = os.path.join(root, '%dx%d'%(rows,cols), mode, 'model_'+name+'.npz') 109 | 110 | item = (npz_file, category) 111 | dataset.append(item) 112 | 113 | return dataset 114 | 115 | 116 | class KNNBuilder: 117 | def __init__(self, k): 118 | self.k = k 119 | self.dimension = 3 120 | 121 | def build_nn_index(self, database): 122 | ''' 123 | :param database: numpy array of Nx3 124 | :return: Faiss index, in CPU 125 | ''' 126 | index = faiss.IndexFlatL2(self.dimension) # dimension is 3 127 | index.add(database) 128 | return index 129 | 130 | def search_nn(self, index, query, k): 131 | ''' 132 | :param index: Faiss index 133 | :param query: numpy array of Nx3 134 | :return: D: numpy array of Nxk 135 | I: numpy array of Nxk 136 | ''' 137 | D, I = index.search(query, k) 138 | return D, I 139 | 140 | def self_build_search(self, x): 141 | ''' 142 | 143 | :param x: numpy array of Nxd 144 | :return: D: numpy array of Nxk 145 | I: numpy array of Nxk 146 | ''' 147 | x = np.ascontiguousarray(x, dtype=np.float32) 148 | index = self.build_nn_index(x) 149 | D, I = self.search_nn(index, x, self.k) 150 | return D, I 151 | 152 | 153 | class FarthestSampler: 154 | def __init__(self): 155 | pass 156 | 157 | def calc_distances(self, p0, points): 158 | return ((p0 - points) ** 2).sum(axis=1) 159 | 160 | def sample(self, pts, k): 161 | farthest_pts = np.zeros((k, 3)) 162 | farthest_pts[0] = pts[np.random.randint(len(pts))] 163 | distances = self.calc_distances(farthest_pts[0], pts) 164 | for i in range(1, k): 165 | farthest_pts[i] = pts[np.argmax(distances)] 166 | distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts)) 167 | return farthest_pts 168 | 169 | 170 | class ModelNet_Shrec_Loader(data.Dataset): 171 | def __init__(self, root, mode, opt): 172 | super(ModelNet_Shrec_Loader, self).__init__() 173 | self.root = root 174 | self.opt = opt 175 | self.mode = mode 176 | 177 | if self.opt.dataset == 'modelnet': 178 | self.dataset = make_dataset_modelnet40_10k(self.root, mode, opt) 179 | elif self.opt.dataset == 'shrec': 180 | self.dataset = make_dataset_shrec2016(self.root, mode, opt) 181 | else: 182 | raise Exception('Dataset incorrect.') 183 | 184 | # kNN search on SOM nodes 185 | self.knn_builder = KNNBuilder(self.opt.som_k) 186 | 187 | # farthest point sample 188 | self.fathest_sampler = FarthestSampler() 189 | 190 | def __len__(self): 191 | return len(self.dataset) 192 | 193 | def __getitem__(self, index): 194 | if self.opt.dataset == 'modelnet': 195 | pc_np_file, class_id, som_node_np_file = self.dataset[index] 196 | 197 | data = np.load(pc_np_file) 198 | data = data[np.random.choice(data.shape[0], self.opt.input_pc_num, replace=False), :] 199 | 200 | pc_np = data[:, 0:3] # Nx3 201 | surface_normal_np = data[:, 3:6] # Nx3 202 | som_node_np = np.load(som_node_np_file) # node_numx3 203 | elif self.opt.dataset == 'shrec': 204 | npz_file, class_id = self.dataset[index] 205 | data = np.load(npz_file) 206 | 207 | pc_np = data['pc'] 208 | surface_normal_np = data['sn'] 209 | som_node_np = data['som_node'] 210 | 211 | # random choice 212 | choice_idx = np.random.choice(pc_np.shape[0], self.opt.input_pc_num, replace=False) 213 | pc_np = pc_np[choice_idx, :] 214 | surface_normal_np = surface_normal_np[choice_idx, :] 215 | else: 216 | raise Exception('Dataset incorrect.') 217 | 218 | # augmentation 219 | if self.mode == 'train': 220 | # rotate by 0/90/180/270 degree over z axis 221 | # pc_np = rotate_point_cloud_90(pc_np) 222 | # som_node_np = rotate_point_cloud_90(som_node_np) 223 | 224 | # rotation perturbation, pc and som should follow the same rotation, surface normal rotation is unclear 225 | if self.opt.rot_horizontal: 226 | pc_np, surface_normal_np, som_node_np = rotate_point_cloud_with_normal_som(pc_np, surface_normal_np, som_node_np) 227 | if self.opt.rot_perturbation: 228 | pc_np, surface_normal_np, som_node_np = rotate_perturbation_point_cloud_with_normal_som(pc_np, surface_normal_np, som_node_np) 229 | 230 | # random jittering 231 | pc_np = jitter_point_cloud(pc_np) 232 | surface_normal_np = jitter_point_cloud(surface_normal_np) 233 | som_node_np = jitter_point_cloud(som_node_np, sigma=0.04, clip=0.1) 234 | 235 | # random scale 236 | scale = np.random.uniform(low=0.8, high=1.2) 237 | pc_np = pc_np * scale 238 | som_node_np = som_node_np * scale 239 | surface_normal_np = surface_normal_np * scale 240 | 241 | # random shift 242 | if self.opt.translation_perturbation: 243 | shift = np.random.uniform(-0.1, 0.1, (1,3)) 244 | pc_np += shift 245 | som_node_np += shift 246 | 247 | # convert to tensor 248 | pc = torch.from_numpy(pc_np.transpose().astype(np.float32)) # 3xN 249 | 250 | # surface normal 251 | surface_normal = torch.from_numpy(surface_normal_np.transpose().astype(np.float32)) # 3xN 252 | 253 | # som 254 | som_node = torch.from_numpy(som_node_np.transpose().astype(np.float32)) # 3xnode_num 255 | 256 | # kNN search: som -> som 257 | if self.opt.som_k >= 2: 258 | D, I = self.knn_builder.self_build_search(som_node_np) 259 | som_knn_I = torch.from_numpy(I.astype(np.int64)) # node_num x som_k 260 | else: 261 | som_knn_I = torch.from_numpy(np.arange(start=0, stop=self.opt.node_num, dtype=np.int64).reshape((self.opt.node_num, 1))) # node_num x 1 262 | 263 | # print(som_node_np) 264 | # print(D) 265 | # print(I) 266 | # assert False 267 | 268 | if self.opt.dataset == 'shrec': 269 | return pc, surface_normal, class_id, som_node, som_knn_I, index 270 | else: 271 | return pc, surface_normal, class_id, som_node, som_knn_I 272 | 273 | 274 | if __name__=="__main__": 275 | # dataset = make_dataset_modelnet40('/ssd/dataset/modelnet40_ply_hdf5_2048/', True) 276 | # print(len(dataset)) 277 | # print(dataset[0]) 278 | 279 | 280 | class VirtualOpt(): 281 | def __init__(self): 282 | self.load_all_data = False 283 | self.input_pc_num = 5000 284 | self.batch_size = 8 285 | self.dataset = '10k' 286 | self.node_num = 64 287 | self.classes = 10 288 | self.som_k = 9 289 | opt = VirtualOpt() 290 | trainset = ModelNet_Shrec_Loader('/ssd/dataset/modelnet40-normal_numpy/', 'train', opt) 291 | print('---') 292 | print(len(trainset)) 293 | print(trainset[0]) 294 | 295 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=False, num_workers=4) 296 | -------------------------------------------------------------------------------- /data/sampler_matlab/RemoveWhiteSpace.m: -------------------------------------------------------------------------------- 1 | function u_out = RemoveWhiteSpace(u_in, varargin) 2 | % February 2nd, 2012, By Reza Farrahi Moghaddam, Synchromedia Lab, ETS, Montreal, Canada 3 | % 4 | % RemoveWhiteSpace function removes white spaces around an image. 5 | % 6 | % Syntax: 7 | % 1. For an image: u_out = RemoveWhiteSpace(u_in) 8 | % 9 | % 2. For an image file, to write the result on the same file: RemoveWhiteSpace([], 'file', input_filename) 10 | % 11 | % 3. For an image file, to make a new output file: RemoveWhiteSpace([], 'file', input_filename, 'output', output_filename) 12 | % 13 | 14 | % get the arguments 15 | [it_is_a_file_flag, input_filename, output_filename] = check_the_argin_infile(nargin, varargin{:}); 16 | 17 | % 18 | if (it_is_a_file_flag) 19 | [u_in, map] = imread(input_filename); 20 | if (numel(map) ~= 0) 21 | u_in = ind2rgb(u_in, map); 22 | end 23 | end 24 | u_in = mat2gray_infile(u_in); 25 | [xm ym zm] = size(u_in); 26 | 27 | % 28 | if (zm == 3) 29 | % u_gray = rgb2gray(u_in); 30 | u_gray = mean(u_in, 3); 31 | else 32 | u_gray = mean(u_in, 3); 33 | end 34 | 35 | % 36 | u_white_mask = u_gray > 0.99; 37 | u_white_mask_hori = reshape(sum(u_white_mask, 1), [], 1); 38 | u_white_mask_vert = sum(u_white_mask, 2); 39 | % 40 | u_white_mask_hori(u_white_mask_hori < xm / 2) = 0; 41 | u_white_mask_vert(u_white_mask_vert < ym / 2) = 0; 42 | u_white_mask_hori_diff = diff(u_white_mask_hori); 43 | u_white_mask_vert_diff = diff(u_white_mask_vert); 44 | [~, boundingbox_hori] = findpeaks(abs(u_white_mask_hori_diff)); 45 | [~, boundingbox_vert] = findpeaks(abs(diff(u_white_mask_vert))); 46 | if (numel(boundingbox_hori) == 0) 47 | boundingbox_hori = [0 ym]; 48 | elseif (numel(boundingbox_hori) == 1) 49 | if (boundingbox_hori > ym / 2) 50 | boundingbox_hori = [0, boundingbox_hori]; 51 | 52 | else 53 | boundingbox_hori = [boundingbox_hori, ym]; 54 | end 55 | else 56 | boundingbox_hori = boundingbox_hori([1, end]); 57 | boundingbox_hori = boundingbox_hori .* [- u_white_mask_hori_diff(boundingbox_hori(1)) > 0; u_white_mask_hori_diff(boundingbox_hori(2)) > 0]; 58 | end 59 | if (numel(boundingbox_vert) == 0) 60 | boundingbox_vert = [0 xm]; 61 | elseif (numel(boundingbox_vert) == 1) 62 | if (boundingbox_vert > xm / 2) 63 | boundingbox_vert = [0, boundingbox_vert]; 64 | else 65 | boundingbox_vert = [boundingbox_vert, xm]; 66 | end 67 | else 68 | boundingbox_vert = boundingbox_vert([1, end]); 69 | boundingbox_vert = boundingbox_vert .* [- u_white_mask_vert_diff(boundingbox_vert(1)) > 0; u_white_mask_vert_diff(boundingbox_vert(2)) > 0]; 70 | end 71 | boundingbox_hori(1) = boundingbox_hori(1) + 1; 72 | boundingbox_vert(1) = boundingbox_vert(1) + 1; 73 | if (boundingbox_hori(2) == 0) 74 | boundingbox_hori(2) = ym; 75 | end 76 | if (boundingbox_vert(2) == 0) 77 | boundingbox_vert(2) = xm; 78 | end 79 | % 80 | u_out = u_in(boundingbox_vert(1) : boundingbox_vert(2), boundingbox_hori(1) : boundingbox_hori(2), :); 81 | 82 | % 83 | if (it_is_a_file_flag) 84 | imwrite(u_out, output_filename); 85 | end 86 | 87 | end 88 | 89 | 90 | function [it_is_a_file_flag, input_filename, output_filename] = check_the_argin_infile(nargin, varargin) 91 | % 120202: Reza 92 | 93 | % % 94 | it_is_a_file_flag = false; 95 | input_filename = ''; 96 | output_filename = ''; 97 | 98 | % 99 | default_fields = {'file', 'output'}; 100 | for temp_label = 1 : 2 : (nargin - 1) 101 | parameter_name = varargin{temp_label}; 102 | parameter_val = varargin{temp_label + 1}; 103 | matched_field = find(strcmpi(parameter_name, default_fields)); 104 | if isempty(matched_field) 105 | error('Error: Unknown argument: %s.\n', parameter_name); 106 | else 107 | switch(matched_field) 108 | case 1 % input_filename 109 | input_filename = parameter_val; 110 | output_filename = parameter_val; 111 | it_is_a_file_flag = true; 112 | case 2 % output_filename 113 | output_filename = parameter_val; 114 | it_is_a_file_flag = true; 115 | end 116 | end 117 | end 118 | 119 | % 120 | if (numel(output_filename) > 0)&&(numel(input_filename) == 0) 121 | input_filename = output_filename; 122 | end 123 | 124 | end 125 | 126 | function u_in = mat2gray_infile(u_in) 127 | % 120212: Reza 128 | 129 | % 130 | u_in = double(u_in); 131 | u_in_min = min(u_in(:)); 132 | u_in_max = max(u_in(:)); 133 | 134 | % 135 | if (u_in_max == u_in_min) 136 | u_in(:) = 1; 137 | else 138 | u_in = (u_in - u_in_min) ./ (u_in_max - u_in_min); 139 | end 140 | 141 | end 142 | 143 | 144 | %{ 145 | u_in = mat2gray(imread('test.png')); 146 | subfigure(2, 2, [2 2]), imshow(u_in, 'InitialMagnification','fit'); 147 | subfigure(2, 2, [1 1]), imshow(RemoveWhiteSpace(u_in), 'InitialMagnification','fit'); 148 | %} 149 | 150 | %{ 151 | RemoveWhiteSpace([], 'file', 'test.png', 'output', 'test_out.png'); 152 | %} 153 | 154 | %{ 155 | RemoveWhiteSpace([], 'file', 'test.png'); 156 | %} 157 | -------------------------------------------------------------------------------- /data/sampler_matlab/pc_generator.m: -------------------------------------------------------------------------------- 1 | root = '/ssd/dataset/SHREC2016/train/'; 2 | N = 10000; 3 | 4 | root_content = dir(root); 5 | root_content = root_content(3:end); 6 | folder_list = {}; 7 | 8 | for i=1:1:length(root_content) 9 | folder_name = root_content(i).name; 10 | folder_list{i} = folder_name; 11 | 12 | folder_content = dir([root, folder_name]); 13 | folder_content = folder_content(3:end); 14 | 15 | parfor j=1:1:length(folder_content) 16 | if folder_content(j).bytes > 100 && strcmp(folder_content(j).name(end-2:end), 'obj') 17 | obj_file = [root, folder_name, '/', folder_content(j).name]; 18 | [pc, pc_normal] = sampler(obj_file, N); 19 | 20 | % write to txt 21 | dlmwrite([obj_file(1:end-3), 'txt'], [pc, pc_normal], 'delimiter', ' '); 22 | end 23 | end 24 | 25 | % scatter3(pc(:,1), pc(:,2), pc(:,3), 50, pc_normal, 'Marker', '.'); 26 | end 27 | 28 | disp(folder_list) -------------------------------------------------------------------------------- /data/sampler_matlab/pc_generator_test.m: -------------------------------------------------------------------------------- 1 | root = '/ssd/dataset/SHREC2016/obj_txt/test_allinone/'; 2 | N = 10000; 3 | 4 | folder_content = dir(root); 5 | folder_content = folder_content(3:end); 6 | 7 | parfor j=1:1:length(folder_content) 8 | if folder_content(j).bytes > 100 && strcmp(folder_content(j).name(end-2:end), 'obj') 9 | obj_file = [root, folder_content(j).name]; 10 | [pc, pc_normal] = sampler(obj_file, N); 11 | 12 | % write to txt 13 | dlmwrite([obj_file(1:end-3), 'txt'], [pc, pc_normal], 'delimiter', ' '); 14 | end 15 | end 16 | 17 | % scatter3(pc(:,1), pc(:,2), pc(:,3), 50, pc_normal, 'Marker', '.'); 18 | 19 | -------------------------------------------------------------------------------- /data/sampler_matlab/read_obj.m: -------------------------------------------------------------------------------- 1 | function [vertex, faces] = read_obj(file) 2 | 3 | % get vertex and faces 4 | text = fileread(file); 5 | text = regexprep(text, '/[0-9]*', ''); 6 | data = textscan(text, '%c %f %f %f%*c'); 7 | 8 | type = data{1}; 9 | values = [data{2}, data{3}, data{4}]; 10 | 11 | v_idx = find(type=='v'); 12 | vertex = values(v_idx, :); 13 | 14 | f_idx = find(type=='f'); 15 | faces = values(f_idx, :); 16 | 17 | end -------------------------------------------------------------------------------- /data/sampler_matlab/sampler.m: -------------------------------------------------------------------------------- 1 | function [pc, pc_normal] = sampler(file, N) 2 | 3 | %% get vertex and faces 4 | [vertex, faces] = read_obj(file); 5 | 6 | %% get triangle area and face normal 7 | triangles_a = vertex(faces(:,1), :); % Fx3 8 | triangles_b = vertex(faces(:,2), :); % Fx3 9 | triangles_c = vertex(faces(:,3), :); % Fx3 10 | 11 | a_b = triangles_b - triangles_a; % Fx3 12 | a_c = triangles_c - triangles_a; % Fx3 13 | 14 | a_bxa_c = cross(a_b, a_c, 2); % Fx3 15 | tmp = sqrt(sum(a_bxa_c.^2, 2)); % Fx1 16 | areas = 0.5 * tmp; % Fx1 17 | normals = a_bxa_c ./ tmp; % Fx3 18 | 19 | %% weighted random sample 20 | sampled_tri_idx = randsample(length(areas), N, true, areas); 21 | 22 | sampled_a = triangles_a(sampled_tri_idx, :); % Nx3 23 | sampled_b = triangles_b(sampled_tri_idx, :); % Nx3 24 | sampled_c = triangles_c(sampled_tri_idx, :); % Nx3 25 | 26 | pc_normal = normals(sampled_tri_idx, :); % Nx3 27 | 28 | %% sample points 29 | u = rand(N, 1); % Nx1 30 | v = rand(N, 1); 31 | invalid = u + v > 1 32 | u(invalid) = 1 - u(invalid) 33 | v(invalid) = 1 - v(invalid) 34 | 35 | pc = sampled_a + u .* sampled_b + v .* sampled_c; 36 | 37 | end 38 | 39 | % toc; 40 | % scatter3(pc(:,1), pc(:,2), pc(:,3), 50, pc_normal, 'Marker', '.') 41 | -------------------------------------------------------------------------------- /data/sampler_matlab/visualization.m: -------------------------------------------------------------------------------- 1 | %close all; 2 | clear all; 3 | 4 | query = '000437'; 5 | list_file = ['/ssd/tmp/test_normal/', query]; 6 | 7 | font_size = 18; 8 | fig_size = [8, 6]; 9 | 10 | f = fopen(list_file); 11 | data = textscan(f, '%s %f'); 12 | shape_list = data{1}; 13 | fclose(f); 14 | 15 | for i=1:1:length(shape_list) 16 | candidate = shape_list{i,1}; 17 | obj_file = ['/ssd/dataset/SHREC2016/obj_txt/test_allinone/model_', candidate, '.obj']; 18 | 19 | [vertex, faces] = read_obj(obj_file); 20 | fig = figure('Visible', 'Off'); 21 | trisurf(faces, vertex(:,1), vertex(:,2), vertex(:,3), ... 22 | 'FaceColor', [0.8,0.8,0.8], 'EdgeColor', 'none', 'FaceLighting', 'flat', ... 23 | 'AmbientStrength', 0.5, 'SpecularColorReflectance', 1); 24 | colormap(gray) 25 | light('Position',[-0.4 0.2 0.9], 'Style', 'infinite') 26 | 27 | lim_max = max(max([xlim;ylim;zlim])); 28 | lim_min = min(min([xlim;ylim;zlim])); 29 | xlim([lim_min, lim_max]); 30 | ylim([lim_min, lim_max]); 31 | zlim([lim_min, lim_max]); 32 | 33 | axis off 34 | set(fig, 'Units', 'Inches', 'Position', [0, 0, fig_size(1), fig_size(2)], 'PaperUnits', 'Inches', 'PaperSize', [fig_size(1), fig_size(2)]); 35 | set(gcf, 'PaperUnits', 'Inches'); 36 | set(gcf, 'PaperPosition', [0, 0, fig_size(1), fig_size(2)]); 37 | saveas(fig, ['visualization/', query, '_', num2str(i), '_', candidate, '.png'], 'png'); 38 | 39 | %% crop the white edges 40 | RemoveWhiteSpace([], 'file', ['visualization/', query, '_', num2str(i), '_', candidate, '.png']); 41 | 42 | if i>=6 43 | break; 44 | end 45 | end 46 | 47 | %close all; 48 | clear all; 49 | 50 | 51 | % 'FaceColor', [0.7,0.7,0.7] 52 | % gouraud 53 | -------------------------------------------------------------------------------- /data/shapenet_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | import random 4 | import numbers 5 | import os 6 | import os.path 7 | import numpy as np 8 | import struct 9 | import math 10 | 11 | import torch 12 | import torchvision 13 | import matplotlib.pyplot as plt 14 | import h5py 15 | import json 16 | 17 | import matplotlib.pyplot as plt 18 | from mpl_toolkits.mplot3d import Axes3D 19 | 20 | from .augmentation import * 21 | 22 | 23 | def load_h5_data_label_seg(h5_filename): 24 | f = h5py.File(h5_filename) 25 | data = f['data'][:] 26 | label = f['label'][:] 27 | seg = f['pid'][:] 28 | return (data, label, seg) 29 | 30 | 31 | def make_dataset_shapenet_normal(root, mode): 32 | if mode == 'train': 33 | f = open(os.path.join(root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') 34 | file_name_list = json.load(f) 35 | f.close() 36 | elif mode == 'test': 37 | f = open(os.path.join(root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') 38 | file_name_list = json.load(f) 39 | f.close() 40 | else: 41 | raise Exception('Mode should be train/test.') 42 | 43 | return file_name_list 44 | 45 | 46 | class KNNBuilder: 47 | def __init__(self, k): 48 | self.k = k 49 | self.dimension = 3 50 | 51 | def build_nn_index(self, database): 52 | ''' 53 | :param database: numpy array of Nx3 54 | :return: Faiss index, in CPU 55 | ''' 56 | index = faiss.IndexFlatL2(self.dimension) # dimension is 3 57 | index.add(database) 58 | return index 59 | 60 | def search_nn(self, index, query, k): 61 | ''' 62 | :param index: Faiss index 63 | :param query: numpy array of Nx3 64 | :return: D: numpy array of Nxk 65 | I: numpy array of Nxk 66 | ''' 67 | D, I = index.search(query, k) 68 | return D, I 69 | 70 | def self_build_search(self, x): 71 | ''' 72 | 73 | :param x: numpy array of Nxd 74 | :return: D: numpy array of Nxk 75 | I: numpy array of Nxk 76 | ''' 77 | x = np.ascontiguousarray(x, dtype=np.float32) 78 | index = self.build_nn_index(x) 79 | D, I = self.search_nn(index, x, self.k) 80 | return D, I 81 | 82 | 83 | class FarthestSampler: 84 | def __init__(self): 85 | pass 86 | 87 | def calc_distances(self, p0, points): 88 | return ((p0 - points) ** 2).sum(axis=1) 89 | 90 | def sample(self, pts, k): 91 | farthest_pts = np.zeros((k, 3)) 92 | farthest_pts[0] = pts[np.random.randint(len(pts))] 93 | distances = self.calc_distances(farthest_pts[0], pts) 94 | for i in range(1, k): 95 | farthest_pts[i] = pts[np.argmax(distances)] 96 | distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts)) 97 | return farthest_pts 98 | 99 | 100 | class ShapeNetLoader(data.Dataset): 101 | def __init__(self, root, mode, opt): 102 | super(ShapeNetLoader, self).__init__() 103 | self.root = root 104 | self.opt = opt 105 | self.mode = mode 106 | 107 | self.node_num = opt.node_num 108 | self.rows = round(math.sqrt(self.node_num)) 109 | self.cols = self.rows 110 | 111 | self.dataset = make_dataset_shapenet_normal(self.root, self.mode) 112 | # ensure there is no batch-1 batch 113 | if len(self.dataset) % self.opt.batch_size == 1: 114 | self.dataset.pop() 115 | 116 | # load the folder-category txt 117 | self.categories = ['Airplane', 'Bag', 'Cap', 'Car', 'Chair', 'Earphone', 'Guitar', 'Knife', 'Lamp', 'Laptop', 118 | 'Motorbike', 'Mug', 'Pistol', 'Rocket', 'Skateboard', 'Table'] 119 | self.folders = ['02691156', '02773838', '02954340', '02958343', '03001627', '03261776', '03467517', '03624134', 120 | '03636649', '03642806', '03790512', '03797390', '03948459', '04099429', '04225987', '04379243'] 121 | 122 | # kNN search on SOM nodes 123 | self.knn_builder = KNNBuilder(self.opt.som_k) 124 | 125 | # farthest point sample 126 | self.fathest_sampler = FarthestSampler() 127 | 128 | def __len__(self): 129 | return len(self.dataset) 130 | 131 | def __getitem__(self, index): 132 | # pointnet++ dataset 133 | file = self.dataset[index][11:] 134 | data = np.load(os.path.join(self.root, file + '_%dx%d.npz' % (self.rows, self.cols))) 135 | pc_np = data['pc'] 136 | sn_np = data['sn'] 137 | seg_np = data['part_label'] 138 | som_node_np = data['som_node'] 139 | label = self.folders.index(file[0:8]) 140 | assert(label >= 0) 141 | 142 | if self.opt.input_pc_num < pc_np.shape[0]: 143 | chosen_idx = np.random.choice(pc_np.shape[0], self.opt.input_pc_num, replace=False) 144 | pc_np = pc_np[chosen_idx, :] 145 | sn_np = sn_np[chosen_idx, :] 146 | seg_np = seg_np[chosen_idx] 147 | else: 148 | chosen_idx = np.random.choice(pc_np.shape[0], self.opt.input_pc_num-pc_np.shape[0], replace=True) 149 | pc_np_redundent = pc_np[chosen_idx, :] 150 | sn_np_redundent = sn_np[chosen_idx, :] 151 | seg_np_redundent = seg_np[chosen_idx] 152 | pc_np = np.concatenate((pc_np, pc_np_redundent), axis=0) 153 | sn_np = np.concatenate((sn_np, sn_np_redundent), axis=0) 154 | seg_np = np.concatenate((seg_np, seg_np_redundent), axis=0) 155 | 156 | # augmentation 157 | if self.mode == 'train': 158 | # rotate by random degree over model z (point coordinate y) axis 159 | # pc_np = rotate_point_cloud(pc_np) 160 | # som_node_np = rotate_point_cloud(som_node_np) 161 | 162 | # rotate by 0/90/180/270 degree over model z (point coordinate y) axis 163 | # pc_np = rotate_point_cloud_90(pc_np) 164 | # som_node_np = rotate_point_cloud_90(som_node_np) 165 | 166 | # random jittering 167 | pc_np = jitter_point_cloud(pc_np) 168 | sn_np = jitter_point_cloud(sn_np) 169 | som_node_np = jitter_point_cloud(som_node_np, sigma=0.04, clip=0.1) 170 | 171 | # random scale 172 | scale = np.random.uniform(low=0.8, high=1.2) 173 | pc_np = pc_np * scale 174 | sn_np = sn_np * scale 175 | som_node_np = som_node_np * scale 176 | 177 | # random shift 178 | # shift = np.random.uniform(-0.1, 0.1, (1,3)) 179 | # pc_np += shift 180 | # som_node_np += shift 181 | 182 | # convert to tensor 183 | pc = torch.from_numpy(pc_np.transpose().astype(np.float32)) # 3xN 184 | sn = torch.from_numpy(sn_np.transpose().astype(np.float32)) # 3xN 185 | seg = torch.from_numpy(seg_np.astype(np.int64)) # N 186 | 187 | # som 188 | som_node = torch.from_numpy(som_node_np.transpose().astype(np.float32)) # 3xnode_num 189 | 190 | # kNN search: som -> som 191 | if self.opt.som_k >= 2: 192 | D, I = self.knn_builder.self_build_search(som_node_np) 193 | som_knn_I = torch.from_numpy(I.astype(np.int64)) # node_num x som_k 194 | else: 195 | som_knn_I = torch.from_numpy(np.arange(start=0, stop=self.opt.node_num, dtype=np.int64).reshape( 196 | (self.opt.node_num, 1))) # node_num x 1 197 | 198 | return pc, sn, label, seg, som_node, som_knn_I 199 | 200 | 201 | 202 | if __name__=="__main__": 203 | # dataset = make_dataset_modelnet40('/ssd/dataset/modelnet40_ply_hdf5_2048/', True) 204 | # print(len(dataset)) 205 | # print(dataset[0]) 206 | 207 | 208 | class VirtualOpt(): 209 | def __init__(self): 210 | self.load_all_data = False 211 | self.input_pc_num = 8000 212 | self.batch_size = 20 213 | self.node_num = 49 214 | opt = VirtualOpt() 215 | trainset = ShapeNetLoader('/ssd/dataset/shapenet_part_seg_hdf5_data/', 'train', opt) 216 | print(len(trainset)) 217 | pc, label, seg, som_node = trainset[10] 218 | 219 | print(label) 220 | print(seg) 221 | 222 | x_np = pc.numpy().transpose() 223 | node_np = som_node.numpy().transpose() 224 | fig = plt.figure() 225 | ax = Axes3D(fig) 226 | ax.scatter(x_np[:, 0].tolist(), x_np[:, 1].tolist(), x_np[:, 2].tolist(), s=1) 227 | ax.scatter(node_np[:, 0].tolist(), node_np[:, 1].tolist(), node_np[:, 2].tolist(), s=6, c='r') 228 | plt.show() 229 | 230 | -------------------------------------------------------------------------------- /modelnet/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | 7 | class Options(): 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser() 10 | self.initialized = False 11 | 12 | def initialize(self): 13 | self.parser.add_argument('--gpu_id', type=int, default=1, help='gpu id: e.g. 0, 1, 2. -1 is no GPU') 14 | 15 | self.parser.add_argument('--dataset', type=str, default='modelnet', help='modelnet / shrec / shapenet') 16 | self.parser.add_argument('--dataroot', default='/ssd/jiaxin/datasets/modelnet40-normal_numpy/', help='path to images & laser point clouds') 17 | self.parser.add_argument('--classes', type=int, default=10, help='ModelNet40 or ModelNet10') 18 | self.parser.add_argument('--name', type=str, default='train', help='name of the experiment. It decides where to store samples and models') 19 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 20 | 21 | self.parser.add_argument('--batch_size', type=int, default=8, help='input batch size') 22 | self.parser.add_argument('--input_pc_num', type=int, default=5000, help='# of input points') 23 | self.parser.add_argument('--surface_normal', type=bool, default=True, help='use surface normal in the pc input') 24 | self.parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data') 25 | 26 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 27 | self.parser.add_argument('--display_id', type=int, default=200, help='window id of the web display') 28 | 29 | self.parser.add_argument('--feature_num', type=int, default=1024, help='length of encoded feature') 30 | self.parser.add_argument('--activation', type=str, default='relu', help='activation function: relu, elu') 31 | self.parser.add_argument('--normalization', type=str, default='batch', help='normalization function: batch, instance') 32 | 33 | self.parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 34 | self.parser.add_argument('--dropout', type=float, default=0.7, help='probability of an element to be zeroed') 35 | self.parser.add_argument('--node_num', type=int, default=64, help='som node number') 36 | self.parser.add_argument('--k', type=int, default=3, help='k nearest neighbor') 37 | self.parser.add_argument('--pretrain', type=str, default=None, help='pre-trained encoder dict path') 38 | self.parser.add_argument('--pretrain_lr_ratio', type=float, default=1, help='learning rate ratio between pretrained encoder and classifier') 39 | 40 | self.parser.add_argument('--som_k', type=int, default=9, help='k nearest neighbor of SOM nodes searching on SOM nodes') 41 | self.parser.add_argument('--som_k_type', type=str, default='avg', help='avg / center') 42 | 43 | self.parser.add_argument('--random_pc_dropout_lower_limit', type=float, default=1, help='keep ratio lower limit') 44 | self.parser.add_argument('--bn_momentum', type=float, default=0.1, help='normalization momentum, typically 0.1. Equal to (1-m) in TF') 45 | self.parser.add_argument('--bn_momentum_decay_step', type=int, default=None, help='BN momentum decay step. e.g, 0.5->0.01.') 46 | self.parser.add_argument('--bn_momentum_decay', type=float, default=0.6, help='BN momentum decay step. e.g, 0.5->0.01.') 47 | 48 | 49 | self.parser.add_argument('--rot_horizontal', type=bool, default=False, help='Rotation augmentation around vertical axis.') 50 | self.parser.add_argument('--rot_perturbation', type=bool, default=False, help='Small rotation augmentation around 3 axis.') 51 | self.parser.add_argument('--translation_perturbation', type=bool, default=False, help='Small translation augmentation around 3 axis.') 52 | 53 | self.initialized = True 54 | 55 | def parse(self): 56 | if not self.initialized: 57 | self.initialize() 58 | self.opt = self.parser.parse_args() 59 | 60 | self.opt.device = torch.device("cuda:%d"%(self.opt.gpu_id) if torch.cuda.is_available() else "cpu") 61 | # torch.cuda.set_device(self.opt.gpu_id) 62 | 63 | args = vars(self.opt) 64 | 65 | print('------------ Options -------------') 66 | for k, v in sorted(args.items()): 67 | print('%s: %s' % (str(k), str(v))) 68 | print('-------------- End ----------------') 69 | 70 | # save to the disk 71 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 72 | util.mkdirs(expr_dir) 73 | file_name = os.path.join(expr_dir, 'opt.txt') 74 | with open(file_name, 'wt') as opt_file: 75 | opt_file.write('------------ Options -------------\n') 76 | for k, v in sorted(args.items()): 77 | opt_file.write('%s: %s\n' % (str(k), str(v))) 78 | opt_file.write('-------------- End ----------------\n') 79 | return self.opt 80 | -------------------------------------------------------------------------------- /modelnet/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | import numpy as np 4 | import math 5 | 6 | from modelnet.options import Options 7 | opt = Options().parse() # set CUDA_VISIBLE_DEVICES before import torch 8 | 9 | import torch 10 | import torchvision 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import random 15 | import numpy as np 16 | 17 | from models.classifier import Model 18 | from data.modelnet_shrec_loader import ModelNet_Shrec_Loader 19 | from util.visualizer import Visualizer 20 | 21 | 22 | if __name__=='__main__': 23 | trainset = ModelNet_Shrec_Loader(opt.dataroot, 'train', opt) 24 | dataset_size = len(trainset) 25 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.nThreads) 26 | print('#training point clouds = %d' % len(trainset)) 27 | 28 | testset = ModelNet_Shrec_Loader(opt.dataroot, 'test', opt) 29 | testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.nThreads) 30 | 31 | # create model, optionally load pre-trained model 32 | model = Model(opt) 33 | if opt.pretrain is not None: 34 | model.encoder.load_state_dict(torch.load(opt.pretrain)) 35 | ############################# automation for ModelNet10 / 40 configuration #################### 36 | if opt.classes == 10: 37 | opt.dropout = opt.dropout + 0.1 38 | ############################# automation for ModelNet10 / 40 configuration #################### 39 | 40 | visualizer = Visualizer(opt) 41 | 42 | best_accuracy = 0 43 | for epoch in range(301): 44 | 45 | epoch_iter = 0 46 | for i, data in enumerate(trainloader): 47 | iter_start_time = time.time() 48 | epoch_iter += opt.batch_size 49 | 50 | input_pc, input_sn, input_label, input_node, input_node_knn_I = data 51 | model.set_input(input_pc, input_sn, input_label, input_node, input_node_knn_I) 52 | 53 | model.optimize(epoch=epoch) 54 | 55 | if i % 200 == 0: 56 | # print/plot errors 57 | t = (time.time() - iter_start_time) / opt.batch_size 58 | 59 | errors = model.get_current_errors() 60 | 61 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 62 | visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors) 63 | 64 | # print(model.autoencoder.encoder.feature) 65 | # visuals = model.get_current_visuals() 66 | # visualizer.display_current_results(visuals, epoch, i) 67 | 68 | # test network 69 | if epoch >= 0 and epoch%1==0: 70 | batch_amount = 0 71 | model.test_loss.data.zero_() 72 | model.test_accuracy.data.zero_() 73 | for i, data in enumerate(testloader): 74 | input_pc, input_sn, input_label, input_node, input_node_knn_I = data 75 | model.set_input(input_pc, input_sn, input_label, input_node, input_node_knn_I) 76 | model.test_model() 77 | 78 | batch_amount += input_label.size()[0] 79 | 80 | # # accumulate loss 81 | model.test_loss += model.loss.detach() * input_label.size()[0] 82 | 83 | # # accumulate accuracy 84 | _, predicted_idx = torch.max(model.score.data, dim=1, keepdim=False) 85 | correct_mask = torch.eq(predicted_idx, model.input_label).float() 86 | test_accuracy = torch.mean(correct_mask).cpu() 87 | model.test_accuracy += test_accuracy * input_label.size()[0] 88 | 89 | model.test_loss /= batch_amount 90 | model.test_accuracy /= batch_amount 91 | if model.test_accuracy.item() > best_accuracy: 92 | best_accuracy = model.test_accuracy.item() 93 | print('Tested network. So far best: %f' % best_accuracy) 94 | 95 | # save network 96 | if opt.classes == 10: 97 | saving_acc_threshold = 0.930 98 | else: 99 | saving_acc_threshold = 0.918 100 | if model.test_accuracy.item() > saving_acc_threshold: 101 | print("Saving network...") 102 | model.save_network(model.encoder, 'encoder', '%d_%f' % (epoch, model.test_accuracy.item()), opt.gpu_id) 103 | model.save_network(model.classifier, 'classifier', '%d_%f' % (epoch, model.test_accuracy.item()), opt.gpu_id) 104 | 105 | # learning rate decay 106 | if opt.classes == 10: 107 | lr_decay_step = 40 108 | else: 109 | lr_decay_step = 20 110 | if epoch%lr_decay_step==0 and epoch > 0: 111 | model.update_learning_rate(0.5) 112 | # batch normalization momentum decay: 113 | next_epoch = epoch + 1 114 | if (opt.bn_momentum_decay_step is not None) and (next_epoch >= 1) and ( 115 | next_epoch % opt.bn_momentum_decay_step == 0): 116 | current_bn_momentum = opt.bn_momentum * ( 117 | opt.bn_momentum_decay ** (next_epoch // opt.bn_momentum_decay_step)) 118 | print('BN momentum updated to: %f' % current_bn_momentum) 119 | 120 | # save network 121 | # if epoch%20==0 and epoch>0: 122 | # print("Saving network...") 123 | # model.save_network(model.classifier, 'cls', '%d' % epoch, opt.gpu_id) 124 | 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import math 6 | from collections import OrderedDict 7 | import os 8 | import random 9 | 10 | from . import networks 11 | from . import losses 12 | 13 | class Model(): 14 | def __init__(self, opt): 15 | self.opt = opt 16 | 17 | self.old_lr = opt.lr 18 | 19 | self.encoder = networks.Encoder(opt) 20 | self.decoder = networks.Decoder(opt) 21 | self.chamfer_criteria = losses.ChamferLoss(opt) 22 | if self.opt.gpu_id >= 0: 23 | self.encoder = self.encoder.to(self.opt.device) 24 | self.decoder = self.decoder.to(self.opt.device) 25 | self.chamfer_criteria = self.chamfer_criteria.to(self.opt.device) 26 | 27 | self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), 28 | lr=self.opt.lr, 29 | betas=(0.9, 0.999)) 30 | self.optimizer_decoder = torch.optim.Adam(self.decoder.parameters(), 31 | lr=self.opt.lr, 32 | betas=(0.9, 0.999)) 33 | 34 | # place holder for GPU tensors 35 | self.input_pc = torch.FloatTensor(self.opt.batch_size, 3, self.opt.input_pc_num).uniform_() 36 | self.input_sn = torch.FloatTensor(self.opt.batch_size, 3, self.opt.input_pc_num).uniform_() 37 | self.input_label = torch.LongTensor(self.opt.batch_size).fill_(1) 38 | self.input_node = torch.FloatTensor(self.opt.batch_size, 3, self.opt.node_num) 39 | self.input_node_knn_I = torch.LongTensor(self.opt.batch_size, self.opt.node_num, self.opt.som_k) 40 | 41 | # record the test loss and accuracy 42 | self.test_loss = torch.FloatTensor([0]) 43 | 44 | if self.opt.gpu_id >= 0: 45 | self.input_pc = self.input_pc.to(self.opt.device) 46 | self.input_sn = self.input_sn.to(self.opt.device) 47 | self.input_label = self.input_label.to(self.opt.device) 48 | self.input_node = self.input_node.to(self.opt.device) 49 | self.input_node_knn_I = self.input_node_knn_I.to(self.opt.device) 50 | self.test_loss = self.test_loss.to(self.opt.device) 51 | 52 | def set_input(self, input_pc, input_sn, input_label, input_node, input_node_knn_I): 53 | self.input_pc.resize_(input_pc.size()).copy_(input_pc) 54 | self.input_sn.resize_(input_sn.size()).copy_(input_sn) 55 | self.input_label.resize_(input_label.size()).copy_(input_label) 56 | self.input_node.resize_(input_node.size()).copy_(input_node) 57 | self.input_node_knn_I.resize_(input_node_knn_I.size()).copy_(input_node_knn_I) 58 | self.pc = self.input_pc.detach() 59 | self.sn = self.input_sn.detach() 60 | self.label = self.input_label.detach() 61 | 62 | def forward(self, is_train=False, epoch=None): 63 | self.feature = self.encoder(self.pc, self.sn, self.input_node, self.input_node_knn_I, is_train, epoch) # Bx1024 64 | self.predicted_pc = self.decoder(self.feature) 65 | 66 | def optimize(self, epoch=None): 67 | # random point dropout 68 | if self.opt.random_pc_dropout_lower_limit < 0.99: 69 | dropout_keep_ratio = random.uniform(self.opt.random_pc_dropout_lower_limit, 1.0) 70 | resulting_pc_num = round(dropout_keep_ratio * self.opt.input_pc_num) 71 | chosen_indices = np.random.choice(self.opt.input_pc_num, resulting_pc_num, replace=False) 72 | chosen_indices_tensor = torch.from_numpy(chosen_indices).to(self.opt.device) 73 | self.pc = torch.index_select(self.pc, dim=2, index=chosen_indices_tensor) 74 | self.sn = torch.index_select(self.sn, dim=2, index=chosen_indices_tensor) 75 | 76 | self.encoder.train() 77 | self.decoder.train() 78 | self.forward(is_train=True, epoch=epoch) 79 | 80 | self.encoder.zero_grad() 81 | self.decoder.zero_grad() 82 | 83 | if self.opt.output_conv_pc_num > 0: 84 | # loss for second last conv pyramid # 32x32 85 | self.loss_chamfer_conv5 = self.chamfer_criteria(self.decoder.conv_pc5, self.pc) 86 | 87 | # loss for third last conv pyramid # 16x16 88 | self.loss_chamfer_conv4 = self.chamfer_criteria(self.decoder.conv_pc4, self.pc) 89 | 90 | # loss for the last pyramid, i.e., the final pc 91 | self.loss_chamfer = self.chamfer_criteria(self.predicted_pc, self.pc) 92 | 93 | if self.opt.output_conv_pc_num == 1024: 94 | self.loss = self.loss_chamfer + self.loss_chamfer_conv4 95 | elif self.opt.output_conv_pc_num == 4096: 96 | self.loss = self.loss_chamfer + self.loss_chamfer_conv5 + self.loss_chamfer_conv4 97 | else: 98 | self.loss = self.loss_chamfer 99 | 100 | self.loss.backward() 101 | 102 | self.optimizer_encoder.step() 103 | self.optimizer_decoder.step() 104 | 105 | def test_model(self): 106 | self.encoder.eval() 107 | self.decoder.eval() 108 | self.forward(is_train=False) 109 | 110 | if self.opt.output_conv_pc_num > 0: 111 | # loss for second last conv pyramid # 32x32 112 | if self.opt.output_conv_pc_num == 4096: 113 | self.loss_chamfer_conv5 = self.chamfer_criteria(self.decoder.conv_pc5, self.pc) 114 | 115 | # loss for third last conv pyramid # 16x16 116 | self.loss_chamfer_conv4 = self.chamfer_criteria(self.decoder.conv_pc4, self.pc) 117 | 118 | # loss for the last pyramid, i.e., the final pc 119 | self.loss_chamfer = self.chamfer_criteria(self.predicted_pc, self.pc) 120 | 121 | if self.opt.output_conv_pc_num == 1024: 122 | self.loss = self.loss_chamfer + self.loss_chamfer_conv4 123 | elif self.opt.output_conv_pc_num == 4096: 124 | self.loss = self.loss_chamfer + self.loss_chamfer_conv5 + self.loss_chamfer_conv4 125 | elif self.opt.output_conv_pc_num == 0: 126 | self.loss = self.loss_chamfer 127 | 128 | # visualization with visdom 129 | def get_current_visuals(self): 130 | # display only one instance of pc/img 131 | input_pc_np = self.input_pc[0].cpu().numpy() 132 | predicted_pc_np = self.predicted_pc.cpu().data[0].numpy() 133 | 134 | return OrderedDict([('input_pc', input_pc_np),('predicted_pc', predicted_pc_np)]) 135 | 136 | def get_current_errors(self): 137 | return OrderedDict([ 138 | ('total', self.loss_chamfer.item()), 139 | ('forward', self.chamfer_criteria.forward_loss.item()), 140 | ('backward', self.chamfer_criteria.backward_loss.item()), 141 | ('test_loss', self.test_loss.item()) 142 | ]) 143 | 144 | def save_network(self, network, network_label, epoch_label, gpu_id): 145 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 146 | save_path = os.path.join(self.opt.checkpoints_dir, save_filename) 147 | torch.save(network.cpu().state_dict(), save_path) 148 | if gpu_id>=0 and torch.cuda.is_available(): 149 | # torch.cuda.device(gpu_id) 150 | network.to(self.opt.device) 151 | 152 | def update_learning_rate(self, ratio): 153 | # encoder + decoder 154 | lr = self.old_lr * ratio 155 | for param_group in self.optimizer_encoder.param_groups: 156 | param_group['lr'] = lr 157 | for param_group in self.optimizer_decoder.param_groups: 158 | param_group['lr'] = lr 159 | print('update encoder-decoder learning rate: %f -> %f' % (self.old_lr, lr)) 160 | self.old_lr = lr 161 | 162 | -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import math 6 | from collections import OrderedDict 7 | import os 8 | import sys 9 | import random 10 | 11 | from . import networks 12 | from . import losses 13 | 14 | 15 | class Model(): 16 | def __init__(self, opt): 17 | self.opt = opt 18 | 19 | self.encoder = networks.Encoder(opt) 20 | self.classifier = networks.Classifier(opt) 21 | 22 | # learning rate_control 23 | if self.opt.pretrain is not None: 24 | self.old_lr_encoder = self.opt.lr * self.opt.pretrain_lr_ratio 25 | else: 26 | self.old_lr_encoder = self.opt.lr 27 | self.old_lr_classifier = self.opt.lr 28 | 29 | self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), 30 | lr=self.old_lr_encoder, 31 | betas=(0.9, 0.999), 32 | weight_decay=0) 33 | self.optimizer_classifier = torch.optim.Adam(self.classifier.parameters(), 34 | lr= self.old_lr_classifier, 35 | betas=(0.9, 0.999), 36 | weight_decay=0) 37 | 38 | self.softmax_criteria = nn.CrossEntropyLoss() 39 | if self.opt.gpu_id >= 0: 40 | self.encoder = self.encoder.to(self.opt.device) 41 | self.classifier = self.classifier.to(self.opt.device) 42 | self.softmax_criteria = self.softmax_criteria.to(self.opt.device) 43 | 44 | # place holder for GPU tensors 45 | self.input_pc = torch.FloatTensor(self.opt.batch_size, 3, self.opt.input_pc_num).uniform_() 46 | self.input_sn = torch.FloatTensor(self.opt.batch_size, 3, self.opt.input_pc_num).uniform_() 47 | self.input_label = torch.LongTensor(self.opt.batch_size).fill_(1) 48 | self.input_node = torch.FloatTensor(self.opt.batch_size, 3, self.opt.node_num) 49 | self.input_node_knn_I = torch.LongTensor(self.opt.batch_size, self.opt.node_num, self.opt.som_k) 50 | 51 | # record the test loss and accuracy 52 | self.test_loss = torch.tensor([0], dtype=torch.float32, requires_grad=False) 53 | self.test_accuracy = torch.tensor([0], dtype=torch.float32, requires_grad=False) 54 | 55 | if self.opt.gpu_id >= -1: 56 | self.input_pc = self.input_pc.to(self.opt.device) 57 | self.input_sn = self.input_sn.to(self.opt.device) 58 | self.input_label = self.input_label.to(self.opt.device) 59 | self.input_node = self.input_node.to(self.opt.device) 60 | self.input_node_knn_I = self.input_node_knn_I.to(self.opt.device) 61 | self.test_loss = self.test_loss.to(self.opt.device) 62 | # self.test_accuracy = self.test_accuracy.to(self.opt.device) 63 | 64 | def set_input(self, input_pc, input_sn, input_label, input_node, input_node_knn_I): 65 | self.input_pc.resize_(input_pc.size()).copy_(input_pc) 66 | self.input_sn.resize_(input_sn.size()).copy_(input_sn) 67 | self.input_label.resize_(input_label.size()).copy_(input_label) 68 | self.input_node.resize_(input_node.size()).copy_(input_node) 69 | self.input_node_knn_I.resize_(input_node_knn_I.size()).copy_(input_node_knn_I) 70 | self.pc = self.input_pc.detach() 71 | self.sn = self.input_sn.detach() 72 | self.label = self.input_label.detach() 73 | 74 | def forward(self, is_train=False, epoch=None): 75 | self.feature = self.encoder(self.pc, self.sn, self.input_node, self.input_node_knn_I, is_train, epoch) # Bx1024 76 | self.score = self.classifier(self.feature, epoch) 77 | 78 | def optimize(self, epoch=None): 79 | # random point dropout 80 | if self.opt.random_pc_dropout_lower_limit < 0.99: 81 | dropout_keep_ratio = random.uniform(self.opt.random_pc_dropout_lower_limit, 1.0) 82 | resulting_pc_num = round(dropout_keep_ratio*self.opt.input_pc_num) 83 | chosen_indices = np.random.choice(self.opt.input_pc_num, resulting_pc_num, replace=False) 84 | chosen_indices_tensor = torch.from_numpy(chosen_indices).to(self.opt.device) 85 | self.pc = torch.index_select(self.pc, dim=2, index=chosen_indices_tensor) 86 | self.sn = torch.index_select(self.sn, dim=2, index=chosen_indices_tensor) 87 | 88 | self.encoder.train() 89 | self.classifier.train() 90 | self.forward(is_train=True, epoch=epoch) 91 | 92 | self.encoder.zero_grad() 93 | self.classifier.zero_grad() 94 | 95 | self.loss = self.softmax_criteria(self.score, self.label) 96 | self.loss.backward() 97 | 98 | self.optimizer_encoder.step() 99 | self.optimizer_classifier.step() 100 | 101 | def test_model(self): 102 | self.encoder.eval() 103 | self.classifier.eval() 104 | self.forward(is_train=False) 105 | self.loss = self.softmax_criteria(self.score, self.label) 106 | 107 | # visualization with visdom 108 | def get_current_visuals(self): 109 | # display only one instance of pc/img 110 | input_pc_np = self.input_pc[0].cpu().numpy() 111 | 112 | return OrderedDict([('input_pc', input_pc_np)]) 113 | 114 | def get_current_errors(self): 115 | # get the accuracy 116 | _, predicted_idx = torch.max(self.score.data, dim=1, keepdim=False) 117 | correct_mask = torch.eq(predicted_idx, self.input_label).float() 118 | train_accuracy = torch.mean(correct_mask) 119 | 120 | return OrderedDict([ 121 | ('train_loss', self.loss.item()), 122 | ('train_accuracy', train_accuracy.item()), 123 | ('test_loss', self.test_loss.item()), 124 | ('test_accuracy', self.test_accuracy.item()) 125 | ]) 126 | 127 | def save_network(self, network, network_label, epoch_label, gpu_id): 128 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 129 | save_path = os.path.join(self.opt.checkpoints_dir, save_filename) 130 | torch.save(network.cpu().state_dict(), save_path) 131 | if gpu_id>=0 and torch.cuda.is_available(): 132 | # torch.cuda.device(gpu_id) 133 | network.to(self.opt.device) 134 | 135 | def update_learning_rate(self, ratio): 136 | lr_clip = 0.00001 137 | 138 | # encoder 139 | lr_encoder = self.old_lr_encoder * ratio 140 | if lr_encoder < lr_clip: 141 | lr_encoder = lr_clip 142 | for param_group in self.optimizer_encoder.param_groups: 143 | param_group['lr'] = lr_encoder 144 | print('update encoder learning rate: %f -> %f' % (self.old_lr_encoder, lr_encoder)) 145 | self.old_lr_encoder = lr_encoder 146 | 147 | # classifier 148 | lr_classifier = self.old_lr_classifier * ratio 149 | if lr_classifier < lr_clip: 150 | lr_classifier = lr_clip 151 | for param_group in self.optimizer_classifier.param_groups: 152 | param_group['lr'] = lr_classifier 153 | print('update classifier learning rate: %f -> %f' % (self.old_lr_classifier, lr_classifier)) 154 | self.old_lr_classifier = lr_classifier -------------------------------------------------------------------------------- /models/index_max_ext/index_max.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | // cpu operations ------------------------------- 8 | void max_forward_worker(torch::TensorAccessor* p_data_a, 9 | torch::TensorAccessor* p_index_a, 10 | torch::TensorAccessor* p_max_idx_a, 11 | torch::TensorAccessor* p_max_val_a, 12 | int c_begin, int c_end) { 13 | int B = p_data_a->size(0); 14 | // int C = p_data_a->size(1); 15 | int N = p_data_a->size(2); 16 | // int K = p_max_idx_a->size(2); 17 | 18 | // thread is on C channel 19 | for (int b=0; b (*p_max_val_a)[b][c][k]) { 25 | (*p_max_val_a)[b][c][k] = data_point; 26 | (*p_max_idx_a)[b][c][k] = n; 27 | } 28 | } 29 | } 30 | } 31 | } 32 | 33 | torch::Tensor index_max_forward_pthread_cpu(const torch::Tensor data, 34 | const torch::Tensor index, 35 | const int K, 36 | const int thread_num) { 37 | int B = data.size(0); 38 | int C = data.size(1); 39 | // int N = data.size(2); 40 | 41 | auto max_idx = torch::zeros({B, C, K}, torch::TensorOptions().dtype(torch::kInt32)); 42 | auto max_val = torch::ones({B, C, K}, torch::TensorOptions().dtype(torch::kFloat32)) * -1000.0; 43 | 44 | // use accessor 45 | auto data_a = data.accessor(); 46 | auto index_a = index.accessor(); 47 | auto max_idx_a = max_idx.accessor(); 48 | auto max_val_a = max_val.accessor(); 49 | 50 | // multi thread for loop, divide on C channel 51 | std::thread thread_pool[thread_num]; 52 | int c_interval = int(C) / int(thread_num); 53 | for (int t=0;t(); 93 | auto index_a = index.accessor(); 94 | auto max_idx_a = max_idx.accessor(); 95 | auto max_val_a = max_val.accessor(); 96 | 97 | // single thread for loop 98 | for (int b=0; b max_val_a[b][c][k]) { 104 | max_val_a[b][c][k] = data_point; 105 | max_idx_a[b][c][k] = n; 106 | } 107 | } 108 | } 109 | } 110 | 111 | return max_idx; 112 | } 113 | // cpu operations ------------------------------- 114 | 115 | 116 | 117 | 118 | // cuda operations ------------------------------ 119 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor/variable") 120 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 121 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 122 | 123 | // declare the functions in .cu file 124 | torch::Tensor index_max_forward_cuda(const torch::Tensor data, 125 | const torch::Tensor index, 126 | const int K); 127 | 128 | torch::Tensor index_max_forward_cuda_shared_mem(const torch::Tensor data, 129 | const torch::Tensor index, 130 | const int K); 131 | 132 | torch::Tensor index_max_forward_cuda_wrapper(const torch::Tensor data, 133 | const torch::Tensor index, 134 | const int K){ 135 | CHECK_INPUT(data); 136 | CHECK_INPUT(index); 137 | 138 | return index_max_forward_cuda(data, index, K); 139 | } 140 | 141 | torch::Tensor index_max_forward_cuda_wrapper_shared_mem(const torch::Tensor data, 142 | const torch::Tensor index, 143 | const int K){ 144 | CHECK_INPUT(data); 145 | CHECK_INPUT(index); 146 | 147 | return index_max_forward_cuda_shared_mem(data, index, K); 148 | } 149 | // cuda operations ------------------------------ 150 | 151 | 152 | 153 | 154 | PYBIND11_MODULE(index_max, m) { 155 | m.def("forward_cpu", &index_max_forward_cpu, "CPU single thread"); 156 | m.def("forward_multi_thread_cpu", &index_max_forward_pthread_cpu, "CPU multi-thread"); 157 | m.def("forward_cuda", &index_max_forward_cuda_wrapper, "CUDA code without shared memory"); 158 | m.def("forward_cuda_shared_mem", &index_max_forward_cuda_wrapper_shared_mem, "CUDA code with shared memory"); 159 | } 160 | -------------------------------------------------------------------------------- /models/index_max_ext/index_max_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | 10 | __global__ void index_max_forward_cuda_kernel(const float* __restrict__ data, 11 | const int* __restrict__ index, 12 | int* __restrict__ max_idx, 13 | float* __restrict__ max_val, 14 | const int B, const int C, const int N, const int K){ 15 | int b = threadIdx.x; 16 | int c = blockIdx.x; 17 | 18 | for(int n=0;n max_val[b*C*K+c*K+k]){ 22 | max_val[b*C*K+c*K+k] = data_point; 23 | max_idx[b*C*K+c*K+k] = n; 24 | } 25 | } 26 | } 27 | 28 | 29 | 30 | __global__ void index_max_forward_cuda_kernel_shared_mem(const float* __restrict__ data, 31 | const int* __restrict__ index, 32 | int* __restrict__ max_idx, 33 | float* __restrict__ max_val, 34 | const int B, const int C, const int N, const int K){ 35 | int b = threadIdx.x; 36 | int c = blockIdx.x; 37 | 38 | extern __shared__ float max_val_shared[]; 39 | for (int i=0;i max_val_shared[b*K+k]){ 47 | max_val_shared[b*K+k] = data_point; 48 | max_idx[b*C*K+c*K+k] = n; 49 | } 50 | } 51 | // 52 | // __syncthreads(); 53 | 54 | // for(int n=0;n max_val[b*C*K+c*K+k]){ 58 | // max_val[b*C*K+c*K+k] = data_point; 59 | // max_idx[b*C*K+c*K+k] = n; 60 | // } 61 | // } 62 | } 63 | 64 | 65 | 66 | torch::Tensor index_max_forward_cuda(const torch::Tensor data, const torch::Tensor index, const int K){ 67 | int B = data.size(0); 68 | int C = data.size(1); 69 | int N = data.size(2); 70 | 71 | auto device_idx = data.device().index(); 72 | auto max_idx = torch::zeros({B, C, K}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, device_idx)); 73 | auto max_val = torch::ones({B, C, K}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, device_idx)) * -1000.0; 74 | 75 | index_max_forward_cuda_kernel<<>>(data.data(), 76 | index.data(), 77 | max_idx.data(), 78 | max_val.data(), 79 | B, C, N, K); 80 | 81 | return max_idx; 82 | } 83 | 84 | torch::Tensor index_max_forward_cuda_shared_mem(const torch::Tensor data, const torch::Tensor index, const int K){ 85 | int B = data.size(0); 86 | int C = data.size(1); 87 | int N = data.size(2); 88 | 89 | auto device_idx = data.device().index(); 90 | auto max_idx = torch::zeros({B, C, K}, torch::TensorOptions({torch::kCUDA, device_idx}).dtype(torch::kInt32)); 91 | auto max_val = torch::ones({B, C, K}, torch::TensorOptions({torch::kCUDA, device_idx}).dtype(torch::kFloat32)) * -1000.0; 92 | 93 | index_max_forward_cuda_kernel_shared_mem<<>>(data.data(), 94 | index.data(), 95 | max_idx.data(), 96 | max_val.data(), 97 | B, C, N, K); 98 | 99 | return max_idx; 100 | } 101 | -------------------------------------------------------------------------------- /models/index_max_ext/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import torch 3 | from setuptools import setup 4 | from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension 5 | 6 | # setup(name='index_max', 7 | # ext_modules=[CppExtension('index_max', ['index_max.cpp'])], 8 | # cmdclass={'build_ext': BuildExtension}) 9 | 10 | # setuptools.Extension( 11 | # name='index_max', 12 | # sources=['index_max.cpp'], 13 | # include_dirs=torch.utils.cpp_extension.include_paths(), 14 | # language='c++') 15 | 16 | setup(name='index_max', 17 | ext_modules=[CUDAExtension('index_max', ['index_max.cpp', 'index_max_cuda.cu'])], 18 | cmdclass={'build_ext': BuildExtension}) -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.nn.modules.batchnorm import _BatchNorm 6 | import numpy as np 7 | import math 8 | import torch.utils.model_zoo as model_zoo 9 | import time 10 | 11 | from . import operations 12 | 13 | 14 | class Swish(nn.Module): 15 | def __init__(self): 16 | super(Swish, self).__init__() 17 | 18 | def forward(self, x): 19 | return x * torch.sigmoid(x) 20 | 21 | 22 | class MyBatchNorm1d(_BatchNorm): 23 | r"""Applies Batch Normalization over a 2d or 3d input that is seen as a 24 | mini-batch. 25 | .. math:: 26 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 27 | The mean and standard-deviation are calculated per-dimension over 28 | the mini-batches and gamma and beta are learnable parameter vectors 29 | of size C (where C is the input size). 30 | During training, this layer keeps a running estimate of its computed mean 31 | and variance. The running sum is kept with a default momentum of 0.1. 32 | During evaluation, this running mean/variance is used for normalization. 33 | Because the BatchNorm is done over the `C` dimension, computing statistics 34 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 35 | Args: 36 | num_features: num_features from an expected input of size 37 | `batch_size x num_features [x width]` 38 | eps: a value added to the denominator for numerical stability. 39 | Default: 1e-5 40 | momentum: the value used for the running_mean and running_var 41 | computation. Default: 0.1 42 | affine: a boolean value that when set to ``True``, gives the layer learnable 43 | affine parameters. Default: ``True`` 44 | Shape: 45 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 46 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 47 | """ 48 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, momentum_decay_step=None, momentum_decay=1): 49 | super(MyBatchNorm1d, self).__init__(num_features, eps, momentum, affine) 50 | self.momentum_decay_step = momentum_decay_step 51 | self.momentum_decay = momentum_decay 52 | self.momentum_original = self.momentum 53 | 54 | def _check_input_dim(self, input): 55 | if input.dim() != 2 and input.dim() != 3: 56 | raise ValueError('expected 2D or 3D input (got {}D input)' 57 | .format(input.dim())) 58 | super(MyBatchNorm1d, self)._check_input_dim(input) 59 | 60 | def forward(self, input, epoch=None): 61 | if (epoch is not None) and (epoch >= 1) and (self.momentum_decay_step is not None) and (self.momentum_decay_step > 0): 62 | # perform momentum decay 63 | self.momentum = self.momentum_original * (self.momentum_decay**(epoch//self.momentum_decay_step)) 64 | if self.momentum < 0.01: 65 | self.momentum = 0.01 66 | 67 | 68 | return F.batch_norm( 69 | input, self.running_mean, self.running_var, self.weight, self.bias, 70 | self.training, self.momentum, self.eps) 71 | 72 | 73 | class MyBatchNorm2d(_BatchNorm): 74 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 75 | of 3d inputs 76 | .. math:: 77 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 78 | The mean and standard-deviation are calculated per-dimension over 79 | the mini-batches and gamma and beta are learnable parameter vectors 80 | of size C (where C is the input size). 81 | During training, this layer keeps a running estimate of its computed mean 82 | and variance. The running sum is kept with a default momentum of 0.1. 83 | During evaluation, this running mean/variance is used for normalization. 84 | Because the BatchNorm is done over the `C` dimension, computing statistics 85 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 86 | Args: 87 | num_features: num_features from an expected input of 88 | size batch_size x num_features x height x width 89 | eps: a value added to the denominator for numerical stability. 90 | Default: 1e-5 91 | momentum: the value used for the running_mean and running_var 92 | computation. Default: 0.1 93 | affine: a boolean value that when set to ``True``, gives the layer learnable 94 | affine parameters. Default: ``True`` 95 | Shape: 96 | - Input: :math:`(N, C, H, W)` 97 | - Output: :math:`(N, C, H, W)` (same shape as input) 98 | """ 99 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, momentum_decay_step=None, momentum_decay=1): 100 | super(MyBatchNorm2d, self).__init__(num_features, eps, momentum, affine) 101 | self.momentum_decay_step = momentum_decay_step 102 | self.momentum_decay = momentum_decay 103 | self.momentum_original = self.momentum 104 | 105 | def _check_input_dim(self, input): 106 | if input.dim() != 4: 107 | raise ValueError('expected 4D input (got {}D input)' 108 | .format(input.dim())) 109 | super(MyBatchNorm2d, self)._check_input_dim(input) 110 | 111 | def forward(self, input, epoch=None): 112 | if (epoch is not None) and (epoch >= 1) and (self.momentum_decay_step is not None) and (self.momentum_decay_step > 0): 113 | # perform momentum decay 114 | self.momentum = self.momentum_original * (self.momentum_decay**(epoch//self.momentum_decay_step)) 115 | if self.momentum < 0.01: 116 | self.momentum = 0.01 117 | 118 | return F.batch_norm( 119 | input, self.running_mean, self.running_var, self.weight, self.bias, 120 | self.training, self.momentum, self.eps) 121 | 122 | 123 | class MyLinear(nn.Module): 124 | def __init__(self, in_features, out_features, activation=None, normalization=None, momentum=0.1, bn_momentum_decay_step=None, bn_momentum_decay=1): 125 | super(MyLinear, self).__init__() 126 | self.activation = activation 127 | self.normalization = normalization 128 | 129 | self.linear = nn.Linear(in_features, out_features, bias=True) 130 | if self.normalization == 'batch': 131 | self.norm = MyBatchNorm1d(out_features, momentum=momentum, affine=True, momentum_decay_step=bn_momentum_decay_step, momentum_decay=bn_momentum_decay) 132 | elif self.normalization == 'instance': 133 | self.norm = nn.InstanceNorm1d(out_features, momentum=momentum, affine=True) 134 | if self.activation == 'relu': 135 | self.act = nn.ReLU() 136 | elif 'elu' == activation: 137 | self.act = nn.ELU(alpha=1.0) 138 | elif 'swish' == self.activation: 139 | self.act = Swish() 140 | elif 'leakyrelu' == self.activation: 141 | self.act = nn.LeakyReLU(0.1) 142 | 143 | self.weight_init() 144 | 145 | def weight_init(self): 146 | for m in self.modules(): 147 | if isinstance(m, nn.Linear) : 148 | n = m.in_features 149 | m.weight.data.normal_(0, math.sqrt(2. / n)) 150 | if m.bias is not None: 151 | m.bias.data.fill_(0) 152 | elif isinstance(m, MyBatchNorm1d) or isinstance(m, nn.InstanceNorm1d): 153 | m.weight.data.fill_(1) 154 | m.bias.data.zero_() 155 | 156 | def forward(self, x, epoch=None): 157 | x = self.linear(x) 158 | if self.normalization=='batch': 159 | x = self.norm(x, epoch) 160 | elif self.normalization is not None: 161 | x = self.norm(x) 162 | 163 | if self.activation is not None: 164 | x = self.act(x) 165 | 166 | return x 167 | 168 | 169 | class MyConv2d(nn.Module): 170 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, activation=None, momentum=0.1, normalization=None, bn_momentum_decay_step=None, bn_momentum_decay=1): 171 | super(MyConv2d, self).__init__() 172 | self.activation = activation 173 | self.normalization = normalization 174 | 175 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 176 | if self.normalization == 'batch': 177 | self.norm = MyBatchNorm2d(out_channels, momentum=momentum, affine=True, momentum_decay_step=bn_momentum_decay_step, momentum_decay=bn_momentum_decay) 178 | elif self.normalization == 'instance': 179 | self.norm = nn.InstanceNorm2d(out_channels, momentum=momentum, affine=True) 180 | if self.activation == 'relu': 181 | self.act = nn.ReLU() 182 | elif self.activation == 'elu': 183 | self.act = nn.ELU(alpha=1.0) 184 | elif 'swish' == self.activation: 185 | self.act = Swish() 186 | elif 'leakyrelu' == self.activation: 187 | self.act = nn.LeakyReLU(0.1) 188 | 189 | self.weight_init() 190 | 191 | def weight_init(self): 192 | for m in self.modules(): 193 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): 194 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 195 | m.weight.data.normal_(0, math.sqrt(2. / n)) 196 | if m.bias is not None: 197 | m.bias.data.fill_(0) 198 | elif isinstance(m, MyBatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 199 | m.weight.data.fill_(1) 200 | m.bias.data.zero_() 201 | 202 | def forward(self, x, epoch=None): 203 | x = self.conv(x) 204 | if self.normalization=='batch': 205 | x = self.norm(x, epoch) 206 | elif self.normalization is not None: 207 | x = self.norm(x) 208 | 209 | if self.activation is not None: 210 | x = self.act(x) 211 | return x 212 | 213 | 214 | class UpConv(nn.Module): 215 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, output_padding=0, bias=True, activation=None, normalization=None): 216 | super(UpConv, self).__init__() 217 | self.activation = activation 218 | self.normalization = normalization 219 | 220 | self.up_sample = nn.Upsample(scale_factor=2) 221 | self.conv = MyConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, activation=activation, normalization=normalization) 222 | 223 | self.weight_init() 224 | 225 | def weight_init(self): 226 | for m in self.modules(): 227 | if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): 228 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 229 | m.weight.data.normal_(0, math.sqrt(2. / n)) 230 | if m.bias is not None: 231 | m.bias.data.fill_(0.001) 232 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 233 | m.weight.data.fill_(1) 234 | m.bias.data.zero_() 235 | 236 | def forward(self, x): 237 | x = self.up_sample(x) 238 | x = self.conv(x) 239 | 240 | return x 241 | 242 | 243 | class EquivariantLayer(nn.Module): 244 | def __init__(self, num_in_channels, num_out_channels, activation='relu', normalization=None, momentum=0.1, bn_momentum_decay_step=None, bn_momentum_decay=1): 245 | super(EquivariantLayer, self).__init__() 246 | 247 | self.num_in_channels = num_in_channels 248 | self.num_out_channels = num_out_channels 249 | self.activation = activation 250 | self.normalization = normalization 251 | 252 | self.conv = nn.Conv1d(self.num_in_channels, self.num_out_channels, kernel_size=1, stride=1, padding=0) 253 | 254 | if 'batch' == self.normalization: 255 | self.norm = MyBatchNorm1d(self.num_out_channels, momentum=momentum, affine=True, momentum_decay_step=bn_momentum_decay_step, momentum_decay=bn_momentum_decay) 256 | elif 'instance' == self.normalization: 257 | self.norm = nn.InstanceNorm1d(self.num_out_channels, momentum=momentum, affine=True) 258 | 259 | if 'relu' == self.activation: 260 | self.act = nn.ReLU() 261 | elif 'elu' == self.activation: 262 | self.act = nn.ELU(alpha=1.0) 263 | elif 'swish' == self.activation: 264 | self.act = Swish() 265 | elif 'leakyrelu' == self.activation: 266 | self.act = nn.LeakyReLU(0.1) 267 | 268 | 269 | self.weight_init() 270 | 271 | def weight_init(self): 272 | for m in self.modules(): 273 | if isinstance(m, nn.Conv1d): 274 | n = m.kernel_size[0] * m.in_channels 275 | m.weight.data.normal_(0, math.sqrt(2. / n)) 276 | if m.bias is not None: 277 | m.bias.data.fill_(0) 278 | elif isinstance(m, MyBatchNorm1d) or isinstance(m, nn.InstanceNorm1d): 279 | m.weight.data.fill_(1) 280 | m.bias.data.zero_() 281 | 282 | def forward(self, x, epoch=None): 283 | # x is NxK, x_max is 1xK 284 | # x_max, _ = torch.max(x, 0, keepdim=True) 285 | # y = self.conv(x - x_max.expand_as(x)) 286 | y = self.conv(x) 287 | 288 | if self.normalization=='batch': 289 | y = self.norm(y, epoch) 290 | elif self.normalization is not None: 291 | y = self.norm(y) 292 | 293 | if self.activation is not None: 294 | y = self.act(y) 295 | 296 | return y 297 | 298 | 299 | class KNNModule(nn.Module): 300 | def __init__(self, in_channels, out_channels_list, activation, normalization, momentum=0.1, 301 | bn_momentum_decay_step=None, bn_momentum_decay=1): 302 | super(KNNModule, self).__init__() 303 | 304 | self.layers = nn.ModuleList() 305 | previous_out_channels = in_channels 306 | for c_out in out_channels_list: 307 | self.layers.append(MyConv2d(previous_out_channels, c_out, kernel_size=1, stride=1, padding=0, bias=True, 308 | activation=activation, normalization=normalization, 309 | momentum=momentum, bn_momentum_decay_step=bn_momentum_decay_step, 310 | bn_momentum_decay=bn_momentum_decay)) 311 | previous_out_channels = c_out 312 | 313 | def forward(self, coordinate, x, precomputed_knn_I, K, center_type, epoch=None): 314 | ''' 315 | 316 | :param coordinate: Bx3xM Variable 317 | :param x: BxCxM Variable 318 | :param precomputed_knn_I: BxMxK' 319 | :param K: K neighbors 320 | :param center_type: 'center' or 'avg' 321 | :return: 322 | ''' 323 | # 0. compute knn 324 | # 1. for each node, calculate the center of its k neighborhood 325 | # 2. normalize nodes with the corresponding center 326 | # 3. fc for these normalized points 327 | # 4. maxpool for each neighborhood 328 | 329 | coordinate_tensor = coordinate.data # Bx3xM 330 | if precomputed_knn_I is not None: 331 | assert precomputed_knn_I.size()[2] >= K 332 | knn_I = precomputed_knn_I[:, :, 0:K] 333 | else: 334 | coordinate_Mx1 = coordinate_tensor.unsqueeze(3) # Bx3xMx1 335 | coordinate_1xM = coordinate_tensor.unsqueeze(2) # Bx3x1xM 336 | norm = torch.sum((coordinate_Mx1 - coordinate_1xM) ** 2, dim=1) # BxMxM, each row corresponds to each coordinate - other coordinates 337 | knn_D, knn_I = torch.topk(norm, k=K, dim=2, largest=False, sorted=True) # BxMxK 338 | 339 | # debug 340 | # print(knn_D[0]) 341 | # print(knn_I[0]) 342 | # assert False 343 | 344 | # get gpu_id 345 | device_index = x.device.index 346 | neighbors = operations.knn_gather_wrapper(coordinate_tensor, knn_I) # Bx3xMxK 347 | if center_type == 'avg': 348 | neighbors_center = torch.mean(neighbors, dim=3, keepdim=True) # Bx3xMx1 349 | elif center_type == 'center': 350 | neighbors_center = coordinate_tensor.unsqueeze(3) # Bx3xMx1 351 | neighbors_decentered = (neighbors - neighbors_center).detach() 352 | neighbors_center = neighbors_center.squeeze(3).detach() 353 | 354 | # debug 355 | # print(neighbors[0, 0]) 356 | # print(neighbors_avg[0, 0]) 357 | # print(neighbors_decentered[0, 0]) 358 | # assert False 359 | 360 | x_neighbors = operations.knn_gather_by_indexing(x, knn_I) # BxCxMxK 361 | x_augmented = torch.cat((neighbors_decentered, x_neighbors), dim=1) # Bx(3+C)xMxK 362 | 363 | for layer in self.layers: 364 | x_augmented = layer(x_augmented, epoch) 365 | feature, _ = torch.max(x_augmented, dim=3, keepdim=False) 366 | 367 | return neighbors_center, feature 368 | 369 | 370 | class PointNet(nn.Module): 371 | def __init__(self, in_channels, out_channels_list, activation, normalization, momentum=0.1, bn_momentum_decay_step=None, bn_momentum_decay=1): 372 | super(PointNet, self).__init__() 373 | 374 | self.layers = nn.ModuleList() 375 | previous_out_channels = in_channels 376 | for i, c_out in enumerate(out_channels_list): 377 | if i != len(out_channels_list)-1: 378 | self.layers.append(EquivariantLayer(previous_out_channels, c_out, activation, normalization, 379 | momentum, bn_momentum_decay_step, bn_momentum_decay)) 380 | else: 381 | self.layers.append(EquivariantLayer(previous_out_channels, c_out, None, None)) 382 | previous_out_channels = c_out 383 | 384 | def forward(self, x, epoch=None): 385 | for layer in self.layers: 386 | x = layer(x, epoch) 387 | return x 388 | 389 | 390 | class PointResNet(nn.Module): 391 | def __init__(self, in_channels, out_channels_list, activation, normalization, momentum=0.1, bn_momentum_decay_step=None, bn_momentum_decay=1): 392 | ''' 393 | in -> out[0] 394 | out[0] -> out[1] ---- 395 | out[1] -> out[2] | 396 | ... ... | 397 | out[k-2]+out[1] -> out[k-1] <--- 398 | :param in_channels: 399 | :param out_channels_list: 400 | :param activation: 401 | :param normalization: 402 | :param momentum: 403 | :param bn_momentum_decay_step: 404 | :param bn_momentum_decay: 405 | ''' 406 | super(PointResNet, self).__init__() 407 | self.out_channels_list = out_channels_list 408 | 409 | self.layers = nn.ModuleList() 410 | previous_out_channels = in_channels 411 | for i, c_out in enumerate(out_channels_list): 412 | if i != len(out_channels_list)-1: 413 | self.layers.append(EquivariantLayer(previous_out_channels, c_out, activation, normalization, 414 | momentum, bn_momentum_decay_step, bn_momentum_decay)) 415 | else: 416 | self.layers.append(EquivariantLayer(previous_out_channels+out_channels_list[0], c_out, None, None)) 417 | previous_out_channels = c_out 418 | 419 | def forward(self, x, epoch=None): 420 | ''' 421 | :param x: BxCxN 422 | :param epoch: None or number of epoch, for BN decay. 423 | :return: 424 | ''' 425 | layer0_out = self.layers[0](x, epoch) # BxCxN 426 | for l in range(1, len(self.out_channels_list)-1): 427 | if l == 1: 428 | x_tmp = self.layers[l](layer0_out, epoch) 429 | else: 430 | x_tmp = self.layers[l](x_tmp, epoch) 431 | layer_final_out = self.layers[len(self.out_channels_list)-1](torch.cat((layer0_out, x_tmp), dim=1), epoch) 432 | return layer_final_out 433 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import math 6 | import torch.utils.model_zoo as model_zoo 7 | import time 8 | import torch.nn.functional as F 9 | import faiss 10 | 11 | import json 12 | import os 13 | import os.path 14 | from collections import OrderedDict 15 | 16 | 17 | def robust_norm(var): 18 | ''' 19 | :param var: Variable of BxCxHxW 20 | :return: p-norm of BxCxW 21 | ''' 22 | result = ((var**2).sum(dim=2) + 1e-8).sqrt() 23 | # result = (var ** 2).sum(dim=2) 24 | 25 | # try to make the points less dense, caused by the backward loss 26 | # result = result.clamp(min=7e-3, max=None) 27 | return result 28 | 29 | 30 | class CrossEntropyLossSeg(nn.Module): 31 | def __init__(self, weight=None, size_average=True): 32 | super(CrossEntropyLossSeg, self).__init__() 33 | self.nll_loss = nn.NLLLoss(weight, size_average) 34 | 35 | def forward(self, inputs, targets): 36 | ''' 37 | :param inputs: BxclassxN 38 | :param targets: BxN 39 | :return: 40 | ''' 41 | inputs = inputs.unsqueeze(3) 42 | targets = targets.unsqueeze(2) 43 | return self.nll_loss(F.log_softmax(inputs, dim=1), targets) 44 | 45 | 46 | def visualize_pc_seg(score, seg, label, visualizer, opt, input_pc, batch_num): 47 | # display only one instance of pc/img 48 | input_pc_np = input_pc.cpu().numpy().transpose() # Nx3 49 | pc_color_np = np.ones(input_pc_np.shape, dtype=int) # Nx3 50 | gt_pc_color_np = np.ones(input_pc_np.shape, dtype=int) # Nx3 51 | 52 | # construct color map 53 | _, predicted_seg = torch.max(score, dim=0, keepdim=False) # 50xN -> N 54 | predicted_seg_np = predicted_seg.cpu().numpy() # N 55 | gt_seg_np = seg.cpu().numpy() # N 56 | 57 | color_map_file = os.path.join(opt.dataroot, 'part_color_mapping.json') 58 | color_map = json.load(open(color_map_file, 'r')) 59 | color_map_np = np.fabs((np.asarray(color_map) * 255)).astype(int) # 50x3 60 | 61 | for i in range(input_pc_np.shape[0]): 62 | pc_color_np[i] = color_map_np[predicted_seg_np[i]] 63 | gt_pc_color_np[i] = color_map_np[gt_seg_np[i]] 64 | if gt_seg_np[i] == 49: 65 | gt_pc_color_np[i] = np.asarray([1, 1, 1]).astype(int) 66 | 67 | dict = OrderedDict([('pc_colored_predicted', [input_pc_np, pc_color_np]), 68 | ('pc_colored_gt', [input_pc_np, gt_pc_color_np])]) 69 | 70 | visualizer.display_current_results(dict, 1, 1) 71 | 72 | 73 | def compute_iou_np_array(score, seg, label, visualizer, opt, input_pc): 74 | part_label = [ 75 | [0, 1, 2, 3], 76 | [4, 5], 77 | [6, 7], 78 | [8, 9, 10, 11], 79 | [12, 13, 14, 15], 80 | [16, 17, 18], 81 | [19, 20, 21], 82 | [22, 23], 83 | [24, 25, 26, 27], 84 | [28, 29], 85 | [30, 31, 32, 33, 34, 35], 86 | [36, 37], 87 | [38, 39, 40], 88 | [41, 42, 43], 89 | [44, 45, 46], 90 | [47, 48, 49] 91 | ] 92 | 93 | _, seg_predicted = torch.max(score, dim=1) # BxN 94 | 95 | iou_batch = [] 96 | for i in range(score.size()[0]): 97 | iou_pc = [] 98 | for part in part_label[label[i]]: 99 | gt = seg[i] == part 100 | predict = seg_predicted[i] == part 101 | 102 | intersection = (gt.int() + predict.int()) == 2 103 | union = (gt.int() + predict.int()) >= 1 104 | 105 | if union.sum() == 0: 106 | iou_part = 1.0 107 | else: 108 | iou_part = intersection.int().sum().item() / (union.int().sum().item() + 0.0001) 109 | 110 | iou_pc.append(iou_part) 111 | 112 | iou_batch.append(np.asarray(iou_pc).mean()) 113 | 114 | iou_np = np.asarray(iou_batch) 115 | 116 | return iou_np 117 | 118 | 119 | def compute_iou(score, seg, label, visualizer, opt, input_pc): 120 | ''' 121 | :param score: BxCxN tensor 122 | :param seg: BxN tensor 123 | :return: 124 | ''' 125 | 126 | part_label = [ 127 | [0, 1, 2, 3], 128 | [4, 5], 129 | [6, 7], 130 | [8, 9, 10, 11], 131 | [12, 13, 14, 15], 132 | [16, 17, 18], 133 | [19, 20, 21], 134 | [22, 23], 135 | [24, 25, 26, 27], 136 | [28, 29], 137 | [30, 31, 32, 33, 34, 35], 138 | [36, 37], 139 | [38, 39, 40], 140 | [41, 42, 43], 141 | [44, 45, 46], 142 | [47, 48, 49] 143 | ] 144 | 145 | _, seg_predicted = torch.max(score, dim=1) # BxN 146 | 147 | iou_batch = [] 148 | vis_flag = False 149 | for i in range(score.size()[0]): 150 | iou_pc = [] 151 | for part in part_label[label[i]]: 152 | gt = seg[i] == part 153 | predict = seg_predicted[i] == part 154 | 155 | intersection = (gt.int() + predict.int()) == 2 156 | union = (gt.int() + predict.int()) >= 1 157 | 158 | # print(intersection) 159 | # print(union) 160 | # assert False 161 | 162 | if union.sum() == 0: 163 | iou_part = 1.0 164 | else: 165 | iou_part = intersection.int().sum().item() / (union.int().sum().item() + 0.0001) 166 | 167 | # debug to see what happened 168 | # if iou_part < 0.1: 169 | # print(part) 170 | # print('predict:') 171 | # print(predict.nonzero()) 172 | # print('gt') 173 | # print(gt.nonzero()) 174 | # vis_flag = True 175 | 176 | iou_pc.append(iou_part) 177 | 178 | # debug to see what happened 179 | if vis_flag: 180 | print('============') 181 | print(iou_pc) 182 | print(label[i]) 183 | visualize_pc_seg(score[i], seg[i], label[i], visualizer, opt, input_pc[i], i) 184 | 185 | iou_batch.append(np.asarray(iou_pc).mean()) 186 | 187 | iou = np.asarray(iou_batch).mean() 188 | 189 | return iou 190 | 191 | 192 | class ChamferLoss(nn.Module): 193 | def __init__(self, opt): 194 | super(ChamferLoss, self).__init__() 195 | self.opt = opt 196 | self.dimension = 3 197 | self.k = 1 198 | 199 | # we need only a StandardGpuResources per GPU 200 | self.res = faiss.StandardGpuResources() 201 | self.res.setTempMemoryFraction(0.1) 202 | self.flat_config = faiss.GpuIndexFlatConfig() 203 | self.flat_config.device = opt.gpu_id 204 | 205 | # place holder 206 | self.forward_loss = torch.FloatTensor([0]) 207 | self.backward_loss = torch.FloatTensor([0]) 208 | 209 | def build_nn_index(self, database): 210 | ''' 211 | :param database: numpy array of Nx3 212 | :return: Faiss index, in CPU 213 | ''' 214 | # index = faiss.GpuIndexFlatL2(self.res, self.dimension, self.flat_config) # dimension is 3 215 | index_cpu = faiss.IndexFlatL2(self.dimension) 216 | index = faiss.index_cpu_to_gpu(self.res, self.opt.gpu_id, index_cpu) 217 | index.add(database) 218 | return index 219 | 220 | def search_nn(self, index, query, k): 221 | ''' 222 | :param index: Faiss index 223 | :param query: numpy array of Nx3 224 | :return: D: Variable of Nxk, type FloatTensor, in GPU 225 | I: Variable of Nxk, type LongTensor, in GPU 226 | ''' 227 | D, I = index.search(query, k) 228 | 229 | D_var =torch.from_numpy(np.ascontiguousarray(D)) 230 | I_var = torch.from_numpy(np.ascontiguousarray(I).astype(np.int64)) 231 | if self.opt.gpu_id >= 0: 232 | D_var = D_var.to(self.opt.device) 233 | I_var = I_var.to(self.opt.device) 234 | 235 | return D_var, I_var 236 | 237 | def forward(self, predict_pc, gt_pc): 238 | ''' 239 | :param predict_pc: Bx3xM Variable in GPU 240 | :param gt_pc: Bx3xN Variable in GPU 241 | :return: 242 | ''' 243 | 244 | predict_pc_size = predict_pc.size() 245 | gt_pc_size = gt_pc.size() 246 | 247 | predict_pc_np = np.ascontiguousarray(torch.transpose(predict_pc.data.clone(), 1, 2).cpu().numpy()) # BxMx3 248 | gt_pc_np = np.ascontiguousarray(torch.transpose(gt_pc.data.clone(), 1, 2).cpu().numpy()) # BxNx3 249 | 250 | # selected_gt: Bxkx3xM 251 | selected_gt_by_predict = torch.FloatTensor(predict_pc_size[0], self.k, predict_pc_size[1], predict_pc_size[2]) 252 | # selected_predict: Bxkx3xN 253 | selected_predict_by_gt = torch.FloatTensor(gt_pc_size[0], self.k, gt_pc_size[1], gt_pc_size[2]) 254 | 255 | if self.opt.gpu_id >= 0: 256 | selected_gt_by_predict = selected_gt_by_predict.to(self.opt.device) 257 | selected_predict_by_gt = selected_predict_by_gt.to(self.opt.device) 258 | 259 | # process each batch independently. 260 | for i in range(predict_pc_np.shape[0]): 261 | index_predict = self.build_nn_index(predict_pc_np[i]) 262 | index_gt = self.build_nn_index(gt_pc_np[i]) 263 | 264 | # database is gt_pc, predict_pc -> gt_pc ----------------------------------------------------------- 265 | _, I_var = self.search_nn(index_gt, predict_pc_np[i], self.k) 266 | 267 | # process nearest k neighbors 268 | for k in range(self.k): 269 | selected_gt_by_predict[i,k,...] = gt_pc[i].index_select(1, I_var[:,k]) 270 | 271 | # database is predict_pc, gt_pc -> predict_pc ------------------------------------------------------- 272 | _, I_var = self.search_nn(index_predict, gt_pc_np[i], self.k) 273 | 274 | # process nearest k neighbors 275 | for k in range(self.k): 276 | selected_predict_by_gt[i,k,...] = predict_pc[i].index_select(1, I_var[:,k]) 277 | 278 | # compute loss =================================================== 279 | # selected_gt(Bxkx3xM) vs predict_pc(Bx3xM) 280 | forward_loss_element = robust_norm(selected_gt_by_predict-predict_pc.unsqueeze(1).expand_as(selected_gt_by_predict)) 281 | self.forward_loss = forward_loss_element.mean() 282 | self.forward_loss_array = forward_loss_element.mean(dim=1).mean(dim=1) 283 | 284 | # selected_predict(Bxkx3xN) vs gt_pc(Bx3xN) 285 | backward_loss_element = robust_norm(selected_predict_by_gt - gt_pc.unsqueeze(1).expand_as(selected_predict_by_gt)) # BxkxN 286 | self.backward_loss = backward_loss_element.mean() 287 | self.backward_loss_array = backward_loss_element.mean(dim=1).mean(dim=1) 288 | 289 | self.loss_array = self.forward_loss_array + self.backward_loss_array 290 | return self.forward_loss + self.backward_loss # + self.sparsity_loss 291 | 292 | def __call__(self, predict_pc, gt_pc): 293 | # start_time = time.time() 294 | loss = self.forward(predict_pc, gt_pc) 295 | # print(time.time()-start_time) 296 | return loss -------------------------------------------------------------------------------- /models/operations.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import torch.multiprocessing as mp 9 | import threading 10 | import ctypes 11 | 12 | 13 | # generalized batch size 14 | CUDA_SHARED_MEM_DIM_X = 24 15 | # size of SOM 16 | CUDA_SHARED_MEM_DIM_Y = 512 17 | 18 | 19 | def knn_gather_wrapper(som_node, som_node_knn_I): 20 | ''' 21 | 22 | :param som_node: Bx3xN 23 | :param som_node_knn_I: BxNxK 24 | :param som_node_neighbors: Bx3xNxK 25 | :return: 26 | ''' 27 | B = som_node.size()[0] 28 | C = som_node.size()[1] 29 | N = som_node.size()[2] 30 | K = som_node_knn_I.size()[2] 31 | assert C==3 or C==2 32 | 33 | som_node_neighbors = knn_gather_by_indexing(som_node, som_node_knn_I) 34 | 35 | return som_node_neighbors 36 | 37 | 38 | def knn_gather_by_indexing(som_node, som_node_knn_I): 39 | ''' 40 | 41 | :param som_node: BxCxN 42 | :param som_node_knn_I: BxNxK 43 | :param som_node_neighbors: BxCxNxK 44 | :return: 45 | ''' 46 | B = som_node.size()[0] 47 | C = som_node.size()[1] 48 | N = som_node.size()[2] 49 | K = som_node_knn_I.size()[2] 50 | 51 | som_node_knn_I = som_node_knn_I.unsqueeze(1).expand(B, C, N, K).contiguous().view(B, C, N*K) 52 | som_node_neighbors = torch.gather(som_node, dim=2, index=som_node_knn_I).view(B, C, N, K) 53 | 54 | return som_node_neighbors 55 | -------------------------------------------------------------------------------- /models/segmenter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import math 6 | from collections import OrderedDict 7 | import os 8 | import os.path 9 | import json 10 | 11 | from . import networks 12 | from . import losses 13 | 14 | class Model(): 15 | def __init__(self, opt): 16 | self.opt = opt 17 | 18 | self.encoder = networks.Encoder(opt) 19 | self.segmenter = networks.Segmenter(opt) 20 | 21 | self.softmax_segmenter = losses.CrossEntropyLossSeg() 22 | if self.opt.gpu_id >= 0: 23 | self.encoder = self.encoder.to(self.opt.device) 24 | self.segmenter = self.segmenter.to(self.opt.device) 25 | self.softmax_segmenter = self.softmax_segmenter.to(self.opt.device) 26 | 27 | # learning rate_control 28 | if self.opt.pretrain is not None: 29 | self.old_lr_encoder = self.opt.lr * self.opt.pretrain_lr_ratio 30 | else: 31 | self.old_lr_encoder = self.opt.lr 32 | self.old_lr_segmenter = self.opt.lr 33 | 34 | self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), 35 | lr=self.old_lr_encoder, 36 | betas=(0.9, 0.999), 37 | weight_decay=0) 38 | self.optimizer_segmenter = torch.optim.Adam(self.segmenter.parameters(), 39 | lr=self.old_lr_segmenter, 40 | betas=(0.9, 0.999), 41 | weight_decay=0) 42 | 43 | # place holder for GPU tensors 44 | self.input_pc = torch.FloatTensor(self.opt.batch_size, 3, self.opt.input_pc_num).uniform_() 45 | self.input_sn = torch.FloatTensor(self.opt.batch_size, 3, self.opt.input_pc_num).uniform_() 46 | self.input_label = torch.LongTensor(self.opt.batch_size).fill_(1) 47 | self.input_seg = torch.LongTensor(self.opt.batch_size, 50).fill_(1) 48 | self.input_node = torch.FloatTensor(self.opt.batch_size, 3, self.opt.node_num) 49 | self.input_node_knn_I = torch.LongTensor(self.opt.batch_size, self.opt.node_num, self.opt.som_k) 50 | 51 | # record the test loss and accuracy 52 | self.test_loss_segmenter = torch.FloatTensor([0]) 53 | self.test_accuracy_segmenter = torch.FloatTensor([0]) 54 | self.test_iou = torch.FloatTensor([0]) 55 | 56 | if self.opt.gpu_id >= 0: 57 | self.input_pc = self.input_pc.to(self.opt.device) 58 | self.input_sn = self.input_sn.to(self.opt.device) 59 | self.input_label = self.input_label.to(self.opt.device) 60 | self.input_seg = self.input_seg.to(self.opt.device) 61 | self.input_node = self.input_node.to(self.opt.device) 62 | self.input_node_knn_I = self.input_node_knn_I.to(self.opt.device) 63 | self.test_loss_segmenter = self.test_loss_segmenter.to(self.opt.device) 64 | self.test_accuracy_segmenter = self.test_accuracy_segmenter.to(self.opt.device) 65 | 66 | 67 | def set_input(self, input_pc, input_sn, input_label, input_seg, input_node, input_node_knn_I): 68 | self.input_pc.resize_(input_pc.size()).copy_(input_pc) 69 | self.input_sn.resize_(input_sn.size()).copy_(input_sn) 70 | self.input_label.resize_(input_label.size()).copy_(input_label) 71 | self.input_seg.resize_(input_seg.size()).copy_(input_seg) 72 | self.input_node.resize_(input_node.size()).copy_(input_node) 73 | self.input_node_knn_I.resize_(input_node_knn_I.size()).copy_(input_node_knn_I) 74 | self.pc = self.input_pc.detach() 75 | self.sn = self.input_sn.detach() 76 | self.seg = self.input_seg.detach() 77 | self.label = self.input_label.detach() 78 | 79 | def forward(self, is_train=False, epoch=None): 80 | # ------------------------------------------------------------------ 81 | self.feature = self.encoder(self.pc, self.sn, self.input_node, self.input_node_knn_I, is_train, epoch) 82 | 83 | batch_size = self.feature.size()[0] 84 | feature_num = self.feature.size()[1] 85 | N = self.pc.size()[2] 86 | 87 | # ------------------------------------------------------------------ 88 | k = self.opt.k 89 | # BxkNxnode_num -> BxkN, tensor 90 | _, mask_max_idx = torch.max(self.encoder.mask, dim=2, keepdim=False) # BxkN 91 | mask_max_idx = mask_max_idx.unsqueeze(1) # Bx1xkN 92 | mask_max_idx_384 = mask_max_idx.expand(batch_size, 384, k*N).detach() 93 | mask_max_idx_512 = mask_max_idx.expand(batch_size, 512, k*N).detach() 94 | mask_max_idx_fn = mask_max_idx.expand(batch_size, feature_num, k * N).detach() 95 | 96 | feature_max_first_pn_out = torch.gather(self.encoder.first_pn_out_masked_max , dim=2, index=mask_max_idx_384) # Bx384xnode_num -> Bx384xkN 97 | feature_max_knn_feature_1 = torch.gather(self.encoder.knn_feature_1, dim=2, index=mask_max_idx_512) # Bx512xnode_num -> Bx512xkN 98 | feature_max_final_pn_out = torch.gather(self.encoder.final_pn_out, dim=2, index=mask_max_idx_fn) # Bx1024xnode_num -> Bx1024xkN 99 | 100 | self.score_segmenter = self.segmenter(self.encoder.x_decentered, 101 | self.pc, 102 | self.encoder.centers, 103 | self.sn, 104 | self.input_label, 105 | self.encoder.first_pn_out, 106 | feature_max_first_pn_out, 107 | feature_max_knn_feature_1, 108 | feature_max_final_pn_out, 109 | self.feature) 110 | 111 | def optimize(self, epoch=None): 112 | self.encoder.train() 113 | self.segmenter.train() 114 | self.forward(is_train=True, epoch=epoch) 115 | 116 | self.encoder.zero_grad() 117 | self.segmenter.zero_grad() 118 | 119 | self.loss_segmenter = self.softmax_segmenter(self.score_segmenter, self.seg) 120 | self.loss_segmenter.backward() 121 | 122 | self.optimizer_encoder.step() 123 | self.optimizer_segmenter.step() 124 | 125 | def test_model(self): 126 | self.encoder.eval() 127 | self.segmenter.eval() 128 | self.forward(is_train=False) 129 | 130 | # self.loss_classifier = self.softmax_classifier(self.score_classifier, self.label) 131 | self.loss_segmenter = self.softmax_segmenter(self.score_segmenter, self.seg) 132 | self.loss = self.loss_segmenter 133 | 134 | # visualization with visdom 135 | def get_current_visuals(self): 136 | # display only one instance of pc/img 137 | input_pc_np = self.input_pc[0].cpu().numpy().transpose() # Nx3 138 | pc_color_np = np.zeros(input_pc_np.shape) # Nx3 139 | gt_pc_color_np = np.zeros(input_pc_np.shape) # Nx3 140 | 141 | # construct color map 142 | _, predicted_seg = torch.max(self.score_segmenter.data[0], dim=0, keepdim=False) # 50xN -> N 143 | predicted_seg_np = predicted_seg.cpu().numpy() # N 144 | gt_seg_np = self.seg.data[0].cpu().numpy() # N 145 | 146 | color_map_file = os.path.join(self.opt.dataroot, 'part_color_mapping.json') 147 | color_map = json.load(open(color_map_file, 'r')) 148 | color_map_np = np.rint((np.asarray(color_map)*255).astype(np.int32)) # 50x3 149 | 150 | for i in range(input_pc_np.shape[0]): 151 | pc_color_np[i] = color_map_np[predicted_seg_np[i]] 152 | gt_pc_color_np[i] = color_map_np[gt_seg_np[i]] 153 | 154 | return OrderedDict([('pc_colored_predicted', [input_pc_np, pc_color_np]), 155 | ('pc_colored_gt', [input_pc_np, gt_pc_color_np])]) 156 | 157 | def get_current_errors(self): 158 | # self.score_segmenter: BxclassesxN 159 | _, predicted_seg = torch.max(self.score_segmenter.data, dim=1, keepdim=False) 160 | correct_mask = torch.eq(predicted_seg, self.input_seg).float() 161 | train_accuracy_segmenter = torch.mean(correct_mask) 162 | 163 | return OrderedDict([ 164 | ('train_loss_seg', self.loss_segmenter.item()), 165 | ('train_accuracy_seg', train_accuracy_segmenter), 166 | ('test_loss_seg', self.test_loss_segmenter.item()), 167 | ('test_acc_seg', self.test_accuracy_segmenter.item()), 168 | ('test_iou', self.test_iou.item()) 169 | ]) 170 | 171 | def save_network(self, network, network_label, epoch_label, gpu_id): 172 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 173 | save_path = os.path.join(self.opt.checkpoints_dir, save_filename) 174 | torch.save(network.cpu().state_dict(), save_path) 175 | if gpu_id >= 0 and torch.cuda.is_available(): 176 | # torch.cuda.device(gpu_id) 177 | network.to(self.opt.device) 178 | 179 | def update_learning_rate(self, ratio): 180 | # encoder 181 | lr_encoder = self.old_lr_encoder * ratio 182 | for param_group in self.optimizer_encoder.param_groups: 183 | param_group['lr'] = lr_encoder 184 | print('update encoder learning rate: %f -> %f' % (self.old_lr_encoder, lr_encoder)) 185 | self.old_lr_encoder = lr_encoder 186 | 187 | # segmentation 188 | lr_segmenter = self.old_lr_segmenter * ratio 189 | for param_group in self.optimizer_segmenter.param_groups: 190 | param_group['lr'] = lr_segmenter 191 | print('update segmenter learning rate: %f -> %f' % (self.old_lr_segmenter, lr_segmenter)) 192 | self.old_lr_segmenter = lr_segmenter 193 | -------------------------------------------------------------------------------- /part-seg/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | class Options(): 7 | def __init__(self): 8 | self.parser = argparse.ArgumentParser() 9 | self.initialized = False 10 | 11 | def initialize(self): 12 | self.parser.add_argument('--gpu_id', type=int, default=0, help='gpu id: e.g. 0, 1. -1 is no GPU') 13 | 14 | self.parser.add_argument('--dataset', type=str, default='shapenet', help='shapenet') 15 | self.parser.add_argument('--dataroot', default='/ssd/dataset/shapenetcore_partanno_segmentation_benchmark_v0_normal/', help='path to images & laser point clouds') 16 | self.parser.add_argument('--classes', type=int, default=50, help='ModelNet40 or ModelNet10') 17 | self.parser.add_argument('--name', type=str, default='train', help='name of the experiment. It decides where to store samples and models') 18 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 19 | 20 | self.parser.add_argument('--batch_size', type=int, default=8, help='input batch size') 21 | self.parser.add_argument('--input_pc_num', type=int, default=1024, help='# of input points') 22 | self.parser.add_argument('--surface_normal', type=bool, default=True, help='use surface normal in the pc input') 23 | self.parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data') 24 | 25 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 26 | self.parser.add_argument('--display_id', type=int, default=200, help='window id of the web display') 27 | 28 | self.parser.add_argument('--feature_num', type=int, default=1024, help='length of encoded feature') 29 | self.parser.add_argument('--activation', type=str, default='relu', help='activation function: relu, elu') 30 | self.parser.add_argument('--normalization', type=str, default='batch', help='normalization function: batch, instance') 31 | 32 | self.parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 33 | self.parser.add_argument('--dropout', type=float, default=0.6, help='probability of an element to be zeroed') 34 | self.parser.add_argument('--node_num', type=int, default=64, help='som node number') 35 | self.parser.add_argument('--k', type=int, default=3, help='k nearest neighbor') 36 | # '/ssd/open-source/so-net-full/autoencoder/checkpoints/save/shapenetpart/183_0.034180_net_encoder.pth' 37 | self.parser.add_argument('--pretrain', type=str, default=None, help='pre-trained encoder dict path') 38 | self.parser.add_argument('--pretrain_lr_ratio', type=float, default=1, help='learning rate ratio between pretrained encoder and classifier') 39 | 40 | self.parser.add_argument('--som_k', type=int, default=9, help='k nearest neighbor of SOM nodes searching on SOM nodes') 41 | self.parser.add_argument('--som_k_type', type=str, default='center', help='avg / center') 42 | 43 | self.parser.add_argument('--random_pc_dropout_lower_limit', type=float, default=1, help='keep ratio lower limit') 44 | self.parser.add_argument('--bn_momentum', type=float, default=0.1, help='normalization momentum, typically 0.1. Equal to (1-m) in TF') 45 | self.parser.add_argument('--bn_momentum_decay_step', type=int, default=None, help='BN momentum decay step. e.g, 0.5->0.01.') 46 | self.parser.add_argument('--bn_momentum_decay', type=float, default=0.6, help='BN momentum decay step. e.g, 0.5->0.01.') 47 | 48 | 49 | self.initialized = True 50 | 51 | def parse(self): 52 | if not self.initialized: 53 | self.initialize() 54 | self.opt = self.parser.parse_args() 55 | 56 | self.opt.device = torch.device("cuda:%d" % (self.opt.gpu_id) if torch.cuda.is_available() else "cpu") 57 | # torch.cuda.set_device(self.opt.gpu_id) 58 | 59 | args = vars(self.opt) 60 | 61 | print('------------ Options -------------') 62 | for k, v in sorted(args.items()): 63 | print('%s: %s' % (str(k), str(v))) 64 | print('-------------- End ----------------') 65 | 66 | # save to the disk 67 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 68 | util.mkdirs(expr_dir) 69 | file_name = os.path.join(expr_dir, 'opt.txt') 70 | with open(file_name, 'wt') as opt_file: 71 | opt_file.write('------------ Options -------------\n') 72 | for k, v in sorted(args.items()): 73 | opt_file.write('%s: %s\n' % (str(k), str(v))) 74 | opt_file.write('-------------- End ----------------\n') 75 | return self.opt 76 | -------------------------------------------------------------------------------- /part-seg/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | import numpy as np 4 | import math 5 | 6 | from options import Options 7 | opt = Options().parse() # set CUDA_VISIBLE_DEVICES before import torch 8 | 9 | import torch 10 | import torchvision 11 | from torch.autograd import Variable 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | import random 16 | import numpy as np 17 | 18 | from models import losses 19 | from models.segmenter import Model 20 | from data.shapenet_loader import ShapeNetLoader 21 | from util.visualizer import Visualizer 22 | 23 | 24 | if __name__=='__main__': 25 | trainset = ShapeNetLoader(opt.dataroot, 'train', opt) 26 | dataset_size = len(trainset) 27 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.nThreads) 28 | print('#training point clouds = %d' % len(trainset)) 29 | 30 | testset = ShapeNetLoader(opt.dataroot, 'test', opt) 31 | testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.nThreads) 32 | 33 | visualizer = Visualizer(opt) 34 | 35 | # create model, optionally load pre-trained model 36 | model = Model(opt) 37 | if opt.pretrain is not None: 38 | model.encoder.load_state_dict(torch.load(opt.pretrain)) 39 | 40 | # load pre-trained model 41 | # folder = 'checkpoints/' 42 | # model_epoch = '2' 43 | # model_acc = '0.914946' 44 | # model.encoder.load_state_dict(torch.load(folder + model_epoch + '_' + model_acc + '_net_encoder.pth')) 45 | # model.segmenter.load_state_dict(torch.load(folder + model_epoch + '_' + model_acc + '_net_segmenter.pth')) 46 | 47 | best_iou = 0 48 | for epoch in range(601): 49 | 50 | epoch_iter = 0 51 | for i, data in enumerate(trainloader): 52 | iter_start_time = time.time() 53 | epoch_iter += opt.batch_size 54 | 55 | input_pc, input_sn, input_label, input_seg, input_node, input_node_knn_I = data 56 | model.set_input(input_pc, input_sn, input_label, input_seg, input_node, input_node_knn_I) 57 | 58 | model.optimize() 59 | 60 | if i % 100 == 0: 61 | # print/plot errors 62 | t = (time.time() - iter_start_time) / opt.batch_size 63 | 64 | errors = model.get_current_errors() 65 | 66 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 67 | visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors) 68 | 69 | # print(model.autoencoder.encoder.feature) 70 | # visuals = model.get_current_visuals() 71 | # visualizer.display_current_results(visuals, epoch, i) 72 | 73 | # test network 74 | if epoch >= 0 and epoch%1==0: 75 | batch_amount = 0 76 | model.test_loss_segmenter.data.zero_() 77 | model.test_accuracy_segmenter.data.zero_() 78 | model.test_iou.data.zero_() 79 | for i, data in enumerate(testloader): 80 | input_pc, input_sn, input_label, input_seg, input_node, input_node_knn_I = data 81 | model.set_input(input_pc, input_sn, input_label, input_seg, input_node, input_node_knn_I) 82 | model.test_model() 83 | 84 | batch_amount += input_label.size()[0] 85 | 86 | # # accumulate loss 87 | model.test_loss_segmenter += model.loss_segmenter.detach() * input_label.size()[0] 88 | 89 | _, predicted_seg = torch.max(model.score_segmenter.data, dim=1, keepdim=False) 90 | correct_mask = torch.eq(predicted_seg, model.input_seg).float() 91 | test_accuracy_segmenter = torch.mean(correct_mask) 92 | model.test_accuracy_segmenter += test_accuracy_segmenter * input_label.size()[0] 93 | 94 | # segmentation iou 95 | test_iou_batch = losses.compute_iou(model.score_segmenter.cpu().data, model.input_seg.cpu().data, model.input_label.cpu().data, visualizer, opt, input_pc.cpu().data) 96 | model.test_iou += test_iou_batch * input_label.size()[0] 97 | 98 | # print(test_iou_batch) 99 | # print(model.score_segmenter.size()) 100 | 101 | print(batch_amount) 102 | model.test_loss_segmenter /= batch_amount 103 | model.test_accuracy_segmenter /= batch_amount 104 | model.test_iou /= batch_amount 105 | if model.test_iou.item() > best_iou: 106 | best_iou = model.test_iou.item() 107 | print('Tested network. So far best segmentation: %f' % (best_iou) ) 108 | 109 | # save network 110 | if model.test_iou.item() > 0.835: 111 | print("Saving network...") 112 | model.save_network(model.encoder, 'encoder', '%d_%f' % (epoch, model.test_iou.item()), opt.gpu_id) 113 | model.save_network(model.segmenter, 'segmenter', '%d_%f' % (epoch, model.test_iou.item()), opt.gpu_id) 114 | 115 | # learning rate decay 116 | if epoch%30==0 and epoch>0: 117 | model.update_learning_rate(0.5) 118 | 119 | # save network 120 | # if epoch%20==0 and epoch>0: 121 | # print("Saving network...") 122 | # model.save_network(model.classifier, 'cls', '%d' % epoch, opt.gpu_id) 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /shrec16/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | 7 | class Options(): 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser() 10 | self.initialized = False 11 | 12 | def initialize(self): 13 | self.parser.add_argument('--gpu_id', type=int, default=0, help='gpu id: e.g. 0, 1, 2. -1 is no GPU') 14 | 15 | self.parser.add_argument('--dataset', type=str, default='shrec', help='modelnet / shrec') 16 | self.parser.add_argument('--dataroot', default='/ssd/jiaxin/datasets/SHREC2016/', help='path to images & laser point clouds') 17 | self.parser.add_argument('--classes', type=int, default=55, help='ModelNet40 or ModelNet10') 18 | self.parser.add_argument('--name', type=str, default='train', help='name of the experiment. It decides where to store samples and models') 19 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 20 | 21 | self.parser.add_argument('--batch_size', type=int, default=8, help='input batch size') 22 | self.parser.add_argument('--input_pc_num', type=int, default=5000, help='# of input points') 23 | self.parser.add_argument('--surface_normal', type=bool, default=True, help='use surface normal in the pc input') 24 | self.parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data') 25 | 26 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 27 | self.parser.add_argument('--display_id', type=int, default=200, help='window id of the web display') 28 | 29 | self.parser.add_argument('--feature_num', type=int, default=1024, help='length of encoded feature') 30 | self.parser.add_argument('--activation', type=str, default='relu', help='activation function: relu, elu') 31 | self.parser.add_argument('--normalization', type=str, default='batch', help='normalization function: batch, instance') 32 | 33 | self.parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 34 | self.parser.add_argument('--dropout', type=float, default=0.6, help='probability of an element to be zeroed') 35 | self.parser.add_argument('--node_num', type=int, default=64, help='som node number') 36 | self.parser.add_argument('--k', type=int, default=3, help='k nearest neighbor') 37 | self.parser.add_argument('--pretrain', type=str, default=None, help='pre-trained encoder dict path') 38 | self.parser.add_argument('--pretrain_lr_ratio', type=float, default=1, help='learning rate ratio between pretrained encoder and classifier') 39 | 40 | self.parser.add_argument('--som_k', type=int, default=0, help='k nearest neighbor of SOM nodes searching on SOM nodes') 41 | self.parser.add_argument('--som_k_type', type=str, default='avg', help='avg / center') 42 | 43 | self.parser.add_argument('--random_pc_dropout_lower_limit', type=float, default=1, help='keep ratio lower limit') 44 | self.parser.add_argument('--bn_momentum', type=float, default=0.1, help='normalization momentum, typically 0.1. Equal to (1-m) in TF') 45 | self.parser.add_argument('--bn_momentum_decay_step', type=int, default=None, help='BN momentum decay step. e.g, 0.5->0.01.') 46 | self.parser.add_argument('--bn_momentum_decay', type=float, default=0.6, help='BN momentum decay step. e.g, 0.5->0.01.') 47 | 48 | self.parser.add_argument('--rot_horizontal', type=bool, default=False, help='Rotation augmentation around vertical axis.') 49 | self.parser.add_argument('--rot_perturbation', type=bool, default=False, help='Small rotation augmentation around 3 axis.') 50 | self.parser.add_argument('--translation_perturbation', type=bool, default=False, help='Small translation augmentation around 3 axis.') 51 | 52 | 53 | 54 | self.initialized = True 55 | 56 | def parse(self): 57 | if not self.initialized: 58 | self.initialize() 59 | self.opt = self.parser.parse_args() 60 | 61 | self.opt.device = torch.device("cuda:%d" % (self.opt.gpu_id) if torch.cuda.is_available() else "cpu") 62 | # torch.cuda.set_device(self.opt.gpu_id) 63 | 64 | args = vars(self.opt) 65 | 66 | print('------------ Options -------------') 67 | for k, v in sorted(args.items()): 68 | print('%s: %s' % (str(k), str(v))) 69 | print('-------------- End ----------------') 70 | 71 | # save to the disk 72 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 73 | util.mkdirs(expr_dir) 74 | file_name = os.path.join(expr_dir, 'opt.txt') 75 | with open(file_name, 'wt') as opt_file: 76 | opt_file.write('------------ Options -------------\n') 77 | for k, v in sorted(args.items()): 78 | opt_file.write('%s: %s\n' % (str(k), str(v))) 79 | opt_file.write('-------------- End ----------------\n') 80 | return self.opt 81 | -------------------------------------------------------------------------------- /shrec16/test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | import numpy as np 4 | import math 5 | 6 | from shrec16.options import Options 7 | opt = Options().parse() # set CUDA_VISIBLE_DEVICES before import torch 8 | 9 | import torch 10 | import torchvision 11 | from torch.autograd import Variable 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | import random 16 | import numpy as np 17 | import os 18 | 19 | from models.classifier import Model 20 | from data.modelnet_shrec_loader import ModelNet_Shrec_Loader 21 | from util.visualizer import Visualizer 22 | 23 | 24 | if __name__=='__main__': 25 | testset = ModelNet_Shrec_Loader(opt.dataroot, 'test', opt) 26 | testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.nThreads) 27 | print('#testing point clouds = %d' % len(testset)) 28 | 29 | # create model, optionally load pre-trained model 30 | model = Model(opt) 31 | model.encoder.load_state_dict(torch.load('/ssd/jiaxin/SO-Net/shrec16/checkpoints/0_0.748621_net_encoder.pth')) 32 | model.classifier.load_state_dict(torch.load('/ssd/jiaxin/SO-Net/shrec16/checkpoints/0_0.748621_net_classifier.pth')) 33 | output_folder = '/ssd/tmp/retrieval' 34 | 35 | model.encoder.eval() 36 | model.classifier.eval() 37 | 38 | visualizer = Visualizer(opt) 39 | softmax_layer = nn.Softmax2d() 40 | 41 | batch_amount = 0 42 | feature_map = torch.FloatTensor(len(testset), 55).cuda().zero_() # Nx55 43 | predicted_labels = torch.LongTensor(len(testset)).cuda().zero_() # N 44 | model_name_ids = torch.LongTensor(len(testset)).cuda().zero_() # N 45 | for i, data in enumerate(testloader): 46 | input_pc, input_sn, input_label, input_node, input_node_knn_I, input_model_name_id = data 47 | # input_pc, input_sn, input_model_name_id, input_node = data 48 | model.set_input(input_pc, input_sn, input_model_name_id, input_node,input_node_knn_I) 49 | model.forward() 50 | 51 | batch_size = input_model_name_id.size()[0] 52 | 53 | # feature_map[batch_amount:batch_amount+batch_size] = softmax_layer(model.score.unsqueeze(2).unsqueeze(3)).squeeze(3).squeeze(2).data 54 | feature_map[batch_amount:batch_amount+batch_size] = model.score.data 55 | 56 | _, predicted_idx = torch.max(model.score.data, dim=1, keepdim=False) 57 | predicted_labels[batch_amount:batch_amount+batch_size] = predicted_idx 58 | 59 | model_name_ids[batch_amount:batch_amount+batch_size] = input_model_name_id 60 | 61 | batch_amount += batch_size 62 | 63 | print(feature_map.size()) 64 | print(predicted_labels.size()) 65 | print(model_name_ids.size()) 66 | print(feature_map) 67 | 68 | # calculate neighbors 69 | for i in range(len(testset)): 70 | # find instance that has the same label 71 | feature = feature_map[i] # 55 72 | label = predicted_labels[i] # N 73 | 74 | mask = torch.eq(predicted_labels, label) # N 75 | same_label_indices = torch.nonzero(mask).squeeze(1) # K 76 | 77 | # print(same_label_indices) 78 | 79 | feature_selected = feature_map[same_label_indices] # Kx55 80 | model_name_id_selected = model_name_ids[same_label_indices] # K 81 | 82 | distance = torch.norm(feature.unsqueeze(0)-feature_selected, p=2, dim=1) # Kx55 -> K 83 | sorted, indices = torch.sort(distance) 84 | 85 | nn_model_name_id = model_name_id_selected[indices].cpu().numpy() 86 | nn_distance = sorted.cpu().numpy() 87 | 88 | # print(indices) 89 | # print(nn_model_name_id) 90 | # print(nn_distance) 91 | 92 | # write to file 93 | model_name = '%06d' % model_name_ids[i] 94 | nn_result = np.transpose(np.vstack((nn_model_name_id, nn_distance))) 95 | 96 | if nn_result.shape[0]<=1000: 97 | np.savetxt(os.path.join(output_folder, model_name), nn_result, fmt='%06d %f', delimiter=' ') 98 | else: 99 | np.savetxt(os.path.join(output_folder, model_name), nn_result[0:1000,:], fmt='%06d %f', delimiter=' ') 100 | -------------------------------------------------------------------------------- /shrec16/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | import numpy as np 4 | import math 5 | 6 | from shrec16.options import Options 7 | opt = Options().parse() # set CUDA_VISIBLE_DEVICES before import torch 8 | 9 | import torch 10 | import torchvision 11 | from torch.autograd import Variable 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | import random 16 | import numpy as np 17 | 18 | from models.classifier import Model 19 | from data.modelnet_shrec_loader import ModelNet_Shrec_Loader 20 | from util.visualizer import Visualizer 21 | 22 | 23 | if __name__=='__main__': 24 | trainset = ModelNet_Shrec_Loader(opt.dataroot, 'train', opt) 25 | dataset_size = len(trainset) 26 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.nThreads) 27 | print('#training point clouds = %d' % len(trainset)) 28 | 29 | testset = ModelNet_Shrec_Loader(opt.dataroot, 'val', opt) 30 | testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.nThreads) 31 | 32 | # create model, optionally load pre-trained model 33 | model = Model(opt) 34 | if opt.pretrain is not None: 35 | model.encoder.load_state_dict(torch.load(opt.pretrain)) 36 | ############################# automation for ModelNet10 / 40 configuration #################### 37 | # if opt.classes == 10: 38 | # opt.lr = opt.lr * 0.1 39 | # opt.dropout = opt.dropout + 0.1 40 | ############################# automation for ModelNet10 / 40 configuration #################### 41 | 42 | visualizer = Visualizer(opt) 43 | 44 | best_accuracy = 0 45 | for epoch in range(201): 46 | 47 | epoch_iter = 0 48 | for i, data in enumerate(trainloader): 49 | iter_start_time = time.time() 50 | epoch_iter += opt.batch_size 51 | 52 | input_pc, input_sn, input_label, input_node, input_node_knn_I = data 53 | model.set_input(input_pc, input_sn, input_label, input_node, input_node_knn_I) 54 | 55 | model.optimize(epoch=epoch) 56 | 57 | if i % 600 == 0: 58 | # print/plot errors 59 | t = (time.time() - iter_start_time) / opt.batch_size 60 | 61 | errors = model.get_current_errors() 62 | 63 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 64 | visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors) 65 | 66 | # print(model.autoencoder.encoder.feature) 67 | # visuals = model.get_current_visuals() 68 | # visualizer.display_current_results(visuals, epoch, i) 69 | 70 | # test network 71 | if epoch >= 0 and epoch%1==0: 72 | batch_amount = 0 73 | model.test_loss.data.zero_() 74 | model.test_accuracy.data.zero_() 75 | for i, data in enumerate(testloader): 76 | input_pc, input_sn, input_label, input_node, input_node_knn_I = data 77 | model.set_input(input_pc, input_sn, input_label, input_node, input_node_knn_I) 78 | model.test_model() 79 | 80 | batch_amount += input_label.size()[0] 81 | 82 | # # accumulate loss 83 | model.test_loss += model.loss.detach() * input_label.size()[0] 84 | 85 | # # accumulate accuracy 86 | _, predicted_idx = torch.max(model.score.data, dim=1, keepdim=False) 87 | correct_mask = torch.eq(predicted_idx, model.input_label).float() 88 | test_accuracy = torch.mean(correct_mask).cpu() 89 | model.test_accuracy += test_accuracy * input_label.size()[0] 90 | 91 | model.test_loss /= batch_amount 92 | model.test_accuracy /= batch_amount 93 | if model.test_accuracy.item() > best_accuracy: 94 | best_accuracy = model.test_accuracy.item() 95 | print('Tested network. So far best: %f' % (best_accuracy) ) 96 | 97 | # save network 98 | saving_acc_threshold = 0.0 99 | if model.test_accuracy.item() > saving_acc_threshold: 100 | print("Saving network...") 101 | model.save_network(model.encoder, 'encoder', '%d_%f' % (epoch, model.test_accuracy.item()), opt.gpu_id) 102 | model.save_network(model.classifier, 'classifier', '%d_%f' % (epoch, model.test_accuracy.item()), opt.gpu_id) 103 | 104 | # learning rate decay 105 | if opt.classes == 10: 106 | lr_decay_step = 40 107 | else: 108 | lr_decay_step = 20 109 | if epoch%lr_decay_step==0 and epoch>0: 110 | model.update_learning_rate(0.5) 111 | # batch normalization momentum decay: 112 | next_epoch = epoch + 1 113 | if (opt.bn_momentum_decay_step is not None) and (next_epoch >= 1) and ( 114 | next_epoch % opt.bn_momentum_decay_step == 0): 115 | current_bn_momentum = opt.bn_momentum * ( 116 | opt.bn_momentum_decay ** (next_epoch // opt.bn_momentum_decay_step)) 117 | print('BN momentum updated to: %f' % current_bn_momentum) 118 | 119 | # save network 120 | # if epoch%20==0 and epoch>0: 121 | # print("Saving network...") 122 | # model.save_network(model.classifier, 'cls', '%d' % epoch, opt.gpu_id) 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lijx10/SO-Net/cbff352ce7a0138d9e719447c05d46f675d964c9/util/__init__.py -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /util/potential_field.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numbers 3 | import os 4 | import os.path 5 | import numpy as np 6 | import struct 7 | import math 8 | import time 9 | 10 | 11 | class PotentialField: 12 | def __init__(self, node_num, dim): 13 | self.node_num = node_num 14 | self.dim = dim 15 | np.random.seed(2017) 16 | self.node = np.random.rand(self.node_num, self.dim) * 2 - 1 17 | np.random.seed() 18 | 19 | self.learning_rate = 0.01 20 | 21 | def node_force(self, src, dst): 22 | # return the force from src to dst 23 | f = dst - src 24 | f_norm = np.linalg.norm(f) + 0.00001 25 | f = f / f_norm / f_norm ** 2 26 | return f 27 | 28 | def wall_force(self, dst): 29 | f = np.zeros(self.dim) 30 | for i in range(self.dim): 31 | x = dst[i] 32 | # no force if far away 33 | if math.fabs(x) < 0.01: 34 | continue 35 | 36 | f_tmp = np.zeros(self.dim) 37 | f_tmp[i] = -1 * x * self.node_num/1.5 38 | f = f + f_tmp 39 | return f 40 | 41 | def get_total_node_force(self): 42 | force = np.zeros((self.node_num, self.dim)) 43 | for j in range(self.node_num): 44 | dst = self.node[j] 45 | for k in range(self.node_num): 46 | force[j] += self.node_force(self.node[k], dst) 47 | return force 48 | 49 | def get_total_wall_force(self): 50 | force = np.zeros((self.node_num, self.dim)) 51 | for j in range(self.node_num): 52 | dst = self.node[j] 53 | force[j] += self.wall_force(dst) 54 | return force 55 | 56 | def optimize(self): 57 | for i in range(100): 58 | learning_rate = self.learning_rate 59 | 60 | # cumulate the force 61 | force = np.zeros((self.node_num, self.dim)) 62 | for j in range(self.node_num): 63 | dst = self.node[j] 64 | force[j] += self.wall_force(dst) 65 | 66 | for k in range(self.node_num): 67 | force[j] += self.node_force(self.node[k], dst) 68 | 69 | # apply the force 70 | self.node += force * learning_rate 71 | 72 | self.reorder() 73 | 74 | def reorder(self): 75 | node_ordered = self.node[self.node[:, 0].argsort()] 76 | 77 | rows = int(math.sqrt(self.node_num)) 78 | cols = rows 79 | node_ordered = node_ordered.reshape((rows, cols, self.dim)) 80 | for i in range(rows): 81 | node_row = node_ordered[i] 82 | node_row = node_row[node_row[:, 1].argsort()] 83 | node_ordered[i] = node_row 84 | node_ordered = node_ordered.reshape((self.node_num, self.dim)) 85 | 86 | self.node = node_ordered 87 | 88 | -------------------------------------------------------------------------------- /util/som.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numbers 3 | import os 4 | import os.path 5 | import numpy as np 6 | import struct 7 | import math 8 | import time 9 | import gc 10 | 11 | import torch 12 | import torchvision 13 | 14 | from . import potential_field 15 | 16 | 17 | class SOM(): 18 | def __init__(self, rows=4, cols=4, dim=3, gpu_id=-1): 19 | ''' 20 | Can't put into dataloader, because dataloader keeps only 1 class instance. So this should be used offline, 21 | to save som result into numpy array. 22 | :param rows: 23 | :param cols: 24 | :param dim: 25 | :param gpu_id: 26 | ''' 27 | self.rows = rows 28 | self.cols = cols 29 | self.dim = dim 30 | self.node_num = rows * cols 31 | 32 | self.sigma = 0.4 33 | self.learning_rate = 0.5 34 | self.max_iteration = 60 35 | 36 | self.gpu_id = gpu_id 37 | 38 | # node: Cx(rowsxcols), tensor 39 | self.node = torch.FloatTensor(self.dim, self.rows * self.cols).zero_() 40 | self.node_idx_list = torch.from_numpy(np.arange(self.rows * self.cols).astype(np.float32)) 41 | self.init_weighting_matrix = torch.FloatTensor(self.node_num, self.rows, self.cols) # node_numxrowsxcols 42 | if self.gpu_id >= 0: 43 | self.node = self.node.to(self.device) 44 | self.node_idx_list = self.node_idx_list.to(self.device) 45 | self.init_weighting_matrix = self.init_weighting_matrix.to(self.device) 46 | 47 | self.get_init_weighting_matrix() 48 | 49 | # initialize the node by potential field 50 | pf = potential_field.PotentialField(self.node_num, self.dim) 51 | pf.optimize() 52 | self.node_init_value = torch.from_numpy(pf.node.transpose().astype(np.float32)) 53 | 54 | def node_init(self): 55 | self.node.copy_(self.node_init_value) 56 | 57 | def get_init_weighting_matrix(self): 58 | ''' 59 | get the initial weighting matrix, later the weighting matrix wil base on the init. 60 | ''' 61 | for idx in range(self.rows * self.cols): 62 | (i, j) = self.idx2multi(idx) 63 | self.init_weighting_matrix[idx, :] = self.gaussian((i, j), self.sigma) 64 | if self.gpu_id >= 0: 65 | self.init_weighting_matrix = self.init_weighting_matrix.to(self.device) 66 | 67 | def get_weighting_matrix(self, sigma): 68 | scale = 1.0 / ((sigma / self.sigma) ** 2) 69 | weighting_matrix = torch.exp(torch.log(self.init_weighting_matrix) * scale) 70 | return weighting_matrix 71 | 72 | def gaussian(self, c, sigma): 73 | """Returns a Gaussian centered in c""" 74 | d = 2 * np.pi * sigma * sigma 75 | ax = np.exp(-np.power(np.arange(self.rows) - c[0], 2) / d) 76 | ay = np.exp(-np.power(np.arange(self.cols) - c[1], 2) / d) 77 | return torch.from_numpy(np.outer(ax, ay).astype(np.float32)) 78 | 79 | def idx2multi(self, i): 80 | return (i // self.cols, i % self.cols) 81 | 82 | def query(self, x): 83 | ''' 84 | :param x: input data CxN tensor 85 | :return: mask: Nxnode_num 86 | ''' 87 | # expand as CxNxnode_num 88 | node = self.node.unsqueeze(1).expand(x.size(0), x.size(1), self.rows * self.cols) 89 | x_expanded = x.unsqueeze(2).expand_as(node) 90 | 91 | # calcuate difference between x and each node 92 | diff = x_expanded - node # CxNxnode_num 93 | diff_norm = (diff ** 2).sum(dim=0) # Nxnode_num 94 | 95 | # find the nearest neighbor 96 | _, min_idx = torch.min(diff_norm, dim=1) # N 97 | min_idx_expanded = min_idx.unsqueeze(1).expand(min_idx.size()[0], self.rows * self.cols).float() # Nxnode_num 98 | 99 | node_idx_list = self.node_idx_list.unsqueeze(0).expand_as(min_idx_expanded) # Nxnode_num 100 | mask = torch.eq(min_idx_expanded, node_idx_list).float() # Nxnode_num 101 | mask_row_max, _ = torch.max(mask, dim=0) # node_num, this indicates whether the node has nearby x 102 | 103 | return mask, mask_row_max 104 | 105 | def batch_update(self, x, iteration): 106 | # x is CxN tensor, C==self.dim, W=1 107 | assert (x.size()[0] == self.dim) 108 | 109 | # get learning_rate and sigma 110 | learning_rate = self.learning_rate / (1 + 2 * iteration / self.max_iteration) 111 | sigma = self.sigma / (1 + 2 * iteration / self.max_iteration) 112 | 113 | # expand as CxNxnode_num 114 | node = self.node.unsqueeze(1).expand(x.size(0), x.size(1), self.rows * self.cols) 115 | x_expanded = x.unsqueeze(2).expand_as(node) 116 | 117 | # calcuate difference between x and each node 118 | diff = x_expanded - node # CxNxnode_num 119 | diff_norm = (diff ** 2).sum(dim=0) # Nxnode_num 120 | 121 | # find the nearest neighbor 122 | _, min_idx = torch.min(diff_norm, dim=1) # N 123 | min_idx_expanded = min_idx.unsqueeze(1).expand(min_idx.size()[0], self.rows * self.cols).float() # Nxnode_num 124 | 125 | node_idx_list = self.node_idx_list.unsqueeze(0).expand_as(min_idx_expanded) # Nxnode_num 126 | mask = torch.eq(min_idx_expanded, node_idx_list).float() # Nxnode_num 127 | mask_row_sum = torch.sum(mask, dim=0) + 0.00001 # node_num 128 | mask_row_max, _ = torch.max(mask, dim=0) # node_num, this indicates whether the node has nearby x 129 | 130 | # calculate the mean x for each node 131 | x_expanded_masked = x_expanded * mask.unsqueeze(0).expand_as(x_expanded) # CxNxnode_num 132 | x_expanded_masked_sum = torch.sum(x_expanded_masked, dim=1) # Cxnode_num 133 | x_expanded_mask_mean = x_expanded_masked_sum / mask_row_sum.unsqueeze(0).expand_as( 134 | x_expanded_masked_sum) # Cxnode_num 135 | 136 | # each x_expanded_mask_mean (in total node_num vectors) will calculate its diff with all nodes 137 | # multiply the mask_row_max, so that the isolated node won't be pulled to the center 138 | x_expanded_mask_mean_expanded = x_expanded_mask_mean.unsqueeze(2).expand(self.dim, self.rows * self.cols, 139 | self.rows * self.cols) # Cxnode_numxnode_num 140 | node_expanded_transposed = self.node.unsqueeze(1).expand_as(x_expanded_mask_mean_expanded) # .transpose(1,2) 141 | diff_masked_mean = x_expanded_mask_mean_expanded - node_expanded_transposed # Cxnode_numxnode_num 142 | diff_masked_mean = diff_masked_mean * mask_row_max.unsqueeze(1).unsqueeze(0).expand_as(diff_masked_mean) 143 | 144 | # compute the neighrbor weighting 145 | # weighting_matrix = torch.FloatTensor(self.rows*self.cols, self.rows, self.cols) # node_numxrowsxcols 146 | # for idx in range(self.rows*self.cols): 147 | # (i,j) = self.idx2multi(idx) 148 | # weighting_matrix[idx,:] = self.gaussian((i,j), sigma) 149 | # if self.gpu_id >= 0: 150 | # weighting_matrix = weighting_matrix.to(self.device) 151 | # compute the neighrbor weighting using pre-computed matrix 152 | weighting_matrix = self.get_weighting_matrix(sigma) # node_numxrowsxcols 153 | 154 | # compute the update 155 | weighting_matrix = weighting_matrix.unsqueeze(0).expand(self.dim, self.node_num, self.rows, 156 | self.cols) # Cxnode_numxrowsxcols 157 | diff_masked_mean_matrix_view = diff_masked_mean.view(self.dim, self.node_num, self.rows, self.cols) 158 | delta = diff_masked_mean_matrix_view * weighting_matrix * learning_rate # Cxnode_numxrowsxcols 159 | delta = delta.sum(dim=1) 160 | 161 | # apply the update 162 | node_matrix_view = self.node.view(self.dim, self.rows, self.cols) # Cxrowsxcols 163 | node_matrix_view += delta 164 | 165 | # print(self.node) 166 | 167 | def optimize(self, x): 168 | self.node_init() 169 | for iter in range(int(self.max_iteration / 3)): 170 | self.batch_update(x, 0) 171 | for iter in range(self.max_iteration): 172 | self.batch_update(x, iter) 173 | 174 | 175 | class BatchSOM(): 176 | def __init__(self, rows=4, cols=4, dim=3, gpu_id=None, batch_size=10): 177 | self.rows = rows 178 | self.cols = cols 179 | self.dim = dim 180 | self.node_num = rows * cols 181 | 182 | self.sigma = 0.4 183 | self.learning_rate = 0.5 184 | self.max_iteration = 60 185 | 186 | self.gpu_id = gpu_id 187 | assert gpu_id >= 0 188 | self.device = torch.device("cuda:%d"%(gpu_id) if torch.cuda.is_available() else "cpu") 189 | self.batch_size = batch_size 190 | 191 | # node: BxCx(rowsxcols), tensor 192 | self.node = torch.FloatTensor(self.batch_size, self.dim, self.rows * self.cols).zero_() 193 | self.node_idx_list = torch.from_numpy(np.arange(self.node_num).astype(np.int64)) # node_num LongTensor 194 | self.init_weighting_matrix = torch.FloatTensor(self.node_num, self.rows, self.cols) # node_numxrowsxcols 195 | if self.gpu_id >= 0: 196 | self.node = self.node.to(self.device) 197 | self.node_idx_list = self.node_idx_list.to(self.device) 198 | 199 | # get initial weighting matrix 200 | self.get_init_weighting_matrix() 201 | 202 | # initialize the node by potential field 203 | pf = potential_field.PotentialField(self.node_num, self.dim) 204 | pf.optimize() 205 | self.node_init_value = torch.from_numpy(pf.node.transpose().astype(np.float32)) 206 | 207 | def node_init(self, batch_size): 208 | self.batch_size = batch_size 209 | self.node.resize_(self.batch_size, self.dim, self.node_num) 210 | self.node.copy_(torch.unsqueeze(self.node_init_value, dim=0).expand_as(self.node)) 211 | 212 | def gaussian(self, c, sigma): 213 | """Returns a Gaussian centered in c""" 214 | d = 2 * np.pi * sigma * sigma 215 | ax = np.exp(-np.power(np.arange(self.rows) - c[0], 2) / d) 216 | ay = np.exp(-np.power(np.arange(self.cols) - c[1], 2) / d) 217 | return torch.from_numpy(np.outer(ax, ay).astype(np.float32)) 218 | 219 | def get_init_weighting_matrix(self): 220 | ''' 221 | get the initial weighting matrix, later the weighting matrix wil base on the init. 222 | ''' 223 | for idx in range(self.rows * self.cols): 224 | (i, j) = self.idx2multi(idx) 225 | self.init_weighting_matrix[idx, :] = self.gaussian((i, j), self.sigma) 226 | if self.gpu_id >= 0: 227 | self.init_weighting_matrix = self.init_weighting_matrix.to(self.device) 228 | 229 | def get_weighting_matrix(self, sigma): 230 | scale = 1.0 / ((sigma / self.sigma) ** 2) 231 | weighting_matrix = torch.exp(torch.log(self.init_weighting_matrix) * scale) 232 | return weighting_matrix 233 | 234 | def idx2multi(self, i): 235 | return (i // self.cols, i % self.cols) 236 | 237 | def query_topk(self, x, k): 238 | ''' 239 | :param x: input data BxCxN tensor 240 | :param k: topk 241 | :return: mask: Nxnode_num 242 | ''' 243 | 244 | # expand as BxCxNxnode_num 245 | node = self.node.unsqueeze(2).expand(x.size(0), x.size(1), x.size(2), self.rows * self.cols) 246 | x_expanded = x.unsqueeze(3).expand_as(node) 247 | 248 | # calcuate difference between x and each node 249 | diff = x_expanded - node # BxCxNxnode_num 250 | diff_norm = (diff ** 2).sum(dim=1) # BxNxnode_num 251 | 252 | # find the nearest neighbor 253 | _, min_idx = torch.topk(diff_norm, k=k, dim=2, largest=False, sorted=False) # BxNxk 254 | min_idx_expanded = min_idx.unsqueeze(2).expand(min_idx.size()[0], min_idx.size()[1], self.rows * self.cols, 255 | k) # BxNxnode_numxk 256 | 257 | node_idx_list = self.node_idx_list.unsqueeze(0).unsqueeze(0).unsqueeze(3).expand_as(min_idx_expanded).long() # BxNxnode_numxk 258 | mask = torch.eq(min_idx_expanded, node_idx_list).int() # BxNxnode_numxk 259 | # mask = torch.sum(mask, dim=3) # BxNxnode_num 260 | 261 | mask_list, min_idx_list = [], [] 262 | for i in range(k): 263 | mask_list.append(mask[..., i]) 264 | min_idx_list.append(min_idx[..., i]) 265 | mask = torch.cat(tuple(mask_list), dim=1) # BxkNxnode_num 266 | min_idx = torch.cat(tuple(min_idx_list), dim=1) # BxkN 267 | mask_row_max, _ = torch.max(mask, dim=1) # Bxnode_num, this indicates whether the node has nearby x 268 | 269 | return mask, mask_row_max, min_idx 270 | 271 | def query(self, x): 272 | ''' 273 | :param x: input data CxN tensor 274 | :return: mask: Nxnode_num 275 | ''' 276 | # expand as BxCxNxnode_num 277 | node = self.node.unsqueeze(2).expand(x.size(0), x.size(1), x.size(2), self.rows * self.cols) 278 | x_expanded = x.unsqueeze(3).expand_as(node) 279 | 280 | # calcuate difference between x and each node 281 | diff = x_expanded - node # BxCxNxnode_num 282 | diff_norm = (diff ** 2).sum(dim=1) # BxNxnode_num 283 | 284 | # find the nearest neighbor 285 | _, min_idx = torch.min(diff_norm, dim=2) # BxN 286 | min_idx_expanded = min_idx.unsqueeze(2).expand(min_idx.size()[0], min_idx.size()[1], 287 | self.rows * self.cols) # BxNxnode_num 288 | 289 | node_idx_list = self.node_idx_list.unsqueeze(0).unsqueeze(0).expand_as(min_idx_expanded).long() # BxNxnode_num 290 | mask = torch.eq(min_idx_expanded, node_idx_list).float() # BxNxnode_num 291 | mask_row_max, _ = torch.max(mask, dim=1) # Bxnode_num, this indicates whether the node has nearby x 292 | 293 | return mask, mask_row_max 294 | 295 | def batch_update(self, x, learning_rate, sigma): 296 | # x is BxCxN tensor, C==self.dim, W=1 297 | assert (x.size()[1] == self.dim) 298 | assert (x.size()[0] == self.batch_size) 299 | 300 | # expand as BxCxNxnode_num 301 | node = self.node.unsqueeze(2).expand(x.size(0), x.size(1), x.size(2), self.rows * self.cols) 302 | x_expanded = x.unsqueeze(3).expand_as(node) 303 | 304 | # calcuate difference between x and each node 305 | diff = x_expanded - node # BxCxNxnode_num 306 | diff_norm = (diff ** 2).sum(dim=1) # BxNxnode_num 307 | 308 | # find the nearest neighbor 309 | _, min_idx = torch.min(diff_norm, dim=2) # BxN 310 | min_idx_expanded = min_idx.unsqueeze(2).expand(min_idx.size()[0], min_idx.size()[1], 311 | self.rows * self.cols) # BxNxnode_num 312 | 313 | node_idx_list = self.node_idx_list.unsqueeze(0).unsqueeze(0).expand_as(min_idx_expanded).long() # BxNxnode_num 314 | mask = torch.eq(min_idx_expanded, node_idx_list).float() # BxNxnode_num 315 | mask_row_sum = torch.sum(mask, dim=1) + 0.00001 # Bxnode_num 316 | mask_row_max, _ = torch.max(mask, dim=1) # Bxnode_num, this indicates whether the node has nearby x 317 | 318 | # calculate the mean x for each node 319 | x_expanded_masked = x_expanded * mask.unsqueeze(1).expand_as(x_expanded) # BxCxNxnode_num 320 | x_expanded_masked_sum = torch.sum(x_expanded_masked, dim=2) # BxCxnode_num 321 | x_expanded_mask_mean = x_expanded_masked_sum / mask_row_sum.unsqueeze(1).expand_as( 322 | x_expanded_masked_sum) # Cxnode_num 323 | 324 | # each x_expanded_mask_mean (in total node_num vectors) will calculate its diff with all nodes 325 | # multiply the mask_row_max, so that the isolated node won't be pulled to the center 326 | x_expanded_mask_mean_expanded = x_expanded_mask_mean.unsqueeze(3).expand(self.batch_size, self.dim, 327 | self.rows * self.cols, 328 | self.rows * self.cols) # BxCxnode_numxnode_num 329 | node_expanded_transposed = self.node.unsqueeze(2).expand_as(x_expanded_mask_mean_expanded) # .transpose(1,2) 330 | diff_masked_mean = x_expanded_mask_mean_expanded - node_expanded_transposed # BxCxnode_numxnode_num 331 | diff_masked_mean = diff_masked_mean * mask_row_max.unsqueeze(2).unsqueeze(1).expand_as(diff_masked_mean) 332 | 333 | # compute the neighrbor weighting using pre-computed matrix 334 | weighting_matrix = self.get_weighting_matrix(sigma) 335 | 336 | # expand weighting_matrix to be batch, Bxnode_numxrowsxcols 337 | weighting_matrix = weighting_matrix.unsqueeze(0).expand(self.batch_size, self.rows * self.cols, self.rows, 338 | self.cols) 339 | 340 | # compute the update 341 | weighting_matrix = weighting_matrix.unsqueeze(1).expand(self.batch_size, self.dim, self.rows * self.cols, 342 | self.rows, self.cols) # BxCxnode_numxrowsxcols 343 | diff_masked_mean_matrix_view = diff_masked_mean.view(self.batch_size, self.dim, self.rows * self.cols, 344 | self.rows, self.cols) 345 | delta = diff_masked_mean_matrix_view * weighting_matrix * learning_rate # BxCxnode_numxrowsxcols 346 | delta = delta.sum(dim=2) 347 | 348 | # apply the update 349 | node_matrix_view = self.node.view(self.batch_size, self.dim, self.rows, self.cols) # BxCxrowsxcols 350 | node_matrix_view += delta 351 | 352 | # print(self.node) 353 | # print(delta.max()) 354 | 355 | def optimize(self, x): 356 | self.node_init(x.size()[0]) 357 | for iter in range(int(self.max_iteration / 3)): 358 | # get learning_rate and sigma 359 | learning_rate = self.learning_rate 360 | sigma = self.sigma 361 | self.batch_update(x, learning_rate, sigma) 362 | for iter in range(self.max_iteration): 363 | # get learning_rate and sigma 364 | learning_rate = self.learning_rate / (1 + 2*iter / self.max_iteration) 365 | sigma = self.sigma / (1 + 2*iter / self.max_iteration) 366 | self.batch_update(x, learning_rate, sigma) 367 | 368 | 369 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torchvision 4 | import numpy as np 5 | from PIL import Image 6 | import inspect, re 7 | import numpy as np 8 | import os 9 | import collections 10 | 11 | # Converts a Tensor into a Numpy array 12 | # |imtype|: the desired type of the converted numpy array 13 | # Li Jiaxin, because of the tanh() of the generator, here this function assumes that the output is in [-1,1] 14 | # but actually for depth estimation, this function is used only for input rgb, hence is the inverse of ([0-1]-0.5)/0.5 15 | # in this function, CxHxW -> HxWxC 16 | def tensor2im(image_tensor, imtype=np.uint8): 17 | image_numpy = image_tensor[0].cpu().float().numpy() 18 | image_numpy = (np.transpose(image_numpy, (1,2,0)) * 0.5 + 0.5) * 255.0 19 | return image_numpy.astype(imtype) 20 | 21 | # Li Jiaxin, for test images 22 | def tensor2grid_im(image_tensor): 23 | grid = torchvision.utils.make_grid(image_tensor, nrow=5, normalize=False) 24 | grid = (grid.cpu().float().numpy().transpose((1,2,0)) * 0.5 + 0.5) * 255.0 25 | return grid.astype(np.uint8) 26 | 27 | # Li Jiaxin, define a function to convert log depth to uint8 img 28 | # the log depth is single channel tensor ranges around [-10, 10] 29 | def log_depth2im(image_tensor): 30 | image_numpy = image_tensor[0].cpu().float().numpy() 31 | minimum = np.amin(image_numpy) 32 | maximum = np.amax(image_numpy) 33 | image_numpy = (np.transpose(image_numpy, (1,2,0)) - minimum) / (maximum-minimum) * 255 34 | return image_numpy.astype(np.uint8).repeat(3,2) 35 | 36 | def log_depth2grid_im(image_tensor): 37 | # the clamp is according to the data processing 38 | image_tensor = image_tensor.clamp(-0.84, 2.11) 39 | grid = torchvision.utils.make_grid(image_tensor, nrow=5, normalize=True) 40 | grid = grid.cpu().float().numpy().transpose((1, 2, 0)) * 255 41 | return grid.astype(np.uint8) 42 | 43 | def diagnose_network(net, name='network'): 44 | mean = 0.0 45 | count = 0 46 | for param in net.parameters(): 47 | if param.grad is not None: 48 | mean += torch.mean(torch.abs(param.grad.data)) 49 | count += 1 50 | if count > 0: 51 | mean = mean / count 52 | print(name) 53 | print(mean) 54 | 55 | 56 | def save_image(image_numpy, image_path): 57 | image_pil = Image.fromarray(image_numpy) 58 | image_pil.save(image_path) 59 | 60 | def info(object, spacing=10, collapse=1): 61 | """Print methods and doc strings. 62 | Takes module, class, list, dictionary, or string.""" 63 | methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)] 64 | processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s) 65 | print( "\n".join(["%s %s" % 66 | (method.ljust(spacing), 67 | processFunc(str(getattr(object, method).__doc__))) 68 | for method in methodList]) ) 69 | 70 | def varname(p): 71 | for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]: 72 | m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line) 73 | if m: 74 | return m.group(1) 75 | 76 | def print_numpy(x, val=True, shp=False): 77 | x = x.astype(np.float64) 78 | if shp: 79 | print('shape,', x.shape) 80 | if val: 81 | x = x.flatten() 82 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 83 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 84 | 85 | 86 | def mkdirs(paths): 87 | if isinstance(paths, list) and not isinstance(paths, str): 88 | for path in paths: 89 | mkdir(path) 90 | else: 91 | mkdir(paths) 92 | 93 | 94 | def mkdir(path): 95 | if not os.path.exists(path): 96 | os.makedirs(path) 97 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | from . import html 7 | 8 | class Visualizer(): 9 | def __init__(self, opt): 10 | # self.opt = opt 11 | self.display_id = opt.display_id 12 | self.use_html = 0 13 | self.win_size = opt.display_winsize 14 | self.name = opt.name 15 | if self.display_id > 0: 16 | import visdom 17 | self.vis = visdom.Visdom() 18 | 19 | if self.use_html: 20 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 21 | self.img_dir = os.path.join(self.web_dir, 'images') 22 | print('create web directory %s...' % self.web_dir) 23 | util.mkdirs([self.web_dir, self.img_dir]) 24 | 25 | 26 | # |visuals|: dictionary of images to display or save 27 | def display_current_results(self, visuals, epoch, iter=0): 28 | if self.display_id > 0: # show images in the browser 29 | idx = 1 30 | for label, item in visuals.items(): 31 | if 'pc' in label: 32 | self.vis.scatter(np.transpose(item), 33 | Y=None, 34 | opts=dict(title=label, markersize=0.5), 35 | win=self.display_id + idx) 36 | elif 'img' in label: 37 | # the transpose: HxWxC -> CxHxW 38 | self.vis.image(np.transpose(item, (2,0,1)), opts=dict(title=label), 39 | win=self.display_id + idx) 40 | idx += 1 41 | 42 | if self.use_html: # save images to a html file 43 | for label, image_numpy in visuals.items(): 44 | img_path = os.path.join(self.img_dir, 'epoch%.3d-%d_%s.png' % (epoch, iter, label)) 45 | util.save_image(image_numpy, img_path) 46 | # update website 47 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 48 | for n in range(epoch, 0, -1): 49 | webpage.add_header('epoch [%d]' % n) 50 | ims = [] 51 | txts = [] 52 | links = [] 53 | 54 | for label, image_numpy in visuals.items(): 55 | img_path = 'epoch%.3d-%d_%s.png' % (n, iter, label) 56 | ims.append(img_path) 57 | txts.append(label) 58 | links.append(img_path) 59 | webpage.add_images(ims, txts, links, width=self.win_size) 60 | webpage.save() 61 | 62 | # errors: dictionary of error labels and values 63 | def plot_current_errors(self, epoch, counter_ratio, opt, errors): 64 | if not hasattr(self, 'plot_data'): 65 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 66 | self.plot_data['X'].append(epoch + counter_ratio) 67 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 68 | self.vis.line( 69 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), 70 | Y=np.array(self.plot_data['Y']), 71 | opts={ 72 | 'title': self.name + ' loss over time', 73 | 'legend': self.plot_data['legend'], 74 | 'xlabel': 'epoch', 75 | 'ylabel': 'loss'}, 76 | win=self.display_id) 77 | 78 | # errors: same format as |errors| of plotCurrentErrors 79 | def print_current_errors(self, epoch, i, errors, t): 80 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 81 | for k, v in errors.items(): 82 | message += '%s: %.3f ' % (k, v) 83 | 84 | print(message) 85 | 86 | # save image to the disk 87 | def save_images(self, webpage, visuals, image_path): 88 | image_dir = webpage.get_image_dir() 89 | short_path = ntpath.basename(image_path[0]) 90 | name = os.path.splitext(short_path)[0] 91 | 92 | webpage.add_header(name) 93 | ims = [] 94 | txts = [] 95 | links = [] 96 | 97 | for label, image_numpy in visuals.items(): 98 | image_name = '%s_%s.png' % (name, label) 99 | save_path = os.path.join(image_dir, image_name) 100 | util.save_image(image_numpy, save_path) 101 | 102 | ims.append(image_name) 103 | txts.append(label) 104 | links.append(image_name) 105 | webpage.add_images(ims, txts, links, width=self.win_size) 106 | --------------------------------------------------------------------------------