├── rl_l2o
├── __init__.py
├── utils.py
├── compute_reward.py
├── predict_reward.py
└── eps_greedy_search.py
├── object_detection
├── __init__.py
├── data
│ └── features_pcl.pkl
├── scripts
│ ├── train_pointpillar.sh
│ └── train_pointrcnn.sh
└── compute_reward.py
├── overview.png
├── requirements.txt
├── .gitmodules
├── main.py
├── .gitignore
├── README.md
└── LICENSE
/rl_l2o/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/object_detection/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vniclas/lidar_beam_selection/HEAD/overview.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.5.0
2 | torchvision==0.6.0
3 | numpy==1.20
4 | sympy
5 | tqdm
6 | pykdtree
7 |
--------------------------------------------------------------------------------
/object_detection/data/features_pcl.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vniclas/lidar_beam_selection/HEAD/object_detection/data/features_pcl.pkl
--------------------------------------------------------------------------------
/object_detection/scripts/train_pointpillar.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd /home/USER/git/lidar_beam_selection/third_party/OpenPCDet/tools || exit
4 |
5 | bash scripts/dist_train.sh 4 --cfg_file cfgs/kitti_models/pointpillar.yaml --ckpt_save_interval 40 --extra_tag $1
6 |
--------------------------------------------------------------------------------
/object_detection/scripts/train_pointrcnn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd /home/USER/git/lidar_beam_selection/third_party/OpenPCDet/tools || exit
4 |
5 | bash scripts/dist_train.sh 4 --cfg_file cfgs/kitti_models/pointrcnn.yaml --ckpt_save_interval 40 --extra_tag $1
6 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/OpenPCDet"]
2 | path = third_party/OpenPCDet
3 | url = https://github.com/open-mmlab/OpenPCDet.git
4 | [submodule "third_party/Pseudo_Lidar_V2"]
5 | path = third_party/Pseudo_Lidar_V2
6 | url = https://github.com/mileyan/Pseudo_Lidar_V2.git
7 |
--------------------------------------------------------------------------------
/rl_l2o/utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 |
4 | @dataclass
5 | class Colors:
6 | HEADER = '\033[95m'
7 | OKBLUE = '\033[94m'
8 | OKCYAN = '\033[96m'
9 | OKGREEN = '\033[92m'
10 | WARNING = '\033[93m'
11 | FAIL = '\033[91m'
12 | ENDC = '\033[0m'
13 | BOLD = '\033[1m'
14 | UNDERLINE = '\033[4m'
15 |
--------------------------------------------------------------------------------
/rl_l2o/compute_reward.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from copy import deepcopy
3 | from typing import Any, Dict, List, Tuple, Optional
4 |
5 | import numpy as np
6 |
7 |
8 | class RewardComputer(ABC):
9 | def __init__(self, config: Optional[Dict[str, Any]] = None):
10 | self.config = deepcopy(self.DEFAULT_CONFIG)
11 | if config:
12 | self.config.update(config)
13 | self.num_reward_signals = None
14 |
15 | self.cache = [] # (beams, reward)
16 |
17 | def set_number_reward_signals(self, num_reward_signals: int = 1):
18 | self.num_reward_signals = num_reward_signals
19 |
20 | def compute(self, state: np.array) -> Tuple[float, np.ndarray]:
21 | assert self.num_reward_signals is not None
22 |
23 | sorted_state = np.sort(state)
24 | for pair in self.cache:
25 | if np.all(pair[0] == sorted_state):
26 | return float(pair[1].mean()), pair[1]
27 |
28 | reward_signals = self.compute_reward(np.sort(state).tolist(), **self.config)
29 | # Ensure that the code is consistent
30 | assert reward_signals.shape == (self.num_reward_signals,)
31 |
32 | self.cache.append((np.sort(state), reward_signals))
33 | return float(reward_signals.mean()), reward_signals
34 |
35 | def load_cache(self, cache: List[Tuple[np.ndarray, float]]):
36 | self.cache = cache
37 |
38 | @property
39 | @classmethod
40 | @abstractmethod
41 | def DEFAULT_CONFIG(cls):
42 | raise NotImplementedError
43 |
44 | @staticmethod
45 | @abstractmethod
46 | def compute_reward() -> np.ndarray:
47 | pass
48 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from object_detection.compute_reward import RewardComputerObjectDetection
4 | from rl_l2o.eps_greedy_search import EpsGreedySearch
5 |
6 | detector = 'pointpillar'
7 | # detector = 'rcnn'
8 |
9 | # Overwrite default config parameters
10 | config = {
11 | 'checkpoint_files_path': Path(__file__).absolute().parent / f'checkpoints_{detector}',
12 | 'compute_reward': {
13 | 'plv2_dir': '/home/USER/git/lidar_beam_selection/third_party/Pseudo_Lidar_V2/gdc',
14 | 'opcd_dir': '/home/USER/git/lidar_beam_selection/third_party/OpenPCDet',
15 | 'kitti_dir': '/home/USER/data/kitti/training',
16 | 'output_dir': '/home/USER/data/pseudo_lidar_v2',
17 | 'pred_path': '/home/USER/data/pseudo_lidar_v2/sdn_kitti_train_set/depth_maps/trainval',
18 | 'detector': detector,
19 | }
20 | }
21 | logfile = Path(__file__).parent / f'log_{detector}.txt'
22 |
23 | features_pcl_file = Path(__file__).absolute().parent / 'object_detection' / 'data' / 'features_pcl.pkl'
24 | reward_computer = RewardComputerObjectDetection(config['compute_reward'])
25 |
26 | # To start a new run from scratch
27 | search = EpsGreedySearch(features_pcl_file, reward_computer, config, logfile)
28 |
29 | # To resume a previous run
30 | # checkpoint_file = Path(__file__).parent / f'checkpoints_{detector}' / 'checkpoint_100.pkl'
31 | # search.load_checkpoint(checkpoint_file, continue_searching=True)
32 |
33 | # To load the cache containing the true reward
34 | # search.load_reward_computation_cache(checkpoint_file) # To load the cache containing the true reward
35 |
36 | # Start the search
37 | search.run()
38 |
39 | # Display the best beam configuration
40 | best_state = search.best_state(return_reward=True)
41 | print(f'\033[96m RESULT: state={best_state[0]}, reward={best_state[1]:.3f} \033[0m') # Colored as OKCYAN
42 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Data files
132 | *.obj
133 | *.pkl
134 |
135 | # Pycharm
136 | .idea
137 |
138 | # VS Code
139 | .vscode
140 |
141 | # Project specific
142 | object_detection/checkpoints/*
143 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # End-to-End Optimization of LiDAR Beam Configuration
2 | [**arXiv**](https://arxiv.org/abs/2201.03860) | [**IEEE Xplore**](https://ieeexplore.ieee.org/abstract/document/9681305) | [**Video**](https://youtu.be/7fl0gLuvhZc)
3 |
4 | This repository is the official implementation of the paper:
5 |
6 | > **End-to-End Optimization of LiDAR Beam Configuration for 3D Object Detection and Localization**
7 | >
8 | > [Niclas Vödisch](https://vniclas.github.io/), [Ozan Unal](https://vision.ee.ethz.ch/people-details.MjA5ODkz.TGlzdC8zMjg5LC0xOTcxNDY1MTc4.html), [Ke Li](https://icu.ee.ethz.ch/people/person-detail.ke-li.html), [Luc Van Gool](https://vision.ee.ethz.ch/people-details.OTAyMzM=.TGlzdC8zMjQ4LC0xOTcxNDY1MTc4.html), and [Dengxin Dai](https://people.ee.ethz.ch/~daid/).
9 | >
10 | > *IEEE Robotics and Automation Letters (RA-L)*, vol. 7, no. 2, pp. 2242-2249, April 2022
11 |
12 |
13 |
14 |
15 |
16 | If you find our work useful, please consider citing our paper:
17 | ```
18 | @ARTICLE{Voedisch_2022_RAL,
19 | author={Vödisch, Niclas and Unal, Ozan and Li, Ke and Van Gool, Luc and Dai, Dengxin},
20 | journal={IEEE Robotics and Automation Letters},
21 | title={End-to-End Optimization of LiDAR Beam Configuration for 3D Object Detection and Localization},
22 | year={2022},
23 | volume={7},
24 | number={2},
25 | pages={2242-2249},
26 | doi={10.1109/LRA.2022.3142738}}
27 | ```
28 |
29 | ## 📔 Abstract
30 |
31 | Pre-determined beam configurations of low-resolution LiDARs are task-agnostic, hence simply using can result in non-optimal performance.
32 | In this work, we propose to optimize the beam distribution for a given target task via a reinforcement learning-based learning-to-optimize (RL-L2O) framework.
33 | We design our method in an end-to-end fashion leveraging the final performance of the task to guide the search process.
34 | Due to the simplicity of our approach, our work can be integrated with any LiDAR-based application as a simple drop-in module.
35 | In this repository, we provide the code for the exemplary task of 3D object detection.
36 |
37 |
38 | ## 🏗️️ Setup
39 |
40 | To clone this repository and all submodules run:
41 | ```shell
42 | git clone --recurse-submodules -j8 git@github.com:vniclas/lidar_beam_selection.git
43 | ```
44 |
45 | ### ⚙️ Installation
46 |
47 | To install this code, please follow the steps below:
48 | 1. Create a conda environment: `conda create -n beam_selection python=3.8`
49 | 2. Activate the environment: `conda activate beam_selection`
50 | 3. Install dependencies: `pip install -r requirements.txt`
51 | 4. Install cudatoolkit *(change to the used CUDA version)*:
52 | `conda install cudnn cudatoolkit=10.2`
53 | 5. Install [spconv](https://github.com/traveller59/spconv#install) *(change to the used CUDA version)*:
54 | `pip install spconv-cu102`
55 | 6. Install [OpenPCDet](https://github.com/open-mmlab/OpenPCDet) *(linked as submodule)*:
56 | `cd third_party/OpenPCDet && python setup.py develop && cd ../..`
57 | 7. Install [Pseudo-LiDAR++](https://github.com/mileyan/Pseudo_Lidar_V2) *(linked as submodule)*:
58 | `pip install -r third_party/Pseudo_Lidar_V2/requirements.txt`
59 | `pip install pillow==8.3.2` *(avoid runtime warnings)*
60 |
61 | ### 💾 Data Preparation
62 |
63 | 1. Download [KITTI 3D Object Detection dataset](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) and extract the files:
64 | 1. Left color images `image_2`
65 | 2. Right color images `image_3`
66 | 3. Velodyne point clouds `velodyne`
67 | 4. Camera calibration matrices `calib`
68 | 5. Training labels `label_2`
69 | 2. Predict the depth maps:
70 | 1. Download [pretrained model (training+validation)](https://github.com/mileyan/Pseudo_Lidar_V2#pretrained-models)
71 | 2. Generate the data:
72 | ```shell
73 | cd third_party/Pseudo_Lidar_V2
74 | python ./src/main.py -c src/configs/sdn_kitti_train.config \
75 | --resume PATH_TO_CHECKPOINTS/sdn_kitti_object_trainval.pth --datapath PATH_TO_KITTI/training/ \
76 | --data_list ./split/trainval.txt --generate_depth_map --data_tag trainval \
77 | --save_path PATH_TO_DATA/sdn_kitti_train_set
78 | ```
79 | **Note:** Please adjust the paths `PATH_TO_CHECKPOINTS`, `PATH_TO_KITTI`, and `PATH_TO_DATA` to match your setup.
80 | 3. Rename `training/velodyne` to `training/velodyne_original`
81 | 4. Symlink the KITTI folders to PCDet:
82 | * `ln -s PATH_TO_KITTI/training third_party/OpenPCDet/data/kitti/training`
83 | * `ln -s PATH_TO_KITTI/testing third_party/OpenPCDet/data/kitti/testing`
84 |
85 |
86 | ## 🏃 Running 3D Object Detection
87 |
88 | 1. Adjust paths in [`main.py`](main.py). Further available parameters are listed in [`rl_l2o/eps_greedy_search.py`](rl_l2o/eps_greedy_search.py) and can be added in `main.py`.
89 | 2. Adjust the number of epochs of the 3D object detector in *(we used 40 epochs)*:
90 | - [`object_detection/compute_reward.py`](object_detection/compute_reward.py) --> above the class definition
91 | - [`third_party/OpenPCDet/tools/cfgs/kitti_models/pointpillar.yaml`](third_party/OpenPCDet/tools/cfgs/kitti_models/pointpillar.yaml) --> search for `NUM_EPOCHS`
92 | **Note:** If you use another detector, modify the respective configuration file.
93 | 3. Adjust the training scripts of the utilized detector to match your setup, e.g., [`object_detection/scripts/train_pointpillar.sh`](object_detection/scripts/train_pointpillar.sh).
94 | 5. Initiate the search: `python main.py`
95 | **Note:** Since we keep intermediate results to easily re-use them in later iterations, running the script will create a lot of data in the `output_dir` specified in [`main.py`](main.py). You might want to manually delete some folders from time to time.
96 |
97 |
98 | ## 🔧 Adding more Tasks
99 |
100 | Due to the design of the RL-L2O framework, it can be used as a simple drop-in module for many LiDAR applications.
101 | To apply the search algorithm to another task, just implement a custom `RewardComputer`, e.g., see [`object_detection/compute_reward.py`](object_detection/compute_reward.py).
102 | Additionally, you will have to prepare a set of features for each LiDAR beam.
103 | For the KITTI 3D Object Detection dataset, we provide the features as presented in the paper in [`object_detection/data/features_pcl.pkl`](object_detection/data/features_pcl.pkl).
104 |
105 |
106 | ## 👩⚖️ License
107 |
108 | 
109 | This software is made available for non-commercial use under a [Creative Commons Attribution-NonCommercial 4.0 International License](LICENSE). A summary of the license can be found on the [Creative Commons website](http://creativecommons.org/licenses/by-nc/4.0).
110 |
--------------------------------------------------------------------------------
/rl_l2o/predict_reward.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import random
3 | from os import PathLike
4 | from typing import Dict, Any, Optional, Union, Tuple
5 |
6 | import numpy as np
7 | import torch
8 | from sympy.utilities.iterables import multiset_permutations
9 | from torch import nn, optim
10 | from torch.utils.data import Dataset, DataLoader
11 | from tqdm import tqdm
12 |
13 |
14 | class MyDataset(Dataset):
15 | def __init__(self, data: Dict[str, np.ndarray], features_pcl: np.ndarray, elevation_angles_deg: np.ndarray):
16 | super().__init__()
17 | self.features_pcl = features_pcl.astype(np.float32)
18 | self.elevation_angles_deg = elevation_angles_deg.astype(np.float32)
19 | self.state = data['state'].astype(np.int)
20 | self.reward = data['reward'].astype(np.float32)
21 |
22 | def __len__(self):
23 | return self.state.shape[0]
24 |
25 | def __getitem__(self, item):
26 | state = self.state[item, :]
27 | reward = self.reward[item, :]
28 |
29 | features_pcl = self.features_pcl[state]
30 | pairwise_features = np.c_[np.diff(self.elevation_angles_deg[np.sort(state)])] # permutation invariant
31 | features_pcl = np.r_[features_pcl.flatten(), pairwise_features.flatten()]
32 |
33 | sample = {'state': state, 'features_pcl': features_pcl, 'reward': reward}
34 | return sample
35 |
36 |
37 | class MyNetwork(nn.Module):
38 | def __init__(self, output_shape):
39 | super().__init__()
40 | self.INPUT_SHAPE = 39
41 | self.output_shape = output_shape
42 |
43 | self.net = nn.Sequential(
44 | nn.Flatten(),
45 | nn.Linear(self.INPUT_SHAPE, 128),
46 | nn.ReLU(),
47 | nn.Linear(128, 64),
48 | nn.ReLU(),
49 | nn.Linear(64, self.output_shape),
50 | nn.Sigmoid() # Squeeze between 0 and 1 since it is accuracy
51 | )
52 |
53 | def forward(self, x):
54 | return self.net(x)
55 |
56 |
57 | class RewardPredictor:
58 | DEFAULT_CONFIG = {'num_training_epochs': 10, 'training_batch_size': 8, 'device': None}
59 |
60 | def __init__(self,
61 | num_selected_beams: int,
62 | num_reward_signals: int,
63 | features_pcl_file: Union[str, PathLike],
64 | config: Optional[Dict[str, Any]] = None):
65 | self.config = RewardPredictor.DEFAULT_CONFIG
66 | if config:
67 | self.config.update(config)
68 | self.num_selected_beams = num_selected_beams
69 | self.num_reward_signals = num_reward_signals
70 | self.features_pcl = None
71 | self.elevation_angles_deg = None
72 | self._load_features(features_pcl_file)
73 |
74 | if self.config['device'] is None:
75 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
76 | elif self.config['device'] == 'cpu':
77 | self.device = torch.device('cpu')
78 | elif self.config['device'] == 'cuda':
79 | if torch.cuda.is_available():
80 | self.device = torch.device('cuda')
81 | else:
82 | raise RuntimeError('CUDA runtime not available.')
83 | else:
84 | raise ValueError(f'Invalid device: {self.config["device"]}. Valid options are: ["cpu", "cuda"].')
85 | print(f'Found device: {self.device}')
86 |
87 | self.regressor = MyNetwork(self.num_reward_signals)
88 | self.regressor.to(self.device)
89 | self.training_data = {
90 | 'state': np.empty((0, self.num_selected_beams), dtype=np.int),
91 | 'reward': np.empty((0, self.num_reward_signals), dtype=np.float32)
92 | }
93 |
94 | def load_training_data(self, training_data: Dict[str, np.ndarray]):
95 | self.training_data = training_data
96 |
97 | def add_to_training_data(self, state: np.ndarray, reward: np.ndarray):
98 | assert state.shape == (self.num_selected_beams,)
99 | assert reward.shape == (self.num_reward_signals,)
100 |
101 | if not self._is_in_training_data(state):
102 | self.training_data['state'] = np.append(self.training_data['state'],
103 | np.reshape(np.sort(state), (1, self.num_selected_beams)),
104 | axis=0)
105 | self.training_data['reward'] = np.append(self.training_data['reward'],
106 | np.reshape(reward, (1, self.num_reward_signals)),
107 | axis=0)
108 |
109 | def train(self, permute: bool = False, reset_weights: bool = True, verbose: bool = False):
110 | if reset_weights:
111 | self.regressor = MyNetwork(self.num_reward_signals)
112 | self.regressor.to(self.device)
113 |
114 | states, rewards = self._permute_training_data(permute)
115 | training_dataset = MyDataset({'state': states, 'reward': rewards}, self.features_pcl, self.elevation_angles_deg)
116 | training_dataloader = DataLoader(training_dataset, batch_size=self.config['training_batch_size'], shuffle=True)
117 |
118 | criterion = nn.MSELoss(reduction='sum')
119 | optimizer = optim.Adam(self.regressor.parameters(), lr=.001, weight_decay=.0001)
120 |
121 | self.regressor.train()
122 | with tqdm(training_dataloader,
123 | disable=not verbose,
124 | unit='batch',
125 | total=self.config['num_training_epochs'] * len(training_dataloader),
126 | desc=f'Training | {self.config["num_training_epochs"]} epochs') as pbar:
127 | epoch_loss = 0
128 | for epoch in range(self.config['num_training_epochs']):
129 | train_loss = 0
130 | for sample_batched in training_dataloader:
131 | optimizer.zero_grad()
132 | features_pcl = sample_batched['features_pcl'].to(self.device)
133 | true_reward = sample_batched['reward'].to(self.device)
134 | predicted_reward = self.regressor(features_pcl)
135 | loss = criterion(predicted_reward, true_reward)
136 |
137 | train_loss += loss.item()
138 |
139 | loss.backward()
140 | optimizer.step()
141 |
142 | pbar.set_postfix(epoch=epoch, epoch_loss=epoch_loss, batch_loss=loss.item())
143 | pbar.update(1)
144 | epoch_loss = train_loss / len(training_dataloader)
145 |
146 | def predict(self, state: np.ndarray) -> float:
147 | self.regressor.eval()
148 | with torch.no_grad():
149 | features_pcl = self._state_to_features(state)
150 | features_pcl = torch.from_numpy(features_pcl).to(self.device)
151 | predicted_reward = self.regressor(features_pcl).cpu().detach().numpy()
152 | return float(predicted_reward.mean())
153 |
154 | def _is_in_training_data(self, state: np.ndarray):
155 | assert state.shape == (self.num_selected_beams,)
156 | return (np.sort(state) == np.sort(self.training_data['state'])).all(axis=1).any()
157 |
158 | def _permute_training_data(self, permute: bool = True, in_place: bool = False) -> Tuple[np.ndarray, np.ndarray]:
159 | if permute:
160 | state, reward = permute_dataset(self.training_data['state'], self.training_data['reward'])
161 | if in_place:
162 | self.training_data['state'], self.training_data['reward'] = state, reward
163 | else:
164 | state, reward = self.training_data['state'], self.training_data['reward']
165 | return state, reward
166 |
167 | def _load_features(self, features_pcl_file: Union[str, PathLike]):
168 | with open(features_pcl_file, 'rb') as f:
169 | features_pcl_raw = pickle.load(f)
170 |
171 | features_pcl = np.c_[np.fromiter(features_pcl_raw['number_points_avg'].values(), dtype=np.float32),
172 | np.fromiter(features_pcl_raw['mean_d_avg'].values(), dtype=np.float32),
173 | np.fromiter(features_pcl_raw['mean_d_std_avg'].values(), dtype=np.float32),
174 | np.array([
175 | features_pcl_raw['semantic_classes_avg'][beam_id]
176 | for beam_id in np.arange(0, len(features_pcl_raw['semantic_classes_avg']))
177 | ]) / np.fromiter(features_pcl_raw['number_points_avg'].values(), dtype=np.float32).reshape(
178 | (-1, 1))].astype(np.float32)
179 | self.elevation_angles_deg = np.fromiter(features_pcl_raw['elevation_angles_deg'].values(), dtype=np.float32)
180 |
181 | # Standardize features
182 | features_pcl = (features_pcl - features_pcl.mean(axis=0)) / features_pcl.std(axis=0)
183 | features_pcl = np.c_[features_pcl, self.elevation_angles_deg]
184 |
185 | # Save for future usage
186 | self.features_pcl = features_pcl
187 |
188 | def _state_to_features(self, state: np.ndarray):
189 | assert state.shape == (self.num_selected_beams,)
190 | features_pcl = self.features_pcl[state]
191 | pairwise_features = np.c_[np.diff(self.elevation_angles_deg[np.sort(state)])] # permutation invariant
192 | features = np.r_[features_pcl.flatten(), pairwise_features.flatten()]
193 | return features.reshape((1, -1))
194 |
195 |
196 | def permute_dataset(state, metrics):
197 | active_beams_ret = []
198 | metrics_ret = []
199 |
200 | for b, m in zip(state, metrics):
201 | for p in multiset_permutations(b):
202 | active_beams_ret.append(p)
203 | metrics_ret.append(m)
204 |
205 | tmp = list(zip(active_beams_ret, metrics_ret))
206 | random.shuffle(tmp)
207 | active_beams_ret, metrics_ret = zip(*tmp)
208 |
209 | active_beams_ret = np.vstack(active_beams_ret)
210 | metrics_ret = np.vstack(metrics_ret)
211 | return active_beams_ret, metrics_ret
212 |
--------------------------------------------------------------------------------
/object_detection/compute_reward.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import subprocess
4 | import time
5 | from pathlib import Path
6 | from typing import Optional
7 |
8 | import numpy as np
9 |
10 | from rl_l2o.compute_reward import RewardComputer as RewardComputerBase
11 |
12 | # This number should match the value of NUM_EPOCHS in the corresponding detector, e.g,
13 | # third_party/OpenPCDet/tools/cfgs/kitti_models/pointpillar.yaml
14 | DETECTOR_NUM_EPOCHS = 40
15 |
16 |
17 | class RewardComputerObjectDetection(RewardComputerBase):
18 | DEFAULT_CONFIG = {
19 | 'plv2_dir': '',
20 | 'opcd_dir': '',
21 | 'kitti_dir': '',
22 | 'output_dir': '',
23 | 'pred_path': '',
24 | 'detector': 'pointpillar',
25 | 'num_threads': 24
26 | }
27 |
28 | @staticmethod
29 | def compute_reward(
30 | beams: list, # List of beam ids
31 | plv2_dir: str, # PseudoLidarv2/gdc directory
32 | opcd_dir: str, # OpenPCDet directory
33 | kitti_dir: str, # KITTI dataset directory
34 | output_dir: str, # Output directory
35 | pred_path: str, # Path of SDN predictions
36 | detector: str = 'pointpillar', # pointpillar or pointrcnn
37 | num_threads: int = 32) -> np.ndarray:
38 | assert len(beams) == 4, 'This implementation assumes exactly 4 beams.'
39 | assert detector in ['pointpillar', 'pointrcnn']
40 |
41 | # Get paths
42 | calib_path = os.path.join(kitti_dir, 'calib')
43 | image_path = os.path.join(kitti_dir, 'image_2')
44 | point_path = os.path.join(kitti_dir, 'velodyne_original')
45 | split_file = os.path.join(plv2_dir, 'image_sets/trainval.txt')
46 | tmp_dir = os.path.join(output_dir, f'tmp_dir_{detector}')
47 |
48 | tag = f'{beams[0]}_{beams[1]}_{beams[2]}_{beams[3]}'
49 | train_dir = os.path.join(opcd_dir, 'output', 'kitti_models', detector, tag)
50 |
51 | # If the results file already exists, skip the computation
52 | reward = _check_for_existing_logs(train_dir)
53 | if reward is not None:
54 | return reward
55 |
56 | # If the pseudo lidar data already exists, skip most of the pre-processing
57 | output_name = f'pseudo_lidar_{beams[0]}_{beams[1]}_{beams[2]}_{beams[3]}'
58 | output_path = os.path.join(output_dir, output_name)
59 | if not _does_data_exist(output_path):
60 |
61 | # Simulate 4 beam lidar
62 | print(f'Simulating 4 beam lidar with beam selection: {beams}...')
63 | output_name = f'velodyne_{beams[0]}_{beams[1]}_{beams[2]}_{beams[3]}'
64 | output_path = os.path.join(output_dir, output_name)
65 | velodyne_4_beam = f' \
66 | python {plv2_dir}/sparsify.py \
67 | --calib_path {calib_path} \
68 | --image_path {image_path} \
69 | --split_file {split_file} \
70 | --ptc_path {point_path} \
71 | --W 1024 \
72 | --H 64 \
73 | --line_spec {beams[0]} {beams[1]} {beams[2]} {beams[3]} \
74 | --output_path {output_path} \
75 | --store_line_map_dir {tmp_dir} \
76 | --threads {num_threads} \
77 | '
78 |
79 | if not _does_data_exist(output_path):
80 | os.system(velodyne_4_beam)
81 |
82 | # Get ground truth depth map from original 4 beams
83 | print('Generating ground truth depth maps from the simulated 4 beam lidar...')
84 | input_path = output_path
85 | output_name = f'gt_depthmap_{beams[0]}_{beams[1]}_{beams[2]}_{beams[3]}'
86 | output_path = os.path.join(output_dir, output_name)
87 | gt_depthmap_4_beam = f' \
88 | python {plv2_dir}/ptc2depthmap.py \
89 | --output_path {output_path} \
90 | --input_path {input_path} \
91 | --calib_path {calib_path} \
92 | --image_path {image_path} \
93 | --split_file {split_file} \
94 | --threads {num_threads} \
95 | '
96 |
97 | if not _does_data_exist(output_path):
98 | os.system(gt_depthmap_4_beam)
99 |
100 | # Run batch gdc using ground truth 4 beams on predicted depth maps
101 | print('Running batch GDC using 4 beam ground truth depth map on predicted depth maps...')
102 | input_path = output_path
103 | output_name = f'gdc_depthmap_{beams[0]}_{beams[1]}_{beams[2]}_{beams[3]}'
104 | output_path = os.path.join(output_dir, output_name)
105 | gdc_depthmap_4_beam = f' \
106 | python {plv2_dir}/main_batch.py \
107 | --output_path {output_path} \
108 | --input_path {pred_path} \
109 | --calib_path {calib_path} \
110 | --gt_depthmap_path {input_path} \
111 | --threads {num_threads} \
112 | --split_file {split_file} \
113 | '
114 |
115 | if not _does_data_exist(output_path):
116 | os.system(gdc_depthmap_4_beam)
117 |
118 | # Get pseudo lidar from corrected depth
119 | print('Generating pseudo lidar point clouds from corrected depth maps...')
120 | input_path = output_path
121 | output_name = f'pseudo_lidar_{beams[0]}_{beams[1]}_{beams[2]}_{beams[3]}'
122 | output_path = os.path.join(output_dir, output_name)
123 | pseudo_lidar_4_beam = f' \
124 | python {plv2_dir}/depthmap2ptc.py \
125 | --output_path {output_path} \
126 | --input_path {input_path} \
127 | --calib_path {calib_path} \
128 | --threads {num_threads} \
129 | --split_file {split_file} \
130 | '
131 |
132 | if not _does_data_exist(output_path):
133 | os.system(pseudo_lidar_4_beam)
134 |
135 | # Sparsify to 64 lines
136 | print('Sparsifying pseudo lidar point cloud to 64 beams...')
137 | input_path = output_path
138 | output_path = os.path.join(kitti_dir, 'velodyne')
139 | sparse_pseudo_lidar_4_beam = f' \
140 | python {plv2_dir}/sparsify.py \
141 | --output_path {output_path} \
142 | --calib_path {calib_path} \
143 | --image_path {image_path} \
144 | --ptc_path {input_path} \
145 | --split_file {split_file} \
146 | --W 1024 --slice 1 --H 64 \
147 | --threads {num_threads} \
148 | '
149 |
150 | os.system(sparse_pseudo_lidar_4_beam)
151 |
152 | # This only needs to be done once
153 | # Generate info files and ground truth database for KITTI
154 | if not os.path.exists(f'{opcd_dir}/data/kitti/kitti_infos_trainval.pkl'):
155 | print('Generating KITTI file database...')
156 | kitti_database = f' \
157 | python -m pcdet.datasets.kitti.kitti_dataset \
158 | create_kitti_infos \
159 | {opcd_dir}/tools/cfgs/dataset_configs/kitti_dataset.yaml \
160 | '
161 |
162 | os.system(kitti_database)
163 |
164 | # Start training the detector
165 | print('Submitting training job...')
166 | subprocess.Popen([str(Path(__file__).parent / 'scripts' / f'train_{detector}.sh'), tag])
167 |
168 | # Check if job has started and look for log file
169 | reward = _check_for_existing_logs(train_dir, frequency=5)
170 | return reward
171 |
172 |
173 | def _check_for_existing_logs(train_dir: str, frequency: int = -1) -> Optional[np.ndarray]:
174 | if frequency > 0:
175 | print(f'Looking for job every {frequency} seconds in {train_dir}', end='', flush=True)
176 | log_file_path = None
177 | while True:
178 | for file_path in glob.glob(os.path.join(train_dir, 'log_train*.txt')):
179 | if os.path.getsize(file_path) > 0:
180 | log_file_path = file_path
181 | break
182 | if log_file_path is None and frequency > 0:
183 | time.sleep(frequency)
184 | elif log_file_path is not None:
185 | print(f'Log file located: {log_file_path}.\nChecking for performance of epoch {DETECTOR_NUM_EPOCHS}...')
186 | break
187 | else:
188 | return None
189 |
190 | # Continuously check log file for evaluation
191 | check_phrase = f'Performance of EPOCH {DETECTOR_NUM_EPOCHS}'
192 | check_index = -1
193 | car_ap = None
194 | pedestrian_ap = None
195 | cyclist_ap = None
196 | while True:
197 | with open(log_file_path, 'r', encoding='utf-8') as f:
198 | for i, line in enumerate(f):
199 | if check_phrase in line:
200 | check_index = i
201 | elif check_index == -1:
202 | continue
203 | if i == check_index + 12: # 3D Car AP
204 | car_ap = line
205 | elif i == check_index + 32: # 3D Pedestrian AP
206 | pedestrian_ap = line
207 | elif i == check_index + 52: # 3D Cyclist AP
208 | cyclist_ap = line
209 | break
210 | if car_ap is not None and pedestrian_ap is not None and cyclist_ap is not None:
211 | break
212 | if frequency > 0:
213 | print('.', end='', flush=True)
214 | time.sleep(30)
215 |
216 | assert '3d AP:' in car_ap, car_ap
217 | assert '3d AP:' in pedestrian_ap, pedestrian_ap
218 | assert '3d AP:' in cyclist_ap, cyclist_ap
219 | result = np.array([float(car_ap[8:-1].split(', ')[1]) / 100]) # 3D AP for car: moderate
220 | # result = np.array([float(j) / 100 for j in car_ap[8:-1].split(', ')]) # Get 3D AP for easy, medium, and hard
221 | # result = np.array([
222 | # float(car_ap[8:-1].split(', ')[1]) / 100,
223 | # float(pedestrian_ap[8:-1].split(', ')[1]) / 100,
224 | # float(cyclist_ap[8:-1].split(', ')[1]) / 100
225 | # ]) # Get 3D AP for car, ped, and cyclists of moderate difficulty
226 | return result
227 |
228 |
229 | def _does_data_exist(path: str, number_expected_files: int = 7481) -> bool:
230 | if not os.path.exists(path):
231 | return False
232 | number_files = len(os.listdir(path))
233 | if number_files == number_expected_files:
234 | return True
235 | return False
236 |
--------------------------------------------------------------------------------
/rl_l2o/eps_greedy_search.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from copy import deepcopy
3 | from datetime import datetime
4 | from itertools import product
5 | from os import PathLike
6 | from pathlib import Path
7 | from typing import Dict, Any, Optional, Union, Tuple
8 |
9 | import numpy as np
10 | from tqdm import tqdm
11 |
12 | from rl_l2o.compute_reward import RewardComputer
13 | from rl_l2o.predict_reward import RewardPredictor
14 | from rl_l2o.utils import Colors
15 |
16 |
17 | class EpsGreedySearch:
18 | DEFAULT_CONFIG = {
19 | 'initial_beam_ids': None, # after initialization, start with these beams. If None, start with the best so far.
20 | 'epsilon': .1, # take a random action with this probability
21 | 'num_initial_samples': 5, # number of random samples to start training the value function predictor
22 | 'num_samples': 200, # total number of samples
23 | 'max_step_size': 2, # the beam IDs can be shifted a maximum of this number (pos/neg)
24 | 'seed': None, # used to initialize all random number generators
25 | 'num_selected_beams': 4, # number of beams selected for valid states
26 | 'num_reward_signals': 1, # number of environmental signals used to predict the overall reward
27 | 'min_beam_id': 1, # smallest valid beam ID
28 | 'max_beam_id': 40, # highest valid beam ID
29 | 'checkpoint_files_path': Path(__file__).absolute().parents[1] / 'checkpoints',
30 | 'predict_reward': {
31 | 'num_training_epochs': 10,
32 | 'training_batch_size': 8,
33 | 'device': None
34 | },
35 | }
36 |
37 | def __init__(
38 | self,
39 | features_pcl_file: Union[str, PathLike],
40 | reward_computer: RewardComputer,
41 | config: Optional[Dict[str, Any]] = None,
42 | logfile: Union[str, PathLike] = 'log.txt',
43 | ):
44 | self.config = self.DEFAULT_CONFIG.copy()
45 | if config:
46 | self.config.update(config)
47 | if self.config['initial_beam_ids'] is not None:
48 | assert len(self.config['initial_beam_ids']) == self.config['num_selected_beams']
49 | assert min(self.config['initial_beam_ids']) >= self.config['min_beam_id']
50 | assert max(self.config['initial_beam_ids']) <= self.config['max_beam_id']
51 | self.config['features_pcl_file'] = str(features_pcl_file)
52 |
53 | # Random number generators
54 | self.epsilon_generator = np.random.default_rng(seed=self.config['seed'])
55 | self.random_action_generator = np.random.default_rng(seed=self.config['seed'])
56 | self.random_state_generator = np.random.default_rng(seed=self.config['seed'])
57 |
58 | # Miscellaneous objects
59 | self.reward_predictor = RewardPredictor(self.config['num_selected_beams'], self.config['num_reward_signals'],
60 | self.config['features_pcl_file'], self.config['predict_reward'])
61 | self.reward_computer = reward_computer
62 | self.reward_computer.set_number_reward_signals(self.config['num_reward_signals'])
63 | self.log_history = {} # step_counter: (message, color)
64 | self.state_reward_history = [] # (beams, reward, predicted_reward)
65 | self.state_action_pairs = [] # (beams, steps)
66 | self.step_counter = 0
67 | self._continue_searching = False # can be set when loading a checkpoint
68 |
69 | # Logging
70 | self.logfile = Path(logfile)
71 | if self.logfile.exists():
72 | prev_logfile = str(logfile).replace(self.logfile.suffix, f'_old{self.logfile.suffix}')
73 | self.logfile.rename(prev_logfile)
74 | print(f'{Colors.WARNING}WARNING: Found existing logfile and renamed it to "{prev_logfile}".{Colors.ENDC}')
75 |
76 | def run(self) -> np.ndarray:
77 | self._initialize_predictor()
78 | if self._continue_searching:
79 | state = self.state_reward_history[-1][0]
80 | elif self.config['initial_beam_ids'] is None:
81 | state = self.best_state()
82 | else:
83 | if self.step_counter < self.config['num_samples']:
84 | self.step_counter += 1
85 | state = np.asarray(self.config['initial_beam_ids'], dtype=np.int)
86 | reward, reward_signals = self.reward_computer.compute(state)
87 | self.reward_predictor.add_to_training_data(state, reward_signals)
88 | self.reward_predictor.train(permute=True)
89 | is_best_state = self._add_to_state_reward_history(state, reward)
90 | self._write_to_log(state, reward, is_best_state=is_best_state)
91 | self.save_checkpoint()
92 |
93 | # Main loop
94 | while self.step_counter < self.config['num_samples']:
95 | self.step_counter += 1
96 |
97 | if self.epsilon_generator.random() < self.config['epsilon']:
98 | action = self._get_random_action(state)
99 | predicted_reward = self.reward_predictor.predict(state)
100 | else:
101 | action, predicted_reward = self._get_best_action(state)
102 | state = self._apply_action(state, action)
103 | assert self._is_valid_state(state), state
104 | if self._is_in_state_action_pairs(state, action):
105 | state = self._get_random_state(self.random_state_generator)
106 | self._write_message_to_log(
107 | 'WARNING: Resampled new random state due to a loop in the state-action-history.', Colors.WARNING,
108 | True)
109 |
110 | reward, reward_signals = self.reward_computer.compute(state)
111 | self.reward_predictor.add_to_training_data(state, reward_signals)
112 | self.reward_predictor.train(permute=True)
113 |
114 | is_best_state = self._add_to_state_reward_history(state, reward, predicted_reward)
115 | self._write_to_log(state, reward, predicted_reward, is_best_state)
116 | self.save_checkpoint()
117 |
118 | return self.best_state()
119 |
120 | def best_state(self, return_reward: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, float]]:
121 | state_reward = np.array([[x[0], x[1]] for x in self.state_reward_history], dtype=object)
122 | best_reward_index = int(np.argmax(state_reward[:, 1]))
123 | best_state_reward = self.state_reward_history[best_reward_index]
124 | best_state_reward = (np.sort(best_state_reward[0]), best_state_reward[1])
125 | if return_reward:
126 | return best_state_reward
127 | return best_state_reward[0]
128 |
129 | def save_checkpoint(self):
130 | if not self.config['checkpoint_files_path'].exists():
131 | self.config['checkpoint_files_path'].mkdir(parents=True, exist_ok=True)
132 | checkpoint_file = self.config['checkpoint_files_path'] / f'checkpoint_{str(self.step_counter).zfill(3)}.pkl'
133 | checkpoint = {
134 | 'config': self.config,
135 | 'state_reward_history': self.state_reward_history,
136 | 'state_action_pairs': self.state_action_pairs,
137 | 'step_counter': self.step_counter,
138 | 'log_history': self.log_history,
139 | 'epsilon_generator': self.epsilon_generator.__getstate__(),
140 | 'random_action_generator': self.random_action_generator.__getstate__(),
141 | 'random_state_generator': self.random_state_generator.__getstate__(),
142 | 'reward_computer_cache': self.reward_computer.cache,
143 | 'reward_predictor_training_data': self.reward_predictor.training_data,
144 | }
145 | with open(checkpoint_file, 'wb') as f:
146 | pickle.dump(checkpoint, f, pickle.DEFAULT_PROTOCOL)
147 |
148 | def load_checkpoint(self, checkpoint_file: Union[str, PathLike], continue_searching: bool = True):
149 | with open(checkpoint_file, 'rb') as f:
150 | checkpoint = pickle.load(f)
151 | self._continue_searching = continue_searching
152 |
153 | # Check for compatible setups
154 | assert self.config['num_selected_beams'] == checkpoint['config']['num_selected_beams']
155 |
156 | # We do not load paths and total number of samples from checkpoint files
157 | config = deepcopy(self.config)
158 | self.config = checkpoint['config']
159 | self.config['compute_reward'] = config['compute_reward']
160 | self.config['checkpoint_files_path'] = config['checkpoint_files_path']
161 | self.config['features_pcl_file'] = config['features_pcl_file']
162 | self.config['num_samples'] = config['num_samples']
163 |
164 | self.state_action_pairs = checkpoint['state_action_pairs']
165 | self.step_counter = checkpoint['step_counter']
166 | if 'log_history' in checkpoint:
167 | self.log_history = checkpoint['log_history']
168 | self.epsilon_generator.__setstate__(checkpoint['epsilon_generator'])
169 | self.random_action_generator.__setstate__(checkpoint['random_action_generator'])
170 | self.random_state_generator.__setstate__(checkpoint['random_state_generator'])
171 |
172 | self.reward_predictor = RewardPredictor(self.config['num_selected_beams'], self.config['num_reward_signals'],
173 | self.config['features_pcl_file'], self.config['predict_reward'])
174 | self.reward_predictor.load_training_data(checkpoint['reward_predictor_training_data'])
175 | self.reward_computer.load_cache(checkpoint['reward_computer_cache'])
176 |
177 | for step_counter_replay, state_reward in enumerate(checkpoint['state_reward_history'], start=1):
178 | if step_counter_replay in self.log_history:
179 | self._write_message_to_log(self.log_history[step_counter_replay][0],
180 | self.log_history[step_counter_replay][1])
181 | if len(state_reward) == 2:
182 | state_reward = (state_reward[0], state_reward[1], None)
183 | is_best_state = self._add_to_state_reward_history(state_reward[0], state_reward[1], state_reward[2])
184 | self._write_to_log(state_reward[0], state_reward[1], state_reward[2], is_best_state, step_counter_replay)
185 |
186 | self._write_message_to_log('INFO: Resuming from a checkpoint.', Colors.OKBLUE, True)
187 |
188 | def load_reward_prediction_cache(self, checkpoint_file: Union[str, PathLike]):
189 | with open(checkpoint_file, 'rb') as f:
190 | checkpoint = pickle.load(f)
191 |
192 | # Check for compatible setups
193 | assert self.config['num_selected_beams'] == checkpoint['config']['num_selected_beams']
194 |
195 | self.reward_predictor.load_training_data(checkpoint['reward_predictor_training_data'])
196 |
197 | def load_reward_computation_cache(self, checkpoint_file: Union[str, PathLike]):
198 | with open(checkpoint_file, 'rb') as f:
199 | checkpoint = pickle.load(f)
200 |
201 | # Check for compatible setups
202 | assert self.config['num_selected_beams'] == checkpoint['config']['num_selected_beams']
203 |
204 | self.reward_computer.load_cache(checkpoint['reward_computer_cache'])
205 |
206 | def _initialize_predictor(self):
207 | while self.step_counter < self.config['num_initial_samples']:
208 | self.step_counter += 1
209 | state = self._get_random_state(self.random_state_generator)
210 | reward, reward_signals = self.reward_computer.compute(state)
211 | self.reward_predictor.add_to_training_data(state, reward_signals)
212 | is_best_state = self._add_to_state_reward_history(state, reward)
213 | self._write_to_log(state, reward, is_best_state=is_best_state)
214 | self.save_checkpoint()
215 | self.reward_predictor.train(permute=True)
216 |
217 | def _is_valid_state(self, state: np.ndarray) -> bool:
218 | check = np.unique(state).size == self.config['num_selected_beams']
219 | check &= state.min() >= self.config['min_beam_id']
220 | check &= state.max() <= self.config['max_beam_id']
221 | return check
222 |
223 | def _get_random_state(self, random_number_generator: np.random.Generator) -> np.ndarray:
224 | state = np.empty((1,))
225 | while not self._is_valid_state(state):
226 | state = np.unique(
227 | np.sort(
228 | random_number_generator.integers(self.config['min_beam_id'],
229 | self.config['max_beam_id'],
230 | endpoint=True,
231 | size=(self.config['num_selected_beams'],))))
232 | return state
233 |
234 | def _get_random_action(self, state: Optional[np.ndarray] = None) -> np.ndarray:
235 | action = self.random_action_generator.integers(-self.config['max_step_size'],
236 | self.config['max_step_size'],
237 | endpoint=True,
238 | size=(self.config['num_selected_beams'],))
239 | # If a state is provided, only valid actions will be returned.
240 | while state is not None and not self._is_valid_state(self._apply_action(state, action)):
241 | action = self.random_action_generator.integers(-self.config['max_step_size'],
242 | self.config['max_step_size'],
243 | endpoint=True,
244 | size=(self.config['num_selected_beams'],))
245 | return action
246 |
247 | def _get_best_action(self, state: np.ndarray, show_pbar: bool = False) -> Tuple[np.ndarray, Optional[float]]:
248 | number_selected_beams = self.config['num_selected_beams']
249 | actions = []
250 | rewards = []
251 | valid_steps = list(range(-self.config['max_step_size'], self.config['max_step_size'] + 1))
252 | with tqdm(total=len(valid_steps) ** number_selected_beams, desc='Predicting best action',
253 | disable=not show_pbar) as pbar:
254 | for action in product(*[valid_steps] * number_selected_beams):
255 | action = np.array(action)
256 | new_state = self._apply_action(state, action)
257 | if self._is_valid_state(new_state):
258 | predicted_reward = self.reward_predictor.predict(new_state)
259 | actions.append(action)
260 | rewards.append(predicted_reward)
261 | pbar.update(1)
262 | if rewards:
263 | best_reward = max(rewards)
264 | best_action = actions[rewards.index(best_reward)]
265 | else:
266 | best_action = np.zeros((number_selected_beams,), dtype=np.int)
267 | best_reward = None
268 | return best_action, best_reward
269 |
270 | @staticmethod
271 | def _apply_action(state: np.ndarray, action: np.ndarray) -> np.ndarray:
272 | new_state = state + action
273 | return new_state
274 |
275 | def _is_in_state_action_pairs(self, state: np.ndarray, action: np.ndarray) -> bool:
276 | sorted_action = action[np.argsort(state)]
277 | sorted_state = np.sort(state)
278 | for state_action_pair in self.state_action_pairs:
279 | if np.all(state_action_pair[0] == sorted_state) and np.all(state_action_pair[1] == sorted_action):
280 | return True
281 | # Otherwise, add it to the list
282 | self.state_action_pairs.append((sorted_state, sorted_action))
283 | return False
284 |
285 | def _add_to_state_reward_history(self,
286 | state: np.ndarray,
287 | reward: float,
288 | predicted_reward: Optional[float] = None) -> bool:
289 | # Add the (state, reward) pair to the cache and determines whether it is a new global optimum
290 | self.state_reward_history.append((state, reward, predicted_reward))
291 | if np.all(np.sort(state) == self.best_state()):
292 | return True
293 | return False
294 |
295 | def _write_to_log(self,
296 | state: np.ndarray,
297 | reward: float,
298 | predicted_reward: Optional[float] = None,
299 | is_best_state: bool = False,
300 | step_counter: Optional[int] = None):
301 | step_counter = self.step_counter if step_counter is None else step_counter
302 | msg = f'{str(step_counter).rjust(3)} | {str(state)[1:-1].ljust(11)} | {reward:.5f}'
303 | msg = f'{msg} | pred={predicted_reward:.5f}' if predicted_reward is not None else f'{msg} | '
304 | color = Colors.OKGREEN if is_best_state else None
305 | msg = f'{msg} | new best state' if is_best_state else msg
306 | self._write_message_to_log(msg, color)
307 |
308 | def _write_message_to_log(self, message: str, color: Optional[Colors] = None, add_to_log_history: bool = False):
309 | if add_to_log_history:
310 | self.log_history[self.step_counter] = (message, color)
311 | current_time = datetime.now().strftime('%H:%M:%S')
312 | msg = f'{current_time} | {message}'
313 | if color is not None:
314 | print(color + msg + Colors.ENDC)
315 | else:
316 | print(msg)
317 | with open(self.logfile, 'a', encoding='utf-8') as f:
318 | f.write(msg + '\n')
319 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More_considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 |
142 | Section 2 -- Scope.
143 |
144 | a. License grant.
145 |
146 | 1. Subject to the terms and conditions of this Public License,
147 | the Licensor hereby grants You a worldwide, royalty-free,
148 | non-sublicensable, non-exclusive, irrevocable license to
149 | exercise the Licensed Rights in the Licensed Material to:
150 |
151 | a. reproduce and Share the Licensed Material, in whole or
152 | in part, for NonCommercial purposes only; and
153 |
154 | b. produce, reproduce, and Share Adapted Material for
155 | NonCommercial purposes only.
156 |
157 | 2. Exceptions and Limitations. For the avoidance of doubt, where
158 | Exceptions and Limitations apply to Your use, this Public
159 | License does not apply, and You do not need to comply with
160 | its terms and conditions.
161 |
162 | 3. Term. The term of this Public License is specified in Section
163 | 6(a).
164 |
165 | 4. Media and formats; technical modifications allowed. The
166 | Licensor authorizes You to exercise the Licensed Rights in
167 | all media and formats whether now known or hereafter created,
168 | and to make technical modifications necessary to do so. The
169 | Licensor waives and/or agrees not to assert any right or
170 | authority to forbid You from making technical modifications
171 | necessary to exercise the Licensed Rights, including
172 | technical modifications necessary to circumvent Effective
173 | Technological Measures. For purposes of this Public License,
174 | simply making modifications authorized by this Section 2(a)
175 | (4) never produces Adapted Material.
176 |
177 | 5. Downstream recipients.
178 |
179 | a. Offer from the Licensor -- Licensed Material. Every
180 | recipient of the Licensed Material automatically
181 | receives an offer from the Licensor to exercise the
182 | Licensed Rights under the terms and conditions of this
183 | Public License.
184 |
185 | b. No downstream restrictions. You may not offer or impose
186 | any additional or different terms or conditions on, or
187 | apply any Effective Technological Measures to, the
188 | Licensed Material if doing so restricts exercise of the
189 | Licensed Rights by any recipient of the Licensed
190 | Material.
191 |
192 | 6. No endorsement. Nothing in this Public License constitutes or
193 | may be construed as permission to assert or imply that You
194 | are, or that Your use of the Licensed Material is, connected
195 | with, or sponsored, endorsed, or granted official status by,
196 | the Licensor or others designated to receive attribution as
197 | provided in Section 3(a)(1)(A)(i).
198 |
199 | b. Other rights.
200 |
201 | 1. Moral rights, such as the right of integrity, are not
202 | licensed under this Public License, nor are publicity,
203 | privacy, and/or other similar personality rights; however, to
204 | the extent possible, the Licensor waives and/or agrees not to
205 | assert any such rights held by the Licensor to the limited
206 | extent necessary to allow You to exercise the Licensed
207 | Rights, but not otherwise.
208 |
209 | 2. Patent and trademark rights are not licensed under this
210 | Public License.
211 |
212 | 3. To the extent possible, the Licensor waives any right to
213 | collect royalties from You for the exercise of the Licensed
214 | Rights, whether directly or through a collecting society
215 | under any voluntary or waivable statutory or compulsory
216 | licensing scheme. In all other cases the Licensor expressly
217 | reserves any right to collect such royalties, including when
218 | the Licensed Material is used other than for NonCommercial
219 | purposes.
220 |
221 |
222 | Section 3 -- License Conditions.
223 |
224 | Your exercise of the Licensed Rights is expressly made subject to the
225 | following conditions.
226 |
227 | a. Attribution.
228 |
229 | 1. If You Share the Licensed Material (including in modified
230 | form), You must:
231 |
232 | a. retain the following if it is supplied by the Licensor
233 | with the Licensed Material:
234 |
235 | i. identification of the creator(s) of the Licensed
236 | Material and any others designated to receive
237 | attribution, in any reasonable manner requested by
238 | the Licensor (including by pseudonym if
239 | designated);
240 |
241 | ii. a copyright notice;
242 |
243 | iii. a notice that refers to this Public License;
244 |
245 | iv. a notice that refers to the disclaimer of
246 | warranties;
247 |
248 | v. a URI or hyperlink to the Licensed Material to the
249 | extent reasonably practicable;
250 |
251 | b. indicate if You modified the Licensed Material and
252 | retain an indication of any previous modifications; and
253 |
254 | c. indicate the Licensed Material is licensed under this
255 | Public License, and include the text of, or the URI or
256 | hyperlink to, this Public License.
257 |
258 | 2. You may satisfy the conditions in Section 3(a)(1) in any
259 | reasonable manner based on the medium, means, and context in
260 | which You Share the Licensed Material. For example, it may be
261 | reasonable to satisfy the conditions by providing a URI or
262 | hyperlink to a resource that includes the required
263 | information.
264 |
265 | 3. If requested by the Licensor, You must remove any of the
266 | information required by Section 3(a)(1)(A) to the extent
267 | reasonably practicable.
268 |
269 | 4. If You Share Adapted Material You produce, the Adapter's
270 | License You apply must not prevent recipients of the Adapted
271 | Material from complying with this Public License.
272 |
273 |
274 | Section 4 -- Sui Generis Database Rights.
275 |
276 | Where the Licensed Rights include Sui Generis Database Rights that
277 | apply to Your use of the Licensed Material:
278 |
279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280 | to extract, reuse, reproduce, and Share all or a substantial
281 | portion of the contents of the database for NonCommercial purposes
282 | only;
283 |
284 | b. if You include all or a substantial portion of the database
285 | contents in a database in which You have Sui Generis Database
286 | Rights, then the database in which You have Sui Generis Database
287 | Rights (but not its individual contents) is Adapted Material; and
288 |
289 | c. You must comply with the conditions in Section 3(a) if You Share
290 | all or a substantial portion of the contents of the database.
291 |
292 | For the avoidance of doubt, this Section 4 supplements and does not
293 | replace Your obligations under this Public License where the Licensed
294 | Rights include other Copyright and Similar Rights.
295 |
296 |
297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298 |
299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309 |
310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319 |
320 | c. The disclaimer of warranties and limitation of liability provided
321 | above shall be interpreted in a manner that, to the extent
322 | possible, most closely approximates an absolute disclaimer and
323 | waiver of all liability.
324 |
325 |
326 | Section 6 -- Term and Termination.
327 |
328 | a. This Public License applies for the term of the Copyright and
329 | Similar Rights licensed here. However, if You fail to comply with
330 | this Public License, then Your rights under this Public License
331 | terminate automatically.
332 |
333 | b. Where Your right to use the Licensed Material has terminated under
334 | Section 6(a), it reinstates:
335 |
336 | 1. automatically as of the date the violation is cured, provided
337 | it is cured within 30 days of Your discovery of the
338 | violation; or
339 |
340 | 2. upon express reinstatement by the Licensor.
341 |
342 | For the avoidance of doubt, this Section 6(b) does not affect any
343 | right the Licensor may have to seek remedies for Your violations
344 | of this Public License.
345 |
346 | c. For the avoidance of doubt, the Licensor may also offer the
347 | Licensed Material under separate terms or conditions or stop
348 | distributing the Licensed Material at any time; however, doing so
349 | will not terminate this Public License.
350 |
351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352 | License.
353 |
354 |
355 | Section 7 -- Other Terms and Conditions.
356 |
357 | a. The Licensor shall not be bound by any additional or different
358 | terms or conditions communicated by You unless expressly agreed.
359 |
360 | b. Any arrangements, understandings, or agreements regarding the
361 | Licensed Material not stated herein are separate from and
362 | independent of the terms and conditions of this Public License.
363 |
364 |
365 | Section 8 -- Interpretation.
366 |
367 | a. For the avoidance of doubt, this Public License does not, and
368 | shall not be interpreted to, reduce, limit, restrict, or impose
369 | conditions on any use of the Licensed Material that could lawfully
370 | be made without permission under this Public License.
371 |
372 | b. To the extent possible, if any provision of this Public License is
373 | deemed unenforceable, it shall be automatically reformed to the
374 | minimum extent necessary to make it enforceable. If the provision
375 | cannot be reformed, it shall be severed from this Public License
376 | without affecting the enforceability of the remaining terms and
377 | conditions.
378 |
379 | c. No term or condition of this Public License will be waived and no
380 | failure to comply consented to unless expressly agreed to by the
381 | Licensor.
382 |
383 | d. Nothing in this Public License constitutes or may be interpreted
384 | as a limitation upon, or waiver of, any privileges and immunities
385 | that apply to the Licensor or You, including from the legal
386 | processes of any jurisdiction or authority.
387 |
388 | =======================================================================
389 |
390 | Creative Commons is not a party to its public
391 | licenses. Notwithstanding, Creative Commons may elect to apply one of
392 | its public licenses to material it publishes and in those instances
393 | will be considered the “Licensor.” The text of the Creative Commons
394 | public licenses is dedicated to the public domain under the CC0 Public
395 | Domain Dedication. Except for the limited purpose of indicating that
396 | material is shared under a Creative Commons public license or as
397 | otherwise permitted by the Creative Commons policies published at
398 | creativecommons.org/policies, Creative Commons does not authorize the
399 | use of the trademark "Creative Commons" or any other trademark or logo
400 | of Creative Commons without its prior written consent including,
401 | without limitation, in connection with any unauthorized modifications
402 | to any of its public licenses or any other arrangements,
403 | understandings, or agreements concerning use of licensed material. For
404 | the avoidance of doubt, this paragraph does not form part of the
405 | public licenses.
406 |
407 | Creative Commons may be contacted at creativecommons.org.
408 |
--------------------------------------------------------------------------------