├── .gitignore ├── Dockerfile ├── README.md ├── config ├── kitti_eval_ours.prototxt ├── kitti_eval_ours_moredata.prototxt ├── kitti_train_ours.prototxt └── kitti_train_ours_moredata.prototxt ├── demo ├── framework.png ├── output.gif ├── pointcov.png ├── traj+pointcov.png └── traj.png ├── doc └── train.md ├── evaluate.py ├── freeze.yml ├── rslo ├── builder │ ├── __init__.py │ ├── dataset_builder.py │ ├── input_reader_builder.py │ ├── losses_builder.py │ ├── lr_scheduler_builder.py │ ├── optimizer_builder.py │ ├── preprocess_builder.py │ ├── second_builder.py │ └── voxel_builder.py ├── core │ ├── __init__.py │ └── losses.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── kitti_common.py │ ├── kitti_dataset_crossnorm_hdf5.py │ ├── kitti_dataset_hdf5.py │ └── preprocess.py ├── layers │ ├── MaskConv.py │ ├── SparseConv.py │ ├── common.py │ ├── confidence.py │ ├── normalization.py │ ├── se_module.py │ └── svd.py ├── models │ ├── custom_resnet_spc.py │ ├── middle.py │ ├── odom_pred.py │ ├── odom_pred_base.py │ ├── voxel_encoder.py │ └── voxel_odom_net.py ├── protos │ ├── __init__.py │ ├── complile.sh │ ├── input_reader.proto │ ├── input_reader_pb2.py │ ├── losses.proto │ ├── losses.proto.bak │ ├── losses_pb2.py │ ├── model.proto │ ├── model_pb2.py │ ├── optimizer.proto │ ├── optimizer_pb2.py │ ├── pipeline.proto │ ├── pipeline_pb2.py │ ├── preprocess.proto │ ├── preprocess_pb2.py │ ├── sampler.proto │ ├── sampler_pb2.py │ ├── second.proto │ ├── second_pb2.py │ ├── similarity.proto │ ├── similarity_pb2.py │ ├── target.proto.bak │ ├── train.proto │ ├── train_pb2.py │ ├── voxel_generator.proto │ └── voxel_generator_pb2.py ├── torchplus │ ├── __init__.py │ ├── metrics.py │ ├── nn │ │ ├── __init__.py │ │ ├── functional.py │ │ └── modules │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ └── normalization.py │ ├── ops │ │ ├── __init__.py │ │ └── array_ops.py │ ├── tools.py │ └── train │ │ ├── __init__.py │ │ ├── checkpoint.py │ │ ├── common.py │ │ ├── fastai_optim.py │ │ ├── learning_schedules.py │ │ ├── learning_schedules_fastai.py │ │ └── optim.py └── utils │ ├── __init__.py │ ├── check.py │ ├── config_tool.py │ ├── config_tool │ ├── __init__.py │ └── train.py │ ├── distributed_utils.py │ ├── find.py │ ├── geometric.py │ ├── kitti_evaluation.py │ ├── loader.py │ ├── log_tool.py │ ├── math.py │ ├── pose_utils.py │ ├── pose_utils_np.py │ ├── progress_bar.py │ ├── singleton.py │ ├── timer.py │ ├── util.py │ └── visualization.py ├── script ├── create_hdf5.py ├── create_hdf5_crossnormal.py └── eval_ours.sh ├── thirdparty └── chamfer_distance │ ├── __init__.py │ ├── chamfer_distance.cpp │ ├── chamfer_distance.cu │ └── chamfer_distance.py └── train_hdf5.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | # pytype static type analyzer 140 | .pytype/ 141 | 142 | # Cython debug symbols 143 | cython_debug/ 144 | 145 | ### VisualStudioCode ### 146 | .vscode/* 147 | !.vscode/settings.json 148 | !.vscode/tasks.json 149 | !.vscode/launch.json 150 | !.vscode/extensions.json 151 | *.code-workspace 152 | 153 | # Local History for Visual Studio Code 154 | .history/ 155 | 156 | ### VisualStudioCode Patch ### 157 | # Ignore all local history of files 158 | .history 159 | .ionide 160 | 161 | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python 162 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | #FROM ubuntu:18.04 2 | # FROM nvidia/cuda:9.2-devel-ubuntu18.04 3 | FROM nvidia/cuda:9.2-cudnn7-devel-ubuntu18.04 4 | 5 | # Dependencies for glvnd and X11. 6 | RUN apt-get update \ 7 | && apt-get install -y -qq --no-install-recommends \ 8 | libglvnd0 \ 9 | libgl1 \ 10 | libglx0 \ 11 | libegl1 \ 12 | libxext6 \ 13 | libx11-6 \ 14 | && rm -rf /var/lib/apt/lists/* 15 | # Env vars for the nvidia-container-runtime. 16 | ENV NVIDIA_VISIBLE_DEVICES all 17 | ENV NVIDIA_DRIVER_CAPABILITIES graphics,utility,compute 18 | 19 | #env vars for cuda 20 | ENV CUDA_HOME /usr/local/cuda 21 | 22 | #install miniconda 23 | RUN apt-get update --fix-missing && \ 24 | apt-get install -y wget bzip2 ca-certificates curl git && \ 25 | apt-get clean && \ 26 | rm -rf /var/lib/apt/lists/* 27 | 28 | RUN wget --quiet https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh -O ~/miniconda.sh && \ 29 | /bin/bash ~/miniconda.sh -b -p /opt/miniconda3 && \ 30 | rm ~/miniconda.sh && \ 31 | /opt/miniconda3/bin/conda clean -tipsy && \ 32 | ln -s /opt/miniconda3/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 33 | echo ". /opt/miniconda3/etc/profile.d/conda.sh" >> ~/.bashrc && \ 34 | echo "conda activate base" >> ~/.bashrc && \ 35 | echo "conda deactivate && conda activate py37" >> ~/.bashrc 36 | 37 | #https://blog.csdn.net/Mao_Jonah/article/details/89502380 38 | COPY freeze.yml freeze.yml 39 | RUN /opt/miniconda3/bin/conda env create -n py37 -f freeze.yml 40 | #install pytorch 41 | # RUN /opt/miniconda3/bin/conda install pytorch==1.5.1 torchvision==0.6.1 cudatoolkit=9.2 -c pytorch -n py37 42 | # RUN /opt/miniconda3/bin/conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=9.2 -c pytorch -n py37 43 | RUN /opt/miniconda3/bin/conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=9.2 -c pytorch -n py37 44 | 45 | #install cmake 3.13 46 | RUN apt-get update && apt-get install -y software-properties-common 47 | RUN wget -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc | apt-key add - 48 | RUN apt-add-repository 'deb https://apt.kitware.com/ubuntu/ bionic main' && apt-get update && apt-get install -y cmake \ 49 | python3-distutils python-dev python3-dev \ 50 | libboost-all-dev 51 | 52 | 53 | 54 | 55 | #install spconv 56 | 57 | WORKDIR /tmp/unique_for_spconv 58 | RUN git clone --recursive https://github.com/DecaYale/spconv_plus.git 59 | WORKDIR /tmp/unique_for_spconv/spconv_plus 60 | RUN /opt/miniconda3/envs/py37/bin/python setup.py bdist_wheel 61 | RUN cd ./dist && /opt/miniconda3/envs/py37/bin/pip3 install *.whl 62 | 63 | 64 | 65 | 66 | WORKDIR /tmp/ 67 | COPY config.jupyter.tar config.jupyter.tar 68 | RUN tar -xvf config.jupyter.tar -C /root/ 69 | 70 | 71 | # RUN add-apt-repository ppa:ubuntu-toolchain-r/test && apt-get update 72 | # RUN apt-get install -y gcc-5 g++-5 73 | # RUN ls /usr/bin/ | grep gcc 74 | # # RUN ls /usr/bin/ | grep g++ 75 | # RUN mv /usr/bin/gcc /usr/bin/gcc.bak && ln -s /usr/bin/gcc-5 /usr/bin/gcc && gcc --version 76 | # RUN mv /usr/bin/g++ /usr/bin/g++.bak && ln -s /usr/bin/g++-5 /usr/bin/g++ && g++ --version 77 | 78 | #install apex 79 | # RUN . ~/.bashrc && conda activate py37 && git clone https://github.com/NVIDIA/apex.git \ 80 | # RUN . /opt/miniconda3/etc/profile.d/conda.sh && conda init bash 81 | # RUN ls /opt/miniconda3/envs/py37/bin/ | grep pip 82 | # RUN git clone https://github.com/NVIDIA/apex.git \ 83 | # && cd apex && git reset --hard f3a960f80244cf9e80558ab30f7f7e8cbf03c0a0 \ 84 | # && /opt/miniconda3/envs/py37/bin/python setup.py install --cuda_ext --cpp_ext 85 | ENV TORCH_CUDA_ARCH_LIST "6.0 6.2 7.0 7.2" 86 | # make sure we don't overwrite some existing directory called "apex" 87 | WORKDIR /tmp/unique_for_apex 88 | # uninstall Apex if present, twice to make absolutely sure :) 89 | RUN /opt/miniconda3/envs/py37/bin/pip3 uninstall -y apex || : 90 | RUN /opt/miniconda3/envs/py37/bin/pip3 uninstall -y apex || : 91 | # SHA is something the user can touch to force recreation of this Docker layer, 92 | # and therefore force cloning of the latest version of Apex 93 | RUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git 94 | WORKDIR /tmp/unique_for_apex/apex 95 | RUN git checkout f3a960f80244cf9e80558ab30f7f7e8cbf03c0a0 96 | # RUN /opt/miniconda3/envs/py37/bin/pip3 install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 97 | RUN /opt/miniconda3/envs/py37/bin/pip3 install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 98 | 99 | 100 | #other pkgs 101 | RUN apt-get update \ 102 | && apt-get install -y -qq --no-install-recommends \ 103 | cmake build-essential vim xvfb unzip tmux psmisc \ 104 | libx11-dev libassimp-dev \ 105 | mesa-common-dev freeglut3-dev \ 106 | && apt-get clean \ 107 | && rm -rf /var/lib/apt/lists/* 108 | 109 | #create some directories 110 | RUN mkdir -p /home/yxu/Projects/ && ln -s /mnt/workspace/Works /home/yxu/Projects/Works \ 111 | && mkdir -p /DATA/yxu/ && ln -s /mnt/workspace/datasets/ /DATA/yxu/LINEMOD_DEEPIM \ 112 | && ln -s /mnt/workspace/datasets/LINEMOD/ /DATA/yxu/LINEMOD \ 113 | && ln -s /mnt/workspace/datasets/BOP_LINEMOD/ /DATA/yxu/BOP_LINEMOD \ 114 | && mkdir -p /mnt/lustre/xuyan2/ \ 115 | && ln -s /home/yxu/datasets/ /mnt/lustre/xuyan2/datasets 116 | 117 | EXPOSE 8887 8888 8889 10000 10001 10002 118 | WORKDIR / -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RSLO 2 | [[Paper]](https://scholar.google.com/scholar?hl=zh-CN&as_sdt=0%2C5&q=Robust+Self-supervised+LiDAR+Odometry+via+Representative+Structure+Discovery+and+3D+Inherent+Error+Distribution+Modeling&btnG=) 3 | 4 | The code of paper **Robust Self-supervised LiDAR Odometry via Representative Structure Discovery and 3D Inherent Error Modeling** accepted by IEEE Robotics and Automation Letters (RA-L), 2021. 5 | 6 |

7 | animated 8 |

9 | 10 | 11 | 12 | 13 | ## Framework 14 | The self-supervised two-frame odometry network contains three main modules including the Geometric Unit Feature Encoding module, Geometric Unit Transformation Estimation module and the Ego-motion voting module. 15 | 16 | 17 |

18 | alt text 19 |

20 | 21 | ## Estimated Trajectories and Point Covariance Estimations 22 | The comparison (on estimated trajectories) of our method with other competitive baselines (left). The visualization of our estimated point covariances (right). 23 | 24 | 25 | 26 |
27 | 28 |
29 | 30 | 31 | 32 | ## Installation 33 | As the dependencies is complex, a dockerfile has been provide. You need to install [docker](https://docs.docker.com/get-docker/) and [nvidia-docker2](https://github.com/NVIDIA/nvidia-docker) first and then set up the docker image and start up a container with the following commands: 34 | 35 | ``` 36 | cd RSLO 37 | sudo docker build -t rslo . 38 | sudo docker run -it --runtime=nvidia --ipc=host --volume="HOST_VOLUME_YOU_WANT_TO_MAP:DOCKER_VOLUME" -e DISPLAY=$DISPLAY -e QT_X11_NO_MITSHM=1 rslo bash 39 | 40 | ``` 41 | 42 | ## Data Preparation 43 | You need to download the [KITTI odometry dataset](http://www.cvlibs.net/datasets/kitti/eval_odometry.php) and unzip them into the below directory structures. 44 | ``` 45 | ./kitti/dataset 46 | |──sequences 47 | | ├── 00/ 48 | | | ├── calib.txt 49 | | │ ├── velodyne/ 50 | | | | ├── 000000.bin 51 | | | | ├── 000001.bin 52 | | | | └── ... 53 | | ├── 01/ 54 | | | ... 55 | | └── 21/ 56 | └──poses 57 | |──00.txt 58 | |──01.txt 59 | | ... 60 | └──10.txt 61 | 62 | ``` 63 | Then, create hdf5 data with 64 | ``` 65 | python script create_hdf5.py ./kitti/dataset ./kitti/dataset/all.h5 66 | ``` 67 | 68 | ## Test with the Pretrained Models 69 | The trained models on the KITTI dataset have been uploaded to the [OneDrive](https://1drv.ms/u/s!AgP7bY0L6pvta-AeCK1tFxJrn-8?e=1hYWzy). You can download them and put them into the directory "weights" for testing. 70 | 71 | ``` 72 | export PYTHONPATH="$PROJECT_ROOT_PATH:$PYTHONPATH" 73 | export PYTHONPATH="$PROJECT_ROOT_PATH/rslo:$PYTHONPATH" 74 | python -u $PROJECT_ROOT_PATH/evaluate.py multi_proc_eval \ 75 | --config_path $PROJECT_ROOT_PATH/config/kitti_eval_ours.prototxt \ 76 | --model_dir ./outputs/ \ 77 | --use_dist True \ 78 | --gpus_per_node 1 \ 79 | --use_apex True \ 80 | --world_size 1 \ 81 | --dist_port 20000 \ 82 | --pretrained_path $PROJECT_ROOT_PATH/weights/ours.tckpt \ 83 | --refine False \ 84 | ``` 85 | Note that you need to specify the PROJECT_ROOT_PATH, i.e. the absolute directory of the project folder "RSLO" and modify the path to the created data, i.e. all.h5, in the configuration file kitti_eval_ours.prototxt before running the above commands. A bash script "script/eval_ours.sh" is provided for reference. 86 | 87 | ## Training from Scratch 88 | A basic training script demo is shown as below. You can increase the GPU number, i.e. the variable "GPUs", according to your available resources. Generally, larger batch sizes produce stabler training procedures and better final performances. 89 | 90 | 91 | ``` 92 | export PYTHONPATH="$PROJECT_ROOT_PATH:$PYTHONPATH" 93 | export PYTHONPATH="$PROJECT_ROOT_PATH/rslo:$PYTHONPATH" 94 | GPUs=1 # the number of gpus you use 95 | python -u $PROJECT_ROOT_PATH/train_hdf5.py multi_proc_train \ 96 | --config_path $PROJECT_ROOT_PATH/config/kitti_train_ours.prototxt \ 97 | --model_dir ./outputs/ \ 98 | --use_dist True \ 99 | --gpus_per_node $GPUs \ 100 | --use_apex True \ 101 | --world_size $GPUs \ 102 | --dist_port 20000 \ 103 | --refine False \ 104 | 105 | ``` 106 | 107 | 108 | 109 | 110 | 111 | ## Acknowledgments 112 | We thank for the open-sourced codebases [spconv](https://github.com/traveller59/spconv) and [second](https://github.com/traveller59/second.pytorch) 113 | 114 | ## Citation 115 | To cite our paper 116 | ``` 117 | @article{xu2022robust, 118 | title={Robust Self-supervised LiDAR Odometry via Representative Structure Discovery and 3D Inherent Error Modeling}, 119 | author={Xu, Yan and Lin, Junyi and Shi, Jianping and Zhang, Guofeng and Wang, Xiaogang and Li, Hongsheng}, 120 | journal={IEEE Robotics and Automation Letters}, 121 | year={2021}, 122 | publisher={IEEE} 123 | } 124 | ``` 125 | ``` 126 | @inproceedings{xu2020selfvoxelo, 127 | title = {SelfVoxeLO: Self-supervised LiDAR Odometry with Voxel-based Deep Neural Networks}, 128 | author = {Yan Xu and Zhaoyang Huang and Kwan{-}Yee Lin and Xinge Zhu and Jianping Shi and Hujun Bao and Guofeng Zhang and Hongsheng Li}, 129 | booktitle = {4th Conference on Robot Learning, CoRL 2020, 16-18 November 2020, Virtual Event / Cambridge, MA, {USA}}, 130 | volume = {155}, 131 | pages = {115--125}, 132 | publisher = {{PMLR}}, 133 | year = {2020}, 134 | } 135 | ``` 136 | 137 | ## TODO List and ETA 138 | - [x] Inference code and pretrained models (9/10/2022) 139 | - [x] Training code (10/12/2022) 140 | - [ ] Code cleaning and refactor 141 | 142 | 143 | 144 | 145 | -------------------------------------------------------------------------------- /config/kitti_eval_ours.prototxt: -------------------------------------------------------------------------------- 1 | model: { 2 | second: { 3 | use_GN: false 4 | icp_iter: 2 5 | network_class_name: "UnVoxelOdomNetICP3" 6 | voxel_generator { 7 | point_cloud_range : [-70.4, -38.4, -3, 70.4, 38.4, 5] # [x0,y0,z0, x1,y1,z1] 8 | voxel_size : [0.1, 0.1, 0.2] 9 | max_number_of_points_per_voxel : 10 10 | block_factor :1 11 | block_size : 8 12 | height_threshold : -1 #0.05 13 | } 14 | 15 | voxel_feature_extractor: { 16 | module_class_name: "SimpleVoxel_XYZINormalC" 17 | num_filters: [16] 18 | with_distance: false 19 | num_input_features: 7 20 | not_use_norm: true 21 | 22 | } 23 | middle_feature_extractor: { 24 | module_class_name: "SpMiddleFHDWithCov2_3" 25 | downsample_factor: 8 26 | num_input_features: 7 27 | bn_type:"None" 28 | use_leakyReLU: true 29 | } 30 | 31 | odom_predictor:{ 32 | module_class_name : "UNRResNetOdomPredEncDecSVDTempMask" 33 | num_input_features :128 34 | layer_nums: [3,5,5] 35 | layer_strides:[2,2,2] 36 | num_filters: [128, 128, 256] 37 | upsample_strides:[2,2,2]#[1,2,4] 38 | num_upsample_filters:[128,64,64] 39 | pool_size : 1 40 | pool_type: "avg_pool" 41 | cycle_constraint : true 42 | # not_use_norm: false 43 | bn_type:"SyncBN" 44 | pred_pyramid_motion: true 45 | # use_sparse_conv: false 46 | conv_type:"mask_conv" 47 | odom_format: "rx+t" #"r(x+t)" 48 | dense_predict: true 49 | dropout: 0.0000000000000000000001 50 | conf_type: "softmax"#"linear" 51 | use_deep_supervision:true 52 | use_svd: false 53 | } 54 | 55 | loss: { 56 | pyloss_exp_w_base:0.5 57 | rotation_loss{ 58 | loss_type: "AdaptiveWeightedL2" 59 | weight: 1, 60 | init_alpha: -2.5 61 | } 62 | translation_loss{ 63 | loss_type: "AdaptiveWeightedL2", 64 | weight: 1, 65 | init_alpha: 0 66 | } 67 | consistency_loss{ 68 | loss_type: "Aleat5_1ChamferL2NormalWeightedALLSVDLoss" 69 | weight: 1, 70 | penalize_ratio: 0.97 71 | norm: false 72 | pred_downsample_ratio: 1 73 | reg_weight: 0.005#0.0005 74 | sph_weight:1 75 | } 76 | } 77 | num_point_features: 7 # model's num point feature should be independent of dataset 78 | # Outputs 79 | use_sigmoid_score: true 80 | encode_background_as_zeros: true 81 | } 82 | } 83 | 84 | 85 | eval_input_reader: { 86 | dataset: { 87 | dataset_class_name: "KittiDatasetHDF5" 88 | # dataset_class_name: "KittiDataset" 89 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/eval_data_info4.pkl" #TODO: 90 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5" 91 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5" 92 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset" 93 | seq_length: 2 94 | skip: 1 95 | step:1 96 | random_skip: false 97 | } 98 | batch_size: 1 99 | preprocess: { 100 | max_number_of_voxels: 40000 101 | shuffle_points: false 102 | num_workers: 1 103 | anchor_area_threshold: -1 104 | remove_environment: false 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /config/kitti_eval_ours_moredata.prototxt: -------------------------------------------------------------------------------- 1 | model: { 2 | second: { 3 | use_GN: false#true 4 | icp_iter: 2 5 | # sync_bn: true 6 | network_class_name: "UnVoxelOdomNetICP3"#"UnVoxelOdomNetICP"#"UnVoxelOdomNetICP2"#"VoxelOdomNet" 7 | voxel_generator { 8 | # point_cloud_range : [0, -40, -3, 70.4, 40, 8] # [x0,y0,z0, x1,y1,z1] 9 | #point_cloud_range : [-68.8, -40, -3, 68.8, 40, 5] # [x0,y0,z0, x1,y1,z1] 10 | point_cloud_range : [-70.4, -38.4, -3, 70.4, 38.4, 5] # [x0,y0,z0, x1,y1,z1] 11 | # point_cloud_range : [0, -32.0, -3, 52.8, 32.0, 1] 12 | # voxel_size : [0.05, 0.05, 0.2] 13 | voxel_size : [0.1, 0.1, 0.2] 14 | max_number_of_points_per_voxel : 10#5 15 | block_factor :1 16 | block_size : 8 17 | height_threshold : -1#0.05 18 | } 19 | 20 | voxel_feature_extractor: { 21 | module_class_name: "SimpleVoxel_XYZINormalC"#"SimpleVoxel" 22 | num_filters: [16] 23 | with_distance: false 24 | num_input_features: 7#8 25 | not_use_norm: true 26 | 27 | } 28 | middle_feature_extractor: { 29 | module_class_name: "SpMiddleFHDWithCov2_3"#"SpMiddleFHDWithConf4_1"#"SpMiddleFHD" 30 | # num_filters_down1: [] # protobuf don't support empty list. 31 | # num_filters_down2: [] 32 | downsample_factor: 8 33 | num_input_features: 7#8 34 | # not_use_norm: true 35 | bn_type:"None" 36 | #not_use_norm: false 37 | use_leakyReLU: true 38 | } 39 | 40 | odom_predictor:{ 41 | module_class_name : "UNRResNetOdomPredEncDecSVDTempMask" #"UNRResNetOdomPredEncDecSVD"#"RResNetOdomPredEncDec" 42 | num_input_features :128 43 | layer_nums: [3,5,5] 44 | layer_strides:[2,2,2] 45 | num_filters: [128, 128, 256] 46 | upsample_strides:[2,2,2]#[1,2,4] 47 | num_upsample_filters:[128,64,64] #[256, 256, 256] 48 | pool_size : 1 49 | pool_type: "avg_pool" 50 | cycle_constraint : true 51 | # not_use_norm: false 52 | bn_type:"SyncBN" 53 | pred_pyramid_motion: true 54 | # use_sparse_conv: false 55 | conv_type:"mask_conv" 56 | odom_format: "rx+t"#"r(x+t)" 57 | dense_predict: true 58 | dropout: 0.0000000000000000000001#0.2 59 | conf_type: "softmax"#"linear" 60 | use_deep_supervision:true 61 | use_svd: false 62 | } 63 | 64 | loss: { 65 | pyloss_exp_w_base:0.5 66 | rotation_loss{ 67 | loss_type: "AdaptiveWeightedL2" #"AdaptiveWeightedL2RMatrixLoss"# #"AdaptiveWeightedL2", 68 | weight: 1#0, 69 | init_alpha: -2.5 70 | } 71 | translation_loss{ 72 | loss_type: "AdaptiveWeightedL2", 73 | weight: 1, 74 | init_alpha: 0 75 | } 76 | consistency_loss{ 77 | loss_type: "Aleat5_1ChamferL2NormalWeightedALLSVDLoss"#"ChamferL2NormalWeightedALLSVDLoss"#"ChamferL2NormalWeightedALLLoss"#"ChamferL2Loss"#"CosineDistance", 78 | weight: 1,#10#1, 79 | penalize_ratio: 0.97#0.99#0.9999#0.95#0.9999 80 | norm: false 81 | pred_downsample_ratio: 1 82 | reg_weight: 0.005#0.0005 83 | 84 | sph_weight:1 85 | # sample_block_size: [0.5,1,1] 86 | } 87 | } 88 | num_point_features: 7#8#4 # model's num point feature should be independent of dataset 89 | # Outputs 90 | use_sigmoid_score: true 91 | encode_background_as_zeros: true 92 | 93 | 94 | # Loss 95 | 96 | 97 | # Postprocess 98 | post_center_limit_range: [0, -40, -2.2, 70.4, 40, 0.8] 99 | 100 | } 101 | } 102 | 103 | train_input_reader: { 104 | dataset: { 105 | dataset_class_name: "KittiDatasetHDF5" 106 | # dataset_class_name: "KittiDataset" 107 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_data_info7.pkl" 108 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5" 109 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5" 110 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset" 111 | seq_length: 3 112 | skip: 1 113 | step:1 114 | random_skip: false#true 115 | } 116 | 117 | batch_size: 1#2#3#1#3 #8 118 | preprocess: { 119 | max_number_of_voxels: 40000#17000 120 | shuffle_points: false#true 121 | num_workers: 2#2 122 | # groundtruth_localization_noise_std: [1.0, 1.0, 0.5] 123 | # groundtruth_rotation_uniform_noise: [-0.3141592654, 0.3141592654] 124 | # groundtruth_rotation_uniform_noise: [-1.57, 1.57] 125 | # groundtruth_rotation_uniform_noise: [-0.78539816, 0.78539816] 126 | # global_rotation_uniform_noise: [-0.78539816, 0.78539816] 127 | # global_scaling_uniform_noise: [0.95, 1.05] 128 | # global_random_rotation_range_per_object: [0, 0] # pi/4 ~ 3pi/4 129 | # global_translate_noise_std: [0, 0, 0] 130 | # anchor_area_threshold: -1 131 | remove_points_after_sample: true 132 | groundtruth_points_drop_percentage: 0.0 133 | groundtruth_drop_max_keep_points: 15 134 | # remove_unknown_examples: false 135 | # sample_importance: 1.0 136 | random_flip_x: false 137 | random_flip_y: true#false 138 | # remove_environment: false 139 | # downsample_voxel_sizes: [0.05,0.1,0.2,0.4] 140 | downsample_voxel_sizes: [0.1] #[0.1,0.2,0.4, 0.8] 141 | 142 | } 143 | } 144 | 145 | train_config: { 146 | optimizer: { 147 | adam_optimizer: { 148 | learning_rate: { 149 | one_cycle: { 150 | lr_max: 0.8e-3#1e-3#2e-3 151 | moms: [0.95, 0.85] 152 | div_factor: 10.0 153 | pct_start: 0.05 154 | } 155 | } 156 | weight_decay: 1e-5#0.001 157 | } 158 | fixed_weight_decay: true 159 | use_moving_average: false 160 | } 161 | #optimizer: { 162 | # adam_optimizer: { 163 | # learning_rate: { 164 | # exponential_decay: { 165 | # initial_learning_rate: 0.002#0.008#0.002 166 | # decay_length: 0.05#0.1#0.1 167 | # decay_factor: 0.8#0.8 168 | # staircase: True 169 | # } 170 | # } 171 | # weight_decay: 1e-6#0.0001 172 | # } 173 | # fixed_weight_decay: false 174 | # use_moving_average: false 175 | #} 176 | # steps: 99040 # 1238 * 120 177 | # s: 49520 # 619 * 80 178 | # steps: 30950 # 619 * 80 179 | # steps_per_eval: 3095 # 619 * 5 180 | steps: 200000#200000#42500#170000#12750#23200*20/8/4 #23200 # 464 * 50 181 | steps_per_eval: 4000#425#1700#850#637#23200/8#2320 # 619 * 5 182 | 183 | # save_checkpoints_secs : 1800 # half hour 184 | # save_summary_steps : 10 185 | enable_mixed_precision: false 186 | loss_scale_factor: -1 187 | clear_metrics_every_epoch: true 188 | } 189 | 190 | eval_input_reader: { 191 | dataset: { 192 | dataset_class_name: "KittiDatasetHDF5" 193 | # dataset_class_name: "KittiDataset" 194 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/eval_data_info4.pkl" #TODO: 195 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5" 196 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5" 197 | # kitti_info_path: "/media/yy/960evo/datasets/kitti/kitti_infos_test.pkl" 198 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset" 199 | seq_length: 2 200 | skip: 1 201 | step:1 202 | random_skip: false 203 | } 204 | batch_size: 1 205 | preprocess: { 206 | max_number_of_voxels: 40000 207 | shuffle_points: false 208 | num_workers: 1 209 | anchor_area_threshold: -1 210 | remove_environment: false 211 | } 212 | } 213 | eval_train_input_reader: { 214 | dataset: { 215 | dataset_class_name: "KittiDatasetHDF5" 216 | # dataset_class_name: "KittiDataset" 217 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/eval_train_data_info.pkl" #TODO: 218 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5" 219 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5" 220 | # kitti_info_path: "/media/yy/960evo/datasets/kitti/kitti_infos_test.pkl" 221 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset" 222 | seq_length: 2 223 | skip:1 224 | step:1 225 | random_skip: false 226 | } 227 | batch_size: 1 228 | preprocess: { 229 | max_number_of_voxels: 40000 230 | shuffle_points: false 231 | num_workers: 1 232 | } 233 | } -------------------------------------------------------------------------------- /config/kitti_train_ours.prototxt: -------------------------------------------------------------------------------- 1 | model: { 2 | second: { 3 | use_GN: false#true 4 | icp_iter: 2 5 | # sync_bn: true 6 | network_class_name: "UnVoxelOdomNetICP3"#"UnVoxelOdomNetICP"#"UnVoxelOdomNetICP2"#"VoxelOdomNet" 7 | voxel_generator { 8 | # point_cloud_range : [0, -40, -3, 70.4, 40, 8] # [x0,y0,z0, x1,y1,z1] 9 | #point_cloud_range : [-68.8, -40, -3, 68.8, 40, 5] # [x0,y0,z0, x1,y1,z1] 10 | point_cloud_range : [-70.4, -38.4, -3, 70.4, 38.4, 5] # [x0,y0,z0, x1,y1,z1] 11 | # point_cloud_range : [0, -32.0, -3, 52.8, 32.0, 1] 12 | # voxel_size : [0.05, 0.05, 0.2] 13 | voxel_size : [0.1, 0.1, 0.2] 14 | max_number_of_points_per_voxel : 10#5 15 | block_factor :1 16 | block_size : 8 17 | height_threshold : -1#0.05 18 | } 19 | 20 | voxel_feature_extractor: { 21 | module_class_name: "SimpleVoxel_XYZINormalC"#"SimpleVoxel" 22 | num_filters: [16] 23 | with_distance: false 24 | num_input_features: 7#8 25 | not_use_norm: true 26 | 27 | } 28 | middle_feature_extractor: { 29 | module_class_name: "SpMiddleFHDWithCov2_3"#"SpMiddleFHDWithConf4_1"#"SpMiddleFHD" 30 | # num_filters_down1: [] # protobuf don't support empty list. 31 | # num_filters_down2: [] 32 | downsample_factor: 8 33 | num_input_features: 7#8 34 | # not_use_norm: true 35 | bn_type:"None" 36 | #not_use_norm: false 37 | use_leakyReLU: true 38 | } 39 | 40 | odom_predictor:{ 41 | module_class_name : "UNRResNetOdomPredEncDecSVDTempMask" #"UNRResNetOdomPredEncDecSVD"#"RResNetOdomPredEncDec" 42 | num_input_features :128 43 | layer_nums: [3,5,5] 44 | layer_strides:[2,2,2] 45 | num_filters: [128, 128, 256] 46 | upsample_strides:[2,2,2]#[1,2,4] 47 | num_upsample_filters:[128,64,64] #[256, 256, 256] 48 | pool_size : 1 49 | pool_type: "avg_pool" 50 | cycle_constraint : true 51 | # not_use_norm: false 52 | bn_type:"SyncBN" 53 | pred_pyramid_motion: true 54 | # use_sparse_conv: false 55 | conv_type:"mask_conv" 56 | odom_format: "rx+t"#"r(x+t)" 57 | dense_predict: true 58 | dropout: 0.0000000000000000000001#0.2 59 | conf_type: "softmax"#"linear" 60 | use_deep_supervision:true 61 | use_svd: false 62 | } 63 | 64 | loss: { 65 | pyloss_exp_w_base:0.5 66 | rotation_loss{ 67 | loss_type: "AdaptiveWeightedL2" #"AdaptiveWeightedL2RMatrixLoss"# #"AdaptiveWeightedL2", 68 | weight: 1#0, 69 | init_alpha: -2.5 70 | } 71 | translation_loss{ 72 | loss_type: "AdaptiveWeightedL2", 73 | weight: 1, 74 | init_alpha: 0 75 | } 76 | consistency_loss{ 77 | loss_type: "Aleat5_1ChamferL2NormalWeightedALLSVDLoss"#"ChamferL2NormalWeightedALLSVDLoss"#"ChamferL2NormalWeightedALLLoss"#"ChamferL2Loss"#"CosineDistance", 78 | weight: 1,#10#1, 79 | penalize_ratio: 0.97#0.99#0.9999#0.95#0.9999 80 | norm: false 81 | pred_downsample_ratio: 1 82 | reg_weight: 0.005#0.0005 83 | 84 | sph_weight:1 85 | # sample_block_size: [0.5,1,1] 86 | } 87 | } 88 | num_point_features: 7#8#4 # model's num point feature should be independent of dataset 89 | # Outputs 90 | use_sigmoid_score: true 91 | encode_background_as_zeros: true 92 | 93 | 94 | # Loss 95 | 96 | 97 | # Postprocess 98 | post_center_limit_range: [0, -40, -2.2, 70.4, 40, 0.8] 99 | 100 | } 101 | } 102 | 103 | train_input_reader: { 104 | dataset: { 105 | dataset_class_name: "KittiDatasetHDF5" 106 | # dataset_class_name: "KittiDataset" 107 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_data_info7.pkl" 108 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5" 109 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5" 110 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset" 111 | seq_length: 3 112 | skip: 1 113 | step:1 114 | random_skip: false#true 115 | } 116 | 117 | batch_size: 1#2#3#1#3 #8 118 | preprocess: { 119 | max_number_of_voxels: 40000#17000 120 | shuffle_points: false#true 121 | num_workers: 2#2 122 | # groundtruth_localization_noise_std: [1.0, 1.0, 0.5] 123 | # groundtruth_rotation_uniform_noise: [-0.3141592654, 0.3141592654] 124 | # groundtruth_rotation_uniform_noise: [-1.57, 1.57] 125 | # groundtruth_rotation_uniform_noise: [-0.78539816, 0.78539816] 126 | # global_rotation_uniform_noise: [-0.78539816, 0.78539816] 127 | # global_scaling_uniform_noise: [0.95, 1.05] 128 | # global_random_rotation_range_per_object: [0, 0] # pi/4 ~ 3pi/4 129 | # global_translate_noise_std: [0, 0, 0] 130 | # anchor_area_threshold: -1 131 | remove_points_after_sample: true 132 | groundtruth_points_drop_percentage: 0.0 133 | groundtruth_drop_max_keep_points: 15 134 | # remove_unknown_examples: false 135 | # sample_importance: 1.0 136 | random_flip_x: false 137 | random_flip_y: true#false 138 | # remove_environment: false 139 | # downsample_voxel_sizes: [0.05,0.1,0.2,0.4] 140 | downsample_voxel_sizes: [0.1] #[0.1,0.2,0.4, 0.8] 141 | 142 | } 143 | } 144 | 145 | train_config: { 146 | optimizer: { 147 | adam_optimizer: { 148 | learning_rate: { 149 | one_cycle: { 150 | lr_max: 0.8e-3#1e-3#2e-3 151 | moms: [0.95, 0.85] 152 | div_factor: 10.0 153 | pct_start: 0.05 154 | } 155 | } 156 | weight_decay: 1e-5#0.001 157 | } 158 | fixed_weight_decay: true 159 | use_moving_average: false 160 | } 161 | #optimizer: { 162 | # adam_optimizer: { 163 | # learning_rate: { 164 | # exponential_decay: { 165 | # initial_learning_rate: 0.002#0.008#0.002 166 | # decay_length: 0.05#0.1#0.1 167 | # decay_factor: 0.8#0.8 168 | # staircase: True 169 | # } 170 | # } 171 | # weight_decay: 1e-6#0.0001 172 | # } 173 | # fixed_weight_decay: false 174 | # use_moving_average: false 175 | #} 176 | # steps: 99040 # 1238 * 120 177 | # s: 49520 # 619 * 80 178 | # steps: 30950 # 619 * 80 179 | # steps_per_eval: 3095 # 619 * 5 180 | steps: 200000#200000#42500#170000#12750#23200*20/8/4 #23200 # 464 * 50 181 | steps_per_eval: 4000#425#1700#850#637#23200/8#2320 # 619 * 5 182 | 183 | # save_checkpoints_secs : 1800 # half hour 184 | # save_summary_steps : 10 185 | enable_mixed_precision: false 186 | loss_scale_factor: -1 187 | clear_metrics_every_epoch: true 188 | } 189 | 190 | eval_input_reader: { 191 | dataset: { 192 | dataset_class_name: "KittiDatasetHDF5" 193 | # dataset_class_name: "KittiDataset" 194 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/eval_data_info4.pkl" #TODO: 195 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5" 196 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5" 197 | # kitti_info_path: "/media/yy/960evo/datasets/kitti/kitti_infos_test.pkl" 198 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset" 199 | seq_length: 2 200 | skip: 1 201 | step:1 202 | random_skip: false 203 | } 204 | batch_size: 1 205 | preprocess: { 206 | max_number_of_voxels: 40000 207 | shuffle_points: false 208 | num_workers: 1 209 | anchor_area_threshold: -1 210 | remove_environment: false 211 | } 212 | } 213 | eval_train_input_reader: { 214 | dataset: { 215 | dataset_class_name: "KittiDatasetHDF5" 216 | # dataset_class_name: "KittiDataset" 217 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/eval_train_data_info.pkl" #TODO: 218 | # kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/train_val_hier.h5" 219 | kitti_info_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset/all.h5" 220 | # kitti_info_path: "/media/yy/960evo/datasets/kitti/kitti_infos_test.pkl" 221 | kitti_root_path: "/mnt/lustre/xuyan2/datasets/kitti/odometry/dataset" 222 | seq_length: 2 223 | skip:1 224 | step:1 225 | random_skip: false 226 | } 227 | batch_size: 1 228 | preprocess: { 229 | max_number_of_voxels: 40000 230 | shuffle_points: false 231 | num_workers: 1 232 | } 233 | } -------------------------------------------------------------------------------- /demo/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/demo/framework.png -------------------------------------------------------------------------------- /demo/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/demo/output.gif -------------------------------------------------------------------------------- /demo/pointcov.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/demo/pointcov.png -------------------------------------------------------------------------------- /demo/traj+pointcov.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/demo/traj+pointcov.png -------------------------------------------------------------------------------- /demo/traj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/demo/traj.png -------------------------------------------------------------------------------- /doc/train.md: -------------------------------------------------------------------------------- 1 | ## Comming Soon! -------------------------------------------------------------------------------- /rslo/builder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/rslo/builder/__init__.py -------------------------------------------------------------------------------- /rslo/builder/dataset_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Input reader builder. 16 | 17 | Creates data sources for DetectionModels from an InputReader config. See 18 | input_reader.proto for options. 19 | 20 | Note: If users wishes to also use their own InputReaders with the Object 21 | Detection configuration framework, they should define their own builder function 22 | that wraps the build function. 23 | """ 24 | 25 | from rslo.protos import input_reader_pb2 26 | from rslo.data.dataset import get_dataset_class 27 | from rslo.data.preprocess import prep_pointcloud 28 | import numpy as np 29 | from functools import partial 30 | # from rslo.utils.config_tool import get_downsample_factor 31 | 32 | 33 | def build(input_reader_config, 34 | model_config, 35 | training, 36 | voxel_generator, 37 | multi_gpu=False, 38 | use_dist=False, 39 | split=None, 40 | use_hdf5=False 41 | ): 42 | """Builds a tensor dictionary based on the InputReader config. 43 | 44 | Args: 45 | input_reader_config: A input_reader_pb2.InputReader object. 46 | 47 | Returns: 48 | A tensor dict based on the input_reader_config. 49 | 50 | Raises: 51 | ValueError: On invalid input reader proto. 52 | ValueError: If no input paths are specified. 53 | """ 54 | if not isinstance(input_reader_config, input_reader_pb2.InputReader): 55 | raise ValueError('input_reader_config not of type ' 56 | 'input_reader_pb2.InputReader.') 57 | prep_cfg = input_reader_config.preprocess 58 | dataset_cfg = input_reader_config.dataset 59 | num_point_features = model_config.num_point_features 60 | cfg = input_reader_config 61 | db_sampler = None 62 | # if len(db_sampler_cfg.sample_groups) > 0 or db_sampler_cfg.database_info_path != "": # enable sample 63 | # db_sampler = dbsampler_builder.build(db_sampler_cfg) 64 | grid_size = voxel_generator.grid_size 65 | # feature_map_size = grid_size[:2] // out_size_factor 66 | # feature_map_size = [*feature_map_size, 1][::-1] 67 | 68 | dataset_cls = get_dataset_class(dataset_cfg.dataset_class_name) 69 | # assert dataset_cls.NumPointFeatures >= 3, "you must set this to correct value" 70 | # assert dataset_cls.NumPointFeatures == num_point_features, "currently you need keep them same" 71 | 72 | prep_func = partial( 73 | prep_pointcloud, 74 | root_path=dataset_cfg.kitti_root_path, 75 | voxel_generator=voxel_generator, 76 | # target_assigner=target_assigner, 77 | training=training, 78 | max_voxels=prep_cfg.max_number_of_voxels, 79 | shuffle_points=prep_cfg.shuffle_points, 80 | 81 | num_point_features=num_point_features, # dataset_cls.NumPointFeatures, 82 | 83 | # out_size_factor=out_size_factor, 84 | multi_gpu=multi_gpu, 85 | use_dist=use_dist, 86 | min_points_in_gt=prep_cfg.min_num_of_points_in_gt, 87 | random_flip_x=prep_cfg.random_flip_x, 88 | random_flip_y=prep_cfg.random_flip_y, 89 | rand_aug_ratio=prep_cfg.random_aug_ratio, 90 | sample_importance=prep_cfg.sample_importance, 91 | rand_rotation_eps=prep_cfg.rand_rotation_eps, 92 | rand_translation_eps=prep_cfg.rand_translation_eps, 93 | gen_tq_map=model_config.odom_predictor.pred_pyramid_motion, # !!! 94 | do_pre_transform=prep_cfg.do_pre_transform, 95 | cubic_tq_map=prep_cfg.cubic_tq_map, 96 | downsample_voxel_sizes=list(prep_cfg.downsample_voxel_sizes) 97 | ) 98 | 99 | dataset = dataset_cls( 100 | info_path=dataset_cfg.kitti_info_path, 101 | root_path=dataset_cfg.kitti_root_path, 102 | seq_length=dataset_cfg.seq_length, 103 | skip=dataset_cfg.skip, 104 | random_skip=dataset_cfg.random_skip, 105 | prep_func=prep_func, 106 | step=dataset_cfg.step, 107 | num_point_features=num_point_features, 108 | split=split, 109 | ) 110 | 111 | return dataset 112 | -------------------------------------------------------------------------------- /rslo/builder/input_reader_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Input reader builder. 16 | 17 | Creates data sources for DetectionModels from an InputReader config. See 18 | input_reader.proto for options. 19 | 20 | Note: If users wishes to also use their own InputReaders with the Object 21 | Detection configuration framework, they should define their own builder function 22 | that wraps the build function. 23 | """ 24 | 25 | from torch.utils.data import Dataset 26 | 27 | from rslo.builder import dataset_builder 28 | from rslo.protos import input_reader_pb2 29 | 30 | 31 | class DatasetWrapper(Dataset): 32 | """ convert our dataset to Dataset class in pytorch. 33 | """ 34 | 35 | def __init__(self, dataset): 36 | self._dataset = dataset 37 | 38 | def __len__(self): 39 | return len(self._dataset) 40 | 41 | def __getitem__(self, idx): 42 | return self._dataset[idx] 43 | 44 | @property 45 | def dataset(self): 46 | return self._dataset 47 | 48 | 49 | def build(input_reader_config, 50 | model_config, 51 | training, 52 | voxel_generator, 53 | # target_assigner=None, 54 | multi_gpu=False, 55 | use_dist=False, 56 | split=None, 57 | ) -> DatasetWrapper: 58 | """ 59 | Builds a tensor dictionary based on the InputReader config. 60 | 61 | Args: 62 | input_reader_config: A input_reader_pb2.InputReader object. 63 | 64 | Returns: 65 | A tensor dict based on the input_reader_config. 66 | 67 | Raises: 68 | ValueError: On invalid input reader proto. 69 | ValueError: If no input paths are specified. 70 | """ 71 | if not isinstance(input_reader_config, input_reader_pb2.InputReader): 72 | raise ValueError('input_reader_config not of type ' 73 | 'input_reader_pb2.InputReader.') 74 | 75 | dataset = dataset_builder.build( 76 | input_reader_config, 77 | model_config, 78 | training, 79 | voxel_generator, 80 | # target_assigner, 81 | multi_gpu=multi_gpu, 82 | use_dist=use_dist, 83 | split=split 84 | ) 85 | dataset = DatasetWrapper(dataset) 86 | return dataset 87 | -------------------------------------------------------------------------------- /rslo/builder/losses_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A function to build localization and classification losses from config.""" 17 | 18 | from rslo.core import losses 19 | # from rslo.core.ghm_loss import GHMCLoss, GHMRLoss 20 | from rslo.protos import losses_pb2 21 | 22 | 23 | def build(loss_config): 24 | """Build losses based on the config. 25 | 26 | Builds classification, localization losses and optionally a hard example miner 27 | based on the config. 28 | 29 | Args: 30 | loss_config: A losses_pb2.Loss object. 31 | 32 | """ 33 | 34 | 35 | rotation_loss_func = _build_rotation_loss( 36 | loss_config.rotation_loss) 37 | translation_loss_func = _build_translation_loss( 38 | loss_config.translation_loss) 39 | 40 | if loss_config.pyramid_rotation_loss.loss_type != '': 41 | pyramid_rotation_loss = _build_rotation_loss( 42 | loss_config.pyramid_rotation_loss) 43 | else: 44 | pyramid_rotation_loss = rotation_loss_func 45 | 46 | if loss_config.pyramid_translation_loss.loss_type != '': 47 | pyramid_translation_loss = _build_translation_loss( 48 | loss_config.pyramid_translation_loss) 49 | else: 50 | pyramid_translation_loss = translation_loss_func 51 | 52 | consistency_loss = _build_consistency_loss(loss_config.consistency_loss) 53 | 54 | if loss_config.rigid_transform_loss.weight != 0: 55 | rigid_transform_loss = losses.RigidTransformLoss( 56 | rotation_loss_func, translation_loss_func, focal_gamma=loss_config.rigid_transform_loss.focal_gamma) 57 | py_rigid_transform_loss = losses.RigidTransformLoss( 58 | pyramid_rotation_loss, pyramid_translation_loss, focal_gamma=loss_config.rigid_transform_loss.focal_gamma) 59 | 60 | return (rigid_transform_loss, py_rigid_transform_loss, consistency_loss) 61 | 62 | return (rotation_loss_func, translation_loss_func, 63 | pyramid_rotation_loss, pyramid_translation_loss, consistency_loss, 64 | ) 65 | 66 | 67 | def _build_translation_loss(loss_config): 68 | """Builds a translation loss function based on the loss config. 69 | 70 | Args: 71 | loss_config: A losses_pb2.TranslationLoss object. 72 | 73 | Returns: 74 | Loss based on the config. 75 | 76 | Raises: 77 | ValueError: On invalid loss_config. 78 | """ 79 | if not isinstance(loss_config, losses_pb2.TranslationLoss): 80 | raise ValueError( 81 | 'loss_config not of type losses_pb2.TranslationLoss.') 82 | 83 | # loss_config.WhichOneof('localization_loss') 84 | loss_type = loss_config.loss_type 85 | loss_weight = loss_config.weight 86 | focal_gamma = loss_config.focal_gamma 87 | if loss_type == 'L2': 88 | return losses.L2Loss(loss_weight) 89 | elif loss_type == 'AdaptiveWeightedL2': 90 | if loss_config.balance_scale<=0: 91 | loss_config.balance_scale=1 92 | assert loss_config.balance_scale>0 93 | return losses.AdaptiveWeightedL2Loss(loss_config.init_alpha, learn_alpha=not loss_config.not_learn_alpha, loss_weight=loss_weight, focal_gamma=focal_gamma, balance_scale=loss_config.balance_scale) 94 | else: 95 | raise ValueError('Empty loss config.') 96 | 97 | 98 | def _build_rotation_loss(loss_config): 99 | """Builds a classification loss based on the loss config. 100 | 101 | Args: 102 | loss_config: A losses_pb2.RotationLoss object. 103 | 104 | Returns: 105 | Loss based on the config. 106 | 107 | Raises: 108 | ValueError: On invalid loss_config. 109 | """ 110 | if not isinstance(loss_config, losses_pb2.RotaionLoss): 111 | raise ValueError( 112 | 'loss_config not of type losses_pb2.RotaionLoss.') 113 | 114 | 115 | loss_type = loss_config.loss_type 116 | loss_weight = loss_config.weight 117 | focal_gamma = loss_config.focal_gamma 118 | 119 | if loss_type == 'L2': 120 | return losses.L2Loss(loss_weight) 121 | elif loss_type == 'AdaptiveWeightedL2': 122 | if loss_config.balance_scale<=0: 123 | loss_config.balance_scale=1 124 | assert loss_config.balance_scale>0 125 | return losses.AdaptiveWeightedL2Loss(loss_config.init_alpha, learn_alpha=not loss_config.not_learn_alpha, loss_weight=loss_weight, focal_gamma=focal_gamma, balance_scale=loss_config.balance_scale) 126 | elif loss_type == 'AdaptiveWeightedL2RMatrixLoss': 127 | return losses.AdaptiveWeightedL2RMatrixLoss(loss_config.init_alpha, learn_alpha=not loss_config.not_learn_alpha, loss_weight=loss_weight, focal_gamma=focal_gamma) 128 | 129 | raise ValueError('Empty loss config.') 130 | 131 | 132 | def _build_consistency_loss(loss_config): 133 | if not isinstance(loss_config, losses_pb2.ConsistencyLoss): 134 | raise ValueError( 135 | 'loss_config not of type losses_pb2.ConsistencyLoss.') 136 | 137 | loss_type = loss_config.loss_type 138 | loss_weight = loss_config.weight 139 | 140 | if loss_type=='AdaptiveWeightedL2': 141 | return losses.AdaptiveWeightedL2Loss(loss_config.init_alpha, learn_alpha=not loss_config.not_learn_alpha, loss_weight=loss_weight, focal_gamma=loss_config.focal_gamma, balance_scale=loss_config.balance_scale ) 142 | elif loss_type=='Aleat5_1ChamferL2NormalWeightedALLSVDLoss': 143 | assert loss_config.penalize_ratio>0 144 | assert loss_config.pred_downsample_ratio>0 145 | assert loss_config.reg_weight >0 146 | assert loss_config.sph_weight >0 147 | return losses.Aleat5_1ChamferL2NormalWeightedALLSVDLoss(loss_weight=loss_weight, penalize_ratio=loss_config.penalize_ratio, sample_block_size=loss_config.sample_block_size, norm=loss_config.norm, pred_downsample_ratio=loss_config.pred_downsample_ratio,reg_weight=loss_config.reg_weight, sph_weight=loss_config.sph_weight) 148 | else: 149 | print('Warning: Empty loss config.') 150 | return None 151 | 152 | -------------------------------------------------------------------------------- /rslo/builder/lr_scheduler_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions to build DetectionModel training optimizers.""" 17 | 18 | from torchplus.train import learning_schedules_fastai as lsf 19 | import torch 20 | import numpy as np 21 | 22 | def build(optimizer_config, optimizer, total_step): 23 | """Create lr scheduler based on config. note that 24 | lr_scheduler must accept a optimizer that has been restored. 25 | 26 | Args: 27 | optimizer_config: A Optimizer proto message. 28 | 29 | Returns: 30 | An optimizer and a list of variables for summary. 31 | 32 | Raises: 33 | ValueError: when using an unsupported input data type. 34 | """ 35 | optimizer_type = optimizer_config.WhichOneof('optimizer') 36 | 37 | if optimizer_type == 'rms_prop_optimizer': 38 | config = optimizer_config.rms_prop_optimizer 39 | lr_scheduler = _create_learning_rate_scheduler( 40 | config.learning_rate, optimizer, total_step=total_step) 41 | 42 | if optimizer_type == 'momentum_optimizer': 43 | config = optimizer_config.momentum_optimizer 44 | lr_scheduler = _create_learning_rate_scheduler( 45 | config.learning_rate, optimizer, total_step=total_step) 46 | 47 | if optimizer_type == 'adam_optimizer': 48 | config = optimizer_config.adam_optimizer 49 | lr_scheduler = _create_learning_rate_scheduler( 50 | config.learning_rate, optimizer, total_step=total_step) 51 | 52 | return lr_scheduler 53 | 54 | 55 | def _create_learning_rate_scheduler(learning_rate_config, optimizer, total_step): 56 | """Create optimizer learning rate scheduler based on config. 57 | 58 | Args: 59 | learning_rate_config: A LearningRate proto message. 60 | 61 | Returns: 62 | A learning rate. 63 | 64 | Raises: 65 | ValueError: when using an unsupported input data type. 66 | """ 67 | lr_scheduler = None 68 | learning_rate_type = learning_rate_config.WhichOneof('learning_rate') 69 | if learning_rate_type == 'multi_phase': 70 | config = learning_rate_config.multi_phase 71 | lr_phases = [] 72 | mom_phases = [] 73 | for phase_cfg in config.phases: 74 | lr_phases.append((phase_cfg.start, phase_cfg.lambda_func)) 75 | mom_phases.append( 76 | (phase_cfg.start, phase_cfg.momentum_lambda_func)) 77 | lr_scheduler = lsf.LRSchedulerStep( 78 | optimizer, total_step, lr_phases, mom_phases) 79 | 80 | 81 | 82 | if learning_rate_type == 'one_cycle': 83 | config = learning_rate_config.one_cycle 84 | 85 | if len(config.lr_maxs)>1: 86 | assert(len(config.lr_maxs)==4 ) 87 | lr_max=[] 88 | # for i in range(len(config.lr_maxs)): 89 | # lr_max += [config.lr_maxs[i]]*optimizer.param_segs[i] 90 | 91 | lr_max = np.array(list(config.lr_maxs) ) 92 | else: 93 | lr_max = config.lr_max 94 | 95 | lr_scheduler = lsf.OneCycle( 96 | optimizer, total_step, lr_max, list(config.moms), config.div_factor, config.pct_start) 97 | if learning_rate_type == 'exponential_decay': 98 | config = learning_rate_config.exponential_decay 99 | lr_scheduler = lsf.ExponentialDecay( 100 | optimizer, total_step, config.initial_learning_rate, config.decay_length, config.decay_factor, config.staircase) 101 | if learning_rate_type == 'exponential_decay_warmup': 102 | config = learning_rate_config.exponential_decay_warmup 103 | lr_scheduler = lsf.ExponentialDecayWarmup( 104 | optimizer, total_step, config.initial_learning_rate, config.decay_length, config.decay_factor, config.div_factor, 105 | config.pct_start, config.staircase) 106 | if learning_rate_type == 'manual_stepping': 107 | config = learning_rate_config.manual_stepping 108 | lr_scheduler = lsf.ManualStepping( 109 | optimizer, total_step, list(config.boundaries), list(config.rates)) 110 | 111 | if lr_scheduler is None: 112 | raise ValueError('Learning_rate %s not supported.' % 113 | learning_rate_type) 114 | 115 | return lr_scheduler 116 | -------------------------------------------------------------------------------- /rslo/builder/optimizer_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Functions to build DetectionModel training optimizers.""" 16 | 17 | from torchplus.train import learning_schedules 18 | from torchplus.train import optim 19 | import torch 20 | from torch import nn 21 | from torchplus.train.fastai_optim import OptimWrapper, FastAIMixedOptim 22 | from functools import partial 23 | 24 | 25 | def children(m: nn.Module): 26 | "Get children of `m`." 27 | return list(m.children()) 28 | 29 | 30 | def num_children(m: nn.Module) -> int: 31 | "Get number of children modules in `m`." 32 | return len(children(m)) 33 | 34 | # return a list of smallest modules dy 35 | 36 | 37 | def flatten_model(m): 38 | if m is None: 39 | return [] 40 | return sum( 41 | map(flatten_model, m.children()), []) if num_children(m) else [m] 42 | 43 | 44 | # def get_layer_groups(m): return [nn.Sequential(*flatten_model(m))] 45 | def get_layer_groups(m): return [nn.ModuleList(flatten_model(m))] 46 | 47 | 48 | def get_voxeLO_net_layer_groups(net): 49 | vfe_grp = get_layer_groups(net.voxel_feature_extractor)#[0] 50 | mfe_grp = get_layer_groups(net.middle_feature_extractor)#[0] 51 | op_grp = get_layer_groups(net.odom_predictor)#[0] 52 | 53 | # other_grp = get_layer_groups(net._rotation_loss) + \ 54 | # get_layer_groups(net._translation_loss) \ 55 | # + get_layer_groups(net._pyramid_rotation_loss) \ 56 | # + get_layer_groups(net._pyramid_translation_loss) \ 57 | # + get_layer_groups(net._consistency_loss)\ 58 | other_grp = get_layer_groups(nn.Sequential(net._rotation_loss, 59 | net._translation_loss, 60 | net._pyramid_rotation_loss, 61 | net._pyramid_translation_loss, 62 | net._consistency_loss, 63 | )) 64 | 65 | return [vfe_grp, mfe_grp, op_grp,other_grp] 66 | 67 | 68 | def build(optimizer_config, net, name=None, mixed=False, loss_scale=512.0): 69 | """Create optimizer based on config. 70 | 71 | Args: 72 | optimizer_config: A Optimizer proto message. 73 | 74 | Returns: 75 | An optimizer and a list of variables for summary. 76 | 77 | Raises: 78 | ValueError: when using an unsupported input data type. 79 | """ 80 | optimizer_type=optimizer_config.WhichOneof('optimizer') 81 | optimizer=None 82 | 83 | if optimizer_type == 'rms_prop_optimizer': 84 | config=optimizer_config.rms_prop_optimizer 85 | optimizer_func=partial( 86 | torch.optim.RMSprop, 87 | alpha=config.decay, 88 | momentum=config.momentum_optimizer_value, 89 | eps=config.epsilon) 90 | 91 | if optimizer_type == 'momentum_optimizer': 92 | config=optimizer_config.momentum_optimizer 93 | optimizer_func=partial( 94 | torch.optim.SGD, 95 | momentum=config.momentum_optimizer_value, 96 | eps=config.epsilon) 97 | 98 | if optimizer_type == 'adam_optimizer': 99 | config=optimizer_config.adam_optimizer 100 | if optimizer_config.fixed_weight_decay: 101 | optimizer_func=partial( 102 | torch.optim.Adam, betas=(0.9, 0.99), amsgrad=config.amsgrad) 103 | else: 104 | # regular adam 105 | optimizer_func=partial( 106 | torch.optim.Adam, amsgrad=config.amsgrad) 107 | 108 | # optimizer = OptimWrapper(optimizer, true_wd=optimizer_config.fixed_weight_decay, wd=config.weight_decay) 109 | optimizer=OptimWrapper.create( 110 | optimizer_func, 111 | 3e-3, 112 | # get_layer_groups(net), 113 | get_voxeLO_net_layer_groups(net), 114 | wd=config.weight_decay, 115 | true_wd=optimizer_config.fixed_weight_decay, 116 | bn_wd=True) 117 | print(hasattr(optimizer, "_amp_stash"), '_amp_stash') 118 | if optimizer is None: 119 | raise ValueError('Optimizer %s not supported.' % optimizer_type) 120 | 121 | if optimizer_config.use_moving_average: 122 | raise ValueError('torch don\'t support moving average') 123 | if name is None: 124 | # assign a name to optimizer for checkpoint system 125 | optimizer.name=optimizer_type 126 | else: 127 | optimizer.name=name 128 | return optimizer 129 | -------------------------------------------------------------------------------- /rslo/builder/preprocess_builder.py: -------------------------------------------------------------------------------- 1 | import second.core.preprocess as prep 2 | 3 | def build_db_preprocess(db_prep_config): 4 | prep_type = db_prep_config.WhichOneof('database_preprocessing_step') 5 | 6 | if prep_type == 'filter_by_difficulty': 7 | cfg = db_prep_config.filter_by_difficulty 8 | return prep.DBFilterByDifficulty(list(cfg.removed_difficulties)) 9 | elif prep_type == 'filter_by_min_num_points': 10 | cfg = db_prep_config.filter_by_min_num_points 11 | return prep.DBFilterByMinNumPoint(dict(cfg.min_num_point_pairs)) 12 | else: 13 | raise ValueError("unknown database prep type") 14 | 15 | -------------------------------------------------------------------------------- /rslo/builder/second_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 yanyan. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """VoxelNet builder. 16 | """ 17 | 18 | from rslo.protos import second_pb2 19 | from rslo.builder import losses_builder 20 | from rslo.models.voxel_odom_net import get_voxelnet_class 21 | import rslo.models.voxel_odom_net 22 | # import rslo.models.voxel_odom_net_self_icp_enc 23 | # import rslo.models.voxel_odom_net_self_icp3_enc 24 | # import rslo.models.voxel_odom_net_self_icp_enc2 25 | 26 | def build(model_cfg: second_pb2.VoxelNet, voxel_generator, 27 | measure_time=False, testing=False): 28 | """build second pytorch instance. 29 | """ 30 | if not isinstance(model_cfg, second_pb2.VoxelNet): 31 | raise ValueError('model_cfg not of type ' 'second_pb2.VoxelNet.') 32 | vfe_num_filters = list(model_cfg.voxel_feature_extractor.num_filters) 33 | vfe_with_distance = model_cfg.voxel_feature_extractor.with_distance # ?? 34 | grid_size = voxel_generator.grid_size 35 | 36 | dense_shape = [1] + grid_size[::-1].tolist() + [vfe_num_filters[-1]] 37 | pc_range = voxel_generator.point_cloud_range 38 | print(dense_shape, '!!!', flush=True) # [1, 40, 1600, 1408, 16] [1, z,y,x, 16] 39 | 40 | num_input_features = model_cfg.num_point_features 41 | 42 | 43 | 44 | losses = losses_builder.build(model_cfg.loss) 45 | # encode_rad_error_by_sin = model_cfg.encode_rad_error_by_sin 46 | # cls_loss_ftor, loc_loss_ftor, cls_weight, loc_weight, _ = losses 47 | # rotation_loss_func, translation_loss_func, pyramid_rotation_loss_func, pyramid_translation_loss_func, consistency_loss = losses 48 | rigid_transform_loss=None 49 | py_rigid_transform_loss=None 50 | rotation_loss_func=None 51 | translation_loss_func=None 52 | pyramid_rotation_loss_func=None 53 | pyramid_translation_loss_func=None 54 | consistency_loss=None 55 | if len(losses) == 3: 56 | rigid_transform_loss, py_rigid_transform_loss, consistency_loss = losses 57 | 58 | else: 59 | rotation_loss_func, translation_loss_func, pyramid_rotation_loss_func, pyramid_translation_loss_func, consistency_loss = losses 60 | 61 | net = get_voxelnet_class(model_cfg.network_class_name)( 62 | dense_shape, 63 | pc_range=pc_range, 64 | vfe_class_name=model_cfg.voxel_feature_extractor.module_class_name, 65 | vfe_num_filters=vfe_num_filters, 66 | middle_class_name=model_cfg.middle_feature_extractor.module_class_name, 67 | middle_num_input_features=model_cfg.middle_feature_extractor.num_input_features, 68 | middle_num_filters_d1=list( 69 | model_cfg.middle_feature_extractor.num_filters_down1), 70 | middle_num_filters_d2=list( 71 | model_cfg.middle_feature_extractor.num_filters_down2), 72 | middle_use_leakyReLU=model_cfg.middle_feature_extractor.use_leakyReLU, 73 | middle_relu_type=model_cfg.middle_feature_extractor.relu_type, 74 | odom_class_name=model_cfg.odom_predictor.module_class_name, # "ResNetOdomPred", 75 | odom_num_input_features=model_cfg.odom_predictor.num_input_features, # -1, 76 | odom_layer_nums=model_cfg.odom_predictor.layer_nums, # [3, 5, 5], 77 | # [2, 2, 2], 78 | odom_layer_strides=model_cfg.odom_predictor.layer_strides, 79 | # [128, 128, 256], 80 | odom_num_filters=model_cfg.odom_predictor.num_filters, 81 | # [1, 2, 4], 82 | odom_upsample_strides=model_cfg.odom_predictor.upsample_strides, 83 | # [256, 256, 256], 84 | odom_num_upsample_filters=model_cfg.odom_predictor.num_upsample_filters, 85 | odom_pooling_size=model_cfg.odom_predictor.pool_size, 86 | odom_pooling_type=model_cfg.odom_predictor.pool_type, 87 | odom_cycle_constraint=model_cfg.odom_predictor.cycle_constraint, 88 | odom_conv_type=model_cfg.odom_predictor.conv_type, 89 | odom_format=model_cfg.odom_predictor.odom_format, 90 | # odom_use_spc=model_cfg.odom_predictor.use_sparse_conv, 91 | odom_pred_pyramid_motion=model_cfg.odom_predictor.pred_pyramid_motion, 92 | odom_use_deep_supervision=model_cfg.odom_predictor.use_deep_supervision, 93 | odom_use_loss_mask=not model_cfg.odom_predictor.not_use_loss_mask, 94 | odom_use_dynamic_mask=model_cfg.odom_predictor.use_dynamic_mask, 95 | odom_dense_predict=model_cfg.odom_predictor.dense_predict, 96 | odom_use_corr=model_cfg.odom_predictor.use_corr, 97 | odom_dropout=model_cfg.odom_predictor.dropout, 98 | odom_conf_type=model_cfg.odom_predictor.conf_type, 99 | odom_use_SPGN=model_cfg.odom_predictor.use_SPGN, 100 | odom_use_leakyReLU=model_cfg.odom_predictor.use_leakyReLU, 101 | vfe_use_norm=not model_cfg.voxel_feature_extractor.not_use_norm, # True, 102 | # middle_use_norm=not model_cfg.middle_feature_extractor.not_use_norm, # True, 103 | # odom_use_norm=not model_cfg.odom_predictor.not_use_norm, # True, 104 | middle_bn_type = model_cfg.middle_feature_extractor.bn_type, 105 | odom_bn_type = model_cfg.odom_predictor.bn_type, 106 | odom_enc_use_norm = not model_cfg.odom_predictor.not_use_enc_norm, 107 | odom_dropout_input=model_cfg.odom_predictor.dropout_input, 108 | odom_first_conv_groups=max( 109 | 1, model_cfg.odom_predictor.first_conv_groups), 110 | odom_use_se=model_cfg.odom_predictor.odom_use_se, 111 | odom_use_sa=model_cfg.odom_predictor.odom_use_sa, 112 | odom_use_svd=model_cfg.odom_predictor.use_svd, 113 | odom_cubic_pred_height=model_cfg.odom_predictor.cubic_pred_height, 114 | freeze_bn = model_cfg.freeze_bn, 115 | freeze_bn_affine=model_cfg.freeze_bn_affine, 116 | freeze_bn_start_step=model_cfg.freeze_bn_start_step, 117 | # sync_bn=model_cfg.sync_bn, 118 | use_GN=model_cfg.use_GN, 119 | num_input_features=num_input_features, 120 | 121 | encode_background_as_zeros=model_cfg.encode_background_as_zeros, 122 | 123 | with_distance=vfe_with_distance, 124 | rotation_loss=rotation_loss_func, 125 | translation_loss=translation_loss_func, 126 | pyramid_rotation_loss=pyramid_rotation_loss_func, 127 | pyramid_translation_loss=pyramid_translation_loss_func, 128 | rigid_transform_loss=rigid_transform_loss, 129 | pyramid_rigid_transform_loss = py_rigid_transform_loss, 130 | consistency_loss=consistency_loss, 131 | measure_time=measure_time, 132 | voxel_generator=voxel_generator, 133 | pyloss_exp_w_base=model_cfg.loss.pyloss_exp_w_base, 134 | testing=testing, 135 | icp_iter=model_cfg.icp_iter, 136 | ) 137 | return net -------------------------------------------------------------------------------- /rslo/builder/voxel_builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # from spconv.utils import VoxelGeneratorV2, VoxelGenerator 4 | from spconv.utils import VoxelGenerator 5 | from rslo.protos import voxel_generator_pb2 6 | 7 | 8 | # def build(voxel_config): 9 | # """Builds a tensor dictionary based on the InputReader config. 10 | 11 | # Args: 12 | # input_reader_config: A input_reader_pb2.InputReader object. 13 | 14 | # Returns: 15 | # A tensor dict based on the input_reader_config. 16 | 17 | # Raises: 18 | # ValueError: On invalid input reader proto. 19 | # ValueError: If no input paths are specified. 20 | # """ 21 | # if not isinstance(voxel_config, (voxel_generator_pb2.VoxelGenerator)): 22 | # raise ValueError('input_reader_config not of type ' 23 | # 'input_reader_pb2.InputReader.') 24 | # voxel_generator = VoxelGeneratorV2( 25 | # voxel_size=list(voxel_config.voxel_size), 26 | # point_cloud_range=list(voxel_config.point_cloud_range), 27 | # max_num_points=voxel_config.max_number_of_points_per_voxel, 28 | # max_voxels=20000, 29 | # full_mean=voxel_config.full_empty_part_with_mean, 30 | # block_filtering=voxel_config.block_filtering, 31 | # block_factor=voxel_config.block_factor, 32 | # block_size=voxel_config.block_size, 33 | # height_threshold=voxel_config.height_threshold) 34 | # return voxel_generator 35 | 36 | class _VoxelGenerator(VoxelGenerator): 37 | def __init__(self, *args, **kwargs): 38 | super(_VoxelGenerator, self).__init__(*args, **kwargs) 39 | 40 | @property 41 | def grid_size(self): 42 | point_cloud_range = np.array(self.point_cloud_range) 43 | voxel_size = np.array(self.voxel_size) 44 | g_size = (point_cloud_range[3:] - point_cloud_range[:3]) / voxel_size 45 | g_size = np.round(g_size).astype(np.int64) 46 | return g_size 47 | 48 | def generate(self, points, max_voxels=None): 49 | res = super(_VoxelGenerator, self).generate(points, max_voxels) 50 | 51 | return {"voxels": res[0], 52 | "coordinates": res[1], 53 | "num_points_per_voxel": res[2] 54 | } 55 | 56 | 57 | def build(voxel_config): 58 | """Builds a tensor dictionary based on the InputReader config. 59 | 60 | Args: 61 | input_reader_config: A input_reader_pb2.InputReader object. 62 | 63 | Returns: 64 | A tensor dict based on the input_reader_config. 65 | 66 | Raises: 67 | ValueError: On invalid input reader proto. 68 | ValueError: If no input paths are specified. 69 | """ 70 | if not isinstance(voxel_config, (voxel_generator_pb2.VoxelGenerator)): 71 | raise ValueError('input_reader_config not of type ' 72 | 'input_reader_pb2.InputReader.') 73 | # voxel_generator = VoxelGenerator2( 74 | 75 | voxel_config.block_filtering=True 76 | assert(voxel_config.block_filtering) 77 | if voxel_config.block_filtering: 78 | voxel_config.block_factor = max(1,voxel_config.block_factor) 79 | voxel_config.block_size = voxel_config.block_size if voxel_config.block_size>0 else 8 80 | 81 | voxel_config.height_threshold = voxel_config.height_threshold if voxel_config.height_threshold!=0 else 0.2 82 | 83 | voxel_generator = _VoxelGenerator( 84 | voxel_size=list(voxel_config.voxel_size), 85 | point_cloud_range=list(voxel_config.point_cloud_range), 86 | max_num_points=voxel_config.max_number_of_points_per_voxel, 87 | max_voxels=20000, 88 | full_mean=False, 89 | # full_mean=voxel_config.full_empty_part_with_mean, 90 | block_filtering=voxel_config.block_filtering, 91 | block_factor=voxel_config.block_factor, 92 | block_size=voxel_config.block_size, 93 | height_threshold=voxel_config.height_threshold 94 | ) 95 | return voxel_generator 96 | -------------------------------------------------------------------------------- /rslo/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/rslo/core/__init__.py -------------------------------------------------------------------------------- /rslo/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import kitti_dataset_hdf5 2 | from . import kitti_dataset_crossnorm_hdf5 3 | # from . import kitti_dataset_tmp 4 | -------------------------------------------------------------------------------- /rslo/data/dataset.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import pickle 3 | import time 4 | from functools import partial 5 | 6 | import numpy as np 7 | import torch 8 | import rslo.utils.pose_utils as tch_p 9 | 10 | # from rslo.core import box_np_ops 11 | # from rslo.core import preprocess as prep 12 | from rslo.data import kitti_common as kitti 13 | 14 | REGISTERED_DATASET_CLASSES = {} 15 | 16 | 17 | def register_dataset(cls, name=None): 18 | global REGISTERED_DATASET_CLASSES 19 | if name is None: 20 | name = cls.__name__ 21 | assert name not in REGISTERED_DATASET_CLASSES, f"exist class: {REGISTERED_DATASET_CLASSES}" 22 | REGISTERED_DATASET_CLASSES[name] = cls 23 | return cls 24 | 25 | 26 | def get_dataset_class(name): 27 | global REGISTERED_DATASET_CLASSES 28 | assert name in REGISTERED_DATASET_CLASSES, f"available class: {REGISTERED_DATASET_CLASSES}" 29 | return REGISTERED_DATASET_CLASSES[name] 30 | 31 | 32 | class Dataset(object): 33 | NumPointFeatures = -1 34 | 35 | def __getitem__(self, index): 36 | 37 | raise NotImplementedError 38 | 39 | def __len__(self): 40 | raise NotImplementedError 41 | 42 | def get_sensor_data(self, query): 43 | 44 | raise NotImplementedError 45 | 46 | def evaluation(self, dt_annos, output_dir): 47 | """Dataset must provide a evaluation function to evaluate model.""" 48 | raise NotImplementedError 49 | 50 | 51 | 52 | def generate_pointwise_local_transformation_tch(tq, spatial_size, origin_loc, voxel_size, inv_trans_factor=-1, ): 53 | ''' 54 | x up, y left 55 | x right, y up 56 | ''' 57 | # assert(len(spatial_size) == 2) 58 | # if isinstance(tq, np.ndarray): 59 | device= tq.device 60 | dtype = tq.dtype 61 | # device=torch.device(type='cuda', index=0) 62 | # tq = torch.from_numpy(tq.astype(np.float32)) 63 | # tq=tq.to(device=device) 64 | t_g = tq[:3] 65 | q_g = tq[3:] 66 | 67 | if len(spatial_size) == 2: 68 | size_x, size_y = spatial_size 69 | size_z = size_x//size_x 70 | elif len(spatial_size) == 3: 71 | size_x, size_y, size_z = spatial_size 72 | # size_z = 1 73 | else: 74 | raise ValueError() 75 | 76 | # generate coordinates grid 77 | 78 | iv, jv, kv = torch.meshgrid(torch.arange(size_y,device=device), torch.arange( 79 | size_x,device=device), torch.arange(size_z,device=device)) # (size_x,size_y) 80 | # move the origin to the middle 81 | # minus 0.5 to shift to the centre of each grid 82 | # xv = (jv - origin_loc[0]+0.5)*voxel_size[0] 83 | # yv = (-iv + origin_loc[1]-0.5)*voxel_size[1] 84 | xv = (jv - origin_loc[0])*voxel_size[0] 85 | yv = (-iv + origin_loc[1])*voxel_size[1] 86 | # zv = np.zeros_like(xv) 87 | zv = (kv-origin_loc[2])*voxel_size[2] 88 | 89 | # xyzv = np.stack([xv, yv, zv], axis=0).reshape([-1, 3]) # Nx3 90 | # Nx3 # fixed on 7/11/2019 91 | xyzv = torch.stack([xv, yv, zv], dim=-1).reshape([-1, 3]).to(dtype=dtype) 92 | 93 | if inv_trans_factor > 0: 94 | xyzv[:, :2] = inv_trans_factor / \ 95 | (np.linalg.norm(xyzv[:, :2], axis=1, 96 | keepdims=True)+0.1) ** 2 * xyzv[:, :2] 97 | 98 | t_l = tch_p.rotate_vec_by_q(t=t_g[None, ...]-xyzv, q=tch_p.qinv( 99 | q_g[None, ...]).repeat([xyzv.shape[0],1]) ) + xyzv 100 | 101 | # t_map = t_l.reshape([size_x, size_y, 3]) 102 | # t_map = t_l.reshape([size_y, size_x, 3]) 103 | #bug fixed on 8/1/2020 104 | t_map = t_l.reshape([size_y, size_x, size_z, 3]) 105 | # t_map = t_l.reshape([size_x, size_y, size_z, 3]) 106 | 107 | # q_map = np.ones([size_y, size_x, 4], np.float32)*q_g 108 | #bug fixed on 8/1/2020 109 | q_map = torch.ones([size_y, size_x, size_z, 4], dtype=dtype, device=device)*q_g 110 | # q_map = np.ones([size_x, size_y, size_z, 4], np.float32)*q_g 111 | 112 | # tq_map = np.concatenate([t_map, q_map], axis=-1).transpose([2, 0, 1]) 113 | tq_map = torch.cat( 114 | [t_map, q_map], dim=-1).permute(3, 2, 0, 1).squeeze() # channel, z,y,x 115 | # print(tq_map.shape,'!!!') 116 | return tq_map#torch.from_numpy(tq_map) 117 | 118 | 119 | 120 | 121 | def from_pointwise_local_transformation_tch(tq_map, pc_range, inv_trans_factor=-1): 122 | ''' 123 | x up, y left 124 | ''' 125 | 126 | dtype = tq_map.dtype 127 | device = tq_map.device 128 | input_shape = tq_map.shape 129 | spatial_size = input_shape[2:] 130 | if len(spatial_size)==2: 131 | spatial_size = [1]+list(spatial_size) 132 | 133 | grid_size = torch.from_numpy( 134 | np.array(list(spatial_size[::-1]))).to(device=device, dtype=dtype) # attention 135 | # voxel_size = np.array(voxel_generator.voxel_size)*4 # x,y,z 136 | pc_range = torch.from_numpy(pc_range).to(device=device, dtype=dtype) 137 | 138 | # voxel_size = (pc_range[3:5]-pc_range[0:2])/grid_size 139 | voxel_size = (pc_range[3:]-pc_range[:3])/grid_size 140 | 141 | # origin_loc = (0-pc_range[0])/(pc_range[3]-pc_range[0])*grid_size[0], (pc_range[4]-( 142 | # pc_range[4]+pc_range[1])/2)/(pc_range[4]-pc_range[1])*grid_size[1], 0 143 | # #fixed on 7/11/2019 the left-top of a grid is the anchor point 144 | 145 | origin_loc = (0-pc_range[0])/(pc_range[3]-pc_range[0])*(grid_size[0] 146 | ), (pc_range[4]-0)/(pc_range[4]-pc_range[1])*(grid_size[1]), (0-pc_range[2])/(pc_range[5]-pc_range[2]) * grid_size[2] 147 | 148 | assert len(input_shape) == 4 149 | # tq_map = torch.cat([t_map, q_map], dim=-1).permute([2, 0, 1]) 150 | # tq_map = tq_map.permute([1,2, 0]) 151 | #bug fixed on 8/1/2020 152 | tq_map = tq_map.permute([0, 2, 3, 1]).contiguous() #b,y,x,c !!! 153 | # tq_map = tq_map.permute([0, 3, 2, 1]).contiguous() #b,x,y,c 154 | # t_map = t_l.view([size_x, size_y, 3]) 155 | tq_map = tq_map.view(-1, 7) 156 | 157 | t_l = tq_map[:, :3] 158 | q_l = tq_map[:, 3:] 159 | 160 | size_z,size_y, size_x= spatial_size 161 | iv, jv, kv = torch.meshgrid([ 162 | torch.arange(size_y, dtype=dtype, device=device), 163 | torch.arange(size_x, dtype=dtype, device=device), 164 | torch.arange(size_z, dtype=dtype, device=device )]) # (size_x,size_y) 165 | # move the origin to the middle 166 | # minus 0.5 to shift to the centre of each grid 167 | # xv = (jv - origin_loc[0]+0.5)*voxel_size[0] 168 | # yv = (-iv + origin_loc[1]-0.5)*voxel_size[1] 169 | xv = (jv - origin_loc[0])*voxel_size[0] 170 | yv = (-iv + origin_loc[1])*voxel_size[1] 171 | zv = (kv-origin_loc[2])*voxel_size[2] 172 | # zv = torch.zeros_like(xv) 173 | 174 | # xyzv = torch.stack([xv, yv, zv], dim=0).reshape([-1, 3]) # Nx3 175 | # Nx3 # fixed on 7/11/2019 176 | xyzv = torch.stack([xv, yv, zv], dim=-1).reshape([-1, 3]) 177 | if inv_trans_factor > 0: 178 | xyzv[:, :2] = inv_trans_factor / \ 179 | (torch.norm(xyzv[:, :2], dim=1, keepdim=True) + 180 | 0.1) ** 2 * xyzv[:, :2] 181 | 182 | xyzv = torch.cat([xyzv]*input_shape[0], dim=0) 183 | 184 | # import pdb 185 | # pdb.set_trace() 186 | 187 | t_g = tch_p.rotate_vec_by_q(t=(t_l-xyzv), q=q_l) + xyzv 188 | 189 | #bug fixed on 8/1/2020 190 | t_map_g = t_g.view(input_shape[0], input_shape[2], input_shape[3], 3) #b,y,x,c 191 | # t_map_g = t_g.view(input_shape[0], input_shape[3], input_shape[2], 3) #b,x,y,c 192 | 193 | # torch.ones([spatial_size[0], spatial_size[1], 4], np.float32)*q_g 194 | #bug fixed on 8/1/2020 195 | q_map_g = q_l.view(input_shape[0], input_shape[2], input_shape[3], 4) 196 | # q_map_g = q_l.view(input_shape[0], input_shape[3], input_shape[2], 4) 197 | 198 | q_map_g = torch.nn.functional.normalize(q_map_g, dim=-1) 199 | 200 | # tq_map_g = torch.cat([t_map_g, q_map_g], dim=- 201 | # 1).permute([0, 3, 1, 2]).contiguous() 202 | #bug fixed on 8/1/2020 203 | tq_map_g = torch.cat([t_map_g, q_map_g], dim=- 204 | 1).permute([0, 3, 1, 2]).contiguous() 205 | # tq_map_g = torch.cat([t_map_g, q_map_g], dim=- 206 | # 1).permute([0, 3, 2, 1]).contiguous() #b,c,y,x 207 | 208 | return tq_map_g -------------------------------------------------------------------------------- /rslo/layers/MaskConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def set_requires_grad(nets, requires_grad=False): 6 | if not isinstance(nets, list): 7 | nets = [nets] 8 | for net in nets: 9 | if net is not None: 10 | for param in net.parameters(): 11 | param.requires_grad = requires_grad 12 | 13 | class MaskMaxPool2d(nn.MaxPool2d): 14 | def forward(self,x): 15 | if isinstance(x, (list, tuple)): 16 | return super(MaskMaxPool2d, self).forward(x[0]), x[1] 17 | else: 18 | return super(MaskMaxPool2d, self).forward(x) 19 | 20 | class MaskConv(nn.Module): 21 | 22 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, max_pool_mask=True, groups=1): 23 | super(MaskConv, self).__init__() 24 | # assert(kernel_size % 2 == 1) 25 | 26 | self.out_channels = out_channels 27 | self.use_bias = bias 28 | # pad = kernel_size//2 29 | # self.pad = nn.ZeroPad2d(padding) 30 | self.conv1 = nn.Conv2d(in_channels, out_channels, 31 | kernel_size=kernel_size, stride=stride, bias=False, padding=padding, groups=groups) 32 | self.max_pool_mask = max_pool_mask 33 | if max_pool_mask: 34 | self.mask_pool = nn.MaxPool2d( 35 | kernel_size, stride=stride, padding=padding) 36 | self.normalize_const = 1 37 | # print("max_pool_mask") 38 | else: 39 | self.mask_pool = nn.Conv2d(1, 1, 40 | kernel_size, stride, bias=False, padding=padding) 41 | set_requires_grad(self.mask_pool, requires_grad=False) 42 | nn.init.constant_(self.mask_pool.weight, 1) 43 | print("conv_mask") 44 | # self.normalize_const = kernel_size*kernel_size 45 | # def sparse_conv(self, tensor, binary_mask): 46 | def _conv(self, tensor, binary_mask): 47 | 48 | # if binary_mask is None: 49 | # binary_mask = torch.ones([b, 1, h, w]).cuda() 50 | # binary_mask = binary_mask.detach() ##comment on 14/11/19 51 | # print(binary_mask.requires_grad, '!!!') 52 | binary_mask = binary_mask 53 | tensor = self.conv1(tensor) 54 | mask = self.mask_pool(binary_mask) #/self.normalize_const 55 | if not self.max_pool_mask: 56 | # self.normalize_const = torch.max(mask, dim=(2,3), keepdim=True) 57 | self.normalize_const,_ = torch.max(mask.view(mask.shape[0], -1), dim=-1, keepdim=True) 58 | self.normalize_const = self.normalize_const.view(mask.shape[0], 1,1,1) 59 | 60 | mask /= self.normalize_const 61 | # print(mask.requires_grad, '!!!!') 62 | # return tensor, mask.detach() #comment on 14/11/19 63 | return tensor, mask.detach_() 64 | 65 | def forward(self, x): 66 | 67 | if not isinstance(x, (list, tuple)): 68 | x = [x, (torch.sum(x.abs(), dim=1, keepdim=True) != 0).float().detach()] 69 | tensor, binary_mask = x 70 | 71 | tensor, binary_mask = self._conv(tensor, binary_mask) 72 | 73 | return [tensor, binary_mask] 74 | 75 | 76 | class MaskConvTranspose2d(nn.ConvTranspose2d): 77 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 78 | padding=0, output_padding=0, groups=1, bias=True, dilation=1): 79 | super(MaskConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride=1, 80 | padding=0, output_padding=0, groups=1, bias=True, dilation=1) 81 | 82 | def forward(self, x): 83 | if not isinstance(x, (list, tuple)): 84 | x = [x, (torch.sum(x, dim=1, keepdim=True) != 0).float().detach()] 85 | tensor, binary_mask = x 86 | return super(MaskConvTranspose2d, self).forward(tensor), binary_mask.detach() 87 | -------------------------------------------------------------------------------- /rslo/layers/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | def set_requires_grad(nets, requires_grad=False): 6 | if not isinstance(nets, list): 7 | nets = [nets] 8 | for net in nets: 9 | if net is not None: 10 | for param in net.parameters(): 11 | param.requires_grad = requires_grad 12 | 13 | class ExpandDims(nn.Module): 14 | def __init__(self, insert_dims, dim): 15 | super().__init__() 16 | self.insert_dims = insert_dims 17 | self.dim = dim 18 | 19 | def _new_shape(self, old_shape): 20 | new_shape = list(old_shape) 21 | del new_shape[self.dim] 22 | new_shape[self.dim:self.dim] = self.insert_dims 23 | return new_shape 24 | def forward(self, x): 25 | new_shape = self._new_shape(x.shape) 26 | return x.reshape(new_shape) 27 | 28 | class ELUPlus(nn.ELU): 29 | def forward(self, x): 30 | return super(ELUPlus, self).forward(x)+1+1e-6 31 | 32 | 33 | class EXP(nn.Module): 34 | def __init__(self, truncated_threshold=1e20): 35 | super().__init__() 36 | self.truncated_threshold = truncated_threshold 37 | self.truncated_point = np.log(self.truncated_threshold) 38 | 39 | def forward(self, x): 40 | x = torch.tanh(x/self.truncated_point)*self.truncated_point 41 | out = torch.exp(x)+1e-6 42 | return out 43 | 44 | 45 | class ParameterLayer(nn.Module): 46 | def __init__(self, init_value, requires_grad=True): 47 | super(ParameterLayer, self).__init__() 48 | 49 | self.param = nn.Parameter(init_value, requires_grad=requires_grad) 50 | 51 | def forward(self, *input): 52 | return self.param 53 | 54 | 55 | class MaskPropagator(nn.Module): 56 | 57 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, max_pool_mask=True, groups=1): 58 | super(MaskPropagator, self).__init__() 59 | # assert(kernel_size % 2 == 1) 60 | self.max_pool_mask = max_pool_mask 61 | if max_pool_mask: 62 | self.mask_pool = nn.MaxPool2d( 63 | kernel_size, stride=stride, padding=padding) 64 | self.normalize_const = 1 65 | else: 66 | self.mask_pool = nn.Conv2d(1, 1, 67 | kernel_size, stride, bias=False, padding=padding) 68 | set_requires_grad(self.mask_pool, requires_grad=False) 69 | nn.init.constant_(self.mask_pool.weight, 1) 70 | self.normalize_const = 1 71 | def _conv(self, binary_mask): 72 | 73 | binary_mask = binary_mask 74 | mask = self.mask_pool(binary_mask) #/self.normalize_const 75 | if not self.max_pool_mask: 76 | # self.normalize_const = torch.max(mask, dim=(2,3), keepdim=True) 77 | # self.normalize_const,_ = torch.max(mask.view(mask.shape[0], -1), dim=-1, keepdim=True) 78 | # self.normalize_const = self.normalize_const.view(mask.shape[0], 1,1,1) 79 | self.normalizer_const = self.mask_pool((binary_mask>0).float() ) 80 | 81 | mask /= self.normalize_const 82 | return mask.detach_() 83 | 84 | def forward(self, mask): 85 | mask = self._conv(mask) 86 | return mask 87 | 88 | class Dropout2dGivenMask(nn.Module): 89 | def __init__(self, p, dim=1): 90 | super(Dropout2dGivenMask, self).__init__() 91 | 92 | self.p = p 93 | self.dim = 1 94 | 95 | def forward(self, input, mask=None): 96 | if mask is None: 97 | p = torch.ones(input.shape[self.dim], 98 | device=input.device) * (1-self.p) 99 | mask = torch.bernoulli(p).view(1, input.shape[self.dim], 1, 1) 100 | 101 | mask = mask.expand_as(input) 102 | input = torch.where(mask > 0, input, torch.zeros_like(input)) 103 | 104 | return input, mask 105 | 106 | class PCMaskGenerator(nn.Module): 107 | def __init__(self, kernel_sizes, strides, paddings, max_pool_mask=True, apply_func=None): 108 | super().__init__() 109 | assert isinstance(strides, (list, tuple)) 110 | layer_num = len(strides) 111 | 112 | if not isinstance(kernel_sizes, (list, tuple)): 113 | kernel_sizes = [kernel_sizes]*layer_num 114 | if not isinstance(paddings, (list, tuple)): 115 | paddings = [paddings]*layer_num 116 | layers = [] 117 | for k,s,p in zip(kernel_sizes, strides, paddings) : 118 | layers.append(MaskPropagator(1,1,k,s,p, max_pool_mask=max_pool_mask)) 119 | 120 | self.layers = nn.ModuleList(layers) 121 | self._apply_func = apply_func 122 | 123 | def forward(self, x): 124 | 125 | masks = [] 126 | for i in range(len(self.layers)): 127 | x = self.layers[i](x) 128 | if self._apply_func is not None: 129 | masks.append(self._apply_func(x)) 130 | # x = self._apply_func(x) 131 | else: 132 | masks.append(x) 133 | 134 | return masks 135 | -------------------------------------------------------------------------------- /rslo/layers/confidence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ConfidenceModule(nn.Module): 6 | def __init__(self, conf_model, conf_type='softmax'): 7 | super().__init__() 8 | assert conf_type in ['linear', 'softmax'] 9 | self.conf_model = conf_model 10 | self.conf_type = conf_type 11 | self.softmax = nn.Softmax(dim=-1) 12 | 13 | 14 | def forward(self, x, extra_mask=None, temperature=1, return_logit=False): 15 | 16 | if extra_mask is None: 17 | extra_mask = torch.ones_like(x) 18 | 19 | if self.conf_type == 'linear': 20 | conf = (F.elu(self.conf_model(x))+1+1e-12) * (extra_mask+1e-12) 21 | elif self.conf_type == 'softmax': 22 | # conf = self.conf_model(x) 23 | logit = self.conf_model(x) 24 | # conf = torch.where(extra_mask > 0, 25 | # conf, torch.full_like(conf, -1e20)) 26 | # min_conf = torch.min(conf) 27 | conf = torch.where(extra_mask > 0, 28 | logit, torch.full_like(logit, -1000)) 29 | 30 | conf_shape = conf.shape 31 | conf = conf.reshape(*conf.shape[0:2], -1) 32 | 33 | conf = self.softmax(conf/temperature) 34 | conf = conf.reshape(*conf_shape)#.clone() 35 | if return_logit: 36 | return conf, logit 37 | else: 38 | return conf -------------------------------------------------------------------------------- /rslo/layers/se_module.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | 5 | class SELayer(nn.Module): 6 | def __init__(self, channel, reduction=16): 7 | super(SELayer, self).__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | self.fc = nn.Sequential( 10 | nn.Linear(channel, channel // reduction, bias=False), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(channel // reduction, channel, bias=False), 13 | nn.Sigmoid() 14 | ) 15 | 16 | def forward(self, x): 17 | b, c, _, _ = x.size() 18 | y = self.avg_pool(x).view(b, c) 19 | y = self.fc(y).view(b, c, 1, 1) 20 | return x * y.expand_as(x) 21 | 22 | 23 | class SpatialAttentionLayer(nn.Module): 24 | def __init__(self, channel, reduction=16): 25 | super(SpatialAttentionLayer, self).__init__() 26 | self.attention = nn.Sequential( 27 | nn.Conv2d(channel, 1, kernel_size=1, padding=0), 28 | nn.Sigmoid() 29 | ) 30 | 31 | def forward(self, x): 32 | y = self.attention(x) 33 | return x * y.expand_as(x) 34 | 35 | class SpatialAttentionLayerV2(nn.Module): 36 | def __init__(self, channel, reduction=16): 37 | super(SpatialAttentionLayerV2, self).__init__() 38 | self.attention = nn.Sequential( 39 | nn.Conv2d(channel, channel//2, kernel_size=3, padding=1, dilation=1), 40 | nn.Conv2d(channel//2, channel, kernel_size=3, padding=2, dilation=2), 41 | # nn.Conv2d(channel//2, channel, kernel_size=3, stride=2,padding=1), 42 | # nn.Conv2d(channel, channel, kernel_size=1, padding=0), 43 | # nn.Upsample(size=), 44 | nn.Conv2d(channel, 1, kernel_size=3, padding=1), 45 | nn.Sigmoid() 46 | ) 47 | 48 | def forward(self, x): 49 | y = self.attention(x) 50 | return x * y.expand_as(x) 51 | class SpatialAttentionLayerV3(nn.Module): 52 | def __init__(self, channel, reduction=16, LayerNorm=nn.InstanceNorm2d): 53 | super(SpatialAttentionLayerV3, self).__init__() 54 | print(f"SpatialAttentionLayerV3: LayerNone={LayerNorm}") 55 | self.conv1 = nn.Sequential( 56 | nn.Conv2d(channel, 2*channel, kernel_size=3, stride=2, padding=1, ), 57 | LayerNorm(2*channel), 58 | nn.LeakyReLU() 59 | ) 60 | self.conv2 = nn.Sequential( 61 | nn.Conv2d(2*channel, 2*channel, kernel_size=3, stride=1, padding=1, ), 62 | LayerNorm(2*channel), 63 | nn.LeakyReLU() 64 | ) 65 | self.deconv1 = nn.Sequential( 66 | nn.Upsample(scale_factor=2), 67 | nn.Conv2d(2*channel, channel, kernel_size=3, stride=1, padding=1 ), 68 | LayerNorm(channel), 69 | nn.LeakyReLU() 70 | ) 71 | self.conv3 = nn.Sequential( 72 | nn.Conv2d(2*channel, 1, kernel_size=1, stride=1, padding=0, ), 73 | nn.Sigmoid(), 74 | ) 75 | 76 | 77 | def forward(self, x): 78 | x1= self.conv1(x) 79 | x1 = self.conv2(x1) 80 | x1 = self.deconv1(x1) 81 | x1 = torch.cat([x1,x], dim=1) 82 | y = self.conv3(x1) 83 | 84 | # y = self.attention(x) 85 | return x * y.expand_as(x) -------------------------------------------------------------------------------- /rslo/layers/svd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import apex.amp as amp 4 | 5 | 6 | 7 | class SVDHead(nn.Module): 8 | def __init__(self, args=None): 9 | super(SVDHead, self).__init__() 10 | self.reflect = nn.Parameter(torch.eye(3), requires_grad=False) 11 | self.reflect[2, 2] = -1 12 | 13 | @amp.float_function 14 | def forward(self, src, tgt, weight=None ): 15 | ''' 16 | src: Bx3xN 17 | tgt: Bx3xN 18 | weight: BxN 19 | ''' 20 | 21 | batch_size = src.size(0) 22 | 23 | src_centered = src - src.mean(dim=2, keepdim=True) #Bx3xN 24 | src_corr_centered = tgt-tgt.mean(dim=2, keepdim=True) #Bx3xN 25 | 26 | if weight is None: 27 | H = torch.matmul(src_centered, src_corr_centered.transpose(2, 1).contiguous()) 28 | else: 29 | H = torch.matmul(src_centered*weight[:,None,:], src_corr_centered.transpose(2, 1).contiguous() )# + torch.eye(3, device=src_centered.device)*1e-9 30 | 31 | # print(H, src_centered.mean(),src_corr_centered.mean(), src_corr_centered.shape) 32 | U, S, V = [], [], [] 33 | R = [] 34 | 35 | #iterate on the batch dimmension 36 | for i in range(src.size(0)): 37 | u, s, v = torch.svd(H[i]) 38 | # u,s,v=u.cuda(),s.cuda(),v.cuda() 39 | r = torch.matmul(v, u.transpose(1, 0).contiguous()) 40 | r_det = torch.det(r) 41 | if r_det < 0: 42 | u, s, v = torch.svd(H[i]) 43 | v = torch.matmul(v, self.reflect.to(device=v.device)) 44 | r = torch.matmul(v, u.transpose(1, 0).contiguous()) # r=v@u^T 45 | # r = r * self.reflect 46 | R.append(r) 47 | U.append(u) 48 | S.append(s) 49 | V.append(v) 50 | 51 | U = torch.stack(U, dim=0) 52 | V = torch.stack(V, dim=0) 53 | S = torch.stack(S, dim=0) 54 | R = torch.stack(R, dim=0) 55 | 56 | # t = torch.matmul(-R, src.mean(dim=2, keepdim=True)) + src_corr.mean(dim=2, keepdim=True) 57 | t = torch.matmul(-R, src.mean(dim=2, keepdim=True)) + tgt.mean(dim=2, keepdim=True) 58 | #Xs = R@Xt + t 59 | R = R.transpose(-1,-2).contiguous() 60 | t = -R@t 61 | 62 | # R = R.to(dtype=dtype) 63 | # t = t.to(dtype=dtype) 64 | return R, t.view(batch_size, 3) 65 | -------------------------------------------------------------------------------- /rslo/protos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/rslo/protos/__init__.py -------------------------------------------------------------------------------- /rslo/protos/complile.sh: -------------------------------------------------------------------------------- 1 | cd ../../ 2 | find . | grep "\./rslo/protos/.*\.proto$" | xargs -n1 -I {} protoc --proto_path . --python_out=. {} 3 | 4 | # find ls | grep ".proto" | xargs -n1 -I {} protoc --proto_path . --python_out=. {} -------------------------------------------------------------------------------- /rslo/protos/input_reader.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package second.protos; 4 | import "rslo/protos/preprocess.proto"; 5 | import "rslo/protos/sampler.proto"; 6 | 7 | message InputReader { 8 | uint32 batch_size = 1; 9 | 10 | message Dataset { 11 | string kitti_info_path = 1; 12 | string kitti_root_path = 2; 13 | string dataset_class_name = 3; // support KittiDataset and NuScenesDataset 14 | int32 seq_length=4; 15 | int32 skip=5; 16 | bool random_skip=6; 17 | int32 step=7; 18 | 19 | } 20 | Dataset dataset = 2; 21 | message Preprocess { 22 | bool shuffle_points = 1; 23 | uint32 max_number_of_voxels = 2; 24 | repeated float groundtruth_localization_noise_std = 3; 25 | repeated float groundtruth_rotation_uniform_noise = 4; 26 | repeated float global_rotation_uniform_noise = 5; 27 | repeated float global_scaling_uniform_noise = 6; 28 | repeated float global_translate_noise_std = 7; 29 | bool remove_unknown_examples = 8; 30 | uint32 num_workers = 9; 31 | float anchor_area_threshold = 10; 32 | bool remove_points_after_sample = 11; 33 | float groundtruth_points_drop_percentage = 12; 34 | uint32 groundtruth_drop_max_keep_points = 13; 35 | bool remove_environment = 14; 36 | repeated float global_random_rotation_range_per_object = 15; 37 | repeated DatabasePreprocessingStep database_prep_steps = 16; 38 | Sampler database_sampler = 17; 39 | bool use_group_id = 18; // this will enable group sample and noise 40 | int64 min_num_of_points_in_gt = 19; // gt boxes contains less than this will be ignored. 41 | bool random_flip_x = 20; 42 | bool random_flip_y = 21; 43 | float sample_importance = 22; 44 | float rand_rotation_eps=23; 45 | float rand_translation_eps=24; 46 | float random_aug_ratio=25; 47 | float do_pre_transform=26; 48 | bool cubic_tq_map=27; 49 | repeated float downsample_voxel_sizes=28; 50 | } 51 | Preprocess preprocess = 3; 52 | uint32 max_num_epochs = 4; // deprecated 53 | uint32 prefetch_size = 5; // deprecated 54 | 55 | float review_cycle=6; 56 | } 57 | -------------------------------------------------------------------------------- /rslo/protos/losses.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package second.protos; 4 | 5 | // Message for configuring the localization loss, classification loss and hard 6 | // example miner used for training object detection models. See core/losses.py 7 | // for details 8 | message Loss { 9 | RotaionLoss rotation_loss = 1; 10 | TranslationLoss translation_loss = 2; 11 | RotaionLoss pyramid_rotation_loss=3; 12 | TranslationLoss pyramid_translation_loss=4; 13 | ConsistencyLoss consistency_loss=5; 14 | float pyloss_exp_w_base=6; 15 | RigidTransformLoss rigid_transform_loss= 7; 16 | // float rotation_weight = 4; 17 | 18 | // float translation_weight = 5; 19 | 20 | } 21 | 22 | message RigidTransformLoss { 23 | // string loss_type =1; 24 | float weight = 2 ; 25 | // float init_alpha= 3; 26 | // bool not_learn_alpha=4; 27 | float focal_gamma=5; 28 | } 29 | message RotaionLoss { 30 | // oneof { 31 | // WeightedL2LocalizationLoss weighted_l2 = 1; 32 | // WeightedSmoothL1LocalizationLoss weighted_smooth_l1 = 2; 33 | // WeightedGHMLocalizationLoss weighted_ghm = 3; 34 | // } 35 | 36 | string loss_type =1; 37 | float weight = 2 ; 38 | float init_alpha= 3; 39 | bool not_learn_alpha=4; 40 | float focal_gamma=5; 41 | float balance_scale=6; 42 | } 43 | message TranslationLoss { 44 | 45 | string loss_type =1; 46 | float weight = 2 ; 47 | float init_alpha= 3; 48 | bool not_learn_alpha=4; 49 | float focal_gamma=5; 50 | float balance_scale=6; 51 | } 52 | message ConsistencyLoss{ 53 | 54 | string loss_type =1; 55 | float weight = 2 ; 56 | float init_alpha= 3; 57 | bool not_learn_alpha=4; 58 | float focal_gamma=5; 59 | float penalize_ratio=6; 60 | repeated float sample_block_size=7; 61 | bool norm=8; 62 | float pred_downsample_ratio=9; 63 | float reg_weight=10; 64 | float sph_weight=11; 65 | } 66 | -------------------------------------------------------------------------------- /rslo/protos/losses.proto.bak: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package second.protos; 4 | 5 | // Message for configuring the localization loss, classification loss and hard 6 | // example miner used for training object detection models. See core/losses.py 7 | // for details 8 | message Loss { 9 | // Localization loss to use. 10 | LocalizationLoss localization_loss = 1; 11 | 12 | // Classification loss to use. 13 | ClassificationLoss classification_loss = 2; 14 | 15 | // If not left to default, applies hard example mining. 16 | HardExampleMiner hard_example_miner = 3; 17 | 18 | // Classification loss weight. 19 | float classification_weight = 4; 20 | 21 | // Localization loss weight. 22 | float localization_weight = 5; 23 | 24 | } 25 | 26 | // Configuration for bounding box localization loss function. 27 | message LocalizationLoss { 28 | oneof localization_loss { 29 | WeightedL2LocalizationLoss weighted_l2 = 1; 30 | WeightedSmoothL1LocalizationLoss weighted_smooth_l1 = 2; 31 | WeightedGHMLocalizationLoss weighted_ghm = 3; 32 | } 33 | bool encode_rad_error_by_sin = 4; 34 | } 35 | 36 | // L2 location loss: 0.5 * ||weight * (a - b)|| ^ 2 37 | message WeightedL2LocalizationLoss { 38 | // DEPRECATED, do not use. 39 | // Output loss per anchor. 40 | bool anchorwise_output = 1; 41 | repeated float code_weight = 2; 42 | } 43 | 44 | // SmoothL1 (Huber) location loss: .5 * x ^ 2 if |x| < 1 else |x| - .5 45 | message WeightedSmoothL1LocalizationLoss { 46 | // DEPRECATED, do not use. 47 | // Output loss per anchor. 48 | bool anchorwise_output = 1; 49 | float sigma = 2; 50 | repeated float code_weight = 3; 51 | } 52 | message WeightedGHMLocalizationLoss { 53 | // DEPRECATED, do not use. 54 | // Output loss per anchor. 55 | bool anchorwise_output = 1; 56 | float mu = 2; 57 | int32 bins = 3; 58 | float momentum = 4; 59 | repeated float code_weight = 5; 60 | } 61 | 62 | 63 | // Configuration for class prediction loss function. 64 | message ClassificationLoss { 65 | oneof classification_loss { 66 | WeightedSigmoidClassificationLoss weighted_sigmoid = 1; 67 | WeightedSoftmaxClassificationLoss weighted_softmax = 2; 68 | BootstrappedSigmoidClassificationLoss bootstrapped_sigmoid = 3; 69 | SigmoidFocalClassificationLoss weighted_sigmoid_focal = 4; 70 | SoftmaxFocalClassificationLoss weighted_softmax_focal = 5; 71 | GHMClassificationLoss weighted_ghm = 6; 72 | } 73 | } 74 | 75 | // Classification loss using a sigmoid function over class predictions. 76 | message WeightedSigmoidClassificationLoss { 77 | // DEPRECATED, do not use. 78 | // Output loss per anchor. 79 | bool anchorwise_output = 1; 80 | } 81 | 82 | // Sigmoid Focal cross entropy loss as described in 83 | // https://arxiv.org/abs/1708.02002 84 | message SigmoidFocalClassificationLoss { 85 | // DEPRECATED, do not use. 86 | bool anchorwise_output = 1; 87 | // modulating factor for the loss. 88 | float gamma = 2; 89 | // alpha weighting factor for the loss. 90 | float alpha = 3; 91 | } 92 | // Sigmoid Focal cross entropy loss as described in 93 | // https://arxiv.org/abs/1708.02002 94 | message SoftmaxFocalClassificationLoss { 95 | // DEPRECATED, do not use. 96 | bool anchorwise_output = 1; 97 | // modulating factor for the loss. 98 | float gamma = 2; 99 | // alpha weighting factor for the loss. 100 | float alpha = 3; 101 | } 102 | message GHMClassificationLoss { 103 | bool anchorwise_output = 1; 104 | int32 bins = 2; 105 | float momentum = 3; 106 | } 107 | // Classification loss using a softmax function over class predictions. 108 | message WeightedSoftmaxClassificationLoss { 109 | // DEPRECATED, do not use. 110 | // Output loss per anchor. 111 | bool anchorwise_output = 1; 112 | // Scale logit (input) value before calculating softmax classification loss. 113 | // Typically used for softmax distillation. 114 | float logit_scale = 2; 115 | } 116 | 117 | // Classification loss using a sigmoid function over the class prediction with 118 | // the highest prediction score. 119 | message BootstrappedSigmoidClassificationLoss { 120 | // Interpolation weight between 0 and 1. 121 | float alpha = 1; 122 | 123 | // Whether hard boot strapping should be used or not. If true, will only use 124 | // one class favored by model. Othewise, will use all predicted class 125 | // probabilities. 126 | bool hard_bootstrap = 2; 127 | 128 | // DEPRECATED, do not use. 129 | // Output loss per anchor. 130 | bool anchorwise_output = 3; 131 | } 132 | 133 | // Configuation for hard example miner. 134 | message HardExampleMiner { 135 | // Maximum number of hard examples to be selected per image (prior to 136 | // enforcing max negative to positive ratio constraint). If set to 0, 137 | // all examples obtained after NMS are considered. 138 | int32 num_hard_examples = 1; 139 | 140 | // Minimum intersection over union for an example to be discarded during NMS. 141 | float iou_threshold = 2; 142 | 143 | // Whether to use classification losses ('cls', default), localization losses 144 | // ('loc') or both losses ('both'). In the case of 'both', cls_loss_weight and 145 | // loc_loss_weight are used to compute weighted sum of the two losses. 146 | enum LossType { 147 | BOTH = 0; 148 | CLASSIFICATION = 1; 149 | LOCALIZATION = 2; 150 | } 151 | LossType loss_type = 3; 152 | 153 | // Maximum number of negatives to retain for each positive anchor. If 154 | // num_negatives_per_positive is 0 no prespecified negative:positive ratio is 155 | // enforced. 156 | int32 max_negatives_per_positive = 4; 157 | 158 | // Minimum number of negative anchors to sample for a given image. Setting 159 | // this to a positive number samples negatives in an image without any 160 | // positive anchors and thus not bias the model towards having at least one 161 | // detection per image. 162 | int32 min_negatives_per_image = 5; 163 | } 164 | -------------------------------------------------------------------------------- /rslo/protos/model.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package second.protos; 4 | import "rslo/protos/second.proto"; 5 | message DetectionModel{ 6 | oneof model { 7 | VoxelNet second = 1; 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /rslo/protos/model_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: rslo/protos/model.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from rslo.protos import second_pb2 as rslo_dot_protos_dot_second__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='rslo/protos/model.proto', 21 | package='second.protos', 22 | syntax='proto3', 23 | serialized_options=None, 24 | serialized_pb=_b('\n\x17rslo/protos/model.proto\x12\rsecond.protos\x1a\x18rslo/protos/second.proto\"D\n\x0e\x44\x65tectionModel\x12)\n\x06second\x18\x01 \x01(\x0b\x32\x17.second.protos.VoxelNetH\x00\x42\x07\n\x05modelb\x06proto3') 25 | , 26 | dependencies=[rslo_dot_protos_dot_second__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _DETECTIONMODEL = _descriptor.Descriptor( 32 | name='DetectionModel', 33 | full_name='second.protos.DetectionModel', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='second', full_name='second.protos.DetectionModel.second', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | serialized_options=None, file=DESCRIPTOR), 45 | ], 46 | extensions=[ 47 | ], 48 | nested_types=[], 49 | enum_types=[ 50 | ], 51 | serialized_options=None, 52 | is_extendable=False, 53 | syntax='proto3', 54 | extension_ranges=[], 55 | oneofs=[ 56 | _descriptor.OneofDescriptor( 57 | name='model', full_name='second.protos.DetectionModel.model', 58 | index=0, containing_type=None, fields=[]), 59 | ], 60 | serialized_start=68, 61 | serialized_end=136, 62 | ) 63 | 64 | _DETECTIONMODEL.fields_by_name['second'].message_type = rslo_dot_protos_dot_second__pb2._VOXELNET 65 | _DETECTIONMODEL.oneofs_by_name['model'].fields.append( 66 | _DETECTIONMODEL.fields_by_name['second']) 67 | _DETECTIONMODEL.fields_by_name['second'].containing_oneof = _DETECTIONMODEL.oneofs_by_name['model'] 68 | DESCRIPTOR.message_types_by_name['DetectionModel'] = _DETECTIONMODEL 69 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 70 | 71 | DetectionModel = _reflection.GeneratedProtocolMessageType('DetectionModel', (_message.Message,), { 72 | 'DESCRIPTOR' : _DETECTIONMODEL, 73 | '__module__' : 'rslo.protos.model_pb2' 74 | # @@protoc_insertion_point(class_scope:second.protos.DetectionModel) 75 | }) 76 | _sym_db.RegisterMessage(DetectionModel) 77 | 78 | 79 | # @@protoc_insertion_point(module_scope) 80 | -------------------------------------------------------------------------------- /rslo/protos/optimizer.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package second.protos; 4 | 5 | // Messages for configuring the optimizing strategy for training object 6 | // detection models. 7 | 8 | // Top level optimizer message. 9 | message Optimizer { 10 | oneof optimizer { 11 | RMSPropOptimizer rms_prop_optimizer = 1; 12 | MomentumOptimizer momentum_optimizer = 2; 13 | AdamOptimizer adam_optimizer = 3; 14 | } 15 | bool use_moving_average = 4; 16 | float moving_average_decay = 5; 17 | bool fixed_weight_decay = 6; // i.e. AdamW 18 | } 19 | 20 | // Configuration message for the RMSPropOptimizer 21 | // See: https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer 22 | message RMSPropOptimizer { 23 | LearningRate learning_rate = 1; 24 | float momentum_optimizer_value = 2; 25 | float decay = 3; 26 | float epsilon = 4; 27 | float weight_decay = 5; 28 | } 29 | 30 | // Configuration message for the MomentumOptimizer 31 | // See: https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer 32 | message MomentumOptimizer { 33 | LearningRate learning_rate = 1; 34 | float momentum_optimizer_value = 2; 35 | float weight_decay = 3; 36 | } 37 | 38 | // Configuration message for the AdamOptimizer 39 | // See: https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 40 | message AdamOptimizer { 41 | LearningRate learning_rate = 1; 42 | float weight_decay = 2; 43 | bool amsgrad = 3; 44 | } 45 | 46 | message LearningRate { 47 | oneof learning_rate { 48 | MultiPhase multi_phase = 1; 49 | OneCycle one_cycle = 2; 50 | ExponentialDecay exponential_decay = 3; 51 | ManualStepping manual_stepping = 4; 52 | ExponentialDecayWarmup exponential_decay_warmup=5; 53 | } 54 | } 55 | 56 | message LearningRatePhase { 57 | float start = 1; 58 | string lambda_func = 2; 59 | string momentum_lambda_func = 3; 60 | } 61 | 62 | message MultiPhase { 63 | repeated LearningRatePhase phases = 1; 64 | } 65 | 66 | message OneCycle { 67 | float lr_max = 1; 68 | repeated float moms = 2; 69 | float div_factor = 3; 70 | float pct_start = 4; 71 | repeated float lr_maxs = 5; 72 | 73 | } 74 | 75 | /* 76 | ManualStepping example: 77 | initial_learning_rate = 0.001 78 | decay_length = 0.1 79 | decay_factor = 0.8 80 | staircase = True 81 | detail: 82 | progress 0%~10%, lr=0.001 83 | progress 10%~20%, lr=0.001 * 0.8 84 | progress 20%~30%, lr=0.001 * 0.8 * 0.8 85 | ...... 86 | */ 87 | 88 | 89 | message ExponentialDecay { 90 | float initial_learning_rate = 1; 91 | float decay_length = 2; // must in range (0, 1) 92 | float decay_factor = 3; 93 | bool staircase = 4; 94 | } 95 | 96 | message ExponentialDecayWarmup { 97 | float initial_learning_rate = 1; 98 | float decay_length = 2; // must in range (0, 1) 99 | float decay_factor = 3; 100 | bool staircase = 4; 101 | float div_factor=6; 102 | float pct_start=7; 103 | } 104 | 105 | /* 106 | ManualStepping example: 107 | boundaries = [0.8, 0.9] 108 | rates = [0.001, 0.002, 0.003] 109 | detail: 110 | progress 0%~80%, lr=0.001 111 | progress 80%~90%, lr=0.002 112 | progress 90%~100%, lr=0.003 113 | */ 114 | 115 | message ManualStepping { 116 | repeated float boundaries = 1; // must in range (0, 1) 117 | repeated float rates = 2; 118 | } -------------------------------------------------------------------------------- /rslo/protos/pipeline.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package second.protos; 4 | 5 | import "rslo/protos/input_reader.proto"; 6 | import "rslo/protos/model.proto"; 7 | import "rslo/protos/train.proto"; 8 | // Convenience message for configuring a training and eval pipeline. Allows all 9 | // of the pipeline parameters to be configured from one file. 10 | message TrainEvalPipelineConfig { 11 | DetectionModel model = 1; 12 | InputReader train_input_reader = 2; 13 | TrainConfig train_config = 3; 14 | InputReader eval_input_reader = 4; 15 | InputReader eval_train_input_reader = 5; 16 | } 17 | 18 | -------------------------------------------------------------------------------- /rslo/protos/pipeline_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: rslo/protos/pipeline.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from rslo.protos import input_reader_pb2 as rslo_dot_protos_dot_input__reader__pb2 17 | from rslo.protos import model_pb2 as rslo_dot_protos_dot_model__pb2 18 | from rslo.protos import train_pb2 as rslo_dot_protos_dot_train__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='rslo/protos/pipeline.proto', 23 | package='second.protos', 24 | syntax='proto3', 25 | serialized_options=None, 26 | serialized_pb=_b('\n\x1arslo/protos/pipeline.proto\x12\rsecond.protos\x1a\x1erslo/protos/input_reader.proto\x1a\x17rslo/protos/model.proto\x1a\x17rslo/protos/train.proto\"\xa5\x02\n\x17TrainEvalPipelineConfig\x12,\n\x05model\x18\x01 \x01(\x0b\x32\x1d.second.protos.DetectionModel\x12\x36\n\x12train_input_reader\x18\x02 \x01(\x0b\x32\x1a.second.protos.InputReader\x12\x30\n\x0ctrain_config\x18\x03 \x01(\x0b\x32\x1a.second.protos.TrainConfig\x12\x35\n\x11\x65val_input_reader\x18\x04 \x01(\x0b\x32\x1a.second.protos.InputReader\x12;\n\x17\x65val_train_input_reader\x18\x05 \x01(\x0b\x32\x1a.second.protos.InputReaderb\x06proto3') 27 | , 28 | dependencies=[rslo_dot_protos_dot_input__reader__pb2.DESCRIPTOR,rslo_dot_protos_dot_model__pb2.DESCRIPTOR,rslo_dot_protos_dot_train__pb2.DESCRIPTOR,]) 29 | 30 | 31 | 32 | 33 | _TRAINEVALPIPELINECONFIG = _descriptor.Descriptor( 34 | name='TrainEvalPipelineConfig', 35 | full_name='second.protos.TrainEvalPipelineConfig', 36 | filename=None, 37 | file=DESCRIPTOR, 38 | containing_type=None, 39 | fields=[ 40 | _descriptor.FieldDescriptor( 41 | name='model', full_name='second.protos.TrainEvalPipelineConfig.model', index=0, 42 | number=1, type=11, cpp_type=10, label=1, 43 | has_default_value=False, default_value=None, 44 | message_type=None, enum_type=None, containing_type=None, 45 | is_extension=False, extension_scope=None, 46 | serialized_options=None, file=DESCRIPTOR), 47 | _descriptor.FieldDescriptor( 48 | name='train_input_reader', full_name='second.protos.TrainEvalPipelineConfig.train_input_reader', index=1, 49 | number=2, type=11, cpp_type=10, label=1, 50 | has_default_value=False, default_value=None, 51 | message_type=None, enum_type=None, containing_type=None, 52 | is_extension=False, extension_scope=None, 53 | serialized_options=None, file=DESCRIPTOR), 54 | _descriptor.FieldDescriptor( 55 | name='train_config', full_name='second.protos.TrainEvalPipelineConfig.train_config', index=2, 56 | number=3, type=11, cpp_type=10, label=1, 57 | has_default_value=False, default_value=None, 58 | message_type=None, enum_type=None, containing_type=None, 59 | is_extension=False, extension_scope=None, 60 | serialized_options=None, file=DESCRIPTOR), 61 | _descriptor.FieldDescriptor( 62 | name='eval_input_reader', full_name='second.protos.TrainEvalPipelineConfig.eval_input_reader', index=3, 63 | number=4, type=11, cpp_type=10, label=1, 64 | has_default_value=False, default_value=None, 65 | message_type=None, enum_type=None, containing_type=None, 66 | is_extension=False, extension_scope=None, 67 | serialized_options=None, file=DESCRIPTOR), 68 | _descriptor.FieldDescriptor( 69 | name='eval_train_input_reader', full_name='second.protos.TrainEvalPipelineConfig.eval_train_input_reader', index=4, 70 | number=5, type=11, cpp_type=10, label=1, 71 | has_default_value=False, default_value=None, 72 | message_type=None, enum_type=None, containing_type=None, 73 | is_extension=False, extension_scope=None, 74 | serialized_options=None, file=DESCRIPTOR), 75 | ], 76 | extensions=[ 77 | ], 78 | nested_types=[], 79 | enum_types=[ 80 | ], 81 | serialized_options=None, 82 | is_extendable=False, 83 | syntax='proto3', 84 | extension_ranges=[], 85 | oneofs=[ 86 | ], 87 | serialized_start=128, 88 | serialized_end=421, 89 | ) 90 | 91 | _TRAINEVALPIPELINECONFIG.fields_by_name['model'].message_type = rslo_dot_protos_dot_model__pb2._DETECTIONMODEL 92 | _TRAINEVALPIPELINECONFIG.fields_by_name['train_input_reader'].message_type = rslo_dot_protos_dot_input__reader__pb2._INPUTREADER 93 | _TRAINEVALPIPELINECONFIG.fields_by_name['train_config'].message_type = rslo_dot_protos_dot_train__pb2._TRAINCONFIG 94 | _TRAINEVALPIPELINECONFIG.fields_by_name['eval_input_reader'].message_type = rslo_dot_protos_dot_input__reader__pb2._INPUTREADER 95 | _TRAINEVALPIPELINECONFIG.fields_by_name['eval_train_input_reader'].message_type = rslo_dot_protos_dot_input__reader__pb2._INPUTREADER 96 | DESCRIPTOR.message_types_by_name['TrainEvalPipelineConfig'] = _TRAINEVALPIPELINECONFIG 97 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 98 | 99 | TrainEvalPipelineConfig = _reflection.GeneratedProtocolMessageType('TrainEvalPipelineConfig', (_message.Message,), { 100 | 'DESCRIPTOR' : _TRAINEVALPIPELINECONFIG, 101 | '__module__' : 'rslo.protos.pipeline_pb2' 102 | # @@protoc_insertion_point(class_scope:second.protos.TrainEvalPipelineConfig) 103 | }) 104 | _sym_db.RegisterMessage(TrainEvalPipelineConfig) 105 | 106 | 107 | # @@protoc_insertion_point(module_scope) 108 | -------------------------------------------------------------------------------- /rslo/protos/preprocess.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package second.protos; 4 | 5 | message DatabasePreprocessingStep { 6 | oneof database_preprocessing_step { 7 | DBFilterByDifficulty filter_by_difficulty = 1; 8 | DBFilterByMinNumPointInGroundTruth filter_by_min_num_points = 2; 9 | } 10 | } 11 | 12 | message DBFilterByDifficulty{ 13 | repeated int32 removed_difficulties = 1; 14 | } 15 | 16 | message DBFilterByMinNumPointInGroundTruth{ 17 | map min_num_point_pairs = 1; 18 | } 19 | -------------------------------------------------------------------------------- /rslo/protos/sampler.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package second.protos; 4 | import "rslo/protos/preprocess.proto"; 5 | 6 | message Group{ 7 | map name_to_max_num = 1; 8 | } 9 | 10 | message Sampler{ 11 | string database_info_path = 1; 12 | repeated Group sample_groups = 2; 13 | repeated DatabasePreprocessingStep database_prep_steps = 3; 14 | repeated float global_random_rotation_range_per_object = 4; 15 | float rate = 5; 16 | } 17 | -------------------------------------------------------------------------------- /rslo/protos/sampler_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: rslo/protos/sampler.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from rslo.protos import preprocess_pb2 as rslo_dot_protos_dot_preprocess__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='rslo/protos/sampler.proto', 21 | package='second.protos', 22 | syntax='proto3', 23 | serialized_options=None, 24 | serialized_pb=_b('\n\x19rslo/protos/sampler.proto\x12\rsecond.protos\x1a\x1crslo/protos/preprocess.proto\"}\n\x05Group\x12?\n\x0fname_to_max_num\x18\x01 \x03(\x0b\x32&.second.protos.Group.NameToMaxNumEntry\x1a\x33\n\x11NameToMaxNumEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\r:\x02\x38\x01\"\xd8\x01\n\x07Sampler\x12\x1a\n\x12\x64\x61tabase_info_path\x18\x01 \x01(\t\x12+\n\rsample_groups\x18\x02 \x03(\x0b\x32\x14.second.protos.Group\x12\x45\n\x13\x64\x61tabase_prep_steps\x18\x03 \x03(\x0b\x32(.second.protos.DatabasePreprocessingStep\x12/\n\'global_random_rotation_range_per_object\x18\x04 \x03(\x02\x12\x0c\n\x04rate\x18\x05 \x01(\x02\x62\x06proto3') 25 | , 26 | dependencies=[rslo_dot_protos_dot_preprocess__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _GROUP_NAMETOMAXNUMENTRY = _descriptor.Descriptor( 32 | name='NameToMaxNumEntry', 33 | full_name='second.protos.Group.NameToMaxNumEntry', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='key', full_name='second.protos.Group.NameToMaxNumEntry.key', index=0, 40 | number=1, type=9, cpp_type=9, label=1, 41 | has_default_value=False, default_value=_b("").decode('utf-8'), 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | serialized_options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='value', full_name='second.protos.Group.NameToMaxNumEntry.value', index=1, 47 | number=2, type=13, cpp_type=3, label=1, 48 | has_default_value=False, default_value=0, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | serialized_options=None, file=DESCRIPTOR), 52 | ], 53 | extensions=[ 54 | ], 55 | nested_types=[], 56 | enum_types=[ 57 | ], 58 | serialized_options=_b('8\001'), 59 | is_extendable=False, 60 | syntax='proto3', 61 | extension_ranges=[], 62 | oneofs=[ 63 | ], 64 | serialized_start=148, 65 | serialized_end=199, 66 | ) 67 | 68 | _GROUP = _descriptor.Descriptor( 69 | name='Group', 70 | full_name='second.protos.Group', 71 | filename=None, 72 | file=DESCRIPTOR, 73 | containing_type=None, 74 | fields=[ 75 | _descriptor.FieldDescriptor( 76 | name='name_to_max_num', full_name='second.protos.Group.name_to_max_num', index=0, 77 | number=1, type=11, cpp_type=10, label=3, 78 | has_default_value=False, default_value=[], 79 | message_type=None, enum_type=None, containing_type=None, 80 | is_extension=False, extension_scope=None, 81 | serialized_options=None, file=DESCRIPTOR), 82 | ], 83 | extensions=[ 84 | ], 85 | nested_types=[_GROUP_NAMETOMAXNUMENTRY, ], 86 | enum_types=[ 87 | ], 88 | serialized_options=None, 89 | is_extendable=False, 90 | syntax='proto3', 91 | extension_ranges=[], 92 | oneofs=[ 93 | ], 94 | serialized_start=74, 95 | serialized_end=199, 96 | ) 97 | 98 | 99 | _SAMPLER = _descriptor.Descriptor( 100 | name='Sampler', 101 | full_name='second.protos.Sampler', 102 | filename=None, 103 | file=DESCRIPTOR, 104 | containing_type=None, 105 | fields=[ 106 | _descriptor.FieldDescriptor( 107 | name='database_info_path', full_name='second.protos.Sampler.database_info_path', index=0, 108 | number=1, type=9, cpp_type=9, label=1, 109 | has_default_value=False, default_value=_b("").decode('utf-8'), 110 | message_type=None, enum_type=None, containing_type=None, 111 | is_extension=False, extension_scope=None, 112 | serialized_options=None, file=DESCRIPTOR), 113 | _descriptor.FieldDescriptor( 114 | name='sample_groups', full_name='second.protos.Sampler.sample_groups', index=1, 115 | number=2, type=11, cpp_type=10, label=3, 116 | has_default_value=False, default_value=[], 117 | message_type=None, enum_type=None, containing_type=None, 118 | is_extension=False, extension_scope=None, 119 | serialized_options=None, file=DESCRIPTOR), 120 | _descriptor.FieldDescriptor( 121 | name='database_prep_steps', full_name='second.protos.Sampler.database_prep_steps', index=2, 122 | number=3, type=11, cpp_type=10, label=3, 123 | has_default_value=False, default_value=[], 124 | message_type=None, enum_type=None, containing_type=None, 125 | is_extension=False, extension_scope=None, 126 | serialized_options=None, file=DESCRIPTOR), 127 | _descriptor.FieldDescriptor( 128 | name='global_random_rotation_range_per_object', full_name='second.protos.Sampler.global_random_rotation_range_per_object', index=3, 129 | number=4, type=2, cpp_type=6, label=3, 130 | has_default_value=False, default_value=[], 131 | message_type=None, enum_type=None, containing_type=None, 132 | is_extension=False, extension_scope=None, 133 | serialized_options=None, file=DESCRIPTOR), 134 | _descriptor.FieldDescriptor( 135 | name='rate', full_name='second.protos.Sampler.rate', index=4, 136 | number=5, type=2, cpp_type=6, label=1, 137 | has_default_value=False, default_value=float(0), 138 | message_type=None, enum_type=None, containing_type=None, 139 | is_extension=False, extension_scope=None, 140 | serialized_options=None, file=DESCRIPTOR), 141 | ], 142 | extensions=[ 143 | ], 144 | nested_types=[], 145 | enum_types=[ 146 | ], 147 | serialized_options=None, 148 | is_extendable=False, 149 | syntax='proto3', 150 | extension_ranges=[], 151 | oneofs=[ 152 | ], 153 | serialized_start=202, 154 | serialized_end=418, 155 | ) 156 | 157 | _GROUP_NAMETOMAXNUMENTRY.containing_type = _GROUP 158 | _GROUP.fields_by_name['name_to_max_num'].message_type = _GROUP_NAMETOMAXNUMENTRY 159 | _SAMPLER.fields_by_name['sample_groups'].message_type = _GROUP 160 | _SAMPLER.fields_by_name['database_prep_steps'].message_type = rslo_dot_protos_dot_preprocess__pb2._DATABASEPREPROCESSINGSTEP 161 | DESCRIPTOR.message_types_by_name['Group'] = _GROUP 162 | DESCRIPTOR.message_types_by_name['Sampler'] = _SAMPLER 163 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 164 | 165 | Group = _reflection.GeneratedProtocolMessageType('Group', (_message.Message,), { 166 | 167 | 'NameToMaxNumEntry' : _reflection.GeneratedProtocolMessageType('NameToMaxNumEntry', (_message.Message,), { 168 | 'DESCRIPTOR' : _GROUP_NAMETOMAXNUMENTRY, 169 | '__module__' : 'rslo.protos.sampler_pb2' 170 | # @@protoc_insertion_point(class_scope:second.protos.Group.NameToMaxNumEntry) 171 | }) 172 | , 173 | 'DESCRIPTOR' : _GROUP, 174 | '__module__' : 'rslo.protos.sampler_pb2' 175 | # @@protoc_insertion_point(class_scope:second.protos.Group) 176 | }) 177 | _sym_db.RegisterMessage(Group) 178 | _sym_db.RegisterMessage(Group.NameToMaxNumEntry) 179 | 180 | Sampler = _reflection.GeneratedProtocolMessageType('Sampler', (_message.Message,), { 181 | 'DESCRIPTOR' : _SAMPLER, 182 | '__module__' : 'rslo.protos.sampler_pb2' 183 | # @@protoc_insertion_point(class_scope:second.protos.Sampler) 184 | }) 185 | _sym_db.RegisterMessage(Sampler) 186 | 187 | 188 | _GROUP_NAMETOMAXNUMENTRY._options = None 189 | # @@protoc_insertion_point(module_scope) 190 | -------------------------------------------------------------------------------- /rslo/protos/second.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package second.protos; 4 | import "rslo/protos/losses.proto"; 5 | // import "rslo/protos/box_coder.proto"; 6 | // import "rslo/protos/target.proto"; 7 | import "rslo/protos/voxel_generator.proto"; 8 | 9 | message VoxelNet { 10 | string network_class_name = 1; 11 | VoxelGenerator voxel_generator = 2; 12 | message VoxelFeatureExtractor { 13 | string module_class_name = 1; 14 | repeated int32 num_filters = 2; 15 | bool with_distance = 3; 16 | int32 num_input_features = 4; 17 | bool not_use_norm = 5; 18 | } 19 | VoxelFeatureExtractor voxel_feature_extractor = 3; 20 | message MiddleFeatureExtractor { 21 | string module_class_name = 1; 22 | repeated int32 num_filters_down1 = 2; 23 | repeated int32 num_filters_down2 = 3; 24 | int32 num_input_features = 4; 25 | int32 downsample_factor = 5; 26 | // bool not_use_norm = 6; 27 | bool use_leakyReLU = 7; 28 | string relu_type = 8; 29 | string bn_type=9; 30 | 31 | } 32 | MiddleFeatureExtractor middle_feature_extractor = 4; 33 | 34 | message OdomPredictor{ 35 | string module_class_name =1; 36 | int32 num_input_features =2; 37 | repeated int32 layer_nums=3; 38 | repeated int32 layer_strides=4; 39 | repeated int32 num_filters =5; 40 | repeated int32 upsample_strides=6; 41 | repeated int32 num_upsample_filters=7; 42 | // int32 avgpool_size=8; 43 | int32 pool_size=8; 44 | string pool_type=14; 45 | bool cycle_constraint=9; 46 | // bool not_use_norm = 10; 47 | // bool use_sparse_conv = 11; 48 | string conv_type=11; //['official', 'sparse_conv', 'mask_conv'] 49 | bool pred_pyramid_motion=12; 50 | string odom_format=13; 51 | bool dense_predict=15; 52 | bool use_corr=16; 53 | float dropout=17; 54 | string conf_type=18; 55 | bool use_SPGN=19; 56 | bool use_deep_supervision=20; 57 | bool not_use_loss_mask=21; 58 | bool use_dynamic_mask=22; 59 | bool use_leakyReLU=23; 60 | bool dropout_input=24; 61 | int32 first_conv_groups=25; 62 | bool odom_use_se=26; 63 | bool odom_use_sa=27; 64 | int32 cubic_pred_height=28; 65 | bool not_use_enc_norm=29; 66 | bool use_svd=30; 67 | string bn_type=31; 68 | // bool freeze_bn=30; 69 | // bool freeze_bn_affine=31; 70 | } 71 | OdomPredictor odom_predictor=5; 72 | 73 | uint32 num_point_features = 6; 74 | bool use_sigmoid_score = 7; 75 | Loss loss = 8; 76 | // bool encode_rad_error_by_sin = 9; 77 | bool encode_background_as_zeros = 10; 78 | bool use_GN=11; 79 | repeated float post_center_limit_range = 18; 80 | 81 | // deprecated in future 82 | bool lidar_input = 24; 83 | 84 | bool freeze_bn=30; 85 | bool freeze_bn_affine=31; 86 | int32 freeze_bn_start_step =32; 87 | uint32 icp_iter=33; 88 | // bool sync_bn=33; 89 | } -------------------------------------------------------------------------------- /rslo/protos/similarity.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package second.protos; 4 | 5 | // Configuration proto for region similarity calculators. See 6 | // core/region_similarity_calculator.py for details. 7 | message RegionSimilarityCalculator { 8 | oneof region_similarity { 9 | RotateIouSimilarity rotate_iou_similarity = 1; 10 | NearestIouSimilarity nearest_iou_similarity = 2; 11 | DistanceSimilarity distance_similarity = 3; 12 | } 13 | } 14 | 15 | // Configuration for intersection-over-union (IOU) similarity calculator. 16 | message RotateIouSimilarity { 17 | } 18 | 19 | // Configuration for intersection-over-union (IOU) similarity calculator. 20 | message NearestIouSimilarity { 21 | } 22 | 23 | // Configuration for intersection-over-union (IOU) similarity calculator. 24 | message DistanceSimilarity { 25 | float distance_norm = 1; 26 | bool with_rotation = 2; 27 | float rotation_alpha = 3; 28 | } -------------------------------------------------------------------------------- /rslo/protos/target.proto.bak: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package second.protos; 4 | import "rslo/protos/anchors.proto"; 5 | import "rslo/protos/similarity.proto"; 6 | 7 | message ClassSetting { 8 | oneof anchor_generator { 9 | AnchorGeneratorStride anchor_generator_stride = 1; 10 | AnchorGeneratorRange anchor_generator_range = 2; 11 | NoAnchor no_anchor = 3; 12 | } 13 | RegionSimilarityCalculator region_similarity_calculator = 4; 14 | bool use_multi_class_nms = 5; 15 | bool use_rotate_nms = 6; 16 | int32 nms_pre_max_size = 7; 17 | int32 nms_post_max_size = 8; 18 | float nms_score_threshold = 9; 19 | float nms_iou_threshold = 10; 20 | float matched_threshold = 11; 21 | float unmatched_threshold = 12; 22 | string class_name = 13; 23 | repeated int64 feature_map_size = 14; // 3D zyx (DHW) size 24 | } 25 | 26 | message TargetAssigner { 27 | repeated ClassSetting class_settings = 1; 28 | float sample_positive_fraction = 2; 29 | uint32 sample_size = 3; 30 | bool assign_per_class = 4; 31 | repeated int64 nms_pre_max_sizes = 5; // this will override setting in ClassSettings if provide. 32 | repeated int64 nms_post_max_sizes = 6; // this will override setting in ClassSettings if provide. 33 | repeated int64 nms_score_thresholds = 7; // this will override setting in ClassSettings if provide. 34 | repeated int64 nms_iou_thresholds = 8; // this will override setting in ClassSettings if provide. 35 | } -------------------------------------------------------------------------------- /rslo/protos/train.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package second.protos; 4 | 5 | import "rslo/protos/optimizer.proto"; 6 | import "rslo/protos/preprocess.proto"; 7 | 8 | message TrainConfig{ 9 | Optimizer optimizer = 1; 10 | uint32 steps = 2; 11 | uint32 steps_per_eval = 3; 12 | uint32 save_checkpoints_secs = 4; 13 | uint32 save_summary_steps = 5; 14 | bool enable_mixed_precision = 6; 15 | float loss_scale_factor = 7; 16 | bool clear_metrics_every_epoch = 8; 17 | } -------------------------------------------------------------------------------- /rslo/protos/train_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: rslo/protos/train.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from rslo.protos import optimizer_pb2 as rslo_dot_protos_dot_optimizer__pb2 17 | from rslo.protos import preprocess_pb2 as rslo_dot_protos_dot_preprocess__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='rslo/protos/train.proto', 22 | package='second.protos', 23 | syntax='proto3', 24 | serialized_options=None, 25 | serialized_pb=_b('\n\x17rslo/protos/train.proto\x12\rsecond.protos\x1a\x1brslo/protos/optimizer.proto\x1a\x1crslo/protos/preprocess.proto\"\xfa\x01\n\x0bTrainConfig\x12+\n\toptimizer\x18\x01 \x01(\x0b\x32\x18.second.protos.Optimizer\x12\r\n\x05steps\x18\x02 \x01(\r\x12\x16\n\x0esteps_per_eval\x18\x03 \x01(\r\x12\x1d\n\x15save_checkpoints_secs\x18\x04 \x01(\r\x12\x1a\n\x12save_summary_steps\x18\x05 \x01(\r\x12\x1e\n\x16\x65nable_mixed_precision\x18\x06 \x01(\x08\x12\x19\n\x11loss_scale_factor\x18\x07 \x01(\x02\x12!\n\x19\x63lear_metrics_every_epoch\x18\x08 \x01(\x08\x62\x06proto3') 26 | , 27 | dependencies=[rslo_dot_protos_dot_optimizer__pb2.DESCRIPTOR,rslo_dot_protos_dot_preprocess__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _TRAINCONFIG = _descriptor.Descriptor( 33 | name='TrainConfig', 34 | full_name='second.protos.TrainConfig', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='optimizer', full_name='second.protos.TrainConfig.optimizer', index=0, 41 | number=1, type=11, cpp_type=10, label=1, 42 | has_default_value=False, default_value=None, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | serialized_options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='steps', full_name='second.protos.TrainConfig.steps', index=1, 48 | number=2, type=13, cpp_type=3, label=1, 49 | has_default_value=False, default_value=0, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | serialized_options=None, file=DESCRIPTOR), 53 | _descriptor.FieldDescriptor( 54 | name='steps_per_eval', full_name='second.protos.TrainConfig.steps_per_eval', index=2, 55 | number=3, type=13, cpp_type=3, label=1, 56 | has_default_value=False, default_value=0, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | serialized_options=None, file=DESCRIPTOR), 60 | _descriptor.FieldDescriptor( 61 | name='save_checkpoints_secs', full_name='second.protos.TrainConfig.save_checkpoints_secs', index=3, 62 | number=4, type=13, cpp_type=3, label=1, 63 | has_default_value=False, default_value=0, 64 | message_type=None, enum_type=None, containing_type=None, 65 | is_extension=False, extension_scope=None, 66 | serialized_options=None, file=DESCRIPTOR), 67 | _descriptor.FieldDescriptor( 68 | name='save_summary_steps', full_name='second.protos.TrainConfig.save_summary_steps', index=4, 69 | number=5, type=13, cpp_type=3, label=1, 70 | has_default_value=False, default_value=0, 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | serialized_options=None, file=DESCRIPTOR), 74 | _descriptor.FieldDescriptor( 75 | name='enable_mixed_precision', full_name='second.protos.TrainConfig.enable_mixed_precision', index=5, 76 | number=6, type=8, cpp_type=7, label=1, 77 | has_default_value=False, default_value=False, 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | serialized_options=None, file=DESCRIPTOR), 81 | _descriptor.FieldDescriptor( 82 | name='loss_scale_factor', full_name='second.protos.TrainConfig.loss_scale_factor', index=6, 83 | number=7, type=2, cpp_type=6, label=1, 84 | has_default_value=False, default_value=float(0), 85 | message_type=None, enum_type=None, containing_type=None, 86 | is_extension=False, extension_scope=None, 87 | serialized_options=None, file=DESCRIPTOR), 88 | _descriptor.FieldDescriptor( 89 | name='clear_metrics_every_epoch', full_name='second.protos.TrainConfig.clear_metrics_every_epoch', index=7, 90 | number=8, type=8, cpp_type=7, label=1, 91 | has_default_value=False, default_value=False, 92 | message_type=None, enum_type=None, containing_type=None, 93 | is_extension=False, extension_scope=None, 94 | serialized_options=None, file=DESCRIPTOR), 95 | ], 96 | extensions=[ 97 | ], 98 | nested_types=[], 99 | enum_types=[ 100 | ], 101 | serialized_options=None, 102 | is_extendable=False, 103 | syntax='proto3', 104 | extension_ranges=[], 105 | oneofs=[ 106 | ], 107 | serialized_start=102, 108 | serialized_end=352, 109 | ) 110 | 111 | _TRAINCONFIG.fields_by_name['optimizer'].message_type = rslo_dot_protos_dot_optimizer__pb2._OPTIMIZER 112 | DESCRIPTOR.message_types_by_name['TrainConfig'] = _TRAINCONFIG 113 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 114 | 115 | TrainConfig = _reflection.GeneratedProtocolMessageType('TrainConfig', (_message.Message,), { 116 | 'DESCRIPTOR' : _TRAINCONFIG, 117 | '__module__' : 'rslo.protos.train_pb2' 118 | # @@protoc_insertion_point(class_scope:second.protos.TrainConfig) 119 | }) 120 | _sym_db.RegisterMessage(TrainConfig) 121 | 122 | 123 | # @@protoc_insertion_point(module_scope) 124 | -------------------------------------------------------------------------------- /rslo/protos/voxel_generator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package second.protos; 4 | 5 | message VoxelGenerator{ 6 | repeated float voxel_size = 1; 7 | repeated float point_cloud_range = 2; 8 | uint32 max_number_of_points_per_voxel = 3; 9 | bool full_empty_part_with_mean = 4; 10 | bool block_filtering = 5; 11 | int64 block_factor = 6; 12 | int64 block_size = 7; 13 | float height_threshold = 8; 14 | } 15 | -------------------------------------------------------------------------------- /rslo/protos/voxel_generator_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: rslo/protos/voxel_generator.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='rslo/protos/voxel_generator.proto', 20 | package='second.protos', 21 | syntax='proto3', 22 | serialized_options=None, 23 | serialized_pb=_b('\n!rslo/protos/voxel_generator.proto\x12\rsecond.protos\"\xe7\x01\n\x0eVoxelGenerator\x12\x12\n\nvoxel_size\x18\x01 \x03(\x02\x12\x19\n\x11point_cloud_range\x18\x02 \x03(\x02\x12&\n\x1emax_number_of_points_per_voxel\x18\x03 \x01(\r\x12!\n\x19\x66ull_empty_part_with_mean\x18\x04 \x01(\x08\x12\x17\n\x0f\x62lock_filtering\x18\x05 \x01(\x08\x12\x14\n\x0c\x62lock_factor\x18\x06 \x01(\x03\x12\x12\n\nblock_size\x18\x07 \x01(\x03\x12\x18\n\x10height_threshold\x18\x08 \x01(\x02\x62\x06proto3') 24 | ) 25 | 26 | 27 | 28 | 29 | _VOXELGENERATOR = _descriptor.Descriptor( 30 | name='VoxelGenerator', 31 | full_name='second.protos.VoxelGenerator', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name='voxel_size', full_name='second.protos.VoxelGenerator.voxel_size', index=0, 38 | number=1, type=2, cpp_type=6, label=3, 39 | has_default_value=False, default_value=[], 40 | message_type=None, enum_type=None, containing_type=None, 41 | is_extension=False, extension_scope=None, 42 | serialized_options=None, file=DESCRIPTOR), 43 | _descriptor.FieldDescriptor( 44 | name='point_cloud_range', full_name='second.protos.VoxelGenerator.point_cloud_range', index=1, 45 | number=2, type=2, cpp_type=6, label=3, 46 | has_default_value=False, default_value=[], 47 | message_type=None, enum_type=None, containing_type=None, 48 | is_extension=False, extension_scope=None, 49 | serialized_options=None, file=DESCRIPTOR), 50 | _descriptor.FieldDescriptor( 51 | name='max_number_of_points_per_voxel', full_name='second.protos.VoxelGenerator.max_number_of_points_per_voxel', index=2, 52 | number=3, type=13, cpp_type=3, label=1, 53 | has_default_value=False, default_value=0, 54 | message_type=None, enum_type=None, containing_type=None, 55 | is_extension=False, extension_scope=None, 56 | serialized_options=None, file=DESCRIPTOR), 57 | _descriptor.FieldDescriptor( 58 | name='full_empty_part_with_mean', full_name='second.protos.VoxelGenerator.full_empty_part_with_mean', index=3, 59 | number=4, type=8, cpp_type=7, label=1, 60 | has_default_value=False, default_value=False, 61 | message_type=None, enum_type=None, containing_type=None, 62 | is_extension=False, extension_scope=None, 63 | serialized_options=None, file=DESCRIPTOR), 64 | _descriptor.FieldDescriptor( 65 | name='block_filtering', full_name='second.protos.VoxelGenerator.block_filtering', index=4, 66 | number=5, type=8, cpp_type=7, label=1, 67 | has_default_value=False, default_value=False, 68 | message_type=None, enum_type=None, containing_type=None, 69 | is_extension=False, extension_scope=None, 70 | serialized_options=None, file=DESCRIPTOR), 71 | _descriptor.FieldDescriptor( 72 | name='block_factor', full_name='second.protos.VoxelGenerator.block_factor', index=5, 73 | number=6, type=3, cpp_type=2, label=1, 74 | has_default_value=False, default_value=0, 75 | message_type=None, enum_type=None, containing_type=None, 76 | is_extension=False, extension_scope=None, 77 | serialized_options=None, file=DESCRIPTOR), 78 | _descriptor.FieldDescriptor( 79 | name='block_size', full_name='second.protos.VoxelGenerator.block_size', index=6, 80 | number=7, type=3, cpp_type=2, label=1, 81 | has_default_value=False, default_value=0, 82 | message_type=None, enum_type=None, containing_type=None, 83 | is_extension=False, extension_scope=None, 84 | serialized_options=None, file=DESCRIPTOR), 85 | _descriptor.FieldDescriptor( 86 | name='height_threshold', full_name='second.protos.VoxelGenerator.height_threshold', index=7, 87 | number=8, type=2, cpp_type=6, label=1, 88 | has_default_value=False, default_value=float(0), 89 | message_type=None, enum_type=None, containing_type=None, 90 | is_extension=False, extension_scope=None, 91 | serialized_options=None, file=DESCRIPTOR), 92 | ], 93 | extensions=[ 94 | ], 95 | nested_types=[], 96 | enum_types=[ 97 | ], 98 | serialized_options=None, 99 | is_extendable=False, 100 | syntax='proto3', 101 | extension_ranges=[], 102 | oneofs=[ 103 | ], 104 | serialized_start=53, 105 | serialized_end=284, 106 | ) 107 | 108 | DESCRIPTOR.message_types_by_name['VoxelGenerator'] = _VOXELGENERATOR 109 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 110 | 111 | VoxelGenerator = _reflection.GeneratedProtocolMessageType('VoxelGenerator', (_message.Message,), { 112 | 'DESCRIPTOR' : _VOXELGENERATOR, 113 | '__module__' : 'rslo.protos.voxel_generator_pb2' 114 | # @@protoc_insertion_point(class_scope:second.protos.VoxelGenerator) 115 | }) 116 | _sym_db.RegisterMessage(VoxelGenerator) 117 | 118 | 119 | # @@protoc_insertion_point(module_scope) 120 | -------------------------------------------------------------------------------- /rslo/torchplus/__init__.py: -------------------------------------------------------------------------------- 1 | from . import train 2 | from . import nn 3 | from . import metrics 4 | from . import tools 5 | 6 | from .tools import change_default_args 7 | from torchplus.ops.array_ops import scatter_nd, gather_nd, roll 8 | -------------------------------------------------------------------------------- /rslo/torchplus/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from torchplus.nn.functional import one_hot 2 | from torchplus.nn.modules.common import Empty, Sequential 3 | from torchplus.nn.modules.normalization import GroupNorm 4 | -------------------------------------------------------------------------------- /rslo/torchplus/nn/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def one_hot(tensor, depth, dim=-1, on_value=1.0, dtype=torch.float32): 4 | tensor_onehot = torch.zeros( 5 | *list(tensor.shape), depth, dtype=dtype, device=tensor.device) 6 | tensor_onehot.scatter_(dim, tensor.unsqueeze(dim).long(), on_value) 7 | return tensor_onehot 8 | -------------------------------------------------------------------------------- /rslo/torchplus/nn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/rslo/torchplus/nn/modules/__init__.py -------------------------------------------------------------------------------- /rslo/torchplus/nn/modules/common.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | class Empty(torch.nn.Module): 9 | def __init__(self, *args, **kwargs): 10 | super(Empty, self).__init__() 11 | self.weight = torch.zeros([1, ]) # dummy varaible 12 | 13 | def forward(self, *args, **kwargs): 14 | if len(args) == 1: 15 | return args[0] 16 | elif len(args) == 0: 17 | return None 18 | return args 19 | 20 | 21 | class Sequential(torch.nn.Module): 22 | r"""A sequential container. 23 | Modules will be added to it in the order they are passed in the constructor. 24 | Alternatively, an ordered dict of modules can also be passed in. 25 | 26 | To make it easier to understand, given is a small example:: 27 | 28 | # Example of using Sequential 29 | model = Sequential( 30 | nn.Conv2d(1,20,5), 31 | nn.ReLU(), 32 | nn.Conv2d(20,64,5), 33 | nn.ReLU() 34 | ) 35 | 36 | # Example of using Sequential with OrderedDict 37 | model = Sequential(OrderedDict([ 38 | ('conv1', nn.Conv2d(1,20,5)), 39 | ('relu1', nn.ReLU()), 40 | ('conv2', nn.Conv2d(20,64,5)), 41 | ('relu2', nn.ReLU()) 42 | ])) 43 | 44 | # Example of using Sequential with kwargs(python 3.6+) 45 | model = Sequential( 46 | conv1=nn.Conv2d(1,20,5), 47 | relu1=nn.ReLU(), 48 | conv2=nn.Conv2d(20,64,5), 49 | relu2=nn.ReLU() 50 | ) 51 | """ 52 | 53 | def __init__(self, *args, **kwargs): 54 | super(Sequential, self).__init__() 55 | if len(args) == 1 and isinstance(args[0], OrderedDict): 56 | for key, module in args[0].items(): 57 | self.add_module(key, module) 58 | else: 59 | for idx, module in enumerate(args): 60 | self.add_module(str(idx), module) 61 | for name, module in kwargs.items(): 62 | if sys.version_info < (3, 6): 63 | raise ValueError("kwargs only supported in py36+") 64 | if name in self._modules: 65 | raise ValueError("name exists.") 66 | self.add_module(name, module) 67 | 68 | def __getitem__(self, idx): 69 | if not (-len(self) <= idx < len(self)): 70 | raise IndexError('index {} is out of range'.format(idx)) 71 | if idx < 0: 72 | idx += len(self) 73 | it = iter(self._modules.values()) 74 | for i in range(idx): 75 | next(it) 76 | return next(it) 77 | 78 | def __len__(self): 79 | return len(self._modules) 80 | 81 | def add(self, module, name=None): 82 | if name is None: 83 | name = str(len(self._modules)) 84 | if name in self._modules: 85 | raise KeyError("name exists") 86 | self.add_module(name, module) 87 | 88 | def forward(self, input): 89 | # i = 0 90 | for module in self._modules.values(): 91 | # print(i) 92 | input = module(input) 93 | # i += 1 94 | return input 95 | -------------------------------------------------------------------------------- /rslo/torchplus/nn/modules/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class GroupNorm(torch.nn.GroupNorm): 5 | def __init__(self, num_channels, num_groups, eps=1e-5, affine=True): 6 | super().__init__( 7 | num_groups=num_groups, 8 | num_channels=num_channels, 9 | eps=eps, 10 | affine=affine) 11 | -------------------------------------------------------------------------------- /rslo/torchplus/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DecaYale/RSLO/9ef3689bcc2baa71aea5106e79dd5e6a047ce707/rslo/torchplus/ops/__init__.py -------------------------------------------------------------------------------- /rslo/torchplus/ops/array_ops.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import math 3 | import time 4 | import torch 5 | from typing import Optional 6 | 7 | 8 | def scatter_nd(indices, updates, shape): 9 | """pytorch edition of tensorflow scatter_nd. 10 | this function don't contain except handle code. so use this carefully 11 | when indice repeats, don't support repeat add which is supported 12 | in tensorflow. 13 | """ 14 | ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device) 15 | ndim = indices.shape[-1] 16 | output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:] 17 | flatted_indices = indices.view(-1, ndim) 18 | slices = [flatted_indices[:, i] for i in range(ndim)] 19 | slices += [Ellipsis] 20 | ret[slices] = updates.view(*output_shape) 21 | return ret 22 | 23 | 24 | def gather_nd(params, indices): 25 | # this function has a limit that MAX_ADVINDEX_CALC_DIMS=5 26 | ndim = indices.shape[-1] 27 | output_shape = list(indices.shape[:-1]) + list(params.shape[indices.shape[-1]:]) 28 | flatted_indices = indices.view(-1, ndim) 29 | slices = [flatted_indices[:, i] for i in range(ndim)] 30 | slices += [Ellipsis] 31 | return params[slices].view(*output_shape) 32 | 33 | 34 | def roll(x: torch.Tensor, shift: int, dim: int = -1, fill_pad: Optional[int] = None): 35 | 36 | device = x.device 37 | 38 | if 0 == shift: 39 | return x 40 | 41 | elif shift < 0: 42 | shift = -shift 43 | gap = x.index_select(dim, torch.arange(shift, device=device)) 44 | if fill_pad is not None: 45 | gap = fill_pad * torch.ones_like(gap, device=device) 46 | return torch.cat([x.index_select(dim, torch.arange(shift, x.size(dim), device=device)), gap], dim=dim) 47 | 48 | else: 49 | shift = x.size(dim) - shift 50 | gap = x.index_select(dim, torch.arange(shift, x.size(dim), device=device)) 51 | if fill_pad is not None: 52 | gap = fill_pad * torch.ones_like(gap, device=device) 53 | return torch.cat([gap, x.index_select(dim, torch.arange(shift, device=device))], dim=dim) -------------------------------------------------------------------------------- /rslo/torchplus/tools.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | import sys 4 | from collections import OrderedDict 5 | 6 | import numba 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def get_pos_to_kw_map(func): 12 | pos_to_kw = {} 13 | fsig = inspect.signature(func) 14 | pos = 0 15 | for name, info in fsig.parameters.items(): 16 | if info.kind is info.POSITIONAL_OR_KEYWORD: 17 | pos_to_kw[pos] = name 18 | pos += 1 19 | return pos_to_kw 20 | 21 | 22 | def get_kw_to_default_map(func): 23 | kw_to_default = {} 24 | fsig = inspect.signature(func) 25 | for name, info in fsig.parameters.items(): 26 | if info.kind is info.POSITIONAL_OR_KEYWORD: 27 | if info.default is not info.empty: 28 | kw_to_default[name] = info.default 29 | return kw_to_default 30 | 31 | 32 | # def change_default_args(**kwargs): 33 | # def layer_wrapper(layer_class): 34 | # class DefaultArgLayer(layer_class): 35 | # def __init__(self, *args, **kw): 36 | # pos_to_kw = get_pos_to_kw_map(layer_class.__init__) 37 | # kw_to_pos = {kw: pos for pos, kw in pos_to_kw.items()} 38 | # for key, val in kwargs.items(): 39 | # if key not in kw and kw_to_pos[key] > len(args): 40 | # kw[key] = val 41 | # super().__init__(*args, **kw) 42 | 43 | # return DefaultArgLayer 44 | 45 | # return layer_wrapper 46 | 47 | def change_default_args(**kwargs): 48 | def layer_wrapper(layer_class): 49 | class DefaultArgLayer(layer_class): 50 | def __init__(self, *args, **kw): 51 | pos_to_kw = get_pos_to_kw_map(layer_class.__init__) 52 | kw_to_pos = {kw: pos for pos, kw in pos_to_kw.items()} 53 | for key, val in kwargs.items(): 54 | if key not in kw and kw_to_pos[key] > len(args): 55 | kw[key] = val 56 | super(DefaultArgLayer,self).__init__(*args, **kw) 57 | 58 | return DefaultArgLayer 59 | 60 | return layer_wrapper 61 | def torch_to_np_dtype(ttype): 62 | type_map = { 63 | torch.float16: np.dtype(np.float16), 64 | torch.float32: np.dtype(np.float32), 65 | torch.float16: np.dtype(np.float64), 66 | torch.int32: np.dtype(np.int32), 67 | torch.int64: np.dtype(np.int64), 68 | torch.uint8: np.dtype(np.uint8), 69 | } 70 | return type_map[ttype] 71 | -------------------------------------------------------------------------------- /rslo/torchplus/train/__init__.py: -------------------------------------------------------------------------------- 1 | from torchplus.train.checkpoint import (latest_checkpoint, restore, 2 | restore_latest_checkpoints, 3 | restore_models, save, save_models, 4 | try_restore_latest_checkpoints, 5 | save_models_cpu 6 | ) 7 | from torchplus.train.common import create_folder 8 | from torchplus.train.optim import MixedPrecisionWrapper 9 | -------------------------------------------------------------------------------- /rslo/torchplus/train/common.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import shutil 4 | 5 | def create_folder(prefix, add_time=True, add_str=None, delete=False): 6 | additional_str = '' 7 | if delete is True: 8 | if os.path.exists(prefix): 9 | shutil.rmtree(prefix) 10 | os.makedirs(prefix) 11 | folder = prefix 12 | if add_time is True: 13 | # additional_str has a form such as '170903_220351' 14 | additional_str += datetime.datetime.now().strftime("%y%m%d_%H%M%S") 15 | if add_str is not None: 16 | folder += '/' + additional_str + '_' + add_str 17 | else: 18 | folder += '/' + additional_str 19 | if delete is True: 20 | if os.path.exists(folder): 21 | shutil.rmtree(folder) 22 | os.makedirs(folder) 23 | return folder -------------------------------------------------------------------------------- /rslo/torchplus/train/learning_schedules_fastai.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from functools import partial 4 | import torch 5 | 6 | 7 | class LRSchedulerStep(object): 8 | def __init__(self, fai_optimizer, total_step, lr_phases, mom_phases): 9 | self.optimizer = fai_optimizer 10 | self.total_step = total_step 11 | self.lr_phases = [] 12 | 13 | for i, (start, lambda_func) in enumerate(lr_phases): 14 | if len(self.lr_phases) != 0: 15 | assert self.lr_phases[-1][0] < int(start * total_step) 16 | if isinstance(lambda_func, str): 17 | lambda_func = eval(lambda_func) 18 | if i < len(lr_phases) - 1: 19 | self.lr_phases.append((int(start * total_step), 20 | int(lr_phases[i + 1][0] * total_step), 21 | lambda_func)) 22 | else: 23 | self.lr_phases.append((int(start * total_step), total_step, 24 | lambda_func)) 25 | assert self.lr_phases[0][0] == 0 26 | self.mom_phases = [] 27 | 28 | for i, (start, lambda_func) in enumerate(mom_phases): 29 | if len(self.mom_phases) != 0: 30 | assert self.mom_phases[-1][0] < int(start * total_step) 31 | if isinstance(lambda_func, str): 32 | lambda_func = eval(lambda_func) 33 | if i < len(mom_phases) - 1: 34 | self.mom_phases.append((int(start * total_step), 35 | int(mom_phases[i + 1][0] * total_step), 36 | lambda_func)) 37 | else: 38 | self.mom_phases.append((int(start * total_step), total_step, 39 | lambda_func)) 40 | if len(mom_phases) > 0: 41 | assert self.mom_phases[0][0] == 0 42 | 43 | def step(self, step): 44 | lrs = [] 45 | moms = [] 46 | for start, end, func in self.lr_phases: 47 | if step >= start: 48 | # func: lr decay function 49 | lrs.append(func((step - start) / (end - start))) 50 | if len(lrs) > 0: 51 | self.optimizer.lr = lrs[-1] 52 | # import pdb 53 | # pdb.set_trace() 54 | # print(self.optimizer.lr,'!!!!') 55 | 56 | for start, end, func in self.mom_phases: 57 | if step >= start: 58 | moms.append(func((step - start) / (end - start))) 59 | self.optimizer.mom = func((step - start) / (end - start)) 60 | if len(moms) > 0: 61 | self.optimizer.mom = moms[-1] 62 | 63 | @property 64 | def learning_rate(self): 65 | return self.optimizer.lr 66 | 67 | 68 | def annealing_cos(start, end, pct): 69 | # print(pct, start, end) 70 | "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." 71 | cos_out = np.cos(np.pi * pct) + 1 72 | return end + (start - end) / 2 * cos_out 73 | 74 | 75 | class OneCycle(LRSchedulerStep): 76 | def __init__(self, fai_optimizer, total_step, lr_max, moms, div_factor, 77 | pct_start): 78 | self.lr_max = lr_max 79 | self.moms = moms 80 | self.div_factor = div_factor 81 | self.pct_start = pct_start 82 | a1 = int(total_step * self.pct_start) 83 | a2 = total_step - a1 84 | low_lr = self.lr_max / self.div_factor 85 | lr_phases = ((0, partial(annealing_cos, low_lr, self.lr_max)), 86 | (self.pct_start, 87 | partial(annealing_cos, self.lr_max, low_lr / 1e4)) 88 | ) 89 | mom_phases = ((0, partial(annealing_cos, *self.moms)), 90 | (self.pct_start, partial(annealing_cos, 91 | *self.moms[::-1]))) 92 | fai_optimizer.lr, fai_optimizer.mom = low_lr, self.moms[0] 93 | 94 | super().__init__(fai_optimizer, total_step, lr_phases, mom_phases) 95 | 96 | 97 | class ExponentialDecayWarmup(LRSchedulerStep): 98 | def __init__(self, 99 | fai_optimizer, 100 | total_step, 101 | initial_learning_rate, 102 | decay_length, 103 | decay_factor, 104 | div_factor=1, 105 | pct_start=0, 106 | staircase=True): 107 | """ 108 | Args: 109 | decay_length: must in (0, 1) 110 | """ 111 | assert decay_length > 0 112 | assert decay_length < 1 113 | self._decay_steps_unified = decay_length 114 | self._decay_factor = decay_factor 115 | self._staircase = staircase 116 | self.div_factor = div_factor 117 | self.pct_start = pct_start 118 | step = pct_start*total_step # 0 119 | stage = 1 120 | lr_phases = [ 121 | (0, partial(annealing_cos, initial_learning_rate/div_factor, initial_learning_rate))] 122 | if staircase: 123 | while step <= total_step: 124 | func = lambda p, _d=initial_learning_rate * stage: _d 125 | lr_phases.append((step / total_step, func)) 126 | stage *= decay_factor 127 | step += int(decay_length * total_step) 128 | else: 129 | def func(p): return pow(decay_factor, (p / decay_length)) 130 | lr_phases.append((pct_start, func)) 131 | # lr_phases.append((step/total_step, func)) 132 | super().__init__(fai_optimizer, total_step, lr_phases, []) 133 | 134 | 135 | class ExponentialDecay(LRSchedulerStep): 136 | def __init__(self, 137 | fai_optimizer, 138 | total_step, 139 | initial_learning_rate, 140 | decay_length, 141 | decay_factor, 142 | staircase=True): 143 | """ 144 | Args: 145 | decay_length: must in (0, 1) 146 | """ 147 | assert decay_length > 0 148 | assert decay_length < 1 149 | self._decay_steps_unified = decay_length 150 | self._decay_factor = decay_factor 151 | self._staircase = staircase 152 | step = 0 153 | stage = 1 154 | lr_phases = [] 155 | if staircase: 156 | while step <= total_step: 157 | func = lambda p, _d=initial_learning_rate * stage: _d 158 | lr_phases.append((step / total_step, func)) 159 | stage *= decay_factor 160 | step += int(decay_length * total_step) 161 | else: 162 | def func(p): return pow(decay_factor, (p / decay_length)) 163 | lr_phases.append((0, func)) 164 | super().__init__(fai_optimizer, total_step, lr_phases, []) 165 | 166 | 167 | class ManualStepping(LRSchedulerStep): 168 | def __init__(self, fai_optimizer, total_step, boundaries, rates): 169 | assert all([b > 0 and b < 1 for b in boundaries]) 170 | assert len(boundaries) + 1 == len(rates) 171 | boundaries.insert(0, 0.0) 172 | lr_phases = [] 173 | for start, rate in zip(boundaries, rates): 174 | def func(p, _d=rate): return _d 175 | lr_phases.append((start, func)) 176 | super().__init__(fai_optimizer, total_step, lr_phases, []) 177 | 178 | 179 | class FakeOptim: 180 | def __init__(self): 181 | self.lr = 0 182 | self.mom = 0 183 | 184 | 185 | if __name__ == "__main__": 186 | import matplotlib.pyplot as plt 187 | opt = FakeOptim() # 3e-3, wd=0.4, div_factor=10 188 | # schd = OneCycle(opt, 100, 3e-3, (0.95, 0.85), 10.0, 0.4) 189 | schd = ExponentialDecay(opt, 100, 3e-4, 0.1, 0.8, staircase=True) 190 | schd = ManualStepping(opt, 100, [0.8, 0.9], [0.001, 0.0001, 0.00005]) 191 | lrs = [] 192 | moms = [] 193 | for i in range(100): 194 | schd.step(i) 195 | lrs.append(opt.lr) 196 | moms.append(opt.mom) 197 | 198 | plt.plot(lrs) 199 | # plt.plot(moms) 200 | # plt.show() 201 | # plt.plot(moms) 202 | plt.show() 203 | -------------------------------------------------------------------------------- /rslo/torchplus/train/optim.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, Iterable 2 | 3 | import torch 4 | from copy import deepcopy 5 | from itertools import chain 6 | from torch.autograd import Variable 7 | 8 | required = object() 9 | 10 | def param_fp32_copy(params): 11 | param_copy = [ 12 | param.clone().type(torch.cuda.FloatTensor).detach() for param in params 13 | ] 14 | for param in param_copy: 15 | param.requires_grad = True 16 | return param_copy 17 | 18 | def set_grad(params, params_with_grad, scale=1.0): 19 | for param, param_w_grad in zip(params, params_with_grad): 20 | if param.grad is None: 21 | param.grad = torch.nn.Parameter( 22 | param.data.new().resize_(*param.data.size())) 23 | grad = param_w_grad.grad.data 24 | if scale is not None: 25 | grad /= scale 26 | if torch.isnan(grad).any() or torch.isinf(grad).any(): 27 | return True # invalid grad 28 | param.grad.data.copy_(grad) 29 | return False 30 | 31 | class MixedPrecisionWrapper(object): 32 | """mixed precision optimizer wrapper. 33 | Arguments: 34 | optimizer (torch.optim.Optimizer): an instance of 35 | :class:`torch.optim.Optimizer` 36 | scale: (float): a scalar for grad scale. 37 | auto_scale: (bool): whether enable auto scale. 38 | The algorihm of auto scale is discribled in 39 | http://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html 40 | """ 41 | 42 | def __init__(self, 43 | optimizer, 44 | scale=None, 45 | auto_scale=True, 46 | inc_factor=2.0, 47 | dec_factor=0.5, 48 | num_iters_be_stable=500): 49 | # if not isinstance(optimizer, torch.optim.Optimizer): 50 | # raise ValueError("must provide a torch.optim.Optimizer") 51 | self.optimizer = optimizer 52 | if hasattr(self.optimizer, 'name'): 53 | self.name = self.optimizer.name # for ckpt system 54 | param_groups_copy = [] 55 | for i, group in enumerate(optimizer.param_groups): 56 | group_copy = {n: v for n, v in group.items() if n != 'params'} 57 | group_copy['params'] = param_fp32_copy(group['params']) 58 | param_groups_copy.append(group_copy) 59 | 60 | # switch param_groups, may be dangerous 61 | self.param_groups = optimizer.param_groups 62 | optimizer.param_groups = param_groups_copy 63 | self.grad_scale = scale 64 | self.auto_scale = auto_scale 65 | self.inc_factor = inc_factor 66 | self.dec_factor = dec_factor 67 | self.stable_iter_count = 0 68 | self.num_iters_be_stable = num_iters_be_stable 69 | 70 | def __getstate__(self): 71 | return self.optimizer.__getstate__() 72 | 73 | def __setstate__(self, state): 74 | return self.optimizer.__setstate__(state) 75 | 76 | def __repr__(self): 77 | return self.optimizer.__repr__() 78 | 79 | def state_dict(self): 80 | return self.optimizer.state_dict() 81 | 82 | def load_state_dict(self, state_dict): 83 | return self.optimizer.load_state_dict(state_dict) 84 | 85 | def zero_grad(self): 86 | return self.optimizer.zero_grad() 87 | 88 | def step(self, closure=None): 89 | for g, g_copy in zip(self.param_groups, self.optimizer.param_groups): 90 | invalid = set_grad(g_copy['params'], g['params'], self.grad_scale) 91 | if invalid: 92 | if self.grad_scale is None or self.auto_scale is False: 93 | raise ValueError("nan/inf detected but auto_scale disabled.") 94 | self.grad_scale *= self.dec_factor 95 | print('scale decay to {}'.format(self.grad_scale)) 96 | return 97 | if self.auto_scale is True: 98 | self.stable_iter_count += 1 99 | if self.stable_iter_count > self.num_iters_be_stable: 100 | if self.grad_scale is not None: 101 | self.grad_scale *= self.inc_factor 102 | self.stable_iter_count = 0 103 | 104 | if closure is None: 105 | self.optimizer.step() 106 | else: 107 | self.optimizer.step(closure) 108 | for g, g_copy in zip(self.param_groups, self.optimizer.param_groups): 109 | for p_copy, p in zip(g_copy['params'], g['params']): 110 | p.data.copy_(p_copy.data) 111 | -------------------------------------------------------------------------------- /rslo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import time 2 | import contextlib 3 | import torch 4 | 5 | @contextlib.contextmanager 6 | def torch_timer(name=''): 7 | torch.cuda.synchronize() 8 | t = time.time() 9 | yield 10 | torch.cuda.synchronize() 11 | print(name, "time:", time.time() - t) -------------------------------------------------------------------------------- /rslo/utils/check.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def is_array_like(x): 4 | return isinstance(x, (list, tuple, np.ndarray)) 5 | 6 | def shape_mergeable(x, expected_shape): 7 | mergeable = True 8 | if is_array_like(x) and is_array_like(expected_shape): 9 | x = np.array(x) 10 | if len(x.shape) == len(expected_shape): 11 | for s, s_ex in zip(x.shape, expected_shape): 12 | if s_ex is not None and s != s_ex: 13 | mergeable = False 14 | break 15 | return mergeable -------------------------------------------------------------------------------- /rslo/utils/config_tool.py: -------------------------------------------------------------------------------- 1 | # This file contains some config modification function. 2 | # some functions should be only used for KITTI dataset. 3 | 4 | from google.protobuf import text_format 5 | from second.protos import pipeline_pb2, second_pb2 6 | from pathlib import Path 7 | import numpy as np 8 | 9 | 10 | def change_detection_range(model_config, new_range): 11 | assert len(new_range) == 4, "you must provide a list such as [-50, -50, 50, 50]" 12 | old_pc_range = list(model_config.voxel_generator.point_cloud_range) 13 | old_pc_range[:2] = new_range[:2] 14 | old_pc_range[3:5] = new_range[2:] 15 | model_config.voxel_generator.point_cloud_range[:] = old_pc_range 16 | for anchor_generator in model_config.target_assigner.anchor_generators: 17 | a_type = anchor_generator.WhichOneof('anchor_generator') 18 | if a_type == "anchor_generator_range": 19 | a_cfg = anchor_generator.anchor_generator_range 20 | old_a_range = list(a_cfg.anchor_ranges) 21 | old_a_range[:2] = new_range[:2] 22 | old_a_range[3:5] = new_range[2:] 23 | a_cfg.anchor_ranges[:] = old_a_range 24 | elif a_type == "anchor_generator_stride": 25 | a_cfg = anchor_generator.anchor_generator_stride 26 | old_offset = list(a_cfg.offsets) 27 | stride = list(a_cfg.strides) 28 | old_offset[0] = new_range[0] + stride[0] / 2 29 | old_offset[1] = new_range[1] + stride[1] / 2 30 | a_cfg.offsets[:] = old_offset 31 | else: 32 | raise ValueError("unknown") 33 | old_post_range = list(model_config.post_center_limit_range) 34 | old_post_range[:2] = new_range[:2] 35 | old_post_range[3:5] = new_range[2:] 36 | model_config.post_center_limit_range[:] = old_post_range 37 | 38 | def get_downsample_factor(model_config): 39 | downsample_factor = np.prod(model_config.rpn.layer_strides) 40 | if len(model_config.rpn.upsample_strides) > 0: 41 | downsample_factor /= model_config.rpn.upsample_strides[-1] 42 | downsample_factor *= model_config.middle_feature_extractor.downsample_factor 43 | downsample_factor = int(downsample_factor) 44 | assert downsample_factor > 0 45 | return downsample_factor 46 | 47 | 48 | if __name__ == "__main__": 49 | config_path = "/home/yy/deeplearning/deeplearning/mypackages/second/configs/car.lite.1.config" 50 | config = pipeline_pb2.TrainEvalPipelineConfig() 51 | 52 | with open(config_path, "r") as f: 53 | proto_str = f.read() 54 | text_format.Merge(proto_str, config) 55 | 56 | change_detection_range(config, [-50, -50, 50, 50]) 57 | proto_str = text_format.MessageToString(config, indent=2) 58 | print(proto_str) 59 | 60 | -------------------------------------------------------------------------------- /rslo/utils/config_tool/__init__.py: -------------------------------------------------------------------------------- 1 | # This file contains some config modification function. 2 | # some functions should be only used for KITTI dataset. 3 | 4 | from google.protobuf import text_format 5 | from rslo.protos import pipeline_pb2, second_pb2 6 | from pathlib import Path 7 | import numpy as np 8 | 9 | 10 | def read_config(path): 11 | config = pipeline_pb2.TrainEvalPipelineConfig() 12 | 13 | with open(path, "r") as f: 14 | proto_str = f.read() 15 | text_format.Merge(proto_str, config) 16 | return config 17 | 18 | 19 | def change_detection_range(model_config, new_range): 20 | assert len( 21 | new_range) == 4, "you must provide a list such as [-50, -50, 50, 50]" 22 | old_pc_range = list(model_config.voxel_generator.point_cloud_range) 23 | old_pc_range[:2] = new_range[:2] 24 | old_pc_range[3:5] = new_range[2:] 25 | model_config.voxel_generator.point_cloud_range[:] = old_pc_range 26 | for anchor_generator in model_config.target_assigner.anchor_generators: 27 | a_type = anchor_generator.WhichOneof('anchor_generator') 28 | if a_type == "anchor_generator_range": 29 | a_cfg = anchor_generator.anchor_generator_range 30 | old_a_range = list(a_cfg.anchor_ranges) 31 | old_a_range[:2] = new_range[:2] 32 | old_a_range[3:5] = new_range[2:] 33 | a_cfg.anchor_ranges[:] = old_a_range 34 | elif a_type == "anchor_generator_stride": 35 | a_cfg = anchor_generator.anchor_generator_stride 36 | old_offset = list(a_cfg.offsets) 37 | stride = list(a_cfg.strides) 38 | old_offset[0] = new_range[0] + stride[0] / 2 39 | old_offset[1] = new_range[1] + stride[1] / 2 40 | a_cfg.offsets[:] = old_offset 41 | else: 42 | raise ValueError("unknown") 43 | old_post_range = list(model_config.post_center_limit_range) 44 | old_post_range[:2] = new_range[:2] 45 | old_post_range[3:5] = new_range[2:] 46 | model_config.post_center_limit_range[:] = old_post_range 47 | 48 | 49 | def get_downsample_factor(model_config): 50 | downsample_factor = np.prod(model_config.rpn.layer_strides) 51 | if len(model_config.rpn.upsample_strides) > 0: 52 | downsample_factor /= model_config.rpn.upsample_strides[-1] 53 | downsample_factor *= model_config.middle_feature_extractor.downsample_factor 54 | downsample_factor = np.round(downsample_factor).astype(np.int64) 55 | assert downsample_factor > 0 56 | return downsample_factor 57 | 58 | 59 | if __name__ == "__main__": 60 | config_path = "/home/yy/deeplearning/deeplearning/mypackages/second/configs/car.lite.1.config" 61 | config = pipeline_pb2.TrainEvalPipelineConfig() 62 | 63 | with open(config_path, "r") as f: 64 | proto_str = f.read() 65 | text_format.Merge(proto_str, config) 66 | 67 | change_detection_range(config, [-50, -50, 50, 50]) 68 | proto_str = text_format.MessageToString(config, indent=2) 69 | print(proto_str) 70 | -------------------------------------------------------------------------------- /rslo/utils/config_tool/train.py: -------------------------------------------------------------------------------- 1 | from second.protos.optimizer_pb2 import Optimizer, LearningRate, OneCycle, ManualStepping, ExponentialDecay 2 | from second.protos.sampler_pb2 import Sampler 3 | from second.utils.config_tool import read_config 4 | from pathlib import Path 5 | from google.protobuf import text_format 6 | from second.data.all_dataset import get_dataset_class 7 | 8 | def _get_optim_cfg(train_config, optim): 9 | if optim == "adam_optimizer": 10 | return train_config.optimizer.adam_optimizer 11 | elif optim == "rms_prop_optimizer": 12 | return train_config.optimizer.rms_prop_optimizer 13 | elif optim == "momentum_optimizer": 14 | return train_config.optimizer.momentum_optimizer 15 | else: 16 | raise NotImplementedError 17 | 18 | 19 | def manual_stepping(train_config, boundaries, rates, optim="adam_optimizer"): 20 | optim_cfg = _get_optim_cfg(train_config, optim) 21 | optim_cfg.learning_rate.manual_stepping.CopyFrom( 22 | ManualStepping(boundaries=boundaries, rates=rates)) 23 | 24 | 25 | def exp_decay(train_config, 26 | init_lr, 27 | decay_length, 28 | decay_factor, 29 | staircase=True, 30 | optim="adam_optimizer"): 31 | optim_cfg = _get_optim_cfg(train_config, optim) 32 | optim_cfg.learning_rate.exponential_decay.CopyFrom( 33 | ExponentialDecay( 34 | initial_learning_rate=init_lr, 35 | decay_length=decay_length, 36 | decay_factor=decay_factor, 37 | staircase=staircase)) 38 | 39 | 40 | def one_cycle(train_config, 41 | lr_max, 42 | moms, 43 | div_factor, 44 | pct_start, 45 | optim="adam_optimizer"): 46 | optim_cfg = _get_optim_cfg(train_config, optim) 47 | optim_cfg.learning_rate.one_cycle.CopyFrom( 48 | OneCycle( 49 | lr_max=lr_max, 50 | moms=moms, 51 | div_factor=div_factor, 52 | pct_start=pct_start)) 53 | 54 | def _div_up(a, b): 55 | return (a + b - 1) // b 56 | 57 | def set_train_step(config, 58 | epochs, 59 | eval_epoch): 60 | input_cfg = config.train_input_reader 61 | train_cfg = config.train_config 62 | batch_size = input_cfg.batch_size 63 | dataset_name = input_cfg.dataset.dataset_class_name 64 | ds = get_dataset_class(dataset_name)( 65 | root_path=input_cfg.dataset.kitti_root_path, 66 | info_path=input_cfg.dataset.kitti_info_path, 67 | ) 68 | num_examples_after_sample = len(ds) 69 | step_per_epoch = _div_up(num_examples_after_sample, batch_size) 70 | step_per_eval = step_per_epoch * eval_epoch 71 | total_step = step_per_epoch * epochs 72 | train_cfg.steps = total_step 73 | train_cfg.steps_per_eval = step_per_eval 74 | 75 | def disable_sample(config): 76 | input_cfg = config.train_input_reader 77 | input_cfg.database_sampler.CopyFrom(Sampler()) 78 | 79 | def disable_per_gt_aug(config): 80 | prep_cfg = config.train_input_reader.preprocess 81 | prep_cfg.groundtruth_localization_noise_std[:] = [0, 0, 0] 82 | prep_cfg.groundtruth_rotation_uniform_noise[:] = [0, 0] 83 | 84 | def disable_global_aug(config): 85 | prep_cfg = config.train_input_reader.preprocess 86 | prep_cfg.global_rotation_uniform_noise[:] = [0, 0] 87 | prep_cfg.global_scaling_uniform_noise[:] = [0, 0] 88 | prep_cfg.global_random_rotation_range_per_object[:] = [0, 0] 89 | prep_cfg.global_translate_noise_std[:] = [0, 0, 0] 90 | 91 | if __name__ == "__main__": 92 | path = Path(__file__).resolve().parents[2] / "configs/car.lite.config" 93 | config = read_config(path) 94 | manual_stepping(config.train_config, [0.8, 0.9], [1e-4, 1e-5, 1e-6]) 95 | 96 | print(text_format.MessageToString(config, indent=2)) -------------------------------------------------------------------------------- /rslo/utils/find.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import subprocess 5 | import sys 6 | import tempfile 7 | from pathlib import Path 8 | 9 | import fire 10 | 11 | 12 | def _get_info_from_anaconda_info(info, split=":"): 13 | info = info.strip("\n").replace(" ", "") 14 | info_dict = {} 15 | latest_key = "" 16 | for line in info.splitlines(): 17 | if split in line: 18 | pair = line.split(split) 19 | info_dict[pair[0]] = pair[1] 20 | latest_key = pair[0] 21 | else: 22 | if not isinstance(info_dict[latest_key], list): 23 | info_dict[latest_key] = [info_dict[latest_key]] 24 | info_dict[latest_key].append(line) 25 | return info_dict 26 | 27 | 28 | def find_anaconda(): 29 | # try find in default path 30 | path = Path.home() / "anaconda3" 31 | if path.exists(): 32 | return path 33 | # try conda in cmd 34 | try: 35 | info = subprocess.check_output( 36 | "conda info", shell=True).decode('utf-8') 37 | info_dict = _get_info_from_anaconda_info(info) 38 | return info_dict["activeenvlocation"] 39 | except subprocess.CalledProcessError: 40 | raise RuntimeError("find anadonda failed") 41 | 42 | 43 | def find_cuda(): 44 | '''Finds the CUDA install path.''' 45 | # Guess #1 46 | cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') 47 | if cuda_home is None: 48 | # Guess #2 49 | if sys.platform == 'win32': 50 | cuda_homes = glob.glob( 51 | 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') 52 | if len(cuda_homes) == 0: 53 | cuda_home = '' 54 | else: 55 | cuda_home = cuda_homes[0] 56 | else: 57 | cuda_home = '/usr/local/cuda' 58 | if not os.path.exists(cuda_home): 59 | # Guess #3 60 | try: 61 | which = 'where' if sys.platform == 'win32' else 'which' 62 | nvcc = subprocess.check_output( 63 | [which, 'nvcc']).decode().rstrip('\r\n') 64 | cuda_home = os.path.dirname(os.path.dirname(nvcc)) 65 | except Exception: 66 | cuda_home = None 67 | if cuda_home is None: 68 | raise RuntimeError( 69 | "No CUDA runtime is found, using CUDA_HOME='{}'".format(cuda_home)) 70 | return cuda_home 71 | 72 | 73 | def find_cuda_device_arch(): 74 | if sys.platform == 'win32': 75 | # TODO: add windows support 76 | return None 77 | cuda_home = find_cuda() 78 | if cuda_home is None: 79 | return None 80 | cuda_home = Path(cuda_home) 81 | try: 82 | device_query_path = cuda_home / 'extras/demo_suite/deviceQuery' 83 | if not device_query_path.exists(): 84 | source = """ 85 | #include 86 | #include 87 | int main(){ 88 | int nDevices; 89 | cudaGetDeviceCount(&nDevices); 90 | for (int i = 0; i < nDevices; i++) { 91 | cudaDeviceProp prop; 92 | cudaGetDeviceProperties(&prop, i); 93 | std::cout << prop.major << "." << prop.minor << std::endl; 94 | } 95 | return 0; 96 | } 97 | """ 98 | with tempfile.NamedTemporaryFile('w', suffix='.cc') as f: 99 | f_path = Path(f.name) 100 | f.write(source) 101 | f.flush() 102 | try: 103 | # TODO: add windows support 104 | cmd = ( 105 | f"g++ {f.name} -o {f_path.stem}" 106 | f" -I{cuda_home / 'include'} -L{cuda_home / 'lib64'} -lcudart" 107 | ) 108 | print(cmd) 109 | subprocess.check_output(cmd, shell=True, cwd=f_path.parent) 110 | cmd = f"./{f_path.stem}" 111 | arches = subprocess.check_output( 112 | cmd, shell=True, 113 | cwd=f_path.parent).decode().rstrip('\r\n').split("\n") 114 | if len(arches) < 1: 115 | return None 116 | arch = arches[0] 117 | except: 118 | return None 119 | else: 120 | cmd = f"{str(device_query_path)} | grep 'CUDA Capability'" 121 | arch = subprocess.check_output( 122 | cmd, shell=True).decode().rstrip('\r\n').split(" ")[-1] 123 | # assert len(arch) == 2 124 | arch_list = [int(s) for s in arch.split(".")] 125 | arch_int = arch_list[0] * 10 + arch_list[1] 126 | find_work_arch = False 127 | while arch_int > 10: 128 | try: 129 | res = subprocess.check_output("nvcc -arch=sm_{}".format(arch_int), shell=True, stderr=subprocess.STDOUT) 130 | except subprocess.CalledProcessError as e: 131 | if "No input files specified" in e.output.decode(): 132 | find_work_arch = True 133 | break 134 | elif "is not defined for option 'gpu-architecture'" in e.output.decode(): 135 | arch_int -= 1 136 | else: 137 | raise RuntimeError("unknown error") 138 | if find_work_arch: 139 | arch = f"sm_{arch_int}" 140 | else: 141 | arch = None 142 | 143 | except Exception: 144 | arch = None 145 | return arch 146 | 147 | 148 | def get_gpu_memory_usage(): 149 | if sys.platform == 'win32': 150 | # TODO: add windows support 151 | return None 152 | cuda_home = find_cuda() 153 | if cuda_home is None: 154 | return None 155 | cuda_home = Path(cuda_home) 156 | source = """ 157 | #include 158 | #include 159 | int main(){ 160 | int nDevices; 161 | cudaGetDeviceCount(&nDevices); 162 | size_t free_m, total_m; 163 | // output json format. 164 | std::cout << "["; 165 | for (int i = 0; i < nDevices; i++) { 166 | cudaSetDevice(i); 167 | cudaMemGetInfo(&free_m, &total_m); 168 | std::cout << "[" << free_m << "," << total_m << "]"; 169 | if (i != nDevices - 1) 170 | std::cout << "," << std::endl; 171 | } 172 | std::cout << "]" << std::endl; 173 | return 0; 174 | } 175 | """ 176 | with tempfile.NamedTemporaryFile('w', suffix='.cc') as f: 177 | f_path = Path(f.name) 178 | f.write(source) 179 | f.flush() 180 | try: 181 | # TODO: add windows support 182 | cmd = ( 183 | f"g++ {f.name} -o {f_path.stem} -std=c++11" 184 | f" -I{cuda_home / 'include'} -L{cuda_home / 'lib64'} -lcudart") 185 | print(cmd) 186 | subprocess.check_output(cmd, shell=True, cwd=f_path.parent) 187 | cmd = f"./{f_path.stem}" 188 | usages = subprocess.check_output( 189 | cmd, shell=True, cwd=f_path.parent).decode() 190 | usages = json.loads(usages) 191 | return usages 192 | except: 193 | return None 194 | return None 195 | 196 | 197 | if __name__ == "__main__": 198 | print(find_cuda_device_arch()) 199 | # fire.Fire() 200 | -------------------------------------------------------------------------------- /rslo/utils/loader.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from pathlib import Path 3 | import sys 4 | import os 5 | import logging 6 | logger = logging.getLogger('second.utils.loader') 7 | 8 | CUSTOM_LOADED_MODULES = {} 9 | 10 | 11 | def _get_possible_module_path(paths): 12 | ret = [] 13 | for p in paths: 14 | p = Path(p) 15 | for path in p.glob("*"): 16 | if path.suffix in ["py", ".so"] or (path.is_dir()): 17 | if path.stem.isidentifier(): 18 | ret.append(path) 19 | return ret 20 | 21 | 22 | def _get_regular_import_name(path, module_paths): 23 | path = Path(path) 24 | for mp in module_paths: 25 | mp = Path(mp) 26 | if mp == path: 27 | return path.stem 28 | try: 29 | relative_path = path.relative_to(Path(mp)) 30 | parts = list((relative_path.parent / relative_path.stem).parts) 31 | module_name = '.'.join([mp.stem] + parts) 32 | return module_name 33 | except: 34 | pass 35 | return None 36 | 37 | 38 | def import_file(path, name: str = None, add_to_sys=True, 39 | disable_warning=False): 40 | global CUSTOM_LOADED_MODULES 41 | path = Path(path) 42 | module_name = path.stem 43 | try: 44 | user_paths = os.environ['PYTHONPATH'].split(os.pathsep) 45 | except KeyError: 46 | user_paths = [] 47 | possible_paths = _get_possible_module_path(user_paths) 48 | model_import_name = _get_regular_import_name(path, possible_paths) 49 | if model_import_name is not None: 50 | return import_name(model_import_name) 51 | if name is not None: 52 | module_name = name 53 | spec = importlib.util.spec_from_file_location(module_name, path) 54 | module = importlib.util.module_from_spec(spec) 55 | spec.loader.exec_module(module) 56 | if not disable_warning: 57 | logger.warning(( 58 | f"Failed to perform regular import for file {path}. " 59 | "this means this file isn't in any folder in PYTHONPATH " 60 | "or don't have __init__.py in that project. " 61 | "directly file import may fail and some reflecting features are " 62 | "disabled even if import succeed. please add your project to PYTHONPATH " 63 | "or add __init__.py to ensure this file can be regularly imported. " 64 | )) 65 | 66 | if add_to_sys: # this will enable find objects defined in a file. 67 | # avoid replace system modules. 68 | if module_name in sys.modules and module_name not in CUSTOM_LOADED_MODULES: 69 | raise ValueError(f"{module_name} exists in system.") 70 | CUSTOM_LOADED_MODULES[module_name] = module 71 | sys.modules[module_name] = module 72 | return module 73 | 74 | 75 | def import_name(name, package=None): 76 | module = importlib.import_module(name, package) 77 | return module -------------------------------------------------------------------------------- /rslo/utils/log_tool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorboardX import SummaryWriter 3 | import json 4 | from pathlib import Path 5 | 6 | 7 | def _flat_nested_json_dict(json_dict, flatted, sep=".", start=""): 8 | for k, v in json_dict.items(): 9 | if isinstance(v, dict): 10 | _flat_nested_json_dict(v, flatted, sep, start + sep + str(k)) 11 | else: 12 | flatted[start + sep + str(k)] = v 13 | 14 | 15 | def flat_nested_json_dict(json_dict, sep=".") -> dict: 16 | """flat a nested json-like dict. this function make shadow copy. 17 | """ 18 | flatted = {} 19 | for k, v in json_dict.items(): 20 | if isinstance(v, dict): 21 | _flat_nested_json_dict(v, flatted, sep, str(k)) 22 | else: 23 | flatted[str(k)] = v 24 | return flatted 25 | 26 | 27 | def metric_to_str(metrics, sep='.'): 28 | flatted_metrics = flat_nested_json_dict(metrics, sep) 29 | metrics_str_list = [] 30 | for k, v in flatted_metrics.items(): 31 | if isinstance(v, float): 32 | metrics_str_list.append(f"{k}={v:.5}") 33 | elif isinstance(v, (list, tuple)): 34 | if v and isinstance(v[0], float): 35 | v_str = ', '.join([f"{e:.5}" for e in v]) 36 | metrics_str_list.append(f"{k}=[{v_str}]") 37 | else: 38 | metrics_str_list.append(f"{k}={v}") 39 | else: 40 | metrics_str_list.append(f"{k}={v}") 41 | return ', '.join(metrics_str_list) 42 | 43 | 44 | class SimpleModelLog: 45 | """For simple log. 46 | generate 4 kinds of log: 47 | 1. simple log.txt, all metric dicts are flattened to produce 48 | readable results. 49 | 2. TensorBoard scalars and texts 50 | 3. multi-line json file log.json.lst 51 | 4. tensorboard_scalars.json, all scalars are stored in this file 52 | in tensorboard json format. 53 | """ 54 | 55 | def __init__(self, model_dir, disable=False): 56 | self.model_dir = Path(model_dir) 57 | self.log_file = None 58 | self.log_mjson_file = None 59 | self.summary_writter = None 60 | self.metrics = [] 61 | self._text_current_gstep = -1 62 | self._tb_texts = [] 63 | self.disable = disable 64 | def __del__(self): 65 | self.close() 66 | 67 | def open(self): 68 | if self.disable: 69 | return self 70 | model_dir = self.model_dir 71 | assert model_dir.exists() 72 | summary_dir = model_dir / 'summary' 73 | summary_dir.mkdir(parents=True, exist_ok=True) 74 | 75 | log_mjson_file_path = model_dir / f'log.json.lst' 76 | if log_mjson_file_path.exists(): 77 | with open(log_mjson_file_path, 'r') as f: 78 | for line in f.readlines(): 79 | self.metrics.append(json.loads(line)) 80 | log_file_path = model_dir / f'log.txt' 81 | self.log_mjson_file = open(log_mjson_file_path, 'a') 82 | self.log_file = open(log_file_path, 'a') 83 | self.summary_writter = SummaryWriter(str(summary_dir),flush_secs=60) 84 | return self 85 | 86 | def close(self): 87 | if self.disable: 88 | return 89 | assert self.summary_writter is not None 90 | self.log_mjson_file.close() 91 | self.log_file.close() 92 | tb_json_path = str(self.model_dir / "tensorboard_scalars.json") 93 | self.summary_writter.export_scalars_to_json(tb_json_path) 94 | self.summary_writter.close() 95 | self.log_mjson_file = None 96 | self.log_file = None 97 | self.summary_writter = None 98 | 99 | def log_text(self, text, step, tag="regular log"): 100 | if self.disable: 101 | return 102 | """This function only add text to log.txt and tensorboard texts 103 | """ 104 | print(text,flush=True) 105 | print(text, file=self.log_file,flush=True) 106 | if step > self._text_current_gstep and self._text_current_gstep != -1: 107 | total_text = '\n'.join(self._tb_texts) 108 | self.summary_writter.add_text(tag, total_text, global_step=step) 109 | self._tb_texts = [] 110 | self._text_current_gstep = step 111 | else: 112 | self._tb_texts.append(text) 113 | if self._text_current_gstep == -1: 114 | self._text_current_gstep = step 115 | 116 | def log_metrics(self, metrics: dict, step): 117 | if self.disable: 118 | return 119 | flatted_summarys = flat_nested_json_dict(metrics, "/") 120 | for k, v in flatted_summarys.items(): 121 | if isinstance(v, (list, tuple)): 122 | if any([isinstance(e, str) for e in v]): 123 | continue 124 | v_dict = {str(i): e for i, e in enumerate(v)} 125 | for k1, v1 in v_dict.items(): 126 | self.summary_writter.add_scalar(k + "/" + k1, v1, step) 127 | else: 128 | if isinstance(v, str): 129 | continue 130 | self.summary_writter.add_scalar(k, v, step) 131 | log_str = metric_to_str(metrics) 132 | print(log_str, flush=True) 133 | print(log_str, file=self.log_file, flush=True) 134 | print(json.dumps(metrics), file=self.log_mjson_file, flush=True) 135 | 136 | def log_images(self, images: dict, step, prefix=''): 137 | if self.disable: 138 | return 139 | for k, v in images.items(): 140 | self.summary_writter.add_images(prefix+str(k), v, step) 141 | print(f"Summarize images {k}",flush=True) 142 | 143 | def log_histograms(self, vals: dict, step, prefix=''): 144 | if self.disable: 145 | return 146 | for k, v in vals.items(): 147 | self.summary_writter.add_histogram(prefix+str(k), v, step) 148 | print(f"Summarize histograms {k}",flush=True) -------------------------------------------------------------------------------- /rslo/utils/math.py: -------------------------------------------------------------------------------- 1 | import kornia 2 | import torch 3 | -------------------------------------------------------------------------------- /rslo/utils/progress_bar.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import enum 3 | import math 4 | import time 5 | 6 | import numpy as np 7 | 8 | 9 | def progress_str(val, *args, width=20, with_ptg=True): 10 | val = max(0., min(val, 1.)) 11 | assert width > 1 12 | pos = round(width * val) - 1 13 | if with_ptg is True: 14 | log = '[{}%]'.format(max_point_str(val * 100.0, 4)) 15 | log += '[' 16 | for i in range(width): 17 | if i < pos: 18 | log += '=' 19 | elif i == pos: 20 | log += '>' 21 | else: 22 | log += '.' 23 | log += ']' 24 | for arg in args: 25 | log += '[{}]'.format(arg) 26 | return log 27 | 28 | 29 | def second_to_time_str(second, omit_hours_if_possible=True): 30 | second = int(second) 31 | m, s = divmod(second, 60) 32 | h, m = divmod(m, 60) 33 | if omit_hours_if_possible: 34 | if h == 0: 35 | return '{:02d}:{:02d}'.format(m, s) 36 | return '{:02d}:{:02d}:{:02d}'.format(h, m, s) 37 | 38 | 39 | def progress_bar_iter(task_list, width=20, with_ptg=True, step_time_average=50, name=None): 40 | total_step = len(task_list) 41 | step_times = [] 42 | start_time = 0.0 43 | name = '' if name is None else f"[{name}]" 44 | for i, task in enumerate(task_list): 45 | t = time.time() 46 | yield task 47 | step_times.append(time.time() - t) 48 | start_time += step_times[-1] 49 | start_time_str = second_to_time_str(start_time) 50 | average_step_time = np.mean(step_times[-step_time_average:]) + 1e-6 51 | speed_str = "{:.2f}it/s".format(1 / average_step_time) 52 | remain_time = (total_step - i) * average_step_time 53 | remain_time_str = second_to_time_str(remain_time) 54 | time_str = start_time_str + '>' + remain_time_str 55 | prog_str = progress_str( 56 | (i + 1) / total_step, 57 | speed_str, 58 | time_str, 59 | width=width, 60 | with_ptg=with_ptg) 61 | print(name + prog_str + ' ', end='\r', flush=True) 62 | print("") 63 | 64 | 65 | list_bar = progress_bar_iter 66 | 67 | def enumerate_bar(task_list, width=20, with_ptg=True, step_time_average=50, name=None): 68 | total_step = len(task_list) 69 | step_times = [] 70 | start_time = 0.0 71 | name = '' if name is None else f"[{name}]" 72 | for i, task in enumerate(task_list): 73 | t = time.time() 74 | yield i, task 75 | step_times.append(time.time() - t) 76 | start_time += step_times[-1] 77 | start_time_str = second_to_time_str(start_time) 78 | average_step_time = np.mean(step_times[-step_time_average:]) + 1e-6 79 | speed_str = "{:.2f}it/s".format(1 / average_step_time) 80 | remain_time = (total_step - i) * average_step_time 81 | remain_time_str = second_to_time_str(remain_time) 82 | time_str = start_time_str + '>' + remain_time_str 83 | prog_str = progress_str( 84 | (i + 1) / total_step, 85 | speed_str, 86 | time_str, 87 | width=width, 88 | with_ptg=with_ptg) 89 | print(name + prog_str + ' ', end='\r', flush=True) 90 | print("") 91 | 92 | 93 | def max_point_str(val, max_point): 94 | positive = bool(val >= 0.0) 95 | val = np.abs(val) 96 | if val == 0: 97 | point = 1 98 | else: 99 | point = max(int(np.log10(val)), 0) + 1 100 | fmt = "{:." + str(max(max_point - point, 0)) + "f}" 101 | if positive is True: 102 | return fmt.format(val) 103 | else: 104 | return fmt.format(-val) 105 | 106 | 107 | class Unit(enum.Enum): 108 | Iter = 'iter' 109 | Byte = 'byte' 110 | 111 | 112 | def convert_size(size_bytes): 113 | # from https://stackoverflow.com/questions/5194057/better-way-to-convert-file-sizes-in-python 114 | if size_bytes == 0: 115 | return "0B" 116 | size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") 117 | i = int(math.floor(math.log(size_bytes, 1024))) 118 | p = math.pow(1024, i) 119 | s = round(size_bytes / p, 2) 120 | return s, size_name[i] 121 | 122 | 123 | class ProgressBar: 124 | def __init__(self, 125 | width=20, 126 | with_ptg=True, 127 | step_time_average=50, 128 | speed_unit=Unit.Iter): 129 | self._width = width 130 | self._with_ptg = with_ptg 131 | self._step_time_average = step_time_average 132 | self._step_times = [] 133 | self._start_time = 0.0 134 | self._total_size = None 135 | self._speed_unit = speed_unit 136 | 137 | def start(self, total_size): 138 | self._start = True 139 | self._step_times = [] 140 | self._finished_sizes = [] 141 | self._time_elapsed = 0.0 142 | self._current_time = time.time() 143 | self._total_size = total_size 144 | self._progress = 0 145 | 146 | def print_bar(self, finished_size=1, pre_string=None, post_string=None): 147 | self._step_times.append(time.time() - self._current_time) 148 | self._finished_sizes.append(finished_size) 149 | self._time_elapsed += self._step_times[-1] 150 | start_time_str = second_to_time_str(self._time_elapsed) 151 | time_per_size = np.array(self._step_times[-self._step_time_average:]) 152 | time_per_size /= np.array( 153 | self._finished_sizes[-self._step_time_average:]) 154 | average_step_time = np.mean(time_per_size) + 1e-6 155 | if self._speed_unit == Unit.Iter: 156 | speed_str = "{:.2f}it/s".format(1 / average_step_time) 157 | elif self._speed_unit == Unit.Byte: 158 | size, size_unit = convert_size(1 / average_step_time) 159 | speed_str = "{:.2f}{}/s".format(size, size_unit) 160 | else: 161 | raise ValueError("unknown speed unit") 162 | remain_time = (self._total_size - self._progress) * average_step_time 163 | remain_time_str = second_to_time_str(remain_time) 164 | time_str = start_time_str + '>' + remain_time_str 165 | prog_str = progress_str( 166 | (self._progress + 1) / self._total_size, 167 | speed_str, 168 | time_str, 169 | width=self._width, 170 | with_ptg=self._with_ptg) 171 | self._progress += finished_size 172 | if pre_string is not None: 173 | prog_str = pre_string + prog_str 174 | if post_string is not None: 175 | prog_str += post_string 176 | if self._progress >= self._total_size: 177 | print(prog_str + ' ',flush=True) 178 | else: 179 | print(prog_str + ' ', end='\r', flush=True) 180 | self._current_time = time.time() 181 | -------------------------------------------------------------------------------- /rslo/utils/singleton.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | 3 | try: 4 | import fsspec 5 | from petrel_client.client import Client 6 | except: 7 | Client=None 8 | pass 9 | 10 | 11 | class Singleton(type): 12 | _instances = {} 13 | def __call__(cls, *args, **kwargs): 14 | if cls not in cls._instances: 15 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 16 | return cls._instances[cls] 17 | 18 | class HDF5Singleton(type): 19 | _instances = {} 20 | _file_paths = [] 21 | def __call__(cls, file_path, mode='r',libver="latest", swmr=True, rdcc_nbytes=1024**2*15): 22 | # file_path = kwargs.get('file_path') 23 | # if cls not in cls._instances: 24 | if file_path not in cls._file_paths: 25 | # cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 26 | cls._instances[file_path] = super(HDF5Singleton, cls).__call__(file_path, mode, libver, swmr, rdcc_nbytes) 27 | cls._file_paths.append(file_path) 28 | # else: 29 | # cls._instances[file_path].__init__(*args, **kwargs) 30 | 31 | return cls._instances[file_path] 32 | 33 | class HDF5File(metaclass=HDF5Singleton): 34 | # class HDF5File(metaclass=Singleton): 35 | 36 | def __init__(self, file_path, mode='r',libver="latest", swmr=True, rdcc_nbytes=1024**2*15 ): 37 | self.file_path = file_path 38 | 39 | 40 | if Client is not None: 41 | conf_path = '~/petreloss.conf' 42 | print("HDF5 is opening...", flush=True) 43 | client = Client(conf_path) 44 | url = client.generate_presigned_url('s3://Bucket/all.h5') 45 | with fsspec.open(url) as f: 46 | # hf = h5py.File(f) 47 | # self.file = h5py.File(f, mode=mode, libver=libver,swmr=swmr, rdcc_nbytes=rdcc_nbytes) 48 | self.file = h5py.File(f, mode=mode, )#libver=libver,swmr=swmr, rdcc_nbytes=rdcc_nbytes) 49 | print("HDF5 is opened!", flush=True) 50 | else: 51 | self.file = h5py.File(file_path, mode=mode, libver=libver,swmr=swmr, rdcc_nbytes=rdcc_nbytes) 52 | 53 | 54 | # self.file = h5py.File(file_path, mode=mode, libver=libver,swmr=swmr, rdcc_nbytes=rdcc_nbytes) 55 | 56 | def read(self): 57 | # print(self._instances,flush=True) 58 | return self.file 59 | 60 | 61 | -------------------------------------------------------------------------------- /rslo/utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from contextlib import contextmanager 3 | 4 | @contextmanager 5 | def simple_timer(name=''): 6 | t = time.time() 7 | yield 8 | print(f"{name} exec time: {time.time() - t}") -------------------------------------------------------------------------------- /rslo/utils/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import collections 4 | 5 | def freeze_params(params: dict, include: str = None, exclude: str = None): 6 | assert isinstance(params, dict) 7 | include_re = None 8 | if include is not None: 9 | include_re = re.compile(include) 10 | exclude_re = None 11 | if exclude is not None: 12 | exclude_re = re.compile(exclude) 13 | remain_params = [] 14 | for k, p in params.items(): 15 | if include_re is not None: 16 | if include_re.match(k) is not None: 17 | continue 18 | if exclude_re is not None: 19 | if exclude_re.match(k) is None: 20 | continue 21 | remain_params.append(p) 22 | return remain_params 23 | 24 | 25 | def freeze_params_v2(params: dict, include: str = None, exclude: str = None): 26 | assert isinstance(params, dict) 27 | include_re = None 28 | if include is not None: 29 | include_re = re.compile(include) 30 | exclude_re = None 31 | if exclude is not None: 32 | exclude_re = re.compile(exclude) 33 | for k, p in params.items(): 34 | if include_re is not None: 35 | if include_re.match(k) is not None: 36 | p.requires_grad = False 37 | if exclude_re is not None: 38 | if exclude_re.match(k) is None: 39 | p.requires_grad = False 40 | 41 | 42 | def filter_param_dict(state_dict: dict, include: str = None, exclude: str = None): 43 | assert isinstance(state_dict, dict) 44 | include_re = None 45 | if include is not None: 46 | include_re = re.compile(include) 47 | exclude_re = None 48 | if exclude is not None: 49 | exclude_re = re.compile(exclude) 50 | res_dict = {} 51 | for k, p in state_dict.items(): 52 | if include_re is not None: 53 | if include_re.match(k) is None: 54 | continue 55 | if exclude_re is not None: 56 | if exclude_re.match(k) is not None: 57 | continue 58 | res_dict[k] = p 59 | return res_dict 60 | 61 | def modify_parameter_name_with_map(state_dict, parameteter_name_map=None): 62 | if parameteter_name_map is None: 63 | return state_dict 64 | for old,new in parameteter_name_map: 65 | for key in list(state_dict.keys()) : 66 | if old in key: 67 | new_key=key.replace(old, new) 68 | state_dict[new_key] = state_dict.pop(key) 69 | return state_dict 70 | 71 | def load_pretrained_model_map_func(state_dict,parameteter_name_map = None, include:str=None, exclude:str=None): 72 | state_dict = filter_param_dict(state_dict, include, exclude) 73 | state_dict = modify_parameter_name_with_map(state_dict, parameteter_name_map) 74 | 75 | 76 | 77 | def list_recursive_op(input_list, op): 78 | assert isinstance(input_list, list) 79 | 80 | for i, v in enumerate(input_list): 81 | if isinstance(v, list): 82 | input_list[i] = list_recursive_op(v, op) 83 | elif isinstance(v, dict): 84 | input_list[i] = dict_recursive_op(v, op) 85 | else: 86 | input_list[i] = op(v) 87 | 88 | return input_list 89 | 90 | 91 | def dict_recursive_op(input_dict, op): 92 | assert isinstance(input_dict, dict) 93 | 94 | for k, v in input_dict.items(): 95 | if isinstance(v, dict): 96 | input_dict[k] = dict_recursive_op(v, op) 97 | elif isinstance(v, (list,tuple) ): 98 | input_dict[k] = list_recursive_op(v, op) 99 | else: 100 | input_dict[k] = op(v) 101 | 102 | return input_dict 103 | 104 | -------------------------------------------------------------------------------- /rslo/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib import collections as mc 4 | from io import StringIO, BytesIO 5 | import PIL 6 | import cv2 7 | import rslo.utils.pose_utils_np as pun 8 | 9 | 10 | def pltfig2data(fig): 11 | """ 12 | @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it 13 | @param fig a matplotlib figure 14 | @return a numpy 3D array of RGBA values 15 | """ 16 | # draw the renderer 17 | # fig.canvas.draw() 18 | 19 | # # Get the RGBA buffer from the figure 20 | # w, h = fig.canvas.get_width_height() 21 | # buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8) 22 | # buf.shape = (w, h, 4) 23 | 24 | # # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode 25 | # buf = np.roll(buf, 3, axis=2) 26 | # buf = buf.astype(float)/255 27 | 28 | # 申请缓冲地址 29 | buffer_ = BytesIO() # StringIO() # using buffer,great way! 30 | # 保存在内存中,而不是在本地磁盘,注意这个默认认为你要保存的就是plt中的内容 31 | fig.savefig(buffer_, format='png') 32 | buffer_.seek(0) 33 | # 用PIL或CV2从内存中读取 34 | dataPIL = PIL.Image.open(buffer_) 35 | # 转换为nparrary,PIL转换就非常快了,data即为所需 36 | data = np.asarray(dataPIL) 37 | data = data.astype(float)/255. 38 | # cv2.imwrite('test.png', data) 39 | # 释放缓存 40 | buffer_.close() 41 | plt.close(fig) 42 | 43 | return data 44 | 45 | 46 | # def draw_odometry(odom_vectors, gt_vectors=None, view='bv', saving_dir=None): 47 | 48 | def draw_trajectory(poses_pred, poses_gt=None, view='bv', saving_dir=None, figure=None, ax=None, color='b', error_step=1, odom_errors=None): 49 | """[summary] 50 | 51 | Arguments: 52 | poses_pred {[np.array]} -- [(N,7)] 53 | 54 | Keyword Arguments: 55 | poses_gt {[np.array]} -- [(N,7)] (default: {None}) 56 | view {str} -- [description] (default: {'bv'}) 57 | saving_dir {[type]} -- [description] (default: {None}) 58 | figure {[type]} -- [description] (default: {None}) 59 | ax {[type]} -- [description] (default: {None}) 60 | color {str} -- [description] (default: {'b'}) 61 | """ 62 | 63 | assert(view in ['bv', 'front', 'side']) 64 | translation, rotation = poses_pred[:, :3], poses_pred[:, 3:] 65 | if poses_gt is not None: 66 | assert len(poses_pred) == len(poses_gt) 67 | translation_gt, rotation_gt = poses_gt[:, :3], poses_gt[:, 3:] 68 | 69 | if view == 'bv': 70 | dim0, dim1 = 0, 1 71 | elif view == 'front': 72 | dim0, dim1 = 0, 1 73 | elif view == 'side': 74 | dim0, dim1 = 0, 1 75 | 76 | if figure is None or ax is None: 77 | figure = plt.figure() 78 | ax = figure.add_subplot(111) 79 | 80 | for i in range(1, len(translation)): 81 | if i == 1: 82 | ax.plot([translation[i-1][dim0]], [ 83 | translation[i-1][dim1]], '*', markersize=10, color=color) 84 | 85 | ax.plot([translation[i-1][dim0], translation[i][dim0]], [ 86 | translation[i-1][dim1], translation[i][dim1]], '-', markersize=0.5, color=color) 87 | 88 | if poses_gt is not None: 89 | ax.plot([translation_gt[i-1][dim0], translation_gt[i][dim0]], [ 90 | translation_gt[i-1][dim1], translation_gt[i][dim1]], '-', markersize=0.5, color='r') 91 | if i % 50 == 0: 92 | # plot connection lines 93 | ax.plot([translation[i][dim0], translation_gt[i][dim0]], [ 94 | translation[i][dim1], translation_gt[i][dim1]], '-', markersize=0.03, color='gray') 95 | 96 | # and i%error_step==0 and i//error_step 2 | 3 | #include 4 | #include 5 | 6 | __global__ 7 | void ChamferDistanceKernel( 8 | int b, 9 | int n, 10 | const float* xyz, 11 | int m, 12 | const float* xyz2, 13 | float* result, 14 | int* result_i) 15 | { 16 | const int batch=512; 17 | __shared__ float buf[batch*3]; 18 | for (int i=blockIdx.x;ibest){ 130 | result[(i*n+j)]=best; 131 | result_i[(i*n+j)]=best_i; 132 | } 133 | } 134 | __syncthreads(); 135 | } 136 | } 137 | } 138 | 139 | void ChamferDistanceKernelLauncher( 140 | const int b, const int n, 141 | const float* xyz, 142 | const int m, 143 | const float* xyz2, 144 | float* result, 145 | int* result_i, 146 | float* result2, 147 | int* result2_i 148 | ) 149 | { 150 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); 151 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); 152 | 153 | cudaError_t err = cudaGetLastError(); 154 | if (err != cudaSuccess) 155 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); 156 | } 157 | void OneDirectionChamferDistanceKernelLauncher( 158 | const int b, const int n, 159 | const float* xyz, 160 | const int m, 161 | const float* xyz2, 162 | float* result, 163 | int* result_i 164 | // float* result2, 165 | // int* result2_i 166 | ) 167 | { 168 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); 169 | // ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); 170 | 171 | cudaError_t err = cudaGetLastError(); 172 | if (err != cudaSuccess) 173 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); 174 | } 175 | 176 | 177 | __global__ 178 | void ChamferDistanceGradKernel( 179 | int b, int n, 180 | const float* xyz1, 181 | int m, 182 | const float* xyz2, 183 | const float* grad_dist1, 184 | const int* idx1, 185 | float* grad_xyz1, 186 | float* grad_xyz2) 187 | { 188 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); 223 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); 224 | 225 | cudaError_t err = cudaGetLastError(); 226 | if (err != cudaSuccess) 227 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); 228 | } 229 | void OneDirectionChamferDistanceGradKernelLauncher( 230 | const int b, const int n, 231 | const float* xyz1, 232 | const int m, 233 | const float* xyz2, 234 | const float* grad_dist1, 235 | const int* idx1, 236 | // const float* grad_dist2, 237 | // const int* idx2, 238 | float* grad_xyz1, 239 | float* grad_xyz2 240 | ) 241 | { 242 | cudaMemset(grad_xyz1, 0, b*n*3*4); 243 | cudaMemset(grad_xyz2, 0, b*m*3*4); 244 | ChamferDistanceGradKernel<<>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); 245 | // ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); 246 | 247 | cudaError_t err = cudaGetLastError(); 248 | if (err != cudaSuccess) 249 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); 250 | } --------------------------------------------------------------------------------