├── .gitmodules
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── configs
├── point_vox_lidar_template.yaml
├── point_vox_template.yaml
├── point_within_format.yaml
└── point_within_lidar_template.yaml
├── criterions
├── __init__.py
└── nce_loss_moco.py
├── data
├── README.md
├── redwood
│ ├── README.md
│ └── extract_pointcloud.py
├── scannet
│ ├── README.md
│ ├── SensorData.py
│ ├── extract_pointcloud.py
│ └── reader.py
└── waymo
│ ├── README.md
│ └── extract_pointcloud.py
├── datasets
├── __init__.py
├── collators
│ ├── __init__.py
│ ├── point_moco_collator.py
│ ├── point_vox_moco_collator.py
│ ├── point_vox_moco_lidar_collator.py
│ └── vox_moco_collator.py
├── depth_dataset.py
└── transforms
│ ├── augment3d.py
│ ├── transforms.py
│ └── voxelizer.py
├── imgs
└── method.jpg
├── main.py
├── models
├── __init__.py
├── base_ssl3d_model.py
└── trunks
│ ├── __init__.py
│ ├── mlp.py
│ ├── pointnet.py
│ ├── pointnet2_backbone.py
│ ├── smlp.py
│ ├── spconv
│ ├── lib
│ │ └── math_functions.py
│ └── models
│ │ ├── __init__.py
│ │ ├── conditional_random_fields.py
│ │ ├── model.py
│ │ ├── modules
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── resnet_block.py
│ │ └── senet_block.py
│ │ ├── res16unet.py
│ │ ├── resnet.py
│ │ ├── resunet.py
│ │ └── wrapper.py
│ ├── spconv_backbone.py
│ └── spconv_unet.py
├── requirements.txt
├── scripts
├── multinode-wrapper.py
├── pretrain_node1.sh
├── pretrain_node4.sh
└── singlenode-wrapper.py
├── third_party
└── pointnet2
│ ├── _ext_src
│ ├── include
│ │ ├── ball_query.h
│ │ ├── cuda_utils.h
│ │ ├── group_points.h
│ │ ├── interpolate.h
│ │ ├── sampling.h
│ │ └── utils.h
│ └── src
│ │ ├── ball_query.cpp
│ │ ├── ball_query_gpu.cu
│ │ ├── bindings.cpp
│ │ ├── group_points.cpp
│ │ ├── group_points_gpu.cu
│ │ ├── interpolate.cpp
│ │ ├── interpolate_gpu.cu
│ │ ├── sampling.cpp
│ │ └── sampling_gpu.cu
│ ├── pointnet2_modules.py
│ ├── pointnet2_test.py
│ ├── pointnet2_utils.py
│ ├── pytorch_utils.py
│ └── setup.py
└── utils
├── __init__.py
├── logger.py
├── main_utils.py
└── metrics_utils.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/OpenPCDet"]
2 | path = third_party/OpenPCDet
3 | url = https://github.com/zaiweizhang/OpenPCDet
4 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
4 | Please read the [full text](https://code.fb.com/codeofconduct/)
5 | so that you can understand what actions will and will not be tolerated.
6 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | In the context of this project, we do not expect pull requests.
4 | If you find a bug, or would like to suggest an improvement, please open an issue.
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Supervised Pretraining of 3D Features on any Point-Cloud
2 | This code provides a PyTorch implementation and pretrained models for **DepthContrast**, as described in the paper [Self-Supervised Pretraining of 3D Features on any Point-Cloud](http://arxiv.org/abs/2101.02691).
3 |
4 |
5 |

6 |
7 |
8 | DepthContrast is an easy to implement self-supervised method that works across model architectures, input data formats, indoor/outdoor 3D, single/multi-view 3D data.
9 | Similarly to 2D contrastive approaches, DepthContrast learns representations by comparing transformations of a 3D pointcloud/voxel. It does not require any multi-view information between frames, such as point-to-point correspondances. It makes our framework generalize to any 3D pointcloud or voxel input.
10 | DepthContrast pretrains high capacity models for 3D recognition tasks, and leverages large-scale 3D data. It shows state-of-the-art performance on detection and segmentation benchmarks, outperforming all prior work on detection.
11 |
12 | # Model Zoo
13 |
14 | We release our PointNet++ and MinkowskiEngine UNet models pretrained with DepthContrast with the hope that other researchers might also benefit from these pretrained backbones. Due to license issue, models pretrained on Waymo cannot be released. For PointnetMSG and Spconv-UNet models, we encourage the researchers to train by themselves using the provided script.
15 |
16 | We first provide PointNet++ models with different sizes.
17 | | network | epochs | batch-size | ScanNet Det with VoteNet | url | args |
18 | |-------------------|---------------------|---------------------|--------------------|--------------------|--------------------|
19 | | PointNet++-1x | 150 | 1024 | 61.9 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet1x/checkpoint-ep150.pth.tar) | [config](./configs/point_within_format.yaml) |
20 | | PointNet++-2x | 200 | 1024 | 63.3 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet2x/checkpoint-ep200.pth.tar) | [config](./configs/point_within_format.yaml) |
21 | | PointNet++-3x | 150 | 1024 | 64.1 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet3x/checkpoint-ep150.pth.tar) | [config](./configs/point_within_format.yaml) |
22 | | PointNet++-4x | 100 | 1024 | 63.8 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet4x/checkpoint-ep100.pth.tar) | [config](./configs/point_within_format.yaml) |
23 |
24 | The ScanNet detection evaluation metric is mAP at IOU=0.25. You need to change the scale parameter in the config files accordingly.
25 |
26 | We provide the joint training results here, with different epochs. We use epoch 400 to generate the results reported in the paper.
27 |
28 | | Backbone | epochs | batch-size | url | args |
29 | |-------------------|-------------------|---------------------|--------------------|--------------------|
30 | | PointNet++ & MinkowskiEngine UNet | 300 | 1024 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet_unet_joint/checkpoint-ep300.pth.tar) | [config](./configs/point_vox_template.yaml) |
31 | | PointNet++ & MinkowskiEngine UNet | 400 | 1024 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet_unet_joint/checkpoint-ep400.pth.tar) | [config](./configs/point_vox_template.yaml) |
32 | | PointNet++ & MinkowskiEngine UNet | 500 | 1024 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet_unet_joint/checkpoint-ep500.pth.tar) | [config](./configs/point_vox_template.yaml) |
33 | | PointNet++ & MinkowskiEngine UNet | 600 | 1024 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet_unet_joint/checkpoint-ep600.pth.tar) | [config](./configs/point_vox_template.yaml) |
34 | | PointNet++ & MinkowskiEngine UNet | 700 | 1024 | [model](https://dl.fbaipublicfiles.com/DepthContrast/pointnet_unet_joint/checkpoint-ep700.pth.tar) | [config](./configs/point_vox_template.yaml) |
35 |
36 | # Running DepthContrast unsupervised training
37 |
38 | ## Requirements
39 | You can use the requirements.txt to setup the environment.
40 | First download the git-repo and install the pointnet modules:
41 | ```
42 | git clone --recursive https://github.com/facebookresearch/DepthContrast.git
43 | cd pointnet2
44 | python setup.py install
45 | ```
46 | Then install all other packages:
47 | ```
48 | pip install -r requirements.txt
49 | ```
50 | or
51 | ```
52 | conda install --file requirements.txt
53 | ```
54 |
55 | For voxel representation, you have to install MinkowskiEngine. Please see [here](https://github.com/chrischoy/SpatioTemporalSegmentation) on how to install it.
56 |
57 | For the lidar point cloud pretraining, we use models from [OpenPCDet](https://github.com/open-mmlab/OpenPCDet). It should be in the third_party folder. To install OpenPCDet, you need to install [spconv](https://github.com/traveller59/spconv), which is a bit difficult to install and may not be compatible with MinkowskiEngine. Thus, we suggest you use a different conda environment for lidar point cloud pretraining.
58 |
59 | ## Singlenode training
60 | DepthContrast is very simple to implement and experiment with.
61 |
62 | To experiment with it on one GPU and debugging, you can do:
63 | ```
64 | python main.py /path/to/cfg/file
65 | ```
66 |
67 | For the actual training, please use the distributed trainer.
68 | For multi-gpu training in one node, you can run:
69 | ```
70 | python main.py /path/to/cfg_file --multiprocessing-distributed --world-size 1 --rank 0 --ngpus number_of_gpus
71 | ```
72 | To run it with just one gpu, just set the --ngpus to 1.
73 | For submitting it to a slurm node, you can use ./scripts/pretrain_node1.sh. For hyper-parameter tuning, please change the config files.
74 |
75 | ## Multinode training
76 | Distributed training is available via Slurm. We provide several [SBATCH scripts](./scripts) to reproduce our results.
77 | For example, to train DepthContrast on 4 nodes and 32 GPUs with a batch size of 1024 run:
78 | ```
79 | sbatch ./scripts/pretrain_node4.sh /path/to/cfg_file
80 | ```
81 | Note that you might need to remove the copyright header from the sbatch file to launch it.
82 |
83 | # Evaluating models
84 | For votenet finetuning, please checkout this [repo](https://github.com/zaiweizhang/votenet) for more details.
85 |
86 | For H3DNet finetuning, please checkout this [repo](https://github.com/zaiweizhang/H3DNet) for more details.
87 |
88 | For voxel scene segmentation task finetuning, please checkout this [repo](https://github.com/zaiweizhang/SpatioTemporalSegmentation) for more details.
89 |
90 | For lidar point cloud object detection task finetuning, please checkout this [repo](https://github.com/zaiweizhang/OpenPCDet) for more details.
91 |
92 | # Common Issues
93 | For help or issues using DepthContrast, please submit a GitHub issue.
94 |
95 | ## License
96 | See the [LICENSE](LICENSE) file for more details.
97 |
98 | ## Citation
99 | If you find this repository useful in your research, please cite:
100 |
101 | ```
102 | @inproceedings{zhang_depth_contrast,
103 | title={Self-Supervised Pretraining of 3D Features on any Point-Cloud},
104 | author={Zhang, Zaiwei and Girdhar, Rohit and Joulin, Armand and Misra, Ishan},
105 | journal={arXiv preprint arXiv:2101.02691},
106 | year={2021}
107 | }
108 | ```
109 |
--------------------------------------------------------------------------------
/configs/point_vox_lidar_template.yaml:
--------------------------------------------------------------------------------
1 | resume: false
2 | test_only: false
3 | num_workers: 10
4 |
5 | required_devices: 1
6 | no_test: false
7 | debug: false
8 | log2tb: true
9 | allow_double_bs: false
10 | seed: 0
11 | distributed: false
12 | test_freq: 10
13 | print_freq: 5
14 |
15 | dataset:
16 | DATASET_NAMES: [waymo]
17 | DATA_PATHS: ['/path/to/waymo.npy']
18 | BATCHSIZE_PER_REPLICA: 32
19 | LABEL_TYPE: sample_index
20 | DATA_TYPE: point_vox
21 | Lidar: True
22 | VOX: True
23 | POINT_TRANSFORMS:
24 | - name: randomcuboidLidar
25 | crop: 0.5
26 | npoints: 10000
27 | randcrop: 1.0
28 | aspect: 0.75
29 | - name: randomdrop
30 | crop: 0.2
31 | - name: RandomFlipLidar
32 | - name: RandomRotateLidar
33 | - name: RandomScaleLidar
34 | - name: ToTensorLidar
35 | COLLATE_FUNCTION: "point_vox_moco_collator"
36 | INPUT_KEY_NAMES: ["points", "points_moco", "vox", "vox_moco"]
37 | DROP_LAST: True
38 |
39 | optimizer:
40 | name: "sgd"
41 | weight_decay: 0.0001
42 | momentum: 0.9
43 | nesterov: False
44 | num_epochs: 1000
45 | lr:
46 | name: "cosine"
47 | base_lr: 0.12
48 | final_lr: 0.00012
49 |
50 | model:
51 | name: "PointnetMSG_UnetV2"
52 | model_dir: "checkpoints/pointnetMSG_UnetV2_within_format"
53 | model_input: ["points", "points_moco", "vox", "vox_moco"]
54 | model_feature: [["fp2"], ["fp2"], ["conv4"], ["conv4"]]
55 | Lidar: True
56 | arch_point: "pointnet_msg"
57 | args_point:
58 | use_mlp: True
59 | mlp_dim: [128, 128, 128]
60 | arch_vox: "UNetV2"
61 | args_vox:
62 | use_mlp: True
63 | mlp_dim: [128, 128, 128]
64 |
65 | loss:
66 | name: "NCELossMoco"
67 | args:
68 | two_domain: True
69 | LOSS_TYPE: NPID
70 | OTHER_INPUT: True
71 | within_format_weight0: 1.0
72 | within_format_weight1: 1.0
73 | across_format_weight0: 0.0
74 | across_format_weight1: 0.0
75 | NCE_LOSS:
76 | NORM_EMBEDDING: True
77 | TEMPERATURE: 0.1
78 | LOSS_TYPE: cross_entropy
79 | NUM_NEGATIVES: 65536
80 | EMBEDDING_DIM: 128
81 |
--------------------------------------------------------------------------------
/configs/point_vox_template.yaml:
--------------------------------------------------------------------------------
1 | resume: false
2 | test_only: false
3 | num_workers: 10
4 |
5 | required_devices: 1
6 | no_test: false
7 | debug: false
8 | log2tb: true
9 | allow_double_bs: false
10 | seed: 0
11 | distributed: false
12 | test_freq: 2
13 | print_freq: 5
14 |
15 | dataset:
16 | DATASET_NAMES: [scannet]
17 | DATA_PATHS: ['/path/to/datalist.npy']
18 | BATCHSIZE_PER_REPLICA: 32
19 | LABEL_TYPE: sample_index
20 | DATA_TYPE: point_vox
21 | VOX: True
22 | POINT_TRANSFORMS:
23 | - name: randomcuboid
24 | crop: 0.5
25 | npoints: 10000
26 | randcrop: 1.0
27 | aspect: 0.75
28 | - name: randomdrop
29 | crop: 0.2
30 | - name: multiscale
31 | - name: RandomFlip
32 | - name: RandomRotateAll
33 | - name: RandomScale
34 | - name: ToTensor
35 | COLLATE_FUNCTION: "point_vox_moco_collator"
36 | INPUT_KEY_NAMES: ["points", "points_moco", "vox", "vox_moco"]
37 | DROP_LAST: True
38 |
39 | optimizer:
40 | name: "sgd"
41 | weight_decay: 0.0001
42 | momentum: 0.9
43 | nesterov: False
44 | num_epochs: 1000
45 | lr:
46 | name: "cosine"
47 | base_lr: 0.12
48 | final_lr: 0.00012
49 |
50 | model:
51 | name: "Pointnet1X_Unet256"
52 | model_dir: "checkpoints/pointnet1x_Unet256/"
53 | model_input: ["points", "points_moco", "vox", "vox_moco"]
54 | model_feature: [["fp2"], ["fp2"], ["plane7"], ["plane7"]]
55 | arch_point: "pointnet"
56 | args_point:
57 | scale: 1
58 | use_mlp: True
59 | mlp_dim: [512, 512, 128]
60 | arch_vox: "unet"
61 | args_vox:
62 | use_mlp: True
63 | mlp_dim: [256, 256, 128]
64 |
65 | loss:
66 | name: "NCELossMoco"
67 | args:
68 | two_domain: True
69 | LOSS_TYPE: NPID,CMC
70 | OTHER_INPUT: True
71 | within_format_weight0: 0.5
72 | within_format_weight1: 0.5
73 | across_format_weight0: 0.5
74 | across_format_weight1: 0.5
75 | NCE_LOSS:
76 | NORM_EMBEDDING: True
77 | TEMPERATURE: 0.1
78 | LOSS_TYPE: cross_entropy
79 | NUM_NEGATIVES: 131072
80 | EMBEDDING_DIM: 128
81 |
--------------------------------------------------------------------------------
/configs/point_within_format.yaml:
--------------------------------------------------------------------------------
1 | resume: false
2 | test_only: false
3 | num_workers: 10
4 |
5 | required_devices: 1
6 | no_test: false
7 | debug: false
8 | log2tb: true
9 | allow_double_bs: false
10 | seed: 0
11 | distributed: false
12 | test_freq: 10
13 | print_freq: 5
14 |
15 | dataset:
16 | DATASET_NAMES: [scannet]
17 | DATA_PATHS: ['/path/to/datalist.npy']
18 | DATA_LIMIT: -1
19 | BATCHSIZE_PER_REPLICA: 32
20 | LABEL_TYPE: sample_index
21 | DATA_TYPE: points
22 | VOX: False
23 | POINT_TRANSFORMS:
24 | - name: randomcuboid
25 | crop: 0.5
26 | npoints: 10000
27 | randcrop: 1.0
28 | aspect: 0.75
29 | - name: randomdrop
30 | crop: 0.2
31 | - name: multiscale
32 | - name: RandomFlip
33 | - name: RandomRotateAll
34 | - name: RandomScale
35 | - name: ToTensor
36 | COLLATE_FUNCTION: "point_moco_collator"
37 | INPUT_KEY_NAMES: ["points", "points_moco"]
38 | DROP_LAST: True
39 |
40 | optimizer:
41 | name: "sgd"
42 | weight_decay: 0.0001
43 | momentum: 0.9
44 | nesterov: False
45 | num_epochs: 1000
46 | lr:
47 | name: "cosine"
48 | base_lr: 0.12
49 | final_lr: 0.00012
50 |
51 | model:
52 | name: "Pointnet1X"
53 | model_dir: "checkpoints/pointnet1x_within_format"
54 | model_input: ["points", "points_moco"]
55 | model_feature: [["fp2"], ["fp2"]]
56 | arch_point: "pointnet"
57 | args_point:
58 | scale: 1
59 | use_mlp: True
60 | mlp_dim: [512, 512, 128]
61 | loss:
62 | name: "NCELossMoco"
63 | args:
64 | two_domain: True
65 | LOSS_TYPE: NPID
66 | OTHER_INPUT: False
67 | within_format_weight0: 1.0
68 | within_format_weight1: 0.0
69 | across_format_weight0: 0.0
70 | across_format_weight1: 0.0
71 | NCE_LOSS:
72 | NORM_EMBEDDING: True
73 | TEMPERATURE: 0.1
74 | LOSS_TYPE: cross_entropy
75 | NUM_NEGATIVES: 131072
76 | EMBEDDING_DIM: 128
77 |
--------------------------------------------------------------------------------
/configs/point_within_lidar_template.yaml:
--------------------------------------------------------------------------------
1 | resume: false
2 | test_only: false
3 | num_workers: 10
4 |
5 | required_devices: 1
6 | no_test: false
7 | debug: false
8 | log2tb: true
9 | allow_double_bs: false
10 | seed: 0
11 | distributed: false
12 | test_freq: 10
13 | print_freq: 5
14 |
15 | dataset:
16 | DATASET_NAMES: [waymo]
17 | DATA_PATHS: ['/path/to/waymo.npy']
18 | BATCHSIZE_PER_REPLICA: 8
19 | LABEL_TYPE: sample_index
20 | DATA_TYPE: points
21 | Lidar: True
22 | VOX: False
23 | POINT_TRANSFORMS:
24 | - name: randomcuboidLidar
25 | crop: 0.5
26 | npoints: 10000
27 | randcrop: 1.0
28 | aspect: 0.75
29 | - name: randomdrop
30 | crop: 0.2
31 | - name: RandomFlipLidar
32 | - name: RandomRotateLidar
33 | - name: RandomScaleLidar
34 | - name: ToTensorLidar
35 | COLLATE_FUNCTION: "point_moco_collator"
36 | INPUT_KEY_NAMES: ["points", "points_moco"]
37 | DROP_LAST: True
38 |
39 | optimizer:
40 | name: "sgd"
41 | weight_decay: 0.0001
42 | momentum: 0.9
43 | nesterov: False
44 | num_epochs: 1000
45 | lr:
46 | name: "cosine"
47 | base_lr: 0.12
48 | final_lr: 0.00012
49 |
50 | model:
51 | name: "PointnetMSG"
52 | model_dir: "checkpoints/pointnetMSG_within_format"
53 | model_input: ["points", "points_moco"]
54 | model_feature: [["fp2"], ["fp2"]]
55 | Lidar: True
56 | VOX: False
57 | arch_point: "pointnet_msg"
58 | args_point:
59 | use_mlp: True
60 | mlp_dim: [128, 128, 128]
61 |
62 | loss:
63 | name: "NCELossMoco"
64 | args:
65 | two_domain: True
66 | LOSS_TYPE: NPID
67 | OTHER_INPUT: False
68 | within_format_weight0: 1.0
69 | within_format_weight1: 0.0
70 | across_format_weight0: 0.0
71 | across_format_weight1: 0.0
72 | NCE_LOSS:
73 | NORM_EMBEDDING: True
74 | TEMPERATURE: 0.1
75 | LOSS_TYPE: cross_entropy
76 | NUM_NEGATIVES: 65536
77 | EMBEDDING_DIM: 128
78 |
--------------------------------------------------------------------------------
/criterions/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | from .nce_loss_moco import *
8 |
--------------------------------------------------------------------------------
/criterions/nce_loss_moco.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 | import logging
9 | import math
10 | import pprint
11 |
12 | import numpy as np
13 | import torch
14 | from torch import nn
15 |
16 | # utils
17 | @torch.no_grad()
18 | def concat_all_gather(tensor):
19 | """
20 | Performs all_gather operation on the provided tensors.
21 | *** Warning ***: torch.distributed.all_gather has no gradient.
22 | """
23 | if not (torch.distributed.is_initialized()):
24 | return tensor
25 |
26 | tensors_gather = [torch.ones_like(tensor)
27 | for _ in range(torch.distributed.get_world_size())]
28 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
29 |
30 | output = torch.cat(tensors_gather, dim=0)
31 | return output
32 |
33 | class NCELossMoco(nn.Module):
34 | """
35 | Distributed version of the NCE loss. It performs an "all_gather" to gather
36 | the allocated buffers like memory no a single gpu. For this, Pytorch distributed
37 | backend is used. If using NCCL, one must ensure that all the buffer are on GPU.
38 | This class supports training using both NCE and CrossEntropy (InfoNCE).
39 | """
40 |
41 | def __init__(self, config):
42 | super(NCELossMoco, self).__init__()
43 |
44 | assert config["NCE_LOSS"]["LOSS_TYPE"] in [
45 | "cross_entropy",
46 | ], f"Supported types are cross_entropy."
47 |
48 | self.loss_type = config["NCE_LOSS"]["LOSS_TYPE"]
49 | self.loss_list = config["LOSS_TYPE"].split(",")
50 | self.other_queue = config["OTHER_INPUT"]
51 |
52 | self.npid0_w = float(config["within_format_weight0"])
53 | self.npid1_w = float(config["within_format_weight1"])
54 | self.cmc0_w = float(config["across_format_weight0"])
55 | self.cmc1_w = float(config["across_format_weight1"])
56 |
57 | self.K = int(config["NCE_LOSS"]["NUM_NEGATIVES"])
58 | self.dim = int(config["NCE_LOSS"]["EMBEDDING_DIM"])
59 | self.T = float(config["NCE_LOSS"]["TEMPERATURE"])
60 |
61 | self.register_buffer("queue", torch.randn(self.dim, self.K))
62 | self.queue = nn.functional.normalize(self.queue, dim=0)
63 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
64 |
65 | if self.other_queue:
66 | self.register_buffer("queue_other", torch.randn(self.dim, self.K))
67 | self.queue_other = nn.functional.normalize(self.queue_other, dim=0)
68 | self.register_buffer("queue_other_ptr", torch.zeros(1, dtype=torch.long))
69 |
70 | # cross-entropy loss. Also called InfoNCE
71 | self.xe_criterion = nn.CrossEntropyLoss()
72 |
73 | # other constants
74 | self.normalize_embedding = config["NCE_LOSS"]["NORM_EMBEDDING"]
75 |
76 | @classmethod
77 | def from_config(cls, config):
78 | return cls(config)
79 |
80 | @torch.no_grad()
81 | def _dequeue_and_enqueue(self, keys, okeys=None):
82 | # gather keys before updating queue
83 | keys = concat_all_gather(keys)
84 |
85 | batch_size = keys.shape[0]
86 |
87 | ptr = int(self.queue_ptr)
88 | assert self.K % batch_size == 0 # for simplicity
89 |
90 | # replace the keys at ptr (dequeue and enqueue)
91 | self.queue[:, ptr:ptr + batch_size] = torch.transpose(keys, 0, 1)
92 | ptr = (ptr + batch_size) % self.K # move pointer
93 |
94 | self.queue_ptr[0] = ptr
95 |
96 | if self.other_queue:
97 | # gather keys before updating queue
98 | okeys = concat_all_gather(okeys)
99 |
100 | other_ptr = int(self.queue_other_ptr)
101 |
102 | # replace the keys at ptr (dequeue and enqueue)
103 | self.queue_other[:, other_ptr:other_ptr + batch_size] = torch.transpose(okeys, 0, 1)#okeys.T
104 | other_ptr = (other_ptr + batch_size) % self.K # move pointer
105 |
106 | self.queue_other_ptr[0] = other_ptr
107 |
108 | def forward(self, output):
109 | assert isinstance(
110 | output, list
111 | ), "Model output should be a list of tensors. Got Type {}".format(type(output))
112 |
113 | if self.normalize_embedding:
114 | normalized_output1 = nn.functional.normalize(output[0], dim=1, p=2)
115 | normalized_output2 = nn.functional.normalize(output[1], dim=1, p=2)
116 | if self.other_queue:
117 | normalized_output3 = nn.functional.normalize(output[2], dim=1, p=2)
118 | normalized_output4 = nn.functional.normalize(output[3], dim=1, p=2)
119 |
120 | # positive logits: Nx1
121 | l_pos = torch.einsum('nc,nc->n', [normalized_output1, normalized_output2]).unsqueeze(-1)
122 |
123 | # negative logits: NxK
124 | l_neg = torch.einsum('nc,ck->nk', [normalized_output1, self.queue.clone().detach()])
125 |
126 | # logits: Nx(1+K)
127 | logits = torch.cat([l_pos, l_neg], dim=1)
128 |
129 | # apply temperature
130 | logits /= self.T
131 |
132 | if self.other_queue:
133 |
134 | l_pos_p2i = torch.einsum('nc,nc->n', [normalized_output1, normalized_output4]).unsqueeze(-1)
135 | l_neg_p2i = torch.einsum('nc,ck->nk', [normalized_output1, self.queue_other.clone().detach()])
136 | logits_p2i = torch.cat([l_pos_p2i, l_neg_p2i], dim=1)
137 | logits_p2i /= self.T
138 |
139 |
140 | l_pos_i2p = torch.einsum('nc,nc->n', [normalized_output3, normalized_output2]).unsqueeze(-1)
141 | l_neg_i2p = torch.einsum('nc,ck->nk', [normalized_output3, self.queue.clone().detach()])
142 | logits_i2p = torch.cat([l_pos_i2p, l_neg_i2p], dim=1)
143 | logits_i2p /= self.T
144 |
145 |
146 | l_pos_other = torch.einsum('nc,nc->n', [normalized_output3, normalized_output4]).unsqueeze(-1)
147 | l_neg_other = torch.einsum('nc,ck->nk', [normalized_output3, self.queue_other.clone().detach()])
148 | logits_other = torch.cat([l_pos_other, l_neg_other], dim=1)
149 | logits_other /= (self.T)
150 |
151 | if self.other_queue:
152 | self._dequeue_and_enqueue(normalized_output2, okeys=normalized_output4)
153 | else:
154 | self._dequeue_and_enqueue(normalized_output2)
155 |
156 |
157 | labels = torch.zeros(
158 | logits.shape[0], device=logits.device, dtype=torch.int64
159 | )
160 |
161 | loss_npid = self.xe_criterion(torch.squeeze(logits), labels)
162 |
163 | loss_npid_other = torch.tensor(0)
164 | loss_cmc_p2i = torch.tensor(0)
165 | loss_cmc_i2p = torch.tensor(0)
166 |
167 | if self.other_queue:
168 | loss_cmc_p2i = self.xe_criterion(torch.squeeze(logits_p2i), labels)
169 | loss_cmc_i2p = self.xe_criterion(torch.squeeze(logits_i2p), labels)
170 | loss_npid_other = self.xe_criterion(torch.squeeze(logits_other), labels)
171 |
172 | curr_loss = 0
173 | for ltype in self.loss_list:
174 | if ltype == "CMC":
175 | curr_loss += loss_cmc_p2i * self.cmc0_w + loss_cmc_i2p * self.cmc1_w
176 | elif ltype == "NPID":
177 | curr_loss += loss_npid * self.npid0_w
178 | curr_loss += loss_npid_other * self.npid1_w
179 | else:
180 | curr_loss = 0
181 | curr_loss += loss_npid * self.npid0_w
182 |
183 | loss = curr_loss
184 |
185 | return loss, [loss_npid, loss_npid_other, loss_cmc_p2i, loss_cmc_i2p]
186 |
187 | def __repr__(self):
188 | repr_dict = {
189 | "name": self._get_name(),
190 | "loss_type": self.loss_type,
191 | }
192 | return pprint.pformat(repr_dict, indent=2)
193 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | ### Prepare Dataset
2 |
3 | We use scannet and redwood for Pointnet++ and MinkowskiEngine UNet pretraining.
4 | We use waymo for PointnetMSG and Spconv-UNet model pretraining.
5 |
6 | Please see the specific instructions under each folder for how to generate the training data.
7 |
8 | Once you have generated the training data, the framework just takes a .npy file which consists of a list of paths to the extracted pointclouds:
9 |
10 | [/path/to/pt1, /path/to/pt2, path/to/pt3, ..... ]
11 |
12 |
--------------------------------------------------------------------------------
/data/redwood/README.md:
--------------------------------------------------------------------------------
1 | ### Prepare Dataset
2 |
3 | 1. Download Redwood data [HERE](https://github.com/intel-isl/redwood-3dscan).
4 |
5 | 2. Extract the pointclouds of the desired classes using extract_pointcloud.py
6 |
7 | python extract_pointcloud.py /path/to/extracted_data /path/to/extracted_pointcloud_visualization /path/to/extracted_pointclouds redwood_datalist.npy
8 |
9 | The visualizations are optional.
10 |
11 | 3. In our experiment, we use the following 10 classes:
12 |
13 | car, chair, table, bench, bicycle, plant, playground, sculpture, sign, trash_container
14 |
15 | 4. You will need to concatenate the datalist files if you use data from multiple categories.
16 |
--------------------------------------------------------------------------------
/data/redwood/extract_pointcloud.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | import open3d as o3d
8 | import cv2
9 | import sys
10 | import numpy as np
11 | from glob import glob
12 | import os
13 | import zipfile
14 |
15 | def pc2obj(pc, filepath='test.obj'):
16 | pc = pc.T
17 | nverts = pc.shape[1]
18 | with open(filepath, 'w') as f:
19 | f.write("# OBJ file\n")
20 | for v in range(nverts):
21 | f.write("v %.4f %.4f %.4f\n" % (pc[0,v],pc[1,v],pc[2,v]))
22 |
23 | nump = 50000 ### Number of points in the point clouds
24 | scenelist = glob(sys.argv[1]+"*.zip") ### Path to the zip files
25 | datalist = []
26 |
27 | for scene in scenelist:
28 | if os.path.exists(sys.argv[3]+scene.split("/")[-1].split(".")[0]):
29 | continue
30 |
31 | os.system("rm -rf test/")
32 | with zipfile.ZipFile(scene, 'r') as zip_ref:
33 | zip_ref.extractall("test")
34 |
35 | os.system("ls test/rgb/ > rgblist")
36 | os.system("ls test/depth/ > depthlist")
37 |
38 | rgb_path = {}
39 | depth_path = {}
40 | rgblist = open("rgblist", "r")
41 | depthlist = open("depthlist", "r")
42 | for line in rgblist:
43 | seqname = line.split("-")[0]
44 | if seqname in rgb_path:
45 | print (line)
46 | else:
47 | rgb_path[seqname] = line.strip("\n")
48 |
49 | for line in depthlist:
50 | seqname = line.split("-")[0]
51 | if seqname in depth_path:
52 | print (line)
53 | else:
54 | depth_path[seqname] = line.strip("\n")
55 |
56 | intrin_cam = o3d.camera.PinholeCameraIntrinsic()
57 | #print (scene)
58 | if not os.path.exists(sys.argv[2]+scene.split("/")[-1].split(".")[0]):
59 | os.mkdir(sys.argv[2]+scene.split("/")[-1].split(".")[0])
60 | if not os.path.exists(sys.argv[3]+scene.split("/")[-1].split(".")[0]):
61 | os.mkdir(sys.argv[3]+scene.split("/")[-1].split(".")[0])
62 |
63 | success = 0
64 |
65 | if len(rgb_path) <= len(depth_path):
66 | framelist = rgb_path.keys()
67 | else:
68 | framelist = depth_path.keys()
69 | counter = 15
70 | for frame in framelist:
71 | counter += 1
72 | if counter < 15: ### Set the sampling rate here
73 | continue
74 | depth = "test/depth/"+depth_path[frame]
75 |
76 | rgbpath = "test/rgb/"+rgb_path[frame]#scene+"/matterport_color_images/"+room_name+"_i%d_%d.jpg" % (cam_num, frame_num)
77 |
78 | depth_im = cv2.imread(depth, -1)
79 |
80 | try:
81 | o3d_depth = o3d.geometry.Image(depth_im)
82 | rgb_im = cv2.imread(rgbpath)
83 | o3d_rgb = o3d.geometry.Image(rgb_im)
84 | o3d_rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(o3d_rgb, o3d_depth, depth_scale=1000.0, depth_trunc=1000.0, convert_rgb_to_intensity=False)
85 | except:
86 | print (frame)
87 | continue
88 |
89 | if (depth_im.shape[1] != 640) or (depth_im.shape[0] != 480):
90 | print (frame)
91 | continue
92 | intrin_cam.set_intrinsics(width=depth_im.shape[1], height=depth_im.shape[0], fx=525.0, fy=525.0, cx=319.5, cy=239.5)
93 |
94 | pts = o3d.geometry.PointCloud.create_from_rgbd_image(o3d_rgbd, intrin_cam, np.eye(4))
95 |
96 | if len(np.array(pts.points)) < 100:
97 | continue
98 |
99 | if len(np.array(pts.points)) >= nump:
100 | sel_idx = np.random.choice(len(np.array(pts.points)), nump, replace=False)
101 | else:
102 | sel_idx = np.random.choice(len(np.array(pts.points)), nump, replace=True)
103 | temp = np.array(pts.points)[sel_idx]
104 |
105 | if np.isnan(np.sum(temp)):
106 | continue
107 |
108 | color_points = np.array(pts.colors)[sel_idx]
109 | color_points[:,[0,1,2]] = color_points[:,[2,1,0]]
110 |
111 | pts.points = o3d.utility.Vector3dVector(temp)
112 | pts.colors = o3d.utility.Vector3dVector(color_points)
113 | data = np.concatenate([temp,color_points], axis=1)
114 |
115 | o3d.io.write_point_cloud(sys.argv[2]+scene.split("/")[-1].split(".")[0]+"/"+frame+".ply", pts)
116 | np.save(sys.argv[3]+scene.split("/")[-1].split(".")[0]+"/"+frame+".npy", data)
117 |
118 | datalist.append(os.path.abspath(sys.argv[3]+scene.split("/")[-1].split(".")[0]+"/"+frame+".npy"))
119 |
120 | counter = 0
121 | success += 1
122 |
123 | print (success)
124 |
125 | np.save(sys.argv[4], datalist)
126 |
--------------------------------------------------------------------------------
/data/scannet/README.md:
--------------------------------------------------------------------------------
1 | ### Prepare Dataset
2 |
3 | 1. Download ScanNet data [HERE](https://github.com/ScanNet/ScanNet).
4 |
5 | 2. Please use python2.7 to run the following command to extract depth images, color images, and camera intrinsics:
6 |
7 | ```
8 | python2.7 reader.py --scans_path path/to/scannet_v2/scans/ --output_path path/to/extracted_data/ --export_depth_images --export_color_images --export_poses --export_intrinsics
9 | ```
10 |
11 | 2. Extract point clouds using extract_pointcloud.py
12 |
13 | ```
14 | python extract_pointcloud.py path/to/extracted_data/depth/ path/to/extracted_pointcloud_visualization/ path/to/extracted_pointclouds/ scannet_datalist.npy
15 | ```
16 |
17 | The visualizations are optional.
--------------------------------------------------------------------------------
/data/scannet/SensorData.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Code borrowed from: https://github.com/ScanNet/ScanNet
9 | """
10 | import os, struct
11 | import numpy as np
12 | import zlib
13 | import imageio
14 | import cv2
15 |
16 | COMPRESSION_TYPE_COLOR = {-1:'unknown', 0:'raw', 1:'png', 2:'jpeg'}
17 | COMPRESSION_TYPE_DEPTH = {-1:'unknown', 0:'raw_ushort', 1:'zlib_ushort', 2:'occi_ushort'}
18 |
19 | class RGBDFrame():
20 |
21 | def load(self, file_handle):
22 | self.camera_to_world = np.asarray(struct.unpack('f'*16, file_handle.read(16*4)), dtype=np.float32).reshape(4, 4)
23 | self.timestamp_color = struct.unpack('Q', file_handle.read(8))[0]
24 | self.timestamp_depth = struct.unpack('Q', file_handle.read(8))[0]
25 | self.color_size_bytes = struct.unpack('Q', file_handle.read(8))[0]
26 | self.depth_size_bytes = struct.unpack('Q', file_handle.read(8))[0]
27 | self.color_data = ''.join(struct.unpack('c'*self.color_size_bytes, file_handle.read(self.color_size_bytes)))
28 | self.depth_data = ''.join(struct.unpack('c'*self.depth_size_bytes, file_handle.read(self.depth_size_bytes)))
29 |
30 |
31 | def decompress_depth(self, compression_type):
32 | if compression_type == 'zlib_ushort':
33 | return self.decompress_depth_zlib()
34 | else:
35 | raise
36 |
37 |
38 | def decompress_depth_zlib(self):
39 | return zlib.decompress(self.depth_data)
40 |
41 |
42 | def decompress_color(self, compression_type):
43 | if compression_type == 'jpeg':
44 | return self.decompress_color_jpeg()
45 | else:
46 | raise
47 |
48 |
49 | def decompress_color_jpeg(self):
50 | return imageio.imread(self.color_data)
51 |
52 |
53 | class SensorData:
54 |
55 | def __init__(self, filename):
56 | self.version = 4
57 | self.load(filename)
58 |
59 |
60 | def load(self, filename):
61 | with open(filename, 'rb') as f:
62 | version = struct.unpack('I', f.read(4))[0]
63 | assert self.version == version
64 | strlen = struct.unpack('Q', f.read(8))[0]
65 | self.sensor_name = ''.join(struct.unpack('c'*strlen, f.read(strlen)))
66 | self.intrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4)
67 | self.extrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4)
68 | self.intrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4)
69 | self.extrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4)
70 | self.color_compression_type = COMPRESSION_TYPE_COLOR[struct.unpack('i', f.read(4))[0]]
71 | self.depth_compression_type = COMPRESSION_TYPE_DEPTH[struct.unpack('i', f.read(4))[0]]
72 | self.color_width = struct.unpack('I', f.read(4))[0]
73 | self.color_height = struct.unpack('I', f.read(4))[0]
74 | self.depth_width = struct.unpack('I', f.read(4))[0]
75 | self.depth_height = struct.unpack('I', f.read(4))[0]
76 | self.depth_shift = struct.unpack('f', f.read(4))[0]
77 | num_frames = struct.unpack('Q', f.read(8))[0]
78 | self.frames = []
79 | for i in range(num_frames):
80 | frame = RGBDFrame()
81 | frame.load(f)
82 | self.frames.append(frame)
83 |
84 |
85 | def export_depth_images(self, output_path, image_size=None, frame_skip=1):
86 | if not os.path.exists(output_path):
87 | os.makedirs(output_path)
88 | print ('exporting', len(self.frames)//frame_skip, ' depth frames to', output_path)
89 | for f in range(0, len(self.frames), frame_skip):
90 | depth_data = self.frames[f].decompress_depth(self.depth_compression_type)
91 | depth = np.fromstring(depth_data, dtype=np.uint16).reshape(self.depth_height, self.depth_width)
92 | if image_size is not None:
93 | depth = cv2.resize(depth, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST)
94 | imageio.imwrite(os.path.join(output_path, str(f) + '.png'), depth)
95 |
96 |
97 | def export_color_images(self, output_path, image_size=None, frame_skip=1):
98 | if not os.path.exists(output_path):
99 | os.makedirs(output_path)
100 | print ('exporting', len(self.frames)//frame_skip, 'color frames to', output_path)
101 | for f in range(0, len(self.frames), frame_skip):
102 | color = self.frames[f].decompress_color(self.color_compression_type)
103 | if image_size is not None:
104 | color = cv2.resize(color, (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST)
105 | imageio.imwrite(os.path.join(output_path, str(f) + '.jpg'), color)
106 |
107 |
108 | def save_mat_to_file(self, matrix, filename):
109 | with open(filename, 'w') as f:
110 | for line in matrix:
111 | np.savetxt(f, line[np.newaxis], fmt='%f')
112 |
113 |
114 | def export_poses(self, output_path, frame_skip=1):
115 | if not os.path.exists(output_path):
116 | os.makedirs(output_path)
117 | print ('exporting', len(self.frames)//frame_skip, 'camera poses to', output_path)
118 | for f in range(0, len(self.frames), frame_skip):
119 | self.save_mat_to_file(self.frames[f].camera_to_world, os.path.join(output_path, str(f) + '.txt'))
120 |
121 |
122 | def export_intrinsics(self, output_path):
123 | if not os.path.exists(output_path):
124 | os.makedirs(output_path)
125 | print ('exporting camera intrinsics to', output_path)
126 | self.save_mat_to_file(self.intrinsic_color, os.path.join(output_path, 'intrinsic_color.txt'))
127 | self.save_mat_to_file(self.extrinsic_color, os.path.join(output_path, 'extrinsic_color.txt'))
128 | self.save_mat_to_file(self.intrinsic_depth, os.path.join(output_path, 'intrinsic_depth.txt'))
129 | self.save_mat_to_file(self.extrinsic_depth, os.path.join(output_path, 'extrinsic_depth.txt'))
130 |
--------------------------------------------------------------------------------
/data/scannet/extract_pointcloud.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | import open3d as o3d
8 | import cv2
9 | import sys
10 | import numpy as np
11 | from glob import glob
12 | import os
13 |
14 | def pc2obj(pc, filepath='test.obj'):
15 | pc = pc.T
16 | nverts = pc.shape[1]
17 | with open(filepath, 'w') as f:
18 | f.write("# OBJ file\n")
19 | for v in range(nverts):
20 | f.write("v %.4f %.4f %.4f\n" % (pc[0,v],pc[1,v],pc[2,v]))
21 |
22 | nump = 50000 ### Number of points from the depth scans
23 | scenelist = glob(sys.argv[1]+"*") ### Input path to the data
24 |
25 | if not os.path.exists(sys.argv[2]):
26 | os.mkdir(sys.argv[2])
27 | if not os.path.exists(sys.argv[3]):
28 | os.mkdir(sys.argv[3])
29 |
30 | datalist = []
31 | for scene in scenelist:
32 | framelist = glob(scene+"/*")
33 | intrinsic = np.loadtxt(scene.replace("depth", "intrinsic")+"/intrinsic_depth.txt")
34 | intrin_cam = o3d.camera.PinholeCameraIntrinsic()
35 | print (scene)
36 | if not os.path.exists(sys.argv[2]+scene.split("/")[-1]):
37 | os.mkdir(sys.argv[2]+scene.split("/")[-1])
38 | if not os.path.exists(sys.argv[3]+scene.split("/")[-1]):
39 | os.mkdir(sys.argv[3]+scene.split("/")[-1])
40 |
41 | success = 0
42 | counter = 10
43 | for fileidx in range(len(framelist)):
44 | counter += 1
45 | if counter < 10:
46 | continue
47 | frame = scene+"/%d.png"%fileidx
48 | depth = frame
49 | rgbpath = frame.replace("depth", "color").replace(".png", ".jpg")
50 |
51 | depth_im = cv2.imread(depth, -1)
52 | try:
53 | o3d_depth = o3d.geometry.Image(depth_im)
54 | rgb_im = cv2.resize(cv2.imread(rgbpath), (depth_im.shape[1], depth_im.shape[0]))
55 | o3d_rgb = o3d.geometry.Image(rgb_im)
56 | o3d_rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(o3d_rgb, o3d_depth, depth_scale=1000.0, depth_trunc=1000.0, convert_rgb_to_intensity=False)
57 | except:
58 | print (frame)
59 | continue
60 |
61 | intrin_cam.set_intrinsics(width=depth_im.shape[1], height=depth_im.shape[0], fx=intrinsic[1,1], fy=intrinsic[0,0], cx=intrinsic[1,2], cy=intrinsic[0,2])
62 |
63 | pts = o3d.geometry.PointCloud.create_from_rgbd_image(o3d_rgbd, intrin_cam, np.eye(4))
64 |
65 | if len(np.array(pts.points)) >= nump:
66 | sel_idx = np.random.choice(len(np.array(pts.points)), nump, replace=False)
67 | else:
68 | sel_idx = np.random.choice(len(np.array(pts.points)), nump, replace=True)
69 | temp = np.array(pts.points)[sel_idx]
70 |
71 | color_points = np.array(pts.colors)[sel_idx]
72 | color_points[:,[0,1,2]] = color_points[:,[2,1,0]]
73 |
74 | pts.points = o3d.utility.Vector3dVector(temp)
75 | pts.colors = o3d.utility.Vector3dVector(color_points)
76 | data = np.concatenate([temp,color_points], axis=1)
77 |
78 | o3d.io.write_point_cloud(sys.argv[2]+scene.split("/")[-1]+"/"+frame.split("/")[-1].split(".")[0]+".ply", pts)
79 | np.save(sys.argv[3]+scene.split("/")[-1]+"/"+frame.split("/")[-1].split(".")[0]+".npy", data)
80 | datalist.append(os.path.abspath(sys.argv[3]+scene.split("/")[-1]+"/"+frame.split("/")[-1].split(".")[0]+".npy"))
81 | counter = 0
82 | success += 1
83 | print (success)
84 |
85 | np.save(sys.argv[4], datalist)
86 |
--------------------------------------------------------------------------------
/data/scannet/reader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import os, sys
9 |
10 | from SensorData import SensorData
11 |
12 | from glob import glob
13 |
14 | # params
15 | parser = argparse.ArgumentParser()
16 | # data paths
17 | parser.add_argument('--scans_path', required=True, help='path to scans folder')
18 | parser.add_argument('--output_path', required=True, help='path to output folder')
19 | parser.add_argument('--export_depth_images', dest='export_depth_images', action='store_true')
20 | parser.add_argument('--export_color_images', dest='export_color_images', action='store_true')
21 | parser.add_argument('--export_poses', dest='export_poses', action='store_true')
22 | parser.add_argument('--export_intrinsics', dest='export_intrinsics', action='store_true')
23 | parser.set_defaults(export_depth_images=False, export_color_images=False, export_poses=False, export_intrinsics=False)
24 |
25 | opt = parser.parse_args()
26 | print(opt)
27 |
28 |
29 | def main():
30 | scans = glob(opt.scans_path+"/*")
31 | for scan in scans:
32 | scenename = scan.split("/")[-1]
33 | filename = os.path.join(scan, scenename+".sens")
34 | if not os.path.exists(opt.output_path):
35 | os.makedirs(opt.output_path)
36 | os.makedirs(os.path.join(opt.output_path, 'depth'))
37 | os.makedirs(os.path.join(opt.output_path, 'color'))
38 | os.makedirs(os.path.join(opt.output_path, 'pose'))
39 | os.makedirs(os.path.join(opt.output_path, 'intrinsic'))
40 | # load the data
41 | sys.stdout.write('loading %s...' % filename)
42 | sd = SensorData(filename)
43 | sys.stdout.write('loaded!\n')
44 | if opt.export_depth_images:
45 | sd.export_depth_images(os.path.join(opt.output_path, 'depth', scenename))
46 | if opt.export_color_images:
47 | sd.export_color_images(os.path.join(opt.output_path, 'color', scenename))
48 | if opt.export_poses:
49 | sd.export_poses(os.path.join(opt.output_path, 'pose', scenename))
50 | if opt.export_intrinsics:
51 | sd.export_intrinsics(os.path.join(opt.output_path, 'intrinsic', scenename))
52 |
53 |
54 | if __name__ == '__main__':
55 | main()
56 |
--------------------------------------------------------------------------------
/data/waymo/README.md:
--------------------------------------------------------------------------------
1 | ### Prepare Dataset
2 |
3 | 1. Download Waymo data [HERE](https://waymo.com/open/). We suggest to use [this](https://github.com/RalphMao/Waymo-Dataset-Tool) tool for batch downloading.
4 |
5 | 2. Please install this [tool](https://github.com/waymo-research/waymo-open-dataset) for data preprocessing.
6 |
7 | 3. Create a data_folder for extracting point clouds. Use extract_pointcloud.py to extract point clouds.
8 | python extract_pointcloud.py /path/to/downloaded_segments /path/to/data_folder waymo.npy
9 |
--------------------------------------------------------------------------------
/data/waymo/extract_pointcloud.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | import sys
8 | import os
9 | import tensorflow as tf
10 | import numpy as np
11 | import glob
12 | from open3d import *
13 |
14 | tf.enable_eager_execution()
15 |
16 | from waymo_open_dataset.utils import frame_utils
17 | from waymo_open_dataset import dataset_pb2 as open_dataset
18 |
19 |
20 | debug = False
21 | K = 5
22 | SCALE_FACTOR = 2
23 | segments = glob.glob(sys.argv[1]+"/*")
24 | datalist = []
25 |
26 | def pc2obj(pc, filepath='test.obj'):
27 | pc = pc.T
28 | nverts = pc.shape[1]
29 | with open(filepath, 'w') as f:
30 | f.write("# OBJ file\n")
31 | for v in range(nverts):
32 | f.write("v %.4f %.4f %.4f\n" % (pc[0,v],pc[1,v],pc[2,v]))
33 |
34 | def extract(i):
35 | FILENAME = segments[i]
36 | run = FILENAME.split('segment-')[-1].split('.')[0]
37 | out_base_dir = sys.argv[2]+'/%s/' % run
38 |
39 | if not os.path.exists(out_base_dir):
40 | os.makedirs(out_base_dir)
41 |
42 | dataset = tf.data.TFRecordDataset(FILENAME, compression_type='')
43 | print(FILENAME)
44 | pc, pc_c = [], []
45 | camID2extr_v2c = {}
46 | camID2intr = {}
47 |
48 | all_static_pc = []
49 | for frame_cnt, data in enumerate(dataset):
50 | if frame_cnt % 2 != 0: continue ### Set the sampling rate here
51 |
52 | print('frame ', frame_cnt)
53 | frame = open_dataset.Frame()
54 | frame.ParseFromString(bytearray(data.numpy()))
55 |
56 | extr_laser2v = np.array(frame.context.laser_calibrations[0].extrinsic.transform).reshape(4, 4)
57 | extr_v2w = np.array(frame.pose.transform).reshape(4, 4)
58 |
59 | if frame_cnt == 0:
60 | for k in range(len(frame.context.camera_calibrations)):
61 | cameraID = frame.context.camera_calibrations[k].name
62 | extr_c2v =\
63 | np.array(frame.context.camera_calibrations[k].extrinsic.transform).reshape(4, 4)
64 | extr_v2c = np.linalg.inv(extr_c2v)
65 | camID2extr_v2c[frame.images[k].name] = extr_v2c
66 | fx = frame.context.camera_calibrations[k].intrinsic[0]
67 | fy = frame.context.camera_calibrations[k].intrinsic[1]
68 | cx = frame.context.camera_calibrations[k].intrinsic[2]
69 | cy = frame.context.camera_calibrations[k].intrinsic[3]
70 | k1 = frame.context.camera_calibrations[k].intrinsic[4]
71 | k2 = frame.context.camera_calibrations[k].intrinsic[5]
72 | p1 = frame.context.camera_calibrations[k].intrinsic[6]
73 | p2 = frame.context.camera_calibrations[k].intrinsic[7]
74 | k3 = frame.context.camera_calibrations[k].intrinsic[8]
75 | camID2intr[frame.images[k].name] = np.array([[fx, 0, cx],
76 | [0, fy, cy],
77 | [0, 0, 1]])
78 |
79 |
80 | # lidar point cloud
81 |
82 | (range_images, camera_projections, range_image_top_pose) = \
83 | frame_utils.parse_range_image_and_camera_projection(frame)
84 | points, cp_points = frame_utils.convert_range_image_to_point_cloud(
85 | frame,
86 | range_images,
87 | camera_projections,
88 | range_image_top_pose)
89 |
90 |
91 | points_all = np.concatenate(points, axis=0)
92 | np.save('%s/frame_%03d.npy' % (out_base_dir, frame_cnt), points_all)
93 | datalist.append(os.path.abspath('%s/frame_%03d.npy' % (out_base_dir, frame_cnt)))
94 |
95 | if __name__ == '__main__':
96 | for i in range(len(segments)):
97 | extract(i)
98 | np.save(sys.argv[3], datalist)
99 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 | import logging
9 |
10 | import torch
11 | from datasets.collators import get_collator
12 | from datasets.depth_dataset import DepthContrastDataset
13 | from torch.utils.data import DataLoader
14 |
15 |
16 | __all__ = ["DepthContrastDataset", "get_data_files"]
17 |
18 |
19 | def build_dataset(cfg):
20 | dataset = DepthContrastDataset(cfg)
21 | return dataset
22 |
23 |
24 | def print_sampler_config(data_sampler):
25 | sampler_cfg = {
26 | "num_replicas": data_sampler.num_replicas,
27 | "rank": data_sampler.rank,
28 | "epoch": data_sampler.epoch,
29 | "num_samples": data_sampler.num_samples,
30 | "total_size": data_sampler.total_size,
31 | "shuffle": data_sampler.shuffle,
32 | }
33 | if hasattr(data_sampler, "start_iter"):
34 | sampler_cfg["start_iter"] = data_sampler.start_iter
35 | if hasattr(data_sampler, "batch_size"):
36 | sampler_cfg["batch_size"] = data_sampler.batch_size
37 | logging.info("Distributed Sampler config:\n{}".format(sampler_cfg))
38 |
39 |
40 | def get_loader(dataset, dataset_config, num_dataloader_workers, pin_memory):
41 | data_sampler = None
42 | if torch.distributed.is_available() and torch.distributed.is_initialized():
43 | assert torch.distributed.is_initialized(), "Torch distributed isn't initalized"
44 | data_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
45 | logging.info("Created the Distributed Sampler....")
46 | print_sampler_config(data_sampler)
47 | else:
48 | logging.warning(
49 | "Distributed trainer not initialized. Not using the sampler and data will NOT be shuffled" # NOQA
50 | )
51 | collate_function = get_collator(dataset_config["COLLATE_FUNCTION"])
52 | dataloader = DataLoader(
53 | dataset=dataset,
54 | num_workers=num_dataloader_workers,
55 | pin_memory=pin_memory,
56 | shuffle=False,
57 | batch_size=dataset_config["BATCHSIZE_PER_REPLICA"],
58 | collate_fn=collate_function,
59 | sampler=data_sampler,
60 | drop_last=dataset_config["DROP_LAST"],
61 | )
62 | return dataloader
63 |
--------------------------------------------------------------------------------
/datasets/collators/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | from datasets.collators.point_moco_collator import point_moco_collator
10 | try:
11 | from datasets.collators.vox_moco_collator import vox_moco_collator
12 | from datasets.collators.point_vox_moco_collator import point_vox_moco_collator
13 | except:
14 | print ("Cannot import minkowski engine. Try spconv next")
15 | from datasets.collators.point_vox_moco_lidar_collator import point_vox_moco_collator
16 | from torch.utils.data.dataloader import default_collate
17 |
18 |
19 | COLLATORS_MAP = {
20 | "default": default_collate,
21 | "point_moco_collator": point_moco_collator,
22 | "point_vox_moco_collator": point_vox_moco_collator,
23 | }
24 |
25 |
26 | def get_collator(name):
27 | assert name in COLLATORS_MAP, "Unknown collator"
28 | return COLLATORS_MAP[name]
29 |
30 |
31 | __all__ = ["get_collator"]
32 |
--------------------------------------------------------------------------------
/datasets/collators/point_moco_collator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | import torch
10 |
11 | def point_moco_collator(batch):
12 | batch_size = len(batch)
13 |
14 | data_point = [x["data"] for x in batch]
15 | data_moco = [x["data_moco"] for x in batch]
16 | # labels are repeated N+1 times but they are the same
17 | labels = [x["label"][0] for x in batch]
18 | labels = torch.LongTensor(labels).squeeze()
19 |
20 | # data valid is repeated N+1 times but they are the same
21 | data_valid = torch.BoolTensor([x["data_valid"][0] for x in batch])
22 |
23 | points_moco = torch.stack([data_moco[i][0] for i in range(batch_size)])
24 | points = torch.stack([data_point[i][0] for i in range(batch_size)])
25 |
26 | output_batch = {
27 | "points": points,
28 | "points_moco": points_moco,
29 | "label": labels,
30 | "data_valid": data_valid,
31 | }
32 |
33 | return output_batch
34 |
--------------------------------------------------------------------------------
/datasets/collators/point_vox_moco_collator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | import torch
10 |
11 | from datasets.transforms import transforms
12 |
13 | import numpy as np
14 |
15 | collate_fn = transforms.cfl_collate_fn_factory(0)
16 |
17 | def point_vox_moco_collator(batch):
18 | batch_size = len(batch)
19 |
20 | point = [x["data"] for x in batch]
21 | point_moco = [x["data_moco"] for x in batch]
22 | vox = [x["vox"] for x in batch]
23 | vox_moco = [x["vox_moco"] for x in batch]
24 | # labels are repeated N+1 times but they are the same
25 | labels = [x["label"][0] for x in batch]
26 | labels = torch.LongTensor(labels).squeeze()
27 |
28 | # data valid is repeated N+1 times but they are the same
29 | data_valid = torch.BoolTensor([x["data_valid"][0] for x in batch])
30 |
31 | points_moco = torch.stack([point_moco[i][0] for i in range(batch_size)])
32 | points = torch.stack([point[i][0] for i in range(batch_size)])
33 |
34 | vox_moco = collate_fn([vox_moco[i][0] for i in range(batch_size)])
35 | vox = collate_fn([vox[i][0] for i in range(batch_size)])
36 |
37 | output_batch = {
38 | "points": points,
39 | "points_moco": points_moco,
40 | "vox": vox,
41 | "vox_moco": vox_moco,
42 | "label": labels,
43 | "data_valid": data_valid,
44 | }
45 |
46 | return output_batch
47 |
--------------------------------------------------------------------------------
/datasets/collators/point_vox_moco_lidar_collator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | import torch
10 |
11 | import numpy as np
12 |
13 | def point_vox_moco_collator(batch):
14 | batch_size = len(batch)
15 |
16 | point = [x["data"] for x in batch]
17 | point_moco = [x["data_moco"] for x in batch]
18 | vox = [x["vox"] for x in batch]
19 | vox_moco = [x["vox_moco"] for x in batch]
20 |
21 | labels = [x["label"][0] for x in batch]
22 | labels = torch.LongTensor(labels).squeeze()
23 |
24 | # data valid is repeated N+1 times but they are the same
25 | data_valid = torch.BoolTensor([x["data_valid"][0] for x in batch])
26 |
27 | points_moco = torch.stack([point_moco[i][0] for i in range(batch_size)])
28 | points = torch.stack([point[i][0] for i in range(batch_size)])
29 |
30 | vox_data = {"voxels":[], "voxel_coords":[], "voxel_num_points":[]}
31 | counter = 0
32 | for data in vox:
33 | temp = data[0]
34 | voxels_shape = temp["voxels"].shape
35 | voxel_num_points_shape = temp["voxel_num_points"].shape
36 | voxel_coords_shape = temp["voxel_coords"].shape
37 | for key,val in temp.items():
38 | if key in ['voxels', 'voxel_num_points']:
39 | if len(vox_data[key]) > 0:
40 | vox_data[key] = np.concatenate([vox_data[key], val], axis=0)
41 | else:
42 | vox_data[key] = val
43 | elif key == 'voxel_coords':
44 | coor = np.pad(val, ((0, 0), (1, 0)), mode='constant', constant_values=counter)
45 | if len(vox_data[key]) > 0:
46 | vox_data[key] = np.concatenate([vox_data[key], coor], axis=0)
47 | else:
48 | vox_data[key] = coor
49 | counter += 1
50 |
51 | vox_moco_data = {"voxels":[], "voxel_coords":[], "voxel_num_points":[]}
52 | counter = 0
53 | for data in vox_moco:
54 | temp = data[0]
55 | voxels_shape = temp["voxels"].shape
56 | voxel_num_points_shape = temp["voxel_num_points"].shape
57 | voxel_coords_shape = temp["voxel_coords"].shape
58 | for key,val in temp.items():
59 | if key in ['voxels', 'voxel_num_points']:
60 | if len(vox_moco_data[key]) > 0:
61 | vox_moco_data[key] = np.concatenate([vox_moco_data[key], val], axis=0)
62 | else:
63 | vox_moco_data[key] = val
64 | elif key in 'voxel_coords':
65 | coor = np.pad(val, ((0, 0), (1, 0)), mode='constant', constant_values=counter)
66 |
67 | if len(vox_moco_data[key]) > 0:
68 | vox_moco_data[key] = np.concatenate([vox_moco_data[key], coor], axis=0)
69 | else:
70 | vox_moco_data[key] = coor
71 | counter += 1
72 |
73 | vox_data = {k:torch.from_numpy(vox_data[k]) for k in vox_data}
74 | vox_moco_data = {k:torch.from_numpy(vox_moco_data[k]) for k in vox_moco_data}
75 |
76 | output_batch = {
77 | "points": points,
78 | "points_moco": points_moco,
79 | "vox": vox_data,
80 | "vox_moco": vox_moco_data,
81 | "label": labels,
82 | "data_valid": data_valid,
83 | }
84 |
85 | return output_batch
86 |
--------------------------------------------------------------------------------
/datasets/collators/vox_moco_collator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | import torch
10 |
11 | from datasets.transforms import transforms
12 |
13 | import numpy as np
14 |
15 | collate_fn = transforms.cfl_collate_fn_factory(0)
16 |
17 | def vox_moco_collator(batch):
18 | batch_size = len(batch)
19 |
20 | data_point = [x["data"] for x in batch]
21 | data_moco = [x["data_moco"] for x in batch]
22 | # labels are repeated N+1 times but they are the same
23 | labels = [int(x["label"][0]) for x in batch]
24 | labels = torch.LongTensor(labels).squeeze()
25 |
26 | # data valid is repeated N+1 times but they are the same
27 | data_valid = torch.BoolTensor([x["data_valid"][0] for x in batch])
28 |
29 | vox_moco = collate_fn([data_moco[i][0] for i in range(batch_size)])
30 | vox = collate_fn([data_point[i][0] for i in range(batch_size)])
31 |
32 | output_batch = {
33 | "vox": vox,
34 | "vox_moco": vox_moco,
35 | "label": labels,
36 | "data_valid": data_valid,
37 | }
38 |
39 | return output_batch
40 |
--------------------------------------------------------------------------------
/datasets/transforms/voxelizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | import collections
8 |
9 | import numpy as np
10 | import MinkowskiEngine as ME
11 | from scipy.linalg import expm, norm
12 |
13 |
14 | # Rotation matrix along axis with angle theta
15 | def M(axis, theta):
16 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta))
17 |
18 |
19 | class Voxelizer:
20 |
21 | def __init__(self,
22 | voxel_size=1,
23 | clip_bound=None,
24 | use_augmentation=False,
25 | scale_augmentation_bound=None,
26 | rotation_augmentation_bound=None,
27 | translation_augmentation_ratio_bound=None,
28 | ignore_label=255):
29 | """
30 | Args:
31 | voxel_size: side length of a voxel
32 | clip_bound: boundary of the voxelizer. Points outside the bound will be deleted
33 | expects either None or an array like ((-100, 100), (-100, 100), (-100, 100)).
34 | scale_augmentation_bound: None or (0.9, 1.1)
35 | rotation_augmentation_bound: None or ((np.pi / 6, np.pi / 6), None, None) for 3 axis.
36 | Use random order of x, y, z to prevent bias.
37 | translation_augmentation_bound: ((-5, 5), (0, 0), (-10, 10))
38 | ignore_label: label assigned for ignore (not a training label).
39 | """
40 | self.voxel_size = voxel_size
41 | self.clip_bound = clip_bound
42 | self.ignore_label = ignore_label
43 |
44 | # Augmentation
45 | self.use_augmentation = use_augmentation
46 | self.scale_augmentation_bound = scale_augmentation_bound
47 | self.rotation_augmentation_bound = rotation_augmentation_bound
48 | self.translation_augmentation_ratio_bound = translation_augmentation_ratio_bound
49 |
50 | def get_transformation_matrix(self):
51 | voxelization_matrix, rotation_matrix = np.eye(4), np.eye(4)
52 | # Get clip boundary from config or pointcloud.
53 | # Get inner clip bound to crop from.
54 |
55 | # Transform pointcloud coordinate to voxel coordinate.
56 | # 1. Random rotation
57 | rot_mat = np.eye(3)
58 | if self.use_augmentation and self.rotation_augmentation_bound is not None:
59 | if isinstance(self.rotation_augmentation_bound, collections.Iterable):
60 | rot_mats = []
61 | for axis_ind, rot_bound in enumerate(self.rotation_augmentation_bound):
62 | theta = 0
63 | axis = np.zeros(3)
64 | axis[axis_ind] = 1
65 | if rot_bound is not None:
66 | theta = np.random.uniform(*rot_bound)
67 | rot_mats.append(M(axis, theta))
68 | # Use random order
69 | np.random.shuffle(rot_mats)
70 | rot_mat = rot_mats[0] @ rot_mats[1] @ rot_mats[2]
71 | else:
72 | raise ValueError()
73 | rotation_matrix[:3, :3] = rot_mat
74 | # 2. Scale and translate to the voxel space.
75 | scale = 1 / self.voxel_size
76 | if self.use_augmentation and self.scale_augmentation_bound is not None:
77 | scale *= np.random.uniform(*self.scale_augmentation_bound)
78 | np.fill_diagonal(voxelization_matrix[:3, :3], scale)
79 | # Get final transformation matrix.
80 | return voxelization_matrix, rotation_matrix
81 |
82 | def clip(self, coords, center=None, trans_aug_ratio=None):
83 | bound_min = np.min(coords, 0).astype(float)
84 | bound_max = np.max(coords, 0).astype(float)
85 | bound_size = bound_max - bound_min
86 | if center is None:
87 | center = bound_min + bound_size * 0.5
88 | if trans_aug_ratio is not None:
89 | trans = np.multiply(trans_aug_ratio, bound_size)
90 | center += trans
91 | lim = self.clip_bound
92 |
93 | if isinstance(self.clip_bound, (int, float)):
94 | if bound_size.max() < self.clip_bound:
95 | return None
96 | else:
97 | clip_inds = ((coords[:, 0] >= (-lim + center[0])) & \
98 | (coords[:, 0] < (lim + center[0])) & \
99 | (coords[:, 1] >= (-lim + center[1])) & \
100 | (coords[:, 1] < (lim + center[1])) & \
101 | (coords[:, 2] >= (-lim + center[2])) & \
102 | (coords[:, 2] < (lim + center[2])))
103 | return clip_inds
104 |
105 | # Clip points outside the limit
106 | clip_inds = ((coords[:, 0] >= (lim[0][0] + center[0])) & \
107 | (coords[:, 0] < (lim[0][1] + center[0])) & \
108 | (coords[:, 1] >= (lim[1][0] + center[1])) & \
109 | (coords[:, 1] < (lim[1][1] + center[1])) & \
110 | (coords[:, 2] >= (lim[2][0] + center[2])) & \
111 | (coords[:, 2] < (lim[2][1] + center[2])))
112 | return clip_inds
113 |
114 | def voxelize(self, coords, feats, labels, center=None):
115 | assert coords.shape[1] == 3 and coords.shape[0] == feats.shape[0] and coords.shape[0]
116 | if self.clip_bound is not None:
117 | trans_aug_ratio = np.zeros(3)
118 | if self.use_augmentation and self.translation_augmentation_ratio_bound is not None:
119 | for axis_ind, trans_ratio_bound in enumerate(self.translation_augmentation_ratio_bound):
120 | trans_aug_ratio[axis_ind] = np.random.uniform(*trans_ratio_bound)
121 |
122 | clip_inds = self.clip(coords, center, trans_aug_ratio)
123 | if clip_inds is not None:
124 | coords, feats = coords[clip_inds], feats[clip_inds]
125 | if labels is not None:
126 | labels = labels[clip_inds]
127 |
128 | # Get rotation and scale
129 | M_v, M_r = self.get_transformation_matrix()
130 | # Apply transformations
131 | rigid_transformation = M_v
132 | if self.use_augmentation:
133 | rigid_transformation = M_r @ rigid_transformation
134 |
135 | homo_coords = np.hstack((coords, np.ones((coords.shape[0], 1), dtype=coords.dtype)))
136 | coords_aug = np.floor(homo_coords @ rigid_transformation.T[:, :3])
137 |
138 | # key = self.hash(coords_aug) # floor happens by astype(np.uint64)
139 | coords_aug, feats, labels = ME.utils.sparse_quantize(
140 | coords_aug, feats, labels=labels, ignore_label=self.ignore_label)
141 |
142 | return coords_aug, feats, labels, rigid_transformation.flatten()
143 |
144 | def voxelize_temporal(self,
145 | coords_t,
146 | feats_t,
147 | labels_t,
148 | centers=None,
149 | return_transformation=False):
150 | # Legacy code, remove
151 | if centers is None:
152 | centers = [
153 | None,
154 | ] * len(coords_t)
155 | coords_tc, feats_tc, labels_tc, transformation_tc = [], [], [], []
156 |
157 | # ######################### Data Augmentation #############################
158 | # Get rotation and scale
159 | M_v, M_r = self.get_transformation_matrix()
160 | # Apply transformations
161 | rigid_transformation = M_v
162 | if self.use_augmentation:
163 | rigid_transformation = M_r @ rigid_transformation
164 | # ######################### Voxelization #############################
165 | # Voxelize coords
166 | for coords, feats, labels, center in zip(coords_t, feats_t, labels_t, centers):
167 |
168 | ###################################
169 | # Clip the data if bound exists
170 | if self.clip_bound is not None:
171 | trans_aug_ratio = np.zeros(3)
172 | if self.use_augmentation and self.translation_augmentation_ratio_bound is not None:
173 | for axis_ind, trans_ratio_bound in enumerate(self.translation_augmentation_ratio_bound):
174 | trans_aug_ratio[axis_ind] = np.random.uniform(*trans_ratio_bound)
175 |
176 | clip_inds = self.clip(coords, center, trans_aug_ratio)
177 | if clip_inds is not None:
178 | coords, feats = coords[clip_inds], feats[clip_inds]
179 | if labels is not None:
180 | labels = labels[clip_inds]
181 | ###################################
182 |
183 | homo_coords = np.hstack((coords, np.ones((coords.shape[0], 1), dtype=coords.dtype)))
184 | coords_aug = np.floor(homo_coords @ rigid_transformation.T)[:, :3]
185 |
186 | coords_aug, feats, labels = ME.utils.sparse_quantize(
187 | coords_aug, feats, labels=labels, ignore_label=self.ignore_label)
188 |
189 | coords_tc.append(coords_aug)
190 | feats_tc.append(feats)
191 | labels_tc.append(labels)
192 | transformation_tc.append(rigid_transformation.flatten())
193 |
194 | return_args = [coords_tc, feats_tc, labels_tc]
195 | if return_transformation:
196 | return_args.append(transformation_tc)
197 |
198 | return tuple(return_args)
199 |
200 |
201 | def test():
202 | N = 16575
203 | coords = np.random.rand(N, 3) * 10
204 | feats = np.random.rand(N, 4)
205 | labels = np.floor(np.random.rand(N) * 3)
206 | coords[:3] = 0
207 | labels[:3] = 2
208 | voxelizer = Voxelizer()
209 | print(voxelizer.voxelize(coords, feats, labels))
210 |
211 |
212 | if __name__ == '__main__':
213 | test()
214 |
--------------------------------------------------------------------------------
/imgs/method.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/DepthContrast/b8257890c94f7c58aeb5cefeb91af031692611d6/imgs/method.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import argparse
9 | import os
10 | import random
11 | import time
12 | import warnings
13 | import yaml
14 |
15 | import torch
16 | import torch.nn.parallel
17 | import torch.backends.cudnn as cudnn
18 | import torch.optim
19 |
20 | from torch.multiprocessing import Pool, Process, set_start_method
21 | try:
22 | set_start_method('spawn')
23 | except RuntimeError:
24 | pass
25 |
26 | import torch.multiprocessing as mp
27 |
28 | import utils.logger
29 | from utils import main_utils
30 |
31 | parser = argparse.ArgumentParser(description='PyTorch Self Supervised Training in 3D')
32 |
33 | parser.add_argument('cfg', help='model directory')
34 | parser.add_argument('--quiet', action='store_true')
35 |
36 | parser.add_argument('--world-size', default=-1, type=int,
37 | help='number of nodes for distributed training')
38 | parser.add_argument('--rank', default=-1, type=int,
39 | help='node rank for distributed training')
40 | parser.add_argument('--dist-url', default='tcp://localhost:15475', type=str,
41 | help='url used to set up distributed training')
42 | parser.add_argument('--dist-backend', default='nccl', type=str,
43 | help='distributed backend')
44 | parser.add_argument('--seed', default=None, type=int,
45 | help='seed for initializing training. ')
46 | parser.add_argument('--gpu', default=0, type=int,
47 | help='GPU id to use.')
48 | parser.add_argument('--ngpus', default=8, type=int,
49 | help='number of GPUs to use.')
50 | parser.add_argument('--multiprocessing-distributed', action='store_true',
51 | help='Use multi-processing distributed training to launch '
52 | 'N processes per node, which has N GPUs. This is the '
53 | 'fastest way to use PyTorch for either single node or '
54 | 'multi node data parallel training')
55 |
56 |
57 | def main():
58 | args = parser.parse_args()
59 | cfg = yaml.safe_load(open(args.cfg))
60 |
61 | if args.seed is not None:
62 | random.seed(args.seed)
63 | torch.manual_seed(args.seed)
64 | cudnn.deterministic = True
65 | warnings.warn('You have chosen to seed training. '
66 | 'This will turn on the CUDNN deterministic setting, '
67 | 'which can slow down your training considerably! '
68 | 'You may see unexpected behavior when restarting '
69 | 'from checkpoints.')
70 |
71 | ngpus_per_node = args.ngpus
72 | if args.multiprocessing_distributed:
73 | args.world_size = ngpus_per_node * args.world_size
74 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args, cfg))
75 | else:
76 | # Simply call main_worker function
77 | main_worker(args.gpu, ngpus_per_node, args, cfg)
78 |
79 | def main_worker(gpu, ngpus, args, cfg):
80 | args.gpu = gpu
81 | ngpus_per_node = ngpus
82 |
83 | # Setup environment
84 | args = main_utils.initialize_distributed_backend(args, ngpus_per_node) ### Use other method instead
85 | logger, tb_writter, model_dir = main_utils.prep_environment(args, cfg)
86 |
87 | # Define model
88 | model = main_utils.build_model(cfg['model'], logger)
89 | model, args = main_utils.distribute_model_to_cuda(model, args)
90 |
91 | # Define dataloaders
92 | train_loader = main_utils.build_dataloaders(cfg['dataset'], cfg['num_workers'], args.multiprocessing_distributed, logger)
93 |
94 | # Define criterion
95 | train_criterion = main_utils.build_criterion(cfg['loss'], logger=logger)
96 | train_criterion = train_criterion.cuda()
97 |
98 | # Define optimizer
99 | optimizer, scheduler = main_utils.build_optimizer(
100 | params=list(model.parameters())+list(train_criterion.parameters()),
101 | cfg=cfg['optimizer'],
102 | logger=logger)
103 | ckp_manager = main_utils.CheckpointManager(model_dir, rank=args.rank, dist=args.multiprocessing_distributed)
104 |
105 | # Optionally resume from a checkpoint
106 | start_epoch, end_epoch = 0, cfg['optimizer']['num_epochs']
107 | if cfg['resume']:
108 | if ckp_manager.checkpoint_exists(last=True):
109 | start_epoch = ckp_manager.restore(restore_last=True, model=model, optimizer=optimizer, train_criterion=train_criterion)
110 | scheduler.step(start_epoch)
111 | logger.add_line("Checkpoint loaded: '{}' (epoch {})".format(ckp_manager.last_checkpoint_fn(), start_epoch))
112 | else:
113 | logger.add_line("No checkpoint found at '{}'".format(ckp_manager.last_checkpoint_fn()))
114 |
115 | cudnn.benchmark = True
116 |
117 | ############################ TRAIN #########################################
118 | test_freq = cfg['test_freq'] if 'test_freq' in cfg else 1
119 | for epoch in range(start_epoch, end_epoch):
120 | if (epoch % 10) == 0:
121 | ckp_manager.save(epoch, model=model, train_criterion=train_criterion, optimizer=optimizer, filename='checkpoint-ep{}.pth.tar'.format(epoch))
122 |
123 | if args.multiprocessing_distributed:
124 | train_loader.sampler.set_epoch(epoch)
125 |
126 | # Train for one epoch
127 | logger.add_line('='*30 + ' Epoch {} '.format(epoch) + '='*30)
128 | logger.add_line('LR: {}'.format(scheduler.get_lr()))
129 | run_phase('train', train_loader, model, optimizer, train_criterion, epoch, args, cfg, logger, tb_writter)
130 | scheduler.step(epoch)
131 |
132 | if ((epoch % test_freq) == 0) or (epoch == end_epoch - 1):
133 | ckp_manager.save(epoch+1, model=model, optimizer=optimizer, train_criterion=train_criterion)
134 |
135 |
136 | def run_phase(phase, loader, model, optimizer, criterion, epoch, args, cfg, logger, tb_writter):
137 | from utils import metrics_utils
138 | logger.add_line('\n{}: Epoch {}'.format(phase, epoch))
139 | batch_time = metrics_utils.AverageMeter('Time', ':6.3f', window_size=100)
140 | data_time = metrics_utils.AverageMeter('Data', ':6.3f', window_size=100)
141 | loss_meter = metrics_utils.AverageMeter('Loss', ':.3e')
142 | loss_meter_npid1 = metrics_utils.AverageMeter('Loss_npid1', ':.3e')
143 | loss_meter_npid2 = metrics_utils.AverageMeter('Loss_npid2', ':.3e')
144 | loss_meter_cmc1 = metrics_utils.AverageMeter('Loss_cmc1', ':.3e')
145 | loss_meter_cmc2 = metrics_utils.AverageMeter('Loss_cmc2', ':.3e')
146 | progress = utils.logger.ProgressMeter(len(loader), [batch_time, data_time, loss_meter, loss_meter_npid1, loss_meter_npid2, loss_meter_cmc1, loss_meter_cmc2], phase=phase, epoch=epoch, logger=logger, tb_writter=tb_writter)
147 |
148 | # switch to train mode
149 | model.train(phase == 'train')
150 |
151 | end = time.time()
152 | device = args.gpu if args.gpu is not None else 0
153 | for i, sample in enumerate(loader):
154 | # measure data loading time
155 | data_time.update(time.time() - end)
156 |
157 | if phase == 'train':
158 | embedding = model(sample)
159 | else:
160 | with torch.no_grad():
161 | embedding = model(sample)
162 |
163 | # compute loss
164 | loss, loss_debug = criterion(embedding)
165 | loss_meter.update(loss.item(), embedding[0].size(0))
166 | loss_meter_npid1.update(loss_debug[0].item(), embedding[0].size(0))
167 | loss_meter_npid2.update(loss_debug[1].item(), embedding[0].size(0))
168 | loss_meter_cmc1.update(loss_debug[2].item(), embedding[0].size(0))
169 | loss_meter_cmc2.update(loss_debug[3].item(), embedding[0].size(0))
170 |
171 | # compute gradient and do SGD step during training
172 | if phase == 'train':
173 | optimizer.zero_grad()
174 | loss.backward()
175 | optimizer.step()
176 |
177 | # measure elapsed time
178 | batch_time.update(time.time() - end)
179 | end = time.time()
180 |
181 | # print to terminal and tensorboard
182 | step = epoch * len(loader) + i
183 | if (i+1) % cfg['print_freq'] == 0 or i == 0 or i+1 == len(loader):
184 | progress.display(i+1)
185 |
186 | # Sync metrics across all GPUs and print final averages
187 | if args.multiprocessing_distributed:
188 | progress.synchronize_meters(args.gpu)
189 | progress.display(len(loader)*args.world_size)
190 |
191 | if tb_writter is not None:
192 | for meter in progress.meters:
193 | tb_writter.add_scalar('{}-epoch/{}'.format(phase, meter.name), meter.avg, epoch)
194 |
195 |
196 | if __name__ == '__main__':
197 | main()
198 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.abs
7 |
8 | from models.base_ssl3d_model import BaseSSLMultiInputOutputModel
9 |
10 | def build_model(model_config, logger):
11 | return BaseSSLMultiInputOutputModel(model_config, logger)
12 |
13 |
14 | __all__ = ["BaseSSLMultiInputOutputModel", "build_model"]
15 |
--------------------------------------------------------------------------------
/models/trunks/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 | try:
9 | from models.trunks.pointnet import PointNet
10 | except:
11 | PointNet = None
12 |
13 | try:
14 | from models.trunks.spconv.models.res16unet import Res16UNet34
15 | except:
16 | Res16UNet34 = None
17 |
18 | try:
19 | from models.trunks.pointnet2_backbone import PointNet2MSG
20 | from models.trunks.spconv_unet import UNetV2_concat as UNetV2
21 | except:
22 | PointNet2MSG = None
23 | UNetV2 = None
24 |
25 | TRUNKS = {
26 | "pointnet": PointNet,
27 | "unet": Res16UNet34,
28 | "pointnet_msg": PointNet2MSG,
29 | "UNetV2": UNetV2,
30 | }
31 |
32 |
33 | __all__ = ["TRUNKS"]
34 |
--------------------------------------------------------------------------------
/models/trunks/mlp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 |
9 | import torch.nn as nn
10 |
11 | class MLP(nn.Module):
12 | def __init__(
13 | self, dims, use_bn=False, use_relu=True, use_dropout=False, use_bias=True
14 | ):
15 | super().__init__()
16 | layers = []
17 | last_dim = dims[0]
18 | counter = 1
19 | for dim in dims[1:]:
20 | layers.append(nn.Linear(last_dim, dim, bias=use_bias))
21 | counter += 1
22 | if use_bn:
23 | layers.append(
24 | nn.BatchNorm1d(
25 | dim,
26 | eps=1e-5,
27 | momentum=0.1,
28 | )
29 | )
30 | if (counter < len(dims)) and use_relu:
31 | layers.append(nn.ReLU(inplace=True))
32 | last_dim = dim
33 | if use_dropout:
34 | layers.append(nn.Dropout())
35 | self.clf = nn.Sequential(*layers)
36 |
37 | def forward(self, batch):
38 | out = self.clf(batch)
39 | return out
40 |
--------------------------------------------------------------------------------
/models/trunks/pointnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import numpy as np
10 | import sys
11 | import os
12 |
13 | from models.trunks.mlp import MLP
14 |
15 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
16 | ROOT_DIR = os.path.dirname(ROOT_DIR)
17 | ROOT_DIR = os.path.dirname(ROOT_DIR)
18 | sys.path.append(os.path.join(ROOT_DIR, 'third_party', 'pointnet2'))
19 |
20 | from pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule
21 |
22 | class PointNet(nn.Module):
23 | r"""
24 | Backbone network for point cloud feature learning.
25 | Based on Pointnet++ single-scale grouping network.
26 |
27 | Parameters
28 | ----------
29 | input_feature_dim: int
30 | Number of input channels in the feature descriptor for each point.
31 | e.g. 3 for RGB.
32 | """
33 | def __init__(self, input_feature_dim=0, scale=1, use_mlp=False, mlp_dim=None):
34 | super().__init__()
35 |
36 | self.use_mlp = use_mlp
37 | self.sa1 = PointnetSAModuleVotes(
38 | npoint=2048,
39 | radius=0.2,
40 | nsample=64,
41 | mlp=[input_feature_dim, 64*scale, 64*scale, 128*scale],
42 | use_xyz=True,
43 | normalize_xyz=True
44 | )
45 |
46 | self.sa2 = PointnetSAModuleVotes(
47 | npoint=1024,
48 | radius=0.4,
49 | nsample=32,
50 | mlp=[128*scale, 128*scale, 128*scale, 256*scale],
51 | use_xyz=True,
52 | normalize_xyz=True
53 | )
54 |
55 | self.sa3 = PointnetSAModuleVotes(
56 | npoint=512,
57 | radius=0.8,
58 | nsample=16,
59 | mlp=[256*scale, 128*scale, 128*scale, 256*scale],
60 | use_xyz=True,
61 | normalize_xyz=True
62 | )
63 |
64 | self.sa4 = PointnetSAModuleVotes(
65 | npoint=256,
66 | radius=1.2,
67 | nsample=16,
68 | mlp=[256*scale, 128*scale, 128*scale, 256*scale],
69 | use_xyz=True,
70 | normalize_xyz=True
71 | )
72 |
73 | if scale == 1:
74 | self.fp1 = PointnetFPModule(mlp=[256+256,512,512])
75 | self.fp2 = PointnetFPModule(mlp=[512+256,512,512])
76 | else:
77 | self.fp1 = PointnetFPModule(mlp=[256*scale+256*scale,256*scale,256*scale])
78 | self.fp2 = PointnetFPModule(mlp=[256*scale+256*scale,256*scale,256*scale])
79 |
80 | if use_mlp:
81 | self.head = MLP(mlp_dim)
82 |
83 | self.all_feat_names = [
84 | "sa1",
85 | "sa2",
86 | "sa3",
87 | "sa4",
88 | "fp1",
89 | "fp2",
90 | ]
91 |
92 | def _break_up_pc(self, pc):
93 | xyz = pc[..., 0:3].contiguous()
94 | features = (
95 | pc[..., 3:].transpose(1, 2).contiguous()
96 | if pc.size(-1) > 3 else None
97 | )
98 |
99 | return xyz, features
100 |
101 | def forward(self, pointcloud: torch.cuda.FloatTensor, out_feat_keys=["fp2"]):
102 | r"""
103 | Forward pass of the network
104 |
105 | Parameters
106 | ----------
107 | pointcloud: Variable(torch.cuda.FloatTensor)
108 | (B, N, 3 + input_feature_dim) tensor
109 | Point cloud to run predicts on
110 | Each point in the point-cloud MUST
111 | be formated as (x, y, z, features...)
112 |
113 | Returns
114 | ----------
115 | end_points: {XXX_xyz, XXX_features, XXX_inds}
116 | XXX_xyz: float32 Tensor of shape (B,K,3)
117 | XXX_features: float32 Tensor of shape (B,K,D)
118 | XXX-inds: int64 Tensor of shape (B,K) values in [0,N-1]
119 | """
120 | batch_size = pointcloud.shape[0]
121 |
122 | xyz, features = self._break_up_pc(pointcloud)
123 |
124 | features = None ### Do not use other info for now
125 |
126 | end_points = {}
127 |
128 | # --------- 4 SET ABSTRACTION LAYERS ---------
129 | xyz, features, fps_inds = self.sa1(xyz, features)
130 | end_points['sa1_inds'] = fps_inds
131 | end_points['sa1_xyz'] = xyz
132 | end_points['sa1_features'] = features
133 |
134 | xyz, features, fps_inds = self.sa2(xyz, features) # this fps_inds is just 0,1,...,1023
135 | end_points['sa2_inds'] = fps_inds
136 | end_points['sa2_xyz'] = xyz
137 | end_points['sa2_features'] = features
138 |
139 | xyz, features, fps_inds = self.sa3(xyz, features) # this fps_inds is just 0,1,...,511
140 | end_points['sa3_xyz'] = xyz
141 | end_points['sa3_features'] = features
142 |
143 | xyz, features, fps_inds = self.sa4(xyz, features) # this fps_inds is just 0,1,...,255
144 | end_points['sa4_xyz'] = xyz
145 | end_points['sa4_features'] = features
146 |
147 | # --------- 2 FEATURE UPSAMPLING LAYERS --------
148 | features = self.fp1(end_points['sa3_xyz'], end_points['sa4_xyz'], end_points['sa3_features'], end_points['sa4_features'])
149 | end_points['fp1_features'] = features
150 | features = self.fp2(end_points['sa2_xyz'], end_points['sa3_xyz'], end_points['sa2_features'], features)
151 | end_points['fp2_features'] = features
152 | end_points['fp2_xyz'] = end_points['sa2_xyz']
153 | num_seed = end_points['fp2_xyz'].shape[1]
154 | end_points['fp2_inds'] = end_points['sa1_inds'][:,0:num_seed] # indices among the entire input point clouds
155 |
156 | out_feats = [None] * len(out_feat_keys)
157 |
158 | final_feat = []
159 | for key in out_feat_keys:
160 | feat = end_points[key+"_features"]
161 | org_feat = end_points[key+"_features"]
162 | nump = feat.shape[-1]
163 | feat = torch.squeeze(F.max_pool1d(feat, nump))
164 | ### Apply the head here
165 | if self.use_mlp:
166 | feat = self.head(feat)
167 | out_feats[out_feat_keys.index(key)] = feat
168 |
169 | return out_feats
170 |
171 |
172 | if __name__=='__main__':
173 | backbone_net = Pointnet2Backbone(input_feature_dim=3).cuda()
174 | print(backbone_net)
175 | backbone_net.eval()
176 | out = backbone_net(torch.rand(16,20000,6).cuda())
177 | for key in sorted(out.keys()):
178 | print(key, '\t', out[key].shape)
179 |
--------------------------------------------------------------------------------
/models/trunks/pointnet2_backbone.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | import os
12 | import sys
13 |
14 | from models.trunks.mlp import MLP
15 |
16 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
17 | ROOT_DIR = os.path.dirname(ROOT_DIR)
18 | ROOT_DIR = os.path.dirname(ROOT_DIR)
19 | sys.path.append(os.path.join(ROOT_DIR, 'third_party', 'OpenPCDet', "pcdet"))
20 |
21 | from ops.pointnet2.pointnet2_batch import pointnet2_modules
22 |
23 | class PointNet2MSG(nn.Module):
24 | def __init__(self, use_mlp=False, mlp_dim=None):
25 | super().__init__()
26 |
27 | input_channels = 4
28 |
29 | self.SA_modules = nn.ModuleList()
30 | channel_in = input_channels - 3
31 |
32 | self.num_points_each_layer = []
33 | skip_channel_list = [input_channels - 3]
34 | SA_CONFIG = {'NPOINTS': [4096, 1024, 256, 64], 'RADIUS': [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]], 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [16, 32]], 'MLPS': [[[16, 16, 32], [32, 32, 64]], [[64, 64, 128], [64, 96, 128]], [[128, 196, 256], [128, 196, 256]], [[256, 256, 512], [256, 384, 512]]]}
35 |
36 | FP_MLPS = [[128, 128], [256, 256], [512, 512], [512, 512]]
37 |
38 | for k in range(SA_CONFIG["NPOINTS"].__len__()):
39 | mlps = SA_CONFIG["MLPS"][k].copy()
40 | channel_out = 0
41 | for idx in range(mlps.__len__()):
42 | mlps[idx] = [channel_in] + mlps[idx]
43 | channel_out += mlps[idx][-1]
44 |
45 | self.SA_modules.append(
46 | pointnet2_modules.PointnetSAModuleMSG(
47 | npoint=SA_CONFIG["NPOINTS"][k],
48 | radii=SA_CONFIG["RADIUS"][k],
49 | nsamples=SA_CONFIG["NSAMPLE"][k],
50 | mlps=mlps,
51 | use_xyz=True,
52 | )
53 | )
54 | skip_channel_list.append(channel_out)
55 | channel_in = channel_out
56 |
57 | self.FP_modules = nn.ModuleList()
58 |
59 | for k in range(FP_MLPS.__len__()):
60 | pre_channel = FP_MLPS[k + 1][-1] if k + 1 < len(FP_MLPS) else channel_out
61 | self.FP_modules.append(
62 | pointnet2_modules.PointnetFPModule(
63 | mlp=[pre_channel + skip_channel_list[k]] + FP_MLPS[k]
64 | )
65 | )
66 |
67 | self.num_point_features = FP_MLPS[0][-1]
68 |
69 | self.all_feat_names = [
70 | "fp2",
71 | ]
72 |
73 | if use_mlp:
74 | self.use_mlp = True
75 | self.head = MLP(mlp_dim)
76 |
77 |
78 | def break_up_pc(self, pc):
79 | #batch_idx = pc[:, 0]
80 | xyz = pc[:, :, 0:3].contiguous()
81 | features = (pc[:, :, 3:].contiguous() if pc.size(-1) > 3 else None)
82 | return xyz, features
83 |
84 | def forward(self, pointcloud: torch.cuda.FloatTensor, out_feat_keys=None):
85 | """
86 | Args:
87 | batch_dict:
88 | batch_size: int
89 | vfe_features: (num_voxels, C)
90 | points: (num_points, 4 + C), [batch_idx, x, y, z, ...]
91 | Returns:
92 | batch_dict:
93 | encoded_spconv_tensor: sparse tensor
94 | point_features: (N, C)
95 | """
96 | batch_size = pointcloud.shape[0]
97 | points = pointcloud
98 | xyz, features = self.break_up_pc(points)
99 |
100 | features = features.view(batch_size, -1, features.shape[-1]).permute(0, 2, 1).contiguous() if features is not None else None
101 | l_xyz, l_features = [xyz], [features]
102 | for i in range(len(self.SA_modules)):
103 | assert l_features[i].is_contiguous()
104 | li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
105 | assert li_features.is_contiguous()
106 | l_xyz.append(li_xyz)
107 | l_features.append(li_features)
108 |
109 | for i in range(-1, -(len(self.FP_modules) + 1), -1):
110 | assert l_features[i - 1].is_contiguous()
111 | assert l_features[i].is_contiguous()
112 | l_features[i - 1] = self.FP_modules[i](
113 | l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
114 | ) # (B, C, N)
115 | assert l_features[i - 1].is_contiguous()
116 |
117 | point_features = l_features[0]
118 |
119 | end_points = {}
120 | end_points['fp2_features'] = point_features
121 |
122 | out_feats = [None] * len(out_feat_keys)
123 | for key in out_feat_keys:
124 | feat = end_points[key+"_features"]
125 | nump = feat.shape[-1]
126 |
127 | feat = torch.squeeze(F.max_pool1d(feat, nump))
128 | if self.use_mlp:
129 | feat = self.head(feat)
130 | out_feats[out_feat_keys.index(key)] = feat
131 |
132 | return out_feats
133 |
--------------------------------------------------------------------------------
/models/trunks/smlp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 |
9 | import torch.nn as nn
10 | import MinkowskiEngine as ME
11 |
12 | class SMLP(nn.Module):
13 | def __init__(
14 | self, dims, use_bn=False, use_relu=False, use_dropout=False, use_bias=True
15 | ):
16 | super().__init__()
17 | layers = []
18 | last_dim = dims[0]
19 | counter = 1
20 | for dim in dims[1:]:
21 | layers.append(ME.MinkowskiLinear(last_dim, dim, bias=use_bias))
22 | counter += 1
23 | if use_bn:
24 | layers.append(
25 | ME.MinkowskiBatchNorm(
26 | dim,
27 | eps=1e-5,
28 | momentum=0.1,
29 | )
30 | )
31 | if (counter < len(dims)) and use_relu:
32 | layers.append(ME.MinkowskiReLU(inplace=True))
33 | last_dim = dim
34 | if use_dropout:
35 | layers.append(MinkowskiDropout.Dropout())
36 | self.clf = nn.Sequential(*layers)
37 |
38 | def forward(self, batch):
39 | out = self.clf(batch)
40 | return out
41 |
--------------------------------------------------------------------------------
/models/trunks/spconv/lib/math_functions.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from scipy.sparse import csr_matrix
7 | import torch
8 |
9 |
10 | class SparseMM(torch.autograd.Function):
11 | """
12 | Sparse x dense matrix multiplication with autograd support.
13 | Implementation by Soumith Chintala:
14 | https://discuss.pytorch.org/t/
15 | does-pytorch-support-autograd-on-sparse-matrix/6156/7
16 | """
17 |
18 | def forward(self, matrix1, matrix2):
19 | self.save_for_backward(matrix1, matrix2)
20 | return torch.mm(matrix1, matrix2)
21 |
22 | def backward(self, grad_output):
23 | matrix1, matrix2 = self.saved_tensors
24 | grad_matrix1 = grad_matrix2 = None
25 |
26 | if self.needs_input_grad[0]:
27 | grad_matrix1 = torch.mm(grad_output, matrix2.t())
28 |
29 | if self.needs_input_grad[1]:
30 | grad_matrix2 = torch.mm(matrix1.t(), grad_output)
31 |
32 | return grad_matrix1, grad_matrix2
33 |
34 |
35 | def sparse_float_tensor(values, indices, size=None):
36 | """
37 | Return a torch sparse matrix give values and indices (row_ind, col_ind).
38 | If the size is an integer, return a square matrix with side size.
39 | If the size is a torch.Size, use it to initialize the out tensor.
40 | If none, the size is inferred.
41 | """
42 | indices = torch.stack(indices).int()
43 | sargs = [indices, values.float()]
44 | if size is not None:
45 | # Use the provided size
46 | if isinstance(size, int):
47 | size = torch.Size((size, size))
48 | sargs.append(size)
49 | if values.is_cuda:
50 | return torch.cuda.sparse.FloatTensor(*sargs)
51 | else:
52 | return torch.sparse.FloatTensor(*sargs)
53 |
54 |
55 | def diags(values, size=None):
56 | values = values.view(-1)
57 | n = values.nelement()
58 | size = torch.Size((n, n))
59 | indices = (torch.arange(0, n), torch.arange(0, n))
60 | return sparse_float_tensor(values, indices, size)
61 |
62 |
63 | def sparse_to_csr_matrix(tensor):
64 | tensor = tensor.cpu()
65 | inds = tensor._indices().numpy()
66 | vals = tensor._values().numpy()
67 | return csr_matrix((vals, (inds[0], inds[1])), shape=[s for s in tensor.shape])
68 |
69 |
70 | def csr_matrix_to_sparse(mat):
71 | row_ind, col_ind = mat.nonzero()
72 | return sparse_float_tensor(
73 | torch.from_numpy(mat.data),
74 | (torch.from_numpy(row_ind), torch.from_numpy(col_ind)),
75 | size=torch.Size(mat.shape))
76 |
--------------------------------------------------------------------------------
/models/trunks/spconv/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from models.trunks.spconv.models import resunet
7 | from models.trunks.spconv.models import res16unet
8 |
9 | # from models.trilateral_crf import TrilateralCRF
10 | from models.trunks.spconv.models.conditional_random_fields import BilateralCRF, TrilateralCRF
11 |
12 | MODELS = []
13 |
14 |
15 | def add_models(module):
16 | MODELS.extend([getattr(module, a) for a in dir(module) if 'Net' in a])
17 |
18 |
19 | add_models(resunet)
20 | add_models(res16unet)
21 |
22 | WRAPPERS = [BilateralCRF, TrilateralCRF]
23 |
24 |
25 | def get_models():
26 | '''Returns a tuple of sample models.'''
27 | return MODELS
28 |
29 |
30 | def get_wrappers():
31 | return WRAPPERS
32 |
33 |
34 | def load_model(name):
35 | '''Creates and returns an instance of the model given its class name.
36 | '''
37 | # Find the model class from its name
38 | all_models = get_models()
39 | mdict = {model.__name__: model for model in all_models}
40 | if name not in mdict:
41 | print('Invalid model index. Options are:')
42 | # Display a list of valid model names
43 | for model in all_models:
44 | print('\t* {}'.format(model.__name__))
45 | return None
46 | NetClass = mdict[name]
47 |
48 | return NetClass
49 |
50 |
51 | def load_wrapper(name):
52 | '''Creates and returns an instance of the model given its class name.
53 | '''
54 | # Find the model class from its name
55 | all_wrappers = get_wrappers()
56 | mdict = {wrapper.__name__: wrapper for wrapper in all_wrappers}
57 | if name not in mdict:
58 | print('Invalid wrapper index. Options are:')
59 | # Display a list of valid model names
60 | for wrapper in all_wrappers:
61 | print('\t* {}'.format(wrapper.__name__))
62 | return None
63 | WrapperClass = mdict[name]
64 |
65 | return WrapperClass
66 |
--------------------------------------------------------------------------------
/models/trunks/spconv/models/conditional_random_fields.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import torch
7 | import torch.nn as nn
8 | from torch.autograd import Variable
9 |
10 | from MinkowskiEngine import SparseTensor, MinkowskiConvolution, MinkowskiConvolutionFunction, convert_to_int_tensor
11 | from MinkowskiEngine import convert_region_type as me_convert_region_type
12 |
13 | from models.trunks.spconv.models.model import HighDimensionalModel
14 | from models.trunks.spconv.models.wrapper import Wrapper
15 | from models.trunks.spconv.lib.math_functions import SparseMM
16 | from models.trunks.spconv.models.modules.common import convert_region_type
17 |
18 |
19 | class MeanField(HighDimensionalModel):
20 | """
21 | Abstract class for the bilateral and trilateral meanfield
22 | """
23 | OUT_PIXEL_DIST = 1
24 |
25 | # To use the model, must call initialize_coords before forward pass.
26 | # Once data is processed, call clear to reset the model before calling
27 | # initialize_coords
28 | def __init__(self, nchannels, spatial_sigma, chromatic_sigma, meanfield_iterations, is_temporal,
29 | config, **kwargs):
30 | D = 7 if is_temporal else 6
31 | self.is_temporal = is_temporal
32 | # Setup metadata
33 | super(MeanField, self).__init__(nchannels, nchannels, config, D=D)
34 |
35 | self.spatial_sigma = spatial_sigma
36 | self.chromatic_sigma = chromatic_sigma
37 | # temporal sigma is 1
38 | self.meanfield_iterations = meanfield_iterations
39 |
40 | self.pixel_dist = 1
41 | self.stride = 1
42 | self.dilation = 1
43 |
44 | conv = MinkowskiConvolution(
45 | nchannels,
46 | nchannels,
47 | kernel_size=config.wrapper_kernel_size,
48 | has_bias=False,
49 | region_type=convert_region_type(config.wrapper_region_type),
50 | dimension=D)
51 |
52 | # Create a region_offset
53 | self.region_type_, self.region_offset_, _ = me_convert_region_type(
54 | conv.region_type, 1, conv.kernel_size, conv.up_stride, conv.dilation, conv.region_offset,
55 | conv.axis_types, conv.dimension)
56 |
57 | # Check whether the mapping is required
58 | self.requires_mapping = False
59 | self.conv = conv
60 | self.kernel = conv.kernel
61 | self.convs = {}
62 | self.softmaxes = {}
63 | for i in range(self.meanfield_iterations):
64 | self.softmaxes[i] = nn.Softmax(dim=1)
65 | self.convs[i] = MinkowskiConvolutionFunction()
66 |
67 | def initialize_coords(self, model, in_coords, in_color):
68 | if torch.prod(convert_to_int_tensor(model.OUT_PIXEL_DIST, model.D)) != 1:
69 | self.requires_mapping = True
70 |
71 | out_coords = model.get_coords(model.OUT_PIXEL_DIST)
72 | out_color = model.permute_feature(in_color, model.OUT_PIXEL_DIST).int()
73 |
74 | # Tri/Bi-lateral grid
75 | out_tri_coords = torch.cat(
76 | [
77 | (torch.floor(out_coords[:, :3].float() / self.spatial_sigma)).int(),
78 | (torch.floor(out_color.float() / self.chromatic_sigma)).int(),
79 | out_coords[:, 3:] # (time and) batch
80 | ],
81 | dim=1)
82 | orig_tri_coords = torch.cat(
83 | [
84 | (torch.floor(in_coords[:, :3].float() / self.spatial_sigma)).int(),
85 | (torch.floor(in_color.float() / self.chromatic_sigma)).int(),
86 | in_coords[:, 3:] # (time and) batch
87 | ],
88 | dim=1)
89 |
90 | crf_tri_coords = torch.cat((out_tri_coords, orig_tri_coords), dim=0)
91 |
92 | # Create a trilateral Grid
93 | # super(MeanField, self).initialize_coords_with_duplicates(crf_tri_coords)
94 |
95 | # Create Sparse matrix mappings to/from the CRF coords
96 | in_cols = self.get_index_map(out_tri_coords, 1)
97 | self.in_mapping = torch.sparse.FloatTensor(
98 | torch.stack((in_cols.long(), torch.arange(in_cols.size(0), out=torch.LongTensor()))),
99 | torch.ones(in_cols.size(0)), torch.Size((self.n_rows, in_cols.size(0))))
100 |
101 | out_cols = self.get_index_map(orig_tri_coords, 1)
102 | self.out_mapping = torch.sparse.FloatTensor(
103 | torch.stack((torch.arange(out_cols.size(0), out=torch.LongTensor()), out_cols.long())),
104 | torch.ones(out_cols.size(0)), torch.Size((out_cols.size(0), self.n_rows)))
105 |
106 | if self.config.is_cuda:
107 | self.in_mapping, self.out_mapping = self.in_mapping.cuda(), self.out_mapping.cuda()
108 |
109 | else:
110 | self.requires_mapping = False
111 |
112 | out_coords = in_coords
113 | out_color = in_color
114 | crf_tri_coords = torch.cat(
115 | [
116 | (torch.floor(in_coords[:, :3].float() / self.spatial_sigma)).int(),
117 | (torch.floor(in_color.float() / self.chromatic_sigma)).int(),
118 | in_coords[:, 3:], # (time and) batch
119 | ],
120 | dim=1)
121 |
122 | return crf_tri_coords
123 |
124 | def forward(self, x):
125 | xf = x.F
126 | if self.requires_mapping:
127 | # Map the network output to CRF input
128 | xf = SparseMM()(Variable(self.in_mapping), xf)
129 |
130 | out = xf
131 | for i in range(self.meanfield_iterations): # Meanfield iteration
132 | # Normalization
133 | out = self.softmaxes[i](out)
134 | # Pairwise potential
135 | out = self.convs[i].apply(out, self.conv.kernel, x.pixel_dist, self.conv.stride,
136 | self.conv.kernel_size, self.conv.dilation, self.region_type_,
137 | self.region_offset_, x.coords_key, x.coords_key, x.coords_man)
138 | # Add unary
139 | out += xf
140 |
141 | if self.requires_mapping:
142 | # Map the CRF output to the origianl space
143 | out = SparseMM()(Variable(self.out_mapping), out)
144 |
145 | return SparseTensor(out, coords_key=x.coords_key, coords_manager=x.coords_man)
146 |
147 |
148 | class BilateralCRF(Wrapper):
149 | OUT_PIXEL_DIST = 1
150 |
151 | def initialize_filter(self, NetClass, in_nchannel, out_nchannel, config):
152 | self.model = NetClass(in_nchannel, out_nchannel, config)
153 | self.filter = MeanField(
154 | out_nchannel,
155 | spatial_sigma=config.crf_spatial_sigma,
156 | chromatic_sigma=config.crf_chromatic_sigma,
157 | meanfield_iterations=config.meanfield_iterations,
158 | is_temporal=False,
159 | config=config)
160 |
161 |
162 | class TrilateralCRF(Wrapper):
163 | OUT_PIXEL_DIST = 1
164 |
165 | def initialize_filter(self, NetClass, in_nchannel, out_nchannel, config):
166 | self.model = NetClass(in_nchannel, out_nchannel, config)
167 | self.filter = MeanField(
168 | out_nchannel,
169 | spatial_sigma=config.crf_spatial_sigma,
170 | chromatic_sigma=config.crf_chromatic_sigma,
171 | meanfield_iterations=config.meanfield_iterations,
172 | is_temporal=True,
173 | config=config)
174 |
--------------------------------------------------------------------------------
/models/trunks/spconv/models/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from MinkowskiEngine import MinkowskiNetwork
7 |
8 |
9 | class Model(MinkowskiNetwork):
10 | """
11 | Base network for all sparse convnet
12 |
13 | By default, all networks are segmentation networks.
14 | """
15 | OUT_PIXEL_DIST = -1
16 |
17 | def __init__(self, in_channels, out_channels, D, **kwargs):
18 | super(Model, self).__init__(D)
19 | self.in_channels = in_channels
20 | self.out_channels = out_channels
21 | #self.config = config
22 |
23 |
24 | class HighDimensionalModel(Model):
25 | """
26 | Base network for all spatio (temporal) chromatic sparse convnet
27 | """
28 |
29 | def __init__(self, in_channels, out_channels, config, D, **kwargs):
30 | assert D > 4, "Num dimension smaller than 5"
31 | super(HighDimensionalModel, self).__init__(in_channels, out_channels, config, D, **kwargs)
32 |
--------------------------------------------------------------------------------
/models/trunks/spconv/models/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/models/trunks/spconv/models/modules/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import collections
8 | from enum import Enum
9 | import torch.nn as nn
10 |
11 | import MinkowskiEngine as ME
12 |
13 |
14 | class NormType(Enum):
15 | BATCH_NORM = 0
16 | INSTANCE_NORM = 1
17 | INSTANCE_BATCH_NORM = 2
18 |
19 |
20 | def get_norm(norm_type, n_channels, D, bn_momentum=0.1):
21 | if norm_type == NormType.BATCH_NORM:
22 | return ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum)
23 | elif norm_type == NormType.INSTANCE_NORM:
24 | return ME.MinkowskiInstanceNorm(n_channels)
25 | elif norm_type == NormType.INSTANCE_BATCH_NORM:
26 | return nn.Sequential(
27 | ME.MinkowskiInstanceNorm(n_channels),
28 | ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum))
29 | else:
30 | raise ValueError(f'Norm type: {norm_type} not supported')
31 |
32 |
33 | class ConvType(Enum):
34 | """
35 | Define the kernel region type
36 | """
37 | HYPERCUBE = 0, 'HYPERCUBE'
38 | SPATIAL_HYPERCUBE = 1, 'SPATIAL_HYPERCUBE'
39 | SPATIO_TEMPORAL_HYPERCUBE = 2, 'SPATIO_TEMPORAL_HYPERCUBE'
40 | HYPERCROSS = 3, 'HYPERCROSS'
41 | SPATIAL_HYPERCROSS = 4, 'SPATIAL_HYPERCROSS'
42 | SPATIO_TEMPORAL_HYPERCROSS = 5, 'SPATIO_TEMPORAL_HYPERCROSS'
43 | SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS = 6, 'SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS '
44 |
45 | def __new__(cls, value, name):
46 | member = object.__new__(cls)
47 | member._value_ = value
48 | member.fullname = name
49 | return member
50 |
51 | def __int__(self):
52 | return self.value
53 |
54 |
55 | # Covert the ConvType var to a RegionType var
56 | conv_to_region_type = {
57 | # kernel_size = [k, k, k, 1]
58 | ConvType.HYPERCUBE: ME.RegionType.HYPERCUBE,
59 | ConvType.SPATIAL_HYPERCUBE: ME.RegionType.HYPERCUBE,
60 | ConvType.SPATIO_TEMPORAL_HYPERCUBE: ME.RegionType.HYPERCUBE,
61 | ConvType.HYPERCROSS: ME.RegionType.HYPERCROSS,
62 | ConvType.SPATIAL_HYPERCROSS: ME.RegionType.HYPERCROSS,
63 | ConvType.SPATIO_TEMPORAL_HYPERCROSS: ME.RegionType.HYPERCROSS,
64 | ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: ME.RegionType.HYBRID
65 | }
66 |
67 | int_to_region_type = {m.value: m for m in ME.RegionType}
68 |
69 |
70 | def convert_region_type(region_type):
71 | """
72 | Convert the integer region_type to the corresponding RegionType enum object.
73 | """
74 | return int_to_region_type[region_type]
75 |
76 |
77 | def convert_conv_type(conv_type, kernel_size, D):
78 | assert isinstance(conv_type, ConvType), "conv_type must be of ConvType"
79 | region_type = conv_to_region_type[conv_type]
80 | axis_types = None
81 | if conv_type == ConvType.SPATIAL_HYPERCUBE:
82 | # No temporal convolution
83 | if isinstance(kernel_size, collections.Sequence):
84 | kernel_size = kernel_size[:3]
85 | else:
86 | kernel_size = [
87 | kernel_size,
88 | ] * 3
89 | if D == 4:
90 | kernel_size.append(1)
91 | elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCUBE:
92 | # conv_type conversion already handled
93 | assert D == 4
94 | elif conv_type == ConvType.HYPERCUBE:
95 | # conv_type conversion already handled
96 | pass
97 | elif conv_type == ConvType.SPATIAL_HYPERCROSS:
98 | if isinstance(kernel_size, collections.Sequence):
99 | kernel_size = kernel_size[:3]
100 | else:
101 | kernel_size = [
102 | kernel_size,
103 | ] * 3
104 | if D == 4:
105 | kernel_size.append(1)
106 | elif conv_type == ConvType.HYPERCROSS:
107 | # conv_type conversion already handled
108 | pass
109 | elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCROSS:
110 | # conv_type conversion already handled
111 | assert D == 4
112 | elif conv_type == ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS:
113 | # Define the CUBIC conv kernel for spatial dims and CROSS conv for temp dim
114 | axis_types = [
115 | ME.RegionType.HYPERCUBE,
116 | ] * 3
117 | if D == 4:
118 | axis_types.append(ME.RegionType.HYPERCROSS)
119 | return region_type, axis_types, kernel_size
120 |
121 |
122 | def conv(in_planes,
123 | out_planes,
124 | kernel_size,
125 | stride=1,
126 | dilation=1,
127 | bias=False,
128 | conv_type=ConvType.HYPERCUBE,
129 | D=-1):
130 | assert D > 0, 'Dimension must be a positive integer'
131 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D)
132 | kernel_generator = ME.KernelGenerator(
133 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D)
134 |
135 | return ME.MinkowskiConvolution(
136 | in_channels=in_planes,
137 | out_channels=out_planes,
138 | kernel_size=kernel_size,
139 | stride=stride,
140 | dilation=dilation,
141 | has_bias=bias,
142 | kernel_generator=kernel_generator,
143 | dimension=D)
144 |
145 |
146 | def conv_tr(in_planes,
147 | out_planes,
148 | kernel_size,
149 | upsample_stride=1,
150 | dilation=1,
151 | bias=False,
152 | conv_type=ConvType.HYPERCUBE,
153 | D=-1):
154 | assert D > 0, 'Dimension must be a positive integer'
155 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D)
156 | kernel_generator = ME.KernelGenerator(
157 | kernel_size,
158 | upsample_stride,
159 | dilation,
160 | region_type=region_type,
161 | axis_types=axis_types,
162 | dimension=D)
163 |
164 | return ME.MinkowskiConvolutionTranspose(
165 | in_channels=in_planes,
166 | out_channels=out_planes,
167 | kernel_size=kernel_size,
168 | stride=upsample_stride,
169 | dilation=dilation,
170 | has_bias=bias,
171 | kernel_generator=kernel_generator,
172 | dimension=D)
173 |
174 |
175 | def avg_pool(kernel_size,
176 | stride=1,
177 | dilation=1,
178 | conv_type=ConvType.HYPERCUBE,
179 | in_coords_key=None,
180 | D=-1):
181 | assert D > 0, 'Dimension must be a positive integer'
182 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D)
183 | kernel_generator = ME.KernelGenerator(
184 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D)
185 |
186 | return ME.MinkowskiAvgPooling(
187 | kernel_size=kernel_size,
188 | stride=stride,
189 | dilation=dilation,
190 | kernel_generator=kernel_generator,
191 | dimension=D)
192 |
193 |
194 | def avg_unpool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1):
195 | assert D > 0, 'Dimension must be a positive integer'
196 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D)
197 | kernel_generator = ME.KernelGenerator(
198 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D)
199 |
200 | return ME.MinkowskiAvgUnpooling(
201 | kernel_size=kernel_size,
202 | stride=stride,
203 | dilation=dilation,
204 | kernel_generator=kernel_generator,
205 | dimension=D)
206 |
207 |
208 | def sum_pool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1):
209 | assert D > 0, 'Dimension must be a positive integer'
210 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D)
211 | kernel_generator = ME.KernelGenerator(
212 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D)
213 |
214 | return ME.MinkowskiSumPooling(
215 | kernel_size=kernel_size,
216 | stride=stride,
217 | dilation=dilation,
218 | kernel_generator=kernel_generator,
219 | dimension=D)
220 |
--------------------------------------------------------------------------------
/models/trunks/spconv/models/modules/resnet_block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import torch.nn as nn
7 |
8 | from models.trunks.spconv.models.modules.common import ConvType, NormType, get_norm, conv
9 |
10 | from MinkowskiEngine import MinkowskiReLU
11 |
12 |
13 | class BasicBlockBase(nn.Module):
14 | expansion = 1
15 | NORM_TYPE = NormType.BATCH_NORM
16 |
17 | def __init__(self,
18 | inplanes,
19 | planes,
20 | stride=1,
21 | dilation=1,
22 | downsample=None,
23 | conv_type=ConvType.HYPERCUBE,
24 | bn_momentum=0.1,
25 | D=3):
26 | super(BasicBlockBase, self).__init__()
27 |
28 | self.conv1 = conv(
29 | inplanes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D)
30 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
31 | self.conv2 = conv(
32 | planes,
33 | planes,
34 | kernel_size=3,
35 | stride=1,
36 | dilation=dilation,
37 | bias=False,
38 | conv_type=conv_type,
39 | D=D)
40 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
41 | self.relu = MinkowskiReLU(inplace=True)
42 | self.downsample = downsample
43 |
44 | def forward(self, x):
45 | residual = x
46 |
47 | out = self.conv1(x)
48 | out = self.norm1(out)
49 | out = self.relu(out)
50 |
51 | out = self.conv2(out)
52 | out = self.norm2(out)
53 |
54 | if self.downsample is not None:
55 | residual = self.downsample(x)
56 |
57 | out += residual
58 | out = self.relu(out)
59 |
60 | return out
61 |
62 |
63 | class BasicBlock(BasicBlockBase):
64 | NORM_TYPE = NormType.BATCH_NORM
65 |
66 |
67 | class BasicBlockIN(BasicBlockBase):
68 | NORM_TYPE = NormType.INSTANCE_NORM
69 |
70 |
71 | class BasicBlockINBN(BasicBlockBase):
72 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM
73 |
74 |
75 | class BottleneckBase(nn.Module):
76 | expansion = 4
77 | NORM_TYPE = NormType.BATCH_NORM
78 |
79 | def __init__(self,
80 | inplanes,
81 | planes,
82 | stride=1,
83 | dilation=1,
84 | downsample=None,
85 | conv_type=ConvType.HYPERCUBE,
86 | bn_momentum=0.1,
87 | D=3):
88 | super(BottleneckBase, self).__init__()
89 | self.conv1 = conv(inplanes, planes, kernel_size=1, D=D)
90 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
91 |
92 | self.conv2 = conv(
93 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D)
94 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
95 |
96 | self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, D=D)
97 | self.norm3 = get_norm(self.NORM_TYPE, planes * self.expansion, D, bn_momentum=bn_momentum)
98 |
99 | self.relu = MinkowskiReLU(inplace=True)
100 | self.downsample = downsample
101 |
102 | def forward(self, x):
103 | residual = x
104 |
105 | out = self.conv1(x)
106 | out = self.norm1(out)
107 | out = self.relu(out)
108 |
109 | out = self.conv2(out)
110 | out = self.norm2(out)
111 | out = self.relu(out)
112 |
113 | out = self.conv3(out)
114 | out = self.norm3(out)
115 |
116 | if self.downsample is not None:
117 | residual = self.downsample(x)
118 |
119 | out += residual
120 | out = self.relu(out)
121 |
122 | return out
123 |
124 |
125 | class Bottleneck(BottleneckBase):
126 | NORM_TYPE = NormType.BATCH_NORM
127 |
128 |
129 | class BottleneckIN(BottleneckBase):
130 | NORM_TYPE = NormType.INSTANCE_NORM
131 |
132 |
133 | class BottleneckINBN(BottleneckBase):
134 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM
135 |
--------------------------------------------------------------------------------
/models/trunks/spconv/models/modules/senet_block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import torch.nn as nn
7 |
8 | import MinkowskiEngine as ME
9 |
10 | from models.trunks.spconv.models.modules.common import ConvType, NormType
11 | from models.trunks.spconv.models.modules.resnet_block import BasicBlock, Bottleneck
12 |
13 |
14 | class SELayer(nn.Module):
15 |
16 | def __init__(self, channel, reduction=16, D=-1):
17 | # Global coords does not require coords_key
18 | super(SELayer, self).__init__()
19 | self.fc = nn.Sequential(
20 | ME.MinkowskiLinear(channel, channel // reduction), ME.MinkowskiReLU(inplace=True),
21 | ME.MinkowskiLinear(channel // reduction, channel), ME.MinkowskiSigmoid())
22 | self.pooling = ME.MinkowskiGlobalPooling(dimension=D)
23 | self.broadcast_mul = ME.MinkowskiBroadcastMultiplication(dimension=D)
24 |
25 | def forward(self, x):
26 | y = self.pooling(x)
27 | y = self.fc(y)
28 | return self.broadcast_mul(x, y)
29 |
30 |
31 | class SEBasicBlock(BasicBlock):
32 |
33 | def __init__(self,
34 | inplanes,
35 | planes,
36 | stride=1,
37 | dilation=1,
38 | downsample=None,
39 | conv_type=ConvType.HYPERCUBE,
40 | reduction=16,
41 | D=-1):
42 | super(SEBasicBlock, self).__init__(
43 | inplanes,
44 | planes,
45 | stride=stride,
46 | dilation=dilation,
47 | downsample=downsample,
48 | conv_type=conv_type,
49 | D=D)
50 | self.se = SELayer(planes, reduction=reduction, D=D)
51 |
52 | def forward(self, x):
53 | residual = x
54 |
55 | out = self.conv1(x)
56 | out = self.norm1(out)
57 | out = self.relu(out)
58 |
59 | out = self.conv2(out)
60 | out = self.norm2(out)
61 | out = self.se(out)
62 |
63 | if self.downsample is not None:
64 | residual = self.downsample(x)
65 |
66 | out += residual
67 | out = self.relu(out)
68 |
69 | return out
70 |
71 |
72 | class SEBasicBlockSN(SEBasicBlock):
73 | NORM_TYPE = NormType.SPARSE_SWITCH_NORM
74 |
75 |
76 | class SEBasicBlockIN(SEBasicBlock):
77 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
78 |
79 |
80 | class SEBasicBlockLN(SEBasicBlock):
81 | NORM_TYPE = NormType.SPARSE_LAYER_NORM
82 |
83 |
84 | class SEBottleneck(Bottleneck):
85 |
86 | def __init__(self,
87 | inplanes,
88 | planes,
89 | stride=1,
90 | dilation=1,
91 | downsample=None,
92 | conv_type=ConvType.HYPERCUBE,
93 | D=3,
94 | reduction=16):
95 | super(SEBottleneck, self).__init__(
96 | inplanes,
97 | planes,
98 | stride=stride,
99 | dilation=dilation,
100 | downsample=downsample,
101 | conv_type=conv_type,
102 | D=D)
103 | self.se = SELayer(planes * self.expansion, reduction=reduction, D=D)
104 |
105 | def forward(self, x):
106 | residual = x
107 |
108 | out = self.conv1(x)
109 | out = self.norm1(out)
110 | out = self.relu(out)
111 |
112 | out = self.conv2(out)
113 | out = self.norm2(out)
114 | out = self.relu(out)
115 |
116 | out = self.conv3(out)
117 | out = self.norm3(out)
118 | out = self.se(out)
119 |
120 | if self.downsample is not None:
121 | residual = self.downsample(x)
122 |
123 | out += residual
124 | out = self.relu(out)
125 |
126 | return out
127 |
128 |
129 | class SEBottleneckSN(SEBottleneck):
130 | NORM_TYPE = NormType.SPARSE_SWITCH_NORM
131 |
132 |
133 | class SEBottleneckIN(SEBottleneck):
134 | NORM_TYPE = NormType.SPARSE_INSTANCE_NORM
135 |
136 |
137 | class SEBottleneckLN(SEBottleneck):
138 | NORM_TYPE = NormType.SPARSE_LAYER_NORM
139 |
--------------------------------------------------------------------------------
/models/trunks/spconv/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import torch.nn as nn
7 |
8 | import MinkowskiEngine as ME
9 |
10 | from models.trunks.spconv.models.model import Model
11 | from models.trunks.spconv.models.modules.common import ConvType, NormType, get_norm, conv, sum_pool
12 | from models.trunks.spconv.models.modules.resnet_block import BasicBlock, Bottleneck
13 |
14 |
15 | class ResNetBase(Model):
16 | BLOCK = None
17 | LAYERS = ()
18 | INIT_DIM = 64
19 | PLANES = (64, 128, 256, 512)
20 | OUT_PIXEL_DIST = 32
21 | HAS_LAST_BLOCK = False
22 | CONV_TYPE = ConvType.HYPERCUBE
23 |
24 | def __init__(self, in_channels=1, out_channels=1, D=3, **kwargs):
25 | assert self.BLOCK is not None
26 | assert self.OUT_PIXEL_DIST > 0
27 |
28 | super(ResNetBase, self).__init__(in_channels, out_channels, D, **kwargs)
29 |
30 | self.network_initialization(in_channels, out_channels, D)
31 | self.weight_initialization()
32 |
33 | def network_initialization(self, in_channels, out_channels, D):
34 |
35 | def space_n_time_m(n, m):
36 | return n if D == 3 else [n, n, n, m]
37 |
38 | if D == 4:
39 | self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)
40 |
41 | dilations = config.dilations
42 | bn_momentum = config.bn_momentum
43 | self.inplanes = self.INIT_DIM
44 | self.conv1 = conv(
45 | in_channels,
46 | self.inplanes,
47 | kernel_size=space_n_time_m(config.conv1_kernel_size, 1),
48 | stride=1,
49 | D=D)
50 |
51 | self.bn1 = get_norm(NormType.BATCH_NORM, self.inplanes, D=self.D, bn_momentum=bn_momentum)
52 | self.relu = ME.MinkowskiReLU(inplace=True)
53 | self.pool = sum_pool(kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), D=D)
54 |
55 | self.layer1 = self._make_layer(
56 | self.BLOCK,
57 | self.PLANES[0],
58 | self.LAYERS[0],
59 | stride=space_n_time_m(2, 1),
60 | dilation=space_n_time_m(dilations[0], 1))
61 | self.layer2 = self._make_layer(
62 | self.BLOCK,
63 | self.PLANES[1],
64 | self.LAYERS[1],
65 | stride=space_n_time_m(2, 1),
66 | dilation=space_n_time_m(dilations[1], 1))
67 | self.layer3 = self._make_layer(
68 | self.BLOCK,
69 | self.PLANES[2],
70 | self.LAYERS[2],
71 | stride=space_n_time_m(2, 1),
72 | dilation=space_n_time_m(dilations[2], 1))
73 | self.layer4 = self._make_layer(
74 | self.BLOCK,
75 | self.PLANES[3],
76 | self.LAYERS[3],
77 | stride=space_n_time_m(2, 1),
78 | dilation=space_n_time_m(dilations[3], 1))
79 |
80 | self.final = conv(
81 | self.PLANES[3] * self.BLOCK.expansion, out_channels, kernel_size=1, bias=True, D=D)
82 |
83 | def weight_initialization(self):
84 | for m in self.modules():
85 | if isinstance(m, ME.MinkowskiBatchNorm):
86 | nn.init.constant_(m.bn.weight, 1)
87 | nn.init.constant_(m.bn.bias, 0)
88 |
89 | def _make_layer(self,
90 | block,
91 | planes,
92 | blocks,
93 | stride=1,
94 | dilation=1,
95 | norm_type=NormType.BATCH_NORM,
96 | bn_momentum=0.1):
97 | downsample = None
98 | if stride != 1 or self.inplanes != planes * block.expansion:
99 | downsample = nn.Sequential(
100 | conv(
101 | self.inplanes,
102 | planes * block.expansion,
103 | kernel_size=1,
104 | stride=stride,
105 | bias=False,
106 | D=self.D),
107 | get_norm(norm_type, planes * block.expansion, D=self.D, bn_momentum=bn_momentum),
108 | )
109 | layers = []
110 | layers.append(
111 | block(
112 | self.inplanes,
113 | planes,
114 | stride=stride,
115 | dilation=dilation,
116 | downsample=downsample,
117 | conv_type=self.CONV_TYPE,
118 | D=self.D))
119 | self.inplanes = planes * block.expansion
120 | for i in range(1, blocks):
121 | layers.append(
122 | block(
123 | self.inplanes,
124 | planes,
125 | stride=1,
126 | dilation=dilation,
127 | conv_type=self.CONV_TYPE,
128 | D=self.D))
129 |
130 | return nn.Sequential(*layers)
131 |
132 | def forward(self, x):
133 | x = self.conv1(x)
134 | x = self.bn1(x)
135 | x = self.relu(x)
136 | x = self.pool(x)
137 |
138 | x = self.layer1(x)
139 | x = self.layer2(x)
140 | x = self.layer3(x)
141 | x = self.layer4(x)
142 |
143 | x = self.final(x)
144 | return x
145 |
146 |
147 | class ResNet14(ResNetBase):
148 | BLOCK = BasicBlock
149 | LAYERS = (1, 1, 1, 1)
150 |
151 |
152 | class ResNet18(ResNetBase):
153 | BLOCK = BasicBlock
154 | LAYERS = (2, 2, 2, 2)
155 |
156 |
157 | class ResNet34(ResNetBase):
158 | BLOCK = BasicBlock
159 | LAYERS = (3, 4, 6, 3)
160 |
161 |
162 | class ResNet50(ResNetBase):
163 | BLOCK = Bottleneck
164 | LAYERS = (3, 4, 6, 3)
165 |
166 |
167 | class ResNet101(ResNetBase):
168 | BLOCK = Bottleneck
169 | LAYERS = (3, 4, 23, 3)
170 |
171 |
172 | class STResNetBase(ResNetBase):
173 |
174 | CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS
175 |
176 | def __init__(self, in_channels, out_channels, config, D=4, **kwargs):
177 | super(STResNetBase, self).__init__(in_channels, out_channels, config, D, **kwargs)
178 |
179 |
180 | class STResNet14(STResNetBase, ResNet14):
181 | pass
182 |
183 |
184 | class STResNet18(STResNetBase, ResNet18):
185 | pass
186 |
187 |
188 | class STResNet34(STResNetBase, ResNet34):
189 | pass
190 |
191 |
192 | class STResNet50(STResNetBase, ResNet50):
193 | pass
194 |
195 |
196 | class STResNet101(STResNetBase, ResNet101):
197 | pass
198 |
199 |
200 | class STResTesseractNetBase(STResNetBase):
201 | CONV_TYPE = ConvType.HYPERCUBE
202 |
203 |
204 | class STResTesseractNet14(STResTesseractNetBase, STResNet14):
205 | pass
206 |
207 |
208 | class STResTesseractNet18(STResTesseractNetBase, STResNet18):
209 | pass
210 |
211 |
212 | class STResTesseractNet34(STResTesseractNetBase, STResNet34):
213 | pass
214 |
215 |
216 | class STResTesseractNet50(STResTesseractNetBase, STResNet50):
217 | pass
218 |
219 |
220 | class STResTesseractNet101(STResTesseractNetBase, STResNet101):
221 | pass
222 |
--------------------------------------------------------------------------------
/models/trunks/spconv/models/wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import random
7 | from torch.nn import Module
8 |
9 | from MinkowskiEngine import SparseTensor
10 |
11 |
12 | class Wrapper(Module):
13 | """
14 | Wrapper for the segmentation networks.
15 | """
16 | OUT_PIXEL_DIST = -1
17 |
18 | def __init__(self, NetClass, in_nchannel, out_nchannel, config):
19 | super(Wrapper, self).__init__()
20 | self.initialize_filter(NetClass, in_nchannel, out_nchannel, config)
21 |
22 | def initialize_filter(self, NetClass, in_nchannel, out_nchannel, config):
23 | raise NotImplementedError('Must initialize a model and a filter')
24 |
25 | def forward(self, x, coords, colors=None):
26 | soutput = self.model(x)
27 |
28 | # During training, make the network invariant to the filter
29 | if not self.training or random.random() < 0.5:
30 | # Filter requires the model to finish the forward pass
31 | wrapper_coords = self.filter.initialize_coords(self.model, coords, colors)
32 | finput = SparseTensor(soutput.F, wrapper_coords)
33 | soutput = self.filter(finput)
34 |
35 | return soutput
36 |
--------------------------------------------------------------------------------
/models/trunks/spconv_backbone.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | from functools import partial
8 |
9 | import spconv
10 | import torch.nn as nn
11 |
12 |
13 | def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stride=1, padding=0,
14 | conv_type='subm', norm_fn=None):
15 |
16 | if conv_type == 'subm':
17 | conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, bias=False, indice_key=indice_key)
18 | elif conv_type == 'spconv':
19 | conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
20 | bias=False, indice_key=indice_key)
21 | elif conv_type == 'inverseconv':
22 | conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, indice_key=indice_key, bias=False)
23 | else:
24 | raise NotImplementedError
25 |
26 | m = spconv.SparseSequential(
27 | conv,
28 | norm_fn(out_channels),
29 | nn.ReLU(),
30 | )
31 |
32 | return m
33 |
34 |
35 | class SparseBasicBlock(spconv.SparseModule):
36 | expansion = 1
37 |
38 | def __init__(self, inplanes, planes, stride=1, norm_fn=None, downsample=None, indice_key=None):
39 | super(SparseBasicBlock, self).__init__()
40 |
41 | assert norm_fn is not None
42 | bias = norm_fn is not None
43 | self.conv1 = spconv.SubMConv3d(
44 | inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key
45 | )
46 | self.bn1 = norm_fn(planes)
47 | self.relu = nn.ReLU()
48 | self.conv2 = spconv.SubMConv3d(
49 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key
50 | )
51 | self.bn2 = norm_fn(planes)
52 | self.downsample = downsample
53 | self.stride = stride
54 |
55 | def forward(self, x):
56 | identity = x
57 |
58 | out = self.conv1(x)
59 | out.features = self.bn1(out.features)
60 | out.features = self.relu(out.features)
61 |
62 | out = self.conv2(out)
63 | out.features = self.bn2(out.features)
64 |
65 | if self.downsample is not None:
66 | identity = self.downsample(x)
67 |
68 | out.features += identity.features
69 | out.features = self.relu(out.features)
70 |
71 | return out
72 |
73 |
74 | class VoxelBackBone8x(nn.Module):
75 | def __init__(self, model_cfg, input_channels, grid_size, **kwargs):
76 | super().__init__()
77 | self.model_cfg = model_cfg
78 | norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
79 |
80 | self.sparse_shape = grid_size[::-1] + [1, 0, 0]
81 |
82 | self.conv_input = spconv.SparseSequential(
83 | spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key='subm1'),
84 | norm_fn(16),
85 | nn.ReLU(),
86 | )
87 | block = post_act_block
88 |
89 | self.conv1 = spconv.SparseSequential(
90 | block(16, 16, 3, norm_fn=norm_fn, padding=1, indice_key='subm1'),
91 | )
92 |
93 | self.conv2 = spconv.SparseSequential(
94 | # [1600, 1408, 41] <- [800, 704, 21]
95 | block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'),
96 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'),
97 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'),
98 | )
99 |
100 | self.conv3 = spconv.SparseSequential(
101 | # [800, 704, 21] <- [400, 352, 11]
102 | block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'),
103 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'),
104 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'),
105 | )
106 |
107 | self.conv4 = spconv.SparseSequential(
108 | # [400, 352, 11] <- [200, 176, 5]
109 | block(64, 64, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='spconv'),
110 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'),
111 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'),
112 | )
113 |
114 | last_pad = 0
115 | last_pad = self.model_cfg.get('last_pad', last_pad)
116 | self.conv_out = spconv.SparseSequential(
117 | # [200, 150, 5] -> [200, 150, 2]
118 | spconv.SparseConv3d(64, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad,
119 | bias=False, indice_key='spconv_down2'),
120 | norm_fn(128),
121 | nn.ReLU(),
122 | )
123 | self.num_point_features = 128
124 |
125 | def forward(self, batch_dict):
126 | """
127 | Args:
128 | batch_dict:
129 | batch_size: int
130 | vfe_features: (num_voxels, C)
131 | voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx]
132 | Returns:
133 | batch_dict:
134 | encoded_spconv_tensor: sparse tensor
135 | """
136 | voxel_features, voxel_coords = batch_dict['voxel_features'], batch_dict['voxel_coords']
137 | batch_size = batch_dict['batch_size']
138 | input_sp_tensor = spconv.SparseConvTensor(
139 | features=voxel_features,
140 | indices=voxel_coords.int(),
141 | spatial_shape=self.sparse_shape,
142 | batch_size=batch_size
143 | )
144 |
145 | x = self.conv_input(input_sp_tensor)
146 |
147 | x_conv1 = self.conv1(x)
148 | x_conv2 = self.conv2(x_conv1)
149 | x_conv3 = self.conv3(x_conv2)
150 | x_conv4 = self.conv4(x_conv3)
151 |
152 | # for detection head
153 | # [200, 176, 5] -> [200, 176, 2]
154 | out = self.conv_out(x_conv4)
155 |
156 | batch_dict.update({
157 | 'encoded_spconv_tensor': out,
158 | 'encoded_spconv_tensor_stride': 8
159 | })
160 | batch_dict.update({
161 | 'multi_scale_3d_features': {
162 | 'x_conv1': x_conv1,
163 | 'x_conv2': x_conv2,
164 | 'x_conv3': x_conv3,
165 | 'x_conv4': x_conv4,
166 | }
167 | })
168 |
169 | return batch_dict
170 |
171 |
172 | class VoxelResBackBone8x(nn.Module):
173 | def __init__(self, model_cfg, input_channels, grid_size, **kwargs):
174 | super().__init__()
175 | self.model_cfg = model_cfg
176 | norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
177 |
178 | self.sparse_shape = grid_size[::-1] + [1, 0, 0]
179 |
180 | self.conv_input = spconv.SparseSequential(
181 | spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key='subm1'),
182 | norm_fn(16),
183 | nn.ReLU(),
184 | )
185 | block = post_act_block
186 |
187 | self.conv1 = spconv.SparseSequential(
188 | SparseBasicBlock(16, 16, norm_fn=norm_fn, indice_key='res1'),
189 | SparseBasicBlock(16, 16, norm_fn=norm_fn, indice_key='res1'),
190 | )
191 |
192 | self.conv2 = spconv.SparseSequential(
193 | # [1600, 1408, 41] <- [800, 704, 21]
194 | block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'),
195 | SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res2'),
196 | SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res2'),
197 | )
198 |
199 | self.conv3 = spconv.SparseSequential(
200 | # [800, 704, 21] <- [400, 352, 11]
201 | block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'),
202 | SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res3'),
203 | SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res3'),
204 | )
205 |
206 | self.conv4 = spconv.SparseSequential(
207 | # [400, 352, 11] <- [200, 176, 5]
208 | block(64, 128, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='spconv'),
209 | SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res4'),
210 | SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res4'),
211 | )
212 |
213 | last_pad = 0
214 | last_pad = self.model_cfg.get('last_pad', last_pad)
215 | self.conv_out = spconv.SparseSequential(
216 | # [200, 150, 5] -> [200, 150, 2]
217 | spconv.SparseConv3d(128, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad,
218 | bias=False, indice_key='spconv_down2'),
219 | norm_fn(128),
220 | nn.ReLU(),
221 | )
222 | self.num_point_features = 128
223 |
224 | def forward(self, batch_dict):
225 | """
226 | Args:
227 | batch_dict:
228 | batch_size: int
229 | vfe_features: (num_voxels, C)
230 | voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx]
231 | Returns:
232 | batch_dict:
233 | encoded_spconv_tensor: sparse tensor
234 | """
235 | voxel_features, voxel_coords = batch_dict['voxel_features'], batch_dict['voxel_coords']
236 | batch_size = batch_dict['batch_size']
237 | input_sp_tensor = spconv.SparseConvTensor(
238 | features=voxel_features,
239 | indices=voxel_coords.int(),
240 | spatial_shape=self.sparse_shape,
241 | batch_size=batch_size
242 | )
243 | x = self.conv_input(input_sp_tensor)
244 |
245 | x_conv1 = self.conv1(x)
246 | x_conv2 = self.conv2(x_conv1)
247 | x_conv3 = self.conv3(x_conv2)
248 | x_conv4 = self.conv4(x_conv3)
249 |
250 | # for detection head
251 | # [200, 176, 5] -> [200, 176, 2]
252 | out = self.conv_out(x_conv4)
253 |
254 | batch_dict.update({
255 | 'encoded_spconv_tensor': out,
256 | 'encoded_spconv_tensor_stride': 8
257 | })
258 | batch_dict.update({
259 | 'multi_scale_3d_features': {
260 | 'x_conv1': x_conv1,
261 | 'x_conv2': x_conv2,
262 | 'x_conv3': x_conv3,
263 | 'x_conv4': x_conv4,
264 | }
265 | })
266 |
267 | return batch_dict
268 |
--------------------------------------------------------------------------------
/models/trunks/spconv_unet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | from functools import partial
8 |
9 | import spconv
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from models.trunks.mlp import MLP
15 |
16 | from .spconv_backbone import post_act_block
17 |
18 | class SparseBasicBlock(spconv.SparseModule):
19 | expansion = 1
20 |
21 | def __init__(self, inplanes, planes, stride=1, downsample=None, indice_key=None, norm_fn=None):
22 | super(SparseBasicBlock, self).__init__()
23 | self.conv1 = spconv.SubMConv3d(
24 | inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False, indice_key=indice_key
25 | )
26 | self.bn1 = norm_fn(planes)
27 | self.relu = nn.ReLU()
28 | self.conv2 = spconv.SubMConv3d(
29 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False, indice_key=indice_key
30 | )
31 | self.bn2 = norm_fn(planes)
32 | self.downsample = downsample
33 | self.stride = stride
34 |
35 | def forward(self, x):
36 | identity = x.features
37 |
38 | assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim()
39 |
40 | out = self.conv1(x)
41 | out.features = self.bn1(out.features)
42 | out.features = self.relu(out.features)
43 |
44 | out = self.conv2(out)
45 | out.features = self.bn2(out.features)
46 |
47 | if self.downsample is not None:
48 | identity = self.downsample(x)
49 |
50 | out.features += identity
51 | out.features = self.relu(out.features)
52 |
53 | return out
54 |
55 | import numpy as np
56 |
57 | class UNetV2_concat(nn.Module):
58 | """
59 | Sparse Convolution based UNet for point-wise feature learning.
60 | Reference Paper: https://arxiv.org/abs/1907.03670 (Shaoshuai Shi, et. al)
61 | From Points to Parts: 3D Object Detection from Point Cloud with Part-aware and Part-aggregation Network
62 | """
63 | def __init__(self, use_mlp=False, mlp_dim=None):
64 | super().__init__()
65 |
66 | input_channels = 4
67 | voxel_size = [0.1, 0.1, 0.2]
68 | point_cloud_range = np.array([ 0. , -75. , -3. , 75.0, 75. , 3. ], dtype=np.float32)
69 |
70 | grid_size = (point_cloud_range[3:6] - point_cloud_range[0:3]) / np.array(voxel_size)
71 | grid_size = np.round(grid_size).astype(np.int64)
72 | model_cfg = {'NAME': 'UNetV2', 'RETURN_ENCODED_TENSOR': False}
73 |
74 | self.sparse_shape = grid_size[::-1] + [1, 0, 0]
75 | self.voxel_size = voxel_size
76 | self.point_cloud_range = point_cloud_range
77 |
78 | norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
79 |
80 | self.conv_input = spconv.SparseSequential(
81 | spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key='subm1'),
82 | norm_fn(16),
83 | nn.ReLU(),
84 | )
85 | block = post_act_block
86 |
87 | self.conv1 = spconv.SparseSequential(
88 | block(16, 16, 3, norm_fn=norm_fn, padding=1, indice_key='subm1'),
89 | )
90 |
91 | self.conv2 = spconv.SparseSequential(
92 | # [1600, 1408, 41] <- [800, 704, 21]
93 | block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'),
94 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'),
95 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'),
96 | )
97 |
98 | self.conv3 = spconv.SparseSequential(
99 | # [800, 704, 21] <- [400, 352, 11]
100 | block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'),
101 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'),
102 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'),
103 | )
104 |
105 | self.conv4 = spconv.SparseSequential(
106 | # [400, 352, 11] <- [200, 176, 5]
107 | block(64, 64, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='spconv'),
108 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'),
109 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'),
110 | )
111 |
112 | self.conv_out = None
113 |
114 | # decoder
115 | # [400, 352, 11] <- [200, 176, 5]
116 | self.conv_up_t4 = SparseBasicBlock(64, 64, indice_key='subm4', norm_fn=norm_fn)
117 | self.conv_up_m4 = block(128, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4')
118 | self.inv_conv4 = block(64, 64, 3, norm_fn=norm_fn, indice_key='spconv4', conv_type='inverseconv')
119 |
120 | # [800, 704, 21] <- [400, 352, 11]
121 | self.conv_up_t3 = SparseBasicBlock(64, 64, indice_key='subm3', norm_fn=norm_fn)
122 | self.conv_up_m3 = block(128, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3')
123 | self.inv_conv3 = block(64, 32, 3, norm_fn=norm_fn, indice_key='spconv3', conv_type='inverseconv')
124 |
125 | # [1600, 1408, 41] <- [800, 704, 21]
126 | self.conv_up_t2 = SparseBasicBlock(32, 32, indice_key='subm2', norm_fn=norm_fn)
127 | self.conv_up_m2 = block(64, 32, 3, norm_fn=norm_fn, indice_key='subm2')
128 | self.inv_conv2 = block(32, 16, 3, norm_fn=norm_fn, indice_key='spconv2', conv_type='inverseconv')
129 |
130 | # [1600, 1408, 41] <- [1600, 1408, 41]
131 | self.conv_up_t1 = SparseBasicBlock(16, 16, indice_key='subm1', norm_fn=norm_fn)
132 | self.conv_up_m1 = block(32, 16, 3, norm_fn=norm_fn, indice_key='subm1')
133 |
134 | self.conv5 = spconv.SparseSequential(
135 | block(16, 16, 3, norm_fn=norm_fn, padding=1, indice_key='subm1')
136 | )
137 | self.num_point_features = 16
138 |
139 | self.all_feat_names = [
140 | "conv4",
141 | ]
142 |
143 | if use_mlp:
144 | self.use_mlp = True
145 | self.head = MLP(mlp_dim)
146 |
147 | def UR_block_forward(self, x_lateral, x_bottom, conv_t, conv_m, conv_inv):
148 | x_trans = conv_t(x_lateral)
149 | x = x_trans
150 | x.features = torch.cat((x_bottom.features, x_trans.features), dim=1)
151 | x_m = conv_m(x)
152 | x = self.channel_reduction(x, x_m.features.shape[1])
153 | x.features = x_m.features + x.features
154 | x = conv_inv(x)
155 | return x
156 |
157 | @staticmethod
158 | def channel_reduction(x, out_channels):
159 | """
160 | Args:
161 | x: x.features (N, C1)
162 | out_channels: C2
163 |
164 | Returns:
165 |
166 | """
167 | features = x.features
168 | n, in_channels = features.shape
169 | assert (in_channels % out_channels == 0) and (in_channels >= out_channels)
170 |
171 | x.features = features.view(n, out_channels, -1).sum(dim=2)
172 | return x
173 |
174 | def forward(self, x, out_feat_keys=None):
175 | """
176 | Args:
177 | batch_dict:
178 | batch_size: int
179 | vfe_features: (num_voxels, C)
180 | voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx]
181 | Returns:
182 | batch_dict:
183 | encoded_spconv_tensor: sparse tensor
184 | point_features: (N, C)
185 | """
186 | ### Pre processing
187 | voxel_features, voxel_num_points = x['voxels'], x['voxel_num_points']
188 | points_mean = voxel_features[:, :, :].sum(dim=1, keepdim=False)
189 | normalizer = torch.clamp_min(voxel_num_points.view(-1, 1), min=1.0).type_as(voxel_features)
190 | points_mean = points_mean / normalizer
191 | voxel_features = points_mean.contiguous()
192 |
193 | temp = x['voxel_coords'].detach().cpu().numpy()
194 |
195 | batch_size = len(np.unique(temp[:,0]))
196 | voxel_coords = x['voxel_coords']
197 |
198 | input_sp_tensor = spconv.SparseConvTensor(
199 | features=voxel_features.float(),
200 | indices=voxel_coords.int(),
201 | spatial_shape=self.sparse_shape,
202 | batch_size=batch_size
203 | )
204 | x = self.conv_input(input_sp_tensor)
205 |
206 | x_conv1 = self.conv1(x)
207 | x_conv2 = self.conv2(x_conv1)
208 | x_conv3 = self.conv3(x_conv2)
209 | x_conv4 = self.conv4(x_conv3)
210 |
211 | if self.conv_out is not None:
212 | # for detection head
213 | # [200, 176, 5] -> [200, 176, 2]
214 | out = self.conv_out(x_conv4)
215 | #batch_dict['encoded_spconv_tensor'] = out
216 | #batch_dict['encoded_spconv_tensor_stride'] = 8
217 |
218 | # for segmentation head
219 | # [400, 352, 11] <- [200, 176, 5]
220 | x_up4 = self.UR_block_forward(x_conv4, x_conv4, self.conv_up_t4, self.conv_up_m4, self.inv_conv4)
221 | # [800, 704, 21] <- [400, 352, 11]
222 | x_up3 = self.UR_block_forward(x_conv3, x_up4, self.conv_up_t3, self.conv_up_m3, self.inv_conv3)
223 | # [1600, 1408, 41] <- [800, 704, 21]
224 | x_up2 = self.UR_block_forward(x_conv2, x_up3, self.conv_up_t2, self.conv_up_m2, self.inv_conv2)
225 | # [1600, 1408, 41] <- [1600, 1408, 41]
226 | x_up1 = self.UR_block_forward(x_conv1, x_up2, self.conv_up_t1, self.conv_up_m1, self.conv5)
227 |
228 | end_points = {}
229 |
230 | end_points['conv4_features'] = [x_up4.features, x_up3.features, x_up2.features, x_up1.features]#.view(batch_size, -1, 64).permute(0, 2, 1).contiguous()
231 | end_points['indice'] = [x_up4.indices, x_up3.indices, x_up2.indices, x_up1.indices]
232 |
233 | out_feats = [None] * len(out_feat_keys)
234 |
235 | for key in out_feat_keys:
236 | feat = end_points[key+"_features"]
237 |
238 | featlist = []
239 | for i in range(batch_size):
240 | tempfeat = []
241 | for idx in range(len(end_points['indice'])):
242 | temp_idx = end_points['indice'][idx][:,0] == i
243 | temp_f = end_points['conv4_features'][idx][temp_idx].unsqueeze(0).permute(0, 2, 1).contiguous()
244 | tempfeat.append(F.max_pool1d(temp_f, temp_f.shape[-1]).squeeze(-1))
245 | featlist.append(torch.cat(tempfeat, -1))
246 | feat = torch.cat(featlist, 0)
247 | if self.use_mlp:
248 | feat = self.head(feat)
249 | out_feats[out_feat_keys.index(key)] = feat ### Just use smlp
250 | return out_feats
251 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | absl-py=0.9.0=py36_0
5 | alabaster=0.7.10=py36h306e16b_0
6 | anaconda=custom=py36hbbc8b67_0
7 | anaconda-client=1.6.9=py36_0
8 | anaconda-project=0.8.2=py36h44fb852_0
9 | asn1crypto=0.24.0=py36_0
10 | astroid=1.6.1=py36_0
11 | attrs=17.4.0=py36_0
12 | babel=2.5.3=py36_0
13 | backports=1.0=py36hfa02d7e_1
14 | backports.functools_lru_cache=1.5=py36_0
15 | backports.shutil_get_terminal_size=1.0.0=py36hfea85ff_2
16 | backports.weakref=1.0rc1=py36_1
17 | beautifulsoup4=4.6.0=py36h49b8c8c_1
18 | bitarray=0.8.1=py36h14c3975_1
19 | blas=1.0=mkl
20 | bleach=3.3.0
21 | blinker=1.4=py_0
22 | boto=2.48.0=py36h6e4cd66_1
23 | boto3=1.6.7=py_0
24 | botocore=1.9.7=py_0
25 | bz2file=0.98=py36_0
26 | bzip2=1.0.6=h9a117a8_4
27 | ca-certificates=2017.08.26=h1d4fec5_0
28 | cairo=1.14.12=h77bcde2_0
29 | certifi=2019.11.28=py36_1
30 | cffi=1.11.4=py36h9745a5d_0
31 | chardet=3.0.4=py36h0f667ec_1
32 | click=6.7=py36h5253387_0
33 | cloudpickle=0.5.2=py36_1
34 | clyent=1.2.2=py36h7e57e65_1
35 | cmake=3.9.4=h142f0e9_0
36 | colorama=0.3.9=py36h489cec4_0
37 | contextlib2=0.5.5=py36h6c84a62_0
38 | cryptography=2.1.4=py36hd09be54_0
39 | cuda90=1.0=h6433d27_0
40 | cudatoolkit=10.1.243=h6bb024c_0
41 | curl=7.58.0=h84994c4_0
42 | cycler=0.10.0=py36_0
43 | cymem=1.31.2=py36_0
44 | cython=0.27.3=py36h1860423_0
45 | cytoolz=0.8.2=py36_0
46 | dask-core=0.16.1=py36_0
47 | dbus=1.10.22=h3b5a359_0
48 | decorator=4.2.1=py36_0
49 | dill=0.2.7.1=py36_0
50 | distributed=1.20.2=py36_0
51 | docutils=0.14=py36hb0f60f5_0
52 | easydict=1.9=py_0
53 | entrypoints=0.2.3=py36h1aec115_2
54 | et_xmlfile=1.0.1=py36hd6bccc3_0
55 | expat=2.2.4=hc00ebd1_1
56 | fastcache=1.0.2=py36h14c3975_2
57 | ffmpeg=3.4=h7264315_0
58 | filelock=2.0.13=py36h646ffb5_0
59 | flask
60 | flask-cors=3.0.3=py36h2d857d3_0
61 | fontconfig=2.12.6=h49f89f6_0
62 | freetype=2.8=h52ed37b_0
63 | ftfy=4.4.2=py36_0
64 | future=0.16.0=py36_1
65 | get_terminal_size=1.0.0=haa9412d_0
66 | gevent=1.2.2=py36h2fe25dc_0
67 | gflags=2.2.1=hf484d3e_0
68 | glib=2.53.6=hc861d11_1
69 | glob2=0.6=py36he249c77_0
70 | glog=0.3.5=hf484d3e_1
71 | gmp=6.1.2=h6c8ec71_1
72 | gmpy2=2.0.8=py36hc8893dd_2
73 | graphite2=1.3.10=hf63cedd_1
74 | greenlet=0.4.12=py36h2d503a6_0
75 | gst-plugins-base=1.12.4=h33fb286_0
76 | gstreamer=1.12.4=hb53b477_0
77 | harfbuzz=1.7.4=hc5b324e_0
78 | hdf5=1.10.1=h9caa474_1
79 | heapdict=1.0.0=py36_2
80 | html5lib=0.9999999=py36_0
81 | icc_rt=2018.0.3=intel_0
82 | icu=58.2=h211956c_0
83 | idna=2.6=py36h82fb2a8_1
84 | imageio=2.9.0=py_0
85 | imagesize=0.7.1=py36h52d8127_0
86 | intel-openmp=2018.0.0=hc7b2577_8
87 | intelpython=2018.0.3=0
88 | ipykernel=4.8.0=py36_0
89 | ipython=6.2.1=py36_1
90 | ipython_genutils=0.2.0=py36hb52b0d5_0
91 | ipywidgets=7.1.1=py36_0
92 | isort=4.2.15=py36had401c0_0
93 | itsdangerous=0.24=py36h93cc618_1
94 | jasper=1.900.1=hd497a04_4
95 | jbig=2.1=hdba287a_0
96 | jdcal=1.3=py36h4c697fb_0
97 | jedi=0.11.1=py36_0
98 | jinja2=2.11.3
99 | jmespath=0.9.3=py36_0
100 | joblib=0.11=py36_0
101 | jpeg=9b=habf39ab_1
102 | jsonschema=2.6.0=py36h006f8b5_0
103 | jupyter=1.0.0=py36_0
104 | jupyter_client=5.2.2=py36_0
105 | jupyter_console=5.2.0=py36he59e554_1
106 | jupyter_core=4.4.0=py36h7c827e3_0
107 | jupyterlab=0.31.5=py36_0
108 | jupyterlab_launcher=0.10.2=py36_0
109 | kiwisolver=1.0.1=py36hf484d3e_0
110 | lazy-object-proxy=1.3.1=py36h10fcdad_0
111 | libcurl=7.58.0=h1ad7b7a_0
112 | libedit=3.1.20181209=hc058e9b_0
113 | libffi=3.2.1=h4deb6c0_3
114 | libgcc=7.2.0=h69d50b8_2
115 | libgcc-ng=9.1.0=hdf63c60_0
116 | libgfortran=3.0.0=1
117 | libgfortran-ng=7.2.0=h9f7466a_2
118 | libgpuarray=0.7.5=0
119 | libiconv=1.15=0
120 | libopenblas=0.2.20=h9ac9557_7
121 | libopus=1.2.1=hb9ed12e_0
122 | libpng=1.6.37=hbc83047_0
123 | libprotobuf=3.6.0=hdbcaa40_0
124 | libsodium=1.0.15=hf101ebd_0
125 | libssh2=1.8.0=h9cfc8f7_4
126 | libstdcxx-ng=7.2.0=h7a57d05_2
127 | libtiff=4.0.9=h28f6b97_0
128 | libtool=2.4.6=h544aabb_3
129 | libuv=1.20.3=h14c3975_0
130 | libvpx=1.6.1=h888fd40_0
131 | libxcb=1.13=h1bed415_1
132 | libxml2=2.9.9=hea5a465_1
133 | libxslt=1.1.32=h1312cb7_0
134 | llvmlite=0.21.0=py36ha241eea_0
135 | lmdb=0.9.21=hf484d3e_1
136 | locket=0.2.0=py36h787c0ad_1
137 | lxml=4.6.3
138 | lzo=2.10=h49e0be7_2
139 | magma-cuda90=2.3.0=1
140 | mako=1.0.7=py36_0
141 | markdown=2.6.11=py_0
142 | markupsafe=1.0=py36hd9260cd_1
143 | mccabe=0.6.1=py36h5ad9710_1
144 | mistune=0.8.3=py36_0
145 | mkl=2018.0.3=1
146 | mkl-include=2018.0.3=1
147 | mkl-service=1.1.2=py36h17a0993_4
148 | mkl_fft=1.0.4=py36h4414c95_1
149 | mkl_random=1.0.1=py36h4414c95_1
150 | mkldnn=0.14.0=0
151 | mpc=1.0.3=hec55b23_5
152 | mpfr=3.1.5=h11a74b3_2
153 | mpmath=1.0.0=py36hfeacd6b_2
154 | msgpack-python=0.4.8=py36_0
155 | multipledispatch=0.4.9=py36h41da3fb_0
156 | murmurhash=0.28.0=py36_0
157 | nbconvert=5.3.1=py36hb41ffb7_0
158 | nbformat=4.4.0=py36h31c9010_0
159 | nccl2=1.0=0
160 | ncurses=6.1=hf484d3e_0
161 | networkx=2.1=py36_0
162 | ninja=1.8.2=py36h6bb024c_1
163 | nose=1.3.7=py36hcdf7029_2
164 | notebook=6.1.5=py36_0
165 | numpy=1.14.3=py36hcd700cb_1
166 | numpy-base=1.14.3=py36h9be14a7_1
167 | numpydoc=0.7.0=py36h18f165f_0
168 | oauthlib=2.0.6=py_0
169 | olefile=0.45.1=py36_0
170 | openblas=0.2.19=0
171 | opencv=3.3.1=py36h0a11808_0
172 | openmp=2018.0.3=intel_0
173 | openpyxl=2.4.10=py36_0
174 | openssl=1.0.2u=h7b6447c_0
175 | packaging=16.8=py36ha668100_1
176 | pandas=0.23.4=py36h04863e7_0
177 | pandoc=1.19.2.1=hea2e7c5_1
178 | pandocfilters=1.4.2=py36ha6701b7_1
179 | pango=1.41.0=hd475d92_0
180 | parso=0.1.1=py36h35f843b_0
181 | partd=0.3.8=py36h36fd896_0
182 | patchelf=0.9=hf79760b_2
183 | path.py=10.5=py36h55ceabb_0
184 | pathlib2=2.3.0=py36h49efa8e_0
185 | pcre=8.42=h439df22_0
186 | pep8=1.7.1=py36_0
187 | pexpect=4.3.1=py36_0
188 | pickleshare=0.7.4=py36h63277f8_0
189 | pillow=8.2.0
190 | pip=19.3.1=py36_0
191 | pixman=0.34.0=hceecf20_3
192 | pkginfo=1.4.1=py36h215d178_1
193 | plac=0.9.6=py36_0
194 | pluggy=0.6.0=py36hb689045_0
195 | ply=3.10=py36hed35086_0
196 | plyfile=0.7.2=pyh9f0ad1d_0
197 | preshed=1.0.0=py36_0
198 | prompt_toolkit=1.0.15=py36h17d85b1_0
199 | protobuf=3.6.0=py36hf484d3e_0
200 | psutil
201 | ptyprocess=0.5.2=py36h69acd42_0
202 | py
203 | pycodestyle=2.3.1=py36hf609f19_0
204 | pycosat=0.6.3=py36h0a5515d_0
205 | pycparser=2.18=py36hf9f622e_1
206 | pycrypto
207 | pycurl=7.43.0.1=py36hb7f436b_0
208 | pyflakes=1.6.0=py36h7bd6a15_0
209 | pygments=2.7.4
210 | pyjwt=1.5.3=py_0
211 | pylint=1.8.2=py36_0
212 | pyodbc=4.0.22=py36hf484d3e_0
213 | pyopenssl=17.5.0=py36h20ba746_0
214 | pyparsing=2.4.4=py_0
215 | pysocks=1.6.7=py36hd97a5b1_1
216 | pytest=3.3.2=py36_0
217 | python=3.6.6=hc3d631a_0
218 | python-crfsuite=0.9.2=py36_0
219 | python-dateutil=2.8.1=py_0
220 | python-lmdb=0.92=py36_0
221 | pytorch=1.5.0=py3.6_cuda10.1.243_cudnn7.6.3_0
222 | pytz=2019.3=py_0
223 | pyyaml=5.4
224 | pyzmq=16.0.2=py36h3b0cf96_2
225 | qt=5.6.2=h974d657_12
226 | qtawesome=0.4.4=py36h609ed8c_0
227 | qtconsole=4.3.1=py36h8f73b5b_0
228 | qtpy=1.3.1=py36h3691cc8_0
229 | readline=7.0=h7b6447c_5
230 | regex=2017.11.09=py36_0
231 | requests
232 | requests-oauthlib=0.8.0=py36_1
233 | rhash=1.3.6=hb7f436b_0
234 | rope=0.10.7=py36h147e2ec_0
235 | ruamel_yaml=0.15.35=py36h14c3975_1
236 | s3transfer=0.1.13=py36_0
237 | scikit-learn=0.19.1=py36h7aa7ec6_0
238 | scipy=1.0.0=py36hbf646e7_0
239 | send2trash=1.4.2=py36_0
240 | setuptools=41.6.0=py36_0
241 | simplegeneric=0.8.1=py36_2
242 | simplejson=3.14.0=py36h14c3975_0
243 | singledispatch=3.4.0.3=py36h7a266c3_0
244 | sip=4.19.8=py36hf484d3e_0
245 | six=1.13.0=py36_0
246 | smart_open=1.5.6=py36_1
247 | snowballstemmer=1.2.1=py36h6febd40_0
248 | sortedcollections=0.5.3=py36h3c761f9_0
249 | sortedcontainers=1.5.9=py36_0
250 | sphinx=1.6.6=py36_0
251 | sphinxcontrib=1.0=py36h6d0f590_1
252 | sphinxcontrib-websupport=1.0.1=py36hb5cb234_1
253 | spyder=3.2.6=py36_0
254 | sqlalchemy
255 | sqlite=3.30.1=h7b6447c_0
256 | sympy=1.1.1=py36hc6d1c1c_0
257 | tbb=2018.0.4=h6bb024c_1
258 | tbb4py=2018.0.4=py36h6bb024c_1
259 | tblib=1.3.2=py36h34cf8b6_0
260 | tensorboardx=2.1=py_0
261 | termcolor=1.1.0=py36_1
262 | terminado=0.8.1=py36_1
263 | testpath=0.3.1=py36h8cadb63_0
264 | tk=8.6.8=hbc83047_0
265 | toolz=0.9.0=py36_0
266 | torchvision=0.6.0=py36_cu101
267 | tornado=6.0.3=py36h7b6447c_0
268 | tqdm=4.19.7=py_0
269 | traitlets=4.3.2=py36h674d592_0
270 | twython=3.6.0=py36_0
271 | typing=3.6.2=py36h7da032a_0
272 | ujson=1.35=py36_0
273 | unicodecsv=0.14.1=py36ha668878_0
274 | unixodbc=2.3.4=hc36303a_1
275 | urllib3
276 | wcwidth=0.1.7=py36hdf4376a_0
277 | webencodings=0.5.1=py36h800622e_1
278 | werkzeug
279 | wheel=0.33.6=py36_0
280 | widgetsnbextension=3.1.0=py36_0
281 | wrapt=1.10.11=py36h28b7045_0
282 | x264=20131217=3
283 | xlrd=1.1.0=py36h1db9f0c_1
284 | xlsxwriter=1.0.2=py36h3de1aca_0
285 | xlwt=1.3.0=py36h7b00a1f_0
286 | xz=5.2.4=h14c3975_4
287 | yaml=0.1.7=had09818_2
288 | zeromq=4.2.2=hbedb6e5_2
289 | zict=0.1.3=py36h3a3bf81_0
290 | zlib=1.2.11=h7b6447c_3
291 |
--------------------------------------------------------------------------------
/scripts/multinode-wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | import sys, os
8 |
9 | num_nodes = int(os.environ['SLURM_NNODES'])
10 | node_id = int(os.environ['SLURM_NODEID'])
11 | node0 = 'learnfair' + os.environ['SLURM_NODELIST'][10:14]
12 | cmd = 'python {script} {cfg} --dist-url tcp://{node0}:1234 --dist-backend nccl --multiprocessing-distributed --world-size {ws} --rank {rank}'.format(
13 | script=sys.argv[1], cfg=sys.argv[2], node0=node0, ws=num_nodes, rank=node_id)
14 | os.system(cmd)
15 |
--------------------------------------------------------------------------------
/scripts/pretrain_node1.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #! /bin/bash
7 | #SBATCH --job-name=DepthContrast
8 | #SBATCH --nodes=1
9 | #SBATCH --gres=gpu:8
10 | #SBATCH --cpus-per-task=80
11 | #SBATCH --mem=400G
12 | #SBATCH --time=72:00:00
13 | #SBATCH --partition=dev
14 | #SBATCH --comment="test"
15 | #SBATCH --constraint=volta32gb
16 |
17 | #SBATCH --signal=B:USR1@60
18 | #SBATCH --open-mode=append
19 |
20 | EXPERIMENT_PATH="./checkpoints/testlog"
21 | mkdir -p $EXPERIMENT_PATH
22 |
23 | export PYTHONPATH=$PWD:$PYTHONPATH
24 |
25 | srun --output=${EXPERIMENT_PATH}/%j.out --error=${EXPERIMENT_PATH}/%j.err --label python scripts/singlenode-wrapper.py main.py $1
26 |
--------------------------------------------------------------------------------
/scripts/pretrain_node4.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #! /bin/bash
7 | #SBATCH --job-name=DepthContrast
8 | #SBATCH --job-name=DepthContrast
9 | #SBATCH --nodes=4
10 | #SBATCH --gres=gpu:8
11 | #SBATCH --cpus-per-task=80
12 | #SBATCH --mem=400G
13 | #SBATCH --time=72:00:00
14 | #SBATCH --partition=dev
15 | #SBATCH --comment="depthcontrast"
16 | #SBATCH --constraint=volta32gb
17 |
18 | #SBATCH --signal=B:USR1@60
19 | #SBATCH --open-mode=append
20 |
21 | LOG_PATH=$2
22 | mkdir -p $LOG_PATH
23 |
24 | export PYTHONPATH=$PWD:$PYTHONPATH
25 |
26 | srun --output=${LOG_PATH}/%j.out --error=${LOG_PATH}/%j.err --label python scripts/multinode-wrapper.py main.py $1
27 |
--------------------------------------------------------------------------------
/scripts/singlenode-wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | import sys, os
8 |
9 | num_nodes = int(os.environ['SLURM_NNODES'])
10 | node_id = int(os.environ['SLURM_NODEID'])
11 | node0 = 'learnfair' + os.environ['SLURM_NODELIST'][9:13]
12 | cmd = 'python {script} {cfg} --dist-url tcp://{node0}:1234 --dist-backend nccl --multiprocessing-distributed --world-size {ws} --rank {rank}'.format(
13 | script=sys.argv[1], cfg=sys.argv[2], node0=node0, ws=num_nodes, rank=node_id)
14 | os.system(cmd)
15 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/include/ball_query.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 | #include
8 |
9 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
10 | const int nsample);
11 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/include/cuda_utils.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #ifndef _CUDA_UTILS_H
7 | #define _CUDA_UTILS_H
8 |
9 | #include
10 | #include
11 | #include
12 |
13 | #include
14 | #include
15 |
16 | #include
17 |
18 | #define TOTAL_THREADS 512
19 |
20 | inline int opt_n_threads(int work_size) {
21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
22 |
23 | return max(min(1 << pow_2, TOTAL_THREADS), 1);
24 | }
25 |
26 | inline dim3 opt_block_config(int x, int y) {
27 | const int x_threads = opt_n_threads(x);
28 | const int y_threads =
29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
30 | dim3 block_config(x_threads, y_threads, 1);
31 |
32 | return block_config;
33 | }
34 |
35 | #define CUDA_CHECK_ERRORS() \
36 | do { \
37 | cudaError_t err = cudaGetLastError(); \
38 | if (cudaSuccess != err) { \
39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
41 | __FILE__); \
42 | exit(-1); \
43 | } \
44 | } while (0)
45 |
46 | #endif
47 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/include/group_points.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 | #include
8 |
9 | at::Tensor group_points(at::Tensor points, at::Tensor idx);
10 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
11 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/include/interpolate.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 |
8 | #include
9 | #include
10 |
11 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows);
12 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
13 | at::Tensor weight);
14 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
15 | at::Tensor weight, const int m);
16 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/include/sampling.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 | #include
8 |
9 | at::Tensor gather_points(at::Tensor points, at::Tensor idx);
10 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
11 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples);
12 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/include/utils.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 | #include
8 | #include
9 |
10 | #define CHECK_CUDA(x) \
11 | do { \
12 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \
13 | } while (0)
14 |
15 | #define CHECK_CONTIGUOUS(x) \
16 | do { \
17 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \
18 | } while (0)
19 |
20 | #define CHECK_IS_INT(x) \
21 | do { \
22 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \
23 | #x " must be an int tensor"); \
24 | } while (0)
25 |
26 | #define CHECK_IS_FLOAT(x) \
27 | do { \
28 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \
29 | #x " must be a float tensor"); \
30 | } while (0)
31 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/src/ball_query.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "ball_query.h"
7 | #include "utils.h"
8 |
9 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
10 | int nsample, const float *new_xyz,
11 | const float *xyz, int *idx);
12 |
13 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
14 | const int nsample) {
15 | CHECK_CONTIGUOUS(new_xyz);
16 | CHECK_CONTIGUOUS(xyz);
17 | CHECK_IS_FLOAT(new_xyz);
18 | CHECK_IS_FLOAT(xyz);
19 |
20 | if (new_xyz.is_cuda()) {
21 | CHECK_CUDA(xyz);
22 | }
23 |
24 | at::Tensor idx =
25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int));
27 |
28 | if (new_xyz.is_cuda()) {
29 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
30 | radius, nsample, new_xyz.data(),
31 | xyz.data(), idx.data());
32 | } else {
33 | AT_ASSERT(false, "CPU not supported");
34 | }
35 |
36 | return idx;
37 | }
38 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/src/ball_query_gpu.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include
7 | #include
8 | #include
9 |
10 | #include "cuda_utils.h"
11 |
12 | // input: new_xyz(b, m, 3) xyz(b, n, 3)
13 | // output: idx(b, m, nsample)
14 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius,
15 | int nsample,
16 | const float *__restrict__ new_xyz,
17 | const float *__restrict__ xyz,
18 | int *__restrict__ idx) {
19 | int batch_index = blockIdx.x;
20 | xyz += batch_index * n * 3;
21 | new_xyz += batch_index * m * 3;
22 | idx += m * nsample * batch_index;
23 |
24 | int index = threadIdx.x;
25 | int stride = blockDim.x;
26 |
27 | float radius2 = radius * radius;
28 | for (int j = index; j < m; j += stride) {
29 | float new_x = new_xyz[j * 3 + 0];
30 | float new_y = new_xyz[j * 3 + 1];
31 | float new_z = new_xyz[j * 3 + 2];
32 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
33 | float x = xyz[k * 3 + 0];
34 | float y = xyz[k * 3 + 1];
35 | float z = xyz[k * 3 + 2];
36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
37 | (new_z - z) * (new_z - z);
38 | if (d2 < radius2) {
39 | if (cnt == 0) {
40 | for (int l = 0; l < nsample; ++l) {
41 | idx[j * nsample + l] = k;
42 | }
43 | }
44 | idx[j * nsample + cnt] = k;
45 | ++cnt;
46 | }
47 | }
48 | }
49 | }
50 |
51 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
52 | int nsample, const float *new_xyz,
53 | const float *xyz, int *idx) {
54 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
55 | query_ball_point_kernel<<>>(
56 | b, n, m, radius, nsample, new_xyz, xyz, idx);
57 |
58 | CUDA_CHECK_ERRORS();
59 | }
60 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "ball_query.h"
7 | #include "group_points.h"
8 | #include "interpolate.h"
9 | #include "sampling.h"
10 |
11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
12 | m.def("gather_points", &gather_points);
13 | m.def("gather_points_grad", &gather_points_grad);
14 | m.def("furthest_point_sampling", &furthest_point_sampling);
15 |
16 | m.def("three_nn", &three_nn);
17 | m.def("three_interpolate", &three_interpolate);
18 | m.def("three_interpolate_grad", &three_interpolate_grad);
19 |
20 | m.def("ball_query", &ball_query);
21 |
22 | m.def("group_points", &group_points);
23 | m.def("group_points_grad", &group_points_grad);
24 | }
25 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/src/group_points.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "group_points.h"
7 | #include "utils.h"
8 |
9 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
10 | const float *points, const int *idx,
11 | float *out);
12 |
13 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
14 | int nsample, const float *grad_out,
15 | const int *idx, float *grad_points);
16 |
17 | at::Tensor group_points(at::Tensor points, at::Tensor idx) {
18 | CHECK_CONTIGUOUS(points);
19 | CHECK_CONTIGUOUS(idx);
20 | CHECK_IS_FLOAT(points);
21 | CHECK_IS_INT(idx);
22 |
23 | if (points.is_cuda()) {
24 | CHECK_CUDA(idx);
25 | }
26 |
27 | at::Tensor output =
28 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)},
29 | at::device(points.device()).dtype(at::ScalarType::Float));
30 |
31 | if (points.is_cuda()) {
32 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
33 | idx.size(1), idx.size(2), points.data(),
34 | idx.data(), output.data());
35 | } else {
36 | AT_ASSERT(false, "CPU not supported");
37 | }
38 |
39 | return output;
40 | }
41 |
42 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) {
43 | CHECK_CONTIGUOUS(grad_out);
44 | CHECK_CONTIGUOUS(idx);
45 | CHECK_IS_FLOAT(grad_out);
46 | CHECK_IS_INT(idx);
47 |
48 | if (grad_out.is_cuda()) {
49 | CHECK_CUDA(idx);
50 | }
51 |
52 | at::Tensor output =
53 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
54 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
55 |
56 | if (grad_out.is_cuda()) {
57 | group_points_grad_kernel_wrapper(
58 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2),
59 | grad_out.data(), idx.data(), output.data());
60 | } else {
61 | AT_ASSERT(false, "CPU not supported");
62 | }
63 |
64 | return output;
65 | }
66 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/src/group_points_gpu.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include
7 | #include
8 |
9 | #include "cuda_utils.h"
10 |
11 | // input: points(b, c, n) idx(b, npoints, nsample)
12 | // output: out(b, c, npoints, nsample)
13 | __global__ void group_points_kernel(int b, int c, int n, int npoints,
14 | int nsample,
15 | const float *__restrict__ points,
16 | const int *__restrict__ idx,
17 | float *__restrict__ out) {
18 | int batch_index = blockIdx.x;
19 | points += batch_index * n * c;
20 | idx += batch_index * npoints * nsample;
21 | out += batch_index * npoints * nsample * c;
22 |
23 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
24 | const int stride = blockDim.y * blockDim.x;
25 | for (int i = index; i < c * npoints; i += stride) {
26 | const int l = i / npoints;
27 | const int j = i % npoints;
28 | for (int k = 0; k < nsample; ++k) {
29 | int ii = idx[j * nsample + k];
30 | out[(l * npoints + j) * nsample + k] = points[l * n + ii];
31 | }
32 | }
33 | }
34 |
35 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
36 | const float *points, const int *idx,
37 | float *out) {
38 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
39 |
40 | group_points_kernel<<>>(
41 | b, c, n, npoints, nsample, points, idx, out);
42 |
43 | CUDA_CHECK_ERRORS();
44 | }
45 |
46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample)
47 | // output: grad_points(b, c, n)
48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints,
49 | int nsample,
50 | const float *__restrict__ grad_out,
51 | const int *__restrict__ idx,
52 | float *__restrict__ grad_points) {
53 | int batch_index = blockIdx.x;
54 | grad_out += batch_index * npoints * nsample * c;
55 | idx += batch_index * npoints * nsample;
56 | grad_points += batch_index * n * c;
57 |
58 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
59 | const int stride = blockDim.y * blockDim.x;
60 | for (int i = index; i < c * npoints; i += stride) {
61 | const int l = i / npoints;
62 | const int j = i % npoints;
63 | for (int k = 0; k < nsample; ++k) {
64 | int ii = idx[j * nsample + k];
65 | atomicAdd(grad_points + l * n + ii,
66 | grad_out[(l * npoints + j) * nsample + k]);
67 | }
68 | }
69 | }
70 |
71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
72 | int nsample, const float *grad_out,
73 | const int *idx, float *grad_points) {
74 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
75 |
76 | group_points_grad_kernel<<>>(
77 | b, c, n, npoints, nsample, grad_out, idx, grad_points);
78 |
79 | CUDA_CHECK_ERRORS();
80 | }
81 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/src/interpolate.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "interpolate.h"
7 | #include "utils.h"
8 |
9 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
10 | const float *known, float *dist2, int *idx);
11 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
12 | const float *points, const int *idx,
13 | const float *weight, float *out);
14 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
15 | const float *grad_out,
16 | const int *idx, const float *weight,
17 | float *grad_points);
18 |
19 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) {
20 | CHECK_CONTIGUOUS(unknowns);
21 | CHECK_CONTIGUOUS(knows);
22 | CHECK_IS_FLOAT(unknowns);
23 | CHECK_IS_FLOAT(knows);
24 |
25 | if (unknowns.is_cuda()) {
26 | CHECK_CUDA(knows);
27 | }
28 |
29 | at::Tensor idx =
30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
31 | at::device(unknowns.device()).dtype(at::ScalarType::Int));
32 | at::Tensor dist2 =
33 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
34 | at::device(unknowns.device()).dtype(at::ScalarType::Float));
35 |
36 | if (unknowns.is_cuda()) {
37 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
38 | unknowns.data(), knows.data(),
39 | dist2.data(), idx.data());
40 | } else {
41 | AT_ASSERT(false, "CPU not supported");
42 | }
43 |
44 | return {dist2, idx};
45 | }
46 |
47 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
48 | at::Tensor weight) {
49 | CHECK_CONTIGUOUS(points);
50 | CHECK_CONTIGUOUS(idx);
51 | CHECK_CONTIGUOUS(weight);
52 | CHECK_IS_FLOAT(points);
53 | CHECK_IS_INT(idx);
54 | CHECK_IS_FLOAT(weight);
55 |
56 | if (points.is_cuda()) {
57 | CHECK_CUDA(idx);
58 | CHECK_CUDA(weight);
59 | }
60 |
61 | at::Tensor output =
62 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
63 | at::device(points.device()).dtype(at::ScalarType::Float));
64 |
65 | if (points.is_cuda()) {
66 | three_interpolate_kernel_wrapper(
67 | points.size(0), points.size(1), points.size(2), idx.size(1),
68 | points.data(), idx.data(), weight.data(),
69 | output.data());
70 | } else {
71 | AT_ASSERT(false, "CPU not supported");
72 | }
73 |
74 | return output;
75 | }
76 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
77 | at::Tensor weight, const int m) {
78 | CHECK_CONTIGUOUS(grad_out);
79 | CHECK_CONTIGUOUS(idx);
80 | CHECK_CONTIGUOUS(weight);
81 | CHECK_IS_FLOAT(grad_out);
82 | CHECK_IS_INT(idx);
83 | CHECK_IS_FLOAT(weight);
84 |
85 | if (grad_out.is_cuda()) {
86 | CHECK_CUDA(idx);
87 | CHECK_CUDA(weight);
88 | }
89 |
90 | at::Tensor output =
91 | torch::zeros({grad_out.size(0), grad_out.size(1), m},
92 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
93 |
94 | if (grad_out.is_cuda()) {
95 | three_interpolate_grad_kernel_wrapper(
96 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m,
97 | grad_out.data(), idx.data(), weight.data(),
98 | output.data());
99 | } else {
100 | AT_ASSERT(false, "CPU not supported");
101 | }
102 |
103 | return output;
104 | }
105 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/src/interpolate_gpu.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include
7 | #include
8 | #include
9 |
10 | #include "cuda_utils.h"
11 |
12 | // input: unknown(b, n, 3) known(b, m, 3)
13 | // output: dist2(b, n, 3), idx(b, n, 3)
14 | __global__ void three_nn_kernel(int b, int n, int m,
15 | const float *__restrict__ unknown,
16 | const float *__restrict__ known,
17 | float *__restrict__ dist2,
18 | int *__restrict__ idx) {
19 | int batch_index = blockIdx.x;
20 | unknown += batch_index * n * 3;
21 | known += batch_index * m * 3;
22 | dist2 += batch_index * n * 3;
23 | idx += batch_index * n * 3;
24 |
25 | int index = threadIdx.x;
26 | int stride = blockDim.x;
27 | for (int j = index; j < n; j += stride) {
28 | float ux = unknown[j * 3 + 0];
29 | float uy = unknown[j * 3 + 1];
30 | float uz = unknown[j * 3 + 2];
31 |
32 | double best1 = 1e40, best2 = 1e40, best3 = 1e40;
33 | int besti1 = 0, besti2 = 0, besti3 = 0;
34 | for (int k = 0; k < m; ++k) {
35 | float x = known[k * 3 + 0];
36 | float y = known[k * 3 + 1];
37 | float z = known[k * 3 + 2];
38 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
39 | if (d < best1) {
40 | best3 = best2;
41 | besti3 = besti2;
42 | best2 = best1;
43 | besti2 = besti1;
44 | best1 = d;
45 | besti1 = k;
46 | } else if (d < best2) {
47 | best3 = best2;
48 | besti3 = besti2;
49 | best2 = d;
50 | besti2 = k;
51 | } else if (d < best3) {
52 | best3 = d;
53 | besti3 = k;
54 | }
55 | }
56 | dist2[j * 3 + 0] = best1;
57 | dist2[j * 3 + 1] = best2;
58 | dist2[j * 3 + 2] = best3;
59 |
60 | idx[j * 3 + 0] = besti1;
61 | idx[j * 3 + 1] = besti2;
62 | idx[j * 3 + 2] = besti3;
63 | }
64 | }
65 |
66 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
67 | const float *known, float *dist2, int *idx) {
68 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
69 | three_nn_kernel<<>>(b, n, m, unknown, known,
70 | dist2, idx);
71 |
72 | CUDA_CHECK_ERRORS();
73 | }
74 |
75 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
76 | // output: out(b, c, n)
77 | __global__ void three_interpolate_kernel(int b, int c, int m, int n,
78 | const float *__restrict__ points,
79 | const int *__restrict__ idx,
80 | const float *__restrict__ weight,
81 | float *__restrict__ out) {
82 | int batch_index = blockIdx.x;
83 | points += batch_index * m * c;
84 |
85 | idx += batch_index * n * 3;
86 | weight += batch_index * n * 3;
87 |
88 | out += batch_index * n * c;
89 |
90 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
91 | const int stride = blockDim.y * blockDim.x;
92 | for (int i = index; i < c * n; i += stride) {
93 | const int l = i / n;
94 | const int j = i % n;
95 | float w1 = weight[j * 3 + 0];
96 | float w2 = weight[j * 3 + 1];
97 | float w3 = weight[j * 3 + 2];
98 |
99 | int i1 = idx[j * 3 + 0];
100 | int i2 = idx[j * 3 + 1];
101 | int i3 = idx[j * 3 + 2];
102 |
103 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
104 | points[l * m + i3] * w3;
105 | }
106 | }
107 |
108 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
109 | const float *points, const int *idx,
110 | const float *weight, float *out) {
111 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
112 | three_interpolate_kernel<<>>(
113 | b, c, m, n, points, idx, weight, out);
114 |
115 | CUDA_CHECK_ERRORS();
116 | }
117 |
118 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3)
119 | // output: grad_points(b, c, m)
120 |
121 | __global__ void three_interpolate_grad_kernel(
122 | int b, int c, int n, int m, const float *__restrict__ grad_out,
123 | const int *__restrict__ idx, const float *__restrict__ weight,
124 | float *__restrict__ grad_points) {
125 | int batch_index = blockIdx.x;
126 | grad_out += batch_index * n * c;
127 | idx += batch_index * n * 3;
128 | weight += batch_index * n * 3;
129 | grad_points += batch_index * m * c;
130 |
131 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
132 | const int stride = blockDim.y * blockDim.x;
133 | for (int i = index; i < c * n; i += stride) {
134 | const int l = i / n;
135 | const int j = i % n;
136 | float w1 = weight[j * 3 + 0];
137 | float w2 = weight[j * 3 + 1];
138 | float w3 = weight[j * 3 + 2];
139 |
140 | int i1 = idx[j * 3 + 0];
141 | int i2 = idx[j * 3 + 1];
142 | int i3 = idx[j * 3 + 2];
143 |
144 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
145 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
146 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
147 | }
148 | }
149 |
150 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
151 | const float *grad_out,
152 | const int *idx, const float *weight,
153 | float *grad_points) {
154 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
155 | three_interpolate_grad_kernel<<>>(
156 | b, c, n, m, grad_out, idx, weight, grad_points);
157 |
158 | CUDA_CHECK_ERRORS();
159 | }
160 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/src/sampling.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "sampling.h"
7 | #include "utils.h"
8 |
9 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
10 | const float *points, const int *idx,
11 | float *out);
12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
13 | const float *grad_out, const int *idx,
14 | float *grad_points);
15 |
16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
17 | const float *dataset, float *temp,
18 | int *idxs);
19 |
20 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) {
21 | CHECK_CONTIGUOUS(points);
22 | CHECK_CONTIGUOUS(idx);
23 | CHECK_IS_FLOAT(points);
24 | CHECK_IS_INT(idx);
25 |
26 | if (points.is_cuda()) {
27 | CHECK_CUDA(idx);
28 | }
29 |
30 | at::Tensor output =
31 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
32 | at::device(points.device()).dtype(at::ScalarType::Float));
33 |
34 | if (points.is_cuda()) {
35 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
36 | idx.size(1), points.data(),
37 | idx.data(), output.data());
38 | } else {
39 | AT_ASSERT(false, "CPU not supported");
40 | }
41 |
42 | return output;
43 | }
44 |
45 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx,
46 | const int n) {
47 | CHECK_CONTIGUOUS(grad_out);
48 | CHECK_CONTIGUOUS(idx);
49 | CHECK_IS_FLOAT(grad_out);
50 | CHECK_IS_INT(idx);
51 |
52 | if (grad_out.is_cuda()) {
53 | CHECK_CUDA(idx);
54 | }
55 |
56 | at::Tensor output =
57 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
58 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
59 |
60 | if (grad_out.is_cuda()) {
61 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n,
62 | idx.size(1), grad_out.data(),
63 | idx.data(), output.data());
64 | } else {
65 | AT_ASSERT(false, "CPU not supported");
66 | }
67 |
68 | return output;
69 | }
70 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) {
71 | CHECK_CONTIGUOUS(points);
72 | CHECK_IS_FLOAT(points);
73 |
74 | at::Tensor output =
75 | torch::zeros({points.size(0), nsamples},
76 | at::device(points.device()).dtype(at::ScalarType::Int));
77 |
78 | at::Tensor tmp =
79 | torch::full({points.size(0), points.size(1)}, 1e10,
80 | at::device(points.device()).dtype(at::ScalarType::Float));
81 |
82 | if (points.is_cuda()) {
83 | furthest_point_sampling_kernel_wrapper(
84 | points.size(0), points.size(1), nsamples, points.data(),
85 | tmp.data(), output.data());
86 | } else {
87 | AT_ASSERT(false, "CPU not supported");
88 | }
89 |
90 | return output;
91 | }
92 |
--------------------------------------------------------------------------------
/third_party/pointnet2/_ext_src/src/sampling_gpu.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include
7 | #include
8 |
9 | #include "cuda_utils.h"
10 |
11 | // input: points(b, c, n) idx(b, m)
12 | // output: out(b, c, m)
13 | __global__ void gather_points_kernel(int b, int c, int n, int m,
14 | const float *__restrict__ points,
15 | const int *__restrict__ idx,
16 | float *__restrict__ out) {
17 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
18 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
19 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
20 | int a = idx[i * m + j];
21 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a];
22 | }
23 | }
24 | }
25 | }
26 |
27 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
28 | const float *points, const int *idx,
29 | float *out) {
30 | gather_points_kernel<<>>(b, c, n, npoints,
32 | points, idx, out);
33 |
34 | CUDA_CHECK_ERRORS();
35 | }
36 |
37 | // input: grad_out(b, c, m) idx(b, m)
38 | // output: grad_points(b, c, n)
39 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m,
40 | const float *__restrict__ grad_out,
41 | const int *__restrict__ idx,
42 | float *__restrict__ grad_points) {
43 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
44 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
45 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
46 | int a = idx[i * m + j];
47 | atomicAdd(grad_points + (i * c + l) * n + a,
48 | grad_out[(i * c + l) * m + j]);
49 | }
50 | }
51 | }
52 | }
53 |
54 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
55 | const float *grad_out, const int *idx,
56 | float *grad_points) {
57 | gather_points_grad_kernel<<>>(
59 | b, c, n, npoints, grad_out, idx, grad_points);
60 |
61 | CUDA_CHECK_ERRORS();
62 | }
63 |
64 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
65 | int idx1, int idx2) {
66 | const float v1 = dists[idx1], v2 = dists[idx2];
67 | const int i1 = dists_i[idx1], i2 = dists_i[idx2];
68 | dists[idx1] = max(v1, v2);
69 | dists_i[idx1] = v2 > v1 ? i2 : i1;
70 | }
71 |
72 | // Input dataset: (b, n, 3), tmp: (b, n)
73 | // Ouput idxs (b, m)
74 | template
75 | __global__ void furthest_point_sampling_kernel(
76 | int b, int n, int m, const float *__restrict__ dataset,
77 | float *__restrict__ temp, int *__restrict__ idxs) {
78 | if (m <= 0) return;
79 | __shared__ float dists[block_size];
80 | __shared__ int dists_i[block_size];
81 |
82 | int batch_index = blockIdx.x;
83 | dataset += batch_index * n * 3;
84 | temp += batch_index * n;
85 | idxs += batch_index * m;
86 |
87 | int tid = threadIdx.x;
88 | const int stride = block_size;
89 |
90 | int old = 0;
91 | if (threadIdx.x == 0) idxs[0] = old;
92 |
93 | __syncthreads();
94 | for (int j = 1; j < m; j++) {
95 | int besti = 0;
96 | float best = -1;
97 | float x1 = dataset[old * 3 + 0];
98 | float y1 = dataset[old * 3 + 1];
99 | float z1 = dataset[old * 3 + 2];
100 | for (int k = tid; k < n; k += stride) {
101 | float x2, y2, z2;
102 | x2 = dataset[k * 3 + 0];
103 | y2 = dataset[k * 3 + 1];
104 | z2 = dataset[k * 3 + 2];
105 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
106 | if (mag <= 1e-3) continue;
107 |
108 | float d =
109 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
110 |
111 | float d2 = min(d, temp[k]);
112 | temp[k] = d2;
113 | besti = d2 > best ? k : besti;
114 | best = d2 > best ? d2 : best;
115 | }
116 | dists[tid] = best;
117 | dists_i[tid] = besti;
118 | __syncthreads();
119 |
120 | if (block_size >= 512) {
121 | if (tid < 256) {
122 | __update(dists, dists_i, tid, tid + 256);
123 | }
124 | __syncthreads();
125 | }
126 | if (block_size >= 256) {
127 | if (tid < 128) {
128 | __update(dists, dists_i, tid, tid + 128);
129 | }
130 | __syncthreads();
131 | }
132 | if (block_size >= 128) {
133 | if (tid < 64) {
134 | __update(dists, dists_i, tid, tid + 64);
135 | }
136 | __syncthreads();
137 | }
138 | if (block_size >= 64) {
139 | if (tid < 32) {
140 | __update(dists, dists_i, tid, tid + 32);
141 | }
142 | __syncthreads();
143 | }
144 | if (block_size >= 32) {
145 | if (tid < 16) {
146 | __update(dists, dists_i, tid, tid + 16);
147 | }
148 | __syncthreads();
149 | }
150 | if (block_size >= 16) {
151 | if (tid < 8) {
152 | __update(dists, dists_i, tid, tid + 8);
153 | }
154 | __syncthreads();
155 | }
156 | if (block_size >= 8) {
157 | if (tid < 4) {
158 | __update(dists, dists_i, tid, tid + 4);
159 | }
160 | __syncthreads();
161 | }
162 | if (block_size >= 4) {
163 | if (tid < 2) {
164 | __update(dists, dists_i, tid, tid + 2);
165 | }
166 | __syncthreads();
167 | }
168 | if (block_size >= 2) {
169 | if (tid < 1) {
170 | __update(dists, dists_i, tid, tid + 1);
171 | }
172 | __syncthreads();
173 | }
174 |
175 | old = dists_i[0];
176 | if (tid == 0) idxs[j] = old;
177 | }
178 | }
179 |
180 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
181 | const float *dataset, float *temp,
182 | int *idxs) {
183 | unsigned int n_threads = opt_n_threads(n);
184 |
185 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
186 |
187 | switch (n_threads) {
188 | case 512:
189 | furthest_point_sampling_kernel<512>
190 | <<>>(b, n, m, dataset, temp, idxs);
191 | break;
192 | case 256:
193 | furthest_point_sampling_kernel<256>
194 | <<>>(b, n, m, dataset, temp, idxs);
195 | break;
196 | case 128:
197 | furthest_point_sampling_kernel<128>
198 | <<>>(b, n, m, dataset, temp, idxs);
199 | break;
200 | case 64:
201 | furthest_point_sampling_kernel<64>
202 | <<>>(b, n, m, dataset, temp, idxs);
203 | break;
204 | case 32:
205 | furthest_point_sampling_kernel<32>
206 | <<>>(b, n, m, dataset, temp, idxs);
207 | break;
208 | case 16:
209 | furthest_point_sampling_kernel<16>
210 | <<>>(b, n, m, dataset, temp, idxs);
211 | break;
212 | case 8:
213 | furthest_point_sampling_kernel<8>
214 | <<>>(b, n, m, dataset, temp, idxs);
215 | break;
216 | case 4:
217 | furthest_point_sampling_kernel<4>
218 | <<>>(b, n, m, dataset, temp, idxs);
219 | break;
220 | case 2:
221 | furthest_point_sampling_kernel<2>
222 | <<>>(b, n, m, dataset, temp, idxs);
223 | break;
224 | case 1:
225 | furthest_point_sampling_kernel<1>
226 | <<>>(b, n, m, dataset, temp, idxs);
227 | break;
228 | default:
229 | furthest_point_sampling_kernel<512>
230 | <<>>(b, n, m, dataset, temp, idxs);
231 | }
232 |
233 | CUDA_CHECK_ERRORS();
234 | }
235 |
--------------------------------------------------------------------------------
/third_party/pointnet2/pointnet2_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | ''' Testing customized ops. '''
7 |
8 | import torch
9 | from torch.autograd import gradcheck
10 | import numpy as np
11 |
12 | import os
13 | import sys
14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
15 | sys.path.append(BASE_DIR)
16 | import pointnet2_utils
17 |
18 | def test_interpolation_grad():
19 | batch_size = 1
20 | feat_dim = 2
21 | m = 4
22 | feats = torch.randn(batch_size, feat_dim, m, requires_grad=True).float().cuda()
23 |
24 | def interpolate_func(inputs):
25 | idx = torch.from_numpy(np.array([[[0,1,2],[1,2,3]]])).int().cuda()
26 | weight = torch.from_numpy(np.array([[[1,1,1],[2,2,2]]])).float().cuda()
27 | interpolated_feats = pointnet2_utils.three_interpolate(inputs, idx, weight)
28 | return interpolated_feats
29 |
30 | assert (gradcheck(interpolate_func, feats, atol=1e-1, rtol=1e-1))
31 |
32 | if __name__=='__main__':
33 | test_interpolation_grad()
34 |
--------------------------------------------------------------------------------
/third_party/pointnet2/pytorch_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | ''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch '''
7 | import torch
8 | import torch.nn as nn
9 | from typing import List, Tuple
10 |
11 | class SharedMLP(nn.Sequential):
12 |
13 | def __init__(
14 | self,
15 | args: List[int],
16 | *,
17 | bn: bool = False,
18 | activation=nn.ReLU(inplace=True),
19 | preact: bool = False,
20 | first: bool = False,
21 | name: str = ""
22 | ):
23 | super().__init__()
24 |
25 | for i in range(len(args) - 1):
26 | self.add_module(
27 | name + 'layer{}'.format(i),
28 | Conv2d(
29 | args[i],
30 | args[i + 1],
31 | bn=(not first or not preact or (i != 0)) and bn,
32 | activation=activation
33 | if (not first or not preact or (i != 0)) else None,
34 | preact=preact
35 | )
36 | )
37 |
38 |
39 | class _BNBase(nn.Sequential):
40 |
41 | def __init__(self, in_size, batch_norm=None, name=""):
42 | super().__init__()
43 | self.add_module(name + "bn", batch_norm(in_size))
44 |
45 | nn.init.constant_(self[0].weight, 1.0)
46 | nn.init.constant_(self[0].bias, 0)
47 |
48 |
49 | class BatchNorm1d(_BNBase):
50 |
51 | def __init__(self, in_size: int, *, name: str = ""):
52 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
53 |
54 |
55 | class BatchNorm2d(_BNBase):
56 |
57 | def __init__(self, in_size: int, name: str = ""):
58 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
59 |
60 |
61 | class BatchNorm3d(_BNBase):
62 |
63 | def __init__(self, in_size: int, name: str = ""):
64 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name)
65 |
66 |
67 | class _ConvBase(nn.Sequential):
68 |
69 | def __init__(
70 | self,
71 | in_size,
72 | out_size,
73 | kernel_size,
74 | stride,
75 | padding,
76 | activation,
77 | bn,
78 | init,
79 | conv=None,
80 | batch_norm=None,
81 | bias=True,
82 | preact=False,
83 | name=""
84 | ):
85 | super().__init__()
86 |
87 | bias = bias and (not bn)
88 | conv_unit = conv(
89 | in_size,
90 | out_size,
91 | kernel_size=kernel_size,
92 | stride=stride,
93 | padding=padding,
94 | bias=bias
95 | )
96 | init(conv_unit.weight)
97 | if bias:
98 | nn.init.constant_(conv_unit.bias, 0)
99 |
100 | if bn:
101 | if not preact:
102 | bn_unit = batch_norm(out_size)
103 | else:
104 | bn_unit = batch_norm(in_size)
105 |
106 | if preact:
107 | if bn:
108 | self.add_module(name + 'bn', bn_unit)
109 |
110 | if activation is not None:
111 | self.add_module(name + 'activation', activation)
112 |
113 | self.add_module(name + 'conv', conv_unit)
114 |
115 | if not preact:
116 | if bn:
117 | self.add_module(name + 'bn', bn_unit)
118 |
119 | if activation is not None:
120 | self.add_module(name + 'activation', activation)
121 |
122 |
123 | class Conv1d(_ConvBase):
124 |
125 | def __init__(
126 | self,
127 | in_size: int,
128 | out_size: int,
129 | *,
130 | kernel_size: int = 1,
131 | stride: int = 1,
132 | padding: int = 0,
133 | activation=nn.ReLU(inplace=True),
134 | bn: bool = False,
135 | init=nn.init.kaiming_normal_,
136 | bias: bool = True,
137 | preact: bool = False,
138 | name: str = ""
139 | ):
140 | super().__init__(
141 | in_size,
142 | out_size,
143 | kernel_size,
144 | stride,
145 | padding,
146 | activation,
147 | bn,
148 | init,
149 | conv=nn.Conv1d,
150 | batch_norm=BatchNorm1d,
151 | bias=bias,
152 | preact=preact,
153 | name=name
154 | )
155 |
156 |
157 | class Conv2d(_ConvBase):
158 |
159 | def __init__(
160 | self,
161 | in_size: int,
162 | out_size: int,
163 | *,
164 | kernel_size: Tuple[int, int] = (1, 1),
165 | stride: Tuple[int, int] = (1, 1),
166 | padding: Tuple[int, int] = (0, 0),
167 | activation=nn.ReLU(inplace=True),
168 | bn: bool = False,
169 | init=nn.init.kaiming_normal_,
170 | bias: bool = True,
171 | preact: bool = False,
172 | name: str = ""
173 | ):
174 | super().__init__(
175 | in_size,
176 | out_size,
177 | kernel_size,
178 | stride,
179 | padding,
180 | activation,
181 | bn,
182 | init,
183 | conv=nn.Conv2d,
184 | batch_norm=BatchNorm2d,
185 | bias=bias,
186 | preact=preact,
187 | name=name
188 | )
189 |
190 |
191 | class Conv3d(_ConvBase):
192 |
193 | def __init__(
194 | self,
195 | in_size: int,
196 | out_size: int,
197 | *,
198 | kernel_size: Tuple[int, int, int] = (1, 1, 1),
199 | stride: Tuple[int, int, int] = (1, 1, 1),
200 | padding: Tuple[int, int, int] = (0, 0, 0),
201 | activation=nn.ReLU(inplace=True),
202 | bn: bool = False,
203 | init=nn.init.kaiming_normal_,
204 | bias: bool = True,
205 | preact: bool = False,
206 | name: str = ""
207 | ):
208 | super().__init__(
209 | in_size,
210 | out_size,
211 | kernel_size,
212 | stride,
213 | padding,
214 | activation,
215 | bn,
216 | init,
217 | conv=nn.Conv3d,
218 | batch_norm=BatchNorm3d,
219 | bias=bias,
220 | preact=preact,
221 | name=name
222 | )
223 |
224 |
225 | class FC(nn.Sequential):
226 |
227 | def __init__(
228 | self,
229 | in_size: int,
230 | out_size: int,
231 | *,
232 | activation=nn.ReLU(inplace=True),
233 | bn: bool = False,
234 | init=None,
235 | preact: bool = False,
236 | name: str = ""
237 | ):
238 | super().__init__()
239 |
240 | fc = nn.Linear(in_size, out_size, bias=not bn)
241 | if init is not None:
242 | init(fc.weight)
243 | if not bn:
244 | nn.init.constant_(fc.bias, 0)
245 |
246 | if preact:
247 | if bn:
248 | self.add_module(name + 'bn', BatchNorm1d(in_size))
249 |
250 | if activation is not None:
251 | self.add_module(name + 'activation', activation)
252 |
253 | self.add_module(name + 'fc', fc)
254 |
255 | if not preact:
256 | if bn:
257 | self.add_module(name + 'bn', BatchNorm1d(out_size))
258 |
259 | if activation is not None:
260 | self.add_module(name + 'activation', activation)
261 |
262 | def set_bn_momentum_default(bn_momentum):
263 |
264 | def fn(m):
265 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
266 | m.momentum = bn_momentum
267 |
268 | return fn
269 |
270 |
271 | class BNMomentumScheduler(object):
272 |
273 | def __init__(
274 | self, model, bn_lambda, last_epoch=-1,
275 | setter=set_bn_momentum_default
276 | ):
277 | if not isinstance(model, nn.Module):
278 | raise RuntimeError(
279 | "Class '{}' is not a PyTorch nn Module".format(
280 | type(model).__name__
281 | )
282 | )
283 |
284 | self.model = model
285 | self.setter = setter
286 | self.lmbd = bn_lambda
287 |
288 | self.step(last_epoch + 1)
289 | self.last_epoch = last_epoch
290 |
291 | def step(self, epoch=None):
292 | if epoch is None:
293 | epoch = self.last_epoch + 1
294 |
295 | self.last_epoch = epoch
296 | self.model.apply(self.setter(self.lmbd(epoch)))
297 |
298 |
299 |
--------------------------------------------------------------------------------
/third_party/pointnet2/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from setuptools import setup
7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
8 | import glob
9 | import os.path as osp
10 |
11 | this_dir = osp.dirname(osp.abspath(__file__))
12 |
13 | _ext_src_root = "_ext_src"
14 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob(
15 | "{}/src/*.cu".format(_ext_src_root)
16 | )
17 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root))
18 |
19 | setup(
20 | name='pointnet2',
21 | ext_modules=[
22 | CUDAExtension(
23 | name='pointnet2._ext',
24 | sources=_ext_sources,
25 | extra_compile_args={
26 | "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))],
27 | "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))],
28 | },
29 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
30 | )
31 | ],
32 | cmdclass={
33 | 'build_ext': BuildExtension
34 | }
35 | )
36 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/DepthContrast/b8257890c94f7c58aeb5cefeb91af031692611d6/utils/__init__.py
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import datetime
9 | import sys
10 |
11 | import torch
12 | from torch import distributed as dist
13 |
14 |
15 | class Logger(object):
16 | def __init__(self, quiet=False, log_fn=None, rank=0, prefix=""):
17 | self.rank = rank if rank is not None else 0
18 | self.quiet = quiet
19 | self.log_fn = log_fn
20 |
21 | self.prefix = ""
22 | if prefix:
23 | self.prefix = prefix + ' | '
24 |
25 | self.file_pointers = []
26 | if self.rank == 0:
27 | if self.quiet:
28 | open(log_fn, 'w').close()
29 |
30 | def add_line(self, content):
31 | if self.rank == 0:
32 | msg = self.prefix+content
33 | if self.quiet:
34 | fp = open(self.log_fn, 'a')
35 | fp.write(msg+'\n')
36 | fp.flush()
37 | fp.close()
38 | else:
39 | print(msg)
40 | sys.stdout.flush()
41 |
42 |
43 | class ProgressMeter(object):
44 | def __init__(self, num_batches, meters, phase, epoch=None, logger=None, tb_writter=None):
45 | self.batches_per_epoch = num_batches
46 | self.batch_fmtstr = self._get_batch_fmtstr(epoch, num_batches)
47 | self.meters = meters
48 | self.phase = phase
49 | self.epoch = epoch
50 | self.logger = logger
51 | self.tb_writter = tb_writter
52 |
53 | def display(self, batch):
54 | step = self.epoch * self.batches_per_epoch + batch
55 | date = str(datetime.datetime.now())
56 | entries = ['{} | {} {}'.format(date, self.phase, self.batch_fmtstr.format(batch))]
57 | entries += [str(meter) for meter in self.meters]
58 | if self.logger is None:
59 | print('\t'.join(entries))
60 | else:
61 | self.logger.add_line('\t'.join(entries))
62 |
63 | if self.tb_writter is not None:
64 | for meter in self.meters:
65 | self.tb_writter.add_scalar('{}-batch/{}'.format(self.phase, meter.name), meter.val, step)
66 |
67 | def _get_batch_fmtstr(self, epoch, num_batches):
68 | num_digits = len(str(num_batches // 1))
69 | fmt = '{:' + str(num_digits) + 'd}'
70 | epoch_str = '[{}]'.format(epoch) if epoch is not None else ''
71 | return epoch_str+'[' + fmt + '/' + fmt.format(num_batches) + ']'
72 |
73 | def synchronize_meters(self, cur_gpu):
74 | metrics = torch.tensor([m.avg for m in self.meters]).cuda(cur_gpu)
75 | metrics_gather = [torch.ones_like(metrics) for _ in range(dist.get_world_size())]
76 | dist.all_gather(metrics_gather, metrics)
77 |
78 | metrics = torch.stack(metrics_gather).mean(0).cpu().numpy()
79 | for meter, m in zip(self.meters, metrics):
80 | meter.avg = m
81 |
--------------------------------------------------------------------------------
/utils/metrics_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | import torch
9 | from collections import deque
10 |
11 |
12 | def accuracy(output, target, topk=(1,)):
13 | """Computes the accuracy over the k top predictions for the specified values of k"""
14 | with torch.no_grad():
15 | maxk = max(topk)
16 | batch_size = target.size(0)
17 |
18 | _, pred = output.topk(maxk, 1, True, True)
19 | pred = pred.t()
20 | correct = pred.eq(target.view(1, -1).expand_as(pred))
21 |
22 | res = []
23 | for k in topk:
24 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
25 | res.append(correct_k.mul_(100.0 / batch_size))
26 | return res
27 |
28 |
29 | class AverageMeter(object):
30 | """Computes and stores the average and current value"""
31 | def __init__(self, name, fmt=':f', window_size=0):
32 | self.name = name
33 | self.fmt = fmt
34 | self.window_size = window_size
35 | self.reset()
36 |
37 | def reset(self):
38 | if self.window_size > 0:
39 | self.q = deque(maxlen=self.window_size)
40 | self.val = 0
41 | self.avg = 0
42 | self.sum = 0
43 | self.count = 0
44 |
45 | def update(self, val, n=1):
46 | self.val = val
47 | if self.window_size > 0:
48 | self.q.append((val, n))
49 | self.count = sum([n for v, n in self.q])
50 | self.sum = sum([v * n for v, n in self.q])
51 | else:
52 | self.sum += val * n
53 | self.count += n
54 | self.avg = self.sum / self.count
55 |
56 | def __str__(self):
57 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
58 | return fmtstr.format(**self.__dict__)
59 |
60 |
61 |
--------------------------------------------------------------------------------