├── rl_l2o ├── __init__.py ├── utils.py ├── compute_reward.py ├── predict_reward.py └── eps_greedy_search.py ├── object_detection ├── __init__.py ├── data │ └── features_pcl.pkl ├── scripts │ ├── train_pointpillar.sh │ └── train_pointrcnn.sh └── compute_reward.py ├── overview.png ├── requirements.txt ├── .gitmodules ├── main.py ├── .gitignore ├── README.md └── LICENSE /rl_l2o/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /object_detection/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vniclas/lidar_beam_selection/HEAD/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.5.0 2 | torchvision==0.6.0 3 | numpy==1.20 4 | sympy 5 | tqdm 6 | pykdtree 7 | -------------------------------------------------------------------------------- /object_detection/data/features_pcl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vniclas/lidar_beam_selection/HEAD/object_detection/data/features_pcl.pkl -------------------------------------------------------------------------------- /object_detection/scripts/train_pointpillar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd /home/USER/git/lidar_beam_selection/third_party/OpenPCDet/tools || exit 4 | 5 | bash scripts/dist_train.sh 4 --cfg_file cfgs/kitti_models/pointpillar.yaml --ckpt_save_interval 40 --extra_tag $1 6 | -------------------------------------------------------------------------------- /object_detection/scripts/train_pointrcnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd /home/USER/git/lidar_beam_selection/third_party/OpenPCDet/tools || exit 4 | 5 | bash scripts/dist_train.sh 4 --cfg_file cfgs/kitti_models/pointrcnn.yaml --ckpt_save_interval 40 --extra_tag $1 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/OpenPCDet"] 2 | path = third_party/OpenPCDet 3 | url = https://github.com/open-mmlab/OpenPCDet.git 4 | [submodule "third_party/Pseudo_Lidar_V2"] 5 | path = third_party/Pseudo_Lidar_V2 6 | url = https://github.com/mileyan/Pseudo_Lidar_V2.git 7 | -------------------------------------------------------------------------------- /rl_l2o/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Colors: 6 | HEADER = '\033[95m' 7 | OKBLUE = '\033[94m' 8 | OKCYAN = '\033[96m' 9 | OKGREEN = '\033[92m' 10 | WARNING = '\033[93m' 11 | FAIL = '\033[91m' 12 | ENDC = '\033[0m' 13 | BOLD = '\033[1m' 14 | UNDERLINE = '\033[4m' 15 | -------------------------------------------------------------------------------- /rl_l2o/compute_reward.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from copy import deepcopy 3 | from typing import Any, Dict, List, Tuple, Optional 4 | 5 | import numpy as np 6 | 7 | 8 | class RewardComputer(ABC): 9 | def __init__(self, config: Optional[Dict[str, Any]] = None): 10 | self.config = deepcopy(self.DEFAULT_CONFIG) 11 | if config: 12 | self.config.update(config) 13 | self.num_reward_signals = None 14 | 15 | self.cache = [] # (beams, reward) 16 | 17 | def set_number_reward_signals(self, num_reward_signals: int = 1): 18 | self.num_reward_signals = num_reward_signals 19 | 20 | def compute(self, state: np.array) -> Tuple[float, np.ndarray]: 21 | assert self.num_reward_signals is not None 22 | 23 | sorted_state = np.sort(state) 24 | for pair in self.cache: 25 | if np.all(pair[0] == sorted_state): 26 | return float(pair[1].mean()), pair[1] 27 | 28 | reward_signals = self.compute_reward(np.sort(state).tolist(), **self.config) 29 | # Ensure that the code is consistent 30 | assert reward_signals.shape == (self.num_reward_signals,) 31 | 32 | self.cache.append((np.sort(state), reward_signals)) 33 | return float(reward_signals.mean()), reward_signals 34 | 35 | def load_cache(self, cache: List[Tuple[np.ndarray, float]]): 36 | self.cache = cache 37 | 38 | @property 39 | @classmethod 40 | @abstractmethod 41 | def DEFAULT_CONFIG(cls): 42 | raise NotImplementedError 43 | 44 | @staticmethod 45 | @abstractmethod 46 | def compute_reward() -> np.ndarray: 47 | pass 48 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from object_detection.compute_reward import RewardComputerObjectDetection 4 | from rl_l2o.eps_greedy_search import EpsGreedySearch 5 | 6 | detector = 'pointpillar' 7 | # detector = 'rcnn' 8 | 9 | # Overwrite default config parameters 10 | config = { 11 | 'checkpoint_files_path': Path(__file__).absolute().parent / f'checkpoints_{detector}', 12 | 'compute_reward': { 13 | 'plv2_dir': '/home/USER/git/lidar_beam_selection/third_party/Pseudo_Lidar_V2/gdc', 14 | 'opcd_dir': '/home/USER/git/lidar_beam_selection/third_party/OpenPCDet', 15 | 'kitti_dir': '/home/USER/data/kitti/training', 16 | 'output_dir': '/home/USER/data/pseudo_lidar_v2', 17 | 'pred_path': '/home/USER/data/pseudo_lidar_v2/sdn_kitti_train_set/depth_maps/trainval', 18 | 'detector': detector, 19 | } 20 | } 21 | logfile = Path(__file__).parent / f'log_{detector}.txt' 22 | 23 | features_pcl_file = Path(__file__).absolute().parent / 'object_detection' / 'data' / 'features_pcl.pkl' 24 | reward_computer = RewardComputerObjectDetection(config['compute_reward']) 25 | 26 | # To start a new run from scratch 27 | search = EpsGreedySearch(features_pcl_file, reward_computer, config, logfile) 28 | 29 | # To resume a previous run 30 | # checkpoint_file = Path(__file__).parent / f'checkpoints_{detector}' / 'checkpoint_100.pkl' 31 | # search.load_checkpoint(checkpoint_file, continue_searching=True) 32 | 33 | # To load the cache containing the true reward 34 | # search.load_reward_computation_cache(checkpoint_file) # To load the cache containing the true reward 35 | 36 | # Start the search 37 | search.run() 38 | 39 | # Display the best beam configuration 40 | best_state = search.best_state(return_reward=True) 41 | print(f'\033[96m RESULT: state={best_state[0]}, reward={best_state[1]:.3f} \033[0m') # Colored as OKCYAN 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Data files 132 | *.obj 133 | *.pkl 134 | 135 | # Pycharm 136 | .idea 137 | 138 | # VS Code 139 | .vscode 140 | 141 | # Project specific 142 | object_detection/checkpoints/* 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # End-to-End Optimization of LiDAR Beam Configuration 2 | [**arXiv**](https://arxiv.org/abs/2201.03860) | [**IEEE Xplore**](https://ieeexplore.ieee.org/abstract/document/9681305) | [**Video**](https://youtu.be/7fl0gLuvhZc) 3 | 4 | This repository is the official implementation of the paper: 5 | 6 | > **End-to-End Optimization of LiDAR Beam Configuration for 3D Object Detection and Localization** 7 | > 8 | > [Niclas Vödisch](https://vniclas.github.io/), [Ozan Unal](https://vision.ee.ethz.ch/people-details.MjA5ODkz.TGlzdC8zMjg5LC0xOTcxNDY1MTc4.html), [Ke Li](https://icu.ee.ethz.ch/people/person-detail.ke-li.html), [Luc Van Gool](https://vision.ee.ethz.ch/people-details.OTAyMzM=.TGlzdC8zMjQ4LC0xOTcxNDY1MTc4.html), and [Dengxin Dai](https://people.ee.ethz.ch/~daid/). 9 | > 10 | > *IEEE Robotics and Automation Letters (RA-L)*, vol. 7, no. 2, pp. 2242-2249, April 2022 11 | 12 |

13 | Overview of 3D object detection 14 |

15 | 16 | If you find our work useful, please consider citing our paper: 17 | ``` 18 | @ARTICLE{Voedisch_2022_RAL, 19 | author={Vödisch, Niclas and Unal, Ozan and Li, Ke and Van Gool, Luc and Dai, Dengxin}, 20 | journal={IEEE Robotics and Automation Letters}, 21 | title={End-to-End Optimization of LiDAR Beam Configuration for 3D Object Detection and Localization}, 22 | year={2022}, 23 | volume={7}, 24 | number={2}, 25 | pages={2242-2249}, 26 | doi={10.1109/LRA.2022.3142738}} 27 | ``` 28 | 29 | ## 📔 Abstract 30 | 31 | Pre-determined beam configurations of low-resolution LiDARs are task-agnostic, hence simply using can result in non-optimal performance. 32 | In this work, we propose to optimize the beam distribution for a given target task via a reinforcement learning-based learning-to-optimize (RL-L2O) framework. 33 | We design our method in an end-to-end fashion leveraging the final performance of the task to guide the search process. 34 | Due to the simplicity of our approach, our work can be integrated with any LiDAR-based application as a simple drop-in module. 35 | In this repository, we provide the code for the exemplary task of 3D object detection. 36 | 37 | 38 | ## 🏗️️ Setup 39 | 40 | To clone this repository and all submodules run: 41 | ```shell 42 | git clone --recurse-submodules -j8 git@github.com:vniclas/lidar_beam_selection.git 43 | ``` 44 | 45 | ### ⚙️ Installation 46 | 47 | To install this code, please follow the steps below: 48 | 1. Create a conda environment: `conda create -n beam_selection python=3.8` 49 | 2. Activate the environment: `conda activate beam_selection` 50 | 3. Install dependencies: `pip install -r requirements.txt` 51 | 4. Install cudatoolkit *(change to the used CUDA version)*:
52 | `conda install cudnn cudatoolkit=10.2` 53 | 5. Install [spconv](https://github.com/traveller59/spconv#install) *(change to the used CUDA version)*:
54 | `pip install spconv-cu102` 55 | 6. Install [OpenPCDet](https://github.com/open-mmlab/OpenPCDet) *(linked as submodule)*:
56 | `cd third_party/OpenPCDet && python setup.py develop && cd ../..` 57 | 7. Install [Pseudo-LiDAR++](https://github.com/mileyan/Pseudo_Lidar_V2) *(linked as submodule)*:
58 | `pip install -r third_party/Pseudo_Lidar_V2/requirements.txt`
59 | `pip install pillow==8.3.2` *(avoid runtime warnings)* 60 | 61 | ### 💾 Data Preparation 62 | 63 | 1. Download [KITTI 3D Object Detection dataset](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) and extract the files: 64 | 1. Left color images `image_2` 65 | 2. Right color images `image_3` 66 | 3. Velodyne point clouds `velodyne` 67 | 4. Camera calibration matrices `calib` 68 | 5. Training labels `label_2` 69 | 2. Predict the depth maps: 70 | 1. Download [pretrained model (training+validation)](https://github.com/mileyan/Pseudo_Lidar_V2#pretrained-models) 71 | 2. Generate the data: 72 | ```shell 73 | cd third_party/Pseudo_Lidar_V2 74 | python ./src/main.py -c src/configs/sdn_kitti_train.config \ 75 | --resume PATH_TO_CHECKPOINTS/sdn_kitti_object_trainval.pth --datapath PATH_TO_KITTI/training/ \ 76 | --data_list ./split/trainval.txt --generate_depth_map --data_tag trainval \ 77 | --save_path PATH_TO_DATA/sdn_kitti_train_set 78 | ``` 79 | **Note:** Please adjust the paths `PATH_TO_CHECKPOINTS`, `PATH_TO_KITTI`, and `PATH_TO_DATA` to match your setup. 80 | 3. Rename `training/velodyne` to `training/velodyne_original` 81 | 4. Symlink the KITTI folders to PCDet: 82 | * `ln -s PATH_TO_KITTI/training third_party/OpenPCDet/data/kitti/training` 83 | * `ln -s PATH_TO_KITTI/testing third_party/OpenPCDet/data/kitti/testing` 84 | 85 | 86 | ## 🏃 Running 3D Object Detection 87 | 88 | 1. Adjust paths in [`main.py`](main.py). Further available parameters are listed in [`rl_l2o/eps_greedy_search.py`](rl_l2o/eps_greedy_search.py) and can be added in `main.py`. 89 | 2. Adjust the number of epochs of the 3D object detector in *(we used 40 epochs)*: 90 | - [`object_detection/compute_reward.py`](object_detection/compute_reward.py) --> above the class definition 91 | - [`third_party/OpenPCDet/tools/cfgs/kitti_models/pointpillar.yaml`](third_party/OpenPCDet/tools/cfgs/kitti_models/pointpillar.yaml) --> search for `NUM_EPOCHS`
92 | **Note:** If you use another detector, modify the respective configuration file. 93 | 3. Adjust the training scripts of the utilized detector to match your setup, e.g., [`object_detection/scripts/train_pointpillar.sh`](object_detection/scripts/train_pointpillar.sh). 94 | 5. Initiate the search: `python main.py`
95 | **Note:** Since we keep intermediate results to easily re-use them in later iterations, running the script will create a lot of data in the `output_dir` specified in [`main.py`](main.py). You might want to manually delete some folders from time to time. 96 | 97 | 98 | ## 🔧 Adding more Tasks 99 | 100 | Due to the design of the RL-L2O framework, it can be used as a simple drop-in module for many LiDAR applications. 101 | To apply the search algorithm to another task, just implement a custom `RewardComputer`, e.g., see [`object_detection/compute_reward.py`](object_detection/compute_reward.py). 102 | Additionally, you will have to prepare a set of features for each LiDAR beam. 103 | For the KITTI 3D Object Detection dataset, we provide the features as presented in the paper in [`object_detection/data/features_pcl.pkl`](object_detection/data/features_pcl.pkl). 104 | 105 | 106 | ## 👩‍⚖️ License 107 | 108 | Creative Commons License
109 | This software is made available for non-commercial use under a [Creative Commons Attribution-NonCommercial 4.0 International License](LICENSE). A summary of the license can be found on the [Creative Commons website](http://creativecommons.org/licenses/by-nc/4.0). 110 | -------------------------------------------------------------------------------- /rl_l2o/predict_reward.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | from os import PathLike 4 | from typing import Dict, Any, Optional, Union, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | from sympy.utilities.iterables import multiset_permutations 9 | from torch import nn, optim 10 | from torch.utils.data import Dataset, DataLoader 11 | from tqdm import tqdm 12 | 13 | 14 | class MyDataset(Dataset): 15 | def __init__(self, data: Dict[str, np.ndarray], features_pcl: np.ndarray, elevation_angles_deg: np.ndarray): 16 | super().__init__() 17 | self.features_pcl = features_pcl.astype(np.float32) 18 | self.elevation_angles_deg = elevation_angles_deg.astype(np.float32) 19 | self.state = data['state'].astype(np.int) 20 | self.reward = data['reward'].astype(np.float32) 21 | 22 | def __len__(self): 23 | return self.state.shape[0] 24 | 25 | def __getitem__(self, item): 26 | state = self.state[item, :] 27 | reward = self.reward[item, :] 28 | 29 | features_pcl = self.features_pcl[state] 30 | pairwise_features = np.c_[np.diff(self.elevation_angles_deg[np.sort(state)])] # permutation invariant 31 | features_pcl = np.r_[features_pcl.flatten(), pairwise_features.flatten()] 32 | 33 | sample = {'state': state, 'features_pcl': features_pcl, 'reward': reward} 34 | return sample 35 | 36 | 37 | class MyNetwork(nn.Module): 38 | def __init__(self, output_shape): 39 | super().__init__() 40 | self.INPUT_SHAPE = 39 41 | self.output_shape = output_shape 42 | 43 | self.net = nn.Sequential( 44 | nn.Flatten(), 45 | nn.Linear(self.INPUT_SHAPE, 128), 46 | nn.ReLU(), 47 | nn.Linear(128, 64), 48 | nn.ReLU(), 49 | nn.Linear(64, self.output_shape), 50 | nn.Sigmoid() # Squeeze between 0 and 1 since it is accuracy 51 | ) 52 | 53 | def forward(self, x): 54 | return self.net(x) 55 | 56 | 57 | class RewardPredictor: 58 | DEFAULT_CONFIG = {'num_training_epochs': 10, 'training_batch_size': 8, 'device': None} 59 | 60 | def __init__(self, 61 | num_selected_beams: int, 62 | num_reward_signals: int, 63 | features_pcl_file: Union[str, PathLike], 64 | config: Optional[Dict[str, Any]] = None): 65 | self.config = RewardPredictor.DEFAULT_CONFIG 66 | if config: 67 | self.config.update(config) 68 | self.num_selected_beams = num_selected_beams 69 | self.num_reward_signals = num_reward_signals 70 | self.features_pcl = None 71 | self.elevation_angles_deg = None 72 | self._load_features(features_pcl_file) 73 | 74 | if self.config['device'] is None: 75 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 76 | elif self.config['device'] == 'cpu': 77 | self.device = torch.device('cpu') 78 | elif self.config['device'] == 'cuda': 79 | if torch.cuda.is_available(): 80 | self.device = torch.device('cuda') 81 | else: 82 | raise RuntimeError('CUDA runtime not available.') 83 | else: 84 | raise ValueError(f'Invalid device: {self.config["device"]}. Valid options are: ["cpu", "cuda"].') 85 | print(f'Found device: {self.device}') 86 | 87 | self.regressor = MyNetwork(self.num_reward_signals) 88 | self.regressor.to(self.device) 89 | self.training_data = { 90 | 'state': np.empty((0, self.num_selected_beams), dtype=np.int), 91 | 'reward': np.empty((0, self.num_reward_signals), dtype=np.float32) 92 | } 93 | 94 | def load_training_data(self, training_data: Dict[str, np.ndarray]): 95 | self.training_data = training_data 96 | 97 | def add_to_training_data(self, state: np.ndarray, reward: np.ndarray): 98 | assert state.shape == (self.num_selected_beams,) 99 | assert reward.shape == (self.num_reward_signals,) 100 | 101 | if not self._is_in_training_data(state): 102 | self.training_data['state'] = np.append(self.training_data['state'], 103 | np.reshape(np.sort(state), (1, self.num_selected_beams)), 104 | axis=0) 105 | self.training_data['reward'] = np.append(self.training_data['reward'], 106 | np.reshape(reward, (1, self.num_reward_signals)), 107 | axis=0) 108 | 109 | def train(self, permute: bool = False, reset_weights: bool = True, verbose: bool = False): 110 | if reset_weights: 111 | self.regressor = MyNetwork(self.num_reward_signals) 112 | self.regressor.to(self.device) 113 | 114 | states, rewards = self._permute_training_data(permute) 115 | training_dataset = MyDataset({'state': states, 'reward': rewards}, self.features_pcl, self.elevation_angles_deg) 116 | training_dataloader = DataLoader(training_dataset, batch_size=self.config['training_batch_size'], shuffle=True) 117 | 118 | criterion = nn.MSELoss(reduction='sum') 119 | optimizer = optim.Adam(self.regressor.parameters(), lr=.001, weight_decay=.0001) 120 | 121 | self.regressor.train() 122 | with tqdm(training_dataloader, 123 | disable=not verbose, 124 | unit='batch', 125 | total=self.config['num_training_epochs'] * len(training_dataloader), 126 | desc=f'Training | {self.config["num_training_epochs"]} epochs') as pbar: 127 | epoch_loss = 0 128 | for epoch in range(self.config['num_training_epochs']): 129 | train_loss = 0 130 | for sample_batched in training_dataloader: 131 | optimizer.zero_grad() 132 | features_pcl = sample_batched['features_pcl'].to(self.device) 133 | true_reward = sample_batched['reward'].to(self.device) 134 | predicted_reward = self.regressor(features_pcl) 135 | loss = criterion(predicted_reward, true_reward) 136 | 137 | train_loss += loss.item() 138 | 139 | loss.backward() 140 | optimizer.step() 141 | 142 | pbar.set_postfix(epoch=epoch, epoch_loss=epoch_loss, batch_loss=loss.item()) 143 | pbar.update(1) 144 | epoch_loss = train_loss / len(training_dataloader) 145 | 146 | def predict(self, state: np.ndarray) -> float: 147 | self.regressor.eval() 148 | with torch.no_grad(): 149 | features_pcl = self._state_to_features(state) 150 | features_pcl = torch.from_numpy(features_pcl).to(self.device) 151 | predicted_reward = self.regressor(features_pcl).cpu().detach().numpy() 152 | return float(predicted_reward.mean()) 153 | 154 | def _is_in_training_data(self, state: np.ndarray): 155 | assert state.shape == (self.num_selected_beams,) 156 | return (np.sort(state) == np.sort(self.training_data['state'])).all(axis=1).any() 157 | 158 | def _permute_training_data(self, permute: bool = True, in_place: bool = False) -> Tuple[np.ndarray, np.ndarray]: 159 | if permute: 160 | state, reward = permute_dataset(self.training_data['state'], self.training_data['reward']) 161 | if in_place: 162 | self.training_data['state'], self.training_data['reward'] = state, reward 163 | else: 164 | state, reward = self.training_data['state'], self.training_data['reward'] 165 | return state, reward 166 | 167 | def _load_features(self, features_pcl_file: Union[str, PathLike]): 168 | with open(features_pcl_file, 'rb') as f: 169 | features_pcl_raw = pickle.load(f) 170 | 171 | features_pcl = np.c_[np.fromiter(features_pcl_raw['number_points_avg'].values(), dtype=np.float32), 172 | np.fromiter(features_pcl_raw['mean_d_avg'].values(), dtype=np.float32), 173 | np.fromiter(features_pcl_raw['mean_d_std_avg'].values(), dtype=np.float32), 174 | np.array([ 175 | features_pcl_raw['semantic_classes_avg'][beam_id] 176 | for beam_id in np.arange(0, len(features_pcl_raw['semantic_classes_avg'])) 177 | ]) / np.fromiter(features_pcl_raw['number_points_avg'].values(), dtype=np.float32).reshape( 178 | (-1, 1))].astype(np.float32) 179 | self.elevation_angles_deg = np.fromiter(features_pcl_raw['elevation_angles_deg'].values(), dtype=np.float32) 180 | 181 | # Standardize features 182 | features_pcl = (features_pcl - features_pcl.mean(axis=0)) / features_pcl.std(axis=0) 183 | features_pcl = np.c_[features_pcl, self.elevation_angles_deg] 184 | 185 | # Save for future usage 186 | self.features_pcl = features_pcl 187 | 188 | def _state_to_features(self, state: np.ndarray): 189 | assert state.shape == (self.num_selected_beams,) 190 | features_pcl = self.features_pcl[state] 191 | pairwise_features = np.c_[np.diff(self.elevation_angles_deg[np.sort(state)])] # permutation invariant 192 | features = np.r_[features_pcl.flatten(), pairwise_features.flatten()] 193 | return features.reshape((1, -1)) 194 | 195 | 196 | def permute_dataset(state, metrics): 197 | active_beams_ret = [] 198 | metrics_ret = [] 199 | 200 | for b, m in zip(state, metrics): 201 | for p in multiset_permutations(b): 202 | active_beams_ret.append(p) 203 | metrics_ret.append(m) 204 | 205 | tmp = list(zip(active_beams_ret, metrics_ret)) 206 | random.shuffle(tmp) 207 | active_beams_ret, metrics_ret = zip(*tmp) 208 | 209 | active_beams_ret = np.vstack(active_beams_ret) 210 | metrics_ret = np.vstack(metrics_ret) 211 | return active_beams_ret, metrics_ret 212 | -------------------------------------------------------------------------------- /object_detection/compute_reward.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import subprocess 4 | import time 5 | from pathlib import Path 6 | from typing import Optional 7 | 8 | import numpy as np 9 | 10 | from rl_l2o.compute_reward import RewardComputer as RewardComputerBase 11 | 12 | # This number should match the value of NUM_EPOCHS in the corresponding detector, e.g, 13 | # third_party/OpenPCDet/tools/cfgs/kitti_models/pointpillar.yaml 14 | DETECTOR_NUM_EPOCHS = 40 15 | 16 | 17 | class RewardComputerObjectDetection(RewardComputerBase): 18 | DEFAULT_CONFIG = { 19 | 'plv2_dir': '', 20 | 'opcd_dir': '', 21 | 'kitti_dir': '', 22 | 'output_dir': '', 23 | 'pred_path': '', 24 | 'detector': 'pointpillar', 25 | 'num_threads': 24 26 | } 27 | 28 | @staticmethod 29 | def compute_reward( 30 | beams: list, # List of beam ids 31 | plv2_dir: str, # PseudoLidarv2/gdc directory 32 | opcd_dir: str, # OpenPCDet directory 33 | kitti_dir: str, # KITTI dataset directory 34 | output_dir: str, # Output directory 35 | pred_path: str, # Path of SDN predictions 36 | detector: str = 'pointpillar', # pointpillar or pointrcnn 37 | num_threads: int = 32) -> np.ndarray: 38 | assert len(beams) == 4, 'This implementation assumes exactly 4 beams.' 39 | assert detector in ['pointpillar', 'pointrcnn'] 40 | 41 | # Get paths 42 | calib_path = os.path.join(kitti_dir, 'calib') 43 | image_path = os.path.join(kitti_dir, 'image_2') 44 | point_path = os.path.join(kitti_dir, 'velodyne_original') 45 | split_file = os.path.join(plv2_dir, 'image_sets/trainval.txt') 46 | tmp_dir = os.path.join(output_dir, f'tmp_dir_{detector}') 47 | 48 | tag = f'{beams[0]}_{beams[1]}_{beams[2]}_{beams[3]}' 49 | train_dir = os.path.join(opcd_dir, 'output', 'kitti_models', detector, tag) 50 | 51 | # If the results file already exists, skip the computation 52 | reward = _check_for_existing_logs(train_dir) 53 | if reward is not None: 54 | return reward 55 | 56 | # If the pseudo lidar data already exists, skip most of the pre-processing 57 | output_name = f'pseudo_lidar_{beams[0]}_{beams[1]}_{beams[2]}_{beams[3]}' 58 | output_path = os.path.join(output_dir, output_name) 59 | if not _does_data_exist(output_path): 60 | 61 | # Simulate 4 beam lidar 62 | print(f'Simulating 4 beam lidar with beam selection: {beams}...') 63 | output_name = f'velodyne_{beams[0]}_{beams[1]}_{beams[2]}_{beams[3]}' 64 | output_path = os.path.join(output_dir, output_name) 65 | velodyne_4_beam = f' \ 66 | python {plv2_dir}/sparsify.py \ 67 | --calib_path {calib_path} \ 68 | --image_path {image_path} \ 69 | --split_file {split_file} \ 70 | --ptc_path {point_path} \ 71 | --W 1024 \ 72 | --H 64 \ 73 | --line_spec {beams[0]} {beams[1]} {beams[2]} {beams[3]} \ 74 | --output_path {output_path} \ 75 | --store_line_map_dir {tmp_dir} \ 76 | --threads {num_threads} \ 77 | ' 78 | 79 | if not _does_data_exist(output_path): 80 | os.system(velodyne_4_beam) 81 | 82 | # Get ground truth depth map from original 4 beams 83 | print('Generating ground truth depth maps from the simulated 4 beam lidar...') 84 | input_path = output_path 85 | output_name = f'gt_depthmap_{beams[0]}_{beams[1]}_{beams[2]}_{beams[3]}' 86 | output_path = os.path.join(output_dir, output_name) 87 | gt_depthmap_4_beam = f' \ 88 | python {plv2_dir}/ptc2depthmap.py \ 89 | --output_path {output_path} \ 90 | --input_path {input_path} \ 91 | --calib_path {calib_path} \ 92 | --image_path {image_path} \ 93 | --split_file {split_file} \ 94 | --threads {num_threads} \ 95 | ' 96 | 97 | if not _does_data_exist(output_path): 98 | os.system(gt_depthmap_4_beam) 99 | 100 | # Run batch gdc using ground truth 4 beams on predicted depth maps 101 | print('Running batch GDC using 4 beam ground truth depth map on predicted depth maps...') 102 | input_path = output_path 103 | output_name = f'gdc_depthmap_{beams[0]}_{beams[1]}_{beams[2]}_{beams[3]}' 104 | output_path = os.path.join(output_dir, output_name) 105 | gdc_depthmap_4_beam = f' \ 106 | python {plv2_dir}/main_batch.py \ 107 | --output_path {output_path} \ 108 | --input_path {pred_path} \ 109 | --calib_path {calib_path} \ 110 | --gt_depthmap_path {input_path} \ 111 | --threads {num_threads} \ 112 | --split_file {split_file} \ 113 | ' 114 | 115 | if not _does_data_exist(output_path): 116 | os.system(gdc_depthmap_4_beam) 117 | 118 | # Get pseudo lidar from corrected depth 119 | print('Generating pseudo lidar point clouds from corrected depth maps...') 120 | input_path = output_path 121 | output_name = f'pseudo_lidar_{beams[0]}_{beams[1]}_{beams[2]}_{beams[3]}' 122 | output_path = os.path.join(output_dir, output_name) 123 | pseudo_lidar_4_beam = f' \ 124 | python {plv2_dir}/depthmap2ptc.py \ 125 | --output_path {output_path} \ 126 | --input_path {input_path} \ 127 | --calib_path {calib_path} \ 128 | --threads {num_threads} \ 129 | --split_file {split_file} \ 130 | ' 131 | 132 | if not _does_data_exist(output_path): 133 | os.system(pseudo_lidar_4_beam) 134 | 135 | # Sparsify to 64 lines 136 | print('Sparsifying pseudo lidar point cloud to 64 beams...') 137 | input_path = output_path 138 | output_path = os.path.join(kitti_dir, 'velodyne') 139 | sparse_pseudo_lidar_4_beam = f' \ 140 | python {plv2_dir}/sparsify.py \ 141 | --output_path {output_path} \ 142 | --calib_path {calib_path} \ 143 | --image_path {image_path} \ 144 | --ptc_path {input_path} \ 145 | --split_file {split_file} \ 146 | --W 1024 --slice 1 --H 64 \ 147 | --threads {num_threads} \ 148 | ' 149 | 150 | os.system(sparse_pseudo_lidar_4_beam) 151 | 152 | # This only needs to be done once 153 | # Generate info files and ground truth database for KITTI 154 | if not os.path.exists(f'{opcd_dir}/data/kitti/kitti_infos_trainval.pkl'): 155 | print('Generating KITTI file database...') 156 | kitti_database = f' \ 157 | python -m pcdet.datasets.kitti.kitti_dataset \ 158 | create_kitti_infos \ 159 | {opcd_dir}/tools/cfgs/dataset_configs/kitti_dataset.yaml \ 160 | ' 161 | 162 | os.system(kitti_database) 163 | 164 | # Start training the detector 165 | print('Submitting training job...') 166 | subprocess.Popen([str(Path(__file__).parent / 'scripts' / f'train_{detector}.sh'), tag]) 167 | 168 | # Check if job has started and look for log file 169 | reward = _check_for_existing_logs(train_dir, frequency=5) 170 | return reward 171 | 172 | 173 | def _check_for_existing_logs(train_dir: str, frequency: int = -1) -> Optional[np.ndarray]: 174 | if frequency > 0: 175 | print(f'Looking for job every {frequency} seconds in {train_dir}', end='', flush=True) 176 | log_file_path = None 177 | while True: 178 | for file_path in glob.glob(os.path.join(train_dir, 'log_train*.txt')): 179 | if os.path.getsize(file_path) > 0: 180 | log_file_path = file_path 181 | break 182 | if log_file_path is None and frequency > 0: 183 | time.sleep(frequency) 184 | elif log_file_path is not None: 185 | print(f'Log file located: {log_file_path}.\nChecking for performance of epoch {DETECTOR_NUM_EPOCHS}...') 186 | break 187 | else: 188 | return None 189 | 190 | # Continuously check log file for evaluation 191 | check_phrase = f'Performance of EPOCH {DETECTOR_NUM_EPOCHS}' 192 | check_index = -1 193 | car_ap = None 194 | pedestrian_ap = None 195 | cyclist_ap = None 196 | while True: 197 | with open(log_file_path, 'r', encoding='utf-8') as f: 198 | for i, line in enumerate(f): 199 | if check_phrase in line: 200 | check_index = i 201 | elif check_index == -1: 202 | continue 203 | if i == check_index + 12: # 3D Car AP 204 | car_ap = line 205 | elif i == check_index + 32: # 3D Pedestrian AP 206 | pedestrian_ap = line 207 | elif i == check_index + 52: # 3D Cyclist AP 208 | cyclist_ap = line 209 | break 210 | if car_ap is not None and pedestrian_ap is not None and cyclist_ap is not None: 211 | break 212 | if frequency > 0: 213 | print('.', end='', flush=True) 214 | time.sleep(30) 215 | 216 | assert '3d AP:' in car_ap, car_ap 217 | assert '3d AP:' in pedestrian_ap, pedestrian_ap 218 | assert '3d AP:' in cyclist_ap, cyclist_ap 219 | result = np.array([float(car_ap[8:-1].split(', ')[1]) / 100]) # 3D AP for car: moderate 220 | # result = np.array([float(j) / 100 for j in car_ap[8:-1].split(', ')]) # Get 3D AP for easy, medium, and hard 221 | # result = np.array([ 222 | # float(car_ap[8:-1].split(', ')[1]) / 100, 223 | # float(pedestrian_ap[8:-1].split(', ')[1]) / 100, 224 | # float(cyclist_ap[8:-1].split(', ')[1]) / 100 225 | # ]) # Get 3D AP for car, ped, and cyclists of moderate difficulty 226 | return result 227 | 228 | 229 | def _does_data_exist(path: str, number_expected_files: int = 7481) -> bool: 230 | if not os.path.exists(path): 231 | return False 232 | number_files = len(os.listdir(path)) 233 | if number_files == number_expected_files: 234 | return True 235 | return False 236 | -------------------------------------------------------------------------------- /rl_l2o/eps_greedy_search.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from copy import deepcopy 3 | from datetime import datetime 4 | from itertools import product 5 | from os import PathLike 6 | from pathlib import Path 7 | from typing import Dict, Any, Optional, Union, Tuple 8 | 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from rl_l2o.compute_reward import RewardComputer 13 | from rl_l2o.predict_reward import RewardPredictor 14 | from rl_l2o.utils import Colors 15 | 16 | 17 | class EpsGreedySearch: 18 | DEFAULT_CONFIG = { 19 | 'initial_beam_ids': None, # after initialization, start with these beams. If None, start with the best so far. 20 | 'epsilon': .1, # take a random action with this probability 21 | 'num_initial_samples': 5, # number of random samples to start training the value function predictor 22 | 'num_samples': 200, # total number of samples 23 | 'max_step_size': 2, # the beam IDs can be shifted a maximum of this number (pos/neg) 24 | 'seed': None, # used to initialize all random number generators 25 | 'num_selected_beams': 4, # number of beams selected for valid states 26 | 'num_reward_signals': 1, # number of environmental signals used to predict the overall reward 27 | 'min_beam_id': 1, # smallest valid beam ID 28 | 'max_beam_id': 40, # highest valid beam ID 29 | 'checkpoint_files_path': Path(__file__).absolute().parents[1] / 'checkpoints', 30 | 'predict_reward': { 31 | 'num_training_epochs': 10, 32 | 'training_batch_size': 8, 33 | 'device': None 34 | }, 35 | } 36 | 37 | def __init__( 38 | self, 39 | features_pcl_file: Union[str, PathLike], 40 | reward_computer: RewardComputer, 41 | config: Optional[Dict[str, Any]] = None, 42 | logfile: Union[str, PathLike] = 'log.txt', 43 | ): 44 | self.config = self.DEFAULT_CONFIG.copy() 45 | if config: 46 | self.config.update(config) 47 | if self.config['initial_beam_ids'] is not None: 48 | assert len(self.config['initial_beam_ids']) == self.config['num_selected_beams'] 49 | assert min(self.config['initial_beam_ids']) >= self.config['min_beam_id'] 50 | assert max(self.config['initial_beam_ids']) <= self.config['max_beam_id'] 51 | self.config['features_pcl_file'] = str(features_pcl_file) 52 | 53 | # Random number generators 54 | self.epsilon_generator = np.random.default_rng(seed=self.config['seed']) 55 | self.random_action_generator = np.random.default_rng(seed=self.config['seed']) 56 | self.random_state_generator = np.random.default_rng(seed=self.config['seed']) 57 | 58 | # Miscellaneous objects 59 | self.reward_predictor = RewardPredictor(self.config['num_selected_beams'], self.config['num_reward_signals'], 60 | self.config['features_pcl_file'], self.config['predict_reward']) 61 | self.reward_computer = reward_computer 62 | self.reward_computer.set_number_reward_signals(self.config['num_reward_signals']) 63 | self.log_history = {} # step_counter: (message, color) 64 | self.state_reward_history = [] # (beams, reward, predicted_reward) 65 | self.state_action_pairs = [] # (beams, steps) 66 | self.step_counter = 0 67 | self._continue_searching = False # can be set when loading a checkpoint 68 | 69 | # Logging 70 | self.logfile = Path(logfile) 71 | if self.logfile.exists(): 72 | prev_logfile = str(logfile).replace(self.logfile.suffix, f'_old{self.logfile.suffix}') 73 | self.logfile.rename(prev_logfile) 74 | print(f'{Colors.WARNING}WARNING: Found existing logfile and renamed it to "{prev_logfile}".{Colors.ENDC}') 75 | 76 | def run(self) -> np.ndarray: 77 | self._initialize_predictor() 78 | if self._continue_searching: 79 | state = self.state_reward_history[-1][0] 80 | elif self.config['initial_beam_ids'] is None: 81 | state = self.best_state() 82 | else: 83 | if self.step_counter < self.config['num_samples']: 84 | self.step_counter += 1 85 | state = np.asarray(self.config['initial_beam_ids'], dtype=np.int) 86 | reward, reward_signals = self.reward_computer.compute(state) 87 | self.reward_predictor.add_to_training_data(state, reward_signals) 88 | self.reward_predictor.train(permute=True) 89 | is_best_state = self._add_to_state_reward_history(state, reward) 90 | self._write_to_log(state, reward, is_best_state=is_best_state) 91 | self.save_checkpoint() 92 | 93 | # Main loop 94 | while self.step_counter < self.config['num_samples']: 95 | self.step_counter += 1 96 | 97 | if self.epsilon_generator.random() < self.config['epsilon']: 98 | action = self._get_random_action(state) 99 | predicted_reward = self.reward_predictor.predict(state) 100 | else: 101 | action, predicted_reward = self._get_best_action(state) 102 | state = self._apply_action(state, action) 103 | assert self._is_valid_state(state), state 104 | if self._is_in_state_action_pairs(state, action): 105 | state = self._get_random_state(self.random_state_generator) 106 | self._write_message_to_log( 107 | 'WARNING: Resampled new random state due to a loop in the state-action-history.', Colors.WARNING, 108 | True) 109 | 110 | reward, reward_signals = self.reward_computer.compute(state) 111 | self.reward_predictor.add_to_training_data(state, reward_signals) 112 | self.reward_predictor.train(permute=True) 113 | 114 | is_best_state = self._add_to_state_reward_history(state, reward, predicted_reward) 115 | self._write_to_log(state, reward, predicted_reward, is_best_state) 116 | self.save_checkpoint() 117 | 118 | return self.best_state() 119 | 120 | def best_state(self, return_reward: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, float]]: 121 | state_reward = np.array([[x[0], x[1]] for x in self.state_reward_history], dtype=object) 122 | best_reward_index = int(np.argmax(state_reward[:, 1])) 123 | best_state_reward = self.state_reward_history[best_reward_index] 124 | best_state_reward = (np.sort(best_state_reward[0]), best_state_reward[1]) 125 | if return_reward: 126 | return best_state_reward 127 | return best_state_reward[0] 128 | 129 | def save_checkpoint(self): 130 | if not self.config['checkpoint_files_path'].exists(): 131 | self.config['checkpoint_files_path'].mkdir(parents=True, exist_ok=True) 132 | checkpoint_file = self.config['checkpoint_files_path'] / f'checkpoint_{str(self.step_counter).zfill(3)}.pkl' 133 | checkpoint = { 134 | 'config': self.config, 135 | 'state_reward_history': self.state_reward_history, 136 | 'state_action_pairs': self.state_action_pairs, 137 | 'step_counter': self.step_counter, 138 | 'log_history': self.log_history, 139 | 'epsilon_generator': self.epsilon_generator.__getstate__(), 140 | 'random_action_generator': self.random_action_generator.__getstate__(), 141 | 'random_state_generator': self.random_state_generator.__getstate__(), 142 | 'reward_computer_cache': self.reward_computer.cache, 143 | 'reward_predictor_training_data': self.reward_predictor.training_data, 144 | } 145 | with open(checkpoint_file, 'wb') as f: 146 | pickle.dump(checkpoint, f, pickle.DEFAULT_PROTOCOL) 147 | 148 | def load_checkpoint(self, checkpoint_file: Union[str, PathLike], continue_searching: bool = True): 149 | with open(checkpoint_file, 'rb') as f: 150 | checkpoint = pickle.load(f) 151 | self._continue_searching = continue_searching 152 | 153 | # Check for compatible setups 154 | assert self.config['num_selected_beams'] == checkpoint['config']['num_selected_beams'] 155 | 156 | # We do not load paths and total number of samples from checkpoint files 157 | config = deepcopy(self.config) 158 | self.config = checkpoint['config'] 159 | self.config['compute_reward'] = config['compute_reward'] 160 | self.config['checkpoint_files_path'] = config['checkpoint_files_path'] 161 | self.config['features_pcl_file'] = config['features_pcl_file'] 162 | self.config['num_samples'] = config['num_samples'] 163 | 164 | self.state_action_pairs = checkpoint['state_action_pairs'] 165 | self.step_counter = checkpoint['step_counter'] 166 | if 'log_history' in checkpoint: 167 | self.log_history = checkpoint['log_history'] 168 | self.epsilon_generator.__setstate__(checkpoint['epsilon_generator']) 169 | self.random_action_generator.__setstate__(checkpoint['random_action_generator']) 170 | self.random_state_generator.__setstate__(checkpoint['random_state_generator']) 171 | 172 | self.reward_predictor = RewardPredictor(self.config['num_selected_beams'], self.config['num_reward_signals'], 173 | self.config['features_pcl_file'], self.config['predict_reward']) 174 | self.reward_predictor.load_training_data(checkpoint['reward_predictor_training_data']) 175 | self.reward_computer.load_cache(checkpoint['reward_computer_cache']) 176 | 177 | for step_counter_replay, state_reward in enumerate(checkpoint['state_reward_history'], start=1): 178 | if step_counter_replay in self.log_history: 179 | self._write_message_to_log(self.log_history[step_counter_replay][0], 180 | self.log_history[step_counter_replay][1]) 181 | if len(state_reward) == 2: 182 | state_reward = (state_reward[0], state_reward[1], None) 183 | is_best_state = self._add_to_state_reward_history(state_reward[0], state_reward[1], state_reward[2]) 184 | self._write_to_log(state_reward[0], state_reward[1], state_reward[2], is_best_state, step_counter_replay) 185 | 186 | self._write_message_to_log('INFO: Resuming from a checkpoint.', Colors.OKBLUE, True) 187 | 188 | def load_reward_prediction_cache(self, checkpoint_file: Union[str, PathLike]): 189 | with open(checkpoint_file, 'rb') as f: 190 | checkpoint = pickle.load(f) 191 | 192 | # Check for compatible setups 193 | assert self.config['num_selected_beams'] == checkpoint['config']['num_selected_beams'] 194 | 195 | self.reward_predictor.load_training_data(checkpoint['reward_predictor_training_data']) 196 | 197 | def load_reward_computation_cache(self, checkpoint_file: Union[str, PathLike]): 198 | with open(checkpoint_file, 'rb') as f: 199 | checkpoint = pickle.load(f) 200 | 201 | # Check for compatible setups 202 | assert self.config['num_selected_beams'] == checkpoint['config']['num_selected_beams'] 203 | 204 | self.reward_computer.load_cache(checkpoint['reward_computer_cache']) 205 | 206 | def _initialize_predictor(self): 207 | while self.step_counter < self.config['num_initial_samples']: 208 | self.step_counter += 1 209 | state = self._get_random_state(self.random_state_generator) 210 | reward, reward_signals = self.reward_computer.compute(state) 211 | self.reward_predictor.add_to_training_data(state, reward_signals) 212 | is_best_state = self._add_to_state_reward_history(state, reward) 213 | self._write_to_log(state, reward, is_best_state=is_best_state) 214 | self.save_checkpoint() 215 | self.reward_predictor.train(permute=True) 216 | 217 | def _is_valid_state(self, state: np.ndarray) -> bool: 218 | check = np.unique(state).size == self.config['num_selected_beams'] 219 | check &= state.min() >= self.config['min_beam_id'] 220 | check &= state.max() <= self.config['max_beam_id'] 221 | return check 222 | 223 | def _get_random_state(self, random_number_generator: np.random.Generator) -> np.ndarray: 224 | state = np.empty((1,)) 225 | while not self._is_valid_state(state): 226 | state = np.unique( 227 | np.sort( 228 | random_number_generator.integers(self.config['min_beam_id'], 229 | self.config['max_beam_id'], 230 | endpoint=True, 231 | size=(self.config['num_selected_beams'],)))) 232 | return state 233 | 234 | def _get_random_action(self, state: Optional[np.ndarray] = None) -> np.ndarray: 235 | action = self.random_action_generator.integers(-self.config['max_step_size'], 236 | self.config['max_step_size'], 237 | endpoint=True, 238 | size=(self.config['num_selected_beams'],)) 239 | # If a state is provided, only valid actions will be returned. 240 | while state is not None and not self._is_valid_state(self._apply_action(state, action)): 241 | action = self.random_action_generator.integers(-self.config['max_step_size'], 242 | self.config['max_step_size'], 243 | endpoint=True, 244 | size=(self.config['num_selected_beams'],)) 245 | return action 246 | 247 | def _get_best_action(self, state: np.ndarray, show_pbar: bool = False) -> Tuple[np.ndarray, Optional[float]]: 248 | number_selected_beams = self.config['num_selected_beams'] 249 | actions = [] 250 | rewards = [] 251 | valid_steps = list(range(-self.config['max_step_size'], self.config['max_step_size'] + 1)) 252 | with tqdm(total=len(valid_steps) ** number_selected_beams, desc='Predicting best action', 253 | disable=not show_pbar) as pbar: 254 | for action in product(*[valid_steps] * number_selected_beams): 255 | action = np.array(action) 256 | new_state = self._apply_action(state, action) 257 | if self._is_valid_state(new_state): 258 | predicted_reward = self.reward_predictor.predict(new_state) 259 | actions.append(action) 260 | rewards.append(predicted_reward) 261 | pbar.update(1) 262 | if rewards: 263 | best_reward = max(rewards) 264 | best_action = actions[rewards.index(best_reward)] 265 | else: 266 | best_action = np.zeros((number_selected_beams,), dtype=np.int) 267 | best_reward = None 268 | return best_action, best_reward 269 | 270 | @staticmethod 271 | def _apply_action(state: np.ndarray, action: np.ndarray) -> np.ndarray: 272 | new_state = state + action 273 | return new_state 274 | 275 | def _is_in_state_action_pairs(self, state: np.ndarray, action: np.ndarray) -> bool: 276 | sorted_action = action[np.argsort(state)] 277 | sorted_state = np.sort(state) 278 | for state_action_pair in self.state_action_pairs: 279 | if np.all(state_action_pair[0] == sorted_state) and np.all(state_action_pair[1] == sorted_action): 280 | return True 281 | # Otherwise, add it to the list 282 | self.state_action_pairs.append((sorted_state, sorted_action)) 283 | return False 284 | 285 | def _add_to_state_reward_history(self, 286 | state: np.ndarray, 287 | reward: float, 288 | predicted_reward: Optional[float] = None) -> bool: 289 | # Add the (state, reward) pair to the cache and determines whether it is a new global optimum 290 | self.state_reward_history.append((state, reward, predicted_reward)) 291 | if np.all(np.sort(state) == self.best_state()): 292 | return True 293 | return False 294 | 295 | def _write_to_log(self, 296 | state: np.ndarray, 297 | reward: float, 298 | predicted_reward: Optional[float] = None, 299 | is_best_state: bool = False, 300 | step_counter: Optional[int] = None): 301 | step_counter = self.step_counter if step_counter is None else step_counter 302 | msg = f'{str(step_counter).rjust(3)} | {str(state)[1:-1].ljust(11)} | {reward:.5f}' 303 | msg = f'{msg} | pred={predicted_reward:.5f}' if predicted_reward is not None else f'{msg} | ' 304 | color = Colors.OKGREEN if is_best_state else None 305 | msg = f'{msg} | new best state' if is_best_state else msg 306 | self._write_message_to_log(msg, color) 307 | 308 | def _write_message_to_log(self, message: str, color: Optional[Colors] = None, add_to_log_history: bool = False): 309 | if add_to_log_history: 310 | self.log_history[self.step_counter] = (message, color) 311 | current_time = datetime.now().strftime('%H:%M:%S') 312 | msg = f'{current_time} | {message}' 313 | if color is not None: 314 | print(color + msg + Colors.ENDC) 315 | else: 316 | print(msg) 317 | with open(self.logfile, 'a', encoding='utf-8') as f: 318 | f.write(msg + '\n') 319 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | --------------------------------------------------------------------------------