├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── balloon_learning_environment ├── __init__.py ├── acme_utils.py ├── agents │ ├── __init__.py │ ├── acme_eval_agent.py │ ├── agent.py │ ├── agent_registry.py │ ├── agent_test.py │ ├── configs │ │ ├── __init__.py │ │ ├── dqn.gin │ │ ├── finetune_perciatelli.gin │ │ ├── mlp.gin │ │ └── quantile.gin │ ├── dopamine_utils.py │ ├── dopamine_utils_test.py │ ├── dqn_agent.py │ ├── dqn_agent_test.py │ ├── exploration.py │ ├── exploration_test.py │ ├── marco_polo_exploration.py │ ├── marco_polo_exploration_test.py │ ├── mlp_agent.py │ ├── mlp_agent_test.py │ ├── networks.py │ ├── networks_test.py │ ├── perciatelli44.py │ ├── perciatelli44_test.py │ ├── quantile_agent.py │ ├── quantile_agent_test.py │ ├── random_walk_agent.py │ ├── random_walk_agent_test.py │ ├── station_seeker_agent.py │ └── station_seeker_agent_test.py ├── colab │ ├── BLE_Generative_Wind_Field.ipynb │ ├── BLE_view_flight_paths.ipynb │ └── summarize_eval.ipynb ├── distributed_train_acme_qrdqn.py ├── env │ ├── __init__.py │ ├── balloon │ │ ├── __init__.py │ │ ├── acs.py │ │ ├── acs_test.py │ │ ├── altitude_safety.py │ │ ├── altitude_safety_test.py │ │ ├── balloon.py │ │ ├── balloon_test.py │ │ ├── control.py │ │ ├── envelope_safety.py │ │ ├── envelope_safety_test.py │ │ ├── power_safety.py │ │ ├── power_safety_test.py │ │ ├── power_table.py │ │ ├── power_table_test.py │ │ ├── pressure_range_builder.py │ │ ├── pressure_range_builder_test.py │ │ ├── solar.py │ │ ├── solar_test.py │ │ ├── stable_init.py │ │ ├── stable_init_test.py │ │ ├── standard_atmosphere.py │ │ ├── standard_atmosphere_test.py │ │ └── thermal.py │ ├── balloon_arena.py │ ├── balloon_arena_test.py │ ├── balloon_env.py │ ├── balloon_env_test.py │ ├── features.py │ ├── features_test.py │ ├── generative_wind_field.py │ ├── grid_based_wind_field.py │ ├── grid_based_wind_field_test.py │ ├── grid_wind_field_sampler.py │ ├── gym.py │ ├── rendering │ │ ├── __init__.py │ │ ├── matplotlib_renderer.py │ │ └── renderer.py │ ├── simplex_wind_noise.py │ ├── simulator_data.py │ ├── wind_field.py │ ├── wind_field_test.py │ ├── wind_gp.py │ └── wind_gp_test.py ├── eval │ ├── README.md │ ├── __init__.py │ ├── combine_eval_shards.py │ ├── eval.py │ ├── eval_lib.py │ ├── eval_lib_test.py │ ├── strata_seeds.py │ ├── suites.py │ └── suites_test.py ├── generated │ ├── multi_balloon.mp4 │ └── wind_field.mp4 ├── generative │ ├── __init__.py │ ├── dataset_wind_field_reservoir.py │ ├── dataset_wind_field_reservoir_test.py │ ├── learn_wind_field_generator.py │ ├── vae.py │ ├── vae_test.py │ └── wind_field_reservoir.py ├── metrics │ ├── __init__.py │ ├── collector.py │ ├── collector_dispatcher.py │ ├── collector_dispatcher_test.py │ ├── collector_test.py │ ├── console_collector.py │ ├── console_collector_test.py │ ├── pickle_collector.py │ ├── pickle_collector_test.py │ ├── statistics_instance.py │ ├── tensorboard_collector.py │ └── tensorboard_collector_test.py ├── models │ ├── __init__.py │ ├── models.py │ ├── models_test.py │ ├── offlineskies22_decoder.msgpack │ └── perciatelli44.pb ├── train.py ├── train_acme_qrdqn.py ├── train_lib.py ├── train_lib_test.py └── utils │ ├── __init__.py │ ├── constants.py │ ├── run_helpers.py │ ├── run_helpers_test.py │ ├── sampling.py │ ├── sampling_test.py │ ├── spherical_geometry.py │ ├── spherical_geometry_test.py │ ├── test_helpers.py │ ├── transforms.py │ ├── transforms_test.py │ ├── units.py │ ├── units_test.py │ ├── wind.py │ └── wind_test.py ├── docs ├── Makefile ├── README.md ├── about.rst ├── benchmarks.rst ├── changelist.rst ├── conf.py ├── environment.rst ├── getting_started.rst ├── imgs │ ├── balloon_schematic.jpg │ ├── ble_logo.png │ ├── ble_logo_small.png │ ├── reward_function.png │ ├── station_keeping.gif │ ├── training_curve.jpg │ └── wind_field.gif ├── index.rst ├── make.bat ├── new_agent.rst ├── requirements.txt └── src │ ├── agents.rst │ ├── agents │ ├── agent.rst │ ├── dqn_agent.rst │ ├── perciatelli44.rst │ ├── quantile_agent.rst │ └── station_seeker_agent.rst │ ├── balloon_env.rst │ ├── env.rst │ ├── eval_lib.rst │ ├── features.rst │ ├── metrics.rst │ ├── metrics │ ├── collector.rst │ ├── collector_dispatcher.rst │ ├── console_collector.rst │ ├── pickle_collector.rst │ ├── statistics_instance.rst │ └── tensorboard_collector.rst │ └── train_lib.rst ├── pyproject.toml ├── requirements.txt ├── setup.py └── style_guidelines.md /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Greaves" 5 | given-names: "Joshua" 6 | - family-names: "Candido" 7 | given-names: "Salvatore" 8 | - family-names: "Dumoulin" 9 | given-names: "Vincent" 10 | - family-names: "Goroshin" 11 | given-names: "Ross" 12 | - family-names: "Ponda" 13 | given-names: "Sameera S." 14 | - family-names: "Bellemare" 15 | given-names: "Marc G." 16 | - family-names: "Castro" 17 | given-names: "Pablo Samuel" 18 | title: "Balloon Learning Environment" 19 | version: 1.0.0 20 | date-released: 2021-12-06 21 | url: "https://github.com/google/balloon-learning-environment" 22 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Balloon Learning Environment 2 | [Docs][docs] 3 | 4 |
5 | 7 |

8 |
9 | 10 | 11 | The Balloon Learning Environment (BLE) is a simulator for stratospheric 12 | balloons. It is designed as a benchmark environment for deep reinforcement 13 | learning algorithms, and is a followup to the Nature paper 14 | ["Autonomous navigation of stratospheric balloons using reinforcement learning"](https://www.nature.com/articles/s41586-020-2939-8). 15 | 16 | ## Getting Started 17 | 18 | Note: The BLE requires python >= 3.7 19 | 20 | The BLE can easily be installed with pip: 21 | 22 | ``` 23 | $ pip install --upgrade pip 24 | $ pip install balloon_learning_environment 25 | ``` 26 | 27 | To install with the `acme` package: 28 | 29 | ``` 30 | $ pip install --upgrade pip 31 | $ pip install balloon_learning_environment[acme] 32 | ``` 33 | 34 | Once the package has been installed, you can test it runs correctly by 35 | evaluating one of the benchmark agents: 36 | 37 | ``` 38 | python -m balloon_learning_environment.eval.eval \ 39 | --agent=station_seeker \ 40 | --renderer=matplotlib \ 41 | --suite=micro_eval \ 42 | --output_dir=/tmp/ble/eval 43 | ``` 44 | 45 | To install from GitHub directly, run the following commands from the root 46 | directory where you cloned the repository: 47 | 48 | ``` 49 | $ pip install --upgrade pip 50 | $ pip install .[acme] 51 | ``` 52 | 53 | ## Ensure the BLE is Using Your GPU/TPU 54 | 55 | The BLE contains a VAE for generating winds, which you will probably want 56 | to run on your accelerator. See the jax documentation for installing with 57 | [GPU](https://github.com/google/jax#pip-installation-gpu-cuda) or 58 | [TPU](https://github.com/google/jax#pip-installation-google-cloud-tpu). 59 | 60 | As a sanity check, you can open interactive python and run: 61 | 62 | ``` 63 | from balloon_learning_environment.env import balloon_env 64 | env = balloon_env.BalloonEnv() 65 | ``` 66 | 67 | If you are not running with GPU/TPU, you should see a log like: 68 | 69 | ``` 70 | WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) 71 | ``` 72 | 73 | If you don't see this log, you should be good to go! 74 | 75 | ## Next Steps 76 | 77 | For more information, see the [docs][docs]. 78 | 79 | ## Giving credit 80 | 81 | If you use the Balloon Learning Environment in your work, we ask that you use 82 | the following BibTeX entry: 83 | 84 | ``` 85 | @software{Greaves_Balloon_Learning_Environment_2021, 86 | author = {Greaves, Joshua and Candido, Salvatore and Dumoulin, Vincent and Goroshin, Ross and Ponda, Sameera S. and Bellemare, Marc G. and Castro, Pablo Samuel}, 87 | month = {12}, 88 | title = {{Balloon Learning Environment}}, 89 | url = {https://github.com/google/balloon-learning-environment}, 90 | version = {1.0.0}, 91 | year = {2021} 92 | } 93 | ``` 94 | 95 | If you use the `ble_wind_field` dataset, you should also cite 96 | 97 | ``` 98 | Hersbach, H., Bell, B., Berrisford, P., Hirahara, S., Horányi, A., 99 | Muñoz‐Sabater, J., Nicolas, J., Peubey, C., Radu, R., Schepers, D., Simmons, A., 100 | Soci, C., Abdalla, S., Abellan, X., Balsamo, G., Bechtold, P., Biavati, G., 101 | Bidlot, J., Bonavita, M., De Chiara, G., Dahlgren, P., Dee, D., Diamantakis, M., 102 | Dragani, R., Flemming, J., Forbes, R., Fuentes, M., Geer, A., Haimberger, L., 103 | Healy, S., Hogan, R.J., Hólm, E., Janisková, M., Keeley, S., Laloyaux, P., 104 | Lopez, P., Lupu, C., Radnoti, G., de Rosnay, P., Rozum, I., Vamborg, F., 105 | Villaume, S., Thépaut, J-N. (2017): Complete ERA5: Fifth generation of ECMWF 106 | atmospheric reanalyses of the global climate. Copernicus Climate Change Service 107 | (C3S) Data Store (CDS). (Accessed on 01-04-2021) 108 | ``` 109 | 110 | 111 | [docs]: https://balloon-learning-environment.readthedocs.io/en/latest/ 112 | -------------------------------------------------------------------------------- /balloon_learning_environment/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Balloon Learning Environment root package exports.""" 17 | from balloon_learning_environment.env import gym as _ble_gym 18 | 19 | # Register Gym environment. Users can import the environment as: 20 | # gym.make('balloon_learning_environment:BalloonLearningEnvironment-v0') 21 | _ble_gym.register_env() 22 | 23 | # We don't export anything 24 | __all__ = [] 25 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/agent_registry.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """The registry of agents. 17 | 18 | This is where you add new agents; we provide some examples to get you started. 19 | 20 | When writing a new agent, follow the API specified by the base class 21 | `agent.Agent` and implement the abstract methods. 22 | The provided agents are: 23 | RandomAgent: Ignores all observations and picks actions randomly. 24 | MLPAgent: Uses a simple multi-layer perceptron (MLP) to learn the mapping of 25 | states to Q-values. The number of layers and hidden units in the MLP is 26 | configurable. 27 | """ 28 | 29 | from typing import Callable, Optional 30 | 31 | from balloon_learning_environment.agents import agent 32 | from balloon_learning_environment.agents import dqn_agent 33 | from balloon_learning_environment.agents import mlp_agent 34 | from balloon_learning_environment.agents import perciatelli44 35 | from balloon_learning_environment.agents import quantile_agent 36 | from balloon_learning_environment.agents import random_walk_agent 37 | from balloon_learning_environment.agents import station_seeker_agent 38 | 39 | BASE_DIR = 'balloon_learning_environment/agents/configs' 40 | REGISTRY = { 41 | 'random': (agent.RandomAgent, None), 42 | 'mlp': (mlp_agent.MLPAgent, f'{BASE_DIR}/mlp.gin'), 43 | 'dqn': (dqn_agent.DQNAgent, f'{BASE_DIR}/dqn.gin'), 44 | 'perciatelli44': (perciatelli44.Perciatelli44, None), 45 | 'quantile': (quantile_agent.QuantileAgent, f'{BASE_DIR}/quantile.gin'), 46 | 'finetune_perciatelli': (quantile_agent.QuantileAgent, 47 | f'{BASE_DIR}/finetune_perciatelli.gin'), 48 | 'station_seeker': (station_seeker_agent.StationSeekerAgent, None), 49 | 'random_walk': (random_walk_agent.RandomWalkAgent, None), 50 | } 51 | 52 | try: 53 | from balloon_learning_environment.agents import acme_eval_agent # pylint: disable=g-import-not-at-top 54 | REGISTRY['acme_eval_agent'] = (acme_eval_agent.AcmeEvalAgent, None) 55 | except ModuleNotFoundError: 56 | # This is most likely because acme dependencies aren't installed. 57 | pass 58 | _ACME_AGENTS = frozenset({'acme_eval_agent'}) 59 | 60 | 61 | 62 | def agent_constructor(name: str) -> Callable[..., agent.Agent]: 63 | if name not in REGISTRY: 64 | if name in _ACME_AGENTS: 65 | raise ValueError(f'Agent {name} not available. ' 66 | 'Have you tried installing the acme dependencies?') 67 | else: 68 | raise ValueError(f'Agent {name} not recognized') 69 | return REGISTRY[name][0] 70 | 71 | 72 | def get_default_gin_config(name: str) -> Optional[str]: 73 | if name not in REGISTRY: 74 | raise ValueError(f'Agent {name} not recognized') 75 | return REGISTRY[name][1] 76 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/agent_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for balloon_learning_environment.agents.agent.""" 17 | 18 | from absl.testing import absltest 19 | from balloon_learning_environment.agents import agent 20 | import numpy as np 21 | 22 | 23 | class AgentTest(absltest.TestCase): 24 | 25 | def setUp(self): 26 | super().setUp() 27 | self._na = 5 28 | self._observation_shape = (3, 4) 29 | 30 | def test_valid_subclass(self): 31 | 32 | # Create a simple subclass that implements the abstract methods. 33 | class SimpleAgent(agent.Agent): 34 | 35 | def begin_episode(self, unused_obs: None) -> int: 36 | return 0 37 | 38 | def step( # pytype: disable=signature-mismatch # overriding-parameter-type-checks 39 | self, reward: float, observation: None) -> int: 40 | return 0 41 | 42 | def end_episode(self, reward: float, terminal: bool) -> None: 43 | pass 44 | 45 | simple_agent = SimpleAgent(self._na, self._observation_shape) 46 | self.assertEqual('SimpleAgent', simple_agent.get_name()) 47 | self.assertEqual(self._na, simple_agent._num_actions) 48 | self.assertEqual(self._observation_shape, simple_agent._observation_shape) 49 | self.assertEqual(simple_agent.reload_latest_checkpoint(''), -1) 50 | 51 | 52 | class RandomAgentTest(absltest.TestCase): 53 | 54 | def setUp(self): 55 | super().setUp() 56 | self._na = 5 57 | self._observation_shape = (3, 4) 58 | self._observation = np.zeros(self._observation_shape, dtype=np.float32) 59 | 60 | def test_create_agent(self): 61 | random_agent = agent.RandomAgent(self._na, self._observation_shape) 62 | self.assertEqual('RandomAgent', random_agent.get_name()) 63 | self.assertEqual(self._na, random_agent._num_actions) 64 | self.assertEqual(self._observation_shape, random_agent._observation_shape) 65 | 66 | def test_action_selection(self): 67 | random_agent = agent.RandomAgent(self._na, self._observation_shape) 68 | for _ in range(10): # Test for 10 episodes. 69 | action = random_agent.begin_episode(self._observation) 70 | self.assertGreaterEqual(action, 0) 71 | self.assertLess(action, self._na) 72 | for _ in range(20): # Each episode includes 20 steps. 73 | action = random_agent.step(0.0, self._observation) 74 | self.assertIn(action, range(self._na)) 75 | random_agent.end_episode(0.0, True) 76 | 77 | 78 | if __name__ == '__main__': 79 | absltest.main() 80 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/configs/dqn.gin: -------------------------------------------------------------------------------- 1 | # Hyperparameters for a simple DQN-style agent. 2 | import dopamine.jax.agents.dqn.dqn_agent 3 | import balloon_learning_environment.agents.dqn_agent 4 | import balloon_learning_environment.agents.networks 5 | import dopamine.replay_memory.circular_replay_buffer 6 | 7 | balloon_learning_environment.agents.dqn_agent.DQNAgent.network = @networks.MLPNetwork 8 | balloon_learning_environment.agents.dqn_agent.DQNAgent.checkpoint_duration = 5 9 | networks.MLPNetwork.num_layers = 8 10 | networks.MLPNetwork.hidden_units = 600 11 | JaxDQNAgent.gamma = 0.993 12 | JaxDQNAgent.update_horizon = 5 13 | JaxDQNAgent.min_replay_history = 500 14 | JaxDQNAgent.update_period = 4 15 | JaxDQNAgent.target_update_period = 100 16 | JaxDQNAgent.epsilon_fn = @dqn_agent.identity_epsilon 17 | JaxDQNAgent.epsilon_train = 0.01 18 | JaxDQNAgent.epsilon_eval = 0.0 19 | JaxDQNAgent.optimizer = 'adam' 20 | JaxDQNAgent.loss_type = 'mse' # MSE works better with Adam. 21 | JaxDQNAgent.summary_writing_frequency = 1 22 | JaxDQNAgent.allow_partial_reload = True 23 | dqn_agent.create_optimizer.learning_rate = 2e-6 24 | dqn_agent.create_optimizer.eps = 0.00002 25 | 26 | OutOfGraphReplayBuffer.replay_capacity = 2000000 27 | OutOfGraphReplayBuffer.batch_size = 32 28 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/configs/finetune_perciatelli.gin: -------------------------------------------------------------------------------- 1 | # Hyperparameters for a Quantile-style agent that uses MarcoPolo exploration. 2 | # This is the same configuration used in our Nature paper, and further 3 | # starts from the trained parameters used in the Nature paper. 4 | import dopamine.jax.agents.dqn.dqn_agent 5 | import dopamine.jax.agents.quantile.quantile_agent 6 | import balloon_learning_environment.agents.marco_polo_exploration 7 | import balloon_learning_environment.agents.quantile_agent 8 | import balloon_learning_environment.agents.networks 9 | import balloon_learning_environment.agents.random_walk_agent 10 | import dopamine.replay_memory.prioritized_replay_buffer 11 | 12 | QuantileAgent.network = @agents.networks.QuantileNetwork 13 | QuantileAgent.exploration_wrapper_constructor = @marco_polo_exploration.MarcoPoloExploration 14 | QuantileAgent.reload_perciatelli = True 15 | QuantileAgent.checkpoint_duration = 5 16 | MarcoPoloExploration.exploratory_episode_probability = 0.8 17 | MarcoPoloExploration.exploratory_agent_constructor = @random_walk_agent.RandomWalkAgent 18 | agents.networks.QuantileNetwork.num_layers = 8 19 | agents.networks.QuantileNetwork.hidden_units = 600 20 | JaxQuantileAgent.gamma = 0.993 21 | JaxQuantileAgent.update_horizon = 5 22 | JaxQuantileAgent.min_replay_history = 500 23 | JaxQuantileAgent.update_period = 4 24 | JaxQuantileAgent.target_update_period = 100 25 | JaxQuantileAgent.epsilon_train = 0.0 26 | JaxQuantileAgent.epsilon_eval = 0.0 27 | JaxQuantileAgent.epsilon_fn = @dqn_agent.identity_epsilon 28 | JaxQuantileAgent.optimizer = 'adam' 29 | JaxQuantileAgent.summary_writing_frequency = 1 30 | JaxQuantileAgent.num_atoms = 51 31 | JaxQuantileAgent.allow_partial_reload = True 32 | dqn_agent.create_optimizer.learning_rate = 2e-6 33 | dqn_agent.create_optimizer.eps = 0.00002 34 | 35 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 2000000 36 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 37 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/configs/mlp.gin: -------------------------------------------------------------------------------- 1 | import balloon_learning_environment.agents.networks 2 | 3 | networks.MLPNetwork.num_layers = 1 4 | networks.MLPNetwork.hidden_units = 256 5 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/configs/quantile.gin: -------------------------------------------------------------------------------- 1 | # Hyperparameters for a Quantile-style agent that uses MarcoPolo exploration. 2 | # This is the same configuration used in our Nature paper. 3 | import dopamine.jax.agents.dqn.dqn_agent 4 | import dopamine.jax.agents.quantile.quantile_agent 5 | import balloon_learning_environment.agents.marco_polo_exploration 6 | import balloon_learning_environment.agents.quantile_agent 7 | import balloon_learning_environment.agents.networks 8 | import balloon_learning_environment.agents.random_walk_agent 9 | import dopamine.replay_memory.prioritized_replay_buffer 10 | 11 | QuantileAgent.network = @agents.networks.QuantileNetwork 12 | QuantileAgent.exploration_wrapper_constructor = @marco_polo_exploration.MarcoPoloExploration 13 | QuantileAgent.reload_perciatelli = False 14 | QuantileAgent.checkpoint_duration = 5 15 | MarcoPoloExploration.exploratory_episode_probability = 0.8 16 | MarcoPoloExploration.exploratory_agent_constructor = @random_walk_agent.RandomWalkAgent 17 | agents.networks.QuantileNetwork.num_layers = 8 18 | agents.networks.QuantileNetwork.hidden_units = 600 19 | JaxQuantileAgent.gamma = 0.993 20 | JaxQuantileAgent.update_horizon = 5 21 | JaxQuantileAgent.min_replay_history = 500 22 | JaxQuantileAgent.update_period = 4 23 | JaxQuantileAgent.target_update_period = 100 24 | JaxQuantileAgent.epsilon_train = 0.0 25 | JaxQuantileAgent.epsilon_eval = 0.0 26 | JaxQuantileAgent.epsilon_fn = @dqn_agent.identity_epsilon 27 | JaxQuantileAgent.optimizer = 'adam' 28 | JaxQuantileAgent.summary_writing_frequency = 1 29 | JaxQuantileAgent.num_atoms = 51 30 | JaxQuantileAgent.allow_partial_reload = True 31 | dqn_agent.create_optimizer.learning_rate = 2e-6 32 | dqn_agent.create_optimizer.eps = 0.00002 33 | 34 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 2000000 35 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32 36 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/dopamine_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Common utilities for Dopamine-based agents.""" 17 | 18 | import os.path as osp 19 | import pathlib 20 | import pickle 21 | from typing import Any, Dict, Callable 22 | 23 | from absl import logging 24 | import tensorflow as tf 25 | 26 | 27 | def _make_checkpoint_filename(checkpoint_dir: str, 28 | iteration_number: int) -> str: 29 | return osp.join(checkpoint_dir, f'checkpoint_{iteration_number:05d}.pkl') 30 | 31 | 32 | def _is_checkpoint_name(file_name: str) -> bool: 33 | path = pathlib.Path(file_name) 34 | if path.suffix != '.pkl': 35 | return False 36 | 37 | parts = path.stem.split('_') 38 | if len(parts) != 2: 39 | return False 40 | 41 | return parts[0] == 'checkpoint' and len(parts[1]) == 5 42 | 43 | 44 | def _get_checkpoint_iteration(checkpoint_name: str) -> int: 45 | path = pathlib.Path(checkpoint_name) 46 | parts = path.stem.split('_') 47 | return int(parts[1]) 48 | 49 | 50 | def save_checkpoint(checkpoint_dir: str, 51 | iteration_number: int, 52 | bundle_fn: Callable[[str, int], Any]) -> None: 53 | """Save a checkpoint using the provided bundling function.""" 54 | # Try to create checkpoint directory if it doesn't exist. 55 | try: 56 | tf.io.gfile.makedirs(checkpoint_dir) 57 | except tf.errors.PermissionDeniedError: 58 | # If it already exists, ignore exception. 59 | pass 60 | 61 | bundle = bundle_fn(checkpoint_dir, iteration_number) 62 | if bundle is None: 63 | logging.warning('Unable to checkpoint to %s at iteration %d.', 64 | checkpoint_dir, iteration_number) 65 | return 66 | 67 | filename = _make_checkpoint_filename(checkpoint_dir, iteration_number) 68 | with tf.io.gfile.GFile(filename, 'w') as fout: 69 | pickle.dump(bundle, fout) 70 | 71 | 72 | def load_checkpoint( 73 | checkpoint_dir: str, 74 | iteration_number: int, 75 | unbundle_fn: Callable[[str, int, Dict[Any, Any]], bool]) -> None: 76 | """Load a checkpoint using the provided unbundling function.""" 77 | filename = _make_checkpoint_filename(checkpoint_dir, iteration_number) 78 | if not tf.io.gfile.exists(filename): 79 | logging.warning('Unable to restore bundle from %s', filename) 80 | return 81 | 82 | with tf.io.gfile.GFile(filename, 'rb') as fin: 83 | bundle = pickle.load(fin) 84 | if not unbundle_fn(checkpoint_dir, iteration_number, bundle): 85 | logging.warning('Call to parent `unbundle` failed.') 86 | 87 | 88 | def get_latest_checkpoint(checkpoint_dir: str) -> int: 89 | """Find the episode ID of the latest checkpoint, if any.""" 90 | glob = osp.join(checkpoint_dir, 'checkpoint_*.pkl') 91 | def extract_episode(x): 92 | return int(x[x.rfind('checkpoint_') + 11:-4]) 93 | 94 | try: 95 | checkpoint_files = tf.io.gfile.glob(glob) 96 | except tf.errors.NotFoundError: 97 | logging.warning('Unable to reload checkpoint at %s', checkpoint_dir) 98 | return -1 99 | 100 | try: 101 | latest_episode = max(extract_episode(x) for x in checkpoint_files) 102 | except ValueError: 103 | return -1 104 | return latest_episode 105 | 106 | 107 | def clean_up_old_checkpoints(checkpoint_dir: str, 108 | episode_number: int, 109 | checkpoint_duration: int = 50) -> None: 110 | """Removes the most recent stale checkpoint. 111 | 112 | Args: 113 | checkpoint_dir: Directory where checkpoints are stored. 114 | episode_number: Current episode number. 115 | checkpoint_duration: How long (in terms of episodes) a checkpoint should 116 | last. 117 | """ 118 | for file_name in tf.io.gfile.listdir(checkpoint_dir): 119 | if _is_checkpoint_name(file_name): 120 | iteration = _get_checkpoint_iteration(file_name) 121 | if iteration < episode_number - checkpoint_duration: 122 | tf.io.gfile.remove(osp.join(checkpoint_dir, file_name)) 123 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/exploration.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Base class for exploration modules which can be used by the agents. 17 | 18 | An Agent can wrap an Exploration module during its action selection. 19 | So rather than simply issuing 20 | `return action` 21 | It can issue: 22 | `return self.exploration_wrapper(action)` 23 | """ 24 | 25 | from typing import Sequence 26 | import numpy as np 27 | 28 | 29 | class Exploration(object): 30 | """Base class for an exploration module; this wrapper is a no-op.""" 31 | 32 | def __init__(self, unused_num_actions: int, 33 | unused_observation_shape: Sequence[int]): 34 | pass 35 | 36 | def begin_episode(self, observation: np.ndarray, a: int) -> int: 37 | """Returns the same action passed by the agent.""" 38 | del observation 39 | return a 40 | 41 | def step(self, reward: float, observation: np.ndarray, a: int) -> int: 42 | """Returns the same action passed by the agent.""" 43 | del reward 44 | del observation 45 | return a 46 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/exploration_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for balloon_learning_environment.agents.exploration.""" 17 | 18 | from absl.testing import absltest 19 | from balloon_learning_environment.agents import exploration 20 | import numpy as np 21 | 22 | 23 | class ExplorationTest(absltest.TestCase): 24 | 25 | def test_exploration_class(self): 26 | num_actions = 5 27 | observation_shape = (3, 4) 28 | e = exploration.Exploration(num_actions, observation_shape) 29 | # This class just echoes back the actions passed in, ignoring all other 30 | # parameters. 31 | for i in range(5): 32 | self.assertEqual( 33 | i, e.begin_episode(np.random.rand(*observation_shape), i)) 34 | for j in range(10): 35 | self.assertEqual( 36 | j, e.step(np.random.rand(), np.random.rand(*observation_shape), j)) 37 | 38 | 39 | if __name__ == '__main__': 40 | absltest.main() 41 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/marco_polo_exploration.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Exploration strategy used in the Nature paper. 17 | 18 | Specifically, it interleaves between two phases: 19 | - RL phase (where the parent agent's actions are maintained). 20 | - Exploration phase (where a second agent picks actions). 21 | """ 22 | 23 | import datetime as dt 24 | import time 25 | from typing import Callable, Optional, Sequence 26 | 27 | from balloon_learning_environment.agents import agent 28 | from balloon_learning_environment.agents import exploration 29 | from balloon_learning_environment.utils import constants 30 | import gin 31 | import jax 32 | import numpy as np 33 | 34 | 35 | _RL_PHASE_LENGTH = dt.timedelta(hours=4) 36 | _EXPLORATORY_PHASE_LENGTH = dt.timedelta(hours=2) 37 | 38 | 39 | @gin.configurable 40 | class MarcoPoloExploration(exploration.Exploration): 41 | """Exploration strategy used in the Nature paper.""" 42 | 43 | def __init__(self, num_actions: int, observation_shape: Sequence[int], 44 | exploratory_episode_probability: float = gin.REQUIRED, 45 | exploratory_agent_constructor: Callable[ 46 | [int, Sequence[int]], agent.Agent] = gin.REQUIRED, 47 | seed: Optional[int] = None): 48 | self._exploratory_agent = exploratory_agent_constructor( 49 | num_actions, observation_shape) 50 | self._exploratory_episode_probability = exploratory_episode_probability 51 | self._exploratory_episode = False 52 | self._exploratory_phase = False 53 | self._phase_time_elapsed = dt.timedelta() 54 | seed = int(time.time() * 1e6) if seed is None else seed 55 | self._rng = jax.random.PRNGKey(seed) 56 | 57 | def begin_episode(self, observation: np.ndarray, action: int) -> int: 58 | """Initialize episode, which always starts in RL phase.""" 59 | self._exploratory_agent.begin_episode(observation) 60 | self._phase_time_elapsed = dt.timedelta() 61 | rng, self._rng = jax.random.split(self._rng) 62 | self._exploratory_episode = ( 63 | jax.random.uniform(rng) <= self._exploratory_episode_probability) 64 | # We always start in the RL phase. 65 | self._exploratory_phase = False 66 | return action 67 | 68 | def _phase_expired(self) -> bool: 69 | if (self._exploratory_phase and 70 | self._phase_time_elapsed >= _EXPLORATORY_PHASE_LENGTH): 71 | return True 72 | 73 | if not self._exploratory_phase and self._phase_time_elapsed >= _RL_PHASE_LENGTH: 74 | return True 75 | 76 | return False 77 | 78 | def _update_phase(self) -> None: 79 | self._phase_time_elapsed += constants.AGENT_TIME_STEP 80 | if not self._exploratory_episode: 81 | return 82 | 83 | if self._phase_expired(): 84 | self._exploratory_phase = not self._exploratory_phase 85 | self._phase_time_elapsed = dt.timedelta() 86 | 87 | def step(self, reward: float, observation: np.ndarray, action: int) -> int: 88 | """Return `action` if in RL phase, otherwise query _exploratory_agent.""" 89 | self._update_phase() 90 | if self._exploratory_phase: 91 | return self._exploratory_agent.step(reward, observation) 92 | 93 | return action 94 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/networks.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A common set of networks available for agents.""" 17 | 18 | from absl import logging 19 | from dopamine.discrete_domains import atari_lib 20 | from flax import linen as nn 21 | import gin 22 | import jax 23 | import jax.numpy as jnp 24 | 25 | 26 | @gin.configurable 27 | class MLPNetwork(nn.Module): 28 | """A simple MLP network.""" 29 | num_actions: int 30 | num_layers: int = gin.REQUIRED 31 | hidden_units: int = gin.REQUIRED 32 | is_dopamine: bool = False 33 | 34 | @nn.compact 35 | def __call__(self, x: jnp.ndarray): 36 | # This method sets up the MLP for inference, using the specified number of 37 | # layers and units. 38 | logging.info('Creating MLP network with %d layers and %d hidden units', 39 | self.num_layers, self.hidden_units) 40 | 41 | # Network initializer. 42 | kernel_initializer = jax.nn.initializers.glorot_uniform() 43 | x = x.astype(jnp.float32) # Convert to JAX float32 type. 44 | x = x.reshape(-1) # Flatten. 45 | 46 | # Pass through the desired number of hidden layers (we do this for 47 | # one less than `self.num_layers`, as `self._final_layer` counts as one). 48 | for _ in range(self.num_layers - 1): 49 | x = nn.Dense(features=self.hidden_units, 50 | kernel_init=kernel_initializer)(x) 51 | x = nn.relu(x) 52 | 53 | # The final layer will output a value for each action. 54 | q_values = nn.Dense(features=self.num_actions, 55 | kernel_init=kernel_initializer)(x) 56 | 57 | if self.is_dopamine: 58 | q_values = atari_lib.DQNNetworkType(q_values) 59 | return q_values 60 | 61 | 62 | @gin.configurable 63 | class QuantileNetwork(nn.Module): 64 | """Network used to compute the agent's return quantiles.""" 65 | num_actions: int 66 | num_layers: int = gin.REQUIRED 67 | hidden_units: int = gin.REQUIRED 68 | num_atoms: int = 51 # Normally set by JaxQuantileAgent. 69 | inputs_preprocessed: bool = False 70 | 71 | @nn.compact 72 | def __call__(self, x: jnp.ndarray): 73 | # This method sets up the MLP for inference, using the specified number of 74 | # layers and units. 75 | logging.info('Creating MLP network with %d layers, %d hidden units, and ' 76 | '%d atoms', self.num_layers, self.hidden_units, self.num_atoms) 77 | 78 | # Network initializer. 79 | kernel_initializer = nn.initializers.variance_scaling( 80 | scale=1.0 / jnp.sqrt(3.0), 81 | mode='fan_in', 82 | distribution='uniform') 83 | x = x.astype(jnp.float32) # Convert to JAX float32 type. 84 | x = x.reshape(-1) # Flatten. 85 | 86 | # Pass through the desired number of hidden layers (we do this for 87 | # one less than `self.num_layers`, as `self._final_layer` counts as one). 88 | for _ in range(self.num_layers - 1): 89 | x = nn.Dense(features=self.hidden_units, 90 | kernel_init=kernel_initializer)(x) 91 | x = nn.relu(x) 92 | 93 | x = nn.Dense(features=self.num_actions * self.num_atoms, 94 | kernel_init=kernel_initializer)(x) 95 | logits = x.reshape((self.num_actions, self.num_atoms)) 96 | probabilities = nn.softmax(logits) 97 | q_values = jnp.mean(logits, axis=1) 98 | return atari_lib.RainbowNetworkType(q_values, logits, probabilities) 99 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/networks_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for balloon_learning_environment.agents.networks.""" 17 | 18 | from absl.testing import absltest 19 | from balloon_learning_environment.agents import networks 20 | import gin 21 | import jax 22 | import jax.numpy as jnp 23 | 24 | 25 | class NetworksTest(absltest.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | self._num_actions = 4 30 | self._observation_shape = (6, 7) 31 | self._example_state = jnp.zeros(self._observation_shape) 32 | gin.bind_parameter('MLPNetwork.num_layers', 1) 33 | gin.bind_parameter('MLPNetwork.hidden_units', 256) 34 | 35 | def _create_network(self): 36 | self._network_def = networks.MLPNetwork(num_actions=self._num_actions) 37 | 38 | def test_default_network_parameters(self): 39 | self._create_network() 40 | self.assertEqual(self._num_actions, self._network_def.num_actions) 41 | self.assertEqual(1, self._network_def.num_layers) 42 | self.assertEqual(256, self._network_def.hidden_units) 43 | network_params = self._network_def.init(jax.random.PRNGKey(0), 44 | self._example_state) 45 | self.assertIn('Dense_0', network_params['params']) 46 | 47 | def test_custom_network(self): 48 | num_layers = 5 49 | hidden_units = 64 50 | gin.bind_parameter('MLPNetwork.num_layers', num_layers) 51 | gin.bind_parameter('MLPNetwork.hidden_units', hidden_units) 52 | self._create_network() 53 | self.assertEqual(num_layers, self._network_def.num_layers) 54 | self.assertEqual(hidden_units, self._network_def.hidden_units) 55 | network_params = self._network_def.init(jax.random.PRNGKey(0), 56 | self._example_state) 57 | self.assertIn('Dense_0', network_params['params']) 58 | for i in range(num_layers - 1): 59 | self.assertIn(f'Dense_{i}', network_params['params']) 60 | 61 | def test_call_network(self): 62 | self._create_network() 63 | network_params = self._network_def.init(jax.random.PRNGKey(0), 64 | self._example_state) 65 | # All zeros in should produce all zeros out, since the default initializer 66 | # for bias is all zeros. 67 | zeros_out = self._network_def.apply(network_params, 68 | jnp.zeros_like(self._example_state)) 69 | self.assertTrue(jnp.array_equal(zeros_out, jnp.zeros(self._num_actions))) 70 | # All ones in should produce something that is non-zero at the output, since 71 | # we are using Glorot initiazilation. 72 | ones_out = self._network_def.apply(network_params, 73 | jnp.ones_like(self._example_state)) 74 | self.assertFalse(jnp.array_equal(ones_out, jnp.zeros(self._num_actions))) 75 | 76 | 77 | if __name__ == '__main__': 78 | absltest.main() 79 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/perciatelli44.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A frozen Perciatelli44 agent.""" 17 | 18 | from typing import Sequence 19 | 20 | from balloon_learning_environment.agents import agent 21 | from balloon_learning_environment.models import models 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | 26 | def load_perciatelli_session() -> tf.compat.v1.Session: 27 | serialized_perciatelli = models.load_perciatelli44() 28 | 29 | sess = tf.compat.v1.Session() 30 | graph_def = tf.compat.v1.GraphDef() 31 | graph_def.ParseFromString(serialized_perciatelli) 32 | 33 | tf.compat.v1.import_graph_def(graph_def) 34 | return sess 35 | 36 | 37 | class Perciatelli44(agent.Agent): 38 | """Perciatelli44 Agent. 39 | 40 | This is the agent which was reported as state of the art in 41 | "Autonomous navigation of stratospheric balloons using reinforcement 42 | learning" (Bellemare, Candido, Castro, Gong, Machado, Moitra, Ponda, 43 | and Wang, 2020). 44 | 45 | This agent has its weights frozen, and is intended for comparison in 46 | evaluation, not for retraining. 47 | """ 48 | 49 | def __init__(self, num_actions: int, observation_shape: Sequence[int]): 50 | super(Perciatelli44, self).__init__(num_actions, observation_shape) 51 | 52 | if num_actions != 3: 53 | raise ValueError('Perciatelli44 only supports 3 actions.') 54 | if list(observation_shape) != [1099]: 55 | raise ValueError('Perciatelli44 only supports 1099 dimensional input.') 56 | 57 | # TODO(joshgreaves): It would be nice to use the saved_model API 58 | # for loading the Perciatelli graph. 59 | # TODO(joshgreaves): We wanted to avoid a dependency on TF, but adding 60 | # this to the agent registry makes TF a necessity. 61 | self._sess = load_perciatelli_session() 62 | self._action = self._sess.graph.get_tensor_by_name('sleepwalk_action:0') 63 | self._q_vals = self._sess.graph.get_tensor_by_name('q_values:0') 64 | self._observation = self._sess.graph.get_tensor_by_name('observation:0') 65 | 66 | def begin_episode(self, observation: np.ndarray) -> int: 67 | observation = observation.reshape((1, 1099)) 68 | q_vals = self._sess.run(self._q_vals, 69 | feed_dict={self._observation: observation}) 70 | return np.argmax(q_vals).item() 71 | 72 | def step(self, reward: float, observation: np.ndarray) -> int: 73 | observation = observation.reshape((1, 1099)) 74 | q_vals = self._sess.run(self._q_vals, 75 | feed_dict={self._observation: observation}) 76 | return np.argmax(q_vals).item() 77 | 78 | def end_episode(self, reward: float, terminal: bool = True) -> None: 79 | pass 80 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/perciatelli44_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for perciatelli44.""" 17 | 18 | from absl.testing import absltest 19 | from balloon_learning_environment.agents import perciatelli44 20 | import numpy as np 21 | 22 | 23 | class Perciatelli44Test(absltest.TestCase): 24 | 25 | def setUp(self): 26 | super(Perciatelli44Test, self).setUp() 27 | self._perciatelli44 = perciatelli44.Perciatelli44(3, [1099]) 28 | self._observation = np.ones(1099, dtype=np.float32) 29 | 30 | def test_begin_episode_returns_valid_action(self): 31 | action = self._perciatelli44.begin_episode(self._observation) 32 | self.assertIn(action, [0, 1, 2]) 33 | 34 | def test_step_returns_valid_action(self): 35 | action = self._perciatelli44.begin_episode(self._observation) 36 | self.assertIn(action, [0, 1, 2]) 37 | 38 | def test_perciatelli_only_accepts_3_actions(self): 39 | with self.assertRaises(ValueError): 40 | perciatelli44.Perciatelli44(4, [1099]) 41 | 42 | def test_perciatelli_only_accepts_1099_dim_observation(self): 43 | with self.assertRaises(ValueError): 44 | perciatelli44.Perciatelli44(3, [1100]) 45 | 46 | 47 | if __name__ == '__main__': 48 | absltest.main() 49 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/random_walk_agent.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """An exploratory agent that selects actions based on a random walk. 17 | 18 | Note that this class assumes the features passed in correspond to the 19 | Perciatelli features (see balloon_learning_environment.env.features). 20 | """ 21 | 22 | 23 | import datetime as dt 24 | import time 25 | from typing import Optional, Sequence 26 | 27 | from balloon_learning_environment.agents import agent 28 | from balloon_learning_environment.env import features 29 | from balloon_learning_environment.env.balloon import control 30 | from balloon_learning_environment.utils import constants 31 | from balloon_learning_environment.utils import sampling 32 | import gin 33 | import jax 34 | import numpy as np 35 | 36 | 37 | _PERCIATELLI_FEATURES_SHAPE = (1099,) # Expected shape of Perciatelli features. 38 | _HYSTERESIS = 100 # In Pascals. 39 | _STDDEV = 0.1666 # ~ 10 [Pa/min]. 40 | 41 | 42 | # Although this class does not have any gin-configurable parameters, it is 43 | # decorated as gin-configurable so it can be injected into other classes 44 | # (e.g. MarcoPoloExploration). 45 | @gin.configurable 46 | class RandomWalkAgent(agent.Agent): 47 | """An exploratory agent that selects actions based on a random walk.""" 48 | 49 | def __init__(self, num_actions: int, observation_shape: Sequence[int], 50 | seed: Optional[int] = None): 51 | del num_actions 52 | del observation_shape 53 | seed = int(time.time() * 1e6) if seed is None else seed 54 | self._rng = jax.random.PRNGKey(seed) 55 | self._time_elapsed = dt.timedelta() 56 | self._sample_new_target_pressure() 57 | 58 | def _sample_new_target_pressure(self): 59 | self._rng, rng = jax.random.split(self._rng) 60 | self._target_pressure = sampling.sample_pressure(rng) 61 | 62 | def _select_action(self, features_as_vector: np.ndarray) -> int: 63 | assert features_as_vector.shape == _PERCIATELLI_FEATURES_SHAPE 64 | balloon_pressure = features.NamedPerciatelliFeatures( 65 | features_as_vector).balloon_pressure 66 | # Note: higher pressures means lower altitude. 67 | if balloon_pressure - _HYSTERESIS > self._target_pressure: 68 | return control.AltitudeControlCommand.UP 69 | 70 | if balloon_pressure + _HYSTERESIS < self._target_pressure: 71 | return control.AltitudeControlCommand.DOWN 72 | 73 | return control.AltitudeControlCommand.STAY 74 | 75 | def begin_episode(self, observation: np.ndarray) -> int: 76 | self._time_elapsed = dt.timedelta() 77 | self._sample_new_target_pressure() 78 | return self._select_action(observation) 79 | 80 | def step(self, reward: float, observation: np.ndarray) -> int: 81 | del reward 82 | # Advance time_elapsed. 83 | self._time_elapsed += constants.AGENT_TIME_STEP 84 | # Update target pressure. This is essentially a random walk between 85 | # altitudes by sampling from zero-mean Gaussian noise, where the amount of 86 | # variance is proportional to the amount of time (in seconds) that has 87 | # elapsed since the last time it was updated. 88 | self._rng, rng = jax.random.split(self._rng) 89 | self._target_pressure += ( 90 | self._time_elapsed.total_seconds() * _STDDEV * jax.random.normal(rng)) 91 | return self._select_action(observation) 92 | 93 | def end_episode(self, reward: float, terminal: bool = True) -> None: 94 | pass 95 | -------------------------------------------------------------------------------- /balloon_learning_environment/agents/station_seeker_agent_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for station_seeker_agent.""" 17 | 18 | from absl.testing import absltest 19 | from balloon_learning_environment.agents import station_seeker_agent 20 | import numpy as np 21 | 22 | 23 | class StationSeekerAgentTest(absltest.TestCase): 24 | 25 | def setUp(self): 26 | super().setUp() 27 | self._na = 3 28 | self._observation_shape = (3, 4) 29 | self._ss_agent = station_seeker_agent.StationSeekerAgent( 30 | self._na, self._observation_shape) 31 | 32 | def test_create_agent(self): 33 | self.assertEqual('StationSeekerAgent', self._ss_agent.get_name()) 34 | 35 | def test_action_selection(self): 36 | mock_observation = np.zeros(1099) 37 | 38 | # Because the observation is uniform everywhere, we expect the controller 39 | # to stay (action = 1). 40 | for _ in range(10): # Test for 10 episodes. 41 | action = self._ss_agent.begin_episode(mock_observation) 42 | self.assertEqual(action, 1) 43 | for _ in range(20): # Each episode includes 20 steps. 44 | action = self._ss_agent.step(0.0, mock_observation) 45 | self.assertEqual(action, 1) 46 | 47 | def test_end_episode(self): 48 | # end_episode doesn't do anything (it exists to conform to the Agent 49 | # interface). This next line just checks that it runs without problems. 50 | self._ss_agent.end_episode(0.0, True) 51 | 52 | # TODO(bellemare): Test wind score: decreases as bearing increases. 53 | 54 | if __name__ == '__main__': 55 | absltest.main() 56 | -------------------------------------------------------------------------------- /balloon_learning_environment/colab/summarize_eval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Summarize Eval", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "source": [ 22 | "Copyright 2021 The Balloon Learning Environment Authors.\n", 23 | "\n", 24 | "Licensed under the Apache License, Version 2.0 (the \"License\");\n", 25 | "you may not use this file except in compliance with the License.\n", 26 | "You may obtain a copy of the License at\n", 27 | "\n", 28 | " http://www.apache.org/licenses/LICENSE-2.0\n", 29 | "\n", 30 | "Unless required by applicable law or agreed to in writing, software\n", 31 | "distributed under the License is distributed on an \"AS IS\" BASIS,\n", 32 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 33 | "See the License for the specific language governing permissions and\n", 34 | "limitations under the License." 35 | ], 36 | "metadata": { 37 | "id": "u8hGsWm_qiGT" 38 | } 39 | }, 40 | { 41 | "cell_type": "code", 42 | "metadata": { 43 | "cellView": "form", 44 | "id": "VoqfM73Tmji_" 45 | }, 46 | "source": [ 47 | "# @title Imports\n", 48 | "import collections\n", 49 | "import json\n", 50 | "\n", 51 | "from google.colab import files\n", 52 | "import matplotlib.pyplot as plt\n", 53 | "import pandas as pd\n", 54 | "\n", 55 | "from IPython.display import HTML" 56 | ], 57 | "execution_count": null, 58 | "outputs": [] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "metadata": { 63 | "cellView": "form", 64 | "id": "GgEqoO-9mpkt" 65 | }, 66 | "source": [ 67 | "# @title Upload Data\n", 68 | "uploaded_files = files.upload()\n", 69 | "dataframes = dict()\n", 70 | "\n", 71 | "for name, data in uploaded_files.items():\n", 72 | " name = name.rsplit('.', maxsplit=1)[0]\n", 73 | " json_data = json.loads(data)\n", 74 | " dataframes[name] = pd.DataFrame.from_dict(json_data)" 75 | ], 76 | "execution_count": null, 77 | "outputs": [] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "metadata": { 82 | "cellView": "form", 83 | "id": "JGMAmrANm0an" 84 | }, 85 | "source": [ 86 | "# @title Print Aggregate results\n", 87 | "aggregate_data = collections.defaultdict(list)\n", 88 | "for name, df in dataframes.items():\n", 89 | " aggregate_data['num episodes'].append(len(df.out_of_power))\n", 90 | " aggregate_data['out of power'].append(df.out_of_power.sum())\n", 91 | " aggregate_data['zeropressure'].append(df.zeropressure.sum())\n", 92 | " aggregate_data['envelope burst'].append(df.envelope_burst.sum())\n", 93 | "\n", 94 | " finished_runs = df.loc[df.out_of_power == False]\n", 95 | " finished_runs = finished_runs.loc[finished_runs.zeropressure == False]\n", 96 | " finished_runs = finished_runs.loc[finished_runs.envelope_burst == False]\n", 97 | " aggregate_data['mean cumulative reward (finished episodes)'].append(\n", 98 | " finished_runs.cumulative_reward.mean())\n", 99 | " aggregate_data['mean TWR50 (finished episodes)'].append(\n", 100 | " finished_runs.time_within_radius.mean())\n", 101 | " aggregate_data['mean cumulative reward (all episodes)'].append(\n", 102 | " df.cumulative_reward.mean())\n", 103 | " aggregate_data['mean TWR50 (all episodes)'].append(\n", 104 | " df.time_within_radius.mean())\n", 105 | "\n", 106 | "df = pd.DataFrame(aggregate_data)\n", 107 | "df.index = dataframes.keys()\n", 108 | "\n", 109 | "# This is a little hacky, but it works 🤷‍♂️ \n", 110 | "html = df.to_html()\n", 111 | "html += \"\"\n", 112 | "HTML(html)" 113 | ], 114 | "execution_count": null, 115 | "outputs": [] 116 | } 117 | ] 118 | } -------------------------------------------------------------------------------- /balloon_learning_environment/env/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/balloon/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/balloon/acs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ACS power as a function of superpressure.""" 17 | 18 | from balloon_learning_environment.utils import constants 19 | from balloon_learning_environment.utils import units 20 | import numpy as np 21 | from scipy import interpolate 22 | 23 | 24 | _PRESSURE_RATIO_TO_POWER: interpolate.interp1d = ( 25 | interpolate.interp1d( 26 | np.array([1.0, 1.05, 1.2, 1.25, 1.35]), # pressure_ratio 27 | np.array([100.0, 100.0, 300.0, 400.0, 400.0]), # power 28 | fill_value='extrapolate')) 29 | 30 | 31 | _PRESSURE_RATIO_POWER_TO_EFFICIENCY: interpolate.interp2d = ( 32 | interpolate.interp2d( 33 | np.linspace(1.05, 1.35, 13), # pressure_ratio 34 | np.linspace(100.0, 400.0, 4), # power 35 | np.array([0.4, 0.4, 0.3, 0.2, 0.2, 0.00000, 0.00000, 0.00000, 0.00000, 36 | 0.00000, 0.00000, 0.00000, 0.00000, 0.4, 0.3, 0.3, 0.30, 0.25, 37 | 0.23, 0.20, 0.15, 0.12, 0.10, 0.00000, 0.00000, 0.00000, 38 | 0.00000, 0.3, 0.25, 0.25, 0.25, 0.20, 0.20, 0.20, 0.2, 0.15, 39 | 0.13, 0.12, 0.11, 0.00000, 0.23, 0.23, 0.23, 0.23, 0.23, 0.20, 40 | 0.20, 0.20, 0.18, 0.16, 0.15, 0.13]), # efficiency 41 | fill_value=None)) 42 | 43 | 44 | def get_most_efficient_power(pressure_ratio: float) -> units.Power: 45 | """Lookup the optimal operating power from static tables. 46 | 47 | Gets the optimal operating power [W] in terms of kg of air moved per unit 48 | energy. 49 | 50 | Args: 51 | pressure_ratio: Ratio of (balloon pressure + superpresure) to baloon 52 | pressure. 53 | 54 | Returns: 55 | Optimal ACS power at current pressure_ratio. 56 | """ 57 | power = _PRESSURE_RATIO_TO_POWER(pressure_ratio) 58 | return units.Power(watts=power) 59 | 60 | 61 | def get_fan_efficiency(pressure_ratio: float, power: units.Power) -> float: 62 | # Compute efficiency of air flow from current pressure ratio and power. 63 | efficiency = _PRESSURE_RATIO_POWER_TO_EFFICIENCY(pressure_ratio, power.watts) 64 | return float(efficiency) 65 | 66 | 67 | def get_mass_flow(power: units.Power, efficiency: float) -> float: 68 | return efficiency * power.watts / constants.NUM_SECONDS_PER_HOUR 69 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/balloon/acs_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for balloon_learning_environment.env.acs.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from balloon_learning_environment.env.balloon import acs 21 | from balloon_learning_environment.utils import units 22 | 23 | 24 | class AcsTest(parameterized.TestCase): 25 | 26 | @parameterized.named_parameters( 27 | dict(testcase_name='at_min', pressure_ratio=1.0, power=100.0, 28 | comparator='eq'), 29 | dict(testcase_name='at_mid', pressure_ratio=1.2, power=300.0, 30 | comparator='eq'), 31 | dict(testcase_name='at_max', pressure_ratio=1.35, power=400.0, 32 | comparator='eq'), 33 | dict(testcase_name='below_min', pressure_ratio=0.01, power=100.0, 34 | comparator='lt'), 35 | dict(testcase_name='above_max', pressure_ratio=2.0, power=400.0, 36 | comparator='gt')) 37 | def test_get_most_efficient_power(self, pressure_ratio, power, comparator): 38 | if comparator == 'eq': 39 | comparator = self.assertEqual 40 | elif comparator == 'lt': 41 | comparator = self.assertLessEqual 42 | else: 43 | comparator = self.assertGreaterEqual 44 | comparator(acs.get_most_efficient_power(pressure_ratio).watts, power) 45 | 46 | @parameterized.named_parameters( 47 | dict(testcase_name='at_min', pressure_ratio=1.05, power=100.0, 48 | efficiency=0.4, comparator='eq'), 49 | dict(testcase_name='at_max', pressure_ratio=1.35, power=400.0, 50 | efficiency=0.13, comparator='eq'), 51 | dict(testcase_name='below_min', pressure_ratio=0.01, power=10.0, 52 | efficiency=0.4, comparator='gt'), 53 | dict(testcase_name='above_max', pressure_ratio=2.0, power=500.0, 54 | efficiency=0.13, comparator='lt')) 55 | def test_get_fan_efficiency(self, pressure_ratio, power, efficiency, 56 | comparator): 57 | if comparator == 'eq': 58 | comparator = self.assertEqual 59 | elif comparator == 'lt': 60 | comparator = self.assertLessEqual 61 | else: 62 | comparator = self.assertGreaterEqual 63 | comparator(acs.get_fan_efficiency(pressure_ratio, units.Power(watts=power)), 64 | efficiency) 65 | 66 | def test_get_mass_flow(self): 67 | self.assertEqual( 68 | acs.get_mass_flow(units.Power(watts=3.6), 10.0), 0.01) 69 | 70 | 71 | if __name__ == '__main__': 72 | absltest.main() 73 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/balloon/altitude_safety.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Prevents the balloon navigating to unsafe altitudes. 17 | 18 | This safety layer prevents the balloon from navigating to below 50,000 feet 19 | of altitude. Internally, it maintains a state machine to remember whether 20 | the balloon is close to the altitude limit. The balloon only returns to the 21 | nominal state once the balloon has navigated sufficiently far from the 22 | altitude limit. If the balloon moves below the altitude limit, the safety 23 | layer will issue the ascend command. 24 | """ 25 | 26 | import enum 27 | import logging 28 | 29 | from balloon_learning_environment.env.balloon import control 30 | from balloon_learning_environment.env.balloon import standard_atmosphere 31 | from balloon_learning_environment.utils import units 32 | import transitions 33 | 34 | # TODO(joshgreaves): This may require some tuning. 35 | BUFFER = units.Distance(feet=500.0) 36 | RESTART_HYSTERESIS = units.Distance(feet=500.0) 37 | MIN_ALTITUDE = units.Distance(feet=50_000.0) 38 | 39 | 40 | class _AltitudeState(enum.Enum): 41 | NOMINAL = 0 42 | LOW = 1 43 | VERY_LOW = 2 44 | 45 | 46 | # Note: Transitions are applied in the order of the first match. 47 | # '*' is a catch-all, and applies to any state. 48 | _ALTITUDE_SAFETY_TRANSITIONS = ( 49 | dict(trigger='very_low', source='*', dest=_AltitudeState.VERY_LOW), 50 | dict(trigger='low', source='*', dest=_AltitudeState.LOW), 51 | dict( 52 | trigger='low_nominal', 53 | source=(_AltitudeState.VERY_LOW, _AltitudeState.LOW), 54 | dest=_AltitudeState.LOW), 55 | dict( 56 | trigger='low_nominal', 57 | source=_AltitudeState.NOMINAL, 58 | dest=_AltitudeState.NOMINAL), 59 | dict(trigger='nominal', source='*', dest=_AltitudeState.NOMINAL), 60 | ) 61 | 62 | 63 | class AltitudeSafetyLayer: 64 | """A safety layer that prevents balloons navigating to unsafe altitudes.""" 65 | 66 | def __init__(self): 67 | self._state_machine = transitions.Machine( 68 | states=_AltitudeState, 69 | transitions=_ALTITUDE_SAFETY_TRANSITIONS, 70 | initial=_AltitudeState.NOMINAL) 71 | logging.getLogger('transitions').setLevel(logging.WARNING) 72 | 73 | def get_action(self, action: control.AltitudeControlCommand, 74 | atmosphere: standard_atmosphere.Atmosphere, 75 | pressure: float) -> control.AltitudeControlCommand: 76 | """Gets the action recommended by the safety layer. 77 | 78 | Args: 79 | action: The action the controller has supplied to the balloon. 80 | atmosphere: The atmospheric conditions the balloon is flying in. 81 | pressure: The current pressure of the balloon. 82 | 83 | Returns: 84 | An action the safety layer recommends. 85 | """ 86 | altitude = atmosphere.at_pressure(pressure).height 87 | self._transition_state(altitude) 88 | 89 | if self._state_machine.state == _AltitudeState.VERY_LOW: 90 | # If the balloon is too low, make it ascend. 91 | return control.AltitudeControlCommand.UP 92 | elif self._state_machine.state == _AltitudeState.LOW: 93 | # If the balloon is almost too low, don't let it go lower. 94 | if action == control.AltitudeControlCommand.DOWN: 95 | return control.AltitudeControlCommand.STAY 96 | 97 | return action 98 | 99 | @property 100 | def navigation_is_paused(self): 101 | return self._state_machine.state != _AltitudeState.NOMINAL 102 | 103 | def _transition_state(self, altitude: units.Distance): 104 | if altitude < MIN_ALTITUDE: 105 | self._state_machine.very_low() 106 | elif altitude < MIN_ALTITUDE + BUFFER: 107 | self._state_machine.low() 108 | elif altitude < MIN_ALTITUDE + BUFFER + RESTART_HYSTERESIS: 109 | self._state_machine.low_nominal() 110 | else: 111 | self._state_machine.nominal() 112 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/balloon/control.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Control related functionality.""" 17 | 18 | import enum 19 | 20 | 21 | class AltitudeControlCommand(enum.IntEnum): 22 | """Specifies the command/action for balloon navigation.""" 23 | DOWN = 0 24 | STAY = 1 25 | UP = 2 26 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/balloon/power_table.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A lookup table from pressure ratio, state of charge -> power to use.""" 17 | 18 | import bisect 19 | 20 | 21 | def lookup(pressure_ratio: float, 22 | state_of_charge: float) -> float: 23 | """Map pressure_ratio x state_of_charge to power to use.""" 24 | assert pressure_ratio >= 0.99 and pressure_ratio <= 5 25 | pressure_ratio_intervals = [1.08, 1.11, 1.14, 1.17, 1.2, 1.23, 1.26] 26 | soc_mappings = [ # One entry for each pressure ratio interval. 27 | ([0.3, 0.4, 0.5], [0, 150, 175, 200]), # 0.99 -> 1.08 28 | ([0.3, 0.4, 0.7], [0, 200, 200, 225]), # 1.08 -> 1.11 29 | ([0.3, 0.4, 0.6], [0, 225, 225, 250]), # 1.11 -> 1.14 30 | ([0.3, 0.4, 0.5], [0, 200, 225, 250]), # 1.14 -> 1.17 31 | ([0.3, 0.4, 0.5], [0, 225, 250, 275]), # 1.17 -> 1.2 32 | ([0.4, 0.5], [0, 275, 300]), # 1.2 -> 1.23 33 | ([0.5, 0.6], [0, 300, 325]), # 1.23 -> 1.26 34 | ([0.5, 0.6], [0, 325, 350]) # 1.26 -> 5.0 35 | ] 36 | pr_id = bisect.bisect(pressure_ratio_intervals, pressure_ratio) 37 | soc_id = bisect.bisect(soc_mappings[pr_id][0], state_of_charge) 38 | return soc_mappings[pr_id][1][soc_id] 39 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/balloon/power_table_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for power_table.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from balloon_learning_environment.env.balloon import power_table 21 | 22 | 23 | class PowerTableTest(parameterized.TestCase): 24 | 25 | @parameterized.named_parameters( 26 | dict(testcase_name='low_pressure_ratio', pressure_ratio=0.98), 27 | dict(testcase_name='high_pressure_ratio', pressure_ratio=5.01)) 28 | def test_invalid_pressure_ratios(self, pressure_ratio: float): 29 | with self.assertRaises(AssertionError): 30 | power_table.lookup(pressure_ratio, 1.0) 31 | 32 | @parameterized.parameters( 33 | (1.0, 0.2, 0), 34 | (1.0, 0.3, 150), 35 | (1.0, 0.4, 175), 36 | (1.0, 0.5, 200), 37 | (1.0, 0.6, 200), 38 | (1.08, 0.2, 0), 39 | (1.08, 0.3, 200), 40 | (1.08, 0.4, 200), 41 | (1.08, 0.7, 225), 42 | (1.08, 0.8, 225), 43 | (1.11, 0.2, 0), 44 | (1.11, 0.3, 225), 45 | (1.11, 0.4, 225), 46 | (1.11, 0.6, 250), 47 | (1.11, 0.7, 250), 48 | (1.14, 0.2, 0), 49 | (1.14, 0.3, 200), 50 | (1.14, 0.4, 225), 51 | (1.14, 0.5, 250), 52 | (1.14, 0.6, 250), 53 | (1.17, 0.2, 0), 54 | (1.17, 0.3, 225), 55 | (1.17, 0.4, 250), 56 | (1.17, 0.5, 275), 57 | (1.17, 0.6, 275), 58 | (1.2, 0.3, 0), 59 | (1.2, 0.4, 275), 60 | (1.2, 0.5, 300), 61 | (1.2, 0.6, 300), 62 | (1.23, 0.4, 0), 63 | (1.23, 0.5, 300), 64 | (1.23, 0.6, 325), 65 | (1.23, 0.7, 325), 66 | (1.26, 0.4, 0), 67 | (1.26, 0.5, 325), 68 | (1.26, 0.6, 350), 69 | (1.26, 0.7, 350)) 70 | def test_table_lookup(self, pressure_ratio, state_of_charge, 71 | expected_power_to_use): 72 | self.assertEqual(power_table.lookup(pressure_ratio, state_of_charge), 73 | expected_power_to_use) 74 | 75 | 76 | if __name__ == '__main__': 77 | absltest.main() 78 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/balloon/pressure_range_builder_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for pressure_range_builder.""" 17 | 18 | import functools 19 | 20 | from absl.testing import absltest 21 | from balloon_learning_environment.env.balloon import altitude_safety 22 | from balloon_learning_environment.env.balloon import pressure_range_builder 23 | from balloon_learning_environment.env.balloon import standard_atmosphere 24 | from balloon_learning_environment.utils import test_helpers 25 | import jax 26 | 27 | 28 | class AltitudeRangeBuilderTest(absltest.TestCase): 29 | 30 | def setUp(self): 31 | super().setUp() 32 | self.atmosphere = standard_atmosphere.Atmosphere(jax.random.PRNGKey(0)) 33 | self.create_balloon = functools.partial( 34 | test_helpers.create_balloon, atmosphere=self.atmosphere) 35 | 36 | def test_get_pressure_range_returns_valid_range(self): 37 | b = self.create_balloon() 38 | 39 | pressure_range = pressure_range_builder.get_pressure_range( 40 | b.state, self.atmosphere) 41 | 42 | self.assertIsInstance(pressure_range, 43 | pressure_range_builder.AccessiblePressureRange) 44 | self.assertBetween(pressure_range.min_pressure, 1000.0, 100_000.0) 45 | self.assertBetween(pressure_range.max_pressure, 1000.0, 100_000.0) 46 | 47 | def test_get_pressure_range_returns_min_pressure_below_max_pressure(self): 48 | b = self.create_balloon() 49 | 50 | pressure_range = pressure_range_builder.get_pressure_range( 51 | b.state, self.atmosphere) 52 | 53 | self.assertLess(pressure_range.min_pressure, pressure_range.max_pressure) 54 | 55 | def test_get_pressure_range_returns_max_pressure_above_min_altitude(self): 56 | b = self.create_balloon() 57 | 58 | pressure_range = pressure_range_builder.get_pressure_range( 59 | b.state, self.atmosphere) 60 | 61 | self.assertLessEqual( 62 | pressure_range.max_pressure, 63 | self.atmosphere.at_height(altitude_safety.MIN_ALTITUDE).pressure) 64 | 65 | # TODO(joshgreaves): Add more tests when the pressure ranges are as expected. 66 | 67 | 68 | if __name__ == '__main__': 69 | absltest.main() 70 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/balloon/stable_init_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for stable_params.""" 17 | 18 | import datetime as dt 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | from balloon_learning_environment.env import wind_field 23 | from balloon_learning_environment.env.balloon import control 24 | from balloon_learning_environment.env.balloon import solar 25 | from balloon_learning_environment.env.balloon import standard_atmosphere 26 | from balloon_learning_environment.env.balloon import thermal 27 | from balloon_learning_environment.utils import test_helpers 28 | from balloon_learning_environment.utils import units 29 | import jax 30 | 31 | 32 | class StableParamsTest(parameterized.TestCase): 33 | 34 | def setUp(self): 35 | super().setUp() 36 | self.atmosphere = standard_atmosphere.Atmosphere(jax.random.PRNGKey(38)) 37 | 38 | @parameterized.named_parameters( 39 | dict(testcase_name='middle_pressure', init_pressure=9_500.0), 40 | dict(testcase_name='high_pressure', init_pressure=11_500.0), 41 | dict(testcase_name='low_pressure', init_pressure=6_500.0)) 42 | def test_cold_start_to_stable_params_initializes_mols_air_correctly( 43 | self, init_pressure: float): 44 | # The superpressure is very sensitive to temperature, and hence time of 45 | # day, so create the balloon at midnight. 46 | # create_balloon runs cold_start_to_stable_params by default. 47 | b = test_helpers.create_balloon( 48 | pressure=init_pressure, 49 | date_time=units.datetime(2020, 6, 1, 0, 0, 0), 50 | atmosphere=self.atmosphere) 51 | 52 | # Simulate the balloon for a while. If the internal parameters are 53 | # correctly set, it should stay at roughly the correct pressure level. 54 | for _ in range(100): 55 | b.simulate_step( 56 | wind_field.WindVector( 57 | units.Velocity(mps=3.0), units.Velocity(mps=-4.0)), 58 | self.atmosphere, control.AltitudeControlCommand.STAY, 59 | dt.timedelta(seconds=10.0)) 60 | 61 | self.assertLess(abs(b.state.pressure - init_pressure), 100.0) 62 | 63 | @parameterized.named_parameters( 64 | dict(testcase_name='middle_pressure', init_pressure=9_500.0), 65 | dict(testcase_name='high_pressure', init_pressure=11_500.0), 66 | dict(testcase_name='low_pressure', init_pressure=5_000.0)) 67 | def test_cold_start_to_stable_params_initializes_temperature_correctly( 68 | self, init_pressure: float): 69 | # create_balloon runs cold_start_to_stable_params by default. 70 | b = test_helpers.create_balloon( 71 | pressure=init_pressure, atmosphere=self.atmosphere) 72 | 73 | solar_elevation, _, solar_flux = solar.solar_calculator( 74 | b.state.latlng, b.state.date_time) 75 | d_internal_temp = thermal.d_balloon_temperature_dt( 76 | b.state.envelope_volume, b.state.envelope_mass, 77 | b.state.internal_temperature, b.state.ambient_temperature, 78 | b.state.pressure, solar_elevation, solar_flux, 79 | b.state.upwelling_infrared) 80 | 81 | # If the rate of change of temperature is low, the temperature is 82 | # initialized to a stable value. 83 | self.assertLess(d_internal_temp, 1e-3) 84 | 85 | 86 | if __name__ == '__main__': 87 | absltest.main() 88 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/balloon_arena_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for balloon_learning_environment.env.balloon_arena.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from balloon_learning_environment.utils import constants 21 | from balloon_learning_environment.utils import test_helpers 22 | from balloon_learning_environment.utils import units 23 | import jax 24 | 25 | 26 | class BalloonArenaTest(parameterized.TestCase): 27 | 28 | # TODO(joshgreaves): Patch PerciatelliFeatureConstructor, it's too slow. 29 | 30 | def test_int_seeding_gives_determinisic_balloon_initialization(self): 31 | arena1 = test_helpers.create_arena() 32 | arena2 = test_helpers.create_arena() 33 | 34 | arena1.reset(201) 35 | arena2.reset(201) 36 | balloon_state1 = arena1.get_simulator_state().balloon_state 37 | balloon_state2 = arena2.get_simulator_state().balloon_state 38 | 39 | test_helpers.compare_balloon_states(balloon_state1, balloon_state2) 40 | 41 | def test_array_seeding_gives_determinisic_balloon_initialization(self): 42 | arena1 = test_helpers.create_arena() 43 | arena2 = test_helpers.create_arena() 44 | 45 | arena1.reset(jax.random.PRNGKey(201)) 46 | arena2.reset(jax.random.PRNGKey(201)) 47 | balloon_state1 = arena1.get_simulator_state().balloon_state 48 | balloon_state2 = arena2.get_simulator_state().balloon_state 49 | test_helpers.compare_balloon_states(balloon_state1, balloon_state2) 50 | 51 | def test_different_seeds_gives_different_initialization(self): 52 | arena1 = test_helpers.create_arena() 53 | arena2 = test_helpers.create_arena() 54 | 55 | arena1.reset(201) 56 | arena2.reset(202) 57 | balloon_state1 = arena1.get_simulator_state().balloon_state 58 | balloon_state2 = arena2.get_simulator_state().balloon_state 59 | test_helpers.compare_balloon_states( 60 | balloon_state1, balloon_state2, check_not_equal=['x', 'y']) 61 | 62 | def test_random_seeding_doesnt_throw_exception(self): 63 | arena = test_helpers.create_arena() 64 | 65 | arena.reset() 66 | # Succeeds if no error was thrown 67 | 68 | @parameterized.named_parameters((str(x), x) for x in (1, 5, 28, 90, 106, 378)) 69 | def test_balloon_is_initialized_within_200km(self, seed: int): 70 | arena = test_helpers.create_arena() 71 | 72 | arena.reset(seed) 73 | balloon_state = arena.get_simulator_state().balloon_state 74 | 75 | distance = units.relative_distance(balloon_state.x, balloon_state.y) 76 | self.assertLessEqual(distance.km, 200.0) 77 | 78 | @parameterized.named_parameters((str(x), x) for x in (1, 5, 28, 90, 106, 378)) 79 | def test_balloon_is_initialized_within_valid_pressure_range(self, seed: int): 80 | arena = test_helpers.create_arena() 81 | 82 | arena.reset(seed) 83 | balloon_state = arena.get_simulator_state().balloon_state 84 | 85 | self.assertBetween(balloon_state.pressure, 86 | constants.PERCIATELLI_PRESSURE_RANGE_MIN, 87 | constants.PERCIATELLI_PRESSURE_RANGE_MAX) 88 | 89 | if __name__ == '__main__': 90 | absltest.main() 91 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/grid_wind_field_sampler.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """An interface for sampling grid wind fields.""" 17 | 18 | import abc 19 | import datetime as dt 20 | 21 | from balloon_learning_environment.generative import vae 22 | from jax import numpy as jnp 23 | import numpy as np 24 | 25 | 26 | class GridWindFieldSampler(abc.ABC): 27 | 28 | @property 29 | @abc.abstractmethod 30 | def field_shape(self) -> vae.FieldShape: 31 | """Gets the field shape of wind fields sampled by this class.""" 32 | 33 | @abc.abstractmethod 34 | def sample_field(self, 35 | key: jnp.ndarray, 36 | date_time: dt.datetime) -> np.ndarray: 37 | """Samples a wind field and returns it as a numpy array. 38 | 39 | Args: 40 | key: A PRNGKey to use for sampling. 41 | date_time: The date_time of the begining on the wind field. 42 | """ 43 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/gym.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Balloon Learning Environment gym utilities.""" 17 | import contextlib 18 | 19 | 20 | def register_env() -> None: 21 | """Register the Gym environment.""" 22 | # We need to import Gym's registration module inline or else we'll 23 | # get a circular dependency that will result in an error when importing gym 24 | from gym.envs import registration # pylint: disable=g-import-not-at-top 25 | 26 | env_id = 'BalloonLearningEnvironment-v0' 27 | env_entry_point = 'balloon_learning_environment.env.balloon_env:BalloonEnv' 28 | # We guard registration by checking if our env is already registered 29 | # This is necesarry because the plugin system will load our module 30 | # which also calls this function. If multiple `register()` calls are 31 | # made this will result in a warning to the user. 32 | registered = env_id in registration.registry.env_specs 33 | 34 | if not registered: 35 | with contextlib.ExitStack() as stack: 36 | # This is a workaround for Gym 0.21 which didn't support 37 | # registering into the root namespace with the plugin system. 38 | if hasattr(registration, 'namespace'): 39 | stack.enter_context(registration.namespace(None)) 40 | registration.register(id=env_id, entry_point=env_entry_point) 41 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/rendering/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/rendering/renderer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Abstract base class for renderers.""" 17 | 18 | import abc 19 | from typing import Iterable, Optional, Text, Union 20 | 21 | from balloon_learning_environment.env import simulator_data 22 | from flax.metrics import tensorboard 23 | import numpy as np 24 | 25 | 26 | class Renderer(abc.ABC): 27 | """A renderer object for rendering the simulator state.""" 28 | 29 | @abc.abstractmethod 30 | def reset(self) -> None: 31 | pass 32 | 33 | @abc.abstractmethod 34 | def step(self, state: simulator_data.SimulatorState) -> None: 35 | pass 36 | 37 | @abc.abstractmethod 38 | def render(self, 39 | mode: Text, 40 | summary_writer: Optional[tensorboard.SummaryWriter] = None, 41 | iteration: Optional[int] = None) -> Union[None, np.ndarray, Text]: 42 | """Renders a frame. 43 | 44 | Args: 45 | mode: A string specifying the mode. Default gym render modes are `human`, 46 | `rgb_array`, and `ansi`. However, a renderer may specify additional 47 | render modes beyond this. `human` corresponds to rendering directly to 48 | the screen. `rgb_array` renders to a numpy array and returns it. `ansi` 49 | renders to a string or StringIO object. 50 | summary_writer: If not None, will also render the image to the tensorboard 51 | summary. 52 | iteration: Iteration number used for writing to tensorboard. 53 | 54 | Returns: 55 | None, a numpy array of rgb data, or a Text object, depending on the mode. 56 | """ 57 | pass 58 | 59 | @property 60 | @abc.abstractmethod 61 | def render_modes(self) -> Iterable[Text]: 62 | pass 63 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/simulator_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Model classes for simulator state and simulator observations.""" 17 | 18 | import dataclasses 19 | 20 | from balloon_learning_environment.env import wind_field 21 | from balloon_learning_environment.env.balloon import balloon 22 | from balloon_learning_environment.env.balloon import standard_atmosphere 23 | 24 | 25 | @dataclasses.dataclass 26 | class SimulatorState(object): 27 | """Specifies the full state of the simulator. 28 | 29 | Since it specifies the full state of the simulator, it should be 30 | possible to use this for checkpointing and restoring the simulator. 31 | """ 32 | balloon_state: balloon.BalloonState 33 | wind_field: wind_field.WindField 34 | atmosphere: standard_atmosphere.Atmosphere 35 | 36 | 37 | @dataclasses.dataclass 38 | class SimulatorObservation(object): 39 | """Specifies an observation from the simulator. 40 | 41 | This differs from SimulatorState in that the observations are not 42 | ground truth state, and are instead noisy observations from the 43 | environment. 44 | """ 45 | balloon_observation: balloon.BalloonState 46 | wind_at_balloon: wind_field.WindVector 47 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/wind_field_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for balloon_learning_environment.env.wind_field.""" 17 | 18 | import datetime as dt 19 | 20 | from absl.testing import absltest 21 | from balloon_learning_environment.env import wind_field 22 | from balloon_learning_environment.utils import units 23 | 24 | 25 | class WindFieldTest(absltest.TestCase): 26 | 27 | def setUp(self): 28 | super(WindFieldTest, self).setUp() 29 | self.x = units.Distance(km=2.1) 30 | self.y = units.Distance(km=2.2) 31 | self.delta = dt.timedelta(minutes=3) 32 | 33 | def test_some_altitude_goes_north(self): 34 | field = wind_field.SimpleStaticWindField() 35 | self.assertEqual( 36 | wind_field.WindVector( 37 | units.Velocity(mps=0.0), units.Velocity(mps=10.0)), 38 | field.get_forecast(self.x, self.y, 9323.0, self.delta)) 39 | 40 | def test_some_altitude_goes_south(self): 41 | field = wind_field.SimpleStaticWindField() 42 | self.assertEqual( 43 | wind_field.WindVector( 44 | units.Velocity(mps=0.0), units.Velocity(mps=-10.0)), 45 | field.get_forecast(self.x, self.y, 13999.0, self.delta)) 46 | 47 | def test_some_altitude_goes_east(self): 48 | field = wind_field.SimpleStaticWindField() 49 | self.assertEqual( 50 | wind_field.WindVector( 51 | units.Velocity(mps=10.0), units.Velocity(mps=0.0)), 52 | field.get_forecast(self.x, self.y, 5523.0, self.delta)) 53 | 54 | def test_some_altitude_goes_west(self): 55 | field = wind_field.SimpleStaticWindField() 56 | self.assertEqual( 57 | wind_field.WindVector( 58 | units.Velocity(mps=-10.0), units.Velocity(mps=0.0)), 59 | field.get_forecast(self.x, self.y, 11212.0, self.delta)) 60 | 61 | def test_get_forecast_column_gives_same_result_as_get_forecast(self): 62 | field = wind_field.SimpleStaticWindField() 63 | forecast_10k = field.get_forecast(self.x, self.y, 10_000.0, self.delta) 64 | forecast_11k = field.get_forecast(self.x, self.y, 11_000.0, self.delta) 65 | forecast_column = field.get_forecast_column( 66 | self.x, self.y, [10_000.0, 11_000.0], self.delta) 67 | 68 | self.assertEqual(forecast_10k, forecast_column[0]) 69 | self.assertEqual(forecast_11k, forecast_column[1]) 70 | 71 | 72 | if __name__ == '__main__': 73 | absltest.main() 74 | -------------------------------------------------------------------------------- /balloon_learning_environment/env/wind_gp_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for balloon_learning_environment.env.wind_gp.""" 17 | 18 | import datetime as dt 19 | 20 | from absl.testing import absltest 21 | from balloon_learning_environment.env import wind_field 22 | from balloon_learning_environment.env import wind_gp 23 | from balloon_learning_environment.utils import units 24 | 25 | 26 | class WindGpTest(absltest.TestCase): 27 | 28 | def setUp(self): 29 | super(WindGpTest, self).setUp() 30 | # TODO(bellemare): If multiple wind fields are available, use the 31 | # vanilla (4 sheets) one. 32 | wf = wind_field.SimpleStaticWindField() 33 | 34 | # Sets up a model with a dummy forecast. 35 | model = wind_gp.WindGP(wf) 36 | self.model = model 37 | self.x = units.Distance(m=0.0) 38 | self.y = units.Distance(m=0.0) 39 | self.pressure = 0.0 40 | self.delta = dt.timedelta(seconds=0) 41 | self.wind_vector = wind_field.WindVector( 42 | units.Velocity(mps=1.0), units.Velocity(mps=1.0)) 43 | 44 | def test_measurement_has_almost_no_variance(self): 45 | self.model.observe(self.x, self.y, self.pressure, self.delta, 46 | self.wind_vector) 47 | post_measurement = self.model.query( 48 | self.x, self.y, self.pressure, self.delta) 49 | 50 | # Variance is measurement noise at a measured point. 51 | # This constant was computed anaytically. It is given by (see wind_gp.py): 52 | # SIGMA_NOISE_SQUARED / (SIGMA_NOISE_SQUARED + SIGMA_EXP_SQUARED) . 53 | self.assertAlmostEqual(post_measurement[1].item(), 0.003843, places=3) 54 | 55 | def test_observations_affect_forecast_continuously(self): 56 | pre_measurement = self.model.query( 57 | self.x, self.y, self.pressure, self.delta) 58 | self.model.observe(self.x, self.y, self.pressure, self.delta, 59 | self.wind_vector) 60 | post_measurement = self.model.query( 61 | units.Distance(km=0.05), self.y, self.pressure, self.delta) 62 | 63 | self.assertTrue((pre_measurement[0] != post_measurement[0]).all()) 64 | 65 | if __name__ == '__main__': 66 | absltest.main() 67 | -------------------------------------------------------------------------------- /balloon_learning_environment/eval/README.md: -------------------------------------------------------------------------------- 1 | # Balloon Learning Environment Evaluation 2 | This directory includes scripts for evaluating trained agents in the 3 | Balloon Learning Environment. 4 | 5 | ## 1. Run Eval 6 | The following example code will run eval with the random agent on one seed. 7 | For more configurations, see the flags at [eval.py](https://github.com/google/balloon-learning-environment/blob/master/balloon_learning_environment/eval/eval.py). 8 | 9 | ``` 10 | python -m balloon_learning_environment.eval.eval \ 11 | --output_dir=/tmp/ble/eval \ 12 | --agent=random \ 13 | --suite=micro_eval 14 | ``` 15 | An [evaluation suite](https://github.com/google/balloon-learning-environment/blob/master/balloon_learning_environment/eval/suites.py) 16 | can be split into shards (if you want to parallelize 17 | the work) using the `--num_shards` and `--shard_idx` flags. 18 | 19 | 20 | ## 2. Combine Json Files From Shards 21 | If you didn't use shards in step 1, skip to 3. If you used shards, each of 22 | them produced a separate json file. They need to be combined with 23 | `combine_eval_shards.py`. For example, if you ran eval with both station_seeker 24 | and the random agent: 25 | 26 | ``` 27 | python -m balloon_learning_environment.utils.combine_eval_shards \ 28 | --path=/tmp/ble/eval \ 29 | --models=station_seeker --models=random 30 | ``` 31 | 32 | 33 | 34 | ## 3. Visualize Your Results With Colab 35 | Open `balloon_learning_environment/colab/visualize_eval.ipynb` and upload your 36 | json file. 37 | -------------------------------------------------------------------------------- /balloon_learning_environment/eval/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /balloon_learning_environment/eval/combine_eval_shards.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Combines sharded eval results into a single json file. 17 | 18 | """ 19 | 20 | import glob 21 | import json 22 | import os 23 | from typing import Sequence 24 | 25 | from absl import app 26 | from absl import flags 27 | 28 | 29 | # TODO(joshgreaves): Rename models to agents (including README). 30 | flags.DEFINE_string('path', None, 'The path containing the shard results.') 31 | flags.DEFINE_multi_string('models', None, 32 | 'The names of the methods in the directory.') 33 | flags.DEFINE_boolean('pretty_json', False, 34 | 'If true, it will write json files with an indent of 2.') 35 | flags.DEFINE_boolean('flight_paths', False, 36 | 'If True, will include flight paths.') 37 | flags.mark_flags_as_required(['path', 'models']) 38 | FLAGS = flags.FLAGS 39 | 40 | 41 | def main(argv: Sequence[str]) -> None: 42 | del argv # Unused. 43 | 44 | 45 | for model in FLAGS.models: 46 | data = list() 47 | 48 | for path in glob.glob(os.path.join(FLAGS.path, f'{model}_*.json')): 49 | with open(path, 'r') as f: 50 | data.extend(json.load(f)) 51 | 52 | data = sorted(data, key=lambda x: x['seed']) 53 | 54 | if not FLAGS.flight_paths: 55 | for d in data: 56 | d['flight_path'] = [] 57 | 58 | with open(os.path.join(FLAGS.path, f'{model}.json'), 'w') as f: 59 | json.dump(data, f, indent=2 if FLAGS.pretty_json else None) 60 | 61 | 62 | if __name__ == '__main__': 63 | app.run(main) 64 | -------------------------------------------------------------------------------- /balloon_learning_environment/eval/suites.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A collection of evaluation suites.""" 17 | 18 | import dataclasses 19 | from typing import List, Sequence 20 | from balloon_learning_environment.eval import strata_seeds 21 | 22 | 23 | @dataclasses.dataclass 24 | class EvaluationSuite: 25 | """An evaluation suite specification. 26 | 27 | Attributes: 28 | seeds: A sequence of seeds to evaluate the agent on. 29 | max_episode_length: The maximum number of steps to evaluate the agent 30 | on one seed. Must be greater than 0. 31 | """ 32 | seeds: Sequence[int] 33 | max_episode_length: int 34 | 35 | 36 | _eval_suites = dict() 37 | 38 | 39 | _eval_suites['big_eval'] = EvaluationSuite(list(range(10_000)), 960) 40 | _eval_suites['medium_eval'] = EvaluationSuite(list(range(1_000)), 960) 41 | _eval_suites['small_eval'] = EvaluationSuite(list(range(100)), 960) 42 | _eval_suites['tiny_eval'] = EvaluationSuite(list(range(10)), 960) 43 | _eval_suites['micro_eval'] = EvaluationSuite([0], 960) 44 | all_strata = [] 45 | for strata in ['hardest', 'hard', 'mid', 'easy', 'easiest']: 46 | _eval_suites[f'{strata}_strata'] = EvaluationSuite( 47 | strata_seeds.STRATA_SEEDS[strata], 960) 48 | all_strata += strata_seeds.STRATA_SEEDS[strata] 49 | _eval_suites['all_strata'] = EvaluationSuite(all_strata, 960) 50 | 51 | 52 | def available_suites() -> List[str]: 53 | return list(_eval_suites.keys()) 54 | 55 | 56 | def get_eval_suite(name: str) -> EvaluationSuite: 57 | """Gets a named evaluation suite.""" 58 | if name not in _eval_suites: 59 | raise ValueError(f'Unknown eval suite {name}') 60 | 61 | # Copy the seeds, rather than returning a mutable object. 62 | suite = _eval_suites[name] 63 | return EvaluationSuite(list(suite.seeds), suite.max_episode_length) 64 | -------------------------------------------------------------------------------- /balloon_learning_environment/eval/suites_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for suites.""" 17 | 18 | from absl.testing import absltest 19 | from balloon_learning_environment.eval import suites 20 | 21 | 22 | class SuitesTest(absltest.TestCase): 23 | 24 | def test_get_eval_suite_is_successful_for_valid_name(self): 25 | eval_suite = suites.get_eval_suite('big_eval') 26 | 27 | self.assertLen(eval_suite.seeds, 10_000) 28 | 29 | def test_get_eval_suite_raises_error_for_invalid_name(self): 30 | with self.assertRaises(ValueError): 31 | suites.get_eval_suite('invalid name') 32 | 33 | 34 | if __name__ == '__main__': 35 | absltest.main() 36 | -------------------------------------------------------------------------------- /balloon_learning_environment/generated/multi_balloon.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/balloon_learning_environment/generated/multi_balloon.mp4 -------------------------------------------------------------------------------- /balloon_learning_environment/generated/wind_field.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/balloon_learning_environment/generated/wind_field.mp4 -------------------------------------------------------------------------------- /balloon_learning_environment/generative/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /balloon_learning_environment/generative/dataset_wind_field_reservoir.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A wind field reservoir that loads a dataset from a file.""" 17 | 18 | import pickle 19 | from typing import Union 20 | 21 | from absl import logging 22 | from balloon_learning_environment.generative import wind_field_reservoir 23 | import jax 24 | import jax.numpy as jnp 25 | import tensorflow as tf 26 | 27 | 28 | class DatasetWindFieldReservoir(wind_field_reservoir.WindFieldReservoir): 29 | """Retrieves wind fields from an in-memory datastore.""" 30 | 31 | def __init__(self, 32 | data: Union[str, jnp.ndarray], 33 | eval_batch_size: int = 10, 34 | rng_seed=0): 35 | self.eval_batch_size = eval_batch_size 36 | 37 | if isinstance(data, str): 38 | # TODO(scandido): We need to update this to load a single file, with no 39 | # assumed directory/file structure hardcoded. 40 | def _get_shard(i: int): 41 | fn = f'{data}/batch{i:04d}.pickle' 42 | with tf.io.gfile.GFile(fn, 'rb') as f: 43 | arr = pickle.load(f) 44 | return arr 45 | 46 | dataset_shards = [] 47 | for i in range(200): 48 | dataset_shards.append(_get_shard(i)) 49 | logging.info('Loaded shard %d', i) 50 | data = jnp.concatenate(dataset_shards, axis=0) 51 | 52 | self.dataset = data 53 | self._rng = jax.random.PRNGKey(rng_seed) 54 | 55 | def get_batch(self, batch_size: int) -> jnp.ndarray: 56 | """Returns fields used for training. 57 | 58 | Args: 59 | batch_size: The number of fields. 60 | 61 | Returns: 62 | A jax.numpy array that is batch_size x wind field dimensions (see vae.py). 63 | """ 64 | 65 | self._rng, key = jax.random.split(self._rng) 66 | samples = jax.random.choice( 67 | key, 68 | self.dataset.shape[0] - self.eval_batch_size, 69 | shape=(batch_size,), 70 | replace=False) 71 | return self.dataset[samples, ...] 72 | 73 | def get_eval_batch(self) -> jnp.ndarray: 74 | """Returns fields used for eval. 75 | 76 | Returns: 77 | A jax.numpy array that is eval_batch_size x wind field dimensions (see 78 | vae.py). 79 | """ 80 | 81 | return self.dataset[-self.eval_batch_size:, ...] 82 | -------------------------------------------------------------------------------- /balloon_learning_environment/generative/dataset_wind_field_reservoir_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for dataset_wind_field_reservoir.""" 17 | 18 | from absl.testing import absltest 19 | from balloon_learning_environment.generative import dataset_wind_field_reservoir 20 | import jax.numpy as jnp 21 | 22 | 23 | def _make_dataset() -> jnp.ndarray: 24 | train = 0.5 * jnp.ones(shape=(33, 2, 3, 4)) 25 | test = 0.2 * jnp.ones(shape=(10, 2, 3, 4)) 26 | return jnp.concatenate([train, test], axis=0) 27 | 28 | 29 | class DatasetWindFieldReservoirTest(absltest.TestCase): 30 | 31 | def test_reservoir_returns_eval_for_eval(self): 32 | reservoir = dataset_wind_field_reservoir.DatasetWindFieldReservoir( 33 | data=_make_dataset(), eval_batch_size=10) 34 | self.assertTrue(jnp.allclose(reservoir.get_eval_batch(), 35 | 0.2 * jnp.ones(shape=(10, 2, 3, 4)))) 36 | 37 | def test_reservoir_returns_train_for_train(self): 38 | reservoir = dataset_wind_field_reservoir.DatasetWindFieldReservoir( 39 | data=_make_dataset(), eval_batch_size=10) 40 | self.assertTrue(jnp.allclose(reservoir.get_batch(batch_size=33), 41 | 0.5 * jnp.ones(shape=(33, 2, 3, 4)))) 42 | 43 | 44 | if __name__ == '__main__': 45 | absltest.main() 46 | -------------------------------------------------------------------------------- /balloon_learning_environment/generative/vae_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for vae.""" 17 | 18 | from absl.testing import absltest 19 | from balloon_learning_environment.generative import vae 20 | import jax 21 | 22 | 23 | class VaeTest(absltest.TestCase): 24 | 25 | def setUp(self): 26 | super().setUp() 27 | self.key = jax.random.PRNGKey(0) 28 | self.input_shape = (8,) 29 | self.sample_input = jax.random.normal(self.key, self.input_shape) 30 | 31 | def test_encoder_computes_valid_mean_and_logvar(self): 32 | num_latents = 10 33 | encoder = vae.Encoder(num_latents=num_latents) 34 | params = encoder.init(self.key, self.sample_input) 35 | 36 | mean, logvar = encoder.apply(params, self.sample_input) 37 | 38 | # Since mean and logvar are valid in (-inf, inf), just check the shape. 39 | self.assertEqual(mean.shape, (num_latents,)) 40 | self.assertEqual(logvar.shape, (num_latents,)) 41 | 42 | def test_decoder_computes_valid_wind_field(self): 43 | num_latents = 10 44 | sample_latents = jax.random.normal(self.key, (num_latents,)) 45 | 46 | decoder = vae.Decoder() 47 | params = decoder.init(self.key, sample_latents) 48 | 49 | reconstructed_wind_field = decoder.apply(params, sample_latents) 50 | 51 | # Since reconstruction is valid in (-inf, inf), just check the shape. 52 | self.assertEqual(reconstructed_wind_field.shape, 53 | vae.FieldShape().grid_shape()) 54 | 55 | def test_vae_computes_valid_wind_field(self): 56 | num_latents = 10 57 | vae_def = vae.WindFieldVAE(num_latents=num_latents) 58 | params = vae_def.init(self.key, self.sample_input, self.key) 59 | 60 | vae_output = vae_def.apply(params, self.sample_input, self.key) 61 | 62 | self.assertEqual(vae_output.reconstruction.shape, 63 | vae.FieldShape().grid_shape()) 64 | self.assertEqual(vae_output.encoder_output.mean.shape, (num_latents,)) 65 | self.assertEqual(vae_output.encoder_output.logvar.shape, (num_latents,)) 66 | 67 | 68 | if __name__ == '__main__': 69 | absltest.main() 70 | -------------------------------------------------------------------------------- /balloon_learning_environment/generative/wind_field_reservoir.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """An interface to datasets of wind fields used for training a VAE model.""" 17 | 18 | import abc 19 | import jax.numpy as jnp 20 | 21 | 22 | class WindFieldReservoir(abc.ABC): 23 | """Abstract class for wind datasets for training.""" 24 | 25 | @abc.abstractmethod 26 | def get_batch(self, batch_size: int) -> jnp.ndarray: 27 | """Returns fields used for training. 28 | 29 | Args: 30 | batch_size: The number of fields. 31 | 32 | Returns: 33 | A jax.numpy array that is batch_size x wind field dimensions (see vae.py). 34 | """ 35 | 36 | @abc.abstractmethod 37 | def get_eval_batch(self) -> jnp.ndarray: 38 | """Returns fields used for eval. 39 | 40 | Returns: 41 | A jax.numpy array that is batch_size x wind field dimensions (see vae.py). 42 | """ 43 | -------------------------------------------------------------------------------- /balloon_learning_environment/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /balloon_learning_environment/metrics/collector.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Base class for metric collectors. 17 | 18 | Each Collector should subclass this base class, as the CollectorDispatcher 19 | object expects objects of type Collector. 20 | 21 | The methods to implement are: 22 | - `get_name`: a unique identifier for subdirectory creation. 23 | - `pre_training`: called once before training begins. 24 | - `step`: called once for each training step. The parameter is an object of 25 | type `StatisticsInstance` which contains the statistics of the current 26 | training step. 27 | - `end_training`: called once at the end of training, and passes in a 28 | `StatisticsInstance` containing the statistics of the latest training step. 29 | """ 30 | 31 | import abc 32 | import os.path as osp 33 | from typing import Optional 34 | 35 | from balloon_learning_environment.metrics import statistics_instance 36 | import tensorflow as tf 37 | 38 | 39 | class Collector(abc.ABC): 40 | """Abstract class for defining metric collectors.""" 41 | 42 | def __init__(self, 43 | base_dir: Optional[str], 44 | num_actions: int, 45 | current_episode: int): 46 | if base_dir is not None: 47 | self._base_dir = osp.join(base_dir, 'metrics', self.get_name()) 48 | # Try to create logging directory. 49 | try: 50 | tf.io.gfile.makedirs(self._base_dir) 51 | except tf.errors.PermissionDeniedError: 52 | # If it already exists, ignore exception. 53 | pass 54 | else: 55 | self._base_dir = None 56 | self._num_actions = num_actions 57 | self.current_episode = current_episode 58 | self.summary_writer = None # Should be set by subclass, if needed. 59 | 60 | @abc.abstractmethod 61 | def get_name(self) -> str: 62 | pass 63 | 64 | @abc.abstractmethod 65 | def pre_training(self) -> None: 66 | pass 67 | 68 | @abc.abstractmethod 69 | def begin_episode(self) -> None: 70 | pass 71 | 72 | @abc.abstractmethod 73 | def step(self, statistics: statistics_instance.StatisticsInstance) -> None: 74 | pass 75 | 76 | @abc.abstractmethod 77 | def end_episode( 78 | self, statistics: statistics_instance.StatisticsInstance) -> None: 79 | pass 80 | 81 | @abc.abstractmethod 82 | def end_training(self) -> None: 83 | pass 84 | 85 | def has_summary_writer(self) -> bool: 86 | return self.summary_writer is not None 87 | -------------------------------------------------------------------------------- /balloon_learning_environment/metrics/collector_dispatcher.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Class that runs a list of Collectors for metrics reporting. 17 | 18 | This class is what should be called from the main binary and will call each of 19 | the specified collectors for metrics reporting. 20 | 21 | Each metric collector can be further configured via gin bindings. The 22 | constructor for each desired collector should be passed in as a list when 23 | creating this object. All of the collectors are expected to be subclasses of the 24 | `Collector` base class (defined in `collector.py`). 25 | 26 | Example configuration: 27 | ``` 28 | metrics = CollectorDispatcher(base_dir, num_actions, list_of_constructors) 29 | metrics.pre_training() 30 | for i in range(training_steps): 31 | ... 32 | metrics.step(statistics) 33 | metrics.end_training(statistics) 34 | ``` 35 | 36 | The statistics parameter is of type `statistics_instance.StatisticsInstance`, 37 | and contains the raw performance statistics for the current iteration. All 38 | processing (such as averaging) will be handled by each of the individual 39 | collectors. 40 | """ 41 | 42 | from typing import Callable, Optional, Sequence 43 | 44 | from balloon_learning_environment.metrics import collector 45 | from balloon_learning_environment.metrics import console_collector 46 | from balloon_learning_environment.metrics import pickle_collector 47 | from balloon_learning_environment.metrics import statistics_instance 48 | from balloon_learning_environment.metrics import tensorboard_collector 49 | from flax.metrics import tensorboard 50 | 51 | BASE_CONFIG_PATH = 'balloon_learning_environment/metrics/configs' 52 | AVAILABLE_COLLECTORS = { 53 | 'console': console_collector.ConsoleCollector, 54 | 'pickle': pickle_collector.PickleCollector, 55 | 'tensorboard': tensorboard_collector.TensorboardCollector, 56 | } 57 | 58 | 59 | CollectorConstructorType = Callable[[str, int, int], collector.Collector] 60 | 61 | 62 | class CollectorDispatcher(object): 63 | """Class for collecting and reporting Balloon Learning Environment metrics.""" 64 | 65 | def __init__(self, base_dir: Optional[str], num_actions: int, 66 | collectors: Sequence[CollectorConstructorType], 67 | current_episode: int): 68 | self._collectors = [ 69 | collector_constructor(base_dir, num_actions, current_episode) 70 | for collector_constructor in collectors 71 | ] 72 | 73 | def pre_training(self) -> None: 74 | for c in self._collectors: 75 | c.pre_training() 76 | 77 | def begin_episode(self) -> None: 78 | for c in self._collectors: 79 | c.begin_episode() 80 | 81 | def step(self, statistics: statistics_instance.StatisticsInstance) -> None: 82 | for c in self._collectors: 83 | c.step(statistics) 84 | 85 | def end_episode(self, 86 | statistics: statistics_instance.StatisticsInstance) -> None: 87 | for c in self._collectors: 88 | c.end_episode(statistics) 89 | 90 | def end_training(self) -> None: 91 | for c in self._collectors: 92 | c.end_training() 93 | 94 | def get_summary_writer(self) -> Optional[tensorboard.SummaryWriter]: 95 | """Returns the first found instance of a summary_writer, or None.""" 96 | for c in self._collectors: 97 | if c.has_summary_writer(): 98 | return c.summary_writer 99 | return None 100 | -------------------------------------------------------------------------------- /balloon_learning_environment/metrics/collector_dispatcher_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for balloon_learning_environment.metrics.collector_dispatcher.""" 17 | 18 | from absl import flags 19 | from absl.testing import absltest 20 | from balloon_learning_environment.metrics import collector 21 | from balloon_learning_environment.metrics import collector_dispatcher 22 | from balloon_learning_environment.metrics import statistics_instance 23 | 24 | 25 | class CollectorDispatcherTest(absltest.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | self._na = 5 30 | self._tmpdir = flags.FLAGS.test_tmpdir 31 | 32 | def test_with_no_collectors(self): 33 | # This test verifies that we can run successfully with no collectors. 34 | metrics = collector_dispatcher.CollectorDispatcher( 35 | self._tmpdir, self._na, [], 0) 36 | metrics.pre_training() 37 | for _ in range(4): 38 | metrics.begin_episode() 39 | for _ in range(10): 40 | metrics.step(statistics_instance.StatisticsInstance(0, 0, 0, False)) 41 | metrics.end_episode( 42 | statistics_instance.StatisticsInstance(0, 0, 0, False)) 43 | metrics.end_training() 44 | 45 | def test_with_simple_collector(self): 46 | # Create a simple collector that keeps track of received statistics. 47 | logged_stats = [] 48 | 49 | class SimpleCollector(collector.Collector): 50 | 51 | def get_name(self) -> str: 52 | return 'simple' 53 | 54 | def pre_training(self) -> None: 55 | pass 56 | 57 | def begin_episode(self) -> None: 58 | logged_stats.append([]) 59 | 60 | def step(self, statistics) -> None: 61 | logged_stats[-1].append(statistics) 62 | 63 | def end_episode(self, statistics) -> None: 64 | logged_stats[-1].append(statistics) 65 | 66 | def end_training(self) -> None: 67 | pass 68 | 69 | # Create a simple collector that tracks method calls. 70 | counts = { 71 | 'pre_training': 0, 72 | 'begin_episode': 0, 73 | 'step': 0, 74 | 'end_episode': 0, 75 | 'end_training': 0, 76 | } 77 | 78 | class CountCollector(collector.Collector): 79 | 80 | def get_name(self) -> str: 81 | return 'count' 82 | 83 | def pre_training(self) -> None: 84 | counts['pre_training'] += 1 85 | 86 | def begin_episode(self) -> None: 87 | counts['begin_episode'] += 1 88 | 89 | def step(self, statistics) -> None: 90 | counts['step'] += 1 91 | 92 | def end_episode(self, unused_statistics) -> None: 93 | counts['end_episode'] += 1 94 | 95 | def end_training(self) -> None: 96 | counts['end_training'] += 1 97 | 98 | # Run a collection loop. 99 | metrics = collector_dispatcher.CollectorDispatcher( 100 | self._tmpdir, self._na, [SimpleCollector, CountCollector], 0) 101 | metrics.pre_training() 102 | expected_stats = [] 103 | num_episodes = 4 104 | num_steps = 10 105 | for _ in range(num_episodes): 106 | metrics.begin_episode() 107 | expected_stats.append([]) 108 | for j in range(num_steps): 109 | stat = statistics_instance.StatisticsInstance( 110 | step=j, action=num_steps-j, reward=j, terminal=False) 111 | metrics.step(stat) 112 | expected_stats[-1].append(stat) 113 | stat = statistics_instance.StatisticsInstance( 114 | step=num_steps, action=0, reward=num_steps, terminal=True) 115 | metrics.end_episode(stat) 116 | expected_stats[-1].append(stat) 117 | metrics.end_training() 118 | self.assertEqual( 119 | counts, 120 | {'pre_training': 1, 'begin_episode': num_episodes, 121 | 'step': num_episodes * num_steps, 'end_episode': num_episodes, 122 | 'end_training': 1}) 123 | self.assertEqual(expected_stats, logged_stats) 124 | 125 | 126 | if __name__ == '__main__': 127 | absltest.main() 128 | -------------------------------------------------------------------------------- /balloon_learning_environment/metrics/collector_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for balloon_learning_environment.metrics.collector.""" 17 | 18 | import os.path as osp 19 | 20 | from absl import flags 21 | from absl.testing import absltest 22 | from balloon_learning_environment.metrics import collector 23 | 24 | 25 | # A simple subclass that implements the abstract methods. 26 | class SimpleCollector(collector.Collector): 27 | 28 | def get_name(self) -> str: 29 | return 'simple' 30 | 31 | def pre_training(self) -> None: 32 | pass 33 | 34 | def begin_episode(self) -> None: 35 | pass 36 | 37 | def step(self, unused_statistics) -> None: 38 | pass 39 | 40 | def end_episode(self, unused_statistics) -> None: 41 | pass 42 | 43 | def end_training(self) -> None: 44 | pass 45 | 46 | 47 | class CollectorTest(absltest.TestCase): 48 | 49 | def setUp(self): 50 | super().setUp() 51 | self._na = 5 52 | self._tmpdir = flags.FLAGS.test_tmpdir 53 | 54 | def test_instantiate_abstract_class(self): 55 | # It is not possible to instantiate Collector as it has abstract methods. 56 | with self.assertRaises(TypeError): 57 | collector.Collector(self._tmpdir, self._na, 'fail') 58 | 59 | def test_valid_subclass(self): 60 | simple_collector = SimpleCollector(self._tmpdir, self._na, 0) 61 | self.assertEqual(simple_collector._base_dir, 62 | osp.join(self._tmpdir, 'metrics/simple')) 63 | self.assertEqual(self._na, simple_collector._num_actions) 64 | self.assertTrue(osp.exists(simple_collector._base_dir)) 65 | 66 | def test_valid_subclass_with_no_basedir(self): 67 | simple_collector = SimpleCollector(None, self._na, 0) 68 | self.assertIsNone(simple_collector._base_dir) 69 | self.assertEqual(self._na, simple_collector._num_actions) 70 | 71 | 72 | if __name__ == '__main__': 73 | absltest.main() 74 | -------------------------------------------------------------------------------- /balloon_learning_environment/metrics/console_collector.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Collector class for reporting statistics to the console.""" 17 | 18 | import os.path as osp 19 | from typing import Union 20 | 21 | from absl import logging 22 | from balloon_learning_environment.metrics import collector 23 | from balloon_learning_environment.metrics import statistics_instance 24 | import gin 25 | import numpy as np 26 | import tensorflow as tf 27 | 28 | 29 | @gin.configurable(allowlist=['fine_grained_logging', 30 | 'fine_grained_frequency', 31 | 'save_to_file']) 32 | class ConsoleCollector(collector.Collector): 33 | """Collector class for reporting statistics to the console.""" 34 | 35 | def __init__(self, 36 | base_dir: Union[str, None], 37 | num_actions: int, 38 | current_episode: int, 39 | fine_grained_logging: bool = False, 40 | fine_grained_frequency: int = 1, 41 | save_to_file: bool = True): 42 | super().__init__(base_dir, num_actions, current_episode) 43 | if self._base_dir is not None and save_to_file: 44 | self._log_file = osp.join(self._base_dir, 'console.log') 45 | else: 46 | self._log_file = None 47 | self._fine_grained_logging = fine_grained_logging 48 | self._fine_grained_frequency = fine_grained_frequency 49 | 50 | def get_name(self) -> str: 51 | return 'console' 52 | 53 | def pre_training(self) -> None: 54 | if self._log_file is not None: 55 | self._log_file_writer = tf.io.gfile.GFile(self._log_file, 'w') 56 | 57 | def begin_episode(self) -> None: 58 | self._action_counts = np.zeros(self._num_actions) 59 | self._current_episode_reward = 0.0 60 | 61 | def step(self, statistics: statistics_instance.StatisticsInstance) -> None: 62 | self._current_episode_reward += statistics.reward 63 | if statistics.action < 0 or statistics.action >= self._num_actions: 64 | raise ValueError(f'Invalid action: {statistics.action}') 65 | self._action_counts[statistics.action] += 1 66 | if (self._fine_grained_logging 67 | and statistics.step % self._fine_grained_frequency == 0): 68 | step_string = ( 69 | f'Step: {statistics.step}, action: {statistics.action}, ' 70 | f'reward: {statistics.reward}, terminal: {statistics.terminal}\n') 71 | logging.info(step_string) 72 | if self._log_file is not None: 73 | self._log_file_writer.write(step_string) 74 | 75 | def end_episode(self, 76 | statistics: statistics_instance.StatisticsInstance) -> None: 77 | self._current_episode_reward += statistics.reward 78 | self._action_counts[statistics.action] += 1 79 | action_distribution = self._action_counts / np.sum(self._action_counts) 80 | 81 | episode_string = ( 82 | f'Episode {self.current_episode}: ' 83 | f'reward: {self._current_episode_reward:07.2f}, ' # format: 0000.00 84 | f'episode length: {statistics.step}, ' 85 | f'action distribution: {action_distribution}') 86 | logging.info(episode_string) 87 | 88 | if self._log_file is not None: 89 | self._log_file_writer.write(episode_string) # pytype: disable=attribute-error # trace-all-classes 90 | 91 | self.current_episode += 1 92 | 93 | def end_training(self) -> None: 94 | if self._log_file is not None: 95 | self._log_file_writer.close() # pytype: disable=attribute-error # trace-all-classes 96 | -------------------------------------------------------------------------------- /balloon_learning_environment/metrics/pickle_collector.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Collector class for saving episode statistics to a pickle file.""" 17 | 18 | import os.path as osp 19 | import pickle 20 | 21 | from balloon_learning_environment.metrics import collector 22 | from balloon_learning_environment.metrics import statistics_instance 23 | import tensorflow as tf 24 | 25 | 26 | class PickleCollector(collector.Collector): 27 | """Collector class for reporting statistics to the console.""" 28 | 29 | def __init__(self, 30 | base_dir: str, 31 | num_actions: int, 32 | current_episode: int): 33 | if base_dir is None: 34 | raise ValueError('Must specify a base directory for PickleCollector.') 35 | super().__init__(base_dir, num_actions, current_episode) 36 | 37 | def get_name(self) -> str: 38 | return 'pickle' 39 | 40 | def pre_training(self) -> None: 41 | pass 42 | 43 | def begin_episode(self) -> None: 44 | self._statistics = [] 45 | 46 | def step(self, statistics: statistics_instance.StatisticsInstance) -> None: 47 | self._statistics.append(statistics) 48 | 49 | def end_episode(self, 50 | statistics: statistics_instance.StatisticsInstance) -> None: 51 | self._statistics.append(statistics) 52 | pickle_file = osp.join(self._base_dir, 53 | f'pickle_{self.current_episode}.pkl') 54 | with tf.io.gfile.GFile(pickle_file, 'w') as f: 55 | pickle.dump(self._statistics, f, protocol=pickle.HIGHEST_PROTOCOL) 56 | self.current_episode += 1 57 | 58 | def end_training(self) -> None: 59 | pass 60 | -------------------------------------------------------------------------------- /balloon_learning_environment/metrics/statistics_instance.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Class containing statistics from training steps.""" 17 | 18 | import dataclasses 19 | 20 | 21 | @dataclasses.dataclass 22 | class StatisticsInstance: 23 | """Performance statistics to be passed to each of the collectors.""" 24 | step: int 25 | action: int 26 | reward: float 27 | terminal: bool 28 | -------------------------------------------------------------------------------- /balloon_learning_environment/metrics/tensorboard_collector.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Collector class for exporting statistics to Tensorboard.""" 17 | 18 | from balloon_learning_environment.metrics import collector 19 | from balloon_learning_environment.metrics import statistics_instance 20 | from flax.metrics import tensorboard 21 | import gin 22 | 23 | 24 | @gin.configurable(allowlist=['fine_grained_logging', 25 | 'fine_grained_frequency']) 26 | class TensorboardCollector(collector.Collector): 27 | """Collector class for reporting statistics on Tensorboard.""" 28 | 29 | def __init__(self, 30 | base_dir: str, 31 | num_actions: int, 32 | current_episode: int, 33 | fine_grained_logging: bool = False, 34 | fine_grained_frequency: int = 1): 35 | if not isinstance(base_dir, str): 36 | raise ValueError( 37 | 'Must specify a base directory for TensorboardCollector.') 38 | super().__init__(base_dir, num_actions, current_episode) 39 | self._fine_grained_logging = fine_grained_logging 40 | self._fine_grained_frequency = fine_grained_frequency 41 | self.summary_writer = tensorboard.SummaryWriter(self._base_dir) 42 | 43 | def get_name(self) -> str: 44 | return 'tensorboard' 45 | 46 | def pre_training(self) -> None: 47 | # TODO(joshgreaves): This is wrong if we are starting from a checkpoint. 48 | self._global_step = 0 49 | 50 | def begin_episode(self) -> None: 51 | self._episode_reward = 0.0 52 | 53 | def _log_fine_grained_statistics( 54 | self, statistics: statistics_instance.StatisticsInstance) -> None: 55 | self.summary_writer.scalar('Train/FineGrainedReward', statistics.reward, 56 | self._global_step) # pytype: disable=attribute-error # trace-all-classes 57 | self.summary_writer.flush() 58 | 59 | def step(self, statistics: statistics_instance.StatisticsInstance) -> None: 60 | if self._fine_grained_logging: 61 | if self._global_step % self._fine_grained_frequency == 0: 62 | self._log_fine_grained_statistics(statistics) 63 | self._global_step += 1 64 | self._episode_reward += statistics.reward 65 | 66 | def end_episode(self, 67 | statistics: statistics_instance.StatisticsInstance) -> None: 68 | if self._fine_grained_logging: 69 | self._log_fine_grained_statistics(statistics) 70 | self._episode_reward += statistics.reward 71 | self.summary_writer.scalar('Train/EpisodeReward', self._episode_reward, 72 | self.current_episode) 73 | self.summary_writer.scalar('Train/EpisodeLength', statistics.step, 74 | self.current_episode) 75 | self.summary_writer.flush() 76 | self._global_step += 1 # pytype: disable=attribute-error # trace-all-classes 77 | self.current_episode += 1 78 | 79 | def end_training(self) -> None: 80 | pass 81 | -------------------------------------------------------------------------------- /balloon_learning_environment/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /balloon_learning_environment/models/models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convnience functions for loading models.""" 17 | 18 | from importlib import resources 19 | import os 20 | from typing import Optional 21 | 22 | import gin 23 | import tensorflow as tf 24 | 25 | _MODEL_ROOT = 'balloon_learning_environment/models/' 26 | _OFFLINE_SKIES22_RELATIVE_PATH = os.path.join( 27 | _MODEL_ROOT, 'offlineskies22_decoder.msgpack') 28 | _PERCIATELLI44_RELATIVE_PATH = os.path.join( 29 | _MODEL_ROOT, 'perciatelli44.pb') 30 | 31 | 32 | @gin.configurable 33 | def load_offlineskies22(path: Optional[str] = None) -> bytes: 34 | """Loads offlineskies22 serialized wind VAE parameters. 35 | 36 | There are three places this function looks: 37 | 1. At the path specified, if one is specified. 38 | 2. Under the models package using importlib.resources. It should be 39 | found there if the code was installed with pip. 40 | 3. Relative to the project root. It should be found there if running 41 | from a freshly cloned repo. 42 | 43 | Args: 44 | path: An optional path to load the VAE weights from. 45 | 46 | Returns: 47 | The serialized VAE weights as bytes. 48 | 49 | Raises: 50 | ValueError: if a path is specified but the weights can't be loaded. 51 | RuntimeError: if the weights couldn't be found in any of the 52 | specified locations. 53 | """ 54 | # Attempt 1: Load from path, if specified. 55 | # If a path is specified, we expect it is a good path. 56 | if path is not None: 57 | try: 58 | with tf.io.gfile.GFile(path, 'rb') as f: 59 | return f.read() 60 | except tf.errors.NotFoundError: 61 | raise ValueError(f'offlineskies22 checkpoint not found at {path}') 62 | 63 | # Attempt 2: Load from location expected in the built wheel. 64 | try: 65 | with resources.open_binary('balloon_learning_environment.models', 66 | 'offlineskies22_decoder.msgpack') as f: 67 | return f.read() 68 | except FileNotFoundError: 69 | pass 70 | 71 | # Attempt 3: Load from the path relative to the source root. 72 | try: 73 | with tf.io.gfile.GFile(_OFFLINE_SKIES22_RELATIVE_PATH, 'rb') as f: 74 | return f.read() 75 | except tf.errors.NotFoundError: 76 | pass 77 | 78 | raise RuntimeError( 79 | 'Unable to load wind VAE checkpoint from the expected locations.') 80 | 81 | 82 | @gin.configurable 83 | def load_perciatelli44(path: Optional[str] = None) -> bytes: 84 | """Loads Perciatelli44.pb as bytes. 85 | 86 | There are three places this function looks: 87 | 1. At the path specified, if one is specified. 88 | 2. Under the models package using importlib.resources. It should be 89 | found there if the code was installed with pip. 90 | 3. Relative to the project root. It should be found there if running 91 | from a freshly cloned repo. 92 | 93 | Args: 94 | path: An optional path to load the VAE weights from. 95 | 96 | Returns: 97 | The serialized VAE weights as bytes. 98 | 99 | Raises: 100 | ValueError: if a path is specified but the weights can't be loaded. 101 | RuntimeError: if the weights couldn't be found in any of the 102 | specified locations. 103 | """ 104 | # Attempt 1: Load from path, if specified. 105 | # If a path is specified, we expect it is a good path. 106 | if path is not None: 107 | try: 108 | with tf.io.gfile.GFile(path, 'rb') as f: 109 | return f.read() 110 | except tf.errors.NotFoundError: 111 | raise ValueError(f'perciatelli44 checkpoint not found at {path}') 112 | 113 | # Attempt 2: Load from location expected in the built wheel. 114 | try: 115 | with resources.open_binary('balloon_learning_environment.models', 116 | 'perciatelli44.pb') as f: 117 | return f.read() 118 | except FileNotFoundError: 119 | pass 120 | 121 | # Attempt 3: Load from the path relative to the source root. 122 | try: 123 | with tf.io.gfile.GFile(_PERCIATELLI44_RELATIVE_PATH, 'rb') as f: 124 | return f.read() 125 | except FileNotFoundError: 126 | pass 127 | 128 | raise RuntimeError( 129 | 'Unable to load Perciatelli44 checkpoint from the expected locations.') 130 | -------------------------------------------------------------------------------- /balloon_learning_environment/models/models_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for models.""" 17 | 18 | from importlib import resources 19 | import io 20 | from unittest import mock 21 | 22 | from absl.testing import absltest 23 | from absl.testing import parameterized 24 | from balloon_learning_environment.models import models 25 | import tensorflow as tf 26 | 27 | 28 | class ModelsTest(parameterized.TestCase): 29 | 30 | @parameterized.named_parameters( 31 | dict(testcase_name='openskies22', load_fn=models.load_offlineskies22), 32 | dict(testcase_name='perciatelli44', load_fn=models.load_offlineskies22)) 33 | def test_load_with_specified_path_loads_data(self, load_fn): 34 | # Write some fake data in a tmpdir. 35 | tmpfile = self.create_tempfile() 36 | fake_content = b'fake content from specified path' 37 | with open(tmpfile, 'wb') as f: 38 | f.write(fake_content) 39 | 40 | result = load_fn(tmpfile.full_path) 41 | 42 | self.assertEqual(result, fake_content) 43 | 44 | @parameterized.named_parameters( 45 | dict(testcase_name='openskies22', load_fn=models.load_offlineskies22), 46 | dict(testcase_name='perciatelli44', load_fn=models.load_offlineskies22)) 47 | def test_load_with_wrong_path_fails(self, load_fn): 48 | with self.assertRaises(ValueError): 49 | load_fn('this_is_not_a_valid_path') 50 | 51 | @parameterized.named_parameters( 52 | dict(testcase_name='openskies22', load_fn=models.load_offlineskies22), 53 | dict(testcase_name='perciatelli44', load_fn=models.load_offlineskies22)) 54 | @mock.patch.object(resources, 'open_binary', autospec=True) 55 | def test_load_uses_importlib_if_no_path_is_specified(self, 56 | mock_open, 57 | load_fn): 58 | fake_content = b'fake content from importlib.resources' 59 | mock_open.return_value = io.BytesIO(fake_content) 60 | 61 | result = load_fn() 62 | 63 | mock_open.assert_called_once() 64 | self.assertEqual(result, fake_content) 65 | 66 | @parameterized.named_parameters( 67 | dict(testcase_name='openskies22', load_fn=models.load_offlineskies22), 68 | dict(testcase_name='perciatelli44', load_fn=models.load_offlineskies22)) 69 | @mock.patch.object(tf.io.gfile, 'GFile', autospec=True) 70 | def test_load_uses_default_path_as_last_resort(self, 71 | mock_gfile, 72 | load_fn): 73 | fake_content = b'fake content from default path' 74 | mock_gfile.return_value = io.BytesIO(fake_content) 75 | 76 | result = load_fn() 77 | 78 | mock_gfile.assert_called_once() 79 | self.assertEqual(result, fake_content) 80 | 81 | @parameterized.named_parameters( 82 | dict(testcase_name='openskies22', load_fn=models.load_offlineskies22), 83 | dict(testcase_name='perciatelli44', load_fn=models.load_offlineskies22)) 84 | def test_load_raises_runtime_error_if_not_found(self, 85 | load_fn): 86 | with self.assertRaises(RuntimeError): 87 | load_fn() 88 | 89 | if __name__ == '__main__': 90 | absltest.main() 91 | -------------------------------------------------------------------------------- /balloon_learning_environment/models/offlineskies22_decoder.msgpack: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/balloon_learning_environment/models/offlineskies22_decoder.msgpack -------------------------------------------------------------------------------- /balloon_learning_environment/models/perciatelli44.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/balloon_learning_environment/models/perciatelli44.pb -------------------------------------------------------------------------------- /balloon_learning_environment/train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Main entry point for the Balloon Learning Environment. 17 | 18 | """ 19 | 20 | import os.path as osp 21 | 22 | from absl import app 23 | from absl import flags 24 | from balloon_learning_environment import train_lib 25 | from balloon_learning_environment.env.rendering import matplotlib_renderer 26 | from balloon_learning_environment.utils import run_helpers 27 | import gym 28 | import matplotlib 29 | import numpy as np 30 | 31 | 32 | flags.DEFINE_string('agent', 'dqn', 'Type of agent to create.') 33 | flags.DEFINE_string('env_name', 'BalloonLearningEnvironment-v0', 34 | 'Name of environment to create.') 35 | flags.DEFINE_integer('num_iterations', 200, 'Number of episodes to train for.') 36 | flags.DEFINE_integer('max_episode_length', 960, 37 | 'Maximum number of steps per episode. Assuming 2 days, ' 38 | 'with each step lasting 3 minutes.') 39 | flags.DEFINE_string('base_dir', None, 40 | 'Directory where to store statistics/images.') 41 | flags.DEFINE_integer( 42 | 'run_number', 1, 43 | 'When running multiple agents in parallel, this number ' 44 | 'differentiates between the runs. It is appended to base_dir.') 45 | flags.DEFINE_string( 46 | 'wind_field', 'generative', 47 | 'The wind field type to use. See the _WIND_FIELDS dict below for options.') 48 | flags.DEFINE_string('agent_gin_file', None, 49 | 'Gin file for agent configuration.') 50 | flags.DEFINE_multi_string('collectors', ['console'], 51 | 'Collectors to include in metrics collection.') 52 | flags.DEFINE_multi_string('gin_bindings', [], 53 | 'Gin bindings to override default values.') 54 | flags.DEFINE_string( 55 | 'renderer', None, 56 | 'The renderer to use. Note that it is fastest to have this set to None.') 57 | flags.DEFINE_integer( 58 | 'render_period', 10, 59 | 'The period to render with. Only has an effect if renderer is not None.') 60 | flags.DEFINE_integer( 61 | 'episodes_per_iteration', 50, 62 | 'The number of episodes to run in one iteration. Checkpointing occurs ' 63 | 'at the end of each iteration.') 64 | flags.mark_flag_as_required('base_dir') 65 | FLAGS = flags.FLAGS 66 | 67 | 68 | _RENDERERS = { 69 | 'matplotlib': matplotlib_renderer.MatplotlibRenderer, 70 | } 71 | 72 | 73 | def main(_) -> None: 74 | # Prepare metric collector gin files and constructors. 75 | collector_constructors = train_lib.get_collector_data(FLAGS.collectors) 76 | run_helpers.bind_gin_variables(FLAGS.agent, 77 | FLAGS.agent_gin_file, 78 | FLAGS.gin_bindings) 79 | 80 | renderer = None 81 | if FLAGS.renderer is not None: 82 | renderer = _RENDERERS[FLAGS.renderer]() 83 | 84 | wf_factory = run_helpers.get_wind_field_factory(FLAGS.wind_field) 85 | env = gym.make(FLAGS.env_name, 86 | wind_field_factory=wf_factory, 87 | renderer=renderer) 88 | 89 | agent = run_helpers.create_agent( 90 | FLAGS.agent, 91 | env.action_space.n, 92 | observation_shape=env.observation_space.shape) 93 | 94 | base_dir = osp.join(FLAGS.base_dir, FLAGS.agent, str(FLAGS.run_number)) 95 | train_lib.run_training_loop( 96 | base_dir, 97 | env, 98 | agent, 99 | FLAGS.num_iterations, 100 | FLAGS.max_episode_length, 101 | collector_constructors, 102 | render_period=FLAGS.render_period, 103 | episodes_per_iteration=FLAGS.episodes_per_iteration) 104 | 105 | if FLAGS.base_dir is not None: 106 | image_save_path = osp.join(FLAGS.base_dir, 'balloon_path.png') 107 | img = env.render(mode='rgb_array') 108 | if isinstance(img, np.ndarray): 109 | matplotlib.image.imsave(image_save_path, img) 110 | 111 | 112 | if __name__ == '__main__': 113 | app.run(main) 114 | -------------------------------------------------------------------------------- /balloon_learning_environment/train_acme_qrdqn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Entry point for Acme QrDQN training on the BLE.""" 17 | 18 | 19 | from absl import app 20 | from absl import flags 21 | import acme 22 | from acme.jax.deprecated import local_layout 23 | from acme.utils import counting 24 | from acme.utils import loggers 25 | from balloon_learning_environment import acme_utils 26 | import jax 27 | 28 | flags.DEFINE_integer('num_episodes', 1000, 'Number of episodes to train for.') 29 | flags.DEFINE_integer( 30 | 'max_episode_length', 960, 31 | 'Maximum number of steps per episode. Assuming 2 days, ' 32 | 'with each step lasting 3 minutes.') 33 | flags.DEFINE_integer('seed', 0, 'Random seed.') 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | 38 | def main(_) -> None: 39 | 40 | env = acme_utils.create_env(False, FLAGS.max_episode_length) 41 | eval_env = acme_utils.create_env(True, FLAGS.max_episode_length) 42 | env_spec = acme.make_environment_spec(env) 43 | 44 | (rl_agent, config, dqn_network_fn, behavior_policy_fn, 45 | eval_policy_fn) = acme_utils.create_dqn({}) 46 | dqn_network = dqn_network_fn(env_spec) 47 | 48 | counter = counting.Counter(time_delta=0.) 49 | 50 | agent = local_layout.LocalLayout( 51 | seed=FLAGS.seed, 52 | environment_spec=env_spec, 53 | builder=rl_agent, 54 | networks=dqn_network, 55 | policy_network=behavior_policy_fn(dqn_network), 56 | batch_size=config.batch_size, 57 | prefetch_size=4, 58 | counter=counting.Counter(counter, 'learner')) 59 | 60 | eval_actor = rl_agent.make_actor( 61 | jax.random.PRNGKey(0), 62 | policy=eval_policy_fn(dqn_network), 63 | environment_spec=env_spec, 64 | variable_source=agent) 65 | 66 | actor_logger = loggers.make_default_logger('actor') 67 | evaluator_logger = loggers.make_default_logger('evaluator') 68 | 69 | loop = acme.EnvironmentLoop( 70 | env, 71 | agent, 72 | logger=actor_logger, 73 | counter=counting.Counter(counter, 'actor', time_delta=0.)) 74 | eval_loop = acme.EnvironmentLoop( 75 | eval_env, 76 | eval_actor, 77 | logger=evaluator_logger, 78 | counter=counting.Counter(counter, 'evaluator', time_delta=0.)) 79 | for _ in range(FLAGS.num_episodes): 80 | loop.run(1) 81 | eval_loop.run(1) 82 | 83 | 84 | if __name__ == '__main__': 85 | app.run(main) 86 | -------------------------------------------------------------------------------- /balloon_learning_environment/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /balloon_learning_environment/utils/constants.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Common constants used throughout the codebase.""" 17 | 18 | import datetime as dt 19 | 20 | 21 | # --- Physics Constants --- 22 | 23 | GRAVITY: float = 9.80665 # [m/s^2] 24 | NUM_SECONDS_PER_HOUR = 3_600 25 | NUM_SECONDS_PER_DAY: int = 86_400 26 | UNIVERSAL_GAS_CONSTANT: float = 8.3144621 # [J/(mol.K)] 27 | DRY_AIR_MOLAR_MASS: float = 0.028964922481160 # Dry Air. [kg/mol] 28 | HE_MOLAR_MASS: float = 0.004002602 # Helium. [kg/mol] 29 | DRY_AIR_SPECIFIC_GAS_CONSTANT: float = ( 30 | UNIVERSAL_GAS_CONSTANT / DRY_AIR_MOLAR_MASS) # [J/(kg.K)] 31 | 32 | 33 | # --- RL constants --- 34 | # Amount of time that elapses between agent steps. 35 | AGENT_TIME_STEP: dt.timedelta = dt.timedelta(minutes=3) 36 | # Pressure limits for the Perciatelli features. 37 | PERCIATELLI_PRESSURE_RANGE_MIN: int = 5000 38 | PERCIATELLI_PRESSURE_RANGE_MAX: int = 14000 39 | 40 | -------------------------------------------------------------------------------- /balloon_learning_environment/utils/run_helpers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Helper functions for running agents in train/eval.""" 17 | 18 | import os 19 | from typing import Callable, Optional, Sequence 20 | 21 | import balloon_learning_environment 22 | from balloon_learning_environment.agents import agent as base_agent 23 | from balloon_learning_environment.agents import agent_registry 24 | from balloon_learning_environment.env import generative_wind_field 25 | from balloon_learning_environment.env import wind_field 26 | import gin 27 | 28 | 29 | 30 | def get_agent_gin_file(agent_name: str, 31 | gin_file: Optional[str]) -> Optional[str]: 32 | """Gets a gin file for a specified agent. 33 | 34 | If gin_file is specified, that is the gin_file that will be used. 35 | However, if no gin file is specified, it will use the default gin file 36 | for that agent. If the agent has no default gin file, this may return None. 37 | 38 | Args: 39 | agent_name: The name of the agent to retrieve the gin file for. 40 | gin_file: An optional gin file to override the agent's default gin file. 41 | 42 | Returns: 43 | A path to a gin file, or None. 44 | """ 45 | return (agent_registry.get_default_gin_config(agent_name) 46 | if gin_file is None else gin_file) 47 | 48 | 49 | def create_agent(agent_name: str, num_actions: int, 50 | observation_shape: Sequence[int]) -> base_agent.Agent: 51 | return agent_registry.agent_constructor(agent_name)( 52 | num_actions, observation_shape=observation_shape) 53 | 54 | 55 | def get_wind_field_factory( 56 | wind_field_name: str) -> Callable[[], wind_field.WindField]: 57 | """Gets a wind field by name. 58 | 59 | If the wind field name doesn't exist, raises a ValueError. 60 | 61 | Args: 62 | wind_field_name: The name of the wind field to create. 63 | 64 | Returns: 65 | A callable that returns a WindField object. 66 | """ 67 | if wind_field_name == 'simple': 68 | return wind_field.SimpleStaticWindField 69 | elif wind_field_name == 'generative': 70 | return generative_wind_field.generative_wind_field_factory 71 | else: 72 | raise ValueError(f'Unknown wind field {wind_field_name}') 73 | 74 | 75 | def bind_gin_variables( 76 | agent: str, 77 | agent_gin_file: Optional[str] = None, 78 | gin_bindings: Sequence[str] = (), 79 | additional_gin_files: Sequence[str] = () 80 | ) -> None: 81 | """A helper function for binding gin variables for an experiment. 82 | 83 | Args: 84 | agent: The agent being used in the experiment. 85 | agent_gin_file: An optional path to a gin file to override the agent's 86 | default gin file. 87 | gin_bindings: An optional list of gin bindings passed in on the command 88 | line. 89 | additional_gin_files: Any other additional paths to gin files that should be 90 | parsed and bound. 91 | """ 92 | gin_files = [] 93 | 94 | # The gin file paths start with balloon_learning_environment, 95 | # so we need to add the parent directory to the search path. 96 | ble_root = os.path.dirname(balloon_learning_environment.__file__) 97 | ble_parent_dir = os.path.dirname(ble_root) 98 | gin.add_config_file_search_path(ble_parent_dir) 99 | 100 | agent_gin_file = get_agent_gin_file(agent, agent_gin_file) 101 | if agent_gin_file is not None: 102 | gin_files.append(agent_gin_file) 103 | 104 | gin_files.extend(additional_gin_files) 105 | gin.parse_config_files_and_bindings( 106 | gin_files, bindings=gin_bindings, skip_unknown=False) 107 | -------------------------------------------------------------------------------- /balloon_learning_environment/utils/sampling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for balloon_learning_environment.utils.sampling.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from balloon_learning_environment.env.balloon import standard_atmosphere 21 | from balloon_learning_environment.utils import sampling 22 | from balloon_learning_environment.utils import units 23 | import jax 24 | 25 | 26 | class SamplingTest(parameterized.TestCase): 27 | 28 | def setUp(self): 29 | super().setUp() 30 | # Deterministic PRNG state. Tests MUST NOT rely on a specific seed. 31 | self.prng_key = jax.random.PRNGKey(123) 32 | self.atmosphere = standard_atmosphere.Atmosphere(jax.random.PRNGKey(0)) 33 | 34 | def test_sample_location_with_seed_gives_deterministic_lat_lng(self): 35 | location1 = sampling.sample_location(self.prng_key) 36 | location2 = sampling.sample_location(self.prng_key) 37 | 38 | self.assertEqual(location1, location2) 39 | 40 | def test_sample_location_gives_valid_lat_lng(self): 41 | latlng = sampling.sample_location(self.prng_key) 42 | 43 | # We only allow locations near the equator 44 | self.assertBetween(latlng.lat().degrees, -10.0, 10.0) 45 | # We don't allow locations near the international date line 46 | self.assertBetween(latlng.lng().degrees, -175.0, 175.0) 47 | 48 | def test_sample_time_with_seed_gives_deterministic_time(self): 49 | t1 = sampling.sample_time(self.prng_key) 50 | t2 = sampling.sample_time(self.prng_key) 51 | 52 | self.assertEqual(t1, t2) 53 | 54 | def test_sample_time_gives_time_within_range(self): 55 | # Pick a 1 hour segment to give a small valid range for testing 56 | begin_range = units.datetime(year=2020, month=1, day=1, hour=1) 57 | end_range = units.datetime(year=2020, month=1, day=1, hour=2) 58 | t = sampling.sample_time( 59 | self.prng_key, begin_range=begin_range, end_range=end_range) 60 | 61 | self.assertBetween(t, begin_range, end_range) 62 | 63 | def test_sample_pressure_with_seed_gives_deterministic_pressure(self): 64 | p1 = sampling.sample_pressure(self.prng_key, self.atmosphere) 65 | p2 = sampling.sample_pressure(self.prng_key, self.atmosphere) 66 | 67 | self.assertEqual(p1, p2) 68 | 69 | def test_sample_pressure_gives_pressure_within_range(self): 70 | p = sampling.sample_pressure(self.prng_key, self.atmosphere) 71 | self.assertBetween(p, 5000, 14000) 72 | 73 | def test_sample_upwelling_infrared_is_within_range(self): 74 | ir = sampling.sample_upwelling_infrared(self.prng_key) 75 | self.assertBetween(ir, 100.0, 350.0) 76 | 77 | @parameterized.named_parameters( 78 | dict(testcase_name='logit_normal', distribution_type='logit_normal'), 79 | dict(testcase_name='inverse_lognormal', 80 | distribution_type='inverse_lognormal')) 81 | def test_sample_upwelling_infrared_is_within_range_nondefault( 82 | self, distribution_type): 83 | ir = sampling.sample_upwelling_infrared(self.prng_key, 84 | distribution_type=distribution_type) 85 | self.assertBetween(ir, 100.0, 350.0) 86 | 87 | def test_sample_upwelling_infrared_invalid_distribution_type(self): 88 | with self.assertRaises(ValueError): 89 | sampling.sample_upwelling_infrared(self.prng_key, 90 | distribution_type='invalid') 91 | 92 | 93 | if __name__ == '__main__': 94 | absltest.main() 95 | -------------------------------------------------------------------------------- /balloon_learning_environment/utils/spherical_geometry.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Spherical geometry functions.""" 17 | 18 | import copyreg 19 | import math 20 | from typing import Any 21 | 22 | from balloon_learning_environment.utils import units 23 | 24 | import s2sphere as s2 25 | 26 | 27 | # We use the spherical Earth approximation rather than WGS-84, which simplifies 28 | # things and is appropriate for the use case of our simulator. 29 | _EARTH_RADIUS = units.Distance(km=6371) 30 | 31 | 32 | # Note: We _must_ register a pickle function for LatLng, otherwise 33 | # they break when used with gin or dataclasses.astuple. Basically anywhere 34 | # the value may be copied. 35 | # This module should be included most places s2LatLng is used eventually, 36 | # but it may be beneficial to enforce it is included at some later date. 37 | # These type hints are bad, but it's what copyreg wants 😬. 38 | def pickle_latlng(obj: Any) -> tuple: # pylint: disable=g-bare-generic 39 | return s2.LatLng.from_degrees, (obj.lat().degrees, obj.lng().degrees) 40 | 41 | copyreg.pickle(s2.LatLng, pickle_latlng) 42 | 43 | 44 | def calculate_latlng_from_offset(center_latlng: s2.LatLng, 45 | x: units.Distance, 46 | y: units.Distance) -> s2.LatLng: 47 | """Calculates a new lat lng given an origin and x y offsets. 48 | 49 | Args: 50 | center_latlng: The starting latitude and longitude. 51 | x: An offset from center_latlng parallel to longitude. 52 | y: An offset from center_latlng parallel to latitude. 53 | 54 | Returns: 55 | A new latlng that is the specified distance from the start latlng. 56 | """ 57 | # x and y are swapped to give heading with 0 degrees = North. 58 | # This is equivalent to pi / 2 - atan2(y, x). 59 | heading = math.atan2(x.km, y.km) # In radians. 60 | angle = units.relative_distance(x, y) / _EARTH_RADIUS # In radians. 61 | 62 | cos_angle = math.cos(angle) 63 | sin_angle = math.sin(angle) 64 | sin_from_lat = math.sin(center_latlng.lat().radians) 65 | cos_from_lat = math.cos(center_latlng.lat().radians) 66 | 67 | sin_lat = (cos_angle * sin_from_lat + 68 | sin_angle * cos_from_lat * math.cos(heading)) 69 | d_lng = math.atan2(sin_angle * cos_from_lat * math.sin(heading), 70 | cos_angle - sin_from_lat * sin_lat) 71 | 72 | new_lat = math.asin(sin_lat) 73 | new_lat = min(max(new_lat, -math.pi / 2.0), math.pi / 2.0) 74 | new_lng = center_latlng.lng().radians + d_lng 75 | 76 | return s2.LatLng.from_radians(new_lat, new_lng).normalized() 77 | -------------------------------------------------------------------------------- /balloon_learning_environment/utils/spherical_geometry_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for spherical_geometry.""" 17 | 18 | import math 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | from balloon_learning_environment.utils import spherical_geometry 23 | from balloon_learning_environment.utils import units 24 | 25 | import s2sphere as s2 26 | 27 | 28 | class SphericalGeometryTest(parameterized.TestCase): 29 | 30 | @parameterized.named_parameters( 31 | dict(testcase_name='equator', lat=0.0, lng=45.0), 32 | dict(testcase_name='north_pole', lat=80.0, lng=-135.0), 33 | dict(testcase_name='south_pole', lat=-85.0, lng=92.0),) 34 | def test_offset_latlng_gives_1_degree_latitude_per_111km_everywhere( 35 | self, lat: float, lng: float): 36 | center_latlng = s2.LatLng.from_degrees(lat, lng) 37 | x = units.Distance(km=0.0) 38 | y = units.Distance(km=111.0) 39 | 40 | new_latlng = spherical_geometry.calculate_latlng_from_offset( 41 | center_latlng, x, y) 42 | 43 | self.assertAlmostEqual(new_latlng.lat().degrees, lat + 1.0, places=2) 44 | 45 | def test_offset_latlng_gives_1_degree_longitude_per_111km_at_equator(self): 46 | center_latlng = s2.LatLng.from_degrees(0.0, 0.0) 47 | x = units.Distance(km=111.0) # About one degree longitude at equator. 48 | y = units.Distance(km=0.0) 49 | 50 | new_latlng = spherical_geometry.calculate_latlng_from_offset( 51 | center_latlng, x, y) 52 | 53 | self.assertAlmostEqual(new_latlng.lng().degrees, 1.0, places=2) 54 | 55 | def test_offset_latlng_gives_larger_longitude_change_away_from_equator(self): 56 | center_latlng = s2.LatLng.from_degrees(45.0, 0.0) 57 | x = units.Distance(km=111.0) # > 1 degree longitude away from equator. 58 | y = units.Distance(km=0.0) 59 | 60 | new_latlng = spherical_geometry.calculate_latlng_from_offset( 61 | center_latlng, x, y) 62 | 63 | # The change in longitude should be greater than at the equator, but not 64 | # too much greater. These numbers are somewhat arbitrary - they are more 65 | # of a sanity check. The only alternative to this is to directly 66 | # re-write the formula. 67 | self.assertBetween(new_latlng.lng().degrees, 1.25, 1.75) 68 | 69 | def test_offset_latlng_wraps_around_north_pole(self): 70 | center_latlng = s2.LatLng.from_degrees(89.0, -90.0) 71 | x = units.Distance(km=0.0) 72 | y = units.Distance(km=222.0) # About 2 degrees latitude. 73 | 74 | new_latlng = spherical_geometry.calculate_latlng_from_offset( 75 | center_latlng, x, y) 76 | 77 | # We have gone over the North pole, so latitude should be 89 again, 78 | # but we have gon half way around the world longitudinally. 79 | self.assertAlmostEqual(new_latlng.lat().degrees, 89.0, places=2) 80 | self.assertAlmostEqual(new_latlng.lng().degrees, 90.0, places=2) 81 | 82 | @parameterized.named_parameters( 83 | dict( 84 | testcase_name='west_to_east', lng=179.0, degrees=2.0, 85 | expected=-179.0), 86 | dict( 87 | testcase_name='east_to_west', lng=-179.0, degrees=-2.0, 88 | expected=179.0), 89 | dict( 90 | testcase_name='multiple_times', 91 | lng=0.0, 92 | degrees=1440, 93 | expected=0.0)) 94 | def test_offset_latlng_wraps_around_longitude(self, lng: float, 95 | degrees: float, 96 | expected: float): 97 | center_latlng = s2.LatLng.from_degrees(0.0, lng) 98 | # arc_length = radius * angle_radians. 99 | x = spherical_geometry._EARTH_RADIUS * math.radians(degrees) 100 | y = units.Distance(km=0.0) 101 | 102 | new_latlng = spherical_geometry.calculate_latlng_from_offset( 103 | center_latlng, x, y) 104 | 105 | self.assertAlmostEqual(new_latlng.lng().degrees, expected) 106 | 107 | 108 | if __name__ == '__main__': 109 | absltest.main() 110 | -------------------------------------------------------------------------------- /balloon_learning_environment/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Common transforms.""" 17 | 18 | import typing 19 | from typing import Union 20 | 21 | import numpy as np 22 | 23 | 24 | def _contains_negative_values(x: Union[float, np.ndarray]) -> bool: 25 | if isinstance(x, np.ndarray): 26 | return (x < 0).any() 27 | else: 28 | return x < 0 29 | 30 | 31 | @typing.overload 32 | def linear_rescale_with_extrapolation(x: np.ndarray, 33 | vmin: float, 34 | vmax: float) -> np.ndarray: 35 | ... 36 | 37 | 38 | @typing.overload 39 | def linear_rescale_with_extrapolation(x: float, 40 | vmin: float, 41 | vmax: float) -> float: 42 | ... 43 | 44 | 45 | def linear_rescale_with_extrapolation(x, 46 | vmin, 47 | vmax): 48 | """Returns x normalized between [vmin, vmax], with possible extrapolation.""" 49 | if vmax <= vmin: 50 | raise ValueError('Interval must be such that vmax > vmin.') 51 | else: 52 | return (x - vmin) / (vmax - vmin) 53 | 54 | 55 | def undo_linear_rescale_with_extrapolation(x: float, vmin: float, 56 | vmax: float) -> float: 57 | """Computes the input of linear_rescale_with_extrapolation given output.""" 58 | if vmax <= vmin: 59 | raise ValueError('Interval must be such that vmax > vmin.') 60 | return vmin + x * (vmax - vmin) 61 | 62 | 63 | def linear_rescale_with_saturation(x: float, vmin: float, vmax: float) -> float: 64 | """Returns x normalized in [0, 1].""" 65 | y = linear_rescale_with_extrapolation(x, vmin, vmax) 66 | return np.clip(y, 0.0, 1.0).item() 67 | 68 | 69 | @typing.overload 70 | def squash_to_unit_interval(x: np.ndarray, constant: float) -> np.ndarray: 71 | ... 72 | 73 | 74 | @typing.overload 75 | def squash_to_unit_interval(x: float, constant: float) -> float: 76 | ... 77 | 78 | 79 | def squash_to_unit_interval(x, constant): 80 | """Scales non-negative x to be in range [0, 1], with a squash.""" 81 | if constant <= 0: 82 | raise ValueError('Squash constant must be greater than zero.') 83 | if _contains_negative_values(x): 84 | raise ValueError('Squash can only be performed on non-negative values.') 85 | return x / (x + constant) 86 | 87 | 88 | def undo_squash_to_unit_interval(x: float, constant: float) -> float: 89 | """Computes the input value of squash_to_unit_interval given the output.""" 90 | if constant <= 0: 91 | raise ValueError('Squash constant must be greater than zero.') 92 | if 0 > x >= 1: 93 | raise ValueError('Undo squash can only be performed on a value in [0, 1).') 94 | return (x * constant) / (1 - x) 95 | -------------------------------------------------------------------------------- /balloon_learning_environment/utils/wind.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility functions for evaluating a wind field.""" 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | import numpy as np 21 | import scipy.spatial 22 | 23 | 24 | def is_station_keeping_winds(wind_column: np.ndarray) -> bool: 25 | """Determines if a wind column supports station keeping winds. 26 | 27 | We are looking for winds in multiple directions so the balloon can change 28 | altitude and head back (so to speak) towards the target. This corresponds to 29 | the origin sitting within the convex hull of a column of wind vectors. 30 | 31 | Args: 32 | wind_column: A column of (u, v) wind vectors. 33 | 34 | Returns: 35 | yes or no 36 | """ 37 | 38 | hull = scipy.spatial.ConvexHull(wind_column) 39 | support = [wind_column[i, :] for i in hull.vertices] 40 | hull = scipy.spatial.Delaunay(support) 41 | return hull.find_simplex(np.zeros(2)) >= 0 42 | 43 | 44 | @jax.jit 45 | def wind_field_speeds(wind_field: jnp.ndarray) -> jnp.ndarray: 46 | """Returns the wind speed throughout the field. 47 | 48 | Args: 49 | wind_field: A 4D wind field with u, v components. 50 | 51 | Returns: 52 | A 4D array of speeds at the same grid points. 53 | """ 54 | 55 | u = wind_field[:, :, :, :, 0] 56 | v = wind_field[:, :, :, :, 1] 57 | return jnp.sqrt(u * u + v * v) 58 | 59 | 60 | @jax.jit 61 | def mean_speed_in_wind_field(wind_field: jnp.ndarray) -> float: 62 | """Returns the mean wind speed throughout the field. 63 | 64 | Args: 65 | wind_field: A 4D wind field with u, v components. 66 | 67 | Returns: 68 | The mean wind speed. 69 | """ 70 | 71 | return wind_field_speeds(wind_field).mean() 72 | -------------------------------------------------------------------------------- /balloon_learning_environment/utils/wind_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for balloon_learning_environment.utils.wind.""" 17 | 18 | from absl.testing import absltest 19 | from balloon_learning_environment.utils import wind 20 | import jax.numpy as jnp 21 | import numpy as np 22 | 23 | 24 | class WindTest(absltest.TestCase): 25 | 26 | def test_is_station_keeping_winds_postive_case(self): 27 | self.assertTrue(wind.is_station_keeping_winds(np.array([ 28 | [1, 0], [-1, 0], [0, -1], [0, 1]]))) 29 | 30 | def test_is_station_keeping_winds_negative_case(self): 31 | self.assertFalse(wind.is_station_keeping_winds(np.array([ 32 | [1, 1.2], [2, 3.2], [10.2, 33.4]]))) 33 | 34 | def test_wind_field_speeds_works(self): 35 | u = 2.1 * jnp.ones((5, 6, 7, 8)) 36 | v = 3.2 * jnp.ones((5, 6, 7, 8)) 37 | field = jnp.stack([u, v], axis=-1) 38 | self.assertAlmostEqual(wind.wind_field_speeds(field).mean().item(), 39 | jnp.sqrt(2.1 * 2.1 + 3.2 * 3.2).item(), 40 | places=3) 41 | self.assertEqual(wind.wind_field_speeds(field).shape, (5, 6, 7, 8)) 42 | 43 | def test_mean_speed_in_wind_field_works(self): 44 | u = 2.1 * jnp.ones((5, 6, 7, 8)) 45 | v = 3.2 * jnp.ones((5, 6, 7, 8)) 46 | field = jnp.stack([u, v], axis=-1) 47 | self.assertAlmostEqual(wind.mean_speed_in_wind_field(field).item(), 48 | jnp.sqrt(2.1 * 2.1 + 3.2 * 3.2).item(), 49 | places=3) 50 | 51 | 52 | if __name__ == '__main__': 53 | absltest.main() 54 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # BLE Documentation 2 | 3 | To build the documentation, make sure you have the dependencies installed: 4 | 5 | ``` 6 | pip install -r requirements.txt 7 | ``` 8 | 9 | If you are building locally you'll also need the sphinx_rtd_theme: 10 | 11 | ``` 12 | pip install sphinx_rtd_theme 13 | ``` 14 | 15 | Then: 16 | 17 | ``` 18 | make html 19 | ``` 20 | 21 | To clean the documentation: 22 | 23 | ``` 24 | make clean 25 | ``` 26 | -------------------------------------------------------------------------------- /docs/about.rst: -------------------------------------------------------------------------------- 1 | About the Balloon Learning Environment 2 | ====================================== 3 | 4 | This page gives a simple background into the problem of flying 5 | stratospheric balloons and station keeping. This is intended to 6 | be an introduction, and doesn't discuss all the nuances of the simulator 7 | or real-world complexities. 8 | 9 | Stratospheric Balloons 10 | ###################### 11 | 12 | The type of balloon that we consider in the BLE is pictured below. 13 | These balloons have an outer envelope and an inner envelope (ballonet). 14 | The ballonet is filled with a buoyant gas (for example, helium). 15 | The outer envelope contains air which acts as ballast. 16 | An Altitude Control System (ACS) is able to pump air in or out of the envelope. 17 | When air is pumped into the balloon the average density 18 | of the gases in the envelope increases, which in turn makes the balloon 19 | descend. On the other hand, if air is pumped out of the balloon then the 20 | average gas density decreases, resulting in the balloon ascending. 21 | 22 | The balloons are also equipped with a battery to power the ACS, and solar 23 | panels to recharge the batteries. Power is used by communication systems 24 | on the balloon, so battery power is constantly draining when the solar panels 25 | aren't in use. This means that the balloons in the BLE are constrained at 26 | night—they cannot constantly use the ACS without running out of power. 27 | The energy required for running the ACS is asymmetric; 28 | the energy needed to release air from the envelope (ascend) is negligable. 29 | 30 | .. image:: imgs/balloon_schematic.jpg 31 | 32 | Navigating a Windfield 33 | ###################### 34 | 35 | The balloons in the BLE have no way of moving themselves laterally. They 36 | are only capable of moving up, down, or staying at the current altitude. 37 | To navigate they must "surf" the wind to get where they need to go. The 38 | balloons are flown in a 4d windfield, where each different x, y, altitude, time 39 | position gives a different wind vector. A balloon must learn how to navigate 40 | this windfield to achieve its goal. 41 | 42 | .. image:: imgs/wind_field.gif 43 | 44 | Station-Keeping 45 | ############### 46 | 47 | The goal of station-keeping is to remain within a fixed distance of a 48 | ground station. This distance is only measured in the x, y plane, i.e. the 49 | altitude of the balloon is not taken into account. In the BLE, the task 50 | is to keep a balloon within 50km of the ground station. We call the proportion 51 | of time that a balloon remains within 50km of the ground station TWR50 (for 52 | Time Within a Radius of 50km). In the BLE, each episode lasts two days. 53 | 54 | .. image:: imgs/station_keeping.gif 55 | 56 | Failure Modes 57 | ############# 58 | 59 | A balloon can have a critical failure by running out of power, flying too low, 60 | or having a superpressure that is too high or too low. 61 | Each of these are partially protected by a 62 | safety layer, but in extreme conditions there can still be critical failures. 63 | -------------------------------------------------------------------------------- /docs/benchmarks.rst: -------------------------------------------------------------------------------- 1 | Benchmark Results 2 | ================= 3 | 4 | This page will be udpated soon with even more data 📈 5 | 6 | The following graph shows evaluation results on the "small eval" evaluation 7 | suite throughout training for the DQN, quantile, and finetuned Perciatelli44 8 | agents. The horizontal lines are evaluation results for "big eval" for 9 | Perciatelli44 and station seeker. 10 | 11 | .. image:: imgs/training_curve.jpg -------------------------------------------------------------------------------- /docs/changelist.rst: -------------------------------------------------------------------------------- 1 | Change List 2 | =========== 3 | 4 | v1.0.2 5 | ###### 6 | 7 | - Added GridBasedWindField and GenerativeWindFieldSampler. 8 | - Deprecated GenerativeWindField in favor of GridBasedWindField and 9 | GenerativeWindFieldSampler. 10 | - Bug fix: previously the train script would load the latest checkpoint 11 | when restarting but then resume from the previous iteration. It is now 12 | fixed, so if reloading checkpoint i, the agent will continue working 13 | on iteration i + 1. 14 | - Vectorized wind column feature calculations. This gives about a 16% speed 15 | increase. 16 | - Cleanups in balloon.py, by removing staticfunctions. 17 | - Moved wind field creation to run_helpers, since it is common for 18 | training and evaluation. 19 | - Added a flag for calculating the flight path to eval_lib, to allow for 20 | faster evaluation if you don't need flight paths. 21 | - Improvements to AcmeEvalAgent by making it more configurable. 22 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Configuration file for the Sphinx documentation builder. 17 | # 18 | # This file only contains a selection of the most common options. For a full 19 | # list see the documentation: 20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 21 | 22 | # -- Path setup -------------------------------------------------------------- 23 | 24 | # If extensions (or modules to document with autodoc) are in another directory, 25 | # add these directories to sys.path here. If the directory is relative to the 26 | # documentation root, use os.path.abspath to make it absolute, like shown here. 27 | 28 | import os 29 | import sys 30 | sys.path.insert(0, os.path.abspath('..')) 31 | 32 | 33 | # -- Project information ----------------------------------------------------- 34 | 35 | project = 'Balloon Learning Environment' 36 | copyright = '2021, The Balloon Learning Environment Authors' 37 | author = 'The Balloon Learning Environment Authors' 38 | 39 | 40 | # -- General configuration --------------------------------------------------- 41 | master_doc = 'index' 42 | 43 | html_logo = 'imgs/ble_logo_small.png' 44 | 45 | # Add any Sphinx extension module names here, as strings. They can be 46 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 47 | # ones. 48 | extensions = [ 49 | 'sphinx.ext.autodoc', 50 | 'sphinx.ext.autodoc.typehints', 51 | 'sphinx.ext.autosummary', 52 | 'sphinx.ext.napoleon', 53 | 'sphinx.ext.viewcode', 54 | 'myst_parser', 55 | ] 56 | 57 | # This simplifies the documentation by only displaying the class name 58 | # without the module path. For example: 59 | # balloon_learning_environment.env.balloon_env.BalloonEnv -> BalloonEnv. 60 | add_module_names = False 61 | 62 | # This orders the members of a class by the source code order. 63 | autodoc_member_order = 'bysource' 64 | 65 | # This removes type hints from the function definition and moves them 66 | # to the description below. 67 | autodoc_typehints = 'description' 68 | 69 | # Add any paths that contain templates here, relative to this directory. 70 | templates_path = ['_templates'] 71 | 72 | # List of patterns, relative to source directory, that match files and 73 | # directories to ignore when looking for source files. 74 | # This pattern also affects html_static_path and html_extra_path. 75 | exclude_patterns = [ 76 | '_build', 77 | 'Thumbs.db', 78 | '.DS_Store', 79 | 'README.md', 80 | ] 81 | 82 | 83 | # -- Options for HTML output ------------------------------------------------- 84 | 85 | # The theme to use for HTML and HTML Help pages. See the documentation for 86 | # a list of builtin themes. 87 | # 88 | html_theme = 'sphinx_rtd_theme' 89 | 90 | # Add any paths that contain custom static files (such as style sheets) here, 91 | # relative to this directory. They are copied after the builtin static files, 92 | # so a file named "default.css" will overwrite the builtin "default.css". 93 | html_static_path = [] 94 | -------------------------------------------------------------------------------- /docs/environment.rst: -------------------------------------------------------------------------------- 1 | Using and Configuring the Environment 2 | ===================================== 3 | 4 | Basic Usage 5 | ########### 6 | 7 | The main entrypoint to the BLE for most users is the gym environment. To use 8 | the environment, import the balloon environment and use gym to create it: 9 | 10 | .. code-block:: python 11 | 12 | import balloon_learning_environment.env.balloon_env # Registers the environment. 13 | import gym 14 | 15 | env = gym.make('BalloonLearningEnvironment-v0') 16 | 17 | 18 | This will give you a new 19 | `BalloonEnv `_ 20 | object that follows the gym environment interface. Before we run the 21 | environment we can inspect the observation and action spaces: 22 | 23 | .. code-block:: python 24 | 25 | >>> print(env.observation_space) 26 | Box([0. 0. 0. ... 0. 0. 0.], [1. 1. 1. ... 1. 1. 1.], (1099,), float32) 27 | >>> print(env.action_space) 28 | Discrete(3) 29 | 30 | 31 | Here we can see that the observation space is a 1099 dimensional array, 32 | and the action space has 3 discrete actions. We can use the environment 33 | as follows: 34 | 35 | .. code-block:: python 36 | 37 | env.seed(0) 38 | observation_0 = env.reset() 39 | observation_1, reward, is_terminal, info = env.step(0) 40 | 41 | 42 | In this snippet, we seeded the environment to give it a deterministic 43 | initialization. This is useful for replicating results (for example, in 44 | evaluation), but most of the time you'll want to skip this line to have 45 | a random initialization. After seeding the environment we reset it and 46 | stepped once with action 0. 47 | 48 | We expect the observations to have the shape specified by observation_space: 49 | 50 | .. code-block:: python 51 | 52 | >>> print(type(observation_0), observation_0.shape) 53 | (1099,) 54 | 55 | 56 | The reward, is_terminal, and info objects are as follows: 57 | 58 | .. code-block:: python 59 | 60 | >>> print(reward, is_terminal, info, sep='\n') 61 | 0.26891435801077535 62 | False 63 | {'out_of_power': False, 'envelope_burst': False, 'zeropressure': False, 'time_elapsed': datetime.timedelta(seconds=180)} 64 | 65 | 66 | These should be enough to start training an RL agent. 67 | 68 | Configuring the Environment 69 | ########################### 70 | 71 | The environment may be configured to give custom behavior. To see all 72 | options for configuring the environment, see the 73 | `BalloonEnv `_ 74 | constructor. Here, we highlight important options. 75 | 76 | First, the 77 | `FeatureConstructor `_ 78 | class may be swapped out. The feature constructor receives observations 79 | from the simulator at each step, and returns features when required. This 80 | setup allows a feature constructor to maintain its own state, and use the 81 | simulator history to create a feature vector. The default feature constructor 82 | is the 83 | `PerciatelliFeatureConstructor `_. 84 | 85 | The reward function can also be swapped out. The default reward function, 86 | `perciatelli_reward_function `_ 87 | gives a reward of 1.0 as long as the agent is in the stationkeeping readius. 88 | The reward decays exponentially outside of this radius. 89 | 90 | .. image:: imgs/reward_function.png 91 | 92 | -------------------------------------------------------------------------------- /docs/imgs/balloon_schematic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/balloon_schematic.jpg -------------------------------------------------------------------------------- /docs/imgs/ble_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/ble_logo.png -------------------------------------------------------------------------------- /docs/imgs/ble_logo_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/ble_logo_small.png -------------------------------------------------------------------------------- /docs/imgs/reward_function.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/reward_function.png -------------------------------------------------------------------------------- /docs/imgs/station_keeping.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/station_keeping.gif -------------------------------------------------------------------------------- /docs/imgs/training_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/training_curve.jpg -------------------------------------------------------------------------------- /docs/imgs/wind_field.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/wind_field.gif -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Balloon Learning Environment documentation master file, created by 2 | sphinx-quickstart on Wed Dec 1 18:26:08 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Balloon Learning Environment 7 | ============================ 8 | 9 | .. note:: 10 | We are working hard to improve the Balloon Learning Environment 11 | documentation. Please let us know if there is a page you'd like to see here! 12 | 13 | The Balloon Learning Environment (BLE) is a simulator for stratospheric 14 | balloons. It is designed as a challenge environment for deep reinforcement 15 | learning algorithms. 16 | 17 | .. toctree:: 18 | :maxdepth: 1 19 | :caption: Getting Started: 20 | 21 | about 22 | getting_started 23 | new_agent 24 | environment 25 | benchmarks 26 | changelist 27 | 28 | 29 | Indices and tables 30 | ================== 31 | 32 | * :ref:`genindex` 33 | * :ref:`modindex` 34 | * :ref:`search` 35 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | dopamine-rl==4.0.1 3 | flax==0.3.6 4 | gin-config==0.5.0 5 | jax==0.2.25 6 | jaxlib==0.1.74 7 | gym==0.21.0 8 | myst-parser==0.15.2 9 | numpy==1.21.4 10 | opensimplex==0.3 11 | s2sphere==0.2.5 12 | sklearn==0.0 13 | tensorflow==2.7.0 14 | tensorflow-probability==0.15.0 15 | transitions==0.8.10 16 | -------------------------------------------------------------------------------- /docs/src/agents.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.agents 2 | =================================== 3 | 4 | The following agents are available in the Balloon Learning Environment. 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | :caption: balloon_learning_environment.agents: 9 | 10 | agents/agent 11 | agents/dqn_agent 12 | agents/perciatelli44 13 | agents/quantile_agent 14 | agents/station_seeker_agent -------------------------------------------------------------------------------- /docs/src/agents/agent.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.agents.agent 2 | ========================================= 3 | 4 | .. currentmodule:: balloon_learning_environment.agents.agent 5 | 6 | .. autoclass:: balloon_learning_environment.agents.agent.Agent 7 | :members: 8 | 9 | .. automethod:: __init__ 10 | 11 | .. autoclass:: balloon_learning_environment.agents.agent.RandomAgent 12 | :members: 13 | 14 | .. autoclass:: balloon_learning_environment.agents.agent.AgentMode 15 | :members: 16 | -------------------------------------------------------------------------------- /docs/src/agents/dqn_agent.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.agents.dqn_agent 2 | ============================================= 3 | 4 | .. currentmodule:: balloon_learning_environment.agents.dqn_agent 5 | 6 | .. autoclass:: balloon_learning_environment.agents.dqn_agent.DQNAgent 7 | :members: 8 | 9 | .. automethod:: __init__ -------------------------------------------------------------------------------- /docs/src/agents/perciatelli44.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.agents.perciatelli44 2 | ================================================= 3 | 4 | .. currentmodule:: balloon_learning_environment.agents.perciatelli44 5 | 6 | .. autoclass:: balloon_learning_environment.agents.perciatelli44.Perciatelli44 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/src/agents/quantile_agent.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.agents.quantile_agent 2 | ================================================== 3 | 4 | .. currentmodule:: balloon_learning_environment.agents.quantile_agent 5 | 6 | .. autoclass:: balloon_learning_environment.agents.quantile_agent.QuantileAgent 7 | :members: 8 | 9 | .. automethod:: __init__ -------------------------------------------------------------------------------- /docs/src/agents/station_seeker_agent.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.agents.station_seeker_agent 2 | ======================================================== 3 | 4 | .. currentmodule:: balloon_learning_environment.agents.station_seeker_agent 5 | 6 | .. autoclass:: balloon_learning_environment.agents.station_seeker_agent.StationSeekerAgent 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/src/balloon_env.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.env.balloon_env 2 | ============================================ 3 | 4 | .. currentmodule:: balloon_learning_environment.env.balloon_env 5 | 6 | .. autoclass:: balloon_learning_environment.env.balloon_env.BalloonEnv 7 | :members: 8 | 9 | .. automethod:: __init__ 10 | 11 | .. autofunction:: balloon_learning_environment.env.balloon_env.perciatelli_reward_function 12 | 13 | .. image:: ../imgs/reward_function.png 14 | 15 | -------------------------------------------------------------------------------- /docs/src/env.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.env 2 | ================================ 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :caption: balloon_learning_environment.env: 7 | 8 | balloon_env 9 | features -------------------------------------------------------------------------------- /docs/src/eval_lib.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.eval.eval_lib 2 | ========================================== 3 | 4 | .. autofunction:: balloon_learning_environment.eval.eval_lib.eval_agent 5 | -------------------------------------------------------------------------------- /docs/src/features.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.env.features 2 | ========================================= 3 | 4 | .. currentmodule:: balloon_learning_environment.env.features 5 | 6 | .. autoclass:: balloon_learning_environment.env.features.FeatureConstructor 7 | :members: 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. autoclass:: balloon_learning_environment.env.features.PerciatelliFeatureConstructor 13 | :members: 14 | 15 | .. automethod:: __init__ 16 | 17 | 18 | .. autoclass:: balloon_learning_environment.env.features.NamedPerciatelliFeatures 19 | :members: 20 | -------------------------------------------------------------------------------- /docs/src/metrics.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.metrics 2 | ==================================== 3 | 4 | The following collectors are available in the Balloon Learning Environment. 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | :caption: balloon_learning_environment.metrics: 9 | 10 | metrics/collector 11 | metrics/collector_dispatcher 12 | metrics/statistics_instance 13 | metrics/console_collector 14 | metrics/pickle_collector 15 | metrics/tensorboard_collector 16 | -------------------------------------------------------------------------------- /docs/src/metrics/collector.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.metrics.collector 2 | ============================================== 3 | 4 | .. currentmodule:: balloon_learning_environment.metrics.collector 5 | 6 | .. autoclass:: balloon_learning_environment.metrics.collector.Collector 7 | :members: 8 | 9 | .. automethod:: __init__ 10 | -------------------------------------------------------------------------------- /docs/src/metrics/collector_dispatcher.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.metrics.collector_dispatcher 2 | ========================================================= 3 | 4 | .. currentmodule:: balloon_learning_environment.metrics.collector_dispatcher 5 | 6 | .. autoclass:: balloon_learning_environment.metrics.collector_dispatcher.CollectorDispatcher 7 | :members: 8 | 9 | .. automethod:: __init__ 10 | -------------------------------------------------------------------------------- /docs/src/metrics/console_collector.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.metrics.console_collector 2 | ====================================================== 3 | 4 | .. currentmodule:: balloon_learning_environment.metrics.console_collector 5 | 6 | .. autoclass:: balloon_learning_environment.metrics.console_collector.ConsoleCollector 7 | :members: 8 | 9 | .. automethod:: __init__ 10 | -------------------------------------------------------------------------------- /docs/src/metrics/pickle_collector.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.metrics.pickle_collector 2 | ===================================================== 3 | 4 | .. currentmodule:: balloon_learning_environment.metrics.pickle_collector 5 | 6 | .. autoclass:: balloon_learning_environment.metrics.pickle_collector.PickleCollector 7 | :members: 8 | 9 | .. automethod:: __init__ 10 | -------------------------------------------------------------------------------- /docs/src/metrics/statistics_instance.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.metrics.statistics_instance 2 | ========================================================= 3 | 4 | .. currentmodule:: balloon_learning_environment.metrics.statistics_instance 5 | 6 | .. autoclass:: balloon_learning_environment.metrics.statistics_instance.StatisticsInstance 7 | :members: 8 | -------------------------------------------------------------------------------- /docs/src/metrics/tensorboard_collector.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.metrics.tensorboard_collector 2 | ========================================================== 3 | 4 | .. currentmodule:: balloon_learning_environment.metrics.tensorboard_collector 5 | 6 | .. autoclass:: balloon_learning_environment.metrics.tensorboard_collector.TensorboardCollector 7 | :members: 8 | 9 | .. automethod:: __init__ 10 | -------------------------------------------------------------------------------- /docs/src/train_lib.rst: -------------------------------------------------------------------------------- 1 | balloon_learning_environment.train_lib 2 | ====================================== 3 | 4 | .. autofunction:: balloon_learning_environment.train_lib.run_training_loop -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.15.0 2 | ale-py==0.7.3 3 | astunparse==1.6.3 4 | cached-property==1.5.2 5 | cachetools==4.2.4 6 | certifi==2021.10.8 7 | charset-normalizer==2.0.7 8 | chex==0.0.8 9 | clang==5.0 10 | cloudpickle==2.0.0 11 | cycler==0.11.0 12 | decorator==5.1.0 13 | dm-tree==0.1.6 14 | dopamine-rl==4.0.0 15 | flatbuffers==1.12 16 | flax==0.3.6 17 | future==0.18.2 18 | gast==0.4.0 19 | gin==0.1.6 20 | gin-config==0.5.0 21 | google-auth==1.35.0 22 | google-auth-oauthlib==0.4.6 23 | google-pasta==0.2.0 24 | grpcio==1.41.1 25 | gym==0.21.0 26 | h5py==3.1.0 27 | idna==3.3 28 | importlib-metadata==4.8.1 29 | importlib-resources==5.4.0 30 | jax==0.3.0 31 | jaxlib==0.3.0 32 | joblib==1.1.0 33 | keras==2.7.0 34 | Keras-Preprocessing==1.1.2 35 | kiwisolver==1.3.2 36 | libclang==12.0.0 37 | Markdown==3.3.4 38 | matplotlib==3.4.3 39 | msgpack==1.0.2 40 | numpy==1.19.5 41 | oauthlib==3.1.1 42 | opencv-python==4.5.4.58 43 | opensimplex==0.3 44 | opt-einsum==3.3.0 45 | optax==0.0.9 46 | pandas==1.3.4 47 | Pillow==8.4.0 48 | protobuf==3.19.1 49 | pyasn1==0.4.8 50 | pyasn1-modules==0.2.8 51 | pygame==2.0.3 52 | pyparsing==3.0.4 53 | python-dateutil==2.8.2 54 | pytz==2021.3 55 | requests==2.26.0 56 | requests-oauthlib==1.3.0 57 | rsa==4.7.2 58 | s2sphere==0.2.5 59 | scikit-learn==1.0.1 60 | scipy==1.7.1 61 | six==1.15.0 62 | sklearn==0.0 63 | tensorboard==2.6.0 64 | tensorboard-data-server==0.6.1 65 | tensorboard-plugin-wit==1.8.0 66 | tensorflow==2.7.0rc1 67 | tensorflow-estimator==2.7.0 68 | tensorflow-io-gcs-filesystem==0.21.0 69 | tensorflow-probability==0.16.0 70 | termcolor==1.1.0 71 | tf-slim==1.1.0 72 | tfp-nightly==0.15.0.dev20211104 73 | threadpoolctl==3.0.0 74 | toolz==0.11.1 75 | transitions==0.8.10 76 | typing-extensions==3.7.4.3 77 | urllib3==1.26.7 78 | Werkzeug==2.0.2 79 | wrapt==1.12.1 80 | zipp==3.6.0 81 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Balloon Learning Environment Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Setup file for installing the BLE.""" 17 | import os 18 | import pathlib 19 | import setuptools 20 | from setuptools.command import build_py 21 | from setuptools.command import develop 22 | 23 | current_directory = pathlib.Path(__file__).parent 24 | description = (current_directory / 'README.md').read_text() 25 | 26 | core_requirements = [ 27 | 'absl-py', 28 | 'dopamine-rl >= 4.0.0', 29 | 'flax', 30 | 'gin-config', 31 | 'gym', 32 | 'jax >= 0.2.28', 33 | 'jaxlib >= 0.1.76', 34 | 'opensimplex <= 0.3.0', 35 | 's2sphere', 36 | 'scikit-learn', 37 | 'tensorflow', 38 | 'tensorflow-probability', 39 | 'transitions', 40 | ] 41 | 42 | acme_requirements = [ 43 | 'dm-acme', 44 | 'dm-haiku', 45 | 'dm-reverb', 46 | 'dm-sonnet', 47 | 'rlax', 48 | 'xmanager', 49 | ] 50 | 51 | 52 | def generate_requirements_file(path=None): 53 | """Generates requirements.txt file needed for running Acme. 54 | 55 | It is used by Launchpad GCP runtime to generate Acme requirements to be 56 | installed inside the docker image. Acme itself is not installed from pypi, 57 | but instead sources are copied over to reflect any local changes made to 58 | the codebase. 59 | Args: 60 | path: path to the requirements.txt file to generate. 61 | """ 62 | if not path: 63 | path = os.path.join(os.path.dirname(__file__), 'acme_requirements.txt') 64 | with open(path, 'w') as f: 65 | for package in set(core_requirements + acme_requirements): 66 | f.write(f'{package}\n') 67 | 68 | 69 | class BuildPy(build_py.build_py): 70 | 71 | def run(self): 72 | generate_requirements_file() 73 | build_py.build_py.run(self) 74 | 75 | 76 | class Develop(develop.develop): 77 | 78 | def run(self): 79 | generate_requirements_file() 80 | develop.develop.run(self) 81 | 82 | cmdclass = { 83 | 'build_py': BuildPy, 84 | 'develop': Develop, 85 | } 86 | 87 | entry_points = { 88 | 'gym.envs': [ 89 | '__root__=balloon_learning_environment.env.gym:register_env' 90 | ] 91 | } 92 | 93 | 94 | setuptools.setup( 95 | name='balloon_learning_environment', 96 | long_description=description, 97 | long_description_content_type='text/markdown', 98 | version='1.0.2', 99 | cmdclass=cmdclass, 100 | packages=setuptools.find_packages(), 101 | install_requires=core_requirements, 102 | extras_require={ 103 | 'acme': acme_requirements, 104 | }, 105 | package_data={ 106 | '': ['*.msgpack', '*.pb', '*.gin'], 107 | }, 108 | entry_points=entry_points, 109 | python_requires='>=3.7', 110 | ) 111 | -------------------------------------------------------------------------------- /style_guidelines.md: -------------------------------------------------------------------------------- 1 | # Style Guide 2 | This document outlines guidelines for coding style in this repo. 3 | 4 | ## Directory structure 5 | The top-level directory should only contain Python files necessary for executing 6 | the main binary (`train.py`). 7 | 8 | The **env** directory contains all files necessary for running the simulator, 9 | including the balloon simulator, the wind vector model, the feature vector 10 | constructor, and the gym wrapper. 11 | 12 | The **agents** directory contains all files necessary for defining and training 13 | agents which will control the balloon. Note that this may include code necessary 14 | for checkpointing. 15 | 16 | The **metrics** directory contains all files necessary for logging and reporting 17 | performance metrics. 18 | 19 | ## Typing 20 | All variables and functions should be properly typed. We adhere to: 21 | https://docs.python.org/3/library/typing.html 22 | 23 | ## Member variables and methods 24 | In classes, member variables are either _public_ or _private_. We do not make 25 | use of setters and getters. All private members will have their name prefixed by 26 | `_` and should not be accessed from outside the class. 27 | 28 | ## Use of `@property` decorator 29 | We discourage the use of `@property` to avoid confusion, unless required to 30 | conform to an external API specification (for example, for Gym). In these cases, 31 | the reason for its use should be documented above the method. 32 | 33 | ## Abstract classes 34 | We encourage the use of abstract classes to avoid code duplication. Any abstract 35 | class must subclass `abc.ABC` and decorate required methods with 36 | `@abc.abstractmethod`. 37 | 38 | ## Static methods 39 | Functions that are only called within a class but do not access any class 40 | members (e.g. `self.`) must be made static by decorating them with 41 | `@staticmethod`. 42 | 43 | # Data Classes 44 | The `@dataclasses.dataclass` decorator should be used whenever the `__init__` 45 | method would only be setting values based on the constructor parameters. 46 | 47 | # Floats 48 | Prefer to use `0.0` over `0.`, as the former is more obviously a float. 49 | 50 | ## Gin config 51 | We make use of gin config for parameter injection. Rather than passing 52 | parameters from flags all the way to the method/class where it will be used, the 53 | gin-configurable parameters are specified via gin config files (or gin 54 | bindings). 55 | 56 | The guidelines for gin-configurable parameters are: 57 | 1. Only set variables in a gin config which have a default value of 58 | gin.REQUIRED. 59 | 1. Only keyword-only args can be set with a gin config. 60 | 61 | For example in the signature below: 62 | ``` 63 | def f(x: float, y: float = 0.0 *, z: float = 0.0, alpha: float = gin.REQUIRED) 64 | ``` 65 | 66 | the only variable that can (and must) be set via a gin config is `alpha`. 67 | 68 | We accept two gin-config files: one for specifying _environment_ 69 | parameters (via the `--environment_gin_file` flag), and one for specifying 70 | _agent_ parameters (via the `--agent_gin_file` flag). Any other parameters (or 71 | variations from those specified in the config files) can be specified via 72 | the `--gin_bindings` flags. 73 | 74 | For more information see: 75 | https://github.com/google/gin-config 76 | --------------------------------------------------------------------------------