├── .gitmodules ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configs ├── point_vox_lidar_template.yaml ├── point_vox_template.yaml ├── point_within_format.yaml └── point_within_lidar_template.yaml ├── criterions ├── __init__.py └── nce_loss_moco.py ├── data ├── README.md ├── redwood │ ├── README.md │ └── extract_pointcloud.py ├── scannet │ ├── README.md │ ├── SensorData.py │ ├── extract_pointcloud.py │ └── reader.py └── waymo │ ├── README.md │ └── extract_pointcloud.py ├── datasets ├── __init__.py ├── collators │ ├── __init__.py │ ├── point_moco_collator.py │ ├── point_vox_moco_collator.py │ ├── point_vox_moco_lidar_collator.py │ └── vox_moco_collator.py ├── depth_dataset.py └── transforms │ ├── augment3d.py │ ├── transforms.py │ └── voxelizer.py ├── imgs └── method.jpg ├── main.py ├── models ├── __init__.py ├── base_ssl3d_model.py └── trunks │ ├── __init__.py │ ├── mlp.py │ ├── pointnet.py │ ├── pointnet2_backbone.py │ ├── smlp.py │ ├── spconv │ ├── lib │ │ └── math_functions.py │ └── models │ │ ├── __init__.py │ │ ├── conditional_random_fields.py │ │ ├── model.py │ │ ├── modules │ │ ├── __init__.py │ │ ├── common.py │ │ ├── resnet_block.py │ │ └── senet_block.py │ │ ├── res16unet.py │ │ ├── resnet.py │ │ ├── resunet.py │ │ └── wrapper.py │ ├── spconv_backbone.py │ └── spconv_unet.py ├── requirements.txt ├── scripts ├── multinode-wrapper.py ├── pretrain_node1.sh ├── pretrain_node4.sh └── singlenode-wrapper.py ├── third_party └── pointnet2 │ ├── _ext_src │ ├── include │ │ ├── ball_query.h │ │ ├── cuda_utils.h │ │ ├── group_points.h │ │ ├── interpolate.h │ │ ├── sampling.h │ │ └── utils.h │ └── src │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── bindings.cpp │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ ├── sampling.cpp │ │ └── sampling_gpu.cu │ ├── pointnet2_modules.py │ ├── pointnet2_test.py │ ├── pointnet2_utils.py │ ├── pytorch_utils.py │ └── setup.py └── utils ├── __init__.py ├── logger.py ├── main_utils.py └── metrics_utils.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/OpenPCDet"] 2 | path = third_party/OpenPCDet 3 | url = https://github.com/zaiweizhang/OpenPCDet 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | In the context of this project, we do not expect pull requests. 4 | If you find a bug, or would like to suggest an improvement, please open an issue. 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Pretraining of 3D Features on any Point-Cloud 2 | This code provides a PyTorch implementation and pretrained models for **DepthContrast**, as described in the paper [Self-Supervised Pretraining of 3D Features on any Point-Cloud](http://arxiv.org/abs/2101.02691). 3 | 4 |
5 | DepthContrast Pipeline 6 |
7 | 8 | DepthContrast is an easy to implement self-supervised method that works across model architectures, input data formats, indoor/outdoor 3D, single/multi-view 3D data. 9 | Similarly to 2D contrastive approaches, DepthContrast learns representations by comparing transformations of a 3D pointcloud/voxel. It does not require any multi-view information between frames, such as point-to-point correspondances. It makes our framework generalize to any 3D pointcloud or voxel input. 10 | DepthContrast pretrains high capacity models for 3D recognition tasks, and leverages large-scale 3D data. It shows state-of-the-art performance on detection and segmentation benchmarks, outperforming all prior work on detection. 11 | 12 | # Model Zoo 13 | 14 | We release our PointNet++ and MinkowskiEngine UNet models pretrained with DepthContrast with the hope that other researchers might also benefit from these pretrained backbones. Due to license issue, models pretrained on Waymo cannot be released. For PointnetMSG and Spconv-UNet models, we encourage the researchers to train by themselves using the provided script. 15 | 16 | We first provide PointNet++ models with different sizes. 17 | | network | epochs | batch-size | ScanNet Det with VoteNet | url | args | 18 | |-------------------|---------------------|---------------------|--------------------|--------------------|--------------------| 19 | | PointNet++-1x | 150 | 1024 | 61.9 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet1x/checkpoint-ep150.pth.tar) | [config](./configs/point_within_format.yaml) | 20 | | PointNet++-2x | 200 | 1024 | 63.3 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet2x/checkpoint-ep200.pth.tar) | [config](./configs/point_within_format.yaml) | 21 | | PointNet++-3x | 150 | 1024 | 64.1 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet3x/checkpoint-ep150.pth.tar) | [config](./configs/point_within_format.yaml) | 22 | | PointNet++-4x | 100 | 1024 | 63.8 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet4x/checkpoint-ep100.pth.tar) | [config](./configs/point_within_format.yaml) | 23 | 24 | The ScanNet detection evaluation metric is mAP at IOU=0.25. You need to change the scale parameter in the config files accordingly. 25 | 26 | We provide the joint training results here, with different epochs. We use epoch 400 to generate the results reported in the paper. 27 | 28 | | Backbone | epochs | batch-size | url | args | 29 | |-------------------|-------------------|---------------------|--------------------|--------------------| 30 | | PointNet++ & MinkowskiEngine UNet | 300 | 1024 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet_unet_joint/checkpoint-ep300.pth.tar) | [config](./configs/point_vox_template.yaml) | 31 | | PointNet++ & MinkowskiEngine UNet | 400 | 1024 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet_unet_joint/checkpoint-ep400.pth.tar) | [config](./configs/point_vox_template.yaml) | 32 | | PointNet++ & MinkowskiEngine UNet | 500 | 1024 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet_unet_joint/checkpoint-ep500.pth.tar) | [config](./configs/point_vox_template.yaml) | 33 | | PointNet++ & MinkowskiEngine UNet | 600 | 1024 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet_unet_joint/checkpoint-ep600.pth.tar) | [config](./configs/point_vox_template.yaml) | 34 | | PointNet++ & MinkowskiEngine UNet | 700 | 1024 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet_unet_joint/checkpoint-ep700.pth.tar) | [config](./configs/point_vox_template.yaml) | 35 | 36 | # Running DepthContrast unsupervised training 37 | 38 | ## Requirements 39 | You can use the requirements.txt to setup the environment. 40 | First download the git-repo and install the pointnet modules: 41 | ``` 42 | git clone --recursive https://github.com/facebookresearch/DepthContrast.git 43 | cd pointnet2 44 | python setup.py install 45 | ``` 46 | Then install all other packages: 47 | ``` 48 | pip install -r requirements.txt 49 | ``` 50 | or 51 | ``` 52 | conda install --file requirements.txt 53 | ``` 54 | 55 | For voxel representation, you have to install MinkowskiEngine. Please see [here](https://github.com/chrischoy/SpatioTemporalSegmentation) on how to install it. 56 | 57 | For the lidar point cloud pretraining, we use models from [OpenPCDet](https://github.com/open-mmlab/OpenPCDet). It should be in the third_party folder. To install OpenPCDet, you need to install [spconv](https://github.com/traveller59/spconv), which is a bit difficult to install and may not be compatible with MinkowskiEngine. Thus, we suggest you use a different conda environment for lidar point cloud pretraining. 58 | 59 | ## Singlenode training 60 | DepthContrast is very simple to implement and experiment with. 61 | 62 | To experiment with it on one GPU and debugging, you can do: 63 | ``` 64 | python main.py /path/to/cfg/file 65 | ``` 66 | 67 | For the actual training, please use the distributed trainer. 68 | For multi-gpu training in one node, you can run: 69 | ``` 70 | python main.py /path/to/cfg_file --multiprocessing-distributed --world-size 1 --rank 0 --ngpus number_of_gpus 71 | ``` 72 | To run it with just one gpu, just set the --ngpus to 1. 73 | For submitting it to a slurm node, you can use ./scripts/pretrain_node1.sh. For hyper-parameter tuning, please change the config files. 74 | 75 | ## Multinode training 76 | Distributed training is available via Slurm. We provide several [SBATCH scripts](./scripts) to reproduce our results. 77 | For example, to train DepthContrast on 4 nodes and 32 GPUs with a batch size of 1024 run: 78 | ``` 79 | sbatch ./scripts/pretrain_node4.sh /path/to/cfg_file 80 | ``` 81 | Note that you might need to remove the copyright header from the sbatch file to launch it. 82 | 83 | # Evaluating models 84 | For votenet finetuning, please checkout this [repo](https://github.com/zaiweizhang/votenet) for more details. 85 | 86 | For H3DNet finetuning, please checkout this [repo](https://github.com/zaiweizhang/H3DNet) for more details. 87 | 88 | For voxel scene segmentation task finetuning, please checkout this [repo](https://github.com/zaiweizhang/SpatioTemporalSegmentation) for more details. 89 | 90 | For lidar point cloud object detection task finetuning, please checkout this [repo](https://github.com/zaiweizhang/OpenPCDet) for more details. 91 | 92 | # Common Issues 93 | For help or issues using DepthContrast, please submit a GitHub issue. 94 | 95 | ## License 96 | See the [LICENSE](LICENSE) file for more details. 97 | 98 | ## Citation 99 | If you find this repository useful in your research, please cite: 100 | 101 | ``` 102 | @inproceedings{zhang_depth_contrast, 103 | title={Self-Supervised Pretraining of 3D Features on any Point-Cloud}, 104 | author={Zhang, Zaiwei and Girdhar, Rohit and Joulin, Armand and Misra, Ishan}, 105 | journal={arXiv preprint arXiv:2101.02691}, 106 | year={2021} 107 | } 108 | ``` 109 | -------------------------------------------------------------------------------- /configs/point_vox_lidar_template.yaml: -------------------------------------------------------------------------------- 1 | resume: false 2 | test_only: false 3 | num_workers: 10 4 | 5 | required_devices: 1 6 | no_test: false 7 | debug: false 8 | log2tb: true 9 | allow_double_bs: false 10 | seed: 0 11 | distributed: false 12 | test_freq: 10 13 | print_freq: 5 14 | 15 | dataset: 16 | DATASET_NAMES: [waymo] 17 | DATA_PATHS: ['/path/to/waymo.npy'] 18 | BATCHSIZE_PER_REPLICA: 32 19 | LABEL_TYPE: sample_index 20 | DATA_TYPE: point_vox 21 | Lidar: True 22 | VOX: True 23 | POINT_TRANSFORMS: 24 | - name: randomcuboidLidar 25 | crop: 0.5 26 | npoints: 10000 27 | randcrop: 1.0 28 | aspect: 0.75 29 | - name: randomdrop 30 | crop: 0.2 31 | - name: RandomFlipLidar 32 | - name: RandomRotateLidar 33 | - name: RandomScaleLidar 34 | - name: ToTensorLidar 35 | COLLATE_FUNCTION: "point_vox_moco_collator" 36 | INPUT_KEY_NAMES: ["points", "points_moco", "vox", "vox_moco"] 37 | DROP_LAST: True 38 | 39 | optimizer: 40 | name: "sgd" 41 | weight_decay: 0.0001 42 | momentum: 0.9 43 | nesterov: False 44 | num_epochs: 1000 45 | lr: 46 | name: "cosine" 47 | base_lr: 0.12 48 | final_lr: 0.00012 49 | 50 | model: 51 | name: "PointnetMSG_UnetV2" 52 | model_dir: "checkpoints/pointnetMSG_UnetV2_within_format" 53 | model_input: ["points", "points_moco", "vox", "vox_moco"] 54 | model_feature: [["fp2"], ["fp2"], ["conv4"], ["conv4"]] 55 | Lidar: True 56 | arch_point: "pointnet_msg" 57 | args_point: 58 | use_mlp: True 59 | mlp_dim: [128, 128, 128] 60 | arch_vox: "UNetV2" 61 | args_vox: 62 | use_mlp: True 63 | mlp_dim: [128, 128, 128] 64 | 65 | loss: 66 | name: "NCELossMoco" 67 | args: 68 | two_domain: True 69 | LOSS_TYPE: NPID 70 | OTHER_INPUT: True 71 | within_format_weight0: 1.0 72 | within_format_weight1: 1.0 73 | across_format_weight0: 0.0 74 | across_format_weight1: 0.0 75 | NCE_LOSS: 76 | NORM_EMBEDDING: True 77 | TEMPERATURE: 0.1 78 | LOSS_TYPE: cross_entropy 79 | NUM_NEGATIVES: 65536 80 | EMBEDDING_DIM: 128 81 | -------------------------------------------------------------------------------- /configs/point_vox_template.yaml: -------------------------------------------------------------------------------- 1 | resume: false 2 | test_only: false 3 | num_workers: 10 4 | 5 | required_devices: 1 6 | no_test: false 7 | debug: false 8 | log2tb: true 9 | allow_double_bs: false 10 | seed: 0 11 | distributed: false 12 | test_freq: 2 13 | print_freq: 5 14 | 15 | dataset: 16 | DATASET_NAMES: [scannet] 17 | DATA_PATHS: ['/path/to/datalist.npy'] 18 | BATCHSIZE_PER_REPLICA: 32 19 | LABEL_TYPE: sample_index 20 | DATA_TYPE: point_vox 21 | VOX: True 22 | POINT_TRANSFORMS: 23 | - name: randomcuboid 24 | crop: 0.5 25 | npoints: 10000 26 | randcrop: 1.0 27 | aspect: 0.75 28 | - name: randomdrop 29 | crop: 0.2 30 | - name: multiscale 31 | - name: RandomFlip 32 | - name: RandomRotateAll 33 | - name: RandomScale 34 | - name: ToTensor 35 | COLLATE_FUNCTION: "point_vox_moco_collator" 36 | INPUT_KEY_NAMES: ["points", "points_moco", "vox", "vox_moco"] 37 | DROP_LAST: True 38 | 39 | optimizer: 40 | name: "sgd" 41 | weight_decay: 0.0001 42 | momentum: 0.9 43 | nesterov: False 44 | num_epochs: 1000 45 | lr: 46 | name: "cosine" 47 | base_lr: 0.12 48 | final_lr: 0.00012 49 | 50 | model: 51 | name: "Pointnet1X_Unet256" 52 | model_dir: "checkpoints/pointnet1x_Unet256/" 53 | model_input: ["points", "points_moco", "vox", "vox_moco"] 54 | model_feature: [["fp2"], ["fp2"], ["plane7"], ["plane7"]] 55 | arch_point: "pointnet" 56 | args_point: 57 | scale: 1 58 | use_mlp: True 59 | mlp_dim: [512, 512, 128] 60 | arch_vox: "unet" 61 | args_vox: 62 | use_mlp: True 63 | mlp_dim: [256, 256, 128] 64 | 65 | loss: 66 | name: "NCELossMoco" 67 | args: 68 | two_domain: True 69 | LOSS_TYPE: NPID,CMC 70 | OTHER_INPUT: True 71 | within_format_weight0: 0.5 72 | within_format_weight1: 0.5 73 | across_format_weight0: 0.5 74 | across_format_weight1: 0.5 75 | NCE_LOSS: 76 | NORM_EMBEDDING: True 77 | TEMPERATURE: 0.1 78 | LOSS_TYPE: cross_entropy 79 | NUM_NEGATIVES: 131072 80 | EMBEDDING_DIM: 128 81 | -------------------------------------------------------------------------------- /configs/point_within_format.yaml: -------------------------------------------------------------------------------- 1 | resume: false 2 | test_only: false 3 | num_workers: 10 4 | 5 | required_devices: 1 6 | no_test: false 7 | debug: false 8 | log2tb: true 9 | allow_double_bs: false 10 | seed: 0 11 | distributed: false 12 | test_freq: 10 13 | print_freq: 5 14 | 15 | dataset: 16 | DATASET_NAMES: [scannet] 17 | DATA_PATHS: ['/path/to/datalist.npy'] 18 | DATA_LIMIT: -1 19 | BATCHSIZE_PER_REPLICA: 32 20 | LABEL_TYPE: sample_index 21 | DATA_TYPE: points 22 | VOX: False 23 | POINT_TRANSFORMS: 24 | - name: randomcuboid 25 | crop: 0.5 26 | npoints: 10000 27 | randcrop: 1.0 28 | aspect: 0.75 29 | - name: randomdrop 30 | crop: 0.2 31 | - name: multiscale 32 | - name: RandomFlip 33 | - name: RandomRotateAll 34 | - name: RandomScale 35 | - name: ToTensor 36 | COLLATE_FUNCTION: "point_moco_collator" 37 | INPUT_KEY_NAMES: ["points", "points_moco"] 38 | DROP_LAST: True 39 | 40 | optimizer: 41 | name: "sgd" 42 | weight_decay: 0.0001 43 | momentum: 0.9 44 | nesterov: False 45 | num_epochs: 1000 46 | lr: 47 | name: "cosine" 48 | base_lr: 0.12 49 | final_lr: 0.00012 50 | 51 | model: 52 | name: "Pointnet1X" 53 | model_dir: "checkpoints/pointnet1x_within_format" 54 | model_input: ["points", "points_moco"] 55 | model_feature: [["fp2"], ["fp2"]] 56 | arch_point: "pointnet" 57 | args_point: 58 | scale: 1 59 | use_mlp: True 60 | mlp_dim: [512, 512, 128] 61 | loss: 62 | name: "NCELossMoco" 63 | args: 64 | two_domain: True 65 | LOSS_TYPE: NPID 66 | OTHER_INPUT: False 67 | within_format_weight0: 1.0 68 | within_format_weight1: 0.0 69 | across_format_weight0: 0.0 70 | across_format_weight1: 0.0 71 | NCE_LOSS: 72 | NORM_EMBEDDING: True 73 | TEMPERATURE: 0.1 74 | LOSS_TYPE: cross_entropy 75 | NUM_NEGATIVES: 131072 76 | EMBEDDING_DIM: 128 77 | -------------------------------------------------------------------------------- /configs/point_within_lidar_template.yaml: -------------------------------------------------------------------------------- 1 | resume: false 2 | test_only: false 3 | num_workers: 10 4 | 5 | required_devices: 1 6 | no_test: false 7 | debug: false 8 | log2tb: true 9 | allow_double_bs: false 10 | seed: 0 11 | distributed: false 12 | test_freq: 10 13 | print_freq: 5 14 | 15 | dataset: 16 | DATASET_NAMES: [waymo] 17 | DATA_PATHS: ['/path/to/waymo.npy'] 18 | BATCHSIZE_PER_REPLICA: 8 19 | LABEL_TYPE: sample_index 20 | DATA_TYPE: points 21 | Lidar: True 22 | VOX: False 23 | POINT_TRANSFORMS: 24 | - name: randomcuboidLidar 25 | crop: 0.5 26 | npoints: 10000 27 | randcrop: 1.0 28 | aspect: 0.75 29 | - name: randomdrop 30 | crop: 0.2 31 | - name: RandomFlipLidar 32 | - name: RandomRotateLidar 33 | - name: RandomScaleLidar 34 | - name: ToTensorLidar 35 | COLLATE_FUNCTION: "point_moco_collator" 36 | INPUT_KEY_NAMES: ["points", "points_moco"] 37 | DROP_LAST: True 38 | 39 | optimizer: 40 | name: "sgd" 41 | weight_decay: 0.0001 42 | momentum: 0.9 43 | nesterov: False 44 | num_epochs: 1000 45 | lr: 46 | name: "cosine" 47 | base_lr: 0.12 48 | final_lr: 0.00012 49 | 50 | model: 51 | name: "PointnetMSG" 52 | model_dir: "checkpoints/pointnetMSG_within_format" 53 | model_input: ["points", "points_moco"] 54 | model_feature: [["fp2"], ["fp2"]] 55 | Lidar: True 56 | VOX: False 57 | arch_point: "pointnet_msg" 58 | args_point: 59 | use_mlp: True 60 | mlp_dim: [128, 128, 128] 61 | 62 | loss: 63 | name: "NCELossMoco" 64 | args: 65 | two_domain: True 66 | LOSS_TYPE: NPID 67 | OTHER_INPUT: False 68 | within_format_weight0: 1.0 69 | within_format_weight1: 0.0 70 | across_format_weight0: 0.0 71 | across_format_weight1: 0.0 72 | NCE_LOSS: 73 | NORM_EMBEDDING: True 74 | TEMPERATURE: 0.1 75 | LOSS_TYPE: cross_entropy 76 | NUM_NEGATIVES: 65536 77 | EMBEDDING_DIM: 128 78 | -------------------------------------------------------------------------------- /criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .nce_loss_moco import * 8 | -------------------------------------------------------------------------------- /criterions/nce_loss_moco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | import logging 9 | import math 10 | import pprint 11 | 12 | import numpy as np 13 | import torch 14 | from torch import nn 15 | 16 | # utils 17 | @torch.no_grad() 18 | def concat_all_gather(tensor): 19 | """ 20 | Performs all_gather operation on the provided tensors. 21 | *** Warning ***: torch.distributed.all_gather has no gradient. 22 | """ 23 | if not (torch.distributed.is_initialized()): 24 | return tensor 25 | 26 | tensors_gather = [torch.ones_like(tensor) 27 | for _ in range(torch.distributed.get_world_size())] 28 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 29 | 30 | output = torch.cat(tensors_gather, dim=0) 31 | return output 32 | 33 | class NCELossMoco(nn.Module): 34 | """ 35 | Distributed version of the NCE loss. It performs an "all_gather" to gather 36 | the allocated buffers like memory no a single gpu. For this, Pytorch distributed 37 | backend is used. If using NCCL, one must ensure that all the buffer are on GPU. 38 | This class supports training using both NCE and CrossEntropy (InfoNCE). 39 | """ 40 | 41 | def __init__(self, config): 42 | super(NCELossMoco, self).__init__() 43 | 44 | assert config["NCE_LOSS"]["LOSS_TYPE"] in [ 45 | "cross_entropy", 46 | ], f"Supported types are cross_entropy." 47 | 48 | self.loss_type = config["NCE_LOSS"]["LOSS_TYPE"] 49 | self.loss_list = config["LOSS_TYPE"].split(",") 50 | self.other_queue = config["OTHER_INPUT"] 51 | 52 | self.npid0_w = float(config["within_format_weight0"]) 53 | self.npid1_w = float(config["within_format_weight1"]) 54 | self.cmc0_w = float(config["across_format_weight0"]) 55 | self.cmc1_w = float(config["across_format_weight1"]) 56 | 57 | self.K = int(config["NCE_LOSS"]["NUM_NEGATIVES"]) 58 | self.dim = int(config["NCE_LOSS"]["EMBEDDING_DIM"]) 59 | self.T = float(config["NCE_LOSS"]["TEMPERATURE"]) 60 | 61 | self.register_buffer("queue", torch.randn(self.dim, self.K)) 62 | self.queue = nn.functional.normalize(self.queue, dim=0) 63 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 64 | 65 | if self.other_queue: 66 | self.register_buffer("queue_other", torch.randn(self.dim, self.K)) 67 | self.queue_other = nn.functional.normalize(self.queue_other, dim=0) 68 | self.register_buffer("queue_other_ptr", torch.zeros(1, dtype=torch.long)) 69 | 70 | # cross-entropy loss. Also called InfoNCE 71 | self.xe_criterion = nn.CrossEntropyLoss() 72 | 73 | # other constants 74 | self.normalize_embedding = config["NCE_LOSS"]["NORM_EMBEDDING"] 75 | 76 | @classmethod 77 | def from_config(cls, config): 78 | return cls(config) 79 | 80 | @torch.no_grad() 81 | def _dequeue_and_enqueue(self, keys, okeys=None): 82 | # gather keys before updating queue 83 | keys = concat_all_gather(keys) 84 | 85 | batch_size = keys.shape[0] 86 | 87 | ptr = int(self.queue_ptr) 88 | assert self.K % batch_size == 0 # for simplicity 89 | 90 | # replace the keys at ptr (dequeue and enqueue) 91 | self.queue[:, ptr:ptr + batch_size] = torch.transpose(keys, 0, 1) 92 | ptr = (ptr + batch_size) % self.K # move pointer 93 | 94 | self.queue_ptr[0] = ptr 95 | 96 | if self.other_queue: 97 | # gather keys before updating queue 98 | okeys = concat_all_gather(okeys) 99 | 100 | other_ptr = int(self.queue_other_ptr) 101 | 102 | # replace the keys at ptr (dequeue and enqueue) 103 | self.queue_other[:, other_ptr:other_ptr + batch_size] = torch.transpose(okeys, 0, 1)#okeys.T 104 | other_ptr = (other_ptr + batch_size) % self.K # move pointer 105 | 106 | self.queue_other_ptr[0] = other_ptr 107 | 108 | def forward(self, output): 109 | assert isinstance( 110 | output, list 111 | ), "Model output should be a list of tensors. Got Type {}".format(type(output)) 112 | 113 | if self.normalize_embedding: 114 | normalized_output1 = nn.functional.normalize(output[0], dim=1, p=2) 115 | normalized_output2 = nn.functional.normalize(output[1], dim=1, p=2) 116 | if self.other_queue: 117 | normalized_output3 = nn.functional.normalize(output[2], dim=1, p=2) 118 | normalized_output4 = nn.functional.normalize(output[3], dim=1, p=2) 119 | 120 | # positive logits: Nx1 121 | l_pos = torch.einsum('nc,nc->n', [normalized_output1, normalized_output2]).unsqueeze(-1) 122 | 123 | # negative logits: NxK 124 | l_neg = torch.einsum('nc,ck->nk', [normalized_output1, self.queue.clone().detach()]) 125 | 126 | # logits: Nx(1+K) 127 | logits = torch.cat([l_pos, l_neg], dim=1) 128 | 129 | # apply temperature 130 | logits /= self.T 131 | 132 | if self.other_queue: 133 | 134 | l_pos_p2i = torch.einsum('nc,nc->n', [normalized_output1, normalized_output4]).unsqueeze(-1) 135 | l_neg_p2i = torch.einsum('nc,ck->nk', [normalized_output1, self.queue_other.clone().detach()]) 136 | logits_p2i = torch.cat([l_pos_p2i, l_neg_p2i], dim=1) 137 | logits_p2i /= self.T 138 | 139 | 140 | l_pos_i2p = torch.einsum('nc,nc->n', [normalized_output3, normalized_output2]).unsqueeze(-1) 141 | l_neg_i2p = torch.einsum('nc,ck->nk', [normalized_output3, self.queue.clone().detach()]) 142 | logits_i2p = torch.cat([l_pos_i2p, l_neg_i2p], dim=1) 143 | logits_i2p /= self.T 144 | 145 | 146 | l_pos_other = torch.einsum('nc,nc->n', [normalized_output3, normalized_output4]).unsqueeze(-1) 147 | l_neg_other = torch.einsum('nc,ck->nk', [normalized_output3, self.queue_other.clone().detach()]) 148 | logits_other = torch.cat([l_pos_other, l_neg_other], dim=1) 149 | logits_other /= (self.T) 150 | 151 | if self.other_queue: 152 | self._dequeue_and_enqueue(normalized_output2, okeys=normalized_output4) 153 | else: 154 | self._dequeue_and_enqueue(normalized_output2) 155 | 156 | 157 | labels = torch.zeros( 158 | logits.shape[0], device=logits.device, dtype=torch.int64 159 | ) 160 | 161 | loss_npid = self.xe_criterion(torch.squeeze(logits), labels) 162 | 163 | loss_npid_other = torch.tensor(0) 164 | loss_cmc_p2i = torch.tensor(0) 165 | loss_cmc_i2p = torch.tensor(0) 166 | 167 | if self.other_queue: 168 | loss_cmc_p2i = self.xe_criterion(torch.squeeze(logits_p2i), labels) 169 | loss_cmc_i2p = self.xe_criterion(torch.squeeze(logits_i2p), labels) 170 | loss_npid_other = self.xe_criterion(torch.squeeze(logits_other), labels) 171 | 172 | curr_loss = 0 173 | for ltype in self.loss_list: 174 | if ltype == "CMC": 175 | curr_loss += loss_cmc_p2i * self.cmc0_w + loss_cmc_i2p * self.cmc1_w 176 | elif ltype == "NPID": 177 | curr_loss += loss_npid * self.npid0_w 178 | curr_loss += loss_npid_other * self.npid1_w 179 | else: 180 | curr_loss = 0 181 | curr_loss += loss_npid * self.npid0_w 182 | 183 | loss = curr_loss 184 | 185 | return loss, [loss_npid, loss_npid_other, loss_cmc_p2i, loss_cmc_i2p] 186 | 187 | def __repr__(self): 188 | repr_dict = { 189 | "name": self._get_name(), 190 | "loss_type": self.loss_type, 191 | } 192 | return pprint.pformat(repr_dict, indent=2) 193 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ### Prepare Dataset 2 | 3 | We use scannet and redwood for Pointnet++ and MinkowskiEngine UNet pretraining. 4 | We use waymo for PointnetMSG and Spconv-UNet model pretraining. 5 | 6 | Please see the specific instructions under each folder for how to generate the training data. 7 | 8 | Once you have generated the training data, the framework just takes a .npy file which consists of a list of paths to the extracted pointclouds: 9 | 10 | [/path/to/pt1, /path/to/pt2, path/to/pt3, ..... ] 11 | 12 | -------------------------------------------------------------------------------- /data/redwood/README.md: -------------------------------------------------------------------------------- 1 | ### Prepare Dataset 2 | 3 | 1. Download Redwood data [HERE](https://github.com/intel-isl/redwood-3dscan). 4 | 5 | 2. Extract the pointclouds of the desired classes using extract_pointcloud.py 6 | 7 | python extract_pointcloud.py /path/to/extracted_data /path/to/extracted_pointcloud_visualization /path/to/extracted_pointclouds redwood_datalist.npy 8 | 9 | The visualizations are optional. 10 | 11 | 3. In our experiment, we use the following 10 classes: 12 | 13 | car, chair, table, bench, bicycle, plant, playground, sculpture, sign, trash_container 14 | 15 | 4. You will need to concatenate the datalist files if you use data from multiple categories. 16 | -------------------------------------------------------------------------------- /data/redwood/extract_pointcloud.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import open3d as o3d 8 | import cv2 9 | import sys 10 | import numpy as np 11 | from glob import glob 12 | import os 13 | import zipfile 14 | 15 | def pc2obj(pc, filepath='test.obj'): 16 | pc = pc.T 17 | nverts = pc.shape[1] 18 | with open(filepath, 'w') as f: 19 | f.write("# OBJ file\n") 20 | for v in range(nverts): 21 | f.write("v %.4f %.4f %.4f\n" % (pc[0,v],pc[1,v],pc[2,v])) 22 | 23 | nump = 50000 ### Number of points in the point clouds 24 | scenelist = glob(sys.argv[1]+"*.zip") ### Path to the zip files 25 | datalist = [] 26 | 27 | for scene in scenelist: 28 | if os.path.exists(sys.argv[3]+scene.split("/")[-1].split(".")[0]): 29 | continue 30 | 31 | os.system("rm -rf test/") 32 | with zipfile.ZipFile(scene, 'r') as zip_ref: 33 | zip_ref.extractall("test") 34 | 35 | os.system("ls test/rgb/ > rgblist") 36 | os.system("ls test/depth/ > depthlist") 37 | 38 | rgb_path = {} 39 | depth_path = {} 40 | rgblist = open("rgblist", "r") 41 | depthlist = open("depthlist", "r") 42 | for line in rgblist: 43 | seqname = line.split("-")[0] 44 | if seqname in rgb_path: 45 | print (line) 46 | else: 47 | rgb_path[seqname] = line.strip("\n") 48 | 49 | for line in depthlist: 50 | seqname = line.split("-")[0] 51 | if seqname in depth_path: 52 | print (line) 53 | else: 54 | depth_path[seqname] = line.strip("\n") 55 | 56 | intrin_cam = o3d.camera.PinholeCameraIntrinsic() 57 | #print (scene) 58 | if not os.path.exists(sys.argv[2]+scene.split("/")[-1].split(".")[0]): 59 | os.mkdir(sys.argv[2]+scene.split("/")[-1].split(".")[0]) 60 | if not os.path.exists(sys.argv[3]+scene.split("/")[-1].split(".")[0]): 61 | os.mkdir(sys.argv[3]+scene.split("/")[-1].split(".")[0]) 62 | 63 | success = 0 64 | 65 | if len(rgb_path) <= len(depth_path): 66 | framelist = rgb_path.keys() 67 | else: 68 | framelist = depth_path.keys() 69 | counter = 15 70 | for frame in framelist: 71 | counter += 1 72 | if counter < 15: ### Set the sampling rate here 73 | continue 74 | depth = "test/depth/"+depth_path[frame] 75 | 76 | rgbpath = "test/rgb/"+rgb_path[frame]#scene+"/matterport_color_images/"+room_name+"_i%d_%d.jpg" % (cam_num, frame_num) 77 | 78 | depth_im = cv2.imread(depth, -1) 79 | 80 | try: 81 | o3d_depth = o3d.geometry.Image(depth_im) 82 | rgb_im = cv2.imread(rgbpath) 83 | o3d_rgb = o3d.geometry.Image(rgb_im) 84 | o3d_rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(o3d_rgb, o3d_depth, depth_scale=1000.0, depth_trunc=1000.0, convert_rgb_to_intensity=False) 85 | except: 86 | print (frame) 87 | continue 88 | 89 | if (depth_im.shape[1] != 640) or (depth_im.shape[0] != 480): 90 | print (frame) 91 | continue 92 | intrin_cam.set_intrinsics(width=depth_im.shape[1], height=depth_im.shape[0], fx=525.0, fy=525.0, cx=319.5, cy=239.5) 93 | 94 | pts = o3d.geometry.PointCloud.create_from_rgbd_image(o3d_rgbd, intrin_cam, np.eye(4)) 95 | 96 | if len(np.array(pts.points)) < 100: 97 | continue 98 | 99 | if len(np.array(pts.points)) >= nump: 100 | sel_idx = np.random.choice(len(np.array(pts.points)), nump, replace=False) 101 | else: 102 | sel_idx = np.random.choice(len(np.array(pts.points)), nump, replace=True) 103 | temp = np.array(pts.points)[sel_idx] 104 | 105 | if np.isnan(np.sum(temp)): 106 | continue 107 | 108 | color_points = np.array(pts.colors)[sel_idx] 109 | color_points[:,[0,1,2]] = color_points[:,[2,1,0]] 110 | 111 | pts.points = o3d.utility.Vector3dVector(temp) 112 | pts.colors = o3d.utility.Vector3dVector(color_points) 113 | data = np.concatenate([temp,color_points], axis=1) 114 | 115 | o3d.io.write_point_cloud(sys.argv[2]+scene.split("/")[-1].split(".")[0]+"/"+frame+".ply", pts) 116 | np.save(sys.argv[3]+scene.split("/")[-1].split(".")[0]+"/"+frame+".npy", data) 117 | 118 | datalist.append(os.path.abspath(sys.argv[3]+scene.split("/")[-1].split(".")[0]+"/"+frame+".npy")) 119 | 120 | counter = 0 121 | success += 1 122 | 123 | print (success) 124 | 125 | np.save(sys.argv[4], datalist) 126 | -------------------------------------------------------------------------------- /data/scannet/README.md: -------------------------------------------------------------------------------- 1 | ### Prepare Dataset 2 | 3 | 1. Download ScanNet data [HERE](https://github.com/ScanNet/ScanNet). 4 | 5 | 2. Please use python2.7 to run the following command to extract depth images, color images, and camera intrinsics: 6 | 7 | ``` 8 | python2.7 reader.py --scans_path path/to/scannet_v2/scans/ --output_path path/to/extracted_data/ --export_depth_images --export_color_images --export_poses --export_intrinsics 9 | ``` 10 | 11 | 2. Extract point clouds using extract_pointcloud.py 12 | 13 | ``` 14 | python extract_pointcloud.py path/to/extracted_data/depth/ path/to/extracted_pointcloud_visualization/ path/to/extracted_pointclouds/ scannet_datalist.npy 15 | ``` 16 | 17 | The visualizations are optional. -------------------------------------------------------------------------------- /data/scannet/SensorData.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Code borrowed from: https://github.com/ScanNet/ScanNet 9 | """ 10 | import os, struct 11 | import numpy as np 12 | import zlib 13 | import imageio 14 | import cv2 15 | 16 | COMPRESSION_TYPE_COLOR = {-1:'unknown', 0:'raw', 1:'png', 2:'jpeg'} 17 | COMPRESSION_TYPE_DEPTH = {-1:'unknown', 0:'raw_ushort', 1:'zlib_ushort', 2:'occi_ushort'} 18 | 19 | class RGBDFrame(): 20 | 21 | def load(self, file_handle): 22 | self.camera_to_world = np.asarray(struct.unpack('f'*16, file_handle.read(16*4)), dtype=np.float32).reshape(4, 4) 23 | self.timestamp_color = struct.unpack('Q', file_handle.read(8))[0] 24 | self.timestamp_depth = struct.unpack('Q', file_handle.read(8))[0] 25 | self.color_size_bytes = struct.unpack('Q', file_handle.read(8))[0] 26 | self.depth_size_bytes = struct.unpack('Q', file_handle.read(8))[0] 27 | self.color_data = ''.join(struct.unpack('c'*self.color_size_bytes, file_handle.read(self.color_size_bytes))) 28 | self.depth_data = ''.join(struct.unpack('c'*self.depth_size_bytes, file_handle.read(self.depth_size_bytes))) 29 | 30 | 31 | def decompress_depth(self, compression_type): 32 | if compression_type == 'zlib_ushort': 33 | return self.decompress_depth_zlib() 34 | else: 35 | raise 36 | 37 | 38 | def decompress_depth_zlib(self): 39 | return zlib.decompress(self.depth_data) 40 | 41 | 42 | def decompress_color(self, compression_type): 43 | if compression_type == 'jpeg': 44 | return self.decompress_color_jpeg() 45 | else: 46 | raise 47 | 48 | 49 | def decompress_color_jpeg(self): 50 | return imageio.imread(self.color_data) 51 | 52 | 53 | class SensorData: 54 | 55 | def __init__(self, filename): 56 | self.version = 4 57 | self.load(filename) 58 | 59 | 60 | def load(self, filename): 61 | with open(filename, 'rb') as f: 62 | version = struct.unpack('I', f.read(4))[0] 63 | assert self.version == version 64 | strlen = struct.unpack('Q', f.read(8))[0] 65 | self.sensor_name = ''.join(struct.unpack('c'*strlen, f.read(strlen))) 66 | self.intrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 67 | self.extrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 68 | self.intrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 69 | self.extrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 70 | self.color_compression_type = COMPRESSION_TYPE_COLOR[struct.unpack('i', f.read(4))[0]] 71 | self.depth_compression_type = COMPRESSION_TYPE_DEPTH[struct.unpack('i', f.read(4))[0]] 72 | self.color_width = struct.unpack('I', f.read(4))[0] 73 | self.color_height = struct.unpack('I', f.read(4))[0] 74 | self.depth_width = struct.unpack('I', f.read(4))[0] 75 | self.depth_height = struct.unpack('I', f.read(4))[0] 76 | self.depth_shift = struct.unpack('f', f.read(4))[0] 77 | num_frames = struct.unpack('Q', f.read(8))[0] 78 | self.frames = [] 79 | for i in range(num_frames): 80 | frame = RGBDFrame() 81 | frame.load(f) 82 | self.frames.append(frame) 83 | 84 | 85 | def export_depth_images(self, output_path, image_size=None, frame_skip=1): 86 | if not os.path.exists(output_path): 87 | os.makedirs(output_path) 88 | print ('exporting', len(self.frames)//frame_skip, ' depth frames to', output_path) 89 | for f in range(0, len(self.frames), frame_skip): 90 | depth_data = self.frames[f].decompress_depth(self.depth_compression_type) 91 | depth = np.fromstring(depth_data, dtype=np.uint16).reshape(self.depth_height, self.depth_width) 92 | if image_size is not None: 93 | depth = cv2.resize(depth, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST) 94 | imageio.imwrite(os.path.join(output_path, str(f) + '.png'), depth) 95 | 96 | 97 | def export_color_images(self, output_path, image_size=None, frame_skip=1): 98 | if not os.path.exists(output_path): 99 | os.makedirs(output_path) 100 | print ('exporting', len(self.frames)//frame_skip, 'color frames to', output_path) 101 | for f in range(0, len(self.frames), frame_skip): 102 | color = self.frames[f].decompress_color(self.color_compression_type) 103 | if image_size is not None: 104 | color = cv2.resize(color, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST) 105 | imageio.imwrite(os.path.join(output_path, str(f) + '.jpg'), color) 106 | 107 | 108 | def save_mat_to_file(self, matrix, filename): 109 | with open(filename, 'w') as f: 110 | for line in matrix: 111 | np.savetxt(f, line[np.newaxis], fmt='%f') 112 | 113 | 114 | def export_poses(self, output_path, frame_skip=1): 115 | if not os.path.exists(output_path): 116 | os.makedirs(output_path) 117 | print ('exporting', len(self.frames)//frame_skip, 'camera poses to', output_path) 118 | for f in range(0, len(self.frames), frame_skip): 119 | self.save_mat_to_file(self.frames[f].camera_to_world, os.path.join(output_path, str(f) + '.txt')) 120 | 121 | 122 | def export_intrinsics(self, output_path): 123 | if not os.path.exists(output_path): 124 | os.makedirs(output_path) 125 | print ('exporting camera intrinsics to', output_path) 126 | self.save_mat_to_file(self.intrinsic_color, os.path.join(output_path, 'intrinsic_color.txt')) 127 | self.save_mat_to_file(self.extrinsic_color, os.path.join(output_path, 'extrinsic_color.txt')) 128 | self.save_mat_to_file(self.intrinsic_depth, os.path.join(output_path, 'intrinsic_depth.txt')) 129 | self.save_mat_to_file(self.extrinsic_depth, os.path.join(output_path, 'extrinsic_depth.txt')) 130 | -------------------------------------------------------------------------------- /data/scannet/extract_pointcloud.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import open3d as o3d 8 | import cv2 9 | import sys 10 | import numpy as np 11 | from glob import glob 12 | import os 13 | 14 | def pc2obj(pc, filepath='test.obj'): 15 | pc = pc.T 16 | nverts = pc.shape[1] 17 | with open(filepath, 'w') as f: 18 | f.write("# OBJ file\n") 19 | for v in range(nverts): 20 | f.write("v %.4f %.4f %.4f\n" % (pc[0,v],pc[1,v],pc[2,v])) 21 | 22 | nump = 50000 ### Number of points from the depth scans 23 | scenelist = glob(sys.argv[1]+"*") ### Input path to the data 24 | 25 | if not os.path.exists(sys.argv[2]): 26 | os.mkdir(sys.argv[2]) 27 | if not os.path.exists(sys.argv[3]): 28 | os.mkdir(sys.argv[3]) 29 | 30 | datalist = [] 31 | for scene in scenelist: 32 | framelist = glob(scene+"/*") 33 | intrinsic = np.loadtxt(scene.replace("depth", "intrinsic")+"/intrinsic_depth.txt") 34 | intrin_cam = o3d.camera.PinholeCameraIntrinsic() 35 | print (scene) 36 | if not os.path.exists(sys.argv[2]+scene.split("/")[-1]): 37 | os.mkdir(sys.argv[2]+scene.split("/")[-1]) 38 | if not os.path.exists(sys.argv[3]+scene.split("/")[-1]): 39 | os.mkdir(sys.argv[3]+scene.split("/")[-1]) 40 | 41 | success = 0 42 | counter = 10 43 | for fileidx in range(len(framelist)): 44 | counter += 1 45 | if counter < 10: 46 | continue 47 | frame = scene+"/%d.png"%fileidx 48 | depth = frame 49 | rgbpath = frame.replace("depth", "color").replace(".png", ".jpg") 50 | 51 | depth_im = cv2.imread(depth, -1) 52 | try: 53 | o3d_depth = o3d.geometry.Image(depth_im) 54 | rgb_im = cv2.resize(cv2.imread(rgbpath), (depth_im.shape[1], depth_im.shape[0])) 55 | o3d_rgb = o3d.geometry.Image(rgb_im) 56 | o3d_rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(o3d_rgb, o3d_depth, depth_scale=1000.0, depth_trunc=1000.0, convert_rgb_to_intensity=False) 57 | except: 58 | print (frame) 59 | continue 60 | 61 | intrin_cam.set_intrinsics(width=depth_im.shape[1], height=depth_im.shape[0], fx=intrinsic[1,1], fy=intrinsic[0,0], cx=intrinsic[1,2], cy=intrinsic[0,2]) 62 | 63 | pts = o3d.geometry.PointCloud.create_from_rgbd_image(o3d_rgbd, intrin_cam, np.eye(4)) 64 | 65 | if len(np.array(pts.points)) >= nump: 66 | sel_idx = np.random.choice(len(np.array(pts.points)), nump, replace=False) 67 | else: 68 | sel_idx = np.random.choice(len(np.array(pts.points)), nump, replace=True) 69 | temp = np.array(pts.points)[sel_idx] 70 | 71 | color_points = np.array(pts.colors)[sel_idx] 72 | color_points[:,[0,1,2]] = color_points[:,[2,1,0]] 73 | 74 | pts.points = o3d.utility.Vector3dVector(temp) 75 | pts.colors = o3d.utility.Vector3dVector(color_points) 76 | data = np.concatenate([temp,color_points], axis=1) 77 | 78 | o3d.io.write_point_cloud(sys.argv[2]+scene.split("/")[-1]+"/"+frame.split("/")[-1].split(".")[0]+".ply", pts) 79 | np.save(sys.argv[3]+scene.split("/")[-1]+"/"+frame.split("/")[-1].split(".")[0]+".npy", data) 80 | datalist.append(os.path.abspath(sys.argv[3]+scene.split("/")[-1]+"/"+frame.split("/")[-1].split(".")[0]+".npy")) 81 | counter = 0 82 | success += 1 83 | print (success) 84 | 85 | np.save(sys.argv[4], datalist) 86 | -------------------------------------------------------------------------------- /data/scannet/reader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os, sys 9 | 10 | from SensorData import SensorData 11 | 12 | from glob import glob 13 | 14 | # params 15 | parser = argparse.ArgumentParser() 16 | # data paths 17 | parser.add_argument('--scans_path', required=True, help='path to scans folder') 18 | parser.add_argument('--output_path', required=True, help='path to output folder') 19 | parser.add_argument('--export_depth_images', dest='export_depth_images', action='store_true') 20 | parser.add_argument('--export_color_images', dest='export_color_images', action='store_true') 21 | parser.add_argument('--export_poses', dest='export_poses', action='store_true') 22 | parser.add_argument('--export_intrinsics', dest='export_intrinsics', action='store_true') 23 | parser.set_defaults(export_depth_images=False, export_color_images=False, export_poses=False, export_intrinsics=False) 24 | 25 | opt = parser.parse_args() 26 | print(opt) 27 | 28 | 29 | def main(): 30 | scans = glob(opt.scans_path+"/*") 31 | for scan in scans: 32 | scenename = scan.split("/")[-1] 33 | filename = os.path.join(scan, scenename+".sens") 34 | if not os.path.exists(opt.output_path): 35 | os.makedirs(opt.output_path) 36 | os.makedirs(os.path.join(opt.output_path, 'depth')) 37 | os.makedirs(os.path.join(opt.output_path, 'color')) 38 | os.makedirs(os.path.join(opt.output_path, 'pose')) 39 | os.makedirs(os.path.join(opt.output_path, 'intrinsic')) 40 | # load the data 41 | sys.stdout.write('loading %s...' % filename) 42 | sd = SensorData(filename) 43 | sys.stdout.write('loaded!\n') 44 | if opt.export_depth_images: 45 | sd.export_depth_images(os.path.join(opt.output_path, 'depth', scenename)) 46 | if opt.export_color_images: 47 | sd.export_color_images(os.path.join(opt.output_path, 'color', scenename)) 48 | if opt.export_poses: 49 | sd.export_poses(os.path.join(opt.output_path, 'pose', scenename)) 50 | if opt.export_intrinsics: 51 | sd.export_intrinsics(os.path.join(opt.output_path, 'intrinsic', scenename)) 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /data/waymo/README.md: -------------------------------------------------------------------------------- 1 | ### Prepare Dataset 2 | 3 | 1. Download Waymo data [HERE](https://waymo.com/open/). We suggest to use [this](https://github.com/RalphMao/Waymo-Dataset-Tool) tool for batch downloading. 4 | 5 | 2. Please install this [tool](https://github.com/waymo-research/waymo-open-dataset) for data preprocessing. 6 | 7 | 3. Create a data_folder for extracting point clouds. Use extract_pointcloud.py to extract point clouds. 8 | python extract_pointcloud.py /path/to/downloaded_segments /path/to/data_folder waymo.npy 9 | -------------------------------------------------------------------------------- /data/waymo/extract_pointcloud.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import sys 8 | import os 9 | import tensorflow as tf 10 | import numpy as np 11 | import glob 12 | from open3d import * 13 | 14 | tf.enable_eager_execution() 15 | 16 | from waymo_open_dataset.utils import frame_utils 17 | from waymo_open_dataset import dataset_pb2 as open_dataset 18 | 19 | 20 | debug = False 21 | K = 5 22 | SCALE_FACTOR = 2 23 | segments = glob.glob(sys.argv[1]+"/*") 24 | datalist = [] 25 | 26 | def pc2obj(pc, filepath='test.obj'): 27 | pc = pc.T 28 | nverts = pc.shape[1] 29 | with open(filepath, 'w') as f: 30 | f.write("# OBJ file\n") 31 | for v in range(nverts): 32 | f.write("v %.4f %.4f %.4f\n" % (pc[0,v],pc[1,v],pc[2,v])) 33 | 34 | def extract(i): 35 | FILENAME = segments[i] 36 | run = FILENAME.split('segment-')[-1].split('.')[0] 37 | out_base_dir = sys.argv[2]+'/%s/' % run 38 | 39 | if not os.path.exists(out_base_dir): 40 | os.makedirs(out_base_dir) 41 | 42 | dataset = tf.data.TFRecordDataset(FILENAME, compression_type='') 43 | print(FILENAME) 44 | pc, pc_c = [], [] 45 | camID2extr_v2c = {} 46 | camID2intr = {} 47 | 48 | all_static_pc = [] 49 | for frame_cnt, data in enumerate(dataset): 50 | if frame_cnt % 2 != 0: continue ### Set the sampling rate here 51 | 52 | print('frame ', frame_cnt) 53 | frame = open_dataset.Frame() 54 | frame.ParseFromString(bytearray(data.numpy())) 55 | 56 | extr_laser2v = np.array(frame.context.laser_calibrations[0].extrinsic.transform).reshape(4, 4) 57 | extr_v2w = np.array(frame.pose.transform).reshape(4, 4) 58 | 59 | if frame_cnt == 0: 60 | for k in range(len(frame.context.camera_calibrations)): 61 | cameraID = frame.context.camera_calibrations[k].name 62 | extr_c2v =\ 63 | np.array(frame.context.camera_calibrations[k].extrinsic.transform).reshape(4, 4) 64 | extr_v2c = np.linalg.inv(extr_c2v) 65 | camID2extr_v2c[frame.images[k].name] = extr_v2c 66 | fx = frame.context.camera_calibrations[k].intrinsic[0] 67 | fy = frame.context.camera_calibrations[k].intrinsic[1] 68 | cx = frame.context.camera_calibrations[k].intrinsic[2] 69 | cy = frame.context.camera_calibrations[k].intrinsic[3] 70 | k1 = frame.context.camera_calibrations[k].intrinsic[4] 71 | k2 = frame.context.camera_calibrations[k].intrinsic[5] 72 | p1 = frame.context.camera_calibrations[k].intrinsic[6] 73 | p2 = frame.context.camera_calibrations[k].intrinsic[7] 74 | k3 = frame.context.camera_calibrations[k].intrinsic[8] 75 | camID2intr[frame.images[k].name] = np.array([[fx, 0, cx], 76 | [0, fy, cy], 77 | [0, 0, 1]]) 78 | 79 | 80 | # lidar point cloud 81 | 82 | (range_images, camera_projections, range_image_top_pose) = \ 83 | frame_utils.parse_range_image_and_camera_projection(frame) 84 | points, cp_points = frame_utils.convert_range_image_to_point_cloud( 85 | frame, 86 | range_images, 87 | camera_projections, 88 | range_image_top_pose) 89 | 90 | 91 | points_all = np.concatenate(points, axis=0) 92 | np.save('%s/frame_%03d.npy' % (out_base_dir, frame_cnt), points_all) 93 | datalist.append(os.path.abspath('%s/frame_%03d.npy' % (out_base_dir, frame_cnt))) 94 | 95 | if __name__ == '__main__': 96 | for i in range(len(segments)): 97 | extract(i) 98 | np.save(sys.argv[3], datalist) 99 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | import logging 9 | 10 | import torch 11 | from datasets.collators import get_collator 12 | from datasets.depth_dataset import DepthContrastDataset 13 | from torch.utils.data import DataLoader 14 | 15 | 16 | __all__ = ["DepthContrastDataset", "get_data_files"] 17 | 18 | 19 | def build_dataset(cfg): 20 | dataset = DepthContrastDataset(cfg) 21 | return dataset 22 | 23 | 24 | def print_sampler_config(data_sampler): 25 | sampler_cfg = { 26 | "num_replicas": data_sampler.num_replicas, 27 | "rank": data_sampler.rank, 28 | "epoch": data_sampler.epoch, 29 | "num_samples": data_sampler.num_samples, 30 | "total_size": data_sampler.total_size, 31 | "shuffle": data_sampler.shuffle, 32 | } 33 | if hasattr(data_sampler, "start_iter"): 34 | sampler_cfg["start_iter"] = data_sampler.start_iter 35 | if hasattr(data_sampler, "batch_size"): 36 | sampler_cfg["batch_size"] = data_sampler.batch_size 37 | logging.info("Distributed Sampler config:\n{}".format(sampler_cfg)) 38 | 39 | 40 | def get_loader(dataset, dataset_config, num_dataloader_workers, pin_memory): 41 | data_sampler = None 42 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 43 | assert torch.distributed.is_initialized(), "Torch distributed isn't initalized" 44 | data_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 45 | logging.info("Created the Distributed Sampler....") 46 | print_sampler_config(data_sampler) 47 | else: 48 | logging.warning( 49 | "Distributed trainer not initialized. Not using the sampler and data will NOT be shuffled" # NOQA 50 | ) 51 | collate_function = get_collator(dataset_config["COLLATE_FUNCTION"]) 52 | dataloader = DataLoader( 53 | dataset=dataset, 54 | num_workers=num_dataloader_workers, 55 | pin_memory=pin_memory, 56 | shuffle=False, 57 | batch_size=dataset_config["BATCHSIZE_PER_REPLICA"], 58 | collate_fn=collate_function, 59 | sampler=data_sampler, 60 | drop_last=dataset_config["DROP_LAST"], 61 | ) 62 | return dataloader 63 | -------------------------------------------------------------------------------- /datasets/collators/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from datasets.collators.point_moco_collator import point_moco_collator 10 | try: 11 | from datasets.collators.vox_moco_collator import vox_moco_collator 12 | from datasets.collators.point_vox_moco_collator import point_vox_moco_collator 13 | except: 14 | print ("Cannot import minkowski engine. Try spconv next") 15 | from datasets.collators.point_vox_moco_lidar_collator import point_vox_moco_collator 16 | from torch.utils.data.dataloader import default_collate 17 | 18 | 19 | COLLATORS_MAP = { 20 | "default": default_collate, 21 | "point_moco_collator": point_moco_collator, 22 | "point_vox_moco_collator": point_vox_moco_collator, 23 | } 24 | 25 | 26 | def get_collator(name): 27 | assert name in COLLATORS_MAP, "Unknown collator" 28 | return COLLATORS_MAP[name] 29 | 30 | 31 | __all__ = ["get_collator"] 32 | -------------------------------------------------------------------------------- /datasets/collators/point_moco_collator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import torch 10 | 11 | def point_moco_collator(batch): 12 | batch_size = len(batch) 13 | 14 | data_point = [x["data"] for x in batch] 15 | data_moco = [x["data_moco"] for x in batch] 16 | # labels are repeated N+1 times but they are the same 17 | labels = [x["label"][0] for x in batch] 18 | labels = torch.LongTensor(labels).squeeze() 19 | 20 | # data valid is repeated N+1 times but they are the same 21 | data_valid = torch.BoolTensor([x["data_valid"][0] for x in batch]) 22 | 23 | points_moco = torch.stack([data_moco[i][0] for i in range(batch_size)]) 24 | points = torch.stack([data_point[i][0] for i in range(batch_size)]) 25 | 26 | output_batch = { 27 | "points": points, 28 | "points_moco": points_moco, 29 | "label": labels, 30 | "data_valid": data_valid, 31 | } 32 | 33 | return output_batch 34 | -------------------------------------------------------------------------------- /datasets/collators/point_vox_moco_collator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import torch 10 | 11 | from datasets.transforms import transforms 12 | 13 | import numpy as np 14 | 15 | collate_fn = transforms.cfl_collate_fn_factory(0) 16 | 17 | def point_vox_moco_collator(batch): 18 | batch_size = len(batch) 19 | 20 | point = [x["data"] for x in batch] 21 | point_moco = [x["data_moco"] for x in batch] 22 | vox = [x["vox"] for x in batch] 23 | vox_moco = [x["vox_moco"] for x in batch] 24 | # labels are repeated N+1 times but they are the same 25 | labels = [x["label"][0] for x in batch] 26 | labels = torch.LongTensor(labels).squeeze() 27 | 28 | # data valid is repeated N+1 times but they are the same 29 | data_valid = torch.BoolTensor([x["data_valid"][0] for x in batch]) 30 | 31 | points_moco = torch.stack([point_moco[i][0] for i in range(batch_size)]) 32 | points = torch.stack([point[i][0] for i in range(batch_size)]) 33 | 34 | vox_moco = collate_fn([vox_moco[i][0] for i in range(batch_size)]) 35 | vox = collate_fn([vox[i][0] for i in range(batch_size)]) 36 | 37 | output_batch = { 38 | "points": points, 39 | "points_moco": points_moco, 40 | "vox": vox, 41 | "vox_moco": vox_moco, 42 | "label": labels, 43 | "data_valid": data_valid, 44 | } 45 | 46 | return output_batch 47 | -------------------------------------------------------------------------------- /datasets/collators/point_vox_moco_lidar_collator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import torch 10 | 11 | import numpy as np 12 | 13 | def point_vox_moco_collator(batch): 14 | batch_size = len(batch) 15 | 16 | point = [x["data"] for x in batch] 17 | point_moco = [x["data_moco"] for x in batch] 18 | vox = [x["vox"] for x in batch] 19 | vox_moco = [x["vox_moco"] for x in batch] 20 | 21 | labels = [x["label"][0] for x in batch] 22 | labels = torch.LongTensor(labels).squeeze() 23 | 24 | # data valid is repeated N+1 times but they are the same 25 | data_valid = torch.BoolTensor([x["data_valid"][0] for x in batch]) 26 | 27 | points_moco = torch.stack([point_moco[i][0] for i in range(batch_size)]) 28 | points = torch.stack([point[i][0] for i in range(batch_size)]) 29 | 30 | vox_data = {"voxels":[], "voxel_coords":[], "voxel_num_points":[]} 31 | counter = 0 32 | for data in vox: 33 | temp = data[0] 34 | voxels_shape = temp["voxels"].shape 35 | voxel_num_points_shape = temp["voxel_num_points"].shape 36 | voxel_coords_shape = temp["voxel_coords"].shape 37 | for key,val in temp.items(): 38 | if key in ['voxels', 'voxel_num_points']: 39 | if len(vox_data[key]) > 0: 40 | vox_data[key] = np.concatenate([vox_data[key], val], axis=0) 41 | else: 42 | vox_data[key] = val 43 | elif key == 'voxel_coords': 44 | coor = np.pad(val, ((0, 0), (1, 0)), mode='constant', constant_values=counter) 45 | if len(vox_data[key]) > 0: 46 | vox_data[key] = np.concatenate([vox_data[key], coor], axis=0) 47 | else: 48 | vox_data[key] = coor 49 | counter += 1 50 | 51 | vox_moco_data = {"voxels":[], "voxel_coords":[], "voxel_num_points":[]} 52 | counter = 0 53 | for data in vox_moco: 54 | temp = data[0] 55 | voxels_shape = temp["voxels"].shape 56 | voxel_num_points_shape = temp["voxel_num_points"].shape 57 | voxel_coords_shape = temp["voxel_coords"].shape 58 | for key,val in temp.items(): 59 | if key in ['voxels', 'voxel_num_points']: 60 | if len(vox_moco_data[key]) > 0: 61 | vox_moco_data[key] = np.concatenate([vox_moco_data[key], val], axis=0) 62 | else: 63 | vox_moco_data[key] = val 64 | elif key in 'voxel_coords': 65 | coor = np.pad(val, ((0, 0), (1, 0)), mode='constant', constant_values=counter) 66 | 67 | if len(vox_moco_data[key]) > 0: 68 | vox_moco_data[key] = np.concatenate([vox_moco_data[key], coor], axis=0) 69 | else: 70 | vox_moco_data[key] = coor 71 | counter += 1 72 | 73 | vox_data = {k:torch.from_numpy(vox_data[k]) for k in vox_data} 74 | vox_moco_data = {k:torch.from_numpy(vox_moco_data[k]) for k in vox_moco_data} 75 | 76 | output_batch = { 77 | "points": points, 78 | "points_moco": points_moco, 79 | "vox": vox_data, 80 | "vox_moco": vox_moco_data, 81 | "label": labels, 82 | "data_valid": data_valid, 83 | } 84 | 85 | return output_batch 86 | -------------------------------------------------------------------------------- /datasets/collators/vox_moco_collator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import torch 10 | 11 | from datasets.transforms import transforms 12 | 13 | import numpy as np 14 | 15 | collate_fn = transforms.cfl_collate_fn_factory(0) 16 | 17 | def vox_moco_collator(batch): 18 | batch_size = len(batch) 19 | 20 | data_point = [x["data"] for x in batch] 21 | data_moco = [x["data_moco"] for x in batch] 22 | # labels are repeated N+1 times but they are the same 23 | labels = [int(x["label"][0]) for x in batch] 24 | labels = torch.LongTensor(labels).squeeze() 25 | 26 | # data valid is repeated N+1 times but they are the same 27 | data_valid = torch.BoolTensor([x["data_valid"][0] for x in batch]) 28 | 29 | vox_moco = collate_fn([data_moco[i][0] for i in range(batch_size)]) 30 | vox = collate_fn([data_point[i][0] for i in range(batch_size)]) 31 | 32 | output_batch = { 33 | "vox": vox, 34 | "vox_moco": vox_moco, 35 | "label": labels, 36 | "data_valid": data_valid, 37 | } 38 | 39 | return output_batch 40 | -------------------------------------------------------------------------------- /datasets/transforms/voxelizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import collections 8 | 9 | import numpy as np 10 | import MinkowskiEngine as ME 11 | from scipy.linalg import expm, norm 12 | 13 | 14 | # Rotation matrix along axis with angle theta 15 | def M(axis, theta): 16 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta)) 17 | 18 | 19 | class Voxelizer: 20 | 21 | def __init__(self, 22 | voxel_size=1, 23 | clip_bound=None, 24 | use_augmentation=False, 25 | scale_augmentation_bound=None, 26 | rotation_augmentation_bound=None, 27 | translation_augmentation_ratio_bound=None, 28 | ignore_label=255): 29 | """ 30 | Args: 31 | voxel_size: side length of a voxel 32 | clip_bound: boundary of the voxelizer. Points outside the bound will be deleted 33 | expects either None or an array like ((-100, 100), (-100, 100), (-100, 100)). 34 | scale_augmentation_bound: None or (0.9, 1.1) 35 | rotation_augmentation_bound: None or ((np.pi / 6, np.pi / 6), None, None) for 3 axis. 36 | Use random order of x, y, z to prevent bias. 37 | translation_augmentation_bound: ((-5, 5), (0, 0), (-10, 10)) 38 | ignore_label: label assigned for ignore (not a training label). 39 | """ 40 | self.voxel_size = voxel_size 41 | self.clip_bound = clip_bound 42 | self.ignore_label = ignore_label 43 | 44 | # Augmentation 45 | self.use_augmentation = use_augmentation 46 | self.scale_augmentation_bound = scale_augmentation_bound 47 | self.rotation_augmentation_bound = rotation_augmentation_bound 48 | self.translation_augmentation_ratio_bound = translation_augmentation_ratio_bound 49 | 50 | def get_transformation_matrix(self): 51 | voxelization_matrix, rotation_matrix = np.eye(4), np.eye(4) 52 | # Get clip boundary from config or pointcloud. 53 | # Get inner clip bound to crop from. 54 | 55 | # Transform pointcloud coordinate to voxel coordinate. 56 | # 1. Random rotation 57 | rot_mat = np.eye(3) 58 | if self.use_augmentation and self.rotation_augmentation_bound is not None: 59 | if isinstance(self.rotation_augmentation_bound, collections.Iterable): 60 | rot_mats = [] 61 | for axis_ind, rot_bound in enumerate(self.rotation_augmentation_bound): 62 | theta = 0 63 | axis = np.zeros(3) 64 | axis[axis_ind] = 1 65 | if rot_bound is not None: 66 | theta = np.random.uniform(*rot_bound) 67 | rot_mats.append(M(axis, theta)) 68 | # Use random order 69 | np.random.shuffle(rot_mats) 70 | rot_mat = rot_mats[0] @ rot_mats[1] @ rot_mats[2] 71 | else: 72 | raise ValueError() 73 | rotation_matrix[:3, :3] = rot_mat 74 | # 2. Scale and translate to the voxel space. 75 | scale = 1 / self.voxel_size 76 | if self.use_augmentation and self.scale_augmentation_bound is not None: 77 | scale *= np.random.uniform(*self.scale_augmentation_bound) 78 | np.fill_diagonal(voxelization_matrix[:3, :3], scale) 79 | # Get final transformation matrix. 80 | return voxelization_matrix, rotation_matrix 81 | 82 | def clip(self, coords, center=None, trans_aug_ratio=None): 83 | bound_min = np.min(coords, 0).astype(float) 84 | bound_max = np.max(coords, 0).astype(float) 85 | bound_size = bound_max - bound_min 86 | if center is None: 87 | center = bound_min + bound_size * 0.5 88 | if trans_aug_ratio is not None: 89 | trans = np.multiply(trans_aug_ratio, bound_size) 90 | center += trans 91 | lim = self.clip_bound 92 | 93 | if isinstance(self.clip_bound, (int, float)): 94 | if bound_size.max() < self.clip_bound: 95 | return None 96 | else: 97 | clip_inds = ((coords[:, 0] >= (-lim + center[0])) & \ 98 | (coords[:, 0] < (lim + center[0])) & \ 99 | (coords[:, 1] >= (-lim + center[1])) & \ 100 | (coords[:, 1] < (lim + center[1])) & \ 101 | (coords[:, 2] >= (-lim + center[2])) & \ 102 | (coords[:, 2] < (lim + center[2]))) 103 | return clip_inds 104 | 105 | # Clip points outside the limit 106 | clip_inds = ((coords[:, 0] >= (lim[0][0] + center[0])) & \ 107 | (coords[:, 0] < (lim[0][1] + center[0])) & \ 108 | (coords[:, 1] >= (lim[1][0] + center[1])) & \ 109 | (coords[:, 1] < (lim[1][1] + center[1])) & \ 110 | (coords[:, 2] >= (lim[2][0] + center[2])) & \ 111 | (coords[:, 2] < (lim[2][1] + center[2]))) 112 | return clip_inds 113 | 114 | def voxelize(self, coords, feats, labels, center=None): 115 | assert coords.shape[1] == 3 and coords.shape[0] == feats.shape[0] and coords.shape[0] 116 | if self.clip_bound is not None: 117 | trans_aug_ratio = np.zeros(3) 118 | if self.use_augmentation and self.translation_augmentation_ratio_bound is not None: 119 | for axis_ind, trans_ratio_bound in enumerate(self.translation_augmentation_ratio_bound): 120 | trans_aug_ratio[axis_ind] = np.random.uniform(*trans_ratio_bound) 121 | 122 | clip_inds = self.clip(coords, center, trans_aug_ratio) 123 | if clip_inds is not None: 124 | coords, feats = coords[clip_inds], feats[clip_inds] 125 | if labels is not None: 126 | labels = labels[clip_inds] 127 | 128 | # Get rotation and scale 129 | M_v, M_r = self.get_transformation_matrix() 130 | # Apply transformations 131 | rigid_transformation = M_v 132 | if self.use_augmentation: 133 | rigid_transformation = M_r @ rigid_transformation 134 | 135 | homo_coords = np.hstack((coords, np.ones((coords.shape[0], 1), dtype=coords.dtype))) 136 | coords_aug = np.floor(homo_coords @ rigid_transformation.T[:, :3]) 137 | 138 | # key = self.hash(coords_aug) # floor happens by astype(np.uint64) 139 | coords_aug, feats, labels = ME.utils.sparse_quantize( 140 | coords_aug, feats, labels=labels, ignore_label=self.ignore_label) 141 | 142 | return coords_aug, feats, labels, rigid_transformation.flatten() 143 | 144 | def voxelize_temporal(self, 145 | coords_t, 146 | feats_t, 147 | labels_t, 148 | centers=None, 149 | return_transformation=False): 150 | # Legacy code, remove 151 | if centers is None: 152 | centers = [ 153 | None, 154 | ] * len(coords_t) 155 | coords_tc, feats_tc, labels_tc, transformation_tc = [], [], [], [] 156 | 157 | # ######################### Data Augmentation ############################# 158 | # Get rotation and scale 159 | M_v, M_r = self.get_transformation_matrix() 160 | # Apply transformations 161 | rigid_transformation = M_v 162 | if self.use_augmentation: 163 | rigid_transformation = M_r @ rigid_transformation 164 | # ######################### Voxelization ############################# 165 | # Voxelize coords 166 | for coords, feats, labels, center in zip(coords_t, feats_t, labels_t, centers): 167 | 168 | ################################### 169 | # Clip the data if bound exists 170 | if self.clip_bound is not None: 171 | trans_aug_ratio = np.zeros(3) 172 | if self.use_augmentation and self.translation_augmentation_ratio_bound is not None: 173 | for axis_ind, trans_ratio_bound in enumerate(self.translation_augmentation_ratio_bound): 174 | trans_aug_ratio[axis_ind] = np.random.uniform(*trans_ratio_bound) 175 | 176 | clip_inds = self.clip(coords, center, trans_aug_ratio) 177 | if clip_inds is not None: 178 | coords, feats = coords[clip_inds], feats[clip_inds] 179 | if labels is not None: 180 | labels = labels[clip_inds] 181 | ################################### 182 | 183 | homo_coords = np.hstack((coords, np.ones((coords.shape[0], 1), dtype=coords.dtype))) 184 | coords_aug = np.floor(homo_coords @ rigid_transformation.T)[:, :3] 185 | 186 | coords_aug, feats, labels = ME.utils.sparse_quantize( 187 | coords_aug, feats, labels=labels, ignore_label=self.ignore_label) 188 | 189 | coords_tc.append(coords_aug) 190 | feats_tc.append(feats) 191 | labels_tc.append(labels) 192 | transformation_tc.append(rigid_transformation.flatten()) 193 | 194 | return_args = [coords_tc, feats_tc, labels_tc] 195 | if return_transformation: 196 | return_args.append(transformation_tc) 197 | 198 | return tuple(return_args) 199 | 200 | 201 | def test(): 202 | N = 16575 203 | coords = np.random.rand(N, 3) * 10 204 | feats = np.random.rand(N, 4) 205 | labels = np.floor(np.random.rand(N) * 3) 206 | coords[:3] = 0 207 | labels[:3] = 2 208 | voxelizer = Voxelizer() 209 | print(voxelizer.voxelize(coords, feats, labels)) 210 | 211 | 212 | if __name__ == '__main__': 213 | test() 214 | -------------------------------------------------------------------------------- /imgs/method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/DepthContrast/b8257890c94f7c58aeb5cefeb91af031692611d6/imgs/method.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | import os 10 | import random 11 | import time 12 | import warnings 13 | import yaml 14 | 15 | import torch 16 | import torch.nn.parallel 17 | import torch.backends.cudnn as cudnn 18 | import torch.optim 19 | 20 | from torch.multiprocessing import Pool, Process, set_start_method 21 | try: 22 | set_start_method('spawn') 23 | except RuntimeError: 24 | pass 25 | 26 | import torch.multiprocessing as mp 27 | 28 | import utils.logger 29 | from utils import main_utils 30 | 31 | parser = argparse.ArgumentParser(description='PyTorch Self Supervised Training in 3D') 32 | 33 | parser.add_argument('cfg', help='model directory') 34 | parser.add_argument('--quiet', action='store_true') 35 | 36 | parser.add_argument('--world-size', default=-1, type=int, 37 | help='number of nodes for distributed training') 38 | parser.add_argument('--rank', default=-1, type=int, 39 | help='node rank for distributed training') 40 | parser.add_argument('--dist-url', default='tcp://localhost:15475', type=str, 41 | help='url used to set up distributed training') 42 | parser.add_argument('--dist-backend', default='nccl', type=str, 43 | help='distributed backend') 44 | parser.add_argument('--seed', default=None, type=int, 45 | help='seed for initializing training. ') 46 | parser.add_argument('--gpu', default=0, type=int, 47 | help='GPU id to use.') 48 | parser.add_argument('--ngpus', default=8, type=int, 49 | help='number of GPUs to use.') 50 | parser.add_argument('--multiprocessing-distributed', action='store_true', 51 | help='Use multi-processing distributed training to launch ' 52 | 'N processes per node, which has N GPUs. This is the ' 53 | 'fastest way to use PyTorch for either single node or ' 54 | 'multi node data parallel training') 55 | 56 | 57 | def main(): 58 | args = parser.parse_args() 59 | cfg = yaml.safe_load(open(args.cfg)) 60 | 61 | if args.seed is not None: 62 | random.seed(args.seed) 63 | torch.manual_seed(args.seed) 64 | cudnn.deterministic = True 65 | warnings.warn('You have chosen to seed training. ' 66 | 'This will turn on the CUDNN deterministic setting, ' 67 | 'which can slow down your training considerably! ' 68 | 'You may see unexpected behavior when restarting ' 69 | 'from checkpoints.') 70 | 71 | ngpus_per_node = args.ngpus 72 | if args.multiprocessing_distributed: 73 | args.world_size = ngpus_per_node * args.world_size 74 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args, cfg)) 75 | else: 76 | # Simply call main_worker function 77 | main_worker(args.gpu, ngpus_per_node, args, cfg) 78 | 79 | def main_worker(gpu, ngpus, args, cfg): 80 | args.gpu = gpu 81 | ngpus_per_node = ngpus 82 | 83 | # Setup environment 84 | args = main_utils.initialize_distributed_backend(args, ngpus_per_node) ### Use other method instead 85 | logger, tb_writter, model_dir = main_utils.prep_environment(args, cfg) 86 | 87 | # Define model 88 | model = main_utils.build_model(cfg['model'], logger) 89 | model, args = main_utils.distribute_model_to_cuda(model, args) 90 | 91 | # Define dataloaders 92 | train_loader = main_utils.build_dataloaders(cfg['dataset'], cfg['num_workers'], args.multiprocessing_distributed, logger) 93 | 94 | # Define criterion 95 | train_criterion = main_utils.build_criterion(cfg['loss'], logger=logger) 96 | train_criterion = train_criterion.cuda() 97 | 98 | # Define optimizer 99 | optimizer, scheduler = main_utils.build_optimizer( 100 | params=list(model.parameters())+list(train_criterion.parameters()), 101 | cfg=cfg['optimizer'], 102 | logger=logger) 103 | ckp_manager = main_utils.CheckpointManager(model_dir, rank=args.rank, dist=args.multiprocessing_distributed) 104 | 105 | # Optionally resume from a checkpoint 106 | start_epoch, end_epoch = 0, cfg['optimizer']['num_epochs'] 107 | if cfg['resume']: 108 | if ckp_manager.checkpoint_exists(last=True): 109 | start_epoch = ckp_manager.restore(restore_last=True, model=model, optimizer=optimizer, train_criterion=train_criterion) 110 | scheduler.step(start_epoch) 111 | logger.add_line("Checkpoint loaded: '{}' (epoch {})".format(ckp_manager.last_checkpoint_fn(), start_epoch)) 112 | else: 113 | logger.add_line("No checkpoint found at '{}'".format(ckp_manager.last_checkpoint_fn())) 114 | 115 | cudnn.benchmark = True 116 | 117 | ############################ TRAIN ######################################### 118 | test_freq = cfg['test_freq'] if 'test_freq' in cfg else 1 119 | for epoch in range(start_epoch, end_epoch): 120 | if (epoch % 10) == 0: 121 | ckp_manager.save(epoch, model=model, train_criterion=train_criterion, optimizer=optimizer, filename='checkpoint-ep{}.pth.tar'.format(epoch)) 122 | 123 | if args.multiprocessing_distributed: 124 | train_loader.sampler.set_epoch(epoch) 125 | 126 | # Train for one epoch 127 | logger.add_line('='*30 + ' Epoch {} '.format(epoch) + '='*30) 128 | logger.add_line('LR: {}'.format(scheduler.get_lr())) 129 | run_phase('train', train_loader, model, optimizer, train_criterion, epoch, args, cfg, logger, tb_writter) 130 | scheduler.step(epoch) 131 | 132 | if ((epoch % test_freq) == 0) or (epoch == end_epoch - 1): 133 | ckp_manager.save(epoch+1, model=model, optimizer=optimizer, train_criterion=train_criterion) 134 | 135 | 136 | def run_phase(phase, loader, model, optimizer, criterion, epoch, args, cfg, logger, tb_writter): 137 | from utils import metrics_utils 138 | logger.add_line('\n{}: Epoch {}'.format(phase, epoch)) 139 | batch_time = metrics_utils.AverageMeter('Time', ':6.3f', window_size=100) 140 | data_time = metrics_utils.AverageMeter('Data', ':6.3f', window_size=100) 141 | loss_meter = metrics_utils.AverageMeter('Loss', ':.3e') 142 | loss_meter_npid1 = metrics_utils.AverageMeter('Loss_npid1', ':.3e') 143 | loss_meter_npid2 = metrics_utils.AverageMeter('Loss_npid2', ':.3e') 144 | loss_meter_cmc1 = metrics_utils.AverageMeter('Loss_cmc1', ':.3e') 145 | loss_meter_cmc2 = metrics_utils.AverageMeter('Loss_cmc2', ':.3e') 146 | progress = utils.logger.ProgressMeter(len(loader), [batch_time, data_time, loss_meter, loss_meter_npid1, loss_meter_npid2, loss_meter_cmc1, loss_meter_cmc2], phase=phase, epoch=epoch, logger=logger, tb_writter=tb_writter) 147 | 148 | # switch to train mode 149 | model.train(phase == 'train') 150 | 151 | end = time.time() 152 | device = args.gpu if args.gpu is not None else 0 153 | for i, sample in enumerate(loader): 154 | # measure data loading time 155 | data_time.update(time.time() - end) 156 | 157 | if phase == 'train': 158 | embedding = model(sample) 159 | else: 160 | with torch.no_grad(): 161 | embedding = model(sample) 162 | 163 | # compute loss 164 | loss, loss_debug = criterion(embedding) 165 | loss_meter.update(loss.item(), embedding[0].size(0)) 166 | loss_meter_npid1.update(loss_debug[0].item(), embedding[0].size(0)) 167 | loss_meter_npid2.update(loss_debug[1].item(), embedding[0].size(0)) 168 | loss_meter_cmc1.update(loss_debug[2].item(), embedding[0].size(0)) 169 | loss_meter_cmc2.update(loss_debug[3].item(), embedding[0].size(0)) 170 | 171 | # compute gradient and do SGD step during training 172 | if phase == 'train': 173 | optimizer.zero_grad() 174 | loss.backward() 175 | optimizer.step() 176 | 177 | # measure elapsed time 178 | batch_time.update(time.time() - end) 179 | end = time.time() 180 | 181 | # print to terminal and tensorboard 182 | step = epoch * len(loader) + i 183 | if (i+1) % cfg['print_freq'] == 0 or i == 0 or i+1 == len(loader): 184 | progress.display(i+1) 185 | 186 | # Sync metrics across all GPUs and print final averages 187 | if args.multiprocessing_distributed: 188 | progress.synchronize_meters(args.gpu) 189 | progress.display(len(loader)*args.world_size) 190 | 191 | if tb_writter is not None: 192 | for meter in progress.meters: 193 | tb_writter.add_scalar('{}-epoch/{}'.format(phase, meter.name), meter.avg, epoch) 194 | 195 | 196 | if __name__ == '__main__': 197 | main() 198 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree.abs 7 | 8 | from models.base_ssl3d_model import BaseSSLMultiInputOutputModel 9 | 10 | def build_model(model_config, logger): 11 | return BaseSSLMultiInputOutputModel(model_config, logger) 12 | 13 | 14 | __all__ = ["BaseSSLMultiInputOutputModel", "build_model"] 15 | -------------------------------------------------------------------------------- /models/trunks/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | try: 9 | from models.trunks.pointnet import PointNet 10 | except: 11 | PointNet = None 12 | 13 | try: 14 | from models.trunks.spconv.models.res16unet import Res16UNet34 15 | except: 16 | Res16UNet34 = None 17 | 18 | try: 19 | from models.trunks.pointnet2_backbone import PointNet2MSG 20 | from models.trunks.spconv_unet import UNetV2_concat as UNetV2 21 | except: 22 | PointNet2MSG = None 23 | UNetV2 = None 24 | 25 | TRUNKS = { 26 | "pointnet": PointNet, 27 | "unet": Res16UNet34, 28 | "pointnet_msg": PointNet2MSG, 29 | "UNetV2": UNetV2, 30 | } 31 | 32 | 33 | __all__ = ["TRUNKS"] 34 | -------------------------------------------------------------------------------- /models/trunks/mlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import torch.nn as nn 10 | 11 | class MLP(nn.Module): 12 | def __init__( 13 | self, dims, use_bn=False, use_relu=True, use_dropout=False, use_bias=True 14 | ): 15 | super().__init__() 16 | layers = [] 17 | last_dim = dims[0] 18 | counter = 1 19 | for dim in dims[1:]: 20 | layers.append(nn.Linear(last_dim, dim, bias=use_bias)) 21 | counter += 1 22 | if use_bn: 23 | layers.append( 24 | nn.BatchNorm1d( 25 | dim, 26 | eps=1e-5, 27 | momentum=0.1, 28 | ) 29 | ) 30 | if (counter < len(dims)) and use_relu: 31 | layers.append(nn.ReLU(inplace=True)) 32 | last_dim = dim 33 | if use_dropout: 34 | layers.append(nn.Dropout()) 35 | self.clf = nn.Sequential(*layers) 36 | 37 | def forward(self, batch): 38 | out = self.clf(batch) 39 | return out 40 | -------------------------------------------------------------------------------- /models/trunks/pointnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import sys 11 | import os 12 | 13 | from models.trunks.mlp import MLP 14 | 15 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | ROOT_DIR = os.path.dirname(ROOT_DIR) 17 | ROOT_DIR = os.path.dirname(ROOT_DIR) 18 | sys.path.append(os.path.join(ROOT_DIR, 'third_party', 'pointnet2')) 19 | 20 | from pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule 21 | 22 | class PointNet(nn.Module): 23 | r""" 24 | Backbone network for point cloud feature learning. 25 | Based on Pointnet++ single-scale grouping network. 26 | 27 | Parameters 28 | ---------- 29 | input_feature_dim: int 30 | Number of input channels in the feature descriptor for each point. 31 | e.g. 3 for RGB. 32 | """ 33 | def __init__(self, input_feature_dim=0, scale=1, use_mlp=False, mlp_dim=None): 34 | super().__init__() 35 | 36 | self.use_mlp = use_mlp 37 | self.sa1 = PointnetSAModuleVotes( 38 | npoint=2048, 39 | radius=0.2, 40 | nsample=64, 41 | mlp=[input_feature_dim, 64*scale, 64*scale, 128*scale], 42 | use_xyz=True, 43 | normalize_xyz=True 44 | ) 45 | 46 | self.sa2 = PointnetSAModuleVotes( 47 | npoint=1024, 48 | radius=0.4, 49 | nsample=32, 50 | mlp=[128*scale, 128*scale, 128*scale, 256*scale], 51 | use_xyz=True, 52 | normalize_xyz=True 53 | ) 54 | 55 | self.sa3 = PointnetSAModuleVotes( 56 | npoint=512, 57 | radius=0.8, 58 | nsample=16, 59 | mlp=[256*scale, 128*scale, 128*scale, 256*scale], 60 | use_xyz=True, 61 | normalize_xyz=True 62 | ) 63 | 64 | self.sa4 = PointnetSAModuleVotes( 65 | npoint=256, 66 | radius=1.2, 67 | nsample=16, 68 | mlp=[256*scale, 128*scale, 128*scale, 256*scale], 69 | use_xyz=True, 70 | normalize_xyz=True 71 | ) 72 | 73 | if scale == 1: 74 | self.fp1 = PointnetFPModule(mlp=[256+256,512,512]) 75 | self.fp2 = PointnetFPModule(mlp=[512+256,512,512]) 76 | else: 77 | self.fp1 = PointnetFPModule(mlp=[256*scale+256*scale,256*scale,256*scale]) 78 | self.fp2 = PointnetFPModule(mlp=[256*scale+256*scale,256*scale,256*scale]) 79 | 80 | if use_mlp: 81 | self.head = MLP(mlp_dim) 82 | 83 | self.all_feat_names = [ 84 | "sa1", 85 | "sa2", 86 | "sa3", 87 | "sa4", 88 | "fp1", 89 | "fp2", 90 | ] 91 | 92 | def _break_up_pc(self, pc): 93 | xyz = pc[..., 0:3].contiguous() 94 | features = ( 95 | pc[..., 3:].transpose(1, 2).contiguous() 96 | if pc.size(-1) > 3 else None 97 | ) 98 | 99 | return xyz, features 100 | 101 | def forward(self, pointcloud: torch.cuda.FloatTensor, out_feat_keys=["fp2"]): 102 | r""" 103 | Forward pass of the network 104 | 105 | Parameters 106 | ---------- 107 | pointcloud: Variable(torch.cuda.FloatTensor) 108 | (B, N, 3 + input_feature_dim) tensor 109 | Point cloud to run predicts on 110 | Each point in the point-cloud MUST 111 | be formated as (x, y, z, features...) 112 | 113 | Returns 114 | ---------- 115 | end_points: {XXX_xyz, XXX_features, XXX_inds} 116 | XXX_xyz: float32 Tensor of shape (B,K,3) 117 | XXX_features: float32 Tensor of shape (B,K,D) 118 | XXX-inds: int64 Tensor of shape (B,K) values in [0,N-1] 119 | """ 120 | batch_size = pointcloud.shape[0] 121 | 122 | xyz, features = self._break_up_pc(pointcloud) 123 | 124 | features = None ### Do not use other info for now 125 | 126 | end_points = {} 127 | 128 | # --------- 4 SET ABSTRACTION LAYERS --------- 129 | xyz, features, fps_inds = self.sa1(xyz, features) 130 | end_points['sa1_inds'] = fps_inds 131 | end_points['sa1_xyz'] = xyz 132 | end_points['sa1_features'] = features 133 | 134 | xyz, features, fps_inds = self.sa2(xyz, features) # this fps_inds is just 0,1,...,1023 135 | end_points['sa2_inds'] = fps_inds 136 | end_points['sa2_xyz'] = xyz 137 | end_points['sa2_features'] = features 138 | 139 | xyz, features, fps_inds = self.sa3(xyz, features) # this fps_inds is just 0,1,...,511 140 | end_points['sa3_xyz'] = xyz 141 | end_points['sa3_features'] = features 142 | 143 | xyz, features, fps_inds = self.sa4(xyz, features) # this fps_inds is just 0,1,...,255 144 | end_points['sa4_xyz'] = xyz 145 | end_points['sa4_features'] = features 146 | 147 | # --------- 2 FEATURE UPSAMPLING LAYERS -------- 148 | features = self.fp1(end_points['sa3_xyz'], end_points['sa4_xyz'], end_points['sa3_features'], end_points['sa4_features']) 149 | end_points['fp1_features'] = features 150 | features = self.fp2(end_points['sa2_xyz'], end_points['sa3_xyz'], end_points['sa2_features'], features) 151 | end_points['fp2_features'] = features 152 | end_points['fp2_xyz'] = end_points['sa2_xyz'] 153 | num_seed = end_points['fp2_xyz'].shape[1] 154 | end_points['fp2_inds'] = end_points['sa1_inds'][:,0:num_seed] # indices among the entire input point clouds 155 | 156 | out_feats = [None] * len(out_feat_keys) 157 | 158 | final_feat = [] 159 | for key in out_feat_keys: 160 | feat = end_points[key+"_features"] 161 | org_feat = end_points[key+"_features"] 162 | nump = feat.shape[-1] 163 | feat = torch.squeeze(F.max_pool1d(feat, nump)) 164 | ### Apply the head here 165 | if self.use_mlp: 166 | feat = self.head(feat) 167 | out_feats[out_feat_keys.index(key)] = feat 168 | 169 | return out_feats 170 | 171 | 172 | if __name__=='__main__': 173 | backbone_net = Pointnet2Backbone(input_feature_dim=3).cuda() 174 | print(backbone_net) 175 | backbone_net.eval() 176 | out = backbone_net(torch.rand(16,20000,6).cuda()) 177 | for key in sorted(out.keys()): 178 | print(key, '\t', out[key].shape) 179 | -------------------------------------------------------------------------------- /models/trunks/pointnet2_backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import os 12 | import sys 13 | 14 | from models.trunks.mlp import MLP 15 | 16 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | ROOT_DIR = os.path.dirname(ROOT_DIR) 18 | ROOT_DIR = os.path.dirname(ROOT_DIR) 19 | sys.path.append(os.path.join(ROOT_DIR, 'third_party', 'OpenPCDet', "pcdet")) 20 | 21 | from ops.pointnet2.pointnet2_batch import pointnet2_modules 22 | 23 | class PointNet2MSG(nn.Module): 24 | def __init__(self, use_mlp=False, mlp_dim=None): 25 | super().__init__() 26 | 27 | input_channels = 4 28 | 29 | self.SA_modules = nn.ModuleList() 30 | channel_in = input_channels - 3 31 | 32 | self.num_points_each_layer = [] 33 | skip_channel_list = [input_channels - 3] 34 | SA_CONFIG = {'NPOINTS': [4096, 1024, 256, 64], 'RADIUS': [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]], 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [16, 32]], 'MLPS': [[[16, 16, 32], [32, 32, 64]], [[64, 64, 128], [64, 96, 128]], [[128, 196, 256], [128, 196, 256]], [[256, 256, 512], [256, 384, 512]]]} 35 | 36 | FP_MLPS = [[128, 128], [256, 256], [512, 512], [512, 512]] 37 | 38 | for k in range(SA_CONFIG["NPOINTS"].__len__()): 39 | mlps = SA_CONFIG["MLPS"][k].copy() 40 | channel_out = 0 41 | for idx in range(mlps.__len__()): 42 | mlps[idx] = [channel_in] + mlps[idx] 43 | channel_out += mlps[idx][-1] 44 | 45 | self.SA_modules.append( 46 | pointnet2_modules.PointnetSAModuleMSG( 47 | npoint=SA_CONFIG["NPOINTS"][k], 48 | radii=SA_CONFIG["RADIUS"][k], 49 | nsamples=SA_CONFIG["NSAMPLE"][k], 50 | mlps=mlps, 51 | use_xyz=True, 52 | ) 53 | ) 54 | skip_channel_list.append(channel_out) 55 | channel_in = channel_out 56 | 57 | self.FP_modules = nn.ModuleList() 58 | 59 | for k in range(FP_MLPS.__len__()): 60 | pre_channel = FP_MLPS[k + 1][-1] if k + 1 < len(FP_MLPS) else channel_out 61 | self.FP_modules.append( 62 | pointnet2_modules.PointnetFPModule( 63 | mlp=[pre_channel + skip_channel_list[k]] + FP_MLPS[k] 64 | ) 65 | ) 66 | 67 | self.num_point_features = FP_MLPS[0][-1] 68 | 69 | self.all_feat_names = [ 70 | "fp2", 71 | ] 72 | 73 | if use_mlp: 74 | self.use_mlp = True 75 | self.head = MLP(mlp_dim) 76 | 77 | 78 | def break_up_pc(self, pc): 79 | #batch_idx = pc[:, 0] 80 | xyz = pc[:, :, 0:3].contiguous() 81 | features = (pc[:, :, 3:].contiguous() if pc.size(-1) > 3 else None) 82 | return xyz, features 83 | 84 | def forward(self, pointcloud: torch.cuda.FloatTensor, out_feat_keys=None): 85 | """ 86 | Args: 87 | batch_dict: 88 | batch_size: int 89 | vfe_features: (num_voxels, C) 90 | points: (num_points, 4 + C), [batch_idx, x, y, z, ...] 91 | Returns: 92 | batch_dict: 93 | encoded_spconv_tensor: sparse tensor 94 | point_features: (N, C) 95 | """ 96 | batch_size = pointcloud.shape[0] 97 | points = pointcloud 98 | xyz, features = self.break_up_pc(points) 99 | 100 | features = features.view(batch_size, -1, features.shape[-1]).permute(0, 2, 1).contiguous() if features is not None else None 101 | l_xyz, l_features = [xyz], [features] 102 | for i in range(len(self.SA_modules)): 103 | assert l_features[i].is_contiguous() 104 | li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) 105 | assert li_features.is_contiguous() 106 | l_xyz.append(li_xyz) 107 | l_features.append(li_features) 108 | 109 | for i in range(-1, -(len(self.FP_modules) + 1), -1): 110 | assert l_features[i - 1].is_contiguous() 111 | assert l_features[i].is_contiguous() 112 | l_features[i - 1] = self.FP_modules[i]( 113 | l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i] 114 | ) # (B, C, N) 115 | assert l_features[i - 1].is_contiguous() 116 | 117 | point_features = l_features[0] 118 | 119 | end_points = {} 120 | end_points['fp2_features'] = point_features 121 | 122 | out_feats = [None] * len(out_feat_keys) 123 | for key in out_feat_keys: 124 | feat = end_points[key+"_features"] 125 | nump = feat.shape[-1] 126 | 127 | feat = torch.squeeze(F.max_pool1d(feat, nump)) 128 | if self.use_mlp: 129 | feat = self.head(feat) 130 | out_feats[out_feat_keys.index(key)] = feat 131 | 132 | return out_feats 133 | -------------------------------------------------------------------------------- /models/trunks/smlp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import torch.nn as nn 10 | import MinkowskiEngine as ME 11 | 12 | class SMLP(nn.Module): 13 | def __init__( 14 | self, dims, use_bn=False, use_relu=False, use_dropout=False, use_bias=True 15 | ): 16 | super().__init__() 17 | layers = [] 18 | last_dim = dims[0] 19 | counter = 1 20 | for dim in dims[1:]: 21 | layers.append(ME.MinkowskiLinear(last_dim, dim, bias=use_bias)) 22 | counter += 1 23 | if use_bn: 24 | layers.append( 25 | ME.MinkowskiBatchNorm( 26 | dim, 27 | eps=1e-5, 28 | momentum=0.1, 29 | ) 30 | ) 31 | if (counter < len(dims)) and use_relu: 32 | layers.append(ME.MinkowskiReLU(inplace=True)) 33 | last_dim = dim 34 | if use_dropout: 35 | layers.append(MinkowskiDropout.Dropout()) 36 | self.clf = nn.Sequential(*layers) 37 | 38 | def forward(self, batch): 39 | out = self.clf(batch) 40 | return out 41 | -------------------------------------------------------------------------------- /models/trunks/spconv/lib/math_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from scipy.sparse import csr_matrix 7 | import torch 8 | 9 | 10 | class SparseMM(torch.autograd.Function): 11 | """ 12 | Sparse x dense matrix multiplication with autograd support. 13 | Implementation by Soumith Chintala: 14 | https://discuss.pytorch.org/t/ 15 | does-pytorch-support-autograd-on-sparse-matrix/6156/7 16 | """ 17 | 18 | def forward(self, matrix1, matrix2): 19 | self.save_for_backward(matrix1, matrix2) 20 | return torch.mm(matrix1, matrix2) 21 | 22 | def backward(self, grad_output): 23 | matrix1, matrix2 = self.saved_tensors 24 | grad_matrix1 = grad_matrix2 = None 25 | 26 | if self.needs_input_grad[0]: 27 | grad_matrix1 = torch.mm(grad_output, matrix2.t()) 28 | 29 | if self.needs_input_grad[1]: 30 | grad_matrix2 = torch.mm(matrix1.t(), grad_output) 31 | 32 | return grad_matrix1, grad_matrix2 33 | 34 | 35 | def sparse_float_tensor(values, indices, size=None): 36 | """ 37 | Return a torch sparse matrix give values and indices (row_ind, col_ind). 38 | If the size is an integer, return a square matrix with side size. 39 | If the size is a torch.Size, use it to initialize the out tensor. 40 | If none, the size is inferred. 41 | """ 42 | indices = torch.stack(indices).int() 43 | sargs = [indices, values.float()] 44 | if size is not None: 45 | # Use the provided size 46 | if isinstance(size, int): 47 | size = torch.Size((size, size)) 48 | sargs.append(size) 49 | if values.is_cuda: 50 | return torch.cuda.sparse.FloatTensor(*sargs) 51 | else: 52 | return torch.sparse.FloatTensor(*sargs) 53 | 54 | 55 | def diags(values, size=None): 56 | values = values.view(-1) 57 | n = values.nelement() 58 | size = torch.Size((n, n)) 59 | indices = (torch.arange(0, n), torch.arange(0, n)) 60 | return sparse_float_tensor(values, indices, size) 61 | 62 | 63 | def sparse_to_csr_matrix(tensor): 64 | tensor = tensor.cpu() 65 | inds = tensor._indices().numpy() 66 | vals = tensor._values().numpy() 67 | return csr_matrix((vals, (inds[0], inds[1])), shape=[s for s in tensor.shape]) 68 | 69 | 70 | def csr_matrix_to_sparse(mat): 71 | row_ind, col_ind = mat.nonzero() 72 | return sparse_float_tensor( 73 | torch.from_numpy(mat.data), 74 | (torch.from_numpy(row_ind), torch.from_numpy(col_ind)), 75 | size=torch.Size(mat.shape)) 76 | -------------------------------------------------------------------------------- /models/trunks/spconv/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from models.trunks.spconv.models import resunet 7 | from models.trunks.spconv.models import res16unet 8 | 9 | # from models.trilateral_crf import TrilateralCRF 10 | from models.trunks.spconv.models.conditional_random_fields import BilateralCRF, TrilateralCRF 11 | 12 | MODELS = [] 13 | 14 | 15 | def add_models(module): 16 | MODELS.extend([getattr(module, a) for a in dir(module) if 'Net' in a]) 17 | 18 | 19 | add_models(resunet) 20 | add_models(res16unet) 21 | 22 | WRAPPERS = [BilateralCRF, TrilateralCRF] 23 | 24 | 25 | def get_models(): 26 | '''Returns a tuple of sample models.''' 27 | return MODELS 28 | 29 | 30 | def get_wrappers(): 31 | return WRAPPERS 32 | 33 | 34 | def load_model(name): 35 | '''Creates and returns an instance of the model given its class name. 36 | ''' 37 | # Find the model class from its name 38 | all_models = get_models() 39 | mdict = {model.__name__: model for model in all_models} 40 | if name not in mdict: 41 | print('Invalid model index. Options are:') 42 | # Display a list of valid model names 43 | for model in all_models: 44 | print('\t* {}'.format(model.__name__)) 45 | return None 46 | NetClass = mdict[name] 47 | 48 | return NetClass 49 | 50 | 51 | def load_wrapper(name): 52 | '''Creates and returns an instance of the model given its class name. 53 | ''' 54 | # Find the model class from its name 55 | all_wrappers = get_wrappers() 56 | mdict = {wrapper.__name__: wrapper for wrapper in all_wrappers} 57 | if name not in mdict: 58 | print('Invalid wrapper index. Options are:') 59 | # Display a list of valid model names 60 | for wrapper in all_wrappers: 61 | print('\t* {}'.format(wrapper.__name__)) 62 | return None 63 | WrapperClass = mdict[name] 64 | 65 | return WrapperClass 66 | -------------------------------------------------------------------------------- /models/trunks/spconv/models/conditional_random_fields.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | from MinkowskiEngine import SparseTensor, MinkowskiConvolution, MinkowskiConvolutionFunction, convert_to_int_tensor 11 | from MinkowskiEngine import convert_region_type as me_convert_region_type 12 | 13 | from models.trunks.spconv.models.model import HighDimensionalModel 14 | from models.trunks.spconv.models.wrapper import Wrapper 15 | from models.trunks.spconv.lib.math_functions import SparseMM 16 | from models.trunks.spconv.models.modules.common import convert_region_type 17 | 18 | 19 | class MeanField(HighDimensionalModel): 20 | """ 21 | Abstract class for the bilateral and trilateral meanfield 22 | """ 23 | OUT_PIXEL_DIST = 1 24 | 25 | # To use the model, must call initialize_coords before forward pass. 26 | # Once data is processed, call clear to reset the model before calling 27 | # initialize_coords 28 | def __init__(self, nchannels, spatial_sigma, chromatic_sigma, meanfield_iterations, is_temporal, 29 | config, **kwargs): 30 | D = 7 if is_temporal else 6 31 | self.is_temporal = is_temporal 32 | # Setup metadata 33 | super(MeanField, self).__init__(nchannels, nchannels, config, D=D) 34 | 35 | self.spatial_sigma = spatial_sigma 36 | self.chromatic_sigma = chromatic_sigma 37 | # temporal sigma is 1 38 | self.meanfield_iterations = meanfield_iterations 39 | 40 | self.pixel_dist = 1 41 | self.stride = 1 42 | self.dilation = 1 43 | 44 | conv = MinkowskiConvolution( 45 | nchannels, 46 | nchannels, 47 | kernel_size=config.wrapper_kernel_size, 48 | has_bias=False, 49 | region_type=convert_region_type(config.wrapper_region_type), 50 | dimension=D) 51 | 52 | # Create a region_offset 53 | self.region_type_, self.region_offset_, _ = me_convert_region_type( 54 | conv.region_type, 1, conv.kernel_size, conv.up_stride, conv.dilation, conv.region_offset, 55 | conv.axis_types, conv.dimension) 56 | 57 | # Check whether the mapping is required 58 | self.requires_mapping = False 59 | self.conv = conv 60 | self.kernel = conv.kernel 61 | self.convs = {} 62 | self.softmaxes = {} 63 | for i in range(self.meanfield_iterations): 64 | self.softmaxes[i] = nn.Softmax(dim=1) 65 | self.convs[i] = MinkowskiConvolutionFunction() 66 | 67 | def initialize_coords(self, model, in_coords, in_color): 68 | if torch.prod(convert_to_int_tensor(model.OUT_PIXEL_DIST, model.D)) != 1: 69 | self.requires_mapping = True 70 | 71 | out_coords = model.get_coords(model.OUT_PIXEL_DIST) 72 | out_color = model.permute_feature(in_color, model.OUT_PIXEL_DIST).int() 73 | 74 | # Tri/Bi-lateral grid 75 | out_tri_coords = torch.cat( 76 | [ 77 | (torch.floor(out_coords[:, :3].float() / self.spatial_sigma)).int(), 78 | (torch.floor(out_color.float() / self.chromatic_sigma)).int(), 79 | out_coords[:, 3:] # (time and) batch 80 | ], 81 | dim=1) 82 | orig_tri_coords = torch.cat( 83 | [ 84 | (torch.floor(in_coords[:, :3].float() / self.spatial_sigma)).int(), 85 | (torch.floor(in_color.float() / self.chromatic_sigma)).int(), 86 | in_coords[:, 3:] # (time and) batch 87 | ], 88 | dim=1) 89 | 90 | crf_tri_coords = torch.cat((out_tri_coords, orig_tri_coords), dim=0) 91 | 92 | # Create a trilateral Grid 93 | # super(MeanField, self).initialize_coords_with_duplicates(crf_tri_coords) 94 | 95 | # Create Sparse matrix mappings to/from the CRF coords 96 | in_cols = self.get_index_map(out_tri_coords, 1) 97 | self.in_mapping = torch.sparse.FloatTensor( 98 | torch.stack((in_cols.long(), torch.arange(in_cols.size(0), out=torch.LongTensor()))), 99 | torch.ones(in_cols.size(0)), torch.Size((self.n_rows, in_cols.size(0)))) 100 | 101 | out_cols = self.get_index_map(orig_tri_coords, 1) 102 | self.out_mapping = torch.sparse.FloatTensor( 103 | torch.stack((torch.arange(out_cols.size(0), out=torch.LongTensor()), out_cols.long())), 104 | torch.ones(out_cols.size(0)), torch.Size((out_cols.size(0), self.n_rows))) 105 | 106 | if self.config.is_cuda: 107 | self.in_mapping, self.out_mapping = self.in_mapping.cuda(), self.out_mapping.cuda() 108 | 109 | else: 110 | self.requires_mapping = False 111 | 112 | out_coords = in_coords 113 | out_color = in_color 114 | crf_tri_coords = torch.cat( 115 | [ 116 | (torch.floor(in_coords[:, :3].float() / self.spatial_sigma)).int(), 117 | (torch.floor(in_color.float() / self.chromatic_sigma)).int(), 118 | in_coords[:, 3:], # (time and) batch 119 | ], 120 | dim=1) 121 | 122 | return crf_tri_coords 123 | 124 | def forward(self, x): 125 | xf = x.F 126 | if self.requires_mapping: 127 | # Map the network output to CRF input 128 | xf = SparseMM()(Variable(self.in_mapping), xf) 129 | 130 | out = xf 131 | for i in range(self.meanfield_iterations): # Meanfield iteration 132 | # Normalization 133 | out = self.softmaxes[i](out) 134 | # Pairwise potential 135 | out = self.convs[i].apply(out, self.conv.kernel, x.pixel_dist, self.conv.stride, 136 | self.conv.kernel_size, self.conv.dilation, self.region_type_, 137 | self.region_offset_, x.coords_key, x.coords_key, x.coords_man) 138 | # Add unary 139 | out += xf 140 | 141 | if self.requires_mapping: 142 | # Map the CRF output to the origianl space 143 | out = SparseMM()(Variable(self.out_mapping), out) 144 | 145 | return SparseTensor(out, coords_key=x.coords_key, coords_manager=x.coords_man) 146 | 147 | 148 | class BilateralCRF(Wrapper): 149 | OUT_PIXEL_DIST = 1 150 | 151 | def initialize_filter(self, NetClass, in_nchannel, out_nchannel, config): 152 | self.model = NetClass(in_nchannel, out_nchannel, config) 153 | self.filter = MeanField( 154 | out_nchannel, 155 | spatial_sigma=config.crf_spatial_sigma, 156 | chromatic_sigma=config.crf_chromatic_sigma, 157 | meanfield_iterations=config.meanfield_iterations, 158 | is_temporal=False, 159 | config=config) 160 | 161 | 162 | class TrilateralCRF(Wrapper): 163 | OUT_PIXEL_DIST = 1 164 | 165 | def initialize_filter(self, NetClass, in_nchannel, out_nchannel, config): 166 | self.model = NetClass(in_nchannel, out_nchannel, config) 167 | self.filter = MeanField( 168 | out_nchannel, 169 | spatial_sigma=config.crf_spatial_sigma, 170 | chromatic_sigma=config.crf_chromatic_sigma, 171 | meanfield_iterations=config.meanfield_iterations, 172 | is_temporal=True, 173 | config=config) 174 | -------------------------------------------------------------------------------- /models/trunks/spconv/models/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from MinkowskiEngine import MinkowskiNetwork 7 | 8 | 9 | class Model(MinkowskiNetwork): 10 | """ 11 | Base network for all sparse convnet 12 | 13 | By default, all networks are segmentation networks. 14 | """ 15 | OUT_PIXEL_DIST = -1 16 | 17 | def __init__(self, in_channels, out_channels, D, **kwargs): 18 | super(Model, self).__init__(D) 19 | self.in_channels = in_channels 20 | self.out_channels = out_channels 21 | #self.config = config 22 | 23 | 24 | class HighDimensionalModel(Model): 25 | """ 26 | Base network for all spatio (temporal) chromatic sparse convnet 27 | """ 28 | 29 | def __init__(self, in_channels, out_channels, config, D, **kwargs): 30 | assert D > 4, "Num dimension smaller than 5" 31 | super(HighDimensionalModel, self).__init__(in_channels, out_channels, config, D, **kwargs) 32 | -------------------------------------------------------------------------------- /models/trunks/spconv/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /models/trunks/spconv/models/modules/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import collections 8 | from enum import Enum 9 | import torch.nn as nn 10 | 11 | import MinkowskiEngine as ME 12 | 13 | 14 | class NormType(Enum): 15 | BATCH_NORM = 0 16 | INSTANCE_NORM = 1 17 | INSTANCE_BATCH_NORM = 2 18 | 19 | 20 | def get_norm(norm_type, n_channels, D, bn_momentum=0.1): 21 | if norm_type == NormType.BATCH_NORM: 22 | return ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum) 23 | elif norm_type == NormType.INSTANCE_NORM: 24 | return ME.MinkowskiInstanceNorm(n_channels) 25 | elif norm_type == NormType.INSTANCE_BATCH_NORM: 26 | return nn.Sequential( 27 | ME.MinkowskiInstanceNorm(n_channels), 28 | ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum)) 29 | else: 30 | raise ValueError(f'Norm type: {norm_type} not supported') 31 | 32 | 33 | class ConvType(Enum): 34 | """ 35 | Define the kernel region type 36 | """ 37 | HYPERCUBE = 0, 'HYPERCUBE' 38 | SPATIAL_HYPERCUBE = 1, 'SPATIAL_HYPERCUBE' 39 | SPATIO_TEMPORAL_HYPERCUBE = 2, 'SPATIO_TEMPORAL_HYPERCUBE' 40 | HYPERCROSS = 3, 'HYPERCROSS' 41 | SPATIAL_HYPERCROSS = 4, 'SPATIAL_HYPERCROSS' 42 | SPATIO_TEMPORAL_HYPERCROSS = 5, 'SPATIO_TEMPORAL_HYPERCROSS' 43 | SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS = 6, 'SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS ' 44 | 45 | def __new__(cls, value, name): 46 | member = object.__new__(cls) 47 | member._value_ = value 48 | member.fullname = name 49 | return member 50 | 51 | def __int__(self): 52 | return self.value 53 | 54 | 55 | # Covert the ConvType var to a RegionType var 56 | conv_to_region_type = { 57 | # kernel_size = [k, k, k, 1] 58 | ConvType.HYPERCUBE: ME.RegionType.HYPERCUBE, 59 | ConvType.SPATIAL_HYPERCUBE: ME.RegionType.HYPERCUBE, 60 | ConvType.SPATIO_TEMPORAL_HYPERCUBE: ME.RegionType.HYPERCUBE, 61 | ConvType.HYPERCROSS: ME.RegionType.HYPERCROSS, 62 | ConvType.SPATIAL_HYPERCROSS: ME.RegionType.HYPERCROSS, 63 | ConvType.SPATIO_TEMPORAL_HYPERCROSS: ME.RegionType.HYPERCROSS, 64 | ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: ME.RegionType.HYBRID 65 | } 66 | 67 | int_to_region_type = {m.value: m for m in ME.RegionType} 68 | 69 | 70 | def convert_region_type(region_type): 71 | """ 72 | Convert the integer region_type to the corresponding RegionType enum object. 73 | """ 74 | return int_to_region_type[region_type] 75 | 76 | 77 | def convert_conv_type(conv_type, kernel_size, D): 78 | assert isinstance(conv_type, ConvType), "conv_type must be of ConvType" 79 | region_type = conv_to_region_type[conv_type] 80 | axis_types = None 81 | if conv_type == ConvType.SPATIAL_HYPERCUBE: 82 | # No temporal convolution 83 | if isinstance(kernel_size, collections.Sequence): 84 | kernel_size = kernel_size[:3] 85 | else: 86 | kernel_size = [ 87 | kernel_size, 88 | ] * 3 89 | if D == 4: 90 | kernel_size.append(1) 91 | elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCUBE: 92 | # conv_type conversion already handled 93 | assert D == 4 94 | elif conv_type == ConvType.HYPERCUBE: 95 | # conv_type conversion already handled 96 | pass 97 | elif conv_type == ConvType.SPATIAL_HYPERCROSS: 98 | if isinstance(kernel_size, collections.Sequence): 99 | kernel_size = kernel_size[:3] 100 | else: 101 | kernel_size = [ 102 | kernel_size, 103 | ] * 3 104 | if D == 4: 105 | kernel_size.append(1) 106 | elif conv_type == ConvType.HYPERCROSS: 107 | # conv_type conversion already handled 108 | pass 109 | elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCROSS: 110 | # conv_type conversion already handled 111 | assert D == 4 112 | elif conv_type == ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: 113 | # Define the CUBIC conv kernel for spatial dims and CROSS conv for temp dim 114 | axis_types = [ 115 | ME.RegionType.HYPERCUBE, 116 | ] * 3 117 | if D == 4: 118 | axis_types.append(ME.RegionType.HYPERCROSS) 119 | return region_type, axis_types, kernel_size 120 | 121 | 122 | def conv(in_planes, 123 | out_planes, 124 | kernel_size, 125 | stride=1, 126 | dilation=1, 127 | bias=False, 128 | conv_type=ConvType.HYPERCUBE, 129 | D=-1): 130 | assert D > 0, 'Dimension must be a positive integer' 131 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 132 | kernel_generator = ME.KernelGenerator( 133 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D) 134 | 135 | return ME.MinkowskiConvolution( 136 | in_channels=in_planes, 137 | out_channels=out_planes, 138 | kernel_size=kernel_size, 139 | stride=stride, 140 | dilation=dilation, 141 | has_bias=bias, 142 | kernel_generator=kernel_generator, 143 | dimension=D) 144 | 145 | 146 | def conv_tr(in_planes, 147 | out_planes, 148 | kernel_size, 149 | upsample_stride=1, 150 | dilation=1, 151 | bias=False, 152 | conv_type=ConvType.HYPERCUBE, 153 | D=-1): 154 | assert D > 0, 'Dimension must be a positive integer' 155 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 156 | kernel_generator = ME.KernelGenerator( 157 | kernel_size, 158 | upsample_stride, 159 | dilation, 160 | region_type=region_type, 161 | axis_types=axis_types, 162 | dimension=D) 163 | 164 | return ME.MinkowskiConvolutionTranspose( 165 | in_channels=in_planes, 166 | out_channels=out_planes, 167 | kernel_size=kernel_size, 168 | stride=upsample_stride, 169 | dilation=dilation, 170 | has_bias=bias, 171 | kernel_generator=kernel_generator, 172 | dimension=D) 173 | 174 | 175 | def avg_pool(kernel_size, 176 | stride=1, 177 | dilation=1, 178 | conv_type=ConvType.HYPERCUBE, 179 | in_coords_key=None, 180 | D=-1): 181 | assert D > 0, 'Dimension must be a positive integer' 182 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 183 | kernel_generator = ME.KernelGenerator( 184 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D) 185 | 186 | return ME.MinkowskiAvgPooling( 187 | kernel_size=kernel_size, 188 | stride=stride, 189 | dilation=dilation, 190 | kernel_generator=kernel_generator, 191 | dimension=D) 192 | 193 | 194 | def avg_unpool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1): 195 | assert D > 0, 'Dimension must be a positive integer' 196 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 197 | kernel_generator = ME.KernelGenerator( 198 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D) 199 | 200 | return ME.MinkowskiAvgUnpooling( 201 | kernel_size=kernel_size, 202 | stride=stride, 203 | dilation=dilation, 204 | kernel_generator=kernel_generator, 205 | dimension=D) 206 | 207 | 208 | def sum_pool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1): 209 | assert D > 0, 'Dimension must be a positive integer' 210 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 211 | kernel_generator = ME.KernelGenerator( 212 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D) 213 | 214 | return ME.MinkowskiSumPooling( 215 | kernel_size=kernel_size, 216 | stride=stride, 217 | dilation=dilation, 218 | kernel_generator=kernel_generator, 219 | dimension=D) 220 | -------------------------------------------------------------------------------- /models/trunks/spconv/models/modules/resnet_block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch.nn as nn 7 | 8 | from models.trunks.spconv.models.modules.common import ConvType, NormType, get_norm, conv 9 | 10 | from MinkowskiEngine import MinkowskiReLU 11 | 12 | 13 | class BasicBlockBase(nn.Module): 14 | expansion = 1 15 | NORM_TYPE = NormType.BATCH_NORM 16 | 17 | def __init__(self, 18 | inplanes, 19 | planes, 20 | stride=1, 21 | dilation=1, 22 | downsample=None, 23 | conv_type=ConvType.HYPERCUBE, 24 | bn_momentum=0.1, 25 | D=3): 26 | super(BasicBlockBase, self).__init__() 27 | 28 | self.conv1 = conv( 29 | inplanes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D) 30 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 31 | self.conv2 = conv( 32 | planes, 33 | planes, 34 | kernel_size=3, 35 | stride=1, 36 | dilation=dilation, 37 | bias=False, 38 | conv_type=conv_type, 39 | D=D) 40 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 41 | self.relu = MinkowskiReLU(inplace=True) 42 | self.downsample = downsample 43 | 44 | def forward(self, x): 45 | residual = x 46 | 47 | out = self.conv1(x) 48 | out = self.norm1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.norm2(out) 53 | 54 | if self.downsample is not None: 55 | residual = self.downsample(x) 56 | 57 | out += residual 58 | out = self.relu(out) 59 | 60 | return out 61 | 62 | 63 | class BasicBlock(BasicBlockBase): 64 | NORM_TYPE = NormType.BATCH_NORM 65 | 66 | 67 | class BasicBlockIN(BasicBlockBase): 68 | NORM_TYPE = NormType.INSTANCE_NORM 69 | 70 | 71 | class BasicBlockINBN(BasicBlockBase): 72 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM 73 | 74 | 75 | class BottleneckBase(nn.Module): 76 | expansion = 4 77 | NORM_TYPE = NormType.BATCH_NORM 78 | 79 | def __init__(self, 80 | inplanes, 81 | planes, 82 | stride=1, 83 | dilation=1, 84 | downsample=None, 85 | conv_type=ConvType.HYPERCUBE, 86 | bn_momentum=0.1, 87 | D=3): 88 | super(BottleneckBase, self).__init__() 89 | self.conv1 = conv(inplanes, planes, kernel_size=1, D=D) 90 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 91 | 92 | self.conv2 = conv( 93 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D) 94 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 95 | 96 | self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, D=D) 97 | self.norm3 = get_norm(self.NORM_TYPE, planes * self.expansion, D, bn_momentum=bn_momentum) 98 | 99 | self.relu = MinkowskiReLU(inplace=True) 100 | self.downsample = downsample 101 | 102 | def forward(self, x): 103 | residual = x 104 | 105 | out = self.conv1(x) 106 | out = self.norm1(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv2(out) 110 | out = self.norm2(out) 111 | out = self.relu(out) 112 | 113 | out = self.conv3(out) 114 | out = self.norm3(out) 115 | 116 | if self.downsample is not None: 117 | residual = self.downsample(x) 118 | 119 | out += residual 120 | out = self.relu(out) 121 | 122 | return out 123 | 124 | 125 | class Bottleneck(BottleneckBase): 126 | NORM_TYPE = NormType.BATCH_NORM 127 | 128 | 129 | class BottleneckIN(BottleneckBase): 130 | NORM_TYPE = NormType.INSTANCE_NORM 131 | 132 | 133 | class BottleneckINBN(BottleneckBase): 134 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM 135 | -------------------------------------------------------------------------------- /models/trunks/spconv/models/modules/senet_block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch.nn as nn 7 | 8 | import MinkowskiEngine as ME 9 | 10 | from models.trunks.spconv.models.modules.common import ConvType, NormType 11 | from models.trunks.spconv.models.modules.resnet_block import BasicBlock, Bottleneck 12 | 13 | 14 | class SELayer(nn.Module): 15 | 16 | def __init__(self, channel, reduction=16, D=-1): 17 | # Global coords does not require coords_key 18 | super(SELayer, self).__init__() 19 | self.fc = nn.Sequential( 20 | ME.MinkowskiLinear(channel, channel // reduction), ME.MinkowskiReLU(inplace=True), 21 | ME.MinkowskiLinear(channel // reduction, channel), ME.MinkowskiSigmoid()) 22 | self.pooling = ME.MinkowskiGlobalPooling(dimension=D) 23 | self.broadcast_mul = ME.MinkowskiBroadcastMultiplication(dimension=D) 24 | 25 | def forward(self, x): 26 | y = self.pooling(x) 27 | y = self.fc(y) 28 | return self.broadcast_mul(x, y) 29 | 30 | 31 | class SEBasicBlock(BasicBlock): 32 | 33 | def __init__(self, 34 | inplanes, 35 | planes, 36 | stride=1, 37 | dilation=1, 38 | downsample=None, 39 | conv_type=ConvType.HYPERCUBE, 40 | reduction=16, 41 | D=-1): 42 | super(SEBasicBlock, self).__init__( 43 | inplanes, 44 | planes, 45 | stride=stride, 46 | dilation=dilation, 47 | downsample=downsample, 48 | conv_type=conv_type, 49 | D=D) 50 | self.se = SELayer(planes, reduction=reduction, D=D) 51 | 52 | def forward(self, x): 53 | residual = x 54 | 55 | out = self.conv1(x) 56 | out = self.norm1(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv2(out) 60 | out = self.norm2(out) 61 | out = self.se(out) 62 | 63 | if self.downsample is not None: 64 | residual = self.downsample(x) 65 | 66 | out += residual 67 | out = self.relu(out) 68 | 69 | return out 70 | 71 | 72 | class SEBasicBlockSN(SEBasicBlock): 73 | NORM_TYPE = NormType.SPARSE_SWITCH_NORM 74 | 75 | 76 | class SEBasicBlockIN(SEBasicBlock): 77 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 78 | 79 | 80 | class SEBasicBlockLN(SEBasicBlock): 81 | NORM_TYPE = NormType.SPARSE_LAYER_NORM 82 | 83 | 84 | class SEBottleneck(Bottleneck): 85 | 86 | def __init__(self, 87 | inplanes, 88 | planes, 89 | stride=1, 90 | dilation=1, 91 | downsample=None, 92 | conv_type=ConvType.HYPERCUBE, 93 | D=3, 94 | reduction=16): 95 | super(SEBottleneck, self).__init__( 96 | inplanes, 97 | planes, 98 | stride=stride, 99 | dilation=dilation, 100 | downsample=downsample, 101 | conv_type=conv_type, 102 | D=D) 103 | self.se = SELayer(planes * self.expansion, reduction=reduction, D=D) 104 | 105 | def forward(self, x): 106 | residual = x 107 | 108 | out = self.conv1(x) 109 | out = self.norm1(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv2(out) 113 | out = self.norm2(out) 114 | out = self.relu(out) 115 | 116 | out = self.conv3(out) 117 | out = self.norm3(out) 118 | out = self.se(out) 119 | 120 | if self.downsample is not None: 121 | residual = self.downsample(x) 122 | 123 | out += residual 124 | out = self.relu(out) 125 | 126 | return out 127 | 128 | 129 | class SEBottleneckSN(SEBottleneck): 130 | NORM_TYPE = NormType.SPARSE_SWITCH_NORM 131 | 132 | 133 | class SEBottleneckIN(SEBottleneck): 134 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM 135 | 136 | 137 | class SEBottleneckLN(SEBottleneck): 138 | NORM_TYPE = NormType.SPARSE_LAYER_NORM 139 | -------------------------------------------------------------------------------- /models/trunks/spconv/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch.nn as nn 7 | 8 | import MinkowskiEngine as ME 9 | 10 | from models.trunks.spconv.models.model import Model 11 | from models.trunks.spconv.models.modules.common import ConvType, NormType, get_norm, conv, sum_pool 12 | from models.trunks.spconv.models.modules.resnet_block import BasicBlock, Bottleneck 13 | 14 | 15 | class ResNetBase(Model): 16 | BLOCK = None 17 | LAYERS = () 18 | INIT_DIM = 64 19 | PLANES = (64, 128, 256, 512) 20 | OUT_PIXEL_DIST = 32 21 | HAS_LAST_BLOCK = False 22 | CONV_TYPE = ConvType.HYPERCUBE 23 | 24 | def __init__(self, in_channels=1, out_channels=1, D=3, **kwargs): 25 | assert self.BLOCK is not None 26 | assert self.OUT_PIXEL_DIST > 0 27 | 28 | super(ResNetBase, self).__init__(in_channels, out_channels, D, **kwargs) 29 | 30 | self.network_initialization(in_channels, out_channels, D) 31 | self.weight_initialization() 32 | 33 | def network_initialization(self, in_channels, out_channels, D): 34 | 35 | def space_n_time_m(n, m): 36 | return n if D == 3 else [n, n, n, m] 37 | 38 | if D == 4: 39 | self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) 40 | 41 | dilations = config.dilations 42 | bn_momentum = config.bn_momentum 43 | self.inplanes = self.INIT_DIM 44 | self.conv1 = conv( 45 | in_channels, 46 | self.inplanes, 47 | kernel_size=space_n_time_m(config.conv1_kernel_size, 1), 48 | stride=1, 49 | D=D) 50 | 51 | self.bn1 = get_norm(NormType.BATCH_NORM, self.inplanes, D=self.D, bn_momentum=bn_momentum) 52 | self.relu = ME.MinkowskiReLU(inplace=True) 53 | self.pool = sum_pool(kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), D=D) 54 | 55 | self.layer1 = self._make_layer( 56 | self.BLOCK, 57 | self.PLANES[0], 58 | self.LAYERS[0], 59 | stride=space_n_time_m(2, 1), 60 | dilation=space_n_time_m(dilations[0], 1)) 61 | self.layer2 = self._make_layer( 62 | self.BLOCK, 63 | self.PLANES[1], 64 | self.LAYERS[1], 65 | stride=space_n_time_m(2, 1), 66 | dilation=space_n_time_m(dilations[1], 1)) 67 | self.layer3 = self._make_layer( 68 | self.BLOCK, 69 | self.PLANES[2], 70 | self.LAYERS[2], 71 | stride=space_n_time_m(2, 1), 72 | dilation=space_n_time_m(dilations[2], 1)) 73 | self.layer4 = self._make_layer( 74 | self.BLOCK, 75 | self.PLANES[3], 76 | self.LAYERS[3], 77 | stride=space_n_time_m(2, 1), 78 | dilation=space_n_time_m(dilations[3], 1)) 79 | 80 | self.final = conv( 81 | self.PLANES[3] * self.BLOCK.expansion, out_channels, kernel_size=1, bias=True, D=D) 82 | 83 | def weight_initialization(self): 84 | for m in self.modules(): 85 | if isinstance(m, ME.MinkowskiBatchNorm): 86 | nn.init.constant_(m.bn.weight, 1) 87 | nn.init.constant_(m.bn.bias, 0) 88 | 89 | def _make_layer(self, 90 | block, 91 | planes, 92 | blocks, 93 | stride=1, 94 | dilation=1, 95 | norm_type=NormType.BATCH_NORM, 96 | bn_momentum=0.1): 97 | downsample = None 98 | if stride != 1 or self.inplanes != planes * block.expansion: 99 | downsample = nn.Sequential( 100 | conv( 101 | self.inplanes, 102 | planes * block.expansion, 103 | kernel_size=1, 104 | stride=stride, 105 | bias=False, 106 | D=self.D), 107 | get_norm(norm_type, planes * block.expansion, D=self.D, bn_momentum=bn_momentum), 108 | ) 109 | layers = [] 110 | layers.append( 111 | block( 112 | self.inplanes, 113 | planes, 114 | stride=stride, 115 | dilation=dilation, 116 | downsample=downsample, 117 | conv_type=self.CONV_TYPE, 118 | D=self.D)) 119 | self.inplanes = planes * block.expansion 120 | for i in range(1, blocks): 121 | layers.append( 122 | block( 123 | self.inplanes, 124 | planes, 125 | stride=1, 126 | dilation=dilation, 127 | conv_type=self.CONV_TYPE, 128 | D=self.D)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | x = self.conv1(x) 134 | x = self.bn1(x) 135 | x = self.relu(x) 136 | x = self.pool(x) 137 | 138 | x = self.layer1(x) 139 | x = self.layer2(x) 140 | x = self.layer3(x) 141 | x = self.layer4(x) 142 | 143 | x = self.final(x) 144 | return x 145 | 146 | 147 | class ResNet14(ResNetBase): 148 | BLOCK = BasicBlock 149 | LAYERS = (1, 1, 1, 1) 150 | 151 | 152 | class ResNet18(ResNetBase): 153 | BLOCK = BasicBlock 154 | LAYERS = (2, 2, 2, 2) 155 | 156 | 157 | class ResNet34(ResNetBase): 158 | BLOCK = BasicBlock 159 | LAYERS = (3, 4, 6, 3) 160 | 161 | 162 | class ResNet50(ResNetBase): 163 | BLOCK = Bottleneck 164 | LAYERS = (3, 4, 6, 3) 165 | 166 | 167 | class ResNet101(ResNetBase): 168 | BLOCK = Bottleneck 169 | LAYERS = (3, 4, 23, 3) 170 | 171 | 172 | class STResNetBase(ResNetBase): 173 | 174 | CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS 175 | 176 | def __init__(self, in_channels, out_channels, config, D=4, **kwargs): 177 | super(STResNetBase, self).__init__(in_channels, out_channels, config, D, **kwargs) 178 | 179 | 180 | class STResNet14(STResNetBase, ResNet14): 181 | pass 182 | 183 | 184 | class STResNet18(STResNetBase, ResNet18): 185 | pass 186 | 187 | 188 | class STResNet34(STResNetBase, ResNet34): 189 | pass 190 | 191 | 192 | class STResNet50(STResNetBase, ResNet50): 193 | pass 194 | 195 | 196 | class STResNet101(STResNetBase, ResNet101): 197 | pass 198 | 199 | 200 | class STResTesseractNetBase(STResNetBase): 201 | CONV_TYPE = ConvType.HYPERCUBE 202 | 203 | 204 | class STResTesseractNet14(STResTesseractNetBase, STResNet14): 205 | pass 206 | 207 | 208 | class STResTesseractNet18(STResTesseractNetBase, STResNet18): 209 | pass 210 | 211 | 212 | class STResTesseractNet34(STResTesseractNetBase, STResNet34): 213 | pass 214 | 215 | 216 | class STResTesseractNet50(STResTesseractNetBase, STResNet50): 217 | pass 218 | 219 | 220 | class STResTesseractNet101(STResTesseractNetBase, STResNet101): 221 | pass 222 | -------------------------------------------------------------------------------- /models/trunks/spconv/models/wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import random 7 | from torch.nn import Module 8 | 9 | from MinkowskiEngine import SparseTensor 10 | 11 | 12 | class Wrapper(Module): 13 | """ 14 | Wrapper for the segmentation networks. 15 | """ 16 | OUT_PIXEL_DIST = -1 17 | 18 | def __init__(self, NetClass, in_nchannel, out_nchannel, config): 19 | super(Wrapper, self).__init__() 20 | self.initialize_filter(NetClass, in_nchannel, out_nchannel, config) 21 | 22 | def initialize_filter(self, NetClass, in_nchannel, out_nchannel, config): 23 | raise NotImplementedError('Must initialize a model and a filter') 24 | 25 | def forward(self, x, coords, colors=None): 26 | soutput = self.model(x) 27 | 28 | # During training, make the network invariant to the filter 29 | if not self.training or random.random() < 0.5: 30 | # Filter requires the model to finish the forward pass 31 | wrapper_coords = self.filter.initialize_coords(self.model, coords, colors) 32 | finput = SparseTensor(soutput.F, wrapper_coords) 33 | soutput = self.filter(finput) 34 | 35 | return soutput 36 | -------------------------------------------------------------------------------- /models/trunks/spconv_backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from functools import partial 8 | 9 | import spconv 10 | import torch.nn as nn 11 | 12 | 13 | def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stride=1, padding=0, 14 | conv_type='subm', norm_fn=None): 15 | 16 | if conv_type == 'subm': 17 | conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, bias=False, indice_key=indice_key) 18 | elif conv_type == 'spconv': 19 | conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, 20 | bias=False, indice_key=indice_key) 21 | elif conv_type == 'inverseconv': 22 | conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, indice_key=indice_key, bias=False) 23 | else: 24 | raise NotImplementedError 25 | 26 | m = spconv.SparseSequential( 27 | conv, 28 | norm_fn(out_channels), 29 | nn.ReLU(), 30 | ) 31 | 32 | return m 33 | 34 | 35 | class SparseBasicBlock(spconv.SparseModule): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, norm_fn=None, downsample=None, indice_key=None): 39 | super(SparseBasicBlock, self).__init__() 40 | 41 | assert norm_fn is not None 42 | bias = norm_fn is not None 43 | self.conv1 = spconv.SubMConv3d( 44 | inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key 45 | ) 46 | self.bn1 = norm_fn(planes) 47 | self.relu = nn.ReLU() 48 | self.conv2 = spconv.SubMConv3d( 49 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key 50 | ) 51 | self.bn2 = norm_fn(planes) 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | identity = x 57 | 58 | out = self.conv1(x) 59 | out.features = self.bn1(out.features) 60 | out.features = self.relu(out.features) 61 | 62 | out = self.conv2(out) 63 | out.features = self.bn2(out.features) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out.features += identity.features 69 | out.features = self.relu(out.features) 70 | 71 | return out 72 | 73 | 74 | class VoxelBackBone8x(nn.Module): 75 | def __init__(self, model_cfg, input_channels, grid_size, **kwargs): 76 | super().__init__() 77 | self.model_cfg = model_cfg 78 | norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) 79 | 80 | self.sparse_shape = grid_size[::-1] + [1, 0, 0] 81 | 82 | self.conv_input = spconv.SparseSequential( 83 | spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key='subm1'), 84 | norm_fn(16), 85 | nn.ReLU(), 86 | ) 87 | block = post_act_block 88 | 89 | self.conv1 = spconv.SparseSequential( 90 | block(16, 16, 3, norm_fn=norm_fn, padding=1, indice_key='subm1'), 91 | ) 92 | 93 | self.conv2 = spconv.SparseSequential( 94 | # [1600, 1408, 41] <- [800, 704, 21] 95 | block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'), 96 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'), 97 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'), 98 | ) 99 | 100 | self.conv3 = spconv.SparseSequential( 101 | # [800, 704, 21] <- [400, 352, 11] 102 | block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'), 103 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'), 104 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'), 105 | ) 106 | 107 | self.conv4 = spconv.SparseSequential( 108 | # [400, 352, 11] <- [200, 176, 5] 109 | block(64, 64, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='spconv'), 110 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), 111 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), 112 | ) 113 | 114 | last_pad = 0 115 | last_pad = self.model_cfg.get('last_pad', last_pad) 116 | self.conv_out = spconv.SparseSequential( 117 | # [200, 150, 5] -> [200, 150, 2] 118 | spconv.SparseConv3d(64, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad, 119 | bias=False, indice_key='spconv_down2'), 120 | norm_fn(128), 121 | nn.ReLU(), 122 | ) 123 | self.num_point_features = 128 124 | 125 | def forward(self, batch_dict): 126 | """ 127 | Args: 128 | batch_dict: 129 | batch_size: int 130 | vfe_features: (num_voxels, C) 131 | voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx] 132 | Returns: 133 | batch_dict: 134 | encoded_spconv_tensor: sparse tensor 135 | """ 136 | voxel_features, voxel_coords = batch_dict['voxel_features'], batch_dict['voxel_coords'] 137 | batch_size = batch_dict['batch_size'] 138 | input_sp_tensor = spconv.SparseConvTensor( 139 | features=voxel_features, 140 | indices=voxel_coords.int(), 141 | spatial_shape=self.sparse_shape, 142 | batch_size=batch_size 143 | ) 144 | 145 | x = self.conv_input(input_sp_tensor) 146 | 147 | x_conv1 = self.conv1(x) 148 | x_conv2 = self.conv2(x_conv1) 149 | x_conv3 = self.conv3(x_conv2) 150 | x_conv4 = self.conv4(x_conv3) 151 | 152 | # for detection head 153 | # [200, 176, 5] -> [200, 176, 2] 154 | out = self.conv_out(x_conv4) 155 | 156 | batch_dict.update({ 157 | 'encoded_spconv_tensor': out, 158 | 'encoded_spconv_tensor_stride': 8 159 | }) 160 | batch_dict.update({ 161 | 'multi_scale_3d_features': { 162 | 'x_conv1': x_conv1, 163 | 'x_conv2': x_conv2, 164 | 'x_conv3': x_conv3, 165 | 'x_conv4': x_conv4, 166 | } 167 | }) 168 | 169 | return batch_dict 170 | 171 | 172 | class VoxelResBackBone8x(nn.Module): 173 | def __init__(self, model_cfg, input_channels, grid_size, **kwargs): 174 | super().__init__() 175 | self.model_cfg = model_cfg 176 | norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) 177 | 178 | self.sparse_shape = grid_size[::-1] + [1, 0, 0] 179 | 180 | self.conv_input = spconv.SparseSequential( 181 | spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key='subm1'), 182 | norm_fn(16), 183 | nn.ReLU(), 184 | ) 185 | block = post_act_block 186 | 187 | self.conv1 = spconv.SparseSequential( 188 | SparseBasicBlock(16, 16, norm_fn=norm_fn, indice_key='res1'), 189 | SparseBasicBlock(16, 16, norm_fn=norm_fn, indice_key='res1'), 190 | ) 191 | 192 | self.conv2 = spconv.SparseSequential( 193 | # [1600, 1408, 41] <- [800, 704, 21] 194 | block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'), 195 | SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res2'), 196 | SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res2'), 197 | ) 198 | 199 | self.conv3 = spconv.SparseSequential( 200 | # [800, 704, 21] <- [400, 352, 11] 201 | block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'), 202 | SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res3'), 203 | SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res3'), 204 | ) 205 | 206 | self.conv4 = spconv.SparseSequential( 207 | # [400, 352, 11] <- [200, 176, 5] 208 | block(64, 128, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='spconv'), 209 | SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res4'), 210 | SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res4'), 211 | ) 212 | 213 | last_pad = 0 214 | last_pad = self.model_cfg.get('last_pad', last_pad) 215 | self.conv_out = spconv.SparseSequential( 216 | # [200, 150, 5] -> [200, 150, 2] 217 | spconv.SparseConv3d(128, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad, 218 | bias=False, indice_key='spconv_down2'), 219 | norm_fn(128), 220 | nn.ReLU(), 221 | ) 222 | self.num_point_features = 128 223 | 224 | def forward(self, batch_dict): 225 | """ 226 | Args: 227 | batch_dict: 228 | batch_size: int 229 | vfe_features: (num_voxels, C) 230 | voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx] 231 | Returns: 232 | batch_dict: 233 | encoded_spconv_tensor: sparse tensor 234 | """ 235 | voxel_features, voxel_coords = batch_dict['voxel_features'], batch_dict['voxel_coords'] 236 | batch_size = batch_dict['batch_size'] 237 | input_sp_tensor = spconv.SparseConvTensor( 238 | features=voxel_features, 239 | indices=voxel_coords.int(), 240 | spatial_shape=self.sparse_shape, 241 | batch_size=batch_size 242 | ) 243 | x = self.conv_input(input_sp_tensor) 244 | 245 | x_conv1 = self.conv1(x) 246 | x_conv2 = self.conv2(x_conv1) 247 | x_conv3 = self.conv3(x_conv2) 248 | x_conv4 = self.conv4(x_conv3) 249 | 250 | # for detection head 251 | # [200, 176, 5] -> [200, 176, 2] 252 | out = self.conv_out(x_conv4) 253 | 254 | batch_dict.update({ 255 | 'encoded_spconv_tensor': out, 256 | 'encoded_spconv_tensor_stride': 8 257 | }) 258 | batch_dict.update({ 259 | 'multi_scale_3d_features': { 260 | 'x_conv1': x_conv1, 261 | 'x_conv2': x_conv2, 262 | 'x_conv3': x_conv3, 263 | 'x_conv4': x_conv4, 264 | } 265 | }) 266 | 267 | return batch_dict 268 | -------------------------------------------------------------------------------- /models/trunks/spconv_unet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from functools import partial 8 | 9 | import spconv 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from models.trunks.mlp import MLP 15 | 16 | from .spconv_backbone import post_act_block 17 | 18 | class SparseBasicBlock(spconv.SparseModule): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None, indice_key=None, norm_fn=None): 22 | super(SparseBasicBlock, self).__init__() 23 | self.conv1 = spconv.SubMConv3d( 24 | inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False, indice_key=indice_key 25 | ) 26 | self.bn1 = norm_fn(planes) 27 | self.relu = nn.ReLU() 28 | self.conv2 = spconv.SubMConv3d( 29 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False, indice_key=indice_key 30 | ) 31 | self.bn2 = norm_fn(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | identity = x.features 37 | 38 | assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim() 39 | 40 | out = self.conv1(x) 41 | out.features = self.bn1(out.features) 42 | out.features = self.relu(out.features) 43 | 44 | out = self.conv2(out) 45 | out.features = self.bn2(out.features) 46 | 47 | if self.downsample is not None: 48 | identity = self.downsample(x) 49 | 50 | out.features += identity 51 | out.features = self.relu(out.features) 52 | 53 | return out 54 | 55 | import numpy as np 56 | 57 | class UNetV2_concat(nn.Module): 58 | """ 59 | Sparse Convolution based UNet for point-wise feature learning. 60 | Reference Paper: https://arxiv.org/abs/1907.03670 (Shaoshuai Shi, et. al) 61 | From Points to Parts: 3D Object Detection from Point Cloud with Part-aware and Part-aggregation Network 62 | """ 63 | def __init__(self, use_mlp=False, mlp_dim=None): 64 | super().__init__() 65 | 66 | input_channels = 4 67 | voxel_size = [0.1, 0.1, 0.2] 68 | point_cloud_range = np.array([ 0. , -75. , -3. , 75.0, 75. , 3. ], dtype=np.float32) 69 | 70 | grid_size = (point_cloud_range[3:6] - point_cloud_range[0:3]) / np.array(voxel_size) 71 | grid_size = np.round(grid_size).astype(np.int64) 72 | model_cfg = {'NAME': 'UNetV2', 'RETURN_ENCODED_TENSOR': False} 73 | 74 | self.sparse_shape = grid_size[::-1] + [1, 0, 0] 75 | self.voxel_size = voxel_size 76 | self.point_cloud_range = point_cloud_range 77 | 78 | norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) 79 | 80 | self.conv_input = spconv.SparseSequential( 81 | spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key='subm1'), 82 | norm_fn(16), 83 | nn.ReLU(), 84 | ) 85 | block = post_act_block 86 | 87 | self.conv1 = spconv.SparseSequential( 88 | block(16, 16, 3, norm_fn=norm_fn, padding=1, indice_key='subm1'), 89 | ) 90 | 91 | self.conv2 = spconv.SparseSequential( 92 | # [1600, 1408, 41] <- [800, 704, 21] 93 | block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'), 94 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'), 95 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'), 96 | ) 97 | 98 | self.conv3 = spconv.SparseSequential( 99 | # [800, 704, 21] <- [400, 352, 11] 100 | block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'), 101 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'), 102 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'), 103 | ) 104 | 105 | self.conv4 = spconv.SparseSequential( 106 | # [400, 352, 11] <- [200, 176, 5] 107 | block(64, 64, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='spconv'), 108 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), 109 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), 110 | ) 111 | 112 | self.conv_out = None 113 | 114 | # decoder 115 | # [400, 352, 11] <- [200, 176, 5] 116 | self.conv_up_t4 = SparseBasicBlock(64, 64, indice_key='subm4', norm_fn=norm_fn) 117 | self.conv_up_m4 = block(128, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4') 118 | self.inv_conv4 = block(64, 64, 3, norm_fn=norm_fn, indice_key='spconv4', conv_type='inverseconv') 119 | 120 | # [800, 704, 21] <- [400, 352, 11] 121 | self.conv_up_t3 = SparseBasicBlock(64, 64, indice_key='subm3', norm_fn=norm_fn) 122 | self.conv_up_m3 = block(128, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3') 123 | self.inv_conv3 = block(64, 32, 3, norm_fn=norm_fn, indice_key='spconv3', conv_type='inverseconv') 124 | 125 | # [1600, 1408, 41] <- [800, 704, 21] 126 | self.conv_up_t2 = SparseBasicBlock(32, 32, indice_key='subm2', norm_fn=norm_fn) 127 | self.conv_up_m2 = block(64, 32, 3, norm_fn=norm_fn, indice_key='subm2') 128 | self.inv_conv2 = block(32, 16, 3, norm_fn=norm_fn, indice_key='spconv2', conv_type='inverseconv') 129 | 130 | # [1600, 1408, 41] <- [1600, 1408, 41] 131 | self.conv_up_t1 = SparseBasicBlock(16, 16, indice_key='subm1', norm_fn=norm_fn) 132 | self.conv_up_m1 = block(32, 16, 3, norm_fn=norm_fn, indice_key='subm1') 133 | 134 | self.conv5 = spconv.SparseSequential( 135 | block(16, 16, 3, norm_fn=norm_fn, padding=1, indice_key='subm1') 136 | ) 137 | self.num_point_features = 16 138 | 139 | self.all_feat_names = [ 140 | "conv4", 141 | ] 142 | 143 | if use_mlp: 144 | self.use_mlp = True 145 | self.head = MLP(mlp_dim) 146 | 147 | def UR_block_forward(self, x_lateral, x_bottom, conv_t, conv_m, conv_inv): 148 | x_trans = conv_t(x_lateral) 149 | x = x_trans 150 | x.features = torch.cat((x_bottom.features, x_trans.features), dim=1) 151 | x_m = conv_m(x) 152 | x = self.channel_reduction(x, x_m.features.shape[1]) 153 | x.features = x_m.features + x.features 154 | x = conv_inv(x) 155 | return x 156 | 157 | @staticmethod 158 | def channel_reduction(x, out_channels): 159 | """ 160 | Args: 161 | x: x.features (N, C1) 162 | out_channels: C2 163 | 164 | Returns: 165 | 166 | """ 167 | features = x.features 168 | n, in_channels = features.shape 169 | assert (in_channels % out_channels == 0) and (in_channels >= out_channels) 170 | 171 | x.features = features.view(n, out_channels, -1).sum(dim=2) 172 | return x 173 | 174 | def forward(self, x, out_feat_keys=None): 175 | """ 176 | Args: 177 | batch_dict: 178 | batch_size: int 179 | vfe_features: (num_voxels, C) 180 | voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx] 181 | Returns: 182 | batch_dict: 183 | encoded_spconv_tensor: sparse tensor 184 | point_features: (N, C) 185 | """ 186 | ### Pre processing 187 | voxel_features, voxel_num_points = x['voxels'], x['voxel_num_points'] 188 | points_mean = voxel_features[:, :, :].sum(dim=1, keepdim=False) 189 | normalizer = torch.clamp_min(voxel_num_points.view(-1, 1), min=1.0).type_as(voxel_features) 190 | points_mean = points_mean / normalizer 191 | voxel_features = points_mean.contiguous() 192 | 193 | temp = x['voxel_coords'].detach().cpu().numpy() 194 | 195 | batch_size = len(np.unique(temp[:,0])) 196 | voxel_coords = x['voxel_coords'] 197 | 198 | input_sp_tensor = spconv.SparseConvTensor( 199 | features=voxel_features.float(), 200 | indices=voxel_coords.int(), 201 | spatial_shape=self.sparse_shape, 202 | batch_size=batch_size 203 | ) 204 | x = self.conv_input(input_sp_tensor) 205 | 206 | x_conv1 = self.conv1(x) 207 | x_conv2 = self.conv2(x_conv1) 208 | x_conv3 = self.conv3(x_conv2) 209 | x_conv4 = self.conv4(x_conv3) 210 | 211 | if self.conv_out is not None: 212 | # for detection head 213 | # [200, 176, 5] -> [200, 176, 2] 214 | out = self.conv_out(x_conv4) 215 | #batch_dict['encoded_spconv_tensor'] = out 216 | #batch_dict['encoded_spconv_tensor_stride'] = 8 217 | 218 | # for segmentation head 219 | # [400, 352, 11] <- [200, 176, 5] 220 | x_up4 = self.UR_block_forward(x_conv4, x_conv4, self.conv_up_t4, self.conv_up_m4, self.inv_conv4) 221 | # [800, 704, 21] <- [400, 352, 11] 222 | x_up3 = self.UR_block_forward(x_conv3, x_up4, self.conv_up_t3, self.conv_up_m3, self.inv_conv3) 223 | # [1600, 1408, 41] <- [800, 704, 21] 224 | x_up2 = self.UR_block_forward(x_conv2, x_up3, self.conv_up_t2, self.conv_up_m2, self.inv_conv2) 225 | # [1600, 1408, 41] <- [1600, 1408, 41] 226 | x_up1 = self.UR_block_forward(x_conv1, x_up2, self.conv_up_t1, self.conv_up_m1, self.conv5) 227 | 228 | end_points = {} 229 | 230 | end_points['conv4_features'] = [x_up4.features, x_up3.features, x_up2.features, x_up1.features]#.view(batch_size, -1, 64).permute(0, 2, 1).contiguous() 231 | end_points['indice'] = [x_up4.indices, x_up3.indices, x_up2.indices, x_up1.indices] 232 | 233 | out_feats = [None] * len(out_feat_keys) 234 | 235 | for key in out_feat_keys: 236 | feat = end_points[key+"_features"] 237 | 238 | featlist = [] 239 | for i in range(batch_size): 240 | tempfeat = [] 241 | for idx in range(len(end_points['indice'])): 242 | temp_idx = end_points['indice'][idx][:,0] == i 243 | temp_f = end_points['conv4_features'][idx][temp_idx].unsqueeze(0).permute(0, 2, 1).contiguous() 244 | tempfeat.append(F.max_pool1d(temp_f, temp_f.shape[-1]).squeeze(-1)) 245 | featlist.append(torch.cat(tempfeat, -1)) 246 | feat = torch.cat(featlist, 0) 247 | if self.use_mlp: 248 | feat = self.head(feat) 249 | out_feats[out_feat_keys.index(key)] = feat ### Just use smlp 250 | return out_feats 251 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | absl-py=0.9.0=py36_0 5 | alabaster=0.7.10=py36h306e16b_0 6 | anaconda=custom=py36hbbc8b67_0 7 | anaconda-client=1.6.9=py36_0 8 | anaconda-project=0.8.2=py36h44fb852_0 9 | asn1crypto=0.24.0=py36_0 10 | astroid=1.6.1=py36_0 11 | attrs=17.4.0=py36_0 12 | babel=2.5.3=py36_0 13 | backports=1.0=py36hfa02d7e_1 14 | backports.functools_lru_cache=1.5=py36_0 15 | backports.shutil_get_terminal_size=1.0.0=py36hfea85ff_2 16 | backports.weakref=1.0rc1=py36_1 17 | beautifulsoup4=4.6.0=py36h49b8c8c_1 18 | bitarray=0.8.1=py36h14c3975_1 19 | blas=1.0=mkl 20 | bleach=3.3.0 21 | blinker=1.4=py_0 22 | boto=2.48.0=py36h6e4cd66_1 23 | boto3=1.6.7=py_0 24 | botocore=1.9.7=py_0 25 | bz2file=0.98=py36_0 26 | bzip2=1.0.6=h9a117a8_4 27 | ca-certificates=2017.08.26=h1d4fec5_0 28 | cairo=1.14.12=h77bcde2_0 29 | certifi=2019.11.28=py36_1 30 | cffi=1.11.4=py36h9745a5d_0 31 | chardet=3.0.4=py36h0f667ec_1 32 | click=6.7=py36h5253387_0 33 | cloudpickle=0.5.2=py36_1 34 | clyent=1.2.2=py36h7e57e65_1 35 | cmake=3.9.4=h142f0e9_0 36 | colorama=0.3.9=py36h489cec4_0 37 | contextlib2=0.5.5=py36h6c84a62_0 38 | cryptography=2.1.4=py36hd09be54_0 39 | cuda90=1.0=h6433d27_0 40 | cudatoolkit=10.1.243=h6bb024c_0 41 | curl=7.58.0=h84994c4_0 42 | cycler=0.10.0=py36_0 43 | cymem=1.31.2=py36_0 44 | cython=0.27.3=py36h1860423_0 45 | cytoolz=0.8.2=py36_0 46 | dask-core=0.16.1=py36_0 47 | dbus=1.10.22=h3b5a359_0 48 | decorator=4.2.1=py36_0 49 | dill=0.2.7.1=py36_0 50 | distributed=1.20.2=py36_0 51 | docutils=0.14=py36hb0f60f5_0 52 | easydict=1.9=py_0 53 | entrypoints=0.2.3=py36h1aec115_2 54 | et_xmlfile=1.0.1=py36hd6bccc3_0 55 | expat=2.2.4=hc00ebd1_1 56 | fastcache=1.0.2=py36h14c3975_2 57 | ffmpeg=3.4=h7264315_0 58 | filelock=2.0.13=py36h646ffb5_0 59 | flask 60 | flask-cors=3.0.3=py36h2d857d3_0 61 | fontconfig=2.12.6=h49f89f6_0 62 | freetype=2.8=h52ed37b_0 63 | ftfy=4.4.2=py36_0 64 | future=0.16.0=py36_1 65 | get_terminal_size=1.0.0=haa9412d_0 66 | gevent=1.2.2=py36h2fe25dc_0 67 | gflags=2.2.1=hf484d3e_0 68 | glib=2.53.6=hc861d11_1 69 | glob2=0.6=py36he249c77_0 70 | glog=0.3.5=hf484d3e_1 71 | gmp=6.1.2=h6c8ec71_1 72 | gmpy2=2.0.8=py36hc8893dd_2 73 | graphite2=1.3.10=hf63cedd_1 74 | greenlet=0.4.12=py36h2d503a6_0 75 | gst-plugins-base=1.12.4=h33fb286_0 76 | gstreamer=1.12.4=hb53b477_0 77 | harfbuzz=1.7.4=hc5b324e_0 78 | hdf5=1.10.1=h9caa474_1 79 | heapdict=1.0.0=py36_2 80 | html5lib=0.9999999=py36_0 81 | icc_rt=2018.0.3=intel_0 82 | icu=58.2=h211956c_0 83 | idna=2.6=py36h82fb2a8_1 84 | imageio=2.9.0=py_0 85 | imagesize=0.7.1=py36h52d8127_0 86 | intel-openmp=2018.0.0=hc7b2577_8 87 | intelpython=2018.0.3=0 88 | ipykernel=4.8.0=py36_0 89 | ipython=6.2.1=py36_1 90 | ipython_genutils=0.2.0=py36hb52b0d5_0 91 | ipywidgets=7.1.1=py36_0 92 | isort=4.2.15=py36had401c0_0 93 | itsdangerous=0.24=py36h93cc618_1 94 | jasper=1.900.1=hd497a04_4 95 | jbig=2.1=hdba287a_0 96 | jdcal=1.3=py36h4c697fb_0 97 | jedi=0.11.1=py36_0 98 | jinja2=2.11.3 99 | jmespath=0.9.3=py36_0 100 | joblib=0.11=py36_0 101 | jpeg=9b=habf39ab_1 102 | jsonschema=2.6.0=py36h006f8b5_0 103 | jupyter=1.0.0=py36_0 104 | jupyter_client=5.2.2=py36_0 105 | jupyter_console=5.2.0=py36he59e554_1 106 | jupyter_core=4.4.0=py36h7c827e3_0 107 | jupyterlab=0.31.5=py36_0 108 | jupyterlab_launcher=0.10.2=py36_0 109 | kiwisolver=1.0.1=py36hf484d3e_0 110 | lazy-object-proxy=1.3.1=py36h10fcdad_0 111 | libcurl=7.58.0=h1ad7b7a_0 112 | libedit=3.1.20181209=hc058e9b_0 113 | libffi=3.2.1=h4deb6c0_3 114 | libgcc=7.2.0=h69d50b8_2 115 | libgcc-ng=9.1.0=hdf63c60_0 116 | libgfortran=3.0.0=1 117 | libgfortran-ng=7.2.0=h9f7466a_2 118 | libgpuarray=0.7.5=0 119 | libiconv=1.15=0 120 | libopenblas=0.2.20=h9ac9557_7 121 | libopus=1.2.1=hb9ed12e_0 122 | libpng=1.6.37=hbc83047_0 123 | libprotobuf=3.6.0=hdbcaa40_0 124 | libsodium=1.0.15=hf101ebd_0 125 | libssh2=1.8.0=h9cfc8f7_4 126 | libstdcxx-ng=7.2.0=h7a57d05_2 127 | libtiff=4.0.9=h28f6b97_0 128 | libtool=2.4.6=h544aabb_3 129 | libuv=1.20.3=h14c3975_0 130 | libvpx=1.6.1=h888fd40_0 131 | libxcb=1.13=h1bed415_1 132 | libxml2=2.9.9=hea5a465_1 133 | libxslt=1.1.32=h1312cb7_0 134 | llvmlite=0.21.0=py36ha241eea_0 135 | lmdb=0.9.21=hf484d3e_1 136 | locket=0.2.0=py36h787c0ad_1 137 | lxml=4.6.3 138 | lzo=2.10=h49e0be7_2 139 | magma-cuda90=2.3.0=1 140 | mako=1.0.7=py36_0 141 | markdown=2.6.11=py_0 142 | markupsafe=1.0=py36hd9260cd_1 143 | mccabe=0.6.1=py36h5ad9710_1 144 | mistune=0.8.3=py36_0 145 | mkl=2018.0.3=1 146 | mkl-include=2018.0.3=1 147 | mkl-service=1.1.2=py36h17a0993_4 148 | mkl_fft=1.0.4=py36h4414c95_1 149 | mkl_random=1.0.1=py36h4414c95_1 150 | mkldnn=0.14.0=0 151 | mpc=1.0.3=hec55b23_5 152 | mpfr=3.1.5=h11a74b3_2 153 | mpmath=1.0.0=py36hfeacd6b_2 154 | msgpack-python=0.4.8=py36_0 155 | multipledispatch=0.4.9=py36h41da3fb_0 156 | murmurhash=0.28.0=py36_0 157 | nbconvert=5.3.1=py36hb41ffb7_0 158 | nbformat=4.4.0=py36h31c9010_0 159 | nccl2=1.0=0 160 | ncurses=6.1=hf484d3e_0 161 | networkx=2.1=py36_0 162 | ninja=1.8.2=py36h6bb024c_1 163 | nose=1.3.7=py36hcdf7029_2 164 | notebook=6.1.5=py36_0 165 | numpy=1.14.3=py36hcd700cb_1 166 | numpy-base=1.14.3=py36h9be14a7_1 167 | numpydoc=0.7.0=py36h18f165f_0 168 | oauthlib=2.0.6=py_0 169 | olefile=0.45.1=py36_0 170 | openblas=0.2.19=0 171 | opencv=3.3.1=py36h0a11808_0 172 | openmp=2018.0.3=intel_0 173 | openpyxl=2.4.10=py36_0 174 | openssl=1.0.2u=h7b6447c_0 175 | packaging=16.8=py36ha668100_1 176 | pandas=0.23.4=py36h04863e7_0 177 | pandoc=1.19.2.1=hea2e7c5_1 178 | pandocfilters=1.4.2=py36ha6701b7_1 179 | pango=1.41.0=hd475d92_0 180 | parso=0.1.1=py36h35f843b_0 181 | partd=0.3.8=py36h36fd896_0 182 | patchelf=0.9=hf79760b_2 183 | path.py=10.5=py36h55ceabb_0 184 | pathlib2=2.3.0=py36h49efa8e_0 185 | pcre=8.42=h439df22_0 186 | pep8=1.7.1=py36_0 187 | pexpect=4.3.1=py36_0 188 | pickleshare=0.7.4=py36h63277f8_0 189 | pillow=8.2.0 190 | pip=19.3.1=py36_0 191 | pixman=0.34.0=hceecf20_3 192 | pkginfo=1.4.1=py36h215d178_1 193 | plac=0.9.6=py36_0 194 | pluggy=0.6.0=py36hb689045_0 195 | ply=3.10=py36hed35086_0 196 | plyfile=0.7.2=pyh9f0ad1d_0 197 | preshed=1.0.0=py36_0 198 | prompt_toolkit=1.0.15=py36h17d85b1_0 199 | protobuf=3.6.0=py36hf484d3e_0 200 | psutil 201 | ptyprocess=0.5.2=py36h69acd42_0 202 | py 203 | pycodestyle=2.3.1=py36hf609f19_0 204 | pycosat=0.6.3=py36h0a5515d_0 205 | pycparser=2.18=py36hf9f622e_1 206 | pycrypto 207 | pycurl=7.43.0.1=py36hb7f436b_0 208 | pyflakes=1.6.0=py36h7bd6a15_0 209 | pygments=2.7.4 210 | pyjwt=1.5.3=py_0 211 | pylint=1.8.2=py36_0 212 | pyodbc=4.0.22=py36hf484d3e_0 213 | pyopenssl=17.5.0=py36h20ba746_0 214 | pyparsing=2.4.4=py_0 215 | pysocks=1.6.7=py36hd97a5b1_1 216 | pytest=3.3.2=py36_0 217 | python=3.6.6=hc3d631a_0 218 | python-crfsuite=0.9.2=py36_0 219 | python-dateutil=2.8.1=py_0 220 | python-lmdb=0.92=py36_0 221 | pytorch=1.5.0=py3.6_cuda10.1.243_cudnn7.6.3_0 222 | pytz=2019.3=py_0 223 | pyyaml=5.4 224 | pyzmq=16.0.2=py36h3b0cf96_2 225 | qt=5.6.2=h974d657_12 226 | qtawesome=0.4.4=py36h609ed8c_0 227 | qtconsole=4.3.1=py36h8f73b5b_0 228 | qtpy=1.3.1=py36h3691cc8_0 229 | readline=7.0=h7b6447c_5 230 | regex=2017.11.09=py36_0 231 | requests 232 | requests-oauthlib=0.8.0=py36_1 233 | rhash=1.3.6=hb7f436b_0 234 | rope=0.10.7=py36h147e2ec_0 235 | ruamel_yaml=0.15.35=py36h14c3975_1 236 | s3transfer=0.1.13=py36_0 237 | scikit-learn=0.19.1=py36h7aa7ec6_0 238 | scipy=1.0.0=py36hbf646e7_0 239 | send2trash=1.4.2=py36_0 240 | setuptools=41.6.0=py36_0 241 | simplegeneric=0.8.1=py36_2 242 | simplejson=3.14.0=py36h14c3975_0 243 | singledispatch=3.4.0.3=py36h7a266c3_0 244 | sip=4.19.8=py36hf484d3e_0 245 | six=1.13.0=py36_0 246 | smart_open=1.5.6=py36_1 247 | snowballstemmer=1.2.1=py36h6febd40_0 248 | sortedcollections=0.5.3=py36h3c761f9_0 249 | sortedcontainers=1.5.9=py36_0 250 | sphinx=1.6.6=py36_0 251 | sphinxcontrib=1.0=py36h6d0f590_1 252 | sphinxcontrib-websupport=1.0.1=py36hb5cb234_1 253 | spyder=3.2.6=py36_0 254 | sqlalchemy 255 | sqlite=3.30.1=h7b6447c_0 256 | sympy=1.1.1=py36hc6d1c1c_0 257 | tbb=2018.0.4=h6bb024c_1 258 | tbb4py=2018.0.4=py36h6bb024c_1 259 | tblib=1.3.2=py36h34cf8b6_0 260 | tensorboardx=2.1=py_0 261 | termcolor=1.1.0=py36_1 262 | terminado=0.8.1=py36_1 263 | testpath=0.3.1=py36h8cadb63_0 264 | tk=8.6.8=hbc83047_0 265 | toolz=0.9.0=py36_0 266 | torchvision=0.6.0=py36_cu101 267 | tornado=6.0.3=py36h7b6447c_0 268 | tqdm=4.19.7=py_0 269 | traitlets=4.3.2=py36h674d592_0 270 | twython=3.6.0=py36_0 271 | typing=3.6.2=py36h7da032a_0 272 | ujson=1.35=py36_0 273 | unicodecsv=0.14.1=py36ha668878_0 274 | unixodbc=2.3.4=hc36303a_1 275 | urllib3 276 | wcwidth=0.1.7=py36hdf4376a_0 277 | webencodings=0.5.1=py36h800622e_1 278 | werkzeug 279 | wheel=0.33.6=py36_0 280 | widgetsnbextension=3.1.0=py36_0 281 | wrapt=1.10.11=py36h28b7045_0 282 | x264=20131217=3 283 | xlrd=1.1.0=py36h1db9f0c_1 284 | xlsxwriter=1.0.2=py36h3de1aca_0 285 | xlwt=1.3.0=py36h7b00a1f_0 286 | xz=5.2.4=h14c3975_4 287 | yaml=0.1.7=had09818_2 288 | zeromq=4.2.2=hbedb6e5_2 289 | zict=0.1.3=py36h3a3bf81_0 290 | zlib=1.2.11=h7b6447c_3 291 | -------------------------------------------------------------------------------- /scripts/multinode-wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import sys, os 8 | 9 | num_nodes = int(os.environ['SLURM_NNODES']) 10 | node_id = int(os.environ['SLURM_NODEID']) 11 | node0 = 'learnfair' + os.environ['SLURM_NODELIST'][10:14] 12 | cmd = 'python {script} {cfg} --dist-url tcp://{node0}:1234 --dist-backend nccl --multiprocessing-distributed --world-size {ws} --rank {rank}'.format( 13 | script=sys.argv[1], cfg=sys.argv[2], node0=node0, ws=num_nodes, rank=node_id) 14 | os.system(cmd) 15 | -------------------------------------------------------------------------------- /scripts/pretrain_node1.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #! /bin/bash 7 | #SBATCH --job-name=DepthContrast 8 | #SBATCH --nodes=1 9 | #SBATCH --gres=gpu:8 10 | #SBATCH --cpus-per-task=80 11 | #SBATCH --mem=400G 12 | #SBATCH --time=72:00:00 13 | #SBATCH --partition=dev 14 | #SBATCH --comment="test" 15 | #SBATCH --constraint=volta32gb 16 | 17 | #SBATCH --signal=B:USR1@60 18 | #SBATCH --open-mode=append 19 | 20 | EXPERIMENT_PATH="./checkpoints/testlog" 21 | mkdir -p $EXPERIMENT_PATH 22 | 23 | export PYTHONPATH=$PWD:$PYTHONPATH 24 | 25 | srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --label python scripts/singlenode-wrapper.py main.py $1 26 | -------------------------------------------------------------------------------- /scripts/pretrain_node4.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #! /bin/bash 7 | #SBATCH --job-name=DepthContrast 8 | #SBATCH --job-name=DepthContrast 9 | #SBATCH --nodes=4 10 | #SBATCH --gres=gpu:8 11 | #SBATCH --cpus-per-task=80 12 | #SBATCH --mem=400G 13 | #SBATCH --time=72:00:00 14 | #SBATCH --partition=dev 15 | #SBATCH --comment="depthcontrast" 16 | #SBATCH --constraint=volta32gb 17 | 18 | #SBATCH --signal=B:USR1@60 19 | #SBATCH --open-mode=append 20 | 21 | LOG_PATH=$2 22 | mkdir -p $LOG_PATH 23 | 24 | export PYTHONPATH=$PWD:$PYTHONPATH 25 | 26 | srun --output=${LOG_PATH}/%j.out --error=${LOG_PATH}/%j.err --label python scripts/multinode-wrapper.py main.py $1 27 | -------------------------------------------------------------------------------- /scripts/singlenode-wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import sys, os 8 | 9 | num_nodes = int(os.environ['SLURM_NNODES']) 10 | node_id = int(os.environ['SLURM_NODEID']) 11 | node0 = 'learnfair' + os.environ['SLURM_NODELIST'][9:13] 12 | cmd = 'python {script} {cfg} --dist-url tcp://{node0}:1234 --dist-backend nccl --multiprocessing-distributed --world-size {ws} --rank {rank}'.format( 13 | script=sys.argv[1], cfg=sys.argv[2], node0=node0, ws=num_nodes, rank=node_id) 14 | os.system(cmd) 15 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 10 | const int nsample); 11 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #ifndef _CUDA_UTILS_H 7 | #define _CUDA_UTILS_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | #define TOTAL_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | 23 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 24 | } 25 | 26 | inline dim3 opt_block_config(int x, int y) { 27 | const int x_threads = opt_n_threads(x); 28 | const int y_threads = 29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 30 | dim3 block_config(x_threads, y_threads, 1); 31 | 32 | return block_config; 33 | } 34 | 35 | #define CUDA_CHECK_ERRORS() \ 36 | do { \ 37 | cudaError_t err = cudaGetLastError(); \ 38 | if (cudaSuccess != err) { \ 39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 41 | __FILE__); \ 42 | exit(-1); \ 43 | } \ 44 | } while (0) 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/include/group_points.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 12 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 13 | at::Tensor weight); 14 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 15 | at::Tensor weight, const int m); 16 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/include/sampling.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 12 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) \ 11 | do { \ 12 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_CONTIGUOUS(x) \ 16 | do { \ 17 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ 18 | } while (0) 19 | 20 | #define CHECK_IS_INT(x) \ 21 | do { \ 22 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ 23 | #x " must be an int tensor"); \ 24 | } while (0) 25 | 26 | #define CHECK_IS_FLOAT(x) \ 27 | do { \ 28 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ 29 | #x " must be a float tensor"); \ 30 | } while (0) 31 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "utils.h" 8 | 9 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 10 | int nsample, const float *new_xyz, 11 | const float *xyz, int *idx); 12 | 13 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 14 | const int nsample) { 15 | CHECK_CONTIGUOUS(new_xyz); 16 | CHECK_CONTIGUOUS(xyz); 17 | CHECK_IS_FLOAT(new_xyz); 18 | CHECK_IS_FLOAT(xyz); 19 | 20 | if (new_xyz.is_cuda()) { 21 | CHECK_CUDA(xyz); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.is_cuda()) { 29 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, nsample, new_xyz.data(), 31 | xyz.data(), idx.data()); 32 | } else { 33 | AT_ASSERT(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 13 | // output: idx(b, m, nsample) 14 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 15 | int nsample, 16 | const float *__restrict__ new_xyz, 17 | const float *__restrict__ xyz, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | xyz += batch_index * n * 3; 21 | new_xyz += batch_index * m * 3; 22 | idx += m * nsample * batch_index; 23 | 24 | int index = threadIdx.x; 25 | int stride = blockDim.x; 26 | 27 | float radius2 = radius * radius; 28 | for (int j = index; j < m; j += stride) { 29 | float new_x = new_xyz[j * 3 + 0]; 30 | float new_y = new_xyz[j * 3 + 1]; 31 | float new_z = new_xyz[j * 3 + 2]; 32 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 33 | float x = xyz[k * 3 + 0]; 34 | float y = xyz[k * 3 + 1]; 35 | float z = xyz[k * 3 + 2]; 36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 37 | (new_z - z) * (new_z - z); 38 | if (d2 < radius2) { 39 | if (cnt == 0) { 40 | for (int l = 0; l < nsample; ++l) { 41 | idx[j * nsample + l] = k; 42 | } 43 | } 44 | idx[j * nsample + cnt] = k; 45 | ++cnt; 46 | } 47 | } 48 | } 49 | } 50 | 51 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 52 | int nsample, const float *new_xyz, 53 | const float *xyz, int *idx) { 54 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 55 | query_ball_point_kernel<<>>( 56 | b, n, m, radius, nsample, new_xyz, xyz, idx); 57 | 58 | CUDA_CHECK_ERRORS(); 59 | } 60 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "group_points.h" 8 | #include "interpolate.h" 9 | #include "sampling.h" 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("gather_points", &gather_points); 13 | m.def("gather_points_grad", &gather_points_grad); 14 | m.def("furthest_point_sampling", &furthest_point_sampling); 15 | 16 | m.def("three_nn", &three_nn); 17 | m.def("three_interpolate", &three_interpolate); 18 | m.def("three_interpolate_grad", &three_interpolate_grad); 19 | 20 | m.def("ball_query", &ball_query); 21 | 22 | m.def("group_points", &group_points); 23 | m.def("group_points_grad", &group_points_grad); 24 | } 25 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "group_points.h" 7 | #include "utils.h" 8 | 9 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 10 | const float *points, const int *idx, 11 | float *out); 12 | 13 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 14 | int nsample, const float *grad_out, 15 | const int *idx, float *grad_points); 16 | 17 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 18 | CHECK_CONTIGUOUS(points); 19 | CHECK_CONTIGUOUS(idx); 20 | CHECK_IS_FLOAT(points); 21 | CHECK_IS_INT(idx); 22 | 23 | if (points.is_cuda()) { 24 | CHECK_CUDA(idx); 25 | } 26 | 27 | at::Tensor output = 28 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 29 | at::device(points.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (points.is_cuda()) { 32 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 33 | idx.size(1), idx.size(2), points.data(), 34 | idx.data(), output.data()); 35 | } else { 36 | AT_ASSERT(false, "CPU not supported"); 37 | } 38 | 39 | return output; 40 | } 41 | 42 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 43 | CHECK_CONTIGUOUS(grad_out); 44 | CHECK_CONTIGUOUS(idx); 45 | CHECK_IS_FLOAT(grad_out); 46 | CHECK_IS_INT(idx); 47 | 48 | if (grad_out.is_cuda()) { 49 | CHECK_CUDA(idx); 50 | } 51 | 52 | at::Tensor output = 53 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 54 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 55 | 56 | if (grad_out.is_cuda()) { 57 | group_points_grad_kernel_wrapper( 58 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 59 | grad_out.data(), idx.data(), output.data()); 60 | } else { 61 | AT_ASSERT(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, npoints, nsample) 12 | // output: out(b, c, npoints, nsample) 13 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 14 | int nsample, 15 | const float *__restrict__ points, 16 | const int *__restrict__ idx, 17 | float *__restrict__ out) { 18 | int batch_index = blockIdx.x; 19 | points += batch_index * n * c; 20 | idx += batch_index * npoints * nsample; 21 | out += batch_index * npoints * nsample * c; 22 | 23 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 24 | const int stride = blockDim.y * blockDim.x; 25 | for (int i = index; i < c * npoints; i += stride) { 26 | const int l = i / npoints; 27 | const int j = i % npoints; 28 | for (int k = 0; k < nsample; ++k) { 29 | int ii = idx[j * nsample + k]; 30 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 31 | } 32 | } 33 | } 34 | 35 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 36 | const float *points, const int *idx, 37 | float *out) { 38 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 39 | 40 | group_points_kernel<<>>( 41 | b, c, n, npoints, nsample, points, idx, out); 42 | 43 | CUDA_CHECK_ERRORS(); 44 | } 45 | 46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 47 | // output: grad_points(b, c, n) 48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 49 | int nsample, 50 | const float *__restrict__ grad_out, 51 | const int *__restrict__ idx, 52 | float *__restrict__ grad_points) { 53 | int batch_index = blockIdx.x; 54 | grad_out += batch_index * npoints * nsample * c; 55 | idx += batch_index * npoints * nsample; 56 | grad_points += batch_index * n * c; 57 | 58 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 59 | const int stride = blockDim.y * blockDim.x; 60 | for (int i = index; i < c * npoints; i += stride) { 61 | const int l = i / npoints; 62 | const int j = i % npoints; 63 | for (int k = 0; k < nsample; ++k) { 64 | int ii = idx[j * nsample + k]; 65 | atomicAdd(grad_points + l * n + ii, 66 | grad_out[(l * npoints + j) * nsample + k]); 67 | } 68 | } 69 | } 70 | 71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 72 | int nsample, const float *grad_out, 73 | const int *idx, float *grad_points) { 74 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 75 | 76 | group_points_grad_kernel<<>>( 77 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 78 | 79 | CUDA_CHECK_ERRORS(); 80 | } 81 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "interpolate.h" 7 | #include "utils.h" 8 | 9 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 10 | const float *known, float *dist2, int *idx); 11 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 12 | const float *points, const int *idx, 13 | const float *weight, float *out); 14 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 15 | const float *grad_out, 16 | const int *idx, const float *weight, 17 | float *grad_points); 18 | 19 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 20 | CHECK_CONTIGUOUS(unknowns); 21 | CHECK_CONTIGUOUS(knows); 22 | CHECK_IS_FLOAT(unknowns); 23 | CHECK_IS_FLOAT(knows); 24 | 25 | if (unknowns.is_cuda()) { 26 | CHECK_CUDA(knows); 27 | } 28 | 29 | at::Tensor idx = 30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 31 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 32 | at::Tensor dist2 = 33 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 34 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 35 | 36 | if (unknowns.is_cuda()) { 37 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 38 | unknowns.data(), knows.data(), 39 | dist2.data(), idx.data()); 40 | } else { 41 | AT_ASSERT(false, "CPU not supported"); 42 | } 43 | 44 | return {dist2, idx}; 45 | } 46 | 47 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 48 | at::Tensor weight) { 49 | CHECK_CONTIGUOUS(points); 50 | CHECK_CONTIGUOUS(idx); 51 | CHECK_CONTIGUOUS(weight); 52 | CHECK_IS_FLOAT(points); 53 | CHECK_IS_INT(idx); 54 | CHECK_IS_FLOAT(weight); 55 | 56 | if (points.is_cuda()) { 57 | CHECK_CUDA(idx); 58 | CHECK_CUDA(weight); 59 | } 60 | 61 | at::Tensor output = 62 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 63 | at::device(points.device()).dtype(at::ScalarType::Float)); 64 | 65 | if (points.is_cuda()) { 66 | three_interpolate_kernel_wrapper( 67 | points.size(0), points.size(1), points.size(2), idx.size(1), 68 | points.data(), idx.data(), weight.data(), 69 | output.data()); 70 | } else { 71 | AT_ASSERT(false, "CPU not supported"); 72 | } 73 | 74 | return output; 75 | } 76 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 77 | at::Tensor weight, const int m) { 78 | CHECK_CONTIGUOUS(grad_out); 79 | CHECK_CONTIGUOUS(idx); 80 | CHECK_CONTIGUOUS(weight); 81 | CHECK_IS_FLOAT(grad_out); 82 | CHECK_IS_INT(idx); 83 | CHECK_IS_FLOAT(weight); 84 | 85 | if (grad_out.is_cuda()) { 86 | CHECK_CUDA(idx); 87 | CHECK_CUDA(weight); 88 | } 89 | 90 | at::Tensor output = 91 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 92 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 93 | 94 | if (grad_out.is_cuda()) { 95 | three_interpolate_grad_kernel_wrapper( 96 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 97 | grad_out.data(), idx.data(), weight.data(), 98 | output.data()); 99 | } else { 100 | AT_ASSERT(false, "CPU not supported"); 101 | } 102 | 103 | return output; 104 | } 105 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: unknown(b, n, 3) known(b, m, 3) 13 | // output: dist2(b, n, 3), idx(b, n, 3) 14 | __global__ void three_nn_kernel(int b, int n, int m, 15 | const float *__restrict__ unknown, 16 | const float *__restrict__ known, 17 | float *__restrict__ dist2, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | unknown += batch_index * n * 3; 21 | known += batch_index * m * 3; 22 | dist2 += batch_index * n * 3; 23 | idx += batch_index * n * 3; 24 | 25 | int index = threadIdx.x; 26 | int stride = blockDim.x; 27 | for (int j = index; j < n; j += stride) { 28 | float ux = unknown[j * 3 + 0]; 29 | float uy = unknown[j * 3 + 1]; 30 | float uz = unknown[j * 3 + 2]; 31 | 32 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 33 | int besti1 = 0, besti2 = 0, besti3 = 0; 34 | for (int k = 0; k < m; ++k) { 35 | float x = known[k * 3 + 0]; 36 | float y = known[k * 3 + 1]; 37 | float z = known[k * 3 + 2]; 38 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 39 | if (d < best1) { 40 | best3 = best2; 41 | besti3 = besti2; 42 | best2 = best1; 43 | besti2 = besti1; 44 | best1 = d; 45 | besti1 = k; 46 | } else if (d < best2) { 47 | best3 = best2; 48 | besti3 = besti2; 49 | best2 = d; 50 | besti2 = k; 51 | } else if (d < best3) { 52 | best3 = d; 53 | besti3 = k; 54 | } 55 | } 56 | dist2[j * 3 + 0] = best1; 57 | dist2[j * 3 + 1] = best2; 58 | dist2[j * 3 + 2] = best3; 59 | 60 | idx[j * 3 + 0] = besti1; 61 | idx[j * 3 + 1] = besti2; 62 | idx[j * 3 + 2] = besti3; 63 | } 64 | } 65 | 66 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 67 | const float *known, float *dist2, int *idx) { 68 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 69 | three_nn_kernel<<>>(b, n, m, unknown, known, 70 | dist2, idx); 71 | 72 | CUDA_CHECK_ERRORS(); 73 | } 74 | 75 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 76 | // output: out(b, c, n) 77 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 78 | const float *__restrict__ points, 79 | const int *__restrict__ idx, 80 | const float *__restrict__ weight, 81 | float *__restrict__ out) { 82 | int batch_index = blockIdx.x; 83 | points += batch_index * m * c; 84 | 85 | idx += batch_index * n * 3; 86 | weight += batch_index * n * 3; 87 | 88 | out += batch_index * n * c; 89 | 90 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 91 | const int stride = blockDim.y * blockDim.x; 92 | for (int i = index; i < c * n; i += stride) { 93 | const int l = i / n; 94 | const int j = i % n; 95 | float w1 = weight[j * 3 + 0]; 96 | float w2 = weight[j * 3 + 1]; 97 | float w3 = weight[j * 3 + 2]; 98 | 99 | int i1 = idx[j * 3 + 0]; 100 | int i2 = idx[j * 3 + 1]; 101 | int i3 = idx[j * 3 + 2]; 102 | 103 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 104 | points[l * m + i3] * w3; 105 | } 106 | } 107 | 108 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 109 | const float *points, const int *idx, 110 | const float *weight, float *out) { 111 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 112 | three_interpolate_kernel<<>>( 113 | b, c, m, n, points, idx, weight, out); 114 | 115 | CUDA_CHECK_ERRORS(); 116 | } 117 | 118 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 119 | // output: grad_points(b, c, m) 120 | 121 | __global__ void three_interpolate_grad_kernel( 122 | int b, int c, int n, int m, const float *__restrict__ grad_out, 123 | const int *__restrict__ idx, const float *__restrict__ weight, 124 | float *__restrict__ grad_points) { 125 | int batch_index = blockIdx.x; 126 | grad_out += batch_index * n * c; 127 | idx += batch_index * n * 3; 128 | weight += batch_index * n * 3; 129 | grad_points += batch_index * m * c; 130 | 131 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 132 | const int stride = blockDim.y * blockDim.x; 133 | for (int i = index; i < c * n; i += stride) { 134 | const int l = i / n; 135 | const int j = i % n; 136 | float w1 = weight[j * 3 + 0]; 137 | float w2 = weight[j * 3 + 1]; 138 | float w3 = weight[j * 3 + 2]; 139 | 140 | int i1 = idx[j * 3 + 0]; 141 | int i2 = idx[j * 3 + 1]; 142 | int i3 = idx[j * 3 + 2]; 143 | 144 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 145 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 146 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 147 | } 148 | } 149 | 150 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 151 | const float *grad_out, 152 | const int *idx, const float *weight, 153 | float *grad_points) { 154 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 155 | three_interpolate_grad_kernel<<>>( 156 | b, c, n, m, grad_out, idx, weight, grad_points); 157 | 158 | CUDA_CHECK_ERRORS(); 159 | } 160 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "sampling.h" 7 | #include "utils.h" 8 | 9 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 10 | const float *points, const int *idx, 11 | float *out); 12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | const float *grad_out, const int *idx, 14 | float *grad_points); 15 | 16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 17 | const float *dataset, float *temp, 18 | int *idxs); 19 | 20 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 21 | CHECK_CONTIGUOUS(points); 22 | CHECK_CONTIGUOUS(idx); 23 | CHECK_IS_FLOAT(points); 24 | CHECK_IS_INT(idx); 25 | 26 | if (points.is_cuda()) { 27 | CHECK_CUDA(idx); 28 | } 29 | 30 | at::Tensor output = 31 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 32 | at::device(points.device()).dtype(at::ScalarType::Float)); 33 | 34 | if (points.is_cuda()) { 35 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 36 | idx.size(1), points.data(), 37 | idx.data(), output.data()); 38 | } else { 39 | AT_ASSERT(false, "CPU not supported"); 40 | } 41 | 42 | return output; 43 | } 44 | 45 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 46 | const int n) { 47 | CHECK_CONTIGUOUS(grad_out); 48 | CHECK_CONTIGUOUS(idx); 49 | CHECK_IS_FLOAT(grad_out); 50 | CHECK_IS_INT(idx); 51 | 52 | if (grad_out.is_cuda()) { 53 | CHECK_CUDA(idx); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 58 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (grad_out.is_cuda()) { 61 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 62 | idx.size(1), grad_out.data(), 63 | idx.data(), output.data()); 64 | } else { 65 | AT_ASSERT(false, "CPU not supported"); 66 | } 67 | 68 | return output; 69 | } 70 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 71 | CHECK_CONTIGUOUS(points); 72 | CHECK_IS_FLOAT(points); 73 | 74 | at::Tensor output = 75 | torch::zeros({points.size(0), nsamples}, 76 | at::device(points.device()).dtype(at::ScalarType::Int)); 77 | 78 | at::Tensor tmp = 79 | torch::full({points.size(0), points.size(1)}, 1e10, 80 | at::device(points.device()).dtype(at::ScalarType::Float)); 81 | 82 | if (points.is_cuda()) { 83 | furthest_point_sampling_kernel_wrapper( 84 | points.size(0), points.size(1), nsamples, points.data(), 85 | tmp.data(), output.data()); 86 | } else { 87 | AT_ASSERT(false, "CPU not supported"); 88 | } 89 | 90 | return output; 91 | } 92 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, m) 12 | // output: out(b, c, m) 13 | __global__ void gather_points_kernel(int b, int c, int n, int m, 14 | const float *__restrict__ points, 15 | const int *__restrict__ idx, 16 | float *__restrict__ out) { 17 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 18 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 19 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 20 | int a = idx[i * m + j]; 21 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 22 | } 23 | } 24 | } 25 | } 26 | 27 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 28 | const float *points, const int *idx, 29 | float *out) { 30 | gather_points_kernel<<>>(b, c, n, npoints, 32 | points, idx, out); 33 | 34 | CUDA_CHECK_ERRORS(); 35 | } 36 | 37 | // input: grad_out(b, c, m) idx(b, m) 38 | // output: grad_points(b, c, n) 39 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 40 | const float *__restrict__ grad_out, 41 | const int *__restrict__ idx, 42 | float *__restrict__ grad_points) { 43 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 44 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 45 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 46 | int a = idx[i * m + j]; 47 | atomicAdd(grad_points + (i * c + l) * n + a, 48 | grad_out[(i * c + l) * m + j]); 49 | } 50 | } 51 | } 52 | } 53 | 54 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 55 | const float *grad_out, const int *idx, 56 | float *grad_points) { 57 | gather_points_grad_kernel<<>>( 59 | b, c, n, npoints, grad_out, idx, grad_points); 60 | 61 | CUDA_CHECK_ERRORS(); 62 | } 63 | 64 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 65 | int idx1, int idx2) { 66 | const float v1 = dists[idx1], v2 = dists[idx2]; 67 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 68 | dists[idx1] = max(v1, v2); 69 | dists_i[idx1] = v2 > v1 ? i2 : i1; 70 | } 71 | 72 | // Input dataset: (b, n, 3), tmp: (b, n) 73 | // Ouput idxs (b, m) 74 | template 75 | __global__ void furthest_point_sampling_kernel( 76 | int b, int n, int m, const float *__restrict__ dataset, 77 | float *__restrict__ temp, int *__restrict__ idxs) { 78 | if (m <= 0) return; 79 | __shared__ float dists[block_size]; 80 | __shared__ int dists_i[block_size]; 81 | 82 | int batch_index = blockIdx.x; 83 | dataset += batch_index * n * 3; 84 | temp += batch_index * n; 85 | idxs += batch_index * m; 86 | 87 | int tid = threadIdx.x; 88 | const int stride = block_size; 89 | 90 | int old = 0; 91 | if (threadIdx.x == 0) idxs[0] = old; 92 | 93 | __syncthreads(); 94 | for (int j = 1; j < m; j++) { 95 | int besti = 0; 96 | float best = -1; 97 | float x1 = dataset[old * 3 + 0]; 98 | float y1 = dataset[old * 3 + 1]; 99 | float z1 = dataset[old * 3 + 2]; 100 | for (int k = tid; k < n; k += stride) { 101 | float x2, y2, z2; 102 | x2 = dataset[k * 3 + 0]; 103 | y2 = dataset[k * 3 + 1]; 104 | z2 = dataset[k * 3 + 2]; 105 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 106 | if (mag <= 1e-3) continue; 107 | 108 | float d = 109 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 110 | 111 | float d2 = min(d, temp[k]); 112 | temp[k] = d2; 113 | besti = d2 > best ? k : besti; 114 | best = d2 > best ? d2 : best; 115 | } 116 | dists[tid] = best; 117 | dists_i[tid] = besti; 118 | __syncthreads(); 119 | 120 | if (block_size >= 512) { 121 | if (tid < 256) { 122 | __update(dists, dists_i, tid, tid + 256); 123 | } 124 | __syncthreads(); 125 | } 126 | if (block_size >= 256) { 127 | if (tid < 128) { 128 | __update(dists, dists_i, tid, tid + 128); 129 | } 130 | __syncthreads(); 131 | } 132 | if (block_size >= 128) { 133 | if (tid < 64) { 134 | __update(dists, dists_i, tid, tid + 64); 135 | } 136 | __syncthreads(); 137 | } 138 | if (block_size >= 64) { 139 | if (tid < 32) { 140 | __update(dists, dists_i, tid, tid + 32); 141 | } 142 | __syncthreads(); 143 | } 144 | if (block_size >= 32) { 145 | if (tid < 16) { 146 | __update(dists, dists_i, tid, tid + 16); 147 | } 148 | __syncthreads(); 149 | } 150 | if (block_size >= 16) { 151 | if (tid < 8) { 152 | __update(dists, dists_i, tid, tid + 8); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 8) { 157 | if (tid < 4) { 158 | __update(dists, dists_i, tid, tid + 4); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 4) { 163 | if (tid < 2) { 164 | __update(dists, dists_i, tid, tid + 2); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 2) { 169 | if (tid < 1) { 170 | __update(dists, dists_i, tid, tid + 1); 171 | } 172 | __syncthreads(); 173 | } 174 | 175 | old = dists_i[0]; 176 | if (tid == 0) idxs[j] = old; 177 | } 178 | } 179 | 180 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 181 | const float *dataset, float *temp, 182 | int *idxs) { 183 | unsigned int n_threads = opt_n_threads(n); 184 | 185 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 186 | 187 | switch (n_threads) { 188 | case 512: 189 | furthest_point_sampling_kernel<512> 190 | <<>>(b, n, m, dataset, temp, idxs); 191 | break; 192 | case 256: 193 | furthest_point_sampling_kernel<256> 194 | <<>>(b, n, m, dataset, temp, idxs); 195 | break; 196 | case 128: 197 | furthest_point_sampling_kernel<128> 198 | <<>>(b, n, m, dataset, temp, idxs); 199 | break; 200 | case 64: 201 | furthest_point_sampling_kernel<64> 202 | <<>>(b, n, m, dataset, temp, idxs); 203 | break; 204 | case 32: 205 | furthest_point_sampling_kernel<32> 206 | <<>>(b, n, m, dataset, temp, idxs); 207 | break; 208 | case 16: 209 | furthest_point_sampling_kernel<16> 210 | <<>>(b, n, m, dataset, temp, idxs); 211 | break; 212 | case 8: 213 | furthest_point_sampling_kernel<8> 214 | <<>>(b, n, m, dataset, temp, idxs); 215 | break; 216 | case 4: 217 | furthest_point_sampling_kernel<4> 218 | <<>>(b, n, m, dataset, temp, idxs); 219 | break; 220 | case 2: 221 | furthest_point_sampling_kernel<2> 222 | <<>>(b, n, m, dataset, temp, idxs); 223 | break; 224 | case 1: 225 | furthest_point_sampling_kernel<1> 226 | <<>>(b, n, m, dataset, temp, idxs); 227 | break; 228 | default: 229 | furthest_point_sampling_kernel<512> 230 | <<>>(b, n, m, dataset, temp, idxs); 231 | } 232 | 233 | CUDA_CHECK_ERRORS(); 234 | } 235 | -------------------------------------------------------------------------------- /third_party/pointnet2/pointnet2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Testing customized ops. ''' 7 | 8 | import torch 9 | from torch.autograd import gradcheck 10 | import numpy as np 11 | 12 | import os 13 | import sys 14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 15 | sys.path.append(BASE_DIR) 16 | import pointnet2_utils 17 | 18 | def test_interpolation_grad(): 19 | batch_size = 1 20 | feat_dim = 2 21 | m = 4 22 | feats = torch.randn(batch_size, feat_dim, m, requires_grad=True).float().cuda() 23 | 24 | def interpolate_func(inputs): 25 | idx = torch.from_numpy(np.array([[[0,1,2],[1,2,3]]])).int().cuda() 26 | weight = torch.from_numpy(np.array([[[1,1,1],[2,2,2]]])).float().cuda() 27 | interpolated_feats = pointnet2_utils.three_interpolate(inputs, idx, weight) 28 | return interpolated_feats 29 | 30 | assert (gradcheck(interpolate_func, feats, atol=1e-1, rtol=1e-1)) 31 | 32 | if __name__=='__main__': 33 | test_interpolation_grad() 34 | -------------------------------------------------------------------------------- /third_party/pointnet2/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | import torch 8 | import torch.nn as nn 9 | from typing import List, Tuple 10 | 11 | class SharedMLP(nn.Sequential): 12 | 13 | def __init__( 14 | self, 15 | args: List[int], 16 | *, 17 | bn: bool = False, 18 | activation=nn.ReLU(inplace=True), 19 | preact: bool = False, 20 | first: bool = False, 21 | name: str = "" 22 | ): 23 | super().__init__() 24 | 25 | for i in range(len(args) - 1): 26 | self.add_module( 27 | name + 'layer{}'.format(i), 28 | Conv2d( 29 | args[i], 30 | args[i + 1], 31 | bn=(not first or not preact or (i != 0)) and bn, 32 | activation=activation 33 | if (not first or not preact or (i != 0)) else None, 34 | preact=preact 35 | ) 36 | ) 37 | 38 | 39 | class _BNBase(nn.Sequential): 40 | 41 | def __init__(self, in_size, batch_norm=None, name=""): 42 | super().__init__() 43 | self.add_module(name + "bn", batch_norm(in_size)) 44 | 45 | nn.init.constant_(self[0].weight, 1.0) 46 | nn.init.constant_(self[0].bias, 0) 47 | 48 | 49 | class BatchNorm1d(_BNBase): 50 | 51 | def __init__(self, in_size: int, *, name: str = ""): 52 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 53 | 54 | 55 | class BatchNorm2d(_BNBase): 56 | 57 | def __init__(self, in_size: int, name: str = ""): 58 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 59 | 60 | 61 | class BatchNorm3d(_BNBase): 62 | 63 | def __init__(self, in_size: int, name: str = ""): 64 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 65 | 66 | 67 | class _ConvBase(nn.Sequential): 68 | 69 | def __init__( 70 | self, 71 | in_size, 72 | out_size, 73 | kernel_size, 74 | stride, 75 | padding, 76 | activation, 77 | bn, 78 | init, 79 | conv=None, 80 | batch_norm=None, 81 | bias=True, 82 | preact=False, 83 | name="" 84 | ): 85 | super().__init__() 86 | 87 | bias = bias and (not bn) 88 | conv_unit = conv( 89 | in_size, 90 | out_size, 91 | kernel_size=kernel_size, 92 | stride=stride, 93 | padding=padding, 94 | bias=bias 95 | ) 96 | init(conv_unit.weight) 97 | if bias: 98 | nn.init.constant_(conv_unit.bias, 0) 99 | 100 | if bn: 101 | if not preact: 102 | bn_unit = batch_norm(out_size) 103 | else: 104 | bn_unit = batch_norm(in_size) 105 | 106 | if preact: 107 | if bn: 108 | self.add_module(name + 'bn', bn_unit) 109 | 110 | if activation is not None: 111 | self.add_module(name + 'activation', activation) 112 | 113 | self.add_module(name + 'conv', conv_unit) 114 | 115 | if not preact: 116 | if bn: 117 | self.add_module(name + 'bn', bn_unit) 118 | 119 | if activation is not None: 120 | self.add_module(name + 'activation', activation) 121 | 122 | 123 | class Conv1d(_ConvBase): 124 | 125 | def __init__( 126 | self, 127 | in_size: int, 128 | out_size: int, 129 | *, 130 | kernel_size: int = 1, 131 | stride: int = 1, 132 | padding: int = 0, 133 | activation=nn.ReLU(inplace=True), 134 | bn: bool = False, 135 | init=nn.init.kaiming_normal_, 136 | bias: bool = True, 137 | preact: bool = False, 138 | name: str = "" 139 | ): 140 | super().__init__( 141 | in_size, 142 | out_size, 143 | kernel_size, 144 | stride, 145 | padding, 146 | activation, 147 | bn, 148 | init, 149 | conv=nn.Conv1d, 150 | batch_norm=BatchNorm1d, 151 | bias=bias, 152 | preact=preact, 153 | name=name 154 | ) 155 | 156 | 157 | class Conv2d(_ConvBase): 158 | 159 | def __init__( 160 | self, 161 | in_size: int, 162 | out_size: int, 163 | *, 164 | kernel_size: Tuple[int, int] = (1, 1), 165 | stride: Tuple[int, int] = (1, 1), 166 | padding: Tuple[int, int] = (0, 0), 167 | activation=nn.ReLU(inplace=True), 168 | bn: bool = False, 169 | init=nn.init.kaiming_normal_, 170 | bias: bool = True, 171 | preact: bool = False, 172 | name: str = "" 173 | ): 174 | super().__init__( 175 | in_size, 176 | out_size, 177 | kernel_size, 178 | stride, 179 | padding, 180 | activation, 181 | bn, 182 | init, 183 | conv=nn.Conv2d, 184 | batch_norm=BatchNorm2d, 185 | bias=bias, 186 | preact=preact, 187 | name=name 188 | ) 189 | 190 | 191 | class Conv3d(_ConvBase): 192 | 193 | def __init__( 194 | self, 195 | in_size: int, 196 | out_size: int, 197 | *, 198 | kernel_size: Tuple[int, int, int] = (1, 1, 1), 199 | stride: Tuple[int, int, int] = (1, 1, 1), 200 | padding: Tuple[int, int, int] = (0, 0, 0), 201 | activation=nn.ReLU(inplace=True), 202 | bn: bool = False, 203 | init=nn.init.kaiming_normal_, 204 | bias: bool = True, 205 | preact: bool = False, 206 | name: str = "" 207 | ): 208 | super().__init__( 209 | in_size, 210 | out_size, 211 | kernel_size, 212 | stride, 213 | padding, 214 | activation, 215 | bn, 216 | init, 217 | conv=nn.Conv3d, 218 | batch_norm=BatchNorm3d, 219 | bias=bias, 220 | preact=preact, 221 | name=name 222 | ) 223 | 224 | 225 | class FC(nn.Sequential): 226 | 227 | def __init__( 228 | self, 229 | in_size: int, 230 | out_size: int, 231 | *, 232 | activation=nn.ReLU(inplace=True), 233 | bn: bool = False, 234 | init=None, 235 | preact: bool = False, 236 | name: str = "" 237 | ): 238 | super().__init__() 239 | 240 | fc = nn.Linear(in_size, out_size, bias=not bn) 241 | if init is not None: 242 | init(fc.weight) 243 | if not bn: 244 | nn.init.constant_(fc.bias, 0) 245 | 246 | if preact: 247 | if bn: 248 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 249 | 250 | if activation is not None: 251 | self.add_module(name + 'activation', activation) 252 | 253 | self.add_module(name + 'fc', fc) 254 | 255 | if not preact: 256 | if bn: 257 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 258 | 259 | if activation is not None: 260 | self.add_module(name + 'activation', activation) 261 | 262 | def set_bn_momentum_default(bn_momentum): 263 | 264 | def fn(m): 265 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 266 | m.momentum = bn_momentum 267 | 268 | return fn 269 | 270 | 271 | class BNMomentumScheduler(object): 272 | 273 | def __init__( 274 | self, model, bn_lambda, last_epoch=-1, 275 | setter=set_bn_momentum_default 276 | ): 277 | if not isinstance(model, nn.Module): 278 | raise RuntimeError( 279 | "Class '{}' is not a PyTorch nn Module".format( 280 | type(model).__name__ 281 | ) 282 | ) 283 | 284 | self.model = model 285 | self.setter = setter 286 | self.lmbd = bn_lambda 287 | 288 | self.step(last_epoch + 1) 289 | self.last_epoch = last_epoch 290 | 291 | def step(self, epoch=None): 292 | if epoch is None: 293 | epoch = self.last_epoch + 1 294 | 295 | self.last_epoch = epoch 296 | self.model.apply(self.setter(self.lmbd(epoch))) 297 | 298 | 299 | -------------------------------------------------------------------------------- /third_party/pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | import glob 9 | import os.path as osp 10 | 11 | this_dir = osp.dirname(osp.abspath(__file__)) 12 | 13 | _ext_src_root = "_ext_src" 14 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 15 | "{}/src/*.cu".format(_ext_src_root) 16 | ) 17 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 18 | 19 | setup( 20 | name='pointnet2', 21 | ext_modules=[ 22 | CUDAExtension( 23 | name='pointnet2._ext', 24 | sources=_ext_sources, 25 | extra_compile_args={ 26 | "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 27 | "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 28 | }, 29 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")], 30 | ) 31 | ], 32 | cmdclass={ 33 | 'build_ext': BuildExtension 34 | } 35 | ) 36 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/DepthContrast/b8257890c94f7c58aeb5cefeb91af031692611d6/utils/__init__.py -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import datetime 9 | import sys 10 | 11 | import torch 12 | from torch import distributed as dist 13 | 14 | 15 | class Logger(object): 16 | def __init__(self, quiet=False, log_fn=None, rank=0, prefix=""): 17 | self.rank = rank if rank is not None else 0 18 | self.quiet = quiet 19 | self.log_fn = log_fn 20 | 21 | self.prefix = "" 22 | if prefix: 23 | self.prefix = prefix + ' | ' 24 | 25 | self.file_pointers = [] 26 | if self.rank == 0: 27 | if self.quiet: 28 | open(log_fn, 'w').close() 29 | 30 | def add_line(self, content): 31 | if self.rank == 0: 32 | msg = self.prefix+content 33 | if self.quiet: 34 | fp = open(self.log_fn, 'a') 35 | fp.write(msg+'\n') 36 | fp.flush() 37 | fp.close() 38 | else: 39 | print(msg) 40 | sys.stdout.flush() 41 | 42 | 43 | class ProgressMeter(object): 44 | def __init__(self, num_batches, meters, phase, epoch=None, logger=None, tb_writter=None): 45 | self.batches_per_epoch = num_batches 46 | self.batch_fmtstr = self._get_batch_fmtstr(epoch, num_batches) 47 | self.meters = meters 48 | self.phase = phase 49 | self.epoch = epoch 50 | self.logger = logger 51 | self.tb_writter = tb_writter 52 | 53 | def display(self, batch): 54 | step = self.epoch * self.batches_per_epoch + batch 55 | date = str(datetime.datetime.now()) 56 | entries = ['{} | {} {}'.format(date, self.phase, self.batch_fmtstr.format(batch))] 57 | entries += [str(meter) for meter in self.meters] 58 | if self.logger is None: 59 | print('\t'.join(entries)) 60 | else: 61 | self.logger.add_line('\t'.join(entries)) 62 | 63 | if self.tb_writter is not None: 64 | for meter in self.meters: 65 | self.tb_writter.add_scalar('{}-batch/{}'.format(self.phase, meter.name), meter.val, step) 66 | 67 | def _get_batch_fmtstr(self, epoch, num_batches): 68 | num_digits = len(str(num_batches // 1)) 69 | fmt = '{:' + str(num_digits) + 'd}' 70 | epoch_str = '[{}]'.format(epoch) if epoch is not None else '' 71 | return epoch_str+'[' + fmt + '/' + fmt.format(num_batches) + ']' 72 | 73 | def synchronize_meters(self, cur_gpu): 74 | metrics = torch.tensor([m.avg for m in self.meters]).cuda(cur_gpu) 75 | metrics_gather = [torch.ones_like(metrics) for _ in range(dist.get_world_size())] 76 | dist.all_gather(metrics_gather, metrics) 77 | 78 | metrics = torch.stack(metrics_gather).mean(0).cpu().numpy() 79 | for meter, m in zip(self.meters, metrics): 80 | meter.avg = m 81 | -------------------------------------------------------------------------------- /utils/metrics_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | from collections import deque 10 | 11 | 12 | def accuracy(output, target, topk=(1,)): 13 | """Computes the accuracy over the k top predictions for the specified values of k""" 14 | with torch.no_grad(): 15 | maxk = max(topk) 16 | batch_size = target.size(0) 17 | 18 | _, pred = output.topk(maxk, 1, True, True) 19 | pred = pred.t() 20 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 21 | 22 | res = [] 23 | for k in topk: 24 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 25 | res.append(correct_k.mul_(100.0 / batch_size)) 26 | return res 27 | 28 | 29 | class AverageMeter(object): 30 | """Computes and stores the average and current value""" 31 | def __init__(self, name, fmt=':f', window_size=0): 32 | self.name = name 33 | self.fmt = fmt 34 | self.window_size = window_size 35 | self.reset() 36 | 37 | def reset(self): 38 | if self.window_size > 0: 39 | self.q = deque(maxlen=self.window_size) 40 | self.val = 0 41 | self.avg = 0 42 | self.sum = 0 43 | self.count = 0 44 | 45 | def update(self, val, n=1): 46 | self.val = val 47 | if self.window_size > 0: 48 | self.q.append((val, n)) 49 | self.count = sum([n for v, n in self.q]) 50 | self.sum = sum([v * n for v, n in self.q]) 51 | else: 52 | self.sum += val * n 53 | self.count += n 54 | self.avg = self.sum / self.count 55 | 56 | def __str__(self): 57 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 58 | return fmtstr.format(**self.__dict__) 59 | 60 | 61 | --------------------------------------------------------------------------------