├── .gitignore
├── Dockerfile
├── README.md
├── config
├── kitti_eval_ours.prototxt
├── kitti_eval_ours_moredata.prototxt
├── kitti_train_ours.prototxt
└── kitti_train_ours_moredata.prototxt
├── demo
├── framework.png
├── output.gif
├── pointcov.png
├── traj+pointcov.png
└── traj.png
├── doc
└── train.md
├── evaluate.py
├── freeze.yml
├── rslo
├── builder
│ ├── __init__.py
│ ├── dataset_builder.py
│ ├── input_reader_builder.py
│ ├── losses_builder.py
│ ├── lr_scheduler_builder.py
│ ├── optimizer_builder.py
│ ├── preprocess_builder.py
│ ├── second_builder.py
│ └── voxel_builder.py
├── core
│ ├── __init__.py
│ └── losses.py
├── data
│ ├── __init__.py
│ ├── dataset.py
│ ├── kitti_common.py
│ ├── kitti_dataset_crossnorm_hdf5.py
│ ├── kitti_dataset_hdf5.py
│ └── preprocess.py
├── layers
│ ├── MaskConv.py
│ ├── SparseConv.py
│ ├── common.py
│ ├── confidence.py
│ ├── normalization.py
│ ├── se_module.py
│ └── svd.py
├── models
│ ├── custom_resnet_spc.py
│ ├── middle.py
│ ├── odom_pred.py
│ ├── odom_pred_base.py
│ ├── voxel_encoder.py
│ └── voxel_odom_net.py
├── protos
│ ├── __init__.py
│ ├── complile.sh
│ ├── input_reader.proto
│ ├── input_reader_pb2.py
│ ├── losses.proto
│ ├── losses.proto.bak
│ ├── losses_pb2.py
│ ├── model.proto
│ ├── model_pb2.py
│ ├── optimizer.proto
│ ├── optimizer_pb2.py
│ ├── pipeline.proto
│ ├── pipeline_pb2.py
│ ├── preprocess.proto
│ ├── preprocess_pb2.py
│ ├── sampler.proto
│ ├── sampler_pb2.py
│ ├── second.proto
│ ├── second_pb2.py
│ ├── similarity.proto
│ ├── similarity_pb2.py
│ ├── target.proto.bak
│ ├── train.proto
│ ├── train_pb2.py
│ ├── voxel_generator.proto
│ └── voxel_generator_pb2.py
├── torchplus
│ ├── __init__.py
│ ├── metrics.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── functional.py
│ │ └── modules
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ └── normalization.py
│ ├── ops
│ │ ├── __init__.py
│ │ └── array_ops.py
│ ├── tools.py
│ └── train
│ │ ├── __init__.py
│ │ ├── checkpoint.py
│ │ ├── common.py
│ │ ├── fastai_optim.py
│ │ ├── learning_schedules.py
│ │ ├── learning_schedules_fastai.py
│ │ └── optim.py
└── utils
│ ├── __init__.py
│ ├── check.py
│ ├── config_tool.py
│ ├── config_tool
│ ├── __init__.py
│ └── train.py
│ ├── distributed_utils.py
│ ├── find.py
│ ├── geometric.py
│ ├── kitti_evaluation.py
│ ├── loader.py
│ ├── log_tool.py
│ ├── math.py
│ ├── pose_utils.py
│ ├── pose_utils_np.py
│ ├── progress_bar.py
│ ├── singleton.py
│ ├── timer.py
│ ├── util.py
│ └── visualization.py
├── script
├── create_hdf5.py
├── create_hdf5_crossnormal.py
└── eval_ours.sh
├── thirdparty
└── chamfer_distance
│ ├── __init__.py
│ ├── chamfer_distance.cpp
│ ├── chamfer_distance.cu
│ └── chamfer_distance.py
└── train_hdf5.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python
4 |
5 | ### Python ###
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
103 | __pypackages__/
104 |
105 | # Celery stuff
106 | celerybeat-schedule
107 | celerybeat.pid
108 |
109 | # SageMath parsed files
110 | *.sage.py
111 |
112 | # Environments
113 | .env
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 | .dmypy.json
134 | dmypy.json
135 |
136 | # Pyre type checker
137 | .pyre/
138 |
139 | # pytype static type analyzer
140 | .pytype/
141 |
142 | # Cython debug symbols
143 | cython_debug/
144 |
145 | ### VisualStudioCode ###
146 | .vscode/*
147 | !.vscode/settings.json
148 | !.vscode/tasks.json
149 | !.vscode/launch.json
150 | !.vscode/extensions.json
151 | *.code-workspace
152 |
153 | # Local History for Visual Studio Code
154 | .history/
155 |
156 | ### VisualStudioCode Patch ###
157 | # Ignore all local history of files
158 | .history
159 | .ionide
160 |
161 | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python
162 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | #FROM ubuntu:18.04
2 | # FROM nvidia/cuda:9.2-devel-ubuntu18.04
3 | FROM nvidia/cuda:9.2-cudnn7-devel-ubuntu18.04
4 |
5 | # Dependencies for glvnd and X11.
6 | RUN apt-get update \
7 | && apt-get install -y -qq --no-install-recommends \
8 | libglvnd0 \
9 | libgl1 \
10 | libglx0 \
11 | libegl1 \
12 | libxext6 \
13 | libx11-6 \
14 | && rm -rf /var/lib/apt/lists/*
15 | # Env vars for the nvidia-container-runtime.
16 | ENV NVIDIA_VISIBLE_DEVICES all
17 | ENV NVIDIA_DRIVER_CAPABILITIES graphics,utility,compute
18 |
19 | #env vars for cuda
20 | ENV CUDA_HOME /usr/local/cuda
21 |
22 | #install miniconda
23 | RUN apt-get update --fix-missing && \
24 | apt-get install -y wget bzip2 ca-certificates curl git && \
25 | apt-get clean && \
26 | rm -rf /var/lib/apt/lists/*
27 |
28 | RUN wget --quiet https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh -O ~/miniconda.sh && \
29 | /bin/bash ~/miniconda.sh -b -p /opt/miniconda3 && \
30 | rm ~/miniconda.sh && \
31 | /opt/miniconda3/bin/conda clean -tipsy && \
32 | ln -s /opt/miniconda3/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
33 | echo ". /opt/miniconda3/etc/profile.d/conda.sh" >> ~/.bashrc && \
34 | echo "conda activate base" >> ~/.bashrc && \
35 | echo "conda deactivate && conda activate py37" >> ~/.bashrc
36 |
37 | #https://blog.csdn.net/Mao_Jonah/article/details/89502380
38 | COPY freeze.yml freeze.yml
39 | RUN /opt/miniconda3/bin/conda env create -n py37 -f freeze.yml
40 | #install pytorch
41 | # RUN /opt/miniconda3/bin/conda install pytorch==1.5.1 torchvision==0.6.1 cudatoolkit=9.2 -c pytorch -n py37
42 | # RUN /opt/miniconda3/bin/conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=9.2 -c pytorch -n py37
43 | RUN /opt/miniconda3/bin/conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=9.2 -c pytorch -n py37
44 |
45 | #install cmake 3.13
46 | RUN apt-get update && apt-get install -y software-properties-common
47 | RUN wget -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc | apt-key add -
48 | RUN apt-add-repository 'deb https://apt.kitware.com/ubuntu/ bionic main' && apt-get update && apt-get install -y cmake \
49 | python3-distutils python-dev python3-dev \
50 | libboost-all-dev
51 |
52 |
53 |
54 |
55 | #install spconv
56 |
57 | WORKDIR /tmp/unique_for_spconv
58 | RUN git clone --recursive https://github.com/DecaYale/spconv_plus.git
59 | WORKDIR /tmp/unique_for_spconv/spconv_plus
60 | RUN /opt/miniconda3/envs/py37/bin/python setup.py bdist_wheel
61 | RUN cd ./dist && /opt/miniconda3/envs/py37/bin/pip3 install *.whl
62 |
63 |
64 |
65 |
66 | WORKDIR /tmp/
67 | COPY config.jupyter.tar config.jupyter.tar
68 | RUN tar -xvf config.jupyter.tar -C /root/
69 |
70 |
71 | # RUN add-apt-repository ppa:ubuntu-toolchain-r/test && apt-get update
72 | # RUN apt-get install -y gcc-5 g++-5
73 | # RUN ls /usr/bin/ | grep gcc
74 | # # RUN ls /usr/bin/ | grep g++
75 | # RUN mv /usr/bin/gcc /usr/bin/gcc.bak && ln -s /usr/bin/gcc-5 /usr/bin/gcc && gcc --version
76 | # RUN mv /usr/bin/g++ /usr/bin/g++.bak && ln -s /usr/bin/g++-5 /usr/bin/g++ && g++ --version
77 |
78 | #install apex
79 | # RUN . ~/.bashrc && conda activate py37 && git clone https://github.com/NVIDIA/apex.git \
80 | # RUN . /opt/miniconda3/etc/profile.d/conda.sh && conda init bash
81 | # RUN ls /opt/miniconda3/envs/py37/bin/ | grep pip
82 | # RUN git clone https://github.com/NVIDIA/apex.git \
83 | # && cd apex && git reset --hard f3a960f80244cf9e80558ab30f7f7e8cbf03c0a0 \
84 | # && /opt/miniconda3/envs/py37/bin/python setup.py install --cuda_ext --cpp_ext
85 | ENV TORCH_CUDA_ARCH_LIST "6.0 6.2 7.0 7.2"
86 | # make sure we don't overwrite some existing directory called "apex"
87 | WORKDIR /tmp/unique_for_apex
88 | # uninstall Apex if present, twice to make absolutely sure :)
89 | RUN /opt/miniconda3/envs/py37/bin/pip3 uninstall -y apex || :
90 | RUN /opt/miniconda3/envs/py37/bin/pip3 uninstall -y apex || :
91 | # SHA is something the user can touch to force recreation of this Docker layer,
92 | # and therefore force cloning of the latest version of Apex
93 | RUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git
94 | WORKDIR /tmp/unique_for_apex/apex
95 | RUN git checkout f3a960f80244cf9e80558ab30f7f7e8cbf03c0a0
96 | # RUN /opt/miniconda3/envs/py37/bin/pip3 install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
97 | RUN /opt/miniconda3/envs/py37/bin/pip3 install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
98 |
99 |
100 | #other pkgs
101 | RUN apt-get update \
102 | && apt-get install -y -qq --no-install-recommends \
103 | cmake build-essential vim xvfb unzip tmux psmisc \
104 | libx11-dev libassimp-dev \
105 | mesa-common-dev freeglut3-dev \
106 | && apt-get clean \
107 | && rm -rf /var/lib/apt/lists/*
108 |
109 | #create some directories
110 | RUN mkdir -p /home/yxu/Projects/ && ln -s /mnt/workspace/Works /home/yxu/Projects/Works \
111 | && mkdir -p /DATA/yxu/ && ln -s /mnt/workspace/datasets/ /DATA/yxu/LINEMOD_DEEPIM \
112 | && ln -s /mnt/workspace/datasets/LINEMOD/ /DATA/yxu/LINEMOD \
113 | && ln -s /mnt/workspace/datasets/BOP_LINEMOD/ /DATA/yxu/BOP_LINEMOD \
114 | && mkdir -p /mnt/lustre/xuyan2/ \
115 | && ln -s /home/yxu/datasets/ /mnt/lustre/xuyan2/datasets
116 |
117 | EXPOSE 8887 8888 8889 10000 10001 10002
118 | WORKDIR /
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RSLO
2 | [[Paper]](https://scholar.google.com/scholar?hl=zh-CN&as_sdt=0%2C5&q=Robust+Self-supervised+LiDAR+Odometry+via+Representative+Structure+Discovery+and+3D+Inherent+Error+Distribution+Modeling&btnG=)
3 |
4 | The code of paper **Robust Self-supervised LiDAR Odometry via Representative Structure Discovery and 3D Inherent Error Modeling** accepted by IEEE Robotics and Automation Letters (RA-L), 2021.
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | ## Framework
14 | The self-supervised two-frame odometry network contains three main modules including the Geometric Unit Feature Encoding module, Geometric Unit Transformation Estimation module and the Ego-motion voting module.
15 |
16 |
17 |
18 |
19 |
20 |
21 | ## Estimated Trajectories and Point Covariance Estimations
22 | The comparison (on estimated trajectories) of our method with other competitive baselines (left). The visualization of our estimated point covariances (right).
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 | ## Installation
33 | As the dependencies is complex, a dockerfile has been provide. You need to install [docker](https://docs.docker.com/get-docker/) and [nvidia-docker2](https://github.com/NVIDIA/nvidia-docker) first and then set up the docker image and start up a container with the following commands:
34 |
35 | ```
36 | cd RSLO
37 | sudo docker build -t rslo .
38 | sudo docker run -it --runtime=nvidia --ipc=host --volume="HOST_VOLUME_YOU_WANT_TO_MAP:DOCKER_VOLUME" -e DISPLAY=$DISPLAY -e QT_X11_NO_MITSHM=1 rslo bash
39 |
40 | ```
41 |
42 | ## Data Preparation
43 | You need to download the [KITTI odometry dataset](http://www.cvlibs.net/datasets/kitti/eval_odometry.php) and unzip them into the below directory structures.
44 | ```
45 | ./kitti/dataset
46 | |──sequences
47 | | ├── 00/
48 | | | ├── calib.txt
49 | | │ ├── velodyne/
50 | | | | ├── 000000.bin
51 | | | | ├── 000001.bin
52 | | | | └── ...
53 | | ├── 01/
54 | | | ...
55 | | └── 21/
56 | └──poses
57 | |──00.txt
58 | |──01.txt
59 | | ...
60 | └──10.txt
61 |
62 | ```
63 | Then, create hdf5 data with
64 | ```
65 | python script create_hdf5.py ./kitti/dataset ./kitti/dataset/all.h5
66 | ```
67 |
68 | ## Test with the Pretrained Models
69 | The trained models on the KITTI dataset have been uploaded to the [OneDrive](https://1drv.ms/u/s!AgP7bY0L6pvta-AeCK1tFxJrn-8?e=1hYWzy). You can download them and put them into the directory "weights" for testing.
70 |
71 | ```
72 | export PYTHONPATH="$PROJECT_ROOT_PATH:$PYTHONPATH"
73 | export PYTHONPATH="$PROJECT_ROOT_PATH/rslo:$PYTHONPATH"
74 | python -u $PROJECT_ROOT_PATH/evaluate.py multi_proc_eval \
75 | --config_path $PROJECT_ROOT_PATH/config/kitti_eval_ours.prototxt \
76 | --model_dir ./outputs/ \
77 | --use_dist True \
78 | --gpus_per_node 1 \
79 | --use_apex True \
80 | --world_size 1 \
81 | --dist_port 20000 \
82 | --pretrained_path $PROJECT_ROOT_PATH/weights/ours.tckpt \
83 | --refine False \
84 | ```
85 | Note that you need to specify the PROJECT_ROOT_PATH, i.e. the absolute directory of the project folder "RSLO" and modify the path to the created data, i.e. all.h5, in the configuration file kitti_eval_ours.prototxt before running the above commands. A bash script "script/eval_ours.sh" is provided for reference.
86 |
87 | ## Training from Scratch
88 | A basic training script demo is shown as below. You can increase the GPU number, i.e. the variable "GPUs", according to your available resources. Generally, larger batch sizes produce stabler training procedures and better final performances.
89 |
90 |
91 | ```
92 | export PYTHONPATH="$PROJECT_ROOT_PATH:$PYTHONPATH"
93 | export PYTHONPATH="$PROJECT_ROOT_PATH/rslo:$PYTHONPATH"
94 | GPUs=1 # the number of gpus you use
95 | python -u $PROJECT_ROOT_PATH/train_hdf5.py multi_proc_train \
96 | --config_path $PROJECT_ROOT_PATH/config/kitti_train_ours.prototxt \
97 | --model_dir ./outputs/ \
98 | --use_dist True \
99 | --gpus_per_node $GPUs \
100 | --use_apex True \
101 | --world_size $GPUs \
102 | --dist_port 20000 \
103 | --refine False \
104 |
105 | ```
106 |
107 |
108 |
109 |
110 |
111 | ## Acknowledgments
112 | We thank for the open-sourced codebases [spconv](https://github.com/traveller59/spconv) and [second](https://github.com/traveller59/second.pytorch)
113 |
114 | ## Citation
115 | To cite our paper
116 | ```
117 | @article{xu2022robust,
118 | title={Robust Self-supervised LiDAR Odometry via Representative Structure Discovery and 3D Inherent Error Modeling},
119 | author={Xu, Yan and Lin, Junyi and Shi, Jianping and Zhang, Guofeng and Wang, Xiaogang and Li, Hongsheng},
120 | journal={IEEE Robotics and Automation Letters},
121 | year={2021},
122 | publisher={IEEE}
123 | }
124 | ```
125 | ```
126 | @inproceedings{xu2020selfvoxelo,
127 | title = {SelfVoxeLO: Self-supervised LiDAR Odometry with Voxel-based Deep Neural Networks},
128 | author = {Yan Xu and Zhaoyang Huang and Kwan{-}Yee Lin and Xinge Zhu and Jianping Shi and Hujun Bao and Guofeng Zhang and Hongsheng Li},
129 | booktitle = {4th Conference on Robot Learning, CoRL 2020, 16-18 November 2020, Virtual Event / Cambridge, MA, {USA}},
130 | volume = {155},
131 | pages = {115--125},
132 | publisher = {{PMLR}},
133 | year = {2020},
134 | }
135 | ```
136 |
137 | ## TODO List and ETA
138 | - [x] Inference code and pretrained models (9/10/2022)
139 | - [x] Training code (10/12/2022)
140 | - [ ] Code cleaning and refactor
141 |
142 |
143 |
144 |
145 |
--------------------------------------------------------------------------------
/config/kitti_eval_ours.prototxt:
--------------------------------------------------------------------------------
1 | model: {
2 | second: {
3 | use_GN: false
4 | icp_iter: 2
5 | network_class_name: "UnVoxelOdomNetICP3"
6 | voxel_generator {
7 | point_cloud_range : [-70.4, -38.4, -3, 70.4, 38.4, 5] # [x0,y0,z0, x1,y1,z1]
8 | voxel_size : [0.1, 0.1, 0.2]
9 | max_number_of_points_per_voxel : 10
10 | block_factor :1
11 | block_size : 8
12 | height_threshold : -1 #0.05
13 | }
14 |
15 | voxel_feature_extractor: {
16 | module_class_name: "SimpleVoxel_XYZINormalC"
17 | num_filters: [16]
18 | with_distance: false
19 | num_input_features: 7
20 | not_use_norm: true
21 |
22 | }
23 | middle_feature_extractor: {
24 | module_class_name: "SpMiddleFHDWithCov2_3"
25 | downsample_factor: 8
26 | num_input_features: 7
27 | bn_type:"None"
28 | use_leakyReLU: true
29 | }
30 |
31 | odom_predictor:{
32 | module_class_name : "UNRResNetOdomPredEncDecSVDTempMask"
33 | num_input_features :128
34 | layer_nums: [3,5,5]
35 | layer_strides:[2,2,2]
36 | num_filters: [128, 128, 256]
37 | upsample_strides:[2,2,2]#[1,2,4]
38 | num_upsample_filters:[128,64,64]
39 | pool_size : 1
40 | pool_type: "avg_pool"
41 | cycle_constraint : true
42 | # not_use_norm: false
43 | bn_type:"SyncBN"
44 | pred_pyramid_motion: true
45 | # use_sparse_conv: false
46 | conv_type:"mask_conv"
47 | odom_format: "rx+t" #"r(x+t)"
48 | dense_predict: true
49 | dropout: 0.0000000000000000000001
50 | conf_type: "softmax"#"linear"
51 | use_deep_supervision:true
52 | use_svd: false
53 | }
54 |
55 | loss: {
56 | pyloss_exp_w_base:0.5
57 | rotation_loss{
58 | loss_type: "AdaptiveWeightedL2"
59 | weight: 1,
60 | init_alpha: -2.5
61 | }
62 | translation_loss{
63 | loss_type: "AdaptiveWeightedL2",
64 | weight: 1,
65 | init_alpha: 0
66 | }
67 | consistency_loss{
68 | loss_type: "Aleat5_1ChamferL2NormalWeightedALLSVDLoss"
69 | weight: 1,
70 | penalize_ratio: 0.97
71 | norm: false
72 | pred_downsample_ratio: 1
73 | reg_weight: 0.005#0.0005
74 | sph_weight:1
75 | }
76 | }
77 | num_point_features: 7 # model's num point feature should be independent of dataset
78 | # Outputs
79 | use_sigmoid_score: true
80 | encode_background_as_zeros: true
81 | }
82 | }
83 |
84 |
85 | eval_input_reader: {
86 | dataset: {
87 | dataset_class_name: "KittiDatasetHDF5"
88 | # dataset_class_name: "KittiDataset"
89 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/eval_data_info4.pkl" #TODO:
90 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5"
91 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5"
92 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset"
93 | seq_length: 2
94 | skip: 1
95 | step:1
96 | random_skip: false
97 | }
98 | batch_size: 1
99 | preprocess: {
100 | max_number_of_voxels: 40000
101 | shuffle_points: false
102 | num_workers: 1
103 | anchor_area_threshold: -1
104 | remove_environment: false
105 | }
106 | }
107 |
--------------------------------------------------------------------------------
/config/kitti_eval_ours_moredata.prototxt:
--------------------------------------------------------------------------------
1 | model: {
2 | second: {
3 | use_GN: false#true
4 | icp_iter: 2
5 | # sync_bn: true
6 | network_class_name: "UnVoxelOdomNetICP3"#"UnVoxelOdomNetICP"#"UnVoxelOdomNetICP2"#"VoxelOdomNet"
7 | voxel_generator {
8 | # point_cloud_range : [0, -40, -3, 70.4, 40, 8] # [x0,y0,z0, x1,y1,z1]
9 | #point_cloud_range : [-68.8, -40, -3, 68.8, 40, 5] # [x0,y0,z0, x1,y1,z1]
10 | point_cloud_range : [-70.4, -38.4, -3, 70.4, 38.4, 5] # [x0,y0,z0, x1,y1,z1]
11 | # point_cloud_range : [0, -32.0, -3, 52.8, 32.0, 1]
12 | # voxel_size : [0.05, 0.05, 0.2]
13 | voxel_size : [0.1, 0.1, 0.2]
14 | max_number_of_points_per_voxel : 10#5
15 | block_factor :1
16 | block_size : 8
17 | height_threshold : -1#0.05
18 | }
19 |
20 | voxel_feature_extractor: {
21 | module_class_name: "SimpleVoxel_XYZINormalC"#"SimpleVoxel"
22 | num_filters: [16]
23 | with_distance: false
24 | num_input_features: 7#8
25 | not_use_norm: true
26 |
27 | }
28 | middle_feature_extractor: {
29 | module_class_name: "SpMiddleFHDWithCov2_3"#"SpMiddleFHDWithConf4_1"#"SpMiddleFHD"
30 | # num_filters_down1: [] # protobuf don't support empty list.
31 | # num_filters_down2: []
32 | downsample_factor: 8
33 | num_input_features: 7#8
34 | # not_use_norm: true
35 | bn_type:"None"
36 | #not_use_norm: false
37 | use_leakyReLU: true
38 | }
39 |
40 | odom_predictor:{
41 | module_class_name : "UNRResNetOdomPredEncDecSVDTempMask" #"UNRResNetOdomPredEncDecSVD"#"RResNetOdomPredEncDec"
42 | num_input_features :128
43 | layer_nums: [3,5,5]
44 | layer_strides:[2,2,2]
45 | num_filters: [128, 128, 256]
46 | upsample_strides:[2,2,2]#[1,2,4]
47 | num_upsample_filters:[128,64,64] #[256, 256, 256]
48 | pool_size : 1
49 | pool_type: "avg_pool"
50 | cycle_constraint : true
51 | # not_use_norm: false
52 | bn_type:"SyncBN"
53 | pred_pyramid_motion: true
54 | # use_sparse_conv: false
55 | conv_type:"mask_conv"
56 | odom_format: "rx+t"#"r(x+t)"
57 | dense_predict: true
58 | dropout: 0.0000000000000000000001#0.2
59 | conf_type: "softmax"#"linear"
60 | use_deep_supervision:true
61 | use_svd: false
62 | }
63 |
64 | loss: {
65 | pyloss_exp_w_base:0.5
66 | rotation_loss{
67 | loss_type: "AdaptiveWeightedL2" #"AdaptiveWeightedL2RMatrixLoss"# #"AdaptiveWeightedL2",
68 | weight: 1#0,
69 | init_alpha: -2.5
70 | }
71 | translation_loss{
72 | loss_type: "AdaptiveWeightedL2",
73 | weight: 1,
74 | init_alpha: 0
75 | }
76 | consistency_loss{
77 | loss_type: "Aleat5_1ChamferL2NormalWeightedALLSVDLoss"#"ChamferL2NormalWeightedALLSVDLoss"#"ChamferL2NormalWeightedALLLoss"#"ChamferL2Loss"#"CosineDistance",
78 | weight: 1,#10#1,
79 | penalize_ratio: 0.97#0.99#0.9999#0.95#0.9999
80 | norm: false
81 | pred_downsample_ratio: 1
82 | reg_weight: 0.005#0.0005
83 |
84 | sph_weight:1
85 | # sample_block_size: [0.5,1,1]
86 | }
87 | }
88 | num_point_features: 7#8#4 # model's num point feature should be independent of dataset
89 | # Outputs
90 | use_sigmoid_score: true
91 | encode_background_as_zeros: true
92 |
93 |
94 | # Loss
95 |
96 |
97 | # Postprocess
98 | post_center_limit_range: [0, -40, -2.2, 70.4, 40, 0.8]
99 |
100 | }
101 | }
102 |
103 | train_input_reader: {
104 | dataset: {
105 | dataset_class_name: "KittiDatasetHDF5"
106 | # dataset_class_name: "KittiDataset"
107 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_data_info7.pkl"
108 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5"
109 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5"
110 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset"
111 | seq_length: 3
112 | skip: 1
113 | step:1
114 | random_skip: false#true
115 | }
116 |
117 | batch_size: 1#2#3#1#3 #8
118 | preprocess: {
119 | max_number_of_voxels: 40000#17000
120 | shuffle_points: false#true
121 | num_workers: 2#2
122 | # groundtruth_localization_noise_std: [1.0, 1.0, 0.5]
123 | # groundtruth_rotation_uniform_noise: [-0.3141592654, 0.3141592654]
124 | # groundtruth_rotation_uniform_noise: [-1.57, 1.57]
125 | # groundtruth_rotation_uniform_noise: [-0.78539816, 0.78539816]
126 | # global_rotation_uniform_noise: [-0.78539816, 0.78539816]
127 | # global_scaling_uniform_noise: [0.95, 1.05]
128 | # global_random_rotation_range_per_object: [0, 0] # pi/4 ~ 3pi/4
129 | # global_translate_noise_std: [0, 0, 0]
130 | # anchor_area_threshold: -1
131 | remove_points_after_sample: true
132 | groundtruth_points_drop_percentage: 0.0
133 | groundtruth_drop_max_keep_points: 15
134 | # remove_unknown_examples: false
135 | # sample_importance: 1.0
136 | random_flip_x: false
137 | random_flip_y: true#false
138 | # remove_environment: false
139 | # downsample_voxel_sizes: [0.05,0.1,0.2,0.4]
140 | downsample_voxel_sizes: [0.1] #[0.1,0.2,0.4, 0.8]
141 |
142 | }
143 | }
144 |
145 | train_config: {
146 | optimizer: {
147 | adam_optimizer: {
148 | learning_rate: {
149 | one_cycle: {
150 | lr_max: 0.8e-3#1e-3#2e-3
151 | moms: [0.95, 0.85]
152 | div_factor: 10.0
153 | pct_start: 0.05
154 | }
155 | }
156 | weight_decay: 1e-5#0.001
157 | }
158 | fixed_weight_decay: true
159 | use_moving_average: false
160 | }
161 | #optimizer: {
162 | # adam_optimizer: {
163 | # learning_rate: {
164 | # exponential_decay: {
165 | # initial_learning_rate: 0.002#0.008#0.002
166 | # decay_length: 0.05#0.1#0.1
167 | # decay_factor: 0.8#0.8
168 | # staircase: True
169 | # }
170 | # }
171 | # weight_decay: 1e-6#0.0001
172 | # }
173 | # fixed_weight_decay: false
174 | # use_moving_average: false
175 | #}
176 | # steps: 99040 # 1238 * 120
177 | # s: 49520 # 619 * 80
178 | # steps: 30950 # 619 * 80
179 | # steps_per_eval: 3095 # 619 * 5
180 | steps: 200000#200000#42500#170000#12750#23200*20/8/4 #23200 # 464 * 50
181 | steps_per_eval: 4000#425#1700#850#637#23200/8#2320 # 619 * 5
182 |
183 | # save_checkpoints_secs : 1800 # half hour
184 | # save_summary_steps : 10
185 | enable_mixed_precision: false
186 | loss_scale_factor: -1
187 | clear_metrics_every_epoch: true
188 | }
189 |
190 | eval_input_reader: {
191 | dataset: {
192 | dataset_class_name: "KittiDatasetHDF5"
193 | # dataset_class_name: "KittiDataset"
194 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/eval_data_info4.pkl" #TODO:
195 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5"
196 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5"
197 | # kitti_info_path: "/media/yy/960evo/datasets/kitti/kitti_infos_test.pkl"
198 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset"
199 | seq_length: 2
200 | skip: 1
201 | step:1
202 | random_skip: false
203 | }
204 | batch_size: 1
205 | preprocess: {
206 | max_number_of_voxels: 40000
207 | shuffle_points: false
208 | num_workers: 1
209 | anchor_area_threshold: -1
210 | remove_environment: false
211 | }
212 | }
213 | eval_train_input_reader: {
214 | dataset: {
215 | dataset_class_name: "KittiDatasetHDF5"
216 | # dataset_class_name: "KittiDataset"
217 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/eval_train_data_info.pkl" #TODO:
218 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5"
219 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5"
220 | # kitti_info_path: "/media/yy/960evo/datasets/kitti/kitti_infos_test.pkl"
221 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset"
222 | seq_length: 2
223 | skip:1
224 | step:1
225 | random_skip: false
226 | }
227 | batch_size: 1
228 | preprocess: {
229 | max_number_of_voxels: 40000
230 | shuffle_points: false
231 | num_workers: 1
232 | }
233 | }
--------------------------------------------------------------------------------
/config/kitti_train_ours.prototxt:
--------------------------------------------------------------------------------
1 | model: {
2 | second: {
3 | use_GN: false#true
4 | icp_iter: 2
5 | # sync_bn: true
6 | network_class_name: "UnVoxelOdomNetICP3"#"UnVoxelOdomNetICP"#"UnVoxelOdomNetICP2"#"VoxelOdomNet"
7 | voxel_generator {
8 | # point_cloud_range : [0, -40, -3, 70.4, 40, 8] # [x0,y0,z0, x1,y1,z1]
9 | #point_cloud_range : [-68.8, -40, -3, 68.8, 40, 5] # [x0,y0,z0, x1,y1,z1]
10 | point_cloud_range : [-70.4, -38.4, -3, 70.4, 38.4, 5] # [x0,y0,z0, x1,y1,z1]
11 | # point_cloud_range : [0, -32.0, -3, 52.8, 32.0, 1]
12 | # voxel_size : [0.05, 0.05, 0.2]
13 | voxel_size : [0.1, 0.1, 0.2]
14 | max_number_of_points_per_voxel : 10#5
15 | block_factor :1
16 | block_size : 8
17 | height_threshold : -1#0.05
18 | }
19 |
20 | voxel_feature_extractor: {
21 | module_class_name: "SimpleVoxel_XYZINormalC"#"SimpleVoxel"
22 | num_filters: [16]
23 | with_distance: false
24 | num_input_features: 7#8
25 | not_use_norm: true
26 |
27 | }
28 | middle_feature_extractor: {
29 | module_class_name: "SpMiddleFHDWithCov2_3"#"SpMiddleFHDWithConf4_1"#"SpMiddleFHD"
30 | # num_filters_down1: [] # protobuf don't support empty list.
31 | # num_filters_down2: []
32 | downsample_factor: 8
33 | num_input_features: 7#8
34 | # not_use_norm: true
35 | bn_type:"None"
36 | #not_use_norm: false
37 | use_leakyReLU: true
38 | }
39 |
40 | odom_predictor:{
41 | module_class_name : "UNRResNetOdomPredEncDecSVDTempMask" #"UNRResNetOdomPredEncDecSVD"#"RResNetOdomPredEncDec"
42 | num_input_features :128
43 | layer_nums: [3,5,5]
44 | layer_strides:[2,2,2]
45 | num_filters: [128, 128, 256]
46 | upsample_strides:[2,2,2]#[1,2,4]
47 | num_upsample_filters:[128,64,64] #[256, 256, 256]
48 | pool_size : 1
49 | pool_type: "avg_pool"
50 | cycle_constraint : true
51 | # not_use_norm: false
52 | bn_type:"SyncBN"
53 | pred_pyramid_motion: true
54 | # use_sparse_conv: false
55 | conv_type:"mask_conv"
56 | odom_format: "rx+t"#"r(x+t)"
57 | dense_predict: true
58 | dropout: 0.0000000000000000000001#0.2
59 | conf_type: "softmax"#"linear"
60 | use_deep_supervision:true
61 | use_svd: false
62 | }
63 |
64 | loss: {
65 | pyloss_exp_w_base:0.5
66 | rotation_loss{
67 | loss_type: "AdaptiveWeightedL2" #"AdaptiveWeightedL2RMatrixLoss"# #"AdaptiveWeightedL2",
68 | weight: 1#0,
69 | init_alpha: -2.5
70 | }
71 | translation_loss{
72 | loss_type: "AdaptiveWeightedL2",
73 | weight: 1,
74 | init_alpha: 0
75 | }
76 | consistency_loss{
77 | loss_type: "Aleat5_1ChamferL2NormalWeightedALLSVDLoss"#"ChamferL2NormalWeightedALLSVDLoss"#"ChamferL2NormalWeightedALLLoss"#"ChamferL2Loss"#"CosineDistance",
78 | weight: 1,#10#1,
79 | penalize_ratio: 0.97#0.99#0.9999#0.95#0.9999
80 | norm: false
81 | pred_downsample_ratio: 1
82 | reg_weight: 0.005#0.0005
83 |
84 | sph_weight:1
85 | # sample_block_size: [0.5,1,1]
86 | }
87 | }
88 | num_point_features: 7#8#4 # model's num point feature should be independent of dataset
89 | # Outputs
90 | use_sigmoid_score: true
91 | encode_background_as_zeros: true
92 |
93 |
94 | # Loss
95 |
96 |
97 | # Postprocess
98 | post_center_limit_range: [0, -40, -2.2, 70.4, 40, 0.8]
99 |
100 | }
101 | }
102 |
103 | train_input_reader: {
104 | dataset: {
105 | dataset_class_name: "KittiDatasetHDF5"
106 | # dataset_class_name: "KittiDataset"
107 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_data_info7.pkl"
108 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5"
109 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5"
110 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset"
111 | seq_length: 3
112 | skip: 1
113 | step:1
114 | random_skip: false#true
115 | }
116 |
117 | batch_size: 1#2#3#1#3 #8
118 | preprocess: {
119 | max_number_of_voxels: 40000#17000
120 | shuffle_points: false#true
121 | num_workers: 2#2
122 | # groundtruth_localization_noise_std: [1.0, 1.0, 0.5]
123 | # groundtruth_rotation_uniform_noise: [-0.3141592654, 0.3141592654]
124 | # groundtruth_rotation_uniform_noise: [-1.57, 1.57]
125 | # groundtruth_rotation_uniform_noise: [-0.78539816, 0.78539816]
126 | # global_rotation_uniform_noise: [-0.78539816, 0.78539816]
127 | # global_scaling_uniform_noise: [0.95, 1.05]
128 | # global_random_rotation_range_per_object: [0, 0] # pi/4 ~ 3pi/4
129 | # global_translate_noise_std: [0, 0, 0]
130 | # anchor_area_threshold: -1
131 | remove_points_after_sample: true
132 | groundtruth_points_drop_percentage: 0.0
133 | groundtruth_drop_max_keep_points: 15
134 | # remove_unknown_examples: false
135 | # sample_importance: 1.0
136 | random_flip_x: false
137 | random_flip_y: true#false
138 | # remove_environment: false
139 | # downsample_voxel_sizes: [0.05,0.1,0.2,0.4]
140 | downsample_voxel_sizes: [0.1] #[0.1,0.2,0.4, 0.8]
141 |
142 | }
143 | }
144 |
145 | train_config: {
146 | optimizer: {
147 | adam_optimizer: {
148 | learning_rate: {
149 | one_cycle: {
150 | lr_max: 0.8e-3#1e-3#2e-3
151 | moms: [0.95, 0.85]
152 | div_factor: 10.0
153 | pct_start: 0.05
154 | }
155 | }
156 | weight_decay: 1e-5#0.001
157 | }
158 | fixed_weight_decay: true
159 | use_moving_average: false
160 | }
161 | #optimizer: {
162 | # adam_optimizer: {
163 | # learning_rate: {
164 | # exponential_decay: {
165 | # initial_learning_rate: 0.002#0.008#0.002
166 | # decay_length: 0.05#0.1#0.1
167 | # decay_factor: 0.8#0.8
168 | # staircase: True
169 | # }
170 | # }
171 | # weight_decay: 1e-6#0.0001
172 | # }
173 | # fixed_weight_decay: false
174 | # use_moving_average: false
175 | #}
176 | # steps: 99040 # 1238 * 120
177 | # s: 49520 # 619 * 80
178 | # steps: 30950 # 619 * 80
179 | # steps_per_eval: 3095 # 619 * 5
180 | steps: 200000#200000#42500#170000#12750#23200*20/8/4 #23200 # 464 * 50
181 | steps_per_eval: 4000#425#1700#850#637#23200/8#2320 # 619 * 5
182 |
183 | # save_checkpoints_secs : 1800 # half hour
184 | # save_summary_steps : 10
185 | enable_mixed_precision: false
186 | loss_scale_factor: -1
187 | clear_metrics_every_epoch: true
188 | }
189 |
190 | eval_input_reader: {
191 | dataset: {
192 | dataset_class_name: "KittiDatasetHDF5"
193 | # dataset_class_name: "KittiDataset"
194 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/eval_data_info4.pkl" #TODO:
195 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5"
196 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5"
197 | # kitti_info_path: "/media/yy/960evo/datasets/kitti/kitti_infos_test.pkl"
198 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset"
199 | seq_length: 2
200 | skip: 1
201 | step:1
202 | random_skip: false
203 | }
204 | batch_size: 1
205 | preprocess: {
206 | max_number_of_voxels: 40000
207 | shuffle_points: false
208 | num_workers: 1
209 | anchor_area_threshold: -1
210 | remove_environment: false
211 | }
212 | }
213 | eval_train_input_reader: {
214 | dataset: {
215 | dataset_class_name: "KittiDatasetHDF5"
216 | # dataset_class_name: "KittiDataset"
217 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/eval_train_data_info.pkl" #TODO:
218 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5"
219 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5"
220 | # kitti_info_path: "/media/yy/960evo/datasets/kitti/kitti_infos_test.pkl"
221 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset"
222 | seq_length: 2
223 | skip:1
224 | step:1
225 | random_skip: false
226 | }
227 | batch_size: 1
228 | preprocess: {
229 | max_number_of_voxels: 40000
230 | shuffle_points: false
231 | num_workers: 1
232 | }
233 | }
--------------------------------------------------------------------------------
/demo/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/demo/framework.png
--------------------------------------------------------------------------------
/demo/output.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/demo/output.gif
--------------------------------------------------------------------------------
/demo/pointcov.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/demo/pointcov.png
--------------------------------------------------------------------------------
/demo/traj+pointcov.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/demo/traj+pointcov.png
--------------------------------------------------------------------------------
/demo/traj.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/demo/traj.png
--------------------------------------------------------------------------------
/doc/train.md:
--------------------------------------------------------------------------------
1 | ## Comming Soon!
--------------------------------------------------------------------------------
/rslo/builder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/rslo/builder/__init__.py
--------------------------------------------------------------------------------
/rslo/builder/dataset_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Input reader builder.
16 |
17 | Creates data sources for DetectionModels from an InputReader config. See
18 | input_reader.proto for options.
19 |
20 | Note: If users wishes to also use their own InputReaders with the Object
21 | Detection configuration framework, they should define their own builder function
22 | that wraps the build function.
23 | """
24 |
25 | from rslo.protos import input_reader_pb2
26 | from rslo.data.dataset import get_dataset_class
27 | from rslo.data.preprocess import prep_pointcloud
28 | import numpy as np
29 | from functools import partial
30 | # from rslo.utils.config_tool import get_downsample_factor
31 |
32 |
33 | def build(input_reader_config,
34 | model_config,
35 | training,
36 | voxel_generator,
37 | multi_gpu=False,
38 | use_dist=False,
39 | split=None,
40 | use_hdf5=False
41 | ):
42 | """Builds a tensor dictionary based on the InputReader config.
43 |
44 | Args:
45 | input_reader_config: A input_reader_pb2.InputReader object.
46 |
47 | Returns:
48 | A tensor dict based on the input_reader_config.
49 |
50 | Raises:
51 | ValueError: On invalid input reader proto.
52 | ValueError: If no input paths are specified.
53 | """
54 | if not isinstance(input_reader_config, input_reader_pb2.InputReader):
55 | raise ValueError('input_reader_config not of type '
56 | 'input_reader_pb2.InputReader.')
57 | prep_cfg = input_reader_config.preprocess
58 | dataset_cfg = input_reader_config.dataset
59 | num_point_features = model_config.num_point_features
60 | cfg = input_reader_config
61 | db_sampler = None
62 | # if len(db_sampler_cfg.sample_groups) > 0 or db_sampler_cfg.database_info_path != "": # enable sample
63 | # db_sampler = dbsampler_builder.build(db_sampler_cfg)
64 | grid_size = voxel_generator.grid_size
65 | # feature_map_size = grid_size[:2] // out_size_factor
66 | # feature_map_size = [*feature_map_size, 1][::-1]
67 |
68 | dataset_cls = get_dataset_class(dataset_cfg.dataset_class_name)
69 | # assert dataset_cls.NumPointFeatures >= 3, "you must set this to correct value"
70 | # assert dataset_cls.NumPointFeatures == num_point_features, "currently you need keep them same"
71 |
72 | prep_func = partial(
73 | prep_pointcloud,
74 | root_path=dataset_cfg.kitti_root_path,
75 | voxel_generator=voxel_generator,
76 | # target_assigner=target_assigner,
77 | training=training,
78 | max_voxels=prep_cfg.max_number_of_voxels,
79 | shuffle_points=prep_cfg.shuffle_points,
80 |
81 | num_point_features=num_point_features, # dataset_cls.NumPointFeatures,
82 |
83 | # out_size_factor=out_size_factor,
84 | multi_gpu=multi_gpu,
85 | use_dist=use_dist,
86 | min_points_in_gt=prep_cfg.min_num_of_points_in_gt,
87 | random_flip_x=prep_cfg.random_flip_x,
88 | random_flip_y=prep_cfg.random_flip_y,
89 | rand_aug_ratio=prep_cfg.random_aug_ratio,
90 | sample_importance=prep_cfg.sample_importance,
91 | rand_rotation_eps=prep_cfg.rand_rotation_eps,
92 | rand_translation_eps=prep_cfg.rand_translation_eps,
93 | gen_tq_map=model_config.odom_predictor.pred_pyramid_motion, # !!!
94 | do_pre_transform=prep_cfg.do_pre_transform,
95 | cubic_tq_map=prep_cfg.cubic_tq_map,
96 | downsample_voxel_sizes=list(prep_cfg.downsample_voxel_sizes)
97 | )
98 |
99 | dataset = dataset_cls(
100 | info_path=dataset_cfg.kitti_info_path,
101 | root_path=dataset_cfg.kitti_root_path,
102 | seq_length=dataset_cfg.seq_length,
103 | skip=dataset_cfg.skip,
104 | random_skip=dataset_cfg.random_skip,
105 | prep_func=prep_func,
106 | step=dataset_cfg.step,
107 | num_point_features=num_point_features,
108 | split=split,
109 | )
110 |
111 | return dataset
112 |
--------------------------------------------------------------------------------
/rslo/builder/input_reader_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Input reader builder.
16 |
17 | Creates data sources for DetectionModels from an InputReader config. See
18 | input_reader.proto for options.
19 |
20 | Note: If users wishes to also use their own InputReaders with the Object
21 | Detection configuration framework, they should define their own builder function
22 | that wraps the build function.
23 | """
24 |
25 | from torch.utils.data import Dataset
26 |
27 | from rslo.builder import dataset_builder
28 | from rslo.protos import input_reader_pb2
29 |
30 |
31 | class DatasetWrapper(Dataset):
32 | """ convert our dataset to Dataset class in pytorch.
33 | """
34 |
35 | def __init__(self, dataset):
36 | self._dataset = dataset
37 |
38 | def __len__(self):
39 | return len(self._dataset)
40 |
41 | def __getitem__(self, idx):
42 | return self._dataset[idx]
43 |
44 | @property
45 | def dataset(self):
46 | return self._dataset
47 |
48 |
49 | def build(input_reader_config,
50 | model_config,
51 | training,
52 | voxel_generator,
53 | # target_assigner=None,
54 | multi_gpu=False,
55 | use_dist=False,
56 | split=None,
57 | ) -> DatasetWrapper:
58 | """
59 | Builds a tensor dictionary based on the InputReader config.
60 |
61 | Args:
62 | input_reader_config: A input_reader_pb2.InputReader object.
63 |
64 | Returns:
65 | A tensor dict based on the input_reader_config.
66 |
67 | Raises:
68 | ValueError: On invalid input reader proto.
69 | ValueError: If no input paths are specified.
70 | """
71 | if not isinstance(input_reader_config, input_reader_pb2.InputReader):
72 | raise ValueError('input_reader_config not of type '
73 | 'input_reader_pb2.InputReader.')
74 |
75 | dataset = dataset_builder.build(
76 | input_reader_config,
77 | model_config,
78 | training,
79 | voxel_generator,
80 | # target_assigner,
81 | multi_gpu=multi_gpu,
82 | use_dist=use_dist,
83 | split=split
84 | )
85 | dataset = DatasetWrapper(dataset)
86 | return dataset
87 |
--------------------------------------------------------------------------------
/rslo/builder/losses_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """A function to build localization and classification losses from config."""
17 |
18 | from rslo.core import losses
19 | # from rslo.core.ghm_loss import GHMCLoss, GHMRLoss
20 | from rslo.protos import losses_pb2
21 |
22 |
23 | def build(loss_config):
24 | """Build losses based on the config.
25 |
26 | Builds classification, localization losses and optionally a hard example miner
27 | based on the config.
28 |
29 | Args:
30 | loss_config: A losses_pb2.Loss object.
31 |
32 | """
33 |
34 |
35 | rotation_loss_func = _build_rotation_loss(
36 | loss_config.rotation_loss)
37 | translation_loss_func = _build_translation_loss(
38 | loss_config.translation_loss)
39 |
40 | if loss_config.pyramid_rotation_loss.loss_type != '':
41 | pyramid_rotation_loss = _build_rotation_loss(
42 | loss_config.pyramid_rotation_loss)
43 | else:
44 | pyramid_rotation_loss = rotation_loss_func
45 |
46 | if loss_config.pyramid_translation_loss.loss_type != '':
47 | pyramid_translation_loss = _build_translation_loss(
48 | loss_config.pyramid_translation_loss)
49 | else:
50 | pyramid_translation_loss = translation_loss_func
51 |
52 | consistency_loss = _build_consistency_loss(loss_config.consistency_loss)
53 |
54 | if loss_config.rigid_transform_loss.weight != 0:
55 | rigid_transform_loss = losses.RigidTransformLoss(
56 | rotation_loss_func, translation_loss_func, focal_gamma=loss_config.rigid_transform_loss.focal_gamma)
57 | py_rigid_transform_loss = losses.RigidTransformLoss(
58 | pyramid_rotation_loss, pyramid_translation_loss, focal_gamma=loss_config.rigid_transform_loss.focal_gamma)
59 |
60 | return (rigid_transform_loss, py_rigid_transform_loss, consistency_loss)
61 |
62 | return (rotation_loss_func, translation_loss_func,
63 | pyramid_rotation_loss, pyramid_translation_loss, consistency_loss,
64 | )
65 |
66 |
67 | def _build_translation_loss(loss_config):
68 | """Builds a translation loss function based on the loss config.
69 |
70 | Args:
71 | loss_config: A losses_pb2.TranslationLoss object.
72 |
73 | Returns:
74 | Loss based on the config.
75 |
76 | Raises:
77 | ValueError: On invalid loss_config.
78 | """
79 | if not isinstance(loss_config, losses_pb2.TranslationLoss):
80 | raise ValueError(
81 | 'loss_config not of type losses_pb2.TranslationLoss.')
82 |
83 | # loss_config.WhichOneof('localization_loss')
84 | loss_type = loss_config.loss_type
85 | loss_weight = loss_config.weight
86 | focal_gamma = loss_config.focal_gamma
87 | if loss_type == 'L2':
88 | return losses.L2Loss(loss_weight)
89 | elif loss_type == 'AdaptiveWeightedL2':
90 | if loss_config.balance_scale<=0:
91 | loss_config.balance_scale=1
92 | assert loss_config.balance_scale>0
93 | return losses.AdaptiveWeightedL2Loss(loss_config.init_alpha, learn_alpha=not loss_config.not_learn_alpha, loss_weight=loss_weight, focal_gamma=focal_gamma, balance_scale=loss_config.balance_scale)
94 | else:
95 | raise ValueError('Empty loss config.')
96 |
97 |
98 | def _build_rotation_loss(loss_config):
99 | """Builds a classification loss based on the loss config.
100 |
101 | Args:
102 | loss_config: A losses_pb2.RotationLoss object.
103 |
104 | Returns:
105 | Loss based on the config.
106 |
107 | Raises:
108 | ValueError: On invalid loss_config.
109 | """
110 | if not isinstance(loss_config, losses_pb2.RotaionLoss):
111 | raise ValueError(
112 | 'loss_config not of type losses_pb2.RotaionLoss.')
113 |
114 |
115 | loss_type = loss_config.loss_type
116 | loss_weight = loss_config.weight
117 | focal_gamma = loss_config.focal_gamma
118 |
119 | if loss_type == 'L2':
120 | return losses.L2Loss(loss_weight)
121 | elif loss_type == 'AdaptiveWeightedL2':
122 | if loss_config.balance_scale<=0:
123 | loss_config.balance_scale=1
124 | assert loss_config.balance_scale>0
125 | return losses.AdaptiveWeightedL2Loss(loss_config.init_alpha, learn_alpha=not loss_config.not_learn_alpha, loss_weight=loss_weight, focal_gamma=focal_gamma, balance_scale=loss_config.balance_scale)
126 | elif loss_type == 'AdaptiveWeightedL2RMatrixLoss':
127 | return losses.AdaptiveWeightedL2RMatrixLoss(loss_config.init_alpha, learn_alpha=not loss_config.not_learn_alpha, loss_weight=loss_weight, focal_gamma=focal_gamma)
128 |
129 | raise ValueError('Empty loss config.')
130 |
131 |
132 | def _build_consistency_loss(loss_config):
133 | if not isinstance(loss_config, losses_pb2.ConsistencyLoss):
134 | raise ValueError(
135 | 'loss_config not of type losses_pb2.ConsistencyLoss.')
136 |
137 | loss_type = loss_config.loss_type
138 | loss_weight = loss_config.weight
139 |
140 | if loss_type=='AdaptiveWeightedL2':
141 | return losses.AdaptiveWeightedL2Loss(loss_config.init_alpha, learn_alpha=not loss_config.not_learn_alpha, loss_weight=loss_weight, focal_gamma=loss_config.focal_gamma, balance_scale=loss_config.balance_scale )
142 | elif loss_type=='Aleat5_1ChamferL2NormalWeightedALLSVDLoss':
143 | assert loss_config.penalize_ratio>0
144 | assert loss_config.pred_downsample_ratio>0
145 | assert loss_config.reg_weight >0
146 | assert loss_config.sph_weight >0
147 | return losses.Aleat5_1ChamferL2NormalWeightedALLSVDLoss(loss_weight=loss_weight, penalize_ratio=loss_config.penalize_ratio, sample_block_size=loss_config.sample_block_size, norm=loss_config.norm, pred_downsample_ratio=loss_config.pred_downsample_ratio,reg_weight=loss_config.reg_weight, sph_weight=loss_config.sph_weight)
148 | else:
149 | print('Warning: Empty loss config.')
150 | return None
151 |
152 |
--------------------------------------------------------------------------------
/rslo/builder/lr_scheduler_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Functions to build DetectionModel training optimizers."""
17 |
18 | from torchplus.train import learning_schedules_fastai as lsf
19 | import torch
20 | import numpy as np
21 |
22 | def build(optimizer_config, optimizer, total_step):
23 | """Create lr scheduler based on config. note that
24 | lr_scheduler must accept a optimizer that has been restored.
25 |
26 | Args:
27 | optimizer_config: A Optimizer proto message.
28 |
29 | Returns:
30 | An optimizer and a list of variables for summary.
31 |
32 | Raises:
33 | ValueError: when using an unsupported input data type.
34 | """
35 | optimizer_type = optimizer_config.WhichOneof('optimizer')
36 |
37 | if optimizer_type == 'rms_prop_optimizer':
38 | config = optimizer_config.rms_prop_optimizer
39 | lr_scheduler = _create_learning_rate_scheduler(
40 | config.learning_rate, optimizer, total_step=total_step)
41 |
42 | if optimizer_type == 'momentum_optimizer':
43 | config = optimizer_config.momentum_optimizer
44 | lr_scheduler = _create_learning_rate_scheduler(
45 | config.learning_rate, optimizer, total_step=total_step)
46 |
47 | if optimizer_type == 'adam_optimizer':
48 | config = optimizer_config.adam_optimizer
49 | lr_scheduler = _create_learning_rate_scheduler(
50 | config.learning_rate, optimizer, total_step=total_step)
51 |
52 | return lr_scheduler
53 |
54 |
55 | def _create_learning_rate_scheduler(learning_rate_config, optimizer, total_step):
56 | """Create optimizer learning rate scheduler based on config.
57 |
58 | Args:
59 | learning_rate_config: A LearningRate proto message.
60 |
61 | Returns:
62 | A learning rate.
63 |
64 | Raises:
65 | ValueError: when using an unsupported input data type.
66 | """
67 | lr_scheduler = None
68 | learning_rate_type = learning_rate_config.WhichOneof('learning_rate')
69 | if learning_rate_type == 'multi_phase':
70 | config = learning_rate_config.multi_phase
71 | lr_phases = []
72 | mom_phases = []
73 | for phase_cfg in config.phases:
74 | lr_phases.append((phase_cfg.start, phase_cfg.lambda_func))
75 | mom_phases.append(
76 | (phase_cfg.start, phase_cfg.momentum_lambda_func))
77 | lr_scheduler = lsf.LRSchedulerStep(
78 | optimizer, total_step, lr_phases, mom_phases)
79 |
80 |
81 |
82 | if learning_rate_type == 'one_cycle':
83 | config = learning_rate_config.one_cycle
84 |
85 | if len(config.lr_maxs)>1:
86 | assert(len(config.lr_maxs)==4 )
87 | lr_max=[]
88 | # for i in range(len(config.lr_maxs)):
89 | # lr_max += [config.lr_maxs[i]]*optimizer.param_segs[i]
90 |
91 | lr_max = np.array(list(config.lr_maxs) )
92 | else:
93 | lr_max = config.lr_max
94 |
95 | lr_scheduler = lsf.OneCycle(
96 | optimizer, total_step, lr_max, list(config.moms), config.div_factor, config.pct_start)
97 | if learning_rate_type == 'exponential_decay':
98 | config = learning_rate_config.exponential_decay
99 | lr_scheduler = lsf.ExponentialDecay(
100 | optimizer, total_step, config.initial_learning_rate, config.decay_length, config.decay_factor, config.staircase)
101 | if learning_rate_type == 'exponential_decay_warmup':
102 | config = learning_rate_config.exponential_decay_warmup
103 | lr_scheduler = lsf.ExponentialDecayWarmup(
104 | optimizer, total_step, config.initial_learning_rate, config.decay_length, config.decay_factor, config.div_factor,
105 | config.pct_start, config.staircase)
106 | if learning_rate_type == 'manual_stepping':
107 | config = learning_rate_config.manual_stepping
108 | lr_scheduler = lsf.ManualStepping(
109 | optimizer, total_step, list(config.boundaries), list(config.rates))
110 |
111 | if lr_scheduler is None:
112 | raise ValueError('Learning_rate %s not supported.' %
113 | learning_rate_type)
114 |
115 | return lr_scheduler
116 |
--------------------------------------------------------------------------------
/rslo/builder/optimizer_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Functions to build DetectionModel training optimizers."""
16 |
17 | from torchplus.train import learning_schedules
18 | from torchplus.train import optim
19 | import torch
20 | from torch import nn
21 | from torchplus.train.fastai_optim import OptimWrapper, FastAIMixedOptim
22 | from functools import partial
23 |
24 |
25 | def children(m: nn.Module):
26 | "Get children of `m`."
27 | return list(m.children())
28 |
29 |
30 | def num_children(m: nn.Module) -> int:
31 | "Get number of children modules in `m`."
32 | return len(children(m))
33 |
34 | # return a list of smallest modules dy
35 |
36 |
37 | def flatten_model(m):
38 | if m is None:
39 | return []
40 | return sum(
41 | map(flatten_model, m.children()), []) if num_children(m) else [m]
42 |
43 |
44 | # def get_layer_groups(m): return [nn.Sequential(*flatten_model(m))]
45 | def get_layer_groups(m): return [nn.ModuleList(flatten_model(m))]
46 |
47 |
48 | def get_voxeLO_net_layer_groups(net):
49 | vfe_grp = get_layer_groups(net.voxel_feature_extractor)#[0]
50 | mfe_grp = get_layer_groups(net.middle_feature_extractor)#[0]
51 | op_grp = get_layer_groups(net.odom_predictor)#[0]
52 |
53 | # other_grp = get_layer_groups(net._rotation_loss) + \
54 | # get_layer_groups(net._translation_loss) \
55 | # + get_layer_groups(net._pyramid_rotation_loss) \
56 | # + get_layer_groups(net._pyramid_translation_loss) \
57 | # + get_layer_groups(net._consistency_loss)\
58 | other_grp = get_layer_groups(nn.Sequential(net._rotation_loss,
59 | net._translation_loss,
60 | net._pyramid_rotation_loss,
61 | net._pyramid_translation_loss,
62 | net._consistency_loss,
63 | ))
64 |
65 | return [vfe_grp, mfe_grp, op_grp,other_grp]
66 |
67 |
68 | def build(optimizer_config, net, name=None, mixed=False, loss_scale=512.0):
69 | """Create optimizer based on config.
70 |
71 | Args:
72 | optimizer_config: A Optimizer proto message.
73 |
74 | Returns:
75 | An optimizer and a list of variables for summary.
76 |
77 | Raises:
78 | ValueError: when using an unsupported input data type.
79 | """
80 | optimizer_type=optimizer_config.WhichOneof('optimizer')
81 | optimizer=None
82 |
83 | if optimizer_type == 'rms_prop_optimizer':
84 | config=optimizer_config.rms_prop_optimizer
85 | optimizer_func=partial(
86 | torch.optim.RMSprop,
87 | alpha=config.decay,
88 | momentum=config.momentum_optimizer_value,
89 | eps=config.epsilon)
90 |
91 | if optimizer_type == 'momentum_optimizer':
92 | config=optimizer_config.momentum_optimizer
93 | optimizer_func=partial(
94 | torch.optim.SGD,
95 | momentum=config.momentum_optimizer_value,
96 | eps=config.epsilon)
97 |
98 | if optimizer_type == 'adam_optimizer':
99 | config=optimizer_config.adam_optimizer
100 | if optimizer_config.fixed_weight_decay:
101 | optimizer_func=partial(
102 | torch.optim.Adam, betas=(0.9, 0.99), amsgrad=config.amsgrad)
103 | else:
104 | # regular adam
105 | optimizer_func=partial(
106 | torch.optim.Adam, amsgrad=config.amsgrad)
107 |
108 | # optimizer = OptimWrapper(optimizer, true_wd=optimizer_config.fixed_weight_decay, wd=config.weight_decay)
109 | optimizer=OptimWrapper.create(
110 | optimizer_func,
111 | 3e-3,
112 | # get_layer_groups(net),
113 | get_voxeLO_net_layer_groups(net),
114 | wd=config.weight_decay,
115 | true_wd=optimizer_config.fixed_weight_decay,
116 | bn_wd=True)
117 | print(hasattr(optimizer, "_amp_stash"), '_amp_stash')
118 | if optimizer is None:
119 | raise ValueError('Optimizer %s not supported.' % optimizer_type)
120 |
121 | if optimizer_config.use_moving_average:
122 | raise ValueError('torch don\'t support moving average')
123 | if name is None:
124 | # assign a name to optimizer for checkpoint system
125 | optimizer.name=optimizer_type
126 | else:
127 | optimizer.name=name
128 | return optimizer
129 |
--------------------------------------------------------------------------------
/rslo/builder/preprocess_builder.py:
--------------------------------------------------------------------------------
1 | import second.core.preprocess as prep
2 |
3 | def build_db_preprocess(db_prep_config):
4 | prep_type = db_prep_config.WhichOneof('database_preprocessing_step')
5 |
6 | if prep_type == 'filter_by_difficulty':
7 | cfg = db_prep_config.filter_by_difficulty
8 | return prep.DBFilterByDifficulty(list(cfg.removed_difficulties))
9 | elif prep_type == 'filter_by_min_num_points':
10 | cfg = db_prep_config.filter_by_min_num_points
11 | return prep.DBFilterByMinNumPoint(dict(cfg.min_num_point_pairs))
12 | else:
13 | raise ValueError("unknown database prep type")
14 |
15 |
--------------------------------------------------------------------------------
/rslo/builder/second_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 yanyan. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """VoxelNet builder.
16 | """
17 |
18 | from rslo.protos import second_pb2
19 | from rslo.builder import losses_builder
20 | from rslo.models.voxel_odom_net import get_voxelnet_class
21 | import rslo.models.voxel_odom_net
22 | # import rslo.models.voxel_odom_net_self_icp_enc
23 | # import rslo.models.voxel_odom_net_self_icp3_enc
24 | # import rslo.models.voxel_odom_net_self_icp_enc2
25 |
26 | def build(model_cfg: second_pb2.VoxelNet, voxel_generator,
27 | measure_time=False, testing=False):
28 | """build second pytorch instance.
29 | """
30 | if not isinstance(model_cfg, second_pb2.VoxelNet):
31 | raise ValueError('model_cfg not of type ' 'second_pb2.VoxelNet.')
32 | vfe_num_filters = list(model_cfg.voxel_feature_extractor.num_filters)
33 | vfe_with_distance = model_cfg.voxel_feature_extractor.with_distance # ??
34 | grid_size = voxel_generator.grid_size
35 |
36 | dense_shape = [1] + grid_size[::-1].tolist() + [vfe_num_filters[-1]]
37 | pc_range = voxel_generator.point_cloud_range
38 | print(dense_shape, '!!!', flush=True) # [1, 40, 1600, 1408, 16] [1, z,y,x, 16]
39 |
40 | num_input_features = model_cfg.num_point_features
41 |
42 |
43 |
44 | losses = losses_builder.build(model_cfg.loss)
45 | # encode_rad_error_by_sin = model_cfg.encode_rad_error_by_sin
46 | # cls_loss_ftor, loc_loss_ftor, cls_weight, loc_weight, _ = losses
47 | # rotation_loss_func, translation_loss_func, pyramid_rotation_loss_func, pyramid_translation_loss_func, consistency_loss = losses
48 | rigid_transform_loss=None
49 | py_rigid_transform_loss=None
50 | rotation_loss_func=None
51 | translation_loss_func=None
52 | pyramid_rotation_loss_func=None
53 | pyramid_translation_loss_func=None
54 | consistency_loss=None
55 | if len(losses) == 3:
56 | rigid_transform_loss, py_rigid_transform_loss, consistency_loss = losses
57 |
58 | else:
59 | rotation_loss_func, translation_loss_func, pyramid_rotation_loss_func, pyramid_translation_loss_func, consistency_loss = losses
60 |
61 | net = get_voxelnet_class(model_cfg.network_class_name)(
62 | dense_shape,
63 | pc_range=pc_range,
64 | vfe_class_name=model_cfg.voxel_feature_extractor.module_class_name,
65 | vfe_num_filters=vfe_num_filters,
66 | middle_class_name=model_cfg.middle_feature_extractor.module_class_name,
67 | middle_num_input_features=model_cfg.middle_feature_extractor.num_input_features,
68 | middle_num_filters_d1=list(
69 | model_cfg.middle_feature_extractor.num_filters_down1),
70 | middle_num_filters_d2=list(
71 | model_cfg.middle_feature_extractor.num_filters_down2),
72 | middle_use_leakyReLU=model_cfg.middle_feature_extractor.use_leakyReLU,
73 | middle_relu_type=model_cfg.middle_feature_extractor.relu_type,
74 | odom_class_name=model_cfg.odom_predictor.module_class_name, # "ResNetOdomPred",
75 | odom_num_input_features=model_cfg.odom_predictor.num_input_features, # -1,
76 | odom_layer_nums=model_cfg.odom_predictor.layer_nums, # [3, 5, 5],
77 | # [2, 2, 2],
78 | odom_layer_strides=model_cfg.odom_predictor.layer_strides,
79 | # [128, 128, 256],
80 | odom_num_filters=model_cfg.odom_predictor.num_filters,
81 | # [1, 2, 4],
82 | odom_upsample_strides=model_cfg.odom_predictor.upsample_strides,
83 | # [256, 256, 256],
84 | odom_num_upsample_filters=model_cfg.odom_predictor.num_upsample_filters,
85 | odom_pooling_size=model_cfg.odom_predictor.pool_size,
86 | odom_pooling_type=model_cfg.odom_predictor.pool_type,
87 | odom_cycle_constraint=model_cfg.odom_predictor.cycle_constraint,
88 | odom_conv_type=model_cfg.odom_predictor.conv_type,
89 | odom_format=model_cfg.odom_predictor.odom_format,
90 | # odom_use_spc=model_cfg.odom_predictor.use_sparse_conv,
91 | odom_pred_pyramid_motion=model_cfg.odom_predictor.pred_pyramid_motion,
92 | odom_use_deep_supervision=model_cfg.odom_predictor.use_deep_supervision,
93 | odom_use_loss_mask=not model_cfg.odom_predictor.not_use_loss_mask,
94 | odom_use_dynamic_mask=model_cfg.odom_predictor.use_dynamic_mask,
95 | odom_dense_predict=model_cfg.odom_predictor.dense_predict,
96 | odom_use_corr=model_cfg.odom_predictor.use_corr,
97 | odom_dropout=model_cfg.odom_predictor.dropout,
98 | odom_conf_type=model_cfg.odom_predictor.conf_type,
99 | odom_use_SPGN=model_cfg.odom_predictor.use_SPGN,
100 | odom_use_leakyReLU=model_cfg.odom_predictor.use_leakyReLU,
101 | vfe_use_norm=not model_cfg.voxel_feature_extractor.not_use_norm, # True,
102 | # middle_use_norm=not model_cfg.middle_feature_extractor.not_use_norm, # True,
103 | # odom_use_norm=not model_cfg.odom_predictor.not_use_norm, # True,
104 | middle_bn_type = model_cfg.middle_feature_extractor.bn_type,
105 | odom_bn_type = model_cfg.odom_predictor.bn_type,
106 | odom_enc_use_norm = not model_cfg.odom_predictor.not_use_enc_norm,
107 | odom_dropout_input=model_cfg.odom_predictor.dropout_input,
108 | odom_first_conv_groups=max(
109 | 1, model_cfg.odom_predictor.first_conv_groups),
110 | odom_use_se=model_cfg.odom_predictor.odom_use_se,
111 | odom_use_sa=model_cfg.odom_predictor.odom_use_sa,
112 | odom_use_svd=model_cfg.odom_predictor.use_svd,
113 | odom_cubic_pred_height=model_cfg.odom_predictor.cubic_pred_height,
114 | freeze_bn = model_cfg.freeze_bn,
115 | freeze_bn_affine=model_cfg.freeze_bn_affine,
116 | freeze_bn_start_step=model_cfg.freeze_bn_start_step,
117 | # sync_bn=model_cfg.sync_bn,
118 | use_GN=model_cfg.use_GN,
119 | num_input_features=num_input_features,
120 |
121 | encode_background_as_zeros=model_cfg.encode_background_as_zeros,
122 |
123 | with_distance=vfe_with_distance,
124 | rotation_loss=rotation_loss_func,
125 | translation_loss=translation_loss_func,
126 | pyramid_rotation_loss=pyramid_rotation_loss_func,
127 | pyramid_translation_loss=pyramid_translation_loss_func,
128 | rigid_transform_loss=rigid_transform_loss,
129 | pyramid_rigid_transform_loss = py_rigid_transform_loss,
130 | consistency_loss=consistency_loss,
131 | measure_time=measure_time,
132 | voxel_generator=voxel_generator,
133 | pyloss_exp_w_base=model_cfg.loss.pyloss_exp_w_base,
134 | testing=testing,
135 | icp_iter=model_cfg.icp_iter,
136 | )
137 | return net
--------------------------------------------------------------------------------
/rslo/builder/voxel_builder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # from spconv.utils import VoxelGeneratorV2, VoxelGenerator
4 | from spconv.utils import VoxelGenerator
5 | from rslo.protos import voxel_generator_pb2
6 |
7 |
8 | # def build(voxel_config):
9 | # """Builds a tensor dictionary based on the InputReader config.
10 |
11 | # Args:
12 | # input_reader_config: A input_reader_pb2.InputReader object.
13 |
14 | # Returns:
15 | # A tensor dict based on the input_reader_config.
16 |
17 | # Raises:
18 | # ValueError: On invalid input reader proto.
19 | # ValueError: If no input paths are specified.
20 | # """
21 | # if not isinstance(voxel_config, (voxel_generator_pb2.VoxelGenerator)):
22 | # raise ValueError('input_reader_config not of type '
23 | # 'input_reader_pb2.InputReader.')
24 | # voxel_generator = VoxelGeneratorV2(
25 | # voxel_size=list(voxel_config.voxel_size),
26 | # point_cloud_range=list(voxel_config.point_cloud_range),
27 | # max_num_points=voxel_config.max_number_of_points_per_voxel,
28 | # max_voxels=20000,
29 | # full_mean=voxel_config.full_empty_part_with_mean,
30 | # block_filtering=voxel_config.block_filtering,
31 | # block_factor=voxel_config.block_factor,
32 | # block_size=voxel_config.block_size,
33 | # height_threshold=voxel_config.height_threshold)
34 | # return voxel_generator
35 |
36 | class _VoxelGenerator(VoxelGenerator):
37 | def __init__(self, *args, **kwargs):
38 | super(_VoxelGenerator, self).__init__(*args, **kwargs)
39 |
40 | @property
41 | def grid_size(self):
42 | point_cloud_range = np.array(self.point_cloud_range)
43 | voxel_size = np.array(self.voxel_size)
44 | g_size = (point_cloud_range[3:] - point_cloud_range[:3]) / voxel_size
45 | g_size = np.round(g_size).astype(np.int64)
46 | return g_size
47 |
48 | def generate(self, points, max_voxels=None):
49 | res = super(_VoxelGenerator, self).generate(points, max_voxels)
50 |
51 | return {"voxels": res[0],
52 | "coordinates": res[1],
53 | "num_points_per_voxel": res[2]
54 | }
55 |
56 |
57 | def build(voxel_config):
58 | """Builds a tensor dictionary based on the InputReader config.
59 |
60 | Args:
61 | input_reader_config: A input_reader_pb2.InputReader object.
62 |
63 | Returns:
64 | A tensor dict based on the input_reader_config.
65 |
66 | Raises:
67 | ValueError: On invalid input reader proto.
68 | ValueError: If no input paths are specified.
69 | """
70 | if not isinstance(voxel_config, (voxel_generator_pb2.VoxelGenerator)):
71 | raise ValueError('input_reader_config not of type '
72 | 'input_reader_pb2.InputReader.')
73 | # voxel_generator = VoxelGenerator2(
74 |
75 | voxel_config.block_filtering=True
76 | assert(voxel_config.block_filtering)
77 | if voxel_config.block_filtering:
78 | voxel_config.block_factor = max(1,voxel_config.block_factor)
79 | voxel_config.block_size = voxel_config.block_size if voxel_config.block_size>0 else 8
80 |
81 | voxel_config.height_threshold = voxel_config.height_threshold if voxel_config.height_threshold!=0 else 0.2
82 |
83 | voxel_generator = _VoxelGenerator(
84 | voxel_size=list(voxel_config.voxel_size),
85 | point_cloud_range=list(voxel_config.point_cloud_range),
86 | max_num_points=voxel_config.max_number_of_points_per_voxel,
87 | max_voxels=20000,
88 | full_mean=False,
89 | # full_mean=voxel_config.full_empty_part_with_mean,
90 | block_filtering=voxel_config.block_filtering,
91 | block_factor=voxel_config.block_factor,
92 | block_size=voxel_config.block_size,
93 | height_threshold=voxel_config.height_threshold
94 | )
95 | return voxel_generator
96 |
--------------------------------------------------------------------------------
/rslo/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/rslo/core/__init__.py
--------------------------------------------------------------------------------
/rslo/data/__init__.py:
--------------------------------------------------------------------------------
1 | from . import kitti_dataset_hdf5
2 | from . import kitti_dataset_crossnorm_hdf5
3 | # from . import kitti_dataset_tmp
4 |
--------------------------------------------------------------------------------
/rslo/data/dataset.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import pickle
3 | import time
4 | from functools import partial
5 |
6 | import numpy as np
7 | import torch
8 | import rslo.utils.pose_utils as tch_p
9 |
10 | # from rslo.core import box_np_ops
11 | # from rslo.core import preprocess as prep
12 | from rslo.data import kitti_common as kitti
13 |
14 | REGISTERED_DATASET_CLASSES = {}
15 |
16 |
17 | def register_dataset(cls, name=None):
18 | global REGISTERED_DATASET_CLASSES
19 | if name is None:
20 | name = cls.__name__
21 | assert name not in REGISTERED_DATASET_CLASSES, f"exist class: {REGISTERED_DATASET_CLASSES}"
22 | REGISTERED_DATASET_CLASSES[name] = cls
23 | return cls
24 |
25 |
26 | def get_dataset_class(name):
27 | global REGISTERED_DATASET_CLASSES
28 | assert name in REGISTERED_DATASET_CLASSES, f"available class: {REGISTERED_DATASET_CLASSES}"
29 | return REGISTERED_DATASET_CLASSES[name]
30 |
31 |
32 | class Dataset(object):
33 | NumPointFeatures = -1
34 |
35 | def __getitem__(self, index):
36 |
37 | raise NotImplementedError
38 |
39 | def __len__(self):
40 | raise NotImplementedError
41 |
42 | def get_sensor_data(self, query):
43 |
44 | raise NotImplementedError
45 |
46 | def evaluation(self, dt_annos, output_dir):
47 | """Dataset must provide a evaluation function to evaluate model."""
48 | raise NotImplementedError
49 |
50 |
51 |
52 | def generate_pointwise_local_transformation_tch(tq, spatial_size, origin_loc, voxel_size, inv_trans_factor=-1, ):
53 | '''
54 | x up, y left
55 | x right, y up
56 | '''
57 | # assert(len(spatial_size) == 2)
58 | # if isinstance(tq, np.ndarray):
59 | device= tq.device
60 | dtype = tq.dtype
61 | # device=torch.device(type='cuda', index=0)
62 | # tq = torch.from_numpy(tq.astype(np.float32))
63 | # tq=tq.to(device=device)
64 | t_g = tq[:3]
65 | q_g = tq[3:]
66 |
67 | if len(spatial_size) == 2:
68 | size_x, size_y = spatial_size
69 | size_z = size_x//size_x
70 | elif len(spatial_size) == 3:
71 | size_x, size_y, size_z = spatial_size
72 | # size_z = 1
73 | else:
74 | raise ValueError()
75 |
76 | # generate coordinates grid
77 |
78 | iv, jv, kv = torch.meshgrid(torch.arange(size_y,device=device), torch.arange(
79 | size_x,device=device), torch.arange(size_z,device=device)) # (size_x,size_y)
80 | # move the origin to the middle
81 | # minus 0.5 to shift to the centre of each grid
82 | # xv = (jv - origin_loc[0]+0.5)*voxel_size[0]
83 | # yv = (-iv + origin_loc[1]-0.5)*voxel_size[1]
84 | xv = (jv - origin_loc[0])*voxel_size[0]
85 | yv = (-iv + origin_loc[1])*voxel_size[1]
86 | # zv = np.zeros_like(xv)
87 | zv = (kv-origin_loc[2])*voxel_size[2]
88 |
89 | # xyzv = np.stack([xv, yv, zv], axis=0).reshape([-1, 3]) # Nx3
90 | # Nx3 # fixed on 7/11/2019
91 | xyzv = torch.stack([xv, yv, zv], dim=-1).reshape([-1, 3]).to(dtype=dtype)
92 |
93 | if inv_trans_factor > 0:
94 | xyzv[:, :2] = inv_trans_factor / \
95 | (np.linalg.norm(xyzv[:, :2], axis=1,
96 | keepdims=True)+0.1) ** 2 * xyzv[:, :2]
97 |
98 | t_l = tch_p.rotate_vec_by_q(t=t_g[None, ...]-xyzv, q=tch_p.qinv(
99 | q_g[None, ...]).repeat([xyzv.shape[0],1]) ) + xyzv
100 |
101 | # t_map = t_l.reshape([size_x, size_y, 3])
102 | # t_map = t_l.reshape([size_y, size_x, 3])
103 | #bug fixed on 8/1/2020
104 | t_map = t_l.reshape([size_y, size_x, size_z, 3])
105 | # t_map = t_l.reshape([size_x, size_y, size_z, 3])
106 |
107 | # q_map = np.ones([size_y, size_x, 4], np.float32)*q_g
108 | #bug fixed on 8/1/2020
109 | q_map = torch.ones([size_y, size_x, size_z, 4], dtype=dtype, device=device)*q_g
110 | # q_map = np.ones([size_x, size_y, size_z, 4], np.float32)*q_g
111 |
112 | # tq_map = np.concatenate([t_map, q_map], axis=-1).transpose([2, 0, 1])
113 | tq_map = torch.cat(
114 | [t_map, q_map], dim=-1).permute(3, 2, 0, 1).squeeze() # channel, z,y,x
115 | # print(tq_map.shape,'!!!')
116 | return tq_map#torch.from_numpy(tq_map)
117 |
118 |
119 |
120 |
121 | def from_pointwise_local_transformation_tch(tq_map, pc_range, inv_trans_factor=-1):
122 | '''
123 | x up, y left
124 | '''
125 |
126 | dtype = tq_map.dtype
127 | device = tq_map.device
128 | input_shape = tq_map.shape
129 | spatial_size = input_shape[2:]
130 | if len(spatial_size)==2:
131 | spatial_size = [1]+list(spatial_size)
132 |
133 | grid_size = torch.from_numpy(
134 | np.array(list(spatial_size[::-1]))).to(device=device, dtype=dtype) # attention
135 | # voxel_size = np.array(voxel_generator.voxel_size)*4 # x,y,z
136 | pc_range = torch.from_numpy(pc_range).to(device=device, dtype=dtype)
137 |
138 | # voxel_size = (pc_range[3:5]-pc_range[0:2])/grid_size
139 | voxel_size = (pc_range[3:]-pc_range[:3])/grid_size
140 |
141 | # origin_loc = (0-pc_range[0])/(pc_range[3]-pc_range[0])*grid_size[0], (pc_range[4]-(
142 | # pc_range[4]+pc_range[1])/2)/(pc_range[4]-pc_range[1])*grid_size[1], 0
143 | # #fixed on 7/11/2019 the left-top of a grid is the anchor point
144 |
145 | origin_loc = (0-pc_range[0])/(pc_range[3]-pc_range[0])*(grid_size[0]
146 | ), (pc_range[4]-0)/(pc_range[4]-pc_range[1])*(grid_size[1]), (0-pc_range[2])/(pc_range[5]-pc_range[2]) * grid_size[2]
147 |
148 | assert len(input_shape) == 4
149 | # tq_map = torch.cat([t_map, q_map], dim=-1).permute([2, 0, 1])
150 | # tq_map = tq_map.permute([1,2, 0])
151 | #bug fixed on 8/1/2020
152 | tq_map = tq_map.permute([0, 2, 3, 1]).contiguous() #b,y,x,c !!!
153 | # tq_map = tq_map.permute([0, 3, 2, 1]).contiguous() #b,x,y,c
154 | # t_map = t_l.view([size_x, size_y, 3])
155 | tq_map = tq_map.view(-1, 7)
156 |
157 | t_l = tq_map[:, :3]
158 | q_l = tq_map[:, 3:]
159 |
160 | size_z,size_y, size_x= spatial_size
161 | iv, jv, kv = torch.meshgrid([
162 | torch.arange(size_y, dtype=dtype, device=device),
163 | torch.arange(size_x, dtype=dtype, device=device),
164 | torch.arange(size_z, dtype=dtype, device=device )]) # (size_x,size_y)
165 | # move the origin to the middle
166 | # minus 0.5 to shift to the centre of each grid
167 | # xv = (jv - origin_loc[0]+0.5)*voxel_size[0]
168 | # yv = (-iv + origin_loc[1]-0.5)*voxel_size[1]
169 | xv = (jv - origin_loc[0])*voxel_size[0]
170 | yv = (-iv + origin_loc[1])*voxel_size[1]
171 | zv = (kv-origin_loc[2])*voxel_size[2]
172 | # zv = torch.zeros_like(xv)
173 |
174 | # xyzv = torch.stack([xv, yv, zv], dim=0).reshape([-1, 3]) # Nx3
175 | # Nx3 # fixed on 7/11/2019
176 | xyzv = torch.stack([xv, yv, zv], dim=-1).reshape([-1, 3])
177 | if inv_trans_factor > 0:
178 | xyzv[:, :2] = inv_trans_factor / \
179 | (torch.norm(xyzv[:, :2], dim=1, keepdim=True) +
180 | 0.1) ** 2 * xyzv[:, :2]
181 |
182 | xyzv = torch.cat([xyzv]*input_shape[0], dim=0)
183 |
184 | # import pdb
185 | # pdb.set_trace()
186 |
187 | t_g = tch_p.rotate_vec_by_q(t=(t_l-xyzv), q=q_l) + xyzv
188 |
189 | #bug fixed on 8/1/2020
190 | t_map_g = t_g.view(input_shape[0], input_shape[2], input_shape[3], 3) #b,y,x,c
191 | # t_map_g = t_g.view(input_shape[0], input_shape[3], input_shape[2], 3) #b,x,y,c
192 |
193 | # torch.ones([spatial_size[0], spatial_size[1], 4], np.float32)*q_g
194 | #bug fixed on 8/1/2020
195 | q_map_g = q_l.view(input_shape[0], input_shape[2], input_shape[3], 4)
196 | # q_map_g = q_l.view(input_shape[0], input_shape[3], input_shape[2], 4)
197 |
198 | q_map_g = torch.nn.functional.normalize(q_map_g, dim=-1)
199 |
200 | # tq_map_g = torch.cat([t_map_g, q_map_g], dim=-
201 | # 1).permute([0, 3, 1, 2]).contiguous()
202 | #bug fixed on 8/1/2020
203 | tq_map_g = torch.cat([t_map_g, q_map_g], dim=-
204 | 1).permute([0, 3, 1, 2]).contiguous()
205 | # tq_map_g = torch.cat([t_map_g, q_map_g], dim=-
206 | # 1).permute([0, 3, 2, 1]).contiguous() #b,c,y,x
207 |
208 | return tq_map_g
--------------------------------------------------------------------------------
/rslo/layers/MaskConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def set_requires_grad(nets, requires_grad=False):
6 | if not isinstance(nets, list):
7 | nets = [nets]
8 | for net in nets:
9 | if net is not None:
10 | for param in net.parameters():
11 | param.requires_grad = requires_grad
12 |
13 | class MaskMaxPool2d(nn.MaxPool2d):
14 | def forward(self,x):
15 | if isinstance(x, (list, tuple)):
16 | return super(MaskMaxPool2d, self).forward(x[0]), x[1]
17 | else:
18 | return super(MaskMaxPool2d, self).forward(x)
19 |
20 | class MaskConv(nn.Module):
21 |
22 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, max_pool_mask=True, groups=1):
23 | super(MaskConv, self).__init__()
24 | # assert(kernel_size % 2 == 1)
25 |
26 | self.out_channels = out_channels
27 | self.use_bias = bias
28 | # pad = kernel_size//2
29 | # self.pad = nn.ZeroPad2d(padding)
30 | self.conv1 = nn.Conv2d(in_channels, out_channels,
31 | kernel_size=kernel_size, stride=stride, bias=False, padding=padding, groups=groups)
32 | self.max_pool_mask = max_pool_mask
33 | if max_pool_mask:
34 | self.mask_pool = nn.MaxPool2d(
35 | kernel_size, stride=stride, padding=padding)
36 | self.normalize_const = 1
37 | # print("max_pool_mask")
38 | else:
39 | self.mask_pool = nn.Conv2d(1, 1,
40 | kernel_size, stride, bias=False, padding=padding)
41 | set_requires_grad(self.mask_pool, requires_grad=False)
42 | nn.init.constant_(self.mask_pool.weight, 1)
43 | print("conv_mask")
44 | # self.normalize_const = kernel_size*kernel_size
45 | # def sparse_conv(self, tensor, binary_mask):
46 | def _conv(self, tensor, binary_mask):
47 |
48 | # if binary_mask is None:
49 | # binary_mask = torch.ones([b, 1, h, w]).cuda()
50 | # binary_mask = binary_mask.detach() ##comment on 14/11/19
51 | # print(binary_mask.requires_grad, '!!!')
52 | binary_mask = binary_mask
53 | tensor = self.conv1(tensor)
54 | mask = self.mask_pool(binary_mask) #/self.normalize_const
55 | if not self.max_pool_mask:
56 | # self.normalize_const = torch.max(mask, dim=(2,3), keepdim=True)
57 | self.normalize_const,_ = torch.max(mask.view(mask.shape[0], -1), dim=-1, keepdim=True)
58 | self.normalize_const = self.normalize_const.view(mask.shape[0], 1,1,1)
59 |
60 | mask /= self.normalize_const
61 | # print(mask.requires_grad, '!!!!')
62 | # return tensor, mask.detach() #comment on 14/11/19
63 | return tensor, mask.detach_()
64 |
65 | def forward(self, x):
66 |
67 | if not isinstance(x, (list, tuple)):
68 | x = [x, (torch.sum(x.abs(), dim=1, keepdim=True) != 0).float().detach()]
69 | tensor, binary_mask = x
70 |
71 | tensor, binary_mask = self._conv(tensor, binary_mask)
72 |
73 | return [tensor, binary_mask]
74 |
75 |
76 | class MaskConvTranspose2d(nn.ConvTranspose2d):
77 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
78 | padding=0, output_padding=0, groups=1, bias=True, dilation=1):
79 | super(MaskConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride=1,
80 | padding=0, output_padding=0, groups=1, bias=True, dilation=1)
81 |
82 | def forward(self, x):
83 | if not isinstance(x, (list, tuple)):
84 | x = [x, (torch.sum(x, dim=1, keepdim=True) != 0).float().detach()]
85 | tensor, binary_mask = x
86 | return super(MaskConvTranspose2d, self).forward(tensor), binary_mask.detach()
87 |
--------------------------------------------------------------------------------
/rslo/layers/common.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | def set_requires_grad(nets, requires_grad=False):
6 | if not isinstance(nets, list):
7 | nets = [nets]
8 | for net in nets:
9 | if net is not None:
10 | for param in net.parameters():
11 | param.requires_grad = requires_grad
12 |
13 | class ExpandDims(nn.Module):
14 | def __init__(self, insert_dims, dim):
15 | super().__init__()
16 | self.insert_dims = insert_dims
17 | self.dim = dim
18 |
19 | def _new_shape(self, old_shape):
20 | new_shape = list(old_shape)
21 | del new_shape[self.dim]
22 | new_shape[self.dim:self.dim] = self.insert_dims
23 | return new_shape
24 | def forward(self, x):
25 | new_shape = self._new_shape(x.shape)
26 | return x.reshape(new_shape)
27 |
28 | class ELUPlus(nn.ELU):
29 | def forward(self, x):
30 | return super(ELUPlus, self).forward(x)+1+1e-6
31 |
32 |
33 | class EXP(nn.Module):
34 | def __init__(self, truncated_threshold=1e20):
35 | super().__init__()
36 | self.truncated_threshold = truncated_threshold
37 | self.truncated_point = np.log(self.truncated_threshold)
38 |
39 | def forward(self, x):
40 | x = torch.tanh(x/self.truncated_point)*self.truncated_point
41 | out = torch.exp(x)+1e-6
42 | return out
43 |
44 |
45 | class ParameterLayer(nn.Module):
46 | def __init__(self, init_value, requires_grad=True):
47 | super(ParameterLayer, self).__init__()
48 |
49 | self.param = nn.Parameter(init_value, requires_grad=requires_grad)
50 |
51 | def forward(self, *input):
52 | return self.param
53 |
54 |
55 | class MaskPropagator(nn.Module):
56 |
57 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, max_pool_mask=True, groups=1):
58 | super(MaskPropagator, self).__init__()
59 | # assert(kernel_size % 2 == 1)
60 | self.max_pool_mask = max_pool_mask
61 | if max_pool_mask:
62 | self.mask_pool = nn.MaxPool2d(
63 | kernel_size, stride=stride, padding=padding)
64 | self.normalize_const = 1
65 | else:
66 | self.mask_pool = nn.Conv2d(1, 1,
67 | kernel_size, stride, bias=False, padding=padding)
68 | set_requires_grad(self.mask_pool, requires_grad=False)
69 | nn.init.constant_(self.mask_pool.weight, 1)
70 | self.normalize_const = 1
71 | def _conv(self, binary_mask):
72 |
73 | binary_mask = binary_mask
74 | mask = self.mask_pool(binary_mask) #/self.normalize_const
75 | if not self.max_pool_mask:
76 | # self.normalize_const = torch.max(mask, dim=(2,3), keepdim=True)
77 | # self.normalize_const,_ = torch.max(mask.view(mask.shape[0], -1), dim=-1, keepdim=True)
78 | # self.normalize_const = self.normalize_const.view(mask.shape[0], 1,1,1)
79 | self.normalizer_const = self.mask_pool((binary_mask>0).float() )
80 |
81 | mask /= self.normalize_const
82 | return mask.detach_()
83 |
84 | def forward(self, mask):
85 | mask = self._conv(mask)
86 | return mask
87 |
88 | class Dropout2dGivenMask(nn.Module):
89 | def __init__(self, p, dim=1):
90 | super(Dropout2dGivenMask, self).__init__()
91 |
92 | self.p = p
93 | self.dim = 1
94 |
95 | def forward(self, input, mask=None):
96 | if mask is None:
97 | p = torch.ones(input.shape[self.dim],
98 | device=input.device) * (1-self.p)
99 | mask = torch.bernoulli(p).view(1, input.shape[self.dim], 1, 1)
100 |
101 | mask = mask.expand_as(input)
102 | input = torch.where(mask > 0, input, torch.zeros_like(input))
103 |
104 | return input, mask
105 |
106 | class PCMaskGenerator(nn.Module):
107 | def __init__(self, kernel_sizes, strides, paddings, max_pool_mask=True, apply_func=None):
108 | super().__init__()
109 | assert isinstance(strides, (list, tuple))
110 | layer_num = len(strides)
111 |
112 | if not isinstance(kernel_sizes, (list, tuple)):
113 | kernel_sizes = [kernel_sizes]*layer_num
114 | if not isinstance(paddings, (list, tuple)):
115 | paddings = [paddings]*layer_num
116 | layers = []
117 | for k,s,p in zip(kernel_sizes, strides, paddings) :
118 | layers.append(MaskPropagator(1,1,k,s,p, max_pool_mask=max_pool_mask))
119 |
120 | self.layers = nn.ModuleList(layers)
121 | self._apply_func = apply_func
122 |
123 | def forward(self, x):
124 |
125 | masks = []
126 | for i in range(len(self.layers)):
127 | x = self.layers[i](x)
128 | if self._apply_func is not None:
129 | masks.append(self._apply_func(x))
130 | # x = self._apply_func(x)
131 | else:
132 | masks.append(x)
133 |
134 | return masks
135 |
--------------------------------------------------------------------------------
/rslo/layers/confidence.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class ConfidenceModule(nn.Module):
6 | def __init__(self, conf_model, conf_type='softmax'):
7 | super().__init__()
8 | assert conf_type in ['linear', 'softmax']
9 | self.conf_model = conf_model
10 | self.conf_type = conf_type
11 | self.softmax = nn.Softmax(dim=-1)
12 |
13 |
14 | def forward(self, x, extra_mask=None, temperature=1, return_logit=False):
15 |
16 | if extra_mask is None:
17 | extra_mask = torch.ones_like(x)
18 |
19 | if self.conf_type == 'linear':
20 | conf = (F.elu(self.conf_model(x))+1+1e-12) * (extra_mask+1e-12)
21 | elif self.conf_type == 'softmax':
22 | # conf = self.conf_model(x)
23 | logit = self.conf_model(x)
24 | # conf = torch.where(extra_mask > 0,
25 | # conf, torch.full_like(conf, -1e20))
26 | # min_conf = torch.min(conf)
27 | conf = torch.where(extra_mask > 0,
28 | logit, torch.full_like(logit, -1000))
29 |
30 | conf_shape = conf.shape
31 | conf = conf.reshape(*conf.shape[0:2], -1)
32 |
33 | conf = self.softmax(conf/temperature)
34 | conf = conf.reshape(*conf_shape)#.clone()
35 | if return_logit:
36 | return conf, logit
37 | else:
38 | return conf
--------------------------------------------------------------------------------
/rslo/layers/se_module.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 |
4 |
5 | class SELayer(nn.Module):
6 | def __init__(self, channel, reduction=16):
7 | super(SELayer, self).__init__()
8 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
9 | self.fc = nn.Sequential(
10 | nn.Linear(channel, channel // reduction, bias=False),
11 | nn.ReLU(inplace=True),
12 | nn.Linear(channel // reduction, channel, bias=False),
13 | nn.Sigmoid()
14 | )
15 |
16 | def forward(self, x):
17 | b, c, _, _ = x.size()
18 | y = self.avg_pool(x).view(b, c)
19 | y = self.fc(y).view(b, c, 1, 1)
20 | return x * y.expand_as(x)
21 |
22 |
23 | class SpatialAttentionLayer(nn.Module):
24 | def __init__(self, channel, reduction=16):
25 | super(SpatialAttentionLayer, self).__init__()
26 | self.attention = nn.Sequential(
27 | nn.Conv2d(channel, 1, kernel_size=1, padding=0),
28 | nn.Sigmoid()
29 | )
30 |
31 | def forward(self, x):
32 | y = self.attention(x)
33 | return x * y.expand_as(x)
34 |
35 | class SpatialAttentionLayerV2(nn.Module):
36 | def __init__(self, channel, reduction=16):
37 | super(SpatialAttentionLayerV2, self).__init__()
38 | self.attention = nn.Sequential(
39 | nn.Conv2d(channel, channel//2, kernel_size=3, padding=1, dilation=1),
40 | nn.Conv2d(channel//2, channel, kernel_size=3, padding=2, dilation=2),
41 | # nn.Conv2d(channel//2, channel, kernel_size=3, stride=2,padding=1),
42 | # nn.Conv2d(channel, channel, kernel_size=1, padding=0),
43 | # nn.Upsample(size=),
44 | nn.Conv2d(channel, 1, kernel_size=3, padding=1),
45 | nn.Sigmoid()
46 | )
47 |
48 | def forward(self, x):
49 | y = self.attention(x)
50 | return x * y.expand_as(x)
51 | class SpatialAttentionLayerV3(nn.Module):
52 | def __init__(self, channel, reduction=16, LayerNorm=nn.InstanceNorm2d):
53 | super(SpatialAttentionLayerV3, self).__init__()
54 | print(f"SpatialAttentionLayerV3: LayerNone={LayerNorm}")
55 | self.conv1 = nn.Sequential(
56 | nn.Conv2d(channel, 2*channel, kernel_size=3, stride=2, padding=1, ),
57 | LayerNorm(2*channel),
58 | nn.LeakyReLU()
59 | )
60 | self.conv2 = nn.Sequential(
61 | nn.Conv2d(2*channel, 2*channel, kernel_size=3, stride=1, padding=1, ),
62 | LayerNorm(2*channel),
63 | nn.LeakyReLU()
64 | )
65 | self.deconv1 = nn.Sequential(
66 | nn.Upsample(scale_factor=2),
67 | nn.Conv2d(2*channel, channel, kernel_size=3, stride=1, padding=1 ),
68 | LayerNorm(channel),
69 | nn.LeakyReLU()
70 | )
71 | self.conv3 = nn.Sequential(
72 | nn.Conv2d(2*channel, 1, kernel_size=1, stride=1, padding=0, ),
73 | nn.Sigmoid(),
74 | )
75 |
76 |
77 | def forward(self, x):
78 | x1= self.conv1(x)
79 | x1 = self.conv2(x1)
80 | x1 = self.deconv1(x1)
81 | x1 = torch.cat([x1,x], dim=1)
82 | y = self.conv3(x1)
83 |
84 | # y = self.attention(x)
85 | return x * y.expand_as(x)
--------------------------------------------------------------------------------
/rslo/layers/svd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import apex.amp as amp
4 |
5 |
6 |
7 | class SVDHead(nn.Module):
8 | def __init__(self, args=None):
9 | super(SVDHead, self).__init__()
10 | self.reflect = nn.Parameter(torch.eye(3), requires_grad=False)
11 | self.reflect[2, 2] = -1
12 |
13 | @amp.float_function
14 | def forward(self, src, tgt, weight=None ):
15 | '''
16 | src: Bx3xN
17 | tgt: Bx3xN
18 | weight: BxN
19 | '''
20 |
21 | batch_size = src.size(0)
22 |
23 | src_centered = src - src.mean(dim=2, keepdim=True) #Bx3xN
24 | src_corr_centered = tgt-tgt.mean(dim=2, keepdim=True) #Bx3xN
25 |
26 | if weight is None:
27 | H = torch.matmul(src_centered, src_corr_centered.transpose(2, 1).contiguous())
28 | else:
29 | H = torch.matmul(src_centered*weight[:,None,:], src_corr_centered.transpose(2, 1).contiguous() )# + torch.eye(3, device=src_centered.device)*1e-9
30 |
31 | # print(H, src_centered.mean(),src_corr_centered.mean(), src_corr_centered.shape)
32 | U, S, V = [], [], []
33 | R = []
34 |
35 | #iterate on the batch dimmension
36 | for i in range(src.size(0)):
37 | u, s, v = torch.svd(H[i])
38 | # u,s,v=u.cuda(),s.cuda(),v.cuda()
39 | r = torch.matmul(v, u.transpose(1, 0).contiguous())
40 | r_det = torch.det(r)
41 | if r_det < 0:
42 | u, s, v = torch.svd(H[i])
43 | v = torch.matmul(v, self.reflect.to(device=v.device))
44 | r = torch.matmul(v, u.transpose(1, 0).contiguous()) # r=v@u^T
45 | # r = r * self.reflect
46 | R.append(r)
47 | U.append(u)
48 | S.append(s)
49 | V.append(v)
50 |
51 | U = torch.stack(U, dim=0)
52 | V = torch.stack(V, dim=0)
53 | S = torch.stack(S, dim=0)
54 | R = torch.stack(R, dim=0)
55 |
56 | # t = torch.matmul(-R, src.mean(dim=2, keepdim=True)) + src_corr.mean(dim=2, keepdim=True)
57 | t = torch.matmul(-R, src.mean(dim=2, keepdim=True)) + tgt.mean(dim=2, keepdim=True)
58 | #Xs = R@Xt + t
59 | R = R.transpose(-1,-2).contiguous()
60 | t = -R@t
61 |
62 | # R = R.to(dtype=dtype)
63 | # t = t.to(dtype=dtype)
64 | return R, t.view(batch_size, 3)
65 |
--------------------------------------------------------------------------------
/rslo/protos/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/rslo/protos/__init__.py
--------------------------------------------------------------------------------
/rslo/protos/complile.sh:
--------------------------------------------------------------------------------
1 | cd ../../
2 | find . | grep "\./rslo/protos/.*\.proto$" | xargs -n1 -I {} protoc --proto_path . --python_out=. {}
3 |
4 | # find ls | grep ".proto" | xargs -n1 -I {} protoc --proto_path . --python_out=. {}
--------------------------------------------------------------------------------
/rslo/protos/input_reader.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package second.protos;
4 | import "rslo/protos/preprocess.proto";
5 | import "rslo/protos/sampler.proto";
6 |
7 | message InputReader {
8 | uint32 batch_size = 1;
9 |
10 | message Dataset {
11 | string kitti_info_path = 1;
12 | string kitti_root_path = 2;
13 | string dataset_class_name = 3; // support KittiDataset and NuScenesDataset
14 | int32 seq_length=4;
15 | int32 skip=5;
16 | bool random_skip=6;
17 | int32 step=7;
18 |
19 | }
20 | Dataset dataset = 2;
21 | message Preprocess {
22 | bool shuffle_points = 1;
23 | uint32 max_number_of_voxels = 2;
24 | repeated float groundtruth_localization_noise_std = 3;
25 | repeated float groundtruth_rotation_uniform_noise = 4;
26 | repeated float global_rotation_uniform_noise = 5;
27 | repeated float global_scaling_uniform_noise = 6;
28 | repeated float global_translate_noise_std = 7;
29 | bool remove_unknown_examples = 8;
30 | uint32 num_workers = 9;
31 | float anchor_area_threshold = 10;
32 | bool remove_points_after_sample = 11;
33 | float groundtruth_points_drop_percentage = 12;
34 | uint32 groundtruth_drop_max_keep_points = 13;
35 | bool remove_environment = 14;
36 | repeated float global_random_rotation_range_per_object = 15;
37 | repeated DatabasePreprocessingStep database_prep_steps = 16;
38 | Sampler database_sampler = 17;
39 | bool use_group_id = 18; // this will enable group sample and noise
40 | int64 min_num_of_points_in_gt = 19; // gt boxes contains less than this will be ignored.
41 | bool random_flip_x = 20;
42 | bool random_flip_y = 21;
43 | float sample_importance = 22;
44 | float rand_rotation_eps=23;
45 | float rand_translation_eps=24;
46 | float random_aug_ratio=25;
47 | float do_pre_transform=26;
48 | bool cubic_tq_map=27;
49 | repeated float downsample_voxel_sizes=28;
50 | }
51 | Preprocess preprocess = 3;
52 | uint32 max_num_epochs = 4; // deprecated
53 | uint32 prefetch_size = 5; // deprecated
54 |
55 | float review_cycle=6;
56 | }
57 |
--------------------------------------------------------------------------------
/rslo/protos/losses.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package second.protos;
4 |
5 | // Message for configuring the localization loss, classification loss and hard
6 | // example miner used for training object detection models. See core/losses.py
7 | // for details
8 | message Loss {
9 | RotaionLoss rotation_loss = 1;
10 | TranslationLoss translation_loss = 2;
11 | RotaionLoss pyramid_rotation_loss=3;
12 | TranslationLoss pyramid_translation_loss=4;
13 | ConsistencyLoss consistency_loss=5;
14 | float pyloss_exp_w_base=6;
15 | RigidTransformLoss rigid_transform_loss= 7;
16 | // float rotation_weight = 4;
17 |
18 | // float translation_weight = 5;
19 |
20 | }
21 |
22 | message RigidTransformLoss {
23 | // string loss_type =1;
24 | float weight = 2 ;
25 | // float init_alpha= 3;
26 | // bool not_learn_alpha=4;
27 | float focal_gamma=5;
28 | }
29 | message RotaionLoss {
30 | // oneof {
31 | // WeightedL2LocalizationLoss weighted_l2 = 1;
32 | // WeightedSmoothL1LocalizationLoss weighted_smooth_l1 = 2;
33 | // WeightedGHMLocalizationLoss weighted_ghm = 3;
34 | // }
35 |
36 | string loss_type =1;
37 | float weight = 2 ;
38 | float init_alpha= 3;
39 | bool not_learn_alpha=4;
40 | float focal_gamma=5;
41 | float balance_scale=6;
42 | }
43 | message TranslationLoss {
44 |
45 | string loss_type =1;
46 | float weight = 2 ;
47 | float init_alpha= 3;
48 | bool not_learn_alpha=4;
49 | float focal_gamma=5;
50 | float balance_scale=6;
51 | }
52 | message ConsistencyLoss{
53 |
54 | string loss_type =1;
55 | float weight = 2 ;
56 | float init_alpha= 3;
57 | bool not_learn_alpha=4;
58 | float focal_gamma=5;
59 | float penalize_ratio=6;
60 | repeated float sample_block_size=7;
61 | bool norm=8;
62 | float pred_downsample_ratio=9;
63 | float reg_weight=10;
64 | float sph_weight=11;
65 | }
66 |
--------------------------------------------------------------------------------
/rslo/protos/losses.proto.bak:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package second.protos;
4 |
5 | // Message for configuring the localization loss, classification loss and hard
6 | // example miner used for training object detection models. See core/losses.py
7 | // for details
8 | message Loss {
9 | // Localization loss to use.
10 | LocalizationLoss localization_loss = 1;
11 |
12 | // Classification loss to use.
13 | ClassificationLoss classification_loss = 2;
14 |
15 | // If not left to default, applies hard example mining.
16 | HardExampleMiner hard_example_miner = 3;
17 |
18 | // Classification loss weight.
19 | float classification_weight = 4;
20 |
21 | // Localization loss weight.
22 | float localization_weight = 5;
23 |
24 | }
25 |
26 | // Configuration for bounding box localization loss function.
27 | message LocalizationLoss {
28 | oneof localization_loss {
29 | WeightedL2LocalizationLoss weighted_l2 = 1;
30 | WeightedSmoothL1LocalizationLoss weighted_smooth_l1 = 2;
31 | WeightedGHMLocalizationLoss weighted_ghm = 3;
32 | }
33 | bool encode_rad_error_by_sin = 4;
34 | }
35 |
36 | // L2 location loss: 0.5 * ||weight * (a - b)|| ^ 2
37 | message WeightedL2LocalizationLoss {
38 | // DEPRECATED, do not use.
39 | // Output loss per anchor.
40 | bool anchorwise_output = 1;
41 | repeated float code_weight = 2;
42 | }
43 |
44 | // SmoothL1 (Huber) location loss: .5 * x ^ 2 if |x| < 1 else |x| - .5
45 | message WeightedSmoothL1LocalizationLoss {
46 | // DEPRECATED, do not use.
47 | // Output loss per anchor.
48 | bool anchorwise_output = 1;
49 | float sigma = 2;
50 | repeated float code_weight = 3;
51 | }
52 | message WeightedGHMLocalizationLoss {
53 | // DEPRECATED, do not use.
54 | // Output loss per anchor.
55 | bool anchorwise_output = 1;
56 | float mu = 2;
57 | int32 bins = 3;
58 | float momentum = 4;
59 | repeated float code_weight = 5;
60 | }
61 |
62 |
63 | // Configuration for class prediction loss function.
64 | message ClassificationLoss {
65 | oneof classification_loss {
66 | WeightedSigmoidClassificationLoss weighted_sigmoid = 1;
67 | WeightedSoftmaxClassificationLoss weighted_softmax = 2;
68 | BootstrappedSigmoidClassificationLoss bootstrapped_sigmoid = 3;
69 | SigmoidFocalClassificationLoss weighted_sigmoid_focal = 4;
70 | SoftmaxFocalClassificationLoss weighted_softmax_focal = 5;
71 | GHMClassificationLoss weighted_ghm = 6;
72 | }
73 | }
74 |
75 | // Classification loss using a sigmoid function over class predictions.
76 | message WeightedSigmoidClassificationLoss {
77 | // DEPRECATED, do not use.
78 | // Output loss per anchor.
79 | bool anchorwise_output = 1;
80 | }
81 |
82 | // Sigmoid Focal cross entropy loss as described in
83 | // https://arxiv.org/abs/1708.02002
84 | message SigmoidFocalClassificationLoss {
85 | // DEPRECATED, do not use.
86 | bool anchorwise_output = 1;
87 | // modulating factor for the loss.
88 | float gamma = 2;
89 | // alpha weighting factor for the loss.
90 | float alpha = 3;
91 | }
92 | // Sigmoid Focal cross entropy loss as described in
93 | // https://arxiv.org/abs/1708.02002
94 | message SoftmaxFocalClassificationLoss {
95 | // DEPRECATED, do not use.
96 | bool anchorwise_output = 1;
97 | // modulating factor for the loss.
98 | float gamma = 2;
99 | // alpha weighting factor for the loss.
100 | float alpha = 3;
101 | }
102 | message GHMClassificationLoss {
103 | bool anchorwise_output = 1;
104 | int32 bins = 2;
105 | float momentum = 3;
106 | }
107 | // Classification loss using a softmax function over class predictions.
108 | message WeightedSoftmaxClassificationLoss {
109 | // DEPRECATED, do not use.
110 | // Output loss per anchor.
111 | bool anchorwise_output = 1;
112 | // Scale logit (input) value before calculating softmax classification loss.
113 | // Typically used for softmax distillation.
114 | float logit_scale = 2;
115 | }
116 |
117 | // Classification loss using a sigmoid function over the class prediction with
118 | // the highest prediction score.
119 | message BootstrappedSigmoidClassificationLoss {
120 | // Interpolation weight between 0 and 1.
121 | float alpha = 1;
122 |
123 | // Whether hard boot strapping should be used or not. If true, will only use
124 | // one class favored by model. Othewise, will use all predicted class
125 | // probabilities.
126 | bool hard_bootstrap = 2;
127 |
128 | // DEPRECATED, do not use.
129 | // Output loss per anchor.
130 | bool anchorwise_output = 3;
131 | }
132 |
133 | // Configuation for hard example miner.
134 | message HardExampleMiner {
135 | // Maximum number of hard examples to be selected per image (prior to
136 | // enforcing max negative to positive ratio constraint). If set to 0,
137 | // all examples obtained after NMS are considered.
138 | int32 num_hard_examples = 1;
139 |
140 | // Minimum intersection over union for an example to be discarded during NMS.
141 | float iou_threshold = 2;
142 |
143 | // Whether to use classification losses ('cls', default), localization losses
144 | // ('loc') or both losses ('both'). In the case of 'both', cls_loss_weight and
145 | // loc_loss_weight are used to compute weighted sum of the two losses.
146 | enum LossType {
147 | BOTH = 0;
148 | CLASSIFICATION = 1;
149 | LOCALIZATION = 2;
150 | }
151 | LossType loss_type = 3;
152 |
153 | // Maximum number of negatives to retain for each positive anchor. If
154 | // num_negatives_per_positive is 0 no prespecified negative:positive ratio is
155 | // enforced.
156 | int32 max_negatives_per_positive = 4;
157 |
158 | // Minimum number of negative anchors to sample for a given image. Setting
159 | // this to a positive number samples negatives in an image without any
160 | // positive anchors and thus not bias the model towards having at least one
161 | // detection per image.
162 | int32 min_negatives_per_image = 5;
163 | }
164 |
--------------------------------------------------------------------------------
/rslo/protos/model.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package second.protos;
4 | import "rslo/protos/second.proto";
5 | message DetectionModel{
6 | oneof model {
7 | VoxelNet second = 1;
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/rslo/protos/model_pb2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by the protocol buffer compiler. DO NOT EDIT!
3 | # source: rslo/protos/model.proto
4 |
5 | import sys
6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
7 | from google.protobuf import descriptor as _descriptor
8 | from google.protobuf import message as _message
9 | from google.protobuf import reflection as _reflection
10 | from google.protobuf import symbol_database as _symbol_database
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from rslo.protos import second_pb2 as rslo_dot_protos_dot_second__pb2
17 |
18 |
19 | DESCRIPTOR = _descriptor.FileDescriptor(
20 | name='rslo/protos/model.proto',
21 | package='second.protos',
22 | syntax='proto3',
23 | serialized_options=None,
24 | serialized_pb=_b('\n\x17rslo/protos/model.proto\x12\rsecond.protos\x1a\x18rslo/protos/second.proto\"D\n\x0e\x44\x65tectionModel\x12)\n\x06second\x18\x01 \x01(\x0b\x32\x17.second.protos.VoxelNetH\x00\x42\x07\n\x05modelb\x06proto3')
25 | ,
26 | dependencies=[rslo_dot_protos_dot_second__pb2.DESCRIPTOR,])
27 |
28 |
29 |
30 |
31 | _DETECTIONMODEL = _descriptor.Descriptor(
32 | name='DetectionModel',
33 | full_name='second.protos.DetectionModel',
34 | filename=None,
35 | file=DESCRIPTOR,
36 | containing_type=None,
37 | fields=[
38 | _descriptor.FieldDescriptor(
39 | name='second', full_name='second.protos.DetectionModel.second', index=0,
40 | number=1, type=11, cpp_type=10, label=1,
41 | has_default_value=False, default_value=None,
42 | message_type=None, enum_type=None, containing_type=None,
43 | is_extension=False, extension_scope=None,
44 | serialized_options=None, file=DESCRIPTOR),
45 | ],
46 | extensions=[
47 | ],
48 | nested_types=[],
49 | enum_types=[
50 | ],
51 | serialized_options=None,
52 | is_extendable=False,
53 | syntax='proto3',
54 | extension_ranges=[],
55 | oneofs=[
56 | _descriptor.OneofDescriptor(
57 | name='model', full_name='second.protos.DetectionModel.model',
58 | index=0, containing_type=None, fields=[]),
59 | ],
60 | serialized_start=68,
61 | serialized_end=136,
62 | )
63 |
64 | _DETECTIONMODEL.fields_by_name['second'].message_type = rslo_dot_protos_dot_second__pb2._VOXELNET
65 | _DETECTIONMODEL.oneofs_by_name['model'].fields.append(
66 | _DETECTIONMODEL.fields_by_name['second'])
67 | _DETECTIONMODEL.fields_by_name['second'].containing_oneof = _DETECTIONMODEL.oneofs_by_name['model']
68 | DESCRIPTOR.message_types_by_name['DetectionModel'] = _DETECTIONMODEL
69 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
70 |
71 | DetectionModel = _reflection.GeneratedProtocolMessageType('DetectionModel', (_message.Message,), {
72 | 'DESCRIPTOR' : _DETECTIONMODEL,
73 | '__module__' : 'rslo.protos.model_pb2'
74 | # @@protoc_insertion_point(class_scope:second.protos.DetectionModel)
75 | })
76 | _sym_db.RegisterMessage(DetectionModel)
77 |
78 |
79 | # @@protoc_insertion_point(module_scope)
80 |
--------------------------------------------------------------------------------
/rslo/protos/optimizer.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package second.protos;
4 |
5 | // Messages for configuring the optimizing strategy for training object
6 | // detection models.
7 |
8 | // Top level optimizer message.
9 | message Optimizer {
10 | oneof optimizer {
11 | RMSPropOptimizer rms_prop_optimizer = 1;
12 | MomentumOptimizer momentum_optimizer = 2;
13 | AdamOptimizer adam_optimizer = 3;
14 | }
15 | bool use_moving_average = 4;
16 | float moving_average_decay = 5;
17 | bool fixed_weight_decay = 6; // i.e. AdamW
18 | }
19 |
20 | // Configuration message for the RMSPropOptimizer
21 | // See: https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
22 | message RMSPropOptimizer {
23 | LearningRate learning_rate = 1;
24 | float momentum_optimizer_value = 2;
25 | float decay = 3;
26 | float epsilon = 4;
27 | float weight_decay = 5;
28 | }
29 |
30 | // Configuration message for the MomentumOptimizer
31 | // See: https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer
32 | message MomentumOptimizer {
33 | LearningRate learning_rate = 1;
34 | float momentum_optimizer_value = 2;
35 | float weight_decay = 3;
36 | }
37 |
38 | // Configuration message for the AdamOptimizer
39 | // See: https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
40 | message AdamOptimizer {
41 | LearningRate learning_rate = 1;
42 | float weight_decay = 2;
43 | bool amsgrad = 3;
44 | }
45 |
46 | message LearningRate {
47 | oneof learning_rate {
48 | MultiPhase multi_phase = 1;
49 | OneCycle one_cycle = 2;
50 | ExponentialDecay exponential_decay = 3;
51 | ManualStepping manual_stepping = 4;
52 | ExponentialDecayWarmup exponential_decay_warmup=5;
53 | }
54 | }
55 |
56 | message LearningRatePhase {
57 | float start = 1;
58 | string lambda_func = 2;
59 | string momentum_lambda_func = 3;
60 | }
61 |
62 | message MultiPhase {
63 | repeated LearningRatePhase phases = 1;
64 | }
65 |
66 | message OneCycle {
67 | float lr_max = 1;
68 | repeated float moms = 2;
69 | float div_factor = 3;
70 | float pct_start = 4;
71 | repeated float lr_maxs = 5;
72 |
73 | }
74 |
75 | /*
76 | ManualStepping example:
77 | initial_learning_rate = 0.001
78 | decay_length = 0.1
79 | decay_factor = 0.8
80 | staircase = True
81 | detail:
82 | progress 0%~10%, lr=0.001
83 | progress 10%~20%, lr=0.001 * 0.8
84 | progress 20%~30%, lr=0.001 * 0.8 * 0.8
85 | ......
86 | */
87 |
88 |
89 | message ExponentialDecay {
90 | float initial_learning_rate = 1;
91 | float decay_length = 2; // must in range (0, 1)
92 | float decay_factor = 3;
93 | bool staircase = 4;
94 | }
95 |
96 | message ExponentialDecayWarmup {
97 | float initial_learning_rate = 1;
98 | float decay_length = 2; // must in range (0, 1)
99 | float decay_factor = 3;
100 | bool staircase = 4;
101 | float div_factor=6;
102 | float pct_start=7;
103 | }
104 |
105 | /*
106 | ManualStepping example:
107 | boundaries = [0.8, 0.9]
108 | rates = [0.001, 0.002, 0.003]
109 | detail:
110 | progress 0%~80%, lr=0.001
111 | progress 80%~90%, lr=0.002
112 | progress 90%~100%, lr=0.003
113 | */
114 |
115 | message ManualStepping {
116 | repeated float boundaries = 1; // must in range (0, 1)
117 | repeated float rates = 2;
118 | }
--------------------------------------------------------------------------------
/rslo/protos/pipeline.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package second.protos;
4 |
5 | import "rslo/protos/input_reader.proto";
6 | import "rslo/protos/model.proto";
7 | import "rslo/protos/train.proto";
8 | // Convenience message for configuring a training and eval pipeline. Allows all
9 | // of the pipeline parameters to be configured from one file.
10 | message TrainEvalPipelineConfig {
11 | DetectionModel model = 1;
12 | InputReader train_input_reader = 2;
13 | TrainConfig train_config = 3;
14 | InputReader eval_input_reader = 4;
15 | InputReader eval_train_input_reader = 5;
16 | }
17 |
18 |
--------------------------------------------------------------------------------
/rslo/protos/pipeline_pb2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by the protocol buffer compiler. DO NOT EDIT!
3 | # source: rslo/protos/pipeline.proto
4 |
5 | import sys
6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
7 | from google.protobuf import descriptor as _descriptor
8 | from google.protobuf import message as _message
9 | from google.protobuf import reflection as _reflection
10 | from google.protobuf import symbol_database as _symbol_database
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from rslo.protos import input_reader_pb2 as rslo_dot_protos_dot_input__reader__pb2
17 | from rslo.protos import model_pb2 as rslo_dot_protos_dot_model__pb2
18 | from rslo.protos import train_pb2 as rslo_dot_protos_dot_train__pb2
19 |
20 |
21 | DESCRIPTOR = _descriptor.FileDescriptor(
22 | name='rslo/protos/pipeline.proto',
23 | package='second.protos',
24 | syntax='proto3',
25 | serialized_options=None,
26 | serialized_pb=_b('\n\x1arslo/protos/pipeline.proto\x12\rsecond.protos\x1a\x1erslo/protos/input_reader.proto\x1a\x17rslo/protos/model.proto\x1a\x17rslo/protos/train.proto\"\xa5\x02\n\x17TrainEvalPipelineConfig\x12,\n\x05model\x18\x01 \x01(\x0b\x32\x1d.second.protos.DetectionModel\x12\x36\n\x12train_input_reader\x18\x02 \x01(\x0b\x32\x1a.second.protos.InputReader\x12\x30\n\x0ctrain_config\x18\x03 \x01(\x0b\x32\x1a.second.protos.TrainConfig\x12\x35\n\x11\x65val_input_reader\x18\x04 \x01(\x0b\x32\x1a.second.protos.InputReader\x12;\n\x17\x65val_train_input_reader\x18\x05 \x01(\x0b\x32\x1a.second.protos.InputReaderb\x06proto3')
27 | ,
28 | dependencies=[rslo_dot_protos_dot_input__reader__pb2.DESCRIPTOR,rslo_dot_protos_dot_model__pb2.DESCRIPTOR,rslo_dot_protos_dot_train__pb2.DESCRIPTOR,])
29 |
30 |
31 |
32 |
33 | _TRAINEVALPIPELINECONFIG = _descriptor.Descriptor(
34 | name='TrainEvalPipelineConfig',
35 | full_name='second.protos.TrainEvalPipelineConfig',
36 | filename=None,
37 | file=DESCRIPTOR,
38 | containing_type=None,
39 | fields=[
40 | _descriptor.FieldDescriptor(
41 | name='model', full_name='second.protos.TrainEvalPipelineConfig.model', index=0,
42 | number=1, type=11, cpp_type=10, label=1,
43 | has_default_value=False, default_value=None,
44 | message_type=None, enum_type=None, containing_type=None,
45 | is_extension=False, extension_scope=None,
46 | serialized_options=None, file=DESCRIPTOR),
47 | _descriptor.FieldDescriptor(
48 | name='train_input_reader', full_name='second.protos.TrainEvalPipelineConfig.train_input_reader', index=1,
49 | number=2, type=11, cpp_type=10, label=1,
50 | has_default_value=False, default_value=None,
51 | message_type=None, enum_type=None, containing_type=None,
52 | is_extension=False, extension_scope=None,
53 | serialized_options=None, file=DESCRIPTOR),
54 | _descriptor.FieldDescriptor(
55 | name='train_config', full_name='second.protos.TrainEvalPipelineConfig.train_config', index=2,
56 | number=3, type=11, cpp_type=10, label=1,
57 | has_default_value=False, default_value=None,
58 | message_type=None, enum_type=None, containing_type=None,
59 | is_extension=False, extension_scope=None,
60 | serialized_options=None, file=DESCRIPTOR),
61 | _descriptor.FieldDescriptor(
62 | name='eval_input_reader', full_name='second.protos.TrainEvalPipelineConfig.eval_input_reader', index=3,
63 | number=4, type=11, cpp_type=10, label=1,
64 | has_default_value=False, default_value=None,
65 | message_type=None, enum_type=None, containing_type=None,
66 | is_extension=False, extension_scope=None,
67 | serialized_options=None, file=DESCRIPTOR),
68 | _descriptor.FieldDescriptor(
69 | name='eval_train_input_reader', full_name='second.protos.TrainEvalPipelineConfig.eval_train_input_reader', index=4,
70 | number=5, type=11, cpp_type=10, label=1,
71 | has_default_value=False, default_value=None,
72 | message_type=None, enum_type=None, containing_type=None,
73 | is_extension=False, extension_scope=None,
74 | serialized_options=None, file=DESCRIPTOR),
75 | ],
76 | extensions=[
77 | ],
78 | nested_types=[],
79 | enum_types=[
80 | ],
81 | serialized_options=None,
82 | is_extendable=False,
83 | syntax='proto3',
84 | extension_ranges=[],
85 | oneofs=[
86 | ],
87 | serialized_start=128,
88 | serialized_end=421,
89 | )
90 |
91 | _TRAINEVALPIPELINECONFIG.fields_by_name['model'].message_type = rslo_dot_protos_dot_model__pb2._DETECTIONMODEL
92 | _TRAINEVALPIPELINECONFIG.fields_by_name['train_input_reader'].message_type = rslo_dot_protos_dot_input__reader__pb2._INPUTREADER
93 | _TRAINEVALPIPELINECONFIG.fields_by_name['train_config'].message_type = rslo_dot_protos_dot_train__pb2._TRAINCONFIG
94 | _TRAINEVALPIPELINECONFIG.fields_by_name['eval_input_reader'].message_type = rslo_dot_protos_dot_input__reader__pb2._INPUTREADER
95 | _TRAINEVALPIPELINECONFIG.fields_by_name['eval_train_input_reader'].message_type = rslo_dot_protos_dot_input__reader__pb2._INPUTREADER
96 | DESCRIPTOR.message_types_by_name['TrainEvalPipelineConfig'] = _TRAINEVALPIPELINECONFIG
97 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
98 |
99 | TrainEvalPipelineConfig = _reflection.GeneratedProtocolMessageType('TrainEvalPipelineConfig', (_message.Message,), {
100 | 'DESCRIPTOR' : _TRAINEVALPIPELINECONFIG,
101 | '__module__' : 'rslo.protos.pipeline_pb2'
102 | # @@protoc_insertion_point(class_scope:second.protos.TrainEvalPipelineConfig)
103 | })
104 | _sym_db.RegisterMessage(TrainEvalPipelineConfig)
105 |
106 |
107 | # @@protoc_insertion_point(module_scope)
108 |
--------------------------------------------------------------------------------
/rslo/protos/preprocess.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package second.protos;
4 |
5 | message DatabasePreprocessingStep {
6 | oneof database_preprocessing_step {
7 | DBFilterByDifficulty filter_by_difficulty = 1;
8 | DBFilterByMinNumPointInGroundTruth filter_by_min_num_points = 2;
9 | }
10 | }
11 |
12 | message DBFilterByDifficulty{
13 | repeated int32 removed_difficulties = 1;
14 | }
15 |
16 | message DBFilterByMinNumPointInGroundTruth{
17 | map min_num_point_pairs = 1;
18 | }
19 |
--------------------------------------------------------------------------------
/rslo/protos/sampler.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package second.protos;
4 | import "rslo/protos/preprocess.proto";
5 |
6 | message Group{
7 | map name_to_max_num = 1;
8 | }
9 |
10 | message Sampler{
11 | string database_info_path = 1;
12 | repeated Group sample_groups = 2;
13 | repeated DatabasePreprocessingStep database_prep_steps = 3;
14 | repeated float global_random_rotation_range_per_object = 4;
15 | float rate = 5;
16 | }
17 |
--------------------------------------------------------------------------------
/rslo/protos/sampler_pb2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by the protocol buffer compiler. DO NOT EDIT!
3 | # source: rslo/protos/sampler.proto
4 |
5 | import sys
6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
7 | from google.protobuf import descriptor as _descriptor
8 | from google.protobuf import message as _message
9 | from google.protobuf import reflection as _reflection
10 | from google.protobuf import symbol_database as _symbol_database
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from rslo.protos import preprocess_pb2 as rslo_dot_protos_dot_preprocess__pb2
17 |
18 |
19 | DESCRIPTOR = _descriptor.FileDescriptor(
20 | name='rslo/protos/sampler.proto',
21 | package='second.protos',
22 | syntax='proto3',
23 | serialized_options=None,
24 | serialized_pb=_b('\n\x19rslo/protos/sampler.proto\x12\rsecond.protos\x1a\x1crslo/protos/preprocess.proto\"}\n\x05Group\x12?\n\x0fname_to_max_num\x18\x01 \x03(\x0b\x32&.second.protos.Group.NameToMaxNumEntry\x1a\x33\n\x11NameToMaxNumEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\xd8\x01\n\x07Sampler\x12\x1a\n\x12\x64\x61tabase_info_path\x18\x01 \x01(\t\x12+\n\rsample_groups\x18\x02 \x03(\x0b\x32\x14.second.protos.Group\x12\x45\n\x13\x64\x61tabase_prep_steps\x18\x03 \x03(\x0b\x32(.second.protos.DatabasePreprocessingStep\x12/\n\'global_random_rotation_range_per_object\x18\x04 \x03(\x02\x12\x0c\n\x04rate\x18\x05 \x01(\x02\x62\x06proto3')
25 | ,
26 | dependencies=[rslo_dot_protos_dot_preprocess__pb2.DESCRIPTOR,])
27 |
28 |
29 |
30 |
31 | _GROUP_NAMETOMAXNUMENTRY = _descriptor.Descriptor(
32 | name='NameToMaxNumEntry',
33 | full_name='second.protos.Group.NameToMaxNumEntry',
34 | filename=None,
35 | file=DESCRIPTOR,
36 | containing_type=None,
37 | fields=[
38 | _descriptor.FieldDescriptor(
39 | name='key', full_name='second.protos.Group.NameToMaxNumEntry.key', index=0,
40 | number=1, type=9, cpp_type=9, label=1,
41 | has_default_value=False, default_value=_b("").decode('utf-8'),
42 | message_type=None, enum_type=None, containing_type=None,
43 | is_extension=False, extension_scope=None,
44 | serialized_options=None, file=DESCRIPTOR),
45 | _descriptor.FieldDescriptor(
46 | name='value', full_name='second.protos.Group.NameToMaxNumEntry.value', index=1,
47 | number=2, type=13, cpp_type=3, label=1,
48 | has_default_value=False, default_value=0,
49 | message_type=None, enum_type=None, containing_type=None,
50 | is_extension=False, extension_scope=None,
51 | serialized_options=None, file=DESCRIPTOR),
52 | ],
53 | extensions=[
54 | ],
55 | nested_types=[],
56 | enum_types=[
57 | ],
58 | serialized_options=_b('8\001'),
59 | is_extendable=False,
60 | syntax='proto3',
61 | extension_ranges=[],
62 | oneofs=[
63 | ],
64 | serialized_start=148,
65 | serialized_end=199,
66 | )
67 |
68 | _GROUP = _descriptor.Descriptor(
69 | name='Group',
70 | full_name='second.protos.Group',
71 | filename=None,
72 | file=DESCRIPTOR,
73 | containing_type=None,
74 | fields=[
75 | _descriptor.FieldDescriptor(
76 | name='name_to_max_num', full_name='second.protos.Group.name_to_max_num', index=0,
77 | number=1, type=11, cpp_type=10, label=3,
78 | has_default_value=False, default_value=[],
79 | message_type=None, enum_type=None, containing_type=None,
80 | is_extension=False, extension_scope=None,
81 | serialized_options=None, file=DESCRIPTOR),
82 | ],
83 | extensions=[
84 | ],
85 | nested_types=[_GROUP_NAMETOMAXNUMENTRY, ],
86 | enum_types=[
87 | ],
88 | serialized_options=None,
89 | is_extendable=False,
90 | syntax='proto3',
91 | extension_ranges=[],
92 | oneofs=[
93 | ],
94 | serialized_start=74,
95 | serialized_end=199,
96 | )
97 |
98 |
99 | _SAMPLER = _descriptor.Descriptor(
100 | name='Sampler',
101 | full_name='second.protos.Sampler',
102 | filename=None,
103 | file=DESCRIPTOR,
104 | containing_type=None,
105 | fields=[
106 | _descriptor.FieldDescriptor(
107 | name='database_info_path', full_name='second.protos.Sampler.database_info_path', index=0,
108 | number=1, type=9, cpp_type=9, label=1,
109 | has_default_value=False, default_value=_b("").decode('utf-8'),
110 | message_type=None, enum_type=None, containing_type=None,
111 | is_extension=False, extension_scope=None,
112 | serialized_options=None, file=DESCRIPTOR),
113 | _descriptor.FieldDescriptor(
114 | name='sample_groups', full_name='second.protos.Sampler.sample_groups', index=1,
115 | number=2, type=11, cpp_type=10, label=3,
116 | has_default_value=False, default_value=[],
117 | message_type=None, enum_type=None, containing_type=None,
118 | is_extension=False, extension_scope=None,
119 | serialized_options=None, file=DESCRIPTOR),
120 | _descriptor.FieldDescriptor(
121 | name='database_prep_steps', full_name='second.protos.Sampler.database_prep_steps', index=2,
122 | number=3, type=11, cpp_type=10, label=3,
123 | has_default_value=False, default_value=[],
124 | message_type=None, enum_type=None, containing_type=None,
125 | is_extension=False, extension_scope=None,
126 | serialized_options=None, file=DESCRIPTOR),
127 | _descriptor.FieldDescriptor(
128 | name='global_random_rotation_range_per_object', full_name='second.protos.Sampler.global_random_rotation_range_per_object', index=3,
129 | number=4, type=2, cpp_type=6, label=3,
130 | has_default_value=False, default_value=[],
131 | message_type=None, enum_type=None, containing_type=None,
132 | is_extension=False, extension_scope=None,
133 | serialized_options=None, file=DESCRIPTOR),
134 | _descriptor.FieldDescriptor(
135 | name='rate', full_name='second.protos.Sampler.rate', index=4,
136 | number=5, type=2, cpp_type=6, label=1,
137 | has_default_value=False, default_value=float(0),
138 | message_type=None, enum_type=None, containing_type=None,
139 | is_extension=False, extension_scope=None,
140 | serialized_options=None, file=DESCRIPTOR),
141 | ],
142 | extensions=[
143 | ],
144 | nested_types=[],
145 | enum_types=[
146 | ],
147 | serialized_options=None,
148 | is_extendable=False,
149 | syntax='proto3',
150 | extension_ranges=[],
151 | oneofs=[
152 | ],
153 | serialized_start=202,
154 | serialized_end=418,
155 | )
156 |
157 | _GROUP_NAMETOMAXNUMENTRY.containing_type = _GROUP
158 | _GROUP.fields_by_name['name_to_max_num'].message_type = _GROUP_NAMETOMAXNUMENTRY
159 | _SAMPLER.fields_by_name['sample_groups'].message_type = _GROUP
160 | _SAMPLER.fields_by_name['database_prep_steps'].message_type = rslo_dot_protos_dot_preprocess__pb2._DATABASEPREPROCESSINGSTEP
161 | DESCRIPTOR.message_types_by_name['Group'] = _GROUP
162 | DESCRIPTOR.message_types_by_name['Sampler'] = _SAMPLER
163 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
164 |
165 | Group = _reflection.GeneratedProtocolMessageType('Group', (_message.Message,), {
166 |
167 | 'NameToMaxNumEntry' : _reflection.GeneratedProtocolMessageType('NameToMaxNumEntry', (_message.Message,), {
168 | 'DESCRIPTOR' : _GROUP_NAMETOMAXNUMENTRY,
169 | '__module__' : 'rslo.protos.sampler_pb2'
170 | # @@protoc_insertion_point(class_scope:second.protos.Group.NameToMaxNumEntry)
171 | })
172 | ,
173 | 'DESCRIPTOR' : _GROUP,
174 | '__module__' : 'rslo.protos.sampler_pb2'
175 | # @@protoc_insertion_point(class_scope:second.protos.Group)
176 | })
177 | _sym_db.RegisterMessage(Group)
178 | _sym_db.RegisterMessage(Group.NameToMaxNumEntry)
179 |
180 | Sampler = _reflection.GeneratedProtocolMessageType('Sampler', (_message.Message,), {
181 | 'DESCRIPTOR' : _SAMPLER,
182 | '__module__' : 'rslo.protos.sampler_pb2'
183 | # @@protoc_insertion_point(class_scope:second.protos.Sampler)
184 | })
185 | _sym_db.RegisterMessage(Sampler)
186 |
187 |
188 | _GROUP_NAMETOMAXNUMENTRY._options = None
189 | # @@protoc_insertion_point(module_scope)
190 |
--------------------------------------------------------------------------------
/rslo/protos/second.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package second.protos;
4 | import "rslo/protos/losses.proto";
5 | // import "rslo/protos/box_coder.proto";
6 | // import "rslo/protos/target.proto";
7 | import "rslo/protos/voxel_generator.proto";
8 |
9 | message VoxelNet {
10 | string network_class_name = 1;
11 | VoxelGenerator voxel_generator = 2;
12 | message VoxelFeatureExtractor {
13 | string module_class_name = 1;
14 | repeated int32 num_filters = 2;
15 | bool with_distance = 3;
16 | int32 num_input_features = 4;
17 | bool not_use_norm = 5;
18 | }
19 | VoxelFeatureExtractor voxel_feature_extractor = 3;
20 | message MiddleFeatureExtractor {
21 | string module_class_name = 1;
22 | repeated int32 num_filters_down1 = 2;
23 | repeated int32 num_filters_down2 = 3;
24 | int32 num_input_features = 4;
25 | int32 downsample_factor = 5;
26 | // bool not_use_norm = 6;
27 | bool use_leakyReLU = 7;
28 | string relu_type = 8;
29 | string bn_type=9;
30 |
31 | }
32 | MiddleFeatureExtractor middle_feature_extractor = 4;
33 |
34 | message OdomPredictor{
35 | string module_class_name =1;
36 | int32 num_input_features =2;
37 | repeated int32 layer_nums=3;
38 | repeated int32 layer_strides=4;
39 | repeated int32 num_filters =5;
40 | repeated int32 upsample_strides=6;
41 | repeated int32 num_upsample_filters=7;
42 | // int32 avgpool_size=8;
43 | int32 pool_size=8;
44 | string pool_type=14;
45 | bool cycle_constraint=9;
46 | // bool not_use_norm = 10;
47 | // bool use_sparse_conv = 11;
48 | string conv_type=11; //['official', 'sparse_conv', 'mask_conv']
49 | bool pred_pyramid_motion=12;
50 | string odom_format=13;
51 | bool dense_predict=15;
52 | bool use_corr=16;
53 | float dropout=17;
54 | string conf_type=18;
55 | bool use_SPGN=19;
56 | bool use_deep_supervision=20;
57 | bool not_use_loss_mask=21;
58 | bool use_dynamic_mask=22;
59 | bool use_leakyReLU=23;
60 | bool dropout_input=24;
61 | int32 first_conv_groups=25;
62 | bool odom_use_se=26;
63 | bool odom_use_sa=27;
64 | int32 cubic_pred_height=28;
65 | bool not_use_enc_norm=29;
66 | bool use_svd=30;
67 | string bn_type=31;
68 | // bool freeze_bn=30;
69 | // bool freeze_bn_affine=31;
70 | }
71 | OdomPredictor odom_predictor=5;
72 |
73 | uint32 num_point_features = 6;
74 | bool use_sigmoid_score = 7;
75 | Loss loss = 8;
76 | // bool encode_rad_error_by_sin = 9;
77 | bool encode_background_as_zeros = 10;
78 | bool use_GN=11;
79 | repeated float post_center_limit_range = 18;
80 |
81 | // deprecated in future
82 | bool lidar_input = 24;
83 |
84 | bool freeze_bn=30;
85 | bool freeze_bn_affine=31;
86 | int32 freeze_bn_start_step =32;
87 | uint32 icp_iter=33;
88 | // bool sync_bn=33;
89 | }
--------------------------------------------------------------------------------
/rslo/protos/similarity.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package second.protos;
4 |
5 | // Configuration proto for region similarity calculators. See
6 | // core/region_similarity_calculator.py for details.
7 | message RegionSimilarityCalculator {
8 | oneof region_similarity {
9 | RotateIouSimilarity rotate_iou_similarity = 1;
10 | NearestIouSimilarity nearest_iou_similarity = 2;
11 | DistanceSimilarity distance_similarity = 3;
12 | }
13 | }
14 |
15 | // Configuration for intersection-over-union (IOU) similarity calculator.
16 | message RotateIouSimilarity {
17 | }
18 |
19 | // Configuration for intersection-over-union (IOU) similarity calculator.
20 | message NearestIouSimilarity {
21 | }
22 |
23 | // Configuration for intersection-over-union (IOU) similarity calculator.
24 | message DistanceSimilarity {
25 | float distance_norm = 1;
26 | bool with_rotation = 2;
27 | float rotation_alpha = 3;
28 | }
--------------------------------------------------------------------------------
/rslo/protos/target.proto.bak:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package second.protos;
4 | import "rslo/protos/anchors.proto";
5 | import "rslo/protos/similarity.proto";
6 |
7 | message ClassSetting {
8 | oneof anchor_generator {
9 | AnchorGeneratorStride anchor_generator_stride = 1;
10 | AnchorGeneratorRange anchor_generator_range = 2;
11 | NoAnchor no_anchor = 3;
12 | }
13 | RegionSimilarityCalculator region_similarity_calculator = 4;
14 | bool use_multi_class_nms = 5;
15 | bool use_rotate_nms = 6;
16 | int32 nms_pre_max_size = 7;
17 | int32 nms_post_max_size = 8;
18 | float nms_score_threshold = 9;
19 | float nms_iou_threshold = 10;
20 | float matched_threshold = 11;
21 | float unmatched_threshold = 12;
22 | string class_name = 13;
23 | repeated int64 feature_map_size = 14; // 3D zyx (DHW) size
24 | }
25 |
26 | message TargetAssigner {
27 | repeated ClassSetting class_settings = 1;
28 | float sample_positive_fraction = 2;
29 | uint32 sample_size = 3;
30 | bool assign_per_class = 4;
31 | repeated int64 nms_pre_max_sizes = 5; // this will override setting in ClassSettings if provide.
32 | repeated int64 nms_post_max_sizes = 6; // this will override setting in ClassSettings if provide.
33 | repeated int64 nms_score_thresholds = 7; // this will override setting in ClassSettings if provide.
34 | repeated int64 nms_iou_thresholds = 8; // this will override setting in ClassSettings if provide.
35 | }
--------------------------------------------------------------------------------
/rslo/protos/train.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package second.protos;
4 |
5 | import "rslo/protos/optimizer.proto";
6 | import "rslo/protos/preprocess.proto";
7 |
8 | message TrainConfig{
9 | Optimizer optimizer = 1;
10 | uint32 steps = 2;
11 | uint32 steps_per_eval = 3;
12 | uint32 save_checkpoints_secs = 4;
13 | uint32 save_summary_steps = 5;
14 | bool enable_mixed_precision = 6;
15 | float loss_scale_factor = 7;
16 | bool clear_metrics_every_epoch = 8;
17 | }
--------------------------------------------------------------------------------
/rslo/protos/train_pb2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by the protocol buffer compiler. DO NOT EDIT!
3 | # source: rslo/protos/train.proto
4 |
5 | import sys
6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
7 | from google.protobuf import descriptor as _descriptor
8 | from google.protobuf import message as _message
9 | from google.protobuf import reflection as _reflection
10 | from google.protobuf import symbol_database as _symbol_database
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | from rslo.protos import optimizer_pb2 as rslo_dot_protos_dot_optimizer__pb2
17 | from rslo.protos import preprocess_pb2 as rslo_dot_protos_dot_preprocess__pb2
18 |
19 |
20 | DESCRIPTOR = _descriptor.FileDescriptor(
21 | name='rslo/protos/train.proto',
22 | package='second.protos',
23 | syntax='proto3',
24 | serialized_options=None,
25 | serialized_pb=_b('\n\x17rslo/protos/train.proto\x12\rsecond.protos\x1a\x1brslo/protos/optimizer.proto\x1a\x1crslo/protos/preprocess.proto\"\xfa\x01\n\x0bTrainConfig\x12+\n\toptimizer\x18\x01 \x01(\x0b\x32\x18.second.protos.Optimizer\x12\r\n\x05steps\x18\x02 \x01(\r\x12\x16\n\x0esteps_per_eval\x18\x03 \x01(\r\x12\x1d\n\x15save_checkpoints_secs\x18\x04 \x01(\r\x12\x1a\n\x12save_summary_steps\x18\x05 \x01(\r\x12\x1e\n\x16\x65nable_mixed_precision\x18\x06 \x01(\x08\x12\x19\n\x11loss_scale_factor\x18\x07 \x01(\x02\x12!\n\x19\x63lear_metrics_every_epoch\x18\x08 \x01(\x08\x62\x06proto3')
26 | ,
27 | dependencies=[rslo_dot_protos_dot_optimizer__pb2.DESCRIPTOR,rslo_dot_protos_dot_preprocess__pb2.DESCRIPTOR,])
28 |
29 |
30 |
31 |
32 | _TRAINCONFIG = _descriptor.Descriptor(
33 | name='TrainConfig',
34 | full_name='second.protos.TrainConfig',
35 | filename=None,
36 | file=DESCRIPTOR,
37 | containing_type=None,
38 | fields=[
39 | _descriptor.FieldDescriptor(
40 | name='optimizer', full_name='second.protos.TrainConfig.optimizer', index=0,
41 | number=1, type=11, cpp_type=10, label=1,
42 | has_default_value=False, default_value=None,
43 | message_type=None, enum_type=None, containing_type=None,
44 | is_extension=False, extension_scope=None,
45 | serialized_options=None, file=DESCRIPTOR),
46 | _descriptor.FieldDescriptor(
47 | name='steps', full_name='second.protos.TrainConfig.steps', index=1,
48 | number=2, type=13, cpp_type=3, label=1,
49 | has_default_value=False, default_value=0,
50 | message_type=None, enum_type=None, containing_type=None,
51 | is_extension=False, extension_scope=None,
52 | serialized_options=None, file=DESCRIPTOR),
53 | _descriptor.FieldDescriptor(
54 | name='steps_per_eval', full_name='second.protos.TrainConfig.steps_per_eval', index=2,
55 | number=3, type=13, cpp_type=3, label=1,
56 | has_default_value=False, default_value=0,
57 | message_type=None, enum_type=None, containing_type=None,
58 | is_extension=False, extension_scope=None,
59 | serialized_options=None, file=DESCRIPTOR),
60 | _descriptor.FieldDescriptor(
61 | name='save_checkpoints_secs', full_name='second.protos.TrainConfig.save_checkpoints_secs', index=3,
62 | number=4, type=13, cpp_type=3, label=1,
63 | has_default_value=False, default_value=0,
64 | message_type=None, enum_type=None, containing_type=None,
65 | is_extension=False, extension_scope=None,
66 | serialized_options=None, file=DESCRIPTOR),
67 | _descriptor.FieldDescriptor(
68 | name='save_summary_steps', full_name='second.protos.TrainConfig.save_summary_steps', index=4,
69 | number=5, type=13, cpp_type=3, label=1,
70 | has_default_value=False, default_value=0,
71 | message_type=None, enum_type=None, containing_type=None,
72 | is_extension=False, extension_scope=None,
73 | serialized_options=None, file=DESCRIPTOR),
74 | _descriptor.FieldDescriptor(
75 | name='enable_mixed_precision', full_name='second.protos.TrainConfig.enable_mixed_precision', index=5,
76 | number=6, type=8, cpp_type=7, label=1,
77 | has_default_value=False, default_value=False,
78 | message_type=None, enum_type=None, containing_type=None,
79 | is_extension=False, extension_scope=None,
80 | serialized_options=None, file=DESCRIPTOR),
81 | _descriptor.FieldDescriptor(
82 | name='loss_scale_factor', full_name='second.protos.TrainConfig.loss_scale_factor', index=6,
83 | number=7, type=2, cpp_type=6, label=1,
84 | has_default_value=False, default_value=float(0),
85 | message_type=None, enum_type=None, containing_type=None,
86 | is_extension=False, extension_scope=None,
87 | serialized_options=None, file=DESCRIPTOR),
88 | _descriptor.FieldDescriptor(
89 | name='clear_metrics_every_epoch', full_name='second.protos.TrainConfig.clear_metrics_every_epoch', index=7,
90 | number=8, type=8, cpp_type=7, label=1,
91 | has_default_value=False, default_value=False,
92 | message_type=None, enum_type=None, containing_type=None,
93 | is_extension=False, extension_scope=None,
94 | serialized_options=None, file=DESCRIPTOR),
95 | ],
96 | extensions=[
97 | ],
98 | nested_types=[],
99 | enum_types=[
100 | ],
101 | serialized_options=None,
102 | is_extendable=False,
103 | syntax='proto3',
104 | extension_ranges=[],
105 | oneofs=[
106 | ],
107 | serialized_start=102,
108 | serialized_end=352,
109 | )
110 |
111 | _TRAINCONFIG.fields_by_name['optimizer'].message_type = rslo_dot_protos_dot_optimizer__pb2._OPTIMIZER
112 | DESCRIPTOR.message_types_by_name['TrainConfig'] = _TRAINCONFIG
113 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
114 |
115 | TrainConfig = _reflection.GeneratedProtocolMessageType('TrainConfig', (_message.Message,), {
116 | 'DESCRIPTOR' : _TRAINCONFIG,
117 | '__module__' : 'rslo.protos.train_pb2'
118 | # @@protoc_insertion_point(class_scope:second.protos.TrainConfig)
119 | })
120 | _sym_db.RegisterMessage(TrainConfig)
121 |
122 |
123 | # @@protoc_insertion_point(module_scope)
124 |
--------------------------------------------------------------------------------
/rslo/protos/voxel_generator.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package second.protos;
4 |
5 | message VoxelGenerator{
6 | repeated float voxel_size = 1;
7 | repeated float point_cloud_range = 2;
8 | uint32 max_number_of_points_per_voxel = 3;
9 | bool full_empty_part_with_mean = 4;
10 | bool block_filtering = 5;
11 | int64 block_factor = 6;
12 | int64 block_size = 7;
13 | float height_threshold = 8;
14 | }
15 |
--------------------------------------------------------------------------------
/rslo/protos/voxel_generator_pb2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by the protocol buffer compiler. DO NOT EDIT!
3 | # source: rslo/protos/voxel_generator.proto
4 |
5 | import sys
6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
7 | from google.protobuf import descriptor as _descriptor
8 | from google.protobuf import message as _message
9 | from google.protobuf import reflection as _reflection
10 | from google.protobuf import symbol_database as _symbol_database
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='rslo/protos/voxel_generator.proto',
20 | package='second.protos',
21 | syntax='proto3',
22 | serialized_options=None,
23 | serialized_pb=_b('\n!rslo/protos/voxel_generator.proto\x12\rsecond.protos\"\xe7\x01\n\x0eVoxelGenerator\x12\x12\n\nvoxel_size\x18\x01 \x03(\x02\x12\x19\n\x11point_cloud_range\x18\x02 \x03(\x02\x12&\n\x1emax_number_of_points_per_voxel\x18\x03 \x01(\r\x12!\n\x19\x66ull_empty_part_with_mean\x18\x04 \x01(\x08\x12\x17\n\x0f\x62lock_filtering\x18\x05 \x01(\x08\x12\x14\n\x0c\x62lock_factor\x18\x06 \x01(\x03\x12\x12\n\nblock_size\x18\x07 \x01(\x03\x12\x18\n\x10height_threshold\x18\x08 \x01(\x02\x62\x06proto3')
24 | )
25 |
26 |
27 |
28 |
29 | _VOXELGENERATOR = _descriptor.Descriptor(
30 | name='VoxelGenerator',
31 | full_name='second.protos.VoxelGenerator',
32 | filename=None,
33 | file=DESCRIPTOR,
34 | containing_type=None,
35 | fields=[
36 | _descriptor.FieldDescriptor(
37 | name='voxel_size', full_name='second.protos.VoxelGenerator.voxel_size', index=0,
38 | number=1, type=2, cpp_type=6, label=3,
39 | has_default_value=False, default_value=[],
40 | message_type=None, enum_type=None, containing_type=None,
41 | is_extension=False, extension_scope=None,
42 | serialized_options=None, file=DESCRIPTOR),
43 | _descriptor.FieldDescriptor(
44 | name='point_cloud_range', full_name='second.protos.VoxelGenerator.point_cloud_range', index=1,
45 | number=2, type=2, cpp_type=6, label=3,
46 | has_default_value=False, default_value=[],
47 | message_type=None, enum_type=None, containing_type=None,
48 | is_extension=False, extension_scope=None,
49 | serialized_options=None, file=DESCRIPTOR),
50 | _descriptor.FieldDescriptor(
51 | name='max_number_of_points_per_voxel', full_name='second.protos.VoxelGenerator.max_number_of_points_per_voxel', index=2,
52 | number=3, type=13, cpp_type=3, label=1,
53 | has_default_value=False, default_value=0,
54 | message_type=None, enum_type=None, containing_type=None,
55 | is_extension=False, extension_scope=None,
56 | serialized_options=None, file=DESCRIPTOR),
57 | _descriptor.FieldDescriptor(
58 | name='full_empty_part_with_mean', full_name='second.protos.VoxelGenerator.full_empty_part_with_mean', index=3,
59 | number=4, type=8, cpp_type=7, label=1,
60 | has_default_value=False, default_value=False,
61 | message_type=None, enum_type=None, containing_type=None,
62 | is_extension=False, extension_scope=None,
63 | serialized_options=None, file=DESCRIPTOR),
64 | _descriptor.FieldDescriptor(
65 | name='block_filtering', full_name='second.protos.VoxelGenerator.block_filtering', index=4,
66 | number=5, type=8, cpp_type=7, label=1,
67 | has_default_value=False, default_value=False,
68 | message_type=None, enum_type=None, containing_type=None,
69 | is_extension=False, extension_scope=None,
70 | serialized_options=None, file=DESCRIPTOR),
71 | _descriptor.FieldDescriptor(
72 | name='block_factor', full_name='second.protos.VoxelGenerator.block_factor', index=5,
73 | number=6, type=3, cpp_type=2, label=1,
74 | has_default_value=False, default_value=0,
75 | message_type=None, enum_type=None, containing_type=None,
76 | is_extension=False, extension_scope=None,
77 | serialized_options=None, file=DESCRIPTOR),
78 | _descriptor.FieldDescriptor(
79 | name='block_size', full_name='second.protos.VoxelGenerator.block_size', index=6,
80 | number=7, type=3, cpp_type=2, label=1,
81 | has_default_value=False, default_value=0,
82 | message_type=None, enum_type=None, containing_type=None,
83 | is_extension=False, extension_scope=None,
84 | serialized_options=None, file=DESCRIPTOR),
85 | _descriptor.FieldDescriptor(
86 | name='height_threshold', full_name='second.protos.VoxelGenerator.height_threshold', index=7,
87 | number=8, type=2, cpp_type=6, label=1,
88 | has_default_value=False, default_value=float(0),
89 | message_type=None, enum_type=None, containing_type=None,
90 | is_extension=False, extension_scope=None,
91 | serialized_options=None, file=DESCRIPTOR),
92 | ],
93 | extensions=[
94 | ],
95 | nested_types=[],
96 | enum_types=[
97 | ],
98 | serialized_options=None,
99 | is_extendable=False,
100 | syntax='proto3',
101 | extension_ranges=[],
102 | oneofs=[
103 | ],
104 | serialized_start=53,
105 | serialized_end=284,
106 | )
107 |
108 | DESCRIPTOR.message_types_by_name['VoxelGenerator'] = _VOXELGENERATOR
109 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
110 |
111 | VoxelGenerator = _reflection.GeneratedProtocolMessageType('VoxelGenerator', (_message.Message,), {
112 | 'DESCRIPTOR' : _VOXELGENERATOR,
113 | '__module__' : 'rslo.protos.voxel_generator_pb2'
114 | # @@protoc_insertion_point(class_scope:second.protos.VoxelGenerator)
115 | })
116 | _sym_db.RegisterMessage(VoxelGenerator)
117 |
118 |
119 | # @@protoc_insertion_point(module_scope)
120 |
--------------------------------------------------------------------------------
/rslo/torchplus/__init__.py:
--------------------------------------------------------------------------------
1 | from . import train
2 | from . import nn
3 | from . import metrics
4 | from . import tools
5 |
6 | from .tools import change_default_args
7 | from torchplus.ops.array_ops import scatter_nd, gather_nd, roll
8 |
--------------------------------------------------------------------------------
/rslo/torchplus/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from torchplus.nn.functional import one_hot
2 | from torchplus.nn.modules.common import Empty, Sequential
3 | from torchplus.nn.modules.normalization import GroupNorm
4 |
--------------------------------------------------------------------------------
/rslo/torchplus/nn/functional.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def one_hot(tensor, depth, dim=-1, on_value=1.0, dtype=torch.float32):
4 | tensor_onehot = torch.zeros(
5 | *list(tensor.shape), depth, dtype=dtype, device=tensor.device)
6 | tensor_onehot.scatter_(dim, tensor.unsqueeze(dim).long(), on_value)
7 | return tensor_onehot
8 |
--------------------------------------------------------------------------------
/rslo/torchplus/nn/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/rslo/torchplus/nn/modules/__init__.py
--------------------------------------------------------------------------------
/rslo/torchplus/nn/modules/common.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from collections import OrderedDict
3 |
4 | import torch
5 | from torch.nn import functional as F
6 |
7 |
8 | class Empty(torch.nn.Module):
9 | def __init__(self, *args, **kwargs):
10 | super(Empty, self).__init__()
11 | self.weight = torch.zeros([1, ]) # dummy varaible
12 |
13 | def forward(self, *args, **kwargs):
14 | if len(args) == 1:
15 | return args[0]
16 | elif len(args) == 0:
17 | return None
18 | return args
19 |
20 |
21 | class Sequential(torch.nn.Module):
22 | r"""A sequential container.
23 | Modules will be added to it in the order they are passed in the constructor.
24 | Alternatively, an ordered dict of modules can also be passed in.
25 |
26 | To make it easier to understand, given is a small example::
27 |
28 | # Example of using Sequential
29 | model = Sequential(
30 | nn.Conv2d(1,20,5),
31 | nn.ReLU(),
32 | nn.Conv2d(20,64,5),
33 | nn.ReLU()
34 | )
35 |
36 | # Example of using Sequential with OrderedDict
37 | model = Sequential(OrderedDict([
38 | ('conv1', nn.Conv2d(1,20,5)),
39 | ('relu1', nn.ReLU()),
40 | ('conv2', nn.Conv2d(20,64,5)),
41 | ('relu2', nn.ReLU())
42 | ]))
43 |
44 | # Example of using Sequential with kwargs(python 3.6+)
45 | model = Sequential(
46 | conv1=nn.Conv2d(1,20,5),
47 | relu1=nn.ReLU(),
48 | conv2=nn.Conv2d(20,64,5),
49 | relu2=nn.ReLU()
50 | )
51 | """
52 |
53 | def __init__(self, *args, **kwargs):
54 | super(Sequential, self).__init__()
55 | if len(args) == 1 and isinstance(args[0], OrderedDict):
56 | for key, module in args[0].items():
57 | self.add_module(key, module)
58 | else:
59 | for idx, module in enumerate(args):
60 | self.add_module(str(idx), module)
61 | for name, module in kwargs.items():
62 | if sys.version_info < (3, 6):
63 | raise ValueError("kwargs only supported in py36+")
64 | if name in self._modules:
65 | raise ValueError("name exists.")
66 | self.add_module(name, module)
67 |
68 | def __getitem__(self, idx):
69 | if not (-len(self) <= idx < len(self)):
70 | raise IndexError('index {} is out of range'.format(idx))
71 | if idx < 0:
72 | idx += len(self)
73 | it = iter(self._modules.values())
74 | for i in range(idx):
75 | next(it)
76 | return next(it)
77 |
78 | def __len__(self):
79 | return len(self._modules)
80 |
81 | def add(self, module, name=None):
82 | if name is None:
83 | name = str(len(self._modules))
84 | if name in self._modules:
85 | raise KeyError("name exists")
86 | self.add_module(name, module)
87 |
88 | def forward(self, input):
89 | # i = 0
90 | for module in self._modules.values():
91 | # print(i)
92 | input = module(input)
93 | # i += 1
94 | return input
95 |
--------------------------------------------------------------------------------
/rslo/torchplus/nn/modules/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class GroupNorm(torch.nn.GroupNorm):
5 | def __init__(self, num_channels, num_groups, eps=1e-5, affine=True):
6 | super().__init__(
7 | num_groups=num_groups,
8 | num_channels=num_channels,
9 | eps=eps,
10 | affine=affine)
11 |
--------------------------------------------------------------------------------
/rslo/torchplus/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/rslo/torchplus/ops/__init__.py
--------------------------------------------------------------------------------
/rslo/torchplus/ops/array_ops.py:
--------------------------------------------------------------------------------
1 | import ctypes
2 | import math
3 | import time
4 | import torch
5 | from typing import Optional
6 |
7 |
8 | def scatter_nd(indices, updates, shape):
9 | """pytorch edition of tensorflow scatter_nd.
10 | this function don't contain except handle code. so use this carefully
11 | when indice repeats, don't support repeat add which is supported
12 | in tensorflow.
13 | """
14 | ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)
15 | ndim = indices.shape[-1]
16 | output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:]
17 | flatted_indices = indices.view(-1, ndim)
18 | slices = [flatted_indices[:, i] for i in range(ndim)]
19 | slices += [Ellipsis]
20 | ret[slices] = updates.view(*output_shape)
21 | return ret
22 |
23 |
24 | def gather_nd(params, indices):
25 | # this function has a limit that MAX_ADVINDEX_CALC_DIMS=5
26 | ndim = indices.shape[-1]
27 | output_shape = list(indices.shape[:-1]) + list(params.shape[indices.shape[-1]:])
28 | flatted_indices = indices.view(-1, ndim)
29 | slices = [flatted_indices[:, i] for i in range(ndim)]
30 | slices += [Ellipsis]
31 | return params[slices].view(*output_shape)
32 |
33 |
34 | def roll(x: torch.Tensor, shift: int, dim: int = -1, fill_pad: Optional[int] = None):
35 |
36 | device = x.device
37 |
38 | if 0 == shift:
39 | return x
40 |
41 | elif shift < 0:
42 | shift = -shift
43 | gap = x.index_select(dim, torch.arange(shift, device=device))
44 | if fill_pad is not None:
45 | gap = fill_pad * torch.ones_like(gap, device=device)
46 | return torch.cat([x.index_select(dim, torch.arange(shift, x.size(dim), device=device)), gap], dim=dim)
47 |
48 | else:
49 | shift = x.size(dim) - shift
50 | gap = x.index_select(dim, torch.arange(shift, x.size(dim), device=device))
51 | if fill_pad is not None:
52 | gap = fill_pad * torch.ones_like(gap, device=device)
53 | return torch.cat([gap, x.index_select(dim, torch.arange(shift, device=device))], dim=dim)
--------------------------------------------------------------------------------
/rslo/torchplus/tools.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import inspect
3 | import sys
4 | from collections import OrderedDict
5 |
6 | import numba
7 | import numpy as np
8 | import torch
9 |
10 |
11 | def get_pos_to_kw_map(func):
12 | pos_to_kw = {}
13 | fsig = inspect.signature(func)
14 | pos = 0
15 | for name, info in fsig.parameters.items():
16 | if info.kind is info.POSITIONAL_OR_KEYWORD:
17 | pos_to_kw[pos] = name
18 | pos += 1
19 | return pos_to_kw
20 |
21 |
22 | def get_kw_to_default_map(func):
23 | kw_to_default = {}
24 | fsig = inspect.signature(func)
25 | for name, info in fsig.parameters.items():
26 | if info.kind is info.POSITIONAL_OR_KEYWORD:
27 | if info.default is not info.empty:
28 | kw_to_default[name] = info.default
29 | return kw_to_default
30 |
31 |
32 | # def change_default_args(**kwargs):
33 | # def layer_wrapper(layer_class):
34 | # class DefaultArgLayer(layer_class):
35 | # def __init__(self, *args, **kw):
36 | # pos_to_kw = get_pos_to_kw_map(layer_class.__init__)
37 | # kw_to_pos = {kw: pos for pos, kw in pos_to_kw.items()}
38 | # for key, val in kwargs.items():
39 | # if key not in kw and kw_to_pos[key] > len(args):
40 | # kw[key] = val
41 | # super().__init__(*args, **kw)
42 |
43 | # return DefaultArgLayer
44 |
45 | # return layer_wrapper
46 |
47 | def change_default_args(**kwargs):
48 | def layer_wrapper(layer_class):
49 | class DefaultArgLayer(layer_class):
50 | def __init__(self, *args, **kw):
51 | pos_to_kw = get_pos_to_kw_map(layer_class.__init__)
52 | kw_to_pos = {kw: pos for pos, kw in pos_to_kw.items()}
53 | for key, val in kwargs.items():
54 | if key not in kw and kw_to_pos[key] > len(args):
55 | kw[key] = val
56 | super(DefaultArgLayer,self).__init__(*args, **kw)
57 |
58 | return DefaultArgLayer
59 |
60 | return layer_wrapper
61 | def torch_to_np_dtype(ttype):
62 | type_map = {
63 | torch.float16: np.dtype(np.float16),
64 | torch.float32: np.dtype(np.float32),
65 | torch.float16: np.dtype(np.float64),
66 | torch.int32: np.dtype(np.int32),
67 | torch.int64: np.dtype(np.int64),
68 | torch.uint8: np.dtype(np.uint8),
69 | }
70 | return type_map[ttype]
71 |
--------------------------------------------------------------------------------
/rslo/torchplus/train/__init__.py:
--------------------------------------------------------------------------------
1 | from torchplus.train.checkpoint import (latest_checkpoint, restore,
2 | restore_latest_checkpoints,
3 | restore_models, save, save_models,
4 | try_restore_latest_checkpoints,
5 | save_models_cpu
6 | )
7 | from torchplus.train.common import create_folder
8 | from torchplus.train.optim import MixedPrecisionWrapper
9 |
--------------------------------------------------------------------------------
/rslo/torchplus/train/common.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import shutil
4 |
5 | def create_folder(prefix, add_time=True, add_str=None, delete=False):
6 | additional_str = ''
7 | if delete is True:
8 | if os.path.exists(prefix):
9 | shutil.rmtree(prefix)
10 | os.makedirs(prefix)
11 | folder = prefix
12 | if add_time is True:
13 | # additional_str has a form such as '170903_220351'
14 | additional_str += datetime.datetime.now().strftime("%y%m%d_%H%M%S")
15 | if add_str is not None:
16 | folder += '/' + additional_str + '_' + add_str
17 | else:
18 | folder += '/' + additional_str
19 | if delete is True:
20 | if os.path.exists(folder):
21 | shutil.rmtree(folder)
22 | os.makedirs(folder)
23 | return folder
--------------------------------------------------------------------------------
/rslo/torchplus/train/learning_schedules_fastai.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | from functools import partial
4 | import torch
5 |
6 |
7 | class LRSchedulerStep(object):
8 | def __init__(self, fai_optimizer, total_step, lr_phases, mom_phases):
9 | self.optimizer = fai_optimizer
10 | self.total_step = total_step
11 | self.lr_phases = []
12 |
13 | for i, (start, lambda_func) in enumerate(lr_phases):
14 | if len(self.lr_phases) != 0:
15 | assert self.lr_phases[-1][0] < int(start * total_step)
16 | if isinstance(lambda_func, str):
17 | lambda_func = eval(lambda_func)
18 | if i < len(lr_phases) - 1:
19 | self.lr_phases.append((int(start * total_step),
20 | int(lr_phases[i + 1][0] * total_step),
21 | lambda_func))
22 | else:
23 | self.lr_phases.append((int(start * total_step), total_step,
24 | lambda_func))
25 | assert self.lr_phases[0][0] == 0
26 | self.mom_phases = []
27 |
28 | for i, (start, lambda_func) in enumerate(mom_phases):
29 | if len(self.mom_phases) != 0:
30 | assert self.mom_phases[-1][0] < int(start * total_step)
31 | if isinstance(lambda_func, str):
32 | lambda_func = eval(lambda_func)
33 | if i < len(mom_phases) - 1:
34 | self.mom_phases.append((int(start * total_step),
35 | int(mom_phases[i + 1][0] * total_step),
36 | lambda_func))
37 | else:
38 | self.mom_phases.append((int(start * total_step), total_step,
39 | lambda_func))
40 | if len(mom_phases) > 0:
41 | assert self.mom_phases[0][0] == 0
42 |
43 | def step(self, step):
44 | lrs = []
45 | moms = []
46 | for start, end, func in self.lr_phases:
47 | if step >= start:
48 | # func: lr decay function
49 | lrs.append(func((step - start) / (end - start)))
50 | if len(lrs) > 0:
51 | self.optimizer.lr = lrs[-1]
52 | # import pdb
53 | # pdb.set_trace()
54 | # print(self.optimizer.lr,'!!!!')
55 |
56 | for start, end, func in self.mom_phases:
57 | if step >= start:
58 | moms.append(func((step - start) / (end - start)))
59 | self.optimizer.mom = func((step - start) / (end - start))
60 | if len(moms) > 0:
61 | self.optimizer.mom = moms[-1]
62 |
63 | @property
64 | def learning_rate(self):
65 | return self.optimizer.lr
66 |
67 |
68 | def annealing_cos(start, end, pct):
69 | # print(pct, start, end)
70 | "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
71 | cos_out = np.cos(np.pi * pct) + 1
72 | return end + (start - end) / 2 * cos_out
73 |
74 |
75 | class OneCycle(LRSchedulerStep):
76 | def __init__(self, fai_optimizer, total_step, lr_max, moms, div_factor,
77 | pct_start):
78 | self.lr_max = lr_max
79 | self.moms = moms
80 | self.div_factor = div_factor
81 | self.pct_start = pct_start
82 | a1 = int(total_step * self.pct_start)
83 | a2 = total_step - a1
84 | low_lr = self.lr_max / self.div_factor
85 | lr_phases = ((0, partial(annealing_cos, low_lr, self.lr_max)),
86 | (self.pct_start,
87 | partial(annealing_cos, self.lr_max, low_lr / 1e4))
88 | )
89 | mom_phases = ((0, partial(annealing_cos, *self.moms)),
90 | (self.pct_start, partial(annealing_cos,
91 | *self.moms[::-1])))
92 | fai_optimizer.lr, fai_optimizer.mom = low_lr, self.moms[0]
93 |
94 | super().__init__(fai_optimizer, total_step, lr_phases, mom_phases)
95 |
96 |
97 | class ExponentialDecayWarmup(LRSchedulerStep):
98 | def __init__(self,
99 | fai_optimizer,
100 | total_step,
101 | initial_learning_rate,
102 | decay_length,
103 | decay_factor,
104 | div_factor=1,
105 | pct_start=0,
106 | staircase=True):
107 | """
108 | Args:
109 | decay_length: must in (0, 1)
110 | """
111 | assert decay_length > 0
112 | assert decay_length < 1
113 | self._decay_steps_unified = decay_length
114 | self._decay_factor = decay_factor
115 | self._staircase = staircase
116 | self.div_factor = div_factor
117 | self.pct_start = pct_start
118 | step = pct_start*total_step # 0
119 | stage = 1
120 | lr_phases = [
121 | (0, partial(annealing_cos, initial_learning_rate/div_factor, initial_learning_rate))]
122 | if staircase:
123 | while step <= total_step:
124 | func = lambda p, _d=initial_learning_rate * stage: _d
125 | lr_phases.append((step / total_step, func))
126 | stage *= decay_factor
127 | step += int(decay_length * total_step)
128 | else:
129 | def func(p): return pow(decay_factor, (p / decay_length))
130 | lr_phases.append((pct_start, func))
131 | # lr_phases.append((step/total_step, func))
132 | super().__init__(fai_optimizer, total_step, lr_phases, [])
133 |
134 |
135 | class ExponentialDecay(LRSchedulerStep):
136 | def __init__(self,
137 | fai_optimizer,
138 | total_step,
139 | initial_learning_rate,
140 | decay_length,
141 | decay_factor,
142 | staircase=True):
143 | """
144 | Args:
145 | decay_length: must in (0, 1)
146 | """
147 | assert decay_length > 0
148 | assert decay_length < 1
149 | self._decay_steps_unified = decay_length
150 | self._decay_factor = decay_factor
151 | self._staircase = staircase
152 | step = 0
153 | stage = 1
154 | lr_phases = []
155 | if staircase:
156 | while step <= total_step:
157 | func = lambda p, _d=initial_learning_rate * stage: _d
158 | lr_phases.append((step / total_step, func))
159 | stage *= decay_factor
160 | step += int(decay_length * total_step)
161 | else:
162 | def func(p): return pow(decay_factor, (p / decay_length))
163 | lr_phases.append((0, func))
164 | super().__init__(fai_optimizer, total_step, lr_phases, [])
165 |
166 |
167 | class ManualStepping(LRSchedulerStep):
168 | def __init__(self, fai_optimizer, total_step, boundaries, rates):
169 | assert all([b > 0 and b < 1 for b in boundaries])
170 | assert len(boundaries) + 1 == len(rates)
171 | boundaries.insert(0, 0.0)
172 | lr_phases = []
173 | for start, rate in zip(boundaries, rates):
174 | def func(p, _d=rate): return _d
175 | lr_phases.append((start, func))
176 | super().__init__(fai_optimizer, total_step, lr_phases, [])
177 |
178 |
179 | class FakeOptim:
180 | def __init__(self):
181 | self.lr = 0
182 | self.mom = 0
183 |
184 |
185 | if __name__ == "__main__":
186 | import matplotlib.pyplot as plt
187 | opt = FakeOptim() # 3e-3, wd=0.4, div_factor=10
188 | # schd = OneCycle(opt, 100, 3e-3, (0.95, 0.85), 10.0, 0.4)
189 | schd = ExponentialDecay(opt, 100, 3e-4, 0.1, 0.8, staircase=True)
190 | schd = ManualStepping(opt, 100, [0.8, 0.9], [0.001, 0.0001, 0.00005])
191 | lrs = []
192 | moms = []
193 | for i in range(100):
194 | schd.step(i)
195 | lrs.append(opt.lr)
196 | moms.append(opt.mom)
197 |
198 | plt.plot(lrs)
199 | # plt.plot(moms)
200 | # plt.show()
201 | # plt.plot(moms)
202 | plt.show()
203 |
--------------------------------------------------------------------------------
/rslo/torchplus/train/optim.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, Iterable
2 |
3 | import torch
4 | from copy import deepcopy
5 | from itertools import chain
6 | from torch.autograd import Variable
7 |
8 | required = object()
9 |
10 | def param_fp32_copy(params):
11 | param_copy = [
12 | param.clone().type(torch.cuda.FloatTensor).detach() for param in params
13 | ]
14 | for param in param_copy:
15 | param.requires_grad = True
16 | return param_copy
17 |
18 | def set_grad(params, params_with_grad, scale=1.0):
19 | for param, param_w_grad in zip(params, params_with_grad):
20 | if param.grad is None:
21 | param.grad = torch.nn.Parameter(
22 | param.data.new().resize_(*param.data.size()))
23 | grad = param_w_grad.grad.data
24 | if scale is not None:
25 | grad /= scale
26 | if torch.isnan(grad).any() or torch.isinf(grad).any():
27 | return True # invalid grad
28 | param.grad.data.copy_(grad)
29 | return False
30 |
31 | class MixedPrecisionWrapper(object):
32 | """mixed precision optimizer wrapper.
33 | Arguments:
34 | optimizer (torch.optim.Optimizer): an instance of
35 | :class:`torch.optim.Optimizer`
36 | scale: (float): a scalar for grad scale.
37 | auto_scale: (bool): whether enable auto scale.
38 | The algorihm of auto scale is discribled in
39 | http://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
40 | """
41 |
42 | def __init__(self,
43 | optimizer,
44 | scale=None,
45 | auto_scale=True,
46 | inc_factor=2.0,
47 | dec_factor=0.5,
48 | num_iters_be_stable=500):
49 | # if not isinstance(optimizer, torch.optim.Optimizer):
50 | # raise ValueError("must provide a torch.optim.Optimizer")
51 | self.optimizer = optimizer
52 | if hasattr(self.optimizer, 'name'):
53 | self.name = self.optimizer.name # for ckpt system
54 | param_groups_copy = []
55 | for i, group in enumerate(optimizer.param_groups):
56 | group_copy = {n: v for n, v in group.items() if n != 'params'}
57 | group_copy['params'] = param_fp32_copy(group['params'])
58 | param_groups_copy.append(group_copy)
59 |
60 | # switch param_groups, may be dangerous
61 | self.param_groups = optimizer.param_groups
62 | optimizer.param_groups = param_groups_copy
63 | self.grad_scale = scale
64 | self.auto_scale = auto_scale
65 | self.inc_factor = inc_factor
66 | self.dec_factor = dec_factor
67 | self.stable_iter_count = 0
68 | self.num_iters_be_stable = num_iters_be_stable
69 |
70 | def __getstate__(self):
71 | return self.optimizer.__getstate__()
72 |
73 | def __setstate__(self, state):
74 | return self.optimizer.__setstate__(state)
75 |
76 | def __repr__(self):
77 | return self.optimizer.__repr__()
78 |
79 | def state_dict(self):
80 | return self.optimizer.state_dict()
81 |
82 | def load_state_dict(self, state_dict):
83 | return self.optimizer.load_state_dict(state_dict)
84 |
85 | def zero_grad(self):
86 | return self.optimizer.zero_grad()
87 |
88 | def step(self, closure=None):
89 | for g, g_copy in zip(self.param_groups, self.optimizer.param_groups):
90 | invalid = set_grad(g_copy['params'], g['params'], self.grad_scale)
91 | if invalid:
92 | if self.grad_scale is None or self.auto_scale is False:
93 | raise ValueError("nan/inf detected but auto_scale disabled.")
94 | self.grad_scale *= self.dec_factor
95 | print('scale decay to {}'.format(self.grad_scale))
96 | return
97 | if self.auto_scale is True:
98 | self.stable_iter_count += 1
99 | if self.stable_iter_count > self.num_iters_be_stable:
100 | if self.grad_scale is not None:
101 | self.grad_scale *= self.inc_factor
102 | self.stable_iter_count = 0
103 |
104 | if closure is None:
105 | self.optimizer.step()
106 | else:
107 | self.optimizer.step(closure)
108 | for g, g_copy in zip(self.param_groups, self.optimizer.param_groups):
109 | for p_copy, p in zip(g_copy['params'], g['params']):
110 | p.data.copy_(p_copy.data)
111 |
--------------------------------------------------------------------------------
/rslo/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import time
2 | import contextlib
3 | import torch
4 |
5 | @contextlib.contextmanager
6 | def torch_timer(name=''):
7 | torch.cuda.synchronize()
8 | t = time.time()
9 | yield
10 | torch.cuda.synchronize()
11 | print(name, "time:", time.time() - t)
--------------------------------------------------------------------------------
/rslo/utils/check.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def is_array_like(x):
4 | return isinstance(x, (list, tuple, np.ndarray))
5 |
6 | def shape_mergeable(x, expected_shape):
7 | mergeable = True
8 | if is_array_like(x) and is_array_like(expected_shape):
9 | x = np.array(x)
10 | if len(x.shape) == len(expected_shape):
11 | for s, s_ex in zip(x.shape, expected_shape):
12 | if s_ex is not None and s != s_ex:
13 | mergeable = False
14 | break
15 | return mergeable
--------------------------------------------------------------------------------
/rslo/utils/config_tool.py:
--------------------------------------------------------------------------------
1 | # This file contains some config modification function.
2 | # some functions should be only used for KITTI dataset.
3 |
4 | from google.protobuf import text_format
5 | from second.protos import pipeline_pb2, second_pb2
6 | from pathlib import Path
7 | import numpy as np
8 |
9 |
10 | def change_detection_range(model_config, new_range):
11 | assert len(new_range) == 4, "you must provide a list such as [-50, -50, 50, 50]"
12 | old_pc_range = list(model_config.voxel_generator.point_cloud_range)
13 | old_pc_range[:2] = new_range[:2]
14 | old_pc_range[3:5] = new_range[2:]
15 | model_config.voxel_generator.point_cloud_range[:] = old_pc_range
16 | for anchor_generator in model_config.target_assigner.anchor_generators:
17 | a_type = anchor_generator.WhichOneof('anchor_generator')
18 | if a_type == "anchor_generator_range":
19 | a_cfg = anchor_generator.anchor_generator_range
20 | old_a_range = list(a_cfg.anchor_ranges)
21 | old_a_range[:2] = new_range[:2]
22 | old_a_range[3:5] = new_range[2:]
23 | a_cfg.anchor_ranges[:] = old_a_range
24 | elif a_type == "anchor_generator_stride":
25 | a_cfg = anchor_generator.anchor_generator_stride
26 | old_offset = list(a_cfg.offsets)
27 | stride = list(a_cfg.strides)
28 | old_offset[0] = new_range[0] + stride[0] / 2
29 | old_offset[1] = new_range[1] + stride[1] / 2
30 | a_cfg.offsets[:] = old_offset
31 | else:
32 | raise ValueError("unknown")
33 | old_post_range = list(model_config.post_center_limit_range)
34 | old_post_range[:2] = new_range[:2]
35 | old_post_range[3:5] = new_range[2:]
36 | model_config.post_center_limit_range[:] = old_post_range
37 |
38 | def get_downsample_factor(model_config):
39 | downsample_factor = np.prod(model_config.rpn.layer_strides)
40 | if len(model_config.rpn.upsample_strides) > 0:
41 | downsample_factor /= model_config.rpn.upsample_strides[-1]
42 | downsample_factor *= model_config.middle_feature_extractor.downsample_factor
43 | downsample_factor = int(downsample_factor)
44 | assert downsample_factor > 0
45 | return downsample_factor
46 |
47 |
48 | if __name__ == "__main__":
49 | config_path = "/home/yy/deeplearning/deeplearning/mypackages/second/configs/car.lite.1.config"
50 | config = pipeline_pb2.TrainEvalPipelineConfig()
51 |
52 | with open(config_path, "r") as f:
53 | proto_str = f.read()
54 | text_format.Merge(proto_str, config)
55 |
56 | change_detection_range(config, [-50, -50, 50, 50])
57 | proto_str = text_format.MessageToString(config, indent=2)
58 | print(proto_str)
59 |
60 |
--------------------------------------------------------------------------------
/rslo/utils/config_tool/__init__.py:
--------------------------------------------------------------------------------
1 | # This file contains some config modification function.
2 | # some functions should be only used for KITTI dataset.
3 |
4 | from google.protobuf import text_format
5 | from rslo.protos import pipeline_pb2, second_pb2
6 | from pathlib import Path
7 | import numpy as np
8 |
9 |
10 | def read_config(path):
11 | config = pipeline_pb2.TrainEvalPipelineConfig()
12 |
13 | with open(path, "r") as f:
14 | proto_str = f.read()
15 | text_format.Merge(proto_str, config)
16 | return config
17 |
18 |
19 | def change_detection_range(model_config, new_range):
20 | assert len(
21 | new_range) == 4, "you must provide a list such as [-50, -50, 50, 50]"
22 | old_pc_range = list(model_config.voxel_generator.point_cloud_range)
23 | old_pc_range[:2] = new_range[:2]
24 | old_pc_range[3:5] = new_range[2:]
25 | model_config.voxel_generator.point_cloud_range[:] = old_pc_range
26 | for anchor_generator in model_config.target_assigner.anchor_generators:
27 | a_type = anchor_generator.WhichOneof('anchor_generator')
28 | if a_type == "anchor_generator_range":
29 | a_cfg = anchor_generator.anchor_generator_range
30 | old_a_range = list(a_cfg.anchor_ranges)
31 | old_a_range[:2] = new_range[:2]
32 | old_a_range[3:5] = new_range[2:]
33 | a_cfg.anchor_ranges[:] = old_a_range
34 | elif a_type == "anchor_generator_stride":
35 | a_cfg = anchor_generator.anchor_generator_stride
36 | old_offset = list(a_cfg.offsets)
37 | stride = list(a_cfg.strides)
38 | old_offset[0] = new_range[0] + stride[0] / 2
39 | old_offset[1] = new_range[1] + stride[1] / 2
40 | a_cfg.offsets[:] = old_offset
41 | else:
42 | raise ValueError("unknown")
43 | old_post_range = list(model_config.post_center_limit_range)
44 | old_post_range[:2] = new_range[:2]
45 | old_post_range[3:5] = new_range[2:]
46 | model_config.post_center_limit_range[:] = old_post_range
47 |
48 |
49 | def get_downsample_factor(model_config):
50 | downsample_factor = np.prod(model_config.rpn.layer_strides)
51 | if len(model_config.rpn.upsample_strides) > 0:
52 | downsample_factor /= model_config.rpn.upsample_strides[-1]
53 | downsample_factor *= model_config.middle_feature_extractor.downsample_factor
54 | downsample_factor = np.round(downsample_factor).astype(np.int64)
55 | assert downsample_factor > 0
56 | return downsample_factor
57 |
58 |
59 | if __name__ == "__main__":
60 | config_path = "/home/yy/deeplearning/deeplearning/mypackages/second/configs/car.lite.1.config"
61 | config = pipeline_pb2.TrainEvalPipelineConfig()
62 |
63 | with open(config_path, "r") as f:
64 | proto_str = f.read()
65 | text_format.Merge(proto_str, config)
66 |
67 | change_detection_range(config, [-50, -50, 50, 50])
68 | proto_str = text_format.MessageToString(config, indent=2)
69 | print(proto_str)
70 |
--------------------------------------------------------------------------------
/rslo/utils/config_tool/train.py:
--------------------------------------------------------------------------------
1 | from second.protos.optimizer_pb2 import Optimizer, LearningRate, OneCycle, ManualStepping, ExponentialDecay
2 | from second.protos.sampler_pb2 import Sampler
3 | from second.utils.config_tool import read_config
4 | from pathlib import Path
5 | from google.protobuf import text_format
6 | from second.data.all_dataset import get_dataset_class
7 |
8 | def _get_optim_cfg(train_config, optim):
9 | if optim == "adam_optimizer":
10 | return train_config.optimizer.adam_optimizer
11 | elif optim == "rms_prop_optimizer":
12 | return train_config.optimizer.rms_prop_optimizer
13 | elif optim == "momentum_optimizer":
14 | return train_config.optimizer.momentum_optimizer
15 | else:
16 | raise NotImplementedError
17 |
18 |
19 | def manual_stepping(train_config, boundaries, rates, optim="adam_optimizer"):
20 | optim_cfg = _get_optim_cfg(train_config, optim)
21 | optim_cfg.learning_rate.manual_stepping.CopyFrom(
22 | ManualStepping(boundaries=boundaries, rates=rates))
23 |
24 |
25 | def exp_decay(train_config,
26 | init_lr,
27 | decay_length,
28 | decay_factor,
29 | staircase=True,
30 | optim="adam_optimizer"):
31 | optim_cfg = _get_optim_cfg(train_config, optim)
32 | optim_cfg.learning_rate.exponential_decay.CopyFrom(
33 | ExponentialDecay(
34 | initial_learning_rate=init_lr,
35 | decay_length=decay_length,
36 | decay_factor=decay_factor,
37 | staircase=staircase))
38 |
39 |
40 | def one_cycle(train_config,
41 | lr_max,
42 | moms,
43 | div_factor,
44 | pct_start,
45 | optim="adam_optimizer"):
46 | optim_cfg = _get_optim_cfg(train_config, optim)
47 | optim_cfg.learning_rate.one_cycle.CopyFrom(
48 | OneCycle(
49 | lr_max=lr_max,
50 | moms=moms,
51 | div_factor=div_factor,
52 | pct_start=pct_start))
53 |
54 | def _div_up(a, b):
55 | return (a + b - 1) // b
56 |
57 | def set_train_step(config,
58 | epochs,
59 | eval_epoch):
60 | input_cfg = config.train_input_reader
61 | train_cfg = config.train_config
62 | batch_size = input_cfg.batch_size
63 | dataset_name = input_cfg.dataset.dataset_class_name
64 | ds = get_dataset_class(dataset_name)(
65 | root_path=input_cfg.dataset.kitti_root_path,
66 | info_path=input_cfg.dataset.kitti_info_path,
67 | )
68 | num_examples_after_sample = len(ds)
69 | step_per_epoch = _div_up(num_examples_after_sample, batch_size)
70 | step_per_eval = step_per_epoch * eval_epoch
71 | total_step = step_per_epoch * epochs
72 | train_cfg.steps = total_step
73 | train_cfg.steps_per_eval = step_per_eval
74 |
75 | def disable_sample(config):
76 | input_cfg = config.train_input_reader
77 | input_cfg.database_sampler.CopyFrom(Sampler())
78 |
79 | def disable_per_gt_aug(config):
80 | prep_cfg = config.train_input_reader.preprocess
81 | prep_cfg.groundtruth_localization_noise_std[:] = [0, 0, 0]
82 | prep_cfg.groundtruth_rotation_uniform_noise[:] = [0, 0]
83 |
84 | def disable_global_aug(config):
85 | prep_cfg = config.train_input_reader.preprocess
86 | prep_cfg.global_rotation_uniform_noise[:] = [0, 0]
87 | prep_cfg.global_scaling_uniform_noise[:] = [0, 0]
88 | prep_cfg.global_random_rotation_range_per_object[:] = [0, 0]
89 | prep_cfg.global_translate_noise_std[:] = [0, 0, 0]
90 |
91 | if __name__ == "__main__":
92 | path = Path(__file__).resolve().parents[2] / "configs/car.lite.config"
93 | config = read_config(path)
94 | manual_stepping(config.train_config, [0.8, 0.9], [1e-4, 1e-5, 1e-6])
95 |
96 | print(text_format.MessageToString(config, indent=2))
--------------------------------------------------------------------------------
/rslo/utils/find.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import os
4 | import subprocess
5 | import sys
6 | import tempfile
7 | from pathlib import Path
8 |
9 | import fire
10 |
11 |
12 | def _get_info_from_anaconda_info(info, split=":"):
13 | info = info.strip("\n").replace(" ", "")
14 | info_dict = {}
15 | latest_key = ""
16 | for line in info.splitlines():
17 | if split in line:
18 | pair = line.split(split)
19 | info_dict[pair[0]] = pair[1]
20 | latest_key = pair[0]
21 | else:
22 | if not isinstance(info_dict[latest_key], list):
23 | info_dict[latest_key] = [info_dict[latest_key]]
24 | info_dict[latest_key].append(line)
25 | return info_dict
26 |
27 |
28 | def find_anaconda():
29 | # try find in default path
30 | path = Path.home() / "anaconda3"
31 | if path.exists():
32 | return path
33 | # try conda in cmd
34 | try:
35 | info = subprocess.check_output(
36 | "conda info", shell=True).decode('utf-8')
37 | info_dict = _get_info_from_anaconda_info(info)
38 | return info_dict["activeenvlocation"]
39 | except subprocess.CalledProcessError:
40 | raise RuntimeError("find anadonda failed")
41 |
42 |
43 | def find_cuda():
44 | '''Finds the CUDA install path.'''
45 | # Guess #1
46 | cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
47 | if cuda_home is None:
48 | # Guess #2
49 | if sys.platform == 'win32':
50 | cuda_homes = glob.glob(
51 | 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
52 | if len(cuda_homes) == 0:
53 | cuda_home = ''
54 | else:
55 | cuda_home = cuda_homes[0]
56 | else:
57 | cuda_home = '/usr/local/cuda'
58 | if not os.path.exists(cuda_home):
59 | # Guess #3
60 | try:
61 | which = 'where' if sys.platform == 'win32' else 'which'
62 | nvcc = subprocess.check_output(
63 | [which, 'nvcc']).decode().rstrip('\r\n')
64 | cuda_home = os.path.dirname(os.path.dirname(nvcc))
65 | except Exception:
66 | cuda_home = None
67 | if cuda_home is None:
68 | raise RuntimeError(
69 | "No CUDA runtime is found, using CUDA_HOME='{}'".format(cuda_home))
70 | return cuda_home
71 |
72 |
73 | def find_cuda_device_arch():
74 | if sys.platform == 'win32':
75 | # TODO: add windows support
76 | return None
77 | cuda_home = find_cuda()
78 | if cuda_home is None:
79 | return None
80 | cuda_home = Path(cuda_home)
81 | try:
82 | device_query_path = cuda_home / 'extras/demo_suite/deviceQuery'
83 | if not device_query_path.exists():
84 | source = """
85 | #include
86 | #include
87 | int main(){
88 | int nDevices;
89 | cudaGetDeviceCount(&nDevices);
90 | for (int i = 0; i < nDevices; i++) {
91 | cudaDeviceProp prop;
92 | cudaGetDeviceProperties(&prop, i);
93 | std::cout << prop.major << "." << prop.minor << std::endl;
94 | }
95 | return 0;
96 | }
97 | """
98 | with tempfile.NamedTemporaryFile('w', suffix='.cc') as f:
99 | f_path = Path(f.name)
100 | f.write(source)
101 | f.flush()
102 | try:
103 | # TODO: add windows support
104 | cmd = (
105 | f"g++ {f.name} -o {f_path.stem}"
106 | f" -I{cuda_home / 'include'} -L{cuda_home / 'lib64'} -lcudart"
107 | )
108 | print(cmd)
109 | subprocess.check_output(cmd, shell=True, cwd=f_path.parent)
110 | cmd = f"./{f_path.stem}"
111 | arches = subprocess.check_output(
112 | cmd, shell=True,
113 | cwd=f_path.parent).decode().rstrip('\r\n').split("\n")
114 | if len(arches) < 1:
115 | return None
116 | arch = arches[0]
117 | except:
118 | return None
119 | else:
120 | cmd = f"{str(device_query_path)} | grep 'CUDA Capability'"
121 | arch = subprocess.check_output(
122 | cmd, shell=True).decode().rstrip('\r\n').split(" ")[-1]
123 | # assert len(arch) == 2
124 | arch_list = [int(s) for s in arch.split(".")]
125 | arch_int = arch_list[0] * 10 + arch_list[1]
126 | find_work_arch = False
127 | while arch_int > 10:
128 | try:
129 | res = subprocess.check_output("nvcc -arch=sm_{}".format(arch_int), shell=True, stderr=subprocess.STDOUT)
130 | except subprocess.CalledProcessError as e:
131 | if "No input files specified" in e.output.decode():
132 | find_work_arch = True
133 | break
134 | elif "is not defined for option 'gpu-architecture'" in e.output.decode():
135 | arch_int -= 1
136 | else:
137 | raise RuntimeError("unknown error")
138 | if find_work_arch:
139 | arch = f"sm_{arch_int}"
140 | else:
141 | arch = None
142 |
143 | except Exception:
144 | arch = None
145 | return arch
146 |
147 |
148 | def get_gpu_memory_usage():
149 | if sys.platform == 'win32':
150 | # TODO: add windows support
151 | return None
152 | cuda_home = find_cuda()
153 | if cuda_home is None:
154 | return None
155 | cuda_home = Path(cuda_home)
156 | source = """
157 | #include
158 | #include
159 | int main(){
160 | int nDevices;
161 | cudaGetDeviceCount(&nDevices);
162 | size_t free_m, total_m;
163 | // output json format.
164 | std::cout << "[";
165 | for (int i = 0; i < nDevices; i++) {
166 | cudaSetDevice(i);
167 | cudaMemGetInfo(&free_m, &total_m);
168 | std::cout << "[" << free_m << "," << total_m << "]";
169 | if (i != nDevices - 1)
170 | std::cout << "," << std::endl;
171 | }
172 | std::cout << "]" << std::endl;
173 | return 0;
174 | }
175 | """
176 | with tempfile.NamedTemporaryFile('w', suffix='.cc') as f:
177 | f_path = Path(f.name)
178 | f.write(source)
179 | f.flush()
180 | try:
181 | # TODO: add windows support
182 | cmd = (
183 | f"g++ {f.name} -o {f_path.stem} -std=c++11"
184 | f" -I{cuda_home / 'include'} -L{cuda_home / 'lib64'} -lcudart")
185 | print(cmd)
186 | subprocess.check_output(cmd, shell=True, cwd=f_path.parent)
187 | cmd = f"./{f_path.stem}"
188 | usages = subprocess.check_output(
189 | cmd, shell=True, cwd=f_path.parent).decode()
190 | usages = json.loads(usages)
191 | return usages
192 | except:
193 | return None
194 | return None
195 |
196 |
197 | if __name__ == "__main__":
198 | print(find_cuda_device_arch())
199 | # fire.Fire()
200 |
--------------------------------------------------------------------------------
/rslo/utils/loader.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from pathlib import Path
3 | import sys
4 | import os
5 | import logging
6 | logger = logging.getLogger('second.utils.loader')
7 |
8 | CUSTOM_LOADED_MODULES = {}
9 |
10 |
11 | def _get_possible_module_path(paths):
12 | ret = []
13 | for p in paths:
14 | p = Path(p)
15 | for path in p.glob("*"):
16 | if path.suffix in ["py", ".so"] or (path.is_dir()):
17 | if path.stem.isidentifier():
18 | ret.append(path)
19 | return ret
20 |
21 |
22 | def _get_regular_import_name(path, module_paths):
23 | path = Path(path)
24 | for mp in module_paths:
25 | mp = Path(mp)
26 | if mp == path:
27 | return path.stem
28 | try:
29 | relative_path = path.relative_to(Path(mp))
30 | parts = list((relative_path.parent / relative_path.stem).parts)
31 | module_name = '.'.join([mp.stem] + parts)
32 | return module_name
33 | except:
34 | pass
35 | return None
36 |
37 |
38 | def import_file(path, name: str = None, add_to_sys=True,
39 | disable_warning=False):
40 | global CUSTOM_LOADED_MODULES
41 | path = Path(path)
42 | module_name = path.stem
43 | try:
44 | user_paths = os.environ['PYTHONPATH'].split(os.pathsep)
45 | except KeyError:
46 | user_paths = []
47 | possible_paths = _get_possible_module_path(user_paths)
48 | model_import_name = _get_regular_import_name(path, possible_paths)
49 | if model_import_name is not None:
50 | return import_name(model_import_name)
51 | if name is not None:
52 | module_name = name
53 | spec = importlib.util.spec_from_file_location(module_name, path)
54 | module = importlib.util.module_from_spec(spec)
55 | spec.loader.exec_module(module)
56 | if not disable_warning:
57 | logger.warning((
58 | f"Failed to perform regular import for file {path}. "
59 | "this means this file isn't in any folder in PYTHONPATH "
60 | "or don't have __init__.py in that project. "
61 | "directly file import may fail and some reflecting features are "
62 | "disabled even if import succeed. please add your project to PYTHONPATH "
63 | "or add __init__.py to ensure this file can be regularly imported. "
64 | ))
65 |
66 | if add_to_sys: # this will enable find objects defined in a file.
67 | # avoid replace system modules.
68 | if module_name in sys.modules and module_name not in CUSTOM_LOADED_MODULES:
69 | raise ValueError(f"{module_name} exists in system.")
70 | CUSTOM_LOADED_MODULES[module_name] = module
71 | sys.modules[module_name] = module
72 | return module
73 |
74 |
75 | def import_name(name, package=None):
76 | module = importlib.import_module(name, package)
77 | return module
--------------------------------------------------------------------------------
/rslo/utils/log_tool.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tensorboardX import SummaryWriter
3 | import json
4 | from pathlib import Path
5 |
6 |
7 | def _flat_nested_json_dict(json_dict, flatted, sep=".", start=""):
8 | for k, v in json_dict.items():
9 | if isinstance(v, dict):
10 | _flat_nested_json_dict(v, flatted, sep, start + sep + str(k))
11 | else:
12 | flatted[start + sep + str(k)] = v
13 |
14 |
15 | def flat_nested_json_dict(json_dict, sep=".") -> dict:
16 | """flat a nested json-like dict. this function make shadow copy.
17 | """
18 | flatted = {}
19 | for k, v in json_dict.items():
20 | if isinstance(v, dict):
21 | _flat_nested_json_dict(v, flatted, sep, str(k))
22 | else:
23 | flatted[str(k)] = v
24 | return flatted
25 |
26 |
27 | def metric_to_str(metrics, sep='.'):
28 | flatted_metrics = flat_nested_json_dict(metrics, sep)
29 | metrics_str_list = []
30 | for k, v in flatted_metrics.items():
31 | if isinstance(v, float):
32 | metrics_str_list.append(f"{k}={v:.5}")
33 | elif isinstance(v, (list, tuple)):
34 | if v and isinstance(v[0], float):
35 | v_str = ', '.join([f"{e:.5}" for e in v])
36 | metrics_str_list.append(f"{k}=[{v_str}]")
37 | else:
38 | metrics_str_list.append(f"{k}={v}")
39 | else:
40 | metrics_str_list.append(f"{k}={v}")
41 | return ', '.join(metrics_str_list)
42 |
43 |
44 | class SimpleModelLog:
45 | """For simple log.
46 | generate 4 kinds of log:
47 | 1. simple log.txt, all metric dicts are flattened to produce
48 | readable results.
49 | 2. TensorBoard scalars and texts
50 | 3. multi-line json file log.json.lst
51 | 4. tensorboard_scalars.json, all scalars are stored in this file
52 | in tensorboard json format.
53 | """
54 |
55 | def __init__(self, model_dir, disable=False):
56 | self.model_dir = Path(model_dir)
57 | self.log_file = None
58 | self.log_mjson_file = None
59 | self.summary_writter = None
60 | self.metrics = []
61 | self._text_current_gstep = -1
62 | self._tb_texts = []
63 | self.disable = disable
64 | def __del__(self):
65 | self.close()
66 |
67 | def open(self):
68 | if self.disable:
69 | return self
70 | model_dir = self.model_dir
71 | assert model_dir.exists()
72 | summary_dir = model_dir / 'summary'
73 | summary_dir.mkdir(parents=True, exist_ok=True)
74 |
75 | log_mjson_file_path = model_dir / f'log.json.lst'
76 | if log_mjson_file_path.exists():
77 | with open(log_mjson_file_path, 'r') as f:
78 | for line in f.readlines():
79 | self.metrics.append(json.loads(line))
80 | log_file_path = model_dir / f'log.txt'
81 | self.log_mjson_file = open(log_mjson_file_path, 'a')
82 | self.log_file = open(log_file_path, 'a')
83 | self.summary_writter = SummaryWriter(str(summary_dir),flush_secs=60)
84 | return self
85 |
86 | def close(self):
87 | if self.disable:
88 | return
89 | assert self.summary_writter is not None
90 | self.log_mjson_file.close()
91 | self.log_file.close()
92 | tb_json_path = str(self.model_dir / "tensorboard_scalars.json")
93 | self.summary_writter.export_scalars_to_json(tb_json_path)
94 | self.summary_writter.close()
95 | self.log_mjson_file = None
96 | self.log_file = None
97 | self.summary_writter = None
98 |
99 | def log_text(self, text, step, tag="regular log"):
100 | if self.disable:
101 | return
102 | """This function only add text to log.txt and tensorboard texts
103 | """
104 | print(text,flush=True)
105 | print(text, file=self.log_file,flush=True)
106 | if step > self._text_current_gstep and self._text_current_gstep != -1:
107 | total_text = '\n'.join(self._tb_texts)
108 | self.summary_writter.add_text(tag, total_text, global_step=step)
109 | self._tb_texts = []
110 | self._text_current_gstep = step
111 | else:
112 | self._tb_texts.append(text)
113 | if self._text_current_gstep == -1:
114 | self._text_current_gstep = step
115 |
116 | def log_metrics(self, metrics: dict, step):
117 | if self.disable:
118 | return
119 | flatted_summarys = flat_nested_json_dict(metrics, "/")
120 | for k, v in flatted_summarys.items():
121 | if isinstance(v, (list, tuple)):
122 | if any([isinstance(e, str) for e in v]):
123 | continue
124 | v_dict = {str(i): e for i, e in enumerate(v)}
125 | for k1, v1 in v_dict.items():
126 | self.summary_writter.add_scalar(k + "/" + k1, v1, step)
127 | else:
128 | if isinstance(v, str):
129 | continue
130 | self.summary_writter.add_scalar(k, v, step)
131 | log_str = metric_to_str(metrics)
132 | print(log_str, flush=True)
133 | print(log_str, file=self.log_file, flush=True)
134 | print(json.dumps(metrics), file=self.log_mjson_file, flush=True)
135 |
136 | def log_images(self, images: dict, step, prefix=''):
137 | if self.disable:
138 | return
139 | for k, v in images.items():
140 | self.summary_writter.add_images(prefix+str(k), v, step)
141 | print(f"Summarize images {k}",flush=True)
142 |
143 | def log_histograms(self, vals: dict, step, prefix=''):
144 | if self.disable:
145 | return
146 | for k, v in vals.items():
147 | self.summary_writter.add_histogram(prefix+str(k), v, step)
148 | print(f"Summarize histograms {k}",flush=True)
--------------------------------------------------------------------------------
/rslo/utils/math.py:
--------------------------------------------------------------------------------
1 | import kornia
2 | import torch
3 |
--------------------------------------------------------------------------------
/rslo/utils/progress_bar.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import enum
3 | import math
4 | import time
5 |
6 | import numpy as np
7 |
8 |
9 | def progress_str(val, *args, width=20, with_ptg=True):
10 | val = max(0., min(val, 1.))
11 | assert width > 1
12 | pos = round(width * val) - 1
13 | if with_ptg is True:
14 | log = '[{}%]'.format(max_point_str(val * 100.0, 4))
15 | log += '['
16 | for i in range(width):
17 | if i < pos:
18 | log += '='
19 | elif i == pos:
20 | log += '>'
21 | else:
22 | log += '.'
23 | log += ']'
24 | for arg in args:
25 | log += '[{}]'.format(arg)
26 | return log
27 |
28 |
29 | def second_to_time_str(second, omit_hours_if_possible=True):
30 | second = int(second)
31 | m, s = divmod(second, 60)
32 | h, m = divmod(m, 60)
33 | if omit_hours_if_possible:
34 | if h == 0:
35 | return '{:02d}:{:02d}'.format(m, s)
36 | return '{:02d}:{:02d}:{:02d}'.format(h, m, s)
37 |
38 |
39 | def progress_bar_iter(task_list, width=20, with_ptg=True, step_time_average=50, name=None):
40 | total_step = len(task_list)
41 | step_times = []
42 | start_time = 0.0
43 | name = '' if name is None else f"[{name}]"
44 | for i, task in enumerate(task_list):
45 | t = time.time()
46 | yield task
47 | step_times.append(time.time() - t)
48 | start_time += step_times[-1]
49 | start_time_str = second_to_time_str(start_time)
50 | average_step_time = np.mean(step_times[-step_time_average:]) + 1e-6
51 | speed_str = "{:.2f}it/s".format(1 / average_step_time)
52 | remain_time = (total_step - i) * average_step_time
53 | remain_time_str = second_to_time_str(remain_time)
54 | time_str = start_time_str + '>' + remain_time_str
55 | prog_str = progress_str(
56 | (i + 1) / total_step,
57 | speed_str,
58 | time_str,
59 | width=width,
60 | with_ptg=with_ptg)
61 | print(name + prog_str + ' ', end='\r', flush=True)
62 | print("")
63 |
64 |
65 | list_bar = progress_bar_iter
66 |
67 | def enumerate_bar(task_list, width=20, with_ptg=True, step_time_average=50, name=None):
68 | total_step = len(task_list)
69 | step_times = []
70 | start_time = 0.0
71 | name = '' if name is None else f"[{name}]"
72 | for i, task in enumerate(task_list):
73 | t = time.time()
74 | yield i, task
75 | step_times.append(time.time() - t)
76 | start_time += step_times[-1]
77 | start_time_str = second_to_time_str(start_time)
78 | average_step_time = np.mean(step_times[-step_time_average:]) + 1e-6
79 | speed_str = "{:.2f}it/s".format(1 / average_step_time)
80 | remain_time = (total_step - i) * average_step_time
81 | remain_time_str = second_to_time_str(remain_time)
82 | time_str = start_time_str + '>' + remain_time_str
83 | prog_str = progress_str(
84 | (i + 1) / total_step,
85 | speed_str,
86 | time_str,
87 | width=width,
88 | with_ptg=with_ptg)
89 | print(name + prog_str + ' ', end='\r', flush=True)
90 | print("")
91 |
92 |
93 | def max_point_str(val, max_point):
94 | positive = bool(val >= 0.0)
95 | val = np.abs(val)
96 | if val == 0:
97 | point = 1
98 | else:
99 | point = max(int(np.log10(val)), 0) + 1
100 | fmt = "{:." + str(max(max_point - point, 0)) + "f}"
101 | if positive is True:
102 | return fmt.format(val)
103 | else:
104 | return fmt.format(-val)
105 |
106 |
107 | class Unit(enum.Enum):
108 | Iter = 'iter'
109 | Byte = 'byte'
110 |
111 |
112 | def convert_size(size_bytes):
113 | # from https://stackoverflow.com/questions/5194057/better-way-to-convert-file-sizes-in-python
114 | if size_bytes == 0:
115 | return "0B"
116 | size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
117 | i = int(math.floor(math.log(size_bytes, 1024)))
118 | p = math.pow(1024, i)
119 | s = round(size_bytes / p, 2)
120 | return s, size_name[i]
121 |
122 |
123 | class ProgressBar:
124 | def __init__(self,
125 | width=20,
126 | with_ptg=True,
127 | step_time_average=50,
128 | speed_unit=Unit.Iter):
129 | self._width = width
130 | self._with_ptg = with_ptg
131 | self._step_time_average = step_time_average
132 | self._step_times = []
133 | self._start_time = 0.0
134 | self._total_size = None
135 | self._speed_unit = speed_unit
136 |
137 | def start(self, total_size):
138 | self._start = True
139 | self._step_times = []
140 | self._finished_sizes = []
141 | self._time_elapsed = 0.0
142 | self._current_time = time.time()
143 | self._total_size = total_size
144 | self._progress = 0
145 |
146 | def print_bar(self, finished_size=1, pre_string=None, post_string=None):
147 | self._step_times.append(time.time() - self._current_time)
148 | self._finished_sizes.append(finished_size)
149 | self._time_elapsed += self._step_times[-1]
150 | start_time_str = second_to_time_str(self._time_elapsed)
151 | time_per_size = np.array(self._step_times[-self._step_time_average:])
152 | time_per_size /= np.array(
153 | self._finished_sizes[-self._step_time_average:])
154 | average_step_time = np.mean(time_per_size) + 1e-6
155 | if self._speed_unit == Unit.Iter:
156 | speed_str = "{:.2f}it/s".format(1 / average_step_time)
157 | elif self._speed_unit == Unit.Byte:
158 | size, size_unit = convert_size(1 / average_step_time)
159 | speed_str = "{:.2f}{}/s".format(size, size_unit)
160 | else:
161 | raise ValueError("unknown speed unit")
162 | remain_time = (self._total_size - self._progress) * average_step_time
163 | remain_time_str = second_to_time_str(remain_time)
164 | time_str = start_time_str + '>' + remain_time_str
165 | prog_str = progress_str(
166 | (self._progress + 1) / self._total_size,
167 | speed_str,
168 | time_str,
169 | width=self._width,
170 | with_ptg=self._with_ptg)
171 | self._progress += finished_size
172 | if pre_string is not None:
173 | prog_str = pre_string + prog_str
174 | if post_string is not None:
175 | prog_str += post_string
176 | if self._progress >= self._total_size:
177 | print(prog_str + ' ',flush=True)
178 | else:
179 | print(prog_str + ' ', end='\r', flush=True)
180 | self._current_time = time.time()
181 |
--------------------------------------------------------------------------------
/rslo/utils/singleton.py:
--------------------------------------------------------------------------------
1 | import h5py
2 |
3 | try:
4 | import fsspec
5 | from petrel_client.client import Client
6 | except:
7 | Client=None
8 | pass
9 |
10 |
11 | class Singleton(type):
12 | _instances = {}
13 | def __call__(cls, *args, **kwargs):
14 | if cls not in cls._instances:
15 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
16 | return cls._instances[cls]
17 |
18 | class HDF5Singleton(type):
19 | _instances = {}
20 | _file_paths = []
21 | def __call__(cls, file_path, mode='r',libver="latest", swmr=True, rdcc_nbytes=1024**2*15):
22 | # file_path = kwargs.get('file_path')
23 | # if cls not in cls._instances:
24 | if file_path not in cls._file_paths:
25 | # cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
26 | cls._instances[file_path] = super(HDF5Singleton, cls).__call__(file_path, mode, libver, swmr, rdcc_nbytes)
27 | cls._file_paths.append(file_path)
28 | # else:
29 | # cls._instances[file_path].__init__(*args, **kwargs)
30 |
31 | return cls._instances[file_path]
32 |
33 | class HDF5File(metaclass=HDF5Singleton):
34 | # class HDF5File(metaclass=Singleton):
35 |
36 | def __init__(self, file_path, mode='r',libver="latest", swmr=True, rdcc_nbytes=1024**2*15 ):
37 | self.file_path = file_path
38 |
39 |
40 | if Client is not None:
41 | conf_path = '~/petreloss.conf'
42 | print("HDF5 is opening...", flush=True)
43 | client = Client(conf_path)
44 | url = client.generate_presigned_url('s3://Bucket/all.h5')
45 | with fsspec.open(url) as f:
46 | # hf = h5py.File(f)
47 | # self.file = h5py.File(f, mode=mode, libver=libver,swmr=swmr, rdcc_nbytes=rdcc_nbytes)
48 | self.file = h5py.File(f, mode=mode, )#libver=libver,swmr=swmr, rdcc_nbytes=rdcc_nbytes)
49 | print("HDF5 is opened!", flush=True)
50 | else:
51 | self.file = h5py.File(file_path, mode=mode, libver=libver,swmr=swmr, rdcc_nbytes=rdcc_nbytes)
52 |
53 |
54 | # self.file = h5py.File(file_path, mode=mode, libver=libver,swmr=swmr, rdcc_nbytes=rdcc_nbytes)
55 |
56 | def read(self):
57 | # print(self._instances,flush=True)
58 | return self.file
59 |
60 |
61 |
--------------------------------------------------------------------------------
/rslo/utils/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 | from contextlib import contextmanager
3 |
4 | @contextmanager
5 | def simple_timer(name=''):
6 | t = time.time()
7 | yield
8 | print(f"{name} exec time: {time.time() - t}")
--------------------------------------------------------------------------------
/rslo/utils/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import collections
4 |
5 | def freeze_params(params: dict, include: str = None, exclude: str = None):
6 | assert isinstance(params, dict)
7 | include_re = None
8 | if include is not None:
9 | include_re = re.compile(include)
10 | exclude_re = None
11 | if exclude is not None:
12 | exclude_re = re.compile(exclude)
13 | remain_params = []
14 | for k, p in params.items():
15 | if include_re is not None:
16 | if include_re.match(k) is not None:
17 | continue
18 | if exclude_re is not None:
19 | if exclude_re.match(k) is None:
20 | continue
21 | remain_params.append(p)
22 | return remain_params
23 |
24 |
25 | def freeze_params_v2(params: dict, include: str = None, exclude: str = None):
26 | assert isinstance(params, dict)
27 | include_re = None
28 | if include is not None:
29 | include_re = re.compile(include)
30 | exclude_re = None
31 | if exclude is not None:
32 | exclude_re = re.compile(exclude)
33 | for k, p in params.items():
34 | if include_re is not None:
35 | if include_re.match(k) is not None:
36 | p.requires_grad = False
37 | if exclude_re is not None:
38 | if exclude_re.match(k) is None:
39 | p.requires_grad = False
40 |
41 |
42 | def filter_param_dict(state_dict: dict, include: str = None, exclude: str = None):
43 | assert isinstance(state_dict, dict)
44 | include_re = None
45 | if include is not None:
46 | include_re = re.compile(include)
47 | exclude_re = None
48 | if exclude is not None:
49 | exclude_re = re.compile(exclude)
50 | res_dict = {}
51 | for k, p in state_dict.items():
52 | if include_re is not None:
53 | if include_re.match(k) is None:
54 | continue
55 | if exclude_re is not None:
56 | if exclude_re.match(k) is not None:
57 | continue
58 | res_dict[k] = p
59 | return res_dict
60 |
61 | def modify_parameter_name_with_map(state_dict, parameteter_name_map=None):
62 | if parameteter_name_map is None:
63 | return state_dict
64 | for old,new in parameteter_name_map:
65 | for key in list(state_dict.keys()) :
66 | if old in key:
67 | new_key=key.replace(old, new)
68 | state_dict[new_key] = state_dict.pop(key)
69 | return state_dict
70 |
71 | def load_pretrained_model_map_func(state_dict,parameteter_name_map = None, include:str=None, exclude:str=None):
72 | state_dict = filter_param_dict(state_dict, include, exclude)
73 | state_dict = modify_parameter_name_with_map(state_dict, parameteter_name_map)
74 |
75 |
76 |
77 | def list_recursive_op(input_list, op):
78 | assert isinstance(input_list, list)
79 |
80 | for i, v in enumerate(input_list):
81 | if isinstance(v, list):
82 | input_list[i] = list_recursive_op(v, op)
83 | elif isinstance(v, dict):
84 | input_list[i] = dict_recursive_op(v, op)
85 | else:
86 | input_list[i] = op(v)
87 |
88 | return input_list
89 |
90 |
91 | def dict_recursive_op(input_dict, op):
92 | assert isinstance(input_dict, dict)
93 |
94 | for k, v in input_dict.items():
95 | if isinstance(v, dict):
96 | input_dict[k] = dict_recursive_op(v, op)
97 | elif isinstance(v, (list,tuple) ):
98 | input_dict[k] = list_recursive_op(v, op)
99 | else:
100 | input_dict[k] = op(v)
101 |
102 | return input_dict
103 |
104 |
--------------------------------------------------------------------------------
/rslo/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from matplotlib import collections as mc
4 | from io import StringIO, BytesIO
5 | import PIL
6 | import cv2
7 | import rslo.utils.pose_utils_np as pun
8 |
9 |
10 | def pltfig2data(fig):
11 | """
12 | @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
13 | @param fig a matplotlib figure
14 | @return a numpy 3D array of RGBA values
15 | """
16 | # draw the renderer
17 | # fig.canvas.draw()
18 |
19 | # # Get the RGBA buffer from the figure
20 | # w, h = fig.canvas.get_width_height()
21 | # buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)
22 | # buf.shape = (w, h, 4)
23 |
24 | # # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
25 | # buf = np.roll(buf, 3, axis=2)
26 | # buf = buf.astype(float)/255
27 |
28 | # 申请缓冲地址
29 | buffer_ = BytesIO() # StringIO() # using buffer,great way!
30 | # 保存在内存中,而不是在本地磁盘,注意这个默认认为你要保存的就是plt中的内容
31 | fig.savefig(buffer_, format='png')
32 | buffer_.seek(0)
33 | # 用PIL或CV2从内存中读取
34 | dataPIL = PIL.Image.open(buffer_)
35 | # 转换为nparrary,PIL转换就非常快了,data即为所需
36 | data = np.asarray(dataPIL)
37 | data = data.astype(float)/255.
38 | # cv2.imwrite('test.png', data)
39 | # 释放缓存
40 | buffer_.close()
41 | plt.close(fig)
42 |
43 | return data
44 |
45 |
46 | # def draw_odometry(odom_vectors, gt_vectors=None, view='bv', saving_dir=None):
47 |
48 | def draw_trajectory(poses_pred, poses_gt=None, view='bv', saving_dir=None, figure=None, ax=None, color='b', error_step=1, odom_errors=None):
49 | """[summary]
50 |
51 | Arguments:
52 | poses_pred {[np.array]} -- [(N,7)]
53 |
54 | Keyword Arguments:
55 | poses_gt {[np.array]} -- [(N,7)] (default: {None})
56 | view {str} -- [description] (default: {'bv'})
57 | saving_dir {[type]} -- [description] (default: {None})
58 | figure {[type]} -- [description] (default: {None})
59 | ax {[type]} -- [description] (default: {None})
60 | color {str} -- [description] (default: {'b'})
61 | """
62 |
63 | assert(view in ['bv', 'front', 'side'])
64 | translation, rotation = poses_pred[:, :3], poses_pred[:, 3:]
65 | if poses_gt is not None:
66 | assert len(poses_pred) == len(poses_gt)
67 | translation_gt, rotation_gt = poses_gt[:, :3], poses_gt[:, 3:]
68 |
69 | if view == 'bv':
70 | dim0, dim1 = 0, 1
71 | elif view == 'front':
72 | dim0, dim1 = 0, 1
73 | elif view == 'side':
74 | dim0, dim1 = 0, 1
75 |
76 | if figure is None or ax is None:
77 | figure = plt.figure()
78 | ax = figure.add_subplot(111)
79 |
80 | for i in range(1, len(translation)):
81 | if i == 1:
82 | ax.plot([translation[i-1][dim0]], [
83 | translation[i-1][dim1]], '*', markersize=10, color=color)
84 |
85 | ax.plot([translation[i-1][dim0], translation[i][dim0]], [
86 | translation[i-1][dim1], translation[i][dim1]], '-', markersize=0.5, color=color)
87 |
88 | if poses_gt is not None:
89 | ax.plot([translation_gt[i-1][dim0], translation_gt[i][dim0]], [
90 | translation_gt[i-1][dim1], translation_gt[i][dim1]], '-', markersize=0.5, color='r')
91 | if i % 50 == 0:
92 | # plot connection lines
93 | ax.plot([translation[i][dim0], translation_gt[i][dim0]], [
94 | translation[i][dim1], translation_gt[i][dim1]], '-', markersize=0.03, color='gray')
95 |
96 | # and i%error_step==0 and i//error_step
2 |
3 | #include
4 | #include
5 |
6 | __global__
7 | void ChamferDistanceKernel(
8 | int b,
9 | int n,
10 | const float* xyz,
11 | int m,
12 | const float* xyz2,
13 | float* result,
14 | int* result_i)
15 | {
16 | const int batch=512;
17 | __shared__ float buf[batch*3];
18 | for (int i=blockIdx.x;ibest){
130 | result[(i*n+j)]=best;
131 | result_i[(i*n+j)]=best_i;
132 | }
133 | }
134 | __syncthreads();
135 | }
136 | }
137 | }
138 |
139 | void ChamferDistanceKernelLauncher(
140 | const int b, const int n,
141 | const float* xyz,
142 | const int m,
143 | const float* xyz2,
144 | float* result,
145 | int* result_i,
146 | float* result2,
147 | int* result2_i
148 | )
149 | {
150 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i);
151 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i);
152 |
153 | cudaError_t err = cudaGetLastError();
154 | if (err != cudaSuccess)
155 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err));
156 | }
157 | void OneDirectionChamferDistanceKernelLauncher(
158 | const int b, const int n,
159 | const float* xyz,
160 | const int m,
161 | const float* xyz2,
162 | float* result,
163 | int* result_i
164 | // float* result2,
165 | // int* result2_i
166 | )
167 | {
168 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i);
169 | // ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i);
170 |
171 | cudaError_t err = cudaGetLastError();
172 | if (err != cudaSuccess)
173 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err));
174 | }
175 |
176 |
177 | __global__
178 | void ChamferDistanceGradKernel(
179 | int b, int n,
180 | const float* xyz1,
181 | int m,
182 | const float* xyz2,
183 | const float* grad_dist1,
184 | const int* idx1,
185 | float* grad_xyz1,
186 | float* grad_xyz2)
187 | {
188 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2);
223 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1);
224 |
225 | cudaError_t err = cudaGetLastError();
226 | if (err != cudaSuccess)
227 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err));
228 | }
229 | void OneDirectionChamferDistanceGradKernelLauncher(
230 | const int b, const int n,
231 | const float* xyz1,
232 | const int m,
233 | const float* xyz2,
234 | const float* grad_dist1,
235 | const int* idx1,
236 | // const float* grad_dist2,
237 | // const int* idx2,
238 | float* grad_xyz1,
239 | float* grad_xyz2
240 | )
241 | {
242 | cudaMemset(grad_xyz1, 0, b*n*3*4);
243 | cudaMemset(grad_xyz2, 0, b*m*3*4);
244 | ChamferDistanceGradKernel<<>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2);
245 | // ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1);
246 |
247 | cudaError_t err = cudaGetLastError();
248 | if (err != cudaSuccess)
249 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err));
250 | }
--------------------------------------------------------------------------------