├── .gitignore ├── LICENSE ├── README.md ├── calib ├── 1843 │ └── 1843_v1.cfg ├── README.md ├── coords_1843.drawio.png ├── coords_1843.png ├── d435i │ ├── RealSense_D435i_v1.yaml │ ├── d435i_params.txt │ ├── depth_intrinsics_640x480.txt │ ├── depth_intrinsics_848x100.txt │ └── depth_intrinsics_848x480.txt └── t265 │ └── RealSense_T265_v1.yaml ├── cartographer ├── configuration_files │ ├── default.lua │ ├── known_poses.lua │ ├── radar.lua │ └── scan_only.lua └── launch │ ├── backpack_2d.launch │ ├── demo_backpack_2d.launch │ ├── offline_backpack_2d.launch │ └── offline_node.launch ├── configs ├── default.yaml ├── main_0.yaml ├── main_1.yaml ├── main_2.yaml ├── main_3.yaml └── main_4.yaml ├── env.yaml ├── main.py ├── main_eval.py ├── odom_eval.sh ├── radarize ├── __init__.py ├── config │ ├── __init__.py │ └── default.py ├── flow │ ├── dataloader.py │ └── model.py ├── rotnet │ ├── dataloader.py │ └── model.py ├── unet │ ├── dataloader.py │ ├── dice_score.py │ └── model.py └── utils │ ├── dsp.py │ ├── grid_map.py │ ├── image_tools.py │ └── radar_config.py ├── requirements.txt ├── run.sh ├── run_eval.sh ├── setup.py ├── slam_eval.sh └── tools ├── create_dataset.py ├── eval_traj.py ├── export_cartographer.py ├── extract_gt.py ├── odombag_to_txt.py ├── run_carto.py ├── test_flow.py ├── test_odom.py ├── test_rot.py ├── test_unet.py ├── train_flow.py ├── train_rot.py └── train_unet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # numpy 107 | *.npz 108 | 109 | # pytorch 110 | *.pth 111 | 112 | # onnx 113 | *.onnx 114 | 115 | # openvino 116 | *.xml 117 | *.bin 118 | *.mapping 119 | 120 | # IntelliJ 121 | .idea/ 122 | 123 | *.swo 124 | *.swp 125 | 126 | data/* 127 | *_/ 128 | 129 | *.jpg 130 | *.txt 131 | *.bag 132 | *.mp4 133 | *.png 134 | *.pbstream 135 | *.pgm 136 | /output_* 137 | log 138 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Radarize: Enhancing Radar SLAM with Generalizable Doppler-Based Odometry [MobiSys'24] 2 | 3 | 4 | https://github.com/ConnectedSystemsLab/radarize_ae/assets/14133352/b713b0d6-a548-4776-8d54-673ec6a2543d 5 | 6 | 7 | ### Prerequisites 8 | 9 | - Ubuntu 20.04 10 | - ROS Noetic 11 | - Conda with Python 3.8+ 12 | - CUDA >= 11.3 capable GPU. 13 | - ImageMagick 14 | 15 | ### Setup 16 | 17 | 1. Install conda environment with ```conda env create -f env.yaml```. 18 | 2. Source environment ```conda activate radarize_ae``` and then ```pip install -e .```. 19 | 3. Install [cartographer_ros](https://google-cartographer-ros.readthedocs.io/en/latest/compilation.html). Inside ```cartographer/```, ```configuration_files/``` and ```launch/``` into ```/install_isolated/share/cartographer_ros/```. 20 | 5. Source conda environment (if not already) and cartographer_ros environment. 21 | ```shell script 22 | conda activate radarize_ae 23 | source /install_isolated/setup.bash 24 | ``` 25 | 26 | ### Dataset Preparation 27 | 28 | 1. Download the dataset ```dataset.zip``` from the [link](https://zenodo.org/records/11093859) and unzip into this directory. 29 | 2. Download the saved models+outputs ```eval.zip``` from [link](https://zenodo.org/records/11093859) and unzip into this directory. 30 | 31 | ### Evaluation 32 | 33 | To generate results in the paper, run the top-level script 34 | ```shell script 35 | ./run_eval.sh 36 | ``` 37 | Then, 38 | 1. Run ```./slam_eval.sh``` to get the SLAM metrics. 39 | 2. Run ```./odom_eval.sh``` to get the odometry metrics. 40 | 41 | ### Training from Scratch 42 | 43 | ```shell script 44 | ./run.sh 45 | ``` 46 | 47 | -------------------------------------------------------------------------------- /calib/1843/1843_v1.cfg: -------------------------------------------------------------------------------- 1 | % *************************************************************** 2 | % Created for SDK ver:03.05 3 | % Created using Visualizer ver:3.5.0.0 4 | % Frequency:77 5 | % Platform:xWR18xx 6 | % Scene Classifier:best_range_res 7 | % Azimuth Resolution(deg):15 8 | % Range Resolution(m):0.043 9 | % Maximum unambiguous Range(m):3.86 10 | % Maximum Radial Velocity(m/s):0.32 11 | % Radial velocity resolution(m/s):0.03 12 | % Frame Duration(msec):200 13 | % RF calibration data:None 14 | % Range Detection Threshold (dB):15 15 | % Doppler Detection Threshold (dB):15 16 | % Range Peak Grouping:enabled 17 | % Doppler Peak Grouping:enabled 18 | % Static clutter removal:disabled 19 | % Angle of Arrival FoV: Full FoV 20 | % Range FoV: Full FoV 21 | % Doppler FoV: Full FoV 22 | % *************************************************************** 23 | sensorStop 24 | flushCfg 25 | dfeDataOutputMode 1 26 | channelCfg 15 7 0 27 | adcCfg 2 1 28 | adcbufCfg -1 0 1 1 1 29 | profileCfg 0 77 122 7 50 0 0 80 1 96 2285 0 0 30 30 | chirpCfg 0 0 0 0 0 0 0 1 31 | chirpCfg 1 1 0 0 0 0 0 4 32 | chirpCfg 2 2 0 0 0 0 0 2 33 | frameCfg 0 2 32 0 33.333 1 0 34 | lowPower 0 0 35 | guiMonitor -1 1 0 0 0 0 0 36 | cfarCfg -1 0 2 8 4 3 0 15 1 37 | cfarCfg -1 1 0 8 4 4 1 15 1 38 | multiObjBeamForming -1 1 0.5 39 | clutterRemoval -1 0 40 | calibDcRangeSig -1 1 -5 8 256 41 | extendedMaxVelocity -1 0 42 | lvdsStreamCfg -1 0 1 0 43 | compRangeBiasAndRxChanPhase 0.0658967 0.52295 0.41629 0.55545 0.40747 0.58170 0.39825 0.73706 0.29401 0.68317 0.34198 0.70334 0.32593 0.73782 0.29321 0.83664 0.18008 0.85019 0.05954 0.85431 0.00671 0.85083 -0.05475 0.98816 -0.15347 44 | measureRangeBiasAndRxChanPhase 0 1.5 0.2 45 | CQRxSatMonitor 0 3 4 67 0 46 | CQSigImgMonitor 0 55 4 47 | analogMonitor 0 0 48 | aoaFovCfg -1 -90 90 -90 90 49 | cfarFovCfg -1 0 0 6.10 50 | cfarFovCfg -1 1 -1.24 1.24 51 | calibData 0 0 0 52 | sensorStart 53 | -------------------------------------------------------------------------------- /calib/README.md: -------------------------------------------------------------------------------- 1 | # Sensor Arrangement 2 | 3 | 4 | # Rosbag Format 5 | 6 | ```shell script 7 | topics: /camera/depth/image_rect_raw/compressedDepth 4307 msgs : sensor_msgs/CompressedImage 8 | /radar0/radar_data 1800 msgs : xwr_raw_ros/RadarFrameFull 9 | /ti_mmwave/radar_scan_pcl_0 1800 msgs : sensor_msgs/PointCloud2 10 | /tracking/fisheye1/image_raw/compressed 1801 msgs : sensor_msgs/CompressedImage 11 | /tracking/fisheye2/image_raw/compressed 1800 msgs : sensor_msgs/CompressedImage 12 | /tracking/imu 11972 msgs : sensor_msgs/Imu 13 | /orb_slam3 1787 msgs : geometry_msgs/PoseStamped 14 | /tracking/odom/sample 11971 msgs : nav_msgs/Odometry 15 | ``` 16 | 17 | - `/tracking/odom/sample`: T265 VIO baseline/pseudo-groundtruth. 18 | - `/camera/depth/image_rect_raw/compressedDepth`: Depth camera pseudo-groundtruth. 19 | - `/tracking/fisheye1/image_raw/compressed`: Black-white fisheye image from left camera. 20 | - `/tracking/fisheye2/image_raw/compressed`: Black-white fisheye image from right camera. 21 | - `/tracking/imu`: Linearly interpolated IMU samples. 22 | - `/ti_mmwave/radar_scan_pcl_0`: Radar point cloud. 23 | - `/radar0/radar_data`: Raw DSP samples from radar. 24 | 25 | -------------------------------------------------------------------------------- /calib/coords_1843.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConnectedSystemsLab/radarize_ae/cbcb6439b89c068d86a424215ee61edc925191ae/calib/coords_1843.drawio.png -------------------------------------------------------------------------------- /calib/coords_1843.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConnectedSystemsLab/radarize_ae/cbcb6439b89c068d86a424215ee61edc925191ae/calib/coords_1843.png -------------------------------------------------------------------------------- /calib/d435i/RealSense_D435i_v1.yaml: -------------------------------------------------------------------------------- 1 | %YAML:1.0 2 | 3 | #-------------------------------------------------------------------------------------------- 4 | # Camera Parameters. Adjust them! 5 | #-------------------------------------------------------------------------------------------- 6 | File.version: "1.0" 7 | 8 | Camera.type: "PinHole" 9 | 10 | # Rectified Camera calibration and distortion parameters (OpenCV) 11 | Camera1.fx: 388.95886821 12 | Camera1.fy: 388.3476606 13 | Camera1.cx: 321.70008494 14 | Camera1.cy: 238.2637429 15 | 16 | Camera1.k1: 0.00354659 17 | Camera1.k2: -0.00234589 18 | Camera1.p1: -0.00066934 19 | Camera1.p2: 0.00075519 20 | 21 | Camera2.fx: 388.92902657 22 | Camera2.fy: 388.3476606 23 | Camera2.cx: 321.70008494 24 | Camera2.cy: 238.2637429 25 | 26 | # distortion parameters 27 | Camera2.k1: 0.00354659 28 | Camera2.k2: -0.00234589 29 | Camera2.p1: -0.00066934 30 | Camera2.p2: 0.00075519 31 | 32 | # Camera resolution 33 | Camera.width: 640 34 | Camera.height: 480 35 | 36 | # Stereo.b: 0.0499585 37 | 38 | # Camera frames per second 39 | Camera.fps: 30 40 | 41 | # Color order of the images (0: BGR, 1: RGB. It is ignored if images are grayscale) 42 | Camera.RGB: 1 43 | 44 | # Close/Far threshold. Baseline times. 45 | Stereo.ThDepth: 40.0 46 | Stereo.T_c1_c2: !!opencv-matrix 47 | rows: 4 48 | cols: 4 49 | dt: f 50 | data: [0.99999977, 0.00004083, 0.00068189, -0.05012109, 51 | -0.00004103, 0.99999996, 0.00029641, -0.00007619, 52 | -0.00068188, -0.00029643, 0.99999972, -0.0001643, 53 | 0.0, 0.0, 0.0, 1.0] 54 | 55 | 56 | # Transformation from body-frame (imu) to left camera 57 | IMU.T_b_c1: !!opencv-matrix 58 | rows: 4 59 | cols: 4 60 | dt: f 61 | data: [ 0.99999521, -0.00307477, -0.00036087, 0.00705315, 62 | 0.00307628, 0.99998626, 0.00424402, -0.01035509, 63 | 0.00034781, -0.00424511, 0.99999093, -0.02071244, 64 | 0.0 , 0.0 , 0.0 , 1.0 ] 65 | 66 | # Do not insert KFs when recently lost 67 | IMU.InsertKFsWhenLost: 1 68 | 69 | # IMU noise (Use those from VINS-mono) 70 | # IMU.NoiseGyro: 0.01 # 2.44e-4 #1e-3 # rad/s^0.5 71 | # IMU.NoiseAcc: 0.1 # 1.47e-3 #1e-2 # m/s^1.5 72 | # IMU.GyroWalk: 0.0001 # rad/s^1.5 73 | # IMU.AccWalk: 0.001 # m/s^2.5 74 | # IMU.Frequency: 200.0 75 | 76 | IMU.NoiseAcc: 8.3796512709476622e-03 77 | IMU.NoiseGyro: 2.0619835867634816e-04 78 | IMU.AccWalk: 1.6803766689321894e-04 79 | IMU.GyroWalk: 3.1590687080878666e-07 80 | IMU.Frequency: 200.0 81 | 82 | #-------------------------------------------------------------------------------------------- 83 | # ORB Parameters 84 | #-------------------------------------------------------------------------------------------- 85 | # ORB Extractor: Number of features per image 86 | ORBextractor.nFeatures: 1250 87 | 88 | # ORB Extractor: Scale factor between levels in the scale pyramid 89 | ORBextractor.scaleFactor: 1.2 90 | 91 | # ORB Extractor: Number of levels in the scale pyramid 92 | ORBextractor.nLevels: 8 93 | 94 | # ORB Extractor: Fast threshold 95 | # Image is divided in a grid. At each cell FAST are extracted imposing a minimum response. 96 | # Firstly we impose iniThFAST. If no corners are detected we impose a lower value minThFAST 97 | # You can lower these values if your images have low contrast 98 | ORBextractor.iniThFAST: 20 99 | ORBextractor.minThFAST: 7 100 | 101 | #-------------------------------------------------------------------------------------------- 102 | # Viewer Parameters 103 | #-------------------------------------------------------------------------------------------- 104 | Viewer.KeyFrameSize: 0.05 105 | Viewer.KeyFrameLineWidth: 1.0 106 | Viewer.GraphLineWidth: 0.9 107 | Viewer.PointSize: 2.0 108 | Viewer.CameraSize: 0.08 109 | Viewer.CameraLineWidth: 3.0 110 | Viewer.ViewpointX: 0.0 111 | Viewer.ViewpointY: -0.7 112 | Viewer.ViewpointZ: -3.5 113 | Viewer.ViewpointF: 500.0 114 | -------------------------------------------------------------------------------- /calib/d435i/depth_intrinsics_640x480.txt: -------------------------------------------------------------------------------- 1 | 640 2 | 480 3 | 383.99365234375 4 | 383.99365234375 5 | 328.120178222656 6 | 240.849258422852 7 | -------------------------------------------------------------------------------- /calib/d435i/depth_intrinsics_848x100.txt: -------------------------------------------------------------------------------- 1 | 848 2 | 100 3 | 423.993011474609 4 | 423.993011474609 5 | 432.966033935547 6 | 50.9377288818359 7 | -------------------------------------------------------------------------------- /calib/d435i/depth_intrinsics_848x480.txt: -------------------------------------------------------------------------------- 1 | 848 2 | 480 3 | 423.993011474609 4 | 423.993011474609 5 | 432.966033935547 6 | 240.937728881836 7 | -------------------------------------------------------------------------------- /calib/t265/RealSense_T265_v1.yaml: -------------------------------------------------------------------------------- 1 | %YAML:1.0 2 | 3 | #-------------------------------------------------------------------------------------------- 4 | # Camera Parameters. Adjust them! 5 | #-------------------------------------------------------------------------------------------- 6 | File.version: "1.0" 7 | 8 | Camera.type: "KannalaBrandt8" 9 | 10 | # Left Camera calibration and distortion parameters (OpenCV) 11 | Camera1.fx: 285.82002842632926 12 | Camera1.fy: 286.39201013451094 13 | Camera1.cx: 424.9534349110273 14 | Camera1.cy: 396.62750983243376 15 | 16 | # Kannala-Brandt distortion parameters 17 | Camera1.k1: -0.0060484133288653714 18 | Camera1.k2: 0.042222609684972684 19 | Camera1.k3: -0.04001457783565841 20 | Camera1.k4: 0.007268133385671889 21 | 22 | # Right Camera calibration and distortion parameters (OpenCV) 23 | Camera2.fx: 286.5364970064857 24 | Camera2.fy: 287.19439276175706 25 | Camera2.cx: 424.31946287546856 26 | Camera2.cy: 396.3415411992376 27 | 28 | # Kannala-Brandt distortion parameters 29 | Camera2.k1: -0.005937201349356227 30 | Camera2.k2: 0.040942594028177265 31 | Camera2.k3: -0.04046713118392198 32 | Camera2.k4: 0.007446253025715504 33 | 34 | # Transformation matrix from right camera to left camera 35 | Stereo.T_c1_c2: !!opencv-matrix 36 | rows: 4 37 | cols: 4 38 | dt: f 39 | data: [ 9.99979824e-01, 6.19970261e-03, 1.38411194e-03, 6.40931652e-02, 40 | -6.20322636e-03, 9.99977493e-01, 2.55624740e-03, -5.59134223e-04, 41 | -1.36823281e-03, -2.56478178e-03, 9.99995775e-01, -6.17149296e-04, 42 | 0.0, 0.0, 0.0, 1.0] 43 | 44 | # Overlapping area between images (to be updated) 45 | Camera1.overlappingBegin: 0 46 | Camera1.overlappingEnd: 848 47 | 48 | Camera2.overlappingBegin: 0 49 | Camera2.overlappingEnd: 848 50 | 51 | # Camera resolution 52 | Camera.width: 848 53 | Camera.height: 800 54 | 55 | # Camera frames per second 56 | Camera.fps: 30 57 | 58 | # Color order of the images (0: BGR, 1: RGB. It is ignored if images are grayscale) 59 | Camera.RGB: 1 60 | 61 | # Close/Far threshold. Baseline times. 62 | Stereo.ThDepth: 40.0 63 | 64 | #-------------------------------------------------------------------------------------------- 65 | # IMU Parameters 66 | #-------------------------------------------------------------------------------------------- 67 | 68 | # Transformation from body-frame (imu) to left camera 69 | IMU.T_b_c1: !!opencv-matrix 70 | rows: 4 71 | cols: 4 72 | dt: f 73 | data: [-9.99960067e-01, 6.31627483e-03, 6.32204312e-03, 1.29057097e-02, 74 | -6.32075483e-03, -9.99979787e-01, -6.88902218e-04, 4.78135333e-03, 75 | 6.31756404e-03, -7.28834793e-04, 9.99979778e-01, -7.78434454e-05, 76 | 0.0, 0.0, 0.0, 1.0] 77 | 78 | # Do not insert KFs when recently lost 79 | IMU.InsertKFsWhenLost: 0 80 | 81 | # IMU noise 82 | # IMU.NoiseGyro: 0.01 # 2.44e-4 #1e-3 # rad/s^0.5 83 | # IMU.NoiseAcc: 0.1 # 1.47e-3 #1e-2 # m/s^1.5 84 | # IMU.GyroWalk: 0.0001 # rad/s^1.5 85 | # IMU.AccWalk: 0.001 # m/s^2.5 86 | # IMU.Frequency: 200.0 87 | IMU.NoiseGyro: 0.0025377988014162068 # 0.000005148030141 # rad/s^0.5 88 | IMU.NoiseAcc: 0.012566092190412638 # 0.000066952452471 # m/s^1.5 89 | IMU.GyroWalk: 5.7816895265149615e-05 # rad/s^1.5 90 | IMU.AccWalk: 0.0013378465829194741 # m/s^2.5 91 | IMU.Frequency: 200.0 92 | 93 | 94 | #-------------------------------------------------------------------------------------------- 95 | # ORB Parameters 96 | #-------------------------------------------------------------------------------------------- 97 | 98 | # ORB Extractor: Number of features per image 99 | ORBextractor.nFeatures: 1000 # Tested with 1250 100 | 101 | # ORB Extractor: Scale factor between levels in the scale pyramid 102 | ORBextractor.scaleFactor: 1.2 103 | 104 | # ORB Extractor: Number of levels in the scale pyramid 105 | ORBextractor.nLevels: 8 106 | 107 | # ORB Extractor: Fast threshold 108 | # Image is divided in a grid. At each cell FAST are extracted imposing a minimum response. 109 | # Firstly we impose iniThFAST. If no corners are detected we impose a lower value minThFAST 110 | # You can lower these values if your images have low contrast 111 | ORBextractor.iniThFAST: 20 112 | ORBextractor.minThFAST: 10 113 | 114 | #-------------------------------------------------------------------------------------------- 115 | # Viewer Parameters 116 | #-------------------------------------------------------------------------------------------- 117 | Viewer.KeyFrameSize: 0.05 118 | Viewer.KeyFrameLineWidth: 1.0 119 | Viewer.GraphLineWidth: 0.9 120 | Viewer.PointSize: 2.0 121 | Viewer.CameraSize: 0.08 122 | Viewer.CameraLineWidth: 3.0 123 | Viewer.ViewpointX: 0.0 124 | Viewer.ViewpointY: -0.7 125 | Viewer.ViewpointZ: -3.5 126 | Viewer.ViewpointF: 500.0 127 | Viewer.imageViewScale: 2.0 128 | -------------------------------------------------------------------------------- /cartographer/configuration_files/default.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 The Cartographer Authors 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | include "map_builder.lua" 16 | include "trajectory_builder.lua" 17 | 18 | options = { 19 | map_builder = MAP_BUILDER, 20 | trajectory_builder = TRAJECTORY_BUILDER, 21 | map_frame = "map", 22 | tracking_frame = "base_link", 23 | published_frame = "base_link", 24 | odom_frame = "odom", 25 | provide_odom_frame = false, 26 | publish_frame_projected_to_2d = false, 27 | use_pose_extrapolator = true, 28 | use_odometry = true, 29 | use_nav_sat = false, 30 | use_landmarks = false, 31 | num_laser_scans = 1, 32 | num_multi_echo_laser_scans = 0, 33 | num_subdivisions_per_laser_scan = 1, 34 | num_point_clouds = 0, 35 | lookup_transform_timeout_sec = 0.2, 36 | submap_publish_period_sec = 0.3, 37 | pose_publish_period_sec = 33e-3, 38 | trajectory_publish_period_sec = 33e-3, 39 | rangefinder_sampling_ratio = 1., 40 | odometry_sampling_ratio = 1., 41 | fixed_frame_pose_sampling_ratio = 1., 42 | imu_sampling_ratio = 1., 43 | landmarks_sampling_ratio = 1., 44 | } 45 | 46 | MAP_BUILDER.use_trajectory_builder_2d = true 47 | TRAJECTORY_BUILDER_2D.use_imu_data = false 48 | TRAJECTORY_BUILDER_2D.min_range = 0.3 49 | TRAJECTORY_BUILDER_2D.max_range = 4.284 50 | TRAJECTORY_BUILDER_2D.missing_data_ray_length = 1.5 51 | TRAJECTORY_BUILDER_2D.num_accumulated_range_data = 1 52 | TRAJECTORY_BUILDER_2D.submaps.num_range_data = 60 53 | 54 | TRAJECTORY_BUILDER_2D.use_online_correlative_scan_matching = false 55 | TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.linear_search_window = 1e-2 56 | TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.angular_search_window = math.rad(20.) 57 | TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.translation_delta_cost_weight = 0.1 58 | TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.rotation_delta_cost_weight = 0.1 59 | 60 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.occupied_space_weight = 1e3 61 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.translation_weight = 1e-3 62 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.rotation_weight = 1e-3 63 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.ceres_solver_options.use_nonmonotonic_steps = false 64 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.ceres_solver_options.max_num_iterations = 20 65 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.ceres_solver_options.num_threads = 1 66 | 67 | TRAJECTORY_BUILDER_2D.motion_filter.max_time_seconds = 0.0 68 | TRAJECTORY_BUILDER_2D.motion_filter.max_distance_meters = 0.0 69 | TRAJECTORY_BUILDER_2D.motion_filter.max_angle_radians = math.rad(0.0) 70 | 71 | POSE_GRAPH.optimize_every_n_nodes = 30 72 | POSE_GRAPH.optimization_problem.local_slam_pose_translation_weight = 1e5 73 | POSE_GRAPH.optimization_problem.local_slam_pose_rotation_weight = 1e5 74 | POSE_GRAPH.optimization_problem.odometry_translation_weight = 3e5 75 | POSE_GRAPH.optimization_problem.odometry_rotation_weight = 3e5 76 | POSE_GRAPH.constraint_builder.max_constraint_distance = 0 77 | 78 | return options 79 | 80 | -------------------------------------------------------------------------------- /cartographer/configuration_files/known_poses.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 The Cartographer Authors 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | include "map_builder.lua" 16 | include "trajectory_builder.lua" 17 | 18 | options = { 19 | map_builder = MAP_BUILDER, 20 | trajectory_builder = TRAJECTORY_BUILDER, 21 | map_frame = "map", 22 | tracking_frame = "base_link", 23 | published_frame = "base_link", 24 | odom_frame = "odom", 25 | provide_odom_frame = false, 26 | publish_frame_projected_to_2d = false, 27 | use_pose_extrapolator = false, 28 | use_odometry = true, 29 | use_nav_sat = false, 30 | use_landmarks = false, 31 | num_laser_scans = 1, 32 | num_multi_echo_laser_scans = 0, 33 | num_subdivisions_per_laser_scan = 1, 34 | num_point_clouds = 0, 35 | lookup_transform_timeout_sec = 0.2, 36 | submap_publish_period_sec = 0.3, 37 | pose_publish_period_sec = 33e-3, 38 | trajectory_publish_period_sec = 33e-3, 39 | rangefinder_sampling_ratio = 1., 40 | odometry_sampling_ratio = 1., 41 | fixed_frame_pose_sampling_ratio = 1., 42 | imu_sampling_ratio = 1., 43 | landmarks_sampling_ratio = 1., 44 | } 45 | 46 | MAP_BUILDER.use_trajectory_builder_2d = true 47 | TRAJECTORY_BUILDER_2D.use_imu_data = false 48 | TRAJECTORY_BUILDER_2D.min_range = 0.3 49 | TRAJECTORY_BUILDER_2D.max_range = 4.284 50 | TRAJECTORY_BUILDER_2D.missing_data_ray_length = 1.5 51 | TRAJECTORY_BUILDER_2D.num_accumulated_range_data = 1 52 | TRAJECTORY_BUILDER_2D.submaps.num_range_data = 90 53 | 54 | TRAJECTORY_BUILDER_2D.use_online_correlative_scan_matching = false 55 | TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.linear_search_window = 0.1 56 | TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.angular_search_window = math.rad(20.) 57 | TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.translation_delta_cost_weight = 1e2 58 | TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.rotation_delta_cost_weight = 1e2 59 | 60 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.occupied_space_weight = 1e2 61 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.translation_weight = .3e5 62 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.rotation_weight = 1e5 63 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.ceres_solver_options.use_nonmonotonic_steps = false 64 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.ceres_solver_options.max_num_iterations = 1 65 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.ceres_solver_options.num_threads = 1 66 | 67 | TRAJECTORY_BUILDER_2D.motion_filter.max_time_seconds = 0.0 68 | TRAJECTORY_BUILDER_2D.motion_filter.max_distance_meters = 0.0 69 | TRAJECTORY_BUILDER_2D.motion_filter.max_angle_radians = math.rad(0.0) 70 | 71 | POSE_GRAPH.optimize_every_n_nodes = 0 72 | POSE_GRAPH.optimization_problem.local_slam_pose_translation_weight = 1e-5 73 | POSE_GRAPH.optimization_problem.local_slam_pose_rotation_weight = 1e-5 74 | POSE_GRAPH.optimization_problem.odometry_translation_weight = 3e5 75 | POSE_GRAPH.optimization_problem.odometry_rotation_weight = 3e5 76 | 77 | return options 78 | -------------------------------------------------------------------------------- /cartographer/configuration_files/radar.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 The Cartographer Authors 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | include "map_builder.lua" 16 | include "trajectory_builder.lua" 17 | 18 | options = { 19 | map_builder = MAP_BUILDER, 20 | trajectory_builder = TRAJECTORY_BUILDER, 21 | map_frame = "map", 22 | tracking_frame = "base_link", 23 | published_frame = "base_link", 24 | odom_frame = "odom", 25 | provide_odom_frame = false, 26 | publish_frame_projected_to_2d = false, 27 | use_pose_extrapolator = true, 28 | use_odometry = true, 29 | use_nav_sat = false, 30 | use_landmarks = false, 31 | num_laser_scans = 1, 32 | num_multi_echo_laser_scans = 0, 33 | num_subdivisions_per_laser_scan = 1, 34 | num_point_clouds = 0, 35 | lookup_transform_timeout_sec = 0.2, 36 | submap_publish_period_sec = 33e-3, 37 | pose_publish_period_sec = 33e-3, 38 | trajectory_publish_period_sec = 33e-3, 39 | rangefinder_sampling_ratio = .5, 40 | odometry_sampling_ratio = .5, 41 | fixed_frame_pose_sampling_ratio = 1., 42 | imu_sampling_ratio = 1., 43 | landmarks_sampling_ratio = 1., 44 | } 45 | 46 | MAP_BUILDER.use_trajectory_builder_2d = true 47 | TRAJECTORY_BUILDER_2D.use_imu_data = false 48 | TRAJECTORY_BUILDER_2D.min_range = 0.3 49 | TRAJECTORY_BUILDER_2D.max_range = 4.284 50 | -- TRAJECTORY_BUILDER_2D.missing_data_ray_length = 99999 51 | TRAJECTORY_BUILDER_2D.num_accumulated_range_data = 1 52 | 53 | TRAJECTORY_BUILDER_2D.use_online_correlative_scan_matching = false 54 | -- TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.linear_search_window = 0.0 55 | -- TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.angular_search_window = math.rad(35.) 56 | -- TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.translation_delta_cost_weight = 1. 57 | -- TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.rotation_delta_cost_weight = 1. 58 | 59 | -- TRAJECTORY_BUILDER_2D.submaps.range_data_inserter.probability_grid_range_data_inserter.hit_probability = 0.8 60 | -- TRAJECTORY_BUILDER_2D.submaps.range_data_inserter.probability_grid_range_data_inserter.miss_probability = 0.4 61 | TRAJECTORY_BUILDER_2D.submaps.grid_options_2d.resolution = 0.08 62 | TRAJECTORY_BUILDER_2D.submaps.num_range_data = 30 63 | 64 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.occupied_space_weight = 70. 65 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.translation_weight = 600. 66 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.rotation_weight = 40. 67 | 68 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.ceres_solver_options.use_nonmonotonic_steps = true 69 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.ceres_solver_options.max_num_iterations = 100 70 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.ceres_solver_options.num_threads = 1 71 | 72 | -- TRAJECTORY_BUILDER_2D.motion_filter.max_time_seconds = 1.0 73 | -- TRAJECTORY_BUILDER_2D.motion_filter.max_distance_meters = 1.0 74 | -- TRAJECTORY_BUILDER_2D.motion_filter.max_angle_radians = math.rad(0.3) 75 | 76 | POSE_GRAPH.optimize_every_n_nodes = 90 77 | 78 | -- POSE_GRAPH.constraint_builder.max_constraint_distance = 5. 79 | -- POSE_GRAPH.constraint_builder.sampling_ratio = 1. 80 | -- POSE_GRAPH.constraint_builder.fast_correlative_scan_matcher.linear_search_window = 5. 81 | -- POSE_GRAPH.constraint_builder.fast_correlative_scan_matcher.angular_search_window = math.rad(45.) 82 | -- POSE_GRAPH.constraint_builder.fast_correlative_scan_matcher.branch_and_bound_depth = 7 83 | -- POSE_GRAPH.constraint_builder.min_score = 0.7 84 | -- POSE_GRAPH.constraint_builder.ceres_scan_matcher.occupied_space_weight = 80. 85 | -- POSE_GRAPH.constraint_builder.ceres_scan_matcher.translation_weight = 40. 86 | -- POSE_GRAPH.constraint_builder.ceres_scan_matcher.rotation_weight = 1. 87 | -- POSE_GRAPH.constraint_builder.loop_closure_translation_weight = 1e5 88 | -- POSE_GRAPH.constraint_builder.loop_closure_rotation_weight = .1e5 89 | 90 | -- POSE_GRAPH.matcher_translation_weight = 5.0e3 91 | -- POSE_GRAPH.matcher_rotation_weight = 1.6e3 92 | 93 | -- POSE_GRAPH.optimization_problem.local_slam_pose_translation_weight = 1.0e5 94 | -- POSE_GRAPH.optimization_problem.local_slam_pose_rotation_weight = 2.0e5 95 | -- POSE_GRAPH.optimization_problem.odometry_translation_weight = 1e6 96 | -- POSE_GRAPH.optimization_problem.odometry_rotation_weight = 4e5 97 | -- POSE_GRAPH.optimization_problem.huber_scale = 1. 98 | -- POSE_GRAPH.optimization_problem.ceres_solver_options.max_num_iterations = 100 99 | -- POSE_GRAPH.optimization_problem.ceres_solver_options.num_threads = 1 100 | 101 | -- POSE_GRAPH.max_num_final_iterations = 100 102 | 103 | 104 | return options 105 | 106 | -------------------------------------------------------------------------------- /cartographer/configuration_files/scan_only.lua: -------------------------------------------------------------------------------- 1 | -- Copyright 2016 The Cartographer Authors 2 | -- 3 | -- Licensed under the Apache License, Version 2.0 (the "License"); 4 | -- you may not use this file except in compliance with the License. 5 | -- You may obtain a copy of the License at 6 | -- 7 | -- http://www.apache.org/licenses/LICENSE-2.0 8 | -- 9 | -- Unless required by applicable law or agreed to in writing, software 10 | -- distributed under the License is distributed on an "AS IS" BASIS, 11 | -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | -- See the License for the specific language governing permissions and 13 | -- limitations under the License. 14 | 15 | include "map_builder.lua" 16 | include "trajectory_builder.lua" 17 | 18 | options = { 19 | map_builder = MAP_BUILDER, 20 | trajectory_builder = TRAJECTORY_BUILDER, 21 | map_frame = "map", 22 | tracking_frame = "base_link", 23 | published_frame = "base_link", 24 | odom_frame = "odom", 25 | provide_odom_frame = false, 26 | publish_frame_projected_to_2d = false, 27 | use_pose_extrapolator = true, 28 | use_odometry = false, 29 | use_nav_sat = false, 30 | use_landmarks = false, 31 | num_laser_scans = 1, 32 | num_multi_echo_laser_scans = 0, 33 | num_subdivisions_per_laser_scan = 1, 34 | num_point_clouds = 0, 35 | lookup_transform_timeout_sec = 0.2, 36 | submap_publish_period_sec = 0.3, 37 | pose_publish_period_sec = 33e-3, 38 | trajectory_publish_period_sec = 33e-3, 39 | rangefinder_sampling_ratio = 1., 40 | odometry_sampling_ratio = 1., 41 | fixed_frame_pose_sampling_ratio = 1., 42 | imu_sampling_ratio = 1., 43 | landmarks_sampling_ratio = 1., 44 | } 45 | 46 | MAP_BUILDER.use_trajectory_builder_2d = true 47 | TRAJECTORY_BUILDER_2D.use_imu_data = false 48 | TRAJECTORY_BUILDER_2D.min_range = 0.3 49 | TRAJECTORY_BUILDER_2D.max_range = 4.284 50 | TRAJECTORY_BUILDER_2D.missing_data_ray_length = 99999 51 | TRAJECTORY_BUILDER_2D.num_accumulated_range_data = 1 52 | TRAJECTORY_BUILDER_2D.submaps.num_range_data = 60 53 | 54 | TRAJECTORY_BUILDER_2D.use_online_correlative_scan_matching = false 55 | TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.linear_search_window = 0.1 56 | TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.angular_search_window = math.rad(60.) 57 | TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.translation_delta_cost_weight = 0.10 58 | TRAJECTORY_BUILDER_2D.real_time_correlative_scan_matcher.rotation_delta_cost_weight = 1.0 59 | 60 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.occupied_space_weight = 1e3 61 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.translation_weight = 1e3 62 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.rotation_weight = 1e3 63 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.ceres_solver_options.use_nonmonotonic_steps = false 64 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.ceres_solver_options.max_num_iterations = 20 65 | TRAJECTORY_BUILDER_2D.ceres_scan_matcher.ceres_solver_options.num_threads = 1 66 | 67 | TRAJECTORY_BUILDER_2D.motion_filter.max_time_seconds = 0.0 68 | TRAJECTORY_BUILDER_2D.motion_filter.max_distance_meters = 0.0 69 | TRAJECTORY_BUILDER_2D.motion_filter.max_angle_radians = math.rad(0.0) 70 | 71 | POSE_GRAPH.optimize_every_n_nodes = 30 72 | POSE_GRAPH.optimization_problem.local_slam_pose_translation_weight = 5e5 73 | POSE_GRAPH.optimization_problem.local_slam_pose_rotation_weight = 5e5 74 | POSE_GRAPH.optimization_problem.odometry_translation_weight = 0 75 | POSE_GRAPH.optimization_problem.odometry_rotation_weight = 0 76 | 77 | POSE_GRAPH.constraint_builder.sampling_ratio = .1 78 | POSE_GRAPH.constraint_builder.min_score = 0.8 79 | 80 | POSE_GRAPH.optimization_problem.ceres_solver_options.max_num_iterations = 10 81 | POSE_GRAPH.optimization_problem.ceres_solver_options.num_threads = 4 82 | POSE_GRAPH.max_num_final_iterations = 10 83 | 84 | 85 | return options 86 | -------------------------------------------------------------------------------- /cartographer/launch/backpack_2d.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 18 | 19 | 20 | 21 | 22 | 23 | 25 | 26 | 28 | 29 | 34 | 35 | 36 | 37 | 39 | 40 | -------------------------------------------------------------------------------- /cartographer/launch/demo_backpack_2d.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 30 | 32 | 33 | -------------------------------------------------------------------------------- /cartographer/launch/offline_backpack_2d.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /cartographer/launch/offline_node.launch: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 34 | 35 | 39 | 40 | 41 | 50 | 51 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConnectedSystemsLab/radarize_ae/cbcb6439b89c068d86a424215ee61edc925191ae/configs/default.yaml -------------------------------------------------------------------------------- /configs/main_0.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN_SPLIT: 3 | - cart_csl_upper_0 4 | - bot_ece_basement_5 5 | - bot_beckman_3rdfloor_1 6 | - bot_beckman_3rdfloor_3 7 | - cart_csl_upper_4 8 | - cart_csl_upper_2 9 | - bot_beckman_5thfloor_4 10 | - cart_sc_loop_0 11 | - bot_beckman_3rdfloor_2 12 | - walk_sc_groundfloor_2 13 | - walk_csl_basement_2 14 | - cart_sc_loop_2 15 | - bot_csl_4thfloor_loop 16 | - bot_ece_basement_0 17 | - bot_beckman_3rdfloor_0 18 | - walk_ece_upper_0 19 | - bot_sc_basement_0 20 | - bot_csl_basement2_2 21 | - walk_sc_basement_2 22 | - walk_sc_office_3 23 | - bot_ece_basement_8 24 | - walk_ece_upper_1 25 | - cart_ece_upper_4 26 | - cart_ece_basement_1 27 | - cart_sc_loop_1 28 | - walk_sc_office_0 29 | - cart_ece_basement_3 30 | - cart_sc_5 31 | - walk_ece_basement_1 32 | - bot_ece_basement_4 33 | - walk_sc_office_4 34 | - walk_sc_basement_1 35 | - walk_sc_basement_3 36 | - bot_sc_basement_2 37 | - bot_csl_basement_1 38 | - bot_csl_basement2_0 39 | - bot_beckman_5thfloor_3 40 | - bot_ece_5thfloor_0 41 | - walk_sc_office_1 42 | - bot_csl_basement_3 43 | - bot_csl_4thfloor_1 44 | - bot_ece_basement_6 45 | - cart_sc_1 46 | - bot_ece_5thfloor_3 47 | - bot_beckman_5thfloor_2 48 | - walk_sc_groundfloor_0 49 | - bot_sc_basement_3 50 | - cart_ece_upper_3 51 | - walk_ece_upper_2 52 | - cart_ece_upper_loop_0 53 | - walk_sc_groundfloor_3 54 | - bot_csl_basement_hall 55 | - bot_sc_basement_4 56 | - cart_ece_basement_0 57 | - walk_csl_basement_3 58 | - walk_csl_groundfloor_3 59 | - walk_sc_basement_0 60 | - bot_csl_basement_0 61 | - walk_sc_groundfloor_1 62 | - cart_csl_basement_1 63 | - cart_csl_upper_3 64 | - walk_csl_basement_0 65 | - walk_ece_basement_0 66 | - cart_ece_gf_3 67 | - cart_csl_basement_0 68 | - bot_beckman_basement_0 69 | - walk_ece_basement_3 70 | - walk_ece_basement_5 71 | - bot_sc_4thfloor_2 72 | - bot_sc_basement_1 73 | - bot_sc_4thfloor_3 74 | - cart_csl_upper_loop_1 75 | - walk_ece_upper_3 76 | - bot_beckman_basement_3 77 | - bot_beckman_5thfloor_0 78 | - bot_csl_basement2_3 79 | - cart_ece_upper_2 80 | - bot_beckman_basement_1 81 | - bot_csl_basement_4 82 | - walk_sc_office_2 83 | - bot_ece_basement_9 84 | - bot_csl_basement_2 85 | - bot_ece_5thfloor_1 86 | 87 | VAL_SPLIT: 88 | - cart_csl_upper_loop_0 89 | - cart_sc_2 90 | - bot_ece_basement_2 91 | - walk_csl_groundfloor_0 92 | - bot_csl_4thfloor_2 93 | - walk_csl_basement_1 94 | - walk_csl_groundfloor_2 95 | - cart_ece_upper_1 96 | - bot_ece_5thfloor_4 97 | - cart_sc_3 98 | - bot_ece_basement_7 99 | - cart_csl_basement_3 100 | - bot_sc_4thfloor_1 101 | 102 | TEST_SPLIT: 103 | - bot_beckman_5thfloor_1 104 | - cart_ece_gf_0 105 | - cart_sc_loop_3 106 | - cart_ece_basement_2 107 | - cart_csl_upper_1 108 | - walk_ece_basement_4 109 | - cart_ece_upper_0 110 | - walk_csl_basement_4 111 | - walk_ece_basement_2 112 | - cart_sc_4 113 | - walk_csl_groundfloor_1 114 | - bot_csl_4thfloor_0 115 | - bot_beckman_basement_4 116 | - bot_csl_basement2_1 117 | - bot_beckman_3rdfloor_4 118 | - cart_ece_gf_1 119 | - bot_csl_4thfloor_3 120 | - bot_beckman_basement_2 121 | - cart_ece_upper_loop_1 122 | - walk_ece_upper_4 123 | - cart_ece_gf_2 124 | - walk_sc_basement_4 125 | - cart_sc_0 126 | - walk_ece_basement_6 127 | - bot_sc_4thfloor_0 128 | - cart_csl_basement_2 129 | - bot_ece_basement_3 130 | - bot_ece_basement_1 131 | - bot_ece_5thfloor_2 132 | 133 | 134 | OUTPUT_DIR: 'output_main_0/' 135 | 136 | 137 | FLOW: 138 | TRAIN: 139 | EPOCHS: 30 140 | 141 | ROTNET: 142 | TRAIN: 143 | EPOCHS: 30 144 | 145 | UNET: 146 | TRAIN: 147 | EPOCHS: 5 148 | 149 | -------------------------------------------------------------------------------- /configs/main_1.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN_SPLIT: 3 | - cart_sc_loop_2 4 | - bot_csl_basement2_1 5 | - cart_csl_upper_1 6 | - cart_csl_upper_3 7 | - cart_ece_upper_1 8 | - bot_ece_basement_0 9 | - cart_csl_upper_2 10 | - cart_ece_basement_3 11 | - bot_csl_basement_4 12 | - bot_beckman_5thfloor_4 13 | - bot_csl_basement_hall 14 | - cart_sc_4 15 | - bot_csl_basement_0 16 | - bot_csl_basement_3 17 | - walk_sc_basement_4 18 | - cart_ece_upper_0 19 | - cart_sc_0 20 | - bot_ece_5thfloor_2 21 | - bot_sc_4thfloor_3 22 | - cart_ece_gf_2 23 | - bot_csl_basement2_0 24 | - bot_ece_basement_5 25 | - bot_csl_basement2_2 26 | - bot_csl_basement_1 27 | - bot_beckman_5thfloor_0 28 | - bot_sc_4thfloor_2 29 | - cart_ece_upper_loop_0 30 | - walk_csl_groundfloor_3 31 | - bot_beckman_5thfloor_2 32 | - walk_ece_basement_6 33 | - walk_sc_office_2 34 | - bot_ece_5thfloor_0 35 | - bot_beckman_3rdfloor_0 36 | - bot_sc_4thfloor_0 37 | - walk_csl_groundfloor_0 38 | - walk_ece_upper_0 39 | - bot_sc_basement_3 40 | - cart_sc_1 41 | - walk_ece_upper_4 42 | - walk_ece_basement_3 43 | - walk_sc_groundfloor_2 44 | - bot_csl_4thfloor_0 45 | - cart_sc_loop_3 46 | - walk_sc_office_1 47 | - bot_ece_5thfloor_1 48 | - cart_ece_basement_1 49 | - bot_ece_5thfloor_4 50 | - bot_beckman_3rdfloor_2 51 | - cart_ece_upper_3 52 | - bot_beckman_3rdfloor_1 53 | - bot_beckman_5thfloor_3 54 | - bot_sc_basement_4 55 | - cart_csl_basement_2 56 | - walk_sc_groundfloor_3 57 | - bot_ece_basement_4 58 | - bot_beckman_basement_3 59 | - walk_sc_office_3 60 | - cart_ece_basement_0 61 | - bot_ece_5thfloor_3 62 | - cart_csl_basement_0 63 | - bot_ece_basement_8 64 | - walk_sc_groundfloor_0 65 | - walk_sc_office_0 66 | - walk_ece_basement_1 67 | - bot_csl_4thfloor_loop 68 | - bot_csl_4thfloor_2 69 | - cart_ece_gf_1 70 | - bot_sc_basement_1 71 | - cart_csl_basement_3 72 | - walk_csl_basement_1 73 | - bot_ece_basement_7 74 | - bot_beckman_3rdfloor_4 75 | - bot_ece_basement_2 76 | - cart_ece_upper_4 77 | - walk_csl_groundfloor_1 78 | - bot_csl_4thfloor_3 79 | - walk_ece_basement_0 80 | - cart_sc_loop_0 81 | - walk_sc_basement_3 82 | - walk_sc_groundfloor_4 83 | - cart_csl_basement_1 84 | - cart_sc_5 85 | - bot_beckman_basement_1 86 | 87 | VAL_SPLIT: 88 | - walk_csl_basement_0 89 | - walk_csl_basement_4 90 | - bot_beckman_basement_0 91 | - bot_ece_basement_1 92 | - bot_ece_basement_9 93 | - cart_sc_3 94 | - walk_sc_basement_2 95 | - bot_csl_basement_2 96 | - cart_ece_basement_2 97 | - walk_ece_upper_3 98 | - cart_sc_loop_1 99 | - cart_ece_upper_loop_1 100 | - cart_csl_upper_4 101 | 102 | TEST_SPLIT: 103 | - walk_sc_office_4 104 | - cart_sc_2 105 | - walk_ece_basement_2 106 | - bot_beckman_basement_4 107 | - walk_ece_basement_5 108 | - walk_ece_basement_4 109 | - cart_csl_upper_loop_1 110 | - bot_ece_basement_6 111 | - cart_ece_gf_0 112 | - walk_csl_basement_3 113 | - bot_ece_basement_3 114 | - cart_csl_upper_0 115 | - walk_sc_basement_1 116 | - bot_sc_basement_2 117 | - walk_sc_basement_0 118 | - walk_csl_groundfloor_2 119 | - bot_beckman_5thfloor_1 120 | - cart_ece_gf_3 121 | - bot_sc_basement_0 122 | - bot_csl_basement2_3 123 | - walk_csl_basement_2 124 | - walk_ece_upper_1 125 | - cart_csl_upper_loop_0 126 | - walk_ece_upper_2 127 | - bot_beckman_basement_2 128 | - bot_sc_4thfloor_1 129 | - bot_beckman_3rdfloor_3 130 | - bot_csl_4thfloor_1 131 | - walk_sc_groundfloor_1 132 | 133 | 134 | OUTPUT_DIR: 'output_main_1/' 135 | 136 | 137 | FLOW: 138 | TRAIN: 139 | EPOCHS: 30 140 | 141 | ROTNET: 142 | TRAIN: 143 | EPOCHS: 30 144 | 145 | UNET: 146 | TRAIN: 147 | EPOCHS: 5 148 | 149 | -------------------------------------------------------------------------------- /configs/main_2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN_SPLIT: 3 | - walk_ece_upper_3 4 | - cart_sc_3 5 | - cart_ece_basement_1 6 | - bot_sc_4thfloor_0 7 | - cart_ece_upper_0 8 | - bot_beckman_5thfloor_4 9 | - bot_ece_basement_2 10 | - cart_ece_gf_0 11 | - walk_csl_basement_2 12 | - bot_ece_5thfloor_2 13 | - walk_sc_groundfloor_3 14 | - walk_ece_basement_6 15 | - cart_sc_loop_3 16 | - cart_ece_upper_loop_0 17 | - bot_csl_basement2_2 18 | - cart_ece_gf_3 19 | - bot_sc_basement_0 20 | - walk_ece_basement_0 21 | - cart_ece_upper_loop_1 22 | - cart_sc_0 23 | - bot_sc_4thfloor_3 24 | - bot_sc_4thfloor_2 25 | - cart_ece_gf_2 26 | - walk_ece_upper_1 27 | - walk_ece_basement_2 28 | - bot_csl_basement_4 29 | - cart_sc_1 30 | - walk_sc_office_2 31 | - cart_ece_basement_0 32 | - bot_ece_5thfloor_3 33 | - walk_ece_upper_2 34 | - walk_csl_groundfloor_2 35 | - walk_ece_basement_5 36 | - bot_ece_basement_9 37 | - bot_beckman_5thfloor_1 38 | - bot_csl_4thfloor_loop 39 | - bot_csl_basement2_1 40 | - bot_beckman_5thfloor_3 41 | - walk_ece_basement_3 42 | - bot_beckman_3rdfloor_4 43 | - walk_ece_basement_4 44 | - walk_sc_office_3 45 | - cart_csl_upper_loop_0 46 | - bot_csl_basement2_0 47 | - bot_ece_basement_0 48 | - cart_csl_basement_3 49 | - bot_ece_5thfloor_0 50 | - walk_sc_basement_3 51 | - walk_csl_basement_4 52 | - cart_csl_basement_2 53 | - bot_ece_basement_1 54 | - bot_beckman_3rdfloor_2 55 | - bot_csl_4thfloor_0 56 | - bot_sc_basement_4 57 | - bot_beckman_5thfloor_0 58 | - bot_beckman_basement_0 59 | - bot_ece_basement_5 60 | - bot_ece_5thfloor_4 61 | - bot_sc_basement_2 62 | - cart_csl_upper_2 63 | - bot_beckman_3rdfloor_0 64 | - cart_sc_loop_0 65 | - bot_ece_basement_4 66 | - cart_sc_5 67 | - bot_ece_basement_3 68 | - bot_beckman_basement_2 69 | - walk_csl_basement_1 70 | - bot_csl_basement_hall 71 | - bot_csl_4thfloor_3 72 | - cart_sc_loop_1 73 | - bot_beckman_basement_1 74 | - bot_beckman_basement_3 75 | - cart_ece_basement_2 76 | - cart_ece_gf_1 77 | - walk_ece_upper_4 78 | - cart_sc_loop_2 79 | - walk_sc_groundfloor_1 80 | - bot_beckman_3rdfloor_3 81 | - bot_ece_basement_7 82 | - bot_sc_basement_3 83 | - cart_ece_upper_1 84 | - walk_ece_upper_0 85 | - bot_beckman_basement_4 86 | 87 | VAL_SPLIT: 88 | - bot_ece_5thfloor_1 89 | - bot_sc_4thfloor_1 90 | - walk_sc_basement_1 91 | - bot_ece_basement_8 92 | - bot_csl_4thfloor_1 93 | - walk_ece_basement_1 94 | - bot_beckman_5thfloor_2 95 | - cart_csl_basement_1 96 | - cart_ece_basement_3 97 | - cart_csl_basement_0 98 | - walk_csl_basement_0 99 | - walk_csl_groundfloor_1 100 | - cart_csl_upper_3 101 | 102 | TEST_SPLIT: 103 | - walk_sc_groundfloor_0 104 | - bot_ece_basement_6 105 | - walk_sc_office_4 106 | - cart_sc_4 107 | - bot_sc_basement_1 108 | - bot_csl_basement_3 109 | - bot_csl_basement_2 110 | - bot_beckman_3rdfloor_1 111 | - walk_sc_basement_4 112 | - walk_sc_basement_2 113 | - bot_csl_4thfloor_2 114 | - walk_csl_groundfloor_3 115 | - cart_csl_upper_0 116 | - cart_csl_upper_loop_1 117 | - cart_csl_upper_4 118 | - walk_csl_basement_3 119 | - bot_csl_basement2_3 120 | - walk_sc_office_1 121 | - cart_sc_2 122 | - walk_sc_office_0 123 | - bot_csl_basement_1 124 | - cart_ece_upper_4 125 | - walk_csl_groundfloor_0 126 | - bot_csl_basement_0 127 | - walk_sc_basement_0 128 | - cart_csl_upper_1 129 | - walk_sc_groundfloor_2 130 | - cart_ece_upper_3 131 | 132 | 133 | OUTPUT_DIR: 'output_main_2/' 134 | 135 | 136 | FLOW: 137 | TRAIN: 138 | EPOCHS: 30 139 | 140 | ROTNET: 141 | TRAIN: 142 | EPOCHS: 30 143 | 144 | UNET: 145 | TRAIN: 146 | EPOCHS: 5 147 | 148 | -------------------------------------------------------------------------------- /configs/main_3.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN_SPLIT: 3 | - walk_sc_office_3 4 | - bot_sc_basement_1 5 | - bot_ece_basement_8 6 | - bot_sc_basement_0 7 | - bot_ece_basement_6 8 | - cart_sc_loop_1 9 | - cart_ece_upper_1 10 | - bot_csl_basement_0 11 | - bot_sc_4thfloor_0 12 | - bot_sc_4thfloor_1 13 | - cart_csl_basement_0 14 | - bot_csl_4thfloor_1 15 | - bot_csl_basement2_2 16 | - cart_sc_5 17 | - bot_csl_4thfloor_0 18 | - walk_csl_basement_0 19 | - bot_sc_basement_4 20 | - walk_ece_upper_2 21 | - walk_csl_basement_2 22 | - cart_ece_upper_3 23 | - bot_csl_basement_2 24 | - cart_csl_upper_0 25 | - bot_sc_basement_2 26 | - bot_ece_5thfloor_2 27 | - walk_sc_office_4 28 | - cart_sc_loop_3 29 | - bot_csl_basement2_0 30 | - walk_sc_basement_4 31 | - cart_ece_gf_3 32 | - walk_ece_upper_4 33 | - walk_sc_office_2 34 | - walk_csl_groundfloor_2 35 | - cart_sc_4 36 | - bot_beckman_basement_2 37 | - bot_beckman_5thfloor_2 38 | - cart_ece_basement_0 39 | - bot_ece_basement_9 40 | - cart_csl_upper_1 41 | - bot_beckman_5thfloor_3 42 | - cart_csl_upper_loop_1 43 | - bot_ece_basement_1 44 | - walk_sc_groundfloor_2 45 | - bot_ece_5thfloor_0 46 | - walk_csl_groundfloor_1 47 | - cart_csl_upper_4 48 | - bot_sc_4thfloor_3 49 | - bot_sc_basement_3 50 | - walk_ece_basement_3 51 | - cart_ece_basement_3 52 | - walk_csl_groundfloor_3 53 | - cart_csl_basement_2 54 | - bot_csl_basement_3 55 | - bot_ece_basement_0 56 | - bot_beckman_5thfloor_0 57 | - walk_ece_basement_5 58 | - walk_ece_upper_1 59 | - walk_ece_basement_2 60 | - walk_sc_groundfloor_0 61 | - bot_ece_basement_5 62 | - walk_sc_basement_1 63 | - walk_sc_basement_3 64 | - bot_csl_4thfloor_2 65 | - cart_ece_upper_0 66 | - walk_ece_basement_6 67 | - bot_beckman_3rdfloor_4 68 | - bot_csl_basement2_3 69 | - walk_ece_basement_0 70 | - bot_sc_4thfloor_2 71 | - walk_sc_groundfloor_3 72 | - walk_ece_upper_0 73 | - bot_csl_basement_4 74 | - bot_ece_5thfloor_3 75 | - walk_sc_groundfloor_4 76 | - walk_csl_groundfloor_0 77 | - bot_ece_basement_3 78 | - walk_sc_basement_2 79 | - cart_ece_basement_1 80 | - cart_ece_gf_0 81 | - walk_ece_basement_1 82 | - bot_csl_4thfloor_loop 83 | - bot_beckman_3rdfloor_3 84 | - walk_csl_basement_1 85 | - cart_sc_1 86 | 87 | VAL_SPLIT: 88 | - cart_ece_gf_2 89 | - cart_sc_loop_2 90 | - bot_ece_5thfloor_1 91 | - cart_sc_loop_0 92 | - bot_beckman_basement_3 93 | - walk_sc_basement_0 94 | - bot_csl_4thfloor_3 95 | - bot_csl_basement2_1 96 | - walk_csl_basement_4 97 | - cart_csl_upper_2 98 | - bot_beckman_basement_0 99 | - bot_ece_basement_2 100 | - cart_ece_upper_2 101 | 102 | TEST_SPLIT: 103 | - walk_ece_upper_3 104 | - cart_csl_basement_1 105 | - bot_ece_5thfloor_4 106 | - bot_beckman_5thfloor_1 107 | - bot_csl_basement_1 108 | - cart_csl_upper_3 109 | - cart_csl_upper_loop_0 110 | - cart_csl_basement_3 111 | - bot_beckman_3rdfloor_1 112 | - walk_ece_basement_4 113 | - walk_sc_office_1 114 | - cart_ece_upper_loop_1 115 | - walk_csl_basement_3 116 | - bot_beckman_basement_1 117 | - cart_sc_2 118 | - bot_beckman_3rdfloor_2 119 | - bot_csl_basement_hall 120 | - cart_ece_gf_1 121 | - walk_sc_office_0 122 | - bot_beckman_3rdfloor_0 123 | - walk_sc_groundfloor_1 124 | - bot_ece_basement_7 125 | - cart_sc_0 126 | - cart_ece_upper_4 127 | - cart_ece_basement_2 128 | - bot_beckman_basement_4 129 | - bot_ece_basement_4 130 | 131 | 132 | OUTPUT_DIR: 'output_main_3/' 133 | 134 | 135 | FLOW: 136 | TRAIN: 137 | EPOCHS: 30 138 | 139 | ROTNET: 140 | TRAIN: 141 | EPOCHS: 30 142 | 143 | UNET: 144 | TRAIN: 145 | EPOCHS: 5 146 | 147 | -------------------------------------------------------------------------------- /configs/main_4.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN_SPLIT: 3 | - walk_csl_groundfloor_0 4 | - bot_csl_4thfloor_1 5 | - bot_beckman_5thfloor_0 6 | - cart_ece_gf_3 7 | - cart_sc_loop_0 8 | - walk_ece_upper_2 9 | - bot_ece_basement_7 10 | - walk_ece_basement_2 11 | - walk_ece_basement_3 12 | - walk_csl_groundfloor_3 13 | - cart_sc_1 14 | - bot_ece_basement_1 15 | - walk_sc_basement_3 16 | - bot_beckman_3rdfloor_0 17 | - bot_ece_basement_6 18 | - walk_sc_groundfloor_0 19 | - cart_sc_loop_2 20 | - bot_beckman_3rdfloor_2 21 | - bot_csl_basement_0 22 | - cart_sc_4 23 | - bot_ece_basement_2 24 | - walk_sc_office_2 25 | - bot_beckman_3rdfloor_1 26 | - bot_csl_basement2_3 27 | - bot_beckman_5thfloor_3 28 | - walk_csl_basement_1 29 | - walk_sc_groundfloor_1 30 | - walk_sc_basement_2 31 | - bot_beckman_3rdfloor_4 32 | - walk_ece_upper_4 33 | - bot_sc_basement_1 34 | - bot_sc_basement_4 35 | - walk_sc_office_4 36 | - bot_ece_basement_5 37 | - cart_csl_upper_4 38 | - bot_beckman_5thfloor_1 39 | - walk_ece_basement_6 40 | - cart_sc_3 41 | - bot_csl_basement2_1 42 | - bot_beckman_5thfloor_2 43 | - walk_sc_basement_1 44 | - bot_beckman_basement_3 45 | - bot_csl_basement2_2 46 | - cart_ece_gf_1 47 | - bot_ece_5thfloor_1 48 | - cart_csl_basement_3 49 | - cart_ece_basement_1 50 | - bot_ece_5thfloor_3 51 | - bot_sc_basement_2 52 | - bot_csl_4thfloor_2 53 | - walk_ece_upper_0 54 | - bot_sc_4thfloor_0 55 | - bot_ece_basement_8 56 | - cart_csl_upper_loop_0 57 | - cart_csl_upper_loop_1 58 | - bot_ece_basement_0 59 | - bot_csl_4thfloor_loop 60 | - bot_ece_basement_3 61 | - walk_csl_basement_2 62 | - bot_sc_basement_3 63 | - walk_ece_basement_5 64 | - bot_ece_5thfloor_2 65 | - cart_ece_basement_0 66 | - cart_ece_upper_0 67 | - walk_sc_office_0 68 | - bot_beckman_basement_2 69 | - cart_ece_upper_2 70 | - cart_ece_gf_0 71 | - cart_sc_loop_3 72 | - walk_ece_upper_3 73 | - walk_sc_office_1 74 | - cart_ece_gf_2 75 | - bot_csl_4thfloor_3 76 | - cart_csl_basement_2 77 | - cart_ece_basement_2 78 | - bot_sc_4thfloor_3 79 | - cart_sc_2 80 | - walk_sc_basement_0 81 | - cart_csl_upper_1 82 | - walk_csl_basement_3 83 | - cart_csl_upper_3 84 | - bot_ece_5thfloor_4 85 | - bot_sc_4thfloor_1 86 | - bot_sc_basement_0 87 | 88 | VAL_SPLIT: 89 | - cart_ece_upper_3 90 | - walk_sc_office_3 91 | - bot_beckman_basement_0 92 | - cart_csl_basement_1 93 | - bot_ece_basement_4 94 | - bot_beckman_5thfloor_4 95 | - bot_ece_basement_9 96 | - bot_csl_basement_1 97 | - bot_csl_4thfloor_0 98 | - walk_csl_basement_0 99 | - bot_sc_4thfloor_2 100 | - walk_ece_basement_1 101 | 102 | TEST_SPLIT: 103 | - walk_ece_basement_0 104 | - cart_sc_5 105 | - cart_csl_upper_0 106 | - cart_sc_0 107 | - bot_csl_basement_4 108 | - bot_csl_basement2_0 109 | - walk_sc_basement_4 110 | - cart_ece_upper_loop_1 111 | - walk_ece_basement_4 112 | - cart_ece_basement_3 113 | - bot_csl_basement_2 114 | - walk_csl_groundfloor_1 115 | - walk_sc_groundfloor_2 116 | - cart_csl_basement_0 117 | - walk_csl_groundfloor_2 118 | - walk_csl_basement_4 119 | - cart_sc_loop_1 120 | - bot_csl_basement_hall 121 | - bot_csl_basement_3 122 | - walk_ece_upper_1 123 | - cart_csl_upper_2 124 | - cart_ece_upper_4 125 | - bot_beckman_basement_4 126 | - bot_beckman_basement_1 127 | - cart_ece_upper_1 128 | - bot_beckman_3rdfloor_3 129 | - bot_ece_5thfloor_0 130 | 131 | 132 | OUTPUT_DIR: 'output_main_4/' 133 | 134 | 135 | FLOW: 136 | TRAIN: 137 | EPOCHS: 30 138 | 139 | ROTNET: 140 | TRAIN: 141 | EPOCHS: 30 142 | 143 | UNET: 144 | TRAIN: 145 | EPOCHS: 5 146 | 147 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: radarize_ae 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.9.16 9 | - torchvision==0.11.0 10 | - pytorch==1.10.0 11 | - cudatoolkit=11.3 12 | - pip 13 | - pip: 14 | - setuptools 15 | - -r requirements.txt 16 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | 6 | import argparse 7 | import glob 8 | import multiprocessing 9 | import subprocess 10 | 11 | from radarize.config import cfg, update_config 12 | 13 | def args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--cfg", 17 | help="experiment configure file name", 18 | default="configs/default.yaml", 19 | type=str, 20 | ) 21 | parser.add_argument( 22 | "--n_proc", 23 | type=int, 24 | default=1, 25 | help="Number of processes to use for parallel processing.", 26 | ) 27 | parser.add_argument( 28 | "opts", 29 | help="Modify config options using the command-line", 30 | default=None, 31 | nargs=argparse.REMAINDER, 32 | ) 33 | args = parser.parse_args() 34 | 35 | return args 36 | 37 | 38 | def run_commands(cmds, n_proc): 39 | with multiprocessing.Pool(n_proc) as pool: 40 | pool.map(subprocess.run, cmds) 41 | 42 | 43 | if __name__ == "__main__": 44 | args = args() 45 | update_config(cfg, args) 46 | 47 | # Preprocess datasets. (bag -> npz) 48 | bag_paths = sorted(glob.glob(os.path.join(cfg["DATASET"]["PATH"], "*.bag"))) 49 | bag_paths = [x for x in bag_paths if not os.path.exists(x.replace(".bag", ".npz"))] 50 | run_commands( 51 | [ 52 | [f"tools/create_dataset.py", f"--cfg={args.cfg}", f"--bag_path={x}"] 53 | for x in bag_paths 54 | ], 55 | args.n_proc, 56 | ) 57 | 58 | train_npz_paths = sorted( 59 | [ 60 | os.path.join(cfg["DATASET"]["PATH"], os.path.basename(x) + ".npz") 61 | for x in cfg["DATASET"]["TRAIN_SPLIT"] 62 | ] 63 | ) 64 | test_npz_paths = sorted( 65 | [ 66 | os.path.join(cfg["DATASET"]["PATH"], os.path.basename(x) + ".npz") 67 | for x in cfg["DATASET"]["TEST_SPLIT"] 68 | ] 69 | ) 70 | 71 | # Extract ground truth. 72 | run_commands( 73 | [ 74 | ["tools/extract_gt.py", f"--cfg={args.cfg}", f"--npz_path={x}"] 75 | for x in test_npz_paths 76 | ], 77 | args.n_proc, 78 | ) 79 | 80 | # Train flow models. 81 | subprocess.run(["tools/train_flow.py", f"--cfg={args.cfg}"], check=True) 82 | subprocess.run(["tools/test_flow.py", f"--cfg={args.cfg}"], check=True) 83 | 84 | # Train rotnet models. 85 | subprocess.run(["tools/train_rot.py", f"--cfg={args.cfg}"], check=True) 86 | subprocess.run(["tools/test_rot.py", f"--cfg={args.cfg}"], check=True) 87 | 88 | # Extract odometry. 89 | run_commands( 90 | [ 91 | ["tools/test_odom.py", f"--cfg={args.cfg}", f"--npz_path={x}"] 92 | for x in test_npz_paths 93 | ], 94 | args.n_proc, 95 | ) 96 | 97 | # Train UNet 98 | subprocess.run(["tools/train_unet.py", f"--cfg={args.cfg}"], check=True) 99 | run_commands( 100 | [ 101 | ["tools/test_unet.py", f"--cfg={args.cfg}", f"--npz_path={x}"] 102 | for x in test_npz_paths 103 | ], 104 | args.n_proc, 105 | ) 106 | 107 | ### Run Cartographer. 108 | # Get ground truth. 109 | subprocess.run( 110 | [ 111 | "tools/run_carto.py", 112 | f"--cfg={args.cfg}", 113 | f"--n_proc=1", 114 | f"--odom=gt", 115 | f"--scan=gt", 116 | f"--params=default", 117 | ], 118 | check=True, 119 | ) 120 | 121 | # RadarHD baseline. 122 | subprocess.run( 123 | [ 124 | "tools/run_carto.py", 125 | f"--cfg={args.cfg}", 126 | f"--n_proc=1", 127 | f"--odom=gt", 128 | f"--scan=radarhd", 129 | f"--params=scan_only", 130 | ], 131 | check=True, 132 | ) 133 | 134 | # RNIN + RadarHD baseline. 135 | subprocess.run( 136 | [ 137 | "tools/run_carto.py", 138 | f"--cfg={args.cfg}", 139 | f"--n_proc=1", 140 | f"--odom=rnin", 141 | f"--scan=radarhd", 142 | f"--params=default", 143 | ], 144 | check=True, 145 | ) 146 | 147 | # milliEgo + RadarHD baseline. 148 | subprocess.run( 149 | [ 150 | "tools/run_carto.py", 151 | f"--cfg={args.cfg}", 152 | f"--n_proc=1", 153 | f"--odom=milliego", 154 | f"--scan=radarhd", 155 | f"--params=default", 156 | ], 157 | check=True, 158 | ) 159 | 160 | # Our odometry + RadarHD baseline. 161 | subprocess.run( 162 | [ 163 | "tools/run_carto.py", 164 | f"--cfg={args.cfg}", 165 | f"--n_proc=1", 166 | f"--odom=odometry", 167 | f"--scan=radarhd", 168 | f"--params=radar", 169 | ], 170 | check=True, 171 | ) 172 | 173 | # Run radarize. 174 | subprocess.run( 175 | [ 176 | "tools/run_carto.py", 177 | f"--cfg={args.cfg}", 178 | f"--n_proc=1", 179 | f"--odom=odometry", 180 | f"--scan=unet", 181 | f"--params=radar", 182 | ], 183 | check=True, 184 | ) 185 | 186 | -------------------------------------------------------------------------------- /main_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | 6 | import argparse 7 | import glob 8 | import multiprocessing 9 | import subprocess 10 | 11 | from radarize.config import cfg, update_config 12 | 13 | def args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--cfg", 17 | help="experiment configure file name", 18 | default="configs/default.yaml", 19 | type=str, 20 | ) 21 | parser.add_argument( 22 | "--n_proc", 23 | type=int, 24 | default=1, 25 | help="Number of processes to use for parallel processing.", 26 | ) 27 | parser.add_argument( 28 | "opts", 29 | help="Modify config options using the command-line", 30 | default=None, 31 | nargs=argparse.REMAINDER, 32 | ) 33 | args = parser.parse_args() 34 | 35 | return args 36 | 37 | 38 | def run_commands(cmds, n_proc): 39 | with multiprocessing.Pool(n_proc) as pool: 40 | pool.map(subprocess.run, cmds) 41 | 42 | 43 | if __name__ == "__main__": 44 | args = args() 45 | update_config(cfg, args) 46 | 47 | # Preprocess datasets. (bag -> npz) 48 | bag_paths = sorted(glob.glob(os.path.join(cfg["DATASET"]["PATH"], "*.bag"))) 49 | bag_paths = [x for x in bag_paths if not os.path.exists(x.replace(".bag", ".npz"))] 50 | run_commands( 51 | [ 52 | [f"tools/create_dataset.py", f"--cfg={args.cfg}", f"--bag_path={x}"] 53 | for x in bag_paths 54 | ], 55 | args.n_proc, 56 | ) 57 | 58 | train_npz_paths = sorted( 59 | [ 60 | os.path.join(cfg["DATASET"]["PATH"], os.path.basename(x) + ".npz") 61 | for x in cfg["DATASET"]["TRAIN_SPLIT"] 62 | ] 63 | ) 64 | test_npz_paths = sorted( 65 | [ 66 | os.path.join(cfg["DATASET"]["PATH"], os.path.basename(x) + ".npz") 67 | for x in cfg["DATASET"]["TEST_SPLIT"] 68 | ] 69 | ) 70 | 71 | # Extract ground truth. 72 | run_commands( 73 | [ 74 | ["tools/extract_gt.py", f"--cfg={args.cfg}", f"--npz_path={x}"] 75 | for x in test_npz_paths 76 | ], 77 | args.n_proc, 78 | ) 79 | 80 | # Flow models. 81 | subprocess.run(["tools/test_flow.py", f"--cfg={args.cfg}"], check=True) 82 | 83 | # Rotation models. 84 | subprocess.run(["tools/test_rot.py", f"--cfg={args.cfg}"], check=True) 85 | 86 | # Extract odometry. 87 | run_commands( 88 | [ 89 | ["tools/test_odom.py", f"--cfg={args.cfg}", f"--npz_path={x}"] 90 | for x in test_npz_paths 91 | ], 92 | args.n_proc, 93 | ) 94 | 95 | # UNet 96 | run_commands( 97 | [ 98 | ["tools/test_unet.py", f"--cfg={args.cfg}", f"--npz_path={x}"] 99 | for x in test_npz_paths 100 | ], 101 | args.n_proc, 102 | ) 103 | 104 | ### Run Cartographer. 105 | # Get ground truth. 106 | subprocess.run( 107 | [ 108 | "tools/run_carto.py", 109 | f"--cfg={args.cfg}", 110 | f"--n_proc=1", 111 | f"--odom=gt", 112 | f"--scan=gt", 113 | f"--params=default", 114 | ], 115 | check=True, 116 | ) 117 | 118 | # RadarHD baseline. 119 | subprocess.run( 120 | [ 121 | "tools/run_carto.py", 122 | f"--cfg={args.cfg}", 123 | f"--n_proc=1", 124 | f"--odom=gt", 125 | f"--scan=radarhd", 126 | f"--params=scan_only", 127 | ], 128 | check=True, 129 | ) 130 | 131 | # RNIN + RadarHD baseline. 132 | subprocess.run( 133 | [ 134 | "tools/run_carto.py", 135 | f"--cfg={args.cfg}", 136 | f"--n_proc=1", 137 | f"--odom=rnin", 138 | f"--scan=radarhd", 139 | f"--params=default", 140 | ], 141 | check=True, 142 | ) 143 | 144 | # milliEgo + RadarHD baseline. 145 | subprocess.run( 146 | [ 147 | "tools/run_carto.py", 148 | f"--cfg={args.cfg}", 149 | f"--n_proc=1", 150 | f"--odom=milliego", 151 | f"--scan=radarhd", 152 | f"--params=default", 153 | ], 154 | check=True, 155 | ) 156 | 157 | # Our odometry + RadarHD baseline. 158 | subprocess.run( 159 | [ 160 | "tools/run_carto.py", 161 | f"--cfg={args.cfg}", 162 | f"--n_proc=1", 163 | f"--odom=odometry", 164 | f"--scan=radarhd", 165 | f"--params=radar", 166 | ], 167 | check=True, 168 | ) 169 | 170 | # Run radarize. 171 | subprocess.run( 172 | [ 173 | "tools/run_carto.py", 174 | f"--cfg={args.cfg}", 175 | f"--n_proc=1", 176 | f"--odom=odometry", 177 | f"--scan=unet", 178 | f"--params=radar", 179 | ], 180 | check=True, 181 | ) 182 | 183 | -------------------------------------------------------------------------------- /odom_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | > log 6 | for d in "gt_gt_default/output" "milliego" "rnin" "odometry" 7 | do 8 | for f in "main_0" "main_1" "main_2" "main_3" "main_4" 9 | do 10 | ./tools/eval_traj.py --cfg=configs/$f.yaml --input=$d | tee -a log 11 | done 12 | done 13 | 14 | > odom_result.txt 15 | for d in "milliego" "rnin" "odometry" 16 | do 17 | for f in "ape_trans" "ape_rot" "rpe_trans" "rpe_rot" 18 | do 19 | average=$(echo "($(cat log | grep $d | grep $f | cut -d' ' -f4 | paste -s -d+))/5" | bc -l) 20 | echo "$d $f ${average}" | tee -a odom_result.txt 21 | done 22 | done 23 | -------------------------------------------------------------------------------- /radarize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ConnectedSystemsLab/radarize_ae/cbcb6439b89c068d86a424215ee61edc925191ae/radarize/__init__.py -------------------------------------------------------------------------------- /radarize/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .default import _C as cfg 2 | from .default import update_config as update_config 3 | -------------------------------------------------------------------------------- /radarize/config/default.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from yacs.config import CfgNode as CN 4 | 5 | _C = CN() 6 | 7 | # Output folder for this experiment configuration. 8 | _C.OUTPUT_DIR = "output_test/" 9 | 10 | # Dataset Preprocessing. 11 | _C.DATASET = CN() 12 | 13 | _C.DATASET.PATH = "data/all/" 14 | _C.DATASET.TRAIN_SPLIT = ["walk_csl_basement_0"] 15 | _C.DATASET.VAL_SPLIT = ["walk_csl_basement_0"] 16 | _C.DATASET.TEST_SPLIT = ["walk_csl_basement_0"] 17 | 18 | _C.DATASET.RADAR_CONFIG = "calib/1843/1843_v1.cfg" 19 | _C.DATASET.DEPTH_INTRINSICS = "calib/d435i/depth_intrinsics_848x100.txt" 20 | _C.DATASET.SYNC_TOPIC = "radar_r_3" 21 | _C.DATASET.CAMERA_TOPIC = "/tracking/fisheye1/image_raw/compressed" 22 | _C.DATASET.DEPTH_TOPIC = "/camera/depth/image_rect_raw/compressedDepth" 23 | _C.DATASET.RADAR_TOPIC = "/radar0/radar_data" 24 | _C.DATASET.PCD_TOPIC = "/ti_mmwave/radar_scan_pcl_0" 25 | _C.DATASET.IMU_TOPIC = "/tracking/imu" 26 | _C.DATASET.POSE_TOPIC = "/tracking/odom/sample" 27 | 28 | # Doppler-angle heatmaps 29 | _C.DATASET.DA = CN() 30 | 31 | _C.DATASET.DA.RADAR_BUFFER_LEN = 3 32 | _C.DATASET.DA.RANGE_SUBSAMPLING_FACTOR = 1 33 | _C.DATASET.DA.RESIZE_SHAPE = [181, 60] 34 | 35 | # Range-azimuth heatmaps 36 | _C.DATASET.RA = CN() 37 | 38 | _C.DATASET.RA.RADAR_BUFFER_LEN = 3 39 | _C.DATASET.RA.RANGE_SUBSAMPLING_FACTOR = 1 40 | _C.DATASET.RA.RAMAP_RSIZE = 96 41 | _C.DATASET.RA.RAMAP_ASIZE = 88 42 | _C.DATASET.RA.RR_MIN = 0 43 | _C.DATASET.RA.RR_MAX = 4.284 44 | _C.DATASET.RA.RA_MIN = -43 45 | _C.DATASET.RA.RA_MAX = 43 46 | 47 | # Flow Module. 48 | _C.FLOW = CN() 49 | 50 | _C.FLOW.MODEL = CN() 51 | 52 | _C.FLOW.MODEL.NAME = "transnet18" 53 | _C.FLOW.MODEL.TYPE = "ResNet18" 54 | _C.FLOW.MODEL.N_CHANNELS = 2 55 | _C.FLOW.MODEL.N_OUTPUTS = 2 56 | 57 | _C.FLOW.DATA = CN() 58 | _C.FLOW.DATA.SUBSAMPLE_FACTOR = 1 59 | 60 | _C.FLOW.TRAIN = CN() 61 | _C.FLOW.TRAIN.BATCH_SIZE = 128 62 | _C.FLOW.TRAIN.LR = 1e-3 63 | _C.FLOW.TRAIN.EPOCHS = 50 64 | _C.FLOW.TRAIN.SEED = 1 65 | _C.FLOW.TRAIN.LOG_STEP = 100 66 | 67 | _C.FLOW.TEST = CN() 68 | _C.FLOW.TEST.BATCH_SIZE = 64 69 | 70 | # Rotation Module. 71 | _C.ROTNET = CN() 72 | 73 | _C.ROTNET.MODEL = CN() 74 | _C.ROTNET.MODEL.NAME = "eca_rotnet18_135" 75 | _C.ROTNET.MODEL.TYPE = "ECAResNet18" 76 | _C.ROTNET.MODEL.N_CHANNELS = 6 77 | _C.ROTNET.MODEL.N_OUTPUTS = 1 78 | 79 | _C.ROTNET.DATA = CN() 80 | _C.ROTNET.DATA.SUBSAMPLE_FACTOR = 2 81 | 82 | _C.ROTNET.TRAIN = CN() 83 | _C.ROTNET.TRAIN.BATCH_SIZE = 128 84 | _C.ROTNET.TRAIN.LR = 1e-3 85 | _C.ROTNET.TRAIN.EPOCHS = 50 86 | _C.ROTNET.TRAIN.SEED = 777 87 | _C.ROTNET.TRAIN.LOG_STEP = 50 88 | _C.ROTNET.TRAIN.TRAIN_SEQ_LEN = 4 89 | _C.ROTNET.TRAIN.TRAIN_RANDOM_SEQ_LEN = True 90 | _C.ROTNET.TRAIN.VAL_SEQ_LEN = 4 91 | _C.ROTNET.TRAIN.VAL_RANDOM_SEQ_LEN = True 92 | 93 | _C.ROTNET.TEST = CN() 94 | _C.ROTNET.TEST.BATCH_SIZE = 128 95 | _C.ROTNET.TEST.SEQ_LEN = 4 96 | 97 | # UNet Module. 98 | 99 | _C.UNET = CN() 100 | 101 | _C.UNET.MODEL = CN() 102 | _C.UNET.MODEL.NAME = "unet" 103 | _C.UNET.MODEL.TYPE = "UNet" 104 | _C.UNET.MODEL.N_CHANNELS = 6 105 | _C.UNET.MODEL.N_CLASSES = 2 106 | 107 | _C.UNET.TRAIN = CN() 108 | _C.UNET.TRAIN.BATCH_SIZE = 48 109 | _C.UNET.TRAIN.LR = 1e-4 110 | _C.UNET.TRAIN.BCE_WEIGHT = 0.0 111 | _C.UNET.TRAIN.DICE_WEIGHT = 1.0 112 | _C.UNET.TRAIN.EPOCHS = 15 113 | _C.UNET.TRAIN.SEED = 1 114 | _C.UNET.TRAIN.LOG_STEP = 100 115 | 116 | _C.UNET.TEST = CN() 117 | 118 | _C.UNET.TEST.BATCH_SIZE = 64 119 | 120 | # Odom Module. 121 | _C.ODOM = CN() 122 | 123 | _C.ODOM.OUTPUT_DIR = "odometry" 124 | 125 | _C.ODOM.MODELS = CN() 126 | _C.ODOM.MODELS.TRANS = "transnet18" 127 | _C.ODOM.MODELS.ROT = "eca_rotnet18_135" 128 | 129 | _C.ODOM.PARAMS = CN() 130 | _C.ODOM.PARAMS.SUBSAMPLE_FACTOR = 2 131 | _C.ODOM.PARAMS.DELAY = 1 132 | _C.ODOM.PARAMS.KF_DELAY = 4 133 | _C.ODOM.PARAMS.POS_THRESH = 999 134 | _C.ODOM.PARAMS.YAW_THRESH = 999 135 | 136 | 137 | def get_cfg_defaults(): 138 | """Get a yacs CfgNode object with default values for my_project.""" 139 | # Return a clone so that the defaults will not be altered 140 | # This is for the "local variable" use pattern 141 | return _C.clone() 142 | 143 | 144 | # Alternatively, provide a way to import the defaults as 145 | # a global singleton: 146 | # cfg = _C # users can `from config import cfg` 147 | 148 | 149 | def update_config(cfg, args): 150 | cfg.defrost() 151 | if args.cfg: 152 | cfg.merge_from_file(args.cfg) 153 | if args.opts: 154 | cfg.merge_from_list(args.opts) 155 | cfg.freeze() 156 | -------------------------------------------------------------------------------- /radarize/flow/dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | import glob 6 | import argparse 7 | import numpy as np 8 | import torch 9 | from tqdm import tqdm 10 | 11 | torch.multiprocessing.set_sharing_strategy("file_system") 12 | import cv2 13 | from torch.utils.data import Dataset, DataLoader 14 | from torchvision import datasets, transforms 15 | from torch.utils.data.dataloader import default_collate 16 | 17 | sys.path.append("..") 18 | 19 | 20 | class FlipFlow(object): 21 | 22 | def __init__(self, prob=0.5): 23 | self.prob = prob 24 | 25 | def __call__(self, sample): 26 | if torch.rand(1).item() < self.prob: 27 | sample["velo_gt"] = sample["velo_gt"] * -1 28 | sample["radar_d"] = transforms.functional.vflip(sample["radar_d"]) 29 | sample["radar_de"] = transforms.functional.vflip(sample["radar_de"]) 30 | 31 | return sample 32 | 33 | 34 | class FlowDataset(Dataset): 35 | """Flow dataset.""" 36 | 37 | topics = ["time", "radar_d", "radar_de", "velo_gt"] 38 | 39 | def __init__(self, path, subsample_factor=1, transform=None): 40 | # Load files from .npz. 41 | self.path = path 42 | print(path) 43 | with np.load(path) as data: 44 | self.files = [k for k in data.files if k in self.topics] 45 | self.dataset = {k: data[k][::subsample_factor] for k in self.files} 46 | 47 | # Check if lengths are the same. 48 | for k in self.files: 49 | print(k, self.dataset[k].shape, self.dataset[k].dtype) 50 | lengths = [self.dataset[k].shape[0] for k in self.files] 51 | assert len(set(lengths)) == 1 52 | self.num_samples = lengths[0] 53 | 54 | # Save transforms. 55 | self.transform = transform 56 | 57 | def __len__(self): 58 | return self.num_samples 59 | 60 | def __getitem__(self, idx): 61 | sample = { 62 | k: ( 63 | torch.from_numpy(self.dataset[k][idx]) 64 | if type(self.dataset[k][idx]) is np.ndarray 65 | else self.dataset[k][idx] 66 | ) 67 | for k in self.files 68 | } 69 | if self.transform: 70 | sample = self.transform(sample) 71 | return sample 72 | 73 | 74 | -------------------------------------------------------------------------------- /radarize/flow/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | import math 8 | from collections import OrderedDict 9 | 10 | 11 | def weight_init(m): 12 | """ 13 | Usage: 14 | model = Model() 15 | model.apply(weight_init) 16 | """ 17 | if isinstance(m, nn.Conv1d): 18 | init.normal_(m.weight.data) 19 | if m.bias is not None: 20 | init.normal_(m.bias.data) 21 | elif isinstance(m, nn.Conv2d): 22 | init.xavier_normal_(m.weight.data) 23 | if m.bias is not None: 24 | init.normal_(m.bias.data) 25 | elif isinstance(m, nn.Conv3d): 26 | init.xavier_normal_(m.weight.data) 27 | if m.bias is not None: 28 | init.normal_(m.bias.data) 29 | elif isinstance(m, nn.ConvTranspose1d): 30 | init.normal_(m.weight.data) 31 | if m.bias is not None: 32 | init.normal_(m.bias.data) 33 | elif isinstance(m, nn.ConvTranspose2d): 34 | init.xavier_normal_(m.weight.data) 35 | if m.bias is not None: 36 | init.normal_(m.bias.data) 37 | elif isinstance(m, nn.ConvTranspose3d): 38 | init.xavier_normal_(m.weight.data) 39 | if m.bias is not None: 40 | init.normal_(m.bias.data) 41 | elif isinstance(m, nn.BatchNorm1d): 42 | init.normal_(m.weight.data, mean=1, std=0.02) 43 | init.constant_(m.bias.data, 0) 44 | elif isinstance(m, nn.BatchNorm2d): 45 | init.normal_(m.weight.data, mean=1, std=0.02) 46 | init.constant_(m.bias.data, 0) 47 | elif isinstance(m, nn.BatchNorm3d): 48 | init.normal_(m.weight.data, mean=1, std=0.02) 49 | init.constant_(m.bias.data, 0) 50 | elif isinstance(m, nn.Linear): 51 | init.xavier_normal_(m.weight.data) 52 | init.normal_(m.bias.data) 53 | elif isinstance(m, nn.LSTM): 54 | for param in m.parameters(): 55 | if len(param.shape) >= 2: 56 | init.orthogonal_(param.data) 57 | else: 58 | init.normal_(param.data) 59 | elif isinstance(m, nn.LSTMCell): 60 | for param in m.parameters(): 61 | if len(param.shape) >= 2: 62 | init.orthogonal_(param.data) 63 | else: 64 | init.normal_(param.data) 65 | elif isinstance(m, nn.GRU): 66 | for param in m.parameters(): 67 | if len(param.shape) >= 2: 68 | init.orthogonal_(param.data) 69 | else: 70 | init.normal_(param.data) 71 | elif isinstance(m, nn.GRUCell): 72 | for param in m.parameters(): 73 | if len(param.shape) >= 2: 74 | init.orthogonal_(param.data) 75 | else: 76 | init.normal_(param.data) 77 | 78 | 79 | class EfficientChannelAttention(nn.Module): # Efficient Channel Attention module 80 | def __init__(self, c, b=1, gamma=2): 81 | super(EfficientChannelAttention, self).__init__() 82 | t = int(abs((math.log(c, 2) + b) / gamma)) 83 | k = t if t % 2 else t + 1 84 | 85 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 86 | self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k / 2), bias=False) 87 | self.sigmoid = nn.Sigmoid() 88 | 89 | def forward(self, x): 90 | x = self.avg_pool(x) 91 | x = self.conv1(x.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 92 | out = self.sigmoid(x) 93 | return out 94 | 95 | 96 | class BasicBlock(nn.Module): # 左侧的 residual block 结构(18-layer、34-layer) 97 | expansion = 1 98 | 99 | def __init__(self, in_planes, planes, stride=1): # 两层卷积 Conv2d + Shutcuts 100 | super(BasicBlock, self).__init__() 101 | self.conv1 = nn.Conv2d( 102 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 103 | ) 104 | self.bn1 = nn.BatchNorm2d(planes) 105 | self.conv2 = nn.Conv2d( 106 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 107 | ) 108 | self.bn2 = nn.BatchNorm2d(planes) 109 | 110 | self.channel = EfficientChannelAttention( 111 | planes 112 | ) # Efficient Channel Attention module 113 | 114 | self.shortcut = nn.Sequential() 115 | if ( 116 | stride != 1 or in_planes != self.expansion * planes 117 | ): # Shutcuts用于构建 Conv Block 和 Identity Block 118 | self.shortcut = nn.Sequential( 119 | nn.Conv2d( 120 | in_planes, 121 | self.expansion * planes, 122 | kernel_size=1, 123 | stride=stride, 124 | bias=False, 125 | ), 126 | nn.BatchNorm2d(self.expansion * planes), 127 | ) 128 | 129 | def forward(self, x): 130 | out = F.relu(self.bn1(self.conv1(x))) 131 | out = self.bn2(self.conv2(out)) 132 | ECA_out = self.channel(out) 133 | out = out * ECA_out 134 | out += self.shortcut(x) 135 | out = F.relu(out) 136 | return out 137 | 138 | 139 | class ECAResNet18(nn.Module): 140 | def __init__(self, n_channels, n_outputs): 141 | super(ECAResNet18, self).__init__() 142 | self.in_planes = 64 143 | num_blocks = [2, 2, 2, 2] 144 | block = BasicBlock 145 | 146 | self.conv1 = nn.Conv2d( 147 | n_channels, 148 | 64, 149 | kernel_size=(7, 7), 150 | stride=(2, 2), 151 | padding=(3, 3), 152 | bias=False, 153 | ) # conv1 154 | self.bn1 = nn.BatchNorm2d(64) 155 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) # conv2_x 156 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) # conv3_x 157 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) # conv4_x 158 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) # conv5_x 159 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 160 | self.linear = nn.Linear(512 * block.expansion, 32) 161 | self.fc2 = nn.Linear(32, n_outputs) 162 | 163 | weight_init(self) 164 | 165 | def _make_layer(self, block, planes, num_blocks, stride): 166 | strides = [stride] + [1] * (num_blocks - 1) 167 | layers = [] 168 | for stride in strides: 169 | layers.append(block(self.in_planes, planes, stride)) 170 | self.in_planes = planes * block.expansion 171 | return nn.Sequential(*layers) 172 | 173 | def forward(self, x): 174 | x = F.relu(self.bn1(self.conv1(x))) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.avgpool(x) 180 | x = torch.flatten(x, 1) 181 | x = self.linear(x) 182 | out = self.fc2(x) 183 | return out 184 | 185 | 186 | class ResNet18(nn.Module): 187 | """Model to predict x and y flow from radar heatmaps.""" 188 | 189 | def __init__(self, n_channels, n_outputs): 190 | super(ResNet18, self).__init__() 191 | 192 | # CNN encoder for heatmaps 193 | self.resnet18 = models.resnet18(pretrained=True) 194 | self.resnet18.conv1 = nn.Conv2d( 195 | n_channels, 196 | 64, 197 | kernel_size=(7, 7), 198 | stride=(2, 2), 199 | padding=(3, 3), 200 | bias=False, 201 | ) 202 | self.resnet18.fc = nn.Linear(512, n_outputs) 203 | 204 | weight_init(self) 205 | 206 | def forward(self, x): 207 | out = self.resnet18(x) 208 | return out 209 | 210 | 211 | class ResNet50(nn.Module): 212 | """Model to predict x and y flow from radar heatmaps.""" 213 | 214 | def __init__(self, n_channels, n_outputs): 215 | super(ResNet50, self).__init__() 216 | 217 | # CNN encoder for heatmaps 218 | self.resnet50 = models.resnet50(pretrained=True) 219 | self.resnet50.conv1 = nn.Conv2d( 220 | n_channels, 221 | 64, 222 | kernel_size=(7, 7), 223 | stride=(2, 2), 224 | padding=(3, 3), 225 | bias=False, 226 | ) 227 | self.resnet50.fc = nn.Linear(2048, n_outputs) 228 | 229 | weight_init(self) 230 | 231 | def forward(self, x): 232 | out = self.resnet50(x) 233 | return out 234 | 235 | 236 | class ResNet18Nano(nn.Module): 237 | """Model to predict x and y flow from radar heatmaps.""" 238 | 239 | def __init__(self, n_channels, n_outputs): 240 | super(ResNet18Nano, self).__init__() 241 | 242 | # CNN encoder for48eatmaps 243 | resnet18 = models.resnet._resnet( 244 | "resnet18", 245 | models.resnet.BasicBlock, 246 | [1, 1, 1, 1], 247 | pretrained=False, 248 | progress=False, 249 | ) 250 | resnet18.conv1 = nn.Conv2d( 251 | n_channels, 252 | 64, 253 | kernel_size=(7, 7), 254 | stride=(2, 2), 255 | padding=(3, 3), 256 | bias=False, 257 | ) 258 | self.enc = nn.Sequential(OrderedDict(list(resnet18.named_children())[:5])) 259 | self.avgpool = resnet18.avgpool 260 | self.fc = nn.Linear(64, n_outputs) 261 | 262 | weight_init(self) 263 | 264 | def init_weights(self): 265 | for m in self.modules(): 266 | m.apply(weight_init) 267 | 268 | def forward(self, x): 269 | x = self.enc(x) 270 | x = self.avgpool(x) 271 | x = torch.flatten(x, 1) 272 | x = self.fc(x) 273 | return x 274 | 275 | 276 | class ResNet18Micro(nn.Module): 277 | """Model to predict x and y flow from radar heatmaps.""" 278 | 279 | def __init__(self, n_channels, n_outputs): 280 | super(ResNet18Micro, self).__init__() 281 | 282 | # CNN encoder for48eatmaps 283 | resnet18 = models.resnet._resnet( 284 | "resnet18", 285 | models.resnet.BasicBlock, 286 | [1, 1, 1, 1], 287 | pretrained=False, 288 | progress=False, 289 | ) 290 | resnet18.conv1 = nn.Conv2d( 291 | n_channels, 292 | 64, 293 | kernel_size=(7, 7), 294 | stride=(2, 2), 295 | padding=(3, 3), 296 | bias=False, 297 | ) 298 | self.enc = nn.Sequential(OrderedDict(list(resnet18.named_children())[:6])) 299 | self.avgpool = resnet18.avgpool 300 | self.fc = nn.Linear(128, n_outputs) 301 | 302 | weight_init(self) 303 | 304 | def forward(self, x): 305 | x = self.enc(x) 306 | x = self.avgpool(x) 307 | x = torch.flatten(x, 1) 308 | x = self.fc(x) 309 | 310 | return x 311 | 312 | -------------------------------------------------------------------------------- /radarize/rotnet/dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | import glob 6 | import argparse 7 | import numpy as np 8 | import torch 9 | from tqdm import tqdm 10 | 11 | torch.multiprocessing.set_sharing_strategy("file_system") 12 | import cv2 13 | from torch.utils.data import Dataset, DataLoader 14 | from torchvision import datasets, transforms 15 | from torch.utils.data.dataloader import default_collate 16 | 17 | sys.path.append("..") 18 | 19 | 20 | class ReverseTime(object): 21 | 22 | def __init__(self, prob=0.5): 23 | self.prob = prob 24 | 25 | def __call__(self, sample): 26 | 27 | if torch.rand(1).item() < self.prob: 28 | for k, v in sample.items(): 29 | sample[k] = torch.flip(v, [0]) 30 | 31 | return sample 32 | 33 | 34 | class RotationDataset(Dataset): 35 | """Rotation dataset.""" 36 | 37 | topics = ["time", "pose_gt", "radar_r_1", "radar_r_3", "radar_r_5", "radar_re"] 38 | 39 | def __init__( 40 | self, path, subsample_factor=1, seq_len=1, random_seq_len=False, transform=None 41 | ): 42 | 43 | # Load files from .npz. 44 | self.path = path 45 | 46 | print(path) 47 | with np.load(path) as data: 48 | self.files = [k for k in data.files if k in self.topics] 49 | self.dataset = {k: data[k][::subsample_factor] for k in self.files} 50 | 51 | # Check if lengths are the same. 52 | for k in self.files: 53 | print(k, self.dataset[k].shape, self.dataset[k].dtype) 54 | lengths = [self.dataset[k].shape[0] for k in self.files] 55 | assert len(set(lengths)) == 1 56 | self.num_samples = lengths[0] 57 | 58 | # Set sequence length for stacking frames across time. 59 | self.random_seq_len = random_seq_len 60 | self.seq_len = seq_len 61 | 62 | # Save transforms. 63 | self.transform = transform 64 | 65 | def __len__(self): 66 | return self.num_samples - self.seq_len 67 | 68 | def __getitem__(self, idx): 69 | sample = {} 70 | 71 | if self.random_seq_len: 72 | seq_len = np.random.randint(1, self.seq_len + 1) 73 | else: 74 | seq_len = self.seq_len 75 | 76 | sample["time"] = torch.tensor(self.dataset["time"][[idx, idx + seq_len]]) 77 | sample["pose_gt"] = torch.tensor(self.dataset["pose_gt"][[idx, idx + seq_len]]) 78 | sample["radar_r"] = torch.tensor( 79 | np.concatenate( 80 | [ 81 | self.dataset["radar_r_1"][idx], 82 | self.dataset["radar_r_3"][idx], 83 | self.dataset["radar_r_5"][idx], 84 | self.dataset["radar_r_5"][idx + seq_len], 85 | self.dataset["radar_r_3"][idx + seq_len], 86 | self.dataset["radar_r_1"][idx + seq_len], 87 | ], 88 | axis=0, 89 | ) 90 | ) 91 | 92 | # sample['time'] = torch.tensor(self.dataset['time'][idx:idx+seq_len]) 93 | # sample['pose_gt'] = torch.tensor(self.dataset['pose_gt'][idx:idx+seq_len]) 94 | # sample['radar_r'] = torch.from_numpy(np.concatenate(self.dataset['radar_r'][idx:idx+seq_len], axis=0)) 95 | 96 | if self.transform: 97 | sample = self.transform(sample) 98 | 99 | return sample 100 | 101 | 102 | -------------------------------------------------------------------------------- /radarize/rotnet/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import math 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision.models as models 11 | 12 | 13 | class EfficientChannelAttention(nn.Module): # Efficient Channel Attention module 14 | def __init__(self, c, b=1, gamma=2): 15 | super(EfficientChannelAttention, self).__init__() 16 | t = int(abs((math.log(c, 2) + b) / gamma)) 17 | k = t if t % 2 else t + 1 18 | 19 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 20 | self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k / 2), bias=False) 21 | self.sigmoid = nn.Sigmoid() 22 | 23 | def forward(self, x): 24 | x = self.avg_pool(x) 25 | x = self.conv1(x.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 26 | out = self.sigmoid(x) 27 | return out 28 | 29 | 30 | class BasicBlock(nn.Module): # 左侧的 residual block 结构(18-layer、34-layer) 31 | expansion = 1 32 | 33 | def __init__(self, in_planes, planes, stride=1): # 两层卷积 Conv2d + Shutcuts 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = nn.Conv2d( 36 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 37 | ) 38 | self.bn1 = nn.BatchNorm2d(planes) 39 | self.conv2 = nn.Conv2d( 40 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 41 | ) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | 44 | self.channel = EfficientChannelAttention( 45 | planes 46 | ) # Efficient Channel Attention module 47 | 48 | self.shortcut = nn.Sequential() 49 | if ( 50 | stride != 1 or in_planes != self.expansion * planes 51 | ): # Shutcuts用于构建 Conv Block 和 Identity Block 52 | self.shortcut = nn.Sequential( 53 | nn.Conv2d( 54 | in_planes, 55 | self.expansion * planes, 56 | kernel_size=1, 57 | stride=stride, 58 | bias=False, 59 | ), 60 | nn.BatchNorm2d(self.expansion * planes), 61 | ) 62 | 63 | def forward(self, x): 64 | out = F.relu(self.bn1(self.conv1(x))) 65 | out = self.bn2(self.conv2(out)) 66 | ECA_out = self.channel(out) 67 | out = out * ECA_out 68 | out += self.shortcut(x) 69 | out = F.relu(out) 70 | return out 71 | 72 | 73 | class ECAResNet18(nn.Module): 74 | def __init__(self, n_channels, n_outputs): 75 | super(ECAResNet18, self).__init__() 76 | self.in_planes = 64 77 | num_blocks = [2, 2, 2, 2] 78 | block = BasicBlock 79 | 80 | self.conv1 = nn.Conv2d( 81 | n_channels, 82 | 64, 83 | kernel_size=(7, 7), 84 | stride=(2, 2), 85 | padding=(3, 3), 86 | bias=False, 87 | ) # conv1 88 | self.bn1 = nn.BatchNorm2d(64) 89 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) # conv2_x 90 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) # conv3_x 91 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) # conv4_x 92 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) # conv5_x 93 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 94 | self.linear = nn.Linear(512 * block.expansion, 32) 95 | self.fc = FcBlock(32, n_outputs) 96 | 97 | weight_init(self) 98 | 99 | def _make_layer(self, block, planes, num_blocks, stride): 100 | strides = [stride] + [1] * (num_blocks - 1) 101 | layers = [] 102 | for stride in strides: 103 | layers.append(block(self.in_planes, planes, stride)) 104 | self.in_planes = planes * block.expansion 105 | return nn.Sequential(*layers) 106 | 107 | def forward(self, x): 108 | x = F.relu(self.bn1(self.conv1(x))) 109 | x = self.layer1(x) 110 | x = self.layer2(x) 111 | x = self.layer3(x) 112 | x = self.layer4(x) 113 | x = self.avgpool(x) 114 | x = torch.flatten(x, 1) 115 | x = self.linear(x) 116 | x = self.fc(x) 117 | return x 118 | 119 | 120 | class FcBlock(nn.Module): 121 | def __init__(self, in_dim, out_dim, mid_dim=256, dropout=0.05): 122 | super(FcBlock, self).__init__() 123 | self.mid_dim = mid_dim 124 | self.in_dim = in_dim 125 | self.out_dim = out_dim 126 | 127 | # fc layers 128 | self.fcs = nn.Sequential( 129 | nn.Linear(self.in_dim, self.mid_dim), 130 | nn.ReLU(True), 131 | nn.Linear(self.mid_dim, self.out_dim), 132 | ) 133 | 134 | def forward(self, x): 135 | # x = x.view(x.size(0), -1) 136 | x = self.fcs(x) 137 | return x 138 | 139 | 140 | def weight_init(m): 141 | """ 142 | Usage: 143 | model = Model() 144 | model.apply(weight_init) 145 | """ 146 | if isinstance(m, nn.Conv1d): 147 | init.normal_(m.weight.data) 148 | if m.bias is not None: 149 | init.normal_(m.bias.data) 150 | elif isinstance(m, nn.Conv2d): 151 | init.xavier_normal_(m.weight.data) 152 | if m.bias is not None: 153 | init.normal_(m.bias.data) 154 | elif isinstance(m, nn.Conv3d): 155 | init.xavier_normal_(m.weight.data) 156 | if m.bias is not None: 157 | init.normal_(m.bias.data) 158 | elif isinstance(m, nn.ConvTranspose1d): 159 | init.normal_(m.weight.data) 160 | if m.bias is not None: 161 | init.normal_(m.bias.data) 162 | elif isinstance(m, nn.ConvTranspose2d): 163 | init.xavier_normal_(m.weight.data) 164 | if m.bias is not None: 165 | init.normal_(m.bias.data) 166 | elif isinstance(m, nn.ConvTranspose3d): 167 | init.xavier_normal_(m.weight.data) 168 | if m.bias is not None: 169 | init.normal_(m.bias.data) 170 | elif isinstance(m, nn.BatchNorm1d): 171 | init.normal_(m.weight.data, mean=1, std=0.02) 172 | init.constant_(m.bias.data, 0) 173 | elif isinstance(m, nn.BatchNorm2d): 174 | init.normal_(m.weight.data, mean=1, std=0.02) 175 | init.constant_(m.bias.data, 0) 176 | elif isinstance(m, nn.BatchNorm3d): 177 | init.normal_(m.weight.data, mean=1, std=0.02) 178 | init.constant_(m.bias.data, 0) 179 | elif isinstance(m, nn.Linear): 180 | init.xavier_normal_(m.weight.data) 181 | init.normal_(m.bias.data) 182 | elif isinstance(m, nn.LSTM): 183 | for param in m.parameters(): 184 | if len(param.shape) >= 2: 185 | init.orthogonal_(param.data) 186 | else: 187 | init.normal_(param.data) 188 | elif isinstance(m, nn.LSTMCell): 189 | for param in m.parameters(): 190 | if len(param.shape) >= 2: 191 | init.orthogonal_(param.data) 192 | else: 193 | init.normal_(param.data) 194 | elif isinstance(m, nn.GRU): 195 | for param in m.parameters(): 196 | if len(param.shape) >= 2: 197 | init.orthogonal_(param.data) 198 | else: 199 | init.normal_(param.data) 200 | elif isinstance(m, nn.GRUCell): 201 | for param in m.parameters(): 202 | if len(param.shape) >= 2: 203 | init.orthogonal_(param.data) 204 | else: 205 | init.normal_(param.data) 206 | 207 | 208 | class ResNet34(nn.Module): 209 | def __init__(self, n_channels, n_outputs): 210 | super(ResNet34, self).__init__() 211 | 212 | self.resnet34 = models.resnet34(pretrained=True) 213 | self.resnet34.conv1 = nn.Conv2d( 214 | n_channels, 215 | 64, 216 | kernel_size=(7, 7), 217 | stride=(2, 2), 218 | padding=(3, 3), 219 | bias=False, 220 | ) 221 | self.resnet34.fc = nn.Linear(512, n_outputs) 222 | 223 | def forward(self, x): 224 | return self.resnet34(x) 225 | 226 | 227 | class ResNet18(nn.Module): 228 | """Model to predict x and y flow from radar heatmaps.""" 229 | 230 | def __init__(self, n_channels, n_outputs): 231 | super(ResNet18, self).__init__() 232 | 233 | self.resnet18 = models.resnet18(pretrained=False) 234 | self.resnet18.conv1 = nn.Conv2d( 235 | n_channels, 236 | 64, 237 | kernel_size=(7, 7), 238 | stride=(2, 2), 239 | padding=(3, 3), 240 | bias=False, 241 | ) 242 | self.resnet18.fc = nn.Linear(512, n_outputs) 243 | 244 | # self.resnet18.layer1[0].conv1 = nn.Conv2d(64, 64, kernel_size=(3, 3), dilation=2, padding=(2,2)) 245 | # self.resnet18.layer1[0].conv2 = nn.Conv2d(64, 64, kernel_size=(3, 3), dilation=2, padding=(2,2)) 246 | # self.resnet18.layer1[1].conv1 = nn.Conv2d(64, 64, kernel_size=(3, 3), dilation=2, padding=(2,2)) 247 | # self.resnet18.layer1[1].conv2 = nn.Conv2d(64, 64, kernel_size=(3, 3), dilation=2, padding=(2,2)) 248 | 249 | # self.resnet18.layer2[0].conv1 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=2, dilation=2, padding=(2,2)) 250 | # self.resnet18.layer2[0].conv2 = nn.Conv2d(128, 128, kernel_size=(3, 3), dilation=2, padding=(2,2)) 251 | # self.resnet18.layer2[1].conv1 = nn.Conv2d(128, 128, kernel_size=(3, 3), dilation=2, padding=(2,2)) 252 | # self.resnet18.layer2[1].conv2 = nn.Conv2d(128, 128, kernel_size=(3, 3), dilation=2, padding=(2,2)) 253 | 254 | # self.resnet18.layer3[0].conv1 = nn.Conv2d(128, 256, kernel_size=(3, 3), stride=2, dilation=2, padding=(2,2)) 255 | # self.resnet18.layer3[0].conv2 = nn.Conv2d(256, 256, kernel_size=(3, 3), dilation=2, padding=(2,2)) 256 | # self.resnet18.layer3[1].conv1 = nn.Conv2d(256, 256, kernel_size=(3, 3), dilation=2, padding=(2,2)) 257 | # self.resnet18.layer3[1].conv2 = nn.Conv2d(256, 256, kernel_size=(3, 3), dilation=2, padding=(2,2)) 258 | 259 | # self.resnet18.layer4[0].conv1 = nn.Conv2d(256, 512, kernel_size=(3, 3), stride=2, dilation=2, padding=(2,2)) 260 | # self.resnet18.layer4[0].conv2 = nn.Conv2d(512, 512, kernel_size=(3, 3), dilation=2, padding=(2,2)) 261 | # self.resnet18.layer4[1].conv1 = nn.Conv2d(512, 512, kernel_size=(3, 3), dilation=2, padding=(2,2)) 262 | # self.resnet18.layer4[1].conv2 = nn.Conv2d(512, 512, kernel_size=(3, 3), dilation=2, padding=(2,2)) 263 | 264 | # print(self.resnet18) 265 | 266 | weight_init(self) 267 | 268 | def forward(self, x): 269 | out = self.resnet18(x) 270 | return out 271 | 272 | 273 | class ResNet50(nn.Module): 274 | """Model to predict x and y flow from radar heatmaps.""" 275 | 276 | def __init__(self, n_channels, n_outputs): 277 | super(ResNet50, self).__init__() 278 | 279 | # CNN encoder for heatmaps 280 | self.resnet50 = models.resnet50(pretrained=True) 281 | self.resnet50.conv1 = nn.Conv2d( 282 | n_channels, 283 | 64, 284 | kernel_size=(7, 7), 285 | stride=(2, 2), 286 | padding=(3, 3), 287 | bias=False, 288 | ) 289 | self.resnet50.fc = nn.Linear(2048, n_outputs) 290 | 291 | def forward(self, x): 292 | out = self.resnet50(x) 293 | return out 294 | 295 | 296 | class ResNet18Nano(nn.Module): 297 | """Model to predict x and y flow from radar heatmaps.""" 298 | 299 | def __init__(self, n_channels, n_outputs): 300 | super(ResNet18Nano, self).__init__() 301 | 302 | # CNN encoder for48eatmaps 303 | resnet18 = models.resnet._resnet( 304 | "resnet18", 305 | models.resnet.BasicBlock, 306 | [1, 1, 1, 1], 307 | pretrained=False, 308 | progress=False, 309 | ) 310 | resnet18.conv1 = nn.Conv2d( 311 | n_channels, 312 | 64, 313 | kernel_size=(7, 7), 314 | stride=(2, 2), 315 | padding=(3, 3), 316 | bias=False, 317 | ) 318 | self.enc = nn.Sequential(OrderedDict(list(resnet18.named_children())[:5])) 319 | self.avgpool = resnet18.avgpool 320 | self.fc = nn.Linear(64, n_outputs) 321 | 322 | def init_weights(self): 323 | for m in self.modules(): 324 | m.apply(weight_init) 325 | 326 | def forward(self, x): 327 | x = self.enc(x) 328 | x = self.avgpool(x) 329 | x = torch.flatten(x, 1) 330 | x = self.fc(x) 331 | return x 332 | 333 | 334 | class ResNet18Micro(nn.Module): 335 | """Model to predict x and y flow from radar heatmaps.""" 336 | 337 | def __init__(self, n_channels, n_outputs): 338 | super(ResNet18Micro, self).__init__() 339 | 340 | # CNN encoder for48eatmaps 341 | resnet18 = models.resnet._resnet( 342 | "resnet18", 343 | models.resnet.BasicBlock, 344 | [1, 1, 1, 1], 345 | pretrained=False, 346 | progress=False, 347 | ) 348 | resnet18.conv1 = nn.Conv2d( 349 | n_channels, 350 | 64, 351 | kernel_size=(7, 7), 352 | stride=(2, 2), 353 | padding=(3, 3), 354 | bias=False, 355 | ) 356 | self.enc = nn.Sequential(OrderedDict(list(resnet18.named_children())[:6])) 357 | self.avgpool = resnet18.avgpool 358 | self.fc = nn.Linear(128, n_outputs) 359 | 360 | def forward(self, x): 361 | x = self.enc(x) 362 | x = self.avgpool(x) 363 | x = torch.flatten(x, 1) 364 | x = self.fc(x) 365 | 366 | return x 367 | 368 | -------------------------------------------------------------------------------- /radarize/unet/dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | import glob 6 | import argparse 7 | import numpy as np 8 | import torch 9 | from tqdm import tqdm 10 | 11 | torch.multiprocessing.set_sharing_strategy("file_system") 12 | import cv2 13 | from torch.utils.data import Dataset, DataLoader 14 | from torchvision import datasets, transforms 15 | from torch.utils.data.dataloader import default_collate 16 | 17 | sys.path.append("..") 18 | 19 | 20 | class FlipRange(object): 21 | 22 | topics = [ 23 | "radar_r_1", 24 | "radar_r_3", 25 | "radar_r_5", 26 | "radar_re_1", 27 | "radar_re_3", 28 | "radar_re_5", 29 | "depth_map", 30 | ] 31 | 32 | def __init__(self, prob=0.5): 33 | self.prob = prob 34 | 35 | def __call__(self, sample): 36 | 37 | if torch.rand(1).item() < self.prob: 38 | for topic in self.topics: 39 | if topic in sample: 40 | sample[topic] = transforms.functional.hflip(sample[topic]) 41 | 42 | return sample 43 | 44 | 45 | class UNetDataset(Dataset): 46 | """UNet dataset.""" 47 | 48 | topics = [ 49 | "time", 50 | "pose_gt", 51 | "radar_r_1", 52 | "radar_r_3", 53 | "radar_r_5", 54 | "radar_re_1", 55 | "radar_re_3", 56 | "radar_re_5", 57 | "depth_map", 58 | ] 59 | 60 | def __init__(self, path, seq_len=1, transform=None): 61 | # Load files from .npz. 62 | self.path = path 63 | 64 | print(path) 65 | with np.load(path) as data: 66 | self.files = [k for k in data.files if k in self.topics] 67 | self.dataset = {k: data[k] for k in self.files} 68 | 69 | # Check if lengths are the same. 70 | for k in self.files: 71 | print(k, self.dataset[k].shape, self.dataset[k].dtype) 72 | lengths = [self.dataset[k].shape[0] for k in self.files] 73 | assert len(set(lengths)) == 1 74 | self.num_samples = lengths[0] 75 | 76 | # Set sequence length for stacking frames across time. 77 | self.seq_len = seq_len 78 | 79 | # Save transforms. 80 | self.transform = transform 81 | 82 | def __len__(self): 83 | return self.num_samples - self.seq_len 84 | 85 | def __getitem__(self, idx): 86 | sample = {} 87 | sample["time"] = torch.tensor(self.dataset["time"][idx : idx + self.seq_len]) 88 | sample["radar_r_1"] = torch.from_numpy( 89 | np.concatenate(self.dataset["radar_r_1"][idx : idx + self.seq_len], axis=0) 90 | ) 91 | sample["radar_r_3"] = torch.from_numpy( 92 | np.concatenate(self.dataset["radar_r_3"][idx : idx + self.seq_len], axis=0) 93 | ) 94 | sample["radar_r_5"] = torch.from_numpy( 95 | np.concatenate(self.dataset["radar_r_5"][idx : idx + self.seq_len], axis=0) 96 | ) 97 | sample["radar_re_1"] = torch.from_numpy( 98 | np.concatenate(self.dataset["radar_re_1"][idx : idx + self.seq_len], axis=0) 99 | ) 100 | sample["radar_re_3"] = torch.from_numpy( 101 | np.concatenate(self.dataset["radar_re_3"][idx : idx + self.seq_len], axis=0) 102 | ) 103 | sample["radar_re_5"] = torch.from_numpy( 104 | np.concatenate(self.dataset["radar_re_5"][idx : idx + self.seq_len], axis=0) 105 | ) 106 | sample["depth_map"] = torch.from_numpy( 107 | self.dataset["depth_map"][idx + self.seq_len] 108 | ) 109 | sample["pose_gt"] = torch.from_numpy( 110 | self.dataset["pose_gt"][idx : idx + self.seq_len] 111 | ) 112 | 113 | if self.transform: 114 | sample = self.transform(sample) 115 | 116 | return sample 117 | 118 | 119 | -------------------------------------------------------------------------------- /radarize/unet/dice_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | def dice_coeff( 8 | input: Tensor, 9 | target: Tensor, 10 | reduce_batch_first: bool = False, 11 | epsilon: float = 1e-6, 12 | ): 13 | # Average of Dice coefficient for all batches, or for a single mask 14 | assert input.size() == target.size() 15 | assert input.dim() == 3 or not reduce_batch_first 16 | 17 | sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3) 18 | 19 | inter = 2 * (input * target).sum(dim=sum_dim) 20 | sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim) 21 | sets_sum = torch.where(sets_sum == 0, inter, sets_sum) 22 | 23 | dice = (inter + epsilon) / (sets_sum + epsilon) 24 | return dice.mean() 25 | 26 | 27 | def multiclass_dice_coeff( 28 | input: Tensor, 29 | target: Tensor, 30 | reduce_batch_first: bool = False, 31 | epsilon: float = 1e-6, 32 | ): 33 | # Average of Dice coefficient for all classes 34 | return dice_coeff( 35 | input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon 36 | ) 37 | 38 | 39 | def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): 40 | # Dice loss (objective to minimize) between 0 and 1 41 | fn = multiclass_dice_coeff if multiclass else dice_coeff 42 | return 1 - fn(input, target, reduce_batch_first=True) 43 | 44 | -------------------------------------------------------------------------------- /radarize/unet/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True), 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.maxpool_conv(x) 39 | 40 | 41 | class Up(nn.Module): 42 | """Upscaling then double conv""" 43 | 44 | def __init__(self, in_channels, out_channels, bilinear=True): 45 | super().__init__() 46 | 47 | # if bilinear, use the normal convolutions to reduce the number of channels 48 | if bilinear: 49 | self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 50 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 51 | else: 52 | self.up = nn.ConvTranspose2d( 53 | in_channels, in_channels // 2, kernel_size=2, stride=2 54 | ) 55 | self.conv = DoubleConv(in_channels, out_channels) 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # input is CHW 60 | diffY = x2.size()[2] - x1.size()[2] 61 | diffX = x2.size()[3] - x1.size()[3] 62 | 63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) 64 | # if you have padding issues, see 65 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 66 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 67 | x = torch.cat([x2, x1], dim=1) 68 | return self.conv(x) 69 | 70 | 71 | class OutConv(nn.Module): 72 | def __init__(self, in_channels, out_channels): 73 | super(OutConv, self).__init__() 74 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 75 | 76 | def forward(self, x): 77 | return self.conv(x) 78 | 79 | 80 | class UNet(nn.Module): 81 | def __init__(self, n_channels, n_classes, bilinear=False): 82 | super(UNet, self).__init__() 83 | self.n_channels = n_channels 84 | self.n_classes = n_classes 85 | self.bilinear = bilinear 86 | 87 | self.inc = DoubleConv(n_channels, 64) 88 | self.down1 = Down(64, 128) 89 | self.down2 = Down(128, 256) 90 | self.down3 = Down(256, 512) 91 | factor = 2 if bilinear else 1 92 | self.down4 = Down(512, 1024 // factor) 93 | self.up1 = Up(1024, 512 // factor, bilinear) 94 | self.up2 = Up(512, 256 // factor, bilinear) 95 | self.up3 = Up(256, 128 // factor, bilinear) 96 | self.up4 = Up(128, 64, bilinear) 97 | self.outc = OutConv(64, n_classes) 98 | 99 | def forward(self, x): 100 | x1 = self.inc(x) 101 | x2 = self.down1(x1) 102 | x3 = self.down2(x2) 103 | x4 = self.down3(x3) 104 | x5 = self.down4(x4) 105 | x = self.up1(x5, x4) 106 | x = self.up2(x, x3) 107 | x = self.up3(x, x2) 108 | x = self.up4(x, x1) 109 | logits = self.outc(x) 110 | probs = F.softmax(logits, dim=1) 111 | return probs 112 | 113 | def use_checkpointing(self): 114 | self.inc = torch.utils.checkpoint(self.inc) 115 | self.down1 = torch.utils.checkpoint(self.down1) 116 | self.down2 = torch.utils.checkpoint(self.down2) 117 | self.down3 = torch.utils.checkpoint(self.down3) 118 | self.down4 = torch.utils.checkpoint(self.down4) 119 | self.up1 = torch.utils.checkpoint(self.up1) 120 | self.up2 = torch.utils.checkpoint(self.up2) 121 | self.up3 = torch.utils.checkpoint(self.up3) 122 | self.up4 = torch.utils.checkpoint(self.up4) 123 | self.outc = torch.utils.checkpoint(self.outc) 124 | 125 | 126 | -------------------------------------------------------------------------------- /radarize/utils/dsp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Helper functions for signal processing. 4 | """ 5 | 6 | import numpy as np 7 | import cv2 8 | from numba import njit, objmode 9 | 10 | def reshape_frame(frame, flip_ods_phase=False, flip_aop_phase=False): 11 | """Use this to reshape RadarFrameFull messages.""" 12 | 13 | platform = frame.platform 14 | adc_output_fmt = frame.adc_output_fmt 15 | rx_phase_bias = np.array( 16 | [ 17 | a + 1j * b 18 | for a, b in zip(frame.rx_phase_bias[0::2], frame.rx_phase_bias[1::2]) 19 | ] 20 | ) 21 | 22 | n_chirps = int(frame.shape[0]) 23 | rx = np.array([int(x) for x in frame.rx]) 24 | n_rx = int(frame.shape[1]) 25 | tx = np.array([int(x) for x in frame.tx]) 26 | n_tx = int(sum(frame.tx)) 27 | n_samples = int(frame.shape[2]) 28 | 29 | return _reshape_frame( 30 | np.array(frame.data), 31 | platform, 32 | adc_output_fmt, 33 | rx_phase_bias, 34 | n_chirps, 35 | rx, 36 | n_rx, 37 | tx, 38 | n_tx, 39 | n_samples, 40 | flip_ods_phase=flip_ods_phase, 41 | flip_aop_phase=flip_aop_phase, 42 | ) 43 | 44 | 45 | @njit(cache=True) 46 | def _reshape_frame( 47 | data, 48 | platform, 49 | adc_output_fmt, 50 | rx_phase_bias, 51 | n_chirps, 52 | rx, 53 | n_rx, 54 | tx, 55 | n_tx, 56 | n_samples, 57 | flip_ods_phase=False, 58 | flip_aop_phase=False, 59 | ): 60 | if adc_output_fmt > 0: 61 | 62 | radar_cube = np.zeros(len(data) // 2, dtype=np.complex64) 63 | 64 | radar_cube[0::2] = 1j * data[0::4] + data[2::4] 65 | radar_cube[1::2] = 1j * data[1::4] + data[3::4] 66 | 67 | radar_cube = radar_cube.reshape((n_chirps, n_rx, n_samples)) 68 | 69 | # Apply RX phase correction for each antenna. 70 | if "xWR68xx" in platform: 71 | if flip_ods_phase: # Apply 180 deg phase change on RX2 and RX3 72 | c = 0 73 | for i_rx, rx_on in enumerate(rx): 74 | if rx_on: 75 | if i_rx == 1 or i_rx == 2: 76 | radar_cube[:, c, :] *= -1 77 | c += 1 78 | elif flip_aop_phase: # Apply 180 deg phase change on RX1 and RX3 79 | c = 0 80 | for i_rx, rx_on in enumerate(rx): 81 | if rx_on: 82 | if i_rx == 0 or i_rx == 2: 83 | radar_cube[:, c, :] *= -1 84 | c += 1 85 | 86 | radar_cube = radar_cube.reshape((n_chirps // n_tx, n_rx * n_tx, n_samples)) 87 | 88 | # Apply RX phase correction from calibration. 89 | c = 0 90 | for i_tx, tx_on in enumerate(tx): 91 | if tx_on: 92 | for i_rx, rx_on in enumerate(rx): 93 | if rx_on: 94 | v_rx = i_tx * len(rx) + i_rx 95 | # print(v_rx) 96 | radar_cube[:, c, :] *= rx_phase_bias[v_rx] 97 | c += 1 98 | 99 | else: 100 | radar_cube = data.reshape((n_chirps // n_tx, n_rx * n_tx, n_samples)).astype( 101 | np.complex64 102 | ) 103 | 104 | return radar_cube 105 | 106 | 107 | def reshape_frame_tdm(frame, flip_ods_phase=False): 108 | """Use this to reshape RadarFrameFull messages.""" 109 | 110 | platform = frame.platform 111 | adc_output_fmt = frame.adc_output_fmt 112 | rx_phase_bias = np.array( 113 | [ 114 | a + 1j * b 115 | for a, b in zip(frame.rx_phase_bias[0::2], frame.rx_phase_bias[1::2]) 116 | ] 117 | ) 118 | 119 | n_chirps = int(frame.shape[0]) 120 | rx = np.array([int(x) for x in frame.rx]) 121 | n_rx = int(frame.shape[1]) 122 | tx = np.array([int(x) for x in frame.tx]) 123 | n_tx = int(sum(frame.tx)) 124 | n_samples = int(frame.shape[2]) 125 | 126 | return _reshape_frame_tdm( 127 | np.array(frame.data), 128 | platform, 129 | adc_output_fmt, 130 | rx_phase_bias, 131 | n_chirps, 132 | rx, 133 | n_rx, 134 | tx, 135 | n_tx, 136 | n_samples, 137 | flip_ods_phase=flip_ods_phase, 138 | ) 139 | 140 | 141 | @njit(cache=True) 142 | def _tdm(radar_cube, n_tx, n_rx): 143 | radar_cube_tdm = np.zeros( 144 | (radar_cube.shape[0] * n_tx, radar_cube.shape[1], radar_cube.shape[2]), 145 | dtype=np.complex64, 146 | ) 147 | 148 | for i in range(n_tx): 149 | radar_cube_tdm[i::n_tx, i * n_rx : (i + 1) * n_rx] = radar_cube[ 150 | :, i * n_rx : (i + 1) * n_rx 151 | ] 152 | 153 | return radar_cube_tdm 154 | 155 | 156 | @njit(cache=True) 157 | def _reshape_frame_tdm( 158 | data, 159 | platform, 160 | adc_output_fmt, 161 | rx_phase_bias, 162 | n_chirps, 163 | rx, 164 | n_rx, 165 | tx, 166 | n_tx, 167 | n_samples, 168 | flip_ods_phase=False, 169 | ): 170 | 171 | radar_cube = _reshape_frame( 172 | data, 173 | platform, 174 | adc_output_fmt, 175 | rx_phase_bias, 176 | n_chirps, 177 | rx, 178 | n_rx, 179 | tx, 180 | n_tx, 181 | n_samples, 182 | flip_ods_phase, 183 | ) 184 | 185 | radar_cube_tdm = _tdm(radar_cube, n_tx, n_rx) 186 | 187 | return radar_cube_tdm 188 | 189 | 190 | @njit(cache=True) 191 | def get_mean(x, axis=0): 192 | return np.sum(x, axis=axis) / x.shape[axis] 193 | 194 | 195 | @njit(cache=True) 196 | def cov_matrix(x): 197 | """Calculates the spatial covariance matrix (Rxx) for a given set of input data (x=inputData). 198 | Assumes rows denote Vrx axis. 199 | """ 200 | 201 | _, num_adc_samples = x.shape 202 | x_T = x.T 203 | Rxx = x @ np.conjugate(x_T) 204 | Rxx = np.divide(Rxx, num_adc_samples) 205 | 206 | return Rxx 207 | 208 | @njit(cache=True) 209 | def gen_steering_vec(ang_est_range, ang_est_resolution, num_ant): 210 | """Generate a steering vector for AOA estimation given the theta range, theta resolution, and number of antennas 211 | """ 212 | num_vec = (2 * ang_est_range + 1) / ang_est_resolution + 1 213 | num_vec = int(round(num_vec)) 214 | steering_vectors = np.zeros((num_vec, num_ant), dtype="complex64") 215 | for kk in range(num_vec): 216 | for jj in range(num_ant): 217 | mag = ( 218 | -1 219 | * np.pi 220 | * jj 221 | * np.sin((-ang_est_range - 1 + kk * ang_est_resolution) * np.pi / 180) 222 | ) 223 | real = np.cos(mag) 224 | imag = np.sin(mag) 225 | 226 | steering_vectors[kk, jj] = np.complex(real, imag) 227 | 228 | return (num_vec, steering_vectors) 229 | 230 | 231 | @njit(cache=True) 232 | def aoa_bartlett(steering_vec, sig_in): 233 | """ 234 | Perform AOA estimation using Bartlett Beamforming on a given input signal (sig_in). 235 | """ 236 | n_theta = steering_vec.shape[0] 237 | n_rx = sig_in.shape[1] 238 | n_range = sig_in.shape[2] 239 | y = np.zeros((sig_in.shape[0], n_theta, n_range), dtype="complex64") 240 | for i in range(sig_in.shape[0]): 241 | y[i] = np.conjugate(steering_vec) @ sig_in[i] 242 | return y 243 | 244 | 245 | @njit(cache=True) 246 | def aoa_capon(x, steering_vector): 247 | """ 248 | Perform AOA estimation using Capon (MVDR) Beamforming on a rx by chirp slice 249 | """ 250 | 251 | Rxx = cov_matrix(x) 252 | Rxx_inv = np.linalg.inv(Rxx).astype(np.complex64) 253 | first = Rxx_inv @ steering_vector.T 254 | den = np.zeros(first.shape[1], dtype=np.complex64) 255 | steering_vector_conj = steering_vector.conj() 256 | first_T = first.T 257 | for i in range(first_T.shape[0]): 258 | for j in range(first_T.shape[1]): 259 | den[i] += steering_vector_conj[i, j] * first_T[i, j] 260 | den = np.reciprocal(den) 261 | 262 | weights = first @ den 263 | 264 | return den, weights 265 | 266 | 267 | @njit(cache=True) 268 | def compute_range_azimuth(radar_cube, angle_res=1, angle_range=90, method="apes"): 269 | 270 | n_range_bins = radar_cube.shape[2] 271 | n_rx = radar_cube.shape[1] 272 | n_chirps = radar_cube.shape[0] 273 | n_angle_bins = (angle_range * 2 + 1) // angle_res + 1 274 | 275 | range_cube = np.zeros_like(radar_cube) 276 | with objmode(range_cube="complex128[:,:,:]"): 277 | range_cube = np.fft.fft(radar_cube, axis=2) 278 | range_cube = np.transpose(range_cube, (2, 1, 0)) 279 | range_cube = np.asarray(range_cube, dtype=np.complex64) 280 | 281 | range_cube_ = np.zeros( 282 | (range_cube.shape[0], range_cube.shape[1], range_cube.shape[2]), 283 | dtype=np.complex64, 284 | ) 285 | 286 | _, steering_vec = gen_steering_vec(angle_range, angle_res, n_rx) 287 | 288 | range_azimuth = np.zeros((n_range_bins, n_angle_bins), dtype=np.complex_) 289 | for r_idx in range(n_range_bins): 290 | range_cube_[r_idx] = range_cube[r_idx] 291 | steering_vec_ = steering_vec 292 | if method == "capon": 293 | range_azimuth[r_idx, :], _ = aoa_capon(range_cube_[r_idx], steering_vec_) 294 | else: 295 | raise ValueError("Unknown method") 296 | 297 | range_azimuth = np.log(np.abs(range_azimuth)) 298 | 299 | return range_azimuth 300 | 301 | @njit(cache=True) 302 | def compute_doppler_azimuth( 303 | radar_cube, 304 | angle_res=1, 305 | angle_range=90, 306 | range_initial_bin=0, 307 | range_subsampling_factor=2, 308 | ): 309 | 310 | n_chirps = radar_cube.shape[0] 311 | n_rx = radar_cube.shape[1] 312 | n_samples = radar_cube.shape[2] 313 | n_angle_bins = (angle_range * 2) // angle_res + 1 314 | 315 | # Subsample range bins. 316 | radar_cube_ = radar_cube[:, :, range_initial_bin::range_subsampling_factor] 317 | radar_cube_ -= get_mean(radar_cube_, axis=0) 318 | 319 | # Doppler processing. 320 | doppler_cube = np.zeros_like(radar_cube_) 321 | with objmode(doppler_cube="complex128[:,:,:]"): 322 | doppler_cube = np.fft.fft(radar_cube_, axis=0) 323 | doppler_cube = np.fft.fftshift(doppler_cube, axes=0) 324 | doppler_cube = np.asarray(doppler_cube, dtype=np.complex64) 325 | 326 | # Azimuth processing. 327 | _, steering_vec = gen_steering_vec(angle_range, angle_res, n_rx) 328 | 329 | doppler_azimuth_cube = aoa_bartlett(steering_vec, doppler_cube) 330 | # doppler_azimuth_cube = doppler_azimuth_cube[:,:,::5] 331 | doppler_azimuth_cube -= np.expand_dims( 332 | get_mean(doppler_azimuth_cube, axis=2), axis=2 333 | ) 334 | 335 | doppler_azimuth = np.log(get_mean(np.abs(doppler_azimuth_cube) ** 2, axis=2)) 336 | 337 | return doppler_azimuth 338 | 339 | 340 | def normalize(data, min_val=None, max_val=None): 341 | """ 342 | Normalize floats to [0.0, 1.0]. 343 | """ 344 | if min_val is None: 345 | min_val = np.min(data) 346 | if max_val is None: 347 | max_val = np.max(data) 348 | img = (((data - min_val) / (max_val - min_val)).clip(0.0, 1.0)).astype(data.dtype) 349 | return img 350 | 351 | def preprocess_1d_radar_1843( 352 | radar_cube, 353 | angle_res=1, 354 | angle_range=90, 355 | range_subsampling_factor=2, 356 | min_val=10.0, 357 | max_val=None, 358 | resize_shape=(48, 48), 359 | ): 360 | """ 361 | Turn radar cube into 1d doppler-azimuth heatmap. 362 | """ 363 | 364 | heatmap = compute_doppler_azimuth( 365 | radar_cube, 366 | angle_res, 367 | angle_range, 368 | range_subsampling_factor=range_subsampling_factor, 369 | ) 370 | 371 | heatmap = normalize(heatmap, min_val=min_val, max_val=max_val) 372 | 373 | heatmap = cv2.resize(heatmap, resize_shape, interpolation=cv2.INTER_AREA) 374 | 375 | return heatmap 376 | 377 | -------------------------------------------------------------------------------- /radarize/utils/grid_map.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Useful functions for grid mapping. 4 | """ 5 | 6 | import numpy as np 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | from numba import njit 10 | 11 | 12 | @njit(cache=True) 13 | def bresenham(start, end): 14 | """ 15 | Implementation of Bresenham's line drawing algorithm 16 | See en.wikipedia.org/wiki/Bresenham's_line_algorithm 17 | Bresenham's Line Algorithm 18 | Produces a np.array from start and end (original from roguebasin.com) 19 | >>> points1 = bresenham((4, 4), (6, 10)) 20 | >>> print(points1) 21 | np.array([[4,4], [4,5], [5,6], [5,7], [5,8], [6,9], [6,10]]) 22 | """ 23 | # setup initial conditions 24 | x1, y1 = start 25 | x2, y2 = end 26 | dx = x2 - x1 27 | dy = y2 - y1 28 | is_steep = abs(dy) > abs(dx) # determine how steep the line is 29 | if is_steep: # rotate line 30 | x1, y1 = y1, x1 31 | x2, y2 = y2, x2 32 | # swap start and end points if necessary and store swap state 33 | swapped = False 34 | if x1 > x2: 35 | x1, x2 = x2, x1 36 | y1, y2 = y2, y1 37 | swapped = True 38 | dx = x2 - x1 # recalculate differentials 39 | dy = y2 - y1 # recalculate differentials 40 | error = int(dx / 2.0) # calculate error 41 | y_step = 1 if y1 < y2 else -1 42 | # iterate over bounding box generating points between start and end 43 | y = y1 44 | points = [] 45 | for x in range(x1, x2 + 1): 46 | coord = (y, x) if is_steep else (x, y) 47 | points.append(coord) 48 | error -= abs(dy) 49 | if error < 0: 50 | y += y_step 51 | error += dx 52 | if swapped: # reverse the list if the coordinates were swapped 53 | points.reverse() 54 | points = np.array(points) 55 | return points 56 | 57 | 58 | @njit(cache=True) 59 | def flood_fill(occupancy_map, center_point, value): 60 | """ 61 | center_point: starting point (x,y) of fill 62 | occupancy_map: occupancy map generated from Bresenham ray-tracing 63 | """ 64 | # Fill empty areas with queue method 65 | sx, sy = occupancy_map.shape 66 | fringe = [] 67 | fringe.insert(0, center_point) 68 | while fringe: 69 | n = fringe.pop() 70 | nx, ny = n 71 | # West 72 | if nx > 0: 73 | if occupancy_map[nx - 1, ny] == 0.5: 74 | occupancy_map[nx - 1, ny] = value 75 | fringe.insert(0, (nx - 1, ny)) 76 | # East 77 | if nx < sx - 1: 78 | if occupancy_map[nx + 1, ny] == 0.5: 79 | occupancy_map[nx + 1, ny] = value 80 | fringe.insert(0, (nx + 1, ny)) 81 | # North 82 | if ny > 0: 83 | if occupancy_map[nx, ny - 1] == 0.5: 84 | occupancy_map[nx, ny - 1] = value 85 | fringe.insert(0, (nx, ny - 1)) 86 | # South 87 | if ny < sy - 1: 88 | if occupancy_map[nx, ny + 1] == 0.5: 89 | occupancy_map[nx, ny + 1] = value 90 | fringe.insert(0, (nx, ny + 1)) 91 | 92 | 93 | @njit(cache=True) 94 | def ray_cast(grid, start, end, value): 95 | 96 | beam = bresenham(start, end) # line 97 | 98 | x_w, y_w = grid.shape 99 | valid_x = np.logical_and(beam[:, 0] >= 0, beam[:, 0] <= x_w - 1) 100 | valid_y = np.logical_and(beam[:, 1] >= 0, beam[:, 1] <= y_w - 1) 101 | valid_mask = np.logical_and(valid_x, valid_y) 102 | valid_beam = beam[valid_mask] 103 | 104 | # grid[valid_beam] = value 105 | for pt in valid_beam: 106 | grid[pt[0], pt[1]] = value 107 | 108 | 109 | @njit(cache=True) 110 | def generate_ray_casting_grid_map(points, range_max, range_bins, hfov=39): 111 | """ 112 | The breshen boolean tells if it's computed with bresenham ray casting 113 | (True) or with flood fill (False) 114 | """ 115 | 116 | x_w = 2 * range_bins 117 | y_w = range_bins 118 | xy_resolution = range_max / range_bins 119 | min_x = -range_max 120 | max_x = range_max 121 | min_y = 0 122 | max_y = range_max 123 | center_x = range_bins 124 | center_y = 0 125 | center = np.array([center_x, center_y]) 126 | 127 | # Initialize occupancy map. 128 | occupancy_map = (np.ones((x_w, y_w)) * 255).astype(np.uint8) 129 | # print((int(np.sqrt(2)*range_bins*np.sin(np.deg2rad(hfov))), \ 130 | # int(np.sqrt(2)*range_bins*np.cos(np.deg2rad(hfov))))) 131 | ray_cast( 132 | occupancy_map, 133 | (center_x, center_y), 134 | ( 135 | center_x - int(np.sqrt(2) * range_bins * np.sin(np.deg2rad(hfov))), 136 | center_y + int(np.sqrt(2) * range_bins * np.cos(np.deg2rad(hfov))), 137 | ), 138 | 255, 139 | ) 140 | # cv2.namedWindow('depth', cv2.WINDOW_KEEPRATIO) 141 | # cv2.imshow('depth', occupancy_map) 142 | # cv2.waitKey() 143 | # print((int(np.sqrt(2)*range_bins*np.sin(np.deg2rad(hfov))), \ 144 | # int(np.sqrt(2)*range_bins*np.cos(np.deg2rad(hfov))))) 145 | ray_cast( 146 | occupancy_map, 147 | (center_x, center_y), 148 | ( 149 | center_x + int(np.sqrt(2) * range_bins * np.sin(np.deg2rad(hfov))), 150 | center_y + int(np.sqrt(2) * range_bins * np.cos(np.deg2rad(hfov))), 151 | ), 152 | 255, 153 | ) 154 | # cv2.namedWindow('depth', cv2.WINDOW_KEEPRATIO) 155 | # cv2.imshow('depth', occupancy_map) 156 | # cv2.waitKey() 157 | flood_fill(occupancy_map, (center_x, y_w - 1), 255) # unoccupied 255 158 | # cv2.namedWindow('depth', cv2.WINDOW_KEEPRATIO) 159 | # cv2.imshow('depth', occupancy_map) 160 | # cv2.waitKey() 161 | 162 | # Occupancy grid computed with bresenham ray casting 163 | for p in points: 164 | # Cull points farther than max range. 165 | if np.linalg.norm(p) >= range_max: 166 | continue 167 | 168 | x, y = p 169 | x_, y_ = (p / np.linalg.norm(p)) * np.sqrt(2) * range_max 170 | 171 | # x, y coordinate of the the free area 172 | ix = int(round((x - min_x) / xy_resolution)) 173 | iy = int(round((y - min_y) / xy_resolution)) 174 | 175 | # x, y coordinate of the unobserved area 176 | ix_ = int(round((x_ - min_x) / xy_resolution)) 177 | iy_ = int(round((y_ - min_y) / xy_resolution)) 178 | 179 | ray_cast(occupancy_map, (ix, iy), (ix_, iy_), 127) # unobserved 127 180 | 181 | # Obstacle 182 | if ix < 1 or ix >= x_w - 1 or iy < 1 or iy >= y_w - 1: 183 | continue 184 | occupancy_map[ix][iy] = 0 # occupied area 0 185 | occupancy_map[ix + 1][iy] = 0 # extend the occupied area 186 | occupancy_map[ix][iy + 1] = 0 # extend the occupied area 187 | occupancy_map[ix + 1][iy + 1] = 0 # extend the occupied area 188 | 189 | return np.rot90(occupancy_map) 190 | 191 | 192 | @njit(cache=True) 193 | def generate_ray_casting_polar_map(points, range_axis, angle_axis): 194 | """ 195 | The breshen boolean tells if it's computed with bresenham ray casting 196 | (True) or with flood fill (False) 197 | """ 198 | 199 | x_w = len(range_axis) 200 | y_w = len(angle_axis) 201 | 202 | # Initialize occupancy map. 203 | occupancy_map = (np.ones((x_w, y_w)) * 1).astype(np.uint8) 204 | 205 | # Occupancy grid computed with bresenham ray casting 206 | for p in points: 207 | # Cull points farther than max range. 208 | 209 | range_p = np.linalg.norm(p) 210 | angle_p = -1 * np.arctan2(p[0], p[1]) 211 | 212 | if range_p >= range_axis[-1]: 213 | continue 214 | if angle_p <= angle_axis[0] or angle_p >= angle_axis[-1]: 215 | continue 216 | 217 | # x, y coordinate of the the free area 218 | range_temp = np.zeros(x_w) 219 | for idx in range(x_w): 220 | range_temp[idx] = abs(range_p - range_axis[idx]) 221 | ix = np.argmin(range_temp) 222 | angle_temp = np.zeros(y_w) 223 | for idy in range(y_w): 224 | angle_temp[idy] = abs(angle_p - angle_axis[idy]) 225 | iy = np.argmin(angle_temp) 226 | 227 | # Obstacle 228 | if ix < 1 or ix >= x_w - 1 or iy < 1 or iy >= y_w - 1: 229 | continue 230 | occupancy_map[ix][iy] = 0 # occupied area 0 231 | occupancy_map[ix + 1][iy] = 0 # extend the occupied area 232 | occupancy_map[ix][iy + 1] = 0 # extend the occupied area 233 | occupancy_map[ix + 1][iy + 1] = 0 # extend the occupied area 234 | 235 | return occupancy_map 236 | -------------------------------------------------------------------------------- /radarize/utils/radar_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Wrapper for .cfg file used to configure radar EVM. 4 | 5 | Also computes derived parameters from the config using get_params(). 6 | """ 7 | 8 | from collections import OrderedDict 9 | import pprint 10 | 11 | 12 | class RadarConfig(OrderedDict): 13 | """Container for EVM config that gets sent through UART. 14 | 15 | Attributes: 16 | cmds: List of valid commands (developed for SDK 3.5.0.4) 17 | multi_cmds: List of commands that can appear more than once per config file. 18 | """ 19 | 20 | headers = [ 21 | "Created for SDK", 22 | "Platform", 23 | ] 24 | 25 | cmds = [ 26 | "dfeDataOutputMode", 27 | "channelCfg", 28 | "adcCfg", 29 | "adcbufCfg", 30 | # 'profileCfg', 31 | # 'chirpCfg', 32 | "lowPower", 33 | "frameCfg", 34 | "guiMonitor", 35 | # 'cfarCfg', 36 | "multiObjBeamForming", 37 | "calibDcRangeSig", 38 | "clutterRemoval", 39 | "aoaFovCfg", 40 | # 'cfarFovCfg', 41 | "compRangeBiasAndRxChanPhase", 42 | "measureRangeBiasAndRxChanPhase", 43 | "extendedMaxVelocity", 44 | "bpmCfg", 45 | # 'CQRxSatMonitor', 46 | # 'CQSigImgMonitor', 47 | "analogMonitor", 48 | "lvdsStreamCfg", 49 | "calibData", 50 | ] 51 | 52 | multi_cmds = [ 53 | "profileCfg", 54 | "chirpCfg", 55 | "cfarCfg", 56 | "cfarFovCfg", 57 | "CQRxSatMonitor", 58 | "CQSigImgMonitor", 59 | ] 60 | 61 | def __init__(self): 62 | super(RadarConfig, self).__init__() 63 | 64 | def __init__(self, cfg): 65 | """Initialize RadarConfig from a list of config commands or a dict of config commands. 66 | 67 | Args: 68 | cfg: List of strings, where each string is a config command. Alternatively, a dict mapping config commands to parameters. 69 | """ 70 | super(RadarConfig, self).__init__() 71 | if isinstance(cfg, list): 72 | self.from_cfg(cfg) 73 | elif isinstance(cfg, dict): 74 | for k, v in cfg.items(): 75 | self[k] = v 76 | 77 | def from_cfg(self, cfg): 78 | """Add commands to RadarConfig from a list of config commands. 79 | 80 | Args: 81 | cfg: List of strings, where each string is a config command. 82 | """ 83 | 84 | for line in cfg: 85 | for hdr in RadarConfig.headers: 86 | if hdr in line: 87 | params = line.split(":")[1].strip() 88 | self[hdr] = params 89 | break 90 | 91 | for cmd in RadarConfig.cmds: 92 | if cmd in line: 93 | params = [ 94 | float(x) if "." in x else int(x) for x in line.split()[1:] 95 | ] 96 | 97 | self[cmd] = params 98 | break 99 | 100 | for cmd in RadarConfig.multi_cmds: 101 | if cmd in line: 102 | params = [ 103 | float(x) if "." in x else int(x) for x in line.split()[1:] 104 | ] 105 | 106 | if cmd not in self.keys(): 107 | self[cmd] = [params] 108 | else: 109 | self[cmd].append(params) 110 | break 111 | 112 | def to_cfg(self): 113 | """Convert RadarConfig into a list of config commands. 114 | 115 | Args: 116 | cfg: List of strings, where each string is a config command. 117 | """ 118 | cfg = [] 119 | 120 | for cmd, params in self.items(): 121 | if isinstance(params[0], list): 122 | # Multi cmd case. 123 | for param in params: 124 | cfg.append( 125 | " ".join( 126 | [cmd] 127 | + [ 128 | f"{x:.2f}" if type(x) is float else f"{x:}" 129 | for x in param 130 | ] 131 | ) 132 | ) 133 | else: 134 | cfg.append( 135 | " ".join( 136 | [cmd] 137 | + [f"{x:.2f}" if type(x) is float else f"{x:}" for x in params] 138 | ) 139 | ) 140 | 141 | return cfg 142 | 143 | def get_params(self): 144 | """Returns number of samples, rx, tx, chirps, frame size, frame time, etc.""" 145 | 146 | sdk = self["Created for SDK"] 147 | platform = self["Platform"] 148 | 149 | adc_output_fmt = int( 150 | self["adcCfg"][1] 151 | ) # 0 - real, 1 - complex 1x, 2- complex 2x 152 | 153 | n_samples = int(self["profileCfg"][0][9]) 154 | 155 | rx_str = self["channelCfg"][0] 156 | rx_bin = bin(int(rx_str))[2:] 157 | rx = [int(x) for x in reversed(rx_bin)] 158 | 159 | tx_str = self["channelCfg"][1] 160 | tx_bin = bin(int(tx_str))[2:] 161 | tx = [int(x) for x in reversed(tx_bin)] 162 | 163 | n_chirps = (int(self["frameCfg"][1]) - int(self["frameCfg"][0]) + 1) * self[ 164 | "frameCfg" 165 | ][2] 166 | 167 | n_tx = sum(tx) 168 | n_rx = sum(rx) 169 | 170 | frame_size = n_samples * n_rx * n_chirps * 2 * (2 if adc_output_fmt > 0 else 1) 171 | frame_time = self["frameCfg"][4] 172 | 173 | range_bias = self["compRangeBiasAndRxChanPhase"][0] 174 | # rx_phase_bias = [a + 1j*b for a,b in zip(self['compRangeBiasAndRxChanPhase'][1::2], 175 | # self['compRangeBiasAndRxChanPhase'][2::2])] 176 | rx_phase_bias = self["compRangeBiasAndRxChanPhase"][1:] 177 | 178 | operating_freq = self["profileCfg"][0][1] # Units in GHz 179 | chirp_time = ( 180 | self["profileCfg"][0][2] + self["profileCfg"][0][4] 181 | ) # Idle time + ramp time, Units in usec 182 | velocity_max = (3e8 / (operating_freq * 1e9)) / ( 183 | 4 * (chirp_time * 1e-6) 184 | ) # Units in m 185 | velocity_res = velocity_max / n_chirps # Units in m 186 | 187 | chirp_slope = self["profileCfg"][0][7] * 1e12 # Units in MHz/usec 188 | sample_rate = self["profileCfg"][0][10] * 1e3 # Units in ksps 189 | range_max = (sample_rate * 3e8) / (2 * chirp_slope) # Units in m 190 | range_res = range_max / n_samples # Units in m 191 | 192 | return OrderedDict( 193 | [ 194 | ("sdk", sdk), 195 | ("platform", platform), 196 | ("adc_output_fmt", adc_output_fmt), 197 | ("range_bias", range_bias), 198 | ("rx_phase_bias", rx_phase_bias), 199 | ("n_chirps", n_chirps), 200 | ("rx", rx), 201 | ("n_rx", n_rx), 202 | ("tx", tx), 203 | ("n_tx", n_tx), 204 | ("n_samples", n_samples), 205 | ("frame_size", frame_size), 206 | ("frame_time", frame_time), 207 | ("chirp_time", chirp_time), 208 | ("chirp_slope", chirp_slope), 209 | ("sample_rate", sample_rate), 210 | ("velocity_max", velocity_max), 211 | ("velocity_res", velocity_res), 212 | ("range_max", range_max), 213 | ("range_res", range_res), 214 | ] 215 | ) 216 | 217 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.5 2 | scipy==1.10.0 3 | matplotlib==3.6.3 4 | pyqt5==5.15.10 5 | numba==0.56.4 6 | opencv-python==4.6.0.66 7 | tqdm 8 | pyyaml 9 | yacs 10 | easydict 11 | pillow==9.4.0 12 | imageio==2.25.0 13 | imageio-ffmpeg==0.4.8 14 | pytransform3d==2.2.0 15 | py3rosmsgs==1.18.2 16 | rosbags==0.9.13 17 | bagpy==0.5 18 | --extra-index-url https://rospypi.github.io/simple/ 19 | roslz4 20 | evo==1.21.0 21 | cvbridge3==1.1 22 | open3d==0.16.0 23 | pycryptodome==3.17 24 | einops==0.6.0 25 | torch-summary==1.4.5 26 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | for file in configs/*.yaml; do 6 | if [ "$file" != "configs/default.yaml" ]; then 7 | echo "Running $file" 8 | python main.py --cfg="$file" 9 | fi 10 | done 11 | -------------------------------------------------------------------------------- /run_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | for file in configs/*.yaml; do 6 | if [ "$file" != "configs/default.yaml" ]; then 7 | echo "Running $file" 8 | python main_eval.py --cfg="$file" 9 | fi 10 | done 11 | 12 | # Get odometry results. 13 | ./odom_eval.sh 14 | 15 | # Get SLAM results. 16 | ./slam_eval.sh 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='radarize', 5 | version='1.0', 6 | packages=find_packages(include=['radarize']) 7 | ) 8 | -------------------------------------------------------------------------------- /slam_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | > log 6 | for d in "gt_gt_default" "gt_radarhd_scan_only" "rnin_radarhd_default" "milliego_radarhd_default" "odometry_radarhd_radar" "odometry_unet_radar" 7 | do 8 | for f in "main_0" "main_1" "main_2" "main_3" "main_4" 9 | do 10 | ./tools/eval_traj.py --cfg=configs/$f.yaml --input=$d/output | tee -a log 11 | done 12 | done 13 | 14 | > slam_result.txt 15 | for d in "gt_radarhd_scan_only" "rnin_radarhd_default" "milliego_radarhd_default" "odometry_radarhd_radar" "odometry_unet_radar" 16 | do 17 | for f in "ape_trans" "ape_rot" "rpe_trans" "rpe_rot" 18 | do 19 | average=$(echo "($(cat log | grep $d | grep $f | cut -d' ' -f4 | paste -s -d+))/5" | bc -l) 20 | echo "$d $f ${average}" | tee -a slam_result.txt 21 | done 22 | done 23 | -------------------------------------------------------------------------------- /tools/eval_traj.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Compare trajectories 5 | https://github.com/MichaelGrupp/evo/blob/master/notebooks/metrics.py_API_Documentation.ipynb 6 | """ 7 | 8 | import os 9 | import sys 10 | 11 | import pprint 12 | from evo.core import metrics 13 | from evo.tools import file_interface 14 | from evo.core import sync 15 | import copy 16 | import numpy as np 17 | import argparse 18 | 19 | from evo.tools import plot 20 | import matplotlib.pyplot as plt 21 | 22 | from radarize.config import cfg, update_config 23 | 24 | def args(): 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument( 28 | "--cfg", help="experiment configure file name", default=None, type=str 29 | ) 30 | parser.add_argument( 31 | "opts", 32 | help="Modify config options using the command-line", 33 | default=None, 34 | nargs=argparse.REMAINDER, 35 | ) 36 | parser.add_argument( 37 | "--input", help="Directory name __.", required=True 38 | ) 39 | args = parser.parse_args() 40 | 41 | return args 42 | 43 | 44 | def load_trajs(ref_file, est_file): 45 | traj_ref = file_interface.read_tum_trajectory_file(ref_file) 46 | traj_est = file_interface.read_tum_trajectory_file(est_file) 47 | 48 | max_diff = 0.01 49 | 50 | # Sync trajectories. 51 | traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est, max_diff) 52 | # Align trajectories w/o scaling. 53 | traj_est_aligned = copy.deepcopy(traj_est) 54 | traj_est_aligned.align(traj_ref, correct_scale=False, correct_only_scale=False) 55 | # Align origin. 56 | traj_est_aligned_origin = copy.deepcopy(traj_est) 57 | traj_est_aligned_origin.align_origin(traj_ref) 58 | 59 | # fig = plt.figure() 60 | # traj_by_label = { 61 | # "estimate (align origin)": traj_est_aligned_origin, 62 | # "estimate (aligned)": traj_est_aligned, 63 | # "reference": traj_ref 64 | # } 65 | # plot.trajectories(fig, traj_by_label, plot.PlotMode.xy) 66 | # plt.savefig(est_file.replace('/output/', '/result/').replace('.txt', '.png')) 67 | # plt.close(fig) 68 | 69 | data = (traj_ref, traj_est_aligned) 70 | 71 | return data 72 | 73 | 74 | # return mean 75 | def get_stat(metric_type, pose_relation_type, data): 76 | metric_types = ["ape", "rpe"] 77 | pose_relation_types = ["translation", "rotation_angle", "rotation", "full"] 78 | if metric_type not in metric_types: 79 | raise ValueError("Invalid metric type. Expected one of: %s" % metric_types) 80 | if pose_relation_type not in pose_relation_types: 81 | raise ValueError( 82 | "Invalid pose_relation type. Expected one of: %s" % pose_relation_types 83 | ) 84 | 85 | if pose_relation_type == "translation": 86 | pose_relation = metrics.PoseRelation.translation_part 87 | if pose_relation_type == "rotation_angle": 88 | pose_relation = metrics.PoseRelation.rotation_angle_rad 89 | if pose_relation_type == "rotation": 90 | pose_relation = metrics.PoseRelation.rotation_part 91 | if pose_relation_type == "full": 92 | pose_relation = metrics.PoseRelation.full_transformation 93 | 94 | if metric_type == "ape": 95 | metric = metrics.APE(pose_relation) 96 | if metric_type == "rpe": 97 | delta = 1 98 | delta_unit = metrics.Unit.frames 99 | all_pairs = False 100 | metric = metrics.RPE( 101 | pose_relation=pose_relation, 102 | delta=delta, 103 | delta_unit=delta_unit, 104 | all_pairs=all_pairs, 105 | ) 106 | 107 | metric.process_data(data) 108 | return metric.get_statistic(metrics.StatisticsType.mean) 109 | 110 | 111 | if __name__ == "__main__": 112 | args = args() 113 | update_config(cfg, args) 114 | 115 | ape_translation = np.array([]) 116 | ape_rotation_angle = np.array([]) 117 | ape_rotation = np.array([]) 118 | ape_full = np.array([]) 119 | rpe_translation = np.array([]) 120 | rpe_rotation_angle = np.array([]) 121 | rpe_rotation = np.array([]) 122 | rpe_full = np.array([]) 123 | 124 | # Create dir. 125 | result_dir = os.path.join(cfg["OUTPUT_DIR"], args.input, "result") 126 | if not os.path.exists(result_dir): 127 | os.makedirs(result_dir) 128 | 129 | for x in cfg["DATASET"]["TEST_SPLIT"]: 130 | ref_file = os.path.join( 131 | cfg["OUTPUT_DIR"], "gt_gt_default", "output", f"{x}.txt" 132 | ) 133 | est_file = os.path.join(cfg["OUTPUT_DIR"], args.input, f"{x}.txt") 134 | data = load_trajs(ref_file, est_file) 135 | ape_translation = np.append( 136 | ape_translation, get_stat("ape", "translation", data) 137 | ) 138 | rpe_translation = np.append( 139 | rpe_translation, get_stat("rpe", "translation", data) 140 | ) 141 | ape_rotation_angle = np.append( 142 | ape_rotation_angle, get_stat("ape", "rotation_angle", data) 143 | ) 144 | rpe_rotation_angle = np.append( 145 | ape_rotation_angle, get_stat("rpe", "rotation_angle", data) 146 | ) 147 | 148 | np.savez( 149 | os.path.join(cfg["OUTPUT_DIR"], args.input, "result", "traj_eval.npz"), 150 | ape_trans=ape_translation, 151 | rpe_trans=rpe_translation, 152 | ape_rot=ape_rotation_angle, 153 | rpe_rot=rpe_rotation_angle, 154 | ) 155 | 156 | print(f"{cfg['OUTPUT_DIR']}{args.input} ape_trans: ", np.mean(ape_translation)) 157 | print(f"{cfg['OUTPUT_DIR']}{args.input} rpe_trans: ", np.mean(rpe_translation)) 158 | print(f"{cfg['OUTPUT_DIR']}{args.input} ape_rot: ", np.mean(ape_rotation_angle)) 159 | print(f"{cfg['OUTPUT_DIR']}{args.input} rpe_rot: ", np.mean(rpe_rotation_angle)) 160 | 161 | if args.input != "gt_gt_default": 162 | gt_gt_default = np.load( 163 | os.path.join( 164 | cfg["OUTPUT_DIR"], "gt_gt_default/output", "result", "traj_eval.npz" 165 | ) 166 | ) 167 | 168 | # Plot APE and RPE tables. 169 | 170 | fig = plt.figure() 171 | # Make sure to sort the x values before plotting. 172 | x = np.array(cfg["DATASET"]["TEST_SPLIT"]) 173 | idxs = np.argsort(x) 174 | x = x[idxs] 175 | plt.bar(x, ape_translation[idxs]) 176 | if args.input != "gt_gt_default": 177 | plt.bar(x, gt_gt_default["ape_trans"][idxs]) 178 | plt.xlabel("Dataset") 179 | plt.ylabel("APE translation") 180 | ax = plt.gca() 181 | plt.draw() 182 | ax.set_xticks(ax.get_xticks(), x, rotation=45, ha="right") 183 | spacing = 0.5 184 | fig.subplots_adjust(bottom=spacing) 185 | plt.title(f"{args.input} mAPE: {np.mean(ape_translation):.3f}") 186 | plt.savefig(os.path.join(cfg["OUTPUT_DIR"], args.input, "result", "ape_trans.png")) 187 | plt.close(fig) 188 | 189 | fig = plt.figure() 190 | # Make sure to sort the x values before plotting. 191 | x = np.array(cfg["DATASET"]["TEST_SPLIT"]) 192 | idxs = np.argsort(x) 193 | x = x[idxs] 194 | plt.bar(x, rpe_translation[idxs]) 195 | if args.input != "gt_gt_default": 196 | plt.bar(x, gt_gt_default["rpe_trans"][idxs]) 197 | plt.xlabel("Dataset") 198 | plt.ylabel("RPE translation") 199 | ax = plt.gca() 200 | plt.draw() 201 | ax.set_xticks(ax.get_xticks(), x, rotation=45, ha="right") 202 | spacing = 0.5 203 | fig.subplots_adjust(bottom=spacing) 204 | plt.title(f"{args.input} mRPE: {np.mean(rpe_translation):.3f}") 205 | plt.savefig(os.path.join(cfg["OUTPUT_DIR"], args.input, "result", "rpe_trans.png")) 206 | plt.close(fig) 207 | 208 | fig = plt.figure() 209 | # Make sure to sort the x values before plotting. 210 | x = np.array(cfg["DATASET"]["TEST_SPLIT"]) 211 | idxs = np.argsort(x) 212 | x = x[idxs] 213 | plt.bar(x, ape_rotation_angle[idxs]) 214 | if args.input != "gt_gt_default": 215 | plt.bar(x, gt_gt_default["ape_rot"][idxs]) 216 | plt.xlabel("Dataset") 217 | plt.ylabel("APE rotation angle") 218 | ax = plt.gca() 219 | plt.draw() 220 | ax.set_xticks(ax.get_xticks(), x, rotation=45, ha="right") 221 | spacing = 0.5 222 | fig.subplots_adjust(bottom=spacing) 223 | plt.title(f"{args.input} mAPE: {np.mean(ape_rotation_angle):.3f}") 224 | plt.savefig(os.path.join(cfg["OUTPUT_DIR"], args.input, "result", "ape_rot.png")) 225 | plt.close(fig) 226 | 227 | fig = plt.figure() 228 | # Make sure to sort the x values before plotting. 229 | x = np.array(cfg["DATASET"]["TEST_SPLIT"]) 230 | idxs = np.argsort(x) 231 | x = x[idxs] 232 | plt.bar(x, rpe_rotation_angle[idxs]) 233 | if args.input != "gt_gt_default": 234 | plt.bar(x, gt_gt_default["rpe_rot"][idxs]) 235 | plt.xlabel("Dataset") 236 | plt.ylabel("RPE rotation angle") 237 | ax = plt.gca() 238 | plt.draw() 239 | ax.set_xticks(ax.get_xticks(), x, rotation=45, ha="right") 240 | spacing = 0.5 241 | fig.subplots_adjust(bottom=spacing) 242 | plt.title(f"{args.input} mRPE: {np.mean(rpe_rotation_angle):.3f}") 243 | plt.savefig(os.path.join(cfg["OUTPUT_DIR"], args.input, "result", "rpe_rot.png")) 244 | plt.close(fig) 245 | -------------------------------------------------------------------------------- /tools/export_cartographer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Creates input .bags to feed into Cartographer. 5 | """ 6 | 7 | import os 8 | import sys 9 | 10 | import matplotlib 11 | import numpy as np 12 | 13 | matplotlib.use("agg") 14 | import argparse 15 | from collections import defaultdict 16 | 17 | import rosbag 18 | from rospy import Time 19 | from nav_msgs.msg import Odometry 20 | from scipy.interpolate import interp1d 21 | from scipy.spatial.transform import Rotation as R 22 | from scipy.spatial.transform import Slerp 23 | from sensor_msgs.msg import LaserScan 24 | from tqdm import tqdm 25 | 26 | from radarize.config import cfg, update_config 27 | 28 | def heatmap2range(heatmap, range_bins, range_min, range_max): 29 | """Convert range-azimuth heatmap to range scan.""" 30 | 31 | # Assign range value to each angle bin 32 | range_value = np.linspace(range_min, range_max, range_bins) 33 | range_scan = range_value[np.argmin(heatmap, axis=0)] 34 | 35 | # Filter angle bins without collision 36 | range_scan[np.min(heatmap, axis=0) > 0.85] = range_max + 1 37 | 38 | return range_scan 39 | 40 | 41 | def sensorDataPreprocessing(odom_path, scan_path, output_path): 42 | # Load odom from TUM format and scans 43 | odom_data = np.loadtxt(odom_path) 44 | scan_data = np.load(scan_path, allow_pickle=True) 45 | 46 | odom_times = odom_data[:, 0] 47 | odom_poses = odom_data[:, 1:] 48 | scan_times = scan_data["time"] 49 | scan_heatmaps = scan_data["depth_map"] 50 | # Trim scan to match odom 51 | valid_idx = np.logical_and(scan_times > odom_times[0], scan_times < odom_times[-1]) 52 | scan_times = scan_times[valid_idx] 53 | scan_heatmaps = scan_heatmaps[valid_idx] 54 | 55 | print(odom_times[0], odom_times[-1]) 56 | print(scan_times[0], scan_times[-1]) 57 | print("odom_times", odom_times.shape) 58 | print("odom_poses", odom_poses.shape) 59 | print("scan_times", scan_times.shape) 60 | print("scan_heatmaps", scan_heatmaps.shape) 61 | 62 | # Interpolate odom to match scan times 63 | odom_p_interp = interp1d( 64 | odom_times, odom_poses[:, :3], kind="linear", axis=0, fill_value="extrapolate" 65 | )(scan_times) 66 | odom_q_interp = Slerp(odom_times, R.from_quat(odom_poses[:, 3:]))(scan_times) 67 | 68 | # Write to bag 69 | with rosbag.Bag(output_path, "w") as bag: 70 | last_ts = Time.from_sec(0.0) 71 | for i in tqdm(range(scan_heatmaps.shape[0])): 72 | 73 | ts = Time.from_sec(scan_times[i]) 74 | if ts < last_ts: 75 | print("Warning: timestamp is not increasing") 76 | 77 | # Create laser scan message 78 | range_bins, angle_bins = scan_heatmaps[i].shape[1:] 79 | range_msg = LaserScan() 80 | range_msg.header.stamp = ts 81 | range_msg.header.frame_id = "horizontal_laser_link" 82 | range_msg.angle_min = np.deg2rad(cfg["DATASET"]["RA"]["RA_MIN"]) 83 | range_msg.angle_max = np.deg2rad(cfg["DATASET"]["RA"]["RA_MAX"]) 84 | range_msg.angle_increment = np.deg2rad( 85 | (cfg["DATASET"]["RA"]["RA_MAX"] - cfg["DATASET"]["RA"]["RA_MIN"]) 86 | / angle_bins 87 | ) 88 | range_msg.time_increment = 0.0 89 | range_msg.scan_time = 33e-3 90 | range_msg.range_min = cfg["DATASET"]["RA"]["RR_MIN"] 91 | range_msg.range_max = cfg["DATASET"]["RA"]["RR_MAX"] 92 | range_msg.ranges = heatmap2range( 93 | np.squeeze(scan_heatmaps[i]), 94 | range_bins, 95 | cfg["DATASET"]["RA"]["RR_MIN"], 96 | cfg["DATASET"]["RA"]["RR_MAX"], 97 | ) 98 | 99 | bag.write("scan", range_msg, ts) 100 | 101 | # Create odometry message 102 | odom_msg = Odometry() 103 | odom_msg.header.stamp = ts 104 | odom_msg.header.frame_id = "map" 105 | odom_msg.child_frame_id = "base_link" 106 | odom_msg.pose.pose.position.x = odom_p_interp[i, 0] 107 | odom_msg.pose.pose.position.y = odom_p_interp[i, 1] 108 | odom_msg.pose.pose.position.z = odom_p_interp[i, 2] 109 | odom_msg.pose.pose.orientation.x = odom_q_interp[i].as_quat()[0] 110 | odom_msg.pose.pose.orientation.y = odom_q_interp[i].as_quat()[1] 111 | odom_msg.pose.pose.orientation.z = odom_q_interp[i].as_quat()[2] 112 | odom_msg.pose.pose.orientation.w = odom_q_interp[i].as_quat()[3] 113 | 114 | bag.write("odom", odom_msg, ts) 115 | 116 | 117 | def args(): 118 | parser = argparse.ArgumentParser() 119 | 120 | parser.add_argument( 121 | "--cfg", help="experiment configure file name", default=None, type=str 122 | ) 123 | parser.add_argument( 124 | "opts", 125 | help="Modify config options using the command-line", 126 | default=None, 127 | nargs=argparse.REMAINDER, 128 | ) 129 | parser.add_argument("--odom_path", help="Path to TUM odometry file.", required=True) 130 | parser.add_argument("--scan_path", help="Path to range scans.", required=True) 131 | parser.add_argument( 132 | "--output_path", help="Path to output directory.", required=True 133 | ) 134 | args = parser.parse_args() 135 | 136 | return args 137 | 138 | 139 | if __name__ == "__main__": 140 | args = args() 141 | update_config(cfg, args) 142 | 143 | print(f"Processing odom path {args.odom_path} + scan path {args.scan_path}...") 144 | 145 | sensorDataPreprocessing(args.odom_path, args.scan_path, args.output_path) 146 | -------------------------------------------------------------------------------- /tools/extract_gt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Extract ground-truth from .bag file. 5 | """ 6 | 7 | import os 8 | import sys 9 | 10 | import argparse 11 | 12 | import numpy as np 13 | np.set_printoptions(precision=3, floatmode="fixed", sign=" ") 14 | 15 | from radarize.config import cfg, update_config 16 | 17 | def args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | "--cfg", help="experiment configure file name", default=None, type=str 21 | ) 22 | parser.add_argument("--npz_path", help="Path to npz.", default=None, required=False) 23 | parser.add_argument( 24 | "opts", 25 | help="Modify config options using the command-line", 26 | default=None, 27 | nargs=argparse.REMAINDER, 28 | ) 29 | args = parser.parse_args() 30 | 31 | return args 32 | 33 | 34 | if __name__ == "__main__": 35 | args = args() 36 | update_config(cfg, args) 37 | 38 | # Create output dir. 39 | out_dir = os.path.join(os.path.join(cfg["OUTPUT_DIR"], "gt")) 40 | if not os.path.exists(out_dir): 41 | os.makedirs(out_dir, exist_ok=True) 42 | 43 | # Get list of npz files. 44 | if args.npz_path: 45 | npz_paths = [args.npz_path] 46 | else: 47 | npz_paths = sorted( 48 | glob.glob(os.path.join(cfg["DATASET"]["TEST_PATH"], "*.npz")) 49 | ) 50 | 51 | for npz_path in npz_paths: 52 | print(f"Processing {npz_path}...") 53 | # Load npz. 54 | with np.load(npz_path) as npz: 55 | data = { 56 | k: npz[k] for k in npz.files if k in ["time", "pose_gt", "depth_map"] 57 | } 58 | 59 | for k, v in data.items(): 60 | print(f"{k}: {v.shape}") 61 | 62 | basename = os.path.basename(npz_path) 63 | 64 | # Save trajectory. 65 | trajectory = np.concatenate( 66 | [ 67 | data["time"].reshape(-1, 1), 68 | data["pose_gt"], 69 | ], 70 | axis=1, 71 | ) 72 | np.savetxt( 73 | os.path.join(out_dir, basename.replace(".npz", ".txt")), 74 | trajectory, 75 | delimiter=" ", 76 | ) 77 | 78 | # Save depth map. 79 | np.savez( 80 | os.path.join(out_dir, basename), 81 | time=data["time"], 82 | depth_map=data["depth_map"], 83 | ) 84 | -------------------------------------------------------------------------------- /tools/odombag_to_txt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Extract trajectory from .bag file. 5 | """ 6 | 7 | import os 8 | import sys 9 | import rosbag 10 | 11 | import argparse 12 | 13 | import numpy as np 14 | from tqdm import tqdm 15 | 16 | np.set_printoptions(precision=3, floatmode="fixed", sign=" ") 17 | 18 | 19 | def args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--bag_path", help="Path to .bag odometry file.", default=None, required=False 23 | ) 24 | parser.add_argument( 25 | "opts", 26 | help="Modify config options using the command-line", 27 | default=None, 28 | nargs=argparse.REMAINDER, 29 | ) 30 | args = parser.parse_args() 31 | 32 | return args 33 | 34 | 35 | def extract_msg(bag): 36 | 37 | pose_topic = "trajectory_0" 38 | 39 | pose_ts, pose_msgs = [], [] 40 | last_ts = None 41 | 42 | for topic, msg, ts in tqdm( 43 | bag.read_messages([pose_topic]), total=bag.get_message_count([pose_topic]) 44 | ): 45 | curr_ts = ts.secs + 1e-9 * ts.nsecs 46 | if last_ts is None: 47 | last_ts = curr_ts 48 | continue 49 | elif curr_ts - last_ts < 33e-3: 50 | continue 51 | else: 52 | pose_msgs.append( 53 | np.array( 54 | [ 55 | msg.transform.translation.x, 56 | msg.transform.translation.y, 57 | msg.transform.translation.z, 58 | msg.transform.rotation.x, 59 | msg.transform.rotation.y, 60 | msg.transform.rotation.z, 61 | msg.transform.rotation.w, 62 | ] 63 | ) 64 | ) 65 | pose_ts.append(curr_ts) 66 | last_ts = curr_ts 67 | 68 | pose_ts = np.array(pose_ts) 69 | pose_msgs = np.array(pose_msgs) 70 | return pose_ts, pose_msgs 71 | 72 | 73 | if __name__ == "__main__": 74 | args = args() 75 | 76 | print(f"Processing {args.bag_path}...") 77 | 78 | # Open bag file. 79 | bag = rosbag.Bag(args.bag_path) 80 | 81 | pose_ts, pose_msgs = extract_msg(bag) 82 | 83 | # Save trajectory. 84 | trajectory = np.concatenate( 85 | [ 86 | pose_ts.reshape(-1, 1), 87 | pose_msgs, 88 | ], 89 | axis=1, 90 | ) 91 | np.savetxt( 92 | os.path.join(args.bag_path.replace(".bag", ".txt")), trajectory, delimiter=" " 93 | ) 94 | -------------------------------------------------------------------------------- /tools/run_carto.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Run Cartographer on input bag. 5 | """ 6 | 7 | import argparse 8 | import glob 9 | import multiprocessing 10 | import os 11 | import subprocess 12 | import sys 13 | import os.path as osp 14 | 15 | from subprocess import Popen 16 | from radarize.config import cfg, update_config 17 | 18 | def args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "--cfg", 22 | help="experiment configure file name", 23 | default="configs/default.yaml", 24 | type=str, 25 | ) 26 | parser.add_argument( 27 | "--n_proc", 28 | type=int, 29 | default=4, 30 | help="Number of processes to use for parallel processing.", 31 | ) 32 | parser.add_argument( 33 | "opts", 34 | help="Modify config options using the command-line", 35 | default=None, 36 | nargs=argparse.REMAINDER, 37 | ) 38 | parser.add_argument("--odom", help="Odometry type.", required=True) 39 | parser.add_argument("--scan", help="Scan type.", required=True) 40 | parser.add_argument("--params", help="Cartographer parameters.", required=True) 41 | parser.add_argument("--demo", action="store_true", help="Turn Rviz on") 42 | args = parser.parse_args() 43 | 44 | return args 45 | 46 | 47 | def run_commands(cmds, n_proc): 48 | print("Commands are: ") 49 | print(cmds) 50 | with multiprocessing.Pool(n_proc) as pool: 51 | pool.map(subprocess.run, cmds) 52 | 53 | 54 | if __name__ == "__main__": 55 | args = args() 56 | update_config(cfg, args) 57 | 58 | print( 59 | f"Running cartographer on Odometry {args.odom} and Scan {args.scan} with Params {args.params}..." 60 | ) 61 | 62 | # Create dir. 63 | carto_in_dir = osp.abspath(osp.join( 64 | cfg["OUTPUT_DIR"], args.odom + "_" + args.scan + "_" + args.params, "input" 65 | )) 66 | carto_out_dir = osp.abspath(osp.join( 67 | cfg["OUTPUT_DIR"], args.odom + "_" + args.scan + "_" + args.params, "output" 68 | )) 69 | if not osp.exists(carto_in_dir): 70 | os.makedirs(carto_in_dir) 71 | if not osp.exists(carto_out_dir): 72 | os.makedirs(carto_out_dir) 73 | 74 | # Assume these dir exist, and .txt are all TUM files 75 | odom_dir = osp.abspath(osp.join(cfg["OUTPUT_DIR"], args.odom)) 76 | scan_dir = osp.abspath(osp.join(cfg["OUTPUT_DIR"], args.scan)) 77 | 78 | # Convert odom + scans into bags. 79 | run_commands( 80 | [ 81 | [ 82 | "tools/export_cartographer.py", 83 | f"--cfg={args.cfg}", 84 | f'--odom_path={osp.abspath(osp.join(odom_dir, x+".txt"))}', 85 | f'--scan_path={osp.abspath(osp.join(scan_dir, x+".npz"))}', 86 | f'--output_path={osp.abspath(osp.join(carto_in_dir, x+".bag"))}', 87 | ] 88 | for x in cfg["DATASET"]["TEST_SPLIT"] 89 | ], 90 | args.n_proc, 91 | ) 92 | 93 | # Run cartographer on all test bags. 94 | for x in cfg["DATASET"]["TEST_SPLIT"]: 95 | if args.demo: 96 | subprocess.run( 97 | [ 98 | "roslaunch", 99 | "cartographer_ros", 100 | "demo_backpack_2d.launch", 101 | f"configuration_basename:={args.params}.lua", 102 | f'bag_filename:={osp.abspath(osp.join(carto_in_dir, x+".bag"))}', 103 | ] 104 | ) 105 | else: 106 | subprocess.run( 107 | [ 108 | "roslaunch", 109 | "cartographer_ros", 110 | "offline_backpack_2d.launch", 111 | f"configuration_basenames:={args.params}.lua", 112 | f'bag_filenames:={osp.abspath(osp.join(carto_in_dir, x+".bag"))}', 113 | ] 114 | ) 115 | 116 | # Convert cartographer output to rosbag. 117 | run_commands( 118 | [ 119 | [ 120 | "cartographer_dev_pbstream_trajectories_to_rosbag", 121 | f'-input={osp.abspath(osp.join(carto_in_dir, x+".bag.pbstream"))}', 122 | f'-output={osp.abspath(osp.join(carto_out_dir, x+".bag"))}', 123 | ] 124 | for x in cfg["DATASET"]["TEST_SPLIT"] 125 | ], 126 | args.n_proc, 127 | ) 128 | # Convert rosbag to TUM format. 129 | run_commands( 130 | [ 131 | [ 132 | "tools/odombag_to_txt.py", 133 | f'--bag_path={osp.abspath(osp.join(carto_out_dir, x+".bag"))}', 134 | ] 135 | for x in cfg["DATASET"]["TEST_SPLIT"] 136 | ], 137 | args.n_proc, 138 | ) 139 | 140 | # Evaluate trajectory. 141 | # subprocess.run(['./eval_traj.py', 142 | # f'--cfg={args.cfg}', 143 | # f'--input={args.odom}_{args.scan}_{args.params}/output'], check=True) 144 | 145 | # Convert cartographer output to PNG. 146 | run_commands( 147 | [ 148 | [ 149 | "cartographer_pbstream_to_ros_map", 150 | f'-pbstream_filename={osp.abspath(osp.join(carto_in_dir, x+".bag.pbstream"))}', 151 | f"-map_filestem={osp.abspath(osp.join(carto_out_dir, x))}", 152 | ] 153 | for x in cfg["DATASET"]["TEST_SPLIT"] 154 | ], 155 | args.n_proc, 156 | ) 157 | subprocess.run( 158 | ["mogrify", "-format", "png", osp.abspath(osp.join(carto_out_dir, "*.pgm"))], check=True 159 | ) 160 | -------------------------------------------------------------------------------- /tools/test_flow.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | 6 | import argparse 7 | import torch 8 | from tqdm import tqdm 9 | import numpy as np 10 | 11 | np.set_printoptions(precision=3, floatmode="fixed", sign=" ") 12 | import matplotlib 13 | 14 | matplotlib.use("Agg") 15 | import matplotlib.pyplot as plt 16 | 17 | from radarize.flow import model 18 | from radarize.flow.dataloader import FlowDataset 19 | from radarize.utils import image_tools 20 | from radarize.config import cfg, update_config 21 | 22 | 23 | def args(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | "--cfg", help="experiment configure file name", default=None, type=str 27 | ) 28 | parser.add_argument("--npz_path", help="Path to npz.", default=None, required=False) 29 | parser.add_argument("--no_cuda", action="store_true") 30 | parser.add_argument( 31 | "opts", 32 | help="Modify config options using the command-line", 33 | default=None, 34 | nargs=argparse.REMAINDER, 35 | ) 36 | args = parser.parse_args() 37 | 38 | return args 39 | 40 | 41 | if __name__ == "__main__": 42 | args = args() 43 | update_config(cfg, args) 44 | 45 | device = torch.device("cpu" if args.no_cuda else "cuda") 46 | 47 | # Load Trained NN 48 | saved_model = torch.load( 49 | os.path.join( 50 | cfg["OUTPUT_DIR"], 51 | cfg["FLOW"]["MODEL"]["NAME"], 52 | f"{cfg['FLOW']['MODEL']['NAME']}.pth", 53 | ) 54 | ) 55 | model_name = saved_model["model_name"] 56 | model_type = saved_model["model_type"] 57 | model_kwargs = saved_model["model_kwargs"] 58 | state_dict = saved_model["model_state_dict"] 59 | net = getattr(model, model_type)(**model_kwargs).to(device) 60 | net.load_state_dict(state_dict) 61 | net.eval() 62 | 63 | # Create output dir. 64 | test_res_dir = os.path.join(os.path.join(cfg["OUTPUT_DIR"], model_name)) 65 | if not os.path.exists(test_res_dir): 66 | os.makedirs(test_res_dir) 67 | 68 | # Get list of bag files in root directory. 69 | if args.npz_path: 70 | npz_paths = [args.npz_path] 71 | else: 72 | npz_paths = sorted( 73 | [ 74 | os.path.join(cfg["DATASET"]["PATH"], x + ".npz") 75 | for x in cfg["DATASET"]["TEST_SPLIT"] 76 | ] 77 | ) 78 | 79 | mean_pred_mae = [] 80 | mean_pred_rmse = [] 81 | 82 | for path in npz_paths: 83 | print(f"Processing {path}...") 84 | 85 | dataset = FlowDataset( 86 | path, 87 | subsample_factor=cfg["FLOW"]["DATA"]["SUBSAMPLE_FACTOR"], 88 | transform=None, 89 | ) 90 | test_loader = torch.utils.data.DataLoader( 91 | dataset, batch_size=1, shuffle=False, num_workers=0 92 | ) 93 | 94 | times = [] 95 | flow_pred_xs, flow_pred_ys = [], [] 96 | flow_gt_xs, flow_gt_ys = [], [] 97 | 98 | with torch.no_grad(): 99 | for i, batch in enumerate(tqdm(test_loader)): 100 | for k, v in batch.items(): 101 | batch[k] = v.to(device) 102 | curr_time = batch["time"].cpu().numpy() 103 | x = torch.cat([batch["radar_d"], batch["radar_de"]], axis=1).to( 104 | torch.float32 105 | ) 106 | flow_gt = batch["velo_gt"].cpu() 107 | 108 | flow_pred = net(x) 109 | flow_pred = flow_pred.cpu() 110 | flow_pred = torch.squeeze(flow_pred, dim=1) 111 | 112 | flow_x, flow_y = flow_pred[:, 0].numpy(), flow_pred[:, 1].numpy() 113 | flow_pred_xs.append(flow_x) 114 | flow_pred_ys.append(flow_y) 115 | 116 | flow_gt_x, flow_gt_y = flow_gt[:, 0].numpy(), flow_gt[:, 1].numpy() 117 | flow_gt_xs.append(flow_gt_x) 118 | flow_gt_ys.append(flow_gt_y) 119 | 120 | times.append(curr_time) 121 | 122 | flow_pred_xs, flow_pred_ys = np.squeeze(np.array(flow_pred_xs)), np.squeeze( 123 | np.array(flow_pred_ys) 124 | ) 125 | flow_gt_xs, flow_gt_ys = np.squeeze(np.array(flow_gt_xs)), np.squeeze( 126 | np.array(flow_gt_ys) 127 | ) 128 | 129 | # altitudes, altitudes_gt = np.array(altitudes), np.array(altitudes_gt) 130 | 131 | print(f"MAE x: {np.mean(np.abs(flow_pred_xs - flow_gt_xs)):.3f}") 132 | print(f"MAE y: {np.mean(np.abs(flow_pred_ys - flow_gt_ys)):.3f}") 133 | 134 | print(f"RMSE x: {np.sqrt(np.mean((flow_pred_xs - flow_gt_xs)**2)):.3f}") 135 | print(f"RMSE y: {np.sqrt(np.mean((flow_pred_ys - flow_gt_ys)**2)):.3f}") 136 | 137 | print(f"err_mean x: {np.mean((flow_pred_xs - flow_gt_xs)):.3f}") 138 | print(f"err_std x: {np.std((flow_pred_xs - flow_gt_xs)):.3f}") 139 | 140 | print(f"err_mean y: {np.mean((flow_pred_ys - flow_gt_ys)):.3f}") 141 | print(f"err_std y: {np.std((flow_pred_ys - flow_gt_ys)):.3f}") 142 | 143 | pred_mae = ( 144 | np.mean(np.abs(flow_pred_xs - flow_gt_xs)) 145 | + np.mean(np.abs(flow_pred_ys - flow_gt_ys)) 146 | ) / 2 147 | pred_rmse = ( 148 | np.sqrt(np.mean((flow_pred_xs - flow_gt_xs) ** 2)) 149 | + np.sqrt(np.mean((flow_pred_ys - flow_gt_ys) ** 2)) 150 | ) / 2 151 | 152 | mean_pred_mae.append(pred_mae) 153 | mean_pred_rmse.append(pred_rmse) 154 | 155 | fig, ax = plt.subplots(4, 1, sharex=True, figsize=(5, 8)) 156 | 157 | ax[0].set_title( 158 | f"MAE x: {np.mean(np.abs(flow_pred_xs - flow_gt_xs)):.3f} RMSE x: {np.sqrt(np.mean((flow_pred_xs - flow_gt_xs)**2)):.3f}" 159 | ) 160 | ax[0].plot(flow_gt_xs, label="velo_gt_x", color="b") 161 | ax[0].plot(flow_pred_xs, label="velo_x", color="r") 162 | ax[0].set_ylim(-1, 1) 163 | 164 | ax[1].set_title( 165 | f"err mean x: {np.mean((flow_pred_xs - flow_gt_xs)):.3f} stdev: {np.std((flow_pred_xs - flow_gt_xs)):.3f}" 166 | ) 167 | ax[1].plot(flow_pred_xs - flow_gt_xs, label="err_x", color="g") 168 | ax[1].set_ylim(-1, 1) 169 | 170 | ax[2].set_title( 171 | f"MAE y: {np.mean(np.abs(flow_pred_ys - flow_gt_ys)):.3f} RMSE y: {np.sqrt(np.mean((flow_pred_ys - flow_gt_ys)**2)):.3f}" 172 | ) 173 | ax[2].plot(flow_gt_ys, label="velo_gt_y", color="b") 174 | ax[2].plot(flow_pred_ys, label="velo_y", color="r") 175 | ax[2].set_ylim(-1, 1) 176 | 177 | ax[3].set_title( 178 | f"err mean y: {np.mean((flow_pred_ys - flow_gt_ys)):.3f} stdev: {np.std((flow_pred_ys - flow_gt_ys)):.3f}" 179 | ) 180 | ax[3].plot(flow_pred_ys - flow_gt_ys, label="err_y", color="g") 181 | ax[3].set_ylim(-1, 1) 182 | 183 | fig.tight_layout() 184 | fig.legend() 185 | print( 186 | os.path.join( 187 | test_res_dir, os.path.basename(os.path.splitext(path)[0] + ".jpg") 188 | ) 189 | ) 190 | fig.savefig( 191 | os.path.join( 192 | test_res_dir, os.path.basename(os.path.splitext(path)[0] + ".jpg") 193 | ) 194 | ) 195 | plt.close(fig) 196 | 197 | d = { 198 | "time": times, 199 | "flow_pred_xs": flow_pred_xs, 200 | "flow_pred_ys": flow_pred_ys, 201 | "flow_gt_xs": flow_gt_xs, 202 | "flow_gt_ys": flow_gt_ys, 203 | } 204 | np.savez( 205 | os.path.join( 206 | test_res_dir, os.path.basename(os.path.splitext(path)[0] + ".npz") 207 | ), 208 | **d, 209 | ) 210 | 211 | with open(os.path.join(test_res_dir, "metrics.txt"), "w") as f: 212 | f.writelines(s + "\n" for s in npz_paths) 213 | f.write( 214 | f"pred mae total {np.mean(mean_pred_mae):.3f} pred rmse total {np.mean(mean_pred_rmse):.3f}" 215 | ) 216 | # print(f"pred mae total {np.mean(mean_pred_mae):.3f} pred rmse total {np.mean(mean_pred_rmse):.3f}") 217 | -------------------------------------------------------------------------------- /tools/test_rot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | 6 | import argparse 7 | 8 | import numpy as np 9 | import PIL.Image as Image 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import torchvision.models as models 15 | from tqdm import tqdm 16 | 17 | np.set_printoptions(precision=3, floatmode="fixed", sign=" ") 18 | 19 | import matplotlib 20 | matplotlib.use("Agg") 21 | from collections import defaultdict 22 | 23 | import imageio 24 | import imageio.v2 as iio 25 | import matplotlib.pyplot as plt 26 | from scipy.spatial.transform import Rotation as R 27 | 28 | from radarize.config import cfg, update_config 29 | from radarize.rotnet import dataloader, model 30 | from radarize.rotnet.dataloader import RotationDataset 31 | from radarize.utils import image_tools 32 | 33 | 34 | def normalize_angle(x): 35 | """Normalize angle to [-pi, pi].""" 36 | return np.arctan2(np.sin(x), np.cos(x)) 37 | 38 | 39 | def quat2yaw(q): 40 | """Convert quaternion to yaw angle. 41 | Args: 42 | q: (N, 4) array of quaternions 43 | Returns: 44 | yaw: (N, 1) array of yaw angles 45 | """ 46 | if q.ndim == 1: 47 | return R.from_quat(q).as_euler("ZYX", degrees=False)[0] 48 | else: 49 | return R.from_quat(q).as_euler("ZYX", degrees=False)[..., 0:1] 50 | 51 | 52 | def args(): 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument( 55 | "--cfg", help="experiment configure file name", default=None, type=str 56 | ) 57 | parser.add_argument("--npz_path", help="Path to npz.", default=None, required=False) 58 | parser.add_argument("--no_cuda", action="store_true") 59 | parser.add_argument( 60 | "opts", 61 | help="Modify config options using the command-line", 62 | default=None, 63 | nargs=argparse.REMAINDER, 64 | ) 65 | args = parser.parse_args() 66 | 67 | return args 68 | 69 | 70 | def visualize_rotation(pred, gt): 71 | fig, ax = plt.subplots(2, 1, sharex=True, figsize=(16, 8)) 72 | 73 | scale_factor = cfg["ROTNET"]["TEST"]["SEQ_LEN"] / 30 74 | 75 | ax[0].set_title( 76 | f"MAE x: {np.mean(np.abs(pred - gt)):.3f} RMSE x: {np.sqrt(np.mean((pred - gt)**2)):.3f}" 77 | ) 78 | ax[0].plot(pred, label="pred", color="r") 79 | ax[0].plot(gt, label="gt", color="b") 80 | ax[0].set_ylim(-2.0 * scale_factor, 2.0 * scale_factor) 81 | ax[0].grid() 82 | 83 | ax[1].set_title( 84 | f"err mean x: {np.mean((pred - gt)):.3f} stdev: {np.std((pred - gt)):.3f}" 85 | ) 86 | ax[1].plot(np.abs(pred - gt), label="err", color="g") 87 | ax[1].set_ylim(-0.0, 2.0 * scale_factor) 88 | 89 | fig.tight_layout() 90 | fig.legend() 91 | 92 | fig.canvas.draw() 93 | w, h = fig.canvas.get_width_height() 94 | buf = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8) 95 | buf.shape = (w, h, 4) 96 | buf = np.roll(buf, 3, axis=2) 97 | image = Image.frombytes("RGBA", (w, h), buf.tobytes()) 98 | image = np.asarray(image, np.uint8) 99 | rgb_image = image[:, :, :3] 100 | 101 | plt.close(fig) 102 | 103 | return rgb_image 104 | 105 | 106 | def test(net, device, test_loader): 107 | net.eval() 108 | 109 | test_l1_loss = 0 110 | test_mse_loss = 0 111 | 112 | # runtimes = [] 113 | 114 | d = defaultdict(list) 115 | m = {} 116 | 117 | with torch.no_grad(): 118 | for batch in tqdm(test_loader): 119 | for k, v in batch.items(): 120 | batch[k] = v.to(device) 121 | 122 | x = batch["radar_r"].to(torch.float32) 123 | x_ = torch.flip(x, [1]) 124 | 125 | theta_1 = quat2yaw(batch["pose_gt"][:, 0, 3:].cpu().numpy()) 126 | theta_2 = quat2yaw(batch["pose_gt"][:, -1, 3:].cpu().numpy()) 127 | delta_theta = normalize_angle(theta_2 - theta_1) 128 | y = torch.from_numpy(delta_theta).to(device).to(torch.float32) 129 | y_pred = (net(x) - net(x_)) / 2 130 | 131 | test_l1_loss += F.l1_loss(y_pred, y).item() 132 | test_mse_loss += F.mse_loss(y_pred, y).item() 133 | 134 | d["time"].append(batch["time"].cpu().numpy()) 135 | d["rot_pred"].append(y_pred.cpu().numpy().astype(np.float64)) 136 | d["rot_gt"].append(delta_theta) 137 | 138 | test_l1_loss = test_l1_loss / len(test_loader) 139 | test_rmse_loss = np.sqrt(test_mse_loss / len(test_loader)) 140 | 141 | print(f"\n[Test] MAE loss: {test_l1_loss:.6f} RMSE loss: {test_rmse_loss:.6f}") 142 | # average_runtime = np.mean(runtimes) 143 | # print(f'Average runtime: {average_runtime:.3f} seconds') 144 | 145 | # For visualization 146 | for k, v in d.items(): 147 | d[k] = np.concatenate(v, axis=0) 148 | print(k, d[k].shape, d[k].dtype) 149 | 150 | # Metrics 151 | # m['runtime'] = average_runtime 152 | m["l1_loss"] = test_l1_loss 153 | m["rmse_loss"] = test_rmse_loss 154 | 155 | return d, m 156 | 157 | 158 | if __name__ == "__main__": 159 | args = args() 160 | update_config(cfg, args) 161 | 162 | device = torch.device("cpu" if args.no_cuda else "cuda") 163 | 164 | # Load Trained NN 165 | saved_model = torch.load( 166 | os.path.join( 167 | cfg["OUTPUT_DIR"], 168 | cfg["ROTNET"]["MODEL"]["NAME"], 169 | f"{cfg['ROTNET']['MODEL']['NAME']}.pth", 170 | ) 171 | ) 172 | model_type = saved_model["model_type"] 173 | model_name = saved_model["model_name"] 174 | print(model_name, device) 175 | state_dict = saved_model["model_state_dict"] 176 | model_kwargs = saved_model["model_kwargs"] 177 | net = getattr(model, model_type)(**model_kwargs).to(device) 178 | net.load_state_dict(state_dict) 179 | 180 | # Create output dir. 181 | test_res_dir = os.path.join(os.path.join(cfg["OUTPUT_DIR"], model_name)) 182 | if not os.path.exists(test_res_dir): 183 | os.makedirs(test_res_dir) 184 | 185 | # Get list of bag files in root directory. 186 | if args.npz_path: 187 | npz_paths = [args.npz_path] 188 | else: 189 | npz_paths = sorted( 190 | [ 191 | os.path.join(cfg["DATASET"]["PATH"], x + ".npz") 192 | for x in cfg["DATASET"]["TEST_SPLIT"] 193 | ] 194 | ) 195 | 196 | all_metrics = defaultdict(list) 197 | 198 | for path in npz_paths: 199 | print(f"Processing {path}...") 200 | 201 | dataset = RotationDataset( 202 | path, 203 | subsample_factor=cfg["ROTNET"]["DATA"]["SUBSAMPLE_FACTOR"], 204 | seq_len=cfg["ROTNET"]["TEST"]["SEQ_LEN"], 205 | random_seq_len=False, 206 | ) 207 | test_loader = torch.utils.data.DataLoader( 208 | dataset, batch_size=1, shuffle=False, num_workers=0 209 | ) 210 | d, m = test(net, device, test_loader) 211 | 212 | # Save output 213 | np.savez(os.path.join(test_res_dir, os.path.basename(path)), **d) 214 | 215 | # Save metrics 216 | fname = os.path.join( 217 | test_res_dir, os.path.basename(path).replace(".npz", ".txt") 218 | ) 219 | with open(fname, "w") as f: 220 | for k, v in m.items(): 221 | f.write(f"{k}: {v:.6f}\n") 222 | 223 | all_metrics["mae"].append(m["l1_loss"]) 224 | all_metrics["rmse"].append(m["rmse_loss"]) 225 | 226 | # Save plots 227 | fname = fname.replace(".txt", ".jpg") 228 | im = visualize_rotation(d["rot_pred"], d["rot_gt"]) 229 | imageio.imwrite(fname, im) 230 | 231 | # Save average metrics 232 | fname = os.path.join(test_res_dir, "metrics.txt") 233 | with open(fname, "w") as f: 234 | for k, v in all_metrics.items(): 235 | f.write(f"{k}: {np.mean(v):.6f}\n") 236 | -------------------------------------------------------------------------------- /tools/test_unet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | 6 | import argparse 7 | import time 8 | 9 | import numpy as np 10 | import PIL.Image as Image 11 | import torch 12 | import torch.nn.functional as F 13 | from tqdm import tqdm 14 | 15 | np.set_printoptions(precision=3, floatmode="fixed", sign=" ") 16 | 17 | import matplotlib 18 | 19 | matplotlib.use("Agg") 20 | from collections import defaultdict 21 | 22 | import imageio.v2 as iio 23 | import matplotlib.pyplot as plt 24 | 25 | from radarize.config import cfg, update_config 26 | from radarize.unet import model 27 | from radarize.unet.dataloader import UNetDataset 28 | from radarize.unet.dice_score import dice_loss 29 | from radarize.utils import image_tools 30 | 31 | 32 | def args(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument( 35 | "--cfg", help="experiment configure file name", default=None, type=str 36 | ) 37 | parser.add_argument("--npz_path", help="Path to npz.", default=None, required=False) 38 | parser.add_argument("--no_cuda", action="store_true") 39 | parser.add_argument( 40 | "opts", 41 | help="Modify config options using the command-line", 42 | default=None, 43 | nargs=argparse.REMAINDER, 44 | ) 45 | args = parser.parse_args() 46 | 47 | return args 48 | 49 | 50 | def confmap2range(confmap): 51 | device = confmap.device 52 | confmap = confmap.cpu().numpy() 53 | bin2range = np.linspace(0, cfg["DATASET"]["RR_MAX"], cfg["DATASET"]["RAMAP_RSIZE"]) 54 | # confmap = np.squeeze(confmap) 55 | range = bin2range[np.argmin(confmap, axis=1)] 56 | # remove areas without wall 57 | range[np.min(confmap, axis=1) > 0.85] = cfg["DATASET"]["RR_MAX"] 58 | return torch.from_numpy(range).to(device) 59 | 60 | 61 | def range2confmap(range): 62 | device = range.device 63 | range = range.cpu().numpy() 64 | bin_size = cfg["DATASET"]["RR_MAX"] / cfg["DATASET"]["RAMAP_RSIZE"] 65 | confmap = np.zeros((cfg["DATASET"]["RAMAP_RSIZE"], range.shape[0])) 66 | for i, r in enumerate(range): 67 | confmap[int(r // bin_size), i] = 1 68 | return torch.from_numpy(confmap).to(device) 69 | 70 | 71 | def visualize_range(input, output, gt): 72 | fig = plt.figure(figsize=(9, 3)) 73 | 74 | fig.add_subplot(1, 3, 1) 75 | plt.imshow(input, origin="lower", aspect="equal") 76 | plt.axis("off") 77 | plt.title("Radar Heatmap") 78 | 79 | fig.add_subplot(1, 3, 2) 80 | plt.imshow(output, origin="lower", aspect="equal") 81 | plt.axis("off") 82 | plt.title("Output") 83 | 84 | fig.add_subplot(1, 3, 3) 85 | plt.imshow(gt, origin="lower", aspect="equal") 86 | plt.axis("off") 87 | plt.title("Ground Truth") 88 | 89 | fig.canvas.draw() 90 | w, h = fig.canvas.get_width_height() 91 | buf = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8) 92 | buf.shape = (w, h, 4) 93 | buf = np.roll(buf, 3, axis=2) 94 | image = Image.frombytes("RGBA", (w, h), buf.tobytes()) 95 | image = np.asarray(image, np.uint8) 96 | rgb_image = image[:, :, :3] 97 | 98 | plt.close(fig) 99 | 100 | return rgb_image 101 | 102 | 103 | def test(net, device, test_loader): 104 | net.eval() 105 | 106 | test_dice_loss = 0 107 | test_mse_loss = 0 108 | 109 | runtimes = [] 110 | 111 | d = defaultdict(list) 112 | m = {} 113 | 114 | with torch.no_grad(): 115 | for i, batch in enumerate(tqdm(test_loader)): 116 | for k, v in batch.items(): 117 | batch[k] = v.to(device) 118 | # x = torch.cat([batch['radar_r'].to(torch.float32), 119 | # batch['radar_re'].to(torch.float32)], dim=1) 120 | x = torch.cat( 121 | [ 122 | batch["radar_r_1"], 123 | batch["radar_r_3"], 124 | batch["radar_r_5"], 125 | batch["radar_re_1"], 126 | batch["radar_re_3"], 127 | batch["radar_re_5"], 128 | ], 129 | dim=1, 130 | ).to(torch.float32) 131 | y = batch["depth_map"].to(torch.float32) 132 | 133 | tic = time.time() 134 | y_pred = net(x) 135 | y_pred = torch.argmax(y_pred, dim=1, keepdim=True) 136 | # y_pred = (y_pred == 1) 137 | toc = time.time() 138 | # print(f'runtime: {toc - tic:.3f} s') 139 | runtimes.append(toc - tic) 140 | 141 | d["time"].append(batch["time"][:, -1].cpu().numpy()) 142 | d["radar_r"].append(batch["radar_r_3"].cpu().numpy()) 143 | d["depth_map_pred"].append(y_pred.cpu().numpy().astype(np.float64)) 144 | d["depth_map"].append(y.cpu().numpy().astype(np.float64)) 145 | 146 | test_dice_loss += dice_loss(y_pred, y, multiclass=True).item() 147 | test_mse_loss += F.mse_loss(y_pred, y).item() 148 | 149 | test_dice_loss = test_dice_loss / len(test_loader) 150 | test_mse_loss = test_mse_loss / len(test_loader) 151 | print(f"\n[Test] Dice loss: {test_dice_loss:.6f} MSE loss: {test_mse_loss:.6f}") 152 | average_runtime = np.mean(runtimes) 153 | print(f"Average runtime: {average_runtime:.3f} seconds") 154 | 155 | # For visualization 156 | for k, v in d.items(): 157 | d[k] = np.concatenate(v, axis=0) 158 | print(k, d[k].shape, d[k].dtype) 159 | 160 | # Metrics 161 | m["runtime"] = average_runtime 162 | m["dice_loss"] = test_dice_loss 163 | m["mse_loss"] = test_mse_loss 164 | 165 | return d, m 166 | 167 | 168 | if __name__ == "__main__": 169 | args = args() 170 | update_config(cfg, args) 171 | 172 | device = torch.device("cpu" if args.no_cuda else "cuda") 173 | 174 | # Load Trained NN 175 | saved_model = torch.load( 176 | os.path.join( 177 | cfg["OUTPUT_DIR"], 178 | cfg["UNET"]["MODEL"]["NAME"], 179 | f"{cfg['UNET']['MODEL']['NAME']}.pth", 180 | ) 181 | ) 182 | model_name = saved_model["model_name"] 183 | model_type = saved_model["model_type"] 184 | state_dict = saved_model["model_state_dict"] 185 | model_kwargs = saved_model["model_kwargs"] 186 | net = getattr(model, model_type)(**model_kwargs).to(device) 187 | net.load_state_dict(state_dict) 188 | print(model_name, device) 189 | 190 | # Create output dir. 191 | test_res_dir = os.path.join(os.path.join(cfg["OUTPUT_DIR"], model_name)) 192 | if not os.path.exists(test_res_dir): 193 | os.makedirs(test_res_dir) 194 | 195 | # Get list of bag files in root directory. 196 | if args.npz_path: 197 | npz_paths = [args.npz_path] 198 | else: 199 | npz_paths = sorted( 200 | [ 201 | os.path.join(cfg["DATASET"]["PATH"], x + ".npz") 202 | for x in cfg["DATASET"]["TEST_SPLIT"] 203 | ] 204 | ) 205 | 206 | for path in npz_paths: 207 | print(f"Processing {path}...") 208 | 209 | dataset = UNetDataset(path, seq_len=1) 210 | test_loader = torch.utils.data.DataLoader( 211 | dataset, batch_size=1, shuffle=False, num_workers=0 212 | ) 213 | d, m = test(net, device, test_loader) 214 | 215 | # Visualize output 216 | # with iio.get_writer(os.path.join(test_res_dir, os.path.basename(path).replace('.npz', '.mp4')), 217 | # format='FFMPEG', 218 | # mode='I', 219 | # fps=30) as writer: 220 | # for i in tqdm(range(len(d['time']))): 221 | # radar_r = d['radar_r'][i][-1,...] 222 | # depth_map_pred = d['depth_map_pred'][i][0,...] 223 | # depth_map = d['depth_map'][i][0,...] 224 | 225 | # writer.append_data(visualize_range(radar_r, 226 | # depth_map_pred, 227 | # depth_map)) 228 | 229 | # Save output 230 | np.savez( 231 | os.path.join(test_res_dir, os.path.basename(path)), 232 | time=d["time"], 233 | depth_map=d["depth_map_pred"], 234 | ) 235 | 236 | # Save metrics 237 | fname = os.path.join( 238 | test_res_dir, os.path.basename(path).replace(".npz", ".txt") 239 | ) 240 | with open(fname, "w") as f: 241 | for k, v in m.items(): 242 | f.write(f"{k}: {v:.6f}\n") 243 | -------------------------------------------------------------------------------- /tools/train_flow.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | 6 | from tqdm import tqdm 7 | import argparse 8 | import numpy as np 9 | import random 10 | import matplotlib 11 | 12 | matplotlib.use("Agg") 13 | import matplotlib.pyplot as plt 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | from torch.optim import lr_scheduler 19 | 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = True 22 | 23 | from radarize.flow import dataloader 24 | from radarize.flow import model 25 | from radarize.config import cfg, update_config 26 | 27 | 28 | def train(net, device, train_loader, optimizer, epoch): 29 | net.train() 30 | loss_plot = 0 31 | for batch_idx, batch in enumerate(tqdm(train_loader)): 32 | for k, v in batch.items(): 33 | batch[k] = v.to(device) 34 | x = torch.cat([batch["radar_d"], batch["radar_de"]], axis=1).to(torch.float32) 35 | 36 | flow_gt = batch["velo_gt"].to(torch.float32) 37 | 38 | optimizer.zero_grad() 39 | flow_pred = net(x) 40 | 41 | flow_loss_x = F.mse_loss(flow_pred[:, 0], flow_gt[:, 0], reduction="mean") 42 | flow_loss_y = F.mse_loss(flow_pred[:, 1], flow_gt[:, 1], reduction="mean") 43 | loss = torch.sqrt((flow_loss_x + flow_loss_y) / 2.0) 44 | loss_plot += loss.item() 45 | 46 | loss.backward() 47 | optimizer.step() 48 | 49 | if batch_idx % cfg["FLOW"]["TRAIN"]["LOG_STEP"] == 0: 50 | print( 51 | f"Train Epoch: {epoch} [({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {np.sqrt(loss.item()):.6f} flow_loss_x: {np.sqrt(flow_loss_x.item()):.6f}, flow_loss_y: {np.sqrt(flow_loss_y.item()):.6f}" 52 | ) 53 | loss_plot /= len(train_loader) 54 | loss_plot = np.sqrt(loss_plot) 55 | return loss_plot 56 | 57 | 58 | def test(net, device, test_loader, scheduler): 59 | net.eval() 60 | 61 | test_loss_sum_mae = 0 62 | test_loss_sum_mse = 0 63 | 64 | with torch.no_grad(): 65 | for batch in test_loader: 66 | for k, v in batch.items(): 67 | batch[k] = v.to(device) 68 | x = torch.cat([batch["radar_d"], batch["radar_de"]], axis=1).to( 69 | torch.float32 70 | ) 71 | flow_gt = batch["velo_gt"].to(torch.float32) 72 | flow_gt = flow_gt[:, :2] 73 | 74 | flow_pred = torch.squeeze(net(x)) 75 | 76 | test_loss_sum_mae += F.l1_loss(flow_pred, flow_gt, reduction="mean").item() 77 | test_loss_sum_mse += F.mse_loss(flow_pred, flow_gt, reduction="mean").item() 78 | 79 | test_loss_mae = test_loss_sum_mae / len(test_loader) 80 | test_loss_rmse = np.sqrt(test_loss_sum_mse / len(test_loader)) 81 | print( 82 | "\n[Test] L1 loss: {:.4f}, RMSE Loss: {:.4f}\n".format( 83 | test_loss_mae, test_loss_rmse 84 | ) 85 | ) 86 | 87 | scheduler.step(test_loss_rmse) 88 | 89 | return test_loss_rmse 90 | 91 | 92 | def args(): 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument( 95 | "--cfg", help="experiment configure file name", default=None, type=str 96 | ) 97 | parser.add_argument( 98 | "opts", 99 | help="Modify config options using the command-line", 100 | default=None, 101 | nargs=argparse.REMAINDER, 102 | ) 103 | parser.add_argument("--no_cuda", action="store_true") 104 | args = parser.parse_args() 105 | 106 | return args 107 | 108 | 109 | if __name__ == "__main__": 110 | args = args() 111 | update_config(cfg, args) 112 | 113 | # Prepare output directory. 114 | train_dir = os.path.join(cfg["OUTPUT_DIR"], cfg["FLOW"]["MODEL"]["NAME"]) 115 | os.makedirs(train_dir, exist_ok=True) 116 | 117 | # Set random seeds. 118 | random.seed(cfg["FLOW"]["TRAIN"]["SEED"]) 119 | np.random.seed(cfg["FLOW"]["TRAIN"]["SEED"]) 120 | torch.manual_seed(cfg["FLOW"]["TRAIN"]["SEED"]) 121 | 122 | # Set training params. 123 | use_cuda = not args.no_cuda and torch.cuda.is_available() 124 | device = torch.device("cuda" if use_cuda else "cpu") 125 | 126 | train_kwargs = {"batch_size": cfg["FLOW"]["TRAIN"]["BATCH_SIZE"], "drop_last": True} 127 | test_kwargs = {"batch_size": cfg["FLOW"]["TEST"]["BATCH_SIZE"]} 128 | 129 | if use_cuda: 130 | cuda_kwargs = { 131 | "num_workers": 0, 132 | "shuffle": True, 133 | "worker_init_fn": lambda id: np.random.seed( 134 | id * cfg["FLOW"]["TRAIN"]["SEED"] 135 | ), 136 | } 137 | train_kwargs.update(cuda_kwargs) 138 | test_kwargs.update(cuda_kwargs) 139 | 140 | # Prepare for the dataset 141 | print("Loading dataset...") 142 | train_paths = [ 143 | os.path.join(cfg["DATASET"]["PATH"], x + ".npz") 144 | for x in cfg["DATASET"]["TRAIN_SPLIT"] 145 | ] 146 | train_datasets = [ 147 | dataloader.FlowDataset( 148 | path, 149 | subsample_factor=cfg["FLOW"]["DATA"]["SUBSAMPLE_FACTOR"], 150 | transform=dataloader.FlipFlow(), 151 | ) 152 | for path in sorted(train_paths) 153 | ] 154 | train_dataset = torch.utils.data.ConcatDataset(train_datasets) 155 | 156 | test_paths = [ 157 | os.path.join(cfg["DATASET"]["PATH"], x + ".npz") 158 | for x in cfg["DATASET"]["VAL_SPLIT"] 159 | ] 160 | test_datasets = [ 161 | dataloader.FlowDataset( 162 | path, 163 | subsample_factor=cfg["FLOW"]["DATA"]["SUBSAMPLE_FACTOR"], 164 | transform=None, 165 | ) 166 | for path in sorted(test_paths) 167 | ] 168 | test_dataset = torch.utils.data.ConcatDataset(test_datasets) 169 | 170 | train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs) 171 | test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs) 172 | 173 | # Load network. 174 | model_kwargs = { 175 | "n_channels": cfg["FLOW"]["MODEL"]["N_CHANNELS"], 176 | "n_outputs": cfg["FLOW"]["MODEL"]["N_OUTPUTS"], 177 | } 178 | net = getattr(model, cfg["FLOW"]["MODEL"]["TYPE"])(**model_kwargs).to(device) 179 | 180 | optimizer = optim.Adam(net.parameters(), lr=cfg["FLOW"]["TRAIN"]["LR"]) 181 | 182 | scheduler = lr_scheduler.ReduceLROnPlateau( 183 | optimizer, "min", factor=0.5, patience=10 184 | ) 185 | 186 | train_loss_array = [] 187 | test_loss_array = [] 188 | 189 | least_test_loss = 1000 190 | for epoch in range(1, cfg["FLOW"]["TRAIN"]["EPOCHS"] + 1): 191 | train_loss = train(net, device, train_loader, optimizer, epoch) 192 | test_loss = test(net, device, test_loader, scheduler) 193 | 194 | train_loss_array.append(train_loss) 195 | test_loss_array.append(test_loss) 196 | 197 | plt.plot(np.array(train_loss_array), "b", label="Train Loss") 198 | plt.plot(np.array(test_loss_array), "r", label="Test Loss") 199 | plt.scatter( 200 | np.argmin(np.array(test_loss_array)), 201 | np.min(test_loss_array), 202 | s=30, 203 | color="green", 204 | ) 205 | plt.title("Loss Plot, min:{:.3f}".format(np.min(test_loss_array))) 206 | plt.legend() 207 | plt.grid() 208 | plt.ylim(bottom=0) 209 | plt.xlim(left=0) 210 | plt.savefig(os.path.join(train_dir, "loss.jpg")) 211 | plt.close() 212 | # scheduler.step() 213 | if test_loss < least_test_loss: 214 | least_test_loss = test_loss 215 | torch.save( 216 | { 217 | "model_name": cfg["FLOW"]["MODEL"]["NAME"], 218 | "model_type": type(net).__name__, 219 | "model_kwargs": model_kwargs, 220 | "model_state_dict": net.state_dict(), 221 | "epoch": epoch, 222 | "test_loss": test_loss, 223 | }, 224 | os.path.join(train_dir, f"{cfg['FLOW']['MODEL']['NAME']}.pth"), 225 | ) 226 | -------------------------------------------------------------------------------- /tools/train_rot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | 6 | import argparse 7 | import random 8 | 9 | import matplotlib 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | matplotlib.use("Agg") 14 | import matplotlib.pyplot as plt 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.optim as optim 19 | import torchvision.models as models 20 | from torch.optim import lr_scheduler 21 | 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = True 24 | 25 | from scipy.spatial.transform import Rotation as R 26 | 27 | from radarize.config import cfg, update_config 28 | from radarize.rotnet import dataloader, model 29 | from radarize.rotnet.dataloader import RotationDataset 30 | 31 | 32 | def normalize_angle(x): 33 | """Normalize angle to [-pi, pi].""" 34 | return np.arctan2(np.sin(x), np.cos(x)) 35 | 36 | 37 | def quat2yaw(q): 38 | """Convert quaternion to yaw angle. 39 | Args: 40 | q: (N, 4) array of quaternions 41 | Returns: 42 | yaw: (N, 1) array of yaw angles 43 | """ 44 | if q.ndim == 1: 45 | return R.from_quat(q).as_euler("ZYX", degrees=False)[0] 46 | else: 47 | return R.from_quat(q).as_euler("ZYX", degrees=False)[..., 0:1] 48 | 49 | 50 | def train(net, device, train_loader, optimizer, epoch): 51 | net.train() 52 | loss_plot = 0 53 | for batch_idx, batch in enumerate(tqdm(train_loader)): 54 | for k, v in batch.items(): 55 | batch[k] = v.to(device) 56 | 57 | x_1 = batch["radar_r"].to(torch.float32) 58 | x_2 = torch.flip(x_1, [1]) 59 | 60 | theta_1 = quat2yaw(batch["pose_gt"][:, 0, 3:].cpu().numpy()) 61 | theta_2 = quat2yaw(batch["pose_gt"][:, -1, 3:].cpu().numpy()) 62 | delta_theta = normalize_angle(theta_2 - theta_1) 63 | y = torch.from_numpy(delta_theta).to(device).to(torch.float32) 64 | 65 | optimizer.zero_grad() 66 | y_pred_1 = net(x_1) 67 | y_pred_2 = net(x_2) 68 | loss = torch.sqrt((F.mse_loss(y_pred_1, y) + F.mse_loss(y_pred_2, -y)) / 2) 69 | loss_plot = loss.item() 70 | loss.backward() 71 | optimizer.step() 72 | 73 | if batch_idx % cfg["ROTNET"]["TRAIN"]["LOG_STEP"] == 0: 74 | print( 75 | f"Train Epoch: {epoch} [({100. * batch_idx / len(train_loader):.0f}%)]\t Loss: {loss.item():.6f}" 76 | ) 77 | 78 | loss_plot /= len(train_loader) 79 | return loss_plot 80 | 81 | 82 | def test(net, device, test_loader, scheduler): 83 | net.eval() 84 | 85 | test_l1_loss = 0 86 | test_mse_loss = 0 87 | 88 | with torch.no_grad(): 89 | for batch in test_loader: 90 | for k, v in batch.items(): 91 | batch[k] = v.to(device) 92 | 93 | x = batch["radar_r"].to(torch.float32) 94 | x_ = torch.flip(x, [1]) 95 | 96 | theta_1 = quat2yaw(batch["pose_gt"][:, 0, 3:].cpu().numpy()) 97 | theta_2 = quat2yaw(batch["pose_gt"][:, -1, 3:].cpu().numpy()) 98 | delta_theta = normalize_angle(theta_2 - theta_1) 99 | y = torch.from_numpy(delta_theta).to(device).to(torch.float32) 100 | 101 | y_pred = (net(x) - net(x_)) / 2 102 | 103 | test_l1_loss += F.l1_loss(y_pred, y).item() 104 | test_mse_loss += F.mse_loss(y_pred, y).item() 105 | 106 | test_l1_loss = test_l1_loss / len(test_loader) 107 | test_rmse_loss = np.sqrt(test_mse_loss / len(test_loader)) 108 | print(f"\n[Test] MAE loss: {test_l1_loss:.4f} RMSE loss: {test_rmse_loss:.4f}") 109 | 110 | scheduler.step(test_rmse_loss) 111 | 112 | return test_rmse_loss 113 | 114 | 115 | def args(): 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument( 118 | "--cfg", help="experiment configure file name", default=None, type=str 119 | ) 120 | parser.add_argument( 121 | "opts", 122 | help="Modify config options using the command-line", 123 | default=None, 124 | nargs=argparse.REMAINDER, 125 | ) 126 | parser.add_argument("--no_cuda", action="store_true") 127 | args = parser.parse_args() 128 | 129 | return args 130 | 131 | 132 | if __name__ == "__main__": 133 | args = args() 134 | update_config(cfg, args) 135 | 136 | # Prepare output directory. 137 | train_dir = os.path.join(cfg["OUTPUT_DIR"], cfg["ROTNET"]["MODEL"]["NAME"]) 138 | os.makedirs(train_dir, exist_ok=True) 139 | 140 | # Set random seeds. 141 | random.seed(cfg["ROTNET"]["TRAIN"]["SEED"]) 142 | np.random.seed(cfg["ROTNET"]["TRAIN"]["SEED"]) 143 | torch.manual_seed(cfg["ROTNET"]["TRAIN"]["SEED"]) 144 | 145 | # Set training params. 146 | use_cuda = not args.no_cuda and torch.cuda.is_available() 147 | device = torch.device("cuda" if use_cuda else "cpu") 148 | 149 | train_kwargs = { 150 | "batch_size": cfg["ROTNET"]["TRAIN"]["BATCH_SIZE"], 151 | "drop_last": True, 152 | } 153 | test_kwargs = {"batch_size": cfg["ROTNET"]["TEST"]["BATCH_SIZE"]} 154 | 155 | if use_cuda: 156 | cuda_kwargs = { 157 | "num_workers": 0, 158 | "shuffle": True, 159 | "worker_init_fn": lambda id: np.random.seed( 160 | id * cfg["ROTNET"]["TRAIN"]["SEED"] 161 | ), 162 | } 163 | train_kwargs.update(cuda_kwargs) 164 | cuda_kwargs = { 165 | "num_workers": 0, 166 | "shuffle": False, 167 | "worker_init_fn": lambda id: np.random.seed( 168 | id * cfg["ROTNET"]["TRAIN"]["SEED"] 169 | ), 170 | } 171 | test_kwargs.update(cuda_kwargs) 172 | 173 | # Prepare for the dataset 174 | print("Loading dataset...") 175 | train_paths = [ 176 | os.path.join(cfg["DATASET"]["PATH"], x + ".npz") 177 | for x in cfg["DATASET"]["TRAIN_SPLIT"] 178 | ] 179 | train_datasets = [ 180 | RotationDataset( 181 | path, 182 | subsample_factor=cfg["ROTNET"]["DATA"]["SUBSAMPLE_FACTOR"], 183 | seq_len=cfg["ROTNET"]["TRAIN"]["TRAIN_SEQ_LEN"], 184 | random_seq_len=cfg["ROTNET"]["TRAIN"]["TRAIN_RANDOM_SEQ_LEN"], 185 | transform=dataloader.ReverseTime(0.5), 186 | ) 187 | for path in train_paths 188 | ] 189 | train_dataset = torch.utils.data.ConcatDataset(train_datasets) 190 | 191 | test_paths = [ 192 | os.path.join(cfg["DATASET"]["PATH"], x + ".npz") 193 | for x in cfg["DATASET"]["VAL_SPLIT"] 194 | ] 195 | test_datasets = [ 196 | RotationDataset( 197 | path, 198 | subsample_factor=cfg["ROTNET"]["DATA"]["SUBSAMPLE_FACTOR"], 199 | seq_len=cfg["ROTNET"]["TRAIN"]["VAL_SEQ_LEN"], 200 | random_seq_len=cfg["ROTNET"]["TRAIN"]["VAL_RANDOM_SEQ_LEN"], 201 | ) 202 | for path in sorted(test_paths) 203 | ] 204 | test_dataset = torch.utils.data.ConcatDataset(test_datasets) 205 | 206 | train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs) 207 | test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs) 208 | 209 | model_kwargs = { 210 | "n_channels": cfg["ROTNET"]["MODEL"]["N_CHANNELS"], 211 | "n_outputs": cfg["ROTNET"]["MODEL"]["N_OUTPUTS"], 212 | } 213 | 214 | # Load network. 215 | net = getattr(model, cfg["ROTNET"]["MODEL"]["TYPE"])(**model_kwargs).to(device) 216 | 217 | optimizer = optim.AdamW( 218 | net.parameters(), lr=cfg["ROTNET"]["TRAIN"]["LR"], betas=(0.9, 0.999) 219 | ) 220 | 221 | scheduler = lr_scheduler.ReduceLROnPlateau( 222 | optimizer, "min", factor=0.5, patience=10 223 | ) 224 | 225 | train_loss_array = [] 226 | test_loss_array = [] 227 | 228 | least_test_loss = np.inf 229 | for epoch in range(1, cfg["ROTNET"]["TRAIN"]["EPOCHS"] + 1): 230 | train_loss = train(net, device, train_loader, optimizer, epoch) 231 | test_loss = test(net, device, test_loader, scheduler) 232 | 233 | train_loss_array.append(train_loss) 234 | test_loss_array.append(test_loss) 235 | 236 | plt.plot(np.array(train_loss_array), "b", label="Train Loss") 237 | plt.plot(np.array(test_loss_array), "r", label="Test Loss") 238 | plt.scatter( 239 | np.argmin(np.array(test_loss_array)), 240 | np.min(test_loss_array), 241 | s=30, 242 | color="green", 243 | ) 244 | plt.title("Loss Plot, min:{:.3f}".format(np.min(test_loss_array))) 245 | plt.legend() 246 | plt.grid() 247 | plt.ylim(bottom=0) 248 | plt.xlim(left=0) 249 | plt.savefig(os.path.join(train_dir, "loss.jpg")) 250 | plt.close() 251 | # scheduler.step() 252 | if test_loss < least_test_loss: 253 | least_test_loss = test_loss 254 | torch.save( 255 | { 256 | "model_name": cfg["ROTNET"]["MODEL"]["NAME"], 257 | "model_type": type(net).__name__, 258 | "model_state_dict": net.state_dict(), 259 | "model_kwargs": model_kwargs, 260 | "epoch": epoch, 261 | "test_loss": test_loss, 262 | }, 263 | os.path.join(train_dir, f"{cfg['ROTNET']['MODEL']['NAME']}.pth"), 264 | ) 265 | -------------------------------------------------------------------------------- /tools/train_unet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | 6 | import argparse 7 | import random 8 | 9 | import matplotlib 10 | import numpy as np 11 | 12 | matplotlib.use("Agg") 13 | import matplotlib.pyplot as plt 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | from torch.optim import lr_scheduler 19 | 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = True 22 | 23 | from radarize.config import cfg, update_config 24 | from radarize.unet import dataloader, model 25 | from radarize.unet.dice_score import dice_loss 26 | 27 | 28 | def train(net, device, train_loader, optimizer, epoch): 29 | net.train() 30 | loss_plot = 0 31 | for batch_idx, batch in enumerate(train_loader): 32 | for k, v in batch.items(): 33 | batch[k] = v.to(device) 34 | # x = torch.cat([batch['radar_r'].to(torch.float32), 35 | # batch['radar_re'].to(torch.float32)], dim=1) 36 | x = torch.cat( 37 | [ 38 | batch["radar_r_1"], 39 | batch["radar_r_3"], 40 | batch["radar_r_5"], 41 | batch["radar_re_1"], 42 | batch["radar_re_3"], 43 | batch["radar_re_5"], 44 | ], 45 | dim=1, 46 | ).to(torch.float32) 47 | y = batch["depth_map"].to(torch.float32) 48 | y = torch.cat([y[:, 0:1, ...] == 0, y[:, 0:1, ...] == 1], dim=1).to( 49 | torch.float32 50 | ) 51 | 52 | optimizer.zero_grad() 53 | y_pred = net(x) 54 | loss = cfg["UNET"]["TRAIN"]["BCE_WEIGHT"] * F.binary_cross_entropy( 55 | y_pred, y 56 | ) + cfg["UNET"]["TRAIN"]["DICE_WEIGHT"] * dice_loss(y_pred, y, multiclass=True) 57 | loss_plot += loss.item() 58 | loss.backward() 59 | optimizer.step() 60 | 61 | if batch_idx % cfg["UNET"]["TRAIN"]["LOG_STEP"] == 0: 62 | print( 63 | f"Train Epoch: {epoch} [({100. * batch_idx / len(train_loader):.0f}%)]\t Loss: {loss.item():.6f}" 64 | ) 65 | 66 | loss_plot /= len(train_loader) 67 | return loss_plot 68 | 69 | 70 | def test(net, device, test_loader): 71 | net.eval() 72 | 73 | test_dice_loss = 0 74 | test_mse_loss = 0 75 | 76 | with torch.no_grad(): 77 | for batch in test_loader: 78 | for k, v in batch.items(): 79 | batch[k] = v.to(device) 80 | # x = torch.cat([batch['radar_r'].to(torch.float32), 81 | # batch['radar_re'].to(torch.float32)], dim=1) 82 | x = torch.cat( 83 | [ 84 | batch["radar_r_1"], 85 | batch["radar_r_3"], 86 | batch["radar_r_5"], 87 | batch["radar_re_1"], 88 | batch["radar_re_3"], 89 | batch["radar_re_5"], 90 | ], 91 | dim=1, 92 | ).to(torch.float32) 93 | y = batch["depth_map"].to(torch.float32) 94 | 95 | y_pred = net(x) 96 | y_pred = torch.argmax(y_pred, dim=1, keepdim=True) 97 | # y_pred = (y_pred == 1) 98 | 99 | test_dice_loss += dice_loss(y_pred, y, multiclass=True).item() 100 | test_mse_loss += F.mse_loss(y_pred, y).item() 101 | 102 | test_dice_loss = test_dice_loss / len(test_loader) 103 | test_mse_loss = test_mse_loss / len(test_loader) 104 | print(f"\n[Test] Dice loss: {test_dice_loss:.6f} MSE loss: {test_mse_loss:.6f}") 105 | return test_dice_loss 106 | 107 | 108 | def args(): 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument( 111 | "--cfg", help="experiment configure file name", default=None, type=str 112 | ) 113 | parser.add_argument( 114 | "opts", 115 | help="Modify config options using the command-line", 116 | default=None, 117 | nargs=argparse.REMAINDER, 118 | ) 119 | parser.add_argument("--no_cuda", action="store_true") 120 | args = parser.parse_args() 121 | 122 | return args 123 | 124 | 125 | if __name__ == "__main__": 126 | args = args() 127 | update_config(cfg, args) 128 | 129 | # Prepare output directory. 130 | train_dir = os.path.join(cfg["OUTPUT_DIR"], cfg["UNET"]["MODEL"]["NAME"]) 131 | os.makedirs(train_dir, exist_ok=True) 132 | 133 | # Set random seeds. 134 | random.seed(cfg["UNET"]["TRAIN"]["SEED"]) 135 | np.random.seed(cfg["UNET"]["TRAIN"]["SEED"]) 136 | torch.manual_seed(cfg["UNET"]["TRAIN"]["SEED"]) 137 | 138 | # Set training params. 139 | use_cuda = not args.no_cuda and torch.cuda.is_available() 140 | device = torch.device("cuda" if use_cuda else "cpu") 141 | 142 | train_kwargs = {"batch_size": cfg["UNET"]["TRAIN"]["BATCH_SIZE"], "drop_last": True} 143 | test_kwargs = {"batch_size": cfg["UNET"]["TEST"]["BATCH_SIZE"]} 144 | 145 | if use_cuda: 146 | cuda_kwargs = { 147 | "num_workers": 0, 148 | "shuffle": True, 149 | "worker_init_fn": lambda id: np.random.seed( 150 | id * cfg["UNET"]["TRAIN"]["SEED"] 151 | ), 152 | } 153 | train_kwargs.update(cuda_kwargs) 154 | cuda_kwargs = { 155 | "num_workers": 0, 156 | "shuffle": False, 157 | "worker_init_fn": lambda id: np.random.seed( 158 | id * cfg["UNET"]["TRAIN"]["SEED"] 159 | ), 160 | } 161 | test_kwargs.update(cuda_kwargs) 162 | 163 | # Prepare for the dataset 164 | print("Loading dataset...") 165 | train_paths = [ 166 | os.path.join(cfg["DATASET"]["PATH"], x + ".npz") 167 | for x in cfg["DATASET"]["TRAIN_SPLIT"] 168 | ] 169 | train_datasets = [ 170 | dataloader.UNetDataset(path, seq_len=1, transform=dataloader.FlipRange(0.5)) 171 | for path in sorted(train_paths) 172 | ] 173 | train_dataset = torch.utils.data.ConcatDataset(train_datasets) 174 | 175 | test_paths = [ 176 | os.path.join(cfg["DATASET"]["PATH"], x + ".npz") 177 | for x in cfg["DATASET"]["VAL_SPLIT"] 178 | ] 179 | test_datasets = [ 180 | dataloader.UNetDataset(path, seq_len=1) for path in sorted(test_paths) 181 | ] 182 | test_dataset = torch.utils.data.ConcatDataset(test_datasets) 183 | 184 | train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs) 185 | test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs) 186 | 187 | model_kwargs = { 188 | "n_channels": cfg["UNET"]["MODEL"]["N_CHANNELS"], 189 | "n_classes": cfg["UNET"]["MODEL"]["N_CLASSES"], 190 | } 191 | 192 | # Load network. 193 | net = getattr(model, cfg["UNET"]["MODEL"]["TYPE"])(**model_kwargs).to(device) 194 | 195 | optimizer = optim.Adam( 196 | net.parameters(), lr=cfg["UNET"]["TRAIN"]["LR"], betas=(0.8, 0.9) 197 | ) 198 | 199 | train_loss_array = [] 200 | test_loss_array = [] 201 | 202 | least_test_loss = np.inf 203 | for epoch in range(1, cfg["UNET"]["TRAIN"]["EPOCHS"] + 1): 204 | train_loss = train(net, device, train_loader, optimizer, epoch) 205 | test_loss = test(net, device, test_loader) 206 | print(train_loss, test_loss) 207 | 208 | train_loss_array.append(train_loss) 209 | test_loss_array.append(test_loss) 210 | 211 | plt.plot(np.array(train_loss_array), "b", label="Train Loss") 212 | plt.plot(np.array(test_loss_array), "r", label="Test Loss") 213 | plt.scatter( 214 | np.argmin(np.array(test_loss_array)), 215 | np.min(test_loss_array), 216 | s=30, 217 | color="green", 218 | ) 219 | plt.title("Loss Plot, min:{:.3f}".format(np.min(test_loss_array))) 220 | plt.legend() 221 | plt.grid() 222 | plt.ylim(bottom=0) 223 | plt.xlim(left=0) 224 | plt.savefig(os.path.join(train_dir, "loss.jpg")) 225 | plt.close() 226 | # scheduler.step() 227 | if test_loss < least_test_loss: 228 | least_test_loss = test_loss 229 | torch.save( 230 | { 231 | "model_name": cfg["UNET"]["MODEL"]["NAME"], 232 | "model_type": type(net).__name__, 233 | "model_state_dict": net.state_dict(), 234 | "model_kwargs": model_kwargs, 235 | "epoch": epoch, 236 | "test_loss": test_loss, 237 | }, 238 | os.path.join(train_dir, f"{cfg['UNET']['MODEL']['NAME']}.pth"), 239 | ) 240 | --------------------------------------------------------------------------------